Skip to content

Commit

Permalink
chore: example fixes (#3176)
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 authored Dec 16, 2024
1 parent a66684c commit ade51b4
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 7 deletions.
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ Model Zoo
* :ref:`torch_compile_resnet`
* :ref:`torch_compile_transformer`
* :ref:`torch_compile_stable_diffusion`
* :ref:`torch_compile_gpt2`
* :ref:`torch_export_gpt2`
* :ref:`torch_export_llama2`
* :ref:`torch_export_sam2`
Expand All @@ -150,6 +151,7 @@ Model Zoo
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
tutorials/_rendered_examples/dynamo/torch_compile_gpt2
tutorials/_rendered_examples/dynamo/torch_export_gpt2
tutorials/_rendered_examples/dynamo/torch_export_llama2
tutorials/_rendered_examples/dynamo/torch_export_sam2
Expand Down
1 change: 1 addition & 0 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Model Zoo
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
* :ref:`_torch_compile_gpt2`: Compiling a GPT2 model using ``torch.compile``
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
117 changes: 117 additions & 0 deletions examples/dynamo/torch_compile_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
.. _torch_compile_gpt2:
Compiling GPT2 using the Torch-TensorRT ``torch.compile`` frontend
==========================================================
This example illustrates the state of the art model `GPT2 <https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf>`_ optimized using
``torch.compile`` frontend of Torch-TensorRT. Install the following dependencies before compilation
.. code-block:: python
pip install -r requirements.txt
GPT2 is a causal (unidirectional) transformer pretrained using language modeling on a very large corpus of text data. In this example, we use the GPT2 model available at `HuggingFace <https://huggingface.co/docs/transformers/en/model_doc/gpt2>`_ and apply torch.compile on it to
get the graph module representation of the graph. Torch-TensorRT converts this graph into an optimized TensorRT engine.
"""

# %%
# Import necessary libraries
# -----------------------------
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer

# %%
# Define the necessary parameters
# -----------------------------
# Torch-TensorRT requires a GPU for successful compilation of the model.
# ``MAX_LENGTH`` is the maximum length the generated tokens can have. This corresponds to the length of the input prompt +
# number of new tokens generated
MAX_LENGTH = 32
DEVICE = torch.device("cuda:0")

# %%
# Model definition
# -----------------------------
# We use ``AutoModelForCausalLM`` class to load the pretrained GPT2 model from hugging face. ``kv_cache`` is not supported in Torch-TRT currently so ``use_cache=False``
with torch.no_grad():
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = (
AutoModelForCausalLM.from_pretrained(
"gpt2",
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
attn_implementation="eager",
)
.eval()
.cuda()
)

# %%
# PyTorch inference
# -----------------------------
# Tokenize a sample input prompt and get pytorch model outputs
prompt = "I enjoy walking with my cute dog"
model_inputs = tokenizer(prompt, return_tensors="pt")
input_ids = model_inputs["input_ids"].cuda()

# %%
# The ``generate()`` API of the ``AutoModelForCausalLM`` class is used for auto-regressive generation with greedy decoding.
pyt_gen_tokens = model.generate(
input_ids,
max_length=MAX_LENGTH,
use_cache=False,
pad_token_id=tokenizer.eos_token_id,
)

# %%
# Torch-TensorRT compilation and inference
# -----------------------------
# The input sequence length is dynamic, so we mark it using ``torch._dynamo.mark_dynamic`` API.
# We provide a (min, max) range of this value so that TensorRT knows in advance what values to optimize for.
# Usually, this would be the context length for the model. We start with ``min=2`` due to the `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023)
model.forward = torch.compile(
model.forward,
backend="tensorrt",
dynamic=None,
options={
"enabled_precisions": {torch.float32},
"disable_tf32": True,
"min_block_size": 1,
},
)

# %%
# Auto-regressive generation loop for greedy decoding using TensorRT model
# The first token generation compiles the model using TensorRT and the second token
# encounters recompilation (which is an issue currently that would be resolved in the future)
trt_gen_tokens = model.generate(
inputs=input_ids,
max_length=MAX_LENGTH,
use_cache=False,
pad_token_id=tokenizer.eos_token_id,
)

# %%
# Decode the output sentences of PyTorch and TensorRT
# -----------------------------
print(
"Pytorch model generated text: ",
tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True),
)
print("=============================")
print(
"TensorRT model generated text: ",
tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True),
)

# %%
# The output sentences should look like

"""
Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll
=============================
TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll
"""
5 changes: 3 additions & 2 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def _pretraced_backend(
repair_input_aliasing(gm, settings)

# Remove sym_int placeholders and inputs
remove_sym_nodes(gm, settings)
remove_sym_nodes(gm, sample_inputs, settings)

torch_inputs = [
input for input in sample_inputs if isinstance(input, torch.Tensor)
]
Expand All @@ -91,7 +92,7 @@ def _pretraced_backend(
# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
gm,
torch_inputs,
sample_inputs,
trace_joint=False,
decompositions=get_decompositions(
settings.enable_experimental_decompositions
Expand Down
14 changes: 9 additions & 5 deletions py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any, Sequence

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
Expand All @@ -7,15 +8,17 @@


def remove_sym_nodes(
gm: torch.fx.GraphModule, settings: CompilationSettings
gm: torch.fx.GraphModule,
sample_inputs: Sequence[Any],
settings: CompilationSettings,
) -> torch.fx.GraphModule:
"""Remove sym_int placeholders which get inserted due to torch.compile's
dynamic=True behavior
"""
# Extract SymInt placeholder Tensors
placeholder_sym_ints = [
node
for node in gm.graph.nodes
placeholder_idx_sym_ints = [
(idx, node)
for idx, node in enumerate(gm.graph.nodes)
if (
node.op == "placeholder"
and isinstance(node.type, type)
Expand All @@ -24,8 +27,9 @@ def remove_sym_nodes(
)
]

for node in placeholder_sym_ints:
for idx, node in placeholder_idx_sym_ints:
gm.graph.erase_node(node)
sample_inputs.pop(idx)

gm.graph.lint()
gm.recompile()
Expand Down

0 comments on commit ade51b4

Please sign in to comment.