Skip to content

Commit b7afac3

Browse files
lezwonBorda
andauthored
Add onnx export (#2596)
* export model to onnx * prepare data before exporting * support for dataloaders and tensors * added tests * use example_input_array add to changelog * updated docstring * added onnx inference tests * temp commit * removed schema valid test * add onnxruntime to environment.yml * moved onnxruntime to environment.yml pip * add example in doc * add lines between code block * added PR to changelog * is file check Co-authored-by: Jirka Borovec <[email protected]> * remove * Co-authored-by: Jirka Borovec <[email protected]> * infer example outputs * added doctest for onnx * fix windows tests * moved eval within condition block * self.forward to self * added docs * fixed docs error * added to toctree * Update CHANGELOG.md Co-authored-by: Jirka Borovec <[email protected]>
1 parent 06e8910 commit b7afac3

File tree

8 files changed

+188
-3
lines changed

8 files changed

+188
-3
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Added SSIM metrics ([#2671](https://github.com/PyTorchLightning/pytorch-lightning/pull/2671))
1313
- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))
1414

15+
- Added support to export a model to ONNX format ([#2596](https://github.com/PyTorchLightning/pytorch-lightning/pull/2596))
16+
1517
- Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246))
1618

1719
- Added support for PyTorch 1.6 ([#2745](https://github.com/PyTorchLightning/pytorch-lightning/pull/2745))

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ PyTorch Lightning Documentation
9999
transfer_learning
100100
tpu
101101
test_set
102+
production_inference
102103

103104
.. toctree::
104105
:maxdepth: 1

docs/source/production_inference.rst

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
Inference in Production
2+
=======================
3+
PyTorch Lightning eases the process of deploying models into production.
4+
5+
6+
Exporting to ONNX
7+
-----------------
8+
PyTorch Lightning provides a handy function to quickly export your model to ONNX format, which allows the model to be independent of PyTorch and run on an ONNX Runtime.
9+
10+
To export your model to ONNX format call the `to_onnx` function on your Lightning Module with the filepath and input_sample.
11+
12+
.. code-block:: python
13+
14+
filepath = 'model.onnx'
15+
model = SimpleModel()
16+
input_sample = torch.randn((1, 64))
17+
model.to_onnx(filepath, input_sample, export_params=True)
18+
19+
You can also skip passing the input sample if the `example_input_array` property is specified in your LightningModule.
20+
21+
Once you have the exported model, you can run it on your ONNX runtime in the following way:
22+
23+
.. code-block:: python
24+
25+
ort_session = onnxruntime.InferenceSession(filepath)
26+
input_name = ort_session.get_inputs()[0].name
27+
ort_inputs = {input_name: np.random.randn(1, 64).astype(np.float32)}
28+
ort_outs = ort_session.run(None, ort_inputs)

environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@ dependencies:
4848
- wandb>=0.8.21
4949
- neptune-client>=0.4.109
5050
- horovod>=0.19.1
51+
- onnxruntime>=1.3.0

pytorch_lightning/core/lightning.py

+39
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
import os
44
import re
5+
import tempfile
56
from abc import ABC, abstractmethod
67
from argparse import Namespace
78
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
@@ -1723,6 +1724,44 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
17231724
else:
17241725
self._hparams = hp
17251726

1727+
def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs):
1728+
"""Saves the model in ONNX format
1729+
1730+
Args:
1731+
file_path: The path of the file the model should be saved to.
1732+
input_sample: A sample of an input tensor for tracing.
1733+
**kwargs: Will be passed to torch.onnx.export function.
1734+
1735+
Example:
1736+
>>> class SimpleModel(LightningModule):
1737+
... def __init__(self):
1738+
... super().__init__()
1739+
... self.l1 = torch.nn.Linear(in_features=64, out_features=4)
1740+
...
1741+
... def forward(self, x):
1742+
... return torch.relu(self.l1(x.view(x.size(0), -1)))
1743+
1744+
>>> with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
1745+
... model = SimpleModel()
1746+
... input_sample = torch.randn((1, 64))
1747+
... model.to_onnx(tmpfile.name, input_sample, export_params=True)
1748+
... os.path.isfile(tmpfile.name)
1749+
True
1750+
"""
1751+
1752+
if isinstance(input_sample, Tensor):
1753+
input_data = input_sample
1754+
elif self.example_input_array is not None:
1755+
input_data = self.example_input_array
1756+
else:
1757+
raise ValueError(f'input_sample and example_input_array tensors are both missing.')
1758+
1759+
if 'example_outputs' not in kwargs:
1760+
self.eval()
1761+
kwargs['example_outputs'] = self(input_data)
1762+
1763+
torch.onnx.export(self, input_data, file_path, **kwargs)
1764+
17261765
@property
17271766
def hparams(self) -> Union[AttributeDict, str]:
17281767
if not hasattr(self, '_hparams'):

requirements/extra.txt

+2
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ omegaconf>=2.0.0
1212
# scipy>=0.13.3
1313
scikit-learn>=0.20.0
1414
torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility
15+
onnx>=1.7.0
16+
onnxruntime>=1.3.0

tests/base/model_template.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ def __init__(
7373
self.test_step_end_called = False
7474
self.test_epoch_end_called = False
7575

76-
# if you specify an example input, the summary will show input/output for each layer
77-
# TODO: to be fixed in #1773
78-
# self.example_input_array = torch.rand(5, 28 * 28)
76+
self.example_input_array = torch.rand(5, 28 * 28)
7977

8078
# build model
8179
self.__build_model()

tests/models/test_onnx_save.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import os
2+
3+
import onnxruntime
4+
import pytest
5+
import torch
6+
import numpy as np
7+
import tests.base.develop_pipelines as tpipes
8+
import tests.base.develop_utils as tutils
9+
from pytorch_lightning import Trainer
10+
from tests.base import EvalModelTemplate
11+
12+
13+
def test_model_saves_with_input_sample(tmpdir):
14+
"""Test that ONNX model saves with input sample and size is greater than 3 MB"""
15+
model = EvalModelTemplate()
16+
trainer = Trainer(max_epochs=1)
17+
trainer.fit(model)
18+
19+
file_path = os.path.join(tmpdir, "model.onxx")
20+
input_sample = torch.randn((1, 28 * 28))
21+
model.to_onnx(file_path, input_sample)
22+
assert os.path.isfile(file_path)
23+
assert os.path.getsize(file_path) > 3e+06
24+
25+
26+
def test_model_saves_with_example_output(tmpdir):
27+
"""Test that ONNX model saves when provided with example output"""
28+
model = EvalModelTemplate()
29+
trainer = Trainer(max_epochs=1)
30+
trainer.fit(model)
31+
32+
file_path = os.path.join(tmpdir, "model.onxx")
33+
input_sample = torch.randn((1, 28 * 28))
34+
model.eval()
35+
example_outputs = model.forward(input_sample)
36+
model.to_onnx(file_path, input_sample, example_outputs=example_outputs)
37+
assert os.path.exists(file_path) is True
38+
39+
40+
def test_model_saves_with_example_input_array(tmpdir):
41+
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
42+
model = EvalModelTemplate()
43+
file_path = os.path.join(tmpdir, "model.onxx")
44+
model.to_onnx(file_path)
45+
assert os.path.exists(file_path) is True
46+
assert os.path.getsize(file_path) > 3e+06
47+
48+
49+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
50+
def test_model_saves_on_multi_gpu(tmpdir):
51+
"""Test that ONNX model saves on a distributed backend"""
52+
tutils.set_random_master_port()
53+
54+
trainer_options = dict(
55+
default_root_dir=tmpdir,
56+
max_epochs=1,
57+
limit_train_batches=10,
58+
limit_val_batches=10,
59+
gpus=[0, 1],
60+
distributed_backend='ddp_spawn',
61+
progress_bar_refresh_rate=0
62+
)
63+
64+
model = EvalModelTemplate()
65+
66+
tpipes.run_model_test(trainer_options, model)
67+
68+
file_path = os.path.join(tmpdir, "model.onxx")
69+
model.to_onnx(file_path)
70+
assert os.path.exists(file_path) is True
71+
72+
73+
def test_verbose_param(tmpdir, capsys):
74+
"""Test that output is present when verbose parameter is set"""
75+
model = EvalModelTemplate()
76+
file_path = os.path.join(tmpdir, "model.onxx")
77+
model.to_onnx(file_path, verbose=True)
78+
captured = capsys.readouterr()
79+
assert "graph(%" in captured.out
80+
81+
82+
def test_error_if_no_input(tmpdir):
83+
"""Test that an exception is thrown when there is no input tensor"""
84+
model = EvalModelTemplate()
85+
model.example_input_array = None
86+
file_path = os.path.join(tmpdir, "model.onxx")
87+
with pytest.raises(ValueError, match=r'input_sample and example_input_array tensors are both missing'):
88+
model.to_onnx(file_path)
89+
90+
91+
def test_if_inference_output_is_valid(tmpdir):
92+
"""Test that the output inferred from ONNX model is same as from PyTorch"""
93+
model = EvalModelTemplate()
94+
trainer = Trainer(max_epochs=5)
95+
trainer.fit(model)
96+
97+
model.eval()
98+
with torch.no_grad():
99+
torch_out = model(model.example_input_array)
100+
101+
file_path = os.path.join(tmpdir, "model.onxx")
102+
model.to_onnx(file_path, model.example_input_array, export_params=True)
103+
104+
ort_session = onnxruntime.InferenceSession(file_path)
105+
106+
def to_numpy(tensor):
107+
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
108+
109+
# compute ONNX Runtime output prediction
110+
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(model.example_input_array)}
111+
ort_outs = ort_session.run(None, ort_inputs)
112+
113+
# compare ONNX Runtime and PyTorch results
114+
assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

0 commit comments

Comments
 (0)