Skip to content

Commit 899cd74

Browse files
authored
flatten Wandb hyperparameters dict (#2459)
* wandb logging fix * Changelog fix * change test
1 parent 7ef73f2 commit 899cd74

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121

2222
### Fixed
2323

24+
- Flattening Wandb Hyperparameters ([#2459](https://github.com/PyTorchLightning/pytorch-lightning/pull/2459))
25+
2426
- Fixed using the same DDP python interpreter and actually running ([#2482](https://github.com/PyTorchLightning/pytorch-lightning/pull/2482))
2527

2628
- Fixed model summary input type conversion for models that have input dtype different from model parameters ([#2510](https://github.com/PyTorchLightning/pytorch-lightning/pull/2510))

pytorch_lightning/loggers/wandb.py

+1
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):
125125
@rank_zero_only
126126
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
127127
params = self._convert_params(params)
128+
params = self._flatten_dict(params)
128129
self.experiment.config.update(params, allow_val_change=True)
129130

130131
@rank_zero_only

tests/loggers/test_wandb.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@ def test_wandb_logger(wandb):
1919
logger.log_metrics({'acc': 1.0}, step=3)
2020
wandb.init().log.assert_called_once_with({'global_step': 3, 'acc': 1.0})
2121

22-
logger.log_hyperparams({'test': None})
23-
wandb.init().config.update.assert_called_once_with({'test': None}, allow_val_change=True)
24-
22+
logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
23+
wandb.init().config.update.assert_called_once_with(
24+
{'test': 'None', 'nested/a': 1, 'b': [2, 3, 4]},
25+
allow_val_change=True,
26+
)
27+
2528
logger.watch('model', 'log', 10)
2629
wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)
2730

0 commit comments

Comments
 (0)