Skip to content

Commit 208a7e0

Browse files
committed
Update tests
1 parent 00968d0 commit 208a7e0

File tree

3 files changed

+84
-44
lines changed

3 files changed

+84
-44
lines changed

examples/DimensionReduction/calibrate_linlinexp.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ end
1818
function EnsembleKalmanProcesses.calculate_timestep!(ekp, g, Δt_new, scheduler::CheckpointScheduler)
1919
EnsembleKalmanProcesses.calculate_timestep!(ekp, g, Δt_new, scheduler.scheduler)
2020
if scheduler.current_index <= length(scheduler.αs) && get_algorithm_time(ekp)[end] > scheduler.αs[scheduler.current_index]
21-
get_algorithm_time(ekp)[end] = scheduler.αs[scheduler.current_index]
21+
ekp.Δt[end] -= get_algorithm_time(ekp)[end] - scheduler.αs[scheduler.current_index]
2222
scheduler.current_index += 1
2323
end
2424

@@ -28,9 +28,9 @@ function get_algorithm_time(ekp::EnsembleKalmanProcess) # This is not defined in
2828
return accumulate(+, ekp.Δt)
2929
end
3030

31-
rng = Random.MersenneTwister(123)
31+
rng = Random.MersenneTwister(41)
3232
input_dim = 100
33-
output_dim = 50
33+
output_dim = 100
3434
αs = [0.0, 0.25, 0.5, 0.75, 1.0]
3535

3636
num_trials = 1
@@ -42,7 +42,7 @@ for trial in 1:num_trials
4242
Parameterized(MvNormal(zeros(size(prior_cov, 1)), prior_cov)), fill(no_constraint(), size(prior_cov, 1)), "linlinexp_prior",
4343
)
4444

45-
n_ensemble = 200
45+
n_ensemble = 400
4646
initial_ensemble = construct_initial_ensemble(rng, prior_obj, n_ensemble)
4747
ekp = EnsembleKalmanProcess(
4848
initial_ensemble,
@@ -64,6 +64,8 @@ for trial in 1:num_trials
6464
end
6565
@info "EKP iterations: $n_iters"
6666
@info "Loss over iterations: $(get_error(ekp))"
67+
@info "Timesteps: $(ekp.Δt)"
68+
@info "Checkpoints: $(cumsum(ekp.Δt))"
6769

6870
ekp_samples = Dict()
6971
for α in αs

examples/DimensionReduction/emulate_sample_linlinexp.jl

Lines changed: 77 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ using EnsembleKalmanProcesses
33
using JLD2
44
using LinearAlgebra
55
using Random
6+
using MCMCChains
7+
using AdvancedMH
68

79
# CES
810
using CalibrateEmulateSample.Emulators
@@ -11,9 +13,29 @@ using CalibrateEmulateSample.Utilities
1113

1214
include("./models.jl")
1315

14-
rng = Random.MersenneTwister(123)
16+
17+
mutable struct NoEmulation <: Emulators.MachineLearningTool
18+
f
19+
output_structure_mats
20+
21+
NoEmulation(f) = new(f, nothing)
22+
end
23+
function Emulators.build_models!(ne::NoEmulation, iopairs, input_structure_mats, output_structure_mats; mlt_kwargs...)
24+
ne.output_structure_mats = output_structure_mats
25+
end
26+
function Emulators.predict(ne::NoEmulation, new_inputs; mlt_kwargs...)
27+
encoder_schedule = mlt_kwargs[:encoder_schedule]
28+
29+
decoded_inputs = decode_with_schedule(encoder_schedule, EnsembleKalmanProcesses.DataContainer(new_inputs), "in")
30+
decoded_outputs = hcat(map(ne.f, eachcol(decoded_inputs.data))...)
31+
encoded_outputs = encode_with_schedule(encoder_schedule, EnsembleKalmanProcesses.DataContainer(decoded_outputs), "out").data
32+
encoded_outputs, cat([Utilities.get_structure_mat(ne.output_structure_mats) for _ in eachcol(new_inputs)]...; dims = 3)
33+
end
34+
35+
36+
rng = Random.MersenneTwister(41)
1537
input_dim = 100
16-
output_dim = 50
38+
output_dim = 100
1739
αs = [0.0, 0.25, 0.5, 0.75, 1.0]
1840

1941
num_trials = 1
@@ -25,53 +47,69 @@ for trial in 1:num_trials
2547
obs_noise_cov = loaded["obs_noise_cov"]
2648
y = loaded["y"]
2749
model = loaded["model"]
50+
true_parameter = loaded["true_parameter"]
2851

2952

3053
min_iter = 1
31-
max_iter = 7 # number of EKP iterations to use data from is at most this
32-
33-
encoder_schedule_decorrelate = [(decorrelate_structure_mat(; retain_var = 0.7), "in_and_out")]
34-
encoder_schedules_li = [
35-
[(likelihood_informed(; retain_KL = 0.9, alpha = α, use_data_as_samples = false), "in_and_out")] for α in αs
36-
]
37-
38-
em_decorrelate = Emulator(
39-
GaussianProcess(Emulators.GPJL(); kernel = nothing, prediction_type = Emulators.YType(), noise_learn = false),
40-
Utilities.get_training_points(ekpobj, min_iter:max_iter);
41-
encoder_schedule = encoder_schedule_decorrelate,
42-
encoder_kwargs = (; prior_cov = cov(prior_obj), obs_noise_cov = obs_noise_cov),
43-
)
44-
45-
ems_li = [
46-
Emulator(
47-
GaussianProcess(Emulators.GPJL(); kernel = nothing, prediction_type = Emulators.YType(), noise_learn = false),
54+
max_iter = 10 # number of EKP iterations to use data from is at most this
55+
56+
encoder_schedule_ref = [(decorrelate_structure_mat(; retain_var = 1.0), "in_and_out")]
57+
58+
dims = [2, 4, 6, 8, 10]
59+
all_errs = zeros(length(dims), 1 + length(αs))
60+
for (dim_i, dim) in enumerate(dims)
61+
encoder_schedule_decorrelate = [(Decorrelator([], [], [], (:dimension, dim), "structure_mat", nothing), "in"), (decorrelate_structure_mat(; retain_var = 1.0), "out")]
62+
encoder_schedules_li = [
63+
[(decorrelate_structure_mat(; retain_var = 1.0), "in_and_out"), (LikelihoodInformed(nothing, nothing, nothing, (:dimension, dim), α, :linreg, false), "in")] for α in αs
64+
]
65+
66+
em_ref = Emulator(
67+
NoEmulation(param -> forward_map(param, model)),
68+
Utilities.get_training_points(ekpobj, min_iter:max_iter);
69+
encoder_schedule = encoder_schedule_ref,
70+
encoder_kwargs = (; prior_cov = cov(prior_obj), obs_noise_cov = obs_noise_cov),
71+
)
72+
73+
em_decorrelate = Emulator(
74+
NoEmulation(param -> forward_map(param, model)),
4875
Utilities.get_training_points(ekpobj, min_iter:max_iter);
49-
encoder_schedule = encoder_schedule,
50-
encoder_kwargs = (; prior_cov = cov(prior_obj), obs_noise_cov = obs_noise_cov, samples_in = ekp_samp[α][1], samples_out = ekp_samp[α][2], observation = y),
76+
encoder_schedule = encoder_schedule_decorrelate,
77+
encoder_kwargs = (; prior_cov = cov(prior_obj), obs_noise_cov = obs_noise_cov),
5178
)
52-
for (encoder_schedule, α) in zip(encoder_schedules_li, αs)
53-
]
5479

55-
post_means = zeros(input_dim, 0)
56-
for em in vcat(em_decorrelate, ems_li...)
57-
u0 = rand(MvNormal(mean(prior_obj), cov(prior_obj)))
58-
mcmc = MCMCWrapper(RWMHSampling(), y, prior_obj, em; init_params = u0)
59-
new_step = optimize_stepsize(mcmc; init_stepsize = 0.1, N = 2000, discard_initial = 0)
80+
ems_li = [
81+
Emulator(
82+
NoEmulation(param -> forward_map(param, model)),
83+
Utilities.get_training_points(ekpobj, min_iter:max_iter);
84+
encoder_schedule = encoder_schedule,
85+
encoder_kwargs = (; prior_cov = cov(prior_obj), obs_noise_cov = obs_noise_cov, samples_in = ekp_samp[α][1], samples_out = ekp_samp[α][2], observation = y),
86+
)
87+
for (encoder_schedule, α) in zip(encoder_schedules_li, αs)
88+
]
89+
90+
post_means = reshape(true_parameter, input_dim, 1)
91+
post_covs = []
92+
for em in vcat(em_ref, em_decorrelate, ems_li...)
93+
u0 = rand(MvNormal(mean(prior_obj), cov(prior_obj)))
94+
mcmc = MCMCWrapper(RWMHSampling(), y, prior_obj, em; init_params = u0)
95+
new_step = optimize_stepsize(mcmc; init_stepsize = 0.05, N = 2000, discard_initial = 0)
96+
97+
println("Begin MCMC - with step size ", new_step)
98+
num_chains = 16
99+
mcmc = MCMCWrapper(RWMHSampling(), y, prior_obj, em; init_params = [u0 for _ in 1:num_chains])
100+
chain = MarkovChainMonteCarlo.sample(mcmc, MCMCThreads(), 40_000, num_chains; chain_type = Chains, stepsize = new_step, discard_initial = 5_000)
60101

61-
println("Begin MCMC - with step size ", new_step)
62-
chain = MarkovChainMonteCarlo.sample(mcmc, 10_000; stepsize = new_step, discard_initial = 2_000)
102+
posterior = MarkovChainMonteCarlo.get_posterior(mcmc, chain)
63103

64-
posterior = MarkovChainMonteCarlo.get_posterior(mcmc, chain)
104+
post_mean = mean(posterior)
105+
post_cov = cov(posterior)
65106

66-
post_mean = mean(posterior)
67-
post_cov = cov(posterior)
68-
println("post_mean")
69-
println(post_mean)
70-
println("post_cov")
71-
println(post_cov)
107+
post_means = hcat(post_means, reshape(post_mean, input_dim, 1))
108+
push!(post_covs, post_cov)
109+
end
72110

73-
post_means = hcat(post_means, reshape(post_mean, input_dim, 1))
111+
all_errs[dim_i, :] = [norm(post_means[:,2] - v)/norm(post_means[:,2]) for v in eachcol(post_means[:,3:end])]'
74112
end
75113

76-
println(post_means[1:10,:])
114+
display(all_errs)
77115
end

examples/DimensionReduction/models.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function linlinexp(input_dim, output_dim, rng)
2626
obs_noise_cov = Diagonal([Float64(j)^(-1 / 2) for j in 1:output_dim])
2727
noise = rand(rng, MvNormal(zeros(output_dim), obs_noise_cov))
2828
# true_parameter = reshape(ones(input_dim), :, 1)
29-
true_parameter = rand(MvNormal(zeros(input_dim), Γ))
29+
true_parameter = rand(rng, MvNormal(zeros(input_dim), Γ))
3030
y = vec(forward_map(true_parameter, model) + noise)
3131
return Γ, y, obs_noise_cov, model, true_parameter
3232
end

0 commit comments

Comments
 (0)