diff --git a/anti_poaching/anti_poaching_v0.py b/anti_poaching/anti_poaching_v0.py index db0790ef79e3e359fbdbc5a7e3fd55aaa9c0ae22..169b92d88857ae28d6cf39383ef32a49a55aca22 100644 --- a/anti_poaching/anti_poaching_v0.py +++ b/anti_poaching/anti_poaching_v0.py @@ -1,7 +1,5 @@ -""" -Module to expose the Conservation Game -Environment from within the env folder. -""" +"""Module to expose the Anti-Poaching PettingZoo env. This also registers +compatibile versions for use with RLlib.""" from gymnasium.spaces import Tuple, Dict, Box from pettingzoo.utils.env import ParallelEnv @@ -12,86 +10,49 @@ from ray.tune.registry import register_env from ray.rllib.models import ModelCatalog from ray.rllib.examples.models.action_mask_model import TorchActionMaskModel from ray.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper +from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv from .env import anti_poaching from .env.utils.game_utils import GridStateConstProb, Trap -from .env.utils.wrappers import ( - NonCategoricalFlatten, - StackerWrapper, - QMIXCompatibilityLayer, -) -from .env.utils.flatten_utils import box_flatten, box_flatten_obs, parse_obs - +from .env.utils.wrappers import QMIXCompatibilityLayer -def rllib_cons_game( - env: ParallelEnv = None, config: dict = None -) -> MultiAgentEnv: - """ - Wraps the Conservation Game to modify the - observation spaces and stack all spaces into - their respective joint spaces. - """ - if config is None: - config = {} - if env is None: - env = anti_poaching.parallel_env(**config) - env = NonCategoricalFlatten(env) # Dict(....) -> Box - env = StackerWrapper(env) # Wrapped(ParallelEnv) -> MultiAgentEnv - return env +# Register environment to use for RLlib +register_env( + anti_poaching.metadata["name"], + lambda config: ParallelPettingZooEnv(anti_poaching.parallel_env(**config)), +) def grouped_rllib_cons_game( env: ParallelEnv = None, config: dict = None, ) -> GroupAgentsWrapper: - """ - This is for QMix and other algorithms that may need - grouped agents. - """ - env = rllib_cons_game(env, config) - env = QMIXCompatibilityLayer(env) - + """This is for QMix and other algorithms that uses agent groups""" + env = anti_poaching.parallel_env(**config) if not env else env + qenv = QMIXCompatibilityLayer(env) grouped_policies = ["rangers", "poachers"] # Declare the observation policy of the group to train, # and add attr to bypass the QMixPolicy validation step obs_space = Tuple( - Tuple( - [ - env.observation_space[agent] - for agent in getattr(env.unwrapped, group) - ] - ) + Tuple([qenv.observation_space[agent] for agent in getattr(env, group)]) for group in grouped_policies ) # Declare the grouped action space as well. act_space = Tuple( - Tuple( - [ - env.action_space[agent] - for agent in getattr(env.unwrapped, group) - ] - ) + Tuple([qenv.action_space(agent) for agent in getattr(env, group)]) for group in grouped_policies ) # return the grouped env using .with_agent_groups - return env.with_agent_groups( - groups={ - group: getattr(env.unwrapped, group) for group in grouped_policies - }, + return qenv.with_agent_groups( + groups={group: getattr(env, group) for group in grouped_policies}, obs_space=obs_space, act_space=act_space, ) -# Register environment to use for RLlib -register_env( - anti_poaching.metadata["name"], - lambda config: rllib_cons_game(config=config), -) - # Register grouped environment for QMix # Note that QMix assumes homogenous agents (ray2.8.0) register_env( diff --git a/anti_poaching/env/anti_poaching.py b/anti_poaching/env/anti_poaching.py index ecc2babfecc333704be6a80e54f9deefbde399f3..75ef3bf0740f0c7d11562f02e605c375692824e5 100644 --- a/anti_poaching/env/anti_poaching.py +++ b/anti_poaching/env/anti_poaching.py @@ -1,21 +1,16 @@ -""" -Module to implement the Anti-Poaching game environment -""" +"""Module to implement the Anti-Poaching game environment""" -import logging import functools -from collections import OrderedDict +from copy import deepcopy import numpy as np -from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, MultiBinary +import gymnasium as gym from pettingzoo.utils.env import ParallelEnv from .utils.game_utils import BaseGridState, GridStateConstProb, Trap - -# Agents are strings -AgentID = str +from .utils.typing import * # Game metadata as global metadata = { - "name": "anti_poaching_v0.2.2", + "name": "anti_poaching_v0.3", "render_modes": BaseGridState.RENDER_MODES, "is_parallelizable": True, } @@ -59,7 +54,6 @@ def parallel_env( poachers, ntraps_per_poacher, max_time, - render_mode, seed, ) @@ -76,12 +70,13 @@ class raw_env(ParallelEnv): - Observation space: - Each observation space is a composite product of smaller spaces. - This reflects the formal definition of the model. + The observations are vectors with different lower and upper bounds. + The older implementation used dictionaries, which favored readibility + over usability. Assumption: - We also assume that a poacher can detect the number of their own traps in - the current cell with probability 1. + We also assume that a poacher can detect the number of their own traps + in the current cell with probability 1. - Rewards: @@ -92,6 +87,8 @@ class raw_env(ParallelEnv): Version History --------------- + v0.3 - Reimplementation to remove a lot of complex code. Functionally + equivalent to v0.2.2, modulo any bugfixes. v0.2.2 - Reward poachers when capturing a prey only, minor bugfixes, observation space now contains remaining time. v0.2.1 - OrderedDicts for the observation spaces (see CHANGELOG) @@ -108,10 +105,10 @@ class raw_env(ParallelEnv): poachers: str, ntraps_per_poacher: int, max_time: int, - render_mode: str = "ansi", seed: int = None, ): self.rng = np.random.default_rng(seed=seed) + self.seed = seed # time properties self.max_time = max_time @@ -120,12 +117,11 @@ class raw_env(ParallelEnv): # agent parameters self.ntraps_per_poacher = ntraps_per_poacher - # agents-related properties + # agent-related properties self.poachers = poachers self.rangers = rangers self.agents = self.rangers + self.poachers self.possible_agents = self.agents[:] - self.poacher_traps = { poacher: [ Trap(name=f"trap_{i}_{poacher}") @@ -136,118 +132,101 @@ class raw_env(ParallelEnv): # Arena-related properties self.grid = grid - grid_size, nrangers, npoachers = grid.N, len(rangers), len(poachers) + nrangers, npoachers = len(rangers), len(poachers) + + # Convenience attributes + self._ranger_obs_size = 8 + nrangers + self._poacher_obs_size = 7 + # Spaces parameters. self.action_spaces = { - **{ranger: Discrete(5, seed=seed) for ranger in self.rangers}, - **{poacher: Discrete(6, seed=seed) for poacher in self.poachers}, + **{ + ranger: gym.spaces.Discrete(5, seed=seed) + for ranger in self.rangers + }, + **{ + poacher: gym.spaces.Discrete(6, seed=seed) + for poacher in self.poachers + }, } + self.observation_spaces = { **{ - poacher: Dict( + ranger: gym.spaces.Dict( { - "observations": Dict( - OrderedDict( - { - "remaining_time": Discrete(max_time + 1), - "state": Box( - low=np.array([-1, -1, 0, 0]), - high=np.array( - [ - grid_size - 1, - grid_size - 1, - ntraps_per_poacher, - np.iinfo(np.int32).max, - ] - ), - dtype=np.int32, - ), - "ground_traps": MultiDiscrete( - [ - ntraps_per_poacher + 1, - ntraps_per_poacher + 1, - ] - ), # traps found in current cell - "rangers": Discrete( - nrangers + 1, start=0 - ), # num. of rangers detected in current cell - "poachers": Discrete( - npoachers + 1, start=0 - ), # num. of poachers detected in current cell - } - ) + "observations": gym.spaces.Box( + np.zeros(self._ranger_obs_size), + np.array( # high + [ + max_time, # max time + *[self.grid.N] * 2, # location + *[1] * nrangers, # partner rangers + npoachers, # #captured-poachers + *[ + ntraps_per_poacher * npoachers, + np.iinfo(INTEGER).max, + ], # poacher-captured traps + *[ + ntraps_per_poacher * npoachers, + np.iinfo(INTEGER).max, + ], # grid-captured traps + ] + ), + seed=seed, + dtype=INTEGER, ), - "action_mask": MultiBinary(6), - }, - seed=seed, + "action_mask": gym.spaces.MultiBinary(5), + } ) - for poacher in self.poachers + for ranger in self.rangers }, **{ - ranger: Dict( + poacher: gym.spaces.Dict( { - "observations": Dict( - OrderedDict( - { - "remaining_time": Discrete(max_time + 1), - "state": MultiDiscrete( - [grid_size, grid_size] - ), # ranger state is location - "partners": MultiBinary( - nrangers, - ), # for the list of rangers sharing the cell - "poacher_traps": MultiDiscrete( - [ - npoachers * ntraps_per_poacher + 1, - max_time, # Max number of prey - ] - ), # For the traps and prey recovered from poacher capture - "poachers_captured": Discrete( - npoachers + 1 - ), - "ground_traps": MultiDiscrete( - [ - npoachers * ntraps_per_poacher + 1, - ] - * 2 - ), # For traps recovered from the grid - } - ) + "observations": gym.spaces.Box( + np.array([0, *[-1] * 2, *[0] * 2, 0, 0]), + np.array( + [ + max_time, # max time + *[self.grid.N] * 2, # location + ntraps_per_poacher, # #traps + np.iinfo(INTEGER).max, # #prey + nrangers, # #rangers detected + npoachers, # #poachers detected + ] + ), + seed=seed, + dtype=INTEGER, ), - "action_mask": MultiBinary(5), - }, - seed=seed, + "action_mask": gym.spaces.MultiBinary(6), + } ) - for ranger in self.rangers + for poacher in self.poachers }, } - self.total_rewards = dict.fromkeys( - self.agents, 0 - ) # stores every agent's rewards in the current episode. - self.killed_agents = [] def reset( self, seed: int = None, return_info: bool = False, options: dict = None ) -> tuple: - """Resets the environment for the next episode. If new configurations are to be set, - then we use the options dictionary as follows. + """Resets the environment for the next episode. If new configurations + are to be set, then we use the options dictionary as follows. >>> env = AntiPoachingGame( ... ) - >>> env.reset(seed=None, options={"rangers": _, "poachers": _, "ntraps_per_poacher": _}) + >>> env.reset(seed=123) - Here, a None `seed` means that the internal RNG of our GridState object is randomly - reset, and thus will generate a new position. The `options` dictionary specifies - the other parameters that can be changed.""" + Here, a None `seed` means that the internal RNG of our GridState object + is randomly reset, and thus will generate a new starting position.""" self.rng = np.random.default_rng(seed=seed) - # Reset the game parameters + # Reset the game parameters: Only override if supplied. self.curr_time = 0 + seed = seed if seed else self.seed + options = options if options is not None else {} # pass options to grid reset: - options = options if isinstance(options, dict) else {} - seed = seed if options.get("renew_config", True) else None self.grid.reset(seed=seed, **options) + # Regenerate agents. self.agents = self.possible_agents[:] self.poacher_traps = { poacher: [ @@ -257,100 +236,37 @@ class raw_env(ParallelEnv): for poacher in self.poachers } - # Re-seed/reinitialise all spaces - for agent in self.agents: - self.observation_spaces[agent].seed(seed) - self.action_spaces[agent].seed(seed) + obs = dict.fromkeys(self.agents) # Returning this - observations = { - **{ - poacher: { - "observations": { - "remaining_time": self.max_time, - "state": self.grid.state[poacher], - "ground_traps": np.zeros(2, dtype=np.int32), - "rangers": 0, - "poachers": 0, - }, - "action_mask": self.grid.permitted_movements(poacher), - } - for poacher in self.poachers - }, - **{ - ranger: { - "observations": { - "remaining_time": self.max_time, - "state": self.grid.state[ranger], - "partners": np.zeros( - shape=(len(self.rangers),), dtype=np.int32 - ), - "poacher_traps": np.zeros(2, dtype=np.int32), - "poachers_captured": np.zeros(1, dtype=np.int32), - "ground_traps": np.zeros(2, dtype=np.int32), - }, - "action_mask": self.grid.permitted_movements(ranger), - } - for ranger in self.rangers - }, - } + # Creating a default object to copy for both agents + _def_ranger_obs = self.observation_space("ranger_0").sample() + _def_poacher_obs = self.observation_space("poacher_0").sample() - # returning infos dictionary with reset - return observations, dict.fromkeys(self.agents, {self.curr_time}) - - def observe(self, agent: AgentID, record: dict) -> dict: - """Each agent receives observations from their current cell. These - are calculated deterministically based on their record, which is - a record of all relevant information during the transition to - their new state.""" - assert ( - "poacher" in agent or "ranger" in agent - ), f"Unknown agent passed as argument! {agent}" - - # get the complete observation first - action_mask = self.grid.permitted_movements(agent) - - if "poacher" in agent: - action_mask[5] = len(self.poacher_traps[agent]) > 0 and all( - action_mask[:2] >= 0 - ) # authorise place_trap - obs = { - "remaining_time": self.max_time - self.curr_time, - "state": self.grid.state[agent], - "ground_traps": np.array( - [record["trap_empty"], record["trap_full"]] - ), - "rangers": record["num_rangers"], - "poachers": record["num_poachers"], - } - elif "ranger" in agent: - # TODO: Update the rangers observation space - partners = np.zeros(shape=(len(self.rangers),), dtype=np.int32) - partners[record["partners"]] = 1 - - obs = { - "remaining_time": self.max_time - self.curr_time, - "state": self.grid.state[agent], - "partners": partners, - "poacher_traps": np.sum( - ( - np.array([nempty, nfull]) - for poacher, (nempty, nfull) in record[ - "poachers_found" - ].items() - ), - axis=1, - ), - "poachers_captured": len(record["poachers_found"].keys()), - "ground_traps": np.array( - [record["trap_empty"], record["trap_full"]] - ), - } + # and zero-ing them out since most init obs are zero. + _def_ranger_obs["observations"] = np.zeros_like( + _def_ranger_obs["observations"], dtype=INTEGER + ) + _def_poacher_obs["observations"] = np.zeros_like( + _def_poacher_obs["observations"], dtype=INTEGER + ) + for agent in self.agents: + # Re-seed/reinitialise all spaces + self.observation_space(agent).seed(seed) + self.action_space(agent).seed(seed) + + # create appropriate observations for t=0 + _copy_obj, _size = ( + (_def_ranger_obs, 3) + if "ranger" in agent + else (_def_poacher_obs, 5) + ) + obs[agent] = deepcopy(_copy_obj) + obs[agent]["observations"][0] = self.max_time + obs[agent]["observations"][1:_size] = self.grid.state[agent] + obs[agent]["action_mask"] = self.grid.permitted_movements(agent) - # return parsed observations - return { - "observations": obs, - "action_mask": action_mask, - } + # returning infos dictionary with reset + return obs, dict.fromkeys(self.agents, {self.curr_time}) def step(self, actions: dict) -> tuple: """Receives a joint action, and sends @@ -363,77 +279,65 @@ class raw_env(ParallelEnv): truncations = dict.fromkeys(self.agents, False) infos = dict.fromkeys(self.agents, {}) - # records tracks transition information that - # is used to create the observations of s^{t+1}. - records = dict.fromkeys(self.agents) + # ... and the obs dict, with dummy action masks for now. + obs = dict.fromkeys(self.agents) for agent in self.agents: - if "ranger" in agent: - records[agent] = { - "trap_empty": 0, - "trap_full": 0, - "poachers_found": { - poacher: [0, 0] - for poacher in self.poachers # nempty, nfull - }, - "partners": [], - } - elif "poacher" in agent: - records[agent] = dict.fromkeys( - ["trap_empty", "trap_full", "num_rangers", "num_poachers"], - 0, - ) + _size = ( + self._ranger_obs_size + if "ranger" in agent + else self._poacher_obs_size + ) + obs[agent] = { + "observations": np.zeros(_size, dtype=INTEGER), + "action_mask": None, + } - # ... and now we run through the transitions ! + # Now we run through the transitions ! The obs + # dictionary is populated by each helper function. # Step 1: Rangers move first - self._rangers_move(actions, records) + self._rangers_move(actions, obs) # Step 2: Poachers move and remove their traps - self._poachers_move_and_get_traps(actions, rewards, records) + self._poachers_move_and_get_traps(actions, rewards, obs) # Step 3: Rangers remove traps and remaining traps capture animals - self._rangers_remove_traps(rewards, records) + self._rangers_remove_traps(rewards, obs) self._traps_catch_animals() # Step 4: Rangers remove poachers and remaining poachers place traps - self._rangers_remove_poachers(rewards, terminations, records) - self._poachers_place_traps(actions, terminations) + self._rangers_remove_poachers(rewards, obs) + self._poachers_place_traps(actions) - # update terminations: dict, rewards: dict, truncations and infos for the next step. + # update the obs for the agents with time, new state and action masks time_status = self.curr_time >= self.max_time for agent in self.agents: + obs[agent]["observations"][0] = self.max_time - self.curr_time + if "ranger" in agent: + obs[agent]["observations"][1:3] = self.grid.state[agent] + obs[agent]["action_mask"] = self.grid.permitted_movements( + agent + ) + elif "poacher" in agent: + _action_mask = self.grid.permitted_movements(agent) + _action_mask[5] = int(len(self.poacher_traps[agent]) > 0) + obs[agent]["observations"][1:5] = self.grid.state[agent] + obs[agent]["action_mask"] = _action_mask + + # update terminations for next step. Note that all agents are + # technically alive until max_time: captured poachers are just + # in a captured state. terminations[agent] |= time_status - self.total_rewards[agent] += rewards[agent] - - # update the list of killed agents this round - self.killed_agents = [ - agent - for agent in self.agents - if terminations[agent] or truncations[agent] - ] - - # calculate info for next step. - observations = { - agent: self.observe(agent, records[agent]) for agent in self.agents - } - # update the arena - self._cleanup() - - return ( - observations, - rewards, - terminations, - truncations, - infos, - ) + # Agents are terminated on the last step. + self.agents = [] if time_status else self.agents + return obs, rewards, terminations, truncations, infos def render(self): """Rendering the grid using the GridState object""" self.grid.render() def state(self) -> dict: - """Grid State is stored in the GridState class as - the `grid` attribute.""" + """State is stored as the `grid: GridState` attribute.""" return self.grid.state def _assign_reward( @@ -442,49 +346,36 @@ class raw_env(ParallelEnv): """Assigns the reward to poacher, and splits the reward among the cooperative rangers. A positive reward adds to poacher and removes proportionally from all rangers.""" - assert rewards is not None, "No reward dict to update!" - - # If assigning a 'reward' to an active player. - # This arises because rangers can recover a trap of a - # captured agent. Since he is captured, he is not - # in rewards, but will always be in total_rewards. - if poacher in rewards: - rewards[poacher] += reward - else: - logging.warn( - f"AntiPoachingGame: Assigning inactive {poacher} a reward of {reward} !!!" - ) - self.total_rewards[poacher] += reward - + rewards[poacher] += reward for ranger in self.rangers: rewards[ranger] -= reward / len(self.rangers) - def _rangers_move(self, actions: dict, records: dict) -> None: - """Helper function to move the rangers and update their records + def _rangers_move(self, actions: dict, obs: dict) -> None: + """Helper function to move the rangers and update their obs with detected partners in the same cell.""" for ranger in [r for r in self.rangers if 1 <= actions[r] <= 4]: self.grid.update_position(ranger, actions[ranger]) for nbor_ranger in [ r - for r in self.rangers - if r != ranger - and all(self.grid.state[r] == self.grid.state[ranger]) + for r in self.grid.get_neighbours(ranger) + if r != ranger and "ranger" in r ]: - # add the ranger number to the records - records[ranger]["partners"].append( - int(nbor_ranger.split("_")[-1]) - ) + # add the ranger number to the obs + obs[ranger]["observations"][ + 2 + int(nbor_ranger.split("_")[-1]) + ] = 1 def _poachers_move_and_get_traps( - self, actions: dict, rewards: dict, records: dict + self, actions: dict, rewards: dict, obs: dict ) -> None: """Helper function: Moves the poachers according to actions and removes their traps (if found) on the next step.""" for poacher in [ p for p in self.poachers - if p in self.agents and 0 <= actions[p] <= 4 + if 0 <= actions[p] <= 4 and self.grid.state[p][0] >= 0 ]: + # Skipping over captured poachers ... self.grid.update_position(poacher, actions[poacher]) for _trap in [ _t @@ -502,20 +393,18 @@ class raw_env(ParallelEnv): poacher, self.grid.remove_trap(_trap), rewards ) - # update trap records, and + # update trap obs, and # reset trap value when recovered. - key = "trap_empty" if _trap.value == 0 else "trap_full" - records[poacher][key] += 1 + key = 3 if _trap.value == 0 else 4 + obs[poacher]["observations"][key] += 1 _trap.value = 0 - # Update records for positions as well + # Update obs for positions as well for nbor in self.grid.get_neighbours(poacher): - if "poacher" in nbor: - records[poacher]["num_poachers"] += 1 - else: - records[poacher]["num_rangers"] += 1 + key = 5 if "ranger" in nbor else 6 + obs[poacher]["observations"][key] += 1 - def _rangers_remove_traps(self, rewards: dict, records: dict): + def _rangers_remove_traps(self, rewards: dict, obs: dict): """Helper function where rangers detect and remove traps in their current cells. Note that detection depends on self.prob_detect_trap.""" @@ -523,8 +412,8 @@ class raw_env(ParallelEnv): # First all agents detect traps. traps_detected = set() # Multiple agents can detect same trap for _trap in [_t for _t in self.grid.state if isinstance(_t, Trap)]: + _loc = self.grid.state[_trap] # also the nbor.rangers locations for ranger in self.grid.get_neighbours(_trap): - _loc = self.grid.state[ranger] if self.rng.random() < self.grid.prob_detect_trap(_loc): traps_detected.add(_trap) @@ -532,67 +421,67 @@ class raw_env(ParallelEnv): # Extract the owning poacher name _poacher = "_".join(_trap.name.split("_")[-2:]) - # updating the records for all implicated rangers - key = "trap_empty" if _trap.value == 0 else "trap_full" + # updating the obs for all implicated rangers + # If trap value is zero(before reset/capture, it was empty) + key = -2 if _trap.value == 0 else -1 for ranger in self.grid.get_neighbours(_trap): - records[ranger][key] += 1 + obs[ranger]["observations"][key] += 1 - # Assign rewards to Rangers + # Assign rewards to Rangers. Includes trap removal logic. self._assign_reward( _poacher, -self.grid.remove_trap(_trap), rewards ) - def _rangers_remove_poachers( - self, rewards: dict, terminations: dict, records: dict - ): - """Helper function where rangers detect and - remove poachers in their current cell. - Note that detection depends on self.prob_detect_cell""" + def _rangers_remove_poachers(self, rewards: dict, obs: dict): + """Helper function where rangers detect and remove poachers in their + current cell. Detection depends on self.prob_detect_cell""" # First mark all captured poachers - caught_poachers = [] + caught_poachers = set() for _poacher in self.poachers: - if _poacher not in self.agents or terminations[_poacher]: + if self.grid.state[_poacher][1] < 0: continue # Agent is already caught, skip - for _ranger in [ # Otherwise, try detecting + _loc = self.grid.state[_poacher] + for _ranger in [ _r for _r in self.grid.get_neighbours(_poacher) if "ranger" in _r + and self.rng.random() < self.grid.prob_detect_cell(_loc) ]: - _loc = self.grid.state[_ranger] - if self.rng.random() < self.grid.prob_detect_cell(_loc): - caught_poachers.append(_poacher) + caught_poachers.add(_poacher) # Poacher detected. - # update their status, and all ranger records + # update their status, and all ranger obs for _poacher in caught_poachers: - terminations[_poacher] = True - # EXPERIMENTAL: penalty for Poachers for each prey recovered - penalty = self.grid.state[_poacher][-1] * self.grid.REWARD_MAP[ - "PREY_FOUND" - ] + self.grid.remove_poacher(_poacher) + penalty = ( + self.grid.remove_poacher(_poacher) # C_capture + + self.grid.state[_poacher][-1] # C_prey + * self.grid.REWARD_MAP["PREY_FOUND"] + + self.grid.state[_poacher][-2] # C_trap + * self.grid.REWARD_MAP["TRAP_FOUND"] + ) self._assign_reward(_poacher, -penalty, rewards) - # updating records for implicated rangers + # Update the trap and prey status for implicated rangers: + # First increment #of caught poachers, then #traps/prey captured. for _ranger in [ _r for _r in self.grid.get_neighbours(_poacher) if "ranger" in _r ]: - # Update the trap and prey status - records[_ranger]["poachers_found"][ + obs[_ranger]["observations"][-5] += 1 + obs[_ranger]["observations"][-4:-2] += self.grid.state[ _poacher - ] += self.grid.state[_poacher][2:] + ][2:] - def _poachers_place_traps(self, actions: dict, terminations: dict): - """Helper function where poachers place traps. - Note that this will not succeed if poacher has - no traps to place.""" + def _poachers_place_traps(self, actions: dict): + """Helper function where poachers place traps. Note that this will not + succeed if poacher has no traps to place.""" for poacher in [ p for p in self.poachers - if (p in self.agents) and (not terminations[p]) and actions[p] == 5 + if self.grid.state[p][1] >= 0 and actions[p] == 5 ]: - trap = self.poacher_traps[poacher].pop() + trap = self.poacher_traps[poacher].pop() # Will throw if empty. self.grid.add_trap(trap, self.grid.state[poacher][0:2]) self.grid.state[poacher][2] -= 1 assert self.grid.state[poacher][2] == len( @@ -605,19 +494,10 @@ class raw_env(ParallelEnv): for trap in [t for t in self.grid.state if isinstance(t, Trap)]: _loc = self.grid.state[trap] if ( - self.rng.random() - < self.grid.prob_animal_appear(_loc) * trap.efficiency + self.rng.random() < self.grid.prob_animal_appear(_loc) and trap.value == 0 ): - trap.value += 1 - - def _cleanup(self) -> None: - """Update the set of live agents, and their position - on the grid.""" - for agent in self.killed_agents: - # clean GridState lists - del self.grid.state[agent] - self.agents.remove(agent) + trap.value = 1 @functools.lru_cache(maxsize=None) def observation_space(self, agent): diff --git a/anti_poaching/env/utils/flatten_utils.py b/anti_poaching/env/utils/flatten_utils.py deleted file mode 100644 index 500f182ff4513ef3882cfc1cc9cf58015d32574a..0000000000000000000000000000000000000000 --- a/anti_poaching/env/utils/flatten_utils.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Module to implement useful functions like box_flatten -and box_flatten_obs. These will be used in the -NonCategoricalFlatten wrapper, as defined in `wrappers.py`. -""" - -from functools import singledispatch -import numpy as np -import gymnasium -from gymnasium.spaces import ( - Box, - Dict, - Discrete, - MultiDiscrete, - MultiBinary, -) - - -# Custom Flattening logic -@singledispatch -def box_flatten(space: gymnasium.Space, dtype=None) -> Box: - """ - SingleDispatch function to recursively fold a - space into a Box: Currently supports only Box, - Discrete, MultiDiscrete and Dict spaces. - """ - raise NotImplementedError(f"Unknown/Unsupported space: {space=}") - - -@box_flatten.register(Dict) -def flatten_dict(space: Dict, dtype=None) -> Box: - """ - Function to fold a Dict space into a Box. - """ - # Recursively flatten all the spaces in the keys - list_boxed_spaces = [box_flatten(subsp) for subsp in space.values()] - - # return the Box with the concatenated shapes - return Box( - low=np.concatenate([subsp.low for subsp in list_boxed_spaces]), - high=np.concatenate([subsp.high for subsp in list_boxed_spaces]), - dtype=np.int32, - ) - - -@box_flatten.register(Discrete) -def flatten_discrete(space: Discrete, dtype=None) -> Box: - """Flatten a discrete box, but not as a categorical - space. We consider Discrete as a Box of ints. - """ - return Box( - low=0, high=space.n - 1, dtype=space.dtype if not dtype else dtype - ) - - -@box_flatten.register(MultiDiscrete) -def flatten_multidiscrete(space: MultiDiscrete, dtype=None) -> Box: - """Flatten a MultiDiscrete box, but not as a categorical - space. We consider MultiDiscrete as a multi-dimensional - Box of ints. - """ - return Box( - low=np.zeros_like(space.nvec).flatten(), - high=np.array(space.nvec).flatten() - 1, - dtype=space.dtype if not dtype else dtype, - ) - - -@box_flatten.register(Box) -def flatten_box(space: Box, dtype=None) -> Box: - """Flatten of a box should just return the box if it is 1D, - else the equivalent of running np.flatten on its samples.""" - return Box( - low=np.array(space.low).flatten(), - high=np.array(space.high).flatten(), - dtype=space.dtype if not dtype else dtype, - ) - - -@box_flatten.register(MultiBinary) -def flatten_multibinary(space: MultiBinary, dtype=None) -> Box: - """Convert a Binary to a Box""" - return Box( - low=0, - high=1, - shape=np.array(space.shape).flatten(), - dtype=space.dtype if not dtype else dtype, - ) - - -def box_flatten_obs(obs, dtype=None) -> np.array: - """ - Box flattens an observation recursively. - """ - if isinstance(obs, dict): - # Recursively fold. - return np.concatenate( - [box_flatten_obs(val).flatten() for val in obs.values()] - ) - - dtype = np.int32 if not dtype else dtype # Default to integer type - if isinstance(obs, (np.ndarray, list)) or np.isscalar(obs): - # Return the flattened version - return np.array(obs, dtype=dtype).flatten() - - raise NotImplementedError(f"Unknown observation sent: {obs}") - - -# Custom unflattening logic -def sizeof_space(space: gymnasium.Space): - """gets the number of elements in the space""" - if isinstance(space, Dict): - return np.sum([sizeof_space(subsp) for subsp in space.values()]) - return np.prod(np.array(space.shape).flatten().astype(np.int64)) - - -def parse_obs(obs: np.array, space: gymnasium.Space): - """Parse the observation into the original space.""" - - if isinstance(space, (Box, Discrete, MultiDiscrete, MultiBinary)): - return obs.reshape(space.shape).astype(space.dtype) - if not isinstance(space, Dict): - raise NotImplementedError - - # Now to implement parsing for dictionaries - begin, end = 0, 0 - new_obs = dict.fromkeys(space.keys()) - for key, subsp in space.items(): - end = begin + sizeof_space(subsp) # Define the chunk to parse - new_obs[key] = parse_obs(obs[begin:end], subsp) - begin = end # update the begin position for next subspace - - return new_obs diff --git a/anti_poaching/env/utils/game_utils.py b/anti_poaching/env/utils/game_utils.py index 0d25452cee00f4e60543f8f3208689975933b0fd..dc0a04dfc157fd591adb45fd709d5403f78ac3cf 100644 --- a/anti_poaching/env/utils/game_utils.py +++ b/anti_poaching/env/utils/game_utils.py @@ -9,6 +9,8 @@ from dataclasses import dataclass import numpy as np import pygame +from .typing import * + FPS = 10 BLK_SIZE = 30 @@ -27,7 +29,6 @@ class Trap: an animal has been caught in it.""" name: str = "Trap" - efficiency: float = 0.40 value: float = 0 def __hash__(self): @@ -40,10 +41,10 @@ class Trap: class BaseGridState(ABC): - """Abstract Class to represent the game state for the Anti-Poaching Game.""" + """AbstractClass to represent game state of an Anti-Poaching Game.""" REWARD_MAP: dict = {"TRAP_FOUND": 2, "PREY_FOUND": 2, "POACHER_FOUND": 100} - NULL_POS: np.array = np.array([-1, -1, 0, 0], dtype=np.int16) + NULL_POS: np.array = np.array([-1, -1, 0, 0], dtype=INTEGER) RENDER_MODES = ["ansi", "rgb"] def __init__( @@ -97,10 +98,9 @@ class BaseGridState(ABC): raise NotImplementedError def get_neighbours(self, agent: str) -> tuple: - """Given an agent, fetch agents in the same cell. - If supplied a Trap, this will fetch all non-Trap - objects (i.e. agents) which are in the same cell. - """ + """Given an agent, fetch agents in the same cell. If supplied a Trap, + this will fetch all non-Trap objects (i.e. agents) which are in the + same cell.""" agent_pos = self.state[agent][0:2] return [ _a @@ -117,13 +117,12 @@ class BaseGridState(ABC): """Tries to remove an agent from the supplied position. Returns a positive reward if successful, else throws an exception.""" assert poacher in self.state, f"Unauthorised remove !!!\n{self}" - self.state[poacher] = self.NULL_POS + self.state[poacher] = deepcopy(self.NULL_POS) return self.REWARD_MAP["POACHER_FOUND"] def remove_trap(self, trap: Trap) -> float: - """Tries to remove a Trap from the supplied position. - Returns the trap's value if successful, else throws an exception.""" - assert trap in self.state, f"Unauthorised remove !!!\n{self}" + """Removes a Trap from the grid.""" + assert trap in self.state, f"Unauthorised trap remove !!!\n{self}" del self.state[trap] return trap.value @@ -132,7 +131,8 @@ class BaseGridState(ABC): seed: int = None, **options, ) -> None: - """Resets the grid to initial position""" + """Resets the grid to initial position. This should use the seed + supplied to the parent anti_poaching.raw_env instance.""" self.rng = np.random.default_rng(seed) # Get options from the kwargs, and reset @@ -149,15 +149,11 @@ class BaseGridState(ABC): poachers: list, ntraps_per_poacher: int, ) -> None: - """Helper function to reinitialise the state - using the RNG engine. This replaces the use of self.init_state, - since the same seed generates the same initial state.""" - - # Regenerate state using the RNG + """Helper function to reinitialise the state using the RNG engine""" self.state = { **{ ranger: self.rng.integers( - low=0, high=self.N, size=2, dtype=np.int16 + low=0, high=self.N, size=2, dtype=INTEGER ) for ranger in rangers }, @@ -165,10 +161,10 @@ class BaseGridState(ABC): poacher: np.concatenate( ( self.rng.integers( - low=0, high=self.N, size=2, dtype=np.int16 + low=0, high=self.N, size=2, dtype=INTEGER ), # random location np.array( - [ntraps_per_poacher, 0], dtype=np.int16 + [ntraps_per_poacher, 0], dtype=INTEGER ), # (..., ntraps, npreys) ) ) @@ -198,6 +194,8 @@ class BaseGridState(ABC): pygame.draw.rect(self.screen, WHITE, rect, 1) for obj, pos in self.state.items(): + if pos[0] < 0: + continue width = 0 if isinstance(obj, Trap): color = GREY @@ -214,9 +212,15 @@ class BaseGridState(ABC): """Rendering as text""" print("\nCurrent state:") print(f"\tGrid size: {self.N}") + printgrid = {(i, j): [] for i in range(self.N) for j in range(self.N)} - for thing, loc in self.state.items(): - printgrid[tuple(loc[0:2].tolist())].append(thing) + for thing, pos in self.state.items(): + try: + if pos[0] < 0: + continue + printgrid[tuple(pos[0:2].tolist())].append(thing) + except KeyError: + breakpoint() for i in range(self.N): for j in range(self.N): @@ -230,10 +234,7 @@ class GridStateConstProb(BaseGridState): State is represented as a dictionary of player-(location-misc) pairs. - v3. Promoted to GridStateConstProb - this is an abstract class. - Also takes over the responsibility of probabilties in - the grid. This allows us to define custom GridState objects, - especially when using custom probabilities. + v3. Promoted to GridStateConstProb - this implements BaseGridState. v2. Deprecates the grid dictionary for simpler code. The GridEnv class is now upgraded to the GridState class, @@ -293,26 +294,23 @@ class GridStateConstProb(BaseGridState): """For an agent, return the permitted actions she can take in her current cell. Returns an numpy.array that represents an action mask.""" - if "ranger" in agent: - size_action_mask = 5 - elif "poacher" in agent: - size_action_mask = 6 - action_mask = np.zeros(size_action_mask, dtype=np.int8) + size_action_mask = 5 if "ranger" in agent else 6 + action_mask = np.zeros(size_action_mask, dtype=SHORT) action_mask[0] = 1 # always validate null-action # Find the agent position x, y = self.state[agent][0:2] if x == -1 or y == -1: - return action_mask # If agent is dead + return action_mask # If agent is captured, nothing further. - # calculate permitted actions + # calculate permitted actions for "live" agents UP = 0 if x <= 0 else 1 DOWN = 0 if x >= self.N - 1 else 1 LEFT = 0 if y <= 0 else 1 RIGHT = 0 if y >= self.N - 1 else 1 # Only setting movement flags. Further actions - # are set by the AntiPoachingGame.observe fn. + # are set by the anti_poaching.raw_env instance. action_mask[1:5] = [UP, LEFT, DOWN, RIGHT] return action_mask @@ -334,10 +332,12 @@ class GridStateConstProb(BaseGridState): # Throw error if invalid position generated # since it needed an invalid action - assert self.valid_pos( - new_pos - ), f"update_position({agent}): Invalid action. {(x,y)=} -> {action=} -> {new_pos=}" - + try: + assert self.valid_pos( + new_pos + ), f"update_position({agent}): Invalid action. {(x,y)=} -> {action=} -> {new_pos=}, grid size = {self.N}" + except AssertionError: + breakpoint() # update grid and state list self.state[agent][0:2] = new_pos @@ -427,7 +427,7 @@ class GridStateVaryingProb(BaseGridState): size_action_mask = 5 elif "poacher" in agent: size_action_mask = 6 - action_mask = np.zeros(size_action_mask, dtype=np.int8) + action_mask = np.zeros(size_action_mask, dtype=SHORT) action_mask[0] = 1 # always validate null-action # Find the agent position diff --git a/anti_poaching/env/utils/typing.py b/anti_poaching/env/utils/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..a1bdb3c59e93cd2a7b56171a524b5fc07277764c --- /dev/null +++ b/anti_poaching/env/utils/typing.py @@ -0,0 +1,7 @@ +"""Module to standardise custom types used for the env and other scripts""" + +import numpy as np + +INTEGER = np.int32 # Core integer type +SHORT = np.int8 # Action-mask types +AgentID = str # Agents are strings diff --git a/anti_poaching/env/utils/wrappers.py b/anti_poaching/env/utils/wrappers.py index 419962f8f416e9043573e417974046bef58bf2da..90761d00c45aeb242be63befd88de9b73d305486 100644 --- a/anti_poaching/env/utils/wrappers.py +++ b/anti_poaching/env/utils/wrappers.py @@ -1,6 +1,5 @@ -"""Module that implements the NonCategoricalFlatten and -Stacker Wrappers that ConservationGame needs. -""" +"""Module that implements the wrappers to make the Anti-Poaching game +compatible with RLlib.""" import numpy as np @@ -12,241 +11,56 @@ from pettingzoo.utils.env import ParallelEnv # Ray imports from ray.rllib.env.multi_agent_env import MultiAgentEnv -# local imports -from .flatten_utils import box_flatten, box_flatten_obs - - -class NonCategoricalFlatten(ObservationWrapper): - """ - Class that takes an environment and flattens the - observation without using Categorical Encoding. - This is currently implemented for Discrete, - MultiDiscrete, Box and Dict environments only for - ParallelEnv environments. - """ - - def __init__(self, env): - ObservationWrapper.__init__(self, env) - - # Store new space in self._observation_spaces - self.observation_spaces = { - agent: Dict( - { - "action_mask": box_flatten( - env.observation_spaces[agent]["action_mask"], - ), - "observations": box_flatten( - env.observation_spaces[agent]["observations"] - ), - } - ) - for agent in self.env.agents - } - - def observation(self, observation): - """ - Applies the modification to all observations - sent out by step and reset i.e. - env.step -> observation(obs), rewards, ... - env.reset-> observation(obs), info - """ - return { - agent: { - "action_mask": box_flatten_obs( - observation[agent]["action_mask"], - dtype=np.int8, - ), - "observations": box_flatten_obs( - observation[agent]["observations"] - ), - } - for agent in observation - } - - def reset(self, *, seed=None, options=None) -> tuple: - """ - Modifies the :attr:`env` after calling :meth:`reset`, - returning a modified observation using :meth:`self.observation`. - """ - reset_obs, infos = self.env.reset(seed=seed, options=options) - return self.observation(reset_obs), infos - - def observation_space(self, agent): - """ - Overrides the observation_space method of - the environment and sends the modified spaces - """ - return self.observation_spaces[agent] - - @property - def unwrapped(self): - """Returns the unwrapped environment""" - _unwrapped = getattr(self.env, "unwrapped", self.env) - if callable(_unwrapped): - return _unwrapped() - return _unwrapped - - -class StackerWrapper(MultiAgentEnv): - """A wrapper for PettingZoo ParallelEnv's where each agent has distinct - action and observation spaces. This wrapper stacks all these spaces into - single dictionaries keyed by the agent IDs. For example, consider - - agent0: obs=(10,), act=Discrete(2) - agent1: obs=(20,), act=Discrete(3) - - Then, we construct an action_space as the dictionary - - action_space = Dict( {agent0: ..., agent1: ... } ) - - Similarly, we construct the observation_space variable. This allows us to - not lose semantic meaning when we define custom wrappers, and hopefully - allow RLlib to optimise happily :) - """ - - def __init__(self, env: ParallelEnv): - super().__init__() - self.env = env - self.agents = env.agents - self._agent_ids = set(self.agents) - - # Provide full (preferred format) observation- and action-spaces as Dicts - # mapping agent IDs to the individual agents' spaces. - self._obs_space_in_preferred_format = True - self.observation_space = Dict( - {agent: self.env.observation_space(agent) for agent in self.agents} - ) - - self._action_space_in_preferred_format = True - self.action_space = Dict( - {agent: self.env.action_space(agent) for agent in self.agents} - ) - - def reset(self, *, seed=None, options=None): - """Resets the base environment""" - return self.env.reset(seed=seed, options=options) - - def step(self, action_dict): - """ - Completely delegate responsibility to ParallelEnv, since it returns the - same format of obs, rewards, terminations, truncations, infos. - """ - # Get environment to step - obs, rew, terminateds, truncateds, info = self.env.step(action_dict) - - # RLlib-specific keys - terminateds["__all__"] = all(terminateds.values()) - truncateds["__all__"] = all(truncateds.values()) - - # return processed dictionaries - return obs, rew, terminateds, truncateds, info - - @property - def unwrapped(self): - """Returns the unwrapped environment""" - _unwrapped = getattr(self.env, "unwrapped", self.env) - if callable(_unwrapped): - return _unwrapped() - return _unwrapped - class QMIXCompatibilityLayer(MultiAgentEnv): - """ - Class that renames the "observations" to "obs", - and pads with observations of terminated agents. - This is for compatibility with QMIX. - """ + """Class that renames the "observations" to "obs". This is for + compatibility with QMIX.""" def __init__(self, env): super().__init__() self.env = env - # Store new space in self._observation_spaces self._obs_space_in_preferred_format = True self.observation_space = { agent: Dict( { - "action_mask": env.observation_space[agent]["action_mask"], - "obs": env.observation_space[agent]["observations"], + "action_mask": env.observation_space(agent)["action_mask"], + "obs": env.observation_space(agent)["observations"], } ) - for agent in self.env.agents + for agent in env.agents } self._action_space_in_preferred_format = True self.action_space = env.action_space - def observation(self, obs, missing_keys: set = None): - """ - Applies the modification to all observations - sent out by step and reset i.e. + def observation(self, obs: dict): + """Applies the modification to all observations sent out by step + and reset i.e. env.step -> observation(obs), rewards, ... env.reset-> observation(obs), info """ - # QMIX cannot handle disappearing agents :(( - if missing_keys is None: - missing_keys = set() - for agent in self.observation_space.keys(): - if agent in missing_keys: - # Need to pad for this agent + rename - obs[agent] = self.env.observation_space[agent].sample() - obs[agent]["action_mask"] = np.array( - [1, 0, 0, 0, 0, 0], dtype=np.int8 - ) # First action is always valid + only poachers die - - # Copy time and set invalid state(location). Zero everything else. - obs[agent]["obs"] = np.zeros_like(obs[agent]["observations"]) - obs[agent]["obs"][0:3] = [ - self.env.unwrapped.max_time - self.env.unwrapped.curr_time, - -1, - -1, - ] - del obs[agent]["observations"] # Delete the invalid key - else: - # Only rename for this agent, if not killed yet - obs[agent] = { - "action_mask": obs[agent]["action_mask"], - "obs": obs[agent]["observations"], - } - - # Return parsed obs + obs[agent]["obs"] = obs[agent].pop("observations") return obs def step(self, action_dict: dict) -> tuple: """Pads each returned dictionary to contain dead agents as well""" - # Terms and truncs will be overwritten anyway - obs, rewards, _, _, infos = self.env.step(action_dict) - - missing_keys = set( - key for key in self.observation_space.keys() if key not in rewards - ) - obs = self.observation(obs, missing_keys) - rewards.update(dict.fromkeys(missing_keys, 0)) - infos.update(dict.fromkeys(missing_keys, {})) - - # Only final step must be True for EpisodeV2 compatibility - terms = dict.fromkeys( - self.observation_space.keys(), - self.env.unwrapped.max_time == self.env.unwrapped.curr_time, - ) - truncs = dict.fromkeys(terms.keys(), False) + obs, rewards, terms, truncs, infos = self.env.step(action_dict) # RLlib-specific keys terms["__all__"] = all(terms.values()) truncs["__all__"] = all(truncs.values()) - return obs, rewards, terms, truncs, infos + return self.observation(obs), rewards, terms, truncs, infos def reset(self, *, seed=None, options=None) -> tuple: - """ - Modifies the :attr:`env` after calling :meth:`reset`, - returning a modified observation using :meth:`self.observation`. - """ + """Modifies the :attr:`env` after calling :meth:`reset`, + returning a modified observation using :meth:`self.observation`.""" reset_obs, infos = self.env.reset(seed=seed, options=options) return self.observation(reset_obs), infos @property - def unwrapped(self): + def get_sub_environments(self): """Returns the unwrapped environment""" _unwrapped = getattr(self.env, "unwrapped", self.env) if callable(_unwrapped): diff --git a/examples/manual_policies/fixed_policy.py b/examples/manual_policies/fixed_policy.py index 4ed8befb5e7ae59535bb2354211b3d38c920ab43..412a8b66e805850bf6a4f4a2b61a5b26e6ffac34 100644 --- a/examples/manual_policies/fixed_policy.py +++ b/examples/manual_policies/fixed_policy.py @@ -111,6 +111,3 @@ if __name__ == "__main__": print("\n GAME OVER !\n") print(cg.curr_time, " is the current time") print(cg.max_time, " is the maximum time") - print("Rewards: ") - for agent in cg.possible_agents: - print(agent, " has ", cg.total_rewards[agent]) diff --git a/examples/manual_policies/random_policy.py b/examples/manual_policies/random_policy.py index 75bf33bab721619cd82beb19709503580f2bd91a..d252ea619d860152263ef48e6f10a88e53693a7c 100644 --- a/examples/manual_policies/random_policy.py +++ b/examples/manual_policies/random_policy.py @@ -59,6 +59,3 @@ if __name__ == "__main__": print("\n GAME OVER !\n") print(cg.curr_time, " is the current time") print(cg.max_time, " is the maximum time") - print("Rewards: ") - for agent in cg.possible_agents: - print(agent, " has ", cg.total_rewards[agent]) diff --git a/examples/rllib_examples/callbacks.py b/examples/rllib_examples/callbacks.py index 6c7f9e8bbfecbd8ce229e124f2be987812dd02dc..f94b461f0d82c48a2863e60fcea518c50f72f697 100644 --- a/examples/rllib_examples/callbacks.py +++ b/examples/rllib_examples/callbacks.py @@ -26,6 +26,20 @@ from ray.rllib.evaluation import Episode, RolloutWorker from ray.rllib.algorithms.callbacks import DefaultCallbacks +class SumRangerRewardMetric(DefaultCallbacks): + """Callback to compute a custom metric: the sum of all ranger + rewards. This is stored as a custom metric: rangers_mean_reward_sum""" + + def on_train_result(self, *, algorithm, result: dict, **kwargs): + result["custom_metrics"]["rangers_mean_reward_sum"] = np.sum( + [ + v + for k, v in result["policy_reward_mean"].items() + if "ranger" in k + ] + ) + + class RestoreNonTrainingAgents(DefaultCallbacks): """ Restores the weights of an algorithm on init @@ -121,7 +135,7 @@ class EpisodeMetricsCallbacks(DefaultCallbacks): # Get base environment env = base_env._unwrapped_env - # cjjollect data + # collect data # First collecting number of traps held by each poacher. for poacher in env.poachers: episode.custom_metrics["num_traps_" + poacher].append( diff --git a/examples/rllib_examples/configs.py b/examples/rllib_examples/configs.py index 25a0aa9c2855694c7f890331b78ac27978faf910..59eb47a44b766b50631925416fab1d1743297fca 100644 --- a/examples/rllib_examples/configs.py +++ b/examples/rllib_examples/configs.py @@ -16,8 +16,7 @@ from ray.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper # Importing my enviroment from anti_poaching.anti_poaching_v0 import ( anti_poaching, - rllib_cons_game, - StackerWrapper, + grouped_rllib_cons_game, ) # Constant batch size for all algorithms @@ -27,7 +26,7 @@ _TRAIN_BATCH_SIZE = _NUM_EPISODES_PER_ITER * _LEN_EPISODE def generate_training_config( - env: StackerWrapper, + env: "AntiPoaching", env_config: dict, policies_to_train: list = None, hyperparams: dict = None, @@ -57,12 +56,7 @@ def generate_training_config( # not desired. !!! disable_env_checking=True, ) - .experimental( - # Otherwise the observation dict will - # be flattened to a numpy array. Plus, - # we're already doing our own preprocessing. - _disable_preprocessor_api=True - ) + # .experimental(_disable_preprocessor_api=True) .framework("torch") .training( model={ @@ -71,15 +65,15 @@ def generate_training_config( "no_masking": False }, # Required by the example }, - **hyperparams, # Only used when set + **hyperparams, # Only non-empty when set ) .resources(num_gpus=num_gpus) .multi_agent( policies={ agent: PolicySpec( policy_class_selector(agent), - observation_space=env.observation_space[agent], - action_space=env.action_space[agent], + observation_space=env.observation_space(agent), + action_space=env.action_space(agent), ) for agent in env.agents }, @@ -90,7 +84,7 @@ def generate_training_config( def ppo_config( - env: StackerWrapper, + env: "AntiPoaching", env_config: dict = None, policies_to_train=None, num_gpus: int = 0, @@ -136,7 +130,7 @@ def ppo_config( def pg_config( - env: StackerWrapper, + env: "AntiPoaching", env_config: dict = None, policies_to_train=None, num_gpus: int = 0, @@ -179,6 +173,7 @@ def qmix_config( # None PolicyClass -> RLlib automatically infers policy_class_selector = lambda agent: None + env = grouped_rllib_cons_game(env, env_config) # Group the env return ( QMixConfig() .framework("torch") diff --git a/examples/rllib_examples/example_utils.py b/examples/rllib_examples/example_utils.py index 9b196608d3b1c498411f2979ea326f8de92bceb4..2ec88a4e185bc4a6bbdd6ea552c157ebb7e86ead 100644 --- a/examples/rllib_examples/example_utils.py +++ b/examples/rllib_examples/example_utils.py @@ -1,15 +1,10 @@ -""""Defines utils for the examples, like argument parsers, -and a method to define an RLlib ready instance using args. -""" +""""Defines utils for the examples, like argument parsers, and a method to +define an RLlib ready instance using args.""" import pathlib import argparse import json -from anti_poaching.anti_poaching_v0 import ( - anti_poaching, - rllib_cons_game, - grouped_rllib_cons_game, -) +from anti_poaching.anti_poaching_v0 import anti_poaching def get_args() -> argparse.ArgumentParser: @@ -38,7 +33,4 @@ def define_game_from_args(args: argparse.Namespace) -> tuple: "prob_animal_appear": args.prob_anim, "max_time": args.max_time, } - env_generator = ( - grouped_rllib_cons_game if args.algo == "QMIX" else rllib_cons_game - ) - return env_generator(None, config=cg_config), cg_config + return anti_poaching.parallel_env(**cg_config), cg_config diff --git a/examples/rllib_examples/from_checkpoint.py b/examples/rllib_examples/from_checkpoint.py index 9039ff77b6cee7b0ff9c124d6a8480cde1c2b9d4..5eb8124bb93c8b40944fecb8cc6da14ce19baaa3 100644 --- a/examples/rllib_examples/from_checkpoint.py +++ b/examples/rllib_examples/from_checkpoint.py @@ -9,7 +9,6 @@ For example, to launch from the checkpoint we can simply use python from_checkpoint.py <CHPOINT_DIR> - """ import json @@ -21,11 +20,7 @@ from ray.tune.logger import pretty_print from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.qmix import QMixConfig, QMix -# Importing my enviroment -from anti_poaching.anti_poaching_v0 import ( - anti_poaching, - rllib_cons_game, -) +from anti_poaching.anti_poaching_v0 import anti_poaching def get_actions(algo: Algorithm, obs: dict) -> list: @@ -47,7 +42,7 @@ def get_actions(algo: Algorithm, obs: dict) -> list: observation=obs[agent], state=state[idx], policy_id=agent, - timestep=env.env.unwrapped.curr_time, + timestep=obs[agent][0]["obs"][0], ) # This is required since compute_single_actions can return @@ -77,7 +72,11 @@ if __name__ == "__main__": env = algo.env_creator(env_config) # Recall that QMix uses the GroupAgentsWrapper for any env. - _base_env = env.env.unwrapped if isinstance(algo, QMix) else env.unwrapped + _base_env = ( + env.env.get_sub_environments + if isinstance(algo, QMix) + else env.get_sub_environments + ) # The calm before the storm ... obs, info = env.reset() diff --git a/examples/rllib_examples/heuristic_policies.py b/examples/rllib_examples/heuristic_policies.py index 07b8a6a5473854595fea2bed5e71b21ef234c7db..8af196fe78c900d6e3a75e0cd32b1003ae78086a 100644 --- a/examples/rllib_examples/heuristic_policies.py +++ b/examples/rllib_examples/heuristic_policies.py @@ -5,6 +5,7 @@ for the Anti Poaching game. """ from typing import Generator +from collections import OrderedDict import pathlib import numpy as np import torch @@ -43,9 +44,12 @@ class ActionMaskedRandomPolicy(Policy): episodes=None, **kwargs, ): - action_mask = np.array( - obs_batch["action_mask"], dtype=np.int8 - ).flatten() + if isinstance(obs_batch, OrderedDict): + action_mask = np.array( + obs_batch["action_mask"], dtype=np.int8 + ).flatten() + else: + action_mask = obs_batch[0][: self.action_space.n].astype(np.int8) return ( [self.action_space.sample(action_mask)], state_batches, @@ -186,9 +190,13 @@ class ActionMaskedPlanningPoacherPolicy(Policy): episodes=None, **kwargs, ): + if isinstance(obs_batch, OrderedDict): + flat_obs = obs_batch["observations"].flatten() + flat_mask = obs_batch["action_mask"].flatten() + else: + flat_obs = obs_batch.flatten() + flat_mask, flat_obs = flat_obs[:6], flat_obs[6:] - flat_obs = obs_batch["observations"].flatten() - flat_mask = obs_batch["action_mask"].flatten() rem_time, curr_pos = flat_obs[0], flat_obs[1:3] if all(curr_pos == [-1, -1]): diff --git a/examples/rllib_examples/main.py b/examples/rllib_examples/main.py index 94301cbf670c99f61629ac1cfbcba84b60398b1e..5d57f6b5fb5f83db54d5be487263dd11b9b87d35 100644 --- a/examples/rllib_examples/main.py +++ b/examples/rllib_examples/main.py @@ -33,10 +33,7 @@ from ray.rllib.algorithms.qmix import QMixConfig, QMix from ray.rllib.algorithms.maddpg import MADDPGConfig, MADDPG # Importing custom code -from anti_poaching.anti_poaching_v0 import ( - anti_poaching, - StackerWrapper, -) +from anti_poaching.anti_poaching_v0 import anti_poaching from heuristic_policies import ( ActionMaskedRandomPolicy, ActionMaskedPlanningPoacherPolicy, @@ -82,19 +79,14 @@ if __name__ == "__main__": # call AntiPoachingGame with parser parameters # and store a reference to the created env in the callbacks. cg, cg_config = define_game_from_args(args) - _base_cg_env = ( - cg.unwrapped if isinstance(cg, StackerWrapper) else cg.env.unwrapped - ) # Choose the agents that learn policies_to_train = [] if "r" in args.policies_train: - policies_to_train += ( - ["rangers"] if args.algo == "QMIX" else _base_cg_env.rangers - ) + policies_to_train += ["rangers"] if args.algo == "QMIX" else cg.rangers if "p" in args.policies_train: policies_to_train += ( - ["poachers"] if args.algo == "QMIX" else _base_cg_env.poachers + ["poachers"] if args.algo == "QMIX" else cg.poachers ) # Choose the algo that runs @@ -104,8 +96,6 @@ if __name__ == "__main__": algo_config = pg_config elif args.algo == "QMIX": algo_config = qmix_config - elif args.algo == "MADDPG": - algo_config = generic_conf(MADDPGConfig) # Generic first else: raise RuntimeError(f"Unknown algorithm specified: {args.algo}") @@ -177,7 +167,6 @@ if __name__ == "__main__": ).fit() # Get the path to this experiment and save env_config there. - path_exp = results[0].path env_config_file_path = pathlib.Path(results[0].path) / "env_config.json" with open(env_config_file_path, "w") as fd: print(json.dumps(cg_config), file=fd) diff --git a/examples/rllib_examples/tune_antipoaching.py b/examples/rllib_examples/tune_antipoaching.py index 85c24a80bc12294ad2ab80d53fd88eaa13cd39d9..7cec120d15fe95b243e42e8b9f78838efe1c8351 100644 --- a/examples/rllib_examples/tune_antipoaching.py +++ b/examples/rllib_examples/tune_antipoaching.py @@ -7,10 +7,7 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.policy.policy import PolicySpec from ray.rllib.algorithms.callbacks import DefaultCallbacks -from anti_poaching.anti_poaching_v0 import ( - anti_poaching, - StackerWrapper, -) +from anti_poaching.anti_poaching_v0 import anti_poaching from heuristic_policies import ( ActionMaskedRandomPolicy, ActionMaskedPlanningPoacherPolicy, @@ -30,14 +27,6 @@ from main import get_policies_from_arg import callbacks as cb -class ComputeCustomMetric(DefaultCallbacks): - def on_train_result(self, *, algorithm, result: dict, **kwargs): - # Compute sum of the absolute value of multiple policies - result["custom_metrics"]["policies_absolute_reward_mean"] = sum( - np.abs(list(result["policy_reward_mean"].values())) - ) - - COMMON_HYPERPARAMS = { "train_batch_size": tune.choice([128, 256, 512, 1024, 2048, 4096]), "gamma": tune.choice([0.9, 0.95, 0.97, 0.99, 0.999]), @@ -91,9 +80,6 @@ if __name__ == "__main__": # Then parsing the above args for the tune search cg_args, _unknown = get_args() cg, cg_config = define_game_from_args(cg_args) - _base_cg_env = ( - cg.unwrapped if isinstance(cg, StackerWrapper) else cg.env.unwrapped - ) args, _ = parser.parse_known_args() # Parse the stopping conditions @@ -109,15 +95,14 @@ if __name__ == "__main__": policies_to_train = [] if "r" in cg_args.policies_train: policies_to_train += ( - ["rangers"] if cg_args.algo == "QMIX" else _base_cg_env.rangers + ["rangers"] if cg_args.algo == "QMIX" else cg.rangers ) if "p" in cg_args.policies_train: policies_to_train += ( - ["poachers"] if cg_args.algo == "QMIX" else _base_cg_env.poachers + ["poachers"] if cg_args.algo == "QMIX" else cg.poachers ) # Choose the algo that runs - print(cg_args.algo) if cg_args.algo == "PPO": algo_config = ppo_config hyperparams = PPO_HYPERPARAMS @@ -146,7 +131,7 @@ if __name__ == "__main__": num_gpus=cg_args.num_gpus, # AlgoConfig to use GPUs. ) .reporting(metrics_num_episodes_for_smoothing=1) - .callbacks(ComputeCustomMetric) + .callbacks(cb.SumRangerRewardMetric) .rollouts( rollout_fragment_length="auto", batch_mode=tune.sample_from( @@ -173,8 +158,8 @@ if __name__ == "__main__": cg_args.algo, param_space=config, tune_config=tune.TuneConfig( - metric="custom_metrics/policies_absolute_reward_mean", - mode="min", + metric="custom_metrics/rangers_mean_reward_sum", + mode="max", num_samples=args.num_samples, ), run_config=air.RunConfig(stop=stop), diff --git a/tests/test_flatten_space.py b/tests/test_flatten_space.py deleted file mode 100644 index 2c9df8d12816a0bf3b59259a6f0c97a0cd25490e..0000000000000000000000000000000000000000 --- a/tests/test_flatten_space.py +++ /dev/null @@ -1,162 +0,0 @@ -from collections import OrderedDict -import numpy as np -import pytest -from gymnasium.spaces import ( - Box, - Dict, - Discrete, - MultiDiscrete, - MultiBinary, - Text, -) -from anti_poaching.anti_poaching_v0 import ( - box_flatten, - box_flatten_obs, - parse_obs, -) - - -def test_box(): - # Testing on a flat box - b = Box(low=1, high=10, dtype=np.int8) - assert box_flatten(b) == b, "Flatten changed the box space!" - - # and then on a 2D box - b = Box(low=1, high=10, dtype=np.int8, shape=(2, 3)) - assert box_flatten(b) == Box( - low=1, high=10, shape=(6,), dtype=np.int8 - ), "Flatten did not correctly work on 2D box!" - - -def test_discrete(): - d = Discrete(10) - b = box_flatten(d) - assert b == Box( - low=0, high=9, dtype=d.dtype - ), f"Discrete is not translated literally to Box !" - for _ in range(1000): - assert b.contains( - np.array([d.sample()]) - ), f"{b=} does not contain {d=}" - assert d.contains(b.sample()[0]), f"{d=} does not contain {b=}" - - -def test_multidiscrete(): - md = MultiDiscrete([11, 11]) - b = box_flatten(md) - assert b == Box( - low=np.array([0, 0]), high=np.array([10, 10]), dtype=md.dtype - ), f"MultiDiscrete is not translated literally to Box !" - for _ in range(1000): - assert b.contains(md.sample()), f"{b=} does not contain {md=}" - assert md.contains(b.sample()), f"{md=} does not contain {b=}" - - -def test_multidim_discrete(): - md = MultiDiscrete([[2, 3], [4, 5]]) - b = box_flatten(md) - assert b == Box( - low=np.array([0, 0, 0, 0]), high=np.array([1, 2, 3, 4]), dtype=md.dtype - ), "MultiDiscrete is not translated literally to Box !" - - -def test_dict(): - dico = Dict( - { - "box": Box(low=1, high=10, dtype=np.float32), - "disc": Discrete(10), - } - ) - # Note that dico will parse through each dictionary using sorted keys. - assert box_flatten(dico) == Box( - low=np.array([1, 0]), high=np.array([10, 9]), dtype=np.float32 - ), "Dictionary flattening to box is not well defined" - - -def test_dict_nested(): - nested_dico = Dict( - { - "box": Box(low=1, high=10, dtype=np.int32), - "dico": Dict( - OrderedDict( - { - "inner_disc": Discrete(10), - "inner_bin": MultiBinary(2), - "inner_box": Box(low=0, high=10), - } - ) - ), - } - ) - # Note that dico will parse through each dictionary using sorted keys. - box = box_flatten(nested_dico) - assert box == Box( - low=np.array([1, 0, 0, 0, 0]), - high=np.array([10, 9, 1, 1, 10]), - ), "Dictionary flattening to box is not well defined" - - for itr in range(1000): - obs = nested_dico.sample() - fl_obs = box_flatten_obs(obs) - assert box.contains( - fl_obs - ), f"{itr=}| {box=} does not contain {fl_obs=}, taken from {obs=}" - - -def test_multibinary(): - mb = MultiBinary([11, 11]) - b = box_flatten(mb) - assert b == Box( - low=0, high=1, shape=mb.shape, dtype=mb.dtype - ), f"MultiBinary is not translated literally to Box !" - for _ in range(1000): - assert b.contains(mb.sample()), f"{b=} does not contain {mb=}" - assert mb.contains(b.sample()), f"{mb=} does not contain {b=}" - - -def test_unsupported_type(): - with pytest.raises(NotImplementedError): - text_space = Text(5) - box_flatten(text_space) - - -def test_parse_obs(): - """Checks if parse_obs parses samples of flattened spaces to the - correct size and type of the original space - """ - spaces = [ - Box(low=1, high=10, dtype=np.int8, shape=[2, 3]), - Discrete(10), - MultiBinary([11, 11]), - MultiDiscrete([[2, 3], [4, 5]]), - Dict( - { - "box": Box(low=1, high=10, dtype=np.float32), - "disc": Discrete(10), - } - ), - Dict( - { - "box": Box(low=1, high=10, dtype=np.int32), - "dico": Dict( - OrderedDict( - { - "inner_disc": Discrete(10), - "inner_bin": MultiBinary(2), - "inner_box": Box(low=0, high=10), - } - ) - ), - } - ), - ] - for space in spaces: - # Test for each space - flat_box = box_flatten(space) - # Test random samples (with casting) - for _ in range(1000): - flat_sample = flat_box.sample() - unflat_sample = parse_obs(flat_sample, space) - assert space.contains( - unflat_sample - ), f"{unflat_sample=} not in {space=}" diff --git a/tests/test_game_env.py b/tests/test_game_env.py index f34c4faaf8bc1201da9c435c303ebca378298088..1aae3126b1ea17c6e87fc29ef0affd82a24dfdf9 100644 --- a/tests/test_game_env.py +++ b/tests/test_game_env.py @@ -35,7 +35,7 @@ def lantipoach_game(): prob_detect_cell=0.5, prob_animal_appear=0.5, prob_detect_trap=0.5, - max_time=500, + max_time=200, ) print("End of fixture ...") @@ -63,9 +63,7 @@ def test_parallel_test_api_suite(lantipoach_game): def test_parallel_seed_test(): - """ - Runs the seed test on the Conservation game. - """ + """Runs the seed test on the Conservation game.""" parallel_seed_test(anti_poaching.parallel_env) @@ -93,21 +91,24 @@ def test_environment(antipoach_game, every: int = 20): for agent in antipoach_game.agents } # step through the environment - observations, _, terminations, truncations, _ = antipoach_game.step( - actions - ) - - # Verify that each agent receives a valid observation !!! + _step_objs = antipoach_game.step(actions) + observations, rewards, terminations, truncations, infos = _step_objs + + # Verify that all agents are alive !!! Recall that poachers are alive, + # but in a caught state. + time_stat = antipoach_game.curr_time < antipoach_game.max_time + for obj in _step_objs: + assert set(obj.keys()) == set( + antipoach_game.possible_agents + ), f"All agents should recieve things until end-of-game (current time = {antipoach_game.curr_time})" + + # Verify that each agent receives valid observations and actions !!! for agent, space in antipoach_game.observation_spaces.items(): - if agent not in antipoach_game.agents: - continue assert space.contains( observations[agent] ), f"{agent} received incompatible observations\n {observations[agent]}" for agent, space in antipoach_game.action_spaces.items(): - if agent not in antipoach_game.agents: - continue assert space.contains( actions[agent] ), f"{agent} received incompatible action\n {actions[agent]}" @@ -118,22 +119,15 @@ def test_environment(antipoach_game, every: int = 20): } # post-processing - done = all( - [ - x or y - for x, y in zip(terminations.values(), truncations.values()) - ] - ) + done = all(terminations.values()) or all(truncations.values()) if antipoach_game.curr_time % every == 0 or done: - antipoach_game.render() + antipoach_game.render() # This should not bug out print("-" * 80) print(antipoach_game.curr_time, " is the current time") print(antipoach_game.max_time, " is the maximum time") print("Rewards: ") - for agent in antipoach_game.possible_agents: - print(agent, " has ", antipoach_game.total_rewards[agent]) - assert sum(antipoach_game.total_rewards.values()) == 0.0 + assert sum(rewards.values()) == 0.0, "Game is not zero-sum !!!" def test_init_antipoach_game(lantipoach_game): @@ -175,50 +169,22 @@ def test_init_antipoach_game(lantipoach_game): assert set(lantipoach_game.rangers + lantipoach_game.poachers) == set( lantipoach_game.observation_spaces.keys() ), "Not all agents have observations, or there are too many agents." - assert all( - ( - set( - [ - "remaining_time", - "state", - "ground_traps", - "rangers", - "poachers", - ] - ) - == set( - lantipoach_game.observation_spaces[poacher][ - "observations" - ].keys() - ) - for poacher in lantipoach_game.poachers - ) - ), "Unknown poacher observation structure" - assert all( - ( - set( - [ - "remaining_time", - "state", - "partners", - "poacher_traps", - "poachers_captured", - "ground_traps", - ] - ) - == set( - lantipoach_game.observation_spaces[ranger][ - "observations" - ].keys() - ) - for ranger in lantipoach_game.rangers - ) - ), "Unknown ranger observation structure" - # Check rewards - assert set(lantipoach_game.total_rewards.keys()) == set( - lantipoach_game.agents - ), "Unknown rewards structure" + obs_keys = {"observations", "action_mask"} + for agent in lantipoach_game.agents: + _space = lantipoach_game.observation_space(agent) + _sample = _space.sample() + assert set(_space.keys()) == obs_keys + if "ranger" in agent: + assert ( + len(_sample["observations"]) + == lantipoach_game._ranger_obs_size + ), "Ranger obs space of incorrect size." + else: + assert ( + len(_sample["observations"]) + == lantipoach_game._poacher_obs_size + ), "Poacher obs space of incorrect size." def test_reset(lantipoach_game): @@ -260,24 +226,6 @@ def test_reset(lantipoach_game): ), "Reset does not reset poachers' traps." -def test_observe(lantipoach_game): - """ - Tests the received observations - """ - with pytest.raises(AssertionError): - # Agent does not exist, therefore should throw. - lantipoach_game.observe("agent", {"invalid": "data"}) - with pytest.raises(KeyError): - # Record is empty, therefore should throw. - lantipoach_game.observe("poacher_0", {}) - with pytest.raises(KeyError): - # Record is invalid, therefore should throw. - lantipoach_game.observe("poacher_0", {"invalid": "data"}) - with pytest.raises(KeyError): - # Record is invalid, therefore should throw. - lantipoach_game.observe("ranger_0", {"invalid": "data"}) - - def test_assign_reward(antipoach_game): """ Tests all the branches of the _assign_reward helper @@ -285,26 +233,20 @@ def test_assign_reward(antipoach_game): """ for poacher in antipoach_game.poachers.copy(): rewards = dict.fromkeys(antipoach_game.agents, 0) - antipoach_game._assign_reward(poacher, 1, rewards) + # First if we assign a reward to a poacher + antipoach_game._assign_reward(poacher, 1, rewards) assert rewards[poacher] == 1 - assert ( - antipoach_game.total_rewards[poacher] == 0 - ), "We did not assign anything here!!" assert ( sum((rewards[ranger] for ranger in antipoach_game.rangers)) == -1 ) # Force assigning a reward after the poacher is `caught` - antipoach_game.agents.remove(poacher) - del rewards[poacher] - antipoach_game._assign_reward(poacher, 1, rewards) - - assert ( - antipoach_game.total_rewards[poacher] == 1 - ), "total_rewards not updated for retired agent" + antipoach_game.grid.state[poacher] = antipoach_game.grid.NULL_POS + antipoach_game._assign_reward(poacher, -1, rewards) + assert rewards[poacher] == 0 assert ( - sum((rewards[ranger] for ranger in antipoach_game.rangers)) == -2 + sum((rewards[ranger] for ranger in antipoach_game.rangers)) == 0 ), "ranger rewards not updated when assigning reward to retired agent" @@ -314,8 +256,7 @@ def test_observation_time(antipoach_game): obs, infos = antipoach_game.reset() for agent in antipoach_game.agents: assert ( - antipoach_game.max_time - == obs[agent]["observations"]["remaining_time"] + antipoach_game.max_time == obs[agent]["observations"][0] ), f"{agent} is missing remaining_time attr in obs" for t in range(1, antipoach_game.max_time + 1): @@ -339,5 +280,5 @@ def test_observation_time(antipoach_game): for agent in antipoach_game.agents: # and that each agent receives this as well. assert (antipoach_game.max_time - t) == obs[agent]["observations"][ - "remaining_time" + 0 ], f"{agent} has bad updates for remaining_time attr in obs" diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py deleted file mode 100644 index 41237047068c64384f1c807c38c3d1445eba7efa..0000000000000000000000000000000000000000 --- a/tests/test_wrappers.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -Module to test the wrapped AntiPoachingGame, -i.e. after NonCategoricalFlatten and StackerWrapper -""" - -from anti_poaching.anti_poaching_v0 import ( - anti_poaching, - NonCategoricalFlatten, - StackerWrapper, -) -from pettingzoo.test import parallel_api_test - - -def test_boxper(): - """ - Test to verify that the NonCategoricalFlatten - per returns a ParallelEnv after modification. - """ - env = anti_poaching.parallel_env() - env = NonCategoricalFlatten(env) - parallel_api_test(env) - - -def test_pipeline(): - """ - Test to verify that the two transformations composed - work well together - """ - env = anti_poaching.parallel_env() - env = NonCategoricalFlatten(env) - - # Run a few iterations - obs = env.reset() - print("After NonCategoricalFlatten: ", obs) - - env = StackerWrapper(env) - - # Run a few iterations - # Note that MultiAgentEnv's always send two objects !!! - obs, _ = env.reset() - print("After StackerWrapper: ", obs) - assert all( - ({"observations", "action_mask"} == obs[agent].keys() for agent in obs) - ), f"Structure of obs changed, {obs=}"