2
2
"""
3
3
`Introduction to ONNX <intro_onnx.html>`_ ||
4
4
**Exporting a PyTorch model to ONNX** ||
5
- `Extending the ONNX Registry <onnx_registry_tutorial.html>`_
5
+ `Extending the ONNX exporter operator support <onnx_registry_tutorial.html>`_ ||
6
+ `Export a model with control flow to ONNX <export_control_flow_model_to_onnx_tutorial.html>`_
6
7
7
8
Export a PyTorch model to ONNX
8
9
==============================
9
10
10
- **Author**: `Ti-Tai Wang <https://github.com/titaiwangms>`_ and `Xavier Dupré <https://github.com/xadupre >`_
11
+ **Author**: `Ti-Tai Wang <https://github.com/titaiwangms>`_, `Justin Chu <[email protected] >`_, `Thiago Crepaldi <https://github.com/thiagocrepaldi >`_.
11
12
12
13
.. note::
13
- As of PyTorch 2.1 , there are two versions of ONNX Exporter.
14
+ As of PyTorch 2.5 , there are two versions of ONNX Exporter.
14
15
15
- * ``torch.onnx.dynamo_export `` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
16
- * ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0
16
+ * ``torch.onnx.export(..., dynamo=True) `` is the newest (still in beta) exporter using ``torch.export`` and Torch FX to capture the graph. It was released with PyTorch 2.5
17
+ * ``torch.onnx.export`` uses TorchScript and has been available since PyTorch 1.2.0
17
18
18
19
"""
19
20
20
21
###############################################################################
21
22
# In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_,
22
23
# we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images.
23
24
# In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the
24
- # ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export `` ONNX exporter.
25
+ # ONNX format using the ``torch.onnx.export(..., dynamo=True) `` ONNX exporter.
25
26
#
26
27
# While PyTorch is great for iterating on the development of models, the model can be deployed to production
27
28
# using different formats, including `ONNX <https://onnx.ai/>`_ (Open Neural Network Exchange)!
47
48
#
48
49
# .. code-block:: bash
49
50
#
50
- # pip install onnx
51
- # pip install onnxscript
51
+ # pip install --upgrade onnx onnxscript
52
52
#
53
53
# 2. Author a simple image classifier model
54
54
# -----------------------------------------
62
62
import torch .nn .functional as F
63
63
64
64
65
- class MyModel (nn .Module ):
66
-
65
+ class ImageClassifierModel (nn .Module ):
67
66
def __init__ (self ):
68
- super (MyModel , self ).__init__ ()
67
+ super ().__init__ ()
69
68
self .conv1 = nn .Conv2d (1 , 6 , 5 )
70
69
self .conv2 = nn .Conv2d (6 , 16 , 5 )
71
70
self .fc1 = nn .Linear (16 * 5 * 5 , 120 )
72
71
self .fc2 = nn .Linear (120 , 84 )
73
72
self .fc3 = nn .Linear (84 , 10 )
74
73
75
- def forward (self , x ):
74
+ def forward (self , x : torch . Tensor ):
76
75
x = F .max_pool2d (F .relu (self .conv1 (x )), (2 , 2 ))
77
76
x = F .max_pool2d (F .relu (self .conv2 (x )), 2 )
78
77
x = torch .flatten (x , 1 )
@@ -81,16 +80,27 @@ def forward(self, x):
81
80
x = self .fc3 (x )
82
81
return x
83
82
83
+
84
84
######################################################################
85
85
# 3. Export the model to ONNX format
86
86
# ----------------------------------
87
87
#
88
88
# Now that we have our model defined, we need to instantiate it and create a random 32x32 input.
89
89
# Next, we can export the model to ONNX format.
90
90
91
- torch_model = MyModel ()
92
- torch_input = torch .randn (1 , 1 , 32 , 32 )
93
- onnx_program = torch .onnx .dynamo_export (torch_model , torch_input )
91
+ torch_model = ImageClassifierModel ()
92
+ # Create example inputs for exporting the model. The inputs should be a tuple of tensors.
93
+ example_inputs = (torch .randn (1 , 1 , 32 , 32 ),)
94
+ onnx_program = torch .onnx .export (torch_model , example_inputs , dynamo = True )
95
+
96
+ ######################################################################
97
+ # 3.5. (Optional) Optimize the ONNX model
98
+ # ---------------------------------------
99
+ #
100
+ # The ONNX model can be optimized with constant folding, and elimination of redundant nodes.
101
+ # The optimization is done in-place, so the original ONNX model is modified.
102
+
103
+ onnx_program .optimize ()
94
104
95
105
######################################################################
96
106
# As we can see, we didn't need any code change to the model.
@@ -102,13 +112,14 @@ def forward(self, x):
102
112
# Although having the exported model loaded in memory is useful in many applications,
103
113
# we can save it to disk with the following code:
104
114
105
- onnx_program .save ("my_image_classifier .onnx" )
115
+ onnx_program .save ("image_classifier_model .onnx" )
106
116
107
117
######################################################################
108
118
# You can load the ONNX file back into memory and check if it is well formed with the following code:
109
119
110
120
import onnx
111
- onnx_model = onnx .load ("my_image_classifier.onnx" )
121
+
122
+ onnx_model = onnx .load ("image_classifier_model.onnx" )
112
123
onnx .checker .check_model (onnx_model )
113
124
114
125
######################################################################
@@ -124,7 +135,7 @@ def forward(self, x):
124
135
# :align: center
125
136
#
126
137
#
127
- # Once Netron is open, we can drag and drop our ``my_image_classifier .onnx`` file into the browser or select it after
138
+ # Once Netron is open, we can drag and drop our ``image_classifier_model .onnx`` file into the browser or select it after
128
139
# clicking the **Open model** button.
129
140
#
130
141
# .. image:: ../../_static/img/onnx/image_classifier_onnx_model_on_netron_web_ui.png
@@ -155,18 +166,17 @@ def forward(self, x):
155
166
156
167
import onnxruntime
157
168
158
- onnx_input = [torch_input ]
159
- print (f"Input length: { len (onnx_input )} " )
160
- print (f"Sample input: { onnx_input } " )
169
+ onnx_inputs = [tensor . numpy ( force = True ) for tensor in example_inputs ]
170
+ print (f"Input length: { len (onnx_inputs )} " )
171
+ print (f"Sample input: { onnx_inputs } " )
161
172
162
- ort_session = onnxruntime .InferenceSession ("./my_image_classifier.onnx" , providers = ['CPUExecutionProvider' ])
173
+ ort_session = onnxruntime .InferenceSession (
174
+ "./image_classifier_model.onnx" , providers = ["CPUExecutionProvider" ]
175
+ )
163
176
164
- def to_numpy (tensor ):
165
- return tensor .detach ().cpu ().numpy () if tensor .requires_grad else tensor .cpu ().numpy ()
177
+ onnxruntime_input = {input_arg .name : input_value for input_arg , input_value in zip (ort_session .get_inputs (), onnx_inputs )}
166
178
167
- onnxruntime_input = {k .name : to_numpy (v ) for k , v in zip (ort_session .get_inputs (), onnx_input )}
168
-
169
- # onnxruntime returns a list of outputs
179
+ # ONNX Runtime returns a list of outputs
170
180
onnxruntime_outputs = ort_session .run (None , onnxruntime_input )[0 ]
171
181
172
182
####################################################################
@@ -179,7 +189,7 @@ def to_numpy(tensor):
179
189
# For that, we need to execute the PyTorch model with the same input and compare the results with ONNX Runtime's.
180
190
# Before comparing the results, we need to convert the PyTorch's output to match ONNX's format.
181
191
182
- torch_outputs = torch_model (torch_input )
192
+ torch_outputs = torch_model (* example_inputs )
183
193
184
194
assert len (torch_outputs ) == len (onnxruntime_outputs )
185
195
for torch_output , onnxruntime_output in zip (torch_outputs , onnxruntime_outputs ):
@@ -209,4 +219,4 @@ def to_numpy(tensor):
209
219
#
210
220
# .. toctree::
211
221
# :hidden:
212
- #
222
+ #
0 commit comments