Skip to content

Commit 6a203ef

Browse files
committed
Implement accessor methods for LabeledTable
1 parent 324a5c5 commit 6a203ef

File tree

3 files changed

+58
-10
lines changed

3 files changed

+58
-10
lines changed

src/StatsLearnModels.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ include("learn.jl")
3434
export
3535
# labeled table
3636
LabeledTable,
37+
predictors,
38+
targets,
3739
label,
3840

3941
# NearestNeighbors.jl

src/labeledtable.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ function LabeledTable(table, names)
2323
LabeledTable{typeof(table)}(table, labs)
2424
end
2525

26+
# -----------------
27+
# TABLES INTERFACE
28+
# -----------------
29+
2630
Tables.istable(::Type{<:LabeledTable}) = true
2731

2832
Tables.rowaccess(::Type{<:LabeledTable{T}}) where {T} = Tables.rowaccess(T)
@@ -35,6 +39,20 @@ Tables.columns(t::LabeledTable) = Tables.columns(t.table)
3539

3640
Tables.columnnames(t::LabeledTable) = Tables.columnnames(t.table)
3741

42+
# -----------------
43+
# ACCESSOR METHODS
44+
# -----------------
45+
46+
Base.parent(t::LabeledTable) = t.table
47+
48+
function predictors(t::LabeledTable)
49+
cols = Tables.columns(t.table)
50+
vars = Tables.columnnames(cols)
51+
setdiff(vars, t.labels)
52+
end
53+
54+
targets(t::LabeledTable) = t.labels
55+
3856
# -----------
3957
# IO METHODS
4058
# -----------

test/runtests.jl

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,34 @@ using Distributions: Binomial
1010
const SLM = StatsLearnModels
1111

1212
@testset "StatsLearnModels.jl" begin
13-
@testset "Basic" begin
14-
# show method
15-
x1 = rand(1:0.1:10, 100)
16-
x2 = rand(1:0.1:10, 100)
17-
y = 2x1 + x2
18-
input = DataFrame(; x1, x2)
19-
output = DataFrame(; y)
20-
model = DecisionTreeClassifier()
21-
fmodel = SLM.fit(model, input, output)
22-
@test sprint(show, fmodel) == "FittedStatsLearnModel{DecisionTreeClassifier}"
13+
@testset "LabeledTable" begin
14+
# labels as symbols
15+
t = (x1=rand(3), x2=rand(3), y=rand(Int, 3))
16+
l = label(t, :y)
17+
@test parent(l) == t
18+
@test predictors(l) == [:x1, :x2]
19+
@test targets(l) == [:y]
20+
21+
# labels as strings
22+
t = (x1=rand(3), x2=rand(3), y=rand(Int, 3))
23+
l = label(t, "y")
24+
@test parent(l) == t
25+
@test predictors(l) == [:x1, :x2]
26+
@test targets(l) == [:y]
27+
28+
# multiple labels
29+
t = (x1=rand(3), x2=rand(3), y1=rand(Int, 3), y2=rand(Int, 3))
30+
l = label(t, ["y1", "y2"])
31+
@test parent(l) == t
32+
@test predictors(l) == [:x1, :x2]
33+
@test targets(l) == [:y1, :y2]
34+
35+
# labels as regex
36+
t = (x1=rand(3), x2=rand(3), y1=rand(Int, 3), y2=rand(Int, 3))
37+
l = label(t, r"y")
38+
@test parent(l) == t
39+
@test predictors(l) == [:x1, :x2]
40+
@test targets(l) == [:y1, :y2]
2341
end
2442

2543
@testset "Models" begin
@@ -77,6 +95,16 @@ const SLM = StatsLearnModels
7795
accuracy = count(foutput.y .== output.y) / length(output.y)
7896
@test accuracy > 0
7997
end
98+
99+
# show method
100+
x1 = rand(1:0.1:10, 100)
101+
x2 = rand(1:0.1:10, 100)
102+
y = 2x1 + x2
103+
input = DataFrame(; x1, x2)
104+
output = DataFrame(; y)
105+
model = DecisionTreeClassifier()
106+
fmodel = SLM.fit(model, input, output)
107+
@test sprint(show, fmodel) == "FittedStatsLearnModel{DecisionTreeClassifier}"
80108
end
81109

82110
@testset "Learn" begin

0 commit comments

Comments
 (0)