|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +""" |
| 4 | +(Beta) ``torch.export`` AOTInductor Tutorial for Python runtime |
| 5 | +=============================================================== |
| 6 | +**Author:** Ankith Gunapal, Bin Bao, Angela Yi |
| 7 | +""" |
| 8 | + |
| 9 | +###################################################################### |
| 10 | +# |
| 11 | +# .. warning:: |
| 12 | +# |
| 13 | +# ``torch._inductor.aot_compile`` and ``torch._export.aot_load`` are in Beta status and are subject to backwards compatibility |
| 14 | +# breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime. |
| 15 | +# |
| 16 | +# It has been shown `previously <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`__ how AOTInductor can be used |
| 17 | +# to do Ahead-of-Time compilation of PyTorch exported models by creating |
| 18 | +# a shared library that can be run in a non-Python environment. |
| 19 | +# |
| 20 | +# |
| 21 | +# In this tutorial, you will learn an end-to-end example of how to use AOTInductor for python runtime. |
| 22 | +# We will look at how to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a |
| 23 | +# shared library. Additionally, we will examine how to execute the shared library in Python runtime using :func:`torch._export.aot_load`. |
| 24 | +# You will learn about the speed up seen in the first inference time using AOTInductor, especially when using |
| 25 | +# ``max-autotune`` mode which can take some time to execute. |
| 26 | +# |
| 27 | +# **Contents** |
| 28 | +# |
| 29 | +# .. contents:: |
| 30 | +# :local: |
| 31 | + |
| 32 | +###################################################################### |
| 33 | +# Prerequisites |
| 34 | +# ------------- |
| 35 | +# * PyTorch 2.4 or later |
| 36 | +# * Basic understanding of ``torch.export`` and AOTInductor |
| 37 | +# * Complete the `AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`_ tutorial |
| 38 | + |
| 39 | +###################################################################### |
| 40 | +# What you will learn |
| 41 | +# ---------------------- |
| 42 | +# * How to use AOTInductor for python runtime. |
| 43 | +# * How to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a shared library |
| 44 | +# * How to run a shared library in Python runtime using :func:`torch._export.aot_load`. |
| 45 | +# * When do you use AOTInductor for python runtime |
| 46 | + |
| 47 | +###################################################################### |
| 48 | +# Model Compilation |
| 49 | +# ----------------- |
| 50 | +# |
| 51 | +# We will use the TorchVision pretrained `ResNet18` model and TorchInductor on the |
| 52 | +# exported PyTorch program using :func:`torch._inductor.aot_compile`. |
| 53 | +# |
| 54 | +# .. note:: |
| 55 | +# |
| 56 | +# This API also supports :func:`torch.compile` options like ``mode`` |
| 57 | +# This means that if used on a CUDA enabled device, you can, for example, set ``"max_autotune": True`` |
| 58 | +# which leverages Triton based matrix multiplications & convolutions, and enables CUDA graphs by default. |
| 59 | +# |
| 60 | +# We also specify ``dynamic_shapes`` for the batch dimension. In this example, ``min=2`` is not a bug and is |
| 61 | +# explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__ |
| 62 | + |
| 63 | + |
| 64 | +import os |
| 65 | +import torch |
| 66 | +from torchvision.models import ResNet18_Weights, resnet18 |
| 67 | + |
| 68 | +model = resnet18(weights=ResNet18_Weights.DEFAULT) |
| 69 | +model.eval() |
| 70 | + |
| 71 | +with torch.inference_mode(): |
| 72 | + |
| 73 | + # Specify the generated shared library path |
| 74 | + aot_compile_options = { |
| 75 | + "aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"), |
| 76 | + } |
| 77 | + if torch.cuda.is_available(): |
| 78 | + device = "cuda" |
| 79 | + aot_compile_options.update({"max_autotune": True}) |
| 80 | + else: |
| 81 | + device = "cpu" |
| 82 | + |
| 83 | + model = model.to(device=device) |
| 84 | + example_inputs = (torch.randn(2, 3, 224, 224, device=device),) |
| 85 | + |
| 86 | + # min=2 is not a bug and is explained in the 0/1 Specialization Problem |
| 87 | + batch_dim = torch.export.Dim("batch", min=2, max=32) |
| 88 | + exported_program = torch.export.export( |
| 89 | + model, |
| 90 | + example_inputs, |
| 91 | + # Specify the first dimension of the input x as dynamic |
| 92 | + dynamic_shapes={"x": {0: batch_dim}}, |
| 93 | + ) |
| 94 | + so_path = torch._inductor.aot_compile( |
| 95 | + exported_program.module(), |
| 96 | + example_inputs, |
| 97 | + # Specify the generated shared library path |
| 98 | + options=aot_compile_options |
| 99 | + ) |
| 100 | + |
| 101 | + |
| 102 | +###################################################################### |
| 103 | +# Model Inference in Python |
| 104 | +# ------------------------- |
| 105 | +# |
| 106 | +# Typically, the shared object generated above is used in a non-Python environment. In PyTorch 2.3, |
| 107 | +# we added a new API called :func:`torch._export.aot_load` to load the shared library in the Python runtime. |
| 108 | +# The API follows a structure similar to the :func:`torch.jit.load` API . You need to specify the path |
| 109 | +# of the shared library and the device where it should be loaded. |
| 110 | +# |
| 111 | +# .. note:: |
| 112 | +# In the example above, we specified ``batch_size=1`` for inference and it still functions correctly even though we specified ``min=2`` in |
| 113 | +# :func:`torch.export.export`. |
| 114 | + |
| 115 | + |
| 116 | +import os |
| 117 | +import torch |
| 118 | + |
| 119 | +device = "cuda" if torch.cuda.is_available() else "cpu" |
| 120 | +model_so_path = os.path.join(os.getcwd(), "resnet18_pt2.so") |
| 121 | + |
| 122 | +model = torch._export.aot_load(model_so_path, device) |
| 123 | +example_inputs = (torch.randn(1, 3, 224, 224, device=device),) |
| 124 | + |
| 125 | +with torch.inference_mode(): |
| 126 | + output = model(example_inputs) |
| 127 | + |
| 128 | +###################################################################### |
| 129 | +# When to use AOTInductor for Python Runtime |
| 130 | +# ------------------------------------------ |
| 131 | +# |
| 132 | +# One of the requirements for using AOTInductor is that the model shouldn't have any graph breaks. |
| 133 | +# Once this requirement is met, the primary use case for using AOTInductor Python Runtime is for |
| 134 | +# model deployment using Python. |
| 135 | +# There are mainly two reasons why you would use AOTInductor Python Runtime: |
| 136 | +# |
| 137 | +# - ``torch._inductor.aot_compile`` generates a shared library. This is useful for model |
| 138 | +# versioning for deployments and tracking model performance over time. |
| 139 | +# - With :func:`torch.compile` being a JIT compiler, there is a warmup |
| 140 | +# cost associated with the first compilation. Your deployment needs to account for the |
| 141 | +# compilation time taken for the first inference. With AOTInductor, the compilation is |
| 142 | +# done offline using ``torch.export.export`` & ``torch._indutor.aot_compile``. The deployment |
| 143 | +# would only load the shared library using ``torch._export.aot_load`` and run inference. |
| 144 | +# |
| 145 | +# |
| 146 | +# The section below shows the speedup achieved with AOTInductor for first inference |
| 147 | +# |
| 148 | +# We define a utility function ``timed`` to measure the time taken for inference |
| 149 | +# |
| 150 | + |
| 151 | +import time |
| 152 | +def timed(fn): |
| 153 | + # Returns the result of running `fn()` and the time it took for `fn()` to run, |
| 154 | + # in seconds. We use CUDA events and synchronization for accurate |
| 155 | + # measurement on CUDA enabled devices. |
| 156 | + if torch.cuda.is_available(): |
| 157 | + start = torch.cuda.Event(enable_timing=True) |
| 158 | + end = torch.cuda.Event(enable_timing=True) |
| 159 | + start.record() |
| 160 | + else: |
| 161 | + start = time.time() |
| 162 | + |
| 163 | + result = fn() |
| 164 | + if torch.cuda.is_available(): |
| 165 | + end.record() |
| 166 | + torch.cuda.synchronize() |
| 167 | + else: |
| 168 | + end = time.time() |
| 169 | + |
| 170 | + # Measure time taken to execute the function in miliseconds |
| 171 | + if torch.cuda.is_available(): |
| 172 | + duration = start.elapsed_time(end) |
| 173 | + else: |
| 174 | + duration = (end - start) * 1000 |
| 175 | + |
| 176 | + return result, duration |
| 177 | + |
| 178 | + |
| 179 | +###################################################################### |
| 180 | +# Lets measure the time for first inference using AOTInductor |
| 181 | + |
| 182 | +torch._dynamo.reset() |
| 183 | + |
| 184 | +model = torch._export.aot_load(model_so_path, device) |
| 185 | +example_inputs = (torch.randn(1, 3, 224, 224, device=device),) |
| 186 | + |
| 187 | +with torch.inference_mode(): |
| 188 | + _, time_taken = timed(lambda: model(example_inputs)) |
| 189 | + print(f"Time taken for first inference for AOTInductor is {time_taken:.2f} ms") |
| 190 | + |
| 191 | + |
| 192 | +###################################################################### |
| 193 | +# Lets measure the time for first inference using ``torch.compile`` |
| 194 | + |
| 195 | +torch._dynamo.reset() |
| 196 | + |
| 197 | +model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device) |
| 198 | +model.eval() |
| 199 | + |
| 200 | +model = torch.compile(model) |
| 201 | +example_inputs = torch.randn(1, 3, 224, 224, device=device) |
| 202 | + |
| 203 | +with torch.inference_mode(): |
| 204 | + _, time_taken = timed(lambda: model(example_inputs)) |
| 205 | + print(f"Time taken for first inference for torch.compile is {time_taken:.2f} ms") |
| 206 | + |
| 207 | +###################################################################### |
| 208 | +# We see that there is a drastic speedup in first inference time using AOTInductor compared |
| 209 | +# to ``torch.compile`` |
| 210 | + |
| 211 | +###################################################################### |
| 212 | +# Conclusion |
| 213 | +# ---------- |
| 214 | +# |
| 215 | +# In this recipe, we have learned how to effectively use the AOTInductor for Python runtime by |
| 216 | +# compiling and loading a pretrained ``ResNet18`` model using the ``torch._inductor.aot_compile`` |
| 217 | +# and ``torch._export.aot_load`` APIs. This process demonstrates the practical application of |
| 218 | +# generating a shared library and running it within a Python environment, even with dynamic shape |
| 219 | +# considerations and device-specific optimizations. We also looked at the advantage of using |
| 220 | +# AOTInductor in model deployments, with regards to speed up in first inference time. |
0 commit comments