USTC975's picture
create the app
43c34cc
import logging
import traceback
from typing import List
from agentreview.environments import Conversation
from .base import TimeStep
from ..message import Message, MessagePool
logger = logging.getLogger(__name__)
class PaperDecision(Conversation):
"""
Area chairs make decision based on the meta reviews
"""
type_name = "paper_decision"
def __init__(self,
player_names: List[str],
experiment_setting: dict,
paper_ids: List[int] = None,
metareviews: List[str] = None,
parallel: bool = False,
**kwargs):
"""
Args:
paper_id (int): the id of the paper, such as 917
paper_decision (str): the decision of the paper, such as "Accept: notable-top-25%"
"""
# Inherit from the parent class of `class Conversation`
super(Conversation, self).__init__(player_names=player_names, parallel=parallel, **kwargs)
self.paper_ids = paper_ids
self.metareviews = metareviews
self.parallel = parallel
self.experiment_setting = experiment_setting
self.ac_scoring_method = kwargs.get("ac_scoring_method")
# The "state" of the environment is maintained by the message pool
self.message_pool = MessagePool()
self.ac_decisions = None
self._current_turn = 0
self._next_player_index = 0
self.phase_index = 5 # "ACs make decision based on meta review" is the last phase (Phase 5)
self._phases = None
@property
def phases(self):
if self._phases is None:
self._phases = {
5: {
"name": "ac_make_decisions",
'speaking_order': ["AC"]
},
}
return self._phases
def step(self, player_name: str, action: str) -> TimeStep:
"""
Step function that is called by the arena.
Args:
player_name: the name of the player that takes the action
action: the action that the agents wants to take
"""
message = Message(
agent_name=player_name, content=action, turn=self._current_turn
)
self.message_pool.append_message(message)
speaking_order = self.phases[self.phase_index]["speaking_order"]
# Reached the end of the speaking order. Move to the next phase.
logging.info(f"Phase {self.phase_index}: {self.phases[self.phase_index]['name']} "
f"| Player {self._next_player_index}: {speaking_order[self._next_player_index]}")
if self._next_player_index == len(speaking_order) - 1:
self._next_player_index = 0
logger.info(f"Phase {self.phase_index}: end of the speaking order. Move to Phase {self.phase_index + 1}.")
self.phase_index += 1
self._current_turn += 1
else:
self._next_player_index += 1
timestep = TimeStep(
observation=self.get_observation(),
reward=self.get_zero_rewards(),
terminal=self.is_terminal(),
) # Return all the messages
return timestep
def check_action(self, action: str, player_name: str) -> bool:
"""Check if the action is valid."""
if player_name.startswith("AC"):
try:
self.ac_decisions = self.parse_ac_decisions(action)
except:
traceback.print_exc()
return False
if not isinstance(self.ac_decisions, dict):
return False
return True
@property
def ac_decisions(self):
return self._ac_decisions
@ac_decisions.setter
def ac_decisions(self, value):
self._ac_decisions = value
def parse_ac_decisions(self, action: str):
"""
Parse the decisions made by the ACs
"""
lines = action.split("\n")
paper2rating = {}
paper_id, rank = None, None
for line in lines:
if line.lower().startswith("paper id:"):
paper_id = int(line.split(":")[1].split('(')[0].strip())
elif self.ac_scoring_method == "ranking" and line.lower().startswith("willingness to accept:"):
rank = int(line.split(":")[1].strip())
elif self.ac_scoring_method == "recommendation" and line.lower().startswith("decision"):
rank = line.split(":")[1].strip()
if paper_id in paper2rating:
raise ValueError(f"Paper {paper_id} is assigned a rank twice.")
if paper_id is not None and rank is not None:
paper2rating[paper_id] = rank
paper_id, rank = None, None
return paper2rating