From e4c1bd17b3208b6c9be6629ef1d3c006356ca840 Mon Sep 17 00:00:00 2001 From: Victor Milewski Date: Mon, 9 Dec 2019 13:00:23 +0100 Subject: [PATCH] Allow for multiple example inputs when creating summary --- pytorch_lightning/core/memory.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index b01451f8cf71e..1abc349e1ef21 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -50,20 +50,31 @@ def get_variable_sizes(self): input_ = self.model.example_input_array if self.model.on_gpu: - input_ = input_.cuda(0) + device = next(self.model.parameters()).get_device() + # test if input is a list or a tuple + if isinstance(input_, (list, tuple)): + input_ = [input_i.cuda(device) if torch.is_tensor(input_i) else input_i + for input_i in input_] + else: + input_ = input_.cuda(device) if self.model.trainer.use_amp: - input_ = input_.half() + # test if it is not a list or a tuple + if isinstance(input_, (list, tuple)): + input_ = [input_i.half() if torch.is_tensor(input_i) else input_i + for input_i in input_] + else: + input_ = input_.half() with torch.no_grad(): for _, m in mods: - if type(input_) is list or type(input_) is tuple: # pragma: no cover + if isinstance(input_, (list, tuple)): # pragma: no cover out = m(*input_) else: out = m(input_) - if type(input_) is tuple or type(input_) is list: # pragma: no cover + if isinstance(input_, (list, tuple)): # pragma: no cover in_size = [] for x in input_: if type(x) is list: @@ -75,7 +86,7 @@ def get_variable_sizes(self): in_sizes.append(in_size) - if type(out) is tuple or type(out) is list: # pragma: no cover + if isinstance(out, (list, tuple)): # pragma: no cover out_size = np.asarray([x.size() for x in out]) else: out_size = np.array(out.size())