Spaces:
Running
Running
| import csv | |
| import json | |
| import logging | |
| import uuid | |
| from typing import Dict, List, Union | |
| from .agent import Player | |
| from .backends import Human | |
| from .config import ArenaConfig | |
| from .environments import Environment, TimeStep, load_environment | |
| class TooManyInvalidActions(Exception): | |
| pass | |
| class Arena: | |
| """Utility class that manages the game environment and players.""" | |
| def __init__( | |
| self, players: List[Player], environment: Environment, args, global_prompt: str = None | |
| ): | |
| # Create a container for the players and environment and reset the game | |
| self.players = players | |
| self.environment = environment | |
| self.global_prompt = global_prompt | |
| self.current_timestep = environment.reset() | |
| self.uuid = uuid.uuid4() # Generate a unique id for the game | |
| self.invalid_actions_retry = 5 | |
| self.args = args | |
| def num_players(self): | |
| return self.environment.num_players | |
| def name_to_player(self) -> Dict[str, Player]: | |
| return {player.name: player for player in self.players} | |
| def reset(self) -> TimeStep: | |
| # Reset the environment | |
| self.current_timestep = self.environment.reset() | |
| # Reset the players | |
| for player in self.players: | |
| player.reset() | |
| # Reset the uuid | |
| self.uuid = uuid.uuid4() | |
| return self.current_timestep | |
| def step(self) -> TimeStep: | |
| """Take a step in the game: one player takes an action and the environment updates.""" | |
| player_name = self.environment.get_next_player() | |
| player = self.name_to_player[player_name] # get the player object | |
| observation = self.environment.get_observation( | |
| player_name | |
| ) # get the observation for the player | |
| timestep = None | |
| for i in range( | |
| self.invalid_actions_retry | |
| ): # try to take an action for a few times | |
| action = player(observation) # take an action | |
| if self.environment.check_action(action, player_name): # action is valid | |
| timestep = self.environment.step( | |
| player_name, action | |
| ) # update the environment | |
| break | |
| else: # action is invalid | |
| logging.warning(f"{player_name} made an invalid action {action}") | |
| continue | |
| if ( | |
| timestep is None | |
| ): # if the player made invalid actions for too many times, terminate the game | |
| warning_msg = f"{player_name} has made invalid actions for {self.invalid_actions_retry} times. Terminating the game." | |
| logging.warning(warning_msg) | |
| raise TooManyInvalidActions(warning_msg) | |
| return timestep | |
| def next_is_human(self): | |
| """Check if the next player is human.""" | |
| player_name = self.environment.get_next_player() | |
| player = self.name_to_player[player_name] | |
| return isinstance(player.backend, Human) | |
| def run(self, num_steps: int = 1): | |
| """Run the game for num_turns.""" | |
| for i in range(num_steps): | |
| timestep = self.step() | |
| if timestep.terminal: | |
| break | |
| def from_config(cls, config: Union[str, ArenaConfig]): | |
| """Create an arena from a config.""" | |
| # If config is a path, load the config | |
| if isinstance(config, str): | |
| config = ArenaConfig.load(config) | |
| global_prompt = config.get("global_prompt", None) | |
| # Create the players | |
| players = [] | |
| for player_config in config.players: | |
| # Add public_prompt to the player config | |
| if global_prompt is not None: | |
| player_config["global_prompt"] = global_prompt | |
| player = Player.from_config(player_config) | |
| players.append(player) | |
| # Check that the player names are unique | |
| player_names = [player.name for player in players] | |
| assert len(player_names) == len( | |
| set(player_names) | |
| ), "Player names must be unique" | |
| # Create the environment | |
| config.environment[ | |
| "player_names" | |
| ] = player_names # add the player names to the environment config | |
| env = load_environment(config.environment) | |
| return cls(players, env, global_prompt=global_prompt) | |
| def to_config(self) -> ArenaConfig: | |
| """Convert the arena to a config.""" | |
| # return { | |
| # "players": [player.to_config() for player in self.players], | |
| # "environment": self.environment.to_config(), | |
| # "global_prompt": self.global_prompt | |
| # } | |
| return ArenaConfig( | |
| players=[player.to_config() for player in self.players], | |
| environment=self.environment.to_config(), | |
| global_prompt=self.global_prompt, | |
| ) | |
| def launch_cli(self, max_steps: int = None, interactive: bool = True): | |
| """Launch the command line interface.""" | |
| from agentreview.ui.cli import ArenaCLI | |
| cli = ArenaCLI(self) | |
| cli.launch(max_steps=max_steps, interactive=interactive) | |
| def save_config(self, path: str): | |
| """Save the config to a file.""" | |
| config = self.to_config() | |
| config.save(path) | |
| def save_history(self, path: str): | |
| """ | |
| Save the history of the game to a file. | |
| Supports csv and json formats. | |
| """ | |
| messages = self.environment.get_observation() | |
| message_rows = [] | |
| if path.endswith(".csv"): | |
| header = [ | |
| "agent_name", | |
| "content", | |
| "turn", | |
| "timestamp", | |
| "visible_to", | |
| "msg_type", | |
| ] | |
| for message in messages: | |
| message_row = [ | |
| message.agent_name, | |
| message.content, | |
| message.turn, | |
| str(message.timestamp), | |
| message.visible_to, | |
| message.msg_type, | |
| ] | |
| message_rows.append(message_row) | |
| with open(path, "w") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(header) | |
| writer.writerows(message_rows) | |
| elif path.endswith(".json"): | |
| for message in messages: | |
| message_row = { | |
| "agent_name": message.agent_name, | |
| "content": message.content, | |
| "turn": message.turn, | |
| "timestamp": str(message.timestamp), | |
| "visible_to": message.visible_to, | |
| "msg_type": message.msg_type, | |
| } | |
| message_rows.append(message_row) | |
| with open(path, "w") as f: | |
| json.dump(message_rows, f, indent=2) | |
| else: | |
| raise ValueError("Invalid file format") | |