PPO与GRPO中的KL散度近似计算 - 知乎
标准 [[KL Divergence]] 计算公式 KL(q∥p)=∑xq(x)logp(x)q(x)=Ex∼q[logp(x)q(x)]
- 在实际计算中,直接计算 KL 散度可能非常困难,主要原因如下:#card
- 需要对所有 x 进行求和或积分,计算成本高。
- 计算过程中可能涉及大规模概率分布,导致内存消耗过大。
[[K1 估计器]] 计算公式 → k1=logp(x)q(x)=−logr
- 为什么是 KL 的无偏估计 #card
- E[k1]=Ex∼q[logp(x)q(x)]=KL(q∥p)
- K1 估计器的方差较大,因为 #card
- 若 q(x)<p(x),k1 可能为负值,导致估计值波动大。
- 若 p(x) 和 q(x) 差异大,k1 可能出现较大的数值变化。
[[K2 估计器]] 公式 → k2=21(logq(x)p(x))2=21(logr)2
- K2 估计器是 有偏估计,但其方差较低,且 偏差在实证中较小,因为:#card
- K2 确保所有样本的值都是正数,使其在计算中更加稳定。
[[K3 估计器]] 公式 → k3=p(x)q(x)−1−logp(x)q(x)=(r−1)−logr=(r−1)+k1
- K3 估计器也是 KL(q∥p) 的无偏估计。证明如下:#card
- 已知:E[k1]=KL(q∥p)
- 我们只需证明:E[r−1]=0
- 即:E[r]=1
- 展开计算:E[r]=Ex∼q[q(x)p(x)]=∫q(x)⋅q(x)p(x)dx=∫p(x)dx=1
- K3 估计器恒大于等于 0。证明如下 #card
- 设函数:f(x)=x−1−logx,(x>0)
- 求导数:f′(x)=1−x1
- 当 0<x<1 时,f′(x)<0 ;当 x>1 时,f′(x)>0 。即,f(x) 在 x=1 处取最小值,且:f(1)=1−1−log1=0
- 所以,对于任意 x>0 ,均有:f(x)≥0
PPO 算法中的 KL 近似计算 #card

在 GRPO(Generalized Reinforcement Policy Optimization) 算法中,KL散度是显式融入到损失函数中,近似计算采用的是 K3 估计器,如下所示:#card
