defxmlcount_reward_func(completions, **kwargs) -> list[float]: contents = [completion[0]["content"] for completion in completions] return [count_xml(c) for c in contents]
soft_format_reward_func #card
包含特定的格式标签
1 2 3 4 5
defsoft_format_reward_func(completions, **kwargs) -> list[float]: pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>" responses = [completion[0]["content"] for completion in completions] matches = [re.fullmatch(pattern, r, re.DOTALL) for r in responses] return [0.5ifmatchelse0.0formatchin matches]
strict_format_reward_func #card
严格匹配?
1 2 3 4 5 6 7 8 9 10 11 12 13
defstrict_format_reward_func(completions, **kwargs) -> list[float]: pattern = r"<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>" responses = [completion[0]["content"] for completion in completions] # 新增调试日志 matches = [] for idx, r inenumerate(responses): print(f"\n--- Processing response {idx} ---") print("Raw content:", repr(r)) # 使用 repr() 显示转义字符 match = re.fullmatch(pattern, r, re.DOTALL) print("Match result:", bool(match)) matches.append(match)
return [0.5ifmatchelse0.0formatchin matches]
int_reward_func #card
答案是否是 int 类型
1 2 3 4
defint_reward_func(completions, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] extracted_responses = [extract_xml_answer(r) for r in responses] return [0.5if r.isdigit() else0.0for r in extracted_responses]
correctness_reward_func #card
答案是否正确
1 2 3 4 5 6 7
defcorrectness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] q = prompts[0][-1]['content'] extracted_responses = [extract_xml_answer(r) for r in responses] print('-' * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") return [2.0if r == a else0.0for r, a inzip(extracted_responses, answer)]