Skip to content

Commit

Permalink
mark autograd function traceable (pytorch#3143)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#236

Pull Request resolved: pytorch#3143

# context
* encountered the following error:
```
Failures:

  1) torchrec.distributed.train_pipeline.tests.test_train_pipelines.TrainPipelineSparseDistCompAutogradTest: test_equal_to_non_pipelined
    1) RuntimeError: Attempting to trace a potentially unsafe C++ autograd function: torch::autograd::CppNode<fbgemm_gpu::PermuteMultiEmbeddingOp>. It may be possible to trace it safely, please refer to the instructions in: https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/.
      File "torchrec/distributed/train_pipeline/tests/test_train_pipelines.py", line 461, in test_equal_to_non_pipelined
        not torch.cuda.is_available(),
      File "hypothesis/core.py", line 1602, in wrapped_test
        raise the_error_hypothesis_found
      File "torchrec/distributed/train_pipeline/tests/test_train_pipelines.py", line 531, in test_equal_to_non_pipelined
        pred_pipeline = pipeline.progress(dataloader)
      File "torchrec/distributed/train_pipeline/train_pipelines.py", line 1603, in progress
        return super().progress(dataloader_iter)
      File "torchrec/distributed/train_pipeline/train_pipelines.py", line 499, in progress
        torch.sum(losses, dim=0).backward()
      File "torch/_tensor.py", line 581, in backward
        torch.autograd.backward(
      File "torch/autograd/__init__.py", line 347, in backward
        _engine_run_backward(
      File "torch/autograd/graph.py", line 825, in _engine_run_backward
        return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
```
* it turned out that the autograd function is not marked as traceable: [doc](https://fburl.com/gdoc/vh88gnk8)
* the test passed after the diff

Reviewed By: IvanKobzarev

Differential Revision: D62889399

fbshipit-source-id: f6cc8b3e2d5c2d407111ddb13124d202000a5d65
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Sep 24, 2024
1 parent 012a658 commit 03537c6
Showing 1 changed file with 1 addition and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using torch::autograd::variable_list;
class PermuteMultiEmbeddingOp
: public torch::autograd::Function<PermuteMultiEmbeddingOp> {
public:
static constexpr bool is_traceable = true;
static variable_list forward(
AutogradContext* ctx,
const at::TensorList& pooled_embs,
Expand Down

0 comments on commit 03537c6

Please sign in to comment.