10
10
11
11
from torch .utils .data import DataLoader
12
12
13
- from ignite ._utils import _to_hours_mins_secs
14
13
from ignite .base import Serializable
15
14
from ignite .engine .events import CallableEventWithFilter , EventEnum , Events , EventsList , RemovableEventHandle , State
16
- from ignite .engine .utils import _check_signature
15
+ from ignite .engine .utils import _check_signature , _to_hours_mins_secs
17
16
18
17
__all__ = ["Engine" ]
19
18
@@ -242,6 +241,11 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
242
241
setattr (wrapper , "_parent" , weakref .ref (handler ))
243
242
return wrapper
244
243
244
+ def _assert_allowed_event (self , event_name : Any ) -> None :
245
+ if event_name not in self ._allowed_events :
246
+ self .logger .error (f"attempt to add event handler to an invalid event { event_name } " )
247
+ raise ValueError (f"Event { event_name } is not a valid event for this { self .__class__ .__name__ } ." )
248
+
245
249
def add_event_handler (self , event_name : Any , handler : Callable , * args : Any , ** kwargs : Any ) -> RemovableEventHandle :
246
250
"""Add an event handler to be executed when the specified event is fired.
247
251
@@ -297,9 +301,7 @@ def execute_something():
297
301
event_filter = event_name .filter
298
302
handler = self ._handler_wrapper (handler , event_name , event_filter )
299
303
300
- if event_name not in self ._allowed_events :
301
- self .logger .error ("attempt to add event handler to an invalid event %s." , event_name )
302
- raise ValueError (f"Event { event_name } is not a valid event for this Engine." )
304
+ self ._assert_allowed_event (event_name )
303
305
304
306
event_args = (Exception (),) if event_name == Events .EXCEPTION_RAISED else ()
305
307
try :
@@ -308,7 +310,7 @@ def execute_something():
308
310
except ValueError :
309
311
_check_signature (handler , "handler" , * (event_args + args ), ** kwargs )
310
312
self ._event_handlers [event_name ].append ((handler , args , kwargs ))
311
- self .logger .debug ("added handler for event %s." , event_name )
313
+ self .logger .debug (f "added handler for event { event_name } " )
312
314
313
315
return RemovableEventHandle (event_name , handler , self )
314
316
@@ -415,13 +417,12 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
415
417
**event_kwargs: optional keyword args to be passed to all handlers.
416
418
417
419
"""
418
- if event_name in self ._allowed_events :
419
- self .logger .debug ("firing handlers for event %s " , event_name )
420
- self .last_event_name = event_name
421
- for func , args , kwargs in self ._event_handlers [event_name ]:
422
- kwargs .update (event_kwargs )
423
- first , others = ((args [0 ],), args [1 :]) if (args and args [0 ] == self ) else ((), args )
424
- func (* first , * (event_args + others ), ** kwargs )
420
+ self .logger .debug (f"firing handlers for event { event_name } " )
421
+ self .last_event_name = event_name
422
+ for func , args , kwargs in self ._event_handlers [event_name ]:
423
+ kwargs .update (event_kwargs )
424
+ first , others = ((args [0 ],), args [1 :]) if (args and args [0 ] == self ) else ((), args )
425
+ func (* first , * (event_args + others ), ** kwargs )
425
426
426
427
def fire_event (self , event_name : Any ) -> None :
427
428
"""Execute all the handlers associated with given event.
@@ -444,6 +445,7 @@ def fire_event(self, event_name: Any) -> None:
444
445
:meth:`~ignite.engine.engine.Engine.register_events`.
445
446
446
447
"""
448
+ self ._assert_allowed_event (event_name )
447
449
return self ._fire_event (event_name )
448
450
449
451
def terminate (self ) -> None :
@@ -765,9 +767,7 @@ def _internal_run(self) -> State:
765
767
# update time wrt handlers
766
768
self .state .times [Events .EPOCH_COMPLETED .name ] = time_taken
767
769
hours , mins , secs = _to_hours_mins_secs (time_taken )
768
- self .logger .info (
769
- "Epoch[%s] Complete. Time taken: %02d:%02d:%02d" % (self .state .epoch , hours , mins , secs )
770
- )
770
+ self .logger .info (f"Epoch[{ self .state .epoch } ] Complete. Time taken: { hours :02d} :{ mins :02d} :{ secs :02d} " )
771
771
if self .should_terminate :
772
772
break
773
773
@@ -780,11 +780,11 @@ def _internal_run(self) -> State:
780
780
# update time wrt handlers
781
781
self .state .times [Events .COMPLETED .name ] = time_taken
782
782
hours , mins , secs = _to_hours_mins_secs (time_taken )
783
- self .logger .info ("Engine run complete. Time taken: % 02d:% 02d:% 02d" % ( hours , mins , secs ) )
783
+ self .logger .info (f "Engine run complete. Time taken: { hours : 02d} : { mins : 02d} : { secs : 02d} " )
784
784
785
785
except BaseException as e :
786
786
self ._dataloader_iter = None
787
- self .logger .error ("Engine run is terminating due to exception: %s." , str ( e ) )
787
+ self .logger .error (f "Engine run is terminating due to exception: { e } " )
788
788
self ._handle_exception (e )
789
789
790
790
self ._dataloader_iter = None
@@ -869,7 +869,7 @@ def _run_once_on_dataset(self) -> float:
869
869
break
870
870
871
871
except Exception as e :
872
- self .logger .error ("Current run is terminating due to exception: %s." , str ( e ) )
872
+ self .logger .error (f "Current run is terminating due to exception: { e } " )
873
873
self ._handle_exception (e )
874
874
875
875
return time .time () - start_time
0 commit comments