Skip to main content

从RLHF到DPO

RM过程

在先前的RLHF之中, 我们有一个表示偏好的Loss

L=E(x,yw,yl)Dlog(σ(rϕ(x,yw))σ(rϕ(x,yl)))L=E_{(x,y_w,y_l)\sim D}log(\sigma(r_{\phi}(x,y_w))-\sigma(r_{\phi}(x,y_l)))

这个loss我们从直观意义上解释了一下这样设计的效果, 但要从里面发掘更多东西, 乃至进一步推广到DPO, 我们还要走一遍这玩意的“推导”过程

如何评价人类对不同答案的“偏好”? 我们不妨先考虑一个简单的, 两个答案之中选择一个更喜欢的的情况(多个答案时, 可以加上排名的权重因子1Ck2\frac{1}{C_k^2}, 也可以拆成多个两个答案比较的pair)

也就是说我们两个答案, 一个是更喜欢的,叫win; 一个是不那么喜欢的, 叫lose. 我们希望模型能够把win的得分高于lose, 也就是说rϕ(x,yw)>rϕ(x,yl)r_{\phi}(x,y_w)>r_{\phi}(x,y_l)

对于人类而言, 选择win的概率是1, 即P(rϕ(x,yw)>rϕ(x,yl)x)=1P(r_{\phi}(x,y_w)>r_{\phi}(x,y_l)|x)=1,而对于模型而言,我们希望它尽可能去拟合人类的选择, 也就是模型的分布去接近人类的分布

从两个分布得到loss, 一个常见的想法就是交叉熵H(p,q)=j=1np(xj)logq(xj)=EplogqH(p,q)=-\sum_{j=1}^np(x_j)logq(x_j)=-E_{p}logq, 其中p是真实分布, q是模型分布

对于模型而言, 常见的对事物的比较关系的建模是BT model(Bradley–Terry model), 也就是说rϕ(x,yw)>rϕ(x,yl)r_{\phi}(x,y_w)>r_{\phi}(x,y_l)的概率是er(x,yw)er(x,yw)+er(x,yl)=σ(rϕ(x,yw)rϕ(x,yl))\frac{e^{r(x,y_w)}}{e^{r(x,y_w)} + e^{r(x,y_l)}} =\sigma(r_{\phi}(x,y_w)-r_{\phi}(x,y_l)) (上下同除一个分子)

那模型分布和人类分布的二分布交叉熵BCE= Ep(x,yw,yl)logσ(rϕ(x,yw)rϕ(x,yl))=E(x,yw,yl)Dlogσ(rϕ(x,yw)rϕ(x,yl))-E_{p(x,y_w,y_l)}log\sigma(r_{\phi}(x,y_w)-r_{\phi}(x,y_l))=-E_{(x,y_w,y_l)\sim D}log\sigma(r_{\phi}(x,y_w)-r_{\phi}(x,y_l)) (1*log + 0*... )

所以我们偏好的Loss就是这个BCE

RL过程

有了reward model, 在RL过程之中我们的训练目标是这样的(ref是原始模型)

maxπθExD,yπθ(yx)[rϕ(x,y)]βDKL[πθ(yx)πref(yx)]max_{\pi_{\theta}}E_{x\sim D,y\sim \pi_{\theta}(y|x)}[r_{\phi}(x,y)]-\beta D_{KL}[\pi_{\theta}(y|x)||\pi_{ref}(y|x)]

πθ\pi_{\theta}就是我们的LLM

DPO的作者发现了什么呢, 这个式子是有显式解的!并且这个显式解表明了我们可以把reward model合进RL过程, 我们的模型自己就是一个隐藏的reward model!

具体的推导过程是这样的

展开KL散度, 我们有

DKL[πθ(yx)πref(yx)]=πθ(yx)logπref(yx)πθ(yx)=Eyπθ(yx)[logπθ(yx)πref(yx)]D_{KL}[\pi_{\theta}(y|x)||\pi_{ref}(y|x)]=-\pi_{\theta}(y|x)log\frac{\pi_{ref}(y|x)}{\pi_{\theta}(y|x)}=E_{y\sim \pi_{\theta}(y|x)}[log\frac{\pi_{\theta}(y|x)}{\pi_{ref}(y|x)}]

并入前一项, 我们有

maxπθExD,yπθ(yx)[rϕ(x,y)βlogπθ(yx)πref(yx)]max_{\pi_{\theta}}E_{x\sim D,y\sim \pi_{\theta}(y|x)}[r_{\phi}(x,y)-\beta log\frac{\pi_{\theta}(y|x)}{\pi_{ref}(y|x)}]

提出一个β-\beta

我们有minπθExD,yπθ(yx)[1βrϕ(x,y)+logπθ(yx)πref(yx)]min_{\pi_{\theta}}E_{x\sim D,y\sim \pi_{\theta}(y|x)}[-\frac{1}{\beta}r_{\phi}(x,y)+log\frac{\pi_{\theta}(y|x)}{\pi_{ref}(y|x)}]

括号里面的是变形为logπθ(yx)πref(yx)e1βrϕ(x,y)log\frac{\pi_{\theta}(y|x)}{\pi_{ref}(y|x)e^{\frac{1}{\beta}r_{\phi}(x,y)}}

已经趋近于两个策略的比例了, 我们把分母归一化

Z(x)=yπref(yx)e1βrϕ(x,y)Z(x)=\sum_{y}\pi_{ref}(y|x)e^{\frac{1}{\beta}r_{\phi}(x,y)}

logπθ(yx)πref(yx)e1βrϕ(x,y)=logπθ(yx)1Z(x)πref(yx)e1βrϕ(x,y)logZ(x)log\frac{\pi_{\theta}(y|x)}{\pi_{ref}(y|x)e^{\frac{1}{\beta}r_{\phi}(x,y)}}=log\frac{\pi_{\theta}(y|x)}{\frac{1}{Z(x)}\pi_{ref}(y|x)e^{\frac{1}{\beta}r_{\phi}(x,y)}}-logZ(x)

这个东西就是最优策略π=1Z(x)πref(yx)e1βrϕ(x,y)\pi^{*}=\frac{1}{Z(x)}\pi_{ref}(y|x)e^{\frac{1}{\beta}r_{\phi}(x,y)}, 后面的logZ(x)不影响梯度对于θ\theta的导数

所以上式等效于minπθExD,yπθ(yx)logπθπ=minπθExDDKL(πθ(yx)π(yx))min_{\pi_{\theta}}E_{x\sim D,y\sim \pi_{\theta}(y|x)}log\frac{\pi_{\theta}}{\pi^{*}}=min_{\pi_{\theta}}E_{x\sim D}D_{KL}(\pi_{\theta}(y|x)|\pi^{*}(y|x))

KL散度的最小值当然在两个分布相等的时候取到, 所以我们的最优策略就是π\pi^{*}

π\pi^{*}式子做一点变形, 得到奖励的关系式rϕ(x,y)=βlogπ(yx)πref(yx)+βlogZ(x)r_{\phi}(x,y)=\beta log\frac{\pi^{*}(y|x)}{\pi_{ref}(y|x)}+\beta logZ(x)

也就是说, 对于不同的奖励rr,其在RL过程之中对应的最优策略π\pi^{*}如上式,π=f(r)\pi^{*}=f(r)

从另一个角度来说, 我们就可以用π\pi^{*}去替换偏好之中的r, 然后去优化这个π\pi^{*}, 就等效于在maxπθExD,yπθ(yx)[rϕ(x,y)]βDKL[πθ(yx)πref(yx)]max_{\pi_{\theta}}E_{x\sim D,y\sim \pi_{\theta}(y|x)}[r_{\phi}(x,y)]-\beta D_{KL}[\pi_{\theta}(y|x)||\pi_{ref}(y|x)]的约束下, 去优化π\pi

我们不知道Z(x)Z(x),但没有关系, 我们上面计算偏好的损失 L=E(x,yw,yl)Dlog(σ(rϕ(x,yw))σ(rϕ(x,yl)))L=E_{(x,y_w,y_l)\sim D}log(\sigma(r_{\phi}(x,y_w))-\sigma(r_{\phi}(x,y_l)))

两者之差消去Z(x)Z(x)项: Lpref=E(x,yw,yl)Dlog(σ(βlogπ(ywx)πref(ywx)βlogπ(ylx)πref(ylx)))L_{pref}=E_{(x,y_w,y_l)\sim D}log(\sigma(\beta log\frac{\pi^{*}(y_w|x)}{\pi_{ref}(y_w|x)}-\beta log\frac{\pi^{*}(y_l|x)}{\pi_{ref}(y_l|x)})), 优化这个π\pi^{*}

诶,我们就发现了,RL的目标自然地包括在了reward model里面, 两者统一了, 这也就是“模型自己就是一个隐藏的reward model”的意思

LDPO(πθ,πref)=E(x,yw,yl)Dlog(σ(βlogπθ(ywx)πref(ywx)βlogπθ(ylx)πref(ylx)))L_{DPO}(\pi_{\theta},\pi_{ref})=E_{(x,y_w,y_l)\sim D}log(\sigma(\beta log\frac{\pi_{\theta}(y_w|x)}{\pi_{ref}(y_w|x)}-\beta log\frac{\pi_{\theta}(y_l|x)}{\pi_{ref}(y_l|x)})),对这个θ\theta求导求出最优的πθ\pi_{\theta}就可以得到在RL约束下的pref Loss最小值

DPO 通过以上的公式转换把 RLHF 无损地转化为了 SFT,在训练的时候不再需要同时跑 4 个模型(reward model, ref model, critic, actor),而是只用跑 actor 和 ref 2 个模型,甚至由于不再在线采数据,ref model 的输出可以预先存下来,训练的时候重复使用

一个参考代码

作者:技术微佬
链接:https://zhuanlan.zhihu.com/p/714131454
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, LlamaConfig
from copy import deepcopy

torch.manual_seed(0)
if __name__ == "__main__":
# 超参数
beta = 0.1
# 加载模型
policy_model = LlamaForCausalLM(config=LlamaConfig(vocab_size=1000, num_hidden_layers=1, hidden_size=128))
reference_model = deepcopy(policy_model)

# data
prompt_ids = [1, 2, 3, 4, 5, 6]
good_response_ids = [7, 8, 9, 10]
# 对loss稍加修改可以应对一个good和多个bad的情况
bad_response_ids_list = [[1, 2, 3, 0], [4, 5, 6, 0]]

# 转换成模型输入
input_ids = torch.LongTensor(
[prompt_ids + good_response_ids, *[prompt_ids + bad_response_ids for bad_response_ids in bad_response_ids_list]]
)
# labels 提前做个shift
labels = torch.LongTensor(
[
[-100] * len(prompt_ids) + good_response_ids,
*[[-100] * len(prompt_ids) + bad_response_ids for bad_response_ids in bad_response_ids_list]
]
)[:, 1:]
loss_mask = (labels != -100)
labels[labels == -100] = 0
# 计算 policy model的log prob
logits = policy_model(input_ids)["logits"][:, :-1, :]
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
all_logps = (per_token_logps * loss_mask).sum(-1)
# 暂时写死第一个是good response的概率
policy_good_logps, policy_bad_logps = all_logps[:1], all_logps[1:]

# 计算 reference model的log prob
with torch.no_grad():
logits = reference_model(input_ids)["logits"][:, :-1, :]
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
all_logps = (per_token_logps * loss_mask).sum(-1)
# 暂时写死第一个是good response的概率
reference_good_logps, reference_bad_logps = all_logps[:1], all_logps[1:]

# 计算loss,会自动进行广播
logits = (policy_good_logps - reference_good_logps) - (policy_bad_logps - reference_bad_logps)
loss = -F.logsigmoid(beta * logits).mean()
print(loss)

理论上看似非常美好无懈可击,但为什么没有全部转投DPO而是还是在用PPO呢?

我觉得知乎上一个答案说得很有道理, 关键在于DPO对于训练数据之外的数据(分布外数据),其泛化能力是要比PPO弱的,训练数据不够和训练数据分布有偏,ref LLM本身的偏移和未显式建模等都会影响实际有限样本下的收敛性。偏好数据不涵盖πref\pi_{ref} LLM的(训练时)输入分布, 而模型结构导致在处理这样的分布外数据的时候鲁棒性比PPO弱

DPO vs PPO:深度解读谁是LLM Alignment的未来【不定期更新】 - Whisper的文章 - 知乎 https://zhuanlan.zhihu.com/p/11913305485