Skip to main content

投机解码简述

· 14 min read
ayanami

动笔的时候会有一种感觉,自己对这个方向了解的还是太少了... 所以大概不会讲得很学术,主打一个轻松愉快,让不了解的人也简单知道一下投机解码speculative decoding

投机的提出

当前,大型语言模型(LLM)在推理阶段普遍采用自回归解码策略,其核心特性是逐步串行生成 token,每一步都依赖前一步的输出。这一计算模式导致推理过程在系统层面面临严重的内存带宽瓶颈:每一步前向计算都需要将完整的模型参数从高带宽内存(HBM)加载到加速器缓存,但仅生成一个 token。由于每次只生成一个 token,导致大量的计算资源被闲置,无法充分发挥加速器的算力潜力,最终造成整体推理效率低下。 为解决这一问题,一种加速大型语言模型推理的思路是提高解码过程的算术强度(即总浮点运算次数 FLOPs 与数据传输量之间的比值),同时减少解码步骤。基于这一理念,研究者们提出了推测解码/投机解码(Speculative Decoding) 技术。Speculative Decoding 的核心思路如下图所示,首先以低成本的方式(一般来说是用小模型)快速生成多个候选 token,然后通过一次并行验证阶段快速验证多个 token,进而减少大模型的 decode 次数,从而达到加速的目的。

上面讲得比较学术,我尝试给一个自己的通俗些的解释:

llm的推理分成两个阶段,prefill 和 decode,prefill处理两个事情,计算输入(prompt)部分的attention和kvcache,输出第一个Token;而decode处理自回归的生成token的后续部分,即输出

为什么这样分呢?实际上是因为他们的计算负载不同,而更本质的原因是现有LLM的主流架构是CasualLM,即三角因果掩码,计算当前token时是无法看到未来token的。这带来了一个结果是,对于输入部分,我们可以并行的计算所有的输入token,但对于输出阶段,由于下一个token依赖于前一个token,所以我们只能串行的计算。

在LLM推理加速方面针对这两种计算的统一和调度有很多很多的研究,例如chunked prefill到新的pd分离、af/am分离等,但直接对这一传统范式发起挑战的大致就是几种:一种尝试换其他架构的模型,比如stable diffusion的dLLM,一种尝试采用多个输出头在训练时就学会“一次预测几个词”(deepseek MTP),剩下的就是投机解码

投机解码的核心思想就是,既然我们的decode阶段是内存密集型的(后面的token依赖于前面的token导致计算不能打满),那我可以把多余的算力利用起来,我用某种机制一次性猜测多个token,然后LLM从生成变为验证,就完成了并行化

Q1: 为什么说生成变为验证是并行化? A1: 因为验证这里有一个关键的地方是,在验证后一个token的时候,直接假设前面猜测的token都是对的,以猜测“千早爱音唐得没边”为例子,模型并不是串行的验证“千”对不对,“早”对不对,而是并行地验证这8个字,在验证“唐”的时候直接假设前面的输出“千早爱音”是对的。带来的效果是,如果“唐”被验证是错的,后续的所有token“唐的没边”都会被舍弃。

Q2:如何验证呢? A2:LLM生成token的最后一步是概率采样,如果猜测的概率是p1, LLM正常推理输出是p2, 如果p1 < p2(这里已经进行了猜测的采样),则选择猜测是对的;如果p1 > p2,则对的概率是 P(p2|p1)=p2/p1, 这样从直觉上就可以理解如何“验证”了,具体输出期望的一致性证明可以参考相关论文

Q3:投机在精度上是不是无损的? A: 看你如何定义。投机的核心是验证中的拒绝采样,学过rl的同学应该对这个概念很熟悉,拒绝采样带来的后果是,输出的期望是一样的,方差会变大。所以llm的期望是一样的,输出方差会变大,可能类似于调大温度。当然投机概率乘的多了还有一些数值精度上的问题。

Q4: 那并行的其他head空算不是更浪费算力和空间吗?一次all-layer的forward时间应该还挺长的 A:传统投机是不接受,但其他head的结果可以加入候选池,就是候选池改进的方法, 实际上不一定会这样验吧。medusa的tree attention就是,我不是序列地验head1,head2,而是尝试在树上直接找到综合接受期望最长的序列,也就是head1并非贪婪采样,不过现在推理引擎不是完全支持这个,据我所知sglang默认是有的,但vllm确实是这种序列的验法。浪费计算你说得对,所以投机work的前提是mem bound,但接受率越高浪费的不就越少吗,本质上还是接受率不够

如何生成猜测

主流是这几种方法:

  • 启发式,如n-gram,在很多任务中,输出会抄写prompt种已经给出的上文,比如总结任务,所以直接在给出的prompt中统计n-gram词频,取以现在输出末尾token开头的最佳选项作为猜测,优势是引入非常简单,劣势是吃任务类型(工作负载),对于很多任务没太大效果
  • 小模型,例如用qwen3-0.6b的输出作为qwen3-8b的输出的猜测
  • 自猜测,如medusa和eagle这种,给模型训练一个额外的附加结构,让其具备类似MTP的推理时猜多个token的能力。这个附加结构早期是放在模型的最后一层,即多个输出头,后来eagle提出最后一层(logits)不如倒数第二层(特征层),并且改造输出头的输入,再加上先前的token(即输入为,之前所有token的倒数第二层+当前token的倒数第二层+之前所有token的实际采样结果),效果非常好,能够达到7~8倍加速的疯狂数字

没有免费的午餐

那么古尔丹,代价是什么呢?

注意在开始我们就讲了,投机是一个利用空闲计算的方法,但实际上,利用空闲计算的方法不止投机一个,例如你有100张卡,你完全可以把不同的计算任务调度到不同的卡上,尽可能打满所有卡的计算

实际上这也是投机的痛点,或者说到底什么时候,投机才是有用的。

magicdec一文中已经指出,投机的适用场景常见于两种工作负载模式:

  1. 端侧推理,你只有一张卡,只能加载一个模型,这个模型还把你的显存占满了,那显然你无法通过加大batch来缓解memory bound,这时候投机是真有收益,eagle论文里面的7x加速也是batch=1的时候跑出来的
  2. 长上下文,你有大集群可以做不同任务的调度,但你的上下文实在太长,kvcache大小是随上下文线性增长的,上下文过长之后,你的大集群也硬生生被整成memory bound了(热知识:显存不是80G都是平等的,显然显卡也有SRAM/DRAM这样的高速低速区,更不提上下文太长之后有些kvcache直接就被offload到内存了)

端侧推理很好理解,那长上下文具体是多长呢?

magicdec给了一个指标是:对于接受率为0.8的投机,在实际的大batch size下(256),大概在3.2k token上下文开始投机能够取得收益(对于GQA这种模型而言,由于其在mem上较优,sd能加速的临界prefill长度会更高,对于非GQA模型是大概1.3k)

(关于端侧推理,我在我自己的一个项目上也试验过投机解码,平均输入长度大概是2k tokens,n-grams投机大概能加速30%,eagle由于我的训练数据等问题,也差不多)

当然以上只是一个最最简单的认识,实际上投机的很多算法相当复杂:

  • 能否快速剪枝某些置信度低的序列,不然预测k个token,可能的组合数指数增长吃不消?——medusa等

  • 剪枝之后如何高效计算?—— tree attention

  • 投机算法中,当出现拒绝验证时,后续的猜测token全部被丢弃,这些猜测token有没有可能被重用?——一系列维护候选池的方法

  • 小模型猜大模型很美好,但不是所有大模型都有对应的小模型,能否支持异构(大小模型词表不同)?—— huggingface uag tli等方法

  • 投机的超参数(例如一次猜几个token等)难以确认,能否用RL等方法优化超参选择? —— banditspec等

  • 能否通过LLM的置信度或者外部的一些规则等来动态开关投机,避免额外浪费的计算量?

  • 能否把投机也用到prefill过程中(选取kv)?—— specprefill

  • 在多模态场景中,如何使用投机,如果能的话,又该怎么做?——vllm roadmap(雾)

  • eagle还是太吃训练了,training方法如何做数据集选择?

  • 除了从prompt中选取候选,能否从参考资料等其他文本中选取猜测?—— snowflakes suffix decoding

  • 投机如何和现有大规模并行融合?(在vllm的投机集成中,投机的模型的并行都是简单的1,即投机模型不做tp来降低实现复杂度)—— 最新的 字节swiftspec

  • ...

展望?

最近投机是真的很火,aaai26中好像就有30篇投机的文章

如果从一个应用者的视角来说的话,现有推理框架(比如vllm&sglang)基本都有投机的集成了,只是集成多少的问题

而训练投机的话,sglang的specForge项目把它变得相当傻瓜化了,现在正在快速发展中

Loading Comments...