From 5270c646f6cc7a118dfe02767a075c9ace5882b0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Sep 2024 10:17:33 -0700 Subject: [PATCH] [BugFix] Fix parsing integer batch size in AOT ghstack-source-id: 18a5798c5377d3e5b65e7b6c87d59917c474fd64 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1004 --- tensordict/_td.py | 10 +++++----- test/test_compile.py | 47 ++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index c7bda5ad7..92d0ae488 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2061,7 +2061,7 @@ def _parse_batch_size( source: T | dict | None, batch_size: Sequence[int] | torch.Size | int | None = None, ) -> torch.Size: - ERR = "batch size was not specified when creating the TensorDict instance and it could not be retrieved from source." + ERR = "batch size {} was not specified when creating the TensorDict instance and it could not be retrieved from source." if is_dynamo_compiling(): if isinstance(batch_size, torch.Size): @@ -2072,22 +2072,22 @@ def _parse_batch_size( return torch.Size(tuple(batch_size)) if batch_size is None: return torch.Size([]) - elif isinstance(batch_size, Number): + elif isinstance(batch_size, (Number, torch.SymInt)): return torch.Size([batch_size]) elif isinstance(source, TensorDictBase): return source.batch_size - raise ValueError() + raise ValueError(ERR.format(batch_size)) try: return torch.Size(batch_size) except Exception: if batch_size is None: return torch.Size([]) - elif isinstance(batch_size, Number): + elif isinstance(batch_size, (Number, torch.SymInt)): return torch.Size([batch_size]) elif isinstance(source, TensorDictBase): return source.batch_size - raise ValueError(ERR) + raise ValueError(ERR.format(batch_size)) @property def batch_dims(self) -> int: diff --git a/test/test_compile.py b/test/test_compile.py index de9220cf6..bd843b074 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -774,16 +774,17 @@ def call(x, td): @pytest.mark.skipif(not _v2_5, reason="Requires PT>=2.5") +@pytest.mark.parametrize("strict", [True, False]) class TestExport: - def test_export_module(self): + def test_export_module(self, strict): torch._dynamo.reset_code_caches() tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]) x = torch.randn(3) y = torch.randn(3) - out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}) + out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict) assert (out.module()(x=x, y=y) == tdm(x=x, y=y)).all() - def test_export_seq(self): + def test_export_seq(self, strict): torch._dynamo.reset_code_caches() tdm = Seq( Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]), @@ -791,9 +792,47 @@ def test_export_seq(self): ) x = torch.randn(3) y = torch.randn(3) - out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}) + out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict) torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y)) + @pytest.mark.parametrize( + "same_shape,dymanic_shape", [[True, True], [True, False], [False, True]] + ) + def test_td_output(self, strict, same_shape, dymanic_shape): + # This will only work when the tensordict is pytree-able + class Test(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + return TensorDict( + { + "x": x, + "y": y, + }, + batch_size=x.shape[0], + ) + + test = Test() + if same_shape: + x, y = torch.zeros(5, 100), torch.zeros(5, 100) + else: + x, y = torch.zeros(2, 100), torch.zeros(2, 100) + if dymanic_shape: + kwargs = { + "dynamic_shapes": { + "x": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")}, + "y": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")}, + } + } + else: + kwargs = {} + + result = torch.export.export(test, args=(x, y), strict=False, **kwargs) + export_mod = result.module() + x_new, y_new = torch.zeros(5, 100), torch.zeros(5, 100) + export_test = export_mod(x_new, y_new) + eager_test = test(x_new, y_new) + assert eager_test.batch_size == export_test.batch_size + assert (export_test == eager_test).all() + @pytest.mark.skipif(not _has_onnx, reason="ONNX is not available") class TestONNXExport: