Skip to content

Commit

Permalink
Merge branch 'main' into update_pg_tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
H-Huang authored Jun 27, 2023
2 parents 1de0a51 + d68dfce commit c500c29
Showing 1 changed file with 51 additions and 54 deletions.
105 changes: 51 additions & 54 deletions intermediate_source/reinforcement_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Key learnings:
- How to create an environment in TorchRL, transform its outputs, and collect data from this environment;
- How to make your classes talk to each other using :class:`tensordict.TensorDict`;
- How to make your classes talk to each other using :class:`~tensordict.TensorDict`;
- The basics of building your training loop with TorchRL:
- How to compute the advantage signal for policy gradient methods;
Expand Down Expand Up @@ -56,7 +56,7 @@
# problem rather than re-inventing the wheel every time you want to train a policy.
#
# For completeness, here is a brief overview of what the loss computes, even though
# this is taken care of by our :class:`ClipPPOLoss` module—the algorithm works as follows:
# this is taken care of by our :class:`~torchrl.objectives.ClipPPOLoss` module—the algorithm works as follows:
# 1. we will sample a batch of data by playing the
# policy in the environment for a given number of steps.
# 2. Then, we will perform a given number of optimization steps with random sub-samples of this batch using
Expand Down Expand Up @@ -99,7 +99,7 @@
# 5. Finally, we will run our training loop and analyze the results.
#
# Throughout this tutorial, we'll be using the :mod:`tensordict` library.
# :class:`tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract
# :class:`~tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract
# what a module reads and writes and care less about the specific data
# description and more about the algorithm itself.
#
Expand All @@ -115,13 +115,8 @@
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (
Compose,
DoubleToFloat,
ObservationNorm,
StepCounter,
TransformedEnv,
)
from torchrl.envs import (Compose, DoubleToFloat, ObservationNorm, StepCounter,
TransformedEnv)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs, set_exploration_mode
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
Expand All @@ -143,7 +138,7 @@
#

device = "cpu" if not torch.has_cuda else "cuda:0"
num_cells = 256 # number of cells in each layer
num_cells = 256 # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0

Expand Down Expand Up @@ -231,8 +226,8 @@
# We will append some transforms to our environments to prepare the data for
# the policy. In Gym, this is usually achieved via wrappers. TorchRL takes a different
# approach, more similar to other pytorch domain libraries, through the use of transforms.
# To add transforms to an environment, one should simply wrap it in a :class:`TransformedEnv`
# instance, and append the sequence of transforms to it. The transformed environment will inherit
# To add transforms to an environment, one should simply wrap it in a :class:`~torchrl.envs.transforms.TransformedEnv`
# instance and append the sequence of transforms to it. The transformed environment will inherit
# the device and meta-data of the wrapped environment, and transform these depending on the sequence
# of transforms it contains.
#
Expand All @@ -245,13 +240,13 @@
# run a certain number of random steps in the environment and compute
# the summary statistics of these observations.
#
# We'll append two other transforms: the :class:`DoubleToFloat` transform will
# We'll append two other transforms: the :class:`~torchrl.envs.transforms.DoubleToFloat` transform will
# convert double entries to single-precision numbers, ready to be read by the
# policy. The :class:`StepCounter` transform will be used to count the steps before
# policy. The :class:`~torchrl.envs.transforms.StepCounter` transform will be used to count the steps before
# the environment is terminated. We will use this measure as a supplementary measure
# of performance.
#
# As we will see later, many of the TorchRL's classes rely on :class:`tensordict.TensorDict`
# As we will see later, many of the TorchRL's classes rely on :class:`~tensordict.TensorDict`
# to communicate. You could think of it as a python dictionary with some extra
# tensor features. In practice, this means that many modules we will be working
# with need to be told what key to read (``in_keys``) and what key to write
Expand All @@ -274,13 +269,13 @@

######################################################################
# As you may have noticed, we have created a normalization layer but we did not
# set its normalization parameters. To do this, :class:`ObservationNorm` can
# set its normalization parameters. To do this, :class:`~torchrl.envs.transforms.ObservationNorm` can
# automatically gather the summary statistics of our environment:
#
env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)

######################################################################
# The :class:`ObservationNorm` transform has now been populated with a
# The :class:`~torchrl.envs.transforms.ObservationNorm` transform has now been populated with a
# location and a scale that will be used to normalize the data.
#
# Let us do a little sanity check for the shape of our summary stats:
Expand All @@ -294,7 +289,8 @@
# For efficiency purposes, TorchRL is quite stringent when it comes to
# environment specs, but you can easily check that your environment specs are
# adequate.
# In our example, the :class:`GymWrapper` and :class:`GymEnv` that inherits
# In our example, the :class:`~torchrl.envs.libs.gym.GymWrapper` and
# :class:`~torchrl.envs.libs.gym.GymEnv` that inherits
# from it already take care of setting the proper specs for your environment so
# you should not have to care about this.
#
Expand Down Expand Up @@ -327,9 +323,9 @@
# action as input, and outputs an observation, a reward and a done state. The
# observation may be composite, meaning that it could be composed of more than one
# tensor. This is not a problem for TorchRL, since the whole set of observations
# is automatically packed in the output :class:`tensordict.TensorDict`. After executing a rollout
# is automatically packed in the output :class:`~tensordict.TensorDict`. After executing a rollout
# (for example, a sequence of environment steps and random action generations) over a given
# number of steps, we will retrieve a :class:`tensordict.TensorDict` instance with a shape
# number of steps, we will retrieve a :class:`~tensordict.TensorDict` instance with a shape
# that matches this trajectory length:
#
rollout = env.rollout(3)
Expand All @@ -339,7 +335,7 @@
######################################################################
# Our rollout data has a shape of ``torch.Size([3])``, which matches the number of steps
# we ran it for. The ``"next"`` entry points to the data coming after the current step.
# In most cases, the ``"next""`` data at time `t` matches the data at ``t+1``, but this
# In most cases, the ``"next"`` data at time `t` matches the data at ``t+1``, but this
# may not be the case if we are using some specific transformations (for example, multi-step).
#
# Policy
Expand All @@ -364,12 +360,11 @@
#
# We design the policy in three steps:
#
# 1. Define a neural network ``D_obs`` -> ``2 * D_action``. Indeed, our ``loc`` (mu) and ``scale`` (sigma) both have dimension ``D_action``;
# 1. Define a neural network ``D_obs`` -> ``2 * D_action``. Indeed, our ``loc`` (mu) and ``scale`` (sigma) both have dimension ``D_action``.
#
# 2. Append a :class:`NormalParamExtractor` to extract a location and a scale (for example, splits the input in two equal parts
# and applies a positive transformation to the scale parameter);
# 2. Append a :class:`~tensordict.nn.distributions.NormalParamExtractor` to extract a location and a scale (for example, splits the input in two equal parts and applies a positive transformation to the scale parameter).
#
# 3. Create a probabilistic :class:`TensorDictModule` that can create this distribution and sample from it.
# 3. Create a probabilistic :class:`~tensordict.nn.TensorDictModule` that can generate this distribution and sample from it.
#

actor_net = nn.Sequential(
Expand All @@ -385,7 +380,7 @@

######################################################################
# To enable the policy to "talk" with the environment through the ``tensordict``
# data carrier, we wrap the ``nn.Module`` in a :class:`TensorDictModule`. This
# data carrier, we wrap the ``nn.Module`` in a :class:`~tensordict.nn.TensorDictModule`. This
# class will simply ready the ``in_keys`` it is provided with and write the
# outputs in-place at the registered ``out_keys``.
#
Expand All @@ -395,18 +390,19 @@

######################################################################
# We now need to build a distribution out of the location and scale of our
# normal distribution. To do so, we instruct the :class:`ProbabilisticActor`
# class to build a :class:`TanhNormal` out of the location and scale
# normal distribution. To do so, we instruct the
# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor`
# class to build a :class:`~torchrl.modules.TanhNormal` out of the location and scale
# parameters. We also provide the minimum and maximum values of this
# distribution, which we gather from the environment specs.
#
# The name of the ``in_keys`` (and hence the name of the ``out_keys`` from
# the :class:`TensorDictModule` above) cannot be set to any value one may
# like, as the :class:`TanhNormal` distribution constructor will expect the
# the :class:`~tensordict.nn.TensorDictModule` above) cannot be set to any value one may
# like, as the :class:`~torchrl.modules.TanhNormal` distribution constructor will expect the
# ``loc`` and ``scale`` keyword arguments. That being said,
# :class:`ProbabilisticActor` also accepts ``Dict[str, str]`` typed ``in_keys``
# where the key-value pair indicates what ``in_key`` string should be used for
# every keyword argument that is to be used.
# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` also accepts
# ``Dict[str, str]`` typed ``in_keys`` where the key-value pair indicates
# what ``in_key`` string should be used for every keyword argument that is to be used.
#
policy_module = ProbabilisticActor(
module=policy_module,
Expand Down Expand Up @@ -450,7 +446,7 @@

######################################################################
# let's try our policy and value modules. As we said earlier, the usage of
# :class:`TensorDictModule` makes it possible to directly read the output
# :class:`~tensordict.nn.TensorDictModule` makes it possible to directly read the output
# of the environment to run these modules, as they know what information to read
# and where to write it:
#
Expand All @@ -461,29 +457,30 @@
# Data collector
# --------------
#
# TorchRL provides a set of :class:`DataCollector` classes. Briefly, these
# classes execute three operations: reset an environment, compute an action
# given the latest observation, execute a step in the environment, and repeat
# the last two steps until the environment reaches a stop signal (or ``"done"``
# state).
# TorchRL provides a set of `DataCollector classes <https://pytorch.org/rl/reference/collectors.html>`__.
# Briefly, these classes execute three operations: reset an environment,
# compute an action given the latest observation, execute a step in the environment,
# and repeat the last two steps until the environment signals a stop (or reaches
# a done state).
#
# They allow you to control how many frames to collect at each iteration
# (through the ``frames_per_batch`` parameter),
# when to reset the environment (through the ``max_frames_per_traj`` argument),
# on which ``device`` the policy should be executed, etc. They are also
# designed to work efficiently with batched and multiprocessed environments.
#
# The simplest data collector is the :class:`SyncDataCollector`: it is an
# iterator that you can use to get batches of data of a given length, and
# The simplest data collector is the :class:`~torchrl.collectors.collectors.SyncDataCollector`:
# it is an iterator that you can use to get batches of data of a given length, and
# that will stop once a total number of frames (``total_frames``) have been
# collected.
# Other data collectors (``MultiSyncDataCollector`` and
# ``MultiaSyncDataCollector``) will execute the same operations in synchronous
# and asynchronous manner over a set of multiprocessed workers.
# Other data collectors (:class:`~torchrl.collectors.collectors.MultiSyncDataCollector` and
# :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`) will execute
# the same operations in synchronous and asynchronous manner over a
# set of multiprocessed workers.
#
# As for the policy and environment before, the data collector will return
# :class:`tensordict.TensorDict` instances with a total number of elements that will
# match ``frames_per_batch``. Using :class:`tensordict.TensorDict` to pass data to the
# :class:`~tensordict.TensorDict` instances with a total number of elements that will
# match ``frames_per_batch``. Using :class:`~tensordict.TensorDict` to pass data to the
# training loop allows you to write data loading pipelines
# that are 100% oblivious to the actual specificities of the rollout content.
#
Expand All @@ -506,10 +503,10 @@
# of epochs.
#
# TorchRL's replay buffers are built using a common container
# :class:`ReplayBuffer` which takes as argument the components of the buffer:
# a storage, a writer, a sampler and possibly some transforms. Only the
# storage (which indicates the replay buffer capacity) is mandatory. We
# also specify a sampler without repetition to avoid sampling multiple times
# :class:`~torchrl.data.ReplayBuffer` which takes as argument the components
# of the buffer: a storage, a writer, a sampler and possibly some transforms.
# Only the storage (which indicates the replay buffer capacity) is mandatory.
# We also specify a sampler without repetition to avoid sampling multiple times
# the same item in one epoch.
# Using a replay buffer for PPO is not mandatory and we could simply
# sample the sub-batches from the collected batch, but using these classes
Expand All @@ -526,7 +523,7 @@
# -------------
#
# The PPO loss can be directly imported from TorchRL for convenience using the
# :class:`ClipPPOLoss` class. This is the easiest way of utilizing PPO:
# :class:`~torchrl.objectives.ClipPPOLoss` class. This is the easiest way of utilizing PPO:
# it hides away the mathematical operations of PPO and the control flow that
# goes with it.
#
Expand All @@ -540,7 +537,7 @@
# ``"value_target"`` entries.
# The ``"value_target"`` is a gradient-free tensor that represents the empirical
# value that the value network should represent with the input observation.
# Both of these will be used by :class:`ClipPPOLoss` to
# Both of these will be used by :class:`~torchrl.objectives.ClipPPOLoss` to
# return the policy and value losses.
#

Expand Down Expand Up @@ -693,7 +690,7 @@
#
# * From an efficiency perspective,
# we could run several simulations in parallel to speed up data collection.
# Check :class:`torchrl.envs.ParallelEnv` for further information.
# Check :class:`~torchrl.envs.ParallelEnv` for further information.
#
# * From a logging perspective, one could add a :class:`torchrl.record.VideoRecorder` transform to
# the environment after asking for rendering to get a visual rendering of the
Expand Down

0 comments on commit c500c29

Please sign in to comment.