Skip to content

Commit 2c6ed9d

Browse files
Merge pull request #16 from flatironinstitute/fix-ess-imse-update
Fix ess imse update
2 parents f845628 + ea6ca6e commit 2c6ed9d

File tree

8 files changed

+382
-137
lines changed

8 files changed

+382
-137
lines changed

bayes_kit/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
# functions
2+
from .autocorr import autocorr
3+
from .ess import ess, ess_imse, ess_ipse
4+
from .iat import iat, iat_imse, iat_ipse
5+
from .rhat import rhat
6+
7+
# classes
8+
from .ensemble import Stretcher
19
from .hmc import HMCDiag
210
from .mala import MALA
311
from .metropolis import Metropolis, MetropolisHastings

bayes_kit/autocorr.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
import numpy.typing as npt
3+
4+
FloatType = np.float64
5+
VectorType = npt.NDArray[FloatType]
6+
7+
def autocorr(chain: VectorType) -> VectorType:
8+
"""Return sample autocorrelations at all lags from 0 to the length
9+
of the sequence minus 1 for the specified sequence. The returned
10+
vector will thus be the same size as the input vector.
11+
12+
Algorithmically, this function calls NumPy's fast Fourier transform
13+
and inverse fast Fourier transforms.
14+
15+
Parameters:
16+
chain: sequence whose autocorrelation is returned
17+
18+
Returns:
19+
autocorrelation estimates at all lags for the specified sequence
20+
21+
Raises:
22+
ValueError: if the size of the chain is less than 2
23+
"""
24+
if len(chain) < 2:
25+
raise ValueError(f"autocorr requires len(chain) >= 2, but {len(chain)=}")
26+
size = 2 ** np.ceil(np.log2(2 * len(chain) - 1)).astype("int")
27+
var = np.var(chain)
28+
ndata = chain - np.mean(chain)
29+
fft = np.fft.fft(ndata, size)
30+
sq_mag = np.abs(fft) ** 2
31+
N = len(ndata)
32+
acorr = np.fft.ifft(sq_mag).real / var / N
33+
return acorr[0:N]

bayes_kit/ess.py

Lines changed: 26 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,100 +1,31 @@
11
import numpy as np
22
import numpy.typing as npt
3+
import bayes_kit.autocorr as autocorr
4+
from bayes_kit.iat import iat, iat_ipse, iat_imse
35

46
FloatType = np.float64
57
IntType = np.int64
68
VectorType = npt.NDArray[FloatType]
79

8-
def autocorr_fft(chain: VectorType) -> VectorType:
9-
"""
10-
Return sample autocorrelations at all lags for the specified sequence.
11-
Algorithmically, this function calls a fast Fourier transform (FFT).
12-
13-
Parameters:
14-
chain: sequence whose autocorrelation is returned
15-
16-
Returns:
17-
autocorrelation estimates at all lags for the specified sequence
18-
"""
19-
size = 2 ** np.ceil(np.log2(2 * len(chain) - 1)).astype("int")
20-
var = np.var(chain)
21-
ndata = chain - np.mean(chain)
22-
fft = np.fft.fft(ndata, size)
23-
pwr = np.abs(fft) ** 2
24-
N = len(ndata)
25-
acorr = np.fft.ifft(pwr).real / var / N
26-
return acorr
27-
28-
def autocorr_np(chain: VectorType) -> VectorType:
29-
"""
30-
Return sample autocorrelations at all lags for the specified sequence.
31-
Algorithmically, this function delegates to the Numpy `correlation()` function.
32-
33-
Parameters:
34-
chain: sequence whose autocorrelation is returned
35-
36-
Returns:
37-
autocorrelation estimates at all lags for the specified sequence
38-
"""
39-
chain_ctr = chain - np.mean(chain)
40-
N = len(chain_ctr)
41-
acorrN = np.correlate(chain_ctr, chain_ctr, "full")[N - 1 :]
42-
return acorrN / N
43-
44-
def autocorr(chain: VectorType) -> VectorType:
45-
"""
46-
Return sample autocorrelations at all lags for the specified sequence.
47-
Algorithmically, this function delegates to `autocorr_fft`.
48-
49-
Parameters:
50-
chain: sequence whose autocorrelation is returned
51-
52-
Returns:
53-
autocorrelation estimates at all lags for the specified sequence
54-
"""
55-
return autocorr_fft(chain)
56-
57-
def first_neg_pair_start(chain: VectorType) -> IntType:
58-
"""
59-
Return the index of first element of the sequence whose sum with the following
60-
element is negative, or the length of the sequence if there is no such element.
61-
62-
Parameters:
63-
chain: input sequence
64-
65-
Return:
66-
index of first element whose sum with following element is negative, or
67-
the number of elements if there is no such element
68-
"""
69-
N = len(chain)
70-
n = 0
71-
while n + 1 < N:
72-
if chain[n] + chain[n + 1] < 0:
73-
return n
74-
n = n + 2
75-
return N
7610

7711
def ess_ipse(chain: VectorType) -> FloatType:
7812
"""
7913
Return an estimate of the effective sample size (ESS) of the specified Markov chain
8014
using the initial positive sequence estimator (IPSE).
8115
8216
Parameters:
83-
chain: Markov chain whose ESS is returned
17+
chain: Markov chain whose ESS is returned
8418
8519
Return:
86-
estimated effective sample size for the specified Markov chain
20+
estimated effective sample size for the specified Markov chain
8721
88-
Throws:
89-
ValueError: if there are fewer than 4 elements in the chain
22+
Raises:
23+
ValueError: if there are fewer than 4 elements in the chain
9024
"""
9125
if len(chain) < 4:
92-
raise ValueError(f"ess requires len(chains) >=4, but {len(chain) = }")
93-
acor = autocorr(chain)
94-
n = first_neg_pair_start(acor)
95-
sigma_sq_hat = acor[0] + 2 * sum(acor[1:n])
96-
ess = len(chain) / sigma_sq_hat
97-
return ess
26+
raise ValueError(f"ess_ipse(chain) requires len(chain) >= 4, but {len(chain)=}")
27+
return len(chain) / iat_ipse(chain)
28+
9829

9930
def ess_imse(chain: VectorType) -> FloatType:
10031
"""
@@ -106,33 +37,23 @@ def ess_imse(chain: VectorType) -> FloatType:
10637
This estimator was introduced in the following paper.
10738
10839
Geyer, C.J., 1992. Practical Markov chain Monte Carlo. Statistical Science
109-
7(4):473--483.
110-
40+
7(4):473--483.
41+
11142
Parameters:
112-
chain: Markov chain whose ESS is returned
43+
chain: Markov chain whose ESS is returned
11344
11445
Return:
115-
estimated effective sample size for the specified Markov chain
46+
estimated effective sample size for the specified Markov chain
11647
11748
Throws:
118-
ValueError: if there are fewer than 4 elements in the chain
49+
ValueError: if there are fewer than 4 elements in the chain
11950
"""
12051
if len(chain) < 4:
121-
raise ValueError(f"ess requires len(chains) >=4, but {len(chain) = }")
122-
acor = autocorr(chain)
123-
n = first_neg_pair_start(acor)
124-
prev_min = acor[1] + acor[2]
125-
# convex minorization uses slow loop
126-
accum = prev_min
127-
i = 3
128-
while i + 1 < n:
129-
minprev = min(prev_min, acor[i] + acor[i + 1])
130-
accum = accum + minprev
131-
i = i + 2
132-
# end diff code
133-
sigma_sq_hat = acor[0] + 2 * accum
134-
ess = len(chain) / sigma_sq_hat
135-
return ess
52+
raise ValueError(
53+
f"ess_imse(chain) requires len(chain) >=4, but {len(chain) = }"
54+
)
55+
return len(chain) / iat_imse(chain)
56+
13657

13758
def ess(chain: VectorType) -> FloatType:
13859
"""
@@ -141,14 +62,14 @@ def ess(chain: VectorType) -> FloatType:
14162
to `ess_imse()`.
14263
14364
Parameters:
144-
chain: Markov chains whose ESS is returned
65+
chain: Markov chains whose ESS is returned
14566
14667
Return:
147-
estimated effective sample size for the specified Markov chain
68+
estimated effective sample size for the specified Markov chain
14869
14970
Throws:
150-
ValueError: if there are fewer than 4 elements in the chain
151-
"""
152-
return ess_imse(chain)
153-
154-
71+
ValueError: if there are fewer than 4 elements in the chain
72+
"""
73+
if len(chain) < 4:
74+
raise ValueError(f"ess(chain) requires len(chain) >=4, but {len(chain) = }")
75+
return len(chain) / iat(chain)

bayes_kit/iat.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import numpy as np
2+
import numpy.typing as npt
3+
import bayes_kit.autocorr as autocorr
4+
5+
FloatType = np.float64
6+
IntType = np.int64
7+
VectorType = npt.NDArray[FloatType]
8+
9+
10+
def _end_pos_pairs(acor: VectorType) -> IntType:
11+
"""
12+
Return the index 1 past the last positive pair of autocorrelations
13+
starting on an even index. The sequence `acor` should contain
14+
autocorrelations from a Markov chain with values at the lag given by
15+
the index (i.e., `acor[0]` is autocorrelation at lag 0 and `acor[5]`
16+
is autocorrelation at lag 5).
17+
18+
The even index pairs are (0, 1), (2, 3), (4, 5), ... This function
19+
scans the pairs in order, and returns 1 plus the second index of the
20+
last such pair that has a positive sum.
21+
22+
Examples:
23+
```python
24+
_end_pos_pairs([]) = 0
25+
_end_pos_pairs([1]) = 0
26+
_end_pos_pairs([1, 0.4]) = 2
27+
_end_pos_pairs([1, -0.4]) = 2
28+
_end_pos_pairs([1, -0.5, 0.25, -0.3]) == 2
29+
_end_pos_pairs([1, -0.5, 0.25, -0.1]) == 4
30+
_end_pos_pairs([1, -0.5, 0.25, -0.3, 0.05]) == 2
31+
_end_pos_pairs([1, -0.5, 0.25, -0.1, 0.05]) == 4
32+
```
33+
34+
Parameters:
35+
acor (VectorType): Input sequence of autocorrelations at lag given by index.
36+
37+
Returns:
38+
The index 1 past the last positive pair of values starting on an even index.
39+
"""
40+
N = len(acor)
41+
n = 0
42+
while n + 1 < N:
43+
if acor[n] + acor[n + 1] < 0:
44+
return n
45+
n += 2
46+
return n
47+
48+
49+
def iat_ipse(chain: VectorType) -> FloatType:
50+
"""
51+
Return an estimate of the integrated autocorrelation time (IAT)
52+
of the specified Markov chain using the initial positive sequence
53+
estimator (IPSE).
54+
55+
The integrated autocorrelation time of a chain is defined to be
56+
the sum of the autocorrelations at every lag (positive and negative).
57+
If `autocorr[n]` is the autocorrelation at lag `n`, then
58+
59+
```
60+
IAT = SUM_{n in Z} autocorr[n],
61+
```
62+
63+
where `Z = {..., -2, -1, 0, 1, 2, ...}` is the set of integers.
64+
65+
Because the autocorrelations are symmetric, `autocorr[n] == autocorr[-n]` and
66+
`autocorr[0] = 1`, if we double count the non-negative entries, we will have
67+
counted `autocorr[0]`, which is 1, twice, so we subtract 1, to get
68+
69+
```
70+
IAT = -1 + 2 * SUM_{n in Nat} autocorr[n],
71+
```
72+
73+
where `Nat = {0, 1, 2, ...}` is the set of natural numbers.
74+
75+
References:
76+
Geyer, Charles J. 2011. “Introduction to Markov Chain Monte Carlo.”
77+
In Handbook of Markov Chain Monte Carlo, edited by Steve Brooks,
78+
Andrew Gelman, Galin L. Jones, and Xiao-Li Meng, 3–48. Chapman;
79+
Hall/CRC.
80+
81+
Parameters:
82+
chain: A Markov chain.
83+
84+
Return:
85+
An estimate of the integrated autocorrelation time (IAT) for the specified chain.
86+
87+
Raises:
88+
ValueError: if there are fewer than 4 elements in the chain
89+
"""
90+
if len(chain) < 4:
91+
raise ValueError(f"ess requires len(chains) >= 4, but {len(chain)=}")
92+
acor = autocorr(chain)
93+
n = _end_pos_pairs(acor)
94+
return 2 * acor[0:n].sum() - 1
95+
96+
97+
def iat_imse(chain: VectorType) -> FloatType:
98+
"""
99+
Return an estimate of the integrated autocorrelation time (IAT)
100+
of the specified Markov chain using the initial monotone sequence
101+
estimator (IMSE).
102+
103+
The IMSE imposes a monotonic downward condition on the sum of pairs,
104+
replacing each sum with the minimum of the sum and the minimum of
105+
the previous sums.
106+
107+
References:
108+
Geyer, C.J., 1992. Practical Markov chain Monte Carlo. Statistical Science
109+
7(4):473--483.
110+
111+
Geyer, Charles J. 2011. “Introduction to Markov Chain Monte Carlo.”
112+
In Handbook of Markov Chain Monte Carlo, edited by Steve Brooks,
113+
Andrew Gelman, Galin L. Jones, and Xiao-Li Meng, 3–48. Chapman;
114+
Hall/CRC.
115+
116+
Parameters:
117+
chain: A Markov chain.
118+
119+
Return:
120+
An estimate of integrated autocorrelation time (IAT) for the specified chain.
121+
122+
Throws:
123+
ValueError: If there are fewer than 4 elements in the chain.
124+
"""
125+
if len(chain) < 4:
126+
raise ValueError(f"iat requires len(chains) >=4, but {len(chain) = }")
127+
acor = autocorr(chain)
128+
n = _end_pos_pairs(acor)
129+
prev_min = acor[0] + acor[1]
130+
acor_sum = prev_min
131+
i = 2
132+
while i + 1 < n:
133+
# enforce monotone downward condition (slow loop)
134+
prev_min = min(prev_min, acor[i] + acor[i + 1])
135+
acor_sum += prev_min
136+
i += 2
137+
return 2 * acor_sum - 1
138+
139+
140+
def iat(chain: VectorType) -> FloatType:
141+
"""
142+
Return an estimate of the integrated autocorrelation time (IAT)
143+
of the specified Markov chain. Evaluated by delegating to the
144+
initial monotone sequence estimator, `iat_imse(chain)`.
145+
146+
The IAT can be less than one in cases where the Markov chain is
147+
anti-correlated.
148+
149+
Parameters:
150+
chain: A Markov chain.
151+
152+
Return:
153+
The integrated autocorrelation time (IAT) for the specified chain.
154+
155+
Throws:
156+
ValueError: If there are fewer than 4 elements in the chain.
157+
"""
158+
return iat_imse(chain)

0 commit comments

Comments
 (0)