Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log training metrics for each epoch #914

Closed
jbschiratti opened this issue Feb 22, 2020 · 21 comments · Fixed by #1357
Closed

Log training metrics for each epoch #914

jbschiratti opened this issue Feb 22, 2020 · 21 comments · Fixed by #1357
Labels
priority: 0 High priority task question Further information is requested

Comments

@jbschiratti
Copy link
Contributor

Currently, I am able to log training metrics to Tensorboard using:

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger(save_dir=save_dir, name="my_model")

[...]

trainer = pl.Trainer(logger=logger)

This logs training metrics (loss, for instance) after each batch. I would like to be able to average these metrics across all batches and log them to TensorBoard only once, at the end of each epoch. This is what the validation_end method does in your example: https://github.com/PyTorchLightning/pytorch-lightning/blob/446a1e23d7fe3b2e07f1a5887fe819d0dfa7d4e0/pl_examples/basic_examples/lightning_module_template.py#L145.

I first thought about writing my own training_end method. But this method is called after each batch instead of being called at the end of an epoch (as I would have thought). The method on_epoch_end seems interesting but does not receive an outputs argument as training_end does. Basically, in my model, I would like to write something like: self.logger.experiment.add_scalar('training_loss', train_loss_mean, global_step=self.current_epoch), but I do not know where to put this line.

  • OS: Debian GNU/Linux 9.11 (stretch)
  • Packaging: PIP
  • Version 0.6.1.dev0
@jbschiratti jbschiratti added the question Further information is requested label Feb 22, 2020
@github-actions
Copy link
Contributor

Hey, thanks for your contribution! Great first issue!

@awaelchli
Copy link
Contributor

How about this:
In __init__:
self.training_losses = []
In training_step method:
self.training_losses.append(loss.item())
In epoch_end method:

train_loss_mean = np.mean(self.training_losses)
self.logger.experiment.add_scalar('training_loss', train_loss_mean, global_step=self.current_epoch)
self.training_losses = []  # reset for next epoch

@polars05
Copy link

polars05 commented Feb 23, 2020

I'm also trying to log training metrics at the end of each epoch, and tried it as follows (based off the MNIST example in https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=x-34xKCI40yW):

class CoolSystem(pl.LightningModule):

    def __init__(self, hparams):
        super(CoolSystem, self).__init__()

        self.l1 = torch.nn.Linear(28 * 28, 10)
        
    # ...

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        
        loss = F.cross_entropy(y_hat, y)
        
        _, preds = torch.max(y_hat, dim=1)
        acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)

        return {'train_loss': loss, 'train_acc': acc}
    
    def training_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['train_loss'] for x in outputs]).mean() # -------- Error is thrown here
        avg_acc = torch.stack([x['train_acc'].float() for x in outputs]).mean()

        logs = {'train_loss': avg_loss, 'train_acc': avg_acc}
        return {'avg_train_loss': avg_loss, 'avg_train_acc': avg_acc, 'log': logs, 'progress_bar': logs}

    def validation_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)

        _, preds = torch.max(y_hat, 1)
        acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)

        return {'val_loss': F.cross_entropy(y_hat, y), 'val_acc': acc}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_acc'].float() for x in outputs]).mean()

        print (avg_acc)

        logs = {'val_loss': avg_loss, 'val_acc': avg_acc}
        return {'avg_val_loss': avg_loss, 'avg_val_acc': avg_acc, 'log': logs, 'progress_bar': logs}

However, the following error is thrown at training_end:

TypeError: string indices must be integers

Weird thing is that it works for validation_end (the validation check passes, and no error is thrown when I remove training_end from my system to debug); could you pls advise? Thanks!

@jbschiratti
Copy link
Contributor Author

jbschiratti commented Feb 23, 2020

@awaelchli That works for me, thanks!

@polars05 From my understanding, training_step and validation_step compute the metrics for the current step (batch) and return them in the output dict. In the validation loop, the metrics from all steps (all batches in an epoch) are aggregated in outputs. The method validation_end allows you to average those metrics epoch-wise (for instance, compute a single loss score for the whole validation dataset). What was confusing for me is that training_end does not work like validation_end (see my post). @awaelchli feel free to correct me if I'm wrong!

@williamFalcon
Copy link
Contributor

yeah, training_end is a bit confusing. we discussed this and decided to change the name/modify it.

training_end aggregates outputs of a batch on dp. We need a true training_end and somethinf for the dp aggregation.

@jeremyjordan @ethanwharris i think you guys had opinions about this? should we try to sneak this into this next release?

@polars05
Copy link

@jbschiratti, thanks for the clarifications! I've modified my system as per what @awaelchli suggested to you and that works for me as well.

However, in the process, I observed that under training_step, if we did return {'train_loss': loss} instead of return {'loss': loss}, then the following error is thrown:

RuntimeError: No `loss` value in the dictionary returned from `model.training_step()`.

For reference, here's my system:

class CoolSystem(pl.LightningModule):

    def __init__(self, hparams):
        super(CoolSystem, self).__init__()

        self.hparams = hparams
        self.data_dir = self.hparams.data_dir
        # self.classes = self.hparams.classes
        self.num_classes = self.hparams.num_classes # len(self.hparams.classes)

        self.model = torchvision.models.resnet18(pretrained=True) # final layer is of size [bs, 1000]
        num_ftrs = self.model.fc.in_features
        self.model.fc = torch.nn.Linear(num_ftrs, self.num_classes) # change final layer to be of size [bs, 2]
        
        self.training_acc_across_batches_at_curr_epoch = []

    # ...

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        
        loss = F.cross_entropy(y_hat, y)
        
        _, preds = torch.max(y_hat, dim=1)
        acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)
        # torch.tensor(acc)
        self.training_acc_across_batches_at_curr_epoch.append(acc.item())
        
        # return {'loss': loss}         # -------> this works...
        return {'train_loss': loss}   # --------> ... but this doesn't
    
    def on_epoch_end(self):
        train_acc_mean = np.mean(self.training_acc_across_batches_at_curr_epoch)
        
        self.logger.experiment.add_scalar('train_acc', train_acc_mean, global_step=self.current_epoch)
        self.training_acc_across_batches_per_epoch = []  # reset for next epoch

My guess is that we do not have the flexibility to name the dict key in training_step?

In addition, if we did the following instead in training_step:

logs = {'train_loss': loss, 'train_acc': acc}
return {'loss': loss, 'train_acc': acc, 'log': logs, 'progress_bar': logs}

then both loss and train_loss shows up on the progress bar

I could simply do:

dict_for_progress_bar = {'train_acc': acc}
dict_for_log = {'train_loss': loss, 'train_acc': acc}
return {'loss': loss, 'train_acc': acc, 'log': dict_for_logs, 'progress_bar': dict_for_progress_bar}

but it might be cleaner for the user to not have to create two separate dicts?

@jeremyjordan
Copy link
Contributor

@williamFalcon I agree that the behavior for training_end should be symmetric with validation_end. Unfortunately, this does require changing the API - should we include a warning message in the next release to alert of a change upcoming in the subsequent release?

The current training_end could be renamed to something more explicit such as collect_dp_batches or something of the like.

@jeremyjordan
Copy link
Contributor

@polars05 that is correct! you can see how the values returned are processed here
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/logging.py#L90

on line 153 you can see that there's an explicit check for a loss key in the output dict.

@versatran01
Copy link

versatran01 commented Feb 24, 2020

Just wondering what is the best way to do something at the end of each training epoch, before validation starts?
From this thread, I assume training_end is not the right place to do that?
Currently, I had to do something in validation_step, where I check whether the index == 0 and do something, but this seems very ugly.

@awaelchli
Copy link
Contributor

@versatran01 Maybe for your use case the on_epoch_end() hook is enough? It doesn't get the metrics from training though.

If you need something that is conceptually similar as validation_end for training, then this does not exist yet. And this is also what @jbschiratti is asking for in this issue.
If you need a workaround until the changes are implemented, see my answer above.

@versatran01
Copy link

@awaelchli Thanks for the suggestion. I just saw some huge updates to the callback system, will give it a try.

@jbschiratti
Copy link
Contributor Author

@williamFalcon @jeremyjordan Changing the current training_end method into something like collect_dp_batches and making training_end similar to the current validation_end does not seem to be too much work. Shall I submit a PR for this?

@williamFalcon
Copy link
Contributor

training_end makes sense.

maybe a better name for collect_dp_batches is: training_step_reduce or something like that

@williamFalcon
Copy link
Contributor

updated in 0.7.1

@jbschiratti
Copy link
Contributor Author

@williamFalcon I had a look at the latest release and I think that my issue still stands. Unless I am mistaken, there is still no equivalent of validation_end (now called validation_epoch_end) in training_loop.py.

At each epoch, one can average validation loss across all batches by doing:

def validation_epoch_end(self, outputs):
    val_loss_mean = 0
    for output in outputs:
        val_loss = output['val_loss']
        if self.trainer.use_dp or self.trainer.use_ddp2:
            val_loss = torch.mean(val_loss)
        val_loss_mean += val_loss
    val_loss_mean /= len(outputs)
    tqdm_dict = {'val_loss': val_loss_mean}
    return OrderedDict({'val_loss': val_loss_mean, 'progress_bar': tqdm_dict, 'log': tqdm_dict})

But as of today, there is still no way to do the same in the training loop ! I though about the function training_epoch_end but it is not used. It is only mentioned in a comment on line 719.

@jbschiratti
Copy link
Contributor Author

Shall I open a new issue ?

@jeremyjordan jeremyjordan reopened this Mar 31, 2020
@jeremyjordan
Copy link
Contributor

@jbschiratti you're right, i don't see where this is implemented. reopened the issue.

@jbschiratti
Copy link
Contributor Author

jbschiratti commented Apr 1, 2020

@jeremyjordan I pushed a "proof of concept". See this commit. This is certainly not perfect but tests (the ones which were not skipped) are OK and it is what I'm expecting from pytorch-lightning's API: the behavior of training_epoch_end should mirror that of validation_epoch_end. If you want, I can start a PR for this issue and we can iterate on it.

@jbschiratti
Copy link
Contributor Author

@williamFalcon Shall I start a PR with what I already did?

@williamFalcon
Copy link
Contributor

@jbschiratti this is awesome. yes please, submit the PR!

@tbenst
Copy link

tbenst commented Sep 3, 2020

Hi @jbschiratti just noticed a difference between training_epoch_end and validation_epoch_end:

def training_epoch_end(self, outputs):
        # works
        return outputs
        
def validation_epoch_end(self, outputs): 
    # fails, unless next two lines are uncommented
    # loss_mean = outputs['val_loss'].mean().item()
    # outputs["val_loss"] = loss_mean
    return outputs

Is this as intended? Or a bug? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
priority: 0 High priority task question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants