Skip to content

Commit 3d0f314

Browse files
committed
Feature importance
1 parent be7a715 commit 3d0f314

File tree

18 files changed

+840
-334
lines changed

18 files changed

+840
-334
lines changed

src/DecisionTree.jl

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ using Random
99
using Statistics
1010
import AbstractTrees
1111

12-
export Leaf, Node, Ensemble, print_tree, depth, build_stump, build_tree,
12+
export Leaf, Node, Root, Ensemble, print_tree, depth, build_stump, build_tree,
1313
prune_tree, apply_tree, apply_tree_proba, nfoldCV_tree, build_forest,
1414
apply_forest, apply_forest_proba, nfoldCV_forest, build_adaboost_stumps,
1515
apply_adaboost_stumps, apply_adaboost_stumps_proba, nfoldCV_stumps,
16-
majority_vote, ConfusionMatrix, confusion_matrix, mean_squared_error, R2, load_data
16+
majority_vote, ConfusionMatrix, confusion_matrix, mean_squared_error, R2, load_data,
17+
impurity_importance, split_importance, permutation_importance, accuracy
1718

1819
# ScikitLearn API
1920
export DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier,
@@ -42,17 +43,32 @@ end
4243

4344
const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
4445

46+
struct Root{S, T}
47+
node :: LeafOrNode{S, T}
48+
n_feat :: Int
49+
featim :: Vector{Float64} # impurity importance
50+
end
51+
4552
struct Ensemble{S, T}
46-
trees :: Vector{LeafOrNode{S, T}}
53+
trees :: Vector{LeafOrNode{S, T}}
54+
n_feat :: Int
55+
featim :: Vector{Float64}
4756
end
4857

58+
4959
is_leaf(l::Leaf) = true
5060
is_leaf(n::Node) = false
5161

52-
zero(String) = ""
62+
zero(::Type{String}) = ""
5363
convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, zero(S), lf, Leaf(zero(T), [zero(T)]))
64+
convert(::Type{Root{S, T}}, node::LeafOrNode{S, T}) where {S, T} = Root{S, T}(node, 0, Float64[])
65+
convert(::Type{LeafOrNode{S, T}}, tree::Root{S, T}) where {S, T} = tree.node
5466
promote_rule(::Type{Node{S, T}}, ::Type{Leaf{T}}) where {S, T} = Node{S, T}
5567
promote_rule(::Type{Leaf{T}}, ::Type{Node{S, T}}) where {S, T} = Node{S, T}
68+
promote_rule(::Type{Root{S, T}}, ::Type{Leaf{T}}) where {S, T} = Root{S, T}
69+
promote_rule(::Type{Leaf{T}}, ::Type{Root{S, T}}) where {S, T} = Root{S, T}
70+
promote_rule(::Type{Root{S, T}}, ::Type{Node{S, T}}) where {S, T} = Root{S, T}
71+
promote_rule(::Type{Node{S, T}}, ::Type{Root{S, T}}) where {S, T} = Root{S, T}
5672

5773
# make a Random Number Generator object
5874
mk_rng(rng::Random.AbstractRNG) = rng
@@ -75,10 +91,12 @@ include("abstract_trees.jl")
7591

7692
length(leaf::Leaf) = 1
7793
length(tree::Node) = length(tree.left) + length(tree.right)
94+
length(tree::Root) = length(tree.node)
7895
length(ensemble::Ensemble) = length(ensemble.trees)
7996

8097
depth(leaf::Leaf) = 0
8198
depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
99+
depth(tree::Root) = depth(tree.node)
82100

83101
function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
84102
n_matches = count(leaf.values .== leaf.majority)
@@ -90,6 +108,13 @@ function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
90108
end
91109

92110

111+
function print_tree(io::IO, tree::Root, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
112+
return print_tree(io, tree.node, depth, indent; sigdigits=sigdigits, feature_names=feature_names)
113+
end
114+
function print_tree(tree::Root, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
115+
return print_tree(stdout, tree, depth, indent; sigdigits=sigdigits, feature_names=feature_names)
116+
end
117+
93118
"""
94119
print_tree([io::IO,] tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
95120
@@ -113,9 +138,9 @@ Feature 3 < -28.15 ?
113138
└─ 8 : 1227/3508
114139
```
115140
116-
To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or
117-
`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the
118-
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
141+
To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object,
142+
`DecisionTree.Node` object or `DecisionTree.Root` object can be wrapped to obtain a tree structure implementing the
143+
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
119144
"""
120145
function print_tree(io::IO, tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
121146
if depth == indent
@@ -149,6 +174,12 @@ function show(io::IO, tree::Node)
149174
print(io, "Depth: $(depth(tree))")
150175
end
151176

177+
function show(io::IO, tree::Root)
178+
println(io, "Decision Tree")
179+
println(io, "Leaves: $(length(tree))")
180+
print(io, "Depth: $(depth(tree))")
181+
end
182+
152183
function show(io::IO, ensemble::Ensemble)
153184
println(io, "Ensemble of Decision Trees")
154185
println(io, "Trees: $(length(ensemble))")

src/abstract_trees.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ In the first case `dc` gets just wrapped, no information is added. No. 2 adds fe
6666
as well as class labels. In the last two cases either of this information is added (Note the
6767
trailing comma; it's needed to make it a tuple).
6868
"""
69+
wrap(tree::DecisionTree.Root, info::NamedTuple = NamedTuple()) = wrap(tree.node, info)
6970
wrap(node::DecisionTree.Node, info::NamedTuple = NamedTuple()) = InfoNode(node, info)
7071
wrap(leaf::DecisionTree.Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info)
7172

0 commit comments

Comments
 (0)