From da527e84b66f34b3e3fe6a0de72c4c71de3be776 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 29 Aug 2020 19:49:17 +0200 Subject: [PATCH 01/38] script --- pytorch_lightning/core/lightning.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f7723e6945e23..2f99fa1e8c7cb 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -184,6 +184,7 @@ def forward(self, batch): return logits """ + return super().forward(*args, **kwargs) def training_step(self, *args, **kwargs): r""" @@ -1729,6 +1730,28 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg torch.onnx.export(self, input_data, file_path, **kwargs) + def to_torchscript(self): + """Saves the model as a JIT module. + This can be overridden to support custom TorchScript module export + Example: + >>> class SimpleModel(LightningModule): + ... def __init__(self): + ... super().__init__() + ... self.l1 = torch.nn.Linear(in_features=64, out_features=4) + ... + ... def forward(self, x): + ... return torch.relu(self.l1(x.view(x.size(0), -1))) + >>> with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile: + ... model = SimpleModel() + ... torch.jit.save(model.to_torchscript(), tmpfile.name) + ... os.path.isfile(tmpfile.name) + True + """ + mode = self.training + scripted_module = torch.jit.script(self.eval()) + self.training = mode + return scripted_module + @property def hparams(self) -> Union[AttributeDict, str]: if not hasattr(self, '_hparams'): From ea4b4e103285ea777a016245219f30e9f6c0faea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 29 Aug 2020 21:41:53 +0200 Subject: [PATCH 02/38] docs --- pytorch_lightning/core/lightning.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 2f99fa1e8c7cb..20a9c4ac84ee3 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1730,9 +1730,14 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg torch.onnx.export(self, input_data, file_path, **kwargs) - def to_torchscript(self): - """Saves the model as a JIT module. - This can be overridden to support custom TorchScript module export + def to_torchscript(self) -> torch.jit.ScriptModule: + """ + Compiles the model to a :class:`~torch.jit.ScriptModule`. + This can be overridden to support custom TorchScript module export. + + Note: + Requires the implementation of the :meth:`LightningModule.forward` method. + Example: >>> class SimpleModel(LightningModule): ... def __init__(self): From 0ad4af0f53fcbb866167da89e3b5ba00057d5311 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 29 Aug 2020 21:41:59 +0200 Subject: [PATCH 03/38] simple test --- tests/core/test_torchscript.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/core/test_torchscript.py diff --git a/tests/core/test_torchscript.py b/tests/core/test_torchscript.py new file mode 100644 index 0000000000000..f8228a2d3035c --- /dev/null +++ b/tests/core/test_torchscript.py @@ -0,0 +1,27 @@ +import torch + +from pytorch_lightning import LightningModule + + +class SimpleModel(LightningModule): + + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(in_features=64, out_features=4) + + def forward(self, x): + return torch.relu(self.l1(x.view(x.size(0), -1))) + + +def test_torchscript_save_load(tmpdir): + """ Test that scripted LightningModule behaves like the original. """ + model = SimpleModel() + example_input = torch.rand(5, 64) + script = model.to_torchscript() + assert isinstance(script, torch.jit.ScriptModule) + output_file = str(tmpdir / "model.jit") + torch.jit.save(script, output_file) + script = torch.jit.load(output_file) + model_output = model(example_input) + script_output = script(example_input) + assert torch.allclose(script_output, model_output) From 246e875d40d94dff9dd3d76053668dafe39383b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 29 Aug 2020 22:02:45 +0200 Subject: [PATCH 04/38] move test --- tests/{core => models}/test_torchscript.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) rename tests/{core => models}/test_torchscript.py (64%) diff --git a/tests/core/test_torchscript.py b/tests/models/test_torchscript.py similarity index 64% rename from tests/core/test_torchscript.py rename to tests/models/test_torchscript.py index f8228a2d3035c..32bbf422a84a8 100644 --- a/tests/core/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -1,6 +1,7 @@ import torch from pytorch_lightning import LightningModule +from tests.base import EvalModelTemplate class SimpleModel(LightningModule): @@ -15,13 +16,17 @@ def forward(self, x): def test_torchscript_save_load(tmpdir): """ Test that scripted LightningModule behaves like the original. """ - model = SimpleModel() - example_input = torch.rand(5, 64) + model = EvalModelTemplate() script = model.to_torchscript() assert isinstance(script, torch.jit.ScriptModule) output_file = str(tmpdir / "model.jit") torch.jit.save(script, output_file) script = torch.jit.load(output_file) - model_output = model(example_input) - script_output = script(example_input) + # properties + assert script.batch_size == model.batch_size + assert script.learning_rate == model.learning_rate + assert not callable(getattr(script, "training_step", None)) + # output matches + model_output = model(model.example_input_array) + script_output = script(model.example_input_array) assert torch.allclose(script_output, model_output) From cee541605c51dda561db231d6ee9ec8b1031a3ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 29 Aug 2020 22:02:53 +0200 Subject: [PATCH 05/38] fix doctest --- pytorch_lightning/core/lightning.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 20a9c4ac84ee3..9fea56832cd45 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1746,10 +1746,10 @@ def to_torchscript(self) -> torch.jit.ScriptModule: ... ... def forward(self, x): ... return torch.relu(self.l1(x.view(x.size(0), -1))) - >>> with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile: - ... model = SimpleModel() - ... torch.jit.save(model.to_torchscript(), tmpfile.name) - ... os.path.isfile(tmpfile.name) + ... + >>> model = SimpleModel() + >>> torch.jit.save(model.to_torchscript(), "model.jit") + >>> os.path.isfile("model.jit") True """ mode = self.training From 192f10fada26d57204820452aef8c2b828fcfcfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 00:19:47 +0200 Subject: [PATCH 06/38] no grad context --- pytorch_lightning/core/lightning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9fea56832cd45..e74799ad317b8 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1753,7 +1753,8 @@ def to_torchscript(self) -> torch.jit.ScriptModule: True """ mode = self.training - scripted_module = torch.jit.script(self.eval()) + with torch.no_grad(): + scripted_module = torch.jit.script(self.eval()) self.training = mode return scripted_module From c12aea2b37da861345933f07602ba58c921802cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 00:36:38 +0200 Subject: [PATCH 07/38] extend tests test test --- tests/base/models.py | 6 ++-- tests/models/test_torchscript.py | 51 +++++++++++++++++++++++++------- 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/tests/base/models.py b/tests/base/models.py index 7f295da59fd65..b4405dfef02e9 100644 --- a/tests/base/models.py +++ b/tests/base/models.py @@ -18,7 +18,7 @@ class Generator(nn.Module): - def __init__(self, latent_dim: tuple, img_shape: tuple): + def __init__(self, latent_dim: int, img_shape: tuple): super().__init__() self.img_shape = img_shape @@ -67,7 +67,7 @@ def forward(self, img): class TestGAN(LightningModule): """Implements a basic GAN for the purpose of illustrating multiple optimizers.""" - def __init__(self, hidden_dim, learning_rate, b1, b2, **kwargs): + def __init__(self, hidden_dim=128, learning_rate=0.001, b1=0.5, b2=0.999, **kwargs): super().__init__() self.hidden_dim = hidden_dim self.learning_rate = learning_rate @@ -163,6 +163,7 @@ def __init__(self): super().__init__() self.rnn = nn.LSTM(10, 20, batch_first=True) self.linear_out = nn.Linear(in_features=20, out_features=5) + self.example_input_array = torch.rand(2, 3, 10) def forward(self, x): seq, last = self.rnn(x) @@ -189,6 +190,7 @@ def __init__(self): self.c_d1_bn = nn.BatchNorm1d(128) self.c_d1_drop = nn.Dropout(0.3) self.c_d2 = nn.Linear(in_features=128, out_features=10) + self.example_input_array = torch.rand(2, 1, 28, 28) def forward(self, x): x = x.view(x.size(0), -1) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 32bbf422a84a8..e5bccd9eb2d8d 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -1,7 +1,9 @@ +import pytest import torch from pytorch_lightning import LightningModule from tests.base import EvalModelTemplate +from tests.base.models import ParityModuleRNN, TestGAN class SimpleModel(LightningModule): @@ -14,19 +16,46 @@ def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) -def test_torchscript_save_load(tmpdir): - """ Test that scripted LightningModule behaves like the original. """ - model = EvalModelTemplate() +@pytest.mark.parametrize("modelclass", [ + EvalModelTemplate, + ParityModuleRNN, + TestGAN, +]) +def test_torchscript_input_output(modelclass): + """ Test that scripted LightningModule forward works. """ + model = modelclass() script = model.to_torchscript() assert isinstance(script, torch.jit.ScriptModule) - output_file = str(tmpdir / "model.jit") - torch.jit.save(script, output_file) - script = torch.jit.load(output_file) - # properties - assert script.batch_size == model.batch_size - assert script.learning_rate == model.learning_rate - assert not callable(getattr(script, "training_step", None)) - # output matches model_output = model(model.example_input_array) script_output = script(model.example_input_array) assert torch.allclose(script_output, model_output) + + +@pytest.mark.parametrize("modelclass", [ + EvalModelTemplate, + ParityModuleRNN, + TestGAN, +]) +def test_torchscript_properties(modelclass): + """ Test that scripted LightningModule has unnecessary methods removed. """ + model = modelclass() + script = model.to_torchscript() + assert not hasattr(model, "batch_size") or hasattr(script, "batch_size") + assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate") + assert not callable(getattr(script, "training_step", None)) + + +@pytest.mark.parametrize("modelclass", [ + EvalModelTemplate, + ParityModuleRNN, + TestGAN, +]) +def test_torchscript_save_load(tmpdir, modelclass): + """ Test that scripted LightningModules can be saved and loaded. """ + model = modelclass() + script = model.to_torchscript() + assert isinstance(script, torch.jit.ScriptModule) + output_file = str(tmpdir / "model.jit") + torch.jit.save(script, output_file) + torch.jit.load(output_file) + From d6f6437078795f2ef87b6bd404383852f35957a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 00:40:59 +0200 Subject: [PATCH 08/38] datamodule test --- tests/models/test_torchscript.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index e5bccd9eb2d8d..92b742af8b099 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -39,7 +39,9 @@ def test_torchscript_input_output(modelclass): def test_torchscript_properties(modelclass): """ Test that scripted LightningModule has unnecessary methods removed. """ model = modelclass() + model.datamodule = TrialMNISTDataModule() script = model.to_torchscript() + assert not hasattr(script, "datamodule") assert not hasattr(model, "batch_size") or hasattr(script, "batch_size") assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate") assert not callable(getattr(script, "training_step", None)) From 49d2166f4d045942d2aa1ea1b4d0da2c482a6200 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 00:41:06 +0200 Subject: [PATCH 09/38] clean up test --- tests/models/test_torchscript.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 92b742af8b099..a62ec614097ec 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -1,21 +1,11 @@ import pytest import torch -from pytorch_lightning import LightningModule from tests.base import EvalModelTemplate +from tests.base.datamodules import TrialMNISTDataModule from tests.base.models import ParityModuleRNN, TestGAN -class SimpleModel(LightningModule): - - def __init__(self): - super().__init__() - self.l1 = torch.nn.Linear(in_features=64, out_features=4) - - def forward(self, x): - return torch.relu(self.l1(x.view(x.size(0), -1))) - - @pytest.mark.parametrize("modelclass", [ EvalModelTemplate, ParityModuleRNN, From e16749025aa36004a7430e8d843240f541ce8d12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 00:58:39 +0200 Subject: [PATCH 10/38] docs --- docs/source/production_inference.rst | 30 ++++++++++++++++++++++++++++ tests/models/test_torchscript.py | 1 - 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/docs/source/production_inference.rst b/docs/source/production_inference.rst index d3ab26b93e419..d59eedd99248c 100644 --- a/docs/source/production_inference.rst +++ b/docs/source/production_inference.rst @@ -1,3 +1,17 @@ +.. testsetup:: * + + import torch + from pytorch_lightning import LightningModule + + class SimpleModel(LightningModule): + + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(in_features=64, out_features=4) + + def forward(self, x): + return torch.relu(self.l1(x.view(x.size(0), -1))) + .. _production-inference: Inference in Production @@ -28,3 +42,19 @@ Once you have the exported model, you can run it on your ONNX runtime in the fol input_name = ort_session.get_inputs()[0].name ort_inputs = {input_name: np.random.randn(1, 64).astype(np.float32)} ort_outs = ort_session.run(None, ort_inputs) + + +Exporting to TorchScript +------------------------ + +TorchScript allows you to serialize your models in a way that it can be loaded in non-Python environments. +The LightningModule has a handy method :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript` +that returns a scripted module which you can save or directly use. + +.. testcode:: + + model = SimpleModel() + script = model.to_torchscript() + + # save for use in production environment + torch.jit.save(script, "model.pt") diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index a62ec614097ec..af1fc3eef725e 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -50,4 +50,3 @@ def test_torchscript_save_load(tmpdir, modelclass): output_file = str(tmpdir / "model.jit") torch.jit.save(script, output_file) torch.jit.load(output_file) - From 7223e9870cef46430a149cd7f3f795c0370b1a4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 00:59:12 +0200 Subject: [PATCH 11/38] name --- pytorch_lightning/core/lightning.py | 4 ++-- tests/models/test_torchscript.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e74799ad317b8..96bfaaeb4be22 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1748,8 +1748,8 @@ def to_torchscript(self) -> torch.jit.ScriptModule: ... return torch.relu(self.l1(x.view(x.size(0), -1))) ... >>> model = SimpleModel() - >>> torch.jit.save(model.to_torchscript(), "model.jit") - >>> os.path.isfile("model.jit") + >>> torch.jit.save(model.to_torchscript(), "model.pt") + >>> os.path.isfile("model.pt") True """ mode = self.training diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index af1fc3eef725e..481ec2fcf5eb3 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -47,6 +47,6 @@ def test_torchscript_save_load(tmpdir, modelclass): model = modelclass() script = model.to_torchscript() assert isinstance(script, torch.jit.ScriptModule) - output_file = str(tmpdir / "model.jit") + output_file = str(tmpdir / "model.pt") torch.jit.save(script, output_file) torch.jit.load(output_file) From 4f0dbbcd13f4472a3ef97286fd6109b73711b934 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 01:03:46 +0200 Subject: [PATCH 12/38] fix import --- docs/source/production_inference.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/production_inference.rst b/docs/source/production_inference.rst index d59eedd99248c..41cf4d0e5c4bb 100644 --- a/docs/source/production_inference.rst +++ b/docs/source/production_inference.rst @@ -1,7 +1,7 @@ .. testsetup:: * import torch - from pytorch_lightning import LightningModule + pytorch_lightning.core.lightning import LightningModule class SimpleModel(LightningModule): From e5da609b2c21280b69517750afe15528756e557c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 01:08:15 +0200 Subject: [PATCH 13/38] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d18fac2668279..67111fa624d92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528/)) +- Added `LightningModule.to_torchscript` to support exporting as `ScriptModule` ([#3258](https://github.com/PyTorchLightning/pytorch-lightning/pull/3258/)) + ### Changed From 92d2c5ac1c49ddc43f69b876f0215e62fd225338 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 01:30:15 +0200 Subject: [PATCH 14/38] fix import --- docs/source/production_inference.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/production_inference.rst b/docs/source/production_inference.rst index 41cf4d0e5c4bb..6139943e2bea8 100644 --- a/docs/source/production_inference.rst +++ b/docs/source/production_inference.rst @@ -1,7 +1,7 @@ .. testsetup:: * import torch - pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.core.lightning import LightningModule class SimpleModel(LightningModule): From 26ad18582a34daa08b2cf5c23b8ba2af2c58e7bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 01:55:19 +0200 Subject: [PATCH 15/38] skip pytorch 1.3 in test --- tests/models/test_torchscript.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 481ec2fcf5eb3..83f1681b821bf 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -1,3 +1,5 @@ +from distutils.version import LooseVersion + import pytest import torch @@ -11,6 +13,10 @@ ParityModuleRNN, TestGAN, ]) +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("1.4.0"), + reason="requires torch >= 1.4", +) def test_torchscript_input_output(modelclass): """ Test that scripted LightningModule forward works. """ model = modelclass() @@ -26,6 +32,10 @@ def test_torchscript_input_output(modelclass): ParityModuleRNN, TestGAN, ]) +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("1.4.0"), + reason="requires torch >= 1.4", + ) def test_torchscript_properties(modelclass): """ Test that scripted LightningModule has unnecessary methods removed. """ model = modelclass() @@ -42,6 +52,10 @@ def test_torchscript_properties(modelclass): ParityModuleRNN, TestGAN, ]) +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("1.4.0"), + reason="requires torch >= 1.4", + ) def test_torchscript_save_load(tmpdir, modelclass): """ Test that scripted LightningModules can be saved and loaded. """ model = modelclass() From ace4f4fdec5a79c6bf399c1392d9041f9a0790cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 02:27:20 +0200 Subject: [PATCH 16/38] update codeblock --- docs/source/production_inference.rst | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/docs/source/production_inference.rst b/docs/source/production_inference.rst index 6139943e2bea8..7b240b255428e 100644 --- a/docs/source/production_inference.rst +++ b/docs/source/production_inference.rst @@ -1,17 +1,3 @@ -.. testsetup:: * - - import torch - from pytorch_lightning.core.lightning import LightningModule - - class SimpleModel(LightningModule): - - def __init__(self): - super().__init__() - self.l1 = torch.nn.Linear(in_features=64, out_features=4) - - def forward(self, x): - return torch.relu(self.l1(x.view(x.size(0), -1))) - .. _production-inference: Inference in Production @@ -51,7 +37,7 @@ TorchScript allows you to serialize your models in a way that it can be loaded i The LightningModule has a handy method :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript` that returns a scripted module which you can save or directly use. -.. testcode:: +.. code-block:: python model = SimpleModel() script = model.to_torchscript() From b22dfc6fd89d95472ccc5c11c9898492a412dd1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 02:41:30 +0200 Subject: [PATCH 17/38] skip bugged 1.4 --- tests/models/test_torchscript.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 83f1681b821bf..9580e2d117f3d 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -13,10 +13,6 @@ ParityModuleRNN, TestGAN, ]) -@pytest.mark.skipif( - LooseVersion(torch.__version__) < LooseVersion("1.4.0"), - reason="requires torch >= 1.4", -) def test_torchscript_input_output(modelclass): """ Test that scripted LightningModule forward works. """ model = modelclass() @@ -32,10 +28,6 @@ def test_torchscript_input_output(modelclass): ParityModuleRNN, TestGAN, ]) -@pytest.mark.skipif( - LooseVersion(torch.__version__) < LooseVersion("1.4.0"), - reason="requires torch >= 1.4", - ) def test_torchscript_properties(modelclass): """ Test that scripted LightningModule has unnecessary methods removed. """ model = modelclass() @@ -53,9 +45,9 @@ def test_torchscript_properties(modelclass): TestGAN, ]) @pytest.mark.skipif( - LooseVersion(torch.__version__) < LooseVersion("1.4.0"), - reason="requires torch >= 1.4", - ) + LooseVersion(torch.__version__) < LooseVersion("1.5.0"), + reason="torch.save/load has bug loading script modules on torch <= 1.4", +) def test_torchscript_save_load(tmpdir, modelclass): """ Test that scripted LightningModules can be saved and loaded. """ model = modelclass() From 81bca94433732421f3fd9d22cdcc79056d1ec9b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 02:41:36 +0200 Subject: [PATCH 18/38] typehints --- tests/base/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/base/models.py b/tests/base/models.py index b4405dfef02e9..161c4da81d54a 100644 --- a/tests/base/models.py +++ b/tests/base/models.py @@ -67,7 +67,7 @@ def forward(self, img): class TestGAN(LightningModule): """Implements a basic GAN for the purpose of illustrating multiple optimizers.""" - def __init__(self, hidden_dim=128, learning_rate=0.001, b1=0.5, b2=0.999, **kwargs): + def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.001, b1: float = 0.5, b2: float = 0.999, **kwargs): super().__init__() self.hidden_dim = hidden_dim self.learning_rate = learning_rate From b883e9710489b7b67c8f736a25adf74ad43e46a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 02:54:03 +0200 Subject: [PATCH 19/38] doctest not working on all pytorch versions --- pytorch_lightning/core/lightning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 96bfaaeb4be22..957e66759de4c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1748,8 +1748,8 @@ def to_torchscript(self) -> torch.jit.ScriptModule: ... return torch.relu(self.l1(x.view(x.size(0), -1))) ... >>> model = SimpleModel() - >>> torch.jit.save(model.to_torchscript(), "model.pt") - >>> os.path.isfile("model.pt") + >>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP + >>> os.path.isfile("model.pt") # doctest: +SKIP True """ mode = self.training From b7be25431ee01eaa77a3482d12febf5c29f2341d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 03:01:55 +0200 Subject: [PATCH 20/38] rename TestGAN to prevent pytest interference --- tests/base/models.py | 2 +- tests/models/test_horovod.py | 4 ++-- tests/models/test_torchscript.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/base/models.py b/tests/base/models.py index 161c4da81d54a..9c319add4aca1 100644 --- a/tests/base/models.py +++ b/tests/base/models.py @@ -64,7 +64,7 @@ def forward(self, img): return validity -class TestGAN(LightningModule): +class BasicGAN(LightningModule): """Implements a basic GAN for the purpose of illustrating multiple optimizers.""" def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.001, b1: float = 0.5, b2: float = 0.999, **kwargs): diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index f48db196c104a..7c6dc3b7417c5 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -13,7 +13,7 @@ import tests.base.develop_utils as tutils from pytorch_lightning import Trainer from tests.base import EvalModelTemplate -from tests.base.models import TestGAN +from tests.base.models import BasicGAN try: from horovod.common.util import nccl_built @@ -145,7 +145,7 @@ def validation_step(self, batch, *args, **kwargs): @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") def test_horovod_multi_optimizer(tmpdir): - model = TestGAN(**EvalModelTemplate.get_default_hparams()) + model = BasicGAN(**EvalModelTemplate.get_default_hparams()) # fit model trainer = Trainer( diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 9580e2d117f3d..3de0f012e7eca 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -5,13 +5,13 @@ from tests.base import EvalModelTemplate from tests.base.datamodules import TrialMNISTDataModule -from tests.base.models import ParityModuleRNN, TestGAN +from tests.base.models import ParityModuleRNN, BasicGAN @pytest.mark.parametrize("modelclass", [ EvalModelTemplate, ParityModuleRNN, - TestGAN, + BasicGAN, ]) def test_torchscript_input_output(modelclass): """ Test that scripted LightningModule forward works. """ @@ -26,7 +26,7 @@ def test_torchscript_input_output(modelclass): @pytest.mark.parametrize("modelclass", [ EvalModelTemplate, ParityModuleRNN, - TestGAN, + BasicGAN, ]) def test_torchscript_properties(modelclass): """ Test that scripted LightningModule has unnecessary methods removed. """ @@ -42,7 +42,7 @@ def test_torchscript_properties(modelclass): @pytest.mark.parametrize("modelclass", [ EvalModelTemplate, ParityModuleRNN, - TestGAN, + BasicGAN, ]) @pytest.mark.skipif( LooseVersion(torch.__version__) < LooseVersion("1.5.0"), From 2af314b1edd28fd8bb1f8023b96e4ae84267aa5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 03:19:53 +0200 Subject: [PATCH 21/38] add note about pytorch version --- docs/source/production_inference.rst | 2 ++ pytorch_lightning/core/lightning.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/production_inference.rst b/docs/source/production_inference.rst index 7b240b255428e..0d909d65e5522 100644 --- a/docs/source/production_inference.rst +++ b/docs/source/production_inference.rst @@ -44,3 +44,5 @@ that returns a scripted module which you can save or directly use. # save for use in production environment torch.jit.save(script, "model.pt") + +It is recommended that you install the latest version of PyTorch to use this feature without limitations. diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 957e66759de4c..25b22f0e83ad2 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1736,7 +1736,9 @@ def to_torchscript(self) -> torch.jit.ScriptModule: This can be overridden to support custom TorchScript module export. Note: - Requires the implementation of the :meth:`LightningModule.forward` method. + - Requires the implementation of the :meth:`LightningModule.forward` method. + - It is recommended that you install the latest version of PyTorch to use this feature without limitations. + Example: >>> class SimpleModel(LightningModule): From 5d76d550251114ba53104a36f0e01a6e0d4311a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 03:27:19 +0200 Subject: [PATCH 22/38] fix torchscript version inconsistency in tests --- tests/models/test_torchscript.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 3de0f012e7eca..512e55af9b5fe 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -36,7 +36,10 @@ def test_torchscript_properties(modelclass): assert not hasattr(script, "datamodule") assert not hasattr(model, "batch_size") or hasattr(script, "batch_size") assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate") - assert not callable(getattr(script, "training_step", None)) + + if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): + # only on torch >= 1.4 do these unused methods get removed + assert not callable(getattr(script, "training_step", None)) @pytest.mark.parametrize("modelclass", [ From 4fd6ceefc2b9bf8abf31e16a5c4145e8c6b44f95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 03:49:03 +0200 Subject: [PATCH 23/38] reset training state + tests --- pytorch_lightning/core/lightning.py | 2 +- tests/models/test_torchscript.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 25b22f0e83ad2..30d987ebaf735 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1757,7 +1757,7 @@ def to_torchscript(self) -> torch.jit.ScriptModule: mode = self.training with torch.no_grad(): scripted_module = torch.jit.script(self.eval()) - self.training = mode + self.train(mode) return scripted_module @property diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 512e55af9b5fe..42b8ac5624442 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -18,11 +18,25 @@ def test_torchscript_input_output(modelclass): model = modelclass() script = model.to_torchscript() assert isinstance(script, torch.jit.ScriptModule) + model.eval() model_output = model(model.example_input_array) script_output = script(model.example_input_array) assert torch.allclose(script_output, model_output) +def test_torchscript_retain_training_state(): + """ Test that torchscript export does not alter the training mode of original model. """ + model = EvalModelTemplate() + model.train(True) + script = model.to_torchscript() + assert model.training + assert not script.training + model.train(False) + _ = model.to_torchscript() + assert not model.training + assert not script.training + + @pytest.mark.parametrize("modelclass", [ EvalModelTemplate, ParityModuleRNN, From 26c28f98cabf722c69021ccc88d3883244b31216 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 30 Aug 2020 21:54:42 +0200 Subject: [PATCH 24/38] update docstring --- pytorch_lightning/core/lightning.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 30d987ebaf735..733249b6b4e64 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1732,8 +1732,9 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg def to_torchscript(self) -> torch.jit.ScriptModule: """ - Compiles the model to a :class:`~torch.jit.ScriptModule`. - This can be overridden to support custom TorchScript module export. + By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. + If you would like to customize the modules that are scripted + or you want to use tracing you should override this method. Note: - Requires the implementation of the :meth:`LightningModule.forward` method. From 9366029766cad725062eb320d6deaef15f2d51e2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 1 Sep 2020 10:37:18 +0200 Subject: [PATCH 25/38] Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- docs/source/production_inference.rst | 2 +- pytorch_lightning/core/lightning.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/production_inference.rst b/docs/source/production_inference.rst index 0d909d65e5522..f81ffe5c8b979 100644 --- a/docs/source/production_inference.rst +++ b/docs/source/production_inference.rst @@ -45,4 +45,4 @@ that returns a scripted module which you can save or directly use. # save for use in production environment torch.jit.save(script, "model.pt") -It is recommended that you install the latest version of PyTorch to use this feature without limitations. +It is recommended that you install the latest supported version of PyTorch to use this feature without limitations. diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 30d987ebaf735..b5830e853e7f2 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1737,7 +1737,7 @@ def to_torchscript(self) -> torch.jit.ScriptModule: Note: - Requires the implementation of the :meth:`LightningModule.forward` method. - - It is recommended that you install the latest version of PyTorch to use this feature without limitations. + - It is recommended that you install the latest supported version of PyTorch to use this feature without limitations. Example: From 721cb5ec8e62e91cfdccf499fac74f2cf439df23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 1 Sep 2020 19:38:35 +0200 Subject: [PATCH 26/38] update docstring, dict return --- pytorch_lightning/core/lightning.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 733249b6b4e64..fb72d31c4dfc6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -23,7 +23,7 @@ import torch import torch.distributed as torch_distrib -from torch import Tensor +from torch import Tensor, ScriptModule from torch.nn import Module from torch.nn.parallel import DistributedDataParallel from torch.optim.optimizer import Optimizer @@ -1730,11 +1730,12 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg torch.onnx.export(self, input_data, file_path, **kwargs) - def to_torchscript(self) -> torch.jit.ScriptModule: + def to_torchscript(self) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. - If you would like to customize the modules that are scripted - or you want to use tracing you should override this method. + If you would like to customize the modules that are scripted or you want to use tracing + you should override this method. In case you want to return multiple modules, we + recommend using dictionary. Note: - Requires the implementation of the :meth:`LightningModule.forward` method. From 3598f21fb2f914bff5047fcef40154499f90f74f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 1 Sep 2020 20:38:55 +0200 Subject: [PATCH 27/38] add docs to index --- docs/source/lightning-module.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/lightning-module.rst b/docs/source/lightning-module.rst index 6be28af55523d..98312ccc84820 100644 --- a/docs/source/lightning-module.rst +++ b/docs/source/lightning-module.rst @@ -770,6 +770,12 @@ to_onnx .. autofunction:: pytorch_lightning.core.lightning.LightningModule.to_onnx :noindex: +to_torchscript +~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.core.lightning.LightningModule.to_torchscript + :noindex: + unfreeze ~~~~~~~~ From c1da6bd70d65b46849612e53ec1ab35e3e15190d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 1 Sep 2020 20:39:08 +0200 Subject: [PATCH 28/38] add link --- pytorch_lightning/core/lightning.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c983bbef3dc7d..c82fbc33f0523 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1735,12 +1735,13 @@ def to_torchscript(self) -> Union[ScriptModule, Dict[str, ScriptModule]]: By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you would like to customize the modules that are scripted or you want to use tracing you should override this method. In case you want to return multiple modules, we - recommend using dictionary. + recommend using a dictionary. Note: - Requires the implementation of the :meth:`LightningModule.forward` method. - - It is recommended that you install the latest supported version of PyTorch to use this feature without limitations. - + - It is recommended that you install the latest supported version of PyTorch + to use this feature without limitations. See also the :mod:`torch.jit` + documentation for supported features. Example: >>> class SimpleModel(LightningModule): From 7d1124a66a443740df594d26080a9243d76037f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 1 Sep 2020 20:42:33 +0200 Subject: [PATCH 29/38] doc eval mode --- pytorch_lightning/core/lightning.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c82fbc33f0523..6a998f8067e5d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1739,6 +1739,7 @@ def to_torchscript(self) -> Union[ScriptModule, Dict[str, ScriptModule]]: Note: - Requires the implementation of the :meth:`LightningModule.forward` method. + - The exported script will be set to evaluation mode. - It is recommended that you install the latest supported version of PyTorch to use this feature without limitations. See also the :mod:`torch.jit` documentation for supported features. From 7f180c021677555a25119cf8776d679df80509db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 1 Sep 2020 20:47:26 +0200 Subject: [PATCH 30/38] forward --- pytorch_lightning/core/lightning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6a998f8067e5d..6e55390100611 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1738,7 +1738,8 @@ def to_torchscript(self) -> Union[ScriptModule, Dict[str, ScriptModule]]: recommend using a dictionary. Note: - - Requires the implementation of the :meth:`LightningModule.forward` method. + - Requires the implementation of the + :meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method. - The exported script will be set to evaluation mode. - It is recommended that you install the latest supported version of PyTorch to use this feature without limitations. See also the :mod:`torch.jit` From 9687a29cc19d7f7d685351520b20a5fbb4397996 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 3 Sep 2020 04:22:06 +0200 Subject: [PATCH 31/38] optional save to file path --- pytorch_lightning/core/lightning.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6e55390100611..55374abfba162 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1730,13 +1730,17 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg torch.onnx.export(self, input_data, file_path, **kwargs) - def to_torchscript(self) -> Union[ScriptModule, Dict[str, ScriptModule]]: + def to_torchscript(self, file_path: str, **kwargs) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you would like to customize the modules that are scripted or you want to use tracing you should override this method. In case you want to return multiple modules, we recommend using a dictionary. + Args: + file_path: Path where to save the torchscript. Default: None (no file saved). + **kwargs: Additional arguments that will be passed to the :func:`torch.jit.save` function. + Note: - Requires the implementation of the :meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method. @@ -1758,11 +1762,20 @@ def to_torchscript(self) -> Union[ScriptModule, Dict[str, ScriptModule]]: >>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP >>> os.path.isfile("model.pt") # doctest: +SKIP True + + Return: + This LightningModule as a torchscript, regardless of whether file_path is + defined or not. """ + mode = self.training with torch.no_grad(): - scripted_module = torch.jit.script(self.eval()) + scripted_module = torch.jit.script(self.eval(), **kwargs) self.train(mode) + + if file_path is not None: + torch.jit.save(scripted_module, file_path) + return scripted_module @property From 868f8b47a57e617cd936d8402fea002a6a418fb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 3 Sep 2020 04:23:28 +0200 Subject: [PATCH 32/38] optional --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 55374abfba162..031e7d72bf775 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1730,7 +1730,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg torch.onnx.export(self, input_data, file_path, **kwargs) - def to_torchscript(self, file_path: str, **kwargs) -> Union[ScriptModule, Dict[str, ScriptModule]]: + def to_torchscript(self, file_path: str = None, **kwargs) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you would like to customize the modules that are scripted or you want to use tracing From 713e477e2e52d9ca95ff3945a0cc2424343bdf0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 3 Sep 2020 04:32:33 +0200 Subject: [PATCH 33/38] test torchscript device --- tests/models/test_torchscript.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 42b8ac5624442..c72892d541152 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -24,6 +24,19 @@ def test_torchscript_input_output(modelclass): assert torch.allclose(script_output, model_output) +@pytest.mark.parametrize("device", [ + torch.device('cpu'), + torch.device('cuda', 0) +]) +def test_torchscript_device(device): + """ Test that scripted module is on the correct device. """ + model = EvalModelTemplate().to(device) + script = model.to_torchscript() + assert next(script.parameters()).device == device + script_output = script(model.example_input_array.to(device)) + assert script_output.device == device + + def test_torchscript_retain_training_state(): """ Test that torchscript export does not alter the training mode of original model. """ model = EvalModelTemplate() From c1fc408c4343abb4e4df6b03dbc1b746b62cd32e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 3 Sep 2020 04:35:08 +0200 Subject: [PATCH 34/38] test save load with file path --- tests/models/test_torchscript.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index c72892d541152..8583f2eac27a9 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -79,10 +79,9 @@ def test_torchscript_properties(modelclass): reason="torch.save/load has bug loading script modules on torch <= 1.4", ) def test_torchscript_save_load(tmpdir, modelclass): - """ Test that scripted LightningModules can be saved and loaded. """ + """ Test that scripted LightningModules is correctly saved and can be loaded. """ model = modelclass() - script = model.to_torchscript() - assert isinstance(script, torch.jit.ScriptModule) output_file = str(tmpdir / "model.pt") - torch.jit.save(script, output_file) - torch.jit.load(output_file) + script = model.to_torchscript(file_path=output_file) + loaded_script = torch.jit.load(output_file) + assert torch.allclose(next(script.parameters()), next(loaded_script.parameters())) From f3289591b2b5721c6a35d0df07ed3be2485e32ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 3 Sep 2020 04:37:51 +0200 Subject: [PATCH 35/38] pep --- tests/models/test_torchscript.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 8583f2eac27a9..c7ebb02dc94a5 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -35,7 +35,7 @@ def test_torchscript_device(device): assert next(script.parameters()).device == device script_output = script(model.example_input_array.to(device)) assert script_output.device == device - + def test_torchscript_retain_training_state(): """ Test that torchscript export does not alter the training mode of original model. """ From 1edd4e42524c097628b84d5fd12fb6b3e3a83d5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 3 Sep 2020 04:42:27 +0200 Subject: [PATCH 36/38] str --- tests/models/test_torchscript.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index c7ebb02dc94a5..4a819bf350d22 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -25,8 +25,8 @@ def test_torchscript_input_output(modelclass): @pytest.mark.parametrize("device", [ - torch.device('cpu'), - torch.device('cuda', 0) + torch.device("cpu"), + torch.device("cuda", 0) ]) def test_torchscript_device(device): """ Test that scripted module is on the correct device. """ From 1263b573ef03fb71d3e27c73e4f599daa2eaafb6 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Thu, 3 Sep 2020 09:01:10 +0200 Subject: [PATCH 37/38] Commit typing suggestion Co-authored-by: ananthsub --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 031e7d72bf775..9b41daa1cd470 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1730,7 +1730,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg torch.onnx.export(self, input_data, file_path, **kwargs) - def to_torchscript(self, file_path: str = None, **kwargs) -> Union[ScriptModule, Dict[str, ScriptModule]]: + def to_torchscript(self, file_path: Optional[str] = None, **kwargs) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you would like to customize the modules that are scripted or you want to use tracing From 78e2d606471d65165bdfcc94bd8b2c9053149445 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 3 Sep 2020 18:35:59 +0200 Subject: [PATCH 38/38] skip test if cuda not available --- tests/models/test_torchscript.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 4a819bf350d22..a57a931820c55 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -28,6 +28,7 @@ def test_torchscript_input_output(modelclass): torch.device("cpu"), torch.device("cuda", 0) ]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") def test_torchscript_device(device): """ Test that scripted module is on the correct device. """ model = EvalModelTemplate().to(device)