Skip to content

Commit

Permalink
More tests and bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Nov 10, 2023
1 parent 1cb51d4 commit bb5659d
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 57 deletions.
88 changes: 49 additions & 39 deletions src/pystk2_gymnasium/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from itertools import repeat
import logging
import functools
import sys
from typing import Any, ClassVar, Dict, List, Optional, Tuple, TypedDict

import gymnasium as gym
Expand Down Expand Up @@ -84,12 +83,11 @@ def kart_observation_space(use_ai: bool):
),
}
)

if use_ai:
space["action"] = kart_action_space()["action"]

return space
space["action"] = kart_action_space()

return space


class STKAction(TypedDict):
Expand Down Expand Up @@ -153,8 +151,6 @@ def __init__(
render_mode=None,
track=None,
num_kart=3,
rank_start=None,
use_ai=False,
max_paths=None,
laps: int = 1,
difficulty: int = 2,
Expand Down Expand Up @@ -194,7 +190,7 @@ def reset_race(
options: Optional[Dict[str, Any]] = None,
) -> Tuple[pystk2.WorldState, Dict[str, Any]]:
if self.race:
del self.race
self.race = None

# Setup the race configuration
self.current_track = self.default_track
Expand Down Expand Up @@ -238,11 +234,13 @@ def close(self):
super().close()
if self.race is not None:
self.race.stop()
del self.race
self.race = None

def world_update(self):
"""Update world state, but keep some information to compute reward"""
self.last_overall_distances = [max(kart.overall_distance, 0) for kart in self.world.karts]
self.last_overall_distances = [
max(kart.overall_distance, 0) for kart in self.world.karts
]
self.world.update()

def get_state(self, kart_ix: int, use_ai: bool):
Expand All @@ -269,6 +267,7 @@ def get_state(self, kart_ix: int, use_ai: bool):
"distance": d_t,
},
)

def get_observation(self, kart_ix, use_ai):
kart = self.world.karts[kart_ix]

Expand Down Expand Up @@ -377,7 +376,7 @@ def sort_closest(positions, *lists):
kartview(x[1]) for x in iterate_from(self.track.path_nodes, path_ix)
),
}

def render(self):
# Just do nothing... rendering is done directly
pass
Expand Down Expand Up @@ -450,24 +449,19 @@ def step(

obs, reward, terminated, info = self.get_state(self.kart_ix, self.use_ai)

return (
obs,
reward,
terminated,
False,
info
)
return (obs, reward, terminated, False, info)


@dataclass
class AgentSpec:
rank_start: Optional[int] = None
use_ai: bool = False


class STKRaceMultiEnv(BaseSTKRaceEnv):
"""Multi-agent race environment"""

def __init__(self, *, agents: List[AgentSpec]=None, **kwargs):
def __init__(self, *, agents: List[AgentSpec] = None, **kwargs):
"""Creates a new race
:param rank_start: The position of the controlled kart, defaults to None
Expand All @@ -484,15 +478,22 @@ def __init__(self, *, agents: List[AgentSpec]=None, **kwargs):
self.kart_indices = None

ranked_agents = [agent for agent in agents if agent.rank_start is not None]

assert all(agent.rank_start < self.num_kart for agent in ranked_agents), "Karts must have all have a valid position"
assert len(set(ranked_agents)) == len(ranked_agents), "Some agents have the same starting position"

self.free_positions = [ix for ix in range(self.num_kart) if ix not in ranked_agents]

self.action_space = spaces.Tuple(repeat(kart_action_space(), len(self.agents)))
self.observation_space = spaces.Tuple(kart_observation_space(agent.use_ai) for agent in self.agents)

assert all(
agent.rank_start < self.num_kart for agent in ranked_agents
), "Karts must have all have a valid position"
assert len(set(ranked_agents)) == len(
ranked_agents
), "Some agents have the same starting position"

self.free_positions = [
ix for ix in range(self.num_kart) if ix not in ranked_agents
]

self.action_space = spaces.Tuple(repeat(kart_action_space(), len(self.agents)))
self.observation_space = spaces.Tuple(
kart_observation_space(agent.use_ai) for agent in self.agents
)

def reset(
self,
Expand All @@ -512,27 +513,37 @@ def reset(
for agent in self.agents:
kart_ix = agent.rank_start or next(pos_iter)
self.kart_indices.append(kart_ix)
self.config.players[
kart_ix
].camera_mode = pystk2.PlayerConfig.CameraMode.ON
self.config.players[kart_ix].camera_mode = pystk2.PlayerConfig.CameraMode.ON
if not agent.use_ai:
self.config.players[
kart_ix
].controller = pystk2.PlayerConfig.Controller.PLAYER_CONTROL

logging.debug("Observed kart indices %s", self.kart_indices)

self.warmup_race()
self.world.update()

return tuple(self.get_observation(ix, agent.use_ai) for agent, ix in zip(self.agents, self.kart_indices)), {}
return (
tuple(
self.get_observation(ix, agent.use_ai)
for agent, ix in zip(self.agents, self.kart_indices)
),
{},
)

def step(
self, actions: Tuple[STKAction]
) -> Tuple[pystk2.WorldState, float, bool, bool, Dict[str, Any]]:
# Performs the action
assert len(actions) == len(self.agents)
self.race.step([get_action(action) for action in actions])
self.race.step(
[
get_action(action)
for agent, action in zip(self.agents, actions)
if not agent.use_ai
]
)

# Update the world state
self.world_update()
Expand All @@ -549,7 +560,6 @@ def step(
terminated_count += 1
infos.append(info)


return (
tuple(observations),
# Only scalar rewards can be given
Expand All @@ -559,6 +569,6 @@ def step(
{
"infos": infos,
# We put back individual rewards
"rewards": rewards
}
)
"rewards": rewards,
},
)
60 changes: 53 additions & 7 deletions src/pystk2_gymnasium/stk_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,56 @@
"""

import copy
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple

import gymnasium as gym
import numpy as np
import pystk2
from gymnasium import spaces
from gymnasium.core import (
Wrapper,
WrapperActType,
WrapperObsType,
ObsType,
ActType,
SupportsFloat,
)

from .envs import STKAction
from pystk2_gymnasium.utils import Discretizer, max_enum_value


class ActionObservationWrapper(Wrapper[ObsType, WrapperActType, ObsType, ActType]):
"""Combines action and observation wrapper"""

def action(self, action: WrapperActType) -> ActType:
raise NotImplementedError

def observation(self, observation: ObsType) -> WrapperObsType:
raise NotImplementedError

def __init__(self, env: gym.Env[ObsType, ActType]):
"""Constructor for the action wrapper."""
Wrapper.__init__(self, env)

def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[WrapperObsType, dict[str, Any]]:
"""Modifies the :attr:`env` after calling :meth:`reset`, returning a
modified observation using :meth:`self.observation`."""
obs, info = self.env.reset(seed=seed, options=options)
return self.observation(obs), info

def step(
self, action: ActType
) -> Tuple[WrapperObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Modifies the :attr:`env` after calling :meth:`step` using
:meth:`self.observation` on the returned observations."""
action = self.action(action)
observation, reward, terminated, truncated, info = self.env.step(action)
return self.observation(observation), reward, terminated, truncated, info


class PolarObservations(gym.ObservationWrapper):
"""Modifies position to polar positions
Expand Down Expand Up @@ -121,7 +160,7 @@ class STKDiscreteAction(STKAction):
steering: int


class DiscreteActionsWrapper(gym.ActionWrapper):
class DiscreteActionsWrapper(ActionObservationWrapper):
# Wraps the actions
def __init__(self, env: gym.Env, *, acceleration_steps=5, steer_steps=5, **kwargs):
super().__init__(env, **kwargs)
Expand Down Expand Up @@ -155,20 +194,19 @@ def to_discrete(self, action):
action["steer"] = self.d_steer.discretize(action["steer"])
return action

def step(self, action) -> Tuple[Any, float, bool, bool, Dict[str, Any]]:
# Transforms the action when part of the environment
obs, reward, terminated, truncated, info = self.env.step(action)
def observation(self, obs):
if "action" in obs:
obs = {**obs}
obs["action"] = self.to_discrete(obs["action"])
return obs, reward, terminated, truncated, info
return obs

def action(
self, action: STKDiscreteAction
) -> Tuple[Any, float, bool, bool, Dict[str, Any]]:
return self.from_discrete(action)


class OnlyContinuousActionsWrapper(gym.ActionWrapper):
class OnlyContinuousActionsWrapper(ActionObservationWrapper):
"""Removes the discrete actions"""

def __init__(self, env: gym.Env, **kwargs):
Expand All @@ -190,5 +228,13 @@ def __init__(self, env: gym.Env, **kwargs):
}
)

def observation(self, obs):
if "action" in obs:
obs = {**obs}
obs["action"] = {
key: obs["action"][key] for key in self.action_space.keys()
}
return obs

def action(self, action: Dict) -> Tuple[Any, float, bool, bool, Dict[str, Any]]:
return {**action, **{key: 0 for key, _ in self.discrete_actions.items()}}
5 changes: 4 additions & 1 deletion src/pystk2_gymnasium/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def observation(self, observation):
}

if self.has_action:
# Transforms from nested action to a flattened
obs_action = observation["action"]
discrete = np.array(
[obs_action[key] for key in self.action_flattener.discrete_keys]
Expand All @@ -102,7 +103,9 @@ def observation(self, observation):
else:
continuous = np.concatenate(
[
obs_action[key].flatten()
np.array([obs_action[key]])
if isinstance(obs_action[key], float)
else obs_action[key].flatten()
for key in self.action_flattener.continuous_keys
]
)
Expand Down
18 changes: 8 additions & 10 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import gymnasium as gym
import pytest
import pystk2_gymnasium
import pystk2_gymnasium # noqa: F401
from pystk2_gymnasium.envs import AgentSpec

envs = [key for key in gym.envs.registry.keys() if key.startswith("supertuxkart/")]


@pytest.mark.parametrize("name", envs)
def test_env(name):
@pytest.mark.parametrize("use_ai", [True, False])
def test_env(name, use_ai):
env = None
if name.startswith("supertuxkart/multi-"):
kwargs = {
"agents": [AgentSpec(), AgentSpec()]
}
kwargs = {"agents": [AgentSpec(use_ai=use_ai), AgentSpec(use_ai=use_ai)]}
else:
kwargs = {
"use_ai": False
}
kwargs = {"use_ai": use_ai}
try:
env = gym.make(name, render_mode=None, **kwargs)

Expand All @@ -28,7 +26,7 @@ def test_env(name):
action = env.action_space.sample()
# print(action)
state, reward, terminated, truncated, _ = env.step(action)
done = truncated or terminated
finally:
done = truncated or terminated
finally:
if env is not None:
env.close()

0 comments on commit bb5659d

Please sign in to comment.