# 1 Introduction

The logit lens (nostalgebraist 2020) is a simple yet powerful tool for understanding how transformer models (Vaswani et al. 2017; Brown et al. 2020) make decisions. In this work, we extend the logit lens approach in a mathematically rigorous and effective way. By treating certain parts of the network activations as constants, we can leverage the linear properties within the network to break down the logit output into individual component contributions. Using this principle, we introduce simple “prisms” for the residual stream, attention layers, and MLP layers. These prisms allow us to calculate how much each component contributes to the final logit output.

Our approach can be thought of as applying a series of prisms to the transformer network. Each prism in the sequence splits the logits from the previous prism into separate components. This enables us to see how different parts of the model—such as attention heads, MLP neurons, or input embeddings—influence the final output.

To showcase the power of our method, we present two illustrative examples:

In the first example, we examine how the gemma-2b model performs the simple factual retrieval task of retrieving a capital city from a country name. Our findings suggest that the model learns to encode information about country names and their capital cities in a way that allows the network to easily convert country embeddings into capital city unembeddings through a linear projection.

The second example explores how the gemma-2b model adds two small numbers (ranging from 1 to 9). We uncover interesting insights into the workings of MLP layers. The network predicts output numbers using interpretable templates learned by MLP neurons. When multiple neurons are activated simultaneously, their predictions interfere with each other, ultimately producing a final prediction that peaks at the correct number.

# 2 Method

We introduce simple “prisms” that allow us to break down the output of transformer networks into individual components. The key idea is to treat nonlinear activations as constants, which enables us to calculate the contribution of any component in the network using a series of linear transformations. In the following subsections, we explore this approach in more detail.

## 2.1 Residual Stream Decomposition

In a typical decoder-only transformer architecture (see Figure 2), the output logit can be expressed as: \[ \text{logits} = W_\text{unembed} \cdot \text{diag}\left(w_{\text{norm}}\right) \cdot \frac{w_{\text{embed}} + \sum_{i=1}^N (a_i + m_i)}{s} \] Here, \(W_\text{unembed} \in \mathbb{R}^{V\times d}\) is the unembedding matrix, \(w_\text{embed}\) is the token embedding vector, \(a_i\) and \(m_i\) denote the attention and MLP outputs at \(i\)-th layer respectively, \(s\) is a normalization factor (usually the root mean square of the denominator), and \(w_{\text{norm}}\) is the scaling vector of the last normalization layer.

By treating \(s\) as a constant, we can break down the \(\text{logits}\) into individual terms for each embed, attention and MLP layer: \[ \text{logits} = P \cdot w_{\text{embed}} + P \cdot a_1 + P \cdot m_1 + \cdots + P \cdot a_N + P \cdot m_N \] where \(P\) is the projection matrix: \[ P = W_{\text{unembed}} \cdot \text{diag}\left(w_{\text{norm}}\right) \cdot \left(s^{-1} I\right) \] Each term, like \(P \cdot a_1\), represents the individual contribution of the corresponding residual vector (\(a_1\)) to the final logit score. The cumulative sum \(P \cdot w_{\text{embed}} + \sum_{i=1}^\ell P \cdot \left( a_i + m_i \right)\) gives the model’s logit score up to layer \(\ell\).

## 2.2 Attention Decomposition

We can naturally break down the attention output into a sum over its attention heads: \[
a_\ell = \sum_{i=1}^{\text{\#heads}} H_\ell^i
\] Here, \(H_\ell^i\) is the \(i\)-th attention head’s output at layer \(\ell\). Additionally, an attention head’s output is a weighted sum of the \(OV\) circuit outputs for each token in the sequence.^{1} The attention output at layer \(\ell\) is: \[
H^i_\ell = \sum_{p} \alpha_p \cdot OV_\ell^{i} \cdot \text{diag}\left(w_{\text{norm}}\right) \cdot \frac{h_{\ell - 1}^p}{s^p_{\ell-1}}
\] In this formula, \(\alpha_p\) is the attention weight from the current token to the previous token \(p\). \(OV_\ell^i\) is the OV matrix for the \(i\)-th attention head at layer \(\ell\). \(h_{\ell-1}^p\) is the hidden state of token \(p\) from the previous layer \(\ell-1.\) Finally, \(s^p_{\ell-1}\) represents the normalization factor.

By treating the attention weight \(\alpha_p\) and normalization factor \(s^p_{\ell-1}\) as constants, we can express the attention output as a sum of linear projections of the previous layer’s hidden state \(h_{\ell-1}^p\). To find the contribution of token \(p\) via attention head \(i\), we use the following projection matrix: \[ P^{a_i}_\ell = P \cdot \left(\alpha_p I\right) \cdot OV_\ell^{i} \cdot \text{diag}\left(w_{\text{norm}}\right) \cdot \left({s^p_{\ell-1}}^{-1} I \right) \] Note that we can further decompose \(h_{\ell-1}^p\) into the sum of the token embedding and all previous layers’ residual outputs using residual stream decomposition.

## 2.3 MLP Decomposition

In transformer networks, the MLP layers consist of two linear transformations, \(W_{\text{up}}\) and \(W_{\text{down}}\), with a non-linear function \(g\) applied between them. The output of an MLP layer \(\ell\) can be expressed as: \[ m_\ell = W_{\text{down}} \cdot g \cdot W_{\text{up}} \cdot \text{diag}\left(w_{\text{norm}}\right) \cdot \frac{h^a_\ell}{s^a_\ell} \] Here, \(h^a_\ell\) is the hidden state from the previous attention layer, \(w_{\text{norm}}\) is a scaling vector, and \(s^a_\ell\) is a normalization factor.

The nonlinear point-wise function \(g\) allows neural networks to learn complex transformations that cannot be represented by linear transformations. Here, we refer to the input dimensions of \(g\) as neurons. We can break down the MLP output \(m_\ell\) into a sum over individual neuron contributions: \[ m_\ell = \sum_{i=1}^{\text{\#neurons}}\text{diag}\left(W^i_{\text{down}}\right) \cdot g^i \cdot \text{diag}\left(W^i_{\text{up}}\right) \cdot \text{diag}\left(w_{\text{norm}}\right) \cdot \frac{h^a_\ell}{s^a_\ell} \] To compute the contribution of the \(i\)-th neuron in MLP layer \(\ell\) to the logit output, we can treat \(g\) and \(s_\ell^a\) as constants and use following projection matrix: \[ P^{m_i}_\ell = P \cdot \text{diag}\left(W^i_{\text{down}}\right) \cdot \left(g^i I\right) \cdot \text{diag}\left(W^i_{\text{up}}\right) \cdot \text{diag} \left(w_{\text{norm}}\right) \cdot \left(\left(s^a_\ell\right)^{-1} I \right) \] Using this projection matrix, we can pinpoint how much a single MLP neuron contributes to the model’s final output logits. Additionally, the same projection matrix allows us to trace how residual vectors from earlier layers influence the final output logits through that specific neuron via the residual stream decomposition.

# 3 Examples

In this section, we apply the *prisms* proposed earlier to explore how the `gemma-2b`

model works internally in two examples. We use the gemma-2b model because it’s small enough to run on a standard PC without a dedicated GPU.

**Retrieving capital city.** First, let’s see how the model retrieves factual information to answer this simple question:

`The capital city of France is ___`

The model correctly predicts `▁Paris`

as the most likely next token. To understand how it arrives at this prediction, we’ll use the prisms from the previous section.

We start by using the residual prism to plot how much each layer contributes to the logit output for several candidate tokens (different capital cities). Comparing the prediction logit of the right answer to reasonable alternatives can reveal important information about the network’s decision process.

Figure 3 shows each layer’s logit contribution for multiple candidate tokens. Some strong signals stand out, with large positive and negative contributions in the first and last layers. These likely mean these layers play key roles in the model’s predictions. Interestingly, there’s a strong positive contribution at the start, followed by an equally strong negative contribution in the next layer (the first attention output). This might be because the `gemma-2b`

model’s embedding and unembedding vectors are the same. So the input token strongly predicts itself as the output (due to the nature of the dot product operation). The network has to balance this out with a strong negative contribution in the next layer.

Figure 3 B zooms in to compare logit contributions at each layer for different targets. The \(a_{15}\) contribution stands out between `▁Paris`

and other candidates. At this layer, the attention output aligns much more with the unembedding vector of `▁Paris`

than other candidates.

We think \(a_{15}\) reads the correct output value from somewhere else via the attention mechanism, so we use the attention prism to decompose \(a_{15}\) into smaller pieces. Figure 4 shows how much each input token influences the output logits via the attention layer 15. The `▁France`

token heavily affects the output through attention head 6 of the layer, which makes very much sense as `▁France`

should somehow inform the network to output the correct capital city.

Next, we again use the residual prism to decompose the attention head 6 logits into smaller pieces. Figure 5 shows how the residual outputs from all previous layers at the `▁France`

token contribute to the output logit via attention head 6. Interestingly, the `▁France`

embedding vector contributes the most to the output logit. This indicates that the embedding vector of `▁France`

somehow already includes the information about its capital city, and this information can be read easily by attention head 6.

One direct result of using prisms is that we have a linear projection that maps the `▁France`

embedding vector to capital city candidates’ unembedding vectors. We think this linear projection is meaningful not just for `▁France`

, but has similar effects on other country tokens too. To check this hypothesis, we apply the same projection matrix to other countries’ embedding vectors. Figure 6 shows the same matrix does indeed project other country names to their respective capitals.

This suggests that the network learns to represent country names and capital city names in such a way that it can easily transform a country embedding to the capital city unembedding using a linear projection. We hypothesize that this observation can be generalized to other relations encoded by the network as well.

**Digit addition.** Let’s explore how `gemma-2b`

performs arithmetic by asking it to complete the following:

`7+2=_`

The model correctly predicts `9`

as the next token. To understand how it achieves this, we employ our prisms toolbox. First, using the residual prism, we decompose the residual stream and examine the contributions of different layers for target tokens ranging from 0 to 9 (Figure 7). The MLP layer at layer 16 (m16) stands out, predicting `9`

with a significantly higher logit value than other candidates. This substantial gap is unique to m16, indicating its crucial role in the model’s prediction of `9`

.

Next, we use the MLP prism to identify which neurons in m16 drive this behavior. Decomposing m16 into contributions from its 16,384 neurons, we find that most are inactive. Extracting the top active neurons, we observe that they account for the majority of m16’s activity. Figure 8 shows these top neurons’ contributions to candidates from 0 to 9, revealing distinct patterns for each neuron. For example, neuron 10029 selectively differentiates odd and even numbers. Neuron 11042 selectively predicts `7`

, while neuron 12552 selectively avoids predicting `7`

. Neurons 15156 and 2363 show sine-wave patterns. While no single neuron dominantly predicts `9`

, the combined effect of these neurons’ predictions peaks at `9`

.

Note that the neurons’ contributions to the target logits are simply linear projections onto different target token unembedding vectors. The neuron activity patterns in Figure 8 are likely encoded in the target token unembeddings; as such, these patterns can be easily extracted using a linear projection. When we visualize the digit unembedding space in 2D (Figure 9), we discover that the numbers form a heart-like shape with reflectional symmetry around the `0`

-`5`

axis.

Our hypothesis is that transformer networks encode templates for outputs in the unembedding space. The MLP layer then selectively reads these templates based on their linear projection \(W_\text{down}\). By triggering a specific combination of neurons, each representing a template, the network ensures the logits reach their maximum value for the tokens with the highest probability.

# 5 Conclusion

This paper introduces logit prisms, a simple but effective way to break down transformer outputs, making them easier to interpret. With logit prisms, we can closely examine how the input embeddings, attention heads, and MLP neurons each contribute to the final output. Applying logit prisms to the `gemma-2b`

model reveals valuable insights into how it works internally.

## References

*Transformer Circuits Thread*.

## Footnotes

The \(OV\) circuit has two matrices: \(V\) reads from the residual stream, and \(O\) writes to it.↩︎

## Citation

```
@misc{nguyen2024prisms,
author = {Nguyễn, Thông},
title = {Logit {Prisms:} {Decomposing} {Transformer} {Outputs} for
{Mechanistic} {Interpretability}},
date = {2024-06-17},
langid = {en},
abstract = {We introduce a straightforward yet effective method to
break down transformer outputs into individual components. By
treating the model’s non-linear activations as constants, we can
decompose the output in a linear fashion, expressing it as a sum of
contributions. These contributions can be easily calculated using
linear projections. We call this approach “logit prisms” and apply
it to analyze the residual streams, attention layers, and MLP layers
within transformer models. Through two illustrative examples, we
demonstrate how these prisms provide valuable insights into the
inner workings of the `gemma-2b` model.}
}
```