@@ -71,21 +71,23 @@ class EarlyStopping(Callback):
71
71
Stop training when a monitored quantity has stopped improving.
72
72
73
73
Args:
74
- monitor (str): quantity to be monitored.
74
+ monitor (str): quantity to be monitored. Default: ``'val_loss'``.
75
75
min_delta (float): minimum change in the monitored quantity
76
76
to qualify as an improvement, i.e. an absolute
77
- change of less than min_delta, will count as no
78
- improvement.
77
+ change of less than ` min_delta` , will count as no
78
+ improvement. Default: ``0``.
79
79
patience (int): number of epochs with no improvement
80
- after which training will be stopped.
81
- verbose (bool): verbosity mode.
80
+ after which training will be stopped. Default: ``0``.
81
+ verbose (bool): verbosity mode. Default: ``0``.
82
82
mode (str): one of {auto, min, max}. In `min` mode,
83
83
training will stop when the quantity
84
84
monitored has stopped decreasing; in `max`
85
85
mode it will stop when the quantity
86
86
monitored has stopped increasing; in `auto`
87
87
mode, the direction is automatically inferred
88
- from the name of the monitored quantity.
88
+ from the name of the monitored quantity. Default: ``'auto'``.
89
+ strict (bool): whether to crash the training if `monitor` is
90
+ not found in the metrics. Default: ``True``.
89
91
90
92
Example::
91
93
@@ -97,18 +99,20 @@ class EarlyStopping(Callback):
97
99
"""
98
100
99
101
def __init__ (self , monitor = 'val_loss' ,
100
- min_delta = 0.0 , patience = 0 , verbose = 0 , mode = 'auto' ):
102
+ min_delta = 0.0 , patience = 0 , verbose = 0 , mode = 'auto' , strict = True ):
101
103
super (EarlyStopping , self ).__init__ ()
102
104
103
105
self .monitor = monitor
104
106
self .patience = patience
105
107
self .verbose = verbose
108
+ self .strict = strict
106
109
self .min_delta = min_delta
107
110
self .wait = 0
108
111
self .stopped_epoch = 0
109
112
110
113
if mode not in ['auto' , 'min' , 'max' ]:
111
- logging .info (f'EarlyStopping mode { mode } is unknown, fallback to auto mode.' )
114
+ if self .verbose > 0 :
115
+ logging .info (f'EarlyStopping mode { mode } is unknown, fallback to auto mode.' )
112
116
mode = 'auto'
113
117
114
118
if mode == 'min' :
@@ -128,23 +132,34 @@ def __init__(self, monitor='val_loss',
128
132
129
133
self .on_train_begin ()
130
134
135
+ def check_metrics (self , logs ):
136
+ monitor_val = logs .get (self .monitor )
137
+ error_msg = (f'Early stopping conditioned on metric `{ self .monitor } `'
138
+ f' which is not available. Available metrics are:'
139
+ f' `{ "`, `" .join (list (logs .keys ()))} `' )
140
+
141
+ if monitor_val is None :
142
+ if self .strict :
143
+ raise RuntimeError (error_msg )
144
+ elif self .verbose > 0 :
145
+ warnings .warn (error_msg , RuntimeWarning )
146
+
147
+ return False
148
+
149
+ return True
150
+
131
151
def on_train_begin (self , logs = None ):
132
152
# Allow instances to be re-used
133
153
self .wait = 0
134
154
self .stopped_epoch = 0
135
155
self .best = np .Inf if self .monitor_op == np .less else - np .Inf
136
156
137
157
def on_epoch_end (self , epoch , logs = None ):
138
- current = logs .get (self .monitor )
139
158
stop_training = False
140
- if current is None :
141
- warnings .warn (
142
- f'Early stopping conditioned on metric `{ self .monitor } `'
143
- f' which is not available. Available metrics are: { "," .join (list (logs .keys ()))} ' ,
144
- RuntimeWarning )
145
- stop_training = True
159
+ if not self .check_metrics (logs ):
146
160
return stop_training
147
161
162
+ current = logs .get (self .monitor )
148
163
if self .monitor_op (current - self .min_delta , self .best ):
149
164
self .best = current
150
165
self .wait = 0
0 commit comments