-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtest_save.py
52 lines (38 loc) · 1.13 KB
/
test_save.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import unittest
import yaml
from pytorch_ner.save import save_model
from pytorch_ner.utils import rmdir
from tests.test_train import label2idx, model, token2idx
path_to_save_folder = "models/test"
path_to_no_onnx_folder = os.path.join(path_to_save_folder, "no_onnx")
path_to_onnx_folder = os.path.join(path_to_save_folder, "onnx")
with open("config.yaml", "r") as fp:
config = yaml.safe_load(fp)
# without onnx
save_model(
path_to_folder=path_to_no_onnx_folder,
model=model,
token2idx=token2idx,
label2idx=label2idx,
config=config,
export_onnx=False,
)
# with onnx
save_model(
path_to_folder=path_to_onnx_folder,
model=model,
token2idx=token2idx,
label2idx=label2idx,
config=config,
export_onnx=True,
)
class TestSave(unittest.TestCase):
def test_num_files(self):
self.assertTrue(len(os.listdir(os.listdir(path_to_no_onnx_folder)[0])) == 4)
def test_num_files_with_onnx(self):
self.assertTrue(len(os.listdir(os.listdir(path_to_onnx_folder)[0])) == 5)
def tearDownClass(cls):
rmdir(path_to_save_folder)
if __name__ == "__main__":
unittest.main()