Skip to content

Commit 5dc6421

Browse files
authored
Add score output for OneClassSVM model (#101)
* add 2 outs * update
1 parent 3df5bb6 commit 5dc6421

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

sqlflow_models/one_class_svm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,5 @@ def sqlflow_train_loop(self, dataset):
103103
def sqlflow_predict_one(self, features):
104104
features = self.concat_features(features)
105105
pred = self.svm.predict(features)
106-
return [pred]
106+
score = self.svm.decision_function(features)
107+
return pred, score

tests/test_one_class_svm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ def test_main(self):
5353

5454
predict_dataset = self.create_dataset()
5555
for features in dataset_reader(predict_dataset):
56-
pred = svm.sqlflow_predict_one(features)
56+
pred = svm.sqlflow_predict_one(features)[0]
5757
pred = np.array(pred)
58-
self.assertEqual(pred.shape, (1, 1))
59-
self.assertTrue(pred[0][0] == 1 or pred[0][0] == -1)
58+
self.assertEqual(pred.shape, (1,))
59+
self.assertTrue(pred[0] == 1 or pred[0] == -1)
6060

6161

6262
if __name__ == '__main__':

0 commit comments

Comments
 (0)