You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In real-world scenarios, user features are constantly changing, so I must use a list as the input for the forward function.
but when I use list input, the torch_tensorrt.dynamo.compile(exp_program, [inputs],min_block_size=2) raise error
To Reproduce
import torch
import torch_tensorrt
from typing import Optional, Sequence,Dict,List
from torch.nn import functional as F
from tzrec.modules.mlp import MLP
from torch import nn
@torch.fx.wrap
def _get_dict(grouped_features_keys: List[str], args:List[torch.Tensor])->Dict[str, torch.Tensor]:
if len(grouped_features_keys) != len(args):
raise ValueError(
"The number of grouped_features_keys must match "
"the number of arguments."
)
grouped_features = {
key: value for key, value in zip(grouped_features_keys, args)
}
return grouped_features
@torch.fx.wrap
def _arange(end: int, device: torch.device) -> torch.Tensor:
return torch.arange(end, device=device)
class MatMul2(torch.nn.Module):
def __init__(self):
super().__init__()
self.keys = ["query","sequence","sequence_length"]
attn_mlp= {'hidden_units': [256, 64], 'dropout_ratio': [], 'activation': 'nn.ReLU', 'use_bn': False}
self.mlp = MLP(in_features=41 * 4, **attn_mlp)
self.linear = nn.Linear(self.mlp.hidden_units[-1], 1)
def forward(self, args1: List[torch.Tensor]):
"""Forward the module."""
# use predict to avoid trace error in self._output_to_prediction(y)
return self.predict(args1)
def predict(self, args: List[torch.Tensor]):
grouped_features= _get_dict(self.keys, args)
query = grouped_features["query"]
sequence = grouped_features["sequence"]
sequence_length = grouped_features["sequence_length"]
max_seq_length = sequence.size(1)
sequence_mask = _arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)
# attn_input = torch.cat(
# [queries, sequence, queries - sequence, queries * sequence], dim=-1
# )
return queries
model = MatMul2().eval().cuda()
a1=torch.randn(2, 41).cuda()
b1=torch.randn(2, 50,41).cuda()
c1=torch.randn(2).cuda()
inputs=[a1,b1,c1]
exp_program = torch.export.export(model, (inputs,))
# # ERROR
trt_gm = torch_tensorrt.dynamo.compile(exp_program, [inputs],min_block_size=2)
# # Run inference
# print(trt_gm(*inputs))
ERROR
Traceback (most recent call last):
File "/larec/tzrec/tests/test_2.py", line 64, in <module>
trt_gm = torch_tensorrt.dynamo.compile(exp_program, [inputs],min_block_size=2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 230, in compile
trt_gm = compile_module(gm, inputs, settings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 427, in compile_module
sample_outputs = gm(
^^^
File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 738, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 316, in __call__
raise e
File "/opt/conda/lib/python3.11/site-packages/torch/fx/graph_module.py", line 303, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/export/_unlift.py", line 33, in _check_input_constraints_pre_hook
return _check_input_constraints_for_graph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_export/utils.py", line 86, in _check_input_constraints_for_graph
raise RuntimeError(
RuntimeError: Expected input at *args[0][0] to be a tensor, but got <class 'torch_tensorrt._Input.Input'>
yjjinjie
changed the title
🐛 [Bug] Encountered bug when using Torch-TensorRT
🐛 [Bug] Encountered bug when using Torch-TensorRT--list inputs
Sep 4, 2024
@yjjinjie Can you try with our latest codebase ? I believe this issue should be resolved. I have tried the following script which works. I was unable to run your script because of from tzrec.modules.mlp import MLP
yes。I have solve it by using this exp_program = torch.export.export(model, (inputs,),dynamic_shapes=dynamic_shapes)
print(exp_program.graph)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs,min_block_size=1,allow_shape_tensors=True,assume_dynamic_shape_support=True)
Bug Description
In real-world scenarios, user features are constantly changing, so I must use a list as the input for the forward function.
but when I use list input, the torch_tensorrt.dynamo.compile(exp_program, [inputs],min_block_size=2) raise error
To Reproduce
ERROR
Environment
The text was updated successfully, but these errors were encountered: