@大模型强化学习之GRPO算法原理浅析

链接: 大模型强化学习之GRPO算法原理浅析 - 知乎

GRPO 方法通俗案例 #card
image.png

  • 模型回答

image.png

  • 计算Advantage结果(一个组内开始计算)

    • advantage的计算公式如下:

      • A^i,t=rimean(r)std(r)\hat{A}_{i, t}=\frac{r_i-\operatorname{mean}(\mathbf{r})}{\operatorname{std}(\mathbf{r})}

image.png

通过标准化来计算奖励,优势函数的作用是某一个输出的token的数值相对于平均输出的优劣势,#card

  • 如果某一个输出的奖励高于平均的mean数值,则结果是正的,

  • 反之低于mean,结果是负的,

  • 这样策略模型会更倾向于生成那些具有较高奖励的输出

一个GRPO算法的的大致流程如下: #incremental #card

  • 用户输入query,模型推理生成一组回答

  • 根据reward function的定义(例如准确性、格式)为每个回答计算奖励得分

  • 在整个组内进行统一对比,计算它们的相对优势Advantage(上面的公式计算)

  • 更新策略,倾向于具有更高优势的回答

  • 模型更新的同时进行KL散度的正则,防止模型学习偏了

[[GRPO]] 公式原理解析

通过不同的规则得到 reward

  • xmlcount_reward_func 格式 #card
    • 输出是否是 xml 格式
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 计算格式分
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1]) * 0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
  • soft_format_reward_func #card
    • 包含特定的格式标签
1
2
3
4
5
def soft_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.fullmatch(pattern, r, re.DOTALL) for r in responses]
return [0.5 if match else 0.0 for match in matches]
  • strict_format_reward_func #card
    • 严格匹配?
1
2
3
4
5
6
7
8
9
10
11
12
13
def strict_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r"<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>"
responses = [completion[0]["content"] for completion in completions]
# 新增调试日志
matches = []
for idx, r in enumerate(responses):
print(f"\n--- Processing response {idx} ---")
print("Raw content:", repr(r)) # 使用 repr() 显示转义字符
match = re.fullmatch(pattern, r, re.DOTALL)
print("Match result:", bool(match))
matches.append(match)

return [0.5 if match else 0.0 for match in matches]
  • int_reward_func #card
    • 答案是否是 int 类型
1
2
3
4
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
  • correctness_reward_func #card
    • 答案是否正确
1
2
3
4
5
6
7
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-' * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}",
f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
作者

Ryen Xiang

发布于

2025-04-13

更新于

2025-04-13

许可协议


网络回响

评论