Logit Prisms: Decomposing Transformer Outputs for Mechanistic Interpretability

Author

Thông T. Nguyễn

Published

June 17, 2024

Other Formats
Abstract
We introduce a straightforward yet effective method to break down transformer outputs into individual components. By treating the model’s non-linear activations as constants, we can decompose the output in a linear fashion, expressing it as a sum of contributions. These contributions can be easily calculated using linear projections. We call this approach “logit prisms” and apply it to analyze the residual streams, attention layers, and MLP layers within transformer models. Through two illustrative examples, we demonstrate how these prisms provide valuable insights into the inner workings of the gemma-2b model.

1 Introduction

Figure 1: An illustration of a “logit” prism decomposing logit into different components (generated by DALL-E)

The logit lens (nostalgebraist 2020) is a simple yet powerful tool for understanding how transformer models (Vaswani et al. 2017; Brown et al. 2020) make decisions. In this work, we extend the logit lens approach in a mathematically rigorous and effective way. By treating certain parts of the network activations as constants, we can leverage the linear properties within the network to break down the logit output into individual component contributions. Using this principle, we introduce simple “prisms” for the residual stream, attention layers, and MLP layers. These prisms allow us to calculate how much each component contributes to the final logit output.

Our approach can be thought of as applying a series of prisms to the transformer network. Each prism in the sequence splits the logits from the previous prism into separate components. This enables us to see how different parts of the model—such as attention heads, MLP neurons, or input embeddings—influence the final output.

To showcase the power of our method, we present two illustrative examples:

  • In the first example, we examine how the gemma-2b model performs the simple factual retrieval task of retrieving a capital city from a country name. Our findings suggest that the model learns to encode information about country names and their capital cities in a way that allows the network to easily convert country embeddings into capital city unembeddings through a linear projection.

  • The second example explores how the gemma-2b model adds two small numbers (ranging from 1 to 9). We uncover interesting insights into the workings of MLP layers. The network predicts output numbers using interpretable templates learned by MLP neurons. When multiple neurons are activated simultaneously, their predictions interfere with each other, ultimately producing a final prediction that peaks at the correct number.

2 Method

flowchart LR
  A[Embed] --> I(( ))
  subgraph repeated N times
  I --> AttnNorm[Norm]
  AttnNorm --> B[Attention]
  B --> C((+))
  I --> C
  C --> MlpNorm[Norm]
  MlpNorm --> D[MLP]
  D --> E((+))
  C --> E
  end
  E --> L[Norm]
  L --> Unembed
Figure 2: A typical decoder-only transformer network where the residual stream is iteratively refined by a sequence of attention and MLP layers. There are five main components: the input embedding (Embed), the normalization layers (Norm), the Attention layers, the MLP layers, and the unembedding layer (Unembed).

We introduce simple “prisms” that allow us to break down the output of transformer networks into individual components. The key idea is to treat nonlinear activations as constants, which enables us to calculate the contribution of any component in the network using a series of linear transformations. In the following subsections, we explore this approach in more detail.

2.1 Residual Stream Decomposition

In a typical decoder-only transformer architecture (see Figure 2), the output logit can be expressed as: \[ \text{logits} = W_\text{unembed} \cdot \text{diag}\left(w_{\text{norm}}\right) \cdot \frac{w_{\text{embed}} + \sum_{i=1}^N (a_i + m_i)}{s} \] Here, \(W_\text{unembed} \in \mathbb{R}^{V\times d}\) is the unembedding matrix, \(w_\text{embed}\) is the token embedding vector, \(a_i\) and \(m_i\) denote the attention and MLP outputs at \(i\)-th layer respectively, \(s\) is a normalization factor (usually the root mean square of the denominator), and \(w_{\text{norm}}\) is the scaling vector of the last normalization layer.

By treating \(s\) as a constant, we can break down the \(\text{logits}\) into individual terms for each embed, attention and MLP layer: \[ \text{logits} = P \cdot w_{\text{embed}} + P \cdot a_1 + P \cdot m_1 + \cdots + P \cdot a_N + P \cdot m_N \] where \(P\) is the projection matrix: \[ P = W_{\text{unembed}} \cdot \text{diag}\left(w_{\text{norm}}\right) \cdot \left(s^{-1} I\right) \] Each term, like \(P \cdot a_1\), represents the individual contribution of the corresponding residual vector (\(a_1\)) to the final logit score. The cumulative sum \(P \cdot w_{\text{embed}} + \sum_{i=1}^\ell P \cdot \left( a_i + m_i \right)\) gives the model’s logit score up to layer \(\ell\).

2.2 Attention Decomposition

We can naturally break down the attention output into a sum over its attention heads: \[ a_\ell = \sum_{i=1}^{\text{\#heads}} H_\ell^i \] Here, \(H_\ell^i\) is the \(i\)-th attention head’s output at layer \(\ell\). Additionally, an attention head’s output is a weighted sum of the \(OV\) circuit outputs for each token in the sequence.1 The attention output at layer \(\ell\) is: \[ H^i_\ell = \sum_{p} \alpha_p \cdot OV_\ell^{i} \cdot \text{diag}\left(w_{\text{norm}}\right) \cdot \frac{h_{\ell - 1}^p}{s^p_{\ell-1}} \] In this formula, \(\alpha_p\) is the attention weight from the current token to the previous token \(p\). \(OV_\ell^i\) is the OV matrix for the \(i\)-th attention head at layer \(\ell\). \(h_{\ell-1}^p\) is the hidden state of token \(p\) from the previous layer \(\ell-1.\) Finally, \(s^p_{\ell-1}\) represents the normalization factor.

By treating the attention weight \(\alpha_p\) and normalization factor \(s^p_{\ell-1}\) as constants, we can express the attention output as a sum of linear projections of the previous layer’s hidden state \(h_{\ell-1}^p\). To find the contribution of token \(p\) via attention head \(i\), we use the following projection matrix: \[ P^{a_i}_\ell = P \cdot \left(\alpha_p I\right) \cdot OV_\ell^{i} \cdot \text{diag}\left(w_{\text{norm}}\right) \cdot \left({s^p_{\ell-1}}^{-1} I \right) \] Note that we can further decompose \(h_{\ell-1}^p\) into the sum of the token embedding and all previous layers’ residual outputs using residual stream decomposition.

2.3 MLP Decomposition

In transformer networks, the MLP layers consist of two linear transformations, \(W_{\text{up}}\) and \(W_{\text{down}}\), with a non-linear function \(g\) applied between them. The output of an MLP layer \(\ell\) can be expressed as: \[ m_\ell = W_{\text{down}} \cdot g \cdot W_{\text{up}} \cdot \text{diag}\left(w_{\text{norm}}\right) \cdot \frac{h^a_\ell}{s^a_\ell} \] Here, \(h^a_\ell\) is the hidden state from the previous attention layer, \(w_{\text{norm}}\) is a scaling vector, and \(s^a_\ell\) is a normalization factor.

The nonlinear point-wise function \(g\) allows neural networks to learn complex transformations that cannot be represented by linear transformations. Here, we refer to the input dimensions of \(g\) as neurons. We can break down the MLP output \(m_\ell\) into a sum over individual neuron contributions: \[ m_\ell = \sum_{i=1}^{\text{\#neurons}}\text{diag}\left(W^i_{\text{down}}\right) \cdot g^i \cdot \text{diag}\left(W^i_{\text{up}}\right) \cdot \text{diag}\left(w_{\text{norm}}\right) \cdot \frac{h^a_\ell}{s^a_\ell} \] To compute the contribution of the \(i\)-th neuron in MLP layer \(\ell\) to the logit output, we can treat \(g\) and \(s_\ell^a\) as constants and use following projection matrix: \[ P^{m_i}_\ell = P \cdot \text{diag}\left(W^i_{\text{down}}\right) \cdot \left(g^i I\right) \cdot \text{diag}\left(W^i_{\text{up}}\right) \cdot \text{diag} \left(w_{\text{norm}}\right) \cdot \left(\left(s^a_\ell\right)^{-1} I \right) \] Using this projection matrix, we can pinpoint how much a single MLP neuron contributes to the model’s final output logits. Additionally, the same projection matrix allows us to trace how residual vectors from earlier layers influence the final output logits through that specific neuron via the residual stream decomposition.

3 Examples

In this section, we apply the prisms proposed earlier to explore how the gemma-2b model works internally in two examples. We use the gemma-2b model because it’s small enough to run on a standard PC without a dedicated GPU.

Retrieving capital city. First, let’s see how the model retrieves factual information to answer this simple question:

The capital city of France is ___

The model correctly predicts ▁Paris as the most likely next token. To understand how it arrives at this prediction, we’ll use the prisms from the previous section.

We start by using the residual prism to plot how much each layer contributes to the logit output for several candidate tokens (different capital cities). Comparing the prediction logit of the right answer to reasonable alternatives can reveal important information about the network’s decision process.

Figure 3 shows each layer’s logit contribution for multiple candidate tokens. Some strong signals stand out, with large positive and negative contributions in the first and last layers. These likely mean these layers play key roles in the model’s predictions. Interestingly, there’s a strong positive contribution at the start, followed by an equally strong negative contribution in the next layer (the first attention output). This might be because the gemma-2b model’s embedding and unembedding vectors are the same. So the input token strongly predicts itself as the output (due to the nature of the dot product operation). The network has to balance this out with a strong negative contribution in the next layer.

Figure 3 B zooms in to compare logit contributions at each layer for different targets. The \(a_{15}\) contribution stands out between ▁Paris and other candidates. At this layer, the attention output aligns much more with the unembedding vector of ▁Paris than other candidates.

Figure 3: Logit contribution of each layer for different target tokens. Figure A shows the contributions of all layers, while Figure B zooms in on the contribution of the last layers.

We think \(a_{15}\) reads the correct output value from somewhere else via the attention mechanism, so we use the attention prism to decompose \(a_{15}\) into smaller pieces. Figure 4 shows how much each input token influences the output logits via the attention layer 15. The ▁France token heavily affects the output through attention head 6 of the layer, which makes very much sense as ▁France should somehow inform the network to output the correct capital city.

Figure 4: Logit contribution of each input token through attention heads at layer 15.

Next, we again use the residual prism to decompose the attention head 6 logits into smaller pieces. Figure 5 shows how the residual outputs from all previous layers at the ▁France token contribute to the output logit via attention head 6. Interestingly, the ▁France embedding vector contributes the most to the output logit. This indicates that the embedding vector of ▁France somehow already includes the information about its capital city, and this information can be read easily by attention head 6.

Figure 5: Logit contribution of all residual outputs through attention head 6 at layer 15.

One direct result of using prisms is that we have a linear projection that maps the ▁France embedding vector to capital city candidates’ unembedding vectors. We think this linear projection is meaningful not just for ▁France, but has similar effects on other country tokens too. To check this hypothesis, we apply the same projection matrix to other countries’ embedding vectors. Figure 6 shows the same matrix does indeed project other country names to their respective capitals.

This suggests that the network learns to represent country names and capital city names in such a way that it can easily transform a country embedding to the capital city unembedding using a linear projection. We hypothesize that this observation can be generalized to other relations encoded by the network as well.

Figure 6: Linear projection from country embedding to capital city’s logit.

Digit addition. Let’s explore how gemma-2b performs arithmetic by asking it to complete the following:

7+2=_

The model correctly predicts 9 as the next token. To understand how it achieves this, we employ our prisms toolbox. First, using the residual prism, we decompose the residual stream and examine the contributions of different layers for target tokens ranging from 0 to 9 (Figure 7). The MLP layer at layer 16 (m16) stands out, predicting 9 with a significantly higher logit value than other candidates. This substantial gap is unique to m16, indicating its crucial role in the model’s prediction of 9.

Figure 7: Contributions of different layers to the logit outputs of different candidates (from 0 to 9) using the residual prism.

Next, we use the MLP prism to identify which neurons in m16 drive this behavior. Decomposing m16 into contributions from its 16,384 neurons, we find that most are inactive. Extracting the top active neurons, we observe that they account for the majority of m16’s activity. Figure 8 shows these top neurons’ contributions to candidates from 0 to 9, revealing distinct patterns for each neuron. For example, neuron 10029 selectively differentiates odd and even numbers. Neuron 11042 selectively predicts 7, while neuron 12552 selectively avoids predicting 7. Neurons 15156 and 2363 show sine-wave patterns. While no single neuron dominantly predicts 9, the combined effect of these neurons’ predictions peaks at 9.

Figure 8: Top neuron contributions for different targets ranging from 0 to 9.

Note that the neurons’ contributions to the target logits are simply linear projections onto different target token unembedding vectors. The neuron activity patterns in Figure 8 are likely encoded in the target token unembeddings; as such, these patterns can be easily extracted using a linear projection. When we visualize the digit unembedding space in 2D (Figure 9), we discover that the numbers form a heart-like shape with reflectional symmetry around the 0-5 axis.

Figure 9: 2D projection of digit unembedding vectors. The embeddings are projected to 2D space using PCA. Each point represents a digit, and the points are connected in numerical order.

Our hypothesis is that transformer networks encode templates for outputs in the unembedding space. The MLP layer then selectively reads these templates based on their linear projection \(W_\text{down}\). By triggering a specific combination of neurons, each representing a template, the network ensures the logits reach their maximum value for the tokens with the highest probability.

5 Conclusion

This paper introduces logit prisms, a simple but effective way to break down transformer outputs, making them easier to interpret. With logit prisms, we can closely examine how the input embeddings, attention heads, and MLP neurons each contribute to the final output. Applying logit prisms to the gemma-2b model reveals valuable insights into how it works internally.

References

Belrose, Nora, Zach Furman, Logan Smith, Danny Halawi, Igor Ostrovsky, Lev McKinney, Stella Biderman, and Jacob Steinhardt. 2023. “Eliciting Latent Predictions from Transformers with the Tuned Lens.” https://arxiv.org/abs/2303.08112.
Brown, Tom B., Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, et al. 2020. “Language Models Are Few-Shot Learners.” https://arxiv.org/abs/2005.14165.
Elhage, Nelson, Neel Nanda, Catherine Olsson, Tom Henighan, Nicholas Joseph, Ben Mann, Amanda Askell, et al. 2021. “A Mathematical Framework for Transformer Circuits.” Transformer Circuits Thread.
Geva, Mor, Roei Schuster, Jonathan Berant, and Omer Levy. 2021. “Transformer Feed-Forward Layers Are Key-Value Memories.” https://arxiv.org/abs/2012.14913.
Mikolov, Tomas, Kai Chen, Greg Corrado, and Jeffrey Dean. 2013. “Efficient Estimation of Word Representations in Vector Space.” https://arxiv.org/abs/1301.3781.
Mirzadeh, Iman, Keivan Alizadeh, Sachin Mehta, Carlo C Del Mundo, Oncel Tuzel, Golnoosh Samei, Mohammad Rastegari, and Mehrdad Farajtabar. 2023. “ReLU Strikes Back: Exploiting Activation Sparsity in Large Language Models.” https://arxiv.org/abs/2310.04564.
Nanda, Neel, Lawrence Chan, Tom Lieberum, Jess Smith, and Jacob Steinhardt. 2023. “Progress Measures for Grokking via Mechanistic Interpretability.” https://arxiv.org/abs/2301.05217.
nostalgebraist. 2020. “Interpreting GPT: The Logit Lens.” https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens.
Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. “Attention Is All You Need.” https://arxiv.org/abs/1706.03762.
Wang, Kevin, Alexandre Variengien, Arthur Conmy, Buck Shlegeris, and Jacob Steinhardt. 2022. “Interpretability in the Wild: A Circuit for Indirect Object Identification in GPT-2 Small.” https://arxiv.org/abs/2211.00593.
Zhang, Zhengyan, Yankai Lin, Zhiyuan Liu, Peng Li, Maosong Sun, and Jie Zhou. 2021. “MoEfication: Transformer Feed-Forward Layers Are Mixtures of Experts.” https://arxiv.org/abs/2110.01786.
Zhang, Zhengyan, Yixin Song, Guanghui Yu, Xu Han, Yankai Lin, Chaojun Xiao, Chenyang Song, Zhiyuan Liu, Zeyu Mi, and Maosong Sun. 2024. “ReLU\(^2\) Wins: Discovering Efficient Activation Functions for Sparse LLMs.” https://arxiv.org/abs/2402.03804.

Footnotes

  1. The \(OV\) circuit has two matrices: \(V\) reads from the residual stream, and \(O\) writes to it.↩︎

Citation

BibTeX citation:
@misc{nguyen2024prisms,
  author = {Nguyễn, Thông},
  title = {Logit {Prisms:} {Decomposing} {Transformer} {Outputs} for
    {Mechanistic} {Interpretability}},
  date = {2024-06-17},
  langid = {en},
  abstract = {We introduce a straightforward yet effective method to
    break down transformer outputs into individual components. By
    treating the model’s non-linear activations as constants, we can
    decompose the output in a linear fashion, expressing it as a sum of
    contributions. These contributions can be easily calculated using
    linear projections. We call this approach “logit prisms” and apply
    it to analyze the residual streams, attention layers, and MLP layers
    within transformer models. Through two illustrative examples, we
    demonstrate how these prisms provide valuable insights into the
    inner workings of the `gemma-2b` model.}
}
For attribution, please cite this work as:
Nguyễn, Thông. 2024. “Logit Prisms: Decomposing Transformer Outputs for Mechanistic Interpretability.”