Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EvalResult support for val loop (PR 3/5) #2651

Merged
merged 116 commits into from
Jul 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
116 commits
Select commit Hold shift + click to select a range
094f8f7
orange berries
williamFalcon Jul 20, 2020
d81ee35
orange berries
williamFalcon Jul 20, 2020
c2edadb
logging metrics
williamFalcon Jul 21, 2020
675bb8c
logging metrics
williamFalcon Jul 21, 2020
62a1b91
logging metrics
williamFalcon Jul 21, 2020
28b101a
logging metrics
williamFalcon Jul 21, 2020
f8e48a8
logging metrics
williamFalcon Jul 21, 2020
0895daf
logging metrics
williamFalcon Jul 21, 2020
72fb07a
logging metrics
williamFalcon Jul 21, 2020
3b5d542
logging metrics
williamFalcon Jul 21, 2020
bdbb0a3
logging metrics
williamFalcon Jul 21, 2020
2756fe1
logging metrics
williamFalcon Jul 21, 2020
98bb6c7
logging metrics
williamFalcon Jul 21, 2020
a2a3559
logging metrics
williamFalcon Jul 21, 2020
296420e
logging metrics
williamFalcon Jul 21, 2020
ae9bc3f
logging metrics
williamFalcon Jul 21, 2020
ea680e5
logging metrics
williamFalcon Jul 21, 2020
27c9ce2
logging metrics
williamFalcon Jul 21, 2020
98cff7d
logging metrics
williamFalcon Jul 21, 2020
3e9d06a
logging metrics
williamFalcon Jul 21, 2020
24e3cf9
logging metrics
williamFalcon Jul 21, 2020
a0eeaff
logging metrics
williamFalcon Jul 21, 2020
9932eb1
logging metrics
williamFalcon Jul 21, 2020
a4843ff
logging metrics
williamFalcon Jul 21, 2020
05f7a21
logging metrics
williamFalcon Jul 21, 2020
9428044
logging metrics
williamFalcon Jul 21, 2020
27275fd
logging metrics
williamFalcon Jul 21, 2020
6450d7e
logging metrics
williamFalcon Jul 21, 2020
1e0b6c5
logging metrics
williamFalcon Jul 21, 2020
cd89c88
logging metrics
williamFalcon Jul 21, 2020
366aba6
logging metrics
williamFalcon Jul 21, 2020
617d110
logging metrics
williamFalcon Jul 21, 2020
8e9ea97
logging metrics
williamFalcon Jul 21, 2020
ad6c5cc
logging metrics
williamFalcon Jul 21, 2020
f7e7e8f
logging metrics
williamFalcon Jul 21, 2020
3f8f79d
logging metrics
williamFalcon Jul 21, 2020
abe2533
logging metrics
williamFalcon Jul 21, 2020
135008d
logging metrics
williamFalcon Jul 21, 2020
987df55
logging metrics
williamFalcon Jul 21, 2020
6c01070
logging metrics
williamFalcon Jul 21, 2020
eec349b
logging metrics
williamFalcon Jul 21, 2020
9c98ca8
logging metrics
williamFalcon Jul 21, 2020
0e7f820
logging metrics
williamFalcon Jul 21, 2020
b86d4c8
logging metrics
williamFalcon Jul 21, 2020
aff9513
logging metrics
williamFalcon Jul 21, 2020
0cda5ce
logging metrics
williamFalcon Jul 22, 2020
731de49
logging metrics
williamFalcon Jul 22, 2020
911b2d3
logging metrics
williamFalcon Jul 22, 2020
7093b41
logging metrics
williamFalcon Jul 22, 2020
350091b
logging metrics
williamFalcon Jul 22, 2020
a5e7c4f
logging metrics
williamFalcon Jul 22, 2020
508d32c
logging metrics
williamFalcon Jul 22, 2020
a3aa4e4
logging metrics
williamFalcon Jul 22, 2020
2d8eb84
logging metrics
williamFalcon Jul 22, 2020
19e35e3
logging metrics
williamFalcon Jul 22, 2020
049bcba
logging metrics
williamFalcon Jul 22, 2020
457044c
logging metrics
williamFalcon Jul 22, 2020
b7c2a71
logging metrics
williamFalcon Jul 22, 2020
21ee13b
logging metrics
williamFalcon Jul 22, 2020
5dc3bd2
logging metrics
williamFalcon Jul 22, 2020
6be625b
logging metrics
williamFalcon Jul 22, 2020
d1b35ab
logging metrics
williamFalcon Jul 22, 2020
eb7b55d
logging metrics
williamFalcon Jul 22, 2020
f392b6d
logging metrics
williamFalcon Jul 22, 2020
fefb3ff
logging metrics
williamFalcon Jul 22, 2020
58444f9
logging metrics
williamFalcon Jul 22, 2020
5786fb1
logging metrics
williamFalcon Jul 22, 2020
b10f867
logging metrics
williamFalcon Jul 22, 2020
ffda9f6
logging metrics
williamFalcon Jul 22, 2020
b27f3e3
logging metrics
williamFalcon Jul 22, 2020
ba1c8cb
logging metrics
williamFalcon Jul 22, 2020
8a35cb9
logging metrics
williamFalcon Jul 22, 2020
fd447f8
logging metrics
williamFalcon Jul 22, 2020
95a9177
logging metrics
williamFalcon Jul 22, 2020
e1ff092
logging metrics
williamFalcon Jul 22, 2020
8675f17
logging metrics
williamFalcon Jul 22, 2020
3cadf41
logging metrics
williamFalcon Jul 22, 2020
1c65951
logging metrics
williamFalcon Jul 22, 2020
2c42427
logging metrics
williamFalcon Jul 22, 2020
0a61dc9
logging metrics
williamFalcon Jul 22, 2020
71ade04
logging metrics
williamFalcon Jul 22, 2020
79bda13
logging metrics
williamFalcon Jul 22, 2020
0646a72
logging metrics
williamFalcon Jul 22, 2020
42f32f1
logging metrics
williamFalcon Jul 22, 2020
16fa4f6
logging metrics
williamFalcon Jul 22, 2020
e4ed1ee
logging metrics
williamFalcon Jul 22, 2020
4a9f8f5
logging metrics
williamFalcon Jul 22, 2020
695c27e
logging metrics
williamFalcon Jul 22, 2020
2d49484
logging metrics
williamFalcon Jul 22, 2020
fb5dafb
logging metrics
williamFalcon Jul 22, 2020
e64e65b
logging metrics
williamFalcon Jul 22, 2020
6971a93
logging metrics
williamFalcon Jul 22, 2020
83f4ffd
logging metrics
williamFalcon Jul 22, 2020
2d43ee1
logging metrics
williamFalcon Jul 22, 2020
040e271
logging metrics
williamFalcon Jul 22, 2020
e257d24
Update pytorch_lightning/callbacks/early_stopping.py
williamFalcon Jul 22, 2020
21bcec6
Update pytorch_lightning/callbacks/model_checkpoint.py
williamFalcon Jul 22, 2020
d0a81f3
Update pytorch_lightning/trainer/evaluation_loop.py
williamFalcon Jul 22, 2020
f25261d
logging metrics
williamFalcon Jul 22, 2020
b0a937a
logging metrics
williamFalcon Jul 22, 2020
c6f1ce3
passing model
Borda Jul 22, 2020
26464ec
Merge branch 'val_st' of https://github.com/PyTorchLightning/pytorch-…
Borda Jul 22, 2020
bd2f876
Update pytorch_lightning/callbacks/model_checkpoint.py
williamFalcon Jul 22, 2020
7667d5c
Update pytorch_lightning/core/step_result.py
williamFalcon Jul 22, 2020
240bac1
logging metrics
williamFalcon Jul 22, 2020
5391275
Merge branch 'val_st' of https://github.com/PyTorchLightning/pytorch-…
williamFalcon Jul 22, 2020
3847dd0
logging metrics
williamFalcon Jul 22, 2020
30ea127
logging metrics
williamFalcon Jul 22, 2020
b3d6373
logging metrics
williamFalcon Jul 22, 2020
78e9da2
logging metrics
williamFalcon Jul 22, 2020
9366c3a
logging metrics
williamFalcon Jul 22, 2020
254ed19
logging metrics
williamFalcon Jul 22, 2020
e14f6d4
logging metrics
williamFalcon Jul 22, 2020
028dd80
logging metrics
williamFalcon Jul 22, 2020
59539d4
logging metrics
williamFalcon Jul 22, 2020
052bd48
logging metrics
williamFalcon Jul 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,26 +134,32 @@ def load_state_dict(self, state_dict):
self.best_score = state_dict['best_score']
self.patience = state_dict['patience']

def on_sanity_check_end(self, trainer, pl_module):
logs = trainer.callback_metrics
self._validate_condition_metric(logs)

def on_validation_end(self, trainer, pl_module):
self._run_early_stopping_check(trainer, pl_module)

def on_validation_epoch_end(self, trainer, pl_module):
Comment on lines 137 to +140
Copy link
Contributor

@awaelchli awaelchli Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since when do we have two of these callback methods? what is the difference?
I suspect on_validation_epoch_end end is never actually called. I can only find the other one
also there are no tests in test_callbacks.py for this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use just one and if we assume that on_validation_epoch_end is a better name, deprecate the other one

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well validation_end is deprecated so I assume on_validation_end shall be too @williamFalcon

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly.

we can deprecate in the new PR. (these were the old names which we deprecated)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lest add deprecation warning as soon as possible so users get enough time...

val_es_key = 'val_early_stop_on'
if trainer.callback_metrics.get(val_es_key) is not None:
self.monitor = val_es_key

# disable strict checking when using structured results
if val_es_key in trainer.callback_metrics:
self.strict = False

self._validate_condition_metric(trainer.callback_metrics)

def on_train_epoch_end(self, trainer, pl_module):
# disable early stopping in train loop when there's a val loop
if self.monitor == 'val_early_stop_on':
return

# early stopping can also work in the train loop when there is no val loop and when using structured results
should_check_early_stop = False
train_es_key = 'early_stop_on'
if trainer.callback_metrics.get(train_es_key, None) is not None:
self.monitor = train_es_key
should_check_early_stop = True

val_es_key = 'val_early_stop_on'
if trainer.callback_metrics.get(val_es_key, None) is not None:
self.monitor = val_es_key
should_check_early_stop = True

if should_check_early_stop:
self._run_early_stopping_check(trainer, pl_module)

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def on_validation_end(self, trainer, pl_module):
if metrics.get('checkpoint_on') is not None:
self.monitor = 'checkpoint_on'

# conditioned val metrics override conditioned train loop metrics
if metrics.get('val_checkpoint_on') is not None:
self.monitor = 'val_checkpoint_on'

if self.save_top_k == 0:
# no models are saved
return
Expand Down
43 changes: 29 additions & 14 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,23 @@ def log(
if 'meta' not in self:
self.__setitem__('meta', {})

self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx)

# set the value
self.__setitem__(name, value)
# if user requests both step and epoch, then we split the metric in two automatically
# one will be logged per step. the other per epoch
if on_step and on_epoch:
# set step version
step_name = f'step_{name}'
self.__set_meta(step_name, value, prog_bar, logger, on_step=True, on_epoch=False, reduce_fx=reduce_fx)
self.__setitem__(step_name, value)

# set epoch version
epoch_name = f'epoch_{name}'
self.__set_meta(epoch_name, value, prog_bar, logger, on_step=False, on_epoch=True, reduce_fx=reduce_fx)
self.__setitem__(epoch_name, value)
else:
self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx)

# set the value
self.__setitem__(name, value)

def __set_meta(
self,
Expand All @@ -111,7 +124,7 @@ def __set_meta(
on_step: bool,
on_epoch: bool,
reduce_fx: Callable,
):
):
# set the meta for the item
meta_value = value
meta = dict(
Expand All @@ -122,6 +135,7 @@ def __set_meta(
reduce_fx=reduce_fx,
value=meta_value
)

self['meta'][name] = meta

# track whether any input requires reduction on epoch end
Expand Down Expand Up @@ -219,11 +233,13 @@ def __copy__(self):

@classmethod
def gather(cls, outputs):
meta = outputs[0]['meta']
meta = outputs[0].get('meta')
result = cls()
result = recursive_gather(outputs, result)
recursive_stack(result)
result['meta'] = meta

if meta:
result['meta'] = meta
return result

@classmethod
Expand Down Expand Up @@ -326,11 +342,10 @@ def log(
):
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)

def get_callback_metrics(self) -> dict:
result = {
'val_early_stop_on': self.early_stop_on,
'val_checkpoint_on': self.checkpoint_on
}

# if __name__ == '__main__':
# import torch
# result = TrainResult()
# result.hiddens = torch.tensor(1)
# result.log('some', 123)
# print(result)
# result.minimize = torch.tensor(1)
return result
Loading