Skip to content

Commit edb8d7a

Browse files
Nested metrics dictionaries now can be passed to the loggers (#1582)
* now func merge_dicts works with nested dictionaries * CHANGELOG.md upd
1 parent 94e5344 commit edb8d7a

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5050

5151
### Fixed
5252

53+
- Added the possibility to pass nested metrics dictionaries to loggers ([#1582](https://github.com/PyTorchLightning/pytorch-lightning/pull/1582))
54+
5355
- Fixed memory leak from opt return ([#1528](https://github.com/PyTorchLightning/pytorch-lightning/pull/1528))
5456

5557
- Fixed saving checkpoint before deleting old ones ([#1453](https://github.com/PyTorchLightning/pytorch-lightning/pull/1453))

pytorch_lightning/loggers/base.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ class LoggerCollection(LightningLoggerBase):
280280
Args:
281281
logger_iterable: An iterable collection of loggers
282282
"""
283+
283284
def __init__(self, logger_iterable: Iterable[LightningLoggerBase]):
284285
super().__init__()
285286
self._logger_iterable = logger_iterable
@@ -347,20 +348,28 @@ def merge_dicts(
347348
348349
Examples:
349350
>>> import pprint
350-
>>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1}
351-
>>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1}
352-
>>> d3 = {'a': 1.1, 'v': 2.3}
351+
>>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1, 'd': {'d1': 1, 'd3': 3}}
352+
>>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1, 'd': {'d1': 2, 'd2': 3}}
353+
>>> d3 = {'a': 1.1, 'v': 2.3, 'd': {'d3': 3, 'd4': {'d5': 1}}}
353354
>>> dflt_func = min
354-
>>> agg_funcs = {'a': np.mean, 'v': max}
355+
>>> agg_funcs = {'a': np.mean, 'v': max, 'd': {'d1': sum}}
355356
>>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func))
356-
{'a': 1.3, 'b': 2.0, 'c': 1, 'v': 2.3}
357+
{'a': 1.3,
358+
'b': 2.0,
359+
'c': 1,
360+
'd': {'d1': 3, 'd2': 3, 'd3': 3, 'd4': {'d5': 1}},
361+
'v': 2.3}
357362
"""
358-
363+
agg_key_funcs = agg_key_funcs or dict()
359364
keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts]))
360365
d_out = {}
361366
for k in keys:
362-
fn = agg_key_funcs.get(k, default_func) if agg_key_funcs else default_func
363-
agg_val = fn([v for v in [d_in.get(k) for d_in in dicts] if v is not None])
364-
d_out[k] = agg_val
367+
fn = agg_key_funcs.get(k)
368+
values_to_agg = [v for v in [d_in.get(k) for d_in in dicts] if v is not None]
369+
370+
if isinstance(values_to_agg[0], dict):
371+
d_out[k] = merge_dicts(values_to_agg, fn, default_func)
372+
else:
373+
d_out[k] = (fn or default_func)(values_to_agg)
365374

366375
return d_out

0 commit comments

Comments
 (0)