Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorial for AOTI Python runtime #2997

Merged
merged 26 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1dea278
Tutorial for AOTI Python runtime
agunapal Aug 12, 2024
cd09129
Apply suggestions from code review
agunapal Aug 13, 2024
3fa9b20
Addressed review comments and added a section on why AOTI Python
agunapal Aug 13, 2024
7c9edb7
Addressed review comments and added a section on why AOTI Python
agunapal Aug 13, 2024
9cba6fb
fixed spelling
agunapal Aug 13, 2024
a6f6cd9
fixed spelling
agunapal Aug 13, 2024
1375373
Apply suggestions from code review
agunapal Aug 16, 2024
7158985
Addressed review comment
agunapal Aug 16, 2024
53f5965
Changing to use g5.4xlarge machine
agunapal Aug 19, 2024
849c8e3
Merge branch 'main' into tutorial/aoti_python
agunapal Aug 19, 2024
4aa8399
Moved tutorial to recipe
agunapal Aug 19, 2024
39b3942
Merge branch 'tutorial/aoti_python' of https://github.com/agunapal/tu…
agunapal Aug 19, 2024
35c5dc8
addressed review comments
agunapal Aug 19, 2024
71acd96
Moved tutorial to recipe
agunapal Aug 19, 2024
7f5fde9
Change base image to nvidia devel image
agunapal Aug 20, 2024
790f762
Change base image to nvidia devel image
agunapal Aug 20, 2024
45df5d0
Update requirements
agunapal Aug 20, 2024
b268a3c
fixed formatting
agunapal Aug 20, 2024
b6c3a01
Merge branch 'main' into tutorial/aoti_python
agunapal Aug 20, 2024
6578d82
update to CUDA 12.4
agunapal Aug 20, 2024
9ee64d9
Merge branch 'tutorial/aoti_python' of https://github.com/agunapal/tu…
agunapal Aug 20, 2024
67bc080
Apply suggestions from code review
agunapal Aug 21, 2024
fc0ff5e
addressed review comments for formatting
agunapal Aug 21, 2024
85f2870
Update recipes_source/torch_export_aoti_python.py
svekars Aug 22, 2024
cb8ea23
Update recipes_source/torch_export_aoti_python.py
svekars Aug 22, 2024
194388e
Merge branch 'main' into tutorial/aoti_python
svekars Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .jenkins/metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
"intermediate_source/model_parallel_tutorial.py": {
"needs": "linux.16xlarge.nvidia.gpu"
},
"intermediate_source/torch_export_aoti_python.py": {
"needs": "linux.16xlarge.nvidia.gpu"
},
"advanced_source/pendulum.py": {
"needs": "linux.g5.4xlarge.nvidia.gpu",
"_comment": "need to be here for the compiling_optimizer_lr_scheduler.py to run."
Expand Down
3 changes: 2 additions & 1 deletion en-wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
ACL
ADI
AOT
AOTInductor
APIs
ATen
AVX
Expand Down Expand Up @@ -617,4 +618,4 @@ warmstarting
warmup
webp
wsi
wsis
wsis
225 changes: 225 additions & 0 deletions intermediate_source/torch_export_aoti_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# -*- coding: utf-8 -*-

"""
(Beta) ``torch.export`` AOTInductor Tutorial for Python runtime
===================================================
**Author:** Ankith Gunapal, Bin Bao
"""

######################################################################
#
# .. warning::
#
# ``torch._export.aot_compile`` and ``torch._export.aot_load`` are in Beta status and are subject to backwards compatibility
# breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime.
#
agunapal marked this conversation as resolved.
Show resolved Hide resolved
# It has been shown `previously <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`__ how AOTInductor can be used
# to do Ahead-of-Time compilation of PyTorch exported models by creating
# a shared library that can be run in a non-Python environment.
#
#
# In this tutorial, you will learn an end-to-end example of how to use AOTInductor for python runtime.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will make the story more complete by explaining the "why" part here, e.g. eliminating recompilation at run time, max-autotune ahead of time, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. Haven't mentioned eliminating recompilation, since the tutorial doesn't show that

# We will look at how to use :func:`torch._export.aot_compile` to generate a shared library.
# Additionally, we will examine how to execute the shared library in Python runtime using :func:`torch._export.aot_load`.
# You will learn about the speed up seen in the first inference time using AOTInductor, especially when using
# ``max-autotune`` mode which can take some time to execute.
#
# **Contents**
#
# .. contents::
# :local:

######################################################################
# Prerequisites
# -------------
# * PyTorch 2.4 or later
# * Basic understanding of ``torch._export`` and AOTInductor
# * Complete the `AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`_ tutorial

######################################################################
# What you will learn
# ----------------------
# * How to use AOTInductor for python runtime.
# * How to use :func:`torch._export.aot_compile` to generate a shared library
# * How to run a shared library in Python runtime using :func:`torch._export.aot_load`.
# * When do you use AOTInductor for python runtime

######################################################################
# Model Compilation
# ------------
agunapal marked this conversation as resolved.
Show resolved Hide resolved
#
agunapal marked this conversation as resolved.
Show resolved Hide resolved
# We will use the TorchVision pretrained `ResNet18` model and TorchInductor on the
# exported PyTorch program using :func:`torch._export.aot_compile`.
#
# .. note::
#
# This API also supports :func:`torch.compile` options like ``mode``
# This means that if used on a CUDA enabled device, you can, for example, set ``"max_autotune": True``
# which leverages Triton based matrix multiplications & convolutions, and enables CUDA graphs by default.
#
# We also specify ``dynamic_shapes`` for the batch dimension. In this example, ``min=2`` is not a bug and is
# explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__
agunapal marked this conversation as resolved.
Show resolved Hide resolved


import os
import torch
from torchvision.models import ResNet18_Weights, resnet18

model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()

with torch.inference_mode():

# Specify the generated shared library path
aot_compile_options = {
"aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"),
agunapal marked this conversation as resolved.
Show resolved Hide resolved
}
if torch.cuda.is_available():
device = "cuda"
aot_compile_options.update({"max_autotune": True})
else:
device = "cpu"
# We need to turn off the below optimizations to support batch_size = 16,
# which is treated like a special case
# https://github.com/pytorch/pytorch/pull/116152
torch.backends.mkldnn.set_flags(False)
torch.backends.nnpack.set_flags(False)
agunapal marked this conversation as resolved.
Show resolved Hide resolved

model = model.to(device=device)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)

# min=2 is not a bug and is explained in the 0/1 Specialization Problem
batch_dim = torch.export.Dim("batch", min=2, max=32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it is ok to use min=1 here, but we can't feed in an example input with batch size 1.

Copy link
Contributor Author

@agunapal agunapal Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An example with batch_size 1 is usually tried often, hence I set min=2

exported_program = torch.export.export(
model,
example_inputs,
# Specify the first dimension of the input x as dynamic
dynamic_shapes={"x": {0: batch_dim}},
)
so_path = torch._inductor.aot_compile(
exported_program.module(),
example_inputs,
# Specify the generated shared library path
options=aot_compile_options
)


######################################################################
# Model Inference in Python
# ------------
#
# Typically, the shared object generated above is used in a non-Python environment. In PyTorch 2.3,
# we added a new API called :func:`torch._export.aot_load` to load the shared library in the Python runtime.
# The API follows a structure similar to the :func:`torch.jit.load` API . You need to specify the path
# of the shared library and the device where it should be loaded.
# .. note::
#
# In the example above, we specified ``batch_size=1`` for inference and it still functions correctly even though we specified ``min=2`` in
# :func:`torch._export.aot_compile`.


import os
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model_so_path = os.path.join(os.getcwd(), "resnet18_pt2.so")

model = torch._export.aot_load(model_so_path, device)
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)

with torch.inference_mode():
output = model(example_inputs)

######################################################################
# When to use AOTInductor for Python Runtime
# ---------------------------------------
#
# One of the requirements for using AOTInductor is that the model shouldn't have any graph breaks.
# Once this requirement is met, the primary use case for using AOTInductor Python Runtime is for
# model deployment using Python.
# There are mainly two reasons why you would use AOTInductor Python Runtime:
#
# - ``torch._export.aot_compile`` generates a shared library. This is useful for model
# versioning for deployments and tracking model performance over time.
# - With :func:`torch.compile` being a JIT compiler, there is a warmup
# cost associated with the first compilation. Your deployment needs to account for the
# compilation time taken for the first inference. With AOTInductor, the compilation is
# done offline using ``torch._export.aot_compile``. The deployment would only load the
# shared library using ``torch._export.aot_load`` and run inference.
#
#
# The section below shows the speedup achieved with AOTInductor for first inference
#
# We define a utility function ``timed`` to measure the time taken for inference
#

import time
def timed(fn):
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for accurate
# measurement on CUDA enabled devices.
if torch.cuda.is_available():
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
else:
start = time.time()

result = fn()
if torch.cuda.is_available():
end.record()
torch.cuda.synchronize()
else:
end = time.time()

# Measure time taken to execute the function in miliseconds
if torch.cuda.is_available():
duration = start.elapsed_time(end)
else:
duration = (end - start) * 1000

return result, duration


######################################################################
# Lets measure the time for first inference using AOTInductor

torch._dynamo.reset()

model = torch._export.aot_load(model_so_path, device)
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)

with torch.inference_mode():
_, time_taken = timed(lambda: model(example_inputs))
print(f"Time taken for first inference for AOTInductor is {time_taken:.2f} ms")


######################################################################
# Lets measure the time for first inference using ``torch.compile``

torch._dynamo.reset()

model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)
model.eval()

model = torch.compile(model)
example_inputs = torch.randn(1, 3, 224, 224, device=device)

with torch.inference_mode():
_, time_taken = timed(lambda: model(example_inputs))
print(f"Time taken for first inference for torch.compile is {time_taken:.2f} ms")

######################################################################
# We see that there is a drastic speedup in first inference time using AOTInductor compared
# to ``torch.compile``

######################################################################
# Conclusion
# ----------
#
# In this recipe, we have learned how to effectively use the AOTInductor for Python runtime by
# compiling and loading a pretrained ``ResNet18`` model using the ``torch._export.aot_compile``
# and ``torch._export.aot_load`` APIs. This process demonstrates the practical application of
# generating a shared library and running it within a Python environment, even with dynamic shape
# considerations and device-specific optimizations. We also looked at the advantage of using
# AOTInductor in model deployments, with regards to speed up in first inference time.
Loading