Post

KV cache – The how not to waste your FLOPS starter

You've probably heard of the Transformers by now, they're everywhere, so much so that new born babies are gonna start saying Transformers as their first word, this blog will explore an important component that makes their inference cost much more managable - The KV cache

KV cache – The how not to waste your FLOPS starter

Attention refresher

The core idea behind self-attention is simple. Given a sequence of $ N $ tokens, for each token embedding we generate three vector representations for each token $ i $ in the sequence: $ q_i $ (query), $ k_i $ (key), and $ v_i $ (value). Collectively, these form the matrices $ Q $, $ K $, and $ V $.

To get the attention scores you’ll then need to:

  1. Compute $ Q \cdot K^T $ (essentially, how similar each token is to every other token).
  2. Apply a causal mask if needed (this ensures predictions don’t cheat by looking ahead).
  3. Rescale the scores and apply a softmax row-wise.
  4. Use these probabilities to create weighted combinations of the values in $ V $.

At each row i of V, we have a contextualized representation of the tokens 0-i. Now that we have this context, we apply some layer normalization, and linear projection, so we can go back to $ \texttt{seq_len} \times \texttt{vocab_size} $. This allows us to take the argmax along the vocab dimension, and you’ve got a prediction for each token at each time step.

Because of this, during training for autoregressive generation, there’s a huge parallelization of the work that’s done, as I just said, at each time step, we’re predicting the tokens for every position in the sequence, let’s clarify this with an example: Say we have the sentence: “Marseille is great”

light mode only dark mode only

If you still can’t see it, I’d really encourage you to compute a 2 by 2 output matrix by hand.

Training efficency’s side effects

If you look at how we the contextualized representation is being used, you can perhaps already see where things can go wrong during inference, when trying to guess what can come after “great”, we only needed the latest vector, and yet during the calculations, we generated all of them…

light mode only dark mode only

If you’ve read some of my previous blogs, you know what’s comming up: reverse engineering :)), let’s got back to how these embeddings were generated, and see where we can scrap some Floating-point operations (FLOPS)

Useless Queries, Useful Keys, and Values

Let’s begin by identifying inefficiencies in the calculations of the attention mechanism during autoregressive inference. Specifically, we’ll analyze the computation of:

\[\text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}} + M\right)\]

and pinpoint the unnecessary operations.

Step 1: Calculating the Attention Matrix

Take a look at the process for computing the last row of the attention matrix:

light mode only
dark mode only

From this visualization, it is clear that to calculate the last row of the attention scores, we only need:

  • The latest query (e.g., corresponding to “Marseille”),
  • All the keys (from all tokens in the sequence).

However, most implementations recompute the full matrix $QK^\top$, which is unnecessary when only the last row is used during inference.

Step 2: Calculating the Contextualized Output

Next, consider the computation of the output matrix $O \cdot V$, where $O$ represents the attention matrix we just computed:

light mode only
dark mode only

Here, to generate the final contextualized representation for the prediction, we only need:

  • The latest row of $O$,
  • The entire values matrix $V$.

Despite this, the naive implementation processes the entire matrix, again… wasted computations.


Summary of Findings

When generating the next token, such as predicting what follows “great” in the sentence “Marseille is great”, we found that:

  • The model needs all keys corresponding to past tokens.
  • The model also requires all values from previous tokens.
  • Only the latest query (from “great”) is needed.

Attention Mechanism Computational Overhead

Consider a transformer processing a sequence with total tokens $\texttt{seq_len}$, where query and key embedding dimensions are $d_k$ and value embedding dimensions are $d_v$. The primary inefficiencies manifest in two critical computational stages: attention score computation and contextualized output generation.

In the standard transformer implementation, the attention computation $\text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}} + M\right)$ exhibits significant inefficiencies. The naive approach generates a complete $\texttt{seq_len} \times \texttt{seq_len}$ attention matrix at each inference step, resulting in a computational complexity of $O(\texttt{seq_len}^2 \cdot d_k)$.

By contrast, an optimized implementation requires only the last query vector, leveraging previously computed key matrices. This approach reduces the computational complexity to $O(\texttt{seq_len} \cdot d_k)$. The difference represents a computational waste of $O((\texttt{seq_len}-1) \cdot \texttt{seq_len} \cdot d_k)$ FLOPS, which becomes increasingly significant for longer sequences.

For an $m \times n$ matrix multiplied by an $n \times p$ matrix, the computational cost is $mnp$ FLOPS

Similar inefficiencies plague the output computation $O \cdot V$. i.e this optimization eliminates another $O((\texttt{seq_len}-1) \cdot \texttt{seq_len} \cdot d_v)$ unnecessary FLOPS.

Optimization Pathway

The computational complexity of $O(\texttt{seq_len}^2 \cdot (d_k + d_v))$ can be significantly reduced to $O(\texttt{seq_len} \cdot (d_k + d_v))$ through several strategic optimizations:

  • Incremental Attention Computation: Calculating only the required components (e.g., the latest row) instead of recomputing the entire attention matrix.
  • Caching Intermediate Results: Reusing previously computed keys and values to avoid redundant calculations.
  • Lazy Evaluation: Deferring computations to execute only what is strictly necessary for generating the next token.
  • Leveraging Hardware Optimizations: Exploiting faster memory resources, such as GPU high-bandwidth memory, to store cached keys and values efficiently—hence the term “cache.”

By combining these strategies, transformer inference can achieve substantial gains in both speed and efficiency.


Let’s illustrate this with an example, consider a transformer with the following configuration:

  • Sequence Length $\texttt{seq_len} = 512$,
  • Query/Key Dimension $d_k = 64$,
  • Value Dimension $d_v = 128$.
Computation StageNaive FLOPSOptimized FLOPSSavings
Attention Score Computation$1.68 \times 10^7$$3.28 \times 10^4$$1.67 \times 10^7$
Contextual Output Computation$3.36 \times 10^7$$6.55 \times 10^4$$3.35 \times 10^7$

Key-value caching achieves substantial computational savings, reducing complexity from $O(\texttt{seq_len}^2)$ to $O(\texttt{seq_len})$ for both attention and output computations.

Conclusion

By recognizing and eliminating redundant computations, we were able to dramatically enhance the computational efficiency of transformer models without compromising their fundamental representational capabilities.

This post is licensed under CC BY 4.0 by the author.