Skip to content

Commit

Permalink
Remove linear lowering pass and converter (#3323)
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu authored Dec 17, 2024
1 parent 283a983 commit d7071ba
Show file tree
Hide file tree
Showing 7 changed files with 0 additions and 362 deletions.
20 changes: 0 additions & 20 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2509,26 +2509,6 @@ def aten_ops_convolution(
)


@dynamo_tensorrt_converter(torch.ops.aten.linear.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.linear, supports_dynamic_shapes=True)
def aten_ops_linear(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.linear.linear(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
weight=args[1],
bias=args_bounds_check(args, 2, None),
)


@dynamo_tensorrt_converter(torch.ops.aten._cdist_forward.default)
def aten_ops_cdist_forward(
ctx: ConversionContext,
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
embedding,
full,
grid,
linear,
matmul,
normalization,
pad,
Expand Down
54 changes: 0 additions & 54 deletions py/torch_tensorrt/dynamo/conversion/impl/linear.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .accumulate_fp32_matmul import accumulate_fp32_matmul
from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_linear import lower_linear
from .pass_manager import DynamoPassManager
from .remove_assert_scalar import remove_assert_scalar
from .remove_detach import remove_detach
Expand All @@ -22,7 +21,6 @@
remove_input_alias_fixing_clones,
constant_fold,
repair_input_as_output,
lower_linear,
fuse_prims_broadcast,
replace_max_pool_with_indices,
replace_full_like_with_full,
Expand Down
42 changes: 0 additions & 42 deletions py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py

This file was deleted.

131 changes: 0 additions & 131 deletions tests/py/dynamo/conversion/test_linear_aten.py

This file was deleted.

112 changes: 0 additions & 112 deletions tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,118 +158,6 @@ def forward(self, x):
torch._dynamo.reset()


class TestLowerLinear(TestCase):
@unittest.skip(
"This test has threshold failures. This is tracked at https://github.com/pytorch/TensorRT/issues/2715",
)
def test_lower_linear(self):
class Linear(torch.nn.Module):
def forward(self, input, weight, bias):
out = torch.ops.aten.linear.default(input, weight, bias)
return out

inputs = [
torch.rand((3, 32)).cuda(),
torch.rand((64, 32)).cuda(),
torch.rand((64,)).cuda(),
]

fx_graph = torch.fx.symbolic_trace(Linear())
expected_ops = {torch.ops.aten.linear.default}
unexpected_ops = {
torch.ops.aten.permute.default,
torch.ops.aten.addmm.default,
}

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)
torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = torch.cat(
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
)
torch_model_results = torch.cat(
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)

self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
msg=f"Linear TRT outputs don't match with the original model.",
)
torch._dynamo.reset()

def test_lower_linear_batch(self):
class Linear(torch.nn.Module):
def forward(self, input, weight, bias):
out = torch.ops.aten.linear.default(input, weight, bias)
return out

inputs = [
torch.rand((2, 2, 32)).cuda(),
torch.rand((64, 32)).cuda(),
torch.rand((64,)).cuda(),
]

fx_graph = torch.fx.symbolic_trace(Linear())

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = torch.cat(
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
)
torch_model_results = torch.cat(
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
msg=f"Linear TRT outputs don't match with the original model.",
)
torch._dynamo.reset()


class TestLowerViewToReshape(TestCase):
def test_view_to_reshape(self):
class ViewToReshape(torch.nn.Module):
Expand Down

0 comments on commit d7071ba

Please sign in to comment.