Skip to content

Commit d4a00ef

Browse files
authored
Add Math-Verify Support (#545)
# Description #287, #295. This PR introduces support for [Math-Verify](https://github.com/huggingface/Math-Verify) as a new rule-based reward scorer, significantly improving evaluation accuracy. # Key changes - Added `math-verify` to the installation dependencies. - Introduced `reward_score/math_verify.py` and updated `reward_score/__init__.py`. # Test Comparison between the existing scorer in math.py and the newly added `math_verify.py`, using Qwen2.5-Math-7B-Instruct: ``` # Use scorer in math.py (original) {'val/test_score/DigitalLearningGmbH/MATH-lighteval': 0.803} # Use scorer in math_verify.py (newly added) {'val/test_score/DigitalLearningGmbH/MATH-lighteval': 0.8338} ``` Test scripts: ```bash set -x # Data Process python examples/data_preprocess/math_dataset.py --local_dir /workspace/datasets/math # Evaluation export CUDA_VISIBLE_DEVICES=4,5,6,7 export VLLM_ATTENTION_BACKEND=XFORMERS math_train_path=/workspace/datasets/math/train.parquet math_test_path=/workspace/datasets/math/test.parquet python3 -m verl.trainer.main_ppo \ data.train_files="$math_train_path" \ data.val_files="$math_test_path" \ data.max_prompt_length=2048 \ data.max_response_length=2048 \ actor_rollout_ref.model.path=Qwen/Qwen2.5-Math-7B-Instruct \ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=1 \ actor_rollout_ref.rollout.temperature=0 \ trainer.logger=['console'] \ trainer.project_name='test-math-verify' \ trainer.experiment_name='test-math-verify' \ +trainer.val_before_train=True \ trainer.n_gpus_per_node=4 \ trainer.nnodes=1 \ trainer.total_epochs=0 \ data.train_batch_size=1024 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ algorithm.adv_estimator=grpo $@ ```
1 parent 1d12fe3 commit d4a00ef

File tree

5 files changed

+42
-2
lines changed

5 files changed

+42
-2
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dependencies = [
3535
"datasets",
3636
"dill",
3737
"hydra-core",
38+
"math-verify",
3839
"numpy",
3940
"pandas",
4041
"peft",

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ dill
66
flash-attn
77
hydra-core
88
liger-kernel
9+
math-verify[antlr4_9_3]
910
numpy
1011
pandas
1112
peft

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
'datasets',
2828
'dill',
2929
'hydra-core',
30+
'math-verify',
3031
'numpy',
3132
'pandas',
3233
'peft',

verl/utils/reward_score/__init__.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N
1919
from . import gsm8k
2020
res = gsm8k.compute_score(solution_str, ground_truth)
2121
elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval']:
22-
from . import math
23-
res = math.compute_score(solution_str, ground_truth)
22+
# from . import math
23+
# res = math.compute_score(solution_str, ground_truth)
24+
25+
# Use Math-Verify (https://github.com/huggingface/Math-Verify) for better evaluation accuracy
26+
from . import math_verify
27+
res = math_verify.compute_score(solution_str, ground_truth)
2428
elif data_source in [
2529
'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12',
2630
'numina_olympiads'
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from math_verify.metric import math_metric
16+
from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig
17+
18+
19+
def compute_score(model_output: str, ground_truth: str) -> bool:
20+
verify_func = math_metric(
21+
gold_extraction_target=(LatexExtractionConfig(),),
22+
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
23+
)
24+
ret_score = 0.
25+
26+
# Wrap the ground truth in \boxed{} format for verification
27+
ground_truth_boxed = "\\boxed{" + ground_truth + "}"
28+
try:
29+
ret_score, _ = verify_func([ground_truth_boxed], [model_output])
30+
except Exception as e:
31+
print(e)
32+
33+
return ret_score

0 commit comments

Comments
 (0)