Flash retention – scaling attention and retention mechanisms

Share
Flash retention – scaling attention and retention mechanisms

The original attention mechanism and its scaling cost

Self‑attention, introduced around 2017, allows each token in a sequence to interact with every other token. The mechanism builds a dense attention matrix whose size grows quadratically with the sequence length. This design is elegant but expensive: the time and memory complexity of computing self‑attention is O(n²), where n is the sequence length. For moderate sequences of a few hundred tokens this quadratic cost is acceptable, but as transformers are applied to long documents, code, images and video, the quadratic cost becomes a bottleneck. Every attention operation reads and writes large intermediate matrices between high‑bandwidth memory (HBM) and on‑chip SRAM on a GPU. These I/O operations dominate runtime and memory consumption; they limit the maximum sequence length that can be processed on a single GPU and slow down training and inference. Approximate attention methods (linear attention, kernelized attention, etc.) attempt to reduce complexity by trading off some accuracy, but they often fail to deliver end‑to‑end speed‑ups and may degrade model quality.

FlashAttention 1 — IO‑aware exact attention

FlashAttention (often referred to as FlashAttention 1) was proposed by Tri Dao and collaborators in 2022 as a response to the scaling challenges of self‑attention. The key insight is that modern GPUs have a hierarchical memory system. By carefully tiling the attention computation so that each block of the query, key and value matrices fits into on‑chip SRAM, FlashAttention reduces the number of reads and writes to HBM. It performs the softmax normalization “on the fly” for each block and writes the partial results directly to the output, avoiding the need to materialize the full n×n attention matrix. Because the algorithm is IO‑aware rather than approximate, it computes the exact attention and preserves model quality.

FlashAttention offers significant efficiency improvements:

  • Memory footprint: FlashAttention reduces the memory required for attention from quadratic to linear in the sequence length (plus a small constant factor for tiling). This allows models to handle much longer sequences without running out of GPU memory.
  • Compute efficiency: By eliminating redundant memory traffic and fusing kernels, FlashAttention achieves substantial speed‑ups. The original paper reports up to a 3× speed‑up on GPT‑2 with a sequence length of 1 k and a 2.4× speed‑up on long‑range benchmarks.
  • Better utilization: The algorithm enables transformer training at higher FLOP utilization because GPU threads spend more time on arithmetic rather than waiting on memory transfers.

These improvements make FlashAttention an important milestone: it removes the memory bottleneck while preserving exactness, enabling researchers to scale transformers to tens of thousands of tokens.

FlashAttention 2 — better parallelism and work partitioning

While FlashAttention 1 delivered large gains, it still left performance on the table because the work was not optimally distributed across the GPU’s thread blocks and warps. FlashAttention 2, released in July 2023, refines the algorithm to exploit GPU parallelism more fully. The authors observed that the original implementation achieved only 25 – 40 % of the theoretical FLOP/s on A100 GPUs because of sub‑optimal partitioning. FlashAttention 2 introduces several improvements:

  1. Reducing non‑matrix operations: The algorithm reduces the number of non‑matrix‑multiplication FLOPs so that more of the GPU’s compute units are used for high‑throughput matrix operations.
  2. Parallelizing across thread blocks: Instead of assigning each attention head to a single block, FlashAttention 2 distributes the computation of a single head across multiple thread blocks. This increases occupancy and makes better use of GPU cores.
  3. Efficient warp scheduling: Within each thread block, work is split across warps in a way that minimizes synchronization and shared‑memory traffic. For instance, splitting the query matrix across four warps while keeping keys and values shared removes the need for communication between warps.

These modifications reduce shared‑memory reads/writes and improve parallel efficiency. FlashAttention 2 achieves roughly a 2× speed‑up over FlashAttention 1 and attains 50 – 73 % of the theoretical maximum FLOP/s on A100 GPUs. When applied to GPT‑style models, it delivers training speeds of up to 225 TFLOP/s per GPU. As a result, FlashAttention 2 has become the default choice for many large‑scale transformer implementations.

Beyond attention: retention layers and the “stage retention” concept

Recent research argues that attention alone is not enough for models that need to adapt over time or operate across sessions. In January 2025, researchers proposed a Retention Layer that augments transformers with a persistent memory module. Whereas attention provides a mechanism for weighting information within the current context, the retention layer stores patterns from past interactions and enables them to be recalled later. This design parallels the stages of social learning: attention, retention, reproduction and motivation. The retention module populates a memory buffer in real time, retrieves relevant templates for new inputs and guides output generation. Because the memory can persist across sessions, the model can learn incrementally without retraining and can personalize its behaviour based on past user interactions. Applications envisioned for retention layers include adaptive personal assistants, fraud detection and autonomous systems.

Although retention layers are still early research, they hint at a future where transformers combine efficient attention (via FlashAttention 2) with persistent memory. This stage retention blends fast context‑aware computation with long‑term knowledge retention, bridging the gap between static pretraining and dynamic adaptation. For technical practitioners, understanding how these mechanisms interact is crucial: efficient attention reduces the cost of processing long sequences, while retention mechanisms provide continuity across time. Together they open the door to large‑scale models that can handle long documents today and adapt to new information tomorrow.

Conclusion

The evolution from the original attention mechanism to FlashAttention 1 and FlashAttention 2 demonstrates how hardware‑aware algorithms can dramatically improve the efficiency of transformer models. The original self‑attention design suffers from quadratic time and memory complexity, limiting the sequence lengths that can be processed. FlashAttention 1 introduces tiling and online softmax to make attention IO‑aware, achieving linear memory and substantial speed‑ups. FlashAttention 2 further optimizes work partitioning and parallelism, doubling performance and closing the gap to hardware limits. Looking ahead, retention layers suggest that combining efficient attention with persistent memory may enable models that not only process long contexts efficiently but also remember and adapt over multiple sessions. For a technical audience, these developments highlight the importance of considering both algorithmic efficiency and memory dynamics when designing and deploying large‑scale neural networks.