[Target audience: Me from a month ago, people who want a sense of what a transformer is doing on a non-technical level, and people who want to chunk their understanding of transformers.]
[AI Safety relevance rating: QI]
Imagine a special art gallery, the Gallery for Painting Transformations (GPT). It takes in a sentence and makes a simple set of paintings to represent it, at first hardly more than stock images. Each morning, a team of artists come in and add to the paintings, and each evening, a snapchat-style filter is applied to each painting individually. Thanks to these efforts, over time the paintings in the gallery grow in meaning and detail, ultimately containing enough information to predict a painting which would be a good addition to the gallery.
This is how GPT-3 works, or at least (I claim) a reasonable analogy for it. In this post I want to flesh out this analogy, with the end result that even non-ML people can get a sense of what large language models are doing. I have tried to be as accurate in how I analogize things as possible, with the notable exception that my examples containing semantic meaning are a lot more human-comprehensible than what GPT-3 is doing, and in reality everything it does looks like “perform a seemingly-random vector operation”. Any mistakes are intentional creative flourishes my own1.
The Paintings
The gallery holds 2048(=n_ctx)2 paintings, which represent the model’s context window (the prompt you feed into GPT-3). In GPT-3, each word is represented by a 12288(=d_model)-dimensional vector. Fortuitously, that’s exactly enough dimensions to make a 64-by-64 pixel RGB image!
The gallery first opens when a user feeds in a prompt, and each token (tokens≈words) in the prompt is turned into a painting which looks like a generic stock image, the “token embedding”. If you feed in more or less than 2048 tokens, the message is trimmed or padded with a special token.
The gallery itself has no set order, there are simply paintings that can be viewed in any order. Instead, since word order is sometimes important in sentences, the gallery also paints on a little indicator to denote sentence position, the “position embedding”. In the previous image I showed this as drawing a little number in the corner of each painting, though this can be far more complicated. Since the position embedding is part of the painting, it can also be transformed over time.
At this point the gallery has paintings, but they’re somewhat simplistic. Over the next 96(=n_layers) days, the gallery transforms the paintings until they are rich in meaning, both individually and as a collection. Each day represents a layer of the network, where a new team of artists (attention heads) works during the morning and a new filter (feed-forward network) is applied each evening.
The Artists
The artists in this analogy are the transformer’s attention heads. Each day, a new team of 96(=n_heads) artists come in and add a few brush strokes to each the canvasses depending on what they see. Each artist has two rules that decides what they paint:
A rule to determine what brush strokes painting X adds to paintings that earn its attention.
This value rule is encoded in the “value” and “output” matrices V and O, each 128-by-12288 (d_head-by-d_model). Thinking of painting X as a 1-by-12288 vector, X tries to add the vector XVO^T to other paintings. If one stacks all the painting vectors into a 2048-by-12288 matrix M, you can compute all outputs at once in the 2048-by-12288 matrix B=MVO^T
A rule to determine how much attention painting X should give to painting Y.
This attention rule is encoded in the “key” and “query” matrices K and Q, each 128-by-12288 (d_head-by-d_model). Thinking of paintings X and Y as 1-by-12288 vectors, the “pre-attention” paid from one to the other is X(Q^T)K(Y^T).
If one stacks all the painting vectors into a 2048-by-12288 matrix M, you can compute all “pre-attentions” at once in the 2048-by-2048 matrix MQ(K^T)(M^T)=(MQ)(MK)^T.
“Pre-attentions” can be any number (positive or negative), and we pass each row through a softmax function so that the actual attention values are between 0 and 1, and each painting gets a total of 1 attention from all the other paintings. Writing σ for the row-wise softmax, the attention matrix is A=σ[(MQ)(MK)^T]
The artist draws all the brush strokes designated by the first rule, with an opacity in proportion to the attention designated by the second rule.
The overall output for the attention head is AB=σ[(MQ)(MK)^T)](MVO^T), which is added to the original matrix M.
Keep in mind that since these are vectors, you should think of this paint as “additive” with itself and the original canvas (instead of literally "painting over and rewriting the original value). All artists work simultaneously on all canvases, and the gallery as a whole has been trained so that the artists work in harmony rather than interfering with each other.
Let’s see how this works in practice by following one hypothetical artist. For simplicity, let’s pretend the gallery only contains two paintings (the others could be padding tokens which the artists are told to ignore).
In painting X, a person is climbing a tree, and in painting Y a person is eating an apple.
Our artist is a “consistency-improving” artist who tries to make all the paintings tie into each other, so the first rule tells them that X should draw a similar-looking tree in the the background, while Y should draw more apples.
Next, the artist thinks about where attention should be directed. Using their second rule (including the position embeddings suggesting X happens before Y), they decide the attention on X should come .5/.5 from X and Y, respectively, and the attention on Y should come .99/.01 from X and Y, respectively.
Drawing the first-rule images to the strengths determined by the second rule, painting X gets a background tree and some apples in the branches of the foreground tree, while painting Y gets a background tree (and some very faint outlines of ghost apples).
Thanks to this artist, the paintings are now more narratively consistent: instead of an unrelated tree-climbing and apple-eating, the paintings show that someone picked apples from an apple orchard, and then ate the apple while still near the tree.
Other artists might perform other tasks: maintaining a good balance of colors, transferring details, art-styles, or themes between canvases, showing cause and effect, or [incomprehensible and seemingly-arbitrary complicated matrix calculation that is somehow essential to the whole network]. Overall, the purpose of the artists is to carry information between canvases.
The Filters
After the artists are finished each day, a filter is applied to each painting. On any particular night the same filter is applied to each painting, but a new filter is used each night.
In more technical terms, the filter is a feedforward network consisting of the input painting (12288 width), a single hidden layer (49152=4*12288 width), and an output layer (12288 width). Writing W_1 and W_2 for the weight matrices, b_1 and b_2 for the bias vectors, and α for the activation function (GPT uses GELU), the output of the filter sublayer is given by F=α(XW_1+b_1)W_2+b_2. X is size 1-by-12288, W_1 is size 12288-by-49152, b_1 is size 1-by-49152, W_2 is size 49152-by-12288, and b_2 is size 1-by-12288.
As with the attention heads, filters use a “residual connection”, meaning that the filter calculates F=α(XW_1+b_1)W_2+b_2 for each painting X, and returns X+F (not F alone). Two intuitive arguments for why we might want residual connections:
During training if the network is bad, F will be some noise centered on 0, and therefore the filter will return X+noise, which is much better than returning just noise.
The filter gets to focus on “how do I improve this painting?” rather than “how do I make a new good painting from scratch?”
What can these filters do? In principal, a great deal - by the universal approximation theorem, they could learn any function on the input if there is enough width in their hidden layer, and this hidden layer is pretty wide. Their only limitations are that the same filter is applied to every painting at once, independently of the other paintings. I imagine them as generic “improvements” to the paintings - upscaling, integrating the additions of the previous day’s artists, and of course [incomprehensible and seemingly-arbitrary complicated matrix calculation that is somehow essential to the whole network]. Overall, the purpose of the filters are to refine and evolve the paintings in isolation from each other.
Conclusion
And that’s what GPT does! To summarize:
Each token in the context is embedded as a painting in the gallery.
The paintings start out crude but over time accumulate more “meaning” and “detail”, due to alternating teams of artists (attention heads) and filters (feed-forward networks).
The artists transfer information between paintings based on the attention that each painting gives to the others.
The filters evolve paintings in isolation from each other.
Both artists and filters use “residual connections”, meaning that their outputs are changes layered on top of the existing painting, rather than making a new painting from scratch.
The artists and filters work in highly-trained coordination to build on top of each other’s work. Together they form an assembly line, where each team counts on the prior teams and is counted on in turn.
There are a few wrinkles I left out, which I’ll touch on here:
Training the artists and filters is very expensive (there are a huge number of them) and uses various tricks, but is not fundamentally different than training any other neural network.
I omitted that just before applying the softmax function in the step where you compute attention, you divide the matrix by 1/sqrt(d_head). This doesn’t change the qualitative behavior of this step, but improves training speed and other things.
There are also layer-norms applied after the artists and after the filters work. I don’t have a great grasp of layer norms yet, but I believe they’re something like another filter.
In other transformers such as the original Attention is All You Need transformer, there are also encoder-decoder attention. To the extent I understand them, this would be like if, after day 48, a copy was made of the gallery. On days 49-96, in addition to the artists and filters, there were teams of art historians that use the day 48 state of the gallery to determine attention and what to draw on the current paintings.
The final state of the gallery is used to predict which token should come next.
I believe everything I say here follows directly from these sources:
Where possible, I will include the parameter name from the GPT-3 paper in addition to the number.