Softmax是Transformer模型架构中非常重要的一环。它所在的Attention模块虽然所需要的计算量不大,但也是不容忽视的一环。同时由于它本身的数学特性所造成的数据依赖,如果按照其原始方法来进行运算,会耗费大量的计算时间,因为它需要三次完整读取数据。
Online normalizer calculation for softmax 提出了online softmax,通过牺牲计算来节省数据读取次数,将三遍完整读取(3 passes)降低到两遍完整读取(2 passes)。
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 则应用了类似的思想,更进一步,利用NVIDIA GPU的本地存储,将读取次数减少为一遍。
那么就让我们详细了解一下其中奥秘吧。本文会出现数学公式,但不要慌张,仅仅是简单的数组知识而已。同时也会辅以可以运行的Python代码,以便理解。
Softmax的数学表达
首先,softmax作用在一维vector上,而非多维tensor。
softmax(X⃗[1:N])=exi∑j−1Nexj(1) softmax( \vec{X} [1:N] ) = \frac {e^{x_i}} {\sum_{j-1}{N} e{x_j} }(1)
我们需要(1)先计算每个x的e^x并进行加总得到总和,(2)然后计算除法。
Safe softmax
但是由于e^x的特性,当x相对较大时,e^x就容易溢出,尤其是使用fl…
Softmax是Transformer模型架构中非常重要的一环。它所在的Attention模块虽然所需要的计算量不大,但也是不容忽视的一环。同时由于它本身的数学特性所造成的数据依赖,如果按照其原始方法来进行运算,会耗费大量的计算时间,因为它需要三次完整读取数据。
Online normalizer calculation for softmax 提出了online softmax,通过牺牲计算来节省数据读取次数,将三遍完整读取(3 passes)降低到两遍完整读取(2 passes)。
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 则应用了类似的思想,更进一步,利用NVIDIA GPU的本地存储,将读取次数减少为一遍。
那么就让我们详细了解一下其中奥秘吧。本文会出现数学公式,但不要慌张,仅仅是简单的数组知识而已。同时也会辅以可以运行的Python代码,以便理解。
Softmax的数学表达
首先,softmax作用在一维vector上,而非多维tensor。
softmax(X⃗[1:N])=exi∑j−1Nexj(1) softmax( \vec{X} [1:N] ) = \frac {e^{x_i}} {\sum_{j-1}{N} e{x_j} }(1)
我们需要(1)先计算每个x的e^x并进行加总得到总和,(2)然后计算除法。
Safe softmax
但是由于e^x的特性,当x相对较大时,e^x就容易溢出,尤其是使用float16甚至更低精度的浮点表达时。因此我们需要对X进行normalization(归一化),并称之为saft softmax。
softmax(X⃗[1:N])=exi−max(X⃗)∑j−1Nexj−max(X⃗)(2) softmax( \vec{X} [1:N] ) = \frac {e^{x_i - \max(\vec{X})}} {\sum_{j-1}{N} e{x_j - \max(\vec{X})} } (2)
因此,softmax的三遍读取为(1)统计最大值(2)加总(3)除法
如果用Python来表示的话:
vec_x = [random.random() for _ in range(N)]
# 1st pass
max_x = max(vec_x)
# 2nd pass
sum_ex = sum([math.exp(x - max_x) for x in vec_x])
# 3rd pass
softmax_x = [math.exp(x - max_x) / sum_ex for x in vec_x]
Online softmax
但是,我们其实可以将(1)和(2)合并,在寻找最大值的同时进行加总操作。
如果我们能够构造两个数组M和S,它们的单元数量和X一样:
mi=max(x⃗[1:i])si=∑j=1iexj−mi(3) m_i = \max(\vec{x}[1:i]) \newline s_i = \sum_{j=1}{i}{e{x_j - m_i}} (3)
即,数组M保存的是数组X到当前index为止的最大值,而数组S保存的是类似softmax的加总(除数)。数组S的所有数值都和softmax的第二步加总的出来的结果的不一样,除了最后一个。也就是说,如果我们一个个读取数组X,并计算数组M和数组S,在计算到最后一个的时候,数组S的d_N正好等于softmax的除数。这样我们只需要一遍读取就能完成步骤(1)和步骤(2)。我们选择符号m来表示max最大值;选择符号s来表示sum加总。
这种奇思妙想不得不令我们想起dynamic programming。
如果我们展开数组S的计算过程:
s1=1(4) s_1 = 1 (4)
si=si−1×emi−1−mi+exi−mi(5) s_i = s_{i-1} \times e^{m_{i-1} - m_i} + e^{x_i - m_i} (5)
sN=∑j=1Nexj−mN(6) s_N = \sum_{j=1}{N} e{x_j - m_N} (6)
我们可以证明公式6,即当i=N的时候,s_i即为softmax中加总之后的结果。
sN=sN−1×emN−1−mN+exN−mN=∑j=1N−1(exj−mN−1×emN−1−mN)+exN−mN=∑j=1N−1exj−mN+exN−mN=∑j=1Nexj−mN(7) s_N = s_{N-1} \times e^{m_{N-1} - m_N} + e^{x_N - m_N} \newline = \sum_{j=1}{N-1}{(e{x_j - m_{N-1}} \times e^{m_{N-1} - m_N})} + e^{x_N - m_N} \newline = \sum_{j=1}{N-1}{e{x_j - m_{N}} } + e^{x_N - m_N} \newline = \sum_{j=1}{N} e{x_j - m_N} (7)
如果用Python来表示的话:
vec_x = [random.random() for _ in range(N)]
s_prev = 1 # previous s_{i-1}
m_prev = vec_x[0] # previous m_{i-1}
for i in range(1, N): # starting from the 2nd element
x = vec_x[i] # read x
m = max(m_prev, x) # calculate m_i
s = s_prev * math.exp(m_prev - m) + math.exp(x - m) # calculate s_i
s_prev = s # save as s_{i-1}
Attention的数学表达
A=softmax(mask(Q⋅KTK))⋅V(8) \mathbb{A} = softmax(mask(\frac{\mathbb{Q} \cdot \mathbb{K}^T}{\sqrt{K}})) \cdot \mathbb{V} (8)
其中A, Q, K, V都是二维matrix,A, Q, K, V的shape分别依次是[N,K]
, [N,K]
, [M,K]
, [M,K]
。但对于Transformer架构而言,N == M
。mask()
函数旨在加载self-attention或者其他种类的masking,但对masking的探讨超越了本文范围,且因为食element-wise操作,所以不改变本文讨论的内容。
为什么要除以常数√K?** 同样是为了归一化,这样在进行 Q·Kᵀ 时(1)不会溢出(2)分布过于极端导致softmax溢出(3)在训练过程中造成不稳定。
简化1:1/√K scaling通常是在进行点乘运算前对Q或者K进行,因此在后面的讨论中不再赘述。
简化2:以后的讨论我们还会简化masking的步骤。
经过简化,同时寻找X和Y符号来表示中间步骤的变量,attention可以表达为:
softmax(Q⋅KT)⋅V=softmax(X)⋅V=Y⋅V=A(9) softmax(\mathbb{Q} \cdot \mathbb{K}^T) \cdot \mathbb{V} = softmax(\mathbb{X}) \cdot \mathbb{V} = \mathbb{Y} \cdot \mathbb{V} = \mathbb{A} (9)
选择符号X和x来表示softmax的自变量,与之前对于softmax的讨论统一;选择y来表示softmax的结果(可惜s已经被用了);选择A来表示attention的结果。
注意:同时我们需要了解,在LLM中,N和M代表的是 batch_size * num_head * seq_length,而K代表的是head_dim。因此N和M是可以运行时动态变量(runtime dynamic variable),而K是编译时静态变量(compile-time static variable)。这个理解对于进行性能优化非常重要。
应用online softmax
如果将online softmax应用在attention上,同时将有如下两遍循环操作:
循环(1)
完全套用online softmax的操作:
xn,m=X[n,m]=Q[n,:]⋅KT[:,m](10) x_{n,m} = \mathbb{X}[n,m] = \mathbb{Q}[n,:] \cdot \mathbb{K}^T[:,m] (10)
注意:此处的符号数量限制,为避免误导,将 n
, m
, k
分别定义为在N
, M
, K
三个dimension上的index。
可以看出,虽然matrix S是的shape为[N,K]
,但是softmax只作用于其inner dimension。也就是说其outer dimension是只是重复的循环操作。因此为了简便,我们只看内循环操作,即矩阵X的一行操作:
xm=Q[n,:]⋅KT[:,m](11) x_m = \mathbb{Q}[n,:] \cdot \mathbb{K}^T[:,m] (11)
其中的符号 n
指代外循环变量 for n in range(N)
,同时符号 m
指代内循环变量 for m in range(M)
gm=max(gm−1,xm)(12) g_m = max(g_{m-1}, x_m) (12)
sm=sm−1×egm−1−gm+exm−gm(13) s_{m} = s_{m-1} \times e^{g_{m-1} - g_m} + e^{x_m - g_m} (13)
此处由于符号m
被使用了,我们将max()
的结果从符号m
换成符号g
(greatest,因为符号x
, a
, i
都被使用了)
循环(2)
将类似思想套用在 Y•V 这个步骤上,则可以得到:
ym=exm−gMsM(14) y_m = \frac{e^{x_m - g_M}}{s_M} (14)
A0⃗=0Am⃗=Am−1⃗+ym×V[m,:](15) \vec{A_0} = 0 \newline \vec{A_m} = \vec{A_{m-1}} + y_m \times \mathbb{V}[m,:] (15)
如果我们调换循环K和循环M,结果一致,只是计算顺序变了。我们从一个一个计算A的element,变成了先计算A的一行的中间值(partial sum)然后加总。
matrix_a[:,:] = 0 # initialize matrix A to be zeros
for n in range(N):
for m in range(M):
for k in range(K):
matrix_a[n, k] += matrix_y[n, m] * matrix_v[m, k]
那么数学表示则变成:
A[n,:]=∑m=1M(Y[n,m]×V[m,:])(17) \mathbb{A}[n,:] = \sum_{m=1}^M( \mathbb{Y}[n,m] \times \mathbb{V}[m,:]) (17)
而(15)中的 y_m
则是当前第n行的Y[n,m]
,因此:
A[n,:]=AM⃗=∑m=1M(ym×V[m,:])=∑m=1M(exm−gMsM×Vm⃗)(18) \mathbb{A}[n,:] = \vec{A_M} = \sum_{m=1}^M(y_m \times \mathbb{V}[m,:]) = \sum_{m=1}M(\frac{e{x_m - g_M}}{s_M} \times \vec{V_m}) (18)
所以循环(2)变成了:
Am⃗=∑m=1M(exm−gMsM×Vm⃗)(19) \vec{A_m} = \sum_{m=1}M(\frac{e{x_m - g_M}}{s_M} \times \vec{V_m}) (19)
优化循环(2)
我们利用oneline softmax同样的思想,将计算 A_M
的过程转变成一个循环即A_m = func(A_{m-1})
,然后构造一个新的A'_m
来去掉数据依赖,因为我们只需要保证A'_M == A_M
即可,而不需要中间过程相等。
A′m⃗=∑i=1m(exi−gmsm×Vi⃗)(20) \vec{A’m} = \sum{i=1}m(\frac{e{x_i - g_m}}{s_m} \times \vec{V_i}) (20)
当 m=M
时,则有:
A′M⃗=∑i=1M(exi−gMsM×Vi⃗)=AM⃗(21) \vec{A’M} = \sum{i=1}M(\frac{e{x_i - g_M}}{s_M} \times \vec{V_i}) = \vec{A_M} (21)
如果比较(20)和(19),仅仅是将m
替换成了i
,将M
替换成了m
,但是本质原理是用循环计算的方式来避免写入内存(avoid materialize matrix)。
如果我们将(20)展开,则可以得到:
A′1⃗=ex1−g1s1×V1⃗Am′⃗=A′m−1⃗×sm−1×egm−1−gmsm+exm−gmsm×Vm⃗(22) \vec{A’1} = \frac{e^{x_1-g_1}}{s_1} \times \vec{V_1} \newline \vec{A’_m} = \vec{A’{m-1}} \times \frac{s_{m-1} \times e^{g_{m-1} - g_m}}{s_m} + \frac{e^{x_m-g_m}}{s_m} \times \vec{V_m} (22)+eX[n,m]−G[n,m](25)
- Attention A1[n,:]=eX[n,1]−G[n,1]S[n,1]×V[1,:]Am[n,:]=Am−1[n,:]×S[n,m−1]×eG[n,m−1]−G[n,m]S[n,m]+eX[n,m]−G[n,m]S[n,m]×V[m,:](26) \mathbb{A}1[n,:] = \frac{e^{\mathbb{X}[n,1] - \mathbb{G}[n,1]}}{\mathbb{S}[n,1]} \times \mathbb{V}[1,:] \newline \mathbb{A}_m[n,:] = \mathbb{A}{m-1}[n,:] \times \frac{\mathbb{S}[n,m-1] \times e^{\mathbb{G}[n,m-1] - \mathbb{G}[n,m]}}{\mathbb{S}[n,m]} + \frac{e^{\mathbb{X}[n,m] - \mathbb{G}[n,m]}}{\mathbb{S}[n,m]} \times \mathbb{V}[m,:] (26)
用Python表示:
matrix_q = [[random.random() for _ in range(K)] for _ in range(N)]
matrix_k = [[random.random() for _ in range(K)] for _ in range(M)]
matrix_v = [[random.random() for _ in range(K)] for _ in range(M)]
matrix_a = [[0.0 for _ in range(K)] for _ in range(M)]
for n in range(N):
for m in range(M):
x = 0.0
for k in range(K):
x += matrix_q[n,k] * matrix_k[m,k]
g = max(g_prev, x) if m > 0 else x
g_prev = g
e = math.exp(x - g) if m > 0 else 1.0
s = s_prev * math.exp(g_prev - g) + e if m > 0 else 1.0
s_prev = s
for k in range(K):
if m > 0:
matrix_a[n,k] = (matrix_a[n,k] * s_prev * math.exp(g_prev - g) + e * matrix_v[m,k] ) / s
else:
matrix_a[n,k] = e * matrix_v[0,k] / s