Transformers
The page explains how transformers work. Readers are expected to be familiar with how ordinary neural networks work and with gradient descent.
Information Lookup
Suppose you have a table of facts.
key | value |
Who was the President in 2002? | George W. Bush |
When was Charlemagne coronated? | October 9, 768 |
... | |
How large is the Indian Ocean? | 73.4 million square kilometers |
We want to use this to help our neural network answer questions. A natural approach is to compute the embeddings of the keys and values. Then our network can send us the embedding of a query, we can compute which key's embedding is closest to the query, and return the value.
class NeuralNetworkFactTable(nn.Module):
def __init__(self, num_rows, query_dimension, value_dimension):
super(NeuralNetworkFactTable, self).__init__()
self.key_embeddings = torch.randn((num_rows, query_dimension))
self.value_embeddings = torch.randn((num_rows, value_dimension))
def forward(self, query_embedding):
similarities = self.key_embeddings @ query_embedding
return self.value_embeddings[similarities.argmax()]
This seems like an intuitive way to help the neural network encode its knowledge, so it can be used later. Unfortunately this doesn't work since the gradients for all of your keys will be zero, since "argmax" is not differentiable.
The traditional solution to "argmax has no gradients" is to use softmax, and that is exactly what we can do here.
class NeuralNetworkFactTable(nn.Module):
def __init__(self, num_rows, query_dimension, value_dimension):
super(NeuralNetworkFactTable, self).__init__()
self.key_embeddings = torch.randn((num_rows, query_dimension))
self.value_embeddings = torch.randn((num_rows, value_dimension))
def forward(self, query_embedding):
similarities = self.key_embeddings @ query_embedding
similarities = nn.softmax(similarities, 0)
return (self.value_embeddings * similarities).sum(0)
Now when we send a query to the table, instead of getting a single value, we get a weighted average of all the values. The more similar a value's key is to our query, the more influence it will have on the returned value.
Attention
What is attention? When people in machine learning talk about attention, they are usually appealing to intuitions about how human attention works:
- Attention is limited. You cannot focus on 100 things at once.
- Attention cannot be negative. You can focus on something, but you cannot antifocus on it.
The use of softmax seems to intuitively bring both these properties to our table-retrieval scheme. The query causes the model to pay "attention" to certain values more than others.
Long Term Dependencies
Let's leave aside self attention for now and talk about sequences. Particularly long sequences. Long before transformers, machine learning researchers had experimented with 1D convolutional neural networks and Recurrent Neural Networks. Both had a similar problem: they struggled to remember important context that had happened far earlier in a text. For example:
Alice was carrying a grape. Alice walked to school. Alice was happy that the color of the fruit she was carrying was ____.
To answer that question the model needs to remember the fruit from the first sentence. For various reasons, this is difficult for RNNs or CNNs, and while you can try to modify their architectures to mitigate the problem, these architectures are fundamentally biased towards preferring to look at recent words, rather than words that occurred a long time ago.
Transformers avoid this bias altogether, by simply having every word be directly connected to every other word. In the age of classical neural networks (e.g. 2015), a preference towards nearby words seemed like a reasonable way to regularize networks, but in the age of incredibly large datasets, we care less about hard-coding seemingly-sensible regularization than we used to.
Self Attention
Self Attention tries to extend the concept of "attention". Rather than paying attention to a static fact in a table, we have multiple inputs, which want to "pay attention" to each other.
We use each input to compute a key and a value. Now we have a table (just like above), but, rather than representing static facts about the Universe, it represents information about each input (the value) and a way to retrieve it (the key).
Finally, we compute a query for each input. Now each input can retrieve information about each other input by comparing its query to every other input's key.
class SelfAttention(nn.Module):
def __init__(self):
super(NeuralNetworkFactTable, self).__init__()
self.query_computer = nn.Seq(...)
self.key_computer = nn.Seq(...)
self.value_computer = nn.Seq(...)
def forward(self, x):
queries = self.query_computer(x) # (N, D) matrix
keys = self.key_computer(x) # (N, D) matrix
values = self.value_computer(x) # (N, D) matrix
similarities = queries @ keys.T # (N, N) matrix
# Make each entry positive, and make each row sum up to one
similarities = nn.functional.softmax(similarities, 1)
# For every input, compute a weighted average of all values
return similarities @ values # (N, D)
This approach is very common in neural networks that are modeling language, where different words in a sentence give "attention" to other words in the sentence.
(Side note: we use "words" in this article for simplicity, but in theory there's no reason you can't tokenize your sentence any way you want, and, in practice, tokenizing with BPEs ("byte pair tokens") is more common than tokenizing with words).In practice there are usually constraints that are imposed. The most common example is that words are often not allowed to pay attention to words that come after them. When you see a "mask" in transformer code, this is what is going on: certain inputs are being hidden from other inputs.
We can force inputs to ignore other inputs by subtracting negative infinity from their similarities before computing the softmax. This will force the weight (after the softmax) to be zero.

While there are lots of ways we can convert inputs into queries, keys, and values, the simplest practical approach is to just use a matrix multiplication, resulting in this code:
class SelfAttention(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(NeuralNetworkFactTable, self).__init__()
self.query_computer = nn.Linear(input_dim, hidden_dim)
self.key_computer = nn.Linear(input_dim, hidden_dim)
self.value_computer = nn.Linear(input_dim, output_dim)
def forward(self, x):
queries = self.query_computer(x) # (N, hidden_dim) matrix
keys = self.key_computer(x) # (N, hidden_dim) matrix
values = self.value_computer(x) # (N, output_dim) matrix
similarities = queries @ keys.T # (N, N) matrix
# input 4 cannot pay attention to input 7
similarities[4,7] -= float('inf')
similarities = nn.functional.softmax(similarities, 1)
return similarities @ values # (N, output_dim)
Note: the code above works for a single sequence that is paying attention to itself, but in practice we're usually training on an entire batch, which complicates the above code. See the Self-Attention Code section for a complete implementation.
Heads
Which brings us to the transformer. The two core ideas of the transformer are given above, but let's be explicit:
- Self Attention: using the attention/query/key/value paradigm describe above
- Full connectivity: all words are connected to all other words
(Note that, while Transformers use both of these features, it's totally possible (sensible even!) to have a model that only has one — for example, you could replace the softmax in a Transformer with a ReLU and you'd have a totally sensible model that fully connects all inputs... but you wouldn't be using "self attention").
Transformers introduce one last concept: multiple "heads". All this means is that we perform this self attention task multiple times (once for each "head") and then concatenate the results together.
class Transformer:
def __init__(self):
self.heads = nn.ModuleList([
SelfAttention(),
SelfAttention(),
# ...
SelfAttention(),
])
def forward(self, x):
results = []
for head in self.heads:
results.append(head(x))
return torch.cat(results, 1)
One could write this as a for loop (and it would be easier to understand), but for performance reasons you will never find this in other people's code. Instead they use tensor voodoo (clever reshaping, transposing, and matrix multiplication) to avoid explicitly looping over each head.
Positional Encodings
One problem that Transformers have is that it cannot determine the relative location of two words. This is a pretty big flaw! "The white lady's hat" vs "The lady's white hat" mean different things!
The solution to this is that, in addition to mapping words to embeddings, we tack on a few extra dimensions that tell the model where the word is. The simplest possible positional embedding is just the index of the word.
For example, the simplified code to convert a sentence into embeddings might look like this:
def sentence2vector(words: list[str]):
result = []
for word in words:
emb = get_embedding(word)
result.append(emb)
return torch.stack(result)
But we can add our simple positional embedding like this:
def sentence2vector(words: list[str]):
result = []
for word in words:
emb = get_embedding(word)
emb = torch.cat([emb, torch.tensor([i])])
result.append(emb)
return torch.stack(result)
(Note: real positional embeddings are more complicated to make things easier on the model)
Miscellaneous
Care About Your Magnitudes
Since NLP folks don't like BatchNorm, they have to do the work of actually thinking about the magnitudes of their activations. As a result, you need to divide by the square root of your query vector dimension before taking a softmax
similarities = (queries @ keys.T) / torch.sqrt(queries.shape[-1])
Self-Attention Code
def matrix_multiply_last_two_axes(A, B):
# Suppose our inputs have these shapes:
# A.shape = (a, b, c, d)
# B.shape = (a, b, d, e)
# We want to run this code:
#
# result = torch.zeros((a, b, c, e))
# for i in range(a):
# for j in range(b):
# result[i,j,:,:] = A[i,j] @ B[i,j]
#
# This function is an efficient implementation
# of that.
assert A.shape[:-2] == B.shape[:-2]
A = A.reshape((-1,) + A.shape[-2:])
B = B.reshape((-1,) + B.shape[-2:])
return torch.bmm(A, B).reshape(A.shape[:-2] + (A[-2], B[-1])
class SelfAttention(nn.Module):
def __init__(self, num_tokens,
input_dim, hidden_dim, output_dim, num_heads = 1):
assert input_dim % num_heads == 0
super(NeuralNetworkFactTable, self).__init__()
self.query_computer = nn.Linear(input_dim, hidden_dim)
self.key_computer = nn.Linear(input_dim, hidden_dim)
self.value_computer = nn.Linear(input_dim, output_dim)
self.num_heads = num_heads
# This mask stops tokens from paying attention to tokens
# that are in the future. The mask looks like
# 1 0 0 0
# 1 1 0 0
# 1 1 1 0
# 1 1 1 1
# We add an extra dimension so this is applied to all
# elements of the batch.
self.register_buffer("bias",
1 - torch.tril(
torch.ones((num_tokens, num_tokens))
).reshape((1, 1, num_tokens, num_tokens)))
def forward(self, x):
# nn.Linear always dot-products with the last dimension,
# so no changes are needed here
batch_size, seq_length, _ = x.shape[0]
shape = (batch_size, seq_length, self.num_heads, -1)
queries = self.query_computer(x).reshape(shape)
keys = self.key_computer(x).reshape(shape)
values = self.value_computer(x).reshape(shape)
# Now we swap the num_heads and seq_length axes, so they
# look like (batch_size, num_heads, seq_length, -1). This
# is convenient because it means we can replace the matrix
# multiplications in the non-batched and single-headed
# code earlier in the article with the
# "matrix_multiply_last_two_axes" we wrote above.
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
# Q/K/V shapes: (batch_size, num_heads, seq_length, -1)
similarities = matrix_multiply_last_two_axes(
queries,
# We're computing the *inner* product of each query and
# each key, so we need another transpose here.
keys.transpose(2, 3),
)
# similarities' shape is
# (batch_size, num_heads, seq_length, seq_length)
# Make the attention magnitudes sensible before the
# softmax.
similarities = similarities / math.sqrt(queries.shape[-1])
# Stop tokens from paying attention to tokens
# in the future.
similarities -= (1 - self.mask) * float('inf')
similarities = nn.functional.softmax(similarities, 3)
result = matrix_multiply_last_two_axes(
similarities,
values,
)
# The only problem is that we transposed the heads and
# the seq_length axes at the beginning of this function,
# so result's shape is
# (batch_size, num_heads, seq_length, output_dim / num_heads)
# when we want it to be
# (batch_size, seq_length, output_dim)
return result.transpose(1, 2).reshape(
batch_size, seq_length, -1))