Mini-GPT
1 / 7
End-to-end GPT-style language model: embeddings -> transformer blocks -> logits -> generation
Token IDs
-> Token Embedding + Position Embedding
-> TransformerBlock x N
-> Final LayerNorm
-> Linear -> Vocab logits
Shapes: input (T), logits (T, V)
Targets are input shifted by one. Loss is mean cross-entropy over positions.
logits = model(x)
loss = cross_entropy_loss(logits, y)
zero_grads()
loss.backward()
update()
max_seq_lenSingle-head attention in this pack for clarity.