type
status
date
slug
summary
tags
category
icon
password
update_time
Sep 19, 2023 03:27 AM
create_time
Jul 18, 2023 02:39 AM

Flash AttentionFlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

2022年的很有意思的一篇文章,相比于别的各类魔改Attention,Flash Attention并未改变原有的计算公式,因此整体计算复杂度并未降低。但是该论文从内存访问,IO开销等角度出发,通过减少GPU内存读写次数,将Attention的计算速度提升2到4倍,同时内存也减少5到20倍。
主要贡献如下2点:
  1. 在不访问整个输入的情况下计算softmax,即拆成分段进行计算。
  1. 在反向传播中不能存储中间注意力矩阵(N2N^2的矩阵),只存储softmax归一化的系数。
先看下标准Attention的计算过程:
S=QKTRN×NP=softmax(S)RN×NO=PVRN×NS=QK^T\in{\mathbb{R}^{N\times{N}}}\\ P=softmax(S)\in{\mathbb{R}^{N\times{N}}}\\ O=PV\in{\mathbb{R}^{N\times{N}}}
在Flash Attention之前,SSPP作为中间变量需要存下来,因此需要N2N^2的内存开销。
Flash Attention采用了softmax tilling:
  • 正常的softmax计算:
    • m(x):=maxi xif(x):=[ex1m(x) ... exBm(x)](x):=if(x)isoftmax(x):=f(x)(x)m(x):=\mathop{max}\limits_{i} \ xi\\ f(x):=[e^{x_1-m(x)}\ ...\ e^{x_B-m(x)}]\\ \ell(x):=\sum_{i}{f(x)_i}\\ softmax(x):=\frac{f(x)}{\ell(x)}
  • 分块的softmax计算(以分2块并行计算为):
    • m(x)=m([x(1) x(2)])=max(m(x(1)),m(x(2)))f(x)=[em(x(1))m(x)f(x(1))    em(x(2))m(x)f(x(2))](x)=([x(1)  x(2)])=em(x(1))m(x)(x(1))+em(x(2))m(x)(x(2))softmax(x)=f(x)(x)m(x)=m([x^{(1)}\ x^{(2)}])=max(m(x^{(1)}),m(x^{(2)}))\\ f(x)=\left[e^{m(x^{(1)})-m(x)}f(x^{(1)})\ \ \ \ e^{m(x^{(2)})-m(x)}f(x^{(2)}) \right]\\ \ell(x)=\ell([x^{(1)}\ \ x^{(2)}])=e^{m(x^{(1)})-m(x)}\ell(x^{(1)})+e^{m(x^{(2)})-m(x)}\ell(x^{(2)})\\ softmax(x)=\frac{f(x)}{\ell(x)}
伪代码实现如下:
notion image
中间变量:OiO_i(最终乘积)、i\ell_i(softmax的分母,即累加和)、mim_i(遍历到当前块为止的最大值),再也不用保存全部的SSPP了。
 
在撰写本篇时,Flash AttentionV2版本刚刚发布,目前看来主要是在从头开始编写,采用了更底层的CUTLASS库,然后还支持了Multi-query Attention和grouped-query Attention。这2个Attention主要是通过共享一些Key和Value,减少缓存,提高运行效率。 从某种程度来说,Flash Attention是单纯的程序优化,包含了类似刷Leetcode的思想,活学活用,这才是程序员撰写代码首先应想到的解法,即在不动算法数学逻辑的基础上,采用诸如动态规划、贪心等算法优化程序性能。
notion image

Sparse Attention(来源于Sparse Transformer) = Strided Attention + Fixed Attention

  1. Strided attention: the i-th position can attend to the j-th position if (i+s) > j > ( i-s) or (i-j) mod s = 0
  1. Fixed attention: the i-th position can attend to the j-th position if floor(j/s) = floor(i/s) or (j mod s) ≥ (s-c)
    1. 其中s为stride length,c为超参数。
notion image
计算复杂度:O(NNd)O(N\sqrt{N}d),其中NN为序列长度,dd为隐藏层的维度(head_dim=hidden_size/num_heads\text{head\_dim}=\text{hidden\_size}/\text{num\_heads})。
计算复杂度:为什么是N\sqrt{N}?是因为论文中ss(stride length)设置为N\sqrt{N}

BigBird: Transformers for Longer Sequences

2021年的文章,主要是降低Vanilla Transformer中Attention的计算复杂度,提出了3种Attention的融合:
notion image
  • sliding attention:本质上就是window attention,滑动指定大小的窗口进行attention计算。计算复杂度为O(3×N)=O(N)O(3\times{N})=O(N)
    • 为什么乘以3,是因为原论文实现时:将key向量复制2次,一次左移一个单位,一次右移一个单位,然后算上原有的,用query向量与这3个向量相乘,就实现了所有sliding tokens的计算(本质上窗口大小设置为3)。 可以这样理解下:之前像query和key都是(batch_size, max_seq_len, num_heads * head_dim)的向量,那么这本质上还是刨除batch_size的二维向量,QKTQK^T需要进行max_seq_len的平方次运算,用N代表max_seq_len,即O(N2)O(N^2)。现在变成滑窗后,很多运算都可以减少了,那么其实就是O(wN)O(wN)ww为窗口大小(每个token只与ww个token计算点积)。
  • global attention:本质上就是预定义几个token或者指定位置的token,只有这些token能参与到与其他token的attention score计算。论文中考虑的是仅使用第一个和最后一个位置的token来作为全局token进行计算,计算复杂度为O(2N)O(2N)
    • 说实话,这个全局attention直接只使用首尾token,并不一定是全局使用最优的token,效果能生效,有点类似CLS、SEP这类特殊token的代表意思。
  • random attention: 与global attention类似,只不过这里是随机挑选的token(需要注意不要挑选首尾的token,否则与global attention重复了)。论文实现时,只随机挑除了首尾的3个位置的token,计算复杂度为O(3N)O(3N)
BigBird使用的是3者的融合版本,实际上计算复杂度基本上是降到了O(3N+2N+3N)=O(8N)=O(N)O(3N+2N+3N)=O(8N)=O(N),基本上达到了线性(计算量随文本长度增加而线性增长)。
其实BigBird有ITC(internal transformer construction)和ETC(external transformer construction)2种训练方式,前面介绍的都是ITC,ETC相比ITC用了更多的token,差异如下:
notion image

LongNet: Scaling Transformers to 1,000,000,000 Tokens

LongNet是近期(2023年7月)的文章,号称是能够将Transformer支持的长度提升到1B,即10亿。在此之前最长的是今年初的Scaling Transformer to 1M tokens and beyond with RMT这篇论文(OpenAI前成员创立的Anthropic的类ChatGPT产品——Claude),其采用的是2022年发表的RMT(Recurrent Memory Transformer)技术,但也只声称达到10K,即100万的长度支持。
notion image
引入了dilated attention机制,主要对QQKKVV进行稀疏处理(有点类似Atrous/Dilation Convolution)。
notion image
LongNet涉及到2个超参数:
  • segment length:分段的长度
  • dilation rate:膨胀率。当dilation_rate=n\text{dilation\_rate}=n时,表示在计算attention score时,每个segment抽取1n\frac{1}{n}的行组成小矩阵进行计算(其实就是segment中2个token间的间距控制)。
在位置编码设计上,综合考虑了2个:
  • Relative Position Bias:相对位置偏置,实际上是计算Attention时候公式为softmax(QKT+B)Vsoftmax(QK^T+B)V,其中BB即为偏置。
  • Rotary Position Embedding (xPos):这是旋转位置编码(RoPE)的一种改进算法。它加入指数衰减校正,以及blockwise causal attention,让模型忽略相距较远的语义关联。
计算复杂度降到了O(N)O(N),计算逻辑与BigBird很类似(BigBird选用了3+2+3):segment length设置了{2048,4096,8192,16384,32768},dilation rate设置了{1,2,4,6,12},因此总体还是O(5N)=O(N)O(5N)=O(N)
此外,在Attention计算时,支持分布式计算的支持。常规的MultiHeadAttention是统一计算的,这里采用了ModuleList去计算。(看源码的README举例中使用了torch.distributed.all_gather功能,实际上源码中未放出完整的分布式训练代码)。
看到有人吐槽这个LongNet跟BigBird的Attention组合方式看着差不多,计算复杂度都是降到了O(N)O(N),实际上还是有些区别的。 首先,LongNet这个提出的Dilated Attention,更为通用。相对而言,BigBird使用的3种Attention都是Dilated Attention的不同参数设置。 其次,LongNet提出了并行程度更高的Attention计算拆分方式,并论证了可行性。主要是指将上图的这些区域分开计算,包括softmax和value的相乘,这在以往的实现是没有的,大多数还是在softmax前就合并了。
 

RetNet:Retentive Network: A Successor to Transformer for Large Language Models

notion image
爆改Transformer,使其既能并行训练,又能串行解码。论文宣称,解码速度是带键值缓存的Transformers的8.4倍,内存节省70%。宣称是Transformer的有力继承者。
计算复杂度降到了O(N)O(N),与RWKV这种基于RNN的结构很类似。主要通过并行、递归、分块递归来提升(这点很类似LongNet)。
  • 并行公式如下:
Q=(XWQ)Θ, K=(XWK)Θ, V=XWVΘn=einθ,Dnm={γ(nm),nm0,n<mRetention(X)=(QKTD)VQ=(XW_Q)\odot\Theta,\ K=(XW_K)\odot\overline{\Theta}, \ V=XW_V\\ \Theta_n=e^{in\theta},D_{nm}=\left\{ \begin{array}{rr}\gamma^{(n-m)}, &n\geq{m} \\ 0, &n<m\end{array}\right. \\ \text{Retention}(X)=(QK^T\odot{D})V
可以看到本质上还是QQKKVV的复合运算,DD其实是xPos和causal masking的组合。只不过,这里的运算去掉了softmax。
notion image
  • 转换成RNN形式的递归公式:
    • Sn=γSn1+KnTVnRetention(Xn)=QnSn, n=1,...,xS_n=\gamma{S_{n-1}}+K_n^TV_n\\ \text{Retention}(X_n)=Q_nS_n, \ n=1,...,|x|
      这里简单验证下并行公式到递归公式的转化,用以论证结果是一致的:
      假设我们有一个token长度为3的输入(x1,x2,x3)(x_1,x_2,x_3),那么输出的Attention(这里为Retention) 输出维度应该为(3,vec_dim)。 先来考虑并行的方式的结果:
      Retention(X)=[q1k1Tv1γq2k1Tv1+q2k2Tv2γ2q3k1Tv1+γq3k2Tv2+q3k3Tv3]\text{Retention}(X)=\begin{bmatrix} q_1*k_1^T*v_1\\ \gamma*q_2*k_1^T*v_1+q_2*k_2^T*v_2 \\ \gamma^2*q_3*k_1^T*v_1+\gamma*q_3*k_2^T*v_2+q_3*k_3^T*v_3 \end{bmatrix}
      其中qiq_ikik_iviv_i表示QQKKVV这3个全连接的输出向量,其维度均为(1, vec_dim)。每个xix_i对应的最终attention输出维度为(1,vec_dim)。3行对应3个token——(x1,x2,x3)(x_1,x_2,x_3)。 转化为递归公式的理解:
      S1=k1Tv1Retention(X1)=q1S1=q1k1Tv1S2=γS1+k2Tv2=γk1Tv1+k2Tv2Retention(X2)=q2S2=q2(γk1Tv1+k2Tv2)=γq2k1Tv1+q2k2Tv2S3=γS2+k3Tv3=γ(γk1Tv1+k2Tv2)+k3Tv3=γ2k1Tv1+γk2Tv2+k3Tv3Retention(X3)=q3S3=q3(γ2k1Tv1+γk2Tv2+k3Tv3)=γ2q3k1Tv1+γq3k2Tv2+q3k3Tv3\begin{align*} S_1&=k_1^T*v_1\\ \text{Retention}(X_1)&=q_1*S_1=q_1*k_1^T*v_1\\ S_2&=\gamma*S_1+k_2^T*v_2\\&=\gamma*k_1^T*v_1+k_2^T*v_2\\ \text{Retention}(X_2)&=q_2*S_2\\&=q_2*(\gamma*k_1^T*v_1+k_2^T*v_2)\\&=\gamma*q_2*k_1^T*v_1+q_2*k_2^T*v_2\\ S_3&=\gamma*S_2+k_3^T*v_3\\&=\gamma*(\gamma*k_1^T*v_1+k_2^T*v_2)+k_3^T*v_3\\&=\gamma^2*k_1^T*v_1+\gamma*k_2^T*v_2+k_3^T*v_3\\ \text{Retention}(X_3)&=q_3*S_3\\&=q_3*(\gamma^2*k_1^T*v_1+\gamma*k_2^T*v_2+k_3^T*v_3)\\&=\gamma^2*q_3*k_1^T*v_1+\gamma*q_3*k_2^T*v_2+q_3*k_3^T*v_3 \end{align*}
      合并后能看到结果是一样的。
  • 还可以转换成分块计算的公式形式:
    • Qi=QBi:B(i+1), Ki=KBi:B(i+1), Vi=VBi:B(i+1)Ri=KiTVi+γBRi1Retention(Xi)=(QiKiTD)Vi)Inner-Chunk+(QiRi)ξCross-Chunk, ξij=γi+1Q_{|i|}=Q_{{Bi}:{B(i+1)}},\ K_{|i|}=K_{{Bi}:{B(i+1)}}, \ V_{|i|}=V_{{Bi}:{B(i+1)}} \\ R_i=K_{|i|}^TV_{|i|}+\gamma^BR_{i-1}\\ \text{Retention}(X_{|i|})=\underbrace{(Q_{|i|}K_{|i|}^T\odot{D})V_{|i|})}_{\text{Inner-Chunk}}+\underbrace{(Q_{|i|}R_i)\odot\xi}_{\text{Cross-Chunk}},\ \xi_{ij}=\gamma^{i+1}
上面三种形式的伪代码:
notion image
论文使用了一种Gated Multi-Scale Retention的方式来计算每层的输出,本质上有点类似Transformer的MultiHeadAttention。计算方式如下:
γ=125arange(0,h)Rhheadi=Retention(X,γi)Y=GroupNormh(Concat(head1,...,headn))MSR(X)=(swish(XWG))Y)WO\begin{align*} \gamma&=1-2^{-5-\text{arange}(0,h)} \in \mathbb{R}^h\\ \text{head}_i&=\text{Retention}(X,\gamma_i)\\ Y&=\text{GroupNorm}_h(\text{Concat}(head_1,...,head_n))\\ \text{MSR}(X)&=(\text{swish}(XW_G))\odot{Y})W_O \end{align*}
上面又引入新的参数WOW_OWGW_Gh=dmodel/dh=d_{\text{model}}/d,其中dd为head dim。正常Transformer计算score时,会有一个softmax用来缩放,这里采用了GroupNorm(正常在batch size较小时使用,如小于16)。
最后看整体的网络结构:
Yl=MSR(LN(Xl))+XlXl+1=FFN(LN(Yl))+Yl\begin{align*} Y^l&=\text{MSR}(\text{LN}(X^l))+X^l\\ X^{l+1}&=\text{FFN}(\text{LN}(Y^l))+Y^l \end{align*}
可以看到还是Transformer那一套,用的是Pre-Norm的方式。
论文中有个用法很神奇,既然我们得到上面三种形式的计算方式,那么怎么使用呢?都一起使用吗? 论文中采用的方式: 训练时:采用并行公式和分块计算的公式 预测时:采用递归公式进行解码 因此训练和预测的解码逻辑是不一样的,但是详细推理下,这三个公式,本质上是一样的。 训练时使用的分块计算来加速,其中间每个块的计算其实仍然是并行的公式。在推理时,递归公式使用可以让整个过程不依赖于句子长度,理论上可以无限长,当然γ\gamma可视为衰减因子,仍然会像RWKV那样,对Prompt比较敏感,需要尽量先输入Task Prompt,再输入Context Prompt,即带着问题找答案,否则会因为信息衰减导致找不到重点而效果大打折扣。 这篇文章的创新点并行和分块在诸如LongNet之类的算法中已有体现,主要是把公式顺便推导到递归形式,这点就牛逼。
 

其它比较好的总结

  1. Demystifying efficient self-attention:列了一些比较典型的Attention进行了剖析
 
论文速读——2023.7.21突破大语言模型长度外推方式调研
  • Giscus
浅色模式