Spaces:
Sleeping
Sleeping
File size: 5,517 Bytes
43c34cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
from abc import abstractmethod
from dataclasses import dataclass
from typing import Dict, List
from ..config import Configurable, EnvironmentConfig
from ..message import Message
from ..utils import AttributedDict
@dataclass
class TimeStep(AttributedDict):
"""
Represents a single step in time within the simulation.
It includes observation, reward, and terminal state.
Attributes:
observation (List[Message]): A list of messages (observations) for the current timestep.
reward (Dict[str, float]): A dictionary with player names as keys and corresponding rewards as values.
terminal (bool): A boolean indicating whether the current state is terminal (end of episode).
"""
observation: List[Message]
reward: Dict[str, float]
terminal: bool
class Environment(Configurable):
"""
Abstract class representing an environment.
It defines the necessary methods any environment must implement.
Inherits from:
Configurable: A custom class that provides methods to handle configuration settings.
Attributes:
type_name (str): Type of the environment, typically set to the lower case of the class name.
Note:
Subclasses should override and implement the abstract methods defined here.
"""
type_name = None
phase_index = 0
task = None
@abstractmethod
def __init__(self, player_names: List[str], **kwargs):
"""
Initialize the Environment.
Parameters:
player_names (List[str]): Names of the players in the environment.
"""
super().__init__(
player_names=player_names, **kwargs
) # registers the arguments with Configurable
self.player_names = player_names
def __init_subclass__(cls, **kwargs):
"""
Automatically called when a subclass is being initialized.
Here it's used to check if the subclass has the required attributes.
"""
for required in ("type_name",):
if getattr(cls, required) is None:
cls.type_name = cls.__name__.lower()
return super().__init_subclass__(**kwargs)
@abstractmethod
def reset(self):
"""
Reset the environment to its initial state.
Note:
This method must be implemented by subclasses.
"""
pass
def to_config(self) -> EnvironmentConfig:
self._config_dict["env_type"] = self.type_name
return EnvironmentConfig(**self._config_dict)
@property
def num_players(self) -> int:
"""Get the number of players."""
return len(self.player_names)
@abstractmethod
def get_next_player(self) -> str:
"""
Return the name of the next player.
Note:
This method must be implemented by subclasses.
Returns:
str: The name of the next player.
"""
pass
@abstractmethod
def get_observation(self, player_name=None) -> List[Message]:
"""
Return observation for a given player.
Note:
This method must be implemented by subclasses.
Parameters:
player_name (str, optional): The name of the player for whom to get the observation.
Returns:
List[Message]: The observation for the player in the form of a list of messages.
"""
pass
@abstractmethod
def print(self):
"""Print the environment state."""
pass
@abstractmethod
def step(self, player_name: str, action: str) -> TimeStep:
"""
Execute a step in the environment given an action from a player.
Note:
This method must be implemented by subclasses.
Parameters:
player_name (str): The name of the player.
action (str): The action that the player wants to take.
Returns:
TimeStep: An object of the TimeStep class containing the observation, reward, and done state.
"""
pass
@abstractmethod
def check_action(self, action: str, player_name: str) -> bool:
"""
Check whether a given action is valid for a player.
Note:
This method must be implemented by subclasses.
Parameters:
action (str): The action to be checked.
player_name (str): The name of the player.
Returns:
bool: True if the action is valid, False otherwise.
"""
return True
@abstractmethod
def is_terminal(self) -> bool:
"""
Check whether the environment is in a terminal state (end of episode).
Note:
This method must be implemented by subclasses.
Returns:
bool: True if the environment is in a terminal state, False otherwise.
"""
pass
def get_zero_rewards(self) -> Dict[str, float]:
"""
Return a dictionary with all player names as keys and zero as reward.
Returns:
Dict[str, float]: A dictionary of players and their rewards (all zero).
"""
return {player_name: 0.0 for player_name in self.player_names}
def get_one_rewards(self) -> Dict[str, float]:
"""
Return a dictionary with all player names as keys and one as reward.
Returns:
Dict[str, float]: A dictionary of players and their rewards (all one).
"""
return {player_name: 1.0 for player_name in self.player_names}
|