Skip to content

Commit e3de7bd

Browse files
committed
Revert "Issue #1247 (#1252)"
This reverts commit b829473.
1 parent 945766b commit e3de7bd

File tree

5 files changed

+300
-1
lines changed

5 files changed

+300
-1
lines changed

ignite/contrib/handlers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ignite.contrib.handlers.clearml_logger import ClearMLLogger
2+
from ignite.contrib.handlers.custom_events import CustomPeriodicEvent
23
from ignite.contrib.handlers.lr_finder import FastaiLRFinder
34
from ignite.contrib.handlers.mlflow_logger import MLflowLogger
45
from ignite.contrib.handlers.neptune_logger import NeptuneLogger
+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import warnings
2+
3+
from ignite.engine import EventEnum, Events, State
4+
5+
6+
class CustomPeriodicEvent:
7+
"""DEPRECATED. Use filtered events instead.
8+
Handler to define a custom periodic events as a number of elapsed iterations/epochs
9+
for an engine.
10+
11+
When custom periodic event is created and attached to an engine, the following events are fired:
12+
1) K iterations is specified:
13+
- `Events.ITERATIONS_<K>_STARTED`
14+
- `Events.ITERATIONS_<K>_COMPLETED`
15+
16+
1) K epochs is specified:
17+
- `Events.EPOCHS_<K>_STARTED`
18+
- `Events.EPOCHS_<K>_COMPLETED`
19+
20+
21+
Examples:
22+
23+
.. code-block:: python
24+
25+
from ignite.engine import Engine, Events
26+
from ignite.contrib.handlers import CustomPeriodicEvent
27+
28+
# Let's define an event every 1000 iterations
29+
cpe1 = CustomPeriodicEvent(n_iterations=1000)
30+
cpe1.attach(trainer)
31+
32+
# Let's define an event every 10 epochs
33+
cpe2 = CustomPeriodicEvent(n_epochs=10)
34+
cpe2.attach(trainer)
35+
36+
@trainer.on(cpe1.Events.ITERATIONS_1000_COMPLETED)
37+
def on_every_1000_iterations(engine):
38+
# run a computation after 1000 iterations
39+
# ...
40+
print(engine.state.iterations_1000)
41+
42+
@trainer.on(cpe2.Events.EPOCHS_10_STARTED)
43+
def on_every_10_epochs(engine):
44+
# run a computation every 10 epochs
45+
# ...
46+
print(engine.state.epochs_10)
47+
48+
49+
Args:
50+
n_iterations (int, optional): number iterations of the custom periodic event
51+
n_epochs (int, optional): number iterations of the custom periodic event. Argument is optional, but only one,
52+
either n_iterations or n_epochs should defined.
53+
54+
"""
55+
56+
def __init__(self, n_iterations=None, n_epochs=None):
57+
58+
warnings.warn(
59+
"CustomPeriodicEvent is deprecated since 0.4.0 and will be removed in 0.5.0. Use filtered events instead.",
60+
DeprecationWarning,
61+
)
62+
63+
if n_iterations is not None:
64+
if not isinstance(n_iterations, int):
65+
raise TypeError("Argument n_iterations should be an integer")
66+
if n_iterations < 1:
67+
raise ValueError("Argument n_iterations should be positive")
68+
69+
if n_epochs is not None:
70+
if not isinstance(n_epochs, int):
71+
raise TypeError("Argument n_epochs should be an integer")
72+
if n_epochs < 1:
73+
raise ValueError("Argument n_epochs should be positive")
74+
75+
if (n_iterations is None and n_epochs is None) or (n_iterations and n_epochs):
76+
raise ValueError("Either n_iterations or n_epochs should be defined")
77+
78+
if n_iterations:
79+
prefix = "iterations"
80+
self.state_attr = "iteration"
81+
self.period = n_iterations
82+
83+
if n_epochs:
84+
prefix = "epochs"
85+
self.state_attr = "epoch"
86+
self.period = n_epochs
87+
88+
self.custom_state_attr = "{}_{}".format(prefix, self.period)
89+
event_name = "{}_{}".format(prefix.upper(), self.period)
90+
setattr(
91+
self,
92+
"Events",
93+
EventEnum("Events", " ".join(["{}_STARTED".format(event_name), "{}_COMPLETED".format(event_name)])),
94+
)
95+
96+
# Update State.event_to_attr
97+
for e in self.Events:
98+
State.event_to_attr[e] = self.custom_state_attr
99+
100+
# Create aliases
101+
self._periodic_event_started = getattr(self.Events, "{}_STARTED".format(event_name))
102+
self._periodic_event_completed = getattr(self.Events, "{}_COMPLETED".format(event_name))
103+
104+
def _on_started(self, engine):
105+
setattr(engine.state, self.custom_state_attr, 0)
106+
107+
def _on_periodic_event_started(self, engine):
108+
if getattr(engine.state, self.state_attr) % self.period == 1:
109+
setattr(engine.state, self.custom_state_attr, getattr(engine.state, self.custom_state_attr) + 1)
110+
engine.fire_event(self._periodic_event_started)
111+
112+
def _on_periodic_event_completed(self, engine):
113+
if getattr(engine.state, self.state_attr) % self.period == 0:
114+
engine.fire_event(self._periodic_event_completed)
115+
116+
def attach(self, engine):
117+
engine.register_events(*self.Events)
118+
119+
engine.add_event_handler(Events.STARTED, self._on_started)
120+
engine.add_event_handler(
121+
getattr(Events, "{}_STARTED".format(self.state_attr.upper())), self._on_periodic_event_started
122+
)
123+
engine.add_event_handler(
124+
getattr(Events, "{}_COMPLETED".format(self.state_attr.upper())), self._on_periodic_event_completed
125+
)

tests/ignite/contrib/handlers/test_base_logger.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from unittest.mock import MagicMock, call
22

3+
import math
34
import pytest
45
import torch
56

7+
from ignite.contrib.handlers import CustomPeriodicEvent
68
from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler
79
from ignite.engine import Engine, Events, EventsList, State
810
from tests.ignite.contrib.handlers import MockFP16DeepSpeedZeroOptimizer
@@ -183,6 +185,33 @@ def update_fn(engine, batch):
183185
mock_log_handler.assert_called_with(trainer, logger, event)
184186
assert mock_log_handler.call_count == n_calls
185187

188+
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
189+
n_iterations = 10
190+
cpe1 = CustomPeriodicEvent(n_iterations=n_iterations)
191+
n = len(data) * n_epochs / n_iterations
192+
nf = math.floor(n)
193+
ns = nf + 1 if nf < n else nf
194+
_test(cpe1.Events.ITERATIONS_10_STARTED, ns, cpe1)
195+
_test(cpe1.Events.ITERATIONS_10_COMPLETED, nf, cpe1)
196+
197+
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
198+
n_iterations = 15
199+
cpe2 = CustomPeriodicEvent(n_iterations=n_iterations)
200+
n = len(data) * n_epochs / n_iterations
201+
nf = math.floor(n)
202+
ns = nf + 1 if nf < n else nf
203+
_test(cpe2.Events.ITERATIONS_15_STARTED, ns, cpe2)
204+
_test(cpe2.Events.ITERATIONS_15_COMPLETED, nf, cpe2)
205+
206+
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
207+
n_custom_epochs = 2
208+
cpe3 = CustomPeriodicEvent(n_epochs=n_custom_epochs)
209+
n = n_epochs / n_custom_epochs
210+
nf = math.floor(n)
211+
ns = nf + 1 if nf < n else nf
212+
_test(cpe3.Events.EPOCHS_2_STARTED, ns, cpe3)
213+
_test(cpe3.Events.EPOCHS_2_COMPLETED, nf, cpe3)
214+
186215

187216
def test_as_context_manager():
188217

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import math
2+
3+
import pytest
4+
5+
from ignite.contrib.handlers.custom_events import CustomPeriodicEvent
6+
from ignite.engine import Engine
7+
8+
9+
def test_bad_input():
10+
11+
with pytest.warns(DeprecationWarning, match=r"CustomPeriodicEvent is deprecated"):
12+
with pytest.raises(TypeError, match="Argument n_iterations should be an integer"):
13+
CustomPeriodicEvent(n_iterations="a")
14+
with pytest.raises(ValueError, match="Argument n_iterations should be positive"):
15+
CustomPeriodicEvent(n_iterations=0)
16+
with pytest.raises(TypeError, match="Argument n_iterations should be an integer"):
17+
CustomPeriodicEvent(n_iterations=10.0)
18+
with pytest.raises(TypeError, match="Argument n_epochs should be an integer"):
19+
CustomPeriodicEvent(n_epochs="a")
20+
with pytest.raises(ValueError, match="Argument n_epochs should be positive"):
21+
CustomPeriodicEvent(n_epochs=0)
22+
with pytest.raises(TypeError, match="Argument n_epochs should be an integer"):
23+
CustomPeriodicEvent(n_epochs=10.0)
24+
with pytest.raises(ValueError, match="Either n_iterations or n_epochs should be defined"):
25+
CustomPeriodicEvent()
26+
with pytest.raises(ValueError, match="Either n_iterations or n_epochs should be defined"):
27+
CustomPeriodicEvent(n_iterations=1, n_epochs=2)
28+
29+
30+
def test_new_events():
31+
def update(*args, **kwargs):
32+
pass
33+
34+
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
35+
engine = Engine(update)
36+
cpe = CustomPeriodicEvent(n_iterations=5)
37+
cpe.attach(engine)
38+
39+
assert hasattr(cpe, "Events")
40+
assert hasattr(cpe.Events, "ITERATIONS_5_STARTED")
41+
assert hasattr(cpe.Events, "ITERATIONS_5_COMPLETED")
42+
43+
assert engine._allowed_events[-2] == getattr(cpe.Events, "ITERATIONS_5_STARTED")
44+
assert engine._allowed_events[-1] == getattr(cpe.Events, "ITERATIONS_5_COMPLETED")
45+
46+
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
47+
cpe = CustomPeriodicEvent(n_epochs=5)
48+
cpe.attach(engine)
49+
50+
assert hasattr(cpe, "Events")
51+
assert hasattr(cpe.Events, "EPOCHS_5_STARTED")
52+
assert hasattr(cpe.Events, "EPOCHS_5_COMPLETED")
53+
54+
assert engine._allowed_events[-2] == getattr(cpe.Events, "EPOCHS_5_STARTED")
55+
assert engine._allowed_events[-1] == getattr(cpe.Events, "EPOCHS_5_COMPLETED")
56+
57+
58+
def test_integration_iterations():
59+
def _test(n_iterations, max_epochs, n_iters_per_epoch):
60+
def update(*args, **kwargs):
61+
pass
62+
63+
engine = Engine(update)
64+
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
65+
cpe = CustomPeriodicEvent(n_iterations=n_iterations)
66+
cpe.attach(engine)
67+
data = list(range(n_iters_per_epoch))
68+
69+
custom_period = [0]
70+
n_calls_iter_started = [0]
71+
n_calls_iter_completed = [0]
72+
73+
event_started = getattr(cpe.Events, "ITERATIONS_{}_STARTED".format(n_iterations))
74+
75+
@engine.on(event_started)
76+
def on_my_event_started(engine):
77+
assert (engine.state.iteration - 1) % n_iterations == 0
78+
custom_period[0] += 1
79+
custom_iter = getattr(engine.state, "iterations_{}".format(n_iterations))
80+
assert custom_iter == custom_period[0]
81+
n_calls_iter_started[0] += 1
82+
83+
event_completed = getattr(cpe.Events, "ITERATIONS_{}_COMPLETED".format(n_iterations))
84+
85+
@engine.on(event_completed)
86+
def on_my_event_ended(engine):
87+
assert engine.state.iteration % n_iterations == 0
88+
custom_iter = getattr(engine.state, "iterations_{}".format(n_iterations))
89+
assert custom_iter == custom_period[0]
90+
n_calls_iter_completed[0] += 1
91+
92+
engine.run(data, max_epochs=max_epochs)
93+
94+
n = len(data) * max_epochs / n_iterations
95+
nf = math.floor(n)
96+
assert custom_period[0] == n_calls_iter_started[0]
97+
assert n_calls_iter_started[0] == nf + 1 if nf < n else nf
98+
assert n_calls_iter_completed[0] == nf
99+
100+
_test(3, 5, 16)
101+
_test(4, 5, 16)
102+
_test(5, 5, 16)
103+
_test(300, 50, 1000)
104+
105+
106+
def test_integration_epochs():
107+
def update(*args, **kwargs):
108+
pass
109+
110+
engine = Engine(update)
111+
112+
n_epochs = 3
113+
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
114+
cpe = CustomPeriodicEvent(n_epochs=n_epochs)
115+
cpe.attach(engine)
116+
data = list(range(16))
117+
118+
custom_period = [1]
119+
120+
@engine.on(cpe.Events.EPOCHS_3_STARTED)
121+
def on_my_epoch_started(engine):
122+
assert (engine.state.epoch - 1) % n_epochs == 0
123+
assert engine.state.epochs_3 == custom_period[0]
124+
125+
@engine.on(cpe.Events.EPOCHS_3_COMPLETED)
126+
def on_my_epoch_ended(engine):
127+
assert engine.state.epoch % n_epochs == 0
128+
assert engine.state.epochs_3 == custom_period[0]
129+
custom_period[0] += 1
130+
131+
engine.run(data, max_epochs=10)
132+
133+
assert custom_period[0] == 4

tests/ignite/contrib/handlers/test_tqdm_logger.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010
import torch
1111

12-
from ignite.contrib.handlers import ProgressBar
12+
from ignite.contrib.handlers import CustomPeriodicEvent, ProgressBar
1313
from ignite.engine import Engine, Events
1414
from ignite.handlers import TerminateOnNan
1515
from ignite.metrics import RunningAverage
@@ -439,6 +439,17 @@ def test_pbar_wrong_events_order():
439439
pbar.attach(engine, event_name=Events.ITERATION_STARTED, closing_event_name=Events.EPOCH_COMPLETED(every=10))
440440

441441

442+
def test_pbar_on_custom_events(capsys):
443+
444+
engine = Engine(update_fn)
445+
pbar = ProgressBar()
446+
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
447+
cpe = CustomPeriodicEvent(n_iterations=15)
448+
449+
with pytest.raises(ValueError, match=r"not in allowed events for this engine"):
450+
pbar.attach(engine, event_name=cpe.Events.ITERATIONS_15_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)
451+
452+
442453
def test_pbar_with_nan_input():
443454
def update(engine, batch):
444455
x = batch

0 commit comments

Comments
 (0)