Skip to content

Commit e244185

Browse files
committed
add some tests
1 parent 64e646e commit e244185

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

tests/base/model_train_steps.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,21 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
2121

2222
# calculate loss
2323
loss_val = self.loss(y, y_hat)
24+
loss_scalar = loss_val.item()
2425

2526
# alternate possible outputs to test
27+
if batch_idx % 2 == 0:
28+
output = OrderedDict({
29+
'loss': loss_val,
30+
'progress_bar': {'some_val': loss_val * loss_val},
31+
'log': {'train_some_val': loss_val * loss_val},
32+
})
33+
34+
# return scalars for "log" and "progress_bar"
2635
output = OrderedDict({
2736
'loss': loss_val,
28-
'progress_bar': {'some_val': loss_val * loss_val},
29-
'log': {'train_some_val': loss_val * loss_val},
37+
'progress_bar': {'some_val': loss_scalar * loss_scalar},
38+
'log': {'train_some_val': loss_scalar * loss_scalar},
3039
})
3140
return output
3241

tests/base/model_valid_epoch_ends.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ class ValidationEpochEndVariations(ABC):
77
"""
88
Houses all variations of validation_epoch_end steps
99
"""
10+
1011
def validation_epoch_end(self, outputs):
1112
"""
1213
Called at the end of validation to aggregate outputs
@@ -23,7 +24,11 @@ def _mean(res, key):
2324
val_loss_mean = _mean(outputs, 'val_loss')
2425
val_acc_mean = _mean(outputs, 'val_acc')
2526

26-
metrics_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
27+
# alternate between tensor and scalar
28+
if self.current_epoch % 2:
29+
metrics_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
30+
else:
31+
metrics_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
2732
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
2833
return results
2934

0 commit comments

Comments
 (0)