Easy-to-Hard Generalization Models
Collection
6 items • Updated • 1
YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
PRMs are trained to predict the correctness of each step on the positions of "\n\n" and "<eos>".
Usage:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "ScalableMath/llemma-7b-prm-metamath-level-1to3-hf"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/llemma_7b")
qa_example = """# Question
Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$
# Solution
To convert from rectangular coordinates to polar coordinates, we use the formulas $r = \sqrt{x^2 + y^2}$ and $\theta = \arctan\left(\frac{y}{x}\right)$.
In this case, $x = 0$ and $y = 3$, so $r = \sqrt{0^2 + 3^2} = 3$ and $\theta = \arctan\left(\frac{3}{0}\right)$.
Since $\frac{3}{0}$ is undefined, we can say that $\theta$ is undefined.
However, we know that $\theta$ is an angle, and since $r > 0$, we can say that $\theta$ is any angle that satisfies $0 \le \theta < 2 \pi$.
Therefore, the polar coordinates of the point $(0,3)$ are $\boxed{(3,\theta)}$, where $0 \le \theta < 2 \pi$.
# Answer
(3,\theta)"""
begin_solution_tokens = tokenizer.encode("\n\n# Solution", add_special_tokens=False)[1:]
scoring_tokens = tokenizer.encode("\n\n", add_special_tokens=False)[1:]
eos_token = tokenizer.eos_token_id
input_ids = tokenizer.encode(qa_example)
begin_solution_flag = False
candidate_positions = []
for start_idx in range(len(input_ids)):
if tuple(input_ids[start_idx:start_idx+len(begin_solution_tokens)]) == tuple(begin_solution_tokens):
begin_solution_flag = True
if begin_solution_flag and tuple(input_ids[start_idx:start_idx+len(scoring_tokens)]) == tuple(scoring_tokens):
candidate_positions.append(start_idx)
if input_ids[start_idx] == eos_token:
candidate_positions.append(start_idx)
break
# maybe delete the first and the second to last candidate_positions
# because they are "\n\n" after "# Solution" and after "# Answer"
del candidate_positions[0]
del candidate_positions[-2]
input_tensor = torch.tensor([input_ids])
candidate_positions = torch.tensor(candidate_positions)
with torch.no_grad():
logits = model(input_tensor).logits
scores =logits.mean(dim=-1)
step_scores = scores[0][candidate_positions]
step_probs = torch.sigmoid(step_scores)
print(step_probs)
# tensor([0.7264, 0.8152, 0.7827, 0.4709, 0.5181])