Bigram Language Model
Lesson, slides, and applied problem sets.
View SlidesLesson
Bigram Language Model
Goal
Build the simplest trainable language model: predict the next token using only the current token.
1) Model definition
A bigram LM stores logits for every pair of tokens:
- Parameters:
Wof shape(V, V) - Given token
t, logits areW[t]
In code, this is just an embedding table where embedding_dim = vocab_size.
2) Cross-entropy loss
For each position:
- logits -> softmax -> probability of target
- loss is
-log(p_target) - total loss is mean over positions
3) Training loop (SGD)
for step in range(steps):
logits = model(x) # (T, V)
loss = cross_entropy(logits, y)
for p in model.parameters():
p.grad = 0.0
loss.backward()
for p in model.parameters():
p.data -= lr * p.grad
4) Sampling
Autoregressive generation:
- start with a token id
- sample next token from softmax of logits
- append and repeat
Temperature can scale logits for more or less randomness.
5) Sanity checks
- Overfit a tiny dataset (loss should decrease quickly)
- Generated text should mirror bigram statistics of the corpus
Key takeaways
- Bigram LM is a learned lookup table of next-token logits.
- It is the smallest end-to-end language model.
- It sets up the training + generation pipeline used by GPT.
Next: expand the context window with an MLP language model.