What is Slow Attension? Key and Query is multiplied in Attention How it happens in Memory The Query, Key and Value are stored in High Bandwith Memory outside the GPU Cores Because of frequent travel back and forth for calculation -> GLobal memory access Multi-Step Computation: The process involves several sequential steps: computing the dot product between Query and Key matrices, applying the softmax function to get attention weights, and then computing the weighted average of the Value matrix. Frequent Global Memory Access: On a GPU, the large matrices (Query, Key, Value, intermediate dot product results, and the final attention matrix) are often stored in High Bandwidth Memory (HBM), which is relatively slow compared to the GPU's on-chip SRAM. Each step of the attention computation requires loading data from HBM to the GPU cores, performing the calculation, and then writing the results back to HBM. This constant movement of large amounts of data back and forth between HBM and the processing cores is known as global memory access and is a significant bottleneck that adds latency. Materializing Large Matrices: For long input sequences, the attention matrix (which is sequence length by sequence length) can become extremely large. Storing and repeatedly accessing this large matrix in HBM consumes significant memory bandwidth and time. Inefficient Softmax Implementation: The standard way to compute softmax can sometimes require multiple passes over the data to ensure numerical stability, further adding to the computation time and memory access. -> SRAM has highest bandwitdth but very small size. We cannot fit the full attention matrix in the SRAM This is where tiling comes in Tiling -> To calculate C=A*B only we have to 32 memory accesses With Tiling we only require 16 Memory accesses FOr N*N block we can cut the Memory access by 1/N -> NOw this matrix can go in SRAM -> THe issue is softmax is not numerically Stable Before exponential we subtract with maximumvalue , known as safe softmax -> THis is very I/O ineffeceient as we are iterating on same thing 3 times Flash Attention is an efficient algorithm designed to compute the attention mechanism used in Transformer models. The traditional self-attention mechanism can be slow and memory-intensive, especially with long sequences, because it involves computing and storing a large attention matrix. Flash Attention addresses these limitations by being a hardware-aware algorithm that is fast, memory efficient, and importantly, exact, meaning it does not use approximations that can sacrifice accuracy. It achieves this by leveraging techniques like tiling and fusing computations. Specifically, Flash Attention avoids materializing the full N x N attention matrix in the high-bandwidth memory (HBM). Instead, it partitions the query, key, and value matrices into smaller tiles. It then loads these tiles into the faster on-chip SRAM, performs the attention computation for these tiles, and updates the partial results directly in SRAM or HBM. This process is repeated for all tiles, fusing the computations and significantly reducing the amount of data transferred between the GPU core and the slower HBM. The algorithm is built upon extending the concept of "online softmax," a clever trick that allows for computing the softmax function in fewer passes over the input sequence, which is then applied to the self-attention computation. This approach allows Flash Attention to compute the attention mechanism exactly while drastically improving speed and memory usage compared to standard implementations and previous approximate methods. Its effectiveness has led to further developments like Flash Attention 2 and 3. ( , , , , , ) Would you like to explore how the tiling technique or the online softmax concept contributes to Flash Attention's efficiency in more detail? Flash Attention avoid materializing the large attention matrix by fusing all the computation together.