@@ -307,7 +307,7 @@ def tf_train(
307
307
val_loss , val_acc , n_iter = 0 , 0 , 0
308
308
for X_batch , y_batch in test_dataset :
309
309
_logits = network (X_batch ) # is_train=False, disable dropout
310
- val_loss += loss_fn (_logits , y_batch , name = 'eval_loss' )
310
+ val_loss += loss_fn (_logits , y_batch )
311
311
if metrics :
312
312
metrics .update (_logits , y_batch )
313
313
val_acc += metrics .result ()
@@ -360,7 +360,7 @@ def ms_train(
360
360
val_loss , val_acc , n_iter = 0 , 0 , 0
361
361
for X_batch , y_batch in test_dataset :
362
362
_logits = network (X_batch )
363
- val_loss += loss_fn (_logits , y_batch , name = 'eval_loss' )
363
+ val_loss += loss_fn (_logits , y_batch )
364
364
if metrics :
365
365
metrics .update (_logits , y_batch )
366
366
val_acc += metrics .result ()
@@ -414,7 +414,7 @@ def pd_train(
414
414
val_loss , val_acc , n_iter = 0 , 0 , 0
415
415
for X_batch , y_batch in test_dataset :
416
416
_logits = network (X_batch ) # is_train=False, disable dropout
417
- val_loss += loss_fn (_logits , y_batch , name = 'eval_loss' )
417
+ val_loss += loss_fn (_logits , y_batch )
418
418
if metrics :
419
419
metrics .update (_logits , y_batch )
420
420
val_acc += metrics .result ()
@@ -468,7 +468,7 @@ def th_train(
468
468
val_loss , val_acc , n_iter = 0 , 0 , 0
469
469
for X_batch , y_batch in test_dataset :
470
470
_logits = network (X_batch ) # is_train=False, disable dropout
471
- val_loss += loss_fn (_logits , y_batch , name = 'eval_loss' )
471
+ val_loss += loss_fn (_logits , y_batch )
472
472
if metrics :
473
473
metrics .update (_logits , y_batch )
474
474
val_acc += metrics .result ()
0 commit comments