Skip to content

Commit e8b35f1

Browse files
committed
🚧 wip
1 parent 031729e commit e8b35f1

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

tests/base/model_test_steps.py

+31-4
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,11 @@ def test_step_result_preds(self, batch, batch_idx, optimizer_idx=None):
115115
result.log('test_loss', loss_test)
116116
result.log('test_acc', test_acc)
117117

118-
#lst_of_str = [random.choice(['dog', 'cat']) for i in range(batch_size)]
119-
# int_outputs = [random.randint(500, 1000) for i in range(batch_size)]
120-
#nested_lst = [[x] for x in int_outputs]
121-
#lst_of_dicts = [{k: v} for k, v in zip(lst_of_str, int_outputs)]
118+
batch_size = x.size(0)
119+
lst_of_str = [random.choice(['dog', 'cat']) for i in range(batch_size)]
120+
lst_of_int = [random.randint(500, 1000) for i in range(batch_size)]
121+
lst_of_lst = [[x] for x in lst_of_int]
122+
lst_of_dict = [{k: v} for k, v in zip(lst_of_str, lst_of_int)]
122123

123124
# This is passed in from pytest via parameterization
124125
option = getattr(self, 'test_option', 0)
@@ -135,5 +136,31 @@ def test_step_result_preds(self, batch, batch_idx, optimizer_idx=None):
135136
elif option == 1:
136137
result.write('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file)
137138
result.write('preds', labels_hat, prediction_file)
139+
140+
# write multi-dimension
141+
elif option == 2:
142+
result.write('idxs', lazy_ids, prediction_file)
143+
result.write('preds', labels_hat, prediction_file)
144+
result.write('x', x, prediction_file)
145+
146+
# write str list
147+
elif option == 3:
148+
result.write('idxs', lazy_ids, prediction_file)
149+
result.write('vals', lst_of_str, prediction_file)
150+
151+
# write int list
152+
elif option == 4:
153+
result.write('idxs', lazy_ids, prediction_file)
154+
result.write('vals', lst_of_str, prediction_file)
155+
156+
# write nested list
157+
elif option == 5:
158+
result.write('idxs', lazy_ids, prediction_file)
159+
result.write('vals', lst_of_str, prediction_file)
160+
161+
# write dict list
162+
elif option == 6:
163+
result.write('idxs', lazy_ids, prediction_file)
164+
result.write('vals', lst_of_str, prediction_file)
138165

139166
return result

tests/core/test_results.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,21 @@ def test_result_reduce_ddp(result_cls):
5555
pytest.param(
5656
1, False, 0, id='test_only_mismatching_tensor', marks=pytest.mark.xfail(raises=ValueError, match="Mism.*")
5757
),
58+
pytest.param(
59+
2, False, 0, id='mix_of_tensor_dims'
60+
),
61+
pytest.param(
62+
3, False, 0, id='string_list_predictions'
63+
),
64+
pytest.param(
65+
4, False, 0, id='int_list_predictions'
66+
),
67+
pytest.param(
68+
5, False, 0, id='nested_list_predictions'
69+
),
70+
pytest.param(
71+
6, False, 0, id='dict_list_predictions'
72+
),
5873
pytest.param(
5974
0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires single-GPU machine")
6075
)
@@ -153,7 +168,9 @@ def test_result_obj_predictions_ddp_spawn(tmpdir):
153168
dm.setup('test')
154169

155170
# check prediction file now exists and is of expected length
171+
size = 0
156172
for prediction_file in prediction_files:
157173
assert prediction_file.exists()
158174
predictions = torch.load(prediction_file)
159-
assert len(predictions) == len(dm.mnist_test) // 2
175+
size += len(predictions)
176+
assert size == len(dm.mnist_test)

0 commit comments

Comments
 (0)