Skip to content

Commit 18dbfc3

Browse files
authored
Merge branch 'main' into master
2 parents e8311bd + b8acfc8 commit 18dbfc3

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

advanced_source/cpp_frontend.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -969,15 +969,15 @@ the data loader every epoch and then write the GAN training code:
969969
discriminator->zero_grad();
970970
torch::Tensor real_images = batch.data;
971971
torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
972-
torch::Tensor real_output = discriminator->forward(real_images);
972+
torch::Tensor real_output = discriminator->forward(real_images).reshape(real_labels.sizes());
973973
torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
974974
d_loss_real.backward();
975975
976976
// Train discriminator with fake images.
977977
torch::Tensor noise = torch::randn({batch.data.size(0), kNoiseSize, 1, 1});
978978
torch::Tensor fake_images = generator->forward(noise);
979979
torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
980-
torch::Tensor fake_output = discriminator->forward(fake_images.detach());
980+
torch::Tensor fake_output = discriminator->forward(fake_images.detach()).reshape(fake_labels.sizes());
981981
torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
982982
d_loss_fake.backward();
983983
@@ -987,7 +987,7 @@ the data loader every epoch and then write the GAN training code:
987987
// Train generator.
988988
generator->zero_grad();
989989
fake_labels.fill_(1);
990-
fake_output = discriminator->forward(fake_images);
990+
fake_output = discriminator->forward(fake_images).reshape(fake_labels.sizes());
991991
torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
992992
g_loss.backward();
993993
generator_optimizer.step();

0 commit comments

Comments
 (0)