When you shouldn't trust einsum
When building neural networks, you’re often working with a bunch of operations that are basically matrix multiplication but with inputs that aren’t arranged or oriented in exactly the right way. One way to deal with this is to pull up all the various tensor shape manipulation operations available in a library like PyTorch, and fiddle with all the dimensions of your tensors until the operation you want is a matrix multiplication. This is a pain, and no one likes writing or reading code like
I, J, K, L = x.shape
x = x.permute([1, 2, 3, 0]).flatten(0, 2)
M, L, P = y.shape
y = y.transpose(0, 1).flatten(1, 2)
z = x @ y
z = z.reshape(J, K, L, M, P)
If the author is thoughtful, these operations will at least be commented to annotate the expected tensor shapes at each point along the way. And there are libraries like jaxtyping that let you record that information in a somewhat more formal and verifiable way. But it’s still messy and annoying.
So a lot of people prefer to use torch.einsum(), or even more powerful libraries like einops to write these sorts of transformations. With einsum(), the whole mess above is just
z = torch.einsum("ijkl, mip -> jklmp", x, y)
This is much nicer and (so I thought) probably better optimized than whatever index shuffling code I might happen to write.
But there are situations where you might try a different approach to computing a matrix multiplication. I recently was using a product of the form
z = torch.einsum("ijk, ikl -> jl", x, y)
In my case, x might have a shape like (8, 16384, 2048) and y might have a shape like (8, 2048, 1024). One way to implement this is to convert it to a strict matmul:
x = x.transpose(0, 1).flatten(1, 2)
y = y.flatten(0, 1)
z = x @ y
But another way is to treat it as a batched matmul and then reduce over the first index:
z = (x @ y).sum(dim=0)
One might think that einsum() would consider these (and other) options and select whichever is the most efficient. This is not the case. PyTorch’s einsum() implementation tries to push as much work onto the GEMM kernel as possible, so it always works as follows:
- Permute dimensions of the two tensors so that the contracting indices are at the end for the first tensor and the beginning for the second tensor.
 - Flatten both tensors into two dimensions.
 - Run a GEMM.
 - Reshape and permute the resulting tensor to have the expected output shape.
 
There’s a reasonable argument that this is the most efficient way to compute a tensor contraction in general. It avoids allocating any intermediate buffers that would otherwise need to be constructed in a multi-step reduction, and GEMM kernels are very well optimized for basically any backend.
There’s a problem, though: if the input tensors are not contiguous, flattening them can require a copy. For some input shapes, that copy can be larger than the intermediate buffer that might be required. (For instance, if x had shape (2, 128, 1024) and y had shape (2, 1024, 128) the intermediate tensor would be (2, 128, 128), much smaller than a copy of x.)
In my case, x was actually a slice X[:8,] along the i dimension from a larger tensor X, and thus not contiguous. This might have been okay, since even then the copy of x was smaller than the intermediate tensor. But this was happening in an autograd context. This meant that the copy of x needed to stick around until the backwards pass was finished, so that its gradient could be populated and then propagated back to X. This doesn’t happen when you use the second method: the sum() call doesn’t need to save anything for the backward pass, and the batched matmul only saves a view of X for the backward pass, consuming no extra memory. As the cherry on top, this einsum() operation was actually happening for a bunch of different overlapping slices of X, making a separate copy of each. Overall, this was adding tens of gigabytes of memory consumption to a forwards + backwards pass of my model. Switching to the two-step reduction drastically reduced this memory consumption.
There are two morals to this story. One is that every abstraction is leaky, and you never know when you’ll need to understand what’s beneath it. In 95% of cases einsum() is smarter than me, and I never bothered to look beneath the hood. That meant I had no idea what might be going on. The second is that if you care about performance, you need to look at performance profiles. The path to my realizing what was going here on stared by wondering why every einsum() call in my execution trace had a child aten::clone(). PyTorch has pretty good memory and execution tracing tools; if you need your GPU to go fast, use them.