Skip to content

Commit 3d50fe6

Browse files
committed
beauty code
1 parent 75d31f7 commit 3d50fe6

File tree

1 file changed

+34
-65
lines changed

1 file changed

+34
-65
lines changed

decision_tree/decision_tree.py

+34-65
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,37 @@
22

33
import cv2
44
import time
5-
import math
5+
import logging
66
import numpy as np
77
import pandas as pd
88

99

1010
from sklearn.cross_validation import train_test_split
1111
from sklearn.metrics import accuracy_score
1212

13+
1314
total_class = 10
1415

16+
def log(func):
17+
def wrapper(*args, **kwargs):
18+
start_time = time.time()
19+
logging.debug('start %s()' % func.__name__)
20+
ret = func(*args, **kwargs)
21+
22+
end_time = time.time()
23+
logging.debug('end %s(), cost %s seconds' % (func.__name__,end_time-start_time))
24+
25+
return ret
26+
return wrapper
27+
28+
1529
# 二值化
1630
def binaryzation(img):
1731
cv_img = img.astype(np.uint8)
1832
cv2.threshold(cv_img,50,1,cv2.cv.CV_THRESH_BINARY_INV,cv_img)
1933
return cv_img
2034

35+
@log
2136
def binaryzation_features(trainset):
2237
features = []
2338

@@ -49,8 +64,6 @@ def predict(self,features):
4964
if self.node_type == 'leaf':
5065
return self.Class
5166

52-
print 'in'
53-
5467
tree = self.dict[features[self.feature]]
5568
return tree.predict(features)
5669

@@ -94,34 +107,22 @@ def calc_ent_grap(x,y):
94107

95108
return ent_grap
96109

97-
def train(train_set,train_label,features,epsilon):
110+
def recurse_train(train_set,train_label,features,epsilon):
98111
global total_class
99112

100113
LEAF = 'leaf'
101114
INTERNAL = 'internal'
102115

103-
104116
# 步骤1——如果train_set中的所有实例都属于同一类Ck
105-
label_dict = [0 for i in xrange(total_class)]
106-
for label in train_label:
107-
label_dict[label] += 1
108-
109-
for label, label_count in enumerate(label_dict):
110-
if label_count == len(train_label):
111-
tree = Tree(LEAF,Class = label)
112-
return tree
117+
label_set = set(train_label)
118+
if len(label_set) == 1:
119+
return Tree(LEAF,Class = label_set.pop())
113120

114121
# 步骤2——如果features为空
115-
max_len,max_class = 0,0
116-
for i in xrange(total_class):
117-
class_i = filter(lambda x:x==i,train_label)
118-
if len(class_i) > max_len:
119-
max_class = i
120-
max_len = len(class_i)
122+
(max_class,max_len) = max([(i,len(filter(lambda x:x==i,train_label))) for i in xrange(total_class)],key = lambda x:x[1])
121123

122124
if len(features) == 0:
123-
tree = Tree(LEAF,Class = max_class)
124-
return tree
125+
return Tree(LEAF,Class = max_class)
125126

126127
# 步骤3——计算信息增益
127128
max_feature = 0
@@ -138,8 +139,7 @@ def train(train_set,train_label,features,epsilon):
138139

139140
# 步骤4——小于阈值
140141
if max_gda < epsilon:
141-
tree = Tree(LEAF,Class = max_class)
142-
return tree
142+
return Tree(LEAF,Class = max_class)
143143

144144
# 步骤5——构建非空子集
145145
sub_features = filter(lambda x:x!=max_feature,features)
@@ -157,11 +157,16 @@ def train(train_set,train_label,features,epsilon):
157157
sub_train_set = train_set[index]
158158
sub_train_label = train_label[index]
159159

160-
sub_tree = train(sub_train_set,sub_train_label,sub_features,epsilon)
160+
sub_tree = recurse_train(sub_train_set,sub_train_label,sub_features,epsilon)
161161
tree.add_tree(feature_value,sub_tree)
162162

163163
return tree
164164

165+
@log
166+
def train(train_set,train_label,features,epsilon):
167+
return recurse_train(train_set,train_label,features,epsilon)
168+
169+
@log
165170
def predict(test_set,tree):
166171

167172
result = []
@@ -173,61 +178,25 @@ def predict(test_set,tree):
173178

174179

175180
if __name__ == '__main__':
176-
# classes = [0,0,1,1,0,0,0,1,1,1,1,1,1,1,0]
177-
#
178-
# age = [0,0,0,0,0,1,1,1,1,1,2,2,2,2,2]
179-
# occupation = [0,0,1,1,0,0,0,1,0,0,0,0,1,1,0]
180-
# house = [0,0,0,1,0,0,0,1,1,1,1,1,0,0,0]
181-
# loan = [0,1,1,0,0,0,1,1,2,2,2,1,1,2,0]
182-
#
183-
# features = []
184-
#
185-
# for i in range(15):
186-
# feature = [age[i],occupation[i],house[i],loan[i]]
187-
# features.append(feature)
188-
#
189-
# trainset = np.array(features)
190-
#
191-
# tree = train(trainset,np.array(classes),[0,1,2,3],0.1)
192-
#
193-
# print type(tree)
194-
# features = [0,0,0,1]
195-
# print tree.predict(np.array(features))
196-
197-
198-
print 'Start read data'
199-
200-
time_1 = time.time()
181+
logger = logging.getLogger()
182+
logger.setLevel(logging.DEBUG)
201183

202184
raw_data = pd.read_csv('../data/train.csv',header=0)
203185
data = raw_data.values
204186

205187
imgs = data[0::,1::]
206188
labels = data[::,0]
207189

190+
# 图片二值化
208191
features = binaryzation_features(imgs)
209192

210193
# 选取 2/3 数据作为训练集, 1/3 数据作为测试集
211194
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.33, random_state=23323)
212-
# print train_features.shape
213-
# print train_features.shape
214-
215-
time_2 = time.time()
216-
print 'read data cost ',time_2 - time_1,' second','\n'
217195

218-
print 'Start training'
219-
tree = train(train_features,train_labels,[i for i in range(784)],0.2)
220-
print type(tree)
221-
print 'knn do not need to train'
222-
time_3 = time.time()
223-
print 'training cost ',time_3 - time_2,' second','\n'
224-
225-
print 'Start predicting'
196+
tree = train(train_features,train_labels,[i for i in range(784)],0.1)
226197
test_predict = predict(test_features,tree)
227-
time_4 = time.time()
228-
print 'predicting cost ',time_4 - time_3,' second','\n'
229-
230198
score = accuracy_score(test_labels,test_predict)
199+
231200
print "The accruacy socre is ", score
232201

233202

0 commit comments

Comments
 (0)