Skip to content

Commit 4ad5a78

Browse files
awaelchliBordajustusschockananthsub
authored
to_torchscript method for LightningModule (#3258)
* script * docs * simple test * move test * fix doctest * no grad context * extend tests test test * datamodule test * clean up test * docs * name * fix import * update changelog * fix import * skip pytorch 1.3 in test * update codeblock * skip bugged 1.4 * typehints * doctest not working on all pytorch versions * rename TestGAN to prevent pytest interference * add note about pytorch version * fix torchscript version inconsistency in tests * reset training state + tests * update docstring * Apply suggestions from code review Co-authored-by: Justus Schock <[email protected]> * update docstring, dict return * add docs to index * add link * doc eval mode * forward * optional save to file path * optional * test torchscript device * test save load with file path * pep * str * Commit typing suggestion Co-authored-by: ananthsub <[email protected]> * skip test if cuda not available Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Justus Schock <[email protected]> Co-authored-by: ananthsub <[email protected]>
1 parent 4a22fca commit 4ad5a78

File tree

7 files changed

+171
-6
lines changed

7 files changed

+171
-6
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528))
1313

14+
- Added `LightningModule.to_torchscript` to support exporting as `ScriptModule` ([#3258](https://github.com/PyTorchLightning/pytorch-lightning/pull/3258/))
15+
1416
### Changed
1517

1618
- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))

docs/source/lightning-module.rst

+6
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,12 @@ to_onnx
770770
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.to_onnx
771771
:noindex:
772772

773+
to_torchscript
774+
~~~~~~~~~~~~~~
775+
776+
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.to_torchscript
777+
:noindex:
778+
773779
unfreeze
774780
~~~~~~~~
775781

docs/source/production_inference.rst

+18
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,21 @@ Once you have the exported model, you can run it on your ONNX runtime in the fol
2828
input_name = ort_session.get_inputs()[0].name
2929
ort_inputs = {input_name: np.random.randn(1, 64).astype(np.float32)}
3030
ort_outs = ort_session.run(None, ort_inputs)
31+
32+
33+
Exporting to TorchScript
34+
------------------------
35+
36+
TorchScript allows you to serialize your models in a way that it can be loaded in non-Python environments.
37+
The LightningModule has a handy method :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript`
38+
that returns a scripted module which you can save or directly use.
39+
40+
.. code-block:: python
41+
42+
model = SimpleModel()
43+
script = model.to_torchscript()
44+
45+
# save for use in production environment
46+
torch.jit.save(script, "model.pt")
47+
48+
It is recommended that you install the latest supported version of PyTorch to use this feature without limitations.

pytorch_lightning/core/lightning.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import torch
2525
import torch.distributed as torch_distrib
26-
from torch import Tensor
26+
from torch import Tensor, ScriptModule
2727
from torch.nn import Module
2828
from torch.nn.parallel import DistributedDataParallel
2929
from torch.optim.optimizer import Optimizer
@@ -184,6 +184,7 @@ def forward(self, batch):
184184
return logits
185185
186186
"""
187+
return super().forward(*args, **kwargs)
187188

188189
def training_step(self, *args, **kwargs):
189190
r"""
@@ -1729,6 +1730,54 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg
17291730

17301731
torch.onnx.export(self, input_data, file_path, **kwargs)
17311732

1733+
def to_torchscript(self, file_path: Optional[str] = None, **kwargs) -> Union[ScriptModule, Dict[str, ScriptModule]]:
1734+
"""
1735+
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
1736+
If you would like to customize the modules that are scripted or you want to use tracing
1737+
you should override this method. In case you want to return multiple modules, we
1738+
recommend using a dictionary.
1739+
1740+
Args:
1741+
file_path: Path where to save the torchscript. Default: None (no file saved).
1742+
**kwargs: Additional arguments that will be passed to the :func:`torch.jit.save` function.
1743+
1744+
Note:
1745+
- Requires the implementation of the
1746+
:meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method.
1747+
- The exported script will be set to evaluation mode.
1748+
- It is recommended that you install the latest supported version of PyTorch
1749+
to use this feature without limitations. See also the :mod:`torch.jit`
1750+
documentation for supported features.
1751+
1752+
Example:
1753+
>>> class SimpleModel(LightningModule):
1754+
... def __init__(self):
1755+
... super().__init__()
1756+
... self.l1 = torch.nn.Linear(in_features=64, out_features=4)
1757+
...
1758+
... def forward(self, x):
1759+
... return torch.relu(self.l1(x.view(x.size(0), -1)))
1760+
...
1761+
>>> model = SimpleModel()
1762+
>>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP
1763+
>>> os.path.isfile("model.pt") # doctest: +SKIP
1764+
True
1765+
1766+
Return:
1767+
This LightningModule as a torchscript, regardless of whether file_path is
1768+
defined or not.
1769+
"""
1770+
1771+
mode = self.training
1772+
with torch.no_grad():
1773+
scripted_module = torch.jit.script(self.eval(), **kwargs)
1774+
self.train(mode)
1775+
1776+
if file_path is not None:
1777+
torch.jit.save(scripted_module, file_path)
1778+
1779+
return scripted_module
1780+
17321781
@property
17331782
def hparams(self) -> Union[AttributeDict, str]:
17341783
if not hasattr(self, '_hparams'):

tests/base/models.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
class Generator(nn.Module):
21-
def __init__(self, latent_dim: tuple, img_shape: tuple):
21+
def __init__(self, latent_dim: int, img_shape: tuple):
2222
super().__init__()
2323
self.img_shape = img_shape
2424

@@ -64,10 +64,10 @@ def forward(self, img):
6464
return validity
6565

6666

67-
class TestGAN(LightningModule):
67+
class BasicGAN(LightningModule):
6868
"""Implements a basic GAN for the purpose of illustrating multiple optimizers."""
6969

70-
def __init__(self, hidden_dim, learning_rate, b1, b2, **kwargs):
70+
def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.001, b1: float = 0.5, b2: float = 0.999, **kwargs):
7171
super().__init__()
7272
self.hidden_dim = hidden_dim
7373
self.learning_rate = learning_rate
@@ -163,6 +163,7 @@ def __init__(self):
163163
super().__init__()
164164
self.rnn = nn.LSTM(10, 20, batch_first=True)
165165
self.linear_out = nn.Linear(in_features=20, out_features=5)
166+
self.example_input_array = torch.rand(2, 3, 10)
166167

167168
def forward(self, x):
168169
seq, last = self.rnn(x)
@@ -189,6 +190,7 @@ def __init__(self):
189190
self.c_d1_bn = nn.BatchNorm1d(128)
190191
self.c_d1_drop = nn.Dropout(0.3)
191192
self.c_d2 = nn.Linear(in_features=128, out_features=10)
193+
self.example_input_array = torch.rand(2, 1, 28, 28)
192194

193195
def forward(self, x):
194196
x = x.view(x.size(0), -1)

tests/models/test_horovod.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import tests.base.develop_utils as tutils
1414
from pytorch_lightning import Trainer
1515
from tests.base import EvalModelTemplate
16-
from tests.base.models import TestGAN
16+
from tests.base.models import BasicGAN
1717

1818
try:
1919
from horovod.common.util import nccl_built
@@ -145,7 +145,7 @@ def validation_step(self, batch, *args, **kwargs):
145145

146146
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
147147
def test_horovod_multi_optimizer(tmpdir):
148-
model = TestGAN(**EvalModelTemplate.get_default_hparams())
148+
model = BasicGAN(**EvalModelTemplate.get_default_hparams())
149149

150150
# fit model
151151
trainer = Trainer(

tests/models/test_torchscript.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from distutils.version import LooseVersion
2+
3+
import pytest
4+
import torch
5+
6+
from tests.base import EvalModelTemplate
7+
from tests.base.datamodules import TrialMNISTDataModule
8+
from tests.base.models import ParityModuleRNN, BasicGAN
9+
10+
11+
@pytest.mark.parametrize("modelclass", [
12+
EvalModelTemplate,
13+
ParityModuleRNN,
14+
BasicGAN,
15+
])
16+
def test_torchscript_input_output(modelclass):
17+
""" Test that scripted LightningModule forward works. """
18+
model = modelclass()
19+
script = model.to_torchscript()
20+
assert isinstance(script, torch.jit.ScriptModule)
21+
model.eval()
22+
model_output = model(model.example_input_array)
23+
script_output = script(model.example_input_array)
24+
assert torch.allclose(script_output, model_output)
25+
26+
27+
@pytest.mark.parametrize("device", [
28+
torch.device("cpu"),
29+
torch.device("cuda", 0)
30+
])
31+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
32+
def test_torchscript_device(device):
33+
""" Test that scripted module is on the correct device. """
34+
model = EvalModelTemplate().to(device)
35+
script = model.to_torchscript()
36+
assert next(script.parameters()).device == device
37+
script_output = script(model.example_input_array.to(device))
38+
assert script_output.device == device
39+
40+
41+
def test_torchscript_retain_training_state():
42+
""" Test that torchscript export does not alter the training mode of original model. """
43+
model = EvalModelTemplate()
44+
model.train(True)
45+
script = model.to_torchscript()
46+
assert model.training
47+
assert not script.training
48+
model.train(False)
49+
_ = model.to_torchscript()
50+
assert not model.training
51+
assert not script.training
52+
53+
54+
@pytest.mark.parametrize("modelclass", [
55+
EvalModelTemplate,
56+
ParityModuleRNN,
57+
BasicGAN,
58+
])
59+
def test_torchscript_properties(modelclass):
60+
""" Test that scripted LightningModule has unnecessary methods removed. """
61+
model = modelclass()
62+
model.datamodule = TrialMNISTDataModule()
63+
script = model.to_torchscript()
64+
assert not hasattr(script, "datamodule")
65+
assert not hasattr(model, "batch_size") or hasattr(script, "batch_size")
66+
assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate")
67+
68+
if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
69+
# only on torch >= 1.4 do these unused methods get removed
70+
assert not callable(getattr(script, "training_step", None))
71+
72+
73+
@pytest.mark.parametrize("modelclass", [
74+
EvalModelTemplate,
75+
ParityModuleRNN,
76+
BasicGAN,
77+
])
78+
@pytest.mark.skipif(
79+
LooseVersion(torch.__version__) < LooseVersion("1.5.0"),
80+
reason="torch.save/load has bug loading script modules on torch <= 1.4",
81+
)
82+
def test_torchscript_save_load(tmpdir, modelclass):
83+
""" Test that scripted LightningModules is correctly saved and can be loaded. """
84+
model = modelclass()
85+
output_file = str(tmpdir / "model.pt")
86+
script = model.to_torchscript(file_path=output_file)
87+
loaded_script = torch.jit.load(output_file)
88+
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))

0 commit comments

Comments
 (0)