Skip to content

Commit 963f87a

Browse files
Adrian WälchliBordawilliamFalcon
authored andcommitted
Type Hints for Lightning Core (Lightning-AI#946)
* first pass for LightningModule typehints * fix return types * add missing types * add type annotations to grads.py * add type annotations to hooks.py * add type annotation to memory.py * proper docstring quotation marks * add type annotations to saving.py * fix cyclic import problem * fix cyclic import problem * add missing whitespace * finish type hints for load_from_ methods * docs: prepare_data does not return anything * fix auto types in docs * revert typehint for trainer in hook * remove unnecessary return docs * some fixes for memory docs * revert typing for args kwargs * added all missing None return types * remove unused import * add more details to dict/list return types * fix line too long * optimize imports * linted * Revert "linted" This reverts commit 8555961. * remove whitespace * update * update * update * update * update * changelog Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: William Falcon <[email protected]>
1 parent 1985c41 commit 963f87a

File tree

6 files changed

+143
-120
lines changed

6 files changed

+143
-120
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
1112
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))
1213

1314
### Changed

pytorch_lightning/core/grads.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""
22
Module to describe gradients
33
"""
4+
from typing import Dict
45

56
from torch import nn
67

78

89
class GradInformation(nn.Module):
910

10-
def grad_norm(self, norm_type):
11+
def grad_norm(self, norm_type: float) -> Dict[str, int]:
1112
results = {}
1213
total_norm = 0
1314
for name, p in self.named_parameters():

pytorch_lightning/core/hooks.py

+16-23
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
3. Add the correct place in the :py:mod:`pytorch_lightning.models.trainer` where it should be called.
1515
1616
"""
17-
17+
from typing import Any
1818

1919
import torch
20-
20+
from torch import Tensor
21+
from torch.optim.optimizer import Optimizer
2122

2223
try:
2324
from apex import amp
@@ -36,48 +37,45 @@ def on_sanity_check_start(self):
3637
:return:
3738
"""
3839

39-
def on_train_start(self):
40+
def on_train_start(self) -> None:
4041
"""Called at the beginning of training before sanity check
41-
:return:
4242
"""
4343
# do something at the start of training
4444

45-
def on_train_end(self):
45+
def on_train_end(self) -> None:
4646
"""
4747
Called at the end of training before logger experiment is closed
48-
:return:
4948
"""
5049
# do something at the end of training
5150

52-
def on_batch_start(self, batch):
51+
def on_batch_start(self, batch: Any) -> None:
5352
"""Called in the training loop before anything happens for that batch.
5453
5554
:param batch:
56-
:return:
5755
"""
5856
# do something when the batch starts
5957

60-
def on_batch_end(self):
58+
def on_batch_end(self) -> None:
6159
"""Called in the training loop after the batch."""
6260
# do something when the batch ends
6361

64-
def on_epoch_start(self):
62+
def on_epoch_start(self) -> None:
6563
"""Called in the training loop at the very beginning of the epoch."""
6664
# do something when the epoch starts
6765

68-
def on_epoch_end(self):
66+
def on_epoch_end(self) -> None:
6967
"""Called in the training loop at the very end of the epoch."""
7068
# do something when the epoch ends
7169

72-
def on_pre_performance_check(self):
70+
def on_pre_performance_check(self) -> None:
7371
"""Called at the very beginning of the validation loop."""
7472
# do something before validation starts
7573

76-
def on_post_performance_check(self):
74+
def on_post_performance_check(self) -> None:
7775
"""Called at the very end of the validation loop."""
7876
# do something before validation end
7977

80-
def on_before_zero_grad(self, optimizer):
78+
def on_before_zero_grad(self, optimizer: Optimizer) -> None:
8179
"""Called after optimizer.step() and before optimizer.zero_grad()
8280
8381
Called in the training loop after taking an optimizer step and before zeroing grads.
@@ -89,17 +87,13 @@ def on_before_zero_grad(self, optimizer):
8987
model.on_before_zero_grad(optimizer) # < ---- called here
9088
optimizer.zero_grad
9189
92-
:param optimizer:
93-
:return:
90+
:param optimizer: The optimizer for which grads should be zeroed.
9491
"""
9592
# do something with the optimizer or inspect it.
9693

97-
def on_after_backward(self):
98-
"""Called after loss.backward() and before optimizers do anything.
99-
100-
:return:
94+
def on_after_backward(self) -> None:
95+
"""Called in the training loop after loss.backward() and before optimizers do anything.
10196
102-
Called in the training loop after model.backward()
10397
This is the ideal place to inspect or log gradient information
10498
10599
.. code-block:: python
@@ -116,14 +110,13 @@ def on_after_backward(self):
116110
117111
"""
118112

119-
def backward(self, trainer, loss, optimizer, optimizer_idx):
113+
def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None:
120114
"""Override backward with your own implementation if you need to
121115
122116
:param trainer: Pointer to the trainer
123117
:param loss: Loss is already scaled by accumulated grads
124118
:param optimizer: Current optimizer being used
125119
:param optimizer_idx: Index of the current optimizer being used
126-
:return:
127120
128121
Called to perform backward step.
129122
Feel free to override as needed.

0 commit comments

Comments
 (0)