9 min read1 day ago
–
Introduction
At the heart of every deep learning framework lies an elegant concept that makes training neural networks not just possible, but efficient: the computational graph. Whether you’re using PyTorch, TensorFlow, or any other modern deep learning framework, understanding how computational graphs work is essential to mastering deep learning.
In this tutorial, we’ll explore how deep learning frameworks use computational graphs to compute derivatives efficiently, enabling us to train models with millions — or even billions — of parameters. We’ll start with simple examples and gradually build up to understanding complex neural networks.
The Problem: Computing Derivatives Efficiently
When training neural networks, we need to compute gradients of a l…
9 min read1 day ago
–
Introduction
At the heart of every deep learning framework lies an elegant concept that makes training neural networks not just possible, but efficient: the computational graph. Whether you’re using PyTorch, TensorFlow, or any other modern deep learning framework, understanding how computational graphs work is essential to mastering deep learning.
In this tutorial, we’ll explore how deep learning frameworks use computational graphs to compute derivatives efficiently, enabling us to train models with millions — or even billions — of parameters. We’ll start with simple examples and gradually build up to understanding complex neural networks.
The Problem: Computing Derivatives Efficiently
When training neural networks, we need to compute gradients of a loss function with respect to model parameters. For a network with thousands or millions of parameters, computing these derivatives manually would be impractical. Deep learning frameworks solve this problem using computational graphs combined with automatic differentiation.
What is a Computational Graph?
A computational graph is a directed graph where:
- Nodes represent variables or intermediate values
- Edges represent operations that transform these values
The beauty of computational graphs is that they allow us to break down complex functions into simple operations, making it straightforward to compute derivatives using the chain rule.
A Simple Example: f(x, y, z) = x × y × z
Let’s start with a concrete example. Consider the function:
f(x, y, z) = x × y × z
Press enter or click to view image in full size
Figure 1: The Computational Graph — Forward Pass
In the forward pass, we construct the graph by traversing from input to output:
- Start with three input nodes:
x,y, andz - Create an intermediate node by computing
x × y - Create the final output node by computing
(x × y) × z
The forward pass computes the function value by moving through the graph from inputs to output. This is represented by the blue arrow in Figure 1, showing the direction of computation.
# Forward pass examplex = 2y = 3z = 4intermediate = x * y # First node: 6output = intermediate * z # Final node: 24
Computing Derivatives: The Backward Pass
Now comes the powerful part: computing derivatives. Once we have the computational graph, we can efficiently compute the derivative of the output with respect to any input by traversing the graph backward.
Step 1: Derivative with Respect to z
Press enter or click to view image in full size
Figure 2: Computing ∂f/∂z
Let’s compute ∂f/∂z. Looking at our function f(x, y, z) = x × y × z, if we take the derivative with respect to z, we simply remove z from the product:
In Figure 2, we see that the node z is highlighted in green with the computed gradient. This gradient is stored at the z node for later use.
Step 2: Derivative with Respect to the Intermediate Node
Press enter or click to view image in full size
Figure 3: Computing ∂f/∂(x × y)
Next, let’s compute the derivative with respect to the intermediate node (x × y). Again, for a simple product, we just remove that variable:
Figure 3 shows this gradient stored at the intermediate node x × y.
Step 3: Applying the Chain Rule for x
Press enter or click to view image in full size
Figure 4: Computing ∂f/∂x Using Chain Rule
Now here’s where it gets interesting. To find ∂f/∂x, we need to apply the chain rule:
Breaking this down:
- We already computed ∂f/∂(x×y) = z (from Step 2)
- ∂(x×y)/∂x = y (derivative of x×y with respect to x)
- Therefore: ∂f/∂x = z × y
The computational graph helps us visualize this chain rule application. We “pull” the gradient from the next node in the graph and multiply it by the local derivative.
Get Utkarsh Mittal’s stories in your inbox
Join Medium for free to get updates from this writer.
Figure 4 illustrates this beautifully, showing how the gradient flows backward through the multiplication operation, with the node x now highlighted in green containing the final gradient.
Step 4: Derivative with Respect to y
Press enter or click to view image in full size
Figure 5: Complete Backward Pass
Similarly, for ∂f/∂y, we apply the chain rule:
This gives us:
- ∂f/∂(x×y) = z (already computed)
- ∂(x×y)/∂y = x
- Therefore: ∂f/∂y = z × x
The Complete Backward Pass
Press enter or click to view image in full size
Figure 6: Full Backward Pass Visualization
Figure 6 shows the complete backward pass with all gradients computed and stored at their respective nodes:
- ∂f/∂x = z × y
- ∂f/∂y = z × x
- ∂f/∂z = x × y
The backward pass moves in the opposite direction of the forward pass (indicated by the cyan arrow), computing and storing gradients at each node as it goes.
Implementing This in PyTorch
Let’s see how PyTorch handles all of this automatically.
import torch# Create tensors with gradient tracking enabledx = torch.tensor([2.0], requires_grad=True)y = torch.tensor([3.0], requires_grad=True)z = torch.tensor([4.0], requires_grad=True)# Perform the operation: f(x, y, z) = x * y * zf = x * y * z # f(x, y, z) = x × y × z# Compute gradientsf.backward()# Print the gradients# Gradients: df/dx = y * z, df/dy = x * z, df/dz = x * yprint(f"Gradient df/dx: {x.grad}") # Output should be 3.0 * 4.0 = 12print(f"Gradient df/dy: {y.grad}") # Output should be 2.0 * 4.0 = 8print(f"Gradient df/dz: {z.grad}") # Output should be 2.0 * 3.0 = 6
Output:
> Gradient df/dx: tensor([12.])> Gradient df/dy: tensor([8.])> Gradient df/dz: tensor([6.])
Let’s verify these results:
- ∂f/∂x = y × z = 3 × 4 = 12 ✓
- ∂f/∂y = x × z = 2 × 4 = 8 ✓
- ∂f/∂z = x × y = 2 × 3 = 6 ✓
The key takeaway: PyTorch automatically builds the computational graph during the forward pass and computes all gradients during the backward pass when we call .backward().
Scaling to Complex Networks: Computational Blocks
Real neural networks are much more complex than our simple multiplication example. Let’s see how the same principles apply to complex computational blocks.
Two Computational Blocks
Press enter or click to view image in full size
Figure 7: Computational Blocks
Consider a network with two computational blocks:
- Input
x→ Computational blockh(x)→ Intermediate outputy - Intermediate
y→ Computational blockf(y)→ Final outputz
This represents the forward pass: x → h(x) → y → f(y) → z
Computing Gradients Through Blocks
Press enter or click to view image in full size
Figure 8: Gradient at Intermediate Node y
To compute the gradient at the intermediate node y, we use backpropagation through the computational block f(y):
∇yz
This gradient is stored at node y, as shown in Figure 9 with the green highlighting.
Backpropagating to the Input
Press enter or click to view image in full size
Figure 9: Complete Backward Pass Through Blocks
To compute the gradient at the input x, we apply the chain rule:
Press enter or click to view image in full size
Breaking this down:
- We already computed ∇_y z (gradient of z with respect to y)
- We compute ∇_x y by backpropagating through block h(x)
- We multiply these together to get ∇_x z
Figure 10 illustrates the complete backward pass, showing how gradients flow from the output back to the input through both computational blocks.
A Concrete PyTorch Example with Functions
Let’s implement this with actual functions to make it concrete.
# Define the computational blocksdef h(x): return x**2 # h(x) = x²def f(y): return y + 1 # f(y) = y + 1# Define tensor with gradient trackingx = torch.tensor(2.0, requires_grad=True)# Define the computational blocksy = h(x) # Some function hz = f(y) # Some function f# Compute gradientsz.backward()# x.grad will contain ∇_{x}zprint("Gradient of z with respect to x:", x.grad)
Output:
> Gradient of z with respect to x: tensor(4.)
Let’s verify this manually:
- The complete function is: z = f(h(x)) = f(x²) = x² + 1
- Taking the derivative: dz/dx = 2x
- At x = 2: dz/dx = 2(2) = 4 ✓
Scaling to Deep Networks
Press enter or click to view image in full size
Figure 10: Multi-Layer Computational Graph
The real power of computational graphs becomes apparent when we scale to deep networks with many layers:
x → Block 1 → x₁ → Block 2 → x₂ → … → Block n → xₙ → Final Block → z
Forward Pass
During the forward pass, we compute the output by sequentially applying each computational block:
- Start with input x
- Apply Block 1 to get x₁
- Apply Block 2 to get x₂
- Continue through all blocks
- Get final output z
Backward Pass
For the backward pass, we compute gradients iteratively using the chain rule:
Press enter or click to view image in full size
As shown in Figure 10, the key insight is that computing derivatives through a complex network always reduces to computing simple products due to the chain rule. Each computational block only needs to know how to compute its local gradient, and the framework handles the rest.
Key Principles of Automatic Differentiation
From our exploration, we can extract several key principles:
1. Computational Graphs Enable Automatic Differentiation
By representing computations as graphs, frameworks can automatically compute derivatives without manual intervention. You simply define the forward pass, and the framework builds the computational graph automatically.
2. The Chain Rule is Applied Recursively
No matter how complex the network, gradient computation always breaks down into applications of the chain rule. Each operation only needs to know its local derivative, making the system modular and extensible.
3. Gradients are Computed Efficiently
The backward pass reuses intermediate computations from the forward pass, making gradient computation efficient. This is far more efficient than computing derivatives symbolically or using finite differences.
4. Any Differentiable Function Can Be Used
As long as you can define the derivative of basic operations (multiplication, addition, exponential, log, etc.), you can compute gradients through arbitrary combinations of these operations. This is why deep learning frameworks support such a wide variety of operations.
Practical Implications
Understanding computational graphs has several practical implications:
Memory Management
Computational graphs store intermediate values during the forward pass for use in the backward pass. For very deep networks or large batch sizes, this can consume significant memory. This is why techniques like gradient checkpointing exist — they trade computation for memory by recomputing some forward pass values during backpropagation.
Custom Operations
If you need to implement a custom operation, you need to define both:
- The forward pass (how to compute the output)
- The backward pass (how to compute the gradient)
The framework will handle everything else.
Debugging
When debugging gradient-related issues, thinking in terms of computational graphs can help you understand where gradients might be vanishing, exploding, or getting blocked.
Dynamic vs Static Graphs
PyTorch uses dynamic computational graphs (built on-the-fly during execution), while older versions of TensorFlow used static graphs (defined before execution). Dynamic graphs are more flexible and intuitive, which is one reason for PyTorch’s popularity.