From bb5659dced5bb0f85fc9ccff358994ea92105716 Mon Sep 17 00:00:00 2001 From: Benjamin Piwowarski Date: Fri, 10 Nov 2023 10:05:24 +0100 Subject: [PATCH] More tests and bug fixes --- src/pystk2_gymnasium/envs.py | 88 ++++++++++++++++------------ src/pystk2_gymnasium/stk_wrappers.py | 60 ++++++++++++++++--- src/pystk2_gymnasium/wrappers.py | 5 +- tests/test_envs.py | 18 +++--- 4 files changed, 114 insertions(+), 57 deletions(-) diff --git a/src/pystk2_gymnasium/envs.py b/src/pystk2_gymnasium/envs.py index dddccd9..6122dd1 100644 --- a/src/pystk2_gymnasium/envs.py +++ b/src/pystk2_gymnasium/envs.py @@ -2,7 +2,6 @@ from itertools import repeat import logging import functools -import sys from typing import Any, ClassVar, Dict, List, Optional, Tuple, TypedDict import gymnasium as gym @@ -84,12 +83,11 @@ def kart_observation_space(use_ai: bool): ), } ) - + if use_ai: - space["action"] = kart_action_space()["action"] - - return space + space["action"] = kart_action_space() + return space class STKAction(TypedDict): @@ -153,8 +151,6 @@ def __init__( render_mode=None, track=None, num_kart=3, - rank_start=None, - use_ai=False, max_paths=None, laps: int = 1, difficulty: int = 2, @@ -194,7 +190,7 @@ def reset_race( options: Optional[Dict[str, Any]] = None, ) -> Tuple[pystk2.WorldState, Dict[str, Any]]: if self.race: - del self.race + self.race = None # Setup the race configuration self.current_track = self.default_track @@ -238,11 +234,13 @@ def close(self): super().close() if self.race is not None: self.race.stop() - del self.race - + self.race = None + def world_update(self): """Update world state, but keep some information to compute reward""" - self.last_overall_distances = [max(kart.overall_distance, 0) for kart in self.world.karts] + self.last_overall_distances = [ + max(kart.overall_distance, 0) for kart in self.world.karts + ] self.world.update() def get_state(self, kart_ix: int, use_ai: bool): @@ -269,6 +267,7 @@ def get_state(self, kart_ix: int, use_ai: bool): "distance": d_t, }, ) + def get_observation(self, kart_ix, use_ai): kart = self.world.karts[kart_ix] @@ -377,7 +376,7 @@ def sort_closest(positions, *lists): kartview(x[1]) for x in iterate_from(self.track.path_nodes, path_ix) ), } - + def render(self): # Just do nothing... rendering is done directly pass @@ -450,24 +449,19 @@ def step( obs, reward, terminated, info = self.get_state(self.kart_ix, self.use_ai) - return ( - obs, - reward, - terminated, - False, - info - ) + return (obs, reward, terminated, False, info) + @dataclass class AgentSpec: rank_start: Optional[int] = None use_ai: bool = False - + class STKRaceMultiEnv(BaseSTKRaceEnv): """Multi-agent race environment""" - def __init__(self, *, agents: List[AgentSpec]=None, **kwargs): + def __init__(self, *, agents: List[AgentSpec] = None, **kwargs): """Creates a new race :param rank_start: The position of the controlled kart, defaults to None @@ -484,15 +478,22 @@ def __init__(self, *, agents: List[AgentSpec]=None, **kwargs): self.kart_indices = None ranked_agents = [agent for agent in agents if agent.rank_start is not None] - - assert all(agent.rank_start < self.num_kart for agent in ranked_agents), "Karts must have all have a valid position" - assert len(set(ranked_agents)) == len(ranked_agents), "Some agents have the same starting position" - - self.free_positions = [ix for ix in range(self.num_kart) if ix not in ranked_agents] - - self.action_space = spaces.Tuple(repeat(kart_action_space(), len(self.agents))) - self.observation_space = spaces.Tuple(kart_observation_space(agent.use_ai) for agent in self.agents) + assert all( + agent.rank_start < self.num_kart for agent in ranked_agents + ), "Karts must have all have a valid position" + assert len(set(ranked_agents)) == len( + ranked_agents + ), "Some agents have the same starting position" + + self.free_positions = [ + ix for ix in range(self.num_kart) if ix not in ranked_agents + ] + + self.action_space = spaces.Tuple(repeat(kart_action_space(), len(self.agents))) + self.observation_space = spaces.Tuple( + kart_observation_space(agent.use_ai) for agent in self.agents + ) def reset( self, @@ -512,27 +513,37 @@ def reset( for agent in self.agents: kart_ix = agent.rank_start or next(pos_iter) self.kart_indices.append(kart_ix) - self.config.players[ - kart_ix - ].camera_mode = pystk2.PlayerConfig.CameraMode.ON + self.config.players[kart_ix].camera_mode = pystk2.PlayerConfig.CameraMode.ON if not agent.use_ai: self.config.players[ kart_ix ].controller = pystk2.PlayerConfig.Controller.PLAYER_CONTROL - + logging.debug("Observed kart indices %s", self.kart_indices) self.warmup_race() self.world.update() - return tuple(self.get_observation(ix, agent.use_ai) for agent, ix in zip(self.agents, self.kart_indices)), {} + return ( + tuple( + self.get_observation(ix, agent.use_ai) + for agent, ix in zip(self.agents, self.kart_indices) + ), + {}, + ) def step( self, actions: Tuple[STKAction] ) -> Tuple[pystk2.WorldState, float, bool, bool, Dict[str, Any]]: # Performs the action assert len(actions) == len(self.agents) - self.race.step([get_action(action) for action in actions]) + self.race.step( + [ + get_action(action) + for agent, action in zip(self.agents, actions) + if not agent.use_ai + ] + ) # Update the world state self.world_update() @@ -549,7 +560,6 @@ def step( terminated_count += 1 infos.append(info) - return ( tuple(observations), # Only scalar rewards can be given @@ -559,6 +569,6 @@ def step( { "infos": infos, # We put back individual rewards - "rewards": rewards - } - ) \ No newline at end of file + "rewards": rewards, + }, + ) diff --git a/src/pystk2_gymnasium/stk_wrappers.py b/src/pystk2_gymnasium/stk_wrappers.py index d3edd82..6ba9012 100644 --- a/src/pystk2_gymnasium/stk_wrappers.py +++ b/src/pystk2_gymnasium/stk_wrappers.py @@ -3,17 +3,56 @@ """ import copy -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple import gymnasium as gym import numpy as np import pystk2 from gymnasium import spaces +from gymnasium.core import ( + Wrapper, + WrapperActType, + WrapperObsType, + ObsType, + ActType, + SupportsFloat, +) from .envs import STKAction from pystk2_gymnasium.utils import Discretizer, max_enum_value +class ActionObservationWrapper(Wrapper[ObsType, WrapperActType, ObsType, ActType]): + """Combines action and observation wrapper""" + + def action(self, action: WrapperActType) -> ActType: + raise NotImplementedError + + def observation(self, observation: ObsType) -> WrapperObsType: + raise NotImplementedError + + def __init__(self, env: gym.Env[ObsType, ActType]): + """Constructor for the action wrapper.""" + Wrapper.__init__(self, env) + + def reset( + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None + ) -> Tuple[WrapperObsType, dict[str, Any]]: + """Modifies the :attr:`env` after calling :meth:`reset`, returning a + modified observation using :meth:`self.observation`.""" + obs, info = self.env.reset(seed=seed, options=options) + return self.observation(obs), info + + def step( + self, action: ActType + ) -> Tuple[WrapperObsType, SupportsFloat, bool, bool, Dict[str, Any]]: + """Modifies the :attr:`env` after calling :meth:`step` using + :meth:`self.observation` on the returned observations.""" + action = self.action(action) + observation, reward, terminated, truncated, info = self.env.step(action) + return self.observation(observation), reward, terminated, truncated, info + + class PolarObservations(gym.ObservationWrapper): """Modifies position to polar positions @@ -121,7 +160,7 @@ class STKDiscreteAction(STKAction): steering: int -class DiscreteActionsWrapper(gym.ActionWrapper): +class DiscreteActionsWrapper(ActionObservationWrapper): # Wraps the actions def __init__(self, env: gym.Env, *, acceleration_steps=5, steer_steps=5, **kwargs): super().__init__(env, **kwargs) @@ -155,12 +194,11 @@ def to_discrete(self, action): action["steer"] = self.d_steer.discretize(action["steer"]) return action - def step(self, action) -> Tuple[Any, float, bool, bool, Dict[str, Any]]: - # Transforms the action when part of the environment - obs, reward, terminated, truncated, info = self.env.step(action) + def observation(self, obs): if "action" in obs: + obs = {**obs} obs["action"] = self.to_discrete(obs["action"]) - return obs, reward, terminated, truncated, info + return obs def action( self, action: STKDiscreteAction @@ -168,7 +206,7 @@ def action( return self.from_discrete(action) -class OnlyContinuousActionsWrapper(gym.ActionWrapper): +class OnlyContinuousActionsWrapper(ActionObservationWrapper): """Removes the discrete actions""" def __init__(self, env: gym.Env, **kwargs): @@ -190,5 +228,13 @@ def __init__(self, env: gym.Env, **kwargs): } ) + def observation(self, obs): + if "action" in obs: + obs = {**obs} + obs["action"] = { + key: obs["action"][key] for key in self.action_space.keys() + } + return obs + def action(self, action: Dict) -> Tuple[Any, float, bool, bool, Dict[str, Any]]: return {**action, **{key: 0 for key, _ in self.discrete_actions.items()}} diff --git a/src/pystk2_gymnasium/wrappers.py b/src/pystk2_gymnasium/wrappers.py index e211e46..05bb66f 100644 --- a/src/pystk2_gymnasium/wrappers.py +++ b/src/pystk2_gymnasium/wrappers.py @@ -93,6 +93,7 @@ def observation(self, observation): } if self.has_action: + # Transforms from nested action to a flattened obs_action = observation["action"] discrete = np.array( [obs_action[key] for key in self.action_flattener.discrete_keys] @@ -102,7 +103,9 @@ def observation(self, observation): else: continuous = np.concatenate( [ - obs_action[key].flatten() + np.array([obs_action[key]]) + if isinstance(obs_action[key], float) + else obs_action[key].flatten() for key in self.action_flattener.continuous_keys ] ) diff --git a/tests/test_envs.py b/tests/test_envs.py index 6e08157..1c155ff 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,21 +1,19 @@ import gymnasium as gym import pytest -import pystk2_gymnasium +import pystk2_gymnasium # noqa: F401 from pystk2_gymnasium.envs import AgentSpec envs = [key for key in gym.envs.registry.keys() if key.startswith("supertuxkart/")] + @pytest.mark.parametrize("name", envs) -def test_env(name): +@pytest.mark.parametrize("use_ai", [True, False]) +def test_env(name, use_ai): env = None if name.startswith("supertuxkart/multi-"): - kwargs = { - "agents": [AgentSpec(), AgentSpec()] - } + kwargs = {"agents": [AgentSpec(use_ai=use_ai), AgentSpec(use_ai=use_ai)]} else: - kwargs = { - "use_ai": False - } + kwargs = {"use_ai": use_ai} try: env = gym.make(name, render_mode=None, **kwargs) @@ -28,7 +26,7 @@ def test_env(name): action = env.action_space.sample() # print(action) state, reward, terminated, truncated, _ = env.step(action) - done = truncated or terminated - finally: + done = truncated or terminated + finally: if env is not None: env.close()