@@ -3,6 +3,8 @@ using EnsembleKalmanProcesses
33using JLD2
44using LinearAlgebra
55using Random
6+ using MCMCChains
7+ using AdvancedMH
68
79# CES
810using CalibrateEmulateSample. Emulators
@@ -11,9 +13,29 @@ using CalibrateEmulateSample.Utilities
1113
1214include (" ./models.jl" )
1315
14- rng = Random. MersenneTwister (123 )
16+
17+ mutable struct NoEmulation <: Emulators.MachineLearningTool
18+ f
19+ output_structure_mats
20+
21+ NoEmulation (f) = new (f, nothing )
22+ end
23+ function Emulators. build_models! (ne:: NoEmulation , iopairs, input_structure_mats, output_structure_mats; mlt_kwargs... )
24+ ne. output_structure_mats = output_structure_mats
25+ end
26+ function Emulators. predict (ne:: NoEmulation , new_inputs; mlt_kwargs... )
27+ encoder_schedule = mlt_kwargs[:encoder_schedule ]
28+
29+ decoded_inputs = decode_with_schedule (encoder_schedule, EnsembleKalmanProcesses. DataContainer (new_inputs), " in" )
30+ decoded_outputs = hcat (map (ne. f, eachcol (decoded_inputs. data))... )
31+ encoded_outputs = encode_with_schedule (encoder_schedule, EnsembleKalmanProcesses. DataContainer (decoded_outputs), " out" ). data
32+ encoded_outputs, cat ([Utilities. get_structure_mat (ne. output_structure_mats) for _ in eachcol (new_inputs)]. .. ; dims = 3 )
33+ end
34+
35+
36+ rng = Random. MersenneTwister (41 )
1537input_dim = 100
16- output_dim = 50
38+ output_dim = 100
1739αs = [0.0 , 0.25 , 0.5 , 0.75 , 1.0 ]
1840
1941num_trials = 1
@@ -25,53 +47,69 @@ for trial in 1:num_trials
2547 obs_noise_cov = loaded[" obs_noise_cov" ]
2648 y = loaded[" y" ]
2749 model = loaded[" model" ]
50+ true_parameter = loaded[" true_parameter" ]
2851
2952
3053 min_iter = 1
31- max_iter = 7 # number of EKP iterations to use data from is at most this
32-
33- encoder_schedule_decorrelate = [(decorrelate_structure_mat (; retain_var = 0.7 ), " in_and_out" )]
34- encoder_schedules_li = [
35- [(likelihood_informed (; retain_KL = 0.9 , alpha = α, use_data_as_samples = false ), " in_and_out" )] for α in αs
36- ]
37-
38- em_decorrelate = Emulator (
39- GaussianProcess (Emulators. GPJL (); kernel = nothing , prediction_type = Emulators. YType (), noise_learn = false ),
40- Utilities. get_training_points (ekpobj, min_iter: max_iter);
41- encoder_schedule = encoder_schedule_decorrelate,
42- encoder_kwargs = (; prior_cov = cov (prior_obj), obs_noise_cov = obs_noise_cov),
43- )
44-
45- ems_li = [
46- Emulator (
47- GaussianProcess (Emulators. GPJL (); kernel = nothing , prediction_type = Emulators. YType (), noise_learn = false ),
54+ max_iter = 10 # number of EKP iterations to use data from is at most this
55+
56+ encoder_schedule_ref = [(decorrelate_structure_mat (; retain_var = 1.0 ), " in_and_out" )]
57+
58+ dims = [2 , 4 , 6 , 8 , 10 ]
59+ all_errs = zeros (length (dims), 1 + length (αs))
60+ for (dim_i, dim) in enumerate (dims)
61+ encoder_schedule_decorrelate = [(Decorrelator ([], [], [], (:dimension , dim), " structure_mat" , nothing ), " in" ), (decorrelate_structure_mat (; retain_var = 1.0 ), " out" )]
62+ encoder_schedules_li = [
63+ [(decorrelate_structure_mat (; retain_var = 1.0 ), " in_and_out" ), (LikelihoodInformed (nothing , nothing , nothing , (:dimension , dim), α, :linreg , false ), " in" )] for α in αs
64+ ]
65+
66+ em_ref = Emulator (
67+ NoEmulation (param -> forward_map (param, model)),
68+ Utilities. get_training_points (ekpobj, min_iter: max_iter);
69+ encoder_schedule = encoder_schedule_ref,
70+ encoder_kwargs = (; prior_cov = cov (prior_obj), obs_noise_cov = obs_noise_cov),
71+ )
72+
73+ em_decorrelate = Emulator (
74+ NoEmulation (param -> forward_map (param, model)),
4875 Utilities. get_training_points (ekpobj, min_iter: max_iter);
49- encoder_schedule = encoder_schedule ,
50- encoder_kwargs = (; prior_cov = cov (prior_obj), obs_noise_cov = obs_noise_cov, samples_in = ekp_samp[α][ 1 ], samples_out = ekp_samp[α][ 2 ], observation = y ),
76+ encoder_schedule = encoder_schedule_decorrelate ,
77+ encoder_kwargs = (; prior_cov = cov (prior_obj), obs_noise_cov = obs_noise_cov),
5178 )
52- for (encoder_schedule, α) in zip (encoder_schedules_li, αs)
53- ]
5479
55- post_means = zeros (input_dim, 0 )
56- for em in vcat (em_decorrelate, ems_li... )
57- u0 = rand (MvNormal (mean (prior_obj), cov (prior_obj)))
58- mcmc = MCMCWrapper (RWMHSampling (), y, prior_obj, em; init_params = u0)
59- new_step = optimize_stepsize (mcmc; init_stepsize = 0.1 , N = 2000 , discard_initial = 0 )
80+ ems_li = [
81+ Emulator (
82+ NoEmulation (param -> forward_map (param, model)),
83+ Utilities. get_training_points (ekpobj, min_iter: max_iter);
84+ encoder_schedule = encoder_schedule,
85+ encoder_kwargs = (; prior_cov = cov (prior_obj), obs_noise_cov = obs_noise_cov, samples_in = ekp_samp[α][1 ], samples_out = ekp_samp[α][2 ], observation = y),
86+ )
87+ for (encoder_schedule, α) in zip (encoder_schedules_li, αs)
88+ ]
89+
90+ post_means = reshape (true_parameter, input_dim, 1 )
91+ post_covs = []
92+ for em in vcat (em_ref, em_decorrelate, ems_li... )
93+ u0 = rand (MvNormal (mean (prior_obj), cov (prior_obj)))
94+ mcmc = MCMCWrapper (RWMHSampling (), y, prior_obj, em; init_params = u0)
95+ new_step = optimize_stepsize (mcmc; init_stepsize = 0.05 , N = 2000 , discard_initial = 0 )
96+
97+ println (" Begin MCMC - with step size " , new_step)
98+ num_chains = 16
99+ mcmc = MCMCWrapper (RWMHSampling (), y, prior_obj, em; init_params = [u0 for _ in 1 : num_chains])
100+ chain = MarkovChainMonteCarlo. sample (mcmc, MCMCThreads (), 40_000 , num_chains; chain_type = Chains, stepsize = new_step, discard_initial = 5_000 )
60101
61- println (" Begin MCMC - with step size " , new_step)
62- chain = MarkovChainMonteCarlo. sample (mcmc, 10_000 ; stepsize = new_step, discard_initial = 2_000 )
102+ posterior = MarkovChainMonteCarlo. get_posterior (mcmc, chain)
63103
64- posterior = MarkovChainMonteCarlo. get_posterior (mcmc, chain)
104+ post_mean = mean (posterior)
105+ post_cov = cov (posterior)
65106
66- post_mean = mean (posterior)
67- post_cov = cov (posterior)
68- println (" post_mean" )
69- println (post_mean)
70- println (" post_cov" )
71- println (post_cov)
107+ post_means = hcat (post_means, reshape (post_mean, input_dim, 1 ))
108+ push! (post_covs, post_cov)
109+ end
72110
73- post_means = hcat (post_means, reshape (post_mean, input_dim, 1 ))
111+ all_errs[dim_i, :] = [ norm (post_means[:, 2 ] - v) / norm (post_means[:, 2 ]) for v in eachcol (post_means[:, 3 : end ])] '
74112 end
75113
76- println (post_means[ 1 : 10 ,:] )
114+ display (all_errs )
77115end
0 commit comments