Skip to content

Commit 6c6e233

Browse files
author
lezwon
committed
use example_input_array
add to changelog
1 parent b3d3e7a commit 6c6e233

File tree

5 files changed

+24
-35
lines changed

5 files changed

+24
-35
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88
## [unreleased] - YYYY-MM-DD
99

1010
### Added
11-
11+
- Added exporting model to ONNX format.
1212

1313
### Changed
1414

pytorch_lightning/core/lightning.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -1718,26 +1718,23 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
17181718
else:
17191719
self._hparams = hp
17201720

1721-
def to_onnx(self, file_path: str, input: Optional[Union[DataLoader, Tensor]] = None, verbose: Optional[bool] = False):
1721+
def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs):
17221722
"""Saves the model in ONNX format
17231723
17241724
Args:
17251725
file_path: The path of the file the model should be saved to.
1726-
input: Either a PyTorch DataLoader with training samples or an input tensor for tracing.
1726+
input_sample: A sample of an input tensor for tracing.
17271727
verbose: Boolean value to indicate if the ONNX output should be printed
17281728
"""
17291729

1730-
if isinstance(input, DataLoader):
1731-
batch = next(iter(input))
1732-
input_data = batch[0]
1733-
elif isinstance(input, Tensor):
1734-
input_data = input
1730+
if isinstance(input_sample, Tensor):
1731+
input_data = input_sample
1732+
elif self.example_input_array is not None:
1733+
input_data = self.example_input_array
17351734
else:
1736-
self.prepare_data()
1737-
batch = next(iter(self.train_dataloader()))
1738-
input_data = batch[0]
1735+
raise ValueError(f'input_sample and example_input_array tensors are both missing.')
17391736

1740-
torch.onnx.export(self, input_data, file_path, verbose=verbose)
1737+
torch.onnx.export(self, input_data, file_path, **kwargs)
17411738

17421739
@property
17431740
def hparams(self) -> Union[AttributeDict, str]:

tests/base/model_template.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666

6767
# if you specify an example input, the summary will show input/output for each layer
6868
# TODO: to be fixed in #1773
69-
# self.example_input_array = torch.rand(5, 28 * 28)
69+
self.example_input_array = torch.rand(5, 28 * 28)
7070

7171
# build model
7272
self.__build_model()
@@ -89,7 +89,6 @@ def __build_model(self):
8989
)
9090

9191
def forward(self, x):
92-
x = x.view(x.size(0), -1)
9392
x = self.c_d1(x)
9493
x = torch.tanh(x)
9594
x = self.c_d1_bn(x)

tests/base/model_train_steps.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
1515
"""Lightning calls this inside the training loop"""
1616
# forward pass
1717
x, y = batch
18+
x = x.view(x.size(0), -1)
1819

1920
y_hat = self(x)
2021

tests/models/test_onxx_save.py

+13-21
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,22 @@
88
from tests.base import EvalModelTemplate
99

1010

11-
def test_model_saves_on_cpu(tmpdir):
12-
"""Test that ONNX model saves on CPU and size is greater than 3 MB"""
11+
def test_model_saves_with_input_sample(tmpdir):
12+
"""Test that ONNX model saves with input sample and size is greater than 3 MB"""
1313
model = EvalModelTemplate()
1414
trainer = Trainer(max_epochs=1)
1515
trainer.fit(model)
1616

17+
file_path = os.path.join(tmpdir, "model.onxx")
18+
input_sample = torch.randn((1, 28 * 28))
19+
model.to_onnx(file_path, input_sample)
20+
assert os.path.exists(file_path) is True
21+
assert os.path.getsize(file_path) > 3e+06
22+
23+
24+
def test_model_saves_with_example_input_array(tmpdir):
25+
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
26+
model = EvalModelTemplate()
1727
file_path = os.path.join(tmpdir, "model.onxx")
1828
model.to_onnx(file_path)
1929
assert os.path.exists(file_path) is True
@@ -50,22 +60,4 @@ def test_verbose_param(tmpdir, capsys):
5060
file_path = os.path.join(tmpdir, "model.onxx")
5161
model.to_onnx(file_path, verbose=True)
5262
captured = capsys.readouterr()
53-
assert "graph(%0" in captured.out
54-
55-
56-
def test_input_param_with_dataloader(tmpdir):
57-
"""Test that ONXX model is saved when a dataloader is passed in as input"""
58-
model = EvalModelTemplate()
59-
dataloader = model.dataloader(train=True)
60-
file_path = os.path.join(tmpdir, "model.onxx")
61-
model.to_onnx(file_path, input=dataloader)
62-
assert os.path.exists(file_path) is True
63-
64-
65-
def test_input_param_with_tensor(tmpdir):
66-
"""Test that ONXX model is saved when a a tensor is passed in as input"""
67-
model = EvalModelTemplate()
68-
tensor = torch.randn((1, 28, 28))
69-
file_path = os.path.join(tmpdir, "model.onxx")
70-
model.to_onnx(file_path, input=tensor)
71-
assert os.path.exists(file_path) is True
63+
assert "graph(%" in captured.out

0 commit comments

Comments
 (0)