Useful links
- YouTube lecture by Andrej Karpathy
- Attention is All You Need paper
- Repo with content from the video lecture
gpt.py
file with code mapping to slides- Apologies if the slides’ font might be too small. I wanted to pack as much info in there as possible without compromising on the visuals.
The what and the why
If you haven’t yet watched Karpathy’s lecture “Let’s build GPT: from scratch, in code, spelled out.” you are missing out on an absolute masterclass. From scratch, in two hours, Andrej builds and trains a character-level Generatively Pretrained Transformer (GPT) on the entire Shakespeare’s corpus, following the paper “Attention is All You Need“.
Following along is probably enough to learn a ton already. I am a very visual guy though, and (in most cases) I can’t really say I have fully digested something until I grab a piece of paper and sketch it out. The purpose of this blog is to do exactly that with Karpathy’s lecture. I wanted to put together a set of slides I could go back to regularly, skim through somehow quickly, and have an aha moment about the Transformer’s architecture without wasting hours reading endless papers. With this in mind, I tried to keep the text to the bare minimum and illustrate the inner workings of the model, mapping diagrams to code.
What follows is a breakdown of the GPTLanguageModel
class, the core of the model Andrej trains. We start from the top and we peel the onion one layer after the other, exposing its components, until we get to the bottom, the self-attention mechanism. Each code block is accompanied by one or two slides showing how it works.
GPTLanguageModel
class
class GPTLanguageModel(nn.Module):
def __init__(self):
super().__init__()
# each token directly reads off the logits for the next token from a lookup table
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head)
for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd) # final layer norm
self.lm_head = nn.Linear(n_embd, vocab_size)
# more code...
def forward(self, idx, targets=None):
B, T = idx.shape
# idx and targets are both (B,T) tensor of integers
tok_emb = self.token_embedding_table(idx) # (B,T,C)
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
x = tok_emb + pos_emb # (B,T,C)
x = self.blocks(x) # (B,T,C)
x = self.ln_f(x) # (B,T,C)
logits = self.lm_head(x) # (B,T,vocab_size)
# more code...
return logits, loss
# more code...
Block
class
class Block(nn.Module):
""" Transformer block: communication followed by computation """
def __init__(self, n_embd, n_head):
# n_embd: embedding dimension, n_head: the number of heads we'd like
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size)
self.ffwd = FeedFoward(n_embd)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
FeedForward
class
class FeedFoward(nn.Module):
""" a simple linear layer followed by a non-linearity """
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
MultiHeadAttention
class
class MultiHeadAttention(nn.Module):
""" multiple heads of self-attention in parallel """
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(head_size * num_heads, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
Head
class, aka self-attention
class Head(nn.Module):
""" one head of self-attention """
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# input of size (batch, time-step, channels)
# output of size (batch, time-step, head size)
B,T,C = x.shape
k = self.key(x) # (B,T,hs)
q = self.query(x) # (B,T,hs)
# compute attention scores ("affinities")
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
wei = F.softmax(wei, dim=-1) # (B, T, T)
wei = self.dropout(wei)
# perform the weighted aggregation of the values
v = self.value(x) # (B,T,hs)
out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
return out
Queries @ Keys: the masking and scaling trick
Let’s go through self-attention with an example
Self-attention in plain English
This is just a note to me with bits copy-pasted from the lecture’s transcript, to put in plain English what self-attention is actually doing.
“Now, we don’t actually want this [scaled row-wise wei
] to be all uniform because different tokens will find different other tokens more or less interesting and we want that [interaction] to be data-dependent. So for example, if I’m a vowel then I might be looking for consonants in my past, and maybe I’d like to know what those consonants are and I’d like that information to flow to me. […] This is the problem self-attention solves. The way it does it is the following: every single node or every single token at each position will emit two vectors. A query and a key. The query vector is what am I looking for, and the key vector is what I contain. The way we get affinities between these tokens in a sequence is by doing a dot product between the keys and the queries. My query dot-products with all the keys of all the other tokens and if the key and the query are sort of aligned they will interact to a very high amount and I will get to learn more about that specific token as opposed to any other token in the sequence […].
[…] This was the eighth token, and the eighth token knows what content it has and it knows at what position it is in, and based on that it creates a query: “hey I’m a vowel, I’m on the eighth position, I’m looking for any consonants at positions up to four”. And then all the nodes get to emit keys, and maybe one of the channels [embeddings values] could be “I am a consonant and I am in a position up to four”. That key would have a high number in that specific channel and that’s how the query and the key [when they dot-product] can find each other and create a high affinity. When this happens, softmax will end up aggregating a lot of information into the eighth position which will get to learn about the token it is paying attention to.”
That’s it. Hope you found this post as useful as it was for me to write it! Thanks for reading thus far!