Skip to content

Commit 1dd1fef

Browse files
committed
Add LabeledTable
1 parent 5a7e7ab commit 1dd1fef

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

src/StatsLearnModels.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,18 @@ import GLM
2222
import DecisionTree as DT
2323
import NearestNeighbors as NN
2424

25+
include("labeledtable.jl")
2526
include("interface.jl")
2627
include("models/nn.jl")
2728
include("models/glm.jl")
2829
include("models/tree.jl")
2930
include("learn.jl")
3031

3132
export
33+
# labeled table
34+
LabeledTable,
35+
label,
36+
3237
# NearestNeighbors.jl
3338
KNNClassifier,
3439
KNNRegressor,

src/labeledtable.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# ------------------------------------------------------------------
2+
# Licensed under the MIT License. See LICENSE in the project root.
3+
# ------------------------------------------------------------------
4+
5+
"""
6+
LabeledTable(table, names)
7+
8+
Stores a Tables.jl `table` along with column `names` that
9+
identify which columns are labels for supervised learning.
10+
"""
11+
struct LabeledTable{T}
12+
table::T
13+
labels::Vector{Symbol}
14+
end
15+
16+
function LabeledTable(table, names)
17+
Tables.istable(table) || throw(ArgumentError("please provide a valid Tables.jl table"))
18+
cols = Tables.columns(table)
19+
vars = Tables.columnnames(cols)
20+
labs = selector(names)(vars)
21+
labs vars || throw(ArgumentError("all labels must be column names in the table"))
22+
vars labs && throw(ArgumentError("there must be at least one feature column in the table"))
23+
LabeledTable{typeof(table)}(table, labs)
24+
end
25+
26+
"""
27+
label(table, names)
28+
29+
Creates a `LabeledTable` from `table` using `names` as label columns.
30+
"""
31+
label(table, names) = LabeledTable(table, names)

0 commit comments

Comments
 (0)