from typing import Optional from src.generation_utils import ( extract_alternative_paths, extract_context, extract_equivalent_classes, self_complete, verify_correctness_pairwise, ) from src.global_edit_utils import clean_up_text from src.text_poa_graph import TextPOAGraph """ Decodes from a TextPOAGraph object to a string by sequentially selecting nodes based on the selection threshold. Only the primary variation of selected variable nodes are selected. Text is edited using the global_edit_function (e.g. to clean up text by removing incoherencies, disfluencies, and redundancies). Args: text_poa_graph: The TextPOAGraph object to decode. selection_threshold: The threshold for selecting nodes. model: The model to use for decoding. Returns: A string of the decoded text. """ def decode_consensus( text_poa_graph: TextPOAGraph, selection_threshold: Optional[float] = 0.5, task: str = "bio", verbose: bool = False, **kwargs, ) -> str: if text_poa_graph.failed: return "Abstain" text_poa_graph.toposort() consensus_node_ids = text_poa_graph.consensus_node_ids selected_node_ids = [] for node_id in consensus_node_ids: if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id: continue selected_node_ids.append(node_id) for neighbor_id in text_poa_graph.nodedict[node_id].outEdges: if neighbor_id in consensus_node_ids: continue if ( len(text_poa_graph.nodedict[neighbor_id].labels) / text_poa_graph.num_sequences >= selection_threshold ): selected_node_ids.append(neighbor_id) texts = [] for node_id in selected_node_ids: if not text_poa_graph.nodedict[node_id].variations: texts.append(text_poa_graph.nodedict[node_id].text) else: all_texts = [v for v in text_poa_graph.nodedict[node_id].variations.values()] all_texts.append(text_poa_graph.nodedict[node_id].text) # select the variation that is longest texts.append(max(all_texts, key=len)) text = " ".join(texts) edited_text = clean_up_text(text=text, task=task, api="openai", **kwargs) if verbose: return text, edited_text else: return edited_text def decode_self_verified( text_poa_graph: TextPOAGraph, problem: str, uncertainty_threshold: float = 0.6, verification_api: str = "openai", verification_model: str = "gpt-4o-mini", grace_period: bool = True, ): high_uncertainty_nodes = [] for node_id in text_poa_graph.consensus_node_ids: if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id: continue outgoing_edges = text_poa_graph.nodedict[node_id].outEdges branching_factor = len(outgoing_edges) / text_poa_graph.num_sequences if branching_factor > uncertainty_threshold: high_uncertainty_nodes.append(node_id) selected_labels = list(text_poa_graph._seq_paths.keys()) masked_candidates = {} uncertain_region = False for label in selected_labels: text = "" for node_id in text_poa_graph._seq_paths[label]: if uncertain_region: text += f" *START_SEPARATOR*_{node_id} " if node_id in high_uncertainty_nodes: uncertain_region = True if len(text_poa_graph.nodedict[node_id].variations) > 0: text += text_poa_graph.nodedict[node_id].variations[label] text += " " else: text += text_poa_graph.nodedict[node_id].text text += " " if uncertain_region and node_id not in high_uncertainty_nodes: text += f" *END_SEPARATOR*_{node_id} " uncertain_region = False masked_candidates[label] = text patch_start_node = None uncertain_ids = [] # give a grace period for the first incorrect step prev_step = {label: None for label in selected_labels} for node_id in high_uncertainty_nodes: uncertain_ids.append(node_id) context_before = extract_context(text_poa_graph, node_id) alternative_paths = extract_alternative_paths(text_poa_graph, node_id) equivalent_classes = extract_equivalent_classes(text_poa_graph, node_id, selected_labels) new_labels = selected_labels.copy() # Only do self-verifaction for labels from different sematically equivalent branches if len(equivalent_classes) <= 1: continue i = 0 while i < len(equivalent_classes): if i + 1 < len(equivalent_classes): label_a = equivalent_classes[i][0] label_b = equivalent_classes[i + 1][0] full_a = context_before[label_a] + alternative_paths[label_a] full_b = context_before[label_b] + alternative_paths[label_b] score = verify_correctness_pairwise( full_text_1=full_a, full_text_2=full_b, verification_model=verification_model, problem=problem, api=verification_api, ) if float(score[0]) < 1.0: print(f"Label {label_a} is incorrect at node {node_id}") masked_candidates[label_a] = ( masked_candidates[label_a] .replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*") .replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*") ) if not prev_step[label_a]: prev_step[label_a] = True if prev_step[label_a] and grace_period or not grace_period: for label_i in equivalent_classes[i]: new_labels.remove(label_i) print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)") if float(score[0]) == 1.0: prev_step[label_a] = False if float(score[1]) < 1.0: print(f"Label {label_b} is incorrect at node {node_id}") masked_candidates[label_b] = ( masked_candidates[label_b] .replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*") .replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*") ) if not prev_step[label_b]: prev_step[label_b] = True if prev_step[label_b] and grace_period or not grace_period: for label_i in equivalent_classes[i + 1]: new_labels.remove(label_i) print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)") if float(score[1]) == 1.0: prev_step[label_b] = False i += 2 else: break if len(new_labels) == 0: patch_start_node = node_id break selected_labels = new_labels.copy() # These are the pruned approaches with masking print(masked_candidates) masked_approaches = "\n".join( [ f"Approach {label}: {masked_candidates[label].replace('START_SEPARATOR', 'START_UNCERTAIN_REGION').replace('END_SEPARATOR', 'END_UNCERTAIN_REGION')}" for label in selected_labels ] ) # These are all approaches with masking all_approaches = "\n".join( [f"Approach {label}: {masked_candidates[label]}" for label in masked_candidates.keys()] ) default_prompt = f""" Solve the following math problem with mathematical precision and clarity. Problem: {problem} Below are potential solution approaches with sections marked as uncertain (between *START_UNCERTAIN_REGION* and *END_UNCERTAIN_REGION*). These sections may contain conceptual or computational errors. There are also sections marked as *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR*. A verification step indicated that these steps are highly likely to contain errors. Potential Approaches: {masked_approaches} Your task: 1. Analyze all potential approaches critically, identifying their mathematical strengths and weaknesses If the approaches contain different answers, think carefully about why they are different, and use this to identify potential errors. 2. Using the sections with special markers, identify potential errors. 3. Develop a rigorous, step-by-step solution based on sound mathematical principles 4. For uncertain regions: - Verify each step using algebraic or numerical validation - If correct, incorporate these steps with appropriate justification - If incorrect, provide clear corrections with mathematical reasoning for your changes 5. Follow a comparative approach, using the differences between approaches to identify potential errors. 6. Do not blindly follow the approaches, but rather use them to identify potential errors. Guidelines for your solution: - Begin with a strategic overview of your chosen approach - Present each mathematical step with clear notation and justification - Pay special attention to areas that were previously marked uncertain Conclude your solution with: Therefore, the final answer is: $\\boxed{{answer}}$. Solution: """ patch_prompt = f""" Solve the following mathematical problem with precision and clarity. Problem: {problem} You have been provided with several partial solution approaches that attempted to solve this problem. None of these approaches are correct, but may contain valuable insights. Sections marked between *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR* indicate steps where previous solutions showed uncertainty. A verification step indicated that these steps are likely to contain errors. INSTRUCTIONS: 1. Synthesize a correct solution using insights from the previous approaches 2. Pay special attention to fixing the problematic areas marked by separators 3. Develop your solution step-by-step, showing clear mathematical reasoning 4. Focus especially on mathematical correctness in areas where previous solutions diverged 5. Present your work in a logical, sequential manner suitable for an advanced reader GUIDELINES FOR MATHEMATICAL RIGOR: 1. MAINTAIN MATHEMATICAL RIGOR - Verify that all mathematical operations follow from established principles and definitions - Ensure dimensional consistency throughout calculations - Check that algebraic manipulations preserve equality and do not introduce errors 2. CONSIDER ALTERNATIVE PERSPECTIVES - Even when approaches reach the same conclusion, examine their reasoning independently - Look for more elegant or insightful connections that may be missed across all approaches - Consider whether fundamental mathematical principles suggest a different path 3. CRITICAL VALIDATION - Test conclusions using known mathematical properties and relationships - When possible, verify results using alternative methods - Be especially cautious when all approaches agree on a result but use similar reasoning 4. USE PRECISION IN CORRECTIONS - When correcting uncertain regions, specify exactly what was incorrect and why - Provide clear mathematical justification for any changes - Ensure corrections align with standard mathematical principles and notations Previous Approaches (for reference only): {all_approaches} Your Solution: [Begin with a clear statement of your approach] [Provide detailed mathematical steps] [Ensure correct handling of complex mathematical operations] [Verify your work at key points, especially in previously problematic areas] Always conclude with: Therefore, the final answer is: $\\boxed{{answer}}$ """ if patch_start_node is not None or len(masked_candidates.keys()) == 1: print("None correct, patching") prompt = patch_prompt else: prompt = default_prompt return self_complete( verification_prompt=prompt, verification_model=verification_model, api=verification_api ), masked_candidates