Skip to content

Commit cb2e4ed

Browse files
titaiwangmssvekars
andauthored
[ONNX] Update API to torch.onnx.export(..., dynamo=True) (#3223)
--------- Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 469d95b commit cb2e4ed

8 files changed

+417
-265
lines changed
-38.1 KB
Binary file not shown.
-25.4 KB
Binary file not shown.

beginner_source/onnx/README.txt

+5-1
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,9 @@ ONNX
1010
https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html
1111

1212
3. onnx_registry_tutorial.py
13-
Extending the ONNX Registry
13+
Extending the ONNX exporter operator support
1414
https://pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html
15+
16+
4. export_control_flow_model_to_onnx_tutorial.py
17+
Export a model with control flow to ONNX
18+
https://pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
`Introduction to ONNX <intro_onnx.html>`_ ||
4+
`Exporting a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_ ||
5+
`Extending the ONNX exporter operator support <onnx_registry_tutorial.html>`_ ||
6+
**`Export a model with control flow to ONNX**
7+
8+
Export a model with control flow to ONNX
9+
========================================
10+
11+
**Author**: `Xavier Dupré <https://github.com/xadupre>`_
12+
"""
13+
14+
15+
###############################################################################
16+
# Overview
17+
# --------
18+
#
19+
# This tutorial demonstrates how to handle control flow logic while exporting
20+
# a PyTorch model to ONNX. It highlights the challenges of exporting
21+
# conditional statements directly and provides solutions to circumvent them.
22+
#
23+
# Conditional logic cannot be exported into ONNX unless they refactored
24+
# to use :func:`torch.cond`. Let's start with a simple model
25+
# implementing a test.
26+
#
27+
# What you will learn:
28+
#
29+
# - How to refactor the model to use :func:`torch.cond` for exporting.
30+
# - How to export a model with control flow logic to ONNX.
31+
# - How to optimize the exported model using the ONNX optimizer.
32+
#
33+
# Prerequisites
34+
# ~~~~~~~~~~~~~
35+
#
36+
# * ``torch >= 2.6``
37+
38+
39+
import torch
40+
41+
###############################################################################
42+
# Define the Models
43+
# -----------------
44+
#
45+
# Two models are defined:
46+
#
47+
# ``ForwardWithControlFlowTest``: A model with a forward method containing an
48+
# if-else conditional.
49+
#
50+
# ``ModelWithControlFlowTest``: A model that incorporates ``ForwardWithControlFlowTest``
51+
# as part of a simple MLP. The models are tested with
52+
# a random input tensor to confirm they execute as expected.
53+
54+
class ForwardWithControlFlowTest(torch.nn.Module):
55+
def forward(self, x):
56+
if x.sum():
57+
return x * 2
58+
return -x
59+
60+
61+
class ModelWithControlFlowTest(torch.nn.Module):
62+
def __init__(self):
63+
super().__init__()
64+
self.mlp = torch.nn.Sequential(
65+
torch.nn.Linear(3, 2),
66+
torch.nn.Linear(2, 1),
67+
ForwardWithControlFlowTest(),
68+
)
69+
70+
def forward(self, x):
71+
out = self.mlp(x)
72+
return out
73+
74+
75+
model = ModelWithControlFlowTest()
76+
77+
78+
###############################################################################
79+
# Exporting the Model: First Attempt
80+
# ----------------------------------
81+
#
82+
# Exporting this model using torch.export.export fails because the control
83+
# flow logic in the forward pass creates a graph break that the exporter cannot
84+
# handle. This behavior is expected, as conditional logic not written using
85+
# :func:`torch.cond` is unsupported.
86+
#
87+
# A try-except block is used to capture the expected failure during the export
88+
# process. If the export unexpectedly succeeds, an ``AssertionError`` is raised.
89+
90+
x = torch.randn(3)
91+
model(x)
92+
93+
try:
94+
torch.export.export(model, (x,), strict=False)
95+
raise AssertionError("This export should failed unless PyTorch now supports this model.")
96+
except Exception as e:
97+
print(e)
98+
99+
###############################################################################
100+
# Using :func:`torch.onnx.export` with JIT Tracing
101+
# ----------------------------------------
102+
#
103+
# When exporting the model using :func:`torch.onnx.export` with the dynamo=True
104+
# argument, the exporter defaults to using JIT tracing. This fallback allows
105+
# the model to export, but the resulting ONNX graph may not faithfully represent
106+
# the original model logic due to the limitations of tracing.
107+
108+
109+
onnx_program = torch.onnx.export(model, (x,), dynamo=True)
110+
print(onnx_program.model)
111+
112+
113+
###############################################################################
114+
# Suggested Patch: Refactoring with :func:`torch.cond`
115+
# --------------------------------------------
116+
#
117+
# To make the control flow exportable, the tutorial demonstrates replacing the
118+
# forward method in ``ForwardWithControlFlowTest`` with a refactored version that
119+
# uses :func:`torch.cond``.
120+
#
121+
# Details of the Refactoring:
122+
#
123+
# Two helper functions (identity2 and neg) represent the branches of the conditional logic:
124+
# * :func:`torch.cond`` is used to specify the condition and the two branches along with the input arguments.
125+
# * The updated forward method is then dynamically assigned to the ``ForwardWithControlFlowTest`` instance within the model. A list of submodules is printed to confirm the replacement.
126+
127+
def new_forward(x):
128+
def identity2(x):
129+
return x * 2
130+
131+
def neg(x):
132+
return -x
133+
134+
return torch.cond(x.sum() > 0, identity2, neg, (x,))
135+
136+
137+
print("the list of submodules")
138+
for name, mod in model.named_modules():
139+
print(name, type(mod))
140+
if isinstance(mod, ForwardWithControlFlowTest):
141+
mod.forward = new_forward
142+
143+
###############################################################################
144+
# Let's see what the FX graph looks like.
145+
146+
print(torch.export.export(model, (x,), strict=False))
147+
148+
###############################################################################
149+
# Let's export again.
150+
151+
onnx_program = torch.onnx.export(model, (x,), dynamo=True)
152+
print(onnx_program.model)
153+
154+
155+
###############################################################################
156+
# We can optimize the model and get rid of the model local functions created to capture the control flow branches.
157+
158+
onnx_program.optimize()
159+
print(onnx_program.model)
160+
161+
###############################################################################
162+
# Conclusion
163+
# ----------
164+
#
165+
# This tutorial demonstrates the challenges of exporting models with conditional
166+
# logic to ONNX and presents a practical solution using :func:`torch.cond`.
167+
# While the default exporters may fail or produce imperfect graphs, refactoring the
168+
# model's logic ensures compatibility and generates a faithful ONNX representation.
169+
#
170+
# By understanding these techniques, we can overcome common pitfalls when
171+
# working with control flow in PyTorch models and ensure smooth integration with ONNX workflows.
172+
#
173+
# Further reading
174+
# ---------------
175+
#
176+
# The list below refers to tutorials that ranges from basic examples to advanced scenarios,
177+
# not necessarily in the order they are listed.
178+
# Feel free to jump directly to specific topics of your interest or
179+
# sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
180+
#
181+
# .. include:: /beginner_source/onnx/onnx_toc.txt
182+
#
183+
# .. toctree::
184+
# :hidden:
185+
#

beginner_source/onnx/export_simple_model_to_onnx_tutorial.py

+39-29
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,27 @@
22
"""
33
`Introduction to ONNX <intro_onnx.html>`_ ||
44
**Exporting a PyTorch model to ONNX** ||
5-
`Extending the ONNX Registry <onnx_registry_tutorial.html>`_
5+
`Extending the ONNX exporter operator support <onnx_registry_tutorial.html>`_ ||
6+
`Export a model with control flow to ONNX <export_control_flow_model_to_onnx_tutorial.html>`_
67
78
Export a PyTorch model to ONNX
89
==============================
910
10-
**Author**: `Ti-Tai Wang <https://github.com/titaiwangms>`_ and `Xavier Dupré <https://github.com/xadupre>`_
11+
**Author**: `Ti-Tai Wang <https://github.com/titaiwangms>`_, `Justin Chu <[email protected]>`_, `Thiago Crepaldi <https://github.com/thiagocrepaldi>`_.
1112
1213
.. note::
13-
As of PyTorch 2.1, there are two versions of ONNX Exporter.
14+
As of PyTorch 2.5, there are two versions of ONNX Exporter.
1415
15-
* ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
16-
* ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0
16+
* ``torch.onnx.export(..., dynamo=True)`` is the newest (still in beta) exporter using ``torch.export`` and Torch FX to capture the graph. It was released with PyTorch 2.5
17+
* ``torch.onnx.export`` uses TorchScript and has been available since PyTorch 1.2.0
1718
1819
"""
1920

2021
###############################################################################
2122
# In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_,
2223
# we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images.
2324
# In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the
24-
# ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export`` ONNX exporter.
25+
# ONNX format using the ``torch.onnx.export(..., dynamo=True)`` ONNX exporter.
2526
#
2627
# While PyTorch is great for iterating on the development of models, the model can be deployed to production
2728
# using different formats, including `ONNX <https://onnx.ai/>`_ (Open Neural Network Exchange)!
@@ -47,8 +48,7 @@
4748
#
4849
# .. code-block:: bash
4950
#
50-
# pip install onnx
51-
# pip install onnxscript
51+
# pip install --upgrade onnx onnxscript
5252
#
5353
# 2. Author a simple image classifier model
5454
# -----------------------------------------
@@ -62,17 +62,16 @@
6262
import torch.nn.functional as F
6363

6464

65-
class MyModel(nn.Module):
66-
65+
class ImageClassifierModel(nn.Module):
6766
def __init__(self):
68-
super(MyModel, self).__init__()
67+
super().__init__()
6968
self.conv1 = nn.Conv2d(1, 6, 5)
7069
self.conv2 = nn.Conv2d(6, 16, 5)
7170
self.fc1 = nn.Linear(16 * 5 * 5, 120)
7271
self.fc2 = nn.Linear(120, 84)
7372
self.fc3 = nn.Linear(84, 10)
7473

75-
def forward(self, x):
74+
def forward(self, x: torch.Tensor):
7675
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
7776
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
7877
x = torch.flatten(x, 1)
@@ -81,16 +80,27 @@ def forward(self, x):
8180
x = self.fc3(x)
8281
return x
8382

83+
8484
######################################################################
8585
# 3. Export the model to ONNX format
8686
# ----------------------------------
8787
#
8888
# Now that we have our model defined, we need to instantiate it and create a random 32x32 input.
8989
# Next, we can export the model to ONNX format.
9090

91-
torch_model = MyModel()
92-
torch_input = torch.randn(1, 1, 32, 32)
93-
onnx_program = torch.onnx.dynamo_export(torch_model, torch_input)
91+
torch_model = ImageClassifierModel()
92+
# Create example inputs for exporting the model. The inputs should be a tuple of tensors.
93+
example_inputs = (torch.randn(1, 1, 32, 32),)
94+
onnx_program = torch.onnx.export(torch_model, example_inputs, dynamo=True)
95+
96+
######################################################################
97+
# 3.5. (Optional) Optimize the ONNX model
98+
# ---------------------------------------
99+
#
100+
# The ONNX model can be optimized with constant folding, and elimination of redundant nodes.
101+
# The optimization is done in-place, so the original ONNX model is modified.
102+
103+
onnx_program.optimize()
94104

95105
######################################################################
96106
# As we can see, we didn't need any code change to the model.
@@ -102,13 +112,14 @@ def forward(self, x):
102112
# Although having the exported model loaded in memory is useful in many applications,
103113
# we can save it to disk with the following code:
104114

105-
onnx_program.save("my_image_classifier.onnx")
115+
onnx_program.save("image_classifier_model.onnx")
106116

107117
######################################################################
108118
# You can load the ONNX file back into memory and check if it is well formed with the following code:
109119

110120
import onnx
111-
onnx_model = onnx.load("my_image_classifier.onnx")
121+
122+
onnx_model = onnx.load("image_classifier_model.onnx")
112123
onnx.checker.check_model(onnx_model)
113124

114125
######################################################################
@@ -124,7 +135,7 @@ def forward(self, x):
124135
# :align: center
125136
#
126137
#
127-
# Once Netron is open, we can drag and drop our ``my_image_classifier.onnx`` file into the browser or select it after
138+
# Once Netron is open, we can drag and drop our ``image_classifier_model.onnx`` file into the browser or select it after
128139
# clicking the **Open model** button.
129140
#
130141
# .. image:: ../../_static/img/onnx/image_classifier_onnx_model_on_netron_web_ui.png
@@ -155,18 +166,17 @@ def forward(self, x):
155166

156167
import onnxruntime
157168

158-
onnx_input = [torch_input]
159-
print(f"Input length: {len(onnx_input)}")
160-
print(f"Sample input: {onnx_input}")
169+
onnx_inputs = [tensor.numpy(force=True) for tensor in example_inputs]
170+
print(f"Input length: {len(onnx_inputs)}")
171+
print(f"Sample input: {onnx_inputs}")
161172

162-
ort_session = onnxruntime.InferenceSession("./my_image_classifier.onnx", providers=['CPUExecutionProvider'])
173+
ort_session = onnxruntime.InferenceSession(
174+
"./image_classifier_model.onnx", providers=["CPUExecutionProvider"]
175+
)
163176

164-
def to_numpy(tensor):
165-
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
177+
onnxruntime_input = {input_arg.name: input_value for input_arg, input_value in zip(ort_session.get_inputs(), onnx_inputs)}
166178

167-
onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
168-
169-
# onnxruntime returns a list of outputs
179+
# ONNX Runtime returns a list of outputs
170180
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)[0]
171181

172182
####################################################################
@@ -179,7 +189,7 @@ def to_numpy(tensor):
179189
# For that, we need to execute the PyTorch model with the same input and compare the results with ONNX Runtime's.
180190
# Before comparing the results, we need to convert the PyTorch's output to match ONNX's format.
181191

182-
torch_outputs = torch_model(torch_input)
192+
torch_outputs = torch_model(*example_inputs)
183193

184194
assert len(torch_outputs) == len(onnxruntime_outputs)
185195
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
@@ -209,4 +219,4 @@ def to_numpy(tensor):
209219
#
210220
# .. toctree::
211221
# :hidden:
212-
#
222+
#

0 commit comments

Comments
 (0)