Skip to content

Commit

Permalink
Merge branch 'main' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
woongjoonchoi authored Sep 25, 2024
2 parents f11d798 + 4126761 commit 61f65fd
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
10 changes: 5 additions & 5 deletions advanced_source/extend_dispatcher.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ to `register a dispatched operator in C++ <dispatcher>`_ and how to write a
What's a new backend?
---------------------

Adding a new backend to PyTorch requires a lot of developement and maintainence from backend extenders.
Adding a new backend to PyTorch requires a lot of development and maintenance from backend extenders.
Before adding a new backend, let's first consider a few common use cases and recommended solutions for them:

* If you have new algorithms for an existing PyTorch operator, send a PR to PyTorch.
Expand All @@ -30,7 +30,7 @@ Before adding a new backend, let's first consider a few common use cases and rec

In this tutorial we'll mainly focus on adding a new out-of-tree device below. Adding out-of-tree support
for a different tensor layout might share many common steps with devices, but we haven't seen an example of
such integrations yet so it might require addtional work from PyTorch to support it.
such integrations yet so it might require additional work from PyTorch to support it.

Get a dispatch key for your backend
-----------------------------------
Expand Down Expand Up @@ -67,12 +67,12 @@ To create a Tensor on ``PrivateUse1`` backend, you need to set dispatch key in `
Note that ``TensorImpl`` class above assumes your Tensor is backed by a storage like CPU/CUDA. We also
provide ``OpaqueTensorImpl`` for backends without a storage. And you might need to tweak/override certain
methods to fit your customized hardware.
One example in pytorch repo is `Vulkan TensorImpl <https://github.com/pytorch/pytorch/blob/1.7/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h>`_.
One example in pytorch repo is `Vulkan TensorImpl <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h>`_.


.. note::
Once the prototype is done and you plan to do regular releases for your backend extension, please feel free to
submit a PR to ``pytorch/pytorch`` to reserve a dedicated dispath key for your backend.
submit a PR to ``pytorch/pytorch`` to reserve a dedicated dispatch key for your backend.


Get the full list of PyTorch operators
Expand Down Expand Up @@ -361,7 +361,7 @@ actively working on might improve the experience in the future:

* Improve test coverage of generic testing framework.
* Improve ``Math`` kernel coverage and more comprehensive tests to make sure ``Math``
kernel bahavior matches other backends like ``CPU/CUDA``.
kernel behavior matches other backends like ``CPU/CUDA``.
* Refactor ``RegistrationDeclarations.h`` to carry the minimal information and reuse
PyTorch's codegen as much as possible.
* Support a backend fallback kernel to automatic convert inputs to CPU and convert the
Expand Down
2 changes: 1 addition & 1 deletion beginner_source/dist_overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ When deciding what parallelism techniques to choose for your model, use these co
#. Use `DistributedDataParallel (DDP) <https://pytorch.org/docs/stable/notes/ddp.html>`__,
if your model fits in a single GPU but you want to easily scale up training using multiple GPUs.

* Use `torchrun <https://pytorch.org/docs/stable/elastic/run.html>`__, to launch multiple pytorch processes if you are you using more than one node.
* Use `torchrun <https://pytorch.org/docs/stable/elastic/run.html>`__, to launch multiple pytorch processes if you are using more than one node.

* See also: `Getting Started with Distributed Data Parallel <../intermediate/ddp_tutorial.html>`__

Expand Down
36 changes: 31 additions & 5 deletions prototype_source/flight_recorder_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ Flight Recorder consists of two core parts:

Enabling Flight Recorder
------------------------
There are two required environment variables to get the initial version of Flight Recorder working.
There are three required environment variables to get the initial version of Flight Recorder working.

- ``TORCH_NCCL_TRACE_BUFFER_SIZE = (0, N)``: Setting ``N`` to a positive number enables collection.
``N`` represents the number of entries that will be kept internally in a circular buffer.
We recommended to set this value at *2000*.
We recommended to set this value at *2000*. The default value is ``2000``.
- ``TORCH_NCCL_DUMP_ON_TIMEOUT = (true, false)``: Setting this to ``true`` will write out diagnostic files to disk on job timeout.
If enabled, there will be one file per rank output in the job's running directory.
If enabled, there will be one file per rank output in the job's running directory. The default value is ``false``.
- ``TORCH_NCCL_DEBUG_INFO_TEMP_FILE``: Setting the path where the flight recorder will be dumped with file prefix. One file per
rank. The default value is ``/tmp/nccl_trace_rank_``.

**Optional settings:**

Expand All @@ -71,6 +73,10 @@ Additional Settings

``fast`` is a new experimental mode that is shown to be much faster than the traditional ``addr2line``.
Use this setting in conjunction with ``TORCH_NCCL_TRACE_CPP_STACK`` to collect C++ traces in the Flight Recorder data.
- If you prefer not to have the flight recorder data dumped into the local disk but rather onto your own storage, you can define your own writer class.
This class should inherit from class ``::c10d::DebugInfoWriter`` `(code) <https://github.com/pytorch/pytorch/blob/release/2.5/torch/csrc/distributed/c10d/NCCLUtils.hpp#L237>`__
and then register the new writer using ``::c10d::DebugInfoWriter::registerWriter`` `(code) <https://github.com/pytorch/pytorch/blob/release/2.5/torch/csrc/distributed/c10d/NCCLUtils.hpp#L242>`__
before we initiate PyTorch distributed.

Retrieving Flight Recorder Data via an API
------------------------------------------
Expand Down Expand Up @@ -169,9 +175,29 @@ To run the convenience script, follow these steps:

2. To run the script, use this command:

.. code:: python
.. code:: shell
python fr_trace.py <dump dir containing trace files> [-o <output file>]
If you install the PyTorch nightly build or build from scratch with ``USE_DISTRIBUTED=1``, you can directly use the following
command directly:

.. code:: shell
torchfrtrace <dump dir containing trace files> [-o <output file>]
Currently, we support two modes for the analyzer script. The first mode allows the script to apply some heuristics to the parsed flight
recorder dumps to generate a report identifying potential culprits for the timeout. The second mode is simply outputs the raw dumps.
By default, the script prints flight recoder dumps for all ranks and all ``ProcessGroups``(PGs). This can be narrowed down to certain
ranks and PGs using the *--selected-ranks* argument. An example command is:
Caveat: tabulate module is needed, so you might need pip install it first.
.. code:: shell
python fr_trace.py -d <dump dir containing trace files> [-o <output file>]
python fr_trace.py <dump dir containing trace files> -j [--selected-ranks i j k ...]
torchfrtrace <dump dir containing trace files> -j [--selected-ranks i j k ...]
Conclusion
----------
Expand Down

0 comments on commit 61f65fd

Please sign in to comment.