Skip to content

Commit 16ce881

Browse files
committed
add some tests
1 parent 64e646e commit 16ce881

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

tests/base/model_train_steps.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,22 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
2121

2222
# calculate loss
2323
loss_val = self.loss(y, y_hat)
24+
loss_scalar = loss_val.item()
2425

2526
# alternate possible outputs to test
27+
if batch_idx % 2 == 0:
28+
output = OrderedDict({
29+
'loss': loss_val,
30+
'progress_bar': {'some_val': loss_val * loss_val},
31+
'log': {'train_some_val': loss_val * loss_val},
32+
})
33+
34+
# return scalars for "log" and "progress_bar"
2635
output = OrderedDict({
2736
'loss': loss_val,
28-
'progress_bar': {'some_val': loss_val * loss_val},
29-
'log': {'train_some_val': loss_val * loss_val},
30-
})
37+
'progress_bar': {'some_val': loss_scalar * loss_scalar},
38+
'log': {'train_some_val': loss_scalar * loss_scalar},
39+
})
3140
return output
3241

3342
def training_step__inf_loss(self, batch, batch_idx, optimizer_idx=None):

tests/base/model_valid_epoch_ends.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ def _mean(res, key):
2323
val_loss_mean = _mean(outputs, 'val_loss')
2424
val_acc_mean = _mean(outputs, 'val_acc')
2525

26-
metrics_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
26+
# alternate between tensor and scalar
27+
if self.current_epoch % 2:
28+
metrics_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
29+
else:
30+
metrics_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
2731
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
2832
return results
2933

0 commit comments

Comments
 (0)