From d68dfce297b03e240f2d0b02dc26613665668830 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matteo=20G=C3=A4tzner?= <50359250+MatteoGaetzner@users.noreply.github.com> Date: Tue, 27 Jun 2023 08:32:18 +0200 Subject: [PATCH] [Doc] Fixed typos and improved cross-referencing in ppo tutorial (#2490) * [Doc] Fixed typos and improved cross-referencing in ppo tutorial * improved cross-referencing in ppo tutorial --- intermediate_source/reinforcement_ppo.py | 105 +++++++++++------------ 1 file changed, 51 insertions(+), 54 deletions(-) diff --git a/intermediate_source/reinforcement_ppo.py b/intermediate_source/reinforcement_ppo.py index b7baba8c1f3..6501e98971e 100644 --- a/intermediate_source/reinforcement_ppo.py +++ b/intermediate_source/reinforcement_ppo.py @@ -16,7 +16,7 @@ Key learnings: - How to create an environment in TorchRL, transform its outputs, and collect data from this environment; -- How to make your classes talk to each other using :class:`tensordict.TensorDict`; +- How to make your classes talk to each other using :class:`~tensordict.TensorDict`; - The basics of building your training loop with TorchRL: - How to compute the advantage signal for policy gradient methods; @@ -56,7 +56,7 @@ # problem rather than re-inventing the wheel every time you want to train a policy. # # For completeness, here is a brief overview of what the loss computes, even though -# this is taken care of by our :class:`ClipPPOLoss` module—the algorithm works as follows: +# this is taken care of by our :class:`~torchrl.objectives.ClipPPOLoss` module—the algorithm works as follows: # 1. we will sample a batch of data by playing the # policy in the environment for a given number of steps. # 2. Then, we will perform a given number of optimization steps with random sub-samples of this batch using @@ -99,7 +99,7 @@ # 5. Finally, we will run our training loop and analyze the results. # # Throughout this tutorial, we'll be using the :mod:`tensordict` library. -# :class:`tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract +# :class:`~tensordict.TensorDict` is the lingua franca of TorchRL: it helps us abstract # what a module reads and writes and care less about the specific data # description and more about the algorithm itself. # @@ -115,13 +115,8 @@ from torchrl.data.replay_buffers import ReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.data.replay_buffers.storages import LazyTensorStorage -from torchrl.envs import ( - Compose, - DoubleToFloat, - ObservationNorm, - StepCounter, - TransformedEnv, -) +from torchrl.envs import (Compose, DoubleToFloat, ObservationNorm, StepCounter, + TransformedEnv) from torchrl.envs.libs.gym import GymEnv from torchrl.envs.utils import check_env_specs, set_exploration_mode from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator @@ -143,7 +138,7 @@ # device = "cpu" if not torch.has_cuda else "cuda:0" -num_cells = 256 # number of cells in each layer +num_cells = 256 # number of cells in each layer i.e. output dim. lr = 3e-4 max_grad_norm = 1.0 @@ -231,8 +226,8 @@ # We will append some transforms to our environments to prepare the data for # the policy. In Gym, this is usually achieved via wrappers. TorchRL takes a different # approach, more similar to other pytorch domain libraries, through the use of transforms. -# To add transforms to an environment, one should simply wrap it in a :class:`TransformedEnv` -# instance, and append the sequence of transforms to it. The transformed environment will inherit +# To add transforms to an environment, one should simply wrap it in a :class:`~torchrl.envs.transforms.TransformedEnv` +# instance and append the sequence of transforms to it. The transformed environment will inherit # the device and meta-data of the wrapped environment, and transform these depending on the sequence # of transforms it contains. # @@ -245,13 +240,13 @@ # run a certain number of random steps in the environment and compute # the summary statistics of these observations. # -# We'll append two other transforms: the :class:`DoubleToFloat` transform will +# We'll append two other transforms: the :class:`~torchrl.envs.transforms.DoubleToFloat` transform will # convert double entries to single-precision numbers, ready to be read by the -# policy. The :class:`StepCounter` transform will be used to count the steps before +# policy. The :class:`~torchrl.envs.transforms.StepCounter` transform will be used to count the steps before # the environment is terminated. We will use this measure as a supplementary measure # of performance. # -# As we will see later, many of the TorchRL's classes rely on :class:`tensordict.TensorDict` +# As we will see later, many of the TorchRL's classes rely on :class:`~tensordict.TensorDict` # to communicate. You could think of it as a python dictionary with some extra # tensor features. In practice, this means that many modules we will be working # with need to be told what key to read (``in_keys``) and what key to write @@ -274,13 +269,13 @@ ###################################################################### # As you may have noticed, we have created a normalization layer but we did not -# set its normalization parameters. To do this, :class:`ObservationNorm` can +# set its normalization parameters. To do this, :class:`~torchrl.envs.transforms.ObservationNorm` can # automatically gather the summary statistics of our environment: # env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0) ###################################################################### -# The :class:`ObservationNorm` transform has now been populated with a +# The :class:`~torchrl.envs.transforms.ObservationNorm` transform has now been populated with a # location and a scale that will be used to normalize the data. # # Let us do a little sanity check for the shape of our summary stats: @@ -294,7 +289,8 @@ # For efficiency purposes, TorchRL is quite stringent when it comes to # environment specs, but you can easily check that your environment specs are # adequate. -# In our example, the :class:`GymWrapper` and :class:`GymEnv` that inherits +# In our example, the :class:`~torchrl.envs.libs.gym.GymWrapper` and +# :class:`~torchrl.envs.libs.gym.GymEnv` that inherits # from it already take care of setting the proper specs for your environment so # you should not have to care about this. # @@ -327,9 +323,9 @@ # action as input, and outputs an observation, a reward and a done state. The # observation may be composite, meaning that it could be composed of more than one # tensor. This is not a problem for TorchRL, since the whole set of observations -# is automatically packed in the output :class:`tensordict.TensorDict`. After executing a rollout +# is automatically packed in the output :class:`~tensordict.TensorDict`. After executing a rollout # (for example, a sequence of environment steps and random action generations) over a given -# number of steps, we will retrieve a :class:`tensordict.TensorDict` instance with a shape +# number of steps, we will retrieve a :class:`~tensordict.TensorDict` instance with a shape # that matches this trajectory length: # rollout = env.rollout(3) @@ -339,7 +335,7 @@ ###################################################################### # Our rollout data has a shape of ``torch.Size([3])``, which matches the number of steps # we ran it for. The ``"next"`` entry points to the data coming after the current step. -# In most cases, the ``"next""`` data at time `t` matches the data at ``t+1``, but this +# In most cases, the ``"next"`` data at time `t` matches the data at ``t+1``, but this # may not be the case if we are using some specific transformations (for example, multi-step). # # Policy @@ -364,12 +360,11 @@ # # We design the policy in three steps: # -# 1. Define a neural network ``D_obs`` -> ``2 * D_action``. Indeed, our ``loc`` (mu) and ``scale`` (sigma) both have dimension ``D_action``; +# 1. Define a neural network ``D_obs`` -> ``2 * D_action``. Indeed, our ``loc`` (mu) and ``scale`` (sigma) both have dimension ``D_action``. # -# 2. Append a :class:`NormalParamExtractor` to extract a location and a scale (for example, splits the input in two equal parts -# and applies a positive transformation to the scale parameter); +# 2. Append a :class:`~tensordict.nn.distributions.NormalParamExtractor` to extract a location and a scale (for example, splits the input in two equal parts and applies a positive transformation to the scale parameter). # -# 3. Create a probabilistic :class:`TensorDictModule` that can create this distribution and sample from it. +# 3. Create a probabilistic :class:`~tensordict.nn.TensorDictModule` that can generate this distribution and sample from it. # actor_net = nn.Sequential( @@ -385,7 +380,7 @@ ###################################################################### # To enable the policy to "talk" with the environment through the ``tensordict`` -# data carrier, we wrap the ``nn.Module`` in a :class:`TensorDictModule`. This +# data carrier, we wrap the ``nn.Module`` in a :class:`~tensordict.nn.TensorDictModule`. This # class will simply ready the ``in_keys`` it is provided with and write the # outputs in-place at the registered ``out_keys``. # @@ -395,18 +390,19 @@ ###################################################################### # We now need to build a distribution out of the location and scale of our -# normal distribution. To do so, we instruct the :class:`ProbabilisticActor` -# class to build a :class:`TanhNormal` out of the location and scale +# normal distribution. To do so, we instruct the +# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` +# class to build a :class:`~torchrl.modules.TanhNormal` out of the location and scale # parameters. We also provide the minimum and maximum values of this # distribution, which we gather from the environment specs. # # The name of the ``in_keys`` (and hence the name of the ``out_keys`` from -# the :class:`TensorDictModule` above) cannot be set to any value one may -# like, as the :class:`TanhNormal` distribution constructor will expect the +# the :class:`~tensordict.nn.TensorDictModule` above) cannot be set to any value one may +# like, as the :class:`~torchrl.modules.TanhNormal` distribution constructor will expect the # ``loc`` and ``scale`` keyword arguments. That being said, -# :class:`ProbabilisticActor` also accepts ``Dict[str, str]`` typed ``in_keys`` -# where the key-value pair indicates what ``in_key`` string should be used for -# every keyword argument that is to be used. +# :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` also accepts +# ``Dict[str, str]`` typed ``in_keys`` where the key-value pair indicates +# what ``in_key`` string should be used for every keyword argument that is to be used. # policy_module = ProbabilisticActor( module=policy_module, @@ -450,7 +446,7 @@ ###################################################################### # let's try our policy and value modules. As we said earlier, the usage of -# :class:`TensorDictModule` makes it possible to directly read the output +# :class:`~tensordict.nn.TensorDictModule` makes it possible to directly read the output # of the environment to run these modules, as they know what information to read # and where to write it: # @@ -461,11 +457,11 @@ # Data collector # -------------- # -# TorchRL provides a set of :class:`DataCollector` classes. Briefly, these -# classes execute three operations: reset an environment, compute an action -# given the latest observation, execute a step in the environment, and repeat -# the last two steps until the environment reaches a stop signal (or ``"done"`` -# state). +# TorchRL provides a set of `DataCollector classes `__. +# Briefly, these classes execute three operations: reset an environment, +# compute an action given the latest observation, execute a step in the environment, +# and repeat the last two steps until the environment signals a stop (or reaches +# a done state). # # They allow you to control how many frames to collect at each iteration # (through the ``frames_per_batch`` parameter), @@ -473,17 +469,18 @@ # on which ``device`` the policy should be executed, etc. They are also # designed to work efficiently with batched and multiprocessed environments. # -# The simplest data collector is the :class:`SyncDataCollector`: it is an -# iterator that you can use to get batches of data of a given length, and +# The simplest data collector is the :class:`~torchrl.collectors.collectors.SyncDataCollector`: +# it is an iterator that you can use to get batches of data of a given length, and # that will stop once a total number of frames (``total_frames``) have been # collected. -# Other data collectors (``MultiSyncDataCollector`` and -# ``MultiaSyncDataCollector``) will execute the same operations in synchronous -# and asynchronous manner over a set of multiprocessed workers. +# Other data collectors (:class:`~torchrl.collectors.collectors.MultiSyncDataCollector` and +# :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`) will execute +# the same operations in synchronous and asynchronous manner over a +# set of multiprocessed workers. # # As for the policy and environment before, the data collector will return -# :class:`tensordict.TensorDict` instances with a total number of elements that will -# match ``frames_per_batch``. Using :class:`tensordict.TensorDict` to pass data to the +# :class:`~tensordict.TensorDict` instances with a total number of elements that will +# match ``frames_per_batch``. Using :class:`~tensordict.TensorDict` to pass data to the # training loop allows you to write data loading pipelines # that are 100% oblivious to the actual specificities of the rollout content. # @@ -506,10 +503,10 @@ # of epochs. # # TorchRL's replay buffers are built using a common container -# :class:`ReplayBuffer` which takes as argument the components of the buffer: -# a storage, a writer, a sampler and possibly some transforms. Only the -# storage (which indicates the replay buffer capacity) is mandatory. We -# also specify a sampler without repetition to avoid sampling multiple times +# :class:`~torchrl.data.ReplayBuffer` which takes as argument the components +# of the buffer: a storage, a writer, a sampler and possibly some transforms. +# Only the storage (which indicates the replay buffer capacity) is mandatory. +# We also specify a sampler without repetition to avoid sampling multiple times # the same item in one epoch. # Using a replay buffer for PPO is not mandatory and we could simply # sample the sub-batches from the collected batch, but using these classes @@ -526,7 +523,7 @@ # ------------- # # The PPO loss can be directly imported from TorchRL for convenience using the -# :class:`ClipPPOLoss` class. This is the easiest way of utilizing PPO: +# :class:`~torchrl.objectives.ClipPPOLoss` class. This is the easiest way of utilizing PPO: # it hides away the mathematical operations of PPO and the control flow that # goes with it. # @@ -540,7 +537,7 @@ # ``"value_target"`` entries. # The ``"value_target"`` is a gradient-free tensor that represents the empirical # value that the value network should represent with the input observation. -# Both of these will be used by :class:`ClipPPOLoss` to +# Both of these will be used by :class:`~torchrl.objectives.ClipPPOLoss` to # return the policy and value losses. # @@ -693,7 +690,7 @@ # # * From an efficiency perspective, # we could run several simulations in parallel to speed up data collection. -# Check :class:`torchrl.envs.ParallelEnv` for further information. +# Check :class:`~torchrl.envs.ParallelEnv` for further information. # # * From a logging perspective, one could add a :class:`torchrl.record.VideoRecorder` transform to # the environment after asking for rendering to get a visual rendering of the