congr-visualizer / src /generation_utils.py
Shahzaib98's picture
Upload 11 files
d2ff6a7 verified
import re
from huggingface_hub import InferenceClient
from openai import OpenAI
from together import Together
from src.text_poa_graph import TextPOAGraph
def extract_context(text_poa_graph, node_id):
"""Extract context up to and including the specified node_id."""
contexts = {}
for label, path in text_poa_graph._seq_paths.items():
idx = path.index(node_id)
context = path[: idx + 1]
contexts[label] = " ".join(
text_poa_graph.nodedict[nid].variations.get(label, text_poa_graph.nodedict[nid].text)
for nid in context
)
return contexts
def extract_alternative_paths(text_poa_graph: TextPOAGraph, node_id):
"""Extract all alternative paths from this uncertainty point to the next consensus node."""
alternative_paths = {}
for label, path in text_poa_graph._seq_paths.items():
idx = path.index(node_id)
next_cn = None
for i in range(idx + 1, len(path)):
if path[i] in text_poa_graph.consensus_node_ids:
next_cn = path[i]
break
if next_cn:
next_cn_idx = path.index(next_cn)
alternative_segment = path[idx + 1 : next_cn_idx + 1]
else:
alternative_segment = []
alternative_paths[label] = " ".join(
text_poa_graph.nodedict[nid].variations.get(label, text_poa_graph.nodedict[nid].text)
for nid in alternative_segment
)
return alternative_paths
def is_same_branch(text_poa_graph: TextPOAGraph, node_id, lable_1, label_2):
"""Check if the next vaiable nodes for two sequences are the same after node_id."""
path_1 = text_poa_graph._seq_paths[lable_1]
path_2 = text_poa_graph._seq_paths[label_2]
idx_1 = path_1.index(node_id)
idx_2 = path_2.index(node_id)
return path_1[idx_1 + 1] == path_2[idx_2 + 1]
def extract_equivalent_classes(text_poa_graph: TextPOAGraph, node_id, selected_labels):
"""Extract equivalent classes from the text POA graph."""
if not selected_labels:
return []
equivalent_classes = []
for label in selected_labels:
matched = False
for class_group in equivalent_classes:
if is_same_branch(text_poa_graph, node_id, class_group[0], label):
class_group.append(label)
matched = True
break
if not matched:
equivalent_classes.append([label])
return equivalent_classes
def verify_correctness_pairwise(
full_text_1: str, full_text_2: str, verification_model: str, problem: str, api: str = "openai"
):
"""Pairwise verification of two partial solution paths."""
if api == "openai":
client = OpenAI()
elif api == "hf":
client = InferenceClient()
elif api == "together":
client = Together()
else:
raise ValueError(f"Invalid API: {api}")
prompt = f"""
You will be given a problem and 2 partial solutions.
Your task is to use comparison as an EFFICIENCY TOOL to quickly identify potential errors.
You will be given guidelines to follow, and you will be penalized if you do not follow them.
Problem: {problem}
Partial Solution 1: {full_text_1}
Partial Solution 2: {full_text_2}
CRITICAL GUIDELINES:
- DO NOT penalize a solution for being incomplete or having missing steps
- DO NOT make a comparison of which solution is better
- DO NOT consider steps incorrect just because they differ between solutions
- DO NOT prematurely evaluate based on final answers or future steps
- DO NOT expect both solutions to be at the same stage of completion
- DO NOT consider a step incorrect just because it lacks sufficient detail or justification
KEY EFFICIENCY PRINCIPLE:
- Use agreement between solutions as evidence of correctness
- Use disagreement as a signal to investigate more deeply
- Only label a step as an error if it contains a specific mathematical mistake
- Incompleteness is not a mathematical error.
Here are the instructions for how to complete your task:
EFFICIENT VERIFICATION APPROACH:
1. QUICK COMPARISON (Use this to focus your attention):
- Immediately identify where the solutions differ in approach or results
- Use these differences as "error hotspots" to prioritize your verification
- When solutions agree, you can generally assume that part is correct
- When solutions disagree, investigate those specific points deeply
2. TARGETED VERIFICATION (Only where needed):
- Most important: Do not consider any incomplete steps as errors
- Focus your mathematical verification on the "hotspots" identified above
- Check mathematical validity only at points of difference or uncertainty
- Avoid line-by-line checking of steps where solutions agree
- For each potential error spot, verify if the mathematical reasoning is valid
- If an intermediate step is later corrected, do not penalize the solution for having the incorrect intermediate step
After your targeted verification, propose a score tuple (score_1, score_2):
- Score (1,1) if both partial solutions are valid
- Score (1,0) if only the first solution is valid
- Score (0,1) if only the second solution is valid
- Score (0,0) if both solutions are invalid
In case you score a solution as 0, you must give an explanation for each check below:
3. FINAL CHECKS:
- If you score a solution as 0, you MUST identify the specific mathematical error.
- You must also double check the problem statement. Reconsider your score and determine if you have misinterpreted the problem statement.
- You must also check whether you have penalized a solution for being incomplete or having missing steps.
Before outputting your final score, you must answer these questions:
STOP! Did you give a score of 0 to a solution that was incomplete?
STOP! Did you penalize a solution for being incomplete or having missing steps?
STOP! Did you make a comparison of which solution is better?
STOP! Did you consider steps incorrect just because they differ between solutions?
STOP! Did you prematurely evaluate based on final answers?
STOP! Did you consider a step incorrect just because it lacks sufficient detail or justification?
Now give your final score:
Final score:
"""
completion = client.chat.completions.create(
model=verification_model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
temperature=0.0,
)
response = completion.choices[0].message.content.strip()
print(full_text_1)
print(full_text_2)
print(f"Correctness score: {response} \n")
score_match = re.findall(r"\(\s*([01](?:\.0)?)\s*,\s*([01](?:\.0)?)\s*\)", response)
score = score_match[-1] if score_match else (0, 0)
return score
def self_complete(verification_prompt: str, verification_model: str, api: str = "openai"):
print(verification_prompt)
"""Completetion method"""
if api == "openai":
client = OpenAI()
elif api == "hf":
client = InferenceClient()
elif api == "together":
client = Together()
else:
raise ValueError(f"Invalid API: {api}")
completion = client.chat.completions.create(
model=verification_model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": verification_prompt},
],
temperature=0.0,
)
response = completion.choices[0].message.content.strip()
return response