FlashAttention-3:快速、准确的异步和低精度注意力
FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision

原始链接: https://www.together.ai/blog/flashattention-3

在最近的进展中,引入了名为 FlashAttention-3 的新版 FlashAttention 算法,以提高图形处理单元 (GPU) 上大型语言模型 (LLM) 中注意力机制的速度。 该版本利用了现代硬件(特别是 Hopper GPU)的新功能,利用重叠计算和数据移动、交错分块矩阵乘法和 softmax 运算以及支持较低精度计算的不相干处理等技术。 FlashAttention-3 以 H100 GPU 理论最大能力的 75% 运行,展示了 GPU 利用率的提高,而 FlashAttention-2 之前达到的仅为 35%。 这使得法学硕士的训练和执行速度更快,速度提高了 1.5-2 倍。 此外,FlashAttention-3 在处理较低精度数字 (FP8) 时可提供更好的性能,从而实现更快的处理并可能减少内存使用量,从而有助于节省成本并提高大规模 AI 操作的效率。 最后,新方法使人工智能模型能够更有效地处理法学硕士中的较长上下文,使应用程序能够理解和生成更长、更复杂的文本,而不会损失性能。 FlashAttention-3 的源代码可在 GitHub 上访问,更多信息可在相关研究论文中找到。 关键点: - 高效的 GPU 利用率:使用 H100 GPU 最大功能的 75%,为训练和执行 LLM 提供更快的速度。 - 精度较低,性能更好:允许使用精度较低的数字 (FP8),同时仍确保高精度,从而加快处理速度,并可能减少内存使用量。 - 更长的上下文支持:使人工智能模型能够有效地处理更长的文本片段,促进应用程序能够理解和生成更长、更复杂的内容,而不会减慢速度。 - 速度提高:与 FlashAttention-2 相比,FLM 处理速度提高了 1.5-2 倍。 - 可用性:可在 GitHub 上访问,相关研究论文中提供了更多信息。

这里讨论的问题是计算机程序的优化,特别是关于数据放置和最大化 CPU 使用率的问题。 演讲者指出,由于内存布局的变化,优化中的微小变化可能会产生最小的改进。 他们提到了一个专注于这个问题的演示,并指出由于缓存行为、访问模式以及不同的 CPU 和内存架构等因素造成的复杂性。 在考虑图形处理单元 (GPU) 时,发言人表示这可能会导致未知领域。 人们提出了人工智能(AI)解决这些挑战的可能性,尽管这取决于“足够”的定义。 假设场景涉及 AI 模型使用 Micrograd 来提高 Torch 的性能。 然而,目前实现这一目标的实用性仍不确定。 演讲者强调了当前的历史意义,因为技术的快速进步带来了竞争优势。 他们对自己在人工智能领域的努力充满信心,并得到了与戴尔、Advizex 等公司以及即将推出的主要数据中心组织的合作伙伴关系的支持。 在人工智能领域,Flash Attention 被引入作为在语言模型内的注意力机制计算期间执行某些操作的方法。 它似乎提供了生成注意力计算所需矩阵的替代方法。 还提到了分组查询注意力(GQA)和滑动窗口注意力,它们似乎专注于改变注意力掩模而不是实际的矩阵生成过程。 这些方法似乎存在多种实现方式,可能需要根据不同的策略进行定制。 最后,通过代码库实现 Flash Attention 似乎最近取得了进展,但没有提供详细信息。 此外,Flash Attention 和 Triton 等系统之间的关系存在一些混乱——Triton 是一个可能的抽象层,允许跨平台灵活部署。 可能需要进一步的研究来阐明这些关系。
相关文章

原文

Attention, as a core layer of the ubiquitous Transformer architecture, is a bottleneck for large language models and long-context applications. FlashAttention (and FlashAttention-2) pioneered an approach to speed up attention on GPUs by minimizing memory reads/writes, and is now used by most libraries to accelerate Transformer training and inference. This has contributed to a massive increase in LLM context length in the last two years, from 2-4K (GPT-3, OPT) to 128K (GPT-4), or even 1M (Llama 3). However, despite its success, FlashAttention has yet to take advantage of new capabilities in modern hardware, with FlashAttention-2 achieving only 35% utilization of theoretical max FLOPs on the H100 GPU. In this blogpost, we describe three main techniques to speed up attention on Hopper GPUs: exploiting asynchrony of the Tensor Cores and TMA to (1) overlap overall computation and data movement via warp-specialization and (2) interleave block-wise matmul and softmax operations, and (3) incoherent processing that leverages hardware support for FP8 low-precision.

We’re excited to release FlashAttention-3 that incorporates these techniques. It’s 1.5-2.0x faster than FlashAttention-2 with FP16, up to 740 TFLOPS, i.e., 75% utilization of H100 theoretical max FLOPS. With FP8, FlashAttention-3 reaches close to 1.2 PFLOPS, with 2.6x smaller error than baseline FP8 attention.

The improvements from FlashAttention-3 will result in:

  1. More efficient GPU Utilization: The new technique uses up to 75% of an H100 GPU's maximum capabilities, up from just 35% before. This results in significantly (1.5-2x) faster than previous versions for training and running of large language models (LLMs).
  1. Better performance with lower precision: FlashAttention-3 can work with lower precision numbers (FP8) while maintaining accuracy. This allows for even faster processing and potentially lower memory usage, which could lead to cost savings and improved efficiency for customers running large-scale AI operations.
  1. Ability to use longer context in LLMs: By speeding up the attention mechanism, FlashAttention-3 enables AI models to work with much longer pieces of text more efficiently. This could allow for applications that can understand and generate longer, more complex content without slowing down.

FlashAttention-3 is available on Github here.

Read the paper here.

FlashAttention Recap

FlashAttention is an algorithm that reorders the attention computation and leverages tiling and recomputation to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. We use tiling to load blocks of inputs from HBM (GPU memory) to SRAM (fast cache), perform attention with respect to that block, and update the output in HBM. By not writing the large intermediate attention matrices to HBM, we reduce the amount of memory reads/writes, which brings 2-4x wallclock time speedup.

Here we show a diagram of FlashAttention forward pass: with tiling and softmax rescaling, we operate by blocks and avoid having to read/write from HBM, while obtaining the correct output with no approximation.

New hardware features on Hopper GPUs - WGMMA, TMA, FP8

While FlashAttention-2 can achieve up to 70% theoretical max FLOPS on Ampere (A100) GPUs, it does not yet take advantage of new features on Hopper GPUs to maximize performance. We describe some of the new Hopper-specific features here, and why they are important.

1. WGMMA (Warpgroup Matrix Multiply-Accumulate). This new feature makes use of the new Tensor Cores on Hopper, with much higher throughput 1 than the older mma.sync instruction in Ampere (image from the H100 white paper).

2. TMA (Tensor Memory Accelerator). This is a special hardware unit that accelerates the transfer of data between global memory and shared memory, taking care of all index calculation and out-of-bound predication. This frees up registers, which is a valuable resource to increase tile size and efficiency.

3. Low-precision with FP8. This doubles the Tensor Core throughput (e.g. 989 TFLOPS with FP16 and 1978 TFLOPS with FP8), but trades off accuracy by using fewer bits to represent floating point numbers.

FlashAttention-3 makes use of all of these new features of Hopper, using powerful abstractions from NVIDIA’s CUTLASS library.

By rewriting FlashAttention to use these new features, we can already significantly speed it up (e.g., from 350 TFLOPS in FlashAttention-2 FP16 forward pass to around 540-570 TFLOPS). However, the asynchronous nature of the new instructions on Hopper (WGMMA and TMA) opens up additional algorithmic opportunities to overlap operations and thereby extract even greater performance. For this blogpost, we’ll explain two such techniques specific to attention. The generic technique of warp specialization, with separate producer and consumer warps doing TMA and WGMMA, is well-covered elsewhere in the context of GEMM and works the same here.

Asynchrony: Overlapping GEMM and Softmax

Why overlap?

Attention has GEMMs (those matmuls between Q and K and between attention probability P and V) and softmax as its two main operations. Why do we need to overlap them? Isn’t most of the FLOPS in the GEMMs anyway? As long as the GEMMs are fast (e.g., computed using WGMMA instructions), shouldn’t the GPU be going brrrr?

The problem is that non-matmul operations are much slower than matmul operations on modern accelerators. Special functions such as exponential (for the softmax) have even lower throughput than floating point multiply-add; they are evaluated by the multi-function unit, a unit separate from floating point multiply-add or matrix multiply-add. As an example, the H100 GPU SXM5 has 989 TFLOPS of FP16 matrix multiply, but only 3.9 TFLOPS (256x less throughput) for special functions 2 ! For head dimension 128, there are 512x more matmul FLOPS than exponential, which means that exponential can take 50% of the time compared to matmul. The situation is even worse for FP8, where the matmul FLOPS are twice as fast yet exponential FLOPS stay the same speed. Ideally we want matmul and softmax to operate in parallel. While the Tensor Cores are busy with matmul, the multi-function units should be calculating exponential!

Inter-warpgroup overlapping with pingpong scheduling

The first and easiest way to overlap GEMM and softmax is to do nothing at all! The warp schedulers already try to schedule warps so that if some warps are blocked (e.g., waiting for GEMM results), other warps can run. That is, the warp schedulers do some of this overlapping for us, for free.

However, we can improve on this by doing some of the scheduling manually. As an example, if we have 2 warpgroups (labeled 1 and 2 – each warpgroup is a group of 4 warps), we can use synchronization barriers (bar.sync) so that warpgroup 1 first does its GEMMs (e.g., GEMM1 of one iteration and GEMM0 of the next iteration), and then warpgroup 2 does its GEMMs while warpgroup 1 does its softmax, and so on. This “pingpong” schedule is illustrated in the figure below, where the same color denotes the same iteration.

This would allow us to perform the softmax in the shadow of the GEMMs of the other warpgroup. Of course, this figure is just a caricature; in practice the scheduling is not really this clean. Nevertheless, pingpong scheduling can improve FP16 attention forward pass from around 570 TFLOPS to 620 TFLOPS (head dim 128, seqlen 8K).

Intra-warpgroup overlapping of GEMM and Softmax

Even within one warpgroup, we can have some part of softmax running while the GEMMs of that warpgroup is running. This is illustrated in this figure, where the same color denotes the same iteration.

This pipelining increases throughput from around 620 TFLOPS to around 640-660 TFLOPS for FP16 attention forward, at the cost of higher register pressure. We need more registers to hold both accumulators of the GEMMs, and the input/output of softmax. Overall, we find this technique to offer a favorable tradeoff.

Low-precision: reduce quantization error with incoherent processing

LLM activation can have outliers with much larger magnitude than the rest of the features. These outliers make it difficult to quantize, producing much larger quantization errors. We leverage incoherent processing, a technique used in the quantization literature (e.g. from QuIP) that multiplies the query and key with a random orthogonal matrix to “spread out” the outliers and reduce quantization error. In particular, we use the Hadamard transform (with random signs), which can be done per attention head in O(d log d) instead of O(d^2) time, where d is the head dimension. Since the Hadamard transform is memory-bandwidth bound, it can be fused with previous operations such as rotary embedding (also memory-bandwidth bound) “for free”.

In our experiment where Q, K, V are generated from a standard normal distribution but 0.1% of the entries have large magnitudes (to simulate outliers), we found that incoherent processing can reduce the quantization error by 2.6x. We show numerical error comparison in the table below. Please see the paper for details.

Attention benchmark

We show some results with FlashAttention-3, and compare it to FlashAttention-2, as well as the implementation in Triton and cuDNN (both of which already use new hardware features of Hopper GPUs).

For FP16, we see about 1.6x-1.8x speedup over FlashAttention-2

For FP8, we can reach close to 1.2 PFLOPS!

Discussion

This blogpost highlights some of the optimizations for FlashAttention available on Hopper GPUs. Other optimizations (e.g., variable length sequences, persistent kernel, and in-kernel transpose for FP8) are covered in the paper.
We have seen that designing algorithms that take advantage of the hardware they run on can bring significant efficiency gains and unlock new model capabilities such as long context. We look forward to future work on optimization for LLM inference, as well as generalizing our techniques to other hardware architectures. 

We also look forward to FlashAttention-3 being integrated in a future release of PyTorch.

Footnotes:

1 Without the wgmma instruction, the older mma.sync instruction can only reach about ⅔ the peak throughput of Hopper Tensor Cores: https://arxiv.org/abs/2402.13499v1

2 The CUDA programming guide specifies that the throughput for special functions is 16 operations per streaming multiprocessor (SM) per clock cycle. We multiply 16 by 132 SMs and 1830 Mhz (clock speed used to calculate 989 TFLOPS of FP16 matmul) to get 3.9 TFLOPS

联系我们 contact @ memedata.com