You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Run python gpu_template.py --gpus 2 --distributed_backend dp
See error
Code sample
Error is below.
Validation sanity check: 0it [00:00, ?it/s]Traceback (most recent call last):
File "/root/workdir/pytorch-lightning/pl_examples/basic_examples/gpu_template.py", line 80, in <module>
main(hyperparams)
File "/root/workdir/pytorch-lightning/pl_examples/basic_examples/gpu_template.py", line 41, in main
trainer.fit(model)
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 853, in fit
self.dp_train(model)
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_parts.py", line 578, in dp_train
self.run_pretrain_routine(model)
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1001, in run_pretrain_routine
False)
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 277, in _evaluate
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 424, in evaluation_forward
output = model(*args)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/overrides/data_parallel.py", line 66, in forward
return self.gather(outputs, self.output_device)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 165, in gather
return gather(outputs, output_device, dim=self.dim)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
res = gather_map(outputs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 62, in gather_map
for k in out))
File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 62, in <genexpr>
for k in out))
File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map
return type(out)(map(gather_map, zip(*outputs)))
TypeError: zip argument #1 must support iteration
defvalidation_step(self, batch, batch_idx):
""" Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """x, y=batchy_hat=self(x)
val_loss=F.cross_entropy(y_hat, y)
labels_hat=torch.argmax(y_hat, dim=1)
n_correct_pred=torch.sum(y==labels_hat).item()
return {'val_loss': val_loss, "n_correct_pred": n_correct_pred, "n_pred": len(x)}
defvalidation_epoch_end(self, outputs):
""" Called at the end of validation to aggregate outputs. :param outputs: list of individual outputs of each validation step. """avg_loss=torch.stack([x['val_loss'] forxinoutputs]).mean()
val_acc=sum([x['n_correct_pred'] forxinoutputs]) /sum(x['n_pred'] forxinoutputs)
tensorboard_logs= {'val_loss': avg_loss, 'val_acc': val_acc}
return {'val_loss': avg_loss, 'log': tensorboard_logs}
to
defvalidation_step(self, batch, batch_idx):
""" Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """x, y=batchy_hat=self(x)
val_loss=F.cross_entropy(y_hat, y)
labels_hat=torch.argmax(y_hat, dim=1)
n_correct_pred=torch.sum(y==labels_hat)
return {
"val_loss": val_loss,
"n_correct_pred": n_correct_pred,
"n_pred": torch.tensor(len(x)).to(val_loss.device),
}
defvalidation_epoch_end(self, outputs):
""" Called at the end of validation to aggregate outputs. :param outputs: list of individual outputs of each validation step. """avg_loss= (
torch.stack([x["val_loss"].detach().cpu() forxinoutputs]).mean().item()
)
val_acc=np.sum(
[x["n_correct_pred"].detach().cpu().numpy() forxinoutputs]
) /np.sum([x["n_pred"].detach().cpu().numpy() forxinoutputs])
tensorboard_logs= {"val_loss": avg_loss, "val_acc": val_acc}
print({"val_loss": avg_loss, "log": tensorboard_logs})
return {"val_loss": avg_loss, "log": tensorboard_logs}
But this approach is not elegant ...
Expected behavior
Return values other than torch.Tensor are allowed.
Environment
PyTorch Version : 1.5
OS (e.g., Linux): Ubuntu 18.04
How you installed PyTorch conda
Python version: 3.7.7
CUDA/cuDNN version: 10.2
Additional context
The text was updated successfully, but these errors were encountered:
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
./pl_examples/basic_examples
python gpu_template.py --gpus 2 --distributed_backend dp
Code sample
Error is below.
This error has something to do with this code (https://github.com/pytorch/pytorch/blob/f4f0dd470c7eb51511194a52e87f0ceec5d4e05e/torch/nn/parallel/scatter_gather.py#L47).
And this error can be fixed by doing the following in
./pl_examples/models/lightning_template.py
to
But this approach is not elegant ...
Expected behavior
Environment
conda
Additional context
The text was updated successfully, but these errors were encountered: