Skip to content

Commit 3c8b781

Browse files
authored
Refactored assert allowed events (pytorch#1549)
- moved _utils._to_hours_mins_secs -> engine.utils._to_hours_mins_secs
1 parent 5ceacbb commit 3c8b781

File tree

3 files changed

+27
-29
lines changed

3 files changed

+27
-29
lines changed

ignite/_utils.py

-9
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,2 @@
1-
from typing import Tuple, Union
2-
31
# For compatibilty
42
from ignite.utils import apply_to_tensor, apply_to_type, convert_tensor, to_onehot
5-
6-
7-
def _to_hours_mins_secs(time_taken: Union[float, int]) -> Tuple[int, int, int]:
8-
"""Convert seconds to hours, mins, and seconds."""
9-
mins, secs = divmod(time_taken, 60)
10-
hours, mins = divmod(mins, 60)
11-
return round(hours), round(mins), round(secs)

ignite/engine/engine.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010

1111
from torch.utils.data import DataLoader
1212

13-
from ignite._utils import _to_hours_mins_secs
1413
from ignite.base import Serializable
1514
from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, EventsList, RemovableEventHandle, State
16-
from ignite.engine.utils import _check_signature
15+
from ignite.engine.utils import _check_signature, _to_hours_mins_secs
1716

1817
__all__ = ["Engine"]
1918

@@ -242,6 +241,11 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
242241
setattr(wrapper, "_parent", weakref.ref(handler))
243242
return wrapper
244243

244+
def _assert_allowed_event(self, event_name: Any) -> None:
245+
if event_name not in self._allowed_events:
246+
self.logger.error(f"attempt to add event handler to an invalid event {event_name}")
247+
raise ValueError(f"Event {event_name} is not a valid event for this {self.__class__.__name__}.")
248+
245249
def add_event_handler(self, event_name: Any, handler: Callable, *args: Any, **kwargs: Any) -> RemovableEventHandle:
246250
"""Add an event handler to be executed when the specified event is fired.
247251
@@ -297,9 +301,7 @@ def execute_something():
297301
event_filter = event_name.filter
298302
handler = self._handler_wrapper(handler, event_name, event_filter)
299303

300-
if event_name not in self._allowed_events:
301-
self.logger.error("attempt to add event handler to an invalid event %s.", event_name)
302-
raise ValueError(f"Event {event_name} is not a valid event for this Engine.")
304+
self._assert_allowed_event(event_name)
303305

304306
event_args = (Exception(),) if event_name == Events.EXCEPTION_RAISED else ()
305307
try:
@@ -308,7 +310,7 @@ def execute_something():
308310
except ValueError:
309311
_check_signature(handler, "handler", *(event_args + args), **kwargs)
310312
self._event_handlers[event_name].append((handler, args, kwargs))
311-
self.logger.debug("added handler for event %s.", event_name)
313+
self.logger.debug(f"added handler for event {event_name}")
312314

313315
return RemovableEventHandle(event_name, handler, self)
314316

@@ -415,13 +417,12 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
415417
**event_kwargs: optional keyword args to be passed to all handlers.
416418
417419
"""
418-
if event_name in self._allowed_events:
419-
self.logger.debug("firing handlers for event %s ", event_name)
420-
self.last_event_name = event_name
421-
for func, args, kwargs in self._event_handlers[event_name]:
422-
kwargs.update(event_kwargs)
423-
first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args)
424-
func(*first, *(event_args + others), **kwargs)
420+
self.logger.debug(f"firing handlers for event {event_name}")
421+
self.last_event_name = event_name
422+
for func, args, kwargs in self._event_handlers[event_name]:
423+
kwargs.update(event_kwargs)
424+
first, others = ((args[0],), args[1:]) if (args and args[0] == self) else ((), args)
425+
func(*first, *(event_args + others), **kwargs)
425426

426427
def fire_event(self, event_name: Any) -> None:
427428
"""Execute all the handlers associated with given event.
@@ -444,6 +445,7 @@ def fire_event(self, event_name: Any) -> None:
444445
:meth:`~ignite.engine.engine.Engine.register_events`.
445446
446447
"""
448+
self._assert_allowed_event(event_name)
447449
return self._fire_event(event_name)
448450

449451
def terminate(self) -> None:
@@ -765,9 +767,7 @@ def _internal_run(self) -> State:
765767
# update time wrt handlers
766768
self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
767769
hours, mins, secs = _to_hours_mins_secs(time_taken)
768-
self.logger.info(
769-
"Epoch[%s] Complete. Time taken: %02d:%02d:%02d" % (self.state.epoch, hours, mins, secs)
770-
)
770+
self.logger.info(f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:02d}")
771771
if self.should_terminate:
772772
break
773773

@@ -780,11 +780,11 @@ def _internal_run(self) -> State:
780780
# update time wrt handlers
781781
self.state.times[Events.COMPLETED.name] = time_taken
782782
hours, mins, secs = _to_hours_mins_secs(time_taken)
783-
self.logger.info("Engine run complete. Time taken: %02d:%02d:%02d" % (hours, mins, secs))
783+
self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:02d}")
784784

785785
except BaseException as e:
786786
self._dataloader_iter = None
787-
self.logger.error("Engine run is terminating due to exception: %s.", str(e))
787+
self.logger.error(f"Engine run is terminating due to exception: {e}")
788788
self._handle_exception(e)
789789

790790
self._dataloader_iter = None
@@ -869,7 +869,7 @@ def _run_once_on_dataset(self) -> float:
869869
break
870870

871871
except Exception as e:
872-
self.logger.error("Current run is terminating due to exception: %s.", str(e))
872+
self.logger.error(f"Current run is terminating due to exception: {e}")
873873
self._handle_exception(e)
874874

875875
return time.time() - start_time

ignite/engine/utils.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from typing import Any, Callable
2+
from typing import Any, Callable, Tuple, Union
33

44

55
def _check_signature(fn: Callable, fn_description: str, *args: Any, **kwargs: Any) -> None:
@@ -19,3 +19,10 @@ def _check_signature(fn: Callable, fn_description: str, *args: Any, **kwargs: An
1919
f"takes parameters {fn_params} but will be called with {passed_params}"
2020
f"({exception_msg})."
2121
)
22+
23+
24+
def _to_hours_mins_secs(time_taken: Union[float, int]) -> Tuple[int, int, int]:
25+
"""Convert seconds to hours, mins, and seconds."""
26+
mins, secs = divmod(time_taken, 60)
27+
hours, mins = divmod(mins, 60)
28+
return round(hours), round(mins), round(secs)

0 commit comments

Comments
 (0)