Post

Doing MORE To consume LESS – Flash Attention V1

Flash Attention played a major role in making LLMs more accessible to consumers. This algorithm embodies how a set of what one might consider "trivial ideas" can come together and form a powerful solution. It highlights how, even in ML, software and hardware should not be thought of as disjoint paradigms.

Doing MORE To consume LESS – Flash Attention V1

Nothing is Free: Why Memory Matters

You click to open a large file, and for a moment, nothing happens. Maybe you tap impatiently, wondering why it’s taking so long. But beneath the surface, your computer is doing exactly what it must: fetching data from storage, loading it into memory, and preparing it for use. That delay? It’s the cost of moving data.

Now, imagine a game loading a massive open world. If every texture, model, and sound had to be fetched directly from your SSD every time it was needed, gameplay would be unbearable. That’s why modern systems use a memory hierarchy: frequently accessed data is kept closer to the processor (in RAM, caches, and registers), while larger, less-used data stays in slower storage.

light mode only

dark mode only

This principle is even more critical in GPUs. While CPUs handle a few complex tasks at a time, GPUs execute thousands of operations in parallel. Every millisecond spent waiting for memory to deliver data slows everything down. That’s why GPUs have multiple levels of memory—registers, shared memory, global memory—each balancing speed, size, and accessibility. Efficiently managing this hierarchy is crucial for high-performance computing, making it especially relevant in deep learning.

Faster memory is more expensive because it requires more sophisticated materials, tighter manufacturing tolerances, and complex architectures to reduce latency and increase bandwidth. Additionally, higher-speed memory often consumes more power and generates more heat, adding to design and cooling costs.

Why Deep Learning Needed GPUs

Deep learning is fundamentally a game of matrix multiplications and tensor operations, tasks that involve performing millions (or billions) of simple calculations in parallel. Traditional CPUs, designed for sequential processing, struggle with this kind of workload. Even with multiple cores, they simply don’t have the architecture to handle the massive parallelism required for deep learning efficiently.

GPUs, on the other hand, were built for exactly this kind of problem. Originally designed for rendering graphics, they excel at performing many small, independent computations at once. Their thousands of cores, combined with high-memory bandwidth, make them ideal for accelerating neural network training, where massive amounts of data must be processed simultaneously.

The shift to GPU computing wasn’t immediate, but once researchers realized that deep learning workloads closely resembled the kinds of parallel tasks GPUs were already optimized for, the adoption became inevitable.

Nvidia stock prices

Memory I/O Optimization

A major bottleneck in deep learning computations is memory access. While GPUs excel at performing massive amounts of parallel computation, they are often limited by memory bandwidth. Even with high-bandwidth memory (HBM) reaching speeds of ~15 TB/s, this is still a constraint when dealing with petaflops of computation. The key challenge is that matrix multiplications—fundamental to deep learning—require frequent memory access, and fetching data repeatedly is costly.

Consider multiplying two matrices:

\[\begin{equation} A \in \mathbb{R}^{M \times N}, \quad B \in \mathbb{R}^{N \times P}, \quad C = A \cdot B, \quad C \in \mathbb{R}^{M \times P} \end{equation}\]

Computing a single element $( C_{0,0} )$ requires loading an entire row of $( A )$ ($( N )$ elements) and an entire column of $( B )$ ($( N )$ elements), totaling $( 2N )$ memory accesses. For the full matrix multiplication, this scales to about:

\[\begin{equation} 2 \times N \times M \times P \end{equation}\]

memory operations. Given the memory bandwidth limits, this frequent data movement can slow down computation significantly.

Specifically, in the case of the Transformer’s attention mechanism, this means loading a row of attention scores from memory each time it is used to compute a single value in the contextualized representations.

light mode only

dark mode only

Tiling: Reducing Redundant Memory Access

To mitigate this, we use tiling (or block matrix multiplication), which reduces the number of times data must be fetched from memory. Instead of computing each element independently and reloading data every time, we divide the matrices into smaller tiles (blocks) that fit in faster cache memory (more precisely, SRAM in the case of GPUs).

For simplicity, consider square matrices $( A, B \in \mathbb{R}^{N \times N} )$, where $( N )$ is split into blocks of size $( K \times K )$ (i.e., $( N = K \cdot J )$ for some integer $( J )$). Instead of multiplying entire rows and columns at once, we:

  1. Load a $( K \times K )$ tile from $( A )$ and a $( K \times K )$ tile from $( B )$ into cache.
  2. Perform a partial matrix multiplication on these blocks and accumulate the results in cache.
  3. Repeat the process for the next set of blocks until all contributions to $( C )$ are computed.

Since each block remains in cache and is reused $( K )$ times before being evicted, the total memory load is reduced from $( 2 \times N \times M \times P )$ to approximately:

\[\begin{equation} \frac{2 \times N \times M \times P}{K} \end{equation}\]

This means we make far fewer slow memory accesses while keeping GPU cores busy with computation, significantly improving efficiency.

But what about attention mechanisms? When computing contextualized representations, we need to store intermediate values such as attention scores somewhere in memory. Ideally, we would retain a block of these scores to minimize redundant memory operations. For simplicity, we assume entire rows are kept in memory, but in practice, this might not be feasible. As we’ve noted before, faster memory comes at the cost of capacity, meaning we must carefully manage what gets stored and when.

light mode only

dark mode only

The Quest For A Numerically Stable Softmax

Returning to the computation of scaled attention scores, the first step is to calculate the dot product between the queries and keys, and then we incorporate an optional additive mask $\mathbb{M}$ to introduce causality (for training Decoder models):

\[\frac{QK^\top}{\sqrt{d_k}} + M\]

we apply the softmax function to obtain the attention weights. Formally, the softmax operation is defined as:

\[\begin{equation} \text{softmax}(x_k) = \frac{e^{x_k}}{\sum_{i=1}^N e^{x_i}} \end{equation}\]

The issue is that directly computing softmax can introduce numerical instability due to the exponential function’s rapid growth. Specifically, in low-precision formats like 16-bit floating point (fp16), overflow can occur when exponentiating large values.

For example, if we use floating point 16-bit (fp16) representation, which has 5 bits for the exponent and 10 bits for the mantissa, the dynamic range is approximately \([-65504, 65504]\). If we have an \(x\) where \(x > \ln(65504) \approx 11\), the result of \(e^x\) will exceed the representable range in fp16.

light mode only

dark mode only

To mitigate this, we normalize by subtracting the maximum value in \(X = [x_1, \dots, x_N]\) from each element, both in the numerator and the denominator. Letting \(\text{max}_x = \max(X)\), we compute:

\[\begin{equation} \text{softmax}(x_k) = \frac{e^{x_k - \text{max}_x}}{\sum_{i=1}^N e^{x_i - \text{max}_x}} \end{equation}\]

This adjustment preserves the proportions but scales down the values, keeping the numerator and denominator within a manageable range for fp16, typically between 0 and 10.

While this stabilization enhances numerical robustness, it introduces additional computational overhead (3 Loops):

  • Finding the maximum value.
  • Summing over exponentials.
  • Computing the final probabilities.

Each of these steps necessitates memory accesses which, as we have already established can become a limiting factor in performance.

In python this would look something like this.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import numpy as np

def softmax_stable(x):
    """Calculates softmax using the numerically stable algorithm."""
    N = len(x)
    m = [0] * (N + 1)
    d = [0] * (N + 1)
    a = [0] * (N + 1)
    
    # And here goes the stinky loops
    m[0] = float('-inf')
    # find the max
    for i in range(1, N + 1):
        m[i] = max(m[i-1], x[i-1])

    d[0] = 0
    # calculate the sum (denominator)
    for i in range(1, N + 1):
        d[i] = d[i-1] + math.exp(x[i-1] - m[N])

    # calculate the attention scores
    for i in range(1, N + 1):
        a[i] = math.exp(x[i-1] - m[N]) / d[N]
    return a[1:]  # Return a[1:] to exclude the dummy a[0]

def torch_softmax(x):
    """Calculates softmax using torch.softmax."""
    x_tensor = torch.tensor(x, dtype=torch.float64)  # Convert to tensor
    return torch.softmax(x_tensor, dim=0).numpy()  # Calculate softmax and convert back to numpy


# Example usage and comparison (which would have overflown (13) without the stablity trick):
x = [2, 5, 1, 13, 3]

my_softmax = softmax_stable(x)
torch_softmax_result = torch_softmax(x)

print("Numerically Stable Softmax:", my_softmax)
print("Torch Softmax:", torch_softmax_result)

np.testing.assert_allclose(my_softmax, torch_softmax_result, rtol=1e-9) 

print("Assertion passed!")

Interestingly, an alternative approach known as online softmax offers a different trade-off. While it actually performs about 25% more arithmetic operations than the standard approach, it provides significant benefits in terms of memory efficiency by reducing redundant memory accesses and enabling streaming computation.

Online softmax

The principle behind online softmax is to restructure the denominator computation so that intermediate results can be reused iteratively. Instead of recomputing the full summation at each step, we reformulate it in terms of previously computed values.

Let $(X = [x_1, x_2, \dots, x_N])$. Instead of computing the denominator as:

\[d = \sum_{j=1}^{N} e^{x_j - \max(X)}\]

we introduce an incremental computation using a running maximum $( m_i )$ and a partial sum $( d_i )$:

  • $( m_i )$ is the maximum of the first $( i )$ elements:

    \[m_i = \max(x_1, x_2, \dots, x_i)\]
  • $( d_i )$ is the sum of exponentiated values up to index $( i )$, normalized by $( m_i )$:

    \[d_i = \sum_{j=1}^{i} e^{x_j - m_i}\]

Rather than recomputing the entire denominator at each step, we update it incrementally:

\[d_{i+1} = d_i \cdot e^{m_i - m_{i+1}} + e^{x_{i+1} - m_{i+1}}\]

where $( m_{i+1} = \max(m_i, x_{i+1}) )$. This recurrence relation enables us to iteratively refine the denominator while ensuring numerical stability.

Example Computation

Consider $( X = [\ln(4), \ln(8), \ln(16)] )$:

  1. Step 1 ($( i=1 )$):
    • $( m_1 = x_1 = \ln(4) )$
    • $( d_1 = e^{x_1 - m_1} = e^0 = 1 )$
  2. Step 2 ($( i=2 )$): -$ ( m_2 = \max(m_1, x_2) = \max(\ln(4), \ln(8)) = \ln(8) )$
    • $( d_2 = d_1 \cdot e^{m_1 - m_2} + e^0 )$
    • $( d_2 = 1 \cdot e^{\ln(4) - \ln(8)} + 1 = \frac{1}{2} + 1 = 1.5 )$
  3. Step 3 ($( i=3 )$):
    • $( m_3 = \max(m_2, x_3) = \max(\ln(8), \ln(16)) = \ln(16) )$
    • $( d_3 = d_2 \cdot e^{m_2 - m_3} + e^0 )$
    • $( d_3 = 1.5 \cdot e^{\ln(8) - \ln(16)} + 1 = 1.5 \cdot \frac{1}{2} + 1 = 1.75 )$

By iterating in this fashion, we incrementally build up the denominator while avoiding large exponentiated values. This formulation not only enhances numerical stability by preventing overflows but also improves memory efficiency by reducing redundant operations in large-scale computations.

This saves us one extra for loop, as can be clearly seen in this code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import numpy as np
import math

def softmax_online(x):
    """Calculates softmax using the online algorithm."""
    N = len(x)
    m = [0] * (N + 1)
    d = [0] * (N + 1)
    a = [0] * (N + 1)

    m[0] = float('-inf')
    d[0] = 0
    
    # See how we fused the first and second loops?
    for i in range(1, N + 1):
        m[i] = max(m[i-1], x[i-1])
        if i == 1:
            d[i] = math.exp(x[i-1] - m[i])
        else:
            d[i] = d[i-1] * math.exp(m[i-1] - m[i]) + math.exp(x[i-1] - m[i])

    for i in range(1, N + 1):
        a[i] = math.exp(x[i-1] - m[N]) / d[N]
    return a[1:]

def torch_softmax(x):
    """Calculates softmax using torch.softmax."""
    x_tensor = torch.tensor(x, dtype=torch.float64)
    return torch.softmax(x_tensor, dim=0).numpy()

# Example usage and comparison:
x = [2, 5, 1, 13, 3]

online_softmax_result = softmax_online(x)
torch_softmax_result = torch_softmax(x)

np.testing.assert_allclose(online_softmax_result, torch_softmax_result, rtol=1e-9)

print("Assertions passed!")

Fully Online Attention Computation

Now that we have established the online softmax formulation, let us examine how it integrates into the full attention mechanism computation. The first key step in attention involves computing the raw attention scores.

Instead of explicitly materializing the attention scores (thus needing less memory!) $( S = \frac{QK^\top}{\sqrt{d_k}} )$, we compute them on the fly while updating the softmax denominator in a single pass.

For each query $( Q_i )$:

  1. Initialize $( m = -\infty )$ and $( d = 0 )$ (running max and denominator).
  2. Iterate over keys $( K_j )$:
    • Compute $( S_{ij} )$ on the fly:
    \[S_{ij} = \frac{Q_i \cdot K_j^\top}{\sqrt{d_k}}\]
    • Update ( m ) and ( d ) using the online softmax recurrence.
  3. Compute final softmax probabilities as:
    \(A_{ij} = e^{S_{ij} - m} / d\)

  4. Apply the attention weight directly to values:
\[O_i = \sum_j A_{ij}V_j\]

the following python code describes what’s going on

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import math

def attention_online(Q, K, V):
    """Computes attention scores and applies them to V in a fully (almost) online manner."""
    N, d_k = Q.shape
    _, d_v = V.shape
    O = torch.zeros(N, d_v)

    for i in range(N):
        m = float('-inf')  # Running max
        d = 0.0  # Running denominator

        # Compute softmax online while computing S_ij on the fly
        for j in range(N):
            S_ij = (Q[i] @ K[j].T) / math.sqrt(d_k)  # Compute S_ij
            m = max(m, S_ij)
            d = d * math.exp(m - S_ij) + math.exp(S_ij - m)

        # Compute final probabilities and apply to values
        for j in range(N):
            S_ij = (Q[i] @ K[j].T) / math.sqrt(d_k)  # Compute S_ij
            A_ij = math.exp(S_ij - m) / d  # Softmax probability
            O[i] += A_ij * V[j]  # Weighted sum of values

    return O

The current implementation has an issue: duplicate computation. We’re calculating $S_{ij}$ twice for each (i,j) pair, once during the max/denominator calculation and once during the final probability computation. This doubles the computational cost unnecessarily, and we still have two loops.

This is because we’re not streaming the calculations of the output matrix, but are instead waiting for the calculation of the denominator (softmax sum). This can be addressed very similarly to what we’ve done before regarding the calculations of the attention scores:

\[O_{i+1} = O_{i} \cdot \frac{d_{\text{i}}}{d_{\text{i+1}}} e^{m_{\text{i}} - m_{i+1}} + A_{ij} V_j\]

This implies that only at the final iteration will we have the correct value of $O_{i+1}.

And finally, we have shrunk the complexity back to using only one inner for loop, saving us the memory cost of materializing the full scores and/or having to recompute them:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import math

def attention_online(Q, K, V):
    """Computes attention using fully online softmax with final denominator correction."""
    N, d_k = Q.shape
    _, d_v = V.shape
    O = torch.zeros(N, d_v, dtype=Q.dtype)

    for i in range(N):
        m = float('-inf')  # Running max
        d = 0.0  # Final denominator
        O_i = torch.zeros(d_v, dtype=Q.dtype)  # Output accumulator

        # Fully fused loop: compute S_ij, update m & d, apply A_ij to V inline
        for j in range(N):
            S_ij = (Q[i] @ K[j].T) / math.sqrt(d_k)  # Compute S_ij
            m_prev, d_prev = m, d  # Store previous values
            m = max(m, S_ij)  # Update max
            d = d * math.exp(m_prev - m) + math.exp(S_ij - m)  # Update denominator

            # Compute probability using final denominator
            A_ij = math.exp(S_ij - m) / d  
            
            # Apply attention in a numerically stable way
            O_i = O_i * (d_prev / d) * math.exp(m_prev - m) + A_ij * V[j]  

        O[i] = O_i  # Store computed output

    return O

Material

This post is licensed under CC BY 4.0 by the author.