Input (T, D) Q, K, V each (T, D)
(T, D)
softmax(Q K^T / sqrt(D)) @ V
Block future tokens with -1e9 mask.
-1e9
Subtract max before exp.