@PPO与GRPO中的KL散度近似计算
标准 [[KL Divergence]] 计算公式
- 在实际计算中,直接计算 KL 散度可能非常困难,主要原因如下:#card
-
需要对所有 x 进行求和或积分,计算成本高。
-
计算过程中可能涉及大规模概率分布,导致内存消耗过大。
-
[[K1 估计器]] 计算公式 :->
-
为什么是 KL 的无偏估计 #card
-
K1 估计器的方差较大,因为 #card
-
若 可能为负值,导致估计值波动大。
-
若 和 差异大, 可能出现较大的数值变化。
-
[[K2 估计器]] 公式 :->
- K2 估计器是 有偏估计,但其方差较低,且 偏差在实证中较小,因为:#card
- K2 确保所有样本的值都是正数,使其在计算中更加稳定。
[[K3 估计器]] 公式 :->
-
K3 估计器也是 的无偏估计。证明如下:#card
-
已知:
-
我们只需证明:
-
即:
-
展开计算:
-
-
K3 估计器恒大于等于 0。证明如下 #card
-
设函数:
-
求导数:
-
当 时, ;当 时, 。即, 在 处取最小值,且:
-
所以,对于任意 ,均有:
-
PPO 算法中的 KL 近似计算#card
在 GRPO(Generalized Reinforcement Policy Optimization) 算法中,KL散度是显式融入到损失函数中,近似计算采用的是 K3 估计器,如下所示:#card
@PPO与GRPO中的KL散度近似计算
https://blog.xiang578.com/post/logseq/@PPO与GRPO中的KL散度近似计算.html