You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
feature_importance[tree.featid] -= (nt *loss(nc, nt) - nl *loss(ncl, nl) - nr *loss(ncr, nr)) / ntt
92
+
feature_importance[tree.featid] -=
93
+
(nt *loss(nc, nt) - nl *loss(ncl, nl) - nr *loss(ncr, nr)) / ntt
84
94
end
85
95
86
-
functionupdate_pruned_impurity!(tree::LeafOrNode{S, T}, feature_importance::Vector{Float64}, ntt::Int, loss::Function= mean_squared_error) where {S, T <:Float64}
96
+
functionupdate_pruned_impurity!(
97
+
tree::LeafOrNode{S, T},
98
+
feature_importance::Vector{Float64},
99
+
ntt::Int,
100
+
loss::Function= mean_squared_error
101
+
) where {S, T <:Float64}
87
102
μl =mean(tree.left.values)
88
103
nl =length(tree.left.values)
89
104
μr =mean(tree.right.values)
90
105
nr =length(tree.right.values)
91
106
nt = nl + nr
92
-
μt = (nl * μl + nr * μr) / nt
107
+
μt = (nl * μl + nr * μr) / nt
93
108
feature_importance[tree.featid] -= (nt *loss([tree.left.values; tree.right.values], repeat([μt], nt)) - nl *loss(tree.left.values, repeat([μl], nl)) - nr *loss(tree.right.values, repeat([μr], nr))) / ntt
Prune tree based on prediction accuracy of each node.
171
-
* `purity_thresh`: If the prediction accuracy of a stump is larger than this value, the node will be pruned and become a leaf.
172
-
* `loss`: The loss function for computing node impurity. Available function include `DecisionTree.util.entropy`, `DecisionTree.util.gini` and `mean_squared_error`. Defaults are `DecisionTree.util.entropy` and `mean_squared_error` for classification tree and regression tree, respectively. If the tree is not a `Root`, this argument does not affect the result.
173
192
174
-
For a tree of type `Root`, when any of its nodes is pruned, the `featim` field will be updated by recomputing the impurity decrease of that node divided by the total number of training observations and subtracting the value.
175
-
The computation of impurity decrease is based on node impurity calculated with the loss function provided as the argument `loss`. The algorithm is as same as that described in the `impurity_importance` documentation.
193
+
* `purity_thresh`: If the prediction accuracy of a stump is larger than this value, the node
194
+
will be pruned and become a leaf.
195
+
196
+
* `loss`: The loss function for computing node impurity. Available function include
197
+
`DecisionTree.util.entropy`, `DecisionTree.util.gini` and `mean_squared_error`. Defaults
198
+
are `DecisionTree.util.entropy` and `mean_squared_error` for classification tree and
199
+
regression tree, respectively. If the tree is not a `Root`, this argument does not affect
200
+
the result.
201
+
202
+
For a tree of type `Root`, when any of its nodes is pruned, the `featim` field will be
203
+
updated by recomputing the impurity decrease of that node divided by the total number of
204
+
training observations and subtracting the value. The computation of impurity decrease is
205
+
based on node impurity calculated with the loss function provided as the argument
206
+
`loss`. The algorithm is as same as that described in the `impurity_importance`
207
+
documentation.
208
+
176
209
This function will recurse until no stumps can be pruned.
177
210
178
211
Warn:
179
212
For regression trees, pruning trees based on accuracy may not be an appropriate method.
180
213
"""
181
-
functionprune_tree(tree::Union{Root{S, T}, LeafOrNode{S, T}}, purity_thresh=1.0, loss::Function= T <:Float64? mean_squared_error : util.entropy) where {S, T}
214
+
functionprune_tree(
215
+
tree::Union{Root{S, T}, LeafOrNode{S, T}},
216
+
purity_thresh=1.0,
217
+
loss::Function= T <:Float64? mean_squared_error : util.entropy
218
+
) where {S, T}
182
219
if purity_thresh >=1.0
183
220
return tree
184
221
end
185
222
ntt =nsample(tree)
186
-
function_prune_run_stump(tree::LeafOrNode{S, T}, purity_thresh::Real, fi::Vector{Float64}= Float64[]) where {S, T}
0 commit comments