-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdga_train.py
76 lines (67 loc) · 2.3 KB
/
dga_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python
# encoding: utf-8
'''
@author: caopeng
@license: (C) Copyright 2016-2020, Big Bird Corporation Limited.
@contact: [email protected]
@software: garner
@file: dga_train.py
@time: 2019/7/10 21:39
@desc:
'''
from model.LstmModel import LstmModel
from model.LstmWithAttentionModel import LstmWithAttentionModel
import codecs
import numpy as np
import datetime
if __name__ == '__main__':
starttime = datetime.datetime.now()
batch_size = 100 # 批处理大小
epochs = 30 # 训练轮数
trainDataPath = './train_data/train_data_01.txt' # 原始数据文件路径
modelPath = 'dga_by_lstm_model-attend0712.h5' # 模型文件保存路径或读取路径
# 读取配置文件
charList = {}
confFilePath = './conf/charList.txt'
confFile = codecs.open(filename=confFilePath, mode='r', encoding='utf-8', errors='ignore')
lines = confFile.readlines()
# 字符序列要从1开始,0是填充字符
ii = 1
for line in lines:
temp = line.strip('\n').strip('\r').strip(' ')
if temp != '':
charList[temp] = ii
ii += 1
max_features = ii
# 训练数据
# 转换数据格式
x_data_sum = []
y_data_sum = []
#
trainFile = codecs.open(filename=trainDataPath, mode='r', encoding='utf-8', errors='ignore')
lines = trainFile.readlines()
for line in lines:
if line.strip('\n').strip('\r').strip(' ') == '':
continue
x_data = []
s = line.strip('\n').strip('\r').strip(' ').split(' ')
x = str(s[0])
y = int(s[1])
for char in x:
try:
x_data.append(charList[char])
except:
print('unexpected char' + ' : ' + char)
x_data.append(0)
x_data_sum.append(x_data)
y_data_sum.append(y)
x_data_sum = np.array(x_data_sum)
y_data_sum = np.array(y_data_sum)
# LstmModel
# lstmModel = LstmModel()
# lstmModel.trainModel(max_features, x_data_sum, y_data_sum, batch_size, epochs, modelPath)
lstmWithAttentionModel = LstmWithAttentionModel()
lstmWithAttentionModel.trainModel(max_features, x_data_sum, y_data_sum, batch_size, epochs, modelPath)
endtime = datetime.datetime.now()
print('=== starttime : ',starttime)
print('=== endtime : ',endtime)