diff --git a/docsrc/index.rst b/docsrc/index.rst index 757acc2011..0bef2f0664 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -118,6 +118,8 @@ Tutorials tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2 tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example + tutorials/_rendered_examples/dynamo/torch_export_gpt2 + tutorials/_rendered_examples/dynamo/torch_export_llama2 Python API Documentation ------------------------ diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index ff3563cffe..6be2aa6515 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -1,15 +1,24 @@ .. _torch_compile: -Dynamo / ``torch.compile`` ----------------------------- +Torch-TensorRT Examples +==================================== -Torch-TensorRT provides a backend for the new ``torch.compile`` API released in PyTorch 2.0. In the following examples we describe -a number of ways you can leverage this backend to accelerate inference. +Please refer to the following examples which demonstrate the usage of different features of Torch-TensorRT. We also provide +examples of Torch-TensorRT compilation of select computer vision and language models. -* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile`` -* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` +Dependencies +------------------------------------ + +Please install the following external depencies (assuming you already have `torch_tensorrt` installed) + +.. code-block:: python + + pip install -r requirements.txt + + +Compiler Features +------------------------------------ * :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API -* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile`` * :ref:`torch_export_cudagraphs`: Using the Cudagraphs integration with `ir="dynamo"` * :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines * :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights @@ -17,3 +26,11 @@ a number of ways you can leverage this backend to accelerate inference. * :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile`` * :ref:`engine_caching_example`: Utilizing engine caching to speed up compilation times * :ref:`engine_caching_bert_example`: Demonstrating engine caching on BERT + +Model Zoo +------------------------------------ +* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile`` +* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` +* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile`` +* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`) +* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`) \ No newline at end of file diff --git a/examples/dynamo/requirements.txt b/examples/dynamo/requirements.txt index 6e53935186..41fe29f09c 100644 --- a/examples/dynamo/requirements.txt +++ b/examples/dynamo/requirements.txt @@ -1,4 +1,4 @@ cupy==13.1.0 -torch>=2.4.0.dev20240503+cu121 -torch-tensorrt>=2.4.0.dev20240503+cu121 triton==2.3.0 +diffusers==0.30.3 +transformers==4.44.2 diff --git a/examples/dynamo/torch_compile_gpt2.py b/examples/dynamo/torch_compile_gpt2.py new file mode 100644 index 0000000000..6c6e1b03a2 --- /dev/null +++ b/examples/dynamo/torch_compile_gpt2.py @@ -0,0 +1,100 @@ +""" +.. _torch_compile_gpt2: + +Compiling GPT2 using the Torch-TensorRT `torch.compile` Backend +========================================================== + +This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a GPT2 model.""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +import torch +import torch_tensorrt +from transformers import AutoModelForCausalLM, AutoTokenizer + +# %% + +# Define the parameters +MAX_TOKENS = 32 +DEVICE = torch.device("cuda:0") + +# Define the GPT2 model from hugging face +# kv_cache is not supported in Torch-TRT currently. +# CPU is used here so that GPU memory is reserved for TRT compilation. +with torch.no_grad(): + tokenizer = AutoTokenizer.from_pretrained("gpt2") + model = ( + AutoModelForCausalLM.from_pretrained( + "gpt2", + pad_token_id=tokenizer.eos_token_id, + use_cache=False, + attn_implementation="eager", + ) + .eval() + .cuda() + ) + +# %% +# Tokenize a sample input prompt and get pytorch model outputs +prompt = "I enjoy walking with my cute dog" +model_inputs = tokenizer(prompt, return_tensors="pt") +input_ids = model_inputs["input_ids"].cuda() + +# Auto-regressive generation loop for greedy search using PyTorch model. +pyt_gen_tokens = model.generate( + input_ids, + max_length=MAX_TOKENS, + use_cache=False, + pad_token_id=tokenizer.eos_token_id, +) + +# %% +# Compilation with `torch.compile` using tensorrt backend and generate TensorRT outputs +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Compile the model and mark the input sequence length to be dynamic +torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023) +model.forward = torch.compile( + model.forward, + backend="tensorrt", + dynamic=None, + options={ + "enabled_precisions": {torch.float32}, + "disable_tf32": True, + "min_block_size": 1, + "debug": True, + }, +) + +# Auto-regressive generation loop for greedy decoding using TensorRT model +# The first token generation compiles the model using TensorRT and the second token +# encounters recompilation +trt_gen_tokens = model.generate( + inputs=input_ids, + max_length=MAX_TOKENS, + use_cache=False, + pad_token_id=tokenizer.eos_token_id, +) + +# %% +# Decode the output sentences of PyTorch and TensorRT +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +print("=============================") +print( + "Pytorch model generated text: ", + tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True), +) +print("=============================") +print( + "TensorRT model generated text: ", + tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True), +) + +# %% +# The output sentences should look like + +# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll +# ============================= +# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll diff --git a/examples/dynamo/torch_compile_llama2.py b/examples/dynamo/torch_compile_llama2.py new file mode 100644 index 0000000000..40ddc97d2c --- /dev/null +++ b/examples/dynamo/torch_compile_llama2.py @@ -0,0 +1,89 @@ +""" +.. _torch_compile_gpt2: + +Compiling GPT2 using the Torch-TensorRT `torch.compile` Backend +========================================================== + +This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a GPT2 model.""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +import torch +import torch_tensorrt +from transformers import AutoModelForCausalLM, AutoTokenizer +from utils import generate + +# %% + +# Define the parameters +MAX_TOKENS = 32 +DEVICE = torch.device("cuda:0") + +# Define the GPT2 model from hugging face +# kv_cache is not supported in Torch-TRT currently. +# CPU is used here so that GPU memory is reserved for TRT compilation. +llama_path = "meta-llama/Llama-2-7b-chat-hf" +with torch.no_grad(): + model = AutoModelForCausalLM.from_pretrained( + llama_path, use_cache=False, attn_implementation="eager" + ).eval() + +tokenizer = AutoTokenizer.from_pretrained(llama_path) + +# %% +# Tokenize a sample input prompt and get pytorch model outputs +prompt = "I enjoy walking with my cute dog" +model_inputs = tokenizer(prompt, return_tensors="pt") +input_ids = model_inputs["input_ids"].cuda() + +# Auto-regressive generation loop for greedy search using PyTorch model. +# We use a custom generate function which is very similar to the huggingface one. +# pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) + +# %% +# Compilation with `torch.compile` using tensorrt backend and generate TensorRT outputs +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Compile the model and mark the input sequence length to be dynamic +with torch_tensorrt.logging.debug(): + torch._dynamo.mark_dynamic(input_ids, 1, min=7, max=1023) + model.forward = torch.compile( + model.forward, + backend="tensorrt", + dynamic=None, + options={ + "enabled_precisions": {torch.float32}, + "disable_tf32": True, + "debug": True, + # "use_python_runtime": True + }, + ) +model(input_ids) +breakpoint() +model(input_ids) +# Auto-regressive generation loop for greedy decoding using TensorRT model +# We use a custom generate function which is very similar to the huggingface one. +# Move inputs to GPU +input_ids = input_ids.to(DEVICE) +trt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) + +# %% +# Decode the output sentences of PyTorch and TensorRT +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +print("=============================") +print( + "Pytorch model generated text: ", + tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True), +) +print("=============================") +print( + "TensorRT model generated text: ", + tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True), +) + +# %% +# The output sentences should look like +# +# diff --git a/examples/dynamo/utils.py b/examples/dynamo/utils.py index 25ad99c12d..90f1f3b72c 100644 --- a/examples/dynamo/utils.py +++ b/examples/dynamo/utils.py @@ -51,7 +51,14 @@ def generate(model, input_seq, max_tokens, eos_token_id): ) while True: - outputs = model(input_seq) + outputs = model( + input_seq, + past_key_values=None, + position_ids=None, + attention_mask=None, + use_cache=False, + token_type_ids=None, + ) logits = outputs.logits next_token_logits = logits[:, -1, :] next_tokens = torch.argmax(next_token_logits, dim=-1) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 97aa2ec443..ecdb8821ba 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -288,6 +288,7 @@ def compile( trt_gm = compile_module( gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache ) + return trt_gm diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 605d963a50..02cc6242aa 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -80,7 +80,8 @@ def _pretraced_backend( repair_input_aliasing(gm) # Remove sym_int placeholders and inputs - remove_sym_nodes(gm) + remove_sym_nodes(gm, sample_inputs) + torch_inputs = [ input for input in sample_inputs if isinstance(input, torch.Tensor) ] @@ -91,7 +92,7 @@ def _pretraced_backend( # Invoke AOTAutograd to translate operators to aten gm = aot_export_joint_simple( gm, - torch_inputs, + sample_inputs, trace_joint=False, decompositions=get_decompositions( settings.enable_experimental_decompositions diff --git a/py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py b/py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py index 8adebc87f8..0042012761 100644 --- a/py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py +++ b/py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py @@ -1,18 +1,21 @@ import logging +from typing import Any, Sequence import torch logger = logging.getLogger(__name__) -def remove_sym_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: +def remove_sym_nodes( + gm: torch.fx.GraphModule, sample_inputs: Sequence[Any] +) -> torch.fx.GraphModule: """Remove sym_int placeholders which get inserted due to torch.compile's dynamic=True behavior """ # Extract SymInt placeholder Tensors - placeholder_sym_ints = [ - node - for node in gm.graph.nodes + placeholder_idx_sym_ints = [ + (idx, node) + for idx, node in enumerate(gm.graph.nodes) if ( node.op == "placeholder" and isinstance(node.type, type) @@ -21,8 +24,9 @@ def remove_sym_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: ) ] - for node in placeholder_sym_ints: + for idx, node in placeholder_idx_sym_ints: gm.graph.erase_node(node) + sample_inputs.pop(idx) gm.graph.lint() gm.recompile()