Skip to content

Commit e7b0c14

Browse files
committed
Refactor tests
1 parent e453eda commit e7b0c14

File tree

1 file changed

+23
-31
lines changed

1 file changed

+23
-31
lines changed

test/runtests.jl

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,21 @@ const SLM = StatsLearnModels
1717
y = 2x1 + x2
1818
input = DataFrame(; x1, x2)
1919
output = DataFrame(; y)
20-
train, test = 1:70, 71:100
2120
model = DecisionTreeClassifier()
22-
fmodel = SLM.fit(model, input[train, :], output[train, :])
21+
fmodel = SLM.fit(model, input, output)
2322
@test sprint(show, fmodel) == "FittedStatsLearnModel{DecisionTreeClassifier}"
2423
end
2524

2625
@testset "Models" begin
2726
@testset "NearestNeighbors" begin
2827
Random.seed!(123)
29-
iris = DataFrame(MLJ.load_iris())
30-
input = iris[:, Not(:target)]
31-
output = iris[:, [:target]]
32-
train, test = MLJ.partition(1:nrow(input), 0.7, rng=123)
28+
input = (x1=rand(100), x2=rand(100))
29+
output = (y=rand(1:3, 100),)
3330
model = KNNClassifier(5)
34-
fmodel = SLM.fit(model, input[train, :], output[train, :])
35-
pred = SLM.predict(fmodel, input[test, :])
36-
accuracy = count(pred.target .== output.target[test]) / length(test)
37-
@test accuracy > 0.9
31+
fmodel = SLM.fit(model, input, output)
32+
foutput = SLM.predict(fmodel, input)
33+
accuracy = count(foutput.y .== output.y) / length(output.y)
34+
@test accuracy > 0
3835

3936
Random.seed!(123)
4037
x1 = rand(1:0.1:10, 100)
@@ -58,43 +55,38 @@ const SLM = StatsLearnModels
5855
output = DataFrame(; y)
5956
model = LinearRegressor()
6057
fmodel = SLM.fit(model, input, output)
61-
pred = SLM.predict(fmodel, input)
62-
@test all(isapprox.(pred.y, output.y, atol=0.5))
58+
foutput = SLM.predict(fmodel, input)
59+
@test all(isapprox.(foutput.y, output.y, atol=0.5))
6360
x = [1, 2, 2]
6461
y = [1, 0, 1]
6562
input = DataFrame(; ones=ones(length(x)), x)
6663
output = DataFrame(; y)
6764
model = GeneralizedLinearRegressor(Binomial(), ProbitLink())
6865
fmodel = SLM.fit(model, input, output)
69-
pred = SLM.predict(fmodel, input)
70-
@test all(isapprox.(pred.y, output.y, atol=0.5))
66+
foutput = SLM.predict(fmodel, input)
67+
@test all(isapprox.(foutput.y, output.y, atol=0.5))
7168
end
7269

7370
@testset "DecisionTree" begin
7471
Random.seed!(123)
75-
iris = DataFrame(MLJ.load_iris())
76-
input = iris[:, Not(:target)]
77-
output = iris[:, [:target]]
78-
train, test = MLJ.partition(1:nrow(input), 0.7, rng=123)
72+
input = (x1=rand(100), x2=rand(100))
73+
output = (y=rand(1:3, 100),)
7974
model = DecisionTreeClassifier()
80-
fmodel = SLM.fit(model, input[train, :], output[train, :])
81-
pred = SLM.predict(fmodel, input[test, :])
82-
accuracy = count(pred.target .== output.target[test]) / length(test)
83-
@test accuracy > 0.9
75+
fmodel = SLM.fit(model, input, output)
76+
foutput = SLM.predict(fmodel, input)
77+
accuracy = count(foutput.y .== output.y) / length(output.y)
78+
@test accuracy > 0
8479
end
8580
end
8681

8782
@testset "Learn" begin
8883
Random.seed!(123)
89-
iris = DataFrame(MLJ.load_iris())
90-
input = iris[:, Not(:target)]
91-
output = iris[:, [:target]]
92-
train, test = MLJ.partition(1:nrow(input), 0.7, rng=123)
84+
train = (x1=rand(100), x2=rand(100), y=rand(1:3, 100))
9385
model = DecisionTreeClassifier()
94-
transform = Learn(label(iris[train, :], :target); model)
95-
@test !isrevertible(transform)
96-
pred = transform(iris[test, :])
97-
accuracy = count(pred.target .== iris.target[test]) / length(test)
98-
@test accuracy > 0.9
86+
learn = Learn(label(train, :y); model)
87+
@test !isrevertible(learn)
88+
preds = learn(train)
89+
accuracy = count(preds.y .== train.y) / length(train.y)
90+
@test accuracy > 0
9991
end
10092
end

0 commit comments

Comments
 (0)