-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdga_test.py
65 lines (56 loc) · 1.84 KB
/
dga_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python
# encoding: utf-8
'''
@author: caopeng
@license: (C) Copyright 2016-2020, Big Bird Corporation Limited.
@contact: [email protected]
@software: garner
@file: dga_test.py
@time: 2019/7/11 22:38
@desc:
'''
from model.LstmModel import LstmModel
import codecs
import numpy as np
import datetime
if __name__ == '__main__':
batch_size = 100 # 批处理大小
epochs = 1 # 训练轮数
trainDataPath = './test_data/test_data_domain.txt' # 原始数据文件路径
modelPath = 'dga_by_lstm_model-attend0712.h5' # 模型文件保存路径或读取路径
resultPath = './test_data/test_data_domainAttention_result-0712.txt'
# 读取配置文件
charList = {}
confFilePath = './conf/charList.txt'
confFile = codecs.open(filename=confFilePath, mode='r', encoding='utf-8', errors='ignore')
lines = confFile.readlines()
# 字符序列要从1开始,0是填充字符
ii = 1
for line in lines:
temp = line.strip('\n').strip('\r').strip(' ')
if temp != '':
charList[temp] = ii
ii += 1
max_features = ii
#
# 训练数据
# 转换数据格式
x_data_sum = []
trainFile = codecs.open(filename=trainDataPath, mode='r', encoding='utf-8', errors='ignore')
lines = trainFile.readlines()
for line in lines:
if line.strip('\n').strip('\r').strip(' ') == '':
continue
x_data = []
x = line.strip('\n').strip('\r').strip(' ')
for char in x:
try:
x_data.append(charList[char])
except:
print('unexpected char' + ' : ' + char)
x_data.append(0)
x_data_sum.append(x_data)
x_data_sum = np.array(x_data_sum)
# LstmModel
lstmModel = LstmModel()
lstmModel.predict(x_data_sum, batch_size, modelPath, resultPath)