Skip to content

Commit 9796e9d

Browse files
authored
Merge pull request #161 from JuliaAI/dev
For a 0.10.12 release
2 parents 0a6097e + 9165f12 commit 9796e9d

File tree

10 files changed

+292
-68
lines changed

10 files changed

+292
-68
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
styleguide.txt
33
makefile
44
.DS_Store
5+
Manifest.toml

Manifest.toml

Lines changed: 0 additions & 61 deletions
This file was deleted.

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ name = "DecisionTree"
22
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
33
license = "MIT"
44
desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms"
5-
version = "0.10.11"
5+
version = "0.10.12"
66

77
[deps]
8+
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
89
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
910
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

README.md

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# DecisionTree.jl
1+
# DecisionTree.jl
22

33
[![CI](https://github.com/JuliaAI/DecisionTree.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/DecisionTree.jl/actions?query=workflow%3ACI)
44
[![Codecov](https://codecov.io/gh/JuliaAI/DecisionTree.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaAI/DecisionTree.jl)
@@ -12,7 +12,7 @@ the [JuliaAI](https://github.com/JuliaAI) organization.
1212
Available via:
1313
* [AutoMLPipeline.jl](https://github.com/IBM/AutoMLPipeline.jl) - create complex ML pipeline structures using simple expressions
1414
* [CombineML.jl](https://github.com/ppalmes/CombineML.jl) - a heterogeneous ensemble learning package
15-
* [MLJ.jl](https://github.com/alan-turing-institute/MLJ.jl) - a machine learning framework for Julia
15+
* [MLJ.jl](https://alan-turing-institute.github.io/MLJ.jl/dev/) - a machine learning framework for Julia
1616
* [ScikitLearn.jl](https://github.com/cstjean/ScikitLearn.jl) - Julia implementation of the scikit-learn API
1717

1818
## Classification
@@ -285,10 +285,46 @@ r2 = nfoldCV_forest(labels, features,
285285
rng = seed)
286286
```
287287

288+
## MLJ.jl API
289+
290+
To use DecsionTree.jl models in
291+
[MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/), first
292+
ensure MLJ.jl and MLJDecisionTreeInterface.jl are both in your Julia
293+
environment. For example, to install in a fresh environment:
294+
295+
```julia
296+
using Pkg
297+
Pkg.activate("my_fresh_mlj_environment", shared=true)
298+
Pkg.add("MLJ")
299+
Pkg.add("MLJDecisionTreeInterface")
300+
```
301+
302+
Detailed usage instructions are available for each model using the
303+
`doc` method. For example:
304+
305+
```julia
306+
using MLJ
307+
doc("DecisionTreeClassifier", pkg="DecisionTree")
308+
```
309+
310+
Available models are: `AdaBoostStumpClassifier`,
311+
`DecisionTreeClassifier`, `DecisionTreeRegressor`,
312+
`RandomForestClassifier`, `RandomForestRegressor`.
313+
314+
288315
## Saving Models
289316
Models can be saved to disk and loaded back with the use of the [JLD2.jl](https://github.com/JuliaIO/JLD2.jl) package.
290317
```julia
291318
using JLD2
292319
@save "model_file.jld2" model
293320
```
294321
Note that even though features and labels of type `Array{Any}` are supported, it is highly recommended that data be cast to explicit types (ie with `float.(), string.()`, etc). This significantly improves model training and prediction execution times, and also drastically reduces the size of saved models.
322+
323+
## Visualization
324+
A `DecisionTree` model can be visualized using the `print_tree`-function of its native interface
325+
(for an example see above in section 'Classification Example').
326+
327+
In addition, an abstraction layer using `AbstractTrees.jl` has been implemented with the intention to facilitate visualizations, which don't rely on any implementation details of `DecisionTree`. For more information have a look at the docs in `src/abstract_trees.jl` and the [`wrap`](@ref)-function, which creates this layer for a `DecisionTree` model.
328+
329+
Apart from this, `AbstractTrees.jl` brings its own implementation of `print_tree`.
330+

src/DecisionTree.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using DelimitedFiles
77
using LinearAlgebra
88
using Random
99
using Statistics
10+
import AbstractTrees
1011

1112
export Leaf, Node, Ensemble, print_tree, depth, build_stump, build_tree,
1213
prune_tree, apply_tree, apply_tree_proba, nfoldCV_tree, build_forest,
@@ -22,6 +23,7 @@ export DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier,
2223
# `using ScikitLearnBase`.
2324
predict, predict_proba, fit!, get_classes
2425

26+
export InfoNode, InfoLeaf, wrap
2527

2628
###########################
2729
########## Types ##########
@@ -65,6 +67,7 @@ include("util.jl")
6567
include("classification/main.jl")
6668
include("regression/main.jl")
6769
include("scikitlearnAPI.jl")
70+
include("abstract_trees.jl")
6871

6972

7073
#############################
@@ -83,6 +86,35 @@ function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
8386
println("$(leaf.majority) : $(ratio)")
8487
end
8588

89+
"""
90+
print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing)
91+
92+
Print a textual visualization of the given decision tree `tree`.
93+
In the example output below, the top node considers whether
94+
"Feature 3" is above or below the threshold -28.156052806422238.
95+
If the value of "Feature 3" is strictly below the threshold for some input to be classified,
96+
we move to the `L->` part underneath, which is a node
97+
looking at if "Feature 2" is above or below -161.04351901384842.
98+
If the value of "Feature 2" is strictly below the threshold for some input to be classified,
99+
we end up at `L-> 5 : 842/3650`. This is to be read as "In the left split,
100+
the tree will classify the input as class 5, as 842 of the 3650 datapoints
101+
in the training data that ended up here were of class 5."
102+
103+
# Example output:
104+
```
105+
Feature 3, Threshold -28.156052806422238
106+
L-> Feature 2, Threshold -161.04351901384842
107+
L-> 5 : 842/3650
108+
R-> 7 : 2493/10555
109+
R-> Feature 7, Threshold 108.1408338577021
110+
L-> 2 : 2434/15287
111+
R-> 8 : 1227/3508
112+
```
113+
114+
To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or
115+
`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the
116+
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
117+
"""
86118
function print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing)
87119
if depth == indent
88120
println()

src/abstract_trees.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""
2+
Implementation of the `AbstractTrees.jl`-interface
3+
(see: [AbstractTrees.jl](https://github.com/JuliaCollections/AbstractTrees.jl)).
4+
5+
The functions `children` and `printnode` make up the interface traits of `AbstractTrees.jl`
6+
(see below for details).
7+
8+
The goal of this implementation is to wrap a `DecisionTree` in this abstract layer,
9+
so that a plot recipe for visualization of the tree can be created that doesn't rely
10+
on any implementation details of `DecisionTree.jl`. That opens the possibility to create
11+
a plot recipe which can be used by a variety of tree-like models.
12+
13+
For a more detailed explanation of this concept have a look at the follwing article
14+
in "Towards Data Science":
15+
["If things are not ready to use"](https://towardsdatascience.com/part-iii-if-things-are-not-ready-to-use-59d2db378bec)
16+
"""
17+
18+
19+
"""
20+
InfoNode{S, T}
21+
InfoLeaf{T}
22+
23+
These types are introduced so that additional information currently not present in
24+
a `DecisionTree`-structure -- namely the feature names and the class labels --
25+
can be used for visualization. This additional information is stored in the attribute `info` of
26+
these types. It is a `NamedTuple`. So it can be used to store arbitraty information,
27+
apart from the two points mentioned.
28+
29+
In analogy to the type definitions of `DecisionTree`, the generic type `S` is
30+
the type of the feature values used within a node as a threshold for the splits
31+
between its children and `T` is the type of the classes given (these might be ids or labels).
32+
"""
33+
struct InfoNode{S, T}
34+
node :: DecisionTree.Node{S, T}
35+
info :: NamedTuple
36+
end
37+
38+
struct InfoLeaf{T}
39+
leaf :: DecisionTree.Leaf{T}
40+
info :: NamedTuple
41+
end
42+
43+
"""
44+
wrap(node::DecisionTree.Node, info = NamedTuple())
45+
wrap(leaf::DecisionTree.Leaf, info = NamedTuple())
46+
47+
Add to each `node` (or `leaf`) the additional information `info`
48+
and wrap both in an `InfoNode`/`InfoLeaf`.
49+
50+
Typically a `node` or a `leaf` is obtained by creating a decision tree using either
51+
the native interface of `DecisionTree.jl` or via other interfaces which are available
52+
for this package (like `MLJ`, ScikitLearn; see their docs for further details).
53+
Using the function `build_tree` of the native interface returns such an object.
54+
55+
To use a DecisionTree `dc` (obtained this way) with the abstraction layer
56+
provided by the `AbstractTrees`-interface implemented here
57+
and optionally add feature names `feature_names` and/or `class_labels`
58+
(both: arrays of strings) use the following syntax:
59+
60+
1. `wdc = wrap(dc)`
61+
2. `wdc = wrap(dc, (featurenames = feature_names, classlabels = class_labels))`
62+
3. `wdc = wrap(dc, (featurenames = feature_names, ))`
63+
4. `wdc = wrap(dc, (classlabels = class_labels, ))`
64+
65+
In the first case `dc` gets just wrapped, no information is added. No. 2 adds feature names
66+
as well as class labels. In the last two cases either of this information is added (Note the
67+
trailing comma; it's needed to make it a tuple).
68+
"""
69+
wrap(node::DecisionTree.Node, info::NamedTuple = NamedTuple()) = InfoNode(node, info)
70+
wrap(leaf::DecisionTree.Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info)
71+
72+
"""
73+
children(node::InfoNode)
74+
75+
Return for each `node` given, its children.
76+
77+
In case of a `DecisionTree` there are always exactly two children, because
78+
the model produces binary trees where all nodes have exactly one left and
79+
one right child. `children` is used for tree traversal.
80+
81+
The additional information `info` is carried over from `node` to its children.
82+
"""
83+
AbstractTrees.children(node::InfoNode) = (
84+
wrap(node.node.left, node.info),
85+
wrap(node.node.right, node.info)
86+
)
87+
AbstractTrees.children(node::InfoLeaf) = ()
88+
89+
"""
90+
printnode(io::IO, node::InfoNode)
91+
printnode(io::IO, leaf::InfoLeaf)
92+
93+
Write a printable representation of `node` or `leaf` to output-stream `io`.
94+
95+
If `node.info`/`leaf.info` have a field called
96+
- `featurenames` it is expected to have an array of feature names corresponding
97+
to the feature ids used in the `DecsionTree`s nodes.
98+
They will be used for printing instead of the ids.
99+
- `classlabels` it is expected to have an array of class labels corresponding
100+
to the class ids used in the `DecisionTree`s leaves.
101+
They will be used for printing instead of the ids.
102+
(Note: DecisionTrees created using MLJ use ids in their leaves;
103+
otherwise class labels are present)
104+
105+
For the condition of the form `feature < value` which gets printed in the `printnode`
106+
variant for `InfoNode`, the left subtree is the 'yes-branch' and the right subtree
107+
accordingly the 'no-branch'. `AbstractTrees.print_tree` outputs the left subtree first
108+
and then below the right subtree.
109+
"""
110+
function AbstractTrees.printnode(io::IO, node::InfoNode)
111+
if :featurenames keys(node.info)
112+
print(io, node.info.featurenames[node.node.featid], " < ", node.node.featval)
113+
else
114+
print(io, "Feature: ", node.node.featid, " < ", node.node.featval)
115+
end
116+
end
117+
118+
function AbstractTrees.printnode(io::IO, leaf::InfoLeaf)
119+
dt_leaf = leaf.leaf
120+
matches = findall(dt_leaf.values .== dt_leaf.majority)
121+
match_count = length(matches)
122+
val_count = length(dt_leaf.values)
123+
if :classlabels keys(leaf.info)
124+
print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)")
125+
else
126+
print(io, "Class: ", dt_leaf.majority, " ($match_count/$val_count)")
127+
end
128+
end

test/classification/iris.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,6 @@ println("\n##### nfoldCV Classification Adaboosted Stumps #####")
9898
n_iterations = 15
9999
nfolds = 3
100100
accuracy = nfoldCV_stumps(labels, features, nfolds, n_iterations)
101-
@test mean(accuracy) > 0.9
101+
@test mean(accuracy) > 0.85
102102

103103
end # @testset

0 commit comments

Comments
 (0)