Beyond Backpropagation - Higher Order, Forward and Reverse-mode Automatic Differentiation for Tensorken
Tensors from Scratch, Part 3
This post describes how I added automatic differentiation to Tensorken. Tensorken is my attempt to build a fully featured yet easy-to-understand and hackable implementation of a deep learning library in Rust. It takes inspiration from the likes of PyTorch, Tinygrad, and JAX.
Tensorken's approach to automatic differentiation (or AD) is heavily inspired by JAX. Like JAX, Tensorken supports higher-order derivatives - besides the first derivative, it can calculate the second, third, and so on. Tensorken supports both forward and reverse-mode AD, and can arbitrarily compose the two. Finally, thanks to good fundamentals explained in the previous two posts (part 1 and part 2), Tensorken can compute derivatives on the CPU or GPU.
All code for this post is in the Tensorken repository, tagged v0.3.
Previously in Tensors From Scratch: neural networks, matrix multiplication, and GPU acceleration
Modern neural networks, for example, large language models (LLMs) like OpenAI's ChatGPT and GPT-4, Microsoft's Bing, Google's Bard, and Anthropic's Claude, are powered by tensors. Tensors are multi-dimensional arrays augmented with operations that execute efficiently on modern hardware, most notably GPUs.
To understand all that, I am building a neural network library like PyTorch or JAX, from the ground up in Rust. These libraries consist of:
A tensor library, to provide efficient operations to slice, dice, and apply bulk operations to tensors.
Accelerators, to accelerate tensor operations on the GPU.
Automatic differentiation, to train neural networks via gradient descent.
Neural network building blocks, to simplify using common activation functions and layers.
In the first post, I focussed on the tensor library. I described almost twenty fundamental tensor operations and abstracted them in a Rust trait called RawTensor
. RawTensor
had a single implementation, CpuRawTensor
which executes tensor computations on the CPU. In the second post, I implemented RawTensor
again in WgpuRawTensor
to execute on the GPU using wgpu, Rust's implementation of WebGPU. We dove into the nitty-gritty of GPU programming in general and wgpu in particular.
This third part of the series describes how to add automatic differentiation to Tensorken. Automatic differentiation (AD) is a technique to compute derivatives of tensor computations, without programmer intervention. AD is crucial because neural networks are trained via gradient descent, which relies on the efficient calculation of derivatives.
How to train your neural network
Let's sketch how to train a neural network to emphasize how important AD is for deep learning.
First, gather training data, and lots of it. Training data are lots of input-expected output pairs. The input examples are encoded as numbers and aggregated in a tensor 𝚇. The outputs go in a tensor 𝚈. Think of 𝚈 as the correct predictions for the inputs 𝚇. For a language model, each example in 𝚇 could be a sequence of words, and 𝚈 the next word, encoded as numbers. (How to encode text as numbers is an interesting problem that's not relevant to this story.)
Second, decide on the architecture of your neural network. A neural network consists of tensors 𝚆ᵢ that contain the parameters of the network. That's what you download when you get a model's weights. The architecture determines how many parameters we have and how we combine the input 𝚇 with parameters Wᵢ to obtain an output 𝚈'. Whatever the architecture is, we can execute it to predict 𝚈':
𝚈' = 𝚏(𝚇, 𝚆ᵢ).
I'm simplifying - researchers distinguish weights 𝚆 and biases b, but in the end, they're both part of the trainable parameters so I'm just lumping them together in the 𝚆s.
Third, using the expected 𝚈 and the prediction 𝚈', calculate the loss 𝙻. The loss is a single number that is high when the prediction is bad, and low when it is good. The loss is calculated by comparing the network's prediction 𝚈' with the expected output 𝚈:
𝙻 = 𝚕(𝚈, 𝚈').
Fourth, calculate the gradient 𝙶 of the loss. The scalar loss value 𝙻 is a function of 𝚇, 𝚆ᵢ, and 𝚈. Imagine the loss function as describing a (highly dimensional!) landscape. Training the network to improve its predictions means changing the parameters Wᵢ to make the loss small. We'd like to know how we should change the parameters to achieve that.
Now is when the gradient comes in. Going back to the landscape analogy, to make the loss smaller we'd like to know the best direction to "move" in to go "down" - that is, from the current value of the parameters, find the direction with the highest slope. If you remember some calculus, the derivative of a function at a point is that slope. So, to calculate the gradient, we calculate the loss function's derivative with respect to each parameter 𝚆ᵢ. In other words, we'll have a number for each parameter that tells us how to change that parameter to make the loss smaller.
Fifth, update the parameters using the gradient. There are many ways of doing this. The simplest is to multiply the gradient with a small number ϵ and subtract it from the parameters:
Wᵢ <- Wᵢ − ϵ𝙶ᵢ.
That's one training step done! Your neural network just got a tiny bit better. Now repeat from step 2 until you've had enough. You can stop when the loss becomes small enough, when it stops changing for some number of iterations, or when your AWS bill exceeds the budget.
Tensorken can already do almost all of those steps. Running a network, calculating a loss, and updating the parameters amounts to applying tensor operations. What's missing is calculating the gradient via the loss function's derivative. In the olden days, people would calculate the derivative of the network by hand, symbolically, and then implement it manually. Clearly tedious and error-prone, not to mention limiting the complexity and size of the networks. Modern neural network libraries calculate a function's output and its derivative without programmer intervention using the miracle of automatic differentiation.
AD is a vast and intricate topic. For a (much) longer primer on the basics, see my earlier post. If you are unfamiliar with AD I encourage you to read it or any of the AD primers in the links.
The following section demonstrates Tensorken's AD capabilities and interface via small examples. Then we'll dive into implementation details, but I'll stay away from the detailed mechanics of AD since that's already covered elsewhere. Instead, I'll focus on how Tensorken implements higher-order, mixed-mode, JAX-style AD as an elegant and minimal Rust library.
The Autodiff Cookbook in Tensorken
To demonstrate Tensorken's AD capabilities, I translated a significant part of JAX's Autodiff Cookbook to Tensorken. I reproduced and edited part of the original text here. JAX's license is Apache 2.0, so I hope this does not incur the wrath of Google. The titles in this section are similar to the ones in JAX's cookbook, in case you want to compare. The full example code is in jax_autodiff_cookbook.rs.
Before we begin - Tensorken runs on the CPU if you create tensors via the Cpu32
type alias and on the GPU via Wgpu32
. (The 32
is because they work with 32-bit floating point numbers.) To make it easy to switch, I'll use the Tr
type alias throughout:
type Tr = Cpu32; // or Wgpu32
Gradients
You can differentiate a function using grad1
. The 1
indicates the number of arguments of the function - a poor man's variadic arguments. In the text, I'll sometimes refer to the family of grad1
, grad2
functions as grad
. In the code, I'll use the function with the correct number of arguments.
To start with, we'll use a simple scalar function - a function that takes a single number and returns a single number:
let p = Tr::scalar(2.0);
let df = grad1(|x| x.tanh(), &p);
> df: [ 0.07065082]
In Tensorken, all arguments must be a tensor Tr
- it doesn't support mixed tensors and scalar numbers. To turn a number into a tensor we first use Tr::scalar
. It makes a tensor with shape [1]
.
grad1
takes a function of one argument 𝚏 and evaluates ∇𝚏(𝚙), the derivative of 𝚏 at a given point 𝚙. You can think of ∇ as a higher-order function that takes a differentiable function 𝚏 and produces a function that evaluates the derivative.
Pronouncing ∇: I say "grad", I've heard people say "del", and the symbol's Unicode name is "nabla".
Similarly, if you have a Rust function f
that evaluates the mathematical function 𝚏, then grad(f, p)
computes the value ∇𝚏(𝚙).
Unlike JAX, Tensorken does not directly expose ∇𝚏 as a first-class function, mostly because I had a hard time accomplishing that in Rust and staying sane! It required returning a closure from grad(f)
so you can write grad(f)(p)
, but satisfying the compiler proved difficult. So far this hasn't been a constraint in practice.
Like JAX, Tensorken does support applying grad
to functions that themselves call grad
to calculate higher-order derivatives:
let ddf = grad1(|x| grad1(|x| x.tanh(), x), &p);
let dddf = grad1(|x| grad1(|x| grad1(|x| x.tanh(), x), x), &p);
> ddf: [ -0.13621868]
> dddf: [ 0.25265408]
Let’s try computing gradients with grad
in a linear logistic regression model. In other words, a simple neural network with one neuron. First, the setup:
// Outputs probability of a label being true.
fn predict<'t, T>(w: &'t T, b: &'t T, inputs: &T) -> T
where
T: TensorLike<'t>,
{
(inputs.dot(w) + b).sigmoid()
}
The function predict
encodes the architecture of our toy model. Its parameters are a vector w
and a scalar b
, for weights and bias. As you can see, we're multiplying the weights with the inputs and adding the bias. Then we use sigmoid
to squish the output values in the [0, 1] interval. This model predicts the probability of an outcome based on some input measurements.
Why is this equivalent to a neural network with a single neuron? Say the vector w
has three elements - three weights. We thus have three inputs as well, in inputs
. The function dot
multiplies each input with its corresponding weight, and then adds them up. The bias b
in the neuron analogy is typically a negative number, which represents a threshold that inputs.dot(w)
must exceed to "activate" the neuron.
All arguments are tensors, but are represented by a generic argument T
. The type needs to be generic so automatic differentiation can work. We'll see later why. T: TensorLike
is a handy constraint to make tensor operations like dot
, +
, and sigmoid
available on T
. You'll see the TensorLike
constraint often when using Tensorken's AD: to make functions differentiable, replace concrete Tensor
types with T: TensorLike
.
Let's run the model.
// Build a toy dataset.
// These are four measurements of some unspecified variable, one in each row.
let inputs = Tr::new(
&[4, 3],
&[
0.52, 1.12, 0.77, //
0.88, -1.08, 0.15, //
0.52, 0.06, -1.30, //
0.74, -2.49, 1.39,
],
);
// These are four observed outcomes, one for each row in the input.
let targets = Tr::new(&[4], &[1.0, 1.0, 0.0, 1.0]);
// Initialize the parameters w and b randomly
let key = 0;
let mut rng = StdRng::seed_from_u64(key);
let w = Tr::randn(&[3], &mut rng);
let b = Tr::randn(&[1], &mut rng);
let prediction = predict(&w, &b, &inputs);
> prediction: [ 0.4059896 0.37711427 0.9770815 0.007901279]
The inputs could be "changes in temperature observed on three consecutive days" and the targets could be "temperature went up or down on the next day". We're then training a model that predicts the probability of the temperature going up given three days' changes in temperature.
Since we initialized the model randomly, its prediction is random. We got unlucky: if you compare targets
(what we want) with prediction
(what we have) there is a big difference. The 3rd and 4th predictions are especially bad, almost the exact opposite of the training data.
To improve our model, we first need to quantify how crap the model is via a loss function.
// Training loss is the negative log-likelihood of the training examples.
fn loss<'t, T>(w: &'t T, b: &'t T, inputs: &T, targets: &'t T) -> T
where
T: TensorLike<'t>,
for<'s> &'s T: TensorLikeRef<T>,
{
let prediction = predict(w, b, inputs);
// ones_like makes a tensor of the same shape with all values equal to 1.
let label_probs = &prediction * targets
+ (&prediction.ones_like() - &prediction) * (targets.ones_like() - targets);
-label_probs.log().sum(&[0])
}
let l = loss(&w, &b, &inputs, &targets);
> loss: [ 10.4931755]
This loss function is negative log-likelihood. You can intuit why it works: prediction
is "compared" with targets
in label_probs
. It contains a high value for predictions that are close to the target. We then take the log
of each, which exaggerates its value: the logarithm is -infinity when label_probs
is zero. Since the logarithm is negative, we negate it to get a positive number. Then we take the sum
of the vector so we have a single positive loss number that is high when the model is doing badly, and low when it's making good predictions.
Now we can improve the model by adjusting its weights and biases. We use grad
to differentiate the loss
function with respect to the parameters w
and b
:
// Differentiate loss wrt weights
let w_grad = grad1(
|w| {
loss(
w,
&Reverse::lift(&b),
&Reverse::lift(&inputs),
&Reverse::lift(&targets),
)
},
&w,
);
print!("w_grad: {w_grad}");
// Differentiate loss wrt bias
let b_grad = grad1(
|b| {
loss(
&Reverse::lift(&w),
b,
&Reverse::lift(&inputs),
&Reverse::lift(&targets),
)
},
&b,
);
> w_grad: [ -1.0830948 2.5363755 -3.2000453]
> b_grad: [ -1.2319121]
To make the types work out, we need to Reverse::lift
all the arguments to loss
we do NOT want to differentiate. They are treated as constants. The type is called Reverse
because Tensorken uses reverse mode AD in this case. The Reverse
type reveals why we need to make the arguments to loss
and predict
generic: the grad
function, while taking a plain Tr
type as the second argument, passes Reverse<Tr>
to the closure. So the function f
can be called with Tr
, Reverse<Tr>
, or other types we'll see later.
Here's the simplified signature for grad1
.We'll get to the full signature later:
pub fn grad1<F>(f: F, at: &Tr) -> Tr where F: Fn(&Reverse<Tr>) -> Reverse<Tr>
Briefly, Reverse
is a wrapper to interpret tensor operations so they calculate the derivative along with the main result. In this example, it'll run a different dot
, +
, and sigmoid
compared to calling loss
with plain tensors of type Tr
.
Calling the loss
function twice is not ideal - we're doing twice the work. We can also calculate the gradients with respect to both w
and b
at the same time, using grad2
.
let (w_grad, b_grad) = grad2(
|w, b| loss(w, b, &Reverse::lift(&inputs), &Reverse::lift(&targets)),
&w,
&b,
);
> w_grad: [ -1.0830948 2.5363755 -3.2000453]
> b_grad: [ -1.2319121]
Finally, let's do a single training iteration and check if that improves our model.
// Update parameters
let new_w = &w - &w_grad;
let new_b = &b - &b_grad;
// Predict
let new_prediction = predict(&new_w, &new_b, &inputs);
let new_loss = loss(&new_w, &new_b, &inputs, &targets);
> new_prediction: [ 0.7384342 0.99262685 0.7747804 0.9996524]
> new_loss: [ 1.8016509]
A massive improvement - we're now only 1.8 crap, down from 10.5!
Evaluate a function and its gradient using value_and_grad
In a real training run, we'd do the above in a loop while keeping an eye on the loss to see when to stop. Again loss
is called twice: once inside grad
and once outside. Luckily, we don't have to. Another convenient family of functions is value_and_grad
to efficiently compute a function and its gradient.
let (loss_value, (w_grad, b_grad)) = value_and_grad2(
|w, b| loss(w, b, &Reverse::lift(&inputs), &Reverse::lift(&targets)),
&w,
&b,
);
> loss: [ 10.4931755]
> w_grad: [ -1.0830948 2.5363755 -3.2000453]
> b_grad: [ -1.2319121]
Checking against numerical differences
Our loss improved, which is a good indication that things work. To gain confidence we can compare Tensorken's derivatives with finite differences.
// step size for finite difference
let eps = Tr::scalar(1e-4);
let half_eps = &eps / Tr::scalar(2.);
let b_grad_numerical = (loss(&w, &(&b + &half_eps), &inputs, &targets)
- loss(&w, &(&b - &half_eps), &inputs, &targets))
/ &eps;
> b_grad_numerical [ -1.2207031]
> b_grad_autodiff [ -1.2319121]
Close enough.
Jacobians using jacfwd
and jacrev
Ignoring bias b
for now, the loss
function is a function of three parameters, represented as a single tensor w
with three elements. It has a single scalar output, represented as a tensor with a single element. Taking the gradient of this function results in a vector of three elements, the sensitivity of the loss to each parameter. This picture becomes more complicated if there is more than one output parameter. grad
still gives an answer, but what does it mean?
let deriv = grad1(
|w| predict(w, &Reverse::lift(&b), &Reverse::lift(&inputs)),
&w,
);
> deriv: [ 0.34956074 -0.0017646346 0.20271438]
Remember that predict
returns a vector with four elements, and the input w
is a vector with three elements. We get a vector with three sensitivities - one for each input. But the sensitivity of which output? There are four. As we'll check below, grad
returns the sum of the sensitivity of all outputs. That's typically not what we want: we'd like to disaggregate the sensitivities.
The usual approach is to represent the sensitivity of each output with respect to each input as a matrix, called the Jacobian. In this case, a 4 by 3 matrix - number of outputs by number of inputs. Tensorken can compute Jacobians, in forward and reverse mode using jacfwd
and jacrev
:
let J = jacfwd(
|w| predict(w, &Forward::lift(&b), &Forward::lift(&inputs)),
&w,
);
> jacfwd result, with shape [4, 3]
┌ ┐
│ 0.12540425 0.2701015 0.18569478 │
│ 0.20671119 -0.25369102 0.03523486 │
│ 0.01164451 0.0013435973 -0.029111274 │
│ 0.0058007482 -0.019518733 0.010895999 │
└ ┘
let J = jacrev(
|w| predict(w, &Reverse::lift(&b), &Reverse::lift(&inputs)),
&w,
);
> jacrev result, with shape [4, 3]
┌ ┐
│ 0.12540427 0.27010152 0.18569478 │
│ 0.20671119 -0.25369102 0.03523486 │
│ 0.01164451 0.0013435973 -0.029111274 │
│ 0.005800748 -0.019518731 0.010895998 │
└ ┘
These two functions compute the same values (up to machine precision), but differ in their implementation: jacfwd
uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices, while jacrev
uses reverse-mode, which is more efficient for "wide" Jacobian matrices. For matrices that are near-square, jacfwd
probably has an edge over jacrev
.
We can now check that grad
computed the sum of the sensitivity of all four outputs:
&J.sum(&[0])
> [ 0.34956074 -0.0017646346 0.20271438]
Using a composition of javfwd
and jacrev
gives us a way to compute dense Hessian matrices. Hessian matrices contain all the second derivatives.
let hessian = jacfwd(
|w| {
jacrev(
|w| {
predict(
w,
&Reverse::lift(&Forward::lift(&b)),
&Reverse::lift(&Forward::lift(&inputs)),
)
},
w,
)
},
&w,
);
println!("hessian with shape {:?}", hessian.shape());
> hessian shape [4, 3, 3]
Why this shape? We start with a function f:𝙽→𝙼. Traditionally, we'd write 𝚏:ℝⁿ→ℝᵐ, but that there are 𝙽 inputs and 𝙼 outputs is more important than that we're talking about real numbers, so I'll omit the ℝ from now on.
At a point 𝚡 ∈ 𝙽 we expect to get the shapes
𝚏(𝚡) ∈ 𝙼, the value of 𝚏 at 𝚡,
∂𝚏(𝚡) ∈ 𝙼 × 𝙽, the Jacobian matrix at 𝚡,
∂²𝚏(𝚡) ∈ 𝙼 × 𝙽 × 𝙽, the Hessian at 𝚡,
and so on.
To implement hessian
we could have used jacfwd(jacrev(f))
or jacrev(jacfwd(f))
or any other composition of the two. But forward-over-reverse is typically the most efficient. That’s because, in the inner Jacobian computation, we’re often differentiating a function with a wide Jacobian (maybe like a loss function 𝚏:𝙽→1), while in the outer Jacobian computation, we’re differentiating a function with a square Jacobian (since ∇𝚏:𝙽×𝙽), which is where forward-mode wins out.
Note we now need to lift
the inputs twice to make Rust's type checker happy.
That concludes the tour of Tensorken's AD capabilities. It packs a lot of punch - now let's see how to fit it in a small package.
A tale of two functions
All AD functions like jacfwd
, jacrev
, and grad
are implemented in terms of two function-type pairs: jvp
with Forward
, and vjp
with Reverse
. JVP stands for Jacobian-vector product, and VJP stands for Vector-Jacobian product. These functions are directly inspired by JAX. To explain their names, we need some math background that deserves a standalone post. If you can't wait, refer to this section in JAX's Autodiff Cookbook.
I'll now introduce jvp
and vjp
, and the beginnings of how AD works in Tensorken. I assume some background knowledge about AD, in particular AD for scalar functions. See my earlier post for a primer.
From scalars to tensors
Forward AD on scalar functions works by replacing operators and functions on numbers with versions that operate on a dual number - a (f32, f32)
tuple. The first element is the primal, which the function computes without AD. The second is the derivative, or tangent. Operations on dual numbers are straightforward:
apply the operation to the primal(s), and
apply differentiation rules to the tangent(s).
For example, multiplication on dual numbers is:
(p₁, t₁) . (p₂, t₂) = (p₁.p₂, p₁.t₂ + p₂.t₁)
Reverse mode is more involved. The primal computation is identical, but instead of calculating the tangent alongside the primal, we collect a trace - essentially a stack of operations. A reverse pass through the trace calculates the derivatives.
Exactly how these operations are replaced is a concern for the implementation. Common methods are code transformation in the compiler, code generation, and operator overloading. Tensorken uses trait-based overloading.
Forward mode has little extra memory requirements beyond bringing the tangent along for the ride, while reverse mode needs to keep a trace that's as long as the computation is deep. As a result, for scalar-to-scalar functions, forward mode is more efficient.
That situation changes if we consider functions from many scalars to one, or vice versa. One extreme is a function that takes a single input and computes n outputs. That's great for forward mode: in one execution of the function on dual numbers, we'll have both the primal result and the derivative - or in other words the sensitivity of each output to a small change in the single input.
However, a function that takes many inputs and has a single output is efficient only in reverse mode. In forward mode, we'd need as many executions of the function as there are inputs - we'd have to pass 1 as tangent for each input separately. In reverse mode, we still need the extra memory for the trace, but one forward pass for the primal and one backward pass for the partial derivatives is all we need.
The good news is that if you understand this, nothing much changes if we allow tensors instead of scalars. After all, a tensor is a container of scalars, and operations on tensors can be broken down into operations on scalars. That's not how we want to implement them though! Bulk operations are where the performance is at.
For forward mode, we'll overload tensor operations to propagate a "dual tensor", a tuple of a primal and a tangent tensor. For reverse mode, we'll build up a trace of tensor operations in the forward pass and get the tangent tensors from a backward pass.
One difference with scalar AD is that we need to take the shape of tensors into account. Besides arithmetic operations like addition and multiplication, we also need to figure out differentiation rules for sum
, reshape
, and others, which affect the shape of both primal and derivative tensors.
JVP for forward-mode AD
Here's the signature of jvp
:
pub fn jvp1<T: Diffable + Clone, F>(f: F, at: &T, tangent: &T) -> (T, T)
where
for<'a> F: Fn(&'a Forward<T>) -> Forward<T>,
As the 1
suffix indicates, this version takes a single primal tensor at
and a single tangent tensor tangent
. It evaluates the primal and tangent of the function f
and returns them as a tuple.
To understand why this signature makes sense, think of AD as a program transformation. Without AD we'd write programs that boil down to:
let p1 = f1(x);
let p2 = f2(p1);
...
With forward AD and jvp
we can rewrite them as:
let (p1,t1) = jvp(f1, x, x.ones_like());
let (p2,t2) = jvp(f2, p1, t1);
...
That illustrates how programs that compose functions can be transformed into programs that compose jvp
-wrapped functions.
Importantly, at
and tangent
must have the same shape, and the two tensors in the output tuple have the same shape. f
computes out.0
from at
, and jvp
additionally computes out.1
from at
and tangent
.
jvp
works for any tensor-like type that implements Diffable
. Diffable
is the foundational trait that defines Tensorken's primitive, differentiable tensor operations. Higher-level operations like matmul
are built on these. Keeping the tensor generic in jvp
allows it to work with any Diffable
implementation - something we'll use when doing higher-order AD. We'll come back to the details soon.
VJP for backward AD
Here is the signature for vjp
:
pub fn vjp1<'b, 't, T: Diffable + Clone + 't, F>(f: F, at: &T) -> (T, PullBack<'t, T>)
where
for<'a> F: Fn(&'a Reverse<'a, 't, T>) -> Reverse<'a, 't, T>,
It looks a bit different because reverse mode has a backward pass. What's the same are the differentiable function f: F
and the primal input at
. vjp
calls f
with Reverse
wrapping the input T
. Since reverse mode needs two passes, vjp
only returns the primal directly.
PullBack
(a term from differential geometry) is a named struct that executes the backward pass. It takes a cotangent, a tensor in the shape of the output of f
, and calculates the tangent, a tensor in the shape of the input of f
.
impl<T: Diffable + Clone> PullBack<'_, T> {
pub fn call(&self, cotangent: &T) -> T
}
It's all backward! But that's why reverse mode AD is more efficient if you have the right tensor shape.
A short note on why jvp
and vjp
have different signatures. On the one hand, we could re-write jvp
to return a PushForward
struct with a call
function that works similarly to jvp
's PullBack
. However, that would require keeping a trace of the operations around so users can call
multiple times with different tangents. That jeopardizes the memory efficiency of forward mode. The ability to re-execute the differentiating pass with different tangents does not offset the added memory usage.
We could also write vjp
with a signature like jvp
by making the PullBack
internal and call
ing at the end. In reverse mode, we have to expend the memory anyway, so we might as well make it available to the user for potential reuse.
Interpreters all the way down
We're now at the point where we can dive into the code, and it's interpreters all the way down.
Before AD, Tensorken's core was the RawTensor
trait, with implementations for the CPU and the GPU. It's useful to think of this trait as the definition of a language for primitive tensor operations, and implementations of the trait as interpreters of that language. Interpreters don't necessarily have to produce a tensor - for debugging and testing a pretty-printing interpreter for RawTensor
is useful:
impl RawTensor for String {
type Elem = f32;
fn exp(&self) -> Self {
format!("{self}.exp()")
}
fn add(&self, other: &Self) -> Self {
format!("({self} + {other})")
}
// etc
}
We can use it as follows:
let t1: String = RawTensor::new(&[2, 2], &[1., 2., 3., 4.]);
let t2: String = RawTensor::new(&[2, 2], &[5., 6., 7., 8.]);
let r = t1.exp().add(&t2.log());
> r: "(new([2, 2], [1.0, 2.0, 3.0, 4.0]).exp() + new([2, 2], [5.0, 6.0, 7.0, 8.0]).log())"
Or even:
let t1: String = "A".to_string();
let t2: String = "B".to_string();
let r = t1.exp().add(&t2.log());
> r: "(A.exp() + B.log())"
We could generate source code or an abstract syntax tree this way, turning the interpreter into a compiler of sorts. That is the essence of the final tagless approach I described in depth in an earlier post. It has many extensibility advantages, which we'll take advantage of soon.
What does this have to do with automatic differentiation? AD is achieved by hard-coding how to differentiate primitive operations like addition and multiplication, and composing those primitive rules via the chain rule. The primitive operations define a language of tensor operations which we can interpret in a few ways - in particular, as straightforward tensor operations without differentiation via Tensor
, as a forward mode differentiated program via Forward
, or as a reverse mode differentiated program via Reverse
. As for RawTensor
we represent the primitive operations of the differentiable language as a trait, Diffable
, and then implement this trait for each interpreter.
Let's start with the trait definition:
pub trait Diffable {
type Elem: Num;
fn log(&self) -> Self;
fn exp(&self) -> Self;
fn elementwise_add(&self, other: &Self) -> Self;
fn elementwise_sub(&self, other: &Self) -> Self;
fn elementwise_mul(&self, other: &Self) -> Self;
fn elementwise_div(&self, other: &Self) -> Self;
fn elementwise_pow(&self, exp: &Self) -> Self;
fn elementwise_eq(&self, other: &Self) -> Self;
fn sum(&self, axes: &[usize]) -> Self;
fn max(&self, axes: &[usize]) -> Self;
fn reshape(&self, shape: &[usize]) -> Self;
fn permute(&self, dims: &[usize]) -> Self;
fn expand(&self, shape: &[usize]) -> Self;
fn pad(&self, padding: &[(usize, usize)]) -> Self;
fn crop(&self, limits: &[(usize, usize)]) -> Self;
fn new(shape: &[usize], data: &[Self::Elem]) -> Self;
fn shape(&self) -> &[usize];
}
Diffable
's operations are similar to RawTensor
's, and we can categorize them in much the same way - unary operations, binary operations, reduce-like operations, and shape-changing operations. Missing is the optimized fused multiply-add in RawTensor
, which illustrates the difference in intent between RawTensor
and Diffable
. While we could make RawTensor
differentiable, I'll now try to convince you we don't want to.
Fused multiply-add is an optimized operation that we need on the lowest level to have some hope of efficiency. It is likely that to make Tensorken more efficient, we'll need to add more special-purpose operations to better exploit hardware primitives, reduce memory usage, and so on.
We don't (necessarily) want to figure out how to differentiate those special-purpose operations - we'd like a small set of primitive operations, define their derivatives, and then compose those into higher-level operations. We then get derivatives of those higher-level operations for free, because differentiation is so beautifully composable. Separating Diffable
from RawTensor
allows us to add efficient, special-purpose operations to RawTensor
without figuring out their derivatives. Vice versa, we can add operations to Diffable
without having to change RawTensor
and its implementations.
Before Diffable
, we translated user-facing operations like matrix multiplication to RawTensor
operations, which were interpreted by a concrete RawTensor
like CpuRawTensor
. Now we add another interpreter, Diffable
, between the user-facing operations and RawTensor
, which not only calculates the primal results but also derivatives. Diffable
interpreters execute both primal and derivative calculations as RawTensor
operations. That means we can combine all implementations of Diffable
with all implementations of RawTensor
. So we can do forward AD on the GPU, reverse AD on the CPU, or any other combination.
Let's make our way down the interpreter layers to see how this works in practice. We'll start with matrix multiplication and end up at CpuRawTensor
.
Each of the sections that follow is one layer of the interpreter lasagne:
High-level tensor operations like
matmul
are translated toDiffable
operations likesum
andelementwise_mul
.A
Diffable
interpreter likeForward
andReverse
translates primitive operations likesum
andelementwise_mul
toRawTensor
operations, adding calculation of derivatives.A
RawTensor
interpreter likeCpuRawTensor
executes the operations on a particular device.
User-facing layer: matrix multiplication in terms of Diffable
Here is a sketch of matmul
, omitting everything that is not an operation on Diffable
:
pub trait DiffableExt: Diffable
{
fn matmul(&self, other: &Self) -> Self {
// preconditions, shape manipulation omitted
// special cases omitted
let l = self.reshape(&l_shape);
// shape manipulation omitted
let r = other
.reshape(&r_shape)
.transpose(r_shape.ndims() - 1, r_shape.ndims() - 2);
// after multiply: [..., m, o, n]
l.mul(&r)
// after sum: [..., m, o, 1]
let sum = prod.sum(&[prod.shape().ndims() - 1]);
// after reshape: [..., m, o]
let s = sum.shape();
sum.reshape(&s[..s.ndims() - 1])
}
}
Tensorken has three implementations of Diffable
: Tensor
, Forward
, and Reverse
. Tensor
doesn't do any differentiation at all - it translates Diffable
to RawTensor
operations. Forward
and Reverse
augment the operations with their respective mode of AD. We'll come back to these later - first, we need to find a Rust vehicle to put the user-facing operations that are not in Diffable
. We could re-implement them on each implementation of Diffable
, but that is redundant. Instead, I've defined DiffableExt
, a sub-trait of Diffable
with a blanket implementation:
pub trait DiffableExt: Diffable
{
// all the fns we want, like matmul, go here.
// They'll need to be defined in terms of Diffable,
// because that's all that's available.
fn matmul(&self, other: &Self) -> Self { ... }
}
impl<T: Diffable> DiffableExt for T {}
The advantage is we only have to implement Diffable
on a concrete type, then anything defined on DiffableExt
is available too (as long as DiffableExt
is in scope.)
The first Diffable
implementation: Tensor
We now need a concrete type to present to users. Tensor
is that type. Its definition is mysteriously simple:
pub struct Tensor<T>(T);
The idea is that the generic type argument T
is a Diffable
. Why not add the type constraint here? Because it's unnecessary - for all interesting implementations, T
is Diffable
. Constraining T
here adds nothing new.
We can now make Tensor<T>
implement Diffable
for any T
that's Diffable
:
impl<T: Diffable> Diffable for Tensor<T> {
type Elem = T::Elem;
fn log(&self) -> Self {
Tensor(self.0.log())
}
// etc
}
All operations delegate to T
. Full implementation here.
From Diffable
to RawTensor
That gets us nowhere - we can have a differentiable Tensor<T>
if we have a differentiable T
. To execute tensor operations we need to get to a RawTensor
. We can do that by interpreting Diffable
operations as RawTensor
operations. In Rust, this means creating a blanket implementation of Diffable
for any RawTensor
:
impl<T: Num, TTensor: RawTensor<Elem = T>> Diffable for TTensor {
type Elem = T;
fn log(&self) -> Self {
self.log()
}
// etc
}
Since Diffable
is a subset of RawTensor
, the implementation is again straightforward. A type like Tensor<CpuRawTensor>
now works, and we can apply all operations in Diffable
and DiffableExt
to it.
It seems like we went around in a big circle. After Tensorken parts 1 and 2, we had a Tensor<T: RawTensor>
with high-level operations like matmul
and primitive operations on RawTensor
. Now we have Tensor<T: Diffable>
with high-level operations like matmul
moved to DiffableExt
, differentiable primitive operations on Diffable
, and primitive executable operations still on RawTensor
.
What we gained is the ability to have other Diffable
implementations. We're going to use that ability now.
Forward-mode AD with Forward
The Forward
type wraps T
with extra stuff so we can transform and trace the computation to calculate the derivative. In interpreter terms, Forward
is an interpreter for the Diffable
language that calculates the derivative alongside the primal result using forward-mode AD. It does that by applying all tensor operations on a dual tensor.
The Forward
type:
pub enum Forward<T> {
Lift(T),
Forward(T, T),
}
Like for Tensor<T>
, the T
here is a Diffable
tensor. The Forward
case should make sense - it's the primal and the tangent tensors. We use the Lift
case if we're not interested in computing the derivative of a tensor. Lifted tensors are treated as constants for the derivative computation. Another way of saying this is that their derivative is zero. We avoid many multiplications with zero by having a dedicated case instead of using the functionally equivalent Forward(t, zero)
.
We can understand jvp1
's implementation now:
pub fn jvp1<T: Diffable + Clone, F>(f: F, at: &T, tangent: &T) -> (T, T)
where
for<'a> F: Fn(&'a Forward<T>) -> Forward<T>,
{
let forward = Forward::Forward(at.clone(), tangent.clone());
let result = f(&forward);
match result {
Forward::Lift(p) => (p.clone(), p.zeros_like()),
Forward::Forward(p, t) => (p, t),
}
}
We wrap the at
and tangent
arguments in Forward
, then call f
with them and unwrap the Forward
from the result.
Forward
must implement Diffable
for this to work. Finally, we come to the implementation of differentiation rules for the primitive operations:
impl<T: Clone + Diffable> Diffable for Forward<T> {
type Elem = T::Elem;
fn elementwise_mul(&self, rhs: &Self) -> Self {
self.binary::<MulOp<T>>(rhs)
}
fn sum(&self, axes: &[usize]) -> Self {
self.unary::<SumOp, _>(axes)
}
// etc
}
Full implementation here.
Calculating the primal and derivatives are encapsulated in Op
structs. The unary
and binary
functions deal with handling Lift
or Forward
enum cases in one place, and delegate to a given Op
struct for the calculation:
impl<T: Diffable> Forward<T> {
fn unary<Op: UnaryOp<T, Args = TArgs> + UnaryDiffOp<T>, TArgs: ?Sized>(
&self,
args: &TArgs,
) -> Self {
let (primal, op) = Op::f(self.primal(), args);
match self {
Forward::Lift(_) => Forward::Lift(primal),
Forward::Forward(_, tan) => Self::Forward(primal, op.dfda(tan)),
}
}
}
binary
is similar but more involved because it has 4 combinations of Lift
and Forward
.
Here's MulOp
:
pub(crate) struct MulOp<TTensor>(TTensor, TTensor);
impl<TTensor: Clone + Diffable> BinaryOp<TTensor> for MulOp<TTensor> {
fn f(a: &TTensor, b: &TTensor) -> (TTensor, Self) {
(a.elementwise_mul(b), MulOp(a.clone(), b.clone()))
}
}
impl<TTensor: Diffable> BinaryDiffOp<TTensor> for MulOp<TTensor> {
fn dfda(&self, d: &TTensor) -> TTensor {
d.elementwise_mul(&self.1) // da * b
}
fn dfdb(&self, d: &TTensor) -> TTensor {
d.elementwise_mul(&self.0) // db * a
}
}
Differentiation rules often capture intermediate results or arguments of the primal computation. So f
returns not only the result of the primal computation but also a struct to store whatever data is needed for the derivative computation. For MulOp
, it captures the input tensors a
and b
.
dfda
and dfdb
define how to compute the derivative with respect to the first and second argument, given d
, the derivative of downstream functions. The differentiation rule for elementwise tensor multiplication is essentially the same as for scalar multiplication.
Unary operations are similar but don't define dfdb
:
pub(crate) struct SumOp(Vec<usize>);
impl<TTensor: Diffable> UnaryOp<TTensor> for SumOp {
type Args = [usize];
fn f(a: &TTensor, axes: &Self::Args) -> (TTensor, Self) {
let r = a.sum(axes);
(r, SumOp(axes.to_vec()))
}
}
impl<TTensor: Diffable> UnaryDiffOp<TTensor> for SumOp {
fn dfda(&self, d: &TTensor) -> TTensor {
d.sum(&self.0)
}
}
SumOp
only needs the reduced axes from the primal computation to calculate dfda
The derivative of the sum is the sum of the derivatives, so we can apply the same sum
to primal and tangent.
You can find all the ops here and here.
Forward<Forward<T>>
for higher order derivatives
Reiterating this signature:
pub fn jvp1<T: Diffable + Clone, F>(f: F, at: &T, tangent: &T) -> (T, T)
where
for<'a> F: Fn(&'a Forward<T>) -> Forward<T>
Since the only requirement on T
is that it's Diffable
and Forward
is Diffable
, besides a Tensor<T>
we can pass a Forward<T>
to jvp1
to calculate higher-order derivatives.
let p: Tensor<CpuRawTensor<f32>> = Tr::scalar(2.0);
let ddf = diff1(|x: &Forward<Tensor<_>>|
diff1(|x: &Forward<Forward<Tensor<_>>>| x.tanh(), x),
&p
);
As we'll see soon, this same design will allow us to combine forward and reverse modes up to arbitrary depth, by building up types like Forward<Reverse<Tensor<..>>>
.
This scheme works because the Diffable
operations are implemented in terms of Diffable
. That looks circular, but it's not: it's a stack of interpreters with a Tensor
at the bottom, which translates Diffable
operations to RawTensor
operations:
Forward<Forward<...>>: Diffable
->... -> Tensor: Diffable
-> RawTensor
Somewhat surprisingly, differentiating a differentiated program gets us the second derivative. One way to make sense of that is that if you compute second or third derivatives symbolically, that's exactly what you do: you apply the differentiation rules multiple times. If you want to do a "fun" exercise, you can work out that stacking Forward
types amounts to operating on duals-of-duals up to the desired order. If you work out the differentiation rules by hand, you'll find that it yields the correct higher-order derivative.
Reverse-mode AD with Reverse
The implementation of Diffable
for Reverse
follows the same pattern as Forward
but is more involved. Because reverse mode accumulates the derivative in a separate backward pass, we can no longer compute everything on the fly when we compute the primal. Instead, Reverse
builds a trace of operations in a forward pass while calculating the primal result, then accumulates derivatives in the backward pass.
The difference with forward mode is visible in the signature of vjp
:
pub fn vjp1<'b, 't, T: Diffable + Clone + 't, F>(f: F, at: &T) -> (T, PullBack<'t, T>)
where
for<'a> F: Fn(&'a Reverse<'a, 't, T>) -> Reverse<'a, 't, T>
Like jvp
, it returns the primal result. Unlike jvp
, it doesn't return the tangent, but instead a PullBack
struct. The only available operation on that is call
:
pub fn call(&self, cotangent: &T) -> Vec<T>
where
T: Diffable + Clone,
This takes a cotangent tensor - in other words, a tensor with the same shape as the result of f
, and returns the tangents of all the arguments of f
. Here's how vjp
is used:
pub fn value_and_gradn<'t, T: Diffable + Clone + 't, F>(f: F, at: &[&T]) -> (T, Vec<T>)
where
for<'a> F: Fn(&'a [Reverse<'a, 't, T>]) -> Reverse<'a, 't, T>,
{
// one forward pass, tracing
let (primal, pullback) = vjpn(f, at);
// one backward pass, accumulating derivatives
let tangents = pullback.call(&primal.ones_like());
// but we get multiple tangents in one go
(primal, tangents)
}
Other implementations for grad
functions follow a similar pattern.
The details of how this is implemented (via a Trace
type) are explained in my post on AD, so I won't repeat them here. It is not substantially different from the scalar case. Briefly, here is the Reverse
type:
pub enum Reverse<'a, 't, T> {
Lift(T),
Reverse(&'a Trace<'t, T>, T, usize),
}
Like Forward
, it has a Lift
case for tensors we don't want to differentiate. The Reverse
case contains the primal T
, and some administrative data to record the trace and do the backward pass.
The implementation of Diffable
looks similar to Forward
:
impl<T: Clone + Diffable> Diffable for Reverse<'_, '_, T> {
type Elem = T::Elem;
fn elementwise_mul(&self, rhs: &Self) -> Self {
self.binary::<MulOp<T>>(rhs)
}
fn sum(&self, axes: &[usize]) -> Self {
self.unary::<SumOp, _>(axes)
}
// other omitted
}
Again we have unary
and binary
helper methods to deal with Lift
and call the appropriate functions on the Op
structs.
Interestingly, even though reverse mode calculates derivatives backward, from the output to the input, MulOp
is identical for forward and reverse mode. This is true for all elementwise operations.
sum
however is different from forward mode. In the backward pass, we get a d
in the shape of the result of the sum
(i.e. with fewer elements) and we need to produce a tensor in the shape of the input of sum
. To do that, we need expand
:
pub(crate) struct SumOp(Vec<usize>);
impl<TTensor: Diffable> UnaryOp<TTensor> for SumOp {
type Args = [usize];
fn f(a: &TTensor, axes: &Self::Args) -> (TTensor, Self) {
let r = a.sum(axes);
(r, SumOp(a.shape().to_vec()))
}
}
impl<TTensor: Diffable> UnaryDiffOp<TTensor> for SumOp {
fn dfda(&self, d: &TTensor) -> TTensor {
d.expand(&self.0)
}
}
Full implementation for reverse mode is in ad_reverse.rs and the reverse operations are in ad_ops_reverse.rs.
After all that, we can run all the examples in the demo section. However, there is one remaining issue.
Un-blowing up matmul, again
The problem is serious. Repeating the (pseudo-code) implementation of matmul
in DiffableExt
:
pub trait DiffableExt: Diffable
{
fn matmul(&self, other: &Self) -> Self {
// preconditions, shape manipulation omitted
// special cases omitted
let l = self.reshape(&l_shape);
// shape manipulation omitted
let r = other
.reshape(&r_shape)
.transpose(r_shape.ndims() - 1, r_shape.ndims() - 2);
// TROUBLE BEGINS HERE
// after multiply: [..., m, o, n]
l.mul(&r)
// after sum: [..., m, o, 1]
let sum = prod.sum(&[prod.shape().ndims() - 1]);
// after reshape: [..., m, o]
let s = sum.shape();
sum.reshape(&s[..s.ndims() - 1])
}
}
See that mul
followed by sum
? In the second part of Tensors from Scratch, I explained that this blows up memory, to the point where this approach is utterly unscalable. The fused multiply-add function in RawTensor
came to the rescue - we rewrote the separate sum
and mul
calls into one l.fused_multiply_add(&r, dims)
, which made it efficient. Now we've regressed to the previous bad situation. What gives?
First, Diffable
doesn't have fused_multiply_add
, so we can't write the optimized version directly. We could add fused_multiply_add
to Diffable
as a primitive operation, but then we have to define a differentiation rule for it in the various modes. One of the main reasons for Diffable
's existence is to avoid that.
Second, while manually fusing mul
and sum
worked for this particular case, users may inadvertently write a mul
followed by a sum
, and fall into this trap themselves. Worse, while we're calculating derivatives by composing operations in forward or reverse mode, Tensorken itself may introduce a mul
followed by a sum
. Manually fusing all cases is not going to work. We need a better solution.
If we were writing a compiler, it'd be straightforward to go through the abstract syntax tree of tensor operations and transform any l.mul(r).sum(axes)
into l.fused_multiply_add(r, axes)
. Can we do a similar optimization here?
Let's think about what's happening from the perspective of interpreters. We have defined a language for writing differentiable programs using the trait Diffable
. Everything we do with tensors - matmul
, crop
, max
, sigmoid
as well as getting derivatives, is eventually a program in terms of the operations on Diffable
. We have three interpreters for Diffable
- one that translates the differentiable program to RawTensor
operations, and two that augment the differentiable program with forward or reverse mode AD. No matter how many times we stack Diffable
on top of Diffable
, eventually the program gets run via a RawTensor
interpreter.
We only have concrete RawTensor
interpreters so far - that calculate the results on CPU or GPU, or that print a string representing the result. But we can also write an interpreter that spits out a new, optimized RawTensor
interpreter, with all mul
+ sum
fused into fused_multiply_add
.
This technique - which I didn't invent at all, to be clear - is introduced more gradually and gracefully in my post on typed tagless final interpreters. Here I'll give a whirlwind tour of the implementation.
We'll use a type called Fuse<T>
. T
is the target optimized RawTensor
. Whenever mul
followed by sum
is detected in the unoptimized, original RawTensor
, Fuse
rewrites the two operations to a fused equivalent.
enum FuseCtx {
Sum(Vec<usize>),
NotSum,
}
pub struct Fuse<T>(Rc<dyn Fn(&FuseCtx) -> T>);
The function from FuseCtx
to the fused T: RawTensor
is a factory function we'll build up while interpreting the original RawTensor
as Fuse<T>
. In other words, Fuse<T>
interprets RawTensor
as a function that given a FuseCtx
produces an optimized RawTensor
. It works in two passes. A first pass builds up the factory function, then a second pass to run the function and get a new RawTensor
.
Since Fuse
only needs to fuse multiply and sum operations, it delays the application of sum
, and instead passes Sum(axes)
to the continuation via FuseCtx
. The continuation calls the delayed sum
if it can't fuse or fused_multiply_add
if it can. Here's the implementation of mul
where fusing happens:
impl<TRaw: RawTensor + Clone + 'static> RawTensor for Fuse<TRaw> {
type Elem = TRaw::Elem;
fn mul(&self, other: &Self) -> Self {
let f_lhs = Rc::clone(&self.0);
let f_rhs = Rc::clone(&other.0);
let nextctx = FuseCtx::NotSum;
Fuse::new(move |ctx| match ctx {
FuseCtx::Sum(axes) => f_lhs(&nextctx).fused_multiply_add(&f_rhs(&nextctx), axes),
FuseCtx::NotSum => f_lhs(&nextctx).mul(&f_rhs(&nextctx)),
})
}
}
The context passed in the closure represents what the next operation is, from the perspective of the current operation. If it's sum
, the Sum
enum case, we fuse. If it's anything else, represented by NotSum
, we know the operation has already been applied and we can't fuse. Since mul
is not a sum
, we pass NotSum
as the next context.
Here is the implementation of sum
:
fn sum(&self, axes: &[usize]) -> Self {
let f = Rc::clone(&self.0);
let my_axes = axes.to_vec();
Fuse::new(move |ctx| match ctx {
FuseCtx::Sum(sum_axes) => f(&FuseCtx::Sum(combine_axes(&my_axes, sum_axes))),
FuseCtx::NotSum => f(&FuseCtx::Sum(my_axes.clone())),
})
}
We do not apply sum
straight away to the resulting interpreter. Instead, we pass Sum
through to the next operation, so it gets a chance to fuse it. Any operations that don't fuse, need to apply the delayed sum
if they get the Sum
enum. We might as well fuse consecutive sum
calls into one by combining axes, hence the first match arm.
Fusing happens in two passes: the first pass builds the FuseCtx -> RawTensor
function. The second pass creates the optimized RawTensor
by calling the function:
impl<T> Fuse<T> {
fn run(&self) -> T {
(self.0)(&FuseCtx::NotSum)
}
}
Link to full implementation of fusing.
Now I can finally reveal the full Cpu32
and Wgpu32
types:
pub type Cpu32 = Tensor<ShapeTracker<Fuse<CpuRawTensor<f32>>>>;
pub type Wgpu32<'d> = Tensor<ShapeTracker<Fuse<WgpuRawTensor<'d, f32>>>>;
The remaining unknown there is ShapeTracker
. ShapeTracker
is a RawTensor
implementation that abstractly interprets the operations by only tracking tensor shapes. It delegates all operations to the RawTensor
it wraps, except shape
:
pub struct ShapeTracker<T>(ShapeStrider, T);
/// This implementation passes every operation through
/// to self.1, except for shape.
impl<TRaw: RawTensor> RawTensor for ShapeTracker<TRaw> {
type Elem = TRaw::Elem;
fn exp(&self) -> Self {
Self(self.0.clone(), self.1.exp())
}
// etc
fn shape(&self) -> &[usize] {
self.0.shape()
}
}
Because Fuse
does not track shapes but does need to implement RawTensor::shape
, it can only return its shape by running the delayed computation. We don't want that - some derivative operations require access to the shape of tensors, and it would be bad if we had to run the tensor computation at that point.
ShapeTracker
solves this for us - it can answer shape
queries without executing tensor operations. The order is important here. ShapeTracker
needs to wrap Fuse
which needs to wrap the concrete CpuRawTensor
or WgpuRawTensor
.
I love it when a plan comes together
Thanks to the power of interpreters aka final tagless encoding, Tensorken gained a capable yet small and extensible AD implementation. So far, I'm really happy with how Tensorken turned out. I started seriously researching deep learning from an implementation perspective at the beginning of 2023 with only some prior exposure to automatic differentiation. I randomly ran into the typed tagless final interpreters paper while I was studying TinyGrad, and figured that TinyGrad's style would lend itself well to the tagless final style. I could not have hoped for a better outcome.
After that, I saw a post on Reddit praising JAX and immediately preferred the functional style over PyTorch's imperative AD interface. It was much more challenging to implement in Rust though! Those signatures look straightforward now, but it took a lot of struggling with closures, lifetimes and lifetimes and closures and then lifetimes some more before everything came together. All this to say - I got lucky trying an implementation style I hadn't ever used and struggled for a long time. When AD finally worked, it felt almost magical. Persistence is worth some IQ points.
Now that Tensorken has all the pieces of a full-fledged deep learning library, it's time to put it to the test. I intend to follow along with Andrej Karpathy's Zero to Hero neural networks course, translating it from PyTorch to Tensorken. At the end of that, we should have a home-grown, walking and talking nanoGPT. Without a doubt, there'll be many interesting problems in Tensorken itself to solve along the way.
Many thanks for reading!
References
JAX: The Autodiff Cookbook. I adapted some parts for this post.
JAX: Autodidax - JAX core from scratch. Great insight into how JAX works under the hood, using a small implementation of JAX.
Video: Automatic Differentiation by Matthew Johnson. Matthew Johnson is one of the authors of JAX. Here he talks about the principles underlying Autograd, another Python-based AD library.
TinyGrad. A small but powerful tensor library with reverse-mode AD in Python.
Swift's Differentiable Programming Manifesto. Swift has a powerful differentiable programming component, integrated with the compiler.
Swift For Tensorflow (Google Drive). A great overview of Swift's approach to AD.
DiffSharp. A tensor library with support for differentiable programming for .NET.