Skip to content

Commit 155d85e

Browse files
committed
Fix more test results
1 parent d178d40 commit 155d85e

File tree

5 files changed

+54
-38
lines changed

5 files changed

+54
-38
lines changed

src/DecisionTree.jl

+10-5
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export InfoNode, InfoLeaf, wrap
2727
########## Types ##########
2828

2929
struct Leaf{T, N}
30-
features :: NTuple{N, T}
30+
classes :: NTuple{N, T}
3131
majority :: Int
3232
values :: NTuple{N, Int}
3333
total :: Int
@@ -54,15 +54,20 @@ struct Ensemble{S, T, N}
5454
featim :: Vector{Float64}
5555
end
5656

57-
Leaf(features::NTuple{T, N}) where {T, N} =
57+
Leaf(features::NTuple{N, T}) where {T, N} =
5858
Leaf(features, 0, Tuple(zeros(T, N)), 0)
59+
Leaf(features::NTuple{N, T}, frequencies::NTuple{N, Int}) where {T, N} =
60+
Leaf(features, argmax(frequencies), frequencies, sum(frequencies))
61+
Leaf(features::Union{<:AbstractVector, <:Tuple},
62+
frequencies::Union{<:AbstractVector{Int}, <:Tuple}) =
63+
Leaf(Tuple(features), Tuple(frequencies))
5964

6065
is_leaf(l::Leaf) = true
6166
is_leaf(n::Node) = false
6267

6368
_zero(::Type{String}) = ""
6469
_zero(x::Any) = zero(x)
65-
convert(::Type{Node{S, T, N}}, lf::Leaf{T, N}) where {S, T, N} = Node(0, _zero(S), lf, Leaf(lf.features))
70+
convert(::Type{Node{S, T, N}}, lf::Leaf{T, N}) where {S, T, N} = Node(0, _zero(S), lf, Leaf(lf.classes))
6671
convert(::Type{Root{S, T, N}}, node::LeafOrNode{S, T, N}) where {S, T, N} = Root{S, T, N}(node, 0, Float64[])
6772
convert(::Type{LeafOrNode{S, T, N}}, tree::Root{S, T, N}) where {S, T, N} = tree.node
6873
promote_rule(::Type{Node{S, T, N}}, ::Type{Leaf{T, N}}) where {S, T, N} = Node{S, T, N}
@@ -101,7 +106,7 @@ depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
101106
depth(tree::Root) = depth(tree.node)
102107

103108
function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
104-
println(io, leaf.features[leaf.majority], " : ",
109+
println(io, leaf.classes[leaf.majority], " : ",
105110
leaf.values[leaf.majority], '/', leaf.total)
106111
end
107112
function print_tree(leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
@@ -165,7 +170,7 @@ end
165170

166171
function show(io::IO, leaf::Leaf)
167172
println(io, "Decision Leaf")
168-
println(io, "Majority: ", leaf.features[leaf.majority])
173+
println(io, "Majority: ", leaf.classes[leaf.majority])
169174
print(io, "Samples: ", leaf.total)
170175
end
171176

src/classification/main.jl

+25-18
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ function _convert(
4141
) where {S, T}
4242

4343
if node.is_leaf
44-
featfreq = Tuple(sum(labels[node.region] .== l) for l in list)
44+
classfreq = Tuple(sum(labels[node.region] .== l) for l in list)
4545
return Leaf{T, length(list)}(
46-
Tuple(list), argmax(featfreq), featfreq, length(node.region))
46+
Tuple(list), argmax(classfreq), classfreq, length(node.region))
4747
else
4848
left = _convert(node.l, list, labels)
4949
right = _convert(node.r, list, labels)
@@ -117,6 +117,7 @@ function build_stump(
117117
labels :: AbstractVector{T},
118118
features :: AbstractMatrix{S},
119119
weights = nothing;
120+
n_classes :: Int = length(unique(labels)),
120121
rng = Random.GLOBAL_RNG,
121122
impurity_importance :: Bool = true) where {S, T}
122123

@@ -133,7 +134,7 @@ function build_stump(
133134
min_purity_increase = 0.0;
134135
rng = rng)
135136

136-
return _build_tree(t, labels, size(features, 2), size(features, 1), impurity_importance)
137+
return _build_tree(t, labels, n_classes, size(features, 2), size(features, 1), impurity_importance)
137138
end
138139

139140
function build_tree(
@@ -144,6 +145,7 @@ function build_tree(
144145
min_samples_leaf = 1,
145146
min_samples_split = 2,
146147
min_purity_increase = 0.0;
148+
n_classes :: Int = length(unique(labels)),
147149
loss = util.entropy :: Function,
148150
rng = Random.GLOBAL_RNG,
149151
impurity_importance :: Bool = true) where {S, T}
@@ -168,18 +170,18 @@ function build_tree(
168170
min_purity_increase = Float64(min_purity_increase),
169171
rng = rng)
170172

171-
return _build_tree(t, labels, size(features, 2), size(features, 1), impurity_importance)
173+
return _build_tree(t, labels, n_classes, size(features, 2), size(features, 1), impurity_importance)
172174
end
173175

174176
function _build_tree(
175177
tree::treeclassifier.Tree{S, T},
176178
labels::AbstractVector{T},
179+
n_classes::Int,
177180
n_features,
178181
n_samples,
179182
impurity_importance::Bool
180183
) where {S, T}
181184
node = _convert(tree.root, tree.list, labels[tree.labels])
182-
n_classes = unique(labels) |> length
183185
if !impurity_importance
184186
return Root{S, T, n_classes}(node, n_features, Float64[])
185187
else
@@ -237,15 +239,15 @@ function prune_tree(
237239
if !isempty(fi)
238240
update_pruned_impurity!(tree, fi, ntt, loss)
239241
end
240-
return Leaf{T, N}(tree.left.features, majority, combined, total)
242+
return Leaf{T, N}(tree.left.classes, majority, combined, total)
241243
else
242244
return tree
243245
end
244246
end
245247
function _prune_run(tree::Root{S, T, N}, purity_thresh::Real) where {S, T, N}
246248
fi = deepcopy(tree.featim) ## recalculate feature importances
247249
node = _prune_run(tree.node, purity_thresh, fi)
248-
return Root{S, T, N}(node, fi)
250+
return Root{S, T, N}(node, tree.n_feat, fi)
249251
end
250252
function _prune_run(
251253
tree::LeafOrNode{S, T, N},
@@ -273,7 +275,7 @@ function prune_tree(
273275
end
274276

275277

276-
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.features[leaf.majority]
278+
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.classes[leaf.majority]
277279
apply_tree(
278280
tree::Root{S, T},
279281
features::AbstractVector{S}
@@ -369,10 +371,11 @@ function build_forest(
369371

370372
t_samples = length(labels)
371373
n_samples = floor(Int, partial_sampling * t_samples)
374+
n_classes = length(unique(labels))
372375

373376
forest = impurity_importance ?
374-
Vector{Root{S, T}}(undef, n_trees) :
375-
Vector{LeafOrNode{S, T}}(undef, n_trees)
377+
Vector{Root{S, T, n_classes}}(undef, n_trees) :
378+
Vector{LeafOrNode{S, T, n_classes}}(undef, n_trees)
376379

377380
entropy_terms = util.compute_entropy_terms(n_samples)
378381
loss = (ns, n) -> util.entropy(ns, n, entropy_terms)
@@ -392,7 +395,8 @@ function build_forest(
392395
max_depth,
393396
min_samples_leaf,
394397
min_samples_split,
395-
min_purity_increase,
398+
min_purity_increase;
399+
n_classes,
396400
loss = loss,
397401
rng = _rng,
398402
impurity_importance = impurity_importance)
@@ -408,7 +412,8 @@ function build_forest(
408412
max_depth,
409413
min_samples_leaf,
410414
min_samples_split,
411-
min_purity_increase,
415+
min_purity_increase;
416+
n_classes,
412417
loss = loss,
413418
impurity_importance = impurity_importance)
414419
end
@@ -418,13 +423,13 @@ function build_forest(
418423
end
419424

420425
function _build_forest(
421-
forest :: Vector{<: Union{Root{S, T}, LeafOrNode{S, T}}},
426+
forest :: Vector{<: Union{Root{S, T, N}, LeafOrNode{S, T, N}}},
422427
n_features ,
423428
n_trees ,
424-
impurity_importance :: Bool) where {S, T}
429+
impurity_importance :: Bool) where {S, T, N}
425430

426431
if !impurity_importance
427-
return Ensemble{S, T}(forest, n_features, Float64[])
432+
return Ensemble{S, T, N}(forest, n_features, Float64[])
428433
else
429434
fi = zeros(Float64, n_features)
430435
for tree in forest
@@ -434,12 +439,12 @@ function _build_forest(
434439
end
435440
end
436441

437-
forest_new = Vector{LeafOrNode{S, T}}(undef, n_trees)
442+
forest_new = Vector{LeafOrNode{S, T, N}}(undef, n_trees)
438443
Threads.@threads for i in 1:n_trees
439444
forest_new[i] = forest[i].node
440445
end
441446

442-
return Ensemble{S, T}(forest_new, n_features, fi ./ n_trees)
447+
return Ensemble{S, T, N}(forest_new, n_features, fi ./ n_trees)
443448
end
444449
end
445450

@@ -516,11 +521,13 @@ function build_adaboost_stumps(
516521
stumps = Node{S, T}[]
517522
coeffs = Float64[]
518523
n_features = size(features, 2)
524+
n_classes = length(unique(labels))
519525
for i in 1:n_iterations
520526
new_stump = build_stump(
521527
labels,
522528
features,
523529
weights;
530+
n_classes,
524531
rng=mk_rng(rng),
525532
impurity_importance=false
526533
)
@@ -540,7 +547,7 @@ function build_adaboost_stumps(
540547
break
541548
end
542549
end
543-
return (Ensemble{S, T}(stumps, n_features, Float64[]), coeffs)
550+
return (Ensemble{S, T, n_classes}(stumps, n_features, Float64[]), coeffs)
544551
end
545552

546553
apply_adaboost_stumps(

src/regression/main.jl

+13-9
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
include("tree.jl")
22

33
function _convert(node::treeregressor.NodeMeta{S}, labels::Array{T}) where {S, T <: Float64}
4+
classes = Tuple(unique(labels))
45
if node.is_leaf
5-
features = Tuple(unique(labels))
6-
featfreq = Tuple(sum(labels[node.region] .== f) for f in features)
7-
return Leaf{T, length(features)}(
8-
features, argmax(featfreq), featfreq, length(node.region))
6+
classfreq = Tuple(sum(labels[node.region] .== f) for f in classes)
7+
return Leaf{T, length(classes)}(
8+
classes, argmax(classfreq), classfreq, length(node.region))
99
else
1010
left = _convert(node.l, labels)
1111
right = _convert(node.r, labels)
12-
return Node{S, T}(node.feature, node.threshold, left, right)
12+
return Node{S, T, length(classes)}(node.feature, node.threshold, left, right)
1313
end
1414
end
1515

@@ -34,6 +34,7 @@ function build_tree(
3434
min_samples_leaf = 5,
3535
min_samples_split = 2,
3636
min_purity_increase = 0.0;
37+
n_classes :: Int = length(unique(labels)),
3738
rng = Random.GLOBAL_RNG,
3839
impurity_importance:: Bool = true) where {S, T <: Float64}
3940

@@ -59,11 +60,11 @@ function build_tree(
5960
node = _convert(t.root, labels[t.labels])
6061
n_features = size(features, 2)
6162
if !impurity_importance
62-
return Root{S, T}(node, n_features, Float64[])
63+
return Root{S, T, n_classes}(node, n_features, Float64[])
6364
else
6465
fi = zeros(Float64, n_features)
6566
update_using_impurity!(fi, t.root)
66-
return Root{S, T}(node, n_features, fi ./ size(features, 1))
67+
return Root{S, T, n_classes}(node, n_features, fi ./ size(features, 1))
6768
end
6869
end
6970

@@ -77,6 +78,7 @@ function build_forest(
7778
min_samples_leaf = 5,
7879
min_samples_split = 2,
7980
min_purity_increase = 0.0;
81+
n_classes :: Int = length(unique(labels)),
8082
rng::Union{Integer,AbstractRNG} = Random.GLOBAL_RNG,
8183
impurity_importance :: Bool = true) where {S, T <: Float64}
8284

@@ -112,7 +114,8 @@ function build_forest(
112114
max_depth,
113115
min_samples_leaf,
114116
min_samples_split,
115-
min_purity_increase,
117+
min_purity_increase;
118+
n_classes,
116119
rng = _rng,
117120
impurity_importance = impurity_importance)
118121
end
@@ -127,7 +130,8 @@ function build_forest(
127130
max_depth,
128131
min_samples_leaf,
129132
min_samples_split,
130-
min_purity_increase,
133+
min_purity_increase;
134+
n_classes,
131135
impurity_importance = impurity_importance)
132136
end
133137
end

test/miscellaneous/abstract_trees_test.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ clabel_pattern(clabel) = "─ " * clabel * " (" # class labels are embedde
1717
check_occurence(str_tree, pool, pattern) = count(map(elem -> occursin(pattern(elem), str_tree), pool)) == length(pool)
1818

1919
@info("Test base functionality")
20-
l1 = Leaf(1, [1,1,2])
21-
l2 = Leaf(2, [1,2,2])
22-
l3 = Leaf(3, [3,3,1])
20+
l1 = Leaf((1,2,3), 1, (2, 1, 0), 3)
21+
l2 = Leaf((1,2,3), 2, (1, 2, 0), 3)
22+
l3 = Leaf((1,2,3), 3, (1, 0, 2), 3)
2323
n2 = Node(2, 0.5, l2, l3)
2424
n1 = Node(1, 0.7, l1, n2)
2525
feature_names = ["firstFt", "secondFt"]
@@ -81,4 +81,4 @@ end
8181
traverse_tree(leaf::InfoLeaf) = nothing
8282

8383
traverse_tree(wrapped_tree)
84-
end
84+
end

test/miscellaneous/convert.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
@testset "convert.jl" begin
44

5-
lf = Leaf(1, [1])
5+
lf = Leaf([1], [1])
66
nv = Node{Int, Int}[]
77
rv = Root{Int, Int}[]
88
push!(nv, lf)
@@ -22,7 +22,7 @@ push!(rv, nv[1])
2222
@test apply_tree(rv[1], [0]) == 1.0
2323
@test apply_tree(rv[2], [0]) == 1.0
2424

25-
lf = Leaf("A", ["B", "A"])
25+
lf = Leaf(["A", "B"], [2, 1])
2626
nv = Node{Int, String}[]
2727
rv = Root{Int, String}[]
2828
push!(nv, lf)

0 commit comments

Comments
 (0)