Skip to content

Commit bb9949a

Browse files
agunapalsvekarsangelayi
authored andcommitted
Tutorial for AOTI Python runtime (#2997)
* Tutorial for AOTI Python runtime --------- Co-authored-by: Svetlana Karslioglu <[email protected]> Co-authored-by: Angela Yi <[email protected]>
1 parent 83298bb commit bb9949a

File tree

7 files changed

+237
-6
lines changed

7 files changed

+237
-6
lines changed

.ci/docker/build.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ IMAGE_NAME="$1"
1111
shift
1212

1313
export UBUNTU_VERSION="20.04"
14+
export CUDA_VERSION="12.4.1"
1415

15-
export BASE_IMAGE="ubuntu:${UBUNTU_VERSION}"
16+
export BASE_IMAGE="nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}"
1617
echo "Building ${IMAGE_NAME} Docker image"
1718

1819
docker build \

.ci/docker/common/common_utils.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ conda_run() {
2222
}
2323

2424
pip_install() {
25-
as_ci_user conda run -n py_$ANACONDA_PYTHON_VERSION pip install --progress-bar off $*
25+
as_ci_user conda run -n py_$ANACONDA_PYTHON_VERSION pip3 install --progress-bar off $*
2626
}

.ci/docker/requirements.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ pytorch-lightning
3030
torchx
3131
torchrl==0.5.0
3232
tensordict==0.5.0
33-
ax-platform>==0.4.0
34-
nbformat>==5.9.2
33+
ax-platform>=0.4.0
34+
nbformat>=5.9.2
3535
datasets
3636
transformers
3737
torchmultimodal-nightly # needs to be updated to stable as soon as it's avaialable
@@ -68,4 +68,4 @@ pygame==2.1.2
6868
pycocotools
6969
semilearn==0.3.2
7070
torchao==0.0.3
71-
segment_anything==1.0
71+
segment_anything==1.0

.jenkins/metadata.json

+3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
"intermediate_source/model_parallel_tutorial.py": {
2929
"needs": "linux.16xlarge.nvidia.gpu"
3030
},
31+
"recipes_source/torch_export_aoti_python.py": {
32+
"needs": "linux.g5.4xlarge.nvidia.gpu"
33+
},
3134
"advanced_source/pendulum.py": {
3235
"needs": "linux.g5.4xlarge.nvidia.gpu",
3336
"_comment": "need to be here for the compiling_optimizer_lr_scheduler.py to run."

en-wordlist.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
ACL
33
ADI
44
AOT
5+
AOTInductor
56
APIs
67
ATen
78
AVX
@@ -617,4 +618,4 @@ warmstarting
617618
warmup
618619
webp
619620
wsi
620-
wsis
621+
wsis

recipes_source/recipes_index.rst

+6
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
150150
:link: ../recipes/recipes/swap_tensors.html
151151
:tags: Basics
152152

153+
.. customcarditem::
154+
:header: torch.export AOTInductor Tutorial for Python runtime
155+
:card_description: Learn an end-to-end example of how to use AOTInductor for python runtime.
156+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
157+
:link: ../recipes/torch_export_aoti_python.html
158+
:tags: Basics
153159

154160
.. Interpretability
155161
+220
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
(Beta) ``torch.export`` AOTInductor Tutorial for Python runtime
5+
===============================================================
6+
**Author:** Ankith Gunapal, Bin Bao, Angela Yi
7+
"""
8+
9+
######################################################################
10+
#
11+
# .. warning::
12+
#
13+
# ``torch._inductor.aot_compile`` and ``torch._export.aot_load`` are in Beta status and are subject to backwards compatibility
14+
# breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime.
15+
#
16+
# It has been shown `previously <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`__ how AOTInductor can be used
17+
# to do Ahead-of-Time compilation of PyTorch exported models by creating
18+
# a shared library that can be run in a non-Python environment.
19+
#
20+
#
21+
# In this tutorial, you will learn an end-to-end example of how to use AOTInductor for python runtime.
22+
# We will look at how to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a
23+
# shared library. Additionally, we will examine how to execute the shared library in Python runtime using :func:`torch._export.aot_load`.
24+
# You will learn about the speed up seen in the first inference time using AOTInductor, especially when using
25+
# ``max-autotune`` mode which can take some time to execute.
26+
#
27+
# **Contents**
28+
#
29+
# .. contents::
30+
# :local:
31+
32+
######################################################################
33+
# Prerequisites
34+
# -------------
35+
# * PyTorch 2.4 or later
36+
# * Basic understanding of ``torch.export`` and AOTInductor
37+
# * Complete the `AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`_ tutorial
38+
39+
######################################################################
40+
# What you will learn
41+
# ----------------------
42+
# * How to use AOTInductor for python runtime.
43+
# * How to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a shared library
44+
# * How to run a shared library in Python runtime using :func:`torch._export.aot_load`.
45+
# * When do you use AOTInductor for python runtime
46+
47+
######################################################################
48+
# Model Compilation
49+
# -----------------
50+
#
51+
# We will use the TorchVision pretrained `ResNet18` model and TorchInductor on the
52+
# exported PyTorch program using :func:`torch._inductor.aot_compile`.
53+
#
54+
# .. note::
55+
#
56+
# This API also supports :func:`torch.compile` options like ``mode``
57+
# This means that if used on a CUDA enabled device, you can, for example, set ``"max_autotune": True``
58+
# which leverages Triton based matrix multiplications & convolutions, and enables CUDA graphs by default.
59+
#
60+
# We also specify ``dynamic_shapes`` for the batch dimension. In this example, ``min=2`` is not a bug and is
61+
# explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__
62+
63+
64+
import os
65+
import torch
66+
from torchvision.models import ResNet18_Weights, resnet18
67+
68+
model = resnet18(weights=ResNet18_Weights.DEFAULT)
69+
model.eval()
70+
71+
with torch.inference_mode():
72+
73+
# Specify the generated shared library path
74+
aot_compile_options = {
75+
"aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"),
76+
}
77+
if torch.cuda.is_available():
78+
device = "cuda"
79+
aot_compile_options.update({"max_autotune": True})
80+
else:
81+
device = "cpu"
82+
83+
model = model.to(device=device)
84+
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
85+
86+
# min=2 is not a bug and is explained in the 0/1 Specialization Problem
87+
batch_dim = torch.export.Dim("batch", min=2, max=32)
88+
exported_program = torch.export.export(
89+
model,
90+
example_inputs,
91+
# Specify the first dimension of the input x as dynamic
92+
dynamic_shapes={"x": {0: batch_dim}},
93+
)
94+
so_path = torch._inductor.aot_compile(
95+
exported_program.module(),
96+
example_inputs,
97+
# Specify the generated shared library path
98+
options=aot_compile_options
99+
)
100+
101+
102+
######################################################################
103+
# Model Inference in Python
104+
# -------------------------
105+
#
106+
# Typically, the shared object generated above is used in a non-Python environment. In PyTorch 2.3,
107+
# we added a new API called :func:`torch._export.aot_load` to load the shared library in the Python runtime.
108+
# The API follows a structure similar to the :func:`torch.jit.load` API . You need to specify the path
109+
# of the shared library and the device where it should be loaded.
110+
#
111+
# .. note::
112+
# In the example above, we specified ``batch_size=1`` for inference and it still functions correctly even though we specified ``min=2`` in
113+
# :func:`torch.export.export`.
114+
115+
116+
import os
117+
import torch
118+
119+
device = "cuda" if torch.cuda.is_available() else "cpu"
120+
model_so_path = os.path.join(os.getcwd(), "resnet18_pt2.so")
121+
122+
model = torch._export.aot_load(model_so_path, device)
123+
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
124+
125+
with torch.inference_mode():
126+
output = model(example_inputs)
127+
128+
######################################################################
129+
# When to use AOTInductor for Python Runtime
130+
# ------------------------------------------
131+
#
132+
# One of the requirements for using AOTInductor is that the model shouldn't have any graph breaks.
133+
# Once this requirement is met, the primary use case for using AOTInductor Python Runtime is for
134+
# model deployment using Python.
135+
# There are mainly two reasons why you would use AOTInductor Python Runtime:
136+
#
137+
# - ``torch._inductor.aot_compile`` generates a shared library. This is useful for model
138+
# versioning for deployments and tracking model performance over time.
139+
# - With :func:`torch.compile` being a JIT compiler, there is a warmup
140+
# cost associated with the first compilation. Your deployment needs to account for the
141+
# compilation time taken for the first inference. With AOTInductor, the compilation is
142+
# done offline using ``torch.export.export`` & ``torch._indutor.aot_compile``. The deployment
143+
# would only load the shared library using ``torch._export.aot_load`` and run inference.
144+
#
145+
#
146+
# The section below shows the speedup achieved with AOTInductor for first inference
147+
#
148+
# We define a utility function ``timed`` to measure the time taken for inference
149+
#
150+
151+
import time
152+
def timed(fn):
153+
# Returns the result of running `fn()` and the time it took for `fn()` to run,
154+
# in seconds. We use CUDA events and synchronization for accurate
155+
# measurement on CUDA enabled devices.
156+
if torch.cuda.is_available():
157+
start = torch.cuda.Event(enable_timing=True)
158+
end = torch.cuda.Event(enable_timing=True)
159+
start.record()
160+
else:
161+
start = time.time()
162+
163+
result = fn()
164+
if torch.cuda.is_available():
165+
end.record()
166+
torch.cuda.synchronize()
167+
else:
168+
end = time.time()
169+
170+
# Measure time taken to execute the function in miliseconds
171+
if torch.cuda.is_available():
172+
duration = start.elapsed_time(end)
173+
else:
174+
duration = (end - start) * 1000
175+
176+
return result, duration
177+
178+
179+
######################################################################
180+
# Lets measure the time for first inference using AOTInductor
181+
182+
torch._dynamo.reset()
183+
184+
model = torch._export.aot_load(model_so_path, device)
185+
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
186+
187+
with torch.inference_mode():
188+
_, time_taken = timed(lambda: model(example_inputs))
189+
print(f"Time taken for first inference for AOTInductor is {time_taken:.2f} ms")
190+
191+
192+
######################################################################
193+
# Lets measure the time for first inference using ``torch.compile``
194+
195+
torch._dynamo.reset()
196+
197+
model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)
198+
model.eval()
199+
200+
model = torch.compile(model)
201+
example_inputs = torch.randn(1, 3, 224, 224, device=device)
202+
203+
with torch.inference_mode():
204+
_, time_taken = timed(lambda: model(example_inputs))
205+
print(f"Time taken for first inference for torch.compile is {time_taken:.2f} ms")
206+
207+
######################################################################
208+
# We see that there is a drastic speedup in first inference time using AOTInductor compared
209+
# to ``torch.compile``
210+
211+
######################################################################
212+
# Conclusion
213+
# ----------
214+
#
215+
# In this recipe, we have learned how to effectively use the AOTInductor for Python runtime by
216+
# compiling and loading a pretrained ``ResNet18`` model using the ``torch._inductor.aot_compile``
217+
# and ``torch._export.aot_load`` APIs. This process demonstrates the practical application of
218+
# generating a shared library and running it within a Python environment, even with dynamic shape
219+
# considerations and device-specific optimizations. We also looked at the advantage of using
220+
# AOTInductor in model deployments, with regards to speed up in first inference time.

0 commit comments

Comments
 (0)