Skip to content

Commit 00968d0

Browse files
committed
Fix likelihood-informed bugs
1 parent 143ce90 commit 00968d0

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

src/Utilities/likelihood_informed.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ function initialize_processor!(
3535
output_structure_vectors::Dict{Symbol, <:StructureVector},
3636
apply_to::AbstractString,
3737
) where {MM <: AbstractMatrix}
38+
input_dim = size(in_data, 1)
3839
output_dim = size(out_data, 1)
3940

4041
if isnothing(get_encoder_mat(li))
@@ -69,9 +70,10 @@ function initialize_processor!(
6970
# This can be a scalar or a matrix; in the latter case, we can even use the covariance
7071
# of the samples (or the prior covariance).
7172
weights = exp.(-1/2 * norm.(eachcol(u .- samples_in)).^2)
73+
weights ./= sum(weights)
7274
D = Diagonal(sqrt.(weights))
73-
uw = (samples_in .- mean(samples_in * Diagonal(weights); dims = 2)) * D
74-
gw = (samples_out .- mean(samples_out * Diagonal(weights); dims = 2)) * D
75+
uw = (samples_in .- sum(samples_in * Diagonal(weights); dims = 2)) * D
76+
gw = (samples_out .- sum(samples_out * Diagonal(weights); dims = 2)) * D
7577
gw / uw
7678
end
7779
end
@@ -84,16 +86,17 @@ function initialize_processor!(
8486
eigen(hermitianpart(mean(grad * grad' for grad in grads)), obs_noise_cov, sortby = (-))
8587
end
8688

87-
if li.dim_criterion[1] == :retain_KL
89+
sv_cumsum = cumsum(decomp.values) / sum(decomp.values)
90+
trunc_val = if li.dim_criterion[1] == :retain_KL
8891
retain_KL = li.dim_criterion[2]
89-
sv_cumsum = cumsum(decomp.values) / sum(decomp.values)
9092
trunc_val = findfirst(x -> (x retain_KL), sv_cumsum)
93+
isnothing(trunc_val) ? (apply_to == "in" ? input_dim : output_dim) : trunc_val
9194
else
9295
@assert li.dim_criterion[1] == :dimension
93-
trunc_val = li.dim_criterion[2]
96+
li.dim_criterion[2]
9497
end
9598
@info " truncating at $trunc_val/$(length(sv_cumsum)) retaining $(100.0*sv_cumsum[trunc_val])% of the KL divergence reduction"
96-
li.encoder_mat = decomp.vectors[:, 1:trunc_val]'
99+
decomp.vectors[:, 1:trunc_val]'
97100
else
98101
@assert apply_to == "out"
99102
@warn "Using LikelihoodInformed on output data with α≠0 triggers a manifold optimization process that may take some time."
@@ -111,7 +114,7 @@ function initialize_processor!(
111114
f = (_, Vs) -> begin
112115
prec = noise_cov_inv - Vs * inv(Vs' * obs_noise_cov * Vs) * Vs'
113116
tr(mean(
114-
grad' * prec * ((1-α)I + α^2 * (y - g)*(y - g)') * prec * grad
117+
grad' * prec * ((1-α)I + α^2 * (y - g)*(y - g)') * prec * grad # TODO: Γ instead of I?
115118
for (g, grad) in zip(eachcol(out_data), grads)
116119
))
117120
end
@@ -140,7 +143,7 @@ function initialize_processor!(
140143
@info " truncating at $k/$output_dim retaining $(100.0*(1-val/ref))% of the KL divergence reduction"
141144
break # TODO: Start bisecting?
142145
else
143-
k *= 2
146+
k = min(2k, output_dim)
144147
end
145148
else
146149
@assert li.dim_criterion[1] == :dimension

0 commit comments

Comments
 (0)