Skip to content

Commit 0715d4b

Browse files
committed
Refactor Learn transform
1 parent 0ef2255 commit 0715d4b

File tree

4 files changed

+58
-93
lines changed

4 files changed

+58
-93
lines changed

src/interface.jl

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,21 @@
22
# Licensed under the MIT License. See LICENSE in the project root.
33
# ------------------------------------------------------------------
44

5-
"""
6-
StatsLearnModel(model, features, targets)
7-
8-
Wrap a statistical learning `model` with selectors
9-
of `features` and `targets`.
10-
11-
## Examples
12-
13-
```julia
14-
StatsLearnModel(DecisionTreeClassifier(), ["x1","x2"], "y")
15-
StatsLearnModel(DecisionTreeClassifier(), 1:3, "target")
16-
```
17-
"""
18-
struct StatsLearnModel{M,F<:ColumnSelector,T<:ColumnSelector}
19-
model::M
20-
feats::F
21-
targs::T
22-
end
23-
24-
StatsLearnModel(model, feats, targs) = StatsLearnModel(model, selector(feats), selector(targs))
25-
265
"""
276
fit(model, input, output)
287
29-
Fit statistical learning `model` using features in `input` table
30-
and targets in `output` table. Returns a fitted model with all
31-
the necessary information for prediction with the `predict` function.
8+
Fit statistical learning `model` using predictors
9+
in `input` table and targets in `output` table.
10+
Returns a fitted model with all the necessary
11+
information for prediction with [`predict`](@ref).
3212
"""
3313
function fit end
3414

35-
function Base.show(io::IO, model::StatsLearnModel{M}) where {M}
36-
println(io, "StatsLearnModel{$(nameof(M))}")
37-
println(io, "├─ features: $(model.feats)")
38-
print(io, "└─ targets: $(model.targs)")
39-
end
40-
4115
"""
4216
FittedStatsLearnModel(model, cache)
4317
44-
Wrap the statistical learning `model` with the `cache`
45-
produced during the [`fit`](@ref) stage.
18+
Wrap the statistical learning `model` with the
19+
`cache` produced during the [`fit`](@ref) stage.
4620
"""
4721
struct FittedStatsLearnModel{M,C}
4822
model::M
@@ -53,7 +27,9 @@ end
5327
predict(model::FittedStatsLearnModel, table)
5428
5529
Predict targets using the fitted statistical
56-
learning `model` and a new `table` of features.
30+
learning `model` and a new `table` containing
31+
the same predictors used during the [`fit`](@ref)
32+
stage.
5733
"""
5834
function predict end
5935

src/learn.jl

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,55 +3,69 @@
33
# ------------------------------------------------------------------
44

55
"""
6-
Learn(train, model, features => targets)
6+
Learn(table; [model])
77
8-
Fits the statistical learning `model` to `train` table,
9-
using the selectors of `features` and `targets`.
8+
Perform supervised learning with labeled `table` and
9+
statistical learning `model`.
10+
11+
Uses `KNNClassifier(1)` or `KNNRegressor(1)` model by
12+
default depending on the scientific type of the labels
13+
stored in the table.
1014
1115
## Examples
1216
1317
```julia
14-
Learn(train, model, [1, 2, 3] => "d")
15-
Learn(train, model, [:a, :b, :c] => :d)
16-
Learn(train, model, ["a", "b", "c"] => 4)
17-
Learn(train, model, [1, 2, 3] => [:d, :e])
18-
Learn(train, model, r"[abc]" => ["d", "e"])
18+
Learn(label(table, "y"))
19+
Learn(label(table, ["y1", "y2"]))
20+
Learn(label(table, 3), model=KNNClassifier(5))
1921
```
22+
23+
See also [`label`](@ref).
2024
"""
21-
struct Learn{M<:FittedStatsLearnModel} <: StatelessFeatureTransform
25+
struct Learn{T<:LabeledTable,M} <: StatelessFeatureTransform
26+
table::T
2227
model::M
23-
feats::Vector{Symbol}
2428
end
2529

26-
Learn(train, model, (feats, targs)::Pair) = Learn(train, StatsLearnModel(model, feats, targs))
27-
28-
function Learn(train, lmodel::StatsLearnModel)
29-
if !Tables.istable(train)
30-
throw(ArgumentError("training data must be a table"))
31-
end
32-
33-
cols = Tables.columns(train)
34-
names = Tables.columnnames(cols)
35-
feats = lmodel.feats(names)
36-
targs = lmodel.targs(names)
30+
Learn(table::LabeledTable; model=_defaultmodel(table)) = Learn(table, model)
3731

38-
input = (; (var => Tables.getcolumn(cols, var) for var in feats)...)
39-
output = (; (var => Tables.getcolumn(cols, var) for var in targs)...)
32+
function applyfeat(transform::Learn, feat, prep)
33+
# labeled table and model
34+
table = transform.table
35+
model = transform.model
4036

41-
fmodel = fit(lmodel.model, input, output)
37+
# variables in labeled table
38+
cols = Tables.columns(table)
39+
vars = Tables.columnnames(cols)
4240

43-
Learn(fmodel, feats)
44-
end
41+
# split targets and predictors
42+
targs = table.labels
43+
preds = setdiff(vars, targs)
4544

46-
isrevertible(::Type{<:Learn}) = false
45+
# learn function with statistical model
46+
input = (; (pred => Tables.getcolumn(cols, pred) for pred in preds)...)
47+
output = (; (targ => Tables.getcolumn(cols, targ) for targ in targs)...)
48+
fmodel = fit(model, input, output)
4749

48-
function applyfeat(transform::Learn, feat, prep)
49-
model = transform.model
50-
vars = transform.feats
50+
# predict labels with new predictors
51+
fcols = Tables.columns(feat)
52+
fvars = Tables.columnnames(fcols)
53+
preds fvars || throw(ArgumentError("predictors $preds not found in input table"))
54+
finput = (; (pred => Tables.getcolumn(fcols, pred) for pred in preds)...)
55+
foutput = predict(fmodel, finput) |> Tables.materializer(feat)
5156

52-
cols = Tables.columns(feat)
53-
pairs = (var => Tables.getcolumn(cols, var) for var in vars)
54-
test = (; pairs...) |> Tables.materializer(feat)
57+
foutput, nothing
58+
end
5559

56-
predict(model, test), nothing
60+
function _defaultmodel(table::LabeledTable)
61+
cols = Tables.columns(table)
62+
vals = Tables.getcolumn(cols, only(table.labels))
63+
type = elscitype(vals)
64+
if type <: Categorical
65+
KNNClassifier(1)
66+
elseif type <: Continuous
67+
KNNRegressor(1)
68+
else
69+
throw(ErrorException("no default learning model for $type labels"))
70+
end
5771
end

test/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
[deps]
2-
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
32
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
43
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
54
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"

test/runtests.jl

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ using Test
66

77
using GLM: ProbitLink
88
using Distributions: Binomial
9-
using ColumnSelectors: selector
109

1110
import MLJ, MLJDecisionTreeInterface
1211

@@ -24,22 +23,6 @@ const SLM = StatsLearnModels
2423
model = DecisionTreeClassifier()
2524
fmodel = SLM.fit(model, input[train, :], output[train, :])
2625
@test sprint(show, fmodel) == "FittedStatsLearnModel{DecisionTreeClassifier}"
27-
28-
# show method
29-
lmodel = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c)
30-
@test sprint(show, lmodel) == """
31-
StatsLearnModel{DecisionTreeClassifier}
32-
├─ features: [:a, :b]
33-
└─ targets: :c"""
34-
35-
# accessor functions
36-
model = DecisionTreeClassifier()
37-
feats = selector([:a, :b])
38-
targs = selector(:c)
39-
lmodel = SLM.StatsLearnModel(model, feats, targs)
40-
@test lmodel.model === model
41-
@test lmodel.feats === feats
42-
@test lmodel.targs === targs
4326
end
4427

4528
@testset "Models" begin
@@ -109,19 +92,12 @@ const SLM = StatsLearnModels
10992
input = iris[:, Not(:target)]
11093
output = iris[:, [:target]]
11194
train, test = MLJ.partition(1:nrow(input), 0.7, rng=123)
112-
outvar = :target
113-
feats = setdiff(propertynames(iris), [outvar])
114-
targs = outvar
11595
model = DecisionTreeClassifier()
116-
transform = Learn(iris[train, :], model, feats => targs)
96+
transform = Learn(label(iris[train, :], :target); model)
11797
@test !isrevertible(transform)
11898
pred = transform(iris[test, :])
11999
accuracy = count(pred.target .== iris.target[test]) / length(test)
120100
@test accuracy > 0.9
121-
122-
# throws
123-
# training data is not a table
124-
@test_throws ArgumentError Learn(nothing, model, feats => targs)
125101
end
126102

127103
@testset "MLJ" begin

0 commit comments

Comments
 (0)