Scaling up self-attention inference
Motivation
Imagine a transformer that could remember everything and evolve continuously!
Current transformer models do not keep improving by interacting with users or the world; they cannot retrieve past experiences or learn from them directly. One can introduce a separate memory module to keep track of past experiences, similar to what OpenAI is doing, but these textual memories cannot capture the full details of past experiences.
We aim for a model that can retrieve all of its past experiences when needed. This model can keep improving itself, learning from its mistakes, much like a human does. It can perform reasoning and planning to solve problems that require thousands of steps, over many days and months. I believe long context window inference is one of the cornerstones of AGI.
With a long context window, we can elegantly solve the RAG (Retrieval-Augmented Generation) problem, eliminating the need for a separate pipeline to retrieve relevant content for the model as we can encode all relevant information in a single input sequence.
Google DeepMind is heavily investing in this direction. Their Gemini 1.5 models support a 2M token context window. Also, magic.dev
, a startup company, recently announced their work on 100M Token Context Windows.
OK, enough of grand visions. Let’s focus on the real question: how can we scale up the context window?
In this post, I will focus on scaling up inference. For training, the reader can take a look at the ring attention line of work. For example, see Ring Attention with Blockwise Transformers for Near-Infinite Context.
Self-attention is indeed the bottleneck in scaling up the context window. During inference, self-attention is the only component in the transformer network that requires the output of all previous tokens for its computation. This is why it’s often stated that self-attention computation grows linearly with sequence length during inference.
However, there is a trick: by adding more compute, it is possible to perform self-attention in \(\mathcal{O}(\log(n))\) time steps, where \(n\) is the length of the prefix sequence. This is good news as we can increase the context length by simply increasing the compute resources.
Algorithm
It is well known that self-attention computation on a sequence can be chunked into smaller subsequences. We only need to weight the outputs of these subsequences carefully to get the attention output on the whole sequence. Doing this recursively for subsequences, all in parallel with hardware support, enables us to perform attention in \(\mathcal{O}(\log(n))\) time steps (the depth of the recursive tree).
Let’s denote the input sequence as \(t_1, t_2, \dots, t_n\). When processing \(t_n\) to predict the next token (\(t_{n+1}\)), self-attention layers perform the following computation:
- Generate 3 vectors (using linear projections): a query vector \(Q_n\), a key vector \(K_n\), and a value vector \(V_n\).
- Compute the dot product of \(Q_n\) with all previous key vectors \(K_i\) for \(1 \leq i \leq n\). The result is normalized using a softmax function:
\[ \begin{align} s_{i} &= Q_n \cdot K_i \\ a_{i} &= \frac{\exp(s_{i})}{\sum_{1 \leq i \leq n} \exp(s_{i})} \end{align} \]
- Output the weighted sum of value vectors:
\[ O(1, n) = \sum_{1 \leq i \leq n} a_{i} V_i \]
Let’s define the log-sum-exp (lse) value of a subsequence \(s_i, s_{i+1}, \dots, s_j\) as follows:
\[ \text{lse}(i, j) = \log\left[ \sum_{i \leq \alpha \leq j} \exp\left( s_\alpha \right) \right] \]
We can show that the output \(O(1, n)\) can be computed from the outputs of two subsequences \(s_1, s_2, \dots, s_m\) and \(s_{m+1}, s_{m+2}, \dots, s_n\) as follows:
\[ \begin{align} O(1, n) &= w_1 \cdot O(1, m) + w_2 \cdot O(m+1, n) \end{align} \]
where
\[ \begin{align} w_1 &= \exp\left[\text{lse}(1, m) - \text{lse}(1, n) \right] \\ w_2 &= \exp\left[\text{lse}(m+1, n) - \text{lse}(1, n) \right] \\ \text{lse}(1, n) &= \text{log-sum-exp}( \left\{\text{lse}(1, m), \text{lse}(m+1, n) \right\} ) \end{align} \]
In short, we can compute the attention output of a sequence from the attention outputs of two subsequences. We only need to keep track of the log-sum-exp values of the subsequences for correctly weighting the outputs.
By picking \(m = \lfloor \frac{n}{2} \rfloor\), and computing \(O(1, n)\) recursively and in parallel, we can compute self-attention in \(\mathcal{O}(\log(n))\) time steps.
In-memory Processing and Networking
If we implement the above parallel algorithm on CPU or GPU, there will be no speed-up compared to the naive algorithm. The reason is simple: the algorithm is memory-bounded. We are limited by the bandwidth between memory and CPU (or memory and GPU). The compute unit has to wait for input vectors (\(Q_i\) and \(K_i\)) to arrive from memory.
One important observation is that we only need to keep track of a vector and a scalar to be able to combine the result from a subsequence. If we are able to compute the vector and scalar for the subsequence locally, we effectively collapse the whole subsequence into a single vector and a scalar, significantly reducing the requirement for memory bandwidth.
Can our DRAM hardware compute self-attention locally in parallel?
Memory manufacturers have started to think about this exact same problem. There are proposals to put compute units next to memory units. For more information, see this Hot Chips 2023 session about Processing In Memory: https://www.youtube.com/watch?v=07JjXXd-0ao.
If such hardware exists, we can then perform the self-attention at the memory side and only need to transfer the query vector to memory at the beginning and get back the output vector at the end of the computation.
Sadly, there is no such DRAM hardware on the market right now. However, we can apply the same principle to a computer network. A network node can be a DRAM memory connected to a CPU or an HBM memory connected to a GPU. We then connect multiple nodes together to perform self-attention.
Even though the network bandwidth is lower than the local memory bandwidth, we only need to transfer a query vector at the beginning and an output vector at the end of the computation. Most of the computation happens locally. We significantly boost memory bandwidth by using multiple computers to process in parallel.
Conclusion
That’s it! That is how you can scale up the context window as large as you want: all you need is enough GPUs (or TPUs) that are connected by a low-latency network.
It is almost certain that the context window will keep growing significantly in the near future. Soon we will have models with 10M, 100M, or even 1B token context windows.
How to train the model to take advantage of the long context window for self-improving, self-correcting, planning, reasoning, etc., is still an open problem!