Skip to content

Commit 5ceacbb

Browse files
sdesrozisDesroziers
and
Desroziers
authored
Use events list for logger handlers (pytorch#1544)
* use events list for loggers * autopep8 fix * add test * fix mypy * add test to catch error * improve docstring Co-authored-by: Desroziers <[email protected]> Co-authored-by: sdesrozis <[email protected]>
1 parent f3fc875 commit 5ceacbb

File tree

4 files changed

+36
-13
lines changed

4 files changed

+36
-13
lines changed

ignite/contrib/handlers/base_logger.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn as nn
88
from torch.optim import Optimizer
99

10-
from ignite.engine import Engine, Events, State
10+
from ignite.engine import Engine, Events, EventsList, State
1111
from ignite.engine.events import CallableEventWithFilter, RemovableEventHandle
1212

1313

@@ -147,26 +147,34 @@ class BaseLogger(metaclass=ABCMeta):
147147
"""
148148

149149
def attach(
150-
self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter]
150+
self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter, EventsList]
151151
) -> RemovableEventHandle:
152152
"""Attach the logger to the engine and execute `log_handler` function at `event_name` events.
153153
154154
Args:
155155
engine (Engine): engine object.
156156
log_handler (callable): a logging handler to execute
157157
event_name: event to attach the logging handler to. Valid events are from
158-
:class:`~ignite.engine.events.Events` or any `event_name` added by
159-
:meth:`~ignite.engine.engine.Engine.register_events`.
158+
:class:`~ignite.engine.events.Events` or class:`~ignite.engine.events.EventsList` or any `event_name`
159+
added by :meth:`~ignite.engine.engine.Engine.register_events`.
160160
161161
Returns:
162162
:class:`~ignite.engine.RemovableEventHandle`, which can be used to remove the handler.
163163
"""
164-
name = event_name
164+
if isinstance(event_name, EventsList):
165+
for name in event_name:
166+
if name not in State.event_to_attr:
167+
raise RuntimeError(f"Unknown event name '{name}'")
168+
engine.add_event_handler(name, log_handler, self, name)
169+
170+
return RemovableEventHandle(event_name, log_handler, engine)
171+
172+
else:
165173

166-
if name not in State.event_to_attr:
167-
raise RuntimeError(f"Unknown event name '{name}'")
174+
if event_name not in State.event_to_attr:
175+
raise RuntimeError(f"Unknown event name '{event_name}'")
168176

169-
return engine.add_event_handler(event_name, log_handler, self, name)
177+
return engine.add_event_handler(event_name, log_handler, self, event_name)
170178

171179
def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> RemovableEventHandle:
172180
"""Shortcut method to attach `OutputHandler` to the logger.

ignite/engine/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import ignite.distributed as idist
77
from ignite.engine.deterministic import DeterministicEngine
88
from ignite.engine.engine import Engine
9-
from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, State
9+
from ignite.engine.events import CallableEventWithFilter, EventEnum, Events, EventsList, RemovableEventHandle, State
1010
from ignite.metrics import Metric
1111
from ignite.utils import convert_tensor
1212

@@ -21,8 +21,10 @@
2121
"Engine",
2222
"DeterministicEngine",
2323
"Events",
24+
"EventsList",
2425
"EventEnum",
2526
"CallableEventWithFilter",
27+
"RemovableEventHandle",
2628
]
2729

2830

ignite/engine/events.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
if TYPE_CHECKING:
1313
from ignite.engine.engine import Engine
1414

15-
__all__ = ["CallableEventWithFilter", "EventEnum", "Events", "State"]
15+
__all__ = ["CallableEventWithFilter", "EventEnum", "Events", "State", "EventsList", "RemovableEventHandle"]
1616

1717

1818
class CallableEventWithFilter:

tests/ignite/contrib/handlers/test_base_logger.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import math
2-
from unittest.mock import MagicMock
2+
from unittest.mock import MagicMock, call
33

44
import pytest
55
import torch
66

77
from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler
8-
from ignite.engine import Engine, Events, State
8+
from ignite.engine import Engine, Events, EventsList, State
99
from tests.ignite.contrib.handlers import MockFP16DeepSpeedZeroOptimizer
1010

1111

@@ -122,7 +122,12 @@ def update_fn(engine, batch):
122122

123123
trainer.run(data, max_epochs=n_epochs)
124124

125-
mock_log_handler.assert_called_with(trainer, logger, event)
125+
if isinstance(event, EventsList):
126+
events = [e for e in event]
127+
else:
128+
events = [event]
129+
calls = [call(trainer, logger, e) for e in events]
130+
mock_log_handler.assert_has_calls(calls)
126131
assert mock_log_handler.call_count == n_calls
127132

128133
_test(Events.ITERATION_STARTED, len(data) * n_epochs)
@@ -134,6 +139,8 @@ def update_fn(engine, batch):
134139

135140
_test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs)
136141

142+
_test(Events.STARTED | Events.COMPLETED, 2)
143+
137144

138145
def test_attach_wrong_event_name():
139146

@@ -144,6 +151,12 @@ def test_attach_wrong_event_name():
144151
with pytest.raises(RuntimeError, match="Unknown event name"):
145152
logger.attach(trainer, log_handler=mock_log_handler, event_name="unknown")
146153

154+
events_list = EventsList()
155+
events_list._events = ["unknown"]
156+
157+
with pytest.raises(RuntimeError, match="Unknown event name"):
158+
logger.attach(trainer, log_handler=mock_log_handler, event_name=events_list)
159+
147160

148161
def test_attach_on_custom_event():
149162
n_epochs = 10

0 commit comments

Comments
 (0)