From ea2dfc67b42c24d999edff0235702a314aab466a Mon Sep 17 00:00:00 2001 From: Ankith Gunapal Date: Fri, 23 Aug 2024 10:58:18 -0700 Subject: [PATCH] Tutorial for AOTI Python runtime (#2997) * Tutorial for AOTI Python runtime --------- Co-authored-by: Svetlana Karslioglu Co-authored-by: Angela Yi --- .ci/docker/build.sh | 3 +- .ci/docker/common/common_utils.sh | 2 +- .ci/docker/requirements.txt | 6 +- .jenkins/metadata.json | 3 + en-wordlist.txt | 3 +- recipes_source/recipes_index.rst | 6 + recipes_source/torch_export_aoti_python.py | 220 +++++++++++++++++++++ 7 files changed, 237 insertions(+), 6 deletions(-) create mode 100644 recipes_source/torch_export_aoti_python.py diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 31f42fdbd8..c646b8f9a8 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -11,8 +11,9 @@ IMAGE_NAME="$1" shift export UBUNTU_VERSION="20.04" +export CUDA_VERSION="12.4.1" -export BASE_IMAGE="ubuntu:${UBUNTU_VERSION}" +export BASE_IMAGE="nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}" echo "Building ${IMAGE_NAME} Docker image" docker build \ diff --git a/.ci/docker/common/common_utils.sh b/.ci/docker/common/common_utils.sh index b20286a409..c7eabda555 100644 --- a/.ci/docker/common/common_utils.sh +++ b/.ci/docker/common/common_utils.sh @@ -22,5 +22,5 @@ conda_run() { } pip_install() { - as_ci_user conda run -n py_$ANACONDA_PYTHON_VERSION pip install --progress-bar off $* + as_ci_user conda run -n py_$ANACONDA_PYTHON_VERSION pip3 install --progress-bar off $* } diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index 00cf2f2103..9668b17fc3 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -30,8 +30,8 @@ pytorch-lightning torchx torchrl==0.5.0 tensordict==0.5.0 -ax-platform>==0.4.0 -nbformat>==5.9.2 +ax-platform>=0.4.0 +nbformat>=5.9.2 datasets transformers torchmultimodal-nightly # needs to be updated to stable as soon as it's avaialable @@ -68,4 +68,4 @@ pygame==2.1.2 pycocotools semilearn==0.3.2 torchao==0.0.3 -segment_anything==1.0 \ No newline at end of file +segment_anything==1.0 diff --git a/.jenkins/metadata.json b/.jenkins/metadata.json index 4814f9a7d2..2f1a9933aa 100644 --- a/.jenkins/metadata.json +++ b/.jenkins/metadata.json @@ -28,6 +28,9 @@ "intermediate_source/model_parallel_tutorial.py": { "needs": "linux.16xlarge.nvidia.gpu" }, + "recipes_source/torch_export_aoti_python.py": { + "needs": "linux.g5.4xlarge.nvidia.gpu" + }, "advanced_source/pendulum.py": { "needs": "linux.g5.4xlarge.nvidia.gpu", "_comment": "need to be here for the compiling_optimizer_lr_scheduler.py to run." diff --git a/en-wordlist.txt b/en-wordlist.txt index 62762ab69c..e69cbaa1a5 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -2,6 +2,7 @@ ACL ADI AOT +AOTInductor APIs ATen AVX @@ -617,4 +618,4 @@ warmstarting warmup webp wsi -wsis \ No newline at end of file +wsis diff --git a/recipes_source/recipes_index.rst b/recipes_source/recipes_index.rst index d94d7d5c22..caccdcc28f 100644 --- a/recipes_source/recipes_index.rst +++ b/recipes_source/recipes_index.rst @@ -150,6 +150,12 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu :link: ../recipes/recipes/swap_tensors.html :tags: Basics +.. customcarditem:: + :header: torch.export AOTInductor Tutorial for Python runtime + :card_description: Learn an end-to-end example of how to use AOTInductor for python runtime. + :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: ../recipes/torch_export_aoti_python.html + :tags: Basics .. Interpretability diff --git a/recipes_source/torch_export_aoti_python.py b/recipes_source/torch_export_aoti_python.py new file mode 100644 index 0000000000..136862078c --- /dev/null +++ b/recipes_source/torch_export_aoti_python.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- + +""" +(Beta) ``torch.export`` AOTInductor Tutorial for Python runtime +=============================================================== +**Author:** Ankith Gunapal, Bin Bao, Angela Yi +""" + +###################################################################### +# +# .. warning:: +# +# ``torch._inductor.aot_compile`` and ``torch._export.aot_load`` are in Beta status and are subject to backwards compatibility +# breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime. +# +# It has been shown `previously `__ how AOTInductor can be used +# to do Ahead-of-Time compilation of PyTorch exported models by creating +# a shared library that can be run in a non-Python environment. +# +# +# In this tutorial, you will learn an end-to-end example of how to use AOTInductor for python runtime. +# We will look at how to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a +# shared library. Additionally, we will examine how to execute the shared library in Python runtime using :func:`torch._export.aot_load`. +# You will learn about the speed up seen in the first inference time using AOTInductor, especially when using +# ``max-autotune`` mode which can take some time to execute. +# +# **Contents** +# +# .. contents:: +# :local: + +###################################################################### +# Prerequisites +# ------------- +# * PyTorch 2.4 or later +# * Basic understanding of ``torch.export`` and AOTInductor +# * Complete the `AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models `_ tutorial + +###################################################################### +# What you will learn +# ---------------------- +# * How to use AOTInductor for python runtime. +# * How to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a shared library +# * How to run a shared library in Python runtime using :func:`torch._export.aot_load`. +# * When do you use AOTInductor for python runtime + +###################################################################### +# Model Compilation +# ----------------- +# +# We will use the TorchVision pretrained `ResNet18` model and TorchInductor on the +# exported PyTorch program using :func:`torch._inductor.aot_compile`. +# +# .. note:: +# +# This API also supports :func:`torch.compile` options like ``mode`` +# This means that if used on a CUDA enabled device, you can, for example, set ``"max_autotune": True`` +# which leverages Triton based matrix multiplications & convolutions, and enables CUDA graphs by default. +# +# We also specify ``dynamic_shapes`` for the batch dimension. In this example, ``min=2`` is not a bug and is +# explained in `The 0/1 Specialization Problem `__ + + +import os +import torch +from torchvision.models import ResNet18_Weights, resnet18 + +model = resnet18(weights=ResNet18_Weights.DEFAULT) +model.eval() + +with torch.inference_mode(): + + # Specify the generated shared library path + aot_compile_options = { + "aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"), + } + if torch.cuda.is_available(): + device = "cuda" + aot_compile_options.update({"max_autotune": True}) + else: + device = "cpu" + + model = model.to(device=device) + example_inputs = (torch.randn(2, 3, 224, 224, device=device),) + + # min=2 is not a bug and is explained in the 0/1 Specialization Problem + batch_dim = torch.export.Dim("batch", min=2, max=32) + exported_program = torch.export.export( + model, + example_inputs, + # Specify the first dimension of the input x as dynamic + dynamic_shapes={"x": {0: batch_dim}}, + ) + so_path = torch._inductor.aot_compile( + exported_program.module(), + example_inputs, + # Specify the generated shared library path + options=aot_compile_options + ) + + +###################################################################### +# Model Inference in Python +# ------------------------- +# +# Typically, the shared object generated above is used in a non-Python environment. In PyTorch 2.3, +# we added a new API called :func:`torch._export.aot_load` to load the shared library in the Python runtime. +# The API follows a structure similar to the :func:`torch.jit.load` API . You need to specify the path +# of the shared library and the device where it should be loaded. +# +# .. note:: +# In the example above, we specified ``batch_size=1`` for inference and it still functions correctly even though we specified ``min=2`` in +# :func:`torch.export.export`. + + +import os +import torch + +device = "cuda" if torch.cuda.is_available() else "cpu" +model_so_path = os.path.join(os.getcwd(), "resnet18_pt2.so") + +model = torch._export.aot_load(model_so_path, device) +example_inputs = (torch.randn(1, 3, 224, 224, device=device),) + +with torch.inference_mode(): + output = model(example_inputs) + +###################################################################### +# When to use AOTInductor for Python Runtime +# ------------------------------------------ +# +# One of the requirements for using AOTInductor is that the model shouldn't have any graph breaks. +# Once this requirement is met, the primary use case for using AOTInductor Python Runtime is for +# model deployment using Python. +# There are mainly two reasons why you would use AOTInductor Python Runtime: +# +# - ``torch._inductor.aot_compile`` generates a shared library. This is useful for model +# versioning for deployments and tracking model performance over time. +# - With :func:`torch.compile` being a JIT compiler, there is a warmup +# cost associated with the first compilation. Your deployment needs to account for the +# compilation time taken for the first inference. With AOTInductor, the compilation is +# done offline using ``torch.export.export`` & ``torch._indutor.aot_compile``. The deployment +# would only load the shared library using ``torch._export.aot_load`` and run inference. +# +# +# The section below shows the speedup achieved with AOTInductor for first inference +# +# We define a utility function ``timed`` to measure the time taken for inference +# + +import time +def timed(fn): + # Returns the result of running `fn()` and the time it took for `fn()` to run, + # in seconds. We use CUDA events and synchronization for accurate + # measurement on CUDA enabled devices. + if torch.cuda.is_available(): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + else: + start = time.time() + + result = fn() + if torch.cuda.is_available(): + end.record() + torch.cuda.synchronize() + else: + end = time.time() + + # Measure time taken to execute the function in miliseconds + if torch.cuda.is_available(): + duration = start.elapsed_time(end) + else: + duration = (end - start) * 1000 + + return result, duration + + +###################################################################### +# Lets measure the time for first inference using AOTInductor + +torch._dynamo.reset() + +model = torch._export.aot_load(model_so_path, device) +example_inputs = (torch.randn(1, 3, 224, 224, device=device),) + +with torch.inference_mode(): + _, time_taken = timed(lambda: model(example_inputs)) + print(f"Time taken for first inference for AOTInductor is {time_taken:.2f} ms") + + +###################################################################### +# Lets measure the time for first inference using ``torch.compile`` + +torch._dynamo.reset() + +model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device) +model.eval() + +model = torch.compile(model) +example_inputs = torch.randn(1, 3, 224, 224, device=device) + +with torch.inference_mode(): + _, time_taken = timed(lambda: model(example_inputs)) + print(f"Time taken for first inference for torch.compile is {time_taken:.2f} ms") + +###################################################################### +# We see that there is a drastic speedup in first inference time using AOTInductor compared +# to ``torch.compile`` + +###################################################################### +# Conclusion +# ---------- +# +# In this recipe, we have learned how to effectively use the AOTInductor for Python runtime by +# compiling and loading a pretrained ``ResNet18`` model using the ``torch._inductor.aot_compile`` +# and ``torch._export.aot_load`` APIs. This process demonstrates the practical application of +# generating a shared library and running it within a Python environment, even with dynamic shape +# considerations and device-specific optimizations. We also looked at the advantage of using +# AOTInductor in model deployments, with regards to speed up in first inference time.