@@ -969,15 +969,15 @@ the data loader every epoch and then write the GAN training code:
969
969
discriminator->zero_grad();
970
970
torch::Tensor real_images = batch.data;
971
971
torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
972
- torch::Tensor real_output = discriminator->forward(real_images);
972
+ torch::Tensor real_output = discriminator->forward(real_images).reshape(real_labels.sizes()) ;
973
973
torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
974
974
d_loss_real.backward();
975
975
976
976
// Train discriminator with fake images.
977
977
torch::Tensor noise = torch::randn({batch.data.size(0), kNoiseSize, 1, 1});
978
978
torch::Tensor fake_images = generator->forward(noise);
979
979
torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
980
- torch::Tensor fake_output = discriminator->forward(fake_images.detach());
980
+ torch::Tensor fake_output = discriminator->forward(fake_images.detach()).reshape(fake_labels.sizes()) ;
981
981
torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
982
982
d_loss_fake.backward();
983
983
@@ -987,7 +987,7 @@ the data loader every epoch and then write the GAN training code:
987
987
// Train generator.
988
988
generator->zero_grad();
989
989
fake_labels.fill_(1);
990
- fake_output = discriminator->forward(fake_images);
990
+ fake_output = discriminator->forward(fake_images).reshape(fake_labels.sizes()) ;
991
991
torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
992
992
g_loss.backward();
993
993
generator_optimizer.step();
0 commit comments