Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Unable to compile RoBERTa #3335

Closed
umarbutler opened this issue Dec 20, 2024 · 6 comments
Closed

🐛 [Bug] Unable to compile RoBERTa #3335

umarbutler opened this issue Dec 20, 2024 · 6 comments
Labels
bug Something isn't working

Comments

@umarbutler
Copy link

umarbutler commented Dec 20, 2024

Bug Description

When I try compiling roberta-base, I get this error:

---------------------------------------------------------------------------
TorchRuntimeError                         Traceback (most recent call last)
Cell In[2], line 17
     14 input_ids = torch.stack([torch.tensor(input) for input in input_ids])
     15 attention_mask = torch.ones_like(input_ids)
---> 17 model = torch_tensorrt.compile(model, inputs = (input_ids, attention_mask))

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/_compile.py:266, in compile(module, ir, inputs, arg_inputs, kwarg_inputs, enabled_precisions, **kwargs)
    263 torchtrt_arg_inputs = prepare_inputs(arg_inputs)
    264 torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)
--> 266 exp_program = dynamo_trace(
    267     module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs
    268 )
    269 trt_graph_module = dynamo_compile(
    270     exp_program,
    271     arg_inputs=torchtrt_arg_inputs,
    272     enabled_precisions=enabled_precisions_set,
    273     **kwargs,
    274 )
    275 return trt_graph_module

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/_tracer.py:83, in trace(mod, inputs, arg_inputs, kwarg_inputs, **kwargs)
     81 dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs)
     82 dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs))
---> 83 exp_program = export(
     84     mod,
     85     tuple(torch_arg_inputs),
     86     kwargs=torch_kwarg_inputs,
     87     dynamic_shapes=dynamic_shapes,
     88 )
     90 return exp_program

File ~/dev/.venv/lib/python3.12/site-packages/torch/export/__init__.py:270, in export(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature)
    264 if isinstance(mod, torch.jit.ScriptModule):
    265     raise ValueError(
    266         "Exporting a ScriptModule is not supported. "
    267         "Maybe try converting your ScriptModule to an ExportedProgram "
    268         "using `TS2EPConverter(mod, args, kwargs).convert()` instead."
    269     )
--> 270 return _export(
    271     mod,
    272     args,
    273     kwargs,
    274     dynamic_shapes,
    275     strict=strict,
    276     preserve_module_call_signature=preserve_module_call_signature,
    277     pre_dispatch=True,
    278 )

File ~/dev/.venv/lib/python3.12/site-packages/torch/export/_trace.py:1017, in _log_export_wrapper.<locals>.wrapper(*args, **kwargs)
   1010     else:
   1011         log_export_usage(
   1012             event="export.error.unclassified",
   1013             type=error_type,
   1014             message=str(e),
   1015             flags=_EXPORT_FLAGS,
   1016         )
-> 1017     raise e
   1018 finally:
   1019     _EXPORT_FLAGS = None

File ~/dev/.venv/lib/python3.12/site-packages/torch/export/_trace.py:990, in _log_export_wrapper.<locals>.wrapper(*args, **kwargs)
    988 try:
    989     start = time.time()
--> 990     ep = fn(*args, **kwargs)
    991     end = time.time()
    992     log_export_usage(
    993         event="export.time",
    994         metrics=end - start,
    995         flags=_EXPORT_FLAGS,
    996         **get_ep_stats(ep),
    997     )

File ~/dev/.venv/lib/python3.12/site-packages/torch/export/exported_program.py:114, in _disable_prexisiting_fake_mode.<locals>.wrapper(*args, **kwargs)
    111 @functools.wraps(fn)
    112 def wrapper(*args, **kwargs):
    113     with unset_fake_temporarily():
--> 114         return fn(*args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/export/_trace.py:1880, in _export(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature, pre_dispatch, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace)
   1877 # Call the appropriate export function based on the strictness of tracing.
   1878 export_func = _strict_export if strict else _non_strict_export
-> 1880 export_artifact = export_func(  # type: ignore[operator]
   1881     mod,
   1882     args,
   1883     kwargs,
   1884     dynamic_shapes,
   1885     preserve_module_call_signature,
   1886     pre_dispatch,
   1887     original_state_dict,
   1888     original_in_spec,
   1889     allow_complex_guards_as_runtime_asserts,
   1890     _is_torch_jit_trace,
   1891 )
   1892 export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
   1894 forward_arg_names = (
   1895     _get_forward_arg_names(mod, args, kwargs) if not _is_torch_jit_trace else None
   1896 )

File ~/dev/.venv/lib/python3.12/site-packages/torch/export/_trace.py:1224, in _strict_export(mod, args, kwargs, dynamic_shapes, preserve_module_call_signature, pre_dispatch, original_state_dict, orig_in_spec, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace)
   1211 def _strict_export(
   1212     mod: torch.nn.Module,
   1213     args: Tuple[Any, ...],
   (...)
   1221     _is_torch_jit_trace: bool,
   1222 ) -> ExportArtifact:
   1223     lower_to_aten = functools.partial(_export_to_aten_ir, pre_dispatch=pre_dispatch)
-> 1224     return _strict_export_lower_to_aten_ir(
   1225         mod=mod,
   1226         args=args,
   1227         kwargs=kwargs,
   1228         dynamic_shapes=dynamic_shapes,
   1229         preserve_module_call_signature=preserve_module_call_signature,
   1230         pre_dispatch=pre_dispatch,
   1231         original_state_dict=original_state_dict,
   1232         orig_in_spec=orig_in_spec,
   1233         allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
   1234         _is_torch_jit_trace=_is_torch_jit_trace,
   1235         lower_to_aten_callback=lower_to_aten,
   1236     )

File ~/dev/.venv/lib/python3.12/site-packages/torch/export/_trace.py:1252, in _strict_export_lower_to_aten_ir(mod, args, kwargs, dynamic_shapes, preserve_module_call_signature, pre_dispatch, original_state_dict, orig_in_spec, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace, lower_to_aten_callback)
   1239 def _strict_export_lower_to_aten_ir(
   1240     mod: torch.nn.Module,
   1241     args: Tuple[Any, ...],
   (...)
   1250     lower_to_aten_callback: Callable,
   1251 ) -> ExportArtifact:
-> 1252     gm_torch_level = _export_to_torch_ir(
   1253         mod,
   1254         args,
   1255         kwargs,
   1256         dynamic_shapes,
   1257         preserve_module_call_signature=preserve_module_call_signature,
   1258         restore_fqn=False,  # don't need to restore because we will do it later
   1259         allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
   1260         _log_export_usage=False,
   1261     )
   1263     # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
   1264     (
   1265         fake_args,
   1266         fake_kwargs,
   1267         dynamo_fake_mode,
   1268     ) = _extract_fake_inputs(gm_torch_level, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/export/_trace.py:560, in _export_to_torch_ir(f, args, kwargs, dynamic_shapes, preserve_module_call_signature, disable_constraint_solver, allow_complex_guards_as_runtime_asserts, restore_fqn, _log_export_usage, same_signature)
    556     module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
    557     with _wrap_submodules(
    558         f, preserve_module_call_signature, module_call_specs
    559     ), _ignore_backend_decomps():
--> 560         gm_torch_level, _ = torch._dynamo.export(
    561             f,
    562             dynamic_shapes=transformed_dynamic_shapes,  # type: ignore[arg-type]
    563             tracing_mode="symbolic",
    564             disable_constraint_solver=disable_constraint_solver,
    565             # currently the following 2 flags are tied together for export purposes,
    566             # but untangle for sake of dynamo export api
    567             prefer_deferred_runtime_asserts_over_guards=True,
    568             allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
    569             _log_export_usage=_log_export_usage,
    570             same_signature=same_signature,
    571         )(
    572             *args,
    573             **kwargs,
    574         )
    575 except (ConstraintViolationError, ValueRangeError) as e:
    576     raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:1432, in export.<locals>.inner(*args, **kwargs)
   1430 # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
   1431 try:
-> 1432     result_traced = opt_f(*args, **kwargs)
   1433 except ConstraintViolationError as e:
   1434     constraint_violation_error = e

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:465, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    460 saved_dynamic_layer_stack_depth = (
    461     torch._C._functorch.get_dynamic_layer_stack_depth()
    462 )
    464 try:
--> 465     return fn(*args, **kwargs)
    466 finally:
    467     # Restore the dynamic layer stack depth if necessary.
    468     torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
    469         saved_dynamic_layer_stack_depth
    470     )

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1269, in CatchErrorsWrapper.__call__(self, frame, cache_entry, frame_state)
   1263             return hijacked_callback(
   1264                 frame, cache_entry, self.hooks, frame_state
   1265             )
   1267 with compile_lock, _disable_current_modes():
   1268     # skip=1: skip this frame
-> 1269     return self._torchdynamo_orig_callable(
   1270         frame, cache_entry, self.hooks, frame_state, skip=1
   1271     )

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:526, in ConvertFrameAssert.__call__(self, frame, cache_entry, hooks, frame_state, skip)
    510 compile_id = CompileId(frame_id, frame_compile_id)
    512 signpost_event(
    513     "dynamo",
    514     "_convert_frame_assert._compile",
   (...)
    523     },
    524 )
--> 526 return _compile(
    527     frame.f_code,
    528     frame.f_globals,
    529     frame.f_locals,
    530     frame.f_builtins,
    531     self._torchdynamo_orig_callable,
    532     self._one_graph,
    533     self._export,
    534     self._export_constraints,
    535     hooks,
    536     cache_entry,
    537     cache_size,
    538     frame,
    539     frame_state=frame_state,
    540     compile_id=compile_id,
    541     skip=skip + 1,
    542 )

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:924, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip)
    922 guarded_code = None
    923 try:
--> 924     guarded_code = compile_inner(code, one_graph, hooks, transform)
    925     return guarded_code
    926 except Exception as e:

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:666, in _compile.<locals>.compile_inner(code, one_graph, hooks, transform)
    664 with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"):
    665     with CompileTimeInstructionCounter.record():
--> 666         return _compile_inner(code, one_graph, hooks, transform)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_utils_internal.py:87, in compile_time_strobelight_meta.<locals>.compile_time_strobelight_meta_inner.<locals>.wrapper_function(*args, **kwargs)
     84     kwargs["skip"] = kwargs["skip"] + 1
     86 if not StrobelightCompileTimeProfiler.enabled:
---> 87     return function(*args, **kwargs)
     89 return StrobelightCompileTimeProfiler.profile_compile_time(
     90     function, phase_name, *args, **kwargs
     91 )

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:699, in _compile.<locals>._compile_inner(code, one_graph, hooks, transform)
    697 CompileContext.get().attempt = attempt
    698 try:
--> 699     out_code = transform_code_object(code, transform)
    700     break
    701 except exc.RestartAnalysis as e:

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py:1322, in transform_code_object(code, transformations, safe)
   1319 instructions = cleaned_instructions(code, safe)
   1320 propagate_line_nums(instructions)
-> 1322 transformations(instructions, code_options)
   1323 return clean_and_assemble_instructions(instructions, keys, code_options)[1]

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:219, in preserve_global_state.<locals>._fn(*args, **kwargs)
    215 exit_stack.enter_context(
    216     torch.fx._symbolic_trace._maybe_revert_all_patches()
    217 )
    218 try:
--> 219     return fn(*args, **kwargs)
    220 finally:
    221     cleanup.close()

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:634, in _compile.<locals>.transform(instructions, code_options)
    632 try:
    633     with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 634         tracer.run()
    635 except exc.UnspecializeRestartAnalysis:
    636     speculation_log.clear()

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2796, in InstructionTranslator.run(self)
   2795 def run(self):
-> 2796     super().run()

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:983, in InstructionTranslatorBase.run(self)
    981 try:
    982     self.output.push_tx(self)
--> 983     while self.step():
    984         pass
    985 except BackendCompilerFailed:

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:895, in InstructionTranslatorBase.step(self)
    892 self.update_block_stack(inst)
    894 try:
--> 895     self.dispatch_table[inst.opcode](self, inst)
    896     return not self.output.should_exit
    897 except exc.ObservedException as e:

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:582, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
    580     return handle_graph_break(self, inst, speculation.reason)
    581 try:
--> 582     return inner_fn(self, inst)
    583 except Unsupported as excp:
    584     if self.generic_context_manager_depth > 0:
    585         # We don't support graph break under GenericContextWrappingVariable,
    586         # If there is, we roll back to the checkpoint and fall back.

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2279, in InstructionTranslatorBase.CALL(self, inst)
   2277 @break_graph_if_unsupported(push=1)
   2278 def CALL(self, inst):
-> 2279     self._call(inst)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2273, in InstructionTranslatorBase._call(self, inst, call_kw)
   2268     kwargs = {}
   2270 try:
   2271     # if call_function fails, need to set kw_names to None, otherwise
   2272     # a subsequent call may have self.kw_names set to an old value
-> 2273     self.call_function(fn, args, kwargs)
   2274 finally:
   2275     self.kw_names = None

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:830, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
    828 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
    829     raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 830 self.push(fn.call_function(self, args, kwargs))

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py:442, in NNModuleVariable.call_function(self, tx, args, kwargs)
    440 else:
    441     assert istype(fn, types.FunctionType)
--> 442 return tx.inline_user_function_return(
    443     variables.UserFunctionVariable(fn, source=fn_source),
    444     args,
    445     kwargs,
    446 )

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:836, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
    832 def inline_user_function_return(self, fn, args, kwargs):
    833     """
    834     A call to some user defined function by inlining it.
    835     """
--> 836     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3011, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
   3008 @classmethod
   3009 def inline_call(cls, parent, func, args, kwargs):
   3010     with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3011         return cls.inline_call_(parent, func, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3139, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
   3137 try:
   3138     with strict_ctx:
-> 3139         tracer.run()
   3140 except exc.ObservedException as e:
   3141     msg = f"Observed exception DURING INLING {code} : {e}"

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:983, in InstructionTranslatorBase.run(self)
    981 try:
    982     self.output.push_tx(self)
--> 983     while self.step():
    984         pass
    985 except BackendCompilerFailed:

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:895, in InstructionTranslatorBase.step(self)
    892 self.update_block_stack(inst)
    894 try:
--> 895     self.dispatch_table[inst.opcode](self, inst)
    896     return not self.output.should_exit
    897 except exc.ObservedException as e:

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:582, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
    580     return handle_graph_break(self, inst, speculation.reason)
    581 try:
--> 582     return inner_fn(self, inst)
    583 except Unsupported as excp:
    584     if self.generic_context_manager_depth > 0:
    585         # We don't support graph break under GenericContextWrappingVariable,
    586         # If there is, we roll back to the checkpoint and fall back.

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1680, in InstructionTranslatorBase.CALL_FUNCTION_EX(self, inst)
   1678 # Map to a dictionary of str -> VariableTracker
   1679 kwargsvars = kwargsvars.keys_as_python_constant()
-> 1680 self.call_function(fn, argsvars.items, kwargsvars)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:830, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
    828 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
    829     raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 830 self.push(fn.call_function(self, args, kwargs))

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:385, in UserMethodVariable.call_function(self, tx, args, kwargs)
    383     fn = getattr(self.obj.value, self.fn.__name__)
    384     return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs)
--> 385 return super().call_function(tx, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:324, in UserFunctionVariable.call_function(self, tx, args, kwargs)
    319 if self.is_constant:
    320     return invoke_and_store_as_constant(
    321         tx, self.fn, self.get_name(), args, kwargs
    322     )
--> 324 return super().call_function(tx, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:111, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs)
    105 def call_function(
    106     self,
    107     tx: "InstructionTranslator",
    108     args: "List[VariableTracker]",
    109     kwargs: "Dict[str, VariableTracker]",
    110 ) -> "VariableTracker":
--> 111     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:836, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
    832 def inline_user_function_return(self, fn, args, kwargs):
    833     """
    834     A call to some user defined function by inlining it.
    835     """
--> 836     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3011, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
   3008 @classmethod
   3009 def inline_call(cls, parent, func, args, kwargs):
   3010     with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3011         return cls.inline_call_(parent, func, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3139, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
   3137 try:
   3138     with strict_ctx:
-> 3139         tracer.run()
   3140 except exc.ObservedException as e:
   3141     msg = f"Observed exception DURING INLING {code} : {e}"

    [... skipping similar frames: InstructionTranslatorBase.run at line 983 (1 times), InstructionTranslatorBase.step at line 895 (1 times), break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper at line 582 (1 times)]

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2279, in InstructionTranslatorBase.CALL(self, inst)
   2277 @break_graph_if_unsupported(push=1)
   2278 def CALL(self, inst):
-> 2279     self._call(inst)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2273, in InstructionTranslatorBase._call(self, inst, call_kw)
   2268     kwargs = {}
   2270 try:
   2271     # if call_function fails, need to set kw_names to None, otherwise
   2272     # a subsequent call may have self.kw_names set to an old value
-> 2273     self.call_function(fn, args, kwargs)
   2274 finally:
   2275     self.kw_names = None

    [... skipping similar frames: InstructionTranslatorBase.call_function at line 830 (1 times)]

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py:442, in NNModuleVariable.call_function(self, tx, args, kwargs)
    440 else:
    441     assert istype(fn, types.FunctionType)
--> 442 return tx.inline_user_function_return(
    443     variables.UserFunctionVariable(fn, source=fn_source),
    444     args,
    445     kwargs,
    446 )

    [... skipping similar frames: InliningInstructionTranslator.inline_call at line 3011 (1 times), InliningInstructionTranslator.inline_call_ at line 3139 (1 times), InstructionTranslatorBase.inline_user_function_return at line 836 (1 times), InstructionTranslatorBase.run at line 983 (1 times), InstructionTranslatorBase.step at line 895 (1 times), break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper at line 582 (1 times)]

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1680, in InstructionTranslatorBase.CALL_FUNCTION_EX(self, inst)
   1678 # Map to a dictionary of str -> VariableTracker
   1679 kwargsvars = kwargsvars.keys_as_python_constant()
-> 1680 self.call_function(fn, argsvars.items, kwargsvars)

    [... skipping similar frames: InstructionTranslatorBase.call_function at line 830 (1 times)]

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:385, in UserMethodVariable.call_function(self, tx, args, kwargs)
    383     fn = getattr(self.obj.value, self.fn.__name__)
    384     return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs)
--> 385 return super().call_function(tx, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:324, in UserFunctionVariable.call_function(self, tx, args, kwargs)
    319 if self.is_constant:
    320     return invoke_and_store_as_constant(
    321         tx, self.fn, self.get_name(), args, kwargs
    322     )
--> 324 return super().call_function(tx, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:111, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs)
    105 def call_function(
    106     self,
    107     tx: "InstructionTranslator",
    108     args: "List[VariableTracker]",
    109     kwargs: "Dict[str, VariableTracker]",
    110 ) -> "VariableTracker":
--> 111     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)

    [... skipping similar frames: InstructionTranslatorBase.call_function at line 830 (5 times), InliningInstructionTranslator.inline_call at line 3011 (5 times), InliningInstructionTranslator.inline_call_ at line 3139 (5 times), InstructionTranslatorBase.inline_user_function_return at line 836 (5 times), InstructionTranslatorBase.run at line 983 (5 times), InstructionTranslatorBase.step at line 895 (5 times), break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper at line 582 (5 times), InstructionTranslatorBase.CALL at line 2279 (3 times), InstructionTranslatorBase._call at line 2273 (3 times), InstructionTranslatorBase.CALL_FUNCTION_EX at line 1680 (2 times), NNModuleVariable.call_function at line 442 (2 times), UserMethodVariable.call_function at line 385 (2 times), UserFunctionVariable.call_function at line 324 (2 times), BaseUserFunctionVariable.call_function at line 111 (2 times)]

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py:442, in NNModuleVariable.call_function(self, tx, args, kwargs)
    440 else:
    441     assert istype(fn, types.FunctionType)
--> 442 return tx.inline_user_function_return(
    443     variables.UserFunctionVariable(fn, source=fn_source),
    444     args,
    445     kwargs,
    446 )

    [... skipping similar frames: InliningInstructionTranslator.inline_call at line 3011 (1 times), InliningInstructionTranslator.inline_call_ at line 3139 (1 times), InstructionTranslatorBase.inline_user_function_return at line 836 (1 times), InstructionTranslatorBase.run at line 983 (1 times), InstructionTranslatorBase.step at line 895 (1 times), break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper at line 582 (1 times)]

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1680, in InstructionTranslatorBase.CALL_FUNCTION_EX(self, inst)
   1678 # Map to a dictionary of str -> VariableTracker
   1679 kwargsvars = kwargsvars.keys_as_python_constant()
-> 1680 self.call_function(fn, argsvars.items, kwargsvars)

    [... skipping similar frames: InstructionTranslatorBase.call_function at line 830 (1 times)]

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:385, in UserMethodVariable.call_function(self, tx, args, kwargs)
    383     fn = getattr(self.obj.value, self.fn.__name__)
    384     return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs)
--> 385 return super().call_function(tx, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:324, in UserFunctionVariable.call_function(self, tx, args, kwargs)
    319 if self.is_constant:
    320     return invoke_and_store_as_constant(
    321         tx, self.fn, self.get_name(), args, kwargs
    322     )
--> 324 return super().call_function(tx, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:111, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs)
    105 def call_function(
    106     self,
    107     tx: "InstructionTranslator",
    108     args: "List[VariableTracker]",
    109     kwargs: "Dict[str, VariableTracker]",
    110 ) -> "VariableTracker":
--> 111     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:836, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
    832 def inline_user_function_return(self, fn, args, kwargs):
    833     """
    834     A call to some user defined function by inlining it.
    835     """
--> 836     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3011, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
   3008 @classmethod
   3009 def inline_call(cls, parent, func, args, kwargs):
   3010     with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3011         return cls.inline_call_(parent, func, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3139, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
   3137 try:
   3138     with strict_ctx:
-> 3139         tracer.run()
   3140 except exc.ObservedException as e:
   3141     msg = f"Observed exception DURING INLING {code} : {e}"

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:983, in InstructionTranslatorBase.run(self)
    981 try:
    982     self.output.push_tx(self)
--> 983     while self.step():
    984         pass
    985 except BackendCompilerFailed:

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:895, in InstructionTranslatorBase.step(self)
    892 self.update_block_stack(inst)
    894 try:
--> 895     self.dispatch_table[inst.opcode](self, inst)
    896     return not self.output.should_exit
    897 except exc.ObservedException as e:

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:582, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
    580     return handle_graph_break(self, inst, speculation.reason)
    581 try:
--> 582     return inner_fn(self, inst)
    583 except Unsupported as excp:
    584     if self.generic_context_manager_depth > 0:
    585         # We don't support graph break under GenericContextWrappingVariable,
    586         # If there is, we roll back to the checkpoint and fall back.

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2279, in InstructionTranslatorBase.CALL(self, inst)
   2277 @break_graph_if_unsupported(push=1)
   2278 def CALL(self, inst):
-> 2279     self._call(inst)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2273, in InstructionTranslatorBase._call(self, inst, call_kw)
   2268     kwargs = {}
   2270 try:
   2271     # if call_function fails, need to set kw_names to None, otherwise
   2272     # a subsequent call may have self.kw_names set to an old value
-> 2273     self.call_function(fn, args, kwargs)
   2274 finally:
   2275     self.kw_names = None

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:830, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
    828 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
    829     raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 830 self.push(fn.call_function(self, args, kwargs))

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py:897, in TorchInGraphFunctionVariable.call_function(self, tx, args, kwargs)
    888             if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
    889                 # Calling fake tensor propagation can mutate the out= tensor in
    890                 # tx.output.tracked_fakes. tracked_fakes are used to apply
   (...)
    893                 # guards. So save the shape now, and check later if it has
    894                 # changed. If it has, graph break.
    895                 fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
--> 897             tensor_variable = wrap_fx_proxy(
    898                 tx=tx,
    899                 proxy=tx.output.create_proxy(
    900                     "call_function",
    901                     fn_,
    902                     *proxy_args_kwargs(args, kwargs),
    903                 ),
    904             )
    906             if (
    907                 isinstance(tensor_variable, TensorVariable)
    908                 and "requires_grad" in kwargs
    909                 and kwargs["requires_grad"].as_python_constant()
    910             ):
    911                 unimplemented(
    912                     """factory functions that return tensors that require grad are not supported.
    913 Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
    914                 )

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py:2037, in wrap_fx_proxy(tx, proxy, example_value, subclass_type, **options)
   2029 kwargs = {
   2030     "tx": tx,
   2031     "proxy": proxy,
   (...)
   2034     **options,
   2035 }
   2036 if subclass_type is None:
-> 2037     return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
   2038 else:
   2039     result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py:2124, in wrap_fx_proxy_cls(target_cls, tx, proxy, example_value, subclass_type, **options)
   2119 with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
   2120     # with preserve_rng_state():
   2121     if example_value is None:
   2122         # only allow_non_graph_fake in this instance because we handle the non-fake
   2123         # cases properly below.
-> 2124         example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
   2126     # Handle recursive calls here
   2127     elif maybe_get_fake_mode(example_value) is tx.fake_mode:

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:2082, in get_fake_value(node, tx, allow_non_graph_fake)
   2079     elif isinstance(cause, TypeError) and "argument" in str(cause):
   2080         unimplemented(f"TypeError {node.target}: {cause}")
-> 2082     raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
   2084 if not allow_non_graph_fake:
   2085     _ = pytree.tree_map_only(
   2086         torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val
   2087     )

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:2017, in get_fake_value(node, tx, allow_non_graph_fake)
   2015 try:
   2016     with tx.fake_mode, enable_python_dispatcher():
-> 2017         ret_val = wrap_fake_exception(
   2018             lambda: run_node(tx.output, node, args, kwargs, nnmodule)
   2019         )
   2020 except Unsupported:
   2021     raise

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:1574, in wrap_fake_exception(fn)
   1572 def wrap_fake_exception(fn):
   1573     try:
-> 1574         return fn()
   1575     except UnsupportedFakeTensorException as e:
   1576         from .exc import unimplemented

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:2018, in get_fake_value.<locals>.<lambda>()
   2015 try:
   2016     with tx.fake_mode, enable_python_dispatcher():
   2017         ret_val = wrap_fake_exception(
-> 2018             lambda: run_node(tx.output, node, args, kwargs, nnmodule)
   2019         )
   2020 except Unsupported:
   2021     raise

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:2150, in run_node(tracer, node, args, kwargs, nnmodule)
   2148         unimplemented(make_error_message(e), from_exc=e)
   2149     except Exception as e:
-> 2150         raise RuntimeError(make_error_message(e)).with_traceback(
   2151             e.__traceback__
   2152         ) from e
   2154 raise AssertionError(op)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:2132, in run_node(tracer, node, args, kwargs, nnmodule)
   2130 try:
   2131     if op == "call_function":
-> 2132         return node.target(*args, **kwargs)
   2133     elif op == "call_method":
   2134         return getattr(args[0], node.target)(*args[1:], **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/utils/_stats.py:21, in count.<locals>.wrapper(*args, **kwargs)
     19     simple_call_counter[fn.__qualname__] = 0
     20 simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
---> 21 return fn(*args, **kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:1238, in FakeTensorMode.__torch_dispatch__(self, func, types, args, kwargs)
   1234 assert (
   1235     torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is None
   1236 ), func
   1237 try:
-> 1238     return self.dispatch(func, types, args, kwargs)
   1239 except TypeError:
   1240     log.exception("fake tensor raised TypeError")

File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:1692, in FakeTensorMode.dispatch(self, func, types, args, kwargs)
   1689         return func(*args, **kwargs)
   1691 if self.cache_enabled:
-> 1692     return self._cached_dispatch_impl(func, types, args, kwargs)
   1693 else:
   1694     return self._dispatch_impl(func, types, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:1339, in FakeTensorMode._cached_dispatch_impl(self, func, types, args, kwargs)
   1337 else:
   1338     self._validate_cache_key(func, args, kwargs)
-> 1339     output = self._dispatch_impl(func, types, args, kwargs)
   1340     entry = self._make_cache_entry(state, key, func, args, kwargs, output)
   1341     key.strip_shape_env()

File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:2021, in FakeTensorMode._dispatch_impl(self, func, types, args, kwargs)
   2017     log.exception("failed while attempting to run meta for %s", func)
   2018     raise
   2020 return maybe_propagate_real_tensors(
-> 2021     self.wrap_meta_outputs_with_default_device_logic(
   2022         r, func, flat_args, device=kwargs.get("device")
   2023     )
   2024 )

File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:2143, in FakeTensorMode.wrap_meta_outputs_with_default_device_logic(self, r, func, flat_args, device)
   2140     else:
   2141         return e
-> 2143 return tree_map(wrap, r)

File ~/dev/.venv/lib/python3.12/site-packages/torch/utils/_pytree.py:964, in tree_map(func, tree, is_leaf, *rests)
    962 leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
    963 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
--> 964 return treespec.unflatten(map(func, *flat_args))

File ~/dev/.venv/lib/python3.12/site-packages/torch/utils/_pytree.py:803, in TreeSpec.unflatten(self, leaves)
    801 def unflatten(self, leaves: Iterable[Any]) -> PyTree:
    802     if not isinstance(leaves, (list, tuple)):
--> 803         leaves = list(leaves)
    804     if len(leaves) != self.num_leaves:
    805         raise ValueError(
    806             f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
    807             f"but the spec refers to a pytree that holds {self.num_leaves} "
    808             f"items ({self}).",
    809         )

File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:2121, in FakeTensorMode.wrap_meta_outputs_with_default_device_logic.<locals>.wrap(e)
   2115     return e
   2117 if common_device is None:
   2118     (
   2119         common_device,
   2120         has_scalar_only_inputs,
-> 2121     ) = FakeTensor._find_common_device(func, flat_args)
   2123 is_our_fake = self.is_our_fake(e)
   2124 if is_our_fake:

File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:872, in FakeTensor._find_common_device(func, flat_args)
    867     raise RuntimeError(
    868         f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
    869     )
    871 for arg in flat_args:
--> 872     merge_devices(arg)
    874 # some functions that allow Python numbers to bind to Tensors
    875 # if we have failed to find a device, and we're running one of these operators,
    876 # we must have scalar only inputs
    877 if should_allow_numbers_as_tensors(func) and common_device is None:
    878     # ops with scalar only inputs always have result on cpu

File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:867, in FakeTensor._find_common_device.<locals>.merge_devices(t)
    863     return
    865 # mismatching devices of non-zero dim tensors, throw
    866 # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as
--> 867 raise RuntimeError(
    868     f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
    869 )

TorchRuntimeError: Failed running call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., size=(128, 12, 4, 64), grad_fn=<PermuteBackward0>), FakeTensor(..., size=(128, 12, 4, 64), grad_fn=<PermuteBackward0>), FakeTensor(..., size=(128, 12, 4, 64), grad_fn=<PermuteBackward0>)), **{'attn_mask': FakeTensor(..., device='cuda:0', size=(128, 1, 4, 4)), 'dropout_p': 0.0, 'is_causal': False}):
Unhandled FakeTensor Device Propagation for aten._scaled_dot_product_flash_attention_for_cpu.default, found two different devices cpu, cuda:0

from user code:
   File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 1318, in forward
    outputs = self.roberta(
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 976, in forward
    encoder_outputs = self.encoder(
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 631, in forward
    layer_outputs = layer_module(
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 520, in forward
    self_attention_outputs = self.attention(
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 447, in forward
    self_outputs = self.self(
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 370, in forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

To Reproduce

Run:

import torch
import torch_tensorrt

from transformers import AutoTokenizer, AutoModelForSequenceClassification

# BEGIN CONFIG #
MODEL_DIR = f'roberta-base'
# END CONFIG #

model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, attn_implementation = 'sdpa')
model = model.to('cuda')
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
input_ids = [tokenizer.encode('Hello world')] * 128
input_ids = torch.stack([torch.tensor(input) for input in input_ids]).to('cuda')
attention_mask = torch.ones_like(input_ids).to('cuda')

model = torch_tensorrt.compile(model, inputs = (input_ids, attention_mask))

Expected behavior

The compilation works.

Environment

WSL 2, Torch-TensorRT version 2.5.0, PyTorch verison 2.5.1, CUDA 12.4, Python 3.12.5

@umarbutler umarbutler added the bug Something isn't working label Dec 20, 2024
@HolyWu
Copy link
Contributor

HolyWu commented Dec 23, 2024

You didn't move model parameters and input tensors to CUDA device, hence the compilation failure.

@umarbutler
Copy link
Author

I get the same error with:

import torch
import torch_tensorrt

from transformers import AutoTokenizer, AutoModelForSequenceClassification

# BEGIN CONFIG #
MODEL_DIR = f'roberta-base'
# END CONFIG #

model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, attn_implementation = 'sdpa')
model = model.to('cuda')
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
input_ids = [tokenizer.encode('Hello world')] * 128
input_ids = torch.stack([torch.tensor(input) for input in input_ids]).to('cuda')
attention_mask = torch.ones_like(input_ids).to('cuda')

model = torch_tensorrt.compile(model, inputs = (input_ids, attention_mask))

Here is the traceback:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[12], line 18
     15 input_ids = torch.stack([torch.tensor(input) for input in input_ids]).to('cuda')
     16 attention_mask = torch.ones_like(input_ids).to('cuda')
---> 18 model = torch_tensorrt.compile(model, inputs = (input_ids, attention_mask))

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/_compile.py:269, in compile(module, ir, inputs, arg_inputs, kwarg_inputs, enabled_precisions, **kwargs)
    264     torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)
    266     exp_program = dynamo_trace(
    267         module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs
    268     )
--> 269     trt_graph_module = dynamo_compile(
    270         exp_program,
    271         arg_inputs=torchtrt_arg_inputs,
    272         enabled_precisions=enabled_precisions_set,
    273         **kwargs,
    274     )
    275     return trt_graph_module
    276 elif target_ir == _IRType.torch_compile:

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py:288, in compile(exported_program, inputs, arg_inputs, kwarg_inputs, device, disable_tf32, assume_dynamic_shape_support, sparse_weights, enabled_precisions, engine_capability, make_refittable, debug, num_avg_timing_iters, workspace_size, dla_sram_size, dla_local_dram_size, dla_global_dram_size, truncate_double, require_full_compilation, min_block_size, torch_executed_ops, torch_executed_modules, pass_through_build_failures, max_aux_streams, version_compatible, optimization_level, use_python_runtime, use_fast_partitioner, enable_experimental_decompositions, dryrun, hardware_compatible, timing_cache_path, lazy_engine_init, cache_built_engines, reuse_cached_engines, engine_cache_dir, engine_cache_size, custom_engine_cache, **kwargs)
    286 settings = CompilationSettings(**compilation_options)
    287 logger.info("Compilation Settings: %s\n", settings)
--> 288 trt_gm = compile_module(
    289     gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
    290 )
    291 return trt_gm

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py:464, in compile_module(gm, sample_arg_inputs, sample_kwarg_inputs, settings, engine_cache)
    462     # Create TRT engines from submodule
    463     if not settings.dryrun:
--> 464         trt_module = convert_module(
    465             submodule,
    466             submodule_inputs,
    467             settings=settings,
    468             name=name,
    469             engine_cache=engine_cache,
    470         )
    472         trt_modules[name] = trt_module
    474 # Parse the graph I/O and store it in dryrun tracker

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py:142, in convert_module(module, inputs, settings, name, engine_cache)
    125 def convert_module(
    126     module: torch.fx.GraphModule,
    127     inputs: Sequence[Input],
   (...)
    130     engine_cache: Optional[BaseEngineCache] = None,
    131 ) -> PythonTorchTensorRTModule | TorchTensorRTModule:
    132     """Convert an FX module to a TRT module
    133     Args:
    134         module: FX GraphModule to convert
   (...)
    140         PythonTorchTensorRTModule or TorchTensorRTModule
    141     """
--> 142     interpreter_result = interpret_module_to_result(
    143         module, inputs, settings, engine_cache=engine_cache
    144     )
    146     rt_cls = PythonTorchTensorRTModule
    148     if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime:

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py:121, in interpret_module_to_result(module, inputs, settings, arg_inputs, kwarg_inputs, engine_cache)
    105     output_dtypes = infer_module_output_dtypes(
    106         module,
    107         inputs,
    108         settings.device,
    109         truncate_double=settings.truncate_double,
    110     )
    112 interpreter = TRTInterpreter(
    113     module,
    114     inputs,
   (...)
    118     engine_cache=engine_cache,
    119 )
--> 121 interpreter_result = interpreter.run()
    122 return interpreter_result

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:616, in TRTInterpreter.run(self, strict_type_constraints, algorithm_selector, tactic_sources)
    607                 engine_str = engine_bytes.getvalue()
    609             return TRTInterpreterResult(
    610                 engine_str,
    611                 self._input_names,
    612                 self._output_names,
    613                 self.weight_name_map,
    614             )
--> 616 self._construct_trt_network_def()
    618 if self.compilation_settings.make_refittable:
    619     self._save_weight_mapping()

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:347, in TRTInterpreter._construct_trt_network_def(self)
    345 self.input_specs_iter = 0
    346 run_module_start_time = datetime.now()
--> 347 super().run()
    348 _LOGGER.info(
    349     f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}"
    350 )

File ~/dev/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py:146, in Interpreter.run(self, initial_env, enable_io_processing, *args)
    143     continue
    145 try:
--> 146     self.env[node] = self.run_node(node)
    147 except Exception as e:
    148     if self.extra_traceback:

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:682, in TRTInterpreter.run_node(self, n)
    677 if _LOGGER.isEnabledFor(logging.DEBUG):
    678     _LOGGER.debug(
    679         f"Converting node {self._cur_node_name} (kind: {n.target}, args: {TRTInterpreter._args_str(n.args)})"
    680     )
--> 682 trt_node: torch.fx.Node = super().run_node(n)
    684 if n.op == "get_attr":
    685     self.const_mapping[str(n)] = (tuple(trt_node.shape), str(trt_node.dtype))

File ~/dev/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py:203, in Interpreter.run_node(self, n)
    201 assert isinstance(args, tuple)
    202 assert isinstance(kwargs, dict)
--> 203 return getattr(self, n.op)(n.target, args, kwargs)

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:791, in TRTInterpreter.call_function(self, target, args, kwargs)
    789     return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)
    790 else:
--> 791     return converter(self.ctx, target, args, kwargs, self._cur_node_name)

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/converter_utils.py:526, in enforce_tensor_types.<locals>.wrapper.<locals>.convert_with_type_enforcement(ctx, target, args, kwargs, name)
    523     elif isinstance(index, str):
    524         new_kwargs[index] = new_value
--> 526 return func(ctx, target, new_args, new_kwargs, name)

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py:956, in aten_ops_cumsum(ctx, target, args, kwargs, name)
    939 @dynamo_tensorrt_converter(
    940     torch.ops.aten.cumsum.default,
    941     capability_validator=refit_validator,
   (...)
    954     name: str,
    955 ) -> Union[TRTTensor, Sequence[TRTTensor]]:
--> 956     return impl.slice.cumsum(
    957         ctx,
    958         target,
    959         SourceIR.ATEN,
    960         name,
    961         args[0],
    962         args[1],
    963     )

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/impl/slice/ops.py:374, in cumsum(ctx, target, source_ir, name, input, dim)
    372     new_dims = tuple(data.shape)
    373     zeros = np.zeros(new_dims)
--> 374     zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")
    376 running_sum = loop.add_recurrence(zero_trttensor)
    377 set_layer_name(running_sum, target, f"{name}_running_sum", source_ir)

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/converter_utils.py:385, in get_trt_tensor(ctx, input_val, name, dtype, min_rank)
    382         input_val = input_val.astype(np.float32)
    384 if isinstance(input_val, (torch.Tensor, np.ndarray, int, float, bool)):
--> 385     return create_constant(ctx, input_val, name, dtype, min_rank)
    386 elif isinstance(input_val, TRTTensor):
    387     return input_val

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/converter_utils.py:346, in create_constant(ctx, value, name, dtype, min_rank)
    344     shape = trt.Dims()
    345 numpy_value = to_numpy(value, dtype)
--> 346 constant = ctx.net.add_constant(
    347     shape if isinstance(value, (int, float, bool)) else value.shape,
    348     numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value,
    349 )
    350 constant.name = name
    351 return constant.get_output(0)

TypeError: add_constant(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorrt_bindings.tensorrt.INetworkDefinition, shape: tensorrt_bindings.tensorrt.Dims, weights: tensorrt_bindings.tensorrt.Weights) -> tensorrt_bindings.tensorrt.IConstantLayer

Invoked with: <tensorrt_bindings.tensorrt.INetworkDefinition object at 0x7ff446691fb0>, (128,), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0.])

While executing %cumsum : [num_users=1] = call_function[target=torch.ops.aten.cumsum.default](args = (%_to_copy, 1), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x7ff4398dc730>: ((128, 4), torch.int64, False, (4, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff43940d270>: ((128, 4), torch.bool, False, (4, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x7ff4398dcc70>: ((128, 4), torch.int32, False, (4, 1), torch.contiguous_format, False, {})}})
Original traceback:
  File "/home/umar/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 1318, in forward
    outputs = self.roberta(
  File "/home/umar/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 912, in forward
    embedding_output = self.embeddings(
  File "/home/umar/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 99, in forward
    position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
  File "/home/umar/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 1681, in create_position_ids_from_input_ids
    incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask

@HolyWu
Copy link
Contributor

HolyWu commented Dec 23, 2024

The error is actually different now and the bug has been fixed in #3258. You'll need to install the nightly version of torch/torchvision/torch_tensorrt at the moment.

@umarbutler
Copy link
Author

@HolyWu Upon installing the nightly version of torch_tensorrt, I now get:

---------------------------------------------------------------------------
OSError                                   Traceback (most recent call last)
Cell In[1], line 11
      9 import safetensors
     10 import transformers
---> 11 import torch_tensorrt
     12 import safetensors.torch
     14 from tqdm import tqdm

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/__init__.py:114
    110         assert ENABLED_FEATURES.torch_tensorrt_runtime
    111         torch.ops.load_library(linked_file_runtime_full_path)
--> 114 _register_with_torch()
    116 from torch_tensorrt._Device import Device  # noqa: F401
    117 from torch_tensorrt._enums import (  # noqa: F401
    118     DeviceType,
    119     EngineCapability,
   (...)
    122     memory_format,
    123 )

File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/__init__.py:107, in _register_with_torch()
    105     assert ENABLED_FEATURES.torchscript_frontend
    106     assert ENABLED_FEATURES.torch_tensorrt_runtime
--> 107     torch.ops.load_library(linked_file_full_path)
    109 elif os.path.isfile(linked_file_runtime_full_path):
    110     assert ENABLED_FEATURES.torch_tensorrt_runtime

File ~/dev/.venv/lib/python3.12/site-packages/torch/_ops.py:1356, in _Ops.load_library(self, path)
   1351 path = _utils_internal.resolve_library_path(path)
   1352 with dl_open_guard():
   1353     # Import the shared library into the process, thus running its
   1354     # static (global) initialization code in order to register custom
   1355     # operators with the JIT.
-> 1356     ctypes.CDLL(path)
   1357 self.loaded_libraries.add(path)

File ~/.pyenv/versions/3.12.5/lib/python3.12/ctypes/__init__.py:379, in CDLL.__init__(self, name, mode, handle, use_errno, use_last_error, winmode)
    376 self._FuncPtr = _FuncPtr
    378 if handle is None:
--> 379     self._handle = _dlopen(self._name, mode)
    380 else:
    381     self._handle = handle

OSError: /home/umar/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/lib/libtorchtrt.so: undefined symbol: _ZN3c106detail23torchInternalAssertFailEPKcS2_jS2_RKSs

@umarbutler
Copy link
Author

It looks like the nightly version also broke my transformers, flash-attention and probably some other stuff. I fixed it by --force-reinstalling torch and torch-tensorrt.

@umarbutler
Copy link
Author

umarbutler commented Dec 24, 2024

I solved it by manually editing the two files that were modified by the patch. Much easier than installing the nightly 😆

Surprisingly though the model is slower when compiled, go figure. It is much faster (it seems %timeit was not giving me accurate timings)

EDIT: TensorRT seems slower than uncompiled for bfloat16 AMP.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants