@@ -41,9 +41,9 @@ function _convert(
41
41
) where {S, T}
42
42
43
43
if node. is_leaf
44
- featfreq = Tuple (sum (labels[node. region] .== l) for l in list)
44
+ classfreq = Tuple (sum (labels[node. region] .== l) for l in list)
45
45
return Leaf {T, length(list)} (
46
- Tuple (list), argmax (featfreq ), featfreq , length (node. region))
46
+ Tuple (list), argmax (classfreq ), classfreq , length (node. region))
47
47
else
48
48
left = _convert (node. l, list, labels)
49
49
right = _convert (node. r, list, labels)
@@ -117,6 +117,7 @@ function build_stump(
117
117
labels :: AbstractVector{T} ,
118
118
features :: AbstractMatrix{S} ,
119
119
weights = nothing ;
120
+ n_classes :: Int = length (unique (labels)),
120
121
rng = Random. GLOBAL_RNG,
121
122
impurity_importance :: Bool = true ) where {S, T}
122
123
@@ -133,7 +134,7 @@ function build_stump(
133
134
min_purity_increase = 0.0 ;
134
135
rng = rng)
135
136
136
- return _build_tree (t, labels, size (features, 2 ), size (features, 1 ), impurity_importance)
137
+ return _build_tree (t, labels, n_classes, size (features, 2 ), size (features, 1 ), impurity_importance)
137
138
end
138
139
139
140
function build_tree (
@@ -144,6 +145,7 @@ function build_tree(
144
145
min_samples_leaf = 1 ,
145
146
min_samples_split = 2 ,
146
147
min_purity_increase = 0.0 ;
148
+ n_classes :: Int = length (unique (labels)),
147
149
loss = util. entropy :: Function ,
148
150
rng = Random. GLOBAL_RNG,
149
151
impurity_importance :: Bool = true ) where {S, T}
@@ -168,18 +170,18 @@ function build_tree(
168
170
min_purity_increase = Float64 (min_purity_increase),
169
171
rng = rng)
170
172
171
- return _build_tree (t, labels, size (features, 2 ), size (features, 1 ), impurity_importance)
173
+ return _build_tree (t, labels, n_classes, size (features, 2 ), size (features, 1 ), impurity_importance)
172
174
end
173
175
174
176
function _build_tree (
175
177
tree:: treeclassifier.Tree{S, T} ,
176
178
labels:: AbstractVector{T} ,
179
+ n_classes:: Int ,
177
180
n_features,
178
181
n_samples,
179
182
impurity_importance:: Bool
180
183
) where {S, T}
181
184
node = _convert (tree. root, tree. list, labels[tree. labels])
182
- n_classes = unique (labels) |> length
183
185
if ! impurity_importance
184
186
return Root {S, T, n_classes} (node, n_features, Float64[])
185
187
else
@@ -237,15 +239,15 @@ function prune_tree(
237
239
if ! isempty (fi)
238
240
update_pruned_impurity! (tree, fi, ntt, loss)
239
241
end
240
- return Leaf {T, N} (tree. left. features , majority, combined, total)
242
+ return Leaf {T, N} (tree. left. classes , majority, combined, total)
241
243
else
242
244
return tree
243
245
end
244
246
end
245
247
function _prune_run (tree:: Root{S, T, N} , purity_thresh:: Real ) where {S, T, N}
246
248
fi = deepcopy (tree. featim) # # recalculate feature importances
247
249
node = _prune_run (tree. node, purity_thresh, fi)
248
- return Root {S, T, N} (node, fi)
250
+ return Root {S, T, N} (node, tree . n_feat, fi)
249
251
end
250
252
function _prune_run (
251
253
tree:: LeafOrNode{S, T, N} ,
@@ -273,7 +275,7 @@ function prune_tree(
273
275
end
274
276
275
277
276
- apply_tree (leaf:: Leaf , feature:: AbstractVector ) = leaf. features [leaf. majority]
278
+ apply_tree (leaf:: Leaf , feature:: AbstractVector ) = leaf. classes [leaf. majority]
277
279
apply_tree (
278
280
tree:: Root{S, T} ,
279
281
features:: AbstractVector{S}
@@ -369,10 +371,11 @@ function build_forest(
369
371
370
372
t_samples = length (labels)
371
373
n_samples = floor (Int, partial_sampling * t_samples)
374
+ n_classes = length (unique (labels))
372
375
373
376
forest = impurity_importance ?
374
- Vector {Root{S, T}} (undef, n_trees) :
375
- Vector {LeafOrNode{S, T}} (undef, n_trees)
377
+ Vector {Root{S, T, n_classes }} (undef, n_trees) :
378
+ Vector {LeafOrNode{S, T, n_classes }} (undef, n_trees)
376
379
377
380
entropy_terms = util. compute_entropy_terms (n_samples)
378
381
loss = (ns, n) -> util. entropy (ns, n, entropy_terms)
@@ -392,7 +395,8 @@ function build_forest(
392
395
max_depth,
393
396
min_samples_leaf,
394
397
min_samples_split,
395
- min_purity_increase,
398
+ min_purity_increase;
399
+ n_classes,
396
400
loss = loss,
397
401
rng = _rng,
398
402
impurity_importance = impurity_importance)
@@ -408,7 +412,8 @@ function build_forest(
408
412
max_depth,
409
413
min_samples_leaf,
410
414
min_samples_split,
411
- min_purity_increase,
415
+ min_purity_increase;
416
+ n_classes,
412
417
loss = loss,
413
418
impurity_importance = impurity_importance)
414
419
end
@@ -418,13 +423,13 @@ function build_forest(
418
423
end
419
424
420
425
function _build_forest (
421
- forest :: Vector{<: Union{Root{S, T}, LeafOrNode{S, T}}} ,
426
+ forest :: Vector{<: Union{Root{S, T, N }, LeafOrNode{S, T, N }}} ,
422
427
n_features ,
423
428
n_trees ,
424
- impurity_importance :: Bool ) where {S, T}
429
+ impurity_importance :: Bool ) where {S, T, N }
425
430
426
431
if ! impurity_importance
427
- return Ensemble {S, T} (forest, n_features, Float64[])
432
+ return Ensemble {S, T, N } (forest, n_features, Float64[])
428
433
else
429
434
fi = zeros (Float64, n_features)
430
435
for tree in forest
@@ -434,12 +439,12 @@ function _build_forest(
434
439
end
435
440
end
436
441
437
- forest_new = Vector {LeafOrNode{S, T}} (undef, n_trees)
442
+ forest_new = Vector {LeafOrNode{S, T, N }} (undef, n_trees)
438
443
Threads. @threads for i in 1 : n_trees
439
444
forest_new[i] = forest[i]. node
440
445
end
441
446
442
- return Ensemble {S, T} (forest_new, n_features, fi ./ n_trees)
447
+ return Ensemble {S, T, N } (forest_new, n_features, fi ./ n_trees)
443
448
end
444
449
end
445
450
@@ -516,11 +521,13 @@ function build_adaboost_stumps(
516
521
stumps = Node{S, T}[]
517
522
coeffs = Float64[]
518
523
n_features = size (features, 2 )
524
+ n_classes = length (unique (labels))
519
525
for i in 1 : n_iterations
520
526
new_stump = build_stump (
521
527
labels,
522
528
features,
523
529
weights;
530
+ n_classes,
524
531
rng= mk_rng (rng),
525
532
impurity_importance= false
526
533
)
@@ -540,7 +547,7 @@ function build_adaboost_stumps(
540
547
break
541
548
end
542
549
end
543
- return (Ensemble {S, T} (stumps, n_features, Float64[]), coeffs)
550
+ return (Ensemble {S, T, n_classes } (stumps, n_features, Float64[]), coeffs)
544
551
end
545
552
546
553
apply_adaboost_stumps (
0 commit comments