1- __precompile__ ()
2-
31module DecisionTree
42
53import Base: length, show, convert, promote_rule, zero
@@ -9,11 +7,11 @@ using Random
97using Statistics
108import AbstractTrees
119
12- export Leaf, Node, Ensemble, print_tree, depth, build_stump, build_tree,
10+ export Leaf, Node, Root, Ensemble, print_tree, depth, build_stump, build_tree,
1311 prune_tree, apply_tree, apply_tree_proba, nfoldCV_tree, build_forest,
1412 apply_forest, apply_forest_proba, nfoldCV_forest, build_adaboost_stumps,
1513 apply_adaboost_stumps, apply_adaboost_stumps_proba, nfoldCV_stumps,
16- majority_vote, ConfusionMatrix, confusion_matrix, mean_squared_error, R2, load_data
14+ load_data, impurity_importance, split_importance, permutation_importance
1715
1816# ScikitLearn API
1917export DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier,
4240
4341const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
4442
43+ struct Root{S, T}
44+ node :: LeafOrNode{S, T}
45+ n_feat :: Int
46+ featim :: Vector{Float64} # impurity importance
47+ end
48+
4549struct Ensemble{S, T}
46- trees :: Vector{LeafOrNode{S, T}}
50+ trees :: Vector{LeafOrNode{S, T}}
51+ n_feat :: Int
52+ featim :: Vector{Float64}
4753end
4854
55+
4956is_leaf (l:: Leaf ) = true
5057is_leaf (n:: Node ) = false
5158
52- zero (String) = " "
59+ zero (:: Type{ String} ) = " "
5360convert (:: Type{Node{S, T}} , lf:: Leaf{T} ) where {S, T} = Node (0 , zero (S), lf, Leaf (zero (T), [zero (T)]))
61+ convert (:: Type{Root{S, T}} , node:: LeafOrNode{S, T} ) where {S, T} = Root {S, T} (node, 0 , Float64[])
62+ convert (:: Type{LeafOrNode{S, T}} , tree:: Root{S, T} ) where {S, T} = tree. node
5463promote_rule (:: Type{Node{S, T}} , :: Type{Leaf{T}} ) where {S, T} = Node{S, T}
5564promote_rule (:: Type{Leaf{T}} , :: Type{Node{S, T}} ) where {S, T} = Node{S, T}
65+ promote_rule (:: Type{Root{S, T}} , :: Type{Leaf{T}} ) where {S, T} = Root{S, T}
66+ promote_rule (:: Type{Leaf{T}} , :: Type{Root{S, T}} ) where {S, T} = Root{S, T}
67+ promote_rule (:: Type{Root{S, T}} , :: Type{Node{S, T}} ) where {S, T} = Root{S, T}
68+ promote_rule (:: Type{Node{S, T}} , :: Type{Root{S, T}} ) where {S, T} = Root{S, T}
5669
5770# make a Random Number Generator object
5871mk_rng (rng:: Random.AbstractRNG ) = rng
@@ -75,10 +88,12 @@ include("abstract_trees.jl")
7588
7689length (leaf:: Leaf ) = 1
7790length (tree:: Node ) = length (tree. left) + length (tree. right)
91+ length (tree:: Root ) = length (tree. node)
7892length (ensemble:: Ensemble ) = length (ensemble. trees)
7993
8094depth (leaf:: Leaf ) = 0
8195depth (tree:: Node ) = 1 + max (depth (tree. left), depth (tree. right))
96+ depth (tree:: Root ) = depth (tree. node)
8297
8398function print_tree (io:: IO , leaf:: Leaf , depth= - 1 , indent= 0 ; feature_names= nothing )
8499 n_matches = count (leaf. values .== leaf. majority)
@@ -90,6 +105,13 @@ function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
90105end
91106
92107
108+ function print_tree (io:: IO , tree:: Root , depth= - 1 , indent= 0 ; sigdigits= 2 , feature_names= nothing )
109+ return print_tree (io, tree. node, depth, indent; sigdigits= sigdigits, feature_names= feature_names)
110+ end
111+ function print_tree (tree:: Root , depth= - 1 , indent= 0 ; sigdigits= 2 , feature_names= nothing )
112+ return print_tree (stdout , tree, depth, indent; sigdigits= sigdigits, feature_names= feature_names)
113+ end
114+
93115"""
94116 print_tree([io::IO,] tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
95117
@@ -113,9 +135,9 @@ Feature 3 < -28.15 ?
113135 └─ 8 : 1227/3508
114136```
115137
116- To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or
117- `DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the
118- AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
138+ To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object,
139+ `DecisionTree.Node` object or `DecisionTree.Root` object can be wrapped to obtain a tree structure implementing the
140+ AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
119141"""
120142function print_tree (io:: IO , tree:: Node , depth= - 1 , indent= 0 ; sigdigits= 2 , feature_names= nothing )
121143 if depth == indent
@@ -149,6 +171,12 @@ function show(io::IO, tree::Node)
149171 print (io, " Depth: $(depth (tree)) " )
150172end
151173
174+ function show (io:: IO , tree:: Root )
175+ println (io, " Decision Tree" )
176+ println (io, " Leaves: $(length (tree)) " )
177+ print (io, " Depth: $(depth (tree)) " )
178+ end
179+
152180function show (io:: IO , ensemble:: Ensemble )
153181 println (io, " Ensemble of Decision Trees" )
154182 println (io, " Trees: $(length (ensemble)) " )
0 commit comments