type
status
date
slug
summary
tags
category
icon
password
update_time
Sep 19, 2023 03:27 AM
create_time
Jul 18, 2023 02:39 AM
Flash Attention:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
2022年的很有意思的一篇文章,相比于别的各类魔改Attention,Flash Attention并未改变原有的计算公式,因此整体计算复杂度并未降低。但是该论文从内存访问,IO开销等角度出发,通过减少GPU内存读写次数,将Attention的计算速度提升2到4倍,同时内存也减少5到20倍。
主要贡献如下2点:
- 在不访问整个输入的情况下计算softmax,即拆成分段进行计算。
- 在反向传播中不能存储中间注意力矩阵(的矩阵),只存储softmax归一化的系数。
先看下标准Attention的计算过程:
在Flash Attention之前,、作为中间变量需要存下来,因此需要的内存开销。
Flash Attention采用了softmax tilling:
- 正常的softmax计算:
- 分块的softmax计算(以分2块并行计算为):
伪代码实现如下:

中间变量:(最终乘积)、(softmax的分母,即累加和)、(遍历到当前块为止的最大值),再也不用保存全部的和了。
在撰写本篇时,Flash AttentionV2版本刚刚发布,目前看来主要是在从头开始编写,采用了更底层的CUTLASS库,然后还支持了Multi-query Attention和grouped-query Attention。这2个Attention主要是通过共享一些Key和Value,减少缓存,提高运行效率。 从某种程度来说,Flash Attention是单纯的程序优化,包含了类似刷Leetcode的思想,活学活用,这才是程序员撰写代码首先应想到的解法,即在不动算法数学逻辑的基础上,采用诸如动态规划、贪心等算法优化程序性能。
Sparse Attention(来源于Sparse Transformer) = Strided Attention + Fixed Attention
- Strided attention: the i-th position can attend to the j-th position if (i+s) > j > ( i-s) or (i-j) mod s = 0
- Fixed attention: the i-th position can attend to the j-th position if floor(j/s) = floor(i/s) or (j mod s) ≥ (s-c)
其中s为stride length,c为超参数。

计算复杂度:,其中为序列长度,为隐藏层的维度()。
计算复杂度:为什么是?是因为论文中(stride length)设置为。
BigBird: Transformers for Longer Sequences
2021年的文章,主要是降低Vanilla Transformer中Attention的计算复杂度,提出了3种Attention的融合:

- sliding attention:本质上就是window attention,滑动指定大小的窗口进行attention计算。计算复杂度为
为什么乘以3,是因为原论文实现时:将key向量复制2次,一次左移一个单位,一次右移一个单位,然后算上原有的,用query向量与这3个向量相乘,就实现了所有sliding tokens的计算(本质上窗口大小设置为3)。 可以这样理解下:之前像query和key都是(batch_size, max_seq_len, num_heads * head_dim)的向量,那么这本质上还是刨除batch_size的二维向量,需要进行max_seq_len的平方次运算,用N代表max_seq_len,即。现在变成滑窗后,很多运算都可以减少了,那么其实就是,为窗口大小(每个token只与个token计算点积)。
- global attention:本质上就是预定义几个token或者指定位置的token,只有这些token能参与到与其他token的attention score计算。论文中考虑的是仅使用第一个和最后一个位置的token来作为全局token进行计算,计算复杂度为。
说实话,这个全局attention直接只使用首尾token,并不一定是全局使用最优的token,效果能生效,有点类似CLS、SEP这类特殊token的代表意思。
- random attention: 与global attention类似,只不过这里是随机挑选的token(需要注意不要挑选首尾的token,否则与global attention重复了)。论文实现时,只随机挑除了首尾的3个位置的token,计算复杂度为。
BigBird使用的是3者的融合版本,实际上计算复杂度基本上是降到了,基本上达到了线性(计算量随文本长度增加而线性增长)。
其实BigBird有ITC(internal transformer construction)和ETC(external transformer construction)2种训练方式,前面介绍的都是ITC,ETC相比ITC用了更多的token,差异如下:

LongNet: Scaling Transformers to 1,000,000,000 Tokens
LongNet是近期(2023年7月)的文章,号称是能够将Transformer支持的长度提升到1B,即10亿。在此之前最长的是今年初的Scaling Transformer to 1M tokens and beyond with RMT这篇论文(OpenAI前成员创立的Anthropic的类ChatGPT产品——Claude),其采用的是2022年发表的RMT(Recurrent Memory Transformer)技术,但也只声称达到10K,即100万的长度支持。

引入了dilated attention机制,主要对、、进行稀疏处理(有点类似Atrous/Dilation Convolution)。

LongNet涉及到2个超参数:
- segment length:分段的长度
- dilation rate:膨胀率。当时,表示在计算attention score时,每个segment抽取的行组成小矩阵进行计算(其实就是segment中2个token间的间距控制)。
在位置编码设计上,综合考虑了2个:
- Relative Position Bias:相对位置偏置,实际上是计算Attention时候公式为,其中即为偏置。
- Rotary Position Embedding (xPos):这是旋转位置编码(RoPE)的一种改进算法。它加入指数衰减校正,以及blockwise causal attention,让模型忽略相距较远的语义关联。
计算复杂度降到了,计算逻辑与BigBird很类似(BigBird选用了3+2+3):segment length设置了{2048,4096,8192,16384,32768},dilation rate设置了{1,2,4,6,12},因此总体还是。
此外,在Attention计算时,支持分布式计算的支持。常规的MultiHeadAttention是统一计算的,这里采用了ModuleList去计算。(看源码的README举例中使用了torch.distributed.all_gather功能,实际上源码中未放出完整的分布式训练代码)。
看到有人吐槽这个LongNet跟BigBird的Attention组合方式看着差不多,计算复杂度都是降到了,实际上还是有些区别的。 首先,LongNet这个提出的Dilated Attention,更为通用。相对而言,BigBird使用的3种Attention都是Dilated Attention的不同参数设置。 其次,LongNet提出了并行程度更高的Attention计算拆分方式,并论证了可行性。主要是指将上图的这些区域分开计算,包括softmax和value的相乘,这在以往的实现是没有的,大多数还是在softmax前就合并了。
RetNet:Retentive Network: A Successor to Transformer for Large Language Models

爆改Transformer,使其既能并行训练,又能串行解码。论文宣称,解码速度是带键值缓存的Transformers的8.4倍,内存节省70%。宣称是Transformer的有力继承者。
计算复杂度降到了,与RWKV这种基于RNN的结构很类似。主要通过并行、递归、分块递归来提升(这点很类似LongNet)。
- 其并行公式如下:
可以看到本质上还是、、的复合运算,其实是xPos和causal masking的组合。只不过,这里的运算去掉了softmax。

- 转换成RNN形式的递归公式:
这里简单验证下并行公式到递归公式的转化,用以论证结果是一致的:
假设我们有一个token长度为3的输入,那么输出的Attention(这里为Retention) 输出维度应该为(3,vec_dim)。
先来考虑并行的方式的结果:
其中、、表示、、这3个全连接的输出向量,其维度均为(1, vec_dim)。每个对应的最终attention输出维度为(1,vec_dim)。3行对应3个token——。
转化为递归公式的理解:
合并后能看到结果是一样的。
- 还可以转换成分块计算的公式形式:
上面三种形式的伪代码:

论文使用了一种Gated Multi-Scale Retention的方式来计算每层的输出,本质上有点类似Transformer的MultiHeadAttention。计算方式如下:
上面又引入新的参数和,,其中为head dim。正常Transformer计算score时,会有一个softmax用来缩放,这里采用了GroupNorm(正常在batch size较小时使用,如小于16)。
最后看整体的网络结构:
可以看到还是Transformer那一套,用的是Pre-Norm的方式。
论文中有个用法很神奇,既然我们得到上面三种形式的计算方式,那么怎么使用呢?都一起使用吗? 论文中采用的方式: 训练时:采用并行公式和分块计算的公式 预测时:采用递归公式进行解码 因此训练和预测的解码逻辑是不一样的,但是详细推理下,这三个公式,本质上是一样的。 训练时使用的分块计算来加速,其中间每个块的计算其实仍然是并行的公式。在推理时,递归公式使用可以让整个过程不依赖于句子长度,理论上可以无限长,当然可视为衰减因子,仍然会像RWKV那样,对Prompt比较敏感,需要尽量先输入Task Prompt,再输入Context Prompt,即带着问题找答案,否则会因为信息衰减导致找不到重点而效果大打折扣。 这篇文章的创新点并行和分块在诸如LongNet之类的算法中已有体现,主要是把公式顺便推导到递归形式,这点就牛逼。
其它比较好的总结
- Demystifying efficient self-attention:列了一些比较典型的Attention进行了剖析
- 作者:哒鸽
- 链接:https://dage110.com/article/attention
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。