diff --git a/arm/arm/launch_utils.py b/arm/arm/launch_utils.py index cf947ca..03eddfc 100644 --- a/arm/arm/launch_utils.py +++ b/arm/arm/launch_utils.py @@ -45,7 +45,7 @@ def create_replay(batch_size: int, timesteps: int, prioritisation: bool, reward_dtype=np.float32, update_horizon=1, observation_elements=observation_elements, - extra_replay_elements=[ReplayElement('demo', (), np.bool)] + extra_replay_elements=[ReplayElement('demo', (), bool)] ) return replay_buffer