Skip to content

Commit

Permalink
Working simple environnements
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Oct 25, 2023
1 parent 6434982 commit 039befe
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 83 deletions.
14 changes: 2 additions & 12 deletions src/pystk2_gymnasium/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@
entry_point="pystk2_gymnasium.envs:SimpleSTKRaceEnv",
max_episode_steps=1500,
additional_wrappers=(
WrapperSpec(
"obs-flattener", "pystk2_gymnasium.wrappers:ObsFlattenerWrapper", {}
),
WrapperSpec(
"action-flattener", "pystk2_gymnasium.wrappers:ActionFlattenerWrapper", {}
),
WrapperSpec("obs-flattener", "pystk2_gymnasium.wrappers:FlattenerWrapper", {}),
),
)

Expand All @@ -32,11 +27,6 @@
entry_point="pystk2_gymnasium.envs:DiscreteActionSTKRaceEnv",
max_episode_steps=1500,
additional_wrappers=(
WrapperSpec(
"obs-flattener", "pystk2_gymnasium.wrappers:ObsFlattenerWrapper", {}
),
WrapperSpec(
"action-flattener", "pystk2_gymnasium.wrappers:ActionFlattenerWrapper", {}
),
WrapperSpec("obs-flattener", "pystk2_gymnasium.wrappers:FlattenerWrapper", {}),
),
)
51 changes: 38 additions & 13 deletions src/pystk2_gymnasium/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ def sort_closest(positions, *lists):
"acceleration": action.acceleration,
}
}
if isinstance(self, DiscreteActionSTKRaceEnv):
obs["action"] = self.to_discrete(obs["action"])

return {
**obs,
# Kart properties
Expand Down Expand Up @@ -452,32 +455,54 @@ class STKDiscreteAction(STKAction):
steering: int


class Discretizer:
def __init__(self, space: spaces.Box, values: int):
self.max_value = float(space.high)
self.min_value = float(space.low)
self.values = values
self.space = spaces.Discrete(values)

def discretize(self, value: float):
v = int(
(value - self.min_value)
* (self.values - 1)
/ (self.max_value - self.min_value)
)
if v >= self.values:
return v - 1
return v

def continuous(self, value: int):
return (self.max_value - self.min_value) * value / (
self.values - 1
) + self.min_value


class DiscreteActionSTKRaceEnv(SimpleSTKRaceEnv):
# Wraps the actions
def __init__(self, acceleration_steps=10, steering_steps=10, **kwargs):
def __init__(self, acceleration_steps=10, steer_steps=10, **kwargs):
super().__init__(**kwargs)
self.acceleration_steps = acceleration_steps
self.steering_steps = steering_steps

self.action_space["acceleration"] = spaces.Discrete(acceleration_steps)
self.action_space["steer"] = spaces.Discrete(steering_steps)
self.d_acceleration = Discretizer(
self.action_space["acceleration"], acceleration_steps
)
self.action_space["acceleration"] = self.d_acceleration.space

self.d_steer = Discretizer(self.action_space["steer"], steer_steps)
self.action_space["steer"] = self.d_steer.space

def from_discrete(self, action):
action = {**action}
action["acceleration"] = action["acceleration"] / self.acceleration_steps
action["acceleration"] = self.d_acceleration.continuous(action["acceleration"])
max_steer_angle = self.world.karts[self.kart_ix].max_steer_angle
action["steer"] = action["steer"] / self.steering_steps * max_steer_angle
action["steer"] = self.d_steer.continuous(action["steer"]) * max_steer_angle
return action

def to_discrete(self, action):
action = {**action}
action["acceleration"] = np.array(
[int(action["acceleration"] * self.acceleration_steps)]
)
action["acceleration"] = self.d_acceleration.discretize(action["acceleration"])
max_steer_angle = self.world.karts[self.kart_ix].max_steer_angle
action["steer"] = np.array(
[int(action["steer"] * self.steering_steps / max_steer_angle)]
)
action["steer"] = self.d_steer.discretize(action["steer"] / max_steer_angle)
return action

def step(
Expand Down
136 changes: 78 additions & 58 deletions src/pystk2_gymnasium/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any
from gymnasium import spaces
import gymnasium as gym
import numpy as np


class SpaceFlattener:
def flatten_space(self, space: gym.Space):
def __init__(self, space: gym.Space):
# Flatten the observation space
self.continuous_keys = []
self.shapes = []
Expand Down Expand Up @@ -37,82 +38,101 @@ def flatten_space(self, space: gym.Space):
self.only_discrete = len(lows) == 0
discrete_space = spaces.MultiDiscrete(counts, dtype=np.int64)
if self.only_discrete:
return discrete_space

return spaces.Dict(
{
"env_obs/discrete": discrete_space,
"env_obs/continuous": spaces.Box(
low=np.concatenate(lows),
high=np.concatenate(highs),
shape=(continuous_size,),
dtype=np.float32,
),
}
)
self.space = discrete_space
else:
self.space = spaces.Dict(
{
"discrete": discrete_space,
"continuous": spaces.Box(
low=np.concatenate(lows),
high=np.concatenate(highs),
shape=(continuous_size,),
dtype=np.float32,
),
}
)


class ObsFlattenerWrapper(gym.ObservationWrapper, SpaceFlattener):
class FlattenerWrapper(gym.ObservationWrapper):
def __init__(self, env: gym.Env):
super().__init__(env)

self.observation_space = self.flatten_space(env.observation_space)
self.observation_flattener = SpaceFlattener(env.observation_space)
self.observation_space = self.observation_flattener.space

self.action_flattener = SpaceFlattener(env.action_space)
self.action_space = self.action_flattener.space

# Adds action in the space
self.has_action = env.observation_space.get("action", None) is not None
if self.has_action:
self.observation_space["action"] = env.action_space
self.observation_space["action"] = self.action_flattener.space

def observation(self, observation):
action = {}
if self.has_action:
action_dict = {
key: np.array([value]) for key, value in observation["action"].items()
}
from .envs import DiscreteActionSTKRaceEnv

if isinstance(self.unwrapped, DiscreteActionSTKRaceEnv):
action_dict = self.unwrapped.to_discrete(action_dict)
action = {f"action/{key}": value for key, value in action_dict.items()}

return {
**action,
"env_obs/discrete": np.array(
[observation[key] for key in self.discrete_keys]
new_obs = {
"discrete": np.array(
[observation[key] for key in self.observation_flattener.discrete_keys]
),
"env_obs/continuous": np.concatenate(
[observation[key].flatten() for key in self.continuous_keys]
"continuous": np.concatenate(
[
observation[key].flatten()
for key in self.observation_flattener.continuous_keys
]
),
}

if self.has_action:
obs_action = observation["action"]
discrete = np.array(
[obs_action[key] for key in self.action_flattener.discrete_keys]
)
if self.action_flattener.only_discrete:
new_obs["action"] = discrete
else:
continuous = np.concatenate(
[
obs_action[key].flatten()
for key in self.action_flattener.continuous_keys
]
)
new_obs["action"] = {"discrete": discrete, "continuous": continuous}

class ActionFlattenerWrapper(gym.ActionWrapper, SpaceFlattener):
def __init__(self, env: gym.Env):
super().__init__(env)
self.action_space = self.flatten_space(env.action_space)
return new_obs

def step(self, action) -> tuple[Any, float, bool, bool, dict[str, Any]]:
return super().step(self.action(action))

def action(self, action):
if self.only_discrete:
assert len(self.discrete_keys) == len(action), (
if self.action_flattener.only_discrete:
assert len(self.action_flattener.discrete_keys) == len(action), (
"Not enough discrete values: "
f"""expected {len(self.discrete_keys)}, got {len(action)}"""
f"""expected {len(self.action_flattener.discrete_keys)}, """
f"""got {len(action)}"""
)
return {
key: key_action for key, key_action in zip(self.discrete_keys, action)
action = {
key: key_action
for key, key_action in zip(self.action_flattener.discrete_keys, action)
}

assert len(self.discrete_keys) == len(
action["env_obs/discrete"]
), "Not enough discrete values: "
f"""expected {len(self.discrete_keys)}, got {len(action["env_obs/discrete"])}"""
discrete = {
key: key_action
for key, key_action in zip(self.discrete_keys, action["env_obs/discrete"])
}
continuous = {
key: action["env_obs/continuous"][
self.indices[ix] : self.indices[ix + 1]
].reshape(shape)
for ix, (key, shape) in enumerate(zip(self.continuous_keys, self.shapes))
}
return {**discrete, **continuous}
else:
assert len(self.action_flattener.discrete_keys) == len(
action["discrete"]
), "Not enough discrete values: "
f"""expected {len(self.discrete_keys)}, got {len(action["discrete"])}"""
discrete = {
key: key_action
for key, key_action in zip(
self.action_flattener.discrete_keys, action["discrete"]
)
}
continuous = {
key: action["continuous"][
self.indices[ix] : self.indices[ix + 1]
].reshape(shape)
for ix, (key, shape) in enumerate(
zip(self.action_flattener.continuous_keys, self.shapes)
)
}
action = {**discrete, **continuous}

return action

0 comments on commit 039befe

Please sign in to comment.