How do transformers work? No, I do not mean electrical transformers, even though that's an interesting subject on its own right. I mean transformer neural networks, the neural network architecture that led to the "AI explosion" we witnessed in the past decade.

There are, of course, countless papers, numerous Internet resources providing details and in-depth explanations. Still, I thought I'd try something different. Given my familiarity with the abstract index notation that is used especially in general relativity and Riemannian geometry, it seemed natural for me to try and document what transformers actually do using this mathematical language.

Inference

We start with data, which is in the form of a sequence of tokens $t_{i=1..C}$ where $C$ is the size of the context window.

The first step is {\em embedding}. Each of these tokens is used as a lookup index to find a corresponding embedding vector. In addition, the token position is also encoded in the form of positional embedding. We end up with a set of vectors,
\begin{align}
Z_{i=1..C}^{j=1..D}=E_{t_i}^j + P_i^j,\tag{1}
\end{align}
where $E_{k=1..V}^{j=1..D}$ is the set of embedding vectors ($V$ is the size of the token vocabulary, $D$ is the model dimension), and $P_i^j$ ($i=1..C$) is the set of positional embedding vectors. Before proceeding, these values are normalized: $Z_i^j\to (Z_i^j - \langle Z^j\rangle_i)/\sqrt{\sigma(Z^j)_i}$ where $\langle Z^j\rangle_i$ and $\sigma(Z^j)_i$ are the mean and variance of the components of the $i$-th vector, respectively.

Next, we have sets of "key", "query" and "value" vectors (never mind the popular explanations, let's just focus on the math):
\begin{align}
K_{j=1..D}^{k=1..K},\qquad Q_{j=1..D}^{k=1..K},\qquad V_{j=1..D}^{k=1..K},\tag{2}
\end{align}
where $K=D/H$ is the dimensions per head $H$. As an example, we may have a model dimension of $D=512$, but divided into $H=4$ heads, so each of these heads will have a set of 64 512-dimensional $K$-vectors; same for $Q$ and $V$.

We next form the product $(Z\cdot K)\cdot(Q\cdot Z)$, that is, the $D\times D$-dimensional "similarity matrix" $S$ (Einstein summation convention implied):

\begin{align}
S_{im} = \frac{1}{\sqrt{K}}\delta_{kl}(Z_i^j K_j^l)(Z_m^n Q_n^k).\tag{3}
\end{align}

If the goal is to create an autoregressive model (e.g., a conversational chatbot) we then mask the upper triangle of the similarity matrix:
\begin{align}
S'_{im}=S_{im}+M_{im},\tag{4}
\end{align}
where
\begin{align}
M_{im}=\begin{cases}0,&\text{if }j\le i,\\-\infty&\text{otherwise}
\end{cases}.\tag{5}
\end{align}


This is then subjected to the softmax function $\sigma$ before it is further multiplied into $V$ to form the "self-attention":
\begin{align}
Y_{i=1..C}^{k=1..K} = \sigma(S_{im})\delta^{jm}V_j^k.\tag{6}
\end{align}
If we had multiple heads, $^{h}Y_i^k$ with $h=1..H$, we then concatenate them into a multihead attention matrix,
\begin{align}
\hat{Y}_{i=1..C}^{j=1..D}={}^{(1)}Y_{i}^k \oplus ... \oplus {}^{(H)}Y_{i}^k.\tag{7}
\end{align}

The (multihead) attention (output projection) is then formed using the weights $W_{l=1..D}^{j=1..D}$:
\begin{align}
A_i^j=W_l^j \hat{Y}_i^l.\tag{8}
\end{align}

The attention is then treated as a perturbation of the embedding:
\begin{align}
\hat{Z}^j_i = Z^j_i + A^j_i.\tag{9}
\end{align}
Once again, we normalize: $\bar{Z}_i^j = (\hat{Z}_i^j - \langle \hat{Z}^j\rangle_i)/\sqrt{\sigma(\hat{Z}^j)_i}$.

Finally, the feed-forward step involves weights $F_k^l$, $G_l^j$ and biases $b^l$, $c^j$, with $l=1..F$ representing the feed-forward dimensionality, and with ReLU (rectified linear unit) activation in-between:

\begin{align}
Z'^j_i = \hat{Z}^j_i + 
G^j_l{\rm max}(0,F^l_k\bar{Z}^k_i + b^l_1)+c^j_1,\tag{10}
\end{align}
where, in a slight abuse of notation, we used $b^l_1$ to indicate that the same vector $b^l$ is used for all lower index positions $i=1..C$.

With this, our multihead transformer block is complete. An actual implementation may have several such transformer layers before producing the final output token sequence: $Z^j_i \to Z'^j_i \to Z''^j_i ...$

After the final step (dropping multiple primes in the notation now) we have $Z'^j_i$ that we now project to "logits", a probability distribution across the token vocabulary:
\begin{align}
L^{m=1..V}_{i=1..C}=Z'^j_i O^m_j,\tag{11}
\end{align}
where $O^{m=1..V}_{j=1..D}$ are the output weights. On output, we select a specific token based on this probability distribution.

Backpropagation

Training is accomplished by gradient descent, which requires computing gradients of all learned parameters with respect to a loss function. First, therefore, a loss function must be defined.

For each token position $i$ the logit forms a probability distribution. We then form
\begin{align}
{\cal L}_i=-\log \left(\frac{\exp L_i^t}{\sum_j\exp L_i^j}\right),\tag{12}
\end{align}
where $t$ is the index of the correct token at the $i$-th position.

Conceptually, this amounts to calculating the cross-entropy loss
\begin{align}
H(p,q)=-\sum_j p^j\log q^j,\tag{13}
\end{align}
where
\begin{align}
q^j=\frac{\exp L_i^j}{\sum_j L_i^j}\tag{14}
\end{align}
represents the predicted probability distribution at the $i$-th position in the context window, whereas
\begin{align}
p^j=\begin{cases}1\qquad\text{if }j=t,\\0\qquad\text{otherwise}
\end{cases}\tag{15}
\end{align}
is the probability "distribution" of the expected result (target), with $t$ representing the index of the correct token at the $i$-th position.

The final loss function is then taken as the average loss across all counted positions in the context window (typically, it's best not to include padding in this calculation):
\begin{align}
{\cal L}=\langle{\cal L}_i\rangle.\tag{16}
\end{align}

Backpropagation begins with the quantity ${\cal L}$ and its partial derivatives. Without going into excess detail, we simply follow the forward inference mechanism backward, going through chains of Jacobian matrices to form the following quantities:
\begin{align}
&\frac{\partial{\cal L}}{\partial L_i^m},\qquad
\frac{\partial{\cal L}}{\partial O_j^m},\qquad
\frac{\partial{\cal L}}{\partial G^j_l},\qquad
\frac{\partial{\cal L}}{\partial c^j},\qquad
\frac{\partial{\cal L}}{\partial F^l_k},\qquad
\frac{\partial{\cal L}}{\partial b^l},\qquad\nonumber\\
&\frac{\partial{\cal L}}{\partial W_l^j},\qquad
\frac{\partial{\cal L}}{\partial V_j^k},\qquad
\frac{\partial{\cal L}}{\partial Q_j^k},\qquad
\frac{\partial{\cal L}}{\partial K_j^k},\qquad
\frac{\partial{\cal L}}{\partial E^j_{t_i}}.\qquad\tag{17}
\end{align}
In the actual implementation, these quantities are formed using cached interim results from the forward pass rather than computationally costly recalculation.

Finally, in each iteration, we take these quantities, scale them by the learning rate $\rho$, and use them to adjust the corresponding quantity:
\begin{align}
X^i_j\to X^i_j - \rho\frac{\partial{\cal L}}{\partial X^i_j}.\tag{18}
\end{align}
The learning rate itself may be subject to scaling, e.g., for per-layer, per-head, or per-transformer quantities.

Model parameters

The number of model parameters can be calculated from the model architecture:

\begin{align}
N=[(4D+1)D+(2D+1)F]L + (C+2V)D.\tag{19}
\end{align}

For instance, a "tiny" model with $D=32$ model dimensions, $F=64$ feedforward dimensions, $L=2$ layers, a vocabulary size of $D=300$ and a context length of $C=96$ tokens yields 38,848 parameters.