You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hyper parameters will not shown in the Tensorboard when using the TensorBoardLogger().
Bug identification
I am trying to use the add_hparams() function to tune my network. Using the sample code provided in PyTorch, I have no problem to view the hyper parameters and the associated accuracy / loss in the tensorboard. The code is shown below
import os
from torch.utils.tensorboard import SummaryWriter
dir_path = "./run2/test_hp"
os.makedirs(dir_path, exist_ok=True)
with SummaryWriter(dir_path) as w:
for i in range(5):
w.add_hparams({'lr': 0.1 * i, 'bsize': i},
{'hparam/accuracy': 10 * i, 'hparam/loss': 10 * i})
But the same code will not work (50% of the time) when I call this code inside the Lightning module. I will explain what the rest of 50% chance is later. The code is shown below. By calling this code alone, I get the following output, guaranteed.
import os
import torch
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.tensorboard import SummaryWriter
class MyNet(LightningModule):
def __init__(self):
super(MyNet, self).__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def train_dataloader(self):
dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
loader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=True)
return loader
def test_dataloader(self):
dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
loader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=False)
return loader
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
tensorboard_logs = {'test_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def test_epoch_end(self, output):
with SummaryWriter(self.logger.log_dir) as w:
for i in range(5):
w.add_hparams({'lr': 0.1 * i, 'bsize': i}, {'hparam/accuracy': 10 * i, 'hparam/loss': 10 * i})
return {}
dir_path = "."
tb_logger = TensorBoardLogger(dir_path, name='run2')
model = MyNet()
trainer = Trainer(gpus=1, max_epochs=1, logger=tb_logger)
trainer.fit(model)
trainer.test()
However, I discover that if I first call my first script, then call the second script. i.e. create the "test_hp" folder first, then create the "version_0" folder. I can see the hyperparameters in the tensorboard. Very weird. And now, if you delete the "test_hp" folder, the hyperparameters disappear again (back to the previous image). So it seems like the hyperparameters logged inside the Lightning model, is dependent on the data generated by the original PyTorch code. Otherwise, it cannot display properly, even we are using the same code. Is this a bug?
Issue
Hyper parameters will not shown in the Tensorboard when using the TensorBoardLogger().
Bug identification
I am trying to use the add_hparams() function to tune my network. Using the sample code provided in PyTorch, I have no problem to view the hyper parameters and the associated accuracy / loss in the tensorboard. The code is shown below
But the same code will not work (50% of the time) when I call this code inside the Lightning module. I will explain what the rest of 50% chance is later. The code is shown below. By calling this code alone, I get the following output, guaranteed.

However, I discover that if I first call my first script, then call the second script. i.e. create the "test_hp" folder first, then create the "version_0" folder. I can see the hyperparameters in the tensorboard. Very weird. And now, if you delete the "test_hp" folder, the hyperparameters disappear again (back to the previous image). So it seems like the hyperparameters logged inside the Lightning model, is dependent on the data generated by the original PyTorch code. Otherwise, it cannot display properly, even we are using the same code. Is this a bug?

Environment
The text was updated successfully, but these errors were encountered: