Skip to content

Commit 458d3e2

Browse files
ethanwharrisBorda
andauthored
Add missing methods to logger collection (#2723)
* Add missing methods to logger collection * Update CHANGELOG.md * Fix errors after merge * Fix codefactor issues * Update CHANGELOG.md Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 1767822 commit 458d3e2

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4141

4242
- Fixed TPU multi-core and Float16 ([#2632](https://github.com/PyTorchLightning/pytorch-lightning/pull/2632))
4343

44+
- Fixed test metrics not being logged with `LoggerCollection` ([#2723](https://github.com/PyTorchLightning/pytorch-lightning/pull/2723))
45+
4446
## [0.8.5] - 2020-07-09
4547

4648
### Added

pytorch_lightning/loggers/base.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -272,24 +272,46 @@ def __init__(self, logger_iterable: Iterable[LightningLoggerBase]):
272272
def __getitem__(self, index: int) -> LightningLoggerBase:
273273
return [logger for logger in self._logger_iterable][index]
274274

275+
def update_agg_funcs(
276+
self,
277+
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
278+
agg_default_func: Callable[[Sequence[float]], float] = np.mean
279+
):
280+
for logger in self._logger_iterable:
281+
logger.update_agg_funcs(agg_key_funcs, agg_default_func)
282+
275283
@property
276284
def experiment(self) -> List[Any]:
277285
return [logger.experiment for logger in self._logger_iterable]
278286

287+
def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
288+
for logger in self._logger_iterable:
289+
logger.agg_and_log_metrics(metrics, step)
290+
279291
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
280-
[logger.log_metrics(metrics, step) for logger in self._logger_iterable]
292+
for logger in self._logger_iterable:
293+
logger.log_metrics(metrics, step)
281294

282295
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
283-
[logger.log_hyperparams(params) for logger in self._logger_iterable]
296+
for logger in self._logger_iterable:
297+
logger.log_hyperparams(params)
284298

285299
def save(self) -> None:
286-
[logger.save() for logger in self._logger_iterable]
300+
for logger in self._logger_iterable:
301+
logger.save()
287302

288303
def finalize(self, status: str) -> None:
289-
[logger.finalize(status) for logger in self._logger_iterable]
304+
for logger in self._logger_iterable:
305+
logger.finalize(status)
290306

291307
def close(self) -> None:
292-
[logger.close() for logger in self._logger_iterable]
308+
for logger in self._logger_iterable:
309+
logger.close()
310+
311+
@property
312+
def save_dir(self) -> Optional[str]:
313+
# Checkpoints should be saved to default / chosen location when using multiple loggers
314+
return None
293315

294316
@property
295317
def name(self) -> str:

tests/loggers/test_base.py

+10
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ def test_logger_collection():
2222
assert logger.experiment[0] == mock1.experiment
2323
assert logger.experiment[1] == mock2.experiment
2424

25+
assert logger.save_dir is None
26+
27+
logger.update_agg_funcs({'test': np.mean}, np.sum)
28+
mock1.update_agg_funcs.assert_called_once_with({'test': np.mean}, np.sum)
29+
mock2.update_agg_funcs.assert_called_once_with({'test': np.mean}, np.sum)
30+
31+
logger.agg_and_log_metrics({'test': 2.0}, 4)
32+
mock1.agg_and_log_metrics.assert_called_once_with({'test': 2.0}, 4)
33+
mock2.agg_and_log_metrics.assert_called_once_with({'test': 2.0}, 4)
34+
2535
logger.close()
2636
mock1.close.assert_called_once()
2737
mock2.close.assert_called_once()

0 commit comments

Comments
 (0)