diff --git a/anti_poaching/anti_poaching_v0.py b/anti_poaching/anti_poaching_v0.py
index db0790ef79e3e359fbdbc5a7e3fd55aaa9c0ae22..169b92d88857ae28d6cf39383ef32a49a55aca22 100644
--- a/anti_poaching/anti_poaching_v0.py
+++ b/anti_poaching/anti_poaching_v0.py
@@ -1,7 +1,5 @@
-"""
-Module to expose the Conservation Game 
-Environment from within the env folder.
-"""
+"""Module to expose the Anti-Poaching PettingZoo env. This also registers 
+compatibile versions for use with RLlib."""
 
 from gymnasium.spaces import Tuple, Dict, Box
 from pettingzoo.utils.env import ParallelEnv
@@ -12,86 +10,49 @@ from ray.tune.registry import register_env
 from ray.rllib.models import ModelCatalog
 from ray.rllib.examples.models.action_mask_model import TorchActionMaskModel
 from ray.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper
+from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
 
 from .env import anti_poaching
 from .env.utils.game_utils import GridStateConstProb, Trap
-from .env.utils.wrappers import (
-    NonCategoricalFlatten,
-    StackerWrapper,
-    QMIXCompatibilityLayer,
-)
-from .env.utils.flatten_utils import box_flatten, box_flatten_obs, parse_obs
-
+from .env.utils.wrappers import QMIXCompatibilityLayer
 
-def rllib_cons_game(
-    env: ParallelEnv = None, config: dict = None
-) -> MultiAgentEnv:
-    """
-    Wraps the Conservation Game to modify the
-    observation spaces and stack all spaces into
-    their respective joint spaces.
-    """
-    if config is None:
-        config = {}
-    if env is None:
-        env = anti_poaching.parallel_env(**config)
-    env = NonCategoricalFlatten(env)  # Dict(....) -> Box
-    env = StackerWrapper(env)  # Wrapped(ParallelEnv) -> MultiAgentEnv
-    return env
+# Register environment to use for RLlib
+register_env(
+    anti_poaching.metadata["name"],
+    lambda config: ParallelPettingZooEnv(anti_poaching.parallel_env(**config)),
+)
 
 
 def grouped_rllib_cons_game(
     env: ParallelEnv = None,
     config: dict = None,
 ) -> GroupAgentsWrapper:
-    """
-    This is for QMix and other algorithms that may need
-    grouped agents.
-    """
-    env = rllib_cons_game(env, config)
-    env = QMIXCompatibilityLayer(env)
-
+    """This is for QMix and other algorithms that uses agent groups"""
+    env = anti_poaching.parallel_env(**config) if not env else env
+    qenv = QMIXCompatibilityLayer(env)
     grouped_policies = ["rangers", "poachers"]
 
     # Declare the observation policy of the group to train,
     # and add attr to bypass the QMixPolicy validation step
     obs_space = Tuple(
-        Tuple(
-            [
-                env.observation_space[agent]
-                for agent in getattr(env.unwrapped, group)
-            ]
-        )
+        Tuple([qenv.observation_space[agent] for agent in getattr(env, group)])
         for group in grouped_policies
     )
 
     # Declare the grouped action space as well.
     act_space = Tuple(
-        Tuple(
-            [
-                env.action_space[agent]
-                for agent in getattr(env.unwrapped, group)
-            ]
-        )
+        Tuple([qenv.action_space(agent) for agent in getattr(env, group)])
         for group in grouped_policies
     )
 
     # return the grouped env using .with_agent_groups
-    return env.with_agent_groups(
-        groups={
-            group: getattr(env.unwrapped, group) for group in grouped_policies
-        },
+    return qenv.with_agent_groups(
+        groups={group: getattr(env, group) for group in grouped_policies},
         obs_space=obs_space,
         act_space=act_space,
     )
 
 
-# Register environment to use for RLlib
-register_env(
-    anti_poaching.metadata["name"],
-    lambda config: rllib_cons_game(config=config),
-)
-
 # Register grouped environment for QMix
 # Note that QMix assumes homogenous agents (ray2.8.0)
 register_env(
diff --git a/anti_poaching/env/anti_poaching.py b/anti_poaching/env/anti_poaching.py
index ecc2babfecc333704be6a80e54f9deefbde399f3..75ef3bf0740f0c7d11562f02e605c375692824e5 100644
--- a/anti_poaching/env/anti_poaching.py
+++ b/anti_poaching/env/anti_poaching.py
@@ -1,21 +1,16 @@
-"""
-Module to implement the Anti-Poaching game environment
-"""
+"""Module to implement the Anti-Poaching game environment"""
 
-import logging
 import functools
-from collections import OrderedDict
+from copy import deepcopy
 import numpy as np
-from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, MultiBinary
+import gymnasium as gym
 from pettingzoo.utils.env import ParallelEnv
 from .utils.game_utils import BaseGridState, GridStateConstProb, Trap
-
-# Agents are strings
-AgentID = str
+from .utils.typing import *
 
 # Game metadata as global
 metadata = {
-    "name": "anti_poaching_v0.2.2",
+    "name": "anti_poaching_v0.3",
     "render_modes": BaseGridState.RENDER_MODES,
     "is_parallelizable": True,
 }
@@ -59,7 +54,6 @@ def parallel_env(
         poachers,
         ntraps_per_poacher,
         max_time,
-        render_mode,
         seed,
     )
 
@@ -76,12 +70,13 @@ class raw_env(ParallelEnv):
 
     - Observation space:
 
-        Each observation space is a composite product of smaller spaces.
-        This reflects the formal definition of the model.
+        The observations are vectors with different lower and upper bounds.
+        The older implementation used dictionaries, which favored readibility
+        over usability.
 
         Assumption:
-        We also assume that a poacher can detect the number of their own traps in
-        the current cell with probability 1.
+        We also assume that a poacher can detect the number of their own traps
+        in the current cell with probability 1.
 
 
     - Rewards:
@@ -92,6 +87,8 @@ class raw_env(ParallelEnv):
     Version History
     ---------------
 
+    v0.3   - Reimplementation to remove a lot of complex code. Functionally
+             equivalent to v0.2.2, modulo any bugfixes.
     v0.2.2 - Reward poachers when capturing a prey only, minor bugfixes,
              observation space now contains remaining time.
     v0.2.1 - OrderedDicts for the observation spaces (see CHANGELOG)
@@ -108,10 +105,10 @@ class raw_env(ParallelEnv):
         poachers: str,
         ntraps_per_poacher: int,
         max_time: int,
-        render_mode: str = "ansi",
         seed: int = None,
     ):
         self.rng = np.random.default_rng(seed=seed)
+        self.seed = seed
 
         # time properties
         self.max_time = max_time
@@ -120,12 +117,11 @@ class raw_env(ParallelEnv):
         # agent parameters
         self.ntraps_per_poacher = ntraps_per_poacher
 
-        # agents-related properties
+        # agent-related properties
         self.poachers = poachers
         self.rangers = rangers
         self.agents = self.rangers + self.poachers
         self.possible_agents = self.agents[:]
-
         self.poacher_traps = {
             poacher: [
                 Trap(name=f"trap_{i}_{poacher}")
@@ -136,118 +132,101 @@ class raw_env(ParallelEnv):
 
         # Arena-related properties
         self.grid = grid
-        grid_size, nrangers, npoachers = grid.N, len(rangers), len(poachers)
+        nrangers, npoachers = len(rangers), len(poachers)
+
+        # Convenience attributes
+        self._ranger_obs_size = 8 + nrangers
+        self._poacher_obs_size = 7
+
         # Spaces parameters.
         self.action_spaces = {
-            **{ranger: Discrete(5, seed=seed) for ranger in self.rangers},
-            **{poacher: Discrete(6, seed=seed) for poacher in self.poachers},
+            **{
+                ranger: gym.spaces.Discrete(5, seed=seed)
+                for ranger in self.rangers
+            },
+            **{
+                poacher: gym.spaces.Discrete(6, seed=seed)
+                for poacher in self.poachers
+            },
         }
+
         self.observation_spaces = {
             **{
-                poacher: Dict(
+                ranger: gym.spaces.Dict(
                     {
-                        "observations": Dict(
-                            OrderedDict(
-                                {
-                                    "remaining_time": Discrete(max_time + 1),
-                                    "state": Box(
-                                        low=np.array([-1, -1, 0, 0]),
-                                        high=np.array(
-                                            [
-                                                grid_size - 1,
-                                                grid_size - 1,
-                                                ntraps_per_poacher,
-                                                np.iinfo(np.int32).max,
-                                            ]
-                                        ),
-                                        dtype=np.int32,
-                                    ),
-                                    "ground_traps": MultiDiscrete(
-                                        [
-                                            ntraps_per_poacher + 1,
-                                            ntraps_per_poacher + 1,
-                                        ]
-                                    ),  # traps found in current cell
-                                    "rangers": Discrete(
-                                        nrangers + 1, start=0
-                                    ),  # num. of rangers detected in current cell
-                                    "poachers": Discrete(
-                                        npoachers + 1, start=0
-                                    ),  # num. of poachers detected in current cell
-                                }
-                            )
+                        "observations": gym.spaces.Box(
+                            np.zeros(self._ranger_obs_size),
+                            np.array(  # high
+                                [
+                                    max_time,  # max time
+                                    *[self.grid.N] * 2,  # location
+                                    *[1] * nrangers,  # partner rangers
+                                    npoachers,  # #captured-poachers
+                                    *[
+                                        ntraps_per_poacher * npoachers,
+                                        np.iinfo(INTEGER).max,
+                                    ],  # poacher-captured traps
+                                    *[
+                                        ntraps_per_poacher * npoachers,
+                                        np.iinfo(INTEGER).max,
+                                    ],  # grid-captured traps
+                                ]
+                            ),
+                            seed=seed,
+                            dtype=INTEGER,
                         ),
-                        "action_mask": MultiBinary(6),
-                    },
-                    seed=seed,
+                        "action_mask": gym.spaces.MultiBinary(5),
+                    }
                 )
-                for poacher in self.poachers
+                for ranger in self.rangers
             },
             **{
-                ranger: Dict(
+                poacher: gym.spaces.Dict(
                     {
-                        "observations": Dict(
-                            OrderedDict(
-                                {
-                                    "remaining_time": Discrete(max_time + 1),
-                                    "state": MultiDiscrete(
-                                        [grid_size, grid_size]
-                                    ),  # ranger state is location
-                                    "partners": MultiBinary(
-                                        nrangers,
-                                    ),  # for the list of rangers sharing the cell
-                                    "poacher_traps": MultiDiscrete(
-                                        [
-                                            npoachers * ntraps_per_poacher + 1,
-                                            max_time,  # Max number of prey
-                                        ]
-                                    ),  # For the traps and prey recovered from poacher capture
-                                    "poachers_captured": Discrete(
-                                        npoachers + 1
-                                    ),
-                                    "ground_traps": MultiDiscrete(
-                                        [
-                                            npoachers * ntraps_per_poacher + 1,
-                                        ]
-                                        * 2
-                                    ),  # For traps recovered from the grid
-                                }
-                            )
+                        "observations": gym.spaces.Box(
+                            np.array([0, *[-1] * 2, *[0] * 2, 0, 0]),
+                            np.array(
+                                [
+                                    max_time,  # max time
+                                    *[self.grid.N] * 2,  # location
+                                    ntraps_per_poacher,  # #traps
+                                    np.iinfo(INTEGER).max,  # #prey
+                                    nrangers,  # #rangers detected
+                                    npoachers,  # #poachers detected
+                                ]
+                            ),
+                            seed=seed,
+                            dtype=INTEGER,
                         ),
-                        "action_mask": MultiBinary(5),
-                    },
-                    seed=seed,
+                        "action_mask": gym.spaces.MultiBinary(6),
+                    }
                 )
-                for ranger in self.rangers
+                for poacher in self.poachers
             },
         }
-        self.total_rewards = dict.fromkeys(
-            self.agents, 0
-        )  # stores every agent's rewards in the current episode.
-        self.killed_agents = []
 
     def reset(
         self, seed: int = None, return_info: bool = False, options: dict = None
     ) -> tuple:
-        """Resets the environment for the next episode. If new configurations are to be set,
-        then we use the options dictionary as follows.
+        """Resets the environment for the next episode. If new configurations
+        are to be set, then we use the options dictionary as follows.
 
         >>> env = AntiPoachingGame( ... )
-        >>> env.reset(seed=None, options={"rangers": _, "poachers": _, "ntraps_per_poacher": _})
+        >>> env.reset(seed=123)
 
-        Here, a None `seed` means that the internal RNG of our GridState object is randomly
-        reset, and thus will generate a new position. The `options` dictionary specifies
-        the other parameters that can be changed."""
+        Here, a None `seed` means that the internal RNG of our GridState object
+        is randomly reset, and thus will generate a new starting position."""
         self.rng = np.random.default_rng(seed=seed)
 
-        # Reset the game parameters
+        # Reset the game parameters: Only override if supplied.
         self.curr_time = 0
+        seed = seed if seed else self.seed
+        options = options if options is not None else {}
 
         # pass options to grid reset:
-        options = options if isinstance(options, dict) else {}
-        seed = seed if options.get("renew_config", True) else None
         self.grid.reset(seed=seed, **options)
 
+        # Regenerate agents.
         self.agents = self.possible_agents[:]
         self.poacher_traps = {
             poacher: [
@@ -257,100 +236,37 @@ class raw_env(ParallelEnv):
             for poacher in self.poachers
         }
 
-        # Re-seed/reinitialise all spaces
-        for agent in self.agents:
-            self.observation_spaces[agent].seed(seed)
-            self.action_spaces[agent].seed(seed)
+        obs = dict.fromkeys(self.agents)  # Returning this
 
-        observations = {
-            **{
-                poacher: {
-                    "observations": {
-                        "remaining_time": self.max_time,
-                        "state": self.grid.state[poacher],
-                        "ground_traps": np.zeros(2, dtype=np.int32),
-                        "rangers": 0,
-                        "poachers": 0,
-                    },
-                    "action_mask": self.grid.permitted_movements(poacher),
-                }
-                for poacher in self.poachers
-            },
-            **{
-                ranger: {
-                    "observations": {
-                        "remaining_time": self.max_time,
-                        "state": self.grid.state[ranger],
-                        "partners": np.zeros(
-                            shape=(len(self.rangers),), dtype=np.int32
-                        ),
-                        "poacher_traps": np.zeros(2, dtype=np.int32),
-                        "poachers_captured": np.zeros(1, dtype=np.int32),
-                        "ground_traps": np.zeros(2, dtype=np.int32),
-                    },
-                    "action_mask": self.grid.permitted_movements(ranger),
-                }
-                for ranger in self.rangers
-            },
-        }
+        # Creating a default object to copy for both agents
+        _def_ranger_obs = self.observation_space("ranger_0").sample()
+        _def_poacher_obs = self.observation_space("poacher_0").sample()
 
-        # returning infos dictionary with reset
-        return observations, dict.fromkeys(self.agents, {self.curr_time})
-
-    def observe(self, agent: AgentID, record: dict) -> dict:
-        """Each agent receives observations from their current cell. These
-        are calculated deterministically based on their record, which is
-        a record of all relevant information during the transition to
-        their new state."""
-        assert (
-            "poacher" in agent or "ranger" in agent
-        ), f"Unknown agent passed as argument! {agent}"
-
-        # get the complete observation first
-        action_mask = self.grid.permitted_movements(agent)
-
-        if "poacher" in agent:
-            action_mask[5] = len(self.poacher_traps[agent]) > 0 and all(
-                action_mask[:2] >= 0
-            )  # authorise place_trap
-            obs = {
-                "remaining_time": self.max_time - self.curr_time,
-                "state": self.grid.state[agent],
-                "ground_traps": np.array(
-                    [record["trap_empty"], record["trap_full"]]
-                ),
-                "rangers": record["num_rangers"],
-                "poachers": record["num_poachers"],
-            }
-        elif "ranger" in agent:
-            # TODO: Update the rangers observation space
-            partners = np.zeros(shape=(len(self.rangers),), dtype=np.int32)
-            partners[record["partners"]] = 1
-
-            obs = {
-                "remaining_time": self.max_time - self.curr_time,
-                "state": self.grid.state[agent],
-                "partners": partners,
-                "poacher_traps": np.sum(
-                    (
-                        np.array([nempty, nfull])
-                        for poacher, (nempty, nfull) in record[
-                            "poachers_found"
-                        ].items()
-                    ),
-                    axis=1,
-                ),
-                "poachers_captured": len(record["poachers_found"].keys()),
-                "ground_traps": np.array(
-                    [record["trap_empty"], record["trap_full"]]
-                ),
-            }
+        # and zero-ing them out since most init obs are zero.
+        _def_ranger_obs["observations"] = np.zeros_like(
+            _def_ranger_obs["observations"], dtype=INTEGER
+        )
+        _def_poacher_obs["observations"] = np.zeros_like(
+            _def_poacher_obs["observations"], dtype=INTEGER
+        )
+        for agent in self.agents:
+            # Re-seed/reinitialise all spaces
+            self.observation_space(agent).seed(seed)
+            self.action_space(agent).seed(seed)
+
+            # create appropriate observations for t=0
+            _copy_obj, _size = (
+                (_def_ranger_obs, 3)
+                if "ranger" in agent
+                else (_def_poacher_obs, 5)
+            )
+            obs[agent] = deepcopy(_copy_obj)
+            obs[agent]["observations"][0] = self.max_time
+            obs[agent]["observations"][1:_size] = self.grid.state[agent]
+            obs[agent]["action_mask"] = self.grid.permitted_movements(agent)
 
-        # return parsed observations
-        return {
-            "observations": obs,
-            "action_mask": action_mask,
-        }
+        # returning infos dictionary with reset
+        return obs, dict.fromkeys(self.agents, {self.curr_time})
 
     def step(self, actions: dict) -> tuple:
         """Receives a joint action, and sends
@@ -363,77 +279,65 @@ class raw_env(ParallelEnv):
         truncations = dict.fromkeys(self.agents, False)
         infos = dict.fromkeys(self.agents, {})
 
-        # records tracks transition information that
-        # is used to create the observations of s^{t+1}.
-        records = dict.fromkeys(self.agents)
+        # ... and the obs dict, with dummy action masks for now.
+        obs = dict.fromkeys(self.agents)
         for agent in self.agents:
-            if "ranger" in agent:
-                records[agent] = {
-                    "trap_empty": 0,
-                    "trap_full": 0,
-                    "poachers_found": {
-                        poacher: [0, 0]
-                        for poacher in self.poachers  # nempty, nfull
-                    },
-                    "partners": [],
-                }
-            elif "poacher" in agent:
-                records[agent] = dict.fromkeys(
-                    ["trap_empty", "trap_full", "num_rangers", "num_poachers"],
-                    0,
-                )
+            _size = (
+                self._ranger_obs_size
+                if "ranger" in agent
+                else self._poacher_obs_size
+            )
+            obs[agent] = {
+                "observations": np.zeros(_size, dtype=INTEGER),
+                "action_mask": None,
+            }
 
-        # ... and now we run through the transitions !
+        # Now we run through the transitions ! The obs
+        # dictionary is populated by each helper function.
         # Step 1: Rangers move first
-        self._rangers_move(actions, records)
+        self._rangers_move(actions, obs)
 
         # Step 2: Poachers move and remove their traps
-        self._poachers_move_and_get_traps(actions, rewards, records)
+        self._poachers_move_and_get_traps(actions, rewards, obs)
 
         # Step 3: Rangers remove traps and remaining traps capture animals
-        self._rangers_remove_traps(rewards, records)
+        self._rangers_remove_traps(rewards, obs)
         self._traps_catch_animals()
 
         # Step 4: Rangers remove poachers and remaining poachers place traps
-        self._rangers_remove_poachers(rewards, terminations, records)
-        self._poachers_place_traps(actions, terminations)
+        self._rangers_remove_poachers(rewards, obs)
+        self._poachers_place_traps(actions)
 
-        # update terminations: dict, rewards: dict, truncations and infos for the next step.
+        # update the obs for the agents with time, new state and action masks
         time_status = self.curr_time >= self.max_time
         for agent in self.agents:
+            obs[agent]["observations"][0] = self.max_time - self.curr_time
+            if "ranger" in agent:
+                obs[agent]["observations"][1:3] = self.grid.state[agent]
+                obs[agent]["action_mask"] = self.grid.permitted_movements(
+                    agent
+                )
+            elif "poacher" in agent:
+                _action_mask = self.grid.permitted_movements(agent)
+                _action_mask[5] = int(len(self.poacher_traps[agent]) > 0)
+                obs[agent]["observations"][1:5] = self.grid.state[agent]
+                obs[agent]["action_mask"] = _action_mask
+
+            # update terminations for next step. Note that all agents are
+            # technically alive until max_time: captured poachers are just
+            # in a captured state.
             terminations[agent] |= time_status
-            self.total_rewards[agent] += rewards[agent]
-
-        # update the list of killed agents this round
-        self.killed_agents = [
-            agent
-            for agent in self.agents
-            if terminations[agent] or truncations[agent]
-        ]
-
-        # calculate info for next step.
-        observations = {
-            agent: self.observe(agent, records[agent]) for agent in self.agents
-        }
 
-        # update the arena
-        self._cleanup()
-
-        return (
-            observations,
-            rewards,
-            terminations,
-            truncations,
-            infos,
-        )
+        # Agents are terminated on the last step.
+        self.agents = [] if time_status else self.agents
+        return obs, rewards, terminations, truncations, infos
 
     def render(self):
         """Rendering the grid using the GridState object"""
         self.grid.render()
 
     def state(self) -> dict:
-        """Grid State is stored in the GridState class as
-        the `grid` attribute."""
+        """State is stored as the `grid: GridState` attribute."""
         return self.grid.state
 
     def _assign_reward(
@@ -442,49 +346,36 @@ class raw_env(ParallelEnv):
         """Assigns the reward to poacher, and splits the reward among the
         cooperative rangers. A positive reward adds to poacher and
         removes proportionally from all rangers."""
-        assert rewards is not None, "No reward dict to update!"
-
-        # If assigning a 'reward' to an active player.
-        # This arises because rangers can recover a trap of a
-        # captured agent. Since he is captured, he is not
-        # in rewards, but will always be in total_rewards.
-        if poacher in rewards:
-            rewards[poacher] += reward
-        else:
-            logging.warn(
-                f"AntiPoachingGame: Assigning inactive {poacher} a reward of {reward} !!!"
-            )
-            self.total_rewards[poacher] += reward
-
+        rewards[poacher] += reward
         for ranger in self.rangers:
             rewards[ranger] -= reward / len(self.rangers)
 
-    def _rangers_move(self, actions: dict, records: dict) -> None:
-        """Helper function to move the rangers and update their records
+    def _rangers_move(self, actions: dict, obs: dict) -> None:
+        """Helper function to move the rangers and update their obs
         with detected partners in the same cell."""
         for ranger in [r for r in self.rangers if 1 <= actions[r] <= 4]:
             self.grid.update_position(ranger, actions[ranger])
             for nbor_ranger in [
                 r
-                for r in self.rangers
-                if r != ranger
-                and all(self.grid.state[r] == self.grid.state[ranger])
+                for r in self.grid.get_neighbours(ranger)
+                if r != ranger and "ranger" in r
             ]:
-                # add the ranger number to the records
-                records[ranger]["partners"].append(
-                    int(nbor_ranger.split("_")[-1])
-                )
+                # add the ranger number to the obs
+                obs[ranger]["observations"][
+                    2 + int(nbor_ranger.split("_")[-1])
+                ] = 1
 
     def _poachers_move_and_get_traps(
-        self, actions: dict, rewards: dict, records: dict
+        self, actions: dict, rewards: dict, obs: dict
     ) -> None:
         """Helper function: Moves the poachers according to actions
         and removes their traps (if found) on the next step."""
         for poacher in [
             p
             for p in self.poachers
-            if p in self.agents and 0 <= actions[p] <= 4
+            if 0 <= actions[p] <= 4 and self.grid.state[p][0] >= 0
         ]:
+            # Skipping over captured poachers ...
             self.grid.update_position(poacher, actions[poacher])
             for _trap in [
                 _t
@@ -502,20 +393,18 @@ class raw_env(ParallelEnv):
                     poacher, self.grid.remove_trap(_trap), rewards
                 )
 
-                # update trap records, and
+                # update trap obs, and
                 # reset trap value when recovered.
-                key = "trap_empty" if _trap.value == 0 else "trap_full"
-                records[poacher][key] += 1
+                key = 3 if _trap.value == 0 else 4
+                obs[poacher]["observations"][key] += 1
                 _trap.value = 0
 
-            # Update records for positions as well
+            # Update obs for positions as well
             for nbor in self.grid.get_neighbours(poacher):
-                if "poacher" in nbor:
-                    records[poacher]["num_poachers"] += 1
-                else:
-                    records[poacher]["num_rangers"] += 1
+                key = 5 if "ranger" in nbor else 6
+                obs[poacher]["observations"][key] += 1
 
-    def _rangers_remove_traps(self, rewards: dict, records: dict):
+    def _rangers_remove_traps(self, rewards: dict, obs: dict):
         """Helper function where rangers detect and
         remove traps in their current cells.
         Note that detection depends on self.prob_detect_trap."""
@@ -523,8 +412,8 @@ class raw_env(ParallelEnv):
         # First all agents detect traps.
         traps_detected = set()  # Multiple agents can detect same trap
         for _trap in [_t for _t in self.grid.state if isinstance(_t, Trap)]:
+            _loc = self.grid.state[_trap]  # also the nbor.rangers locations
             for ranger in self.grid.get_neighbours(_trap):
-                _loc = self.grid.state[ranger]
                 if self.rng.random() < self.grid.prob_detect_trap(_loc):
                     traps_detected.add(_trap)
 
@@ -532,67 +421,67 @@ class raw_env(ParallelEnv):
             # Extract the owning poacher name
             _poacher = "_".join(_trap.name.split("_")[-2:])
 
-            # updating the records for all implicated rangers
-            key = "trap_empty" if _trap.value == 0 else "trap_full"
+            # updating the obs for all implicated rangers
+            # If trap value is zero(before reset/capture, it was empty)
+            key = -2 if _trap.value == 0 else -1
             for ranger in self.grid.get_neighbours(_trap):
-                records[ranger][key] += 1
+                obs[ranger]["observations"][key] += 1
 
-            # Assign rewards to Rangers
+            # Assign rewards to Rangers. Includes trap removal logic.
             self._assign_reward(
                 _poacher, -self.grid.remove_trap(_trap), rewards
             )
 
-    def _rangers_remove_poachers(
-        self, rewards: dict, terminations: dict, records: dict
-    ):
-        """Helper function where rangers detect and
-        remove poachers in their current cell.
-        Note that detection depends on self.prob_detect_cell"""
+    def _rangers_remove_poachers(self, rewards: dict, obs: dict):
+        """Helper function where rangers detect and remove poachers in their
+        current cell. Detection depends on self.prob_detect_cell"""
 
         # First mark all captured poachers
-        caught_poachers = []
+        caught_poachers = set()
         for _poacher in self.poachers:
-            if _poacher not in self.agents or terminations[_poacher]:
+            if self.grid.state[_poacher][1] < 0:
                 continue  # Agent is already caught, skip
-            for _ranger in [  # Otherwise, try detecting
+            _loc = self.grid.state[_poacher]
+            for _ranger in [
                 _r
                 for _r in self.grid.get_neighbours(_poacher)
                 if "ranger" in _r
+                and self.rng.random() < self.grid.prob_detect_cell(_loc)
             ]:
-                _loc = self.grid.state[_ranger]
-                if self.rng.random() < self.grid.prob_detect_cell(_loc):
-                    caught_poachers.append(_poacher)
+                caught_poachers.add(_poacher)  # Poacher detected.
 
-        # update their status, and all ranger records
+        # update their status, and all ranger obs
         for _poacher in caught_poachers:
-            terminations[_poacher] = True
-            # EXPERIMENTAL: penalty for Poachers for each prey recovered
-            penalty = self.grid.state[_poacher][-1] * self.grid.REWARD_MAP[
-                "PREY_FOUND"
-            ] + self.grid.remove_poacher(_poacher)
+            penalty = (
+                self.grid.remove_poacher(_poacher)  # C_capture
+                + self.grid.state[_poacher][-1]  # C_prey
+                * self.grid.REWARD_MAP["PREY_FOUND"]
+                + self.grid.state[_poacher][-2]  # C_trap
+                * self.grid.REWARD_MAP["TRAP_FOUND"]
+            )
             self._assign_reward(_poacher, -penalty, rewards)
 
-            # updating records for implicated rangers
+            # Update the trap and prey status for implicated rangers:
+            # First increment #of caught poachers, then #traps/prey captured.
             for _ranger in [
                 _r
                 for _r in self.grid.get_neighbours(_poacher)
                 if "ranger" in _r
             ]:
-                # Update the trap and prey status
-                records[_ranger]["poachers_found"][
+                obs[_ranger]["observations"][-5] += 1
+                obs[_ranger]["observations"][-4:-2] += self.grid.state[
                     _poacher
-                ] += self.grid.state[_poacher][2:]
+                ][2:]
 
-    def _poachers_place_traps(self, actions: dict, terminations: dict):
-        """Helper function where poachers place traps.
-        Note that this will not succeed if poacher has
-        no traps to place."""
+    def _poachers_place_traps(self, actions: dict):
+        """Helper function where poachers place traps. Note that this will not
+        succeed if poacher has no traps to place."""
         for poacher in [
             p
             for p in self.poachers
-            if (p in self.agents) and (not terminations[p]) and actions[p] == 5
+            if self.grid.state[p][1] >= 0 and actions[p] == 5
         ]:
-            trap = self.poacher_traps[poacher].pop()
+            trap = self.poacher_traps[poacher].pop()  # Will throw if empty.
             self.grid.add_trap(trap, self.grid.state[poacher][0:2])
             self.grid.state[poacher][2] -= 1
             assert self.grid.state[poacher][2] == len(
@@ -605,19 +494,10 @@ class raw_env(ParallelEnv):
         for trap in [t for t in self.grid.state if isinstance(t, Trap)]:
             _loc = self.grid.state[trap]
             if (
-                self.rng.random()
-                < self.grid.prob_animal_appear(_loc) * trap.efficiency
+                self.rng.random() < self.grid.prob_animal_appear(_loc)
                 and trap.value == 0
             ):
-                trap.value += 1
-
-    def _cleanup(self) -> None:
-        """Update the set of live agents, and their position
-        on the grid."""
-        for agent in self.killed_agents:
-            # clean GridState lists
-            del self.grid.state[agent]
-            self.agents.remove(agent)
+                trap.value = 1
 
     @functools.lru_cache(maxsize=None)
     def observation_space(self, agent):
diff --git a/anti_poaching/env/utils/flatten_utils.py b/anti_poaching/env/utils/flatten_utils.py
deleted file mode 100644
index 500f182ff4513ef3882cfc1cc9cf58015d32574a..0000000000000000000000000000000000000000
--- a/anti_poaching/env/utils/flatten_utils.py
+++ /dev/null
@@ -1,132 +0,0 @@
-"""Module to implement useful functions like box_flatten 
-and box_flatten_obs. These will be used in the
-NonCategoricalFlatten wrapper, as defined in `wrappers.py`.
-"""
-
-from functools import singledispatch
-import numpy as np
-import gymnasium
-from gymnasium.spaces import (
-    Box,
-    Dict,
-    Discrete,
-    MultiDiscrete,
-    MultiBinary,
-)
-
-
-# Custom Flattening logic
-@singledispatch
-def box_flatten(space: gymnasium.Space, dtype=None) -> Box:
-    """
-    SingleDispatch function to recursively fold a
-    space into a Box: Currently supports only Box,
-    Discrete, MultiDiscrete and Dict spaces.
-    """
-    raise NotImplementedError(f"Unknown/Unsupported space: {space=}")
-
-
-@box_flatten.register(Dict)
-def flatten_dict(space: Dict, dtype=None) -> Box:
-    """
-    Function to fold a Dict space into a Box.
-    """
-    # Recursively flatten all the spaces in the keys
-    list_boxed_spaces = [box_flatten(subsp) for subsp in space.values()]
-
-    # return the Box with the concatenated shapes
-    return Box(
-        low=np.concatenate([subsp.low for subsp in list_boxed_spaces]),
-        high=np.concatenate([subsp.high for subsp in list_boxed_spaces]),
-        dtype=np.int32,
-    )
-
-
-@box_flatten.register(Discrete)
-def flatten_discrete(space: Discrete, dtype=None) -> Box:
-    """Flatten a discrete box, but not as a categorical
-    space. We consider Discrete as a Box of ints.
-    """
-    return Box(
-        low=0, high=space.n - 1, dtype=space.dtype if not dtype else dtype
-    )
-
-
-@box_flatten.register(MultiDiscrete)
-def flatten_multidiscrete(space: MultiDiscrete, dtype=None) -> Box:
-    """Flatten a MultiDiscrete box, but not as a categorical
-    space. We consider MultiDiscrete as a multi-dimensional
-    Box of ints.
-    """
-    return Box(
-        low=np.zeros_like(space.nvec).flatten(),
-        high=np.array(space.nvec).flatten() - 1,
-        dtype=space.dtype if not dtype else dtype,
-    )
-
-
-@box_flatten.register(Box)
-def flatten_box(space: Box, dtype=None) -> Box:
-    """Flatten of a box should just return the box if it is 1D,
-    else the equivalent of running np.flatten on its samples."""
-    return Box(
-        low=np.array(space.low).flatten(),
-        high=np.array(space.high).flatten(),
-        dtype=space.dtype if not dtype else dtype,
-    )
-
-
-@box_flatten.register(MultiBinary)
-def flatten_multibinary(space: MultiBinary, dtype=None) -> Box:
-    """Convert a Binary to a Box"""
-    return Box(
-        low=0,
-        high=1,
-        shape=np.array(space.shape).flatten(),
-        dtype=space.dtype if not dtype else dtype,
-    )
-
-
-def box_flatten_obs(obs, dtype=None) -> np.array:
-    """
-    Box flattens an observation recursively.
-    """
-    if isinstance(obs, dict):
-        # Recursively fold.
-        return np.concatenate(
-            [box_flatten_obs(val).flatten() for val in obs.values()]
-        )
-
-    dtype = np.int32 if not dtype else dtype  # Default to integer type
-    if isinstance(obs, (np.ndarray, list)) or np.isscalar(obs):
-        # Return the flattened version
-        return np.array(obs, dtype=dtype).flatten()
-
-    raise NotImplementedError(f"Unknown observation sent: {obs}")
-
-
-# Custom unflattening logic
-def sizeof_space(space: gymnasium.Space):
-    """gets the number of elements in the space"""
-    if isinstance(space, Dict):
-        return np.sum([sizeof_space(subsp) for subsp in space.values()])
-    return np.prod(np.array(space.shape).flatten().astype(np.int64))
-
-
-def parse_obs(obs: np.array, space: gymnasium.Space):
-    """Parse the observation into the original space."""
-
-    if isinstance(space, (Box, Discrete, MultiDiscrete, MultiBinary)):
-        return obs.reshape(space.shape).astype(space.dtype)
-    if not isinstance(space, Dict):
-        raise NotImplementedError
-
-    # Now to implement parsing for dictionaries
-    begin, end = 0, 0
-    new_obs = dict.fromkeys(space.keys())
-    for key, subsp in space.items():
-        end = begin + sizeof_space(subsp)  # Define the chunk to parse
-        new_obs[key] = parse_obs(obs[begin:end], subsp)
-        begin = end  # update the begin position for next subspace
-
-    return new_obs
diff --git a/anti_poaching/env/utils/game_utils.py b/anti_poaching/env/utils/game_utils.py
index 0d25452cee00f4e60543f8f3208689975933b0fd..dc0a04dfc157fd591adb45fd709d5403f78ac3cf 100644
--- a/anti_poaching/env/utils/game_utils.py
+++ b/anti_poaching/env/utils/game_utils.py
@@ -9,6 +9,8 @@ from dataclasses import dataclass
 import numpy as np
 import pygame
 
+from .typing import *
+
 FPS = 10
 BLK_SIZE = 30
 
@@ -27,7 +29,6 @@ class Trap:
     an animal has been caught in it."""
 
     name: str = "Trap"
-    efficiency: float = 0.40
     value: float = 0
 
     def __hash__(self):
@@ -40,10 +41,10 @@ class Trap:
 
 
 class BaseGridState(ABC):
-    """Abstract Class to represent the game state for the Anti-Poaching Game."""
+    """AbstractClass to represent game state of an Anti-Poaching Game."""
 
     REWARD_MAP: dict = {"TRAP_FOUND": 2, "PREY_FOUND": 2, "POACHER_FOUND": 100}
-    NULL_POS: np.array = np.array([-1, -1, 0, 0], dtype=np.int16)
+    NULL_POS: np.array = np.array([-1, -1, 0, 0], dtype=INTEGER)
     RENDER_MODES = ["ansi", "rgb"]
 
     def __init__(
@@ -97,10 +98,9 @@ class BaseGridState(ABC):
         raise NotImplementedError
 
     def get_neighbours(self, agent: str) -> tuple:
-        """Given an agent, fetch agents in the same cell.
-        If supplied a Trap, this will fetch all non-Trap
-        objects (i.e. agents) which are in the same cell.
-        """
+        """Given an agent, fetch agents in the same cell. If supplied a Trap,
+        this will fetch all non-Trap objects (i.e. agents) which are in the
+        same cell."""
         agent_pos = self.state[agent][0:2]
         return [
             _a
@@ -117,13 +117,12 @@ class BaseGridState(ABC):
         """Tries to remove an agent from the supplied position.
         Returns a positive reward if successful, else throws an exception."""
         assert poacher in self.state, f"Unauthorised remove !!!\n{self}"
-        self.state[poacher] = self.NULL_POS
+        self.state[poacher] = deepcopy(self.NULL_POS)
         return self.REWARD_MAP["POACHER_FOUND"]
 
     def remove_trap(self, trap: Trap) -> float:
-        """Tries to remove a Trap from the supplied position.
-        Returns the trap's value if successful, else throws an exception."""
-        assert trap in self.state, f"Unauthorised remove !!!\n{self}"
+        """Removes a Trap from the grid."""
+        assert trap in self.state, f"Unauthorised trap remove !!!\n{self}"
         del self.state[trap]
         return trap.value
 
@@ -132,7 +131,8 @@ class BaseGridState(ABC):
         seed: int = None,
         **options,
     ) -> None:
-        """Resets the grid to initial position"""
+        """Resets the grid to initial position. This should use the seed
+        supplied to the parent anti_poaching.raw_env instance."""
         self.rng = np.random.default_rng(seed)
 
         # Get options from the kwargs, and reset
@@ -149,15 +149,11 @@ class BaseGridState(ABC):
         poachers: list,
         ntraps_per_poacher: int,
     ) -> None:
-        """Helper function to reinitialise the state
-        using the RNG engine. This replaces the use of self.init_state,
-        since the same seed generates the same initial state."""
-
-        # Regenerate state using the RNG
+        """Helper function to reinitialise the state using the RNG engine"""
         self.state = {
             **{
                 ranger: self.rng.integers(
-                    low=0, high=self.N, size=2, dtype=np.int16
+                    low=0, high=self.N, size=2, dtype=INTEGER
                 )
                 for ranger in rangers
             },
@@ -165,10 +161,10 @@ class BaseGridState(ABC):
                 poacher: np.concatenate(
                     (
                         self.rng.integers(
-                            low=0, high=self.N, size=2, dtype=np.int16
+                            low=0, high=self.N, size=2, dtype=INTEGER
                         ),  # random location
                         np.array(
-                            [ntraps_per_poacher, 0], dtype=np.int16
+                            [ntraps_per_poacher, 0], dtype=INTEGER
                         ),  # (..., ntraps, npreys)
                     )
                 )
@@ -198,6 +194,8 @@ class BaseGridState(ABC):
                 pygame.draw.rect(self.screen, WHITE, rect, 1)
 
         for obj, pos in self.state.items():
+            if pos[0] < 0:
+                continue
             width = 0
             if isinstance(obj, Trap):
                 color = GREY
@@ -214,9 +212,15 @@ class BaseGridState(ABC):
         """Rendering as text"""
         print("\nCurrent state:")
         print(f"\tGrid size: {self.N}")
+
         printgrid = {(i, j): [] for i in range(self.N) for j in range(self.N)}
-        for thing, loc in self.state.items():
-            printgrid[tuple(loc[0:2].tolist())].append(thing)
+        for thing, pos in self.state.items():
+            try:
+                if pos[0] < 0:
+                    continue
+                printgrid[tuple(pos[0:2].tolist())].append(thing)
+            except KeyError:
+                breakpoint()
 
         for i in range(self.N):
             for j in range(self.N):
@@ -230,10 +234,7 @@ class GridStateConstProb(BaseGridState):
     State is represented as a dictionary of player-(location-misc)
     pairs.
 
-    v3. Promoted to GridStateConstProb - this is an abstract class.
-        Also takes over the responsibility of probabilties in
-        the grid. This allows us to define custom GridState objects,
-        especially when using custom probabilities.
+    v3. Promoted to GridStateConstProb - this implements BaseGridState.
 
     v2. Deprecates the grid dictionary for simpler code.
         The GridEnv class is now upgraded to the GridState class,
@@ -293,26 +294,23 @@ class GridStateConstProb(BaseGridState):
         """For an agent, return the permitted actions she can
         take in her current cell. Returns an numpy.array that
         represents an action mask."""
-        if "ranger" in agent:
-            size_action_mask = 5
-        elif "poacher" in agent:
-            size_action_mask = 6
-        action_mask = np.zeros(size_action_mask, dtype=np.int8)
+        size_action_mask = 5 if "ranger" in agent else 6
+        action_mask = np.zeros(size_action_mask, dtype=SHORT)
         action_mask[0] = 1  # always validate null-action
 
         # Find the agent position
         x, y = self.state[agent][0:2]
         if x == -1 or y == -1:
-            return action_mask  # If agent is dead
+            return action_mask  # If agent is captured, nothing further.
 
-        # calculate permitted actions
+        # calculate permitted actions for "live" agents
         UP = 0 if x <= 0 else 1
         DOWN = 0 if x >= self.N - 1 else 1
         LEFT = 0 if y <= 0 else 1
         RIGHT = 0 if y >= self.N - 1 else 1
 
         # Only setting movement flags. Further actions
-        # are set by the AntiPoachingGame.observe fn.
+        # are set by the anti_poaching.raw_env instance.
         action_mask[1:5] = [UP, LEFT, DOWN, RIGHT]
         return action_mask
 
@@ -334,10 +332,12 @@ class GridStateConstProb(BaseGridState):
 
         # Throw error if invalid position generated
         # since it needed an invalid action
-        assert self.valid_pos(
-            new_pos
-        ), f"update_position({agent}): Invalid action. {(x,y)=} -> {action=} -> {new_pos=}"
-
+        try:
+            assert self.valid_pos(
+                new_pos
+            ), f"update_position({agent}): Invalid action. {(x,y)=} -> {action=} -> {new_pos=}, grid size = {self.N}"
+        except AssertionError:
+            breakpoint()
         # update grid and state list
         self.state[agent][0:2] = new_pos
 
@@ -427,7 +427,7 @@ class GridStateVaryingProb(BaseGridState):
             size_action_mask = 5
         elif "poacher" in agent:
             size_action_mask = 6
-        action_mask = np.zeros(size_action_mask, dtype=np.int8)
+        action_mask = np.zeros(size_action_mask, dtype=SHORT)
         action_mask[0] = 1  # always validate null-action
 
         # Find the agent position
diff --git a/anti_poaching/env/utils/typing.py b/anti_poaching/env/utils/typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1bdb3c59e93cd2a7b56171a524b5fc07277764c
--- /dev/null
+++ b/anti_poaching/env/utils/typing.py
@@ -0,0 +1,7 @@
+"""Module to standardise custom types used for the env and other scripts"""
+
+import numpy as np
+
+INTEGER = np.int32  # Core integer type
+SHORT = np.int8  # Action-mask types
+AgentID = str  # Agents are strings
diff --git a/anti_poaching/env/utils/wrappers.py b/anti_poaching/env/utils/wrappers.py
index 419962f8f416e9043573e417974046bef58bf2da..90761d00c45aeb242be63befd88de9b73d305486 100644
--- a/anti_poaching/env/utils/wrappers.py
+++ b/anti_poaching/env/utils/wrappers.py
@@ -1,6 +1,5 @@
-"""Module that implements the NonCategoricalFlatten and
-Stacker Wrappers that ConservationGame needs.
-"""
+"""Module that implements the wrappers to make the Anti-Poaching game
+compatible with RLlib."""
 
 import numpy as np
 
@@ -12,241 +11,56 @@ from pettingzoo.utils.env import ParallelEnv
 # Ray imports
 from ray.rllib.env.multi_agent_env import MultiAgentEnv
 
-# local imports
-from .flatten_utils import box_flatten, box_flatten_obs
-
-
-class NonCategoricalFlatten(ObservationWrapper):
-    """
-    Class that takes an environment and flattens the
-    observation without using Categorical Encoding.
-    This is currently implemented for Discrete,
-    MultiDiscrete, Box and Dict environments only for
-    ParallelEnv environments.
-    """
-
-    def __init__(self, env):
-        ObservationWrapper.__init__(self, env)
-
-        # Store new space in self._observation_spaces
-        self.observation_spaces = {
-            agent: Dict(
-                {
-                    "action_mask": box_flatten(
-                        env.observation_spaces[agent]["action_mask"],
-                    ),
-                    "observations": box_flatten(
-                        env.observation_spaces[agent]["observations"]
-                    ),
-                }
-            )
-            for agent in self.env.agents
-        }
-
-    def observation(self, observation):
-        """
-        Applies the modification to all observations
-        sent out by step and reset i.e.
-            env.step -> observation(obs), rewards, ...
-            env.reset-> observation(obs), info
-        """
-        return {
-            agent: {
-                "action_mask": box_flatten_obs(
-                    observation[agent]["action_mask"],
-                    dtype=np.int8,
-                ),
-                "observations": box_flatten_obs(
-                    observation[agent]["observations"]
-                ),
-            }
-            for agent in observation
-        }
-
-    def reset(self, *, seed=None, options=None) -> tuple:
-        """
-        Modifies the :attr:`env` after calling :meth:`reset`,
-        returning a modified observation using :meth:`self.observation`.
-        """
-        reset_obs, infos = self.env.reset(seed=seed, options=options)
-        return self.observation(reset_obs), infos
-
-    def observation_space(self, agent):
-        """
-        Overrides the observation_space method of
-        the environment and sends the modified spaces
-        """
-        return self.observation_spaces[agent]
-
-    @property
-    def unwrapped(self):
-        """Returns the unwrapped environment"""
-        _unwrapped = getattr(self.env, "unwrapped", self.env)
-        if callable(_unwrapped):
-            return _unwrapped()
-        return _unwrapped
-
-
-class StackerWrapper(MultiAgentEnv):
-    """A wrapper for PettingZoo ParallelEnv's where each agent has distinct
-    action and observation spaces. This wrapper stacks all these spaces into
-    single dictionaries keyed by the agent IDs. For example, consider
-
-        agent0: obs=(10,), act=Discrete(2)
-        agent1: obs=(20,), act=Discrete(3)
-
-    Then, we construct an action_space as the dictionary
-
-        action_space = Dict( {agent0: ..., agent1: ... } )
-
-    Similarly, we construct the observation_space variable. This allows us to
-    not lose semantic meaning when we define custom wrappers, and hopefully
-    allow RLlib to optimise happily :)
-    """
-
-    def __init__(self, env: ParallelEnv):
-        super().__init__()
-        self.env = env
-        self.agents = env.agents
-        self._agent_ids = set(self.agents)
-
-        # Provide full (preferred format) observation- and action-spaces as Dicts
-        # mapping agent IDs to the individual agents' spaces.
-        self._obs_space_in_preferred_format = True
-        self.observation_space = Dict(
-            {agent: self.env.observation_space(agent) for agent in self.agents}
-        )
-
-        self._action_space_in_preferred_format = True
-        self.action_space = Dict(
-            {agent: self.env.action_space(agent) for agent in self.agents}
-        )
-
-    def reset(self, *, seed=None, options=None):
-        """Resets the base environment"""
-        return self.env.reset(seed=seed, options=options)
-
-    def step(self, action_dict):
-        """
-        Completely delegate responsibility to ParallelEnv, since it returns the
-        same format of obs, rewards, terminations, truncations, infos.
-        """
-        # Get environment to step
-        obs, rew, terminateds, truncateds, info = self.env.step(action_dict)
-
-        # RLlib-specific keys
-        terminateds["__all__"] = all(terminateds.values())
-        truncateds["__all__"] = all(truncateds.values())
-
-        # return processed dictionaries
-        return obs, rew, terminateds, truncateds, info
-
-    @property
-    def unwrapped(self):
-        """Returns the unwrapped environment"""
-        _unwrapped = getattr(self.env, "unwrapped", self.env)
-        if callable(_unwrapped):
-            return _unwrapped()
-        return _unwrapped
-
 
 class QMIXCompatibilityLayer(MultiAgentEnv):
-    """
-    Class that renames the "observations" to "obs",
-    and pads with observations of terminated agents.
-    This is for compatibility with QMIX.
-    """
+    """Class that renames the "observations" to "obs". This is for
+    compatibility with QMIX."""
 
     def __init__(self, env):
         super().__init__()
         self.env = env
-        # Store new space in self._observation_spaces
         self._obs_space_in_preferred_format = True
         self.observation_space = {
             agent: Dict(
                 {
-                    "action_mask": env.observation_space[agent]["action_mask"],
-                    "obs": env.observation_space[agent]["observations"],
+                    "action_mask": env.observation_space(agent)["action_mask"],
+                    "obs": env.observation_space(agent)["observations"],
                 }
             )
-            for agent in self.env.agents
+            for agent in env.agents
         }
 
         self._action_space_in_preferred_format = True
         self.action_space = env.action_space
 
-    def observation(self, obs, missing_keys: set = None):
-        """
-        Applies the modification to all observations
-        sent out by step and reset i.e.
+    def observation(self, obs: dict):
+        """Applies the modification to all observations sent out by step
+        and reset i.e.
             env.step -> observation(obs), rewards, ...
             env.reset-> observation(obs), info
         """
-        # QMIX cannot handle disappearing agents :((
-        if missing_keys is None:
-            missing_keys = set()
-
         for agent in self.observation_space.keys():
-            if agent in missing_keys:
-                # Need to pad for this agent + rename
-                obs[agent] = self.env.observation_space[agent].sample()
-                obs[agent]["action_mask"] = np.array(
-                    [1, 0, 0, 0, 0, 0], dtype=np.int8
-                )  # First action is always valid + only poachers die
-
-                # Copy time and set invalid state(location). Zero everything else.
-                obs[agent]["obs"] = np.zeros_like(obs[agent]["observations"])
-                obs[agent]["obs"][0:3] = [
-                    self.env.unwrapped.max_time - self.env.unwrapped.curr_time,
-                    -1,
-                    -1,
-                ]
-                del obs[agent]["observations"]  # Delete the invalid key
-            else:
-                # Only rename for this agent, if not killed yet
-                obs[agent] = {
-                    "action_mask": obs[agent]["action_mask"],
-                    "obs": obs[agent]["observations"],
-                }
-
-        # Return parsed obs
+            obs[agent]["obs"] = obs[agent].pop("observations")
         return obs
 
     def step(self, action_dict: dict) -> tuple:
         """Pads each returned dictionary to contain dead agents as well"""
-        # Terms and truncs will be overwritten anyway
-        obs, rewards, _, _, infos = self.env.step(action_dict)
-
-        missing_keys = set(
-            key for key in self.observation_space.keys() if key not in rewards
-        )
-        obs = self.observation(obs, missing_keys)
-        rewards.update(dict.fromkeys(missing_keys, 0))
-        infos.update(dict.fromkeys(missing_keys, {}))
-
-        # Only final step must be True for EpisodeV2 compatibility
-        terms = dict.fromkeys(
-            self.observation_space.keys(),
-            self.env.unwrapped.max_time == self.env.unwrapped.curr_time,
-        )
-        truncs = dict.fromkeys(terms.keys(), False)
+        obs, rewards, terms, truncs, infos = self.env.step(action_dict)
 
         # RLlib-specific keys
         terms["__all__"] = all(terms.values())
         truncs["__all__"] = all(truncs.values())
 
-        return obs, rewards, terms, truncs, infos
+        return self.observation(obs), rewards, terms, truncs, infos
 
     def reset(self, *, seed=None, options=None) -> tuple:
-        """
-        Modifies the :attr:`env` after calling :meth:`reset`,
-        returning a modified observation using :meth:`self.observation`.
-        """
+        """Modifies the :attr:`env` after calling :meth:`reset`,
+        returning a modified observation using :meth:`self.observation`."""
         reset_obs, infos = self.env.reset(seed=seed, options=options)
         return self.observation(reset_obs), infos
 
     @property
-    def unwrapped(self):
+    def get_sub_environments(self):
         """Returns the unwrapped environment"""
         _unwrapped = getattr(self.env, "unwrapped", self.env)
         if callable(_unwrapped):
diff --git a/examples/manual_policies/fixed_policy.py b/examples/manual_policies/fixed_policy.py
index 4ed8befb5e7ae59535bb2354211b3d38c920ab43..412a8b66e805850bf6a4f4a2b61a5b26e6ffac34 100644
--- a/examples/manual_policies/fixed_policy.py
+++ b/examples/manual_policies/fixed_policy.py
@@ -111,6 +111,3 @@ if __name__ == "__main__":
     print("\n GAME OVER !\n")
     print(cg.curr_time, " is the current time")
     print(cg.max_time, " is the maximum time")
-    print("Rewards: ")
-    for agent in cg.possible_agents:
-        print(agent, " has ", cg.total_rewards[agent])
diff --git a/examples/manual_policies/random_policy.py b/examples/manual_policies/random_policy.py
index 75bf33bab721619cd82beb19709503580f2bd91a..d252ea619d860152263ef48e6f10a88e53693a7c 100644
--- a/examples/manual_policies/random_policy.py
+++ b/examples/manual_policies/random_policy.py
@@ -59,6 +59,3 @@ if __name__ == "__main__":
     print("\n GAME OVER !\n")
     print(cg.curr_time, " is the current time")
     print(cg.max_time, " is the maximum time")
-    print("Rewards: ")
-    for agent in cg.possible_agents:
-        print(agent, " has ", cg.total_rewards[agent])
diff --git a/examples/rllib_examples/callbacks.py b/examples/rllib_examples/callbacks.py
index 6c7f9e8bbfecbd8ce229e124f2be987812dd02dc..f94b461f0d82c48a2863e60fcea518c50f72f697 100644
--- a/examples/rllib_examples/callbacks.py
+++ b/examples/rllib_examples/callbacks.py
@@ -26,6 +26,20 @@ from ray.rllib.evaluation import Episode, RolloutWorker
 from ray.rllib.algorithms.callbacks import DefaultCallbacks
 
 
+class SumRangerRewardMetric(DefaultCallbacks):
+    """Callback to compute a custom metric: the sum of all ranger
+    rewards. This is stored as a custom metric: rangers_mean_reward_sum"""
+
+    def on_train_result(self, *, algorithm, result: dict, **kwargs):
+        result["custom_metrics"]["rangers_mean_reward_sum"] = np.sum(
+            [
+                v
+                for k, v in result["policy_reward_mean"].items()
+                if "ranger" in k
+            ]
+        )
+
+
 class RestoreNonTrainingAgents(DefaultCallbacks):
     """
     Restores the weights of an algorithm on init
@@ -121,7 +135,7 @@ class EpisodeMetricsCallbacks(DefaultCallbacks):
         # Get base environment
         env = base_env._unwrapped_env
 
-        # cjjollect data
+        # collect data
         # First collecting number of traps held by each poacher.
         for poacher in env.poachers:
             episode.custom_metrics["num_traps_" + poacher].append(
diff --git a/examples/rllib_examples/configs.py b/examples/rllib_examples/configs.py
index 25a0aa9c2855694c7f890331b78ac27978faf910..59eb47a44b766b50631925416fab1d1743297fca 100644
--- a/examples/rllib_examples/configs.py
+++ b/examples/rllib_examples/configs.py
@@ -16,8 +16,7 @@ from ray.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper
 # Importing my enviroment
 from anti_poaching.anti_poaching_v0 import (
     anti_poaching,
-    rllib_cons_game,
-    StackerWrapper,
+    grouped_rllib_cons_game,
 )
 
 # Constant batch size for all algorithms
@@ -27,7 +26,7 @@ _TRAIN_BATCH_SIZE = _NUM_EPISODES_PER_ITER * _LEN_EPISODE
 
 
 def generate_training_config(
-    env: StackerWrapper,
+    env: "AntiPoaching",
     env_config: dict,
     policies_to_train: list = None,
     hyperparams: dict = None,
@@ -57,12 +56,7 @@ def generate_training_config(
             # not desired. !!!
             disable_env_checking=True,
         )
-        .experimental(
-            # Otherwise the observation dict will
-            # be flattened to a numpy array. Plus,
-            # we're already doing our own preprocessing.
-            _disable_preprocessor_api=True
-        )
+        # .experimental(_disable_preprocessor_api=True)
         .framework("torch")
         .training(
             model={
@@ -71,15 +65,15 @@ def generate_training_config(
                     "no_masking": False
                 },  # Required by the example
             },
-            **hyperparams,  # Only used when set
+            **hyperparams,  # Only non-empty when set
         )
         .resources(num_gpus=num_gpus)
         .multi_agent(
             policies={
                 agent: PolicySpec(
                     policy_class_selector(agent),
-                    observation_space=env.observation_space[agent],
-                    action_space=env.action_space[agent],
+                    observation_space=env.observation_space(agent),
+                    action_space=env.action_space(agent),
                 )
                 for agent in env.agents
             },
@@ -90,7 +84,7 @@ def generate_training_config(
 
 
 def ppo_config(
-    env: StackerWrapper,
+    env: "AntiPoaching",
     env_config: dict = None,
     policies_to_train=None,
     num_gpus: int = 0,
@@ -136,7 +130,7 @@ def ppo_config(
 
 
 def pg_config(
-    env: StackerWrapper,
+    env: "AntiPoaching",
     env_config: dict = None,
     policies_to_train=None,
     num_gpus: int = 0,
@@ -179,6 +173,7 @@ def qmix_config(
         # None PolicyClass -> RLlib automatically infers
         policy_class_selector = lambda agent: None
 
+    env = grouped_rllib_cons_game(env, env_config)  # Group the env
     return (
         QMixConfig()
         .framework("torch")
diff --git a/examples/rllib_examples/example_utils.py b/examples/rllib_examples/example_utils.py
index 9b196608d3b1c498411f2979ea326f8de92bceb4..2ec88a4e185bc4a6bbdd6ea552c157ebb7e86ead 100644
--- a/examples/rllib_examples/example_utils.py
+++ b/examples/rllib_examples/example_utils.py
@@ -1,15 +1,10 @@
-""""Defines utils for the examples, like argument parsers,
-and a method to define an RLlib ready instance using args.
-"""
+""""Defines utils for the examples, like argument parsers, and a method to 
+define an RLlib ready instance using args."""
 
 import pathlib
 import argparse
 import json
-from anti_poaching.anti_poaching_v0 import (
-    anti_poaching,
-    rllib_cons_game,
-    grouped_rllib_cons_game,
-)
+from anti_poaching.anti_poaching_v0 import anti_poaching
 
 
 def get_args() -> argparse.ArgumentParser:
@@ -38,7 +33,4 @@ def define_game_from_args(args: argparse.Namespace) -> tuple:
         "prob_animal_appear": args.prob_anim,
         "max_time": args.max_time,
     }
-    env_generator = (
-        grouped_rllib_cons_game if args.algo == "QMIX" else rllib_cons_game
-    )
-    return env_generator(None, config=cg_config), cg_config
+    return anti_poaching.parallel_env(**cg_config), cg_config
diff --git a/examples/rllib_examples/from_checkpoint.py b/examples/rllib_examples/from_checkpoint.py
index 9039ff77b6cee7b0ff9c124d6a8480cde1c2b9d4..5eb8124bb93c8b40944fecb8cc6da14ce19baaa3 100644
--- a/examples/rllib_examples/from_checkpoint.py
+++ b/examples/rllib_examples/from_checkpoint.py
@@ -9,7 +9,6 @@ For example, to launch from the checkpoint
 we can simply use
     
     python from_checkpoint.py <CHPOINT_DIR>
-
 """
 
 import json
@@ -21,11 +20,7 @@ from ray.tune.logger import pretty_print
 from ray.rllib.algorithms.algorithm import Algorithm
 from ray.rllib.algorithms.qmix import QMixConfig, QMix
 
-# Importing my enviroment
-from anti_poaching.anti_poaching_v0 import (
-    anti_poaching,
-    rllib_cons_game,
-)
+from anti_poaching.anti_poaching_v0 import anti_poaching
 
 
 def get_actions(algo: Algorithm, obs: dict) -> list:
@@ -47,7 +42,7 @@ def get_actions(algo: Algorithm, obs: dict) -> list:
             observation=obs[agent],
             state=state[idx],
             policy_id=agent,
-            timestep=env.env.unwrapped.curr_time,
+            timestep=obs[agent][0]["obs"][0],
         )
 
         # This is required since compute_single_actions can return
@@ -77,7 +72,11 @@ if __name__ == "__main__":
     env = algo.env_creator(env_config)
 
     # Recall that QMix uses the GroupAgentsWrapper for any env.
-    _base_env = env.env.unwrapped if isinstance(algo, QMix) else env.unwrapped
+    _base_env = (
+        env.env.get_sub_environments
+        if isinstance(algo, QMix)
+        else env.get_sub_environments
+    )
 
     # The calm before the storm ...
     obs, info = env.reset()
diff --git a/examples/rllib_examples/heuristic_policies.py b/examples/rllib_examples/heuristic_policies.py
index 07b8a6a5473854595fea2bed5e71b21ef234c7db..8af196fe78c900d6e3a75e0cd32b1003ae78086a 100644
--- a/examples/rllib_examples/heuristic_policies.py
+++ b/examples/rllib_examples/heuristic_policies.py
@@ -5,6 +5,7 @@ for the Anti Poaching game.
 """
 
 from typing import Generator
+from collections import OrderedDict
 import pathlib
 import numpy as np
 import torch
@@ -43,9 +44,12 @@ class ActionMaskedRandomPolicy(Policy):
         episodes=None,
         **kwargs,
     ):
-        action_mask = np.array(
-            obs_batch["action_mask"], dtype=np.int8
-        ).flatten()
+        if isinstance(obs_batch, OrderedDict):
+            action_mask = np.array(
+                obs_batch["action_mask"], dtype=np.int8
+            ).flatten()
+        else:
+            action_mask = obs_batch[0][: self.action_space.n].astype(np.int8)
         return (
             [self.action_space.sample(action_mask)],
             state_batches,
@@ -186,9 +190,13 @@ class ActionMaskedPlanningPoacherPolicy(Policy):
         episodes=None,
         **kwargs,
     ):
+        if isinstance(obs_batch, OrderedDict):
+            flat_obs = obs_batch["observations"].flatten()
+            flat_mask = obs_batch["action_mask"].flatten()
+        else:
+            flat_obs = obs_batch.flatten()
+            flat_mask, flat_obs = flat_obs[:6], flat_obs[6:]
 
-        flat_obs = obs_batch["observations"].flatten()
-        flat_mask = obs_batch["action_mask"].flatten()
         rem_time, curr_pos = flat_obs[0], flat_obs[1:3]
 
         if all(curr_pos == [-1, -1]):
diff --git a/examples/rllib_examples/main.py b/examples/rllib_examples/main.py
index 94301cbf670c99f61629ac1cfbcba84b60398b1e..5d57f6b5fb5f83db54d5be487263dd11b9b87d35 100644
--- a/examples/rllib_examples/main.py
+++ b/examples/rllib_examples/main.py
@@ -33,10 +33,7 @@ from ray.rllib.algorithms.qmix import QMixConfig, QMix
 from ray.rllib.algorithms.maddpg import MADDPGConfig, MADDPG
 
 # Importing custom code
-from anti_poaching.anti_poaching_v0 import (
-    anti_poaching,
-    StackerWrapper,
-)
+from anti_poaching.anti_poaching_v0 import anti_poaching
 from heuristic_policies import (
     ActionMaskedRandomPolicy,
     ActionMaskedPlanningPoacherPolicy,
@@ -82,19 +79,14 @@ if __name__ == "__main__":
     # call AntiPoachingGame with parser parameters
     # and store a reference to the created env in the callbacks.
     cg, cg_config = define_game_from_args(args)
-    _base_cg_env = (
-        cg.unwrapped if isinstance(cg, StackerWrapper) else cg.env.unwrapped
-    )
 
     # Choose the agents that learn
     policies_to_train = []
     if "r" in args.policies_train:
-        policies_to_train += (
-            ["rangers"] if args.algo == "QMIX" else _base_cg_env.rangers
-        )
+        policies_to_train += ["rangers"] if args.algo == "QMIX" else cg.rangers
     if "p" in args.policies_train:
         policies_to_train += (
-            ["poachers"] if args.algo == "QMIX" else _base_cg_env.poachers
+            ["poachers"] if args.algo == "QMIX" else cg.poachers
         )
 
     # Choose the algo that runs
@@ -104,8 +96,6 @@ if __name__ == "__main__":
         algo_config = pg_config
     elif args.algo == "QMIX":
         algo_config = qmix_config
-    elif args.algo == "MADDPG":
-        algo_config = generic_conf(MADDPGConfig)  # Generic first
     else:
         raise RuntimeError(f"Unknown algorithm specified: {args.algo}")
 
@@ -177,7 +167,6 @@ if __name__ == "__main__":
     ).fit()
 
     # Get the path to this experiment and save env_config there.
-    path_exp = results[0].path
     env_config_file_path = pathlib.Path(results[0].path) / "env_config.json"
     with open(env_config_file_path, "w") as fd:
         print(json.dumps(cg_config), file=fd)
diff --git a/examples/rllib_examples/tune_antipoaching.py b/examples/rllib_examples/tune_antipoaching.py
index 85c24a80bc12294ad2ab80d53fd88eaa13cd39d9..7cec120d15fe95b243e42e8b9f78838efe1c8351 100644
--- a/examples/rllib_examples/tune_antipoaching.py
+++ b/examples/rllib_examples/tune_antipoaching.py
@@ -7,10 +7,7 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
 from ray.rllib.policy.policy import PolicySpec
 from ray.rllib.algorithms.callbacks import DefaultCallbacks
 
-from anti_poaching.anti_poaching_v0 import (
-    anti_poaching,
-    StackerWrapper,
-)
+from anti_poaching.anti_poaching_v0 import anti_poaching
 from heuristic_policies import (
     ActionMaskedRandomPolicy,
     ActionMaskedPlanningPoacherPolicy,
@@ -30,14 +27,6 @@ from main import get_policies_from_arg
 import callbacks as cb
 
 
-class ComputeCustomMetric(DefaultCallbacks):
-    def on_train_result(self, *, algorithm, result: dict, **kwargs):
-        # Compute sum of the absolute value of multiple policies
-        result["custom_metrics"]["policies_absolute_reward_mean"] = sum(
-            np.abs(list(result["policy_reward_mean"].values()))
-        )
-
-
 COMMON_HYPERPARAMS = {
     "train_batch_size": tune.choice([128, 256, 512, 1024, 2048, 4096]),
     "gamma": tune.choice([0.9, 0.95, 0.97, 0.99, 0.999]),
@@ -91,9 +80,6 @@ if __name__ == "__main__":
     # Then parsing the above args for the tune search
     cg_args, _unknown = get_args()
     cg, cg_config = define_game_from_args(cg_args)
-    _base_cg_env = (
-        cg.unwrapped if isinstance(cg, StackerWrapper) else cg.env.unwrapped
-    )
     args, _ = parser.parse_known_args()
 
     # Parse the stopping conditions
@@ -109,15 +95,14 @@ if __name__ == "__main__":
     policies_to_train = []
     if "r" in cg_args.policies_train:
         policies_to_train += (
-            ["rangers"] if cg_args.algo == "QMIX" else _base_cg_env.rangers
+            ["rangers"] if cg_args.algo == "QMIX" else cg.rangers
         )
     if "p" in cg_args.policies_train:
         policies_to_train += (
-            ["poachers"] if cg_args.algo == "QMIX" else _base_cg_env.poachers
+            ["poachers"] if cg_args.algo == "QMIX" else cg.poachers
         )
 
     # Choose the algo that runs
-    print(cg_args.algo)
     if cg_args.algo == "PPO":
         algo_config = ppo_config
         hyperparams = PPO_HYPERPARAMS
@@ -146,7 +131,7 @@ if __name__ == "__main__":
             num_gpus=cg_args.num_gpus,  # AlgoConfig to use GPUs.
         )
         .reporting(metrics_num_episodes_for_smoothing=1)
-        .callbacks(ComputeCustomMetric)
+        .callbacks(cb.SumRangerRewardMetric)
         .rollouts(
             rollout_fragment_length="auto",
             batch_mode=tune.sample_from(
@@ -173,8 +158,8 @@ if __name__ == "__main__":
         cg_args.algo,
         param_space=config,
         tune_config=tune.TuneConfig(
-            metric="custom_metrics/policies_absolute_reward_mean",
-            mode="min",
+            metric="custom_metrics/rangers_mean_reward_sum",
+            mode="max",
             num_samples=args.num_samples,
         ),
         run_config=air.RunConfig(stop=stop),
diff --git a/tests/test_flatten_space.py b/tests/test_flatten_space.py
deleted file mode 100644
index 2c9df8d12816a0bf3b59259a6f0c97a0cd25490e..0000000000000000000000000000000000000000
--- a/tests/test_flatten_space.py
+++ /dev/null
@@ -1,162 +0,0 @@
-from collections import OrderedDict
-import numpy as np
-import pytest
-from gymnasium.spaces import (
-    Box,
-    Dict,
-    Discrete,
-    MultiDiscrete,
-    MultiBinary,
-    Text,
-)
-from anti_poaching.anti_poaching_v0 import (
-    box_flatten,
-    box_flatten_obs,
-    parse_obs,
-)
-
-
-def test_box():
-    # Testing on a flat box
-    b = Box(low=1, high=10, dtype=np.int8)
-    assert box_flatten(b) == b, "Flatten changed the box space!"
-
-    # and then on a 2D box
-    b = Box(low=1, high=10, dtype=np.int8, shape=(2, 3))
-    assert box_flatten(b) == Box(
-        low=1, high=10, shape=(6,), dtype=np.int8
-    ), "Flatten did not correctly work on 2D box!"
-
-
-def test_discrete():
-    d = Discrete(10)
-    b = box_flatten(d)
-    assert b == Box(
-        low=0, high=9, dtype=d.dtype
-    ), f"Discrete is not translated literally to Box !"
-    for _ in range(1000):
-        assert b.contains(
-            np.array([d.sample()])
-        ), f"{b=} does not contain {d=}"
-        assert d.contains(b.sample()[0]), f"{d=} does not contain {b=}"
-
-
-def test_multidiscrete():
-    md = MultiDiscrete([11, 11])
-    b = box_flatten(md)
-    assert b == Box(
-        low=np.array([0, 0]), high=np.array([10, 10]), dtype=md.dtype
-    ), f"MultiDiscrete is not translated literally to Box !"
-    for _ in range(1000):
-        assert b.contains(md.sample()), f"{b=} does not contain {md=}"
-        assert md.contains(b.sample()), f"{md=} does not contain {b=}"
-
-
-def test_multidim_discrete():
-    md = MultiDiscrete([[2, 3], [4, 5]])
-    b = box_flatten(md)
-    assert b == Box(
-        low=np.array([0, 0, 0, 0]), high=np.array([1, 2, 3, 4]), dtype=md.dtype
-    ), "MultiDiscrete is not translated literally to Box !"
-
-
-def test_dict():
-    dico = Dict(
-        {
-            "box": Box(low=1, high=10, dtype=np.float32),
-            "disc": Discrete(10),
-        }
-    )
-    # Note that dico will parse through each dictionary using sorted keys.
-    assert box_flatten(dico) == Box(
-        low=np.array([1, 0]), high=np.array([10, 9]), dtype=np.float32
-    ), "Dictionary flattening to box is not well defined"
-
-
-def test_dict_nested():
-    nested_dico = Dict(
-        {
-            "box": Box(low=1, high=10, dtype=np.int32),
-            "dico": Dict(
-                OrderedDict(
-                    {
-                        "inner_disc": Discrete(10),
-                        "inner_bin": MultiBinary(2),
-                        "inner_box": Box(low=0, high=10),
-                    }
-                )
-            ),
-        }
-    )
-    # Note that dico will parse through each dictionary using sorted keys.
-    box = box_flatten(nested_dico)
-    assert box == Box(
-        low=np.array([1, 0, 0, 0, 0]),
-        high=np.array([10, 9, 1, 1, 10]),
-    ), "Dictionary flattening to box is not well defined"
-
-    for itr in range(1000):
-        obs = nested_dico.sample()
-        fl_obs = box_flatten_obs(obs)
-        assert box.contains(
-            fl_obs
-        ), f"{itr=}| {box=} does not contain {fl_obs=}, taken from {obs=}"
-
-
-def test_multibinary():
-    mb = MultiBinary([11, 11])
-    b = box_flatten(mb)
-    assert b == Box(
-        low=0, high=1, shape=mb.shape, dtype=mb.dtype
-    ), f"MultiBinary is not translated literally to Box !"
-    for _ in range(1000):
-        assert b.contains(mb.sample()), f"{b=} does not contain {mb=}"
-        assert mb.contains(b.sample()), f"{mb=} does not contain {b=}"
-
-
-def test_unsupported_type():
-    with pytest.raises(NotImplementedError):
-        text_space = Text(5)
-        box_flatten(text_space)
-
-
-def test_parse_obs():
-    """Checks if parse_obs parses samples of flattened spaces to the
-    correct size and type of the original space
-    """
-    spaces = [
-        Box(low=1, high=10, dtype=np.int8, shape=[2, 3]),
-        Discrete(10),
-        MultiBinary([11, 11]),
-        MultiDiscrete([[2, 3], [4, 5]]),
-        Dict(
-            {
-                "box": Box(low=1, high=10, dtype=np.float32),
-                "disc": Discrete(10),
-            }
-        ),
-        Dict(
-            {
-                "box": Box(low=1, high=10, dtype=np.int32),
-                "dico": Dict(
-                    OrderedDict(
-                        {
-                            "inner_disc": Discrete(10),
-                            "inner_bin": MultiBinary(2),
-                            "inner_box": Box(low=0, high=10),
-                        }
-                    )
-                ),
-            }
-        ),
-    ]
-    for space in spaces:
-        # Test for each space
-        flat_box = box_flatten(space)
-        # Test random samples (with casting)
-        for _ in range(1000):
-            flat_sample = flat_box.sample()
-            unflat_sample = parse_obs(flat_sample, space)
-            assert space.contains(
-                unflat_sample
-            ), f"{unflat_sample=} not in {space=}"
diff --git a/tests/test_game_env.py b/tests/test_game_env.py
index f34c4faaf8bc1201da9c435c303ebca378298088..1aae3126b1ea17c6e87fc29ef0affd82a24dfdf9 100644
--- a/tests/test_game_env.py
+++ b/tests/test_game_env.py
@@ -35,7 +35,7 @@ def lantipoach_game():
         prob_detect_cell=0.5,
         prob_animal_appear=0.5,
         prob_detect_trap=0.5,
-        max_time=500,
+        max_time=200,
     )
     print("End of fixture ...")
 
@@ -63,9 +63,7 @@ def test_parallel_test_api_suite(lantipoach_game):
 
 
 def test_parallel_seed_test():
-    """
-    Runs the seed test on the Conservation game.
-    """
+    """Runs the seed test on the Conservation game."""
     parallel_seed_test(anti_poaching.parallel_env)
 
 
@@ -93,21 +91,24 @@ def test_environment(antipoach_game, every: int = 20):
             for agent in antipoach_game.agents
         }
         # step through the environment
-        observations, _, terminations, truncations, _ = antipoach_game.step(
-            actions
-        )
-
-        # Verify that each agent receives a valid observation !!!
+        _step_objs = antipoach_game.step(actions)
+        observations, rewards, terminations, truncations, infos = _step_objs
+
+        # Verify that all agents are alive !!! Recall that poachers are alive,
+        # but in a caught state.
+        time_stat = antipoach_game.curr_time < antipoach_game.max_time
+        for obj in _step_objs:
+            assert set(obj.keys()) == set(
+                antipoach_game.possible_agents
+            ), f"All agents should recieve things until end-of-game (current time = {antipoach_game.curr_time})"
+
+        # Verify that each agent receives valid observations and actions !!!
         for agent, space in antipoach_game.observation_spaces.items():
-            if agent not in antipoach_game.agents:
-                continue
             assert space.contains(
                 observations[agent]
             ), f"{agent} received incompatible observations\n {observations[agent]}"
 
         for agent, space in antipoach_game.action_spaces.items():
-            if agent not in antipoach_game.agents:
-                continue
             assert space.contains(
                 actions[agent]
             ), f"{agent} received incompatible action\n {actions[agent]}"
@@ -118,22 +119,15 @@ def test_environment(antipoach_game, every: int = 20):
         }
 
         # post-processing
-        done = all(
-            [
-                x or y
-                for x, y in zip(terminations.values(), truncations.values())
-            ]
-        )
+        done = all(terminations.values()) or all(truncations.values())
         if antipoach_game.curr_time % every == 0 or done:
-            antipoach_game.render()
+            antipoach_game.render()  # This should not bug out
             print("-" * 80)
 
     print(antipoach_game.curr_time, " is the current time")
     print(antipoach_game.max_time, " is the maximum time")
     print("Rewards: ")
-    for agent in antipoach_game.possible_agents:
-        print(agent, " has ", antipoach_game.total_rewards[agent])
-    assert sum(antipoach_game.total_rewards.values()) == 0.0
+    assert sum(rewards.values()) == 0.0, "Game is not zero-sum !!!"
 
 
 def test_init_antipoach_game(lantipoach_game):
@@ -175,50 +169,22 @@ def test_init_antipoach_game(lantipoach_game):
     assert set(lantipoach_game.rangers + lantipoach_game.poachers) == set(
         lantipoach_game.observation_spaces.keys()
     ), "Not all agents have observations, or there are too many agents."
-    assert all(
-        (
-            set(
-                [
-                    "remaining_time",
-                    "state",
-                    "ground_traps",
-                    "rangers",
-                    "poachers",
-                ]
-            )
-            == set(
-                lantipoach_game.observation_spaces[poacher][
-                    "observations"
-                ].keys()
-            )
-            for poacher in lantipoach_game.poachers
-        )
-    ), "Unknown poacher observation structure"
-    assert all(
-        (
-            set(
-                [
-                    "remaining_time",
-                    "state",
-                    "partners",
-                    "poacher_traps",
-                    "poachers_captured",
-                    "ground_traps",
-                ]
-            )
-            == set(
-                lantipoach_game.observation_spaces[ranger][
-                    "observations"
-                ].keys()
-            )
-            for ranger in lantipoach_game.rangers
-        )
-    ), "Unknown ranger observation structure"
 
-    # Check rewards
-    assert set(lantipoach_game.total_rewards.keys()) == set(
-        lantipoach_game.agents
-    ), "Unknown rewards structure"
+    obs_keys = {"observations", "action_mask"}
+    for agent in lantipoach_game.agents:
+        _space = lantipoach_game.observation_space(agent)
+        _sample = _space.sample()
+        assert set(_space.keys()) == obs_keys
+        if "ranger" in agent:
+            assert (
+                len(_sample["observations"])
+                == lantipoach_game._ranger_obs_size
+            ), "Ranger obs space of incorrect size."
+        else:
+            assert (
+                len(_sample["observations"])
+                == lantipoach_game._poacher_obs_size
+            ), "Poacher obs space of incorrect size."
 
 
 def test_reset(lantipoach_game):
@@ -260,24 +226,6 @@ def test_reset(lantipoach_game):
     ), "Reset does not reset poachers' traps."
 
 
-def test_observe(lantipoach_game):
-    """
-    Tests the received observations
-    """
-    with pytest.raises(AssertionError):
-        # Agent does not exist, therefore should throw.
-        lantipoach_game.observe("agent", {"invalid": "data"})
-    with pytest.raises(KeyError):
-        # Record is empty, therefore should throw.
-        lantipoach_game.observe("poacher_0", {})
-    with pytest.raises(KeyError):
-        # Record is invalid, therefore should throw.
-        lantipoach_game.observe("poacher_0", {"invalid": "data"})
-    with pytest.raises(KeyError):
-        # Record is invalid, therefore should throw.
-        lantipoach_game.observe("ranger_0", {"invalid": "data"})
-
-
 def test_assign_reward(antipoach_game):
     """
     Tests all the branches of the _assign_reward helper
@@ -285,26 +233,20 @@ def test_assign_reward(antipoach_game):
     """
     for poacher in antipoach_game.poachers.copy():
         rewards = dict.fromkeys(antipoach_game.agents, 0)
-        antipoach_game._assign_reward(poacher, 1, rewards)
 
+        # First if we assign a reward to a poacher
+        antipoach_game._assign_reward(poacher, 1, rewards)
         assert rewards[poacher] == 1
-        assert (
-            antipoach_game.total_rewards[poacher] == 0
-        ), "We did not assign anything here!!"
         assert (
             sum((rewards[ranger] for ranger in antipoach_game.rangers)) == -1
         )
 
         # Force assigning a reward after the poacher is `caught`
-        antipoach_game.agents.remove(poacher)
-        del rewards[poacher]
-        antipoach_game._assign_reward(poacher, 1, rewards)
-
-        assert (
-            antipoach_game.total_rewards[poacher] == 1
-        ), "total_rewards not updated for retired agent"
+        antipoach_game.grid.state[poacher] = antipoach_game.grid.NULL_POS
+        antipoach_game._assign_reward(poacher, -1, rewards)
+        assert rewards[poacher] == 0
         assert (
-            sum((rewards[ranger] for ranger in antipoach_game.rangers)) == -2
+            sum((rewards[ranger] for ranger in antipoach_game.rangers)) == 0
         ), "ranger rewards not updated when assigning reward to retired agent"
 
 
@@ -314,8 +256,7 @@ def test_observation_time(antipoach_game):
     obs, infos = antipoach_game.reset()
     for agent in antipoach_game.agents:
         assert (
-            antipoach_game.max_time
-            == obs[agent]["observations"]["remaining_time"]
+            antipoach_game.max_time == obs[agent]["observations"][0]
         ), f"{agent} is missing remaining_time attr in obs"
 
     for t in range(1, antipoach_game.max_time + 1):
@@ -339,5 +280,5 @@ def test_observation_time(antipoach_game):
         for agent in antipoach_game.agents:
             # and that each agent receives this as well.
             assert (antipoach_game.max_time - t) == obs[agent]["observations"][
-                "remaining_time"
+                0
             ], f"{agent} has bad updates for remaining_time attr in obs"
diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py
deleted file mode 100644
index 41237047068c64384f1c807c38c3d1445eba7efa..0000000000000000000000000000000000000000
--- a/tests/test_wrappers.py
+++ /dev/null
@@ -1,44 +0,0 @@
-"""
-Module to test the wrapped AntiPoachingGame, 
-i.e. after NonCategoricalFlatten and StackerWrapper
-"""
-
-from anti_poaching.anti_poaching_v0 import (
-    anti_poaching,
-    NonCategoricalFlatten,
-    StackerWrapper,
-)
-from pettingzoo.test import parallel_api_test
-
-
-def test_boxper():
-    """
-     Test to verify that the NonCategoricalFlatten
-    per returns a ParallelEnv after modification.
-    """
-    env = anti_poaching.parallel_env()
-    env = NonCategoricalFlatten(env)
-    parallel_api_test(env)
-
-
-def test_pipeline():
-    """
-    Test to verify that the two transformations composed
-    work well together
-    """
-    env = anti_poaching.parallel_env()
-    env = NonCategoricalFlatten(env)
-
-    # Run a few iterations
-    obs = env.reset()
-    print("After NonCategoricalFlatten: ", obs)
-
-    env = StackerWrapper(env)
-
-    # Run a few iterations
-    # Note that MultiAgentEnv's always send two objects !!!
-    obs, _ = env.reset()
-    print("After StackerWrapper: ", obs)
-    assert all(
-        ({"observations", "action_mask"} == obs[agent].keys() for agent in obs)
-    ), f"Structure of obs changed, {obs=}"