| """ |
| Worker agent (Agent A and Agent B). Frozen, same model. |
| Uses templated outputs by default for deterministic environment behavior. |
| Can be swapped for an LLM-backed implementation. |
| """ |
| import json |
|
|
|
|
| class WorkerAgent: |
| def __init__(self, model=None, tokenizer=None): |
| self.model = model |
| self.tokenizer = tokenizer |
|
|
| def generate(self, spec: dict, role: str = "agent_a") -> str: |
| if self.model is not None and self.tokenizer is not None: |
| return self._generate_llm(spec, role) |
| return self._generate_template(spec, role) |
|
|
| def correct(self, output: str, correction_request: str) -> str: |
| if self.model is not None and self.tokenizer is not None: |
| return self._correct_llm(output, correction_request) |
| return self._correct_template(output, correction_request) |
|
|
| def _generate_template(self, spec: dict, role: str) -> str: |
| ground_truth = spec.get("ground_truth", {}) |
| return json.dumps(ground_truth, indent=2) |
|
|
| def _correct_template(self, output: str, correction_request: str) -> str: |
| return output + f"\n# corrected: {correction_request}" |
|
|
| def _generate_llm(self, spec: dict, role: str) -> str: |
| from training.prompt_templates import ( |
| WORKER_SYSTEM_PROMPT, |
| WORKER_USER_PROMPT, |
| ) |
| messages = [ |
| {"role": "system", "content": WORKER_SYSTEM_PROMPT.format(role=role)}, |
| {"role": "user", "content": WORKER_USER_PROMPT.format( |
| task_description=spec["task_description"], |
| role_description=role, |
| )}, |
| ] |
| prompt = self.tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| import torch |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=512, |
| temperature=0.9, |
| do_sample=True, |
| pad_token_id=self.tokenizer.eos_token_id, |
| ) |
| return self.tokenizer.decode( |
| outputs[0][inputs["input_ids"].shape[1]:], |
| skip_special_tokens=True, |
| ) |
|
|
| def _correct_llm(self, output: str, correction_request: str) -> str: |
| from training.prompt_templates import WORKER_CORRECTION_PROMPT |
| messages = [ |
| {"role": "user", "content": WORKER_CORRECTION_PROMPT.format( |
| correction_request=correction_request |
| ) + "\n\nORIGINAL OUTPUT:\n" + output}, |
| ] |
| prompt = self.tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| import torch |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=512, |
| temperature=0.9, |
| do_sample=True, |
| pad_token_id=self.tokenizer.eos_token_id, |
| ) |
| return self.tokenizer.decode( |
| outputs[0][inputs["input_ids"].shape[1]:], |
| skip_special_tokens=True, |
| ) |
|
|