@@ -9,11 +9,12 @@ using Random
99using Statistics
1010import AbstractTrees
1111
12- export Leaf, Node, Ensemble, print_tree, depth, build_stump, build_tree,
12+ export Leaf, Node, Root, Ensemble, print_tree, depth, build_stump, build_tree,
1313 prune_tree, apply_tree, apply_tree_proba, nfoldCV_tree, build_forest,
1414 apply_forest, apply_forest_proba, nfoldCV_forest, build_adaboost_stumps,
1515 apply_adaboost_stumps, apply_adaboost_stumps_proba, nfoldCV_stumps,
16- majority_vote, ConfusionMatrix, confusion_matrix, mean_squared_error, R2, load_data
16+ majority_vote, ConfusionMatrix, confusion_matrix, mean_squared_error, R2, load_data,
17+ impurity_importance, split_importance, permutation_importance, accuracy
1718
1819# ScikitLearn API
1920export DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier,
4243
4344const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
4445
46+ struct Root{S, T}
47+ node :: LeafOrNode{S, T}
48+ n_feat :: Int
49+ featim :: Vector{Float64} # impurity importance
50+ end
51+
4552struct Ensemble{S, T}
46- trees :: Vector{LeafOrNode{S, T}}
53+ trees :: Vector{LeafOrNode{S, T}}
54+ n_feat :: Int
55+ featim :: Vector{Float64}
4756end
4857
58+
4959is_leaf (l:: Leaf ) = true
5060is_leaf (n:: Node ) = false
5161
52- zero (String) = " "
62+ zero (:: Type{ String} ) = " "
5363convert (:: Type{Node{S, T}} , lf:: Leaf{T} ) where {S, T} = Node (0 , zero (S), lf, Leaf (zero (T), [zero (T)]))
64+ convert (:: Type{Root{S, T}} , node:: LeafOrNode{S, T} ) where {S, T} = Root {S, T} (node, 0 , Float64[])
65+ convert (:: Type{LeafOrNode{S, T}} , tree:: Root{S, T} ) where {S, T} = tree. node
5466promote_rule (:: Type{Node{S, T}} , :: Type{Leaf{T}} ) where {S, T} = Node{S, T}
5567promote_rule (:: Type{Leaf{T}} , :: Type{Node{S, T}} ) where {S, T} = Node{S, T}
68+ promote_rule (:: Type{Root{S, T}} , :: Type{Leaf{T}} ) where {S, T} = Root{S, T}
69+ promote_rule (:: Type{Leaf{T}} , :: Type{Root{S, T}} ) where {S, T} = Root{S, T}
70+ promote_rule (:: Type{Root{S, T}} , :: Type{Node{S, T}} ) where {S, T} = Root{S, T}
71+ promote_rule (:: Type{Node{S, T}} , :: Type{Root{S, T}} ) where {S, T} = Root{S, T}
5672
5773# make a Random Number Generator object
5874mk_rng (rng:: Random.AbstractRNG ) = rng
@@ -75,10 +91,12 @@ include("abstract_trees.jl")
7591
7692length (leaf:: Leaf ) = 1
7793length (tree:: Node ) = length (tree. left) + length (tree. right)
94+ length (tree:: Root ) = length (tree. node)
7895length (ensemble:: Ensemble ) = length (ensemble. trees)
7996
8097depth (leaf:: Leaf ) = 0
8198depth (tree:: Node ) = 1 + max (depth (tree. left), depth (tree. right))
99+ depth (tree:: Root ) = depth (tree. node)
82100
83101function print_tree (io:: IO , leaf:: Leaf , depth= - 1 , indent= 0 ; feature_names= nothing )
84102 n_matches = count (leaf. values .== leaf. majority)
@@ -90,6 +108,13 @@ function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
90108end
91109
92110
111+ function print_tree (io:: IO , tree:: Root , depth= - 1 , indent= 0 ; sigdigits= 2 , feature_names= nothing )
112+ return print_tree (io, tree. node, depth, indent; sigdigits= sigdigits, feature_names= feature_names)
113+ end
114+ function print_tree (tree:: Root , depth= - 1 , indent= 0 ; sigdigits= 2 , feature_names= nothing )
115+ return print_tree (stdout , tree, depth, indent; sigdigits= sigdigits, feature_names= feature_names)
116+ end
117+
93118"""
94119 print_tree([io::IO,] tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
95120
@@ -113,9 +138,9 @@ Feature 3 < -28.15 ?
113138 └─ 8 : 1227/3508
114139```
115140
116- To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or
117- `DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the
118- AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
141+ To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object,
142+ `DecisionTree.Node` object or `DecisionTree.Root` object can be wrapped to obtain a tree structure implementing the
143+ AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
119144"""
120145function print_tree (io:: IO , tree:: Node , depth= - 1 , indent= 0 ; sigdigits= 2 , feature_names= nothing )
121146 if depth == indent
@@ -149,6 +174,12 @@ function show(io::IO, tree::Node)
149174 print (io, " Depth: $(depth (tree)) " )
150175end
151176
177+ function show (io:: IO , tree:: Root )
178+ println (io, " Decision Tree" )
179+ println (io, " Leaves: $(length (tree)) " )
180+ print (io, " Depth: $(depth (tree)) " )
181+ end
182+
152183function show (io:: IO , ensemble:: Ensemble )
153184 println (io, " Ensemble of Decision Trees" )
154185 println (io, " Trees: $(length (ensemble)) " )
0 commit comments