Spaces:
Running
Running
| import json | |
| import logging | |
| import os.path as osp | |
| from typing import List | |
| from agentreview.environments import Conversation | |
| from agentreview.utility.utils import get_rebuttal_dir | |
| from .base import TimeStep | |
| from ..message import Message | |
| from ..paper_review_message import PaperReviewMessagePool | |
| logger = logging.getLogger(__name__) | |
| class PaperReview(Conversation): | |
| """ | |
| Discussion between reviewers and area chairs. | |
| There are several phases in the reviewing process: | |
| reviewer_write_reviews: reviewers write their reviews based on the paper content. | |
| author_reviewer_discussion: An author respond to comments from the reviewers. | |
| reviewer_ac_discussion: reviewers and an area chair discuss the paper. | |
| ac_discussion: an area chair makes the final decision. | |
| """ | |
| type_name = "paper_review" | |
| def __init__(self, player_names: List[str], paper_id: int, paper_decision: str, experiment_setting: dict, args, | |
| parallel: bool = False, | |
| **kwargs): | |
| """ | |
| Args: | |
| paper_id (int): the id of the paper, such as 917 | |
| paper_decision (str): the decision of the paper, such as "Accept: notable-top-25%" | |
| """ | |
| # Inherit from the parent class of `class Conversation` | |
| super(Conversation, self).__init__(player_names=player_names, parallel=parallel, **kwargs) | |
| self.args = args | |
| self.paper_id = paper_id | |
| self.paper_decision = paper_decision | |
| self.parallel = parallel | |
| self.experiment_setting = experiment_setting | |
| self.player_to_test = experiment_setting.get('player_to_test', None) | |
| self.task = kwargs.get("task") | |
| self.experiment_name = args.experiment_name | |
| # The "state" of the environment is maintained by the message pool | |
| self.message_pool = PaperReviewMessagePool(experiment_setting) | |
| self.phase_index = 0 | |
| self._phases = None | |
| def phases(self): | |
| if self._phases is not None: | |
| return self._phases | |
| reviewer_names = [name for name in self.player_names if name.startswith("Reviewer")] | |
| num_reviewers = len(reviewer_names) | |
| reviewer_names = [f"Reviewer {i}" for i in range(1, num_reviewers + 1)] | |
| self._phases = { | |
| # In phase 0, no LLM-based agents are called. | |
| 0: { | |
| "name": "paper_extraction", | |
| 'speaking_order': ["Paper Extractor"], | |
| }, | |
| 1: { | |
| "name": 'reviewer_write_reviews', | |
| 'speaking_order': reviewer_names | |
| }, | |
| # The author responds to each reviewer's review | |
| 2: { | |
| 'name': 'author_reviewer_discussion', | |
| 'speaking_order': ["Author" for _ in reviewer_names], | |
| }, | |
| 3: { | |
| 'name': 'reviewer_ac_discussion', | |
| 'speaking_order': ["AC"] + reviewer_names, | |
| }, | |
| 4: { | |
| 'name': 'ac_write_metareviews', | |
| 'speaking_order': ["AC"] | |
| }, | |
| 5: { | |
| 'name': 'ac_makes_decisions', | |
| 'speaking_order': ["AC"] | |
| }, | |
| } | |
| return self.phases | |
| def phases(self, value): | |
| self._phases = value | |
| def reset(self): | |
| self._current_phase = "review" | |
| self.phase_index = 0 | |
| return super().reset() | |
| def load_message_history_from_cache(self): | |
| if self._phase_index == 0: | |
| print("Loading message history from BASELINE experiment") | |
| full_paper_discussion_path = get_rebuttal_dir(paper_id=self.paper_id, | |
| experiment_name="BASELINE", | |
| model_name=self.args.model_name, | |
| conference=self.args.conference) | |
| messages = json.load(open(osp.join(full_paper_discussion_path, f"{self.paper_id}.json"), 'r', | |
| encoding='utf-8'))['messages'] | |
| num_messages_from_AC = 0 | |
| for msg in messages: | |
| # We have already extracted contents from the paper. | |
| if msg['agent_name'] == "Paper Extractor": | |
| continue | |
| # Encountering the 2nd message from the AC. Stop loading messages. | |
| if msg['agent_name'] == "AC" and num_messages_from_AC == 1: | |
| break | |
| if msg['agent_name'] == "AC": | |
| num_messages_from_AC += 1 | |
| message = Message(**msg) | |
| self.message_pool.append_message(message) | |
| num_unique_reviewers = len( | |
| set([msg['agent_name'] for msg in messages if msg['agent_name'].startswith("Reviewer")])) | |
| assert num_unique_reviewers == self.args.num_reviewers_per_paper | |
| self._phase_index = 4 | |
| def step(self, player_name: str, action: str) -> TimeStep: | |
| """ | |
| Step function that is called by the arena. | |
| Args: | |
| player_name: the name of the player that takes the action | |
| action: the action that the agents wants to take | |
| """ | |
| message = Message( | |
| agent_name=player_name, content=action, turn=self._current_turn | |
| ) | |
| self.message_pool.append_message(message) | |
| speaking_order = self.phases[self.phase_index]["speaking_order"] | |
| # Reached the end of the speaking order. Move to the next phase. | |
| logging.info(f"Phase {self.phase_index}: {self.phases[self._phase_index]['name']} " | |
| f"| Player {self._next_player_index}: {speaking_order[self._next_player_index]}") | |
| terminal = self.is_terminal() | |
| if self._next_player_index == len(speaking_order) - 1: | |
| self._next_player_index = 0 | |
| if self.phase_index == 4: | |
| terminal = True | |
| logger.info( | |
| "Finishing the simulation for Phase I - IV. Please run `python run_paper_decision_cli.py ` for " | |
| "Phase V. (AC makes decisions).") | |
| else: | |
| print(f"Phase {self.phase_index}: end of the speaking order. Move to Phase ({self.phase_index + 1}).") | |
| self.phase_index += 1 | |
| self._current_turn += 1 | |
| else: | |
| self._next_player_index += 1 | |
| timestep = TimeStep( | |
| observation=self.get_observation(), | |
| reward=self.get_zero_rewards(), | |
| terminal=terminal, | |
| ) # Return all the messages | |
| return timestep | |
| def get_next_player(self) -> str: | |
| """Get the next player in the current phase.""" | |
| speaking_order = self.phases[self.phase_index]["speaking_order"] | |
| next_player = speaking_order[self._next_player_index] | |
| return next_player | |
| def get_observation(self, player_name=None) -> List[Message]: | |
| """Get observation for the player.""" | |
| if player_name is None: | |
| return self.message_pool.get_all_messages() | |
| else: | |
| return self.message_pool.get_visible_messages_for_paper_review( | |
| player_name, phase_index=self.phase_index, next_player_idx=self._next_player_index, | |
| player_names=self.player_names | |
| ) | |
| def get_messages_from_player(self, player_name: str) -> List[str]: | |
| """Get the list of actions that the player can take.""" | |
| return self.message_pool.get_messages_from_player(player_name) | |