@@ -115,10 +115,11 @@ def test_step_result_preds(self, batch, batch_idx, optimizer_idx=None):
115
115
result .log ('test_loss' , loss_test )
116
116
result .log ('test_acc' , test_acc )
117
117
118
- #lst_of_str = [random.choice(['dog', 'cat']) for i in range(batch_size)]
119
- # int_outputs = [random.randint(500, 1000) for i in range(batch_size)]
120
- #nested_lst = [[x] for x in int_outputs]
121
- #lst_of_dicts = [{k: v} for k, v in zip(lst_of_str, int_outputs)]
118
+ batch_size = x .size (0 )
119
+ lst_of_str = [random .choice (['dog' , 'cat' ]) for i in range (batch_size )]
120
+ lst_of_int = [random .randint (500 , 1000 ) for i in range (batch_size )]
121
+ lst_of_lst = [[x ] for x in lst_of_int ]
122
+ lst_of_dict = [{k : v } for k , v in zip (lst_of_str , lst_of_int )]
122
123
123
124
# This is passed in from pytest via parameterization
124
125
option = getattr (self , 'test_option' , 0 )
@@ -135,5 +136,31 @@ def test_step_result_preds(self, batch, batch_idx, optimizer_idx=None):
135
136
elif option == 1 :
136
137
result .write ('idxs' , torch .cat ((lazy_ids , lazy_ids )), prediction_file )
137
138
result .write ('preds' , labels_hat , prediction_file )
139
+
140
+ # write multi-dimension
141
+ elif option == 2 :
142
+ result .write ('idxs' , lazy_ids , prediction_file )
143
+ result .write ('preds' , labels_hat , prediction_file )
144
+ result .write ('x' , x , prediction_file )
145
+
146
+ # write str list
147
+ elif option == 3 :
148
+ result .write ('idxs' , lazy_ids , prediction_file )
149
+ result .write ('vals' , lst_of_str , prediction_file )
150
+
151
+ # write int list
152
+ elif option == 4 :
153
+ result .write ('idxs' , lazy_ids , prediction_file )
154
+ result .write ('vals' , lst_of_str , prediction_file )
155
+
156
+ # write nested list
157
+ elif option == 5 :
158
+ result .write ('idxs' , lazy_ids , prediction_file )
159
+ result .write ('vals' , lst_of_str , prediction_file )
160
+
161
+ # write dict list
162
+ elif option == 6 :
163
+ result .write ('idxs' , lazy_ids , prediction_file )
164
+ result .write ('vals' , lst_of_str , prediction_file )
138
165
139
166
return result
0 commit comments