|
3 | 3 | # ------------------------------------------------------------------ |
4 | 4 |
|
5 | 5 | """ |
6 | | - Learn(train, model, features => targets) |
| 6 | + Learn(table; [model]) |
7 | 7 |
|
8 | | -Fits the statistical learning `model` to `train` table, |
9 | | -using the selectors of `features` and `targets`. |
| 8 | +Perform supervised learning with labeled `table` and |
| 9 | +statistical learning `model`. |
| 10 | +
|
| 11 | +Uses `KNNClassifier(1)` or `KNNRegressor(1)` model by |
| 12 | +default depending on the scientific type of the labels |
| 13 | +stored in the table. |
10 | 14 |
|
11 | 15 | ## Examples |
12 | 16 |
|
13 | 17 | ```julia |
14 | | -Learn(train, model, [1, 2, 3] => "d") |
15 | | -Learn(train, model, [:a, :b, :c] => :d) |
16 | | -Learn(train, model, ["a", "b", "c"] => 4) |
17 | | -Learn(train, model, [1, 2, 3] => [:d, :e]) |
18 | | -Learn(train, model, r"[abc]" => ["d", "e"]) |
| 18 | +Learn(label(table, "y")) |
| 19 | +Learn(label(table, ["y1", "y2"])) |
| 20 | +Learn(label(table, 3), model=KNNClassifier(5)) |
19 | 21 | ``` |
| 22 | +
|
| 23 | +See also [`label`](@ref). |
20 | 24 | """ |
21 | | -struct Learn{M<:FittedStatsLearnModel} <: StatelessFeatureTransform |
| 25 | +struct Learn{T<:LabeledTable,M} <: StatelessFeatureTransform |
| 26 | + table::T |
22 | 27 | model::M |
23 | | - feats::Vector{Symbol} |
24 | 28 | end |
25 | 29 |
|
26 | | -Learn(train, model, (feats, targs)::Pair) = Learn(train, StatsLearnModel(model, feats, targs)) |
27 | | - |
28 | | -function Learn(train, lmodel::StatsLearnModel) |
29 | | - if !Tables.istable(train) |
30 | | - throw(ArgumentError("training data must be a table")) |
31 | | - end |
32 | | - |
33 | | - cols = Tables.columns(train) |
34 | | - names = Tables.columnnames(cols) |
35 | | - feats = lmodel.feats(names) |
36 | | - targs = lmodel.targs(names) |
| 30 | +Learn(table::LabeledTable; model=_defaultmodel(table)) = Learn(table, model) |
37 | 31 |
|
38 | | - input = (; (var => Tables.getcolumn(cols, var) for var in feats)...) |
39 | | - output = (; (var => Tables.getcolumn(cols, var) for var in targs)...) |
| 32 | +function applyfeat(transform::Learn, feat, prep) |
| 33 | + # labeled table and model |
| 34 | + table = transform.table |
| 35 | + model = transform.model |
40 | 36 |
|
41 | | - fmodel = fit(lmodel.model, input, output) |
| 37 | + # variables in labeled table |
| 38 | + cols = Tables.columns(table) |
| 39 | + vars = Tables.columnnames(cols) |
42 | 40 |
|
43 | | - Learn(fmodel, feats) |
44 | | -end |
| 41 | + # split targets and predictors |
| 42 | + targs = table.labels |
| 43 | + preds = setdiff(vars, targs) |
45 | 44 |
|
46 | | -isrevertible(::Type{<:Learn}) = false |
| 45 | + # learn function with statistical model |
| 46 | + input = (; (pred => Tables.getcolumn(cols, pred) for pred in preds)...) |
| 47 | + output = (; (targ => Tables.getcolumn(cols, targ) for targ in targs)...) |
| 48 | + fmodel = fit(model, input, output) |
47 | 49 |
|
48 | | -function applyfeat(transform::Learn, feat, prep) |
49 | | - model = transform.model |
50 | | - vars = transform.feats |
| 50 | + # predict labels with new predictors |
| 51 | + fcols = Tables.columns(feat) |
| 52 | + fvars = Tables.columnnames(fcols) |
| 53 | + preds ⊆ fvars || throw(ArgumentError("predictors $preds not found in input table")) |
| 54 | + finput = (; (pred => Tables.getcolumn(fcols, pred) for pred in preds)...) |
| 55 | + foutput = predict(fmodel, finput) |> Tables.materializer(feat) |
51 | 56 |
|
52 | | - cols = Tables.columns(feat) |
53 | | - pairs = (var => Tables.getcolumn(cols, var) for var in vars) |
54 | | - test = (; pairs...) |> Tables.materializer(feat) |
| 57 | + foutput, nothing |
| 58 | +end |
55 | 59 |
|
56 | | - predict(model, test), nothing |
| 60 | +function _defaultmodel(table::LabeledTable) |
| 61 | + cols = Tables.columns(table) |
| 62 | + vals = Tables.getcolumn(cols, only(table.labels)) |
| 63 | + type = elscitype(vals) |
| 64 | + if type <: Categorical |
| 65 | + KNNClassifier(1) |
| 66 | + elseif type <: Continuous |
| 67 | + KNNRegressor(1) |
| 68 | + else |
| 69 | + throw(ErrorException("no default learning model for $type labels")) |
| 70 | + end |
57 | 71 | end |
0 commit comments