什么是Online Softmax and Flash Attention?
dev.to·1d·
Discuss: DEV

Softmax是Transformer模型架构中非常重要的一环。它所在的Attention模块虽然所需要的计算量不大,但也是不容忽视的一环。同时由于它本身的数学特性所造成的数据依赖,如果按照其原始方法来进行运算,会耗费大量的计算时间,因为它需要三次完整读取数据。

Online normalizer calculation for softmax 提出了online softmax,通过牺牲计算来节省数据读取次数,将三遍完整读取(3 passes)降低到两遍完整读取(2 passes)。

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 则应用了类似的思想,更进一步,利用NVIDIA GPU的本地存储,将读取次数减少为一遍。

那么就让我们详细了解一下其中奥秘吧。本文会出现数学公式,但不要慌张,仅仅是简单的数组知识而已。同时也会辅以可以运行的Python代码,以便理解。

Softmax的数学表达

首先,softmax作用在一维vector上,而非多维tensor。

softmax(X⃗[1:N])=exi∑j−1Nexj(1) softmax( \vec{X} [1:N] ) = \frac {e^{x_i}} {\sum_{j-1}{N} e{x_j} }(1)

我们需要(1)先计算每个xe^x并进行加总得到总和,(2)然后计算除法。

Safe softmax

但是由于e^x的特性,当x相对较大时,e^x就容易溢出,尤其是使用fl…

Similar Posts

Loading similar posts...