介绍:

在大模型计算工程中,Attention计算占据了模型推理阶段的绝大部分资源,本文依据https://zhuanlan.zhihu.com/p/666452391 这篇文章对大模型推理计算加速过程进行总结

计算瓶颈:

Attention的计算复杂度为 $O(N^2)$ 且该计算 无法并发

  1. 大模型需要处理很长的输入输出,在decoder-only的架构中,当前token需要和前面所有的token进行attention计算,随着seq_len和n的增加,模型计算的矩阵尺寸会越来越大。
  2. 同时在生成阶段,每个token的生成都依赖前面所有的计算,只能一个一个token生成,无法并发计算

优化目标

  1. 更低的延迟:延迟是指单个请求返回的时间
  2. 更高的吞吐量:吞吐量是一定时间内处理请求的总量

计算加速方式

因为计算本质上是CPU/GPU执行一系列的指令,在不改变模型结构的前提下,能进行的优化加速主要有以下三种:

  1. 减少非必要/重复计算:减少需要执行的指令数量
  2. 充分利用硬件并发:让单条指令可以一次处理多条数据,或者利用CPU和GPU的多核心机制同时处理多条指令
  3. 加速内存IO速度:利用缓存局部性加速内存读取指令的执行速度,或者减少不必要的内存读写操作

前两种为计算侧优化,后一种为内存IO优化

计算侧优化

KVCache

因为在deocding过程中,每一步都需要计算当前token和之前所有已生成token的attention,因此需要计算所有token的k和v向量,此时前面token的kv值在每一步decoding的时候都被重复计算了。这时候业界一般采用一种空间换时间的方法,将这些kv值存下来,作为两个[seq_len - 1, inner_dim]的向量,在每一轮计算中只需要计算当前token的kv值即可

KVCache是最简单直接有效的优化方法,因此一般模型的默认实现都会自带KVCache,并不需要额外实现,只需指定use_cache=True即可

KVCache 配置开启后,推理过程的两个阶段:

  1. 预填充阶段:发生在计算第一个输出token过程时,此时Cache是空的,计算时需要为每个tranformer layer计算并保存key cache和value cache,在输出token时cache完成填充。此时的FLOPs与KV Cache关闭时一致,存在大量的通用矩阵乘法(Gemm操作),推理速度很慢
  2. 使用KV Cache阶段:发生在计算第二个输出token至最后一个token过程中,此时Cache是有值的,每轮推理之需要读取Cache,同时将当前轮计算出的新的key、value追加写入到Cache。此时的FLOPs降低,Gemm操作变为矩阵向量乘法(Gemv操作),推理速度相对第一阶段变快,此时属于memory-bound类型计算。

Kernel优化和算子融合

对于Nvidia gpu环境而言,我们通过CUDA kernel来执行大部分运算操作,这里的操作包括两种:

  1. 计算受限操作(compute-bound):大的矩阵乘法,多channel的卷积
  2. 存储受限操作(memory-bound):

算子融合是一种常用于深度学习和其他计算密集型任务中的优化技术。其基本操作就是将多个连续的操作或算子合并成一个单一的算子,以减少计算和内存开销。使用算子融合可以提高执行效率,同时减少数据传输和临时存储,通过“少动数据多计算”的方式,更多的利用资源。实际上要做算子融合,还需要考虑具体的硬件架构、编程框架、编译器和具体任务。一些在深度学习中可以考虑融合的算子包括:

  1. 卷积和激活函数:将卷积层和激活函数融合起来,避免额外的内存访问
output = relu(conv2d(input, weights))
  1. 卷积核批量归一化:卷积后的batch normalization层可以和卷积融合从而减少对内存的访问
output = batch_norm(conv2d(input, weights))
  1. 全连接层和激活函数:全连接层后的激活函数可以和全连接层融合从而减少对内存的访问
output = batch_norm(conv2d(input, weights))

对于算子融合而言,需要考虑的一个关键因素是数值的稳定性。在某些情况下,算子融合可能会导致数值精度的损失。此外,还要考虑硬件和软件的兼容性,因为并非所有的算子组合都能从硬件加速中受益。

分布式推理

在大模型的训练和推理过程中,有以下几种主流的分布式并行方式:

内存IO优化

在做GPU计算性能优化时,除了考虑单纯计算的速度,也要考虑内存访问的速度。

  1. GPU和CPU缓存相同,也存在分级缓存(SRAM和HBM),级数越低的缓存大小越小,访问速度越快,因此我们在优化模型推理时也要考虑内存访问的局部性。
  2. KVCache随着batch size和seq_len的增加而扩张,在推理过程中会占据超过30%的内存,可能会出现因为内存不够用而限制最大并发度的问题。即:在并发度较高或者输入输出较大时,内存访问反而可能成为计算的瓶颈,而非CPU/GPU的计算量。

实际的GPU存储可以参考下面的GPU内存硬件分类(根据是否在芯片上可以分为片上内存和片下内存)

GPU的分级缓存
推理时的显存占用

Flash Attention

Attention计算的步骤:

$$S = Q K^T \in \mathbb{R^{N \times N}} $$

$$P = softmax(S) \in \mathbb{R^{N \times N}}$$

$$O = PV \in \mathbb{R^{N \times d}}$$

具体而言,这样的attention计算会分为三步:

在实际计算当中,QKV都是很大的矩阵,直接进行矩阵计算非常缓存不友好。因此通常考虑将矩阵进行分块乘法的操作,即每次只计算一个小的可以放进SRAM大小的block。Flash attention采用了这种分片(tiling)手段将QKV分割,具体单分片大小参考SRAM size,然后将切割完后的小矩阵都加载到SRAM中进行计算。

与此同时,从attention计算的公式可以看到,attention计算当中存在大量的HBM读写操作,这个过程是很耗时的。Flash attention的思想则是减少HBM的读写,将整个attention的计算过程在一次读写中完成。但是attention过程中存在softmax这个规约操作,需要计算矩阵中每一行元素的最大值,所以必须要等待所有分块计算完才能进入下一步。那么Flash attention是怎么做的呢?

Flash attention的效果

Image

  1. Flash attention的实现思想
    Flash attention采用了一种类似动态规划的技巧,实现了online softmax,可以在一个循环中计算出一个分块的最终结果。在这个过程中,Flash attention算法需要不断地对中间结果进行重新计算,但是得益于整个过程不需要读HBM,虽然增加了计算量但是整体的速度仍然获得提高(计算上冗余了,但是对HBM的IO访问大幅下降了)。
  2. Flash attention的具体实现方式

Image

(可以观看B站视频https://www.bilibili.com/video/BV1UT421k7rA/?spm_id_from=333.337.search-card.all.click&vd_source=40b4bab5322c77df2b3a60d699755a2e)

Flash Decoding

Flash attention优化的主要是大矩阵乘法,矩阵越大优化效果越好,但是在在线推理场景时,输入的batch size为1,此时的Q矩阵实际上是一个向量而非矩阵。这意味着如果batch size小于GPU上的SM数量(例如A100上有108个SMs),那么整个计算过程只使用了GPU的一小部分。此时Flash attention无法利用GPU的并发能力,因此研发了Flash decoding方法,以seq_len为维度进行并发,将K和V分成多个部分,并发地与Q相乘进行加速。

Image

具体而言,Flash decoding方法的流程为:

  1. 将keys和values分成较小的block
  2. 使用FlashAttention并行计算query与每个block的注意力(这是和FlashAttention最大的区别)。对于每个block的每行(每一行是一个特征维度),Flash Decoding会额外记录attention values的log-sum-exp(标量值,用于第3步进行rescale)
  3. 对所有output blocks进行reduction得到最终的output,需要用log-sum-exp值来重新调整每个块的贡献

因此,不同于Flash attention适合离线训练和批量推理,Flash decoding在单次推理且上下文长度较长时效果更好。可以通过FlashAttention库或者xFormers的attention kernel来使用Flash Decoding

Continuous Batching

在我们实际进行批量推理时,我们通常使用固定的batch size,将多个请求作为一个batch进行推理,此时我们在设计KVCache时,就会分配两个shape为[batch_size, max_seq_len, inner_dim]的向量,从而确保所有请求在Cache中都放得下。但是这个分配策略会带来两个问题:

  1. 不是所有的请求都可以达到max_seq_len的长度,因此KVCache中很多的内存都被浪费掉了。
  2. 即使一些请求输出长度很短,这些请求也需要等待输出较长的请求输出结束后才能一起返回。

显然,这种固定的Batching(被称为Static Batching)的策略存在问题。因此,Orca提出了一种持续Batching(Continuous Batching)的策略,这个策略允许输出较短的请求提前结束,并由新请求来占用已结束请求的KVCache空间。

Image

实际情况比上图这个简化的模型要复杂一些:由于预填充阶段需要计算,并且计算模式与后续生成时不同,因此无法轻松地将其与token生成进行批处理。连续批处理框架目前通过超参数来管理这一点:waiting_served_ratio,即等待预填充的请求与等待序列结束token的请求的比例.

在批量推理场景中,Continuous Batching可以将模型的吞吐量提升两到三倍。当前主流的推理框架比如Huggingface TGI, Ray serve, vllm, TensorRT-LLM等都支持Continuous Batching策略。

Paged Attention

vLLM团队分析了推理时的内存浪费问题,认为推理中存在三种内存浪费

vLLM团队认为,Continuous Batching可以部分解决Internal Fragmentation的问题,但是Reservation和External Fragmentation的浪费仍然存在。由此他们提出了Paged Attention,Paged Attention借鉴了操作系统中通过Page管理虚拟内存的思想:将KVCache分割为固定大小的Block,这些Block不需要存储在连续内存中,而是由一个统一的内存分配器管理。计算请求按需申请内存,不需要预先留好max_seq_len大小的内存,解决了reservation的浪费,请求结束后释放掉自己的blocks,解决了internal Fragmentation,而系统只需要分配小的block,解决了External Fragmentation的问题。

具体来说,不同于传统的注意力算法,Paged Attention允许在非连续的内存空间中存储连续的key和value。Paged Attention将每个序列的KV cache划分为块,每个块包含固定数量token的键和值。在注意力计算期间,Paged Attention内核可以有效地识别和获取这些块。

因为块在内存中不需要连续,因而可以像操作系统中的虚拟内存一样:

同时,将块视为页面,将token视为字节。请求序列的连续逻辑块通过块表映射到非连续物理块中。物理块在生成新token时按需分配。

Image

因为这种分配方式,内存浪费只会发生在序列的最后一个块中,在实际实践中,这种方式允许系统将更多序列进行批处理,提高GPU利用率,显著提升吞吐量。

除此之外,Paged Attention还采取了高效的内存共享机制。在并行采样中,多个输出序列是由同一个prompt生成的,此时prompt的计算和内存是可以在输出序列中共享的。这种内存共享是由其块表格实现的,与进程共享物理页面的方式类似,即通过将它们的逻辑块映射到同一个物理块的方式来共享块。为了确保安全共享,Paged Attention会对物理块的引用计数进行跟踪,并实现写时复制(Copy-on-write)机制。

Image

用户可以通过官方的vLLM库使用Paged Attention,英伟达的TensorRT-LLM库和微软的Deepspeed-MII库也对部分模型提供了支持。

SplitFuse

微软的Deepspeed团队认为:如果我们在F个前向推理中处理P个token,最高效的分配策略是将它们均分,即每个前向推理处理P/F个token。但是在实际的模型推理过程中,我们需要先在prefill阶段一次性处理整个prompt,然后在generation阶段一个个生成token,这样的方式是不符合最优策略的,因此deepspeed提出了SplitFuse:

总结

本篇笔记针对大模型推理加速的方法进行了一个基本的总结,在实际应用中,大多数的推理框架都是混合使用多个加速技术的,例如vllm同时使用了Continuous Batching和Paged Attention技术,而LMDeploy同时支持以上所有的加速技术。

参考文献&网页

  1. B站视频 Flash Attention 为什么那么快?原理讲解:https://www.bilibili.com/video/BV1UT421k7rA/?spm_id_from=333.337.search-card.all.click&vd_source=40b4bab5322c77df2b3a60d699755a2e
  2. 知乎文章 大模型推理性能优化之KV Cache解读:https://zhuanlan.zhihu.com/p/630832593
  3. Flash attention论文:https://arxiv.org/pdf/2205.14135
  4. Flash attention解读:https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf
  5. AnyScale博客:https://www.anyscale.com/blog/continuous-batching-llm-inference