Lightning supports many different experiment loggers. These loggers allow you to monitor losses, images, text, etc... as training progresses. They usually provide a GUI to visualize and can sometimes even snapshot hyperparameters used in each experiment.
It may slow training down to log every single batch. Trainer has an option to log every k batches instead.
# k = 10
Trainer(row_log_interval=10)
Writing to a logger can be expensive. In Lightning you can set the interval at which you want to log using this trainer flag.
Note
See: :ref:`trainer`
k = 100
Trainer(log_save_interval=k)
To plot metrics into whatever logger you passed in (tensorboard, comet, neptune, etc...)
- Training_end, validation_end, test_end will all log anything in the "log" key of the return dict.
def training_end(self, outputs):
loss = some_loss()
...
logs = {'train_loss': loss}
results = {'log': logs}
return results
def validation_end(self, outputs):
loss = some_loss()
...
logs = {'val_loss': loss}
results = {'log': logs}
return results
def test_end(self, outputs):
loss = some_loss()
...
logs = {'test_loss': loss}
results = {'log': logs}
return results
- Most of the time, you only need training_step and not training_end. You can also return logs from here:
def training_step(self, batch, batch_idx):
loss = some_loss()
...
logs = {'train_loss': loss}
results = {'log': logs}
return results
3. In addition, you can also use any arbitrary functionality from a particular logger from within your LightningModule. For instance, here we log images using tensorboard.
def training_step(self, batch, batch_idx):
self.generated_imgs = self.decoder.generate()
sample_imgs = self.generated_imgs[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('generated_images', grid, 0)
...
return results
Each return dict from the training_end, validation_end, testing_end and training_step also has a key called "progress_bar".
Here we show the validation loss in the progress bar
def validation_end(self, outputs):
loss = some_loss()
...
logs = {'val_loss': loss}
results = {'progress_bar': logs}
return results
When training a model, it's useful to know what hyperparams went into that model. When Lightning creates a checkpoint, it stores a key "hparams" with the hyperparams.
lightning_checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
hyperparams = lightning_checkpoint['hparams']
Some loggers also allow logging the hyperparams used in the experiment. For instance, when using the TestTubeLogger or the TensorBoardLogger, all hyperparams will show in the hparams tab.
Loggers also allow you to snapshot a copy of the code used in this experiment. For example, TestTubeLogger does this with a flag:
from pytorch_lightning.loggers import TestTubeLogger
logger = TestTubeLogger(create_git_tag=True)