Skip to content

Commit 62ce00f

Browse files
EvalResult support for val loop (PR 3/5) (#2651)
* add EvalResult to support to val/test loops
1 parent a3934ad commit 62ce00f

21 files changed

+991
-178
lines changed

pytorch_lightning/callbacks/early_stopping.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -134,26 +134,32 @@ def load_state_dict(self, state_dict):
134134
self.best_score = state_dict['best_score']
135135
self.patience = state_dict['patience']
136136

137-
def on_sanity_check_end(self, trainer, pl_module):
138-
logs = trainer.callback_metrics
139-
self._validate_condition_metric(logs)
140-
141137
def on_validation_end(self, trainer, pl_module):
142138
self._run_early_stopping_check(trainer, pl_module)
143139

140+
def on_validation_epoch_end(self, trainer, pl_module):
141+
val_es_key = 'val_early_stop_on'
142+
if trainer.callback_metrics.get(val_es_key) is not None:
143+
self.monitor = val_es_key
144+
145+
# disable strict checking when using structured results
146+
if val_es_key in trainer.callback_metrics:
147+
self.strict = False
148+
149+
self._validate_condition_metric(trainer.callback_metrics)
150+
144151
def on_train_epoch_end(self, trainer, pl_module):
152+
# disable early stopping in train loop when there's a val loop
153+
if self.monitor == 'val_early_stop_on':
154+
return
155+
145156
# early stopping can also work in the train loop when there is no val loop and when using structured results
146157
should_check_early_stop = False
147158
train_es_key = 'early_stop_on'
148159
if trainer.callback_metrics.get(train_es_key, None) is not None:
149160
self.monitor = train_es_key
150161
should_check_early_stop = True
151162

152-
val_es_key = 'val_early_stop_on'
153-
if trainer.callback_metrics.get(val_es_key, None) is not None:
154-
self.monitor = val_es_key
155-
should_check_early_stop = True
156-
157163
if should_check_early_stop:
158164
self._run_early_stopping_check(trainer, pl_module)
159165

pytorch_lightning/callbacks/model_checkpoint.py

+4
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,10 @@ def on_validation_end(self, trainer, pl_module):
279279
if metrics.get('checkpoint_on') is not None:
280280
self.monitor = 'checkpoint_on'
281281

282+
# conditioned val metrics override conditioned train loop metrics
283+
if metrics.get('val_checkpoint_on') is not None:
284+
self.monitor = 'val_checkpoint_on'
285+
282286
if self.save_top_k == 0:
283287
# no models are saved
284288
return

pytorch_lightning/core/step_result.py

+29-14
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,23 @@ def log(
9797
if 'meta' not in self:
9898
self.__setitem__('meta', {})
9999

100-
self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx)
101-
102-
# set the value
103-
self.__setitem__(name, value)
100+
# if user requests both step and epoch, then we split the metric in two automatically
101+
# one will be logged per step. the other per epoch
102+
if on_step and on_epoch:
103+
# set step version
104+
step_name = f'step_{name}'
105+
self.__set_meta(step_name, value, prog_bar, logger, on_step=True, on_epoch=False, reduce_fx=reduce_fx)
106+
self.__setitem__(step_name, value)
107+
108+
# set epoch version
109+
epoch_name = f'epoch_{name}'
110+
self.__set_meta(epoch_name, value, prog_bar, logger, on_step=False, on_epoch=True, reduce_fx=reduce_fx)
111+
self.__setitem__(epoch_name, value)
112+
else:
113+
self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx)
114+
115+
# set the value
116+
self.__setitem__(name, value)
104117

105118
def __set_meta(
106119
self,
@@ -111,7 +124,7 @@ def __set_meta(
111124
on_step: bool,
112125
on_epoch: bool,
113126
reduce_fx: Callable,
114-
):
127+
):
115128
# set the meta for the item
116129
meta_value = value
117130
meta = dict(
@@ -122,6 +135,7 @@ def __set_meta(
122135
reduce_fx=reduce_fx,
123136
value=meta_value
124137
)
138+
125139
self['meta'][name] = meta
126140

127141
# track whether any input requires reduction on epoch end
@@ -219,11 +233,13 @@ def __copy__(self):
219233

220234
@classmethod
221235
def gather(cls, outputs):
222-
meta = outputs[0]['meta']
236+
meta = outputs[0].get('meta')
223237
result = cls()
224238
result = recursive_gather(outputs, result)
225239
recursive_stack(result)
226-
result['meta'] = meta
240+
241+
if meta:
242+
result['meta'] = meta
227243
return result
228244

229245
@classmethod
@@ -326,11 +342,10 @@ def log(
326342
):
327343
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
328344

345+
def get_callback_metrics(self) -> dict:
346+
result = {
347+
'val_early_stop_on': self.early_stop_on,
348+
'val_checkpoint_on': self.checkpoint_on
349+
}
329350

330-
# if __name__ == '__main__':
331-
# import torch
332-
# result = TrainResult()
333-
# result.hiddens = torch.tensor(1)
334-
# result.log('some', 123)
335-
# print(result)
336-
# result.minimize = torch.tensor(1)
351+
return result

0 commit comments

Comments
 (0)