You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, the above complains because the copy validator in the converter does not support uint8 as a valid input data type. That is because TRT does not support uint8 in its operations. For example if you run the below code (note that there are some changes in the API names) for onnx conversion and then loading in TRT
import torch
import torch_tensorrt
import tensorrt as trt
from torch import nn
class dummy_t(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor):
y = x.clamp_(0, 1).mul_(255).to(dtype=torch.uint8)
return torch.mul(y,1)
xs = [torch.randn((1,3,5,7)).cuda()]
exported = torch.export.export(
dummy_t().cuda(),
args=tuple(xs)
)
from tempfile import NamedTemporaryFile
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
import io
import onnx
output_names = ['output0']
input_names = ["x"]
with NamedTemporaryFile() as f:
onnx_program = torch.onnx.export(
dummy_t().cuda(),
tuple(xs),
f.name,
verbose=False,
opset_version=20,
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
# https://github.com/pytorch/pytorch/issues/73843
input_names=input_names,
output_names=output_names,
dynamo=False,
training=torch.onnx.TrainingMode.EVAL # we can export trainable model!
)
model_onnx: onnx.ModelProto
model_onnx = onnx.load(f.name)
workspace = 10*1024**2
trt_logger = trt.Logger(trt.Logger.INFO)
trt_logger.min_severity = trt.Logger.Severity.VERBOSE
builder = trt.Builder(trt_logger)
config = builder.create_builder_config()
config.set_memory_pool_limit(pool=trt.MemoryPoolType.WORKSPACE, pool_size=workspace)
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, trt_logger)
if not parser.parse(model_onnx.SerializeToString()):
raise RuntimeError(f'failed to load ONNX model')
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
with builder.build_serialized_network(network, config) as engine, io.BytesIO() as engine_bytes: # type: ignore
engine_bytes.write(engine)
engine_bytes.seek(0)
serialized_trt_engine = engine_bytes.read()
pt_trt = PythonTorchTensorRTModule(
serialized_trt_engine,
input_binding_names=input_names,
output_binding_names=["output0"]
)
print(pt_trt(*xs).dtype)
The above will fail with
[11/14/2024-20:39:08] [TRT] [V] Static check for parsing node: /Mul_1 [Mul]
Traceback (most recent call last):
File "/code/torchTRT/TensorRT/issue_3247.py", line 62, in <module>
raise RuntimeError(f'failed to load ONNX model')
RuntimeError: failed to load ONNX model
Bug Description
it report _to_copy is not supported, when we use things like
x.to(torch.uint8)
But actually tensorrt support this operation, if we convert to onnx then load onnx in tensorrt
To Reproduce
Steps to reproduce the behavior:
Expected behavior
Environment
conda
,pip
,libtorch
, source): pipAdditional context
The text was updated successfully, but these errors were encountered: