Skip to content

Commit 1b1a35b

Browse files
committed
some Blue style fixes
1 parent 3d0f314 commit 1b1a35b

File tree

3 files changed

+155
-58
lines changed

3 files changed

+155
-58
lines changed

src/classification/main.jl

Lines changed: 134 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,17 @@ function _convert(
4949
end
5050
end
5151

52-
function update_using_impurity!(feature_importance::Vector{Float64}, node::treeclassifier.NodeMeta{S}) where S
52+
function update_using_impurity!(
53+
feature_importance::Vector{Float64},
54+
node::treeclassifier.NodeMeta{S}
55+
) where S
5356
if !node.is_leaf
5457
update_using_impurity!(feature_importance, node.l)
5558
update_using_impurity!(feature_importance, node.r)
56-
feature_importance[node.feature] += node.node_impurity - node.l.node_impurity - node.r.node_impurity
59+
feature_importance[node.feature] +=
60+
node.node_impurity - node.l.node_impurity - node.r.node_impurity
5761
end
58-
return
62+
return
5963
end
6064

6165
nsample(leaf::Leaf) = length(leaf.values)
@@ -72,24 +76,35 @@ function votes_distribution(labels)
7276
votes
7377
end
7478

75-
function update_pruned_impurity!(tree::LeafOrNode{S, T}, feature_importance::Vector{Float64}, ntt::Int, loss::Function = util.entropy) where {S, T}
79+
function update_pruned_impurity!(
80+
tree::LeafOrNode{S, T},
81+
feature_importance::Vector{Float64},
82+
ntt::Int,
83+
loss::Function = util.entropy
84+
) where {S, T}
7685
all_labels = [tree.left.values; tree.right.values]
7786
nc = votes_distribution(all_labels)
7887
nt = length(all_labels)
7988
ncl = votes_distribution(tree.left.values)
8089
nl = length(tree.left.values)
8190
ncr = votes_distribution(tree.right.values)
8291
nr = nt - nl
83-
feature_importance[tree.featid] -= (nt * loss(nc, nt) - nl * loss(ncl, nl) - nr * loss(ncr, nr)) / ntt
92+
feature_importance[tree.featid] -=
93+
(nt * loss(nc, nt) - nl * loss(ncl, nl) - nr * loss(ncr, nr)) / ntt
8494
end
8595

86-
function update_pruned_impurity!(tree::LeafOrNode{S, T}, feature_importance::Vector{Float64}, ntt::Int, loss::Function = mean_squared_error) where {S, T <: Float64}
96+
function update_pruned_impurity!(
97+
tree::LeafOrNode{S, T},
98+
feature_importance::Vector{Float64},
99+
ntt::Int,
100+
loss::Function = mean_squared_error
101+
) where {S, T <: Float64}
87102
μl = mean(tree.left.values)
88103
nl = length(tree.left.values)
89104
μr = mean(tree.right.values)
90105
nr = length(tree.right.values)
91106
nt = nl + nr
92-
μt = (nl * μl + nr * μr) / nt
107+
μt = (nl * μl + nr * μr) / nt
93108
feature_importance[tree.featid] -= (nt * loss([tree.left.values; tree.right.values], repeat([μt], nt)) - nl * loss(tree.left.values, repeat([μl], nl)) - nr * loss(tree.right.values, repeat([μr], nr))) / ntt
94109
end
95110

@@ -153,7 +168,13 @@ function build_tree(
153168
return _build_tree(t, labels, size(features, 2), size(features, 1), impurity_importance)
154169
end
155170

156-
function _build_tree(tree::treeclassifier.Tree{S, T}, labels::AbstractVector{T}, n_features, n_samples, impurity_importance::Bool) where {S, T}
171+
function _build_tree(
172+
tree::treeclassifier.Tree{S, T},
173+
labels::AbstractVector{T},
174+
n_features,
175+
n_samples,
176+
impurity_importance::Bool
177+
) where {S, T}
157178
node = _convert(tree.root, tree.list, labels[tree.labels])
158179
if !impurity_importance
159180
return Root{S, T}(node, n_features, Float64[])
@@ -168,22 +189,42 @@ end
168189
prune_tree(tree::Union{Root, LeafOrNode}, purity_thresh=1.0, loss::Function)
169190
170191
Prune tree based on prediction accuracy of each node.
171-
* `purity_thresh`: If the prediction accuracy of a stump is larger than this value, the node will be pruned and become a leaf.
172-
* `loss`: The loss function for computing node impurity. Available function include `DecisionTree.util.entropy`, `DecisionTree.util.gini` and `mean_squared_error`. Defaults are `DecisionTree.util.entropy` and `mean_squared_error` for classification tree and regression tree, respectively. If the tree is not a `Root`, this argument does not affect the result.
173192
174-
For a tree of type `Root`, when any of its nodes is pruned, the `featim` field will be updated by recomputing the impurity decrease of that node divided by the total number of training observations and subtracting the value.
175-
The computation of impurity decrease is based on node impurity calculated with the loss function provided as the argument `loss`. The algorithm is as same as that described in the `impurity_importance` documentation.
193+
* `purity_thresh`: If the prediction accuracy of a stump is larger than this value, the node
194+
will be pruned and become a leaf.
195+
196+
* `loss`: The loss function for computing node impurity. Available function include
197+
`DecisionTree.util.entropy`, `DecisionTree.util.gini` and `mean_squared_error`. Defaults
198+
are `DecisionTree.util.entropy` and `mean_squared_error` for classification tree and
199+
regression tree, respectively. If the tree is not a `Root`, this argument does not affect
200+
the result.
201+
202+
For a tree of type `Root`, when any of its nodes is pruned, the `featim` field will be
203+
updated by recomputing the impurity decrease of that node divided by the total number of
204+
training observations and subtracting the value. The computation of impurity decrease is
205+
based on node impurity calculated with the loss function provided as the argument
206+
`loss`. The algorithm is as same as that described in the `impurity_importance`
207+
documentation.
208+
176209
This function will recurse until no stumps can be pruned.
177210
178211
Warn:
179212
For regression trees, pruning trees based on accuracy may not be an appropriate method.
180213
"""
181-
function prune_tree(tree::Union{Root{S, T}, LeafOrNode{S, T}}, purity_thresh=1.0, loss::Function = T <: Float64 ? mean_squared_error : util.entropy) where {S, T}
214+
function prune_tree(
215+
tree::Union{Root{S, T}, LeafOrNode{S, T}},
216+
purity_thresh=1.0,
217+
loss::Function = T <: Float64 ? mean_squared_error : util.entropy
218+
) where {S, T}
182219
if purity_thresh >= 1.0
183220
return tree
184221
end
185222
ntt = nsample(tree)
186-
function _prune_run_stump(tree::LeafOrNode{S, T}, purity_thresh::Real, fi::Vector{Float64} = Float64[]) where {S, T}
223+
function _prune_run_stump(
224+
tree::LeafOrNode{S, T},
225+
purity_thresh::Real,
226+
fi::Vector{Float64} = Float64[]
227+
) where {S, T}
187228
all_labels = [tree.left.values; tree.right.values]
188229
majority = majority_vote(all_labels)
189230
matches = findall(all_labels .== majority)
@@ -202,7 +243,11 @@ function prune_tree(tree::Union{Root{S, T}, LeafOrNode{S, T}}, purity_thresh=1.0
202243
node = _prune_run(tree.node, purity_thresh, fi)
203244
return Root{S, T}(node, tree.n_feat, fi)
204245
end
205-
function _prune_run(tree::LeafOrNode{S, T}, purity_thresh::Real, fi::Vector{Float64} = Float64[]) where {S, T}
246+
function _prune_run(
247+
tree::LeafOrNode{S, T},
248+
purity_thresh::Real,
249+
fi::Vector{Float64} = Float64[]
250+
) where {S, T}
206251
N = length(tree)
207252
if N == 1 ## a Leaf
208253
return tree
@@ -224,7 +269,10 @@ end
224269

225270

226271
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.majority
227-
apply_tree(tree::Root{S, T}, features::AbstractVector{S}) where {S, T} = apply_tree(tree.node, features)
272+
apply_tree(
273+
tree::Root{S, T},
274+
features::AbstractVector{S}
275+
) where {S, T} = apply_tree(tree.node, features)
228276

229277
function apply_tree(tree::Node{S, T}, features::AbstractVector{S}) where {S, T}
230278
if tree.featid == 0
@@ -236,7 +284,10 @@ function apply_tree(tree::Node{S, T}, features::AbstractVector{S}) where {S, T}
236284
end
237285
end
238286

239-
apply_tree(tree::Root{S, T}, features::AbstractMatrix{S}) where {S, T} = apply_tree(tree.node, features)
287+
apply_tree(
288+
tree::Root{S, T},
289+
features::AbstractMatrix{S}
290+
) where {S, T} = apply_tree(tree.node, features)
240291
function apply_tree(tree::LeafOrNode{S, T}, features::AbstractMatrix{S}) where {S, T}
241292
N = size(features,1)
242293
predictions = Array{T}(undef, N)
@@ -250,22 +301,26 @@ function apply_tree(tree::LeafOrNode{S, T}, features::AbstractMatrix{S}) where {
250301
end
251302
end
252303

253-
"""
304+
"""
254305
apply_tree_proba(::Root, features, col_labels::AbstractVector)
255306
256307
computes P(L=label|X) for each row in `features`. It returns a `N_row x
257308
n_labels` matrix of probabilities, each row summing up to 1.
258309
259310
`col_labels` is a vector containing the distinct labels
260311
(eg. ["versicolor", "virginica", "setosa"]). It specifies the column ordering
261-
of the output matrix.
312+
of the output matrix.
262313
"""
263-
apply_tree_proba(tree::Root{S, T}, features::AbstractVector{S}, labels) where {S, T} =
314+
apply_tree_proba(tree::Root{S, T}, features::AbstractVector{S}, labels) where {S, T} =
264315
apply_tree_proba(tree.node, features, labels)
265316
apply_tree_proba(leaf::Leaf{T}, features::AbstractVector{S}, labels) where {S, T} =
266317
compute_probabilities(labels, leaf.values)
267318

268-
function apply_tree_proba(tree::Node{S, T}, features::AbstractVector{S}, labels) where {S, T}
319+
function apply_tree_proba(
320+
tree::Node{S, T},
321+
features::AbstractVector{S},
322+
labels
323+
) where {S, T}
269324
if tree.featval === nothing
270325
return apply_tree_proba(tree.left, features, labels)
271326
elseif features[tree.featid] < tree.featval
@@ -274,7 +329,7 @@ function apply_tree_proba(tree::Node{S, T}, features::AbstractVector{S}, labels)
274329
return apply_tree_proba(tree.right, features, labels)
275330
end
276331
end
277-
apply_tree_proba(tree::Root{S, T}, features::AbstractMatrix{S}, labels) where {S, T} =
332+
apply_tree_proba(tree::Root{S, T}, features::AbstractMatrix{S}, labels) where {S, T} =
278333
apply_tree_proba(tree.node, features, labels)
279334
apply_tree_proba(tree::LeafOrNode{S, T}, features::AbstractMatrix{S}, labels) where {S, T} =
280335
stack_function_results(row->apply_tree_proba(tree, row, labels), features)
@@ -307,7 +362,9 @@ function build_forest(
307362
t_samples = length(labels)
308363
n_samples = floor(Int, partial_sampling * t_samples)
309364

310-
forest = impurity_importance ? Vector{Root{S, T}}(undef, n_trees) : Vector{LeafOrNode{S, T}}(undef, n_trees)
365+
forest = impurity_importance ?
366+
Vector{Root{S, T}}(undef, n_trees) :
367+
Vector{LeafOrNode{S, T}}(undef, n_trees)
311368

312369
entropy_terms = util.compute_entropy_terms(n_samples)
313370
loss = (ns, n) -> util.entropy(ns, n, entropy_terms)
@@ -316,8 +373,8 @@ function build_forest(
316373
Threads.@threads for i in 1:n_trees
317374
# The Mersenne Twister (Julia's default) is not thread-safe.
318375
_rng = copy(rng)
319-
# Take some elements from the ring to have different states for each tree.
320-
# This is the only way given that only a `copy` can be expected to exist for RNGs.
376+
# Take some elements from the ring to have different states for each tree. This
377+
# is the only way given that only a `copy` can be expected to exist for RNGs.
321378
rand(_rng, i)
322379
inds = rand(_rng, 1:t_samples, n_samples)
323380
forest[i] = build_tree(
@@ -353,7 +410,7 @@ function build_forest(
353410
end
354411

355412
function _build_forest(
356-
forest :: Vector{<: Union{Root{S, T}, LeafOrNode{S, T}}},
413+
forest :: Vector{<: Union{Root{S, T}, LeafOrNode{S, T}}},
357414
n_features ,
358415
n_trees ,
359416
impurity_importance :: Bool) where {S, T}
@@ -401,22 +458,30 @@ function apply_forest(forest::Ensemble{S, T}, features::AbstractMatrix{S}) where
401458
return predictions
402459
end
403460

404-
"""
461+
"""
405462
apply_forest_proba(forest::Ensemble, features, col_labels::AbstractVector)
406463
407464
computes P(L=label|X) for each row in `features`. It returns a `N_row x
408465
n_labels` matrix of probabilities, each row summing up to 1.
409466
410467
`col_labels` is a vector containing the distinct labels
411468
(eg. ["versicolor", "virginica", "setosa"]). It specifies the column ordering
412-
of the output matrix.
469+
of the output matrix.
413470
"""
414-
function apply_forest_proba(forest::Ensemble{S, T}, features::AbstractVector{S}, labels) where {S, T}
471+
function apply_forest_proba(
472+
forest::Ensemble{S, T},
473+
features::AbstractVector{S},
474+
labels
475+
) where {S, T}
415476
votes = [apply_tree(tree, features) for tree in forest.trees]
416477
return compute_probabilities(labels, votes)
417478
end
418479

419-
apply_forest_proba(forest::Ensemble{S, T}, features::AbstractMatrix{S}, labels) where {S, T} =
480+
apply_forest_proba(
481+
forest::Ensemble{S, T},
482+
features::AbstractMatrix{S},
483+
labels
484+
) where {S, T} =
420485
stack_function_results(row->apply_forest_proba(forest, row, labels),
421486
features)
422487

@@ -434,7 +499,13 @@ function build_adaboost_stumps(
434499
coeffs = Float64[]
435500
n_features = size(features, 2)
436501
for i in 1:n_iterations
437-
new_stump = build_stump(labels, features, weights; rng=mk_rng(rng), impurity_importance=false)
502+
new_stump = build_stump(
503+
labels,
504+
features,
505+
weights;
506+
rng=mk_rng(rng),
507+
impurity_importance=false
508+
)
438509
predictions = apply_tree(new_stump, features)
439510
err = _weighted_error(labels, predictions, weights)
440511
if err >= thresh # should be better than random guess
@@ -454,9 +525,16 @@ function build_adaboost_stumps(
454525
return (Ensemble{S, T}(stumps, n_features, Float64[]), coeffs)
455526
end
456527

457-
apply_adaboost_stumps(trees::Tuple{<: Ensemble{S, T}, AbstractVector{Float64}}, features::AbstractVecOrMat{S}) where {S, T} = apply_adaboost_stumps(trees..., features)
528+
apply_adaboost_stumps(
529+
trees::Tuple{<: Ensemble{S, T}, AbstractVector{Float64}},
530+
features::AbstractVecOrMat{S}
531+
) where {S, T} = apply_adaboost_stumps(trees..., features)
458532

459-
function apply_adaboost_stumps(stumps::Ensemble{S, T}, coeffs::AbstractVector{Float64}, features::AbstractVector{S}) where {S, T}
533+
function apply_adaboost_stumps(
534+
stumps::Ensemble{S, T},
535+
coeffs::AbstractVector{Float64},
536+
features::AbstractVector{S}
537+
) where {S, T}
460538
n_stumps = length(stumps)
461539
counts = Dict()
462540
for i in 1:n_stumps
@@ -474,7 +552,11 @@ function apply_adaboost_stumps(stumps::Ensemble{S, T}, coeffs::AbstractVector{Fl
474552
return top_prediction
475553
end
476554

477-
function apply_adaboost_stumps(stumps::Ensemble{S, T}, coeffs::AbstractVector{Float64}, features::AbstractMatrix{S}) where {S, T}
555+
function apply_adaboost_stumps(
556+
stumps::Ensemble{S, T},
557+
coeffs::AbstractVector{Float64},
558+
features::AbstractMatrix{S}
559+
) where {S, T}
478560
n_samples = size(features, 1)
479561
predictions = Array{T}(undef, n_samples)
480562
for i in 1:n_samples
@@ -483,23 +565,34 @@ function apply_adaboost_stumps(stumps::Ensemble{S, T}, coeffs::AbstractVector{Fl
483565
return predictions
484566
end
485567

486-
"""
568+
"""
487569
apply_adaboost_stumps_proba(stumps::Ensemble, coeffs, features, labels::AbstractVector)
488570
489571
computes P(L=label|X) for each row in `features`. It returns a `N_row x
490572
n_labels` matrix of probabilities, each row summing up to 1.
491573
492574
`col_labels` is a vector containing the distinct labels
493575
(eg. ["versicolor", "virginica", "setosa"]). It specifies the column ordering
494-
of the output matrix.
576+
of the output matrix.
495577
"""
496-
function apply_adaboost_stumps_proba(stumps::Ensemble{S, T}, coeffs::AbstractVector{Float64},
497-
features::AbstractVector{S}, labels::AbstractVector{T}) where {S, T}
578+
function apply_adaboost_stumps_proba(
579+
stumps::Ensemble{S, T},
580+
coeffs::AbstractVector{Float64},
581+
features::AbstractVector{S},
582+
labels::AbstractVector{T}
583+
) where {S, T}
498584
votes = [apply_tree(stump, features) for stump in stumps.trees]
499585
compute_probabilities(labels, votes, coeffs)
500586
end
501587

502-
function apply_adaboost_stumps_proba(stumps::Ensemble{S, T}, coeffs::AbstractVector{Float64},
503-
features::AbstractMatrix{S}, labels::AbstractVector{T}) where {S, T}
504-
stack_function_results(row->apply_adaboost_stumps_proba(stumps, coeffs, row, labels), features)
588+
function apply_adaboost_stumps_proba(
589+
stumps::Ensemble{S, T},
590+
coeffs::AbstractVector{Float64},
591+
features::AbstractMatrix{S},
592+
labels::AbstractVector{T}
593+
) where {S, T}
594+
stack_function_results(
595+
row->apply_adaboost_stumps_proba(stumps, coeffs, row, labels),
596+
features
597+
)
505598
end

0 commit comments

Comments
 (0)