Skip to content

Commit a8e4f60

Browse files
Merge pull request #9 from flatironinstitute/fix/more-typing
Correct type annotations in all but rhat/ess
2 parents b821fcd + 9177a0b commit a8e4f60

File tree

10 files changed

+37
-27
lines changed

10 files changed

+37
-27
lines changed

bayes_kit/hmc.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from typing import Optional, Tuple, Union
1+
from typing import Iterator, Optional, Union
22
from numpy.typing import NDArray
33
import numpy as np
44

55
from .model_types import GradModel
66

7+
Sample = tuple[NDArray[np.float64], float]
8+
79

810
class HMCDiag:
911
def __init__(
@@ -13,8 +15,7 @@ def __init__(
1315
steps: int,
1416
metric_diag: Optional[NDArray[np.float64]] = None,
1517
init: Optional[NDArray[np.float64]] = None,
16-
seed: Union[None, int, np.random.BitGenerator, np.random.Generator] = None
17-
18+
seed: Union[None, int, np.random.BitGenerator, np.random.Generator] = None,
1819
):
1920
self._model = model
2021
self._dim = self._model.dims()
@@ -24,20 +25,19 @@ def __init__(
2425
self._rand = np.random.default_rng(seed)
2526
self._theta = init or self._rand.normal(size=self._dim)
2627

27-
def __iter__(self):
28+
def __iter__(self) -> Iterator[Sample]:
2829
return self
2930

30-
def __next__(self):
31+
def __next__(self) -> Sample:
3132
return self.sample()
3233

3334
def joint_logp(self, theta: NDArray[np.float64], rho: NDArray[np.float64]) -> float:
34-
return self._model.log_density(theta) - 0.5 * np.dot(
35-
rho, np.multiply(self._metric, rho)
36-
)
35+
adj: float = 0.5 * np.dot(rho, self._metric * rho)
36+
return self._model.log_density(theta) - adj
3737

3838
def leapfrog(
3939
self, theta: NDArray[np.float64], rho: NDArray[np.float64]
40-
) -> Tuple[NDArray[np.float64], NDArray[np.float64]]:
40+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
4141
# TODO(bob-carpenter): refactor to share non-initial and non-final updates
4242
for n in range(self._steps):
4343
lp, grad = self._model.log_density_gradient(theta)
@@ -47,7 +47,7 @@ def leapfrog(
4747
rho = rho_mid + 0.5 * self._stepsize * np.multiply(self._metric, grad)
4848
return (theta, rho)
4949

50-
def sample(self) -> Tuple[NDArray[np.float64], float]:
50+
def sample(self) -> Sample:
5151
rho = self._rand.normal(size=self._dim)
5252
logp = self.joint_logp(self._theta, rho)
5353
theta_prop, rho_prop = self.leapfrog(self._theta, rho)

bayes_kit/mala.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(
1313
model: GradModel,
1414
epsilon: float,
1515
init: Optional[NDArray[np.float64]] = None,
16-
seed: Union[None, int, np.random.BitGenerator, np.random.Generator] = None
16+
seed: Union[None, int, np.random.BitGenerator, np.random.Generator] = None,
1717
):
1818
self._model = model
1919
self._epsilon = epsilon
@@ -50,6 +50,12 @@ def sample(self) -> Sample:
5050

5151
return self._theta, self._log_p_theta
5252

53-
def correction(self, theta_prime: NDArray[np.float64], theta: NDArray[np.float64], grad_theta: NDArray[np.float64]) -> float:
53+
def correction(
54+
self,
55+
theta_prime: NDArray[np.float64],
56+
theta: NDArray[np.float64],
57+
grad_theta: NDArray[np.float64],
58+
) -> float:
5459
x = theta_prime - theta - self._epsilon * grad_theta
55-
return (-0.25 / self._epsilon) * x.dot(x)
60+
dot_self: float = x.dot(x)
61+
return (-0.25 / self._epsilon) * dot_self

bayes_kit/rhat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def rhat(chains: list[SeqType]) -> FloatType:
1616
R-hat was introduced in this paper.
1717
1818
Gelman, A. and Rubin, D. B., 1992. Inference from iterative simulation using
19-
multiple sequences. Statistical Science, 457--472.
19+
multiple sequences. Statistical Science, 457--472.
2020
2121
Parameters:
2222
chains: list of univariate Markov chains
@@ -33,7 +33,7 @@ def rhat(chains: list[SeqType]) -> FloatType:
3333
mean_chain_length = np.mean(chain_lengths)
3434
means = [np.mean(chain) for chain in chains]
3535
vars = [np.var(chain, ddof=1) for chain in chains]
36-
r_hat = np.sqrt(
36+
r_hat: np.float64 = np.sqrt(
3737
(mean_chain_length - 1) / mean_chain_length + np.var(means, ddof=1) / np.mean(vars)
3838
)
3939
return r_hat

bayes_kit/rwm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from typing import Callable, Optional, Tuple, Union
1+
from typing import Callable, Iterator, Optional, Union
22
from numpy.typing import NDArray, ArrayLike
33
import numpy as np
44

55
from .model_types import LogDensityModel
66

7+
Sample = tuple[NDArray[np.float64], float]
8+
79

810
class RandomWalkMetropolis:
911
def __init__(
@@ -20,13 +22,13 @@ def __init__(
2022
self._theta = init or self._rand.normal(size=self._dim)
2123
self._log_p_theta = self._model.log_density(self._theta)
2224

23-
def __iter__(self):
25+
def __iter__(self) -> Iterator[Sample]:
2426
return self
2527

26-
def __next__(self):
28+
def __next__(self) -> Sample:
2729
return self.sample()
2830

29-
def sample(self) -> Tuple[NDArray[np.float64], float]:
31+
def sample(self) -> Sample:
3032
# does not include initial value as first draw
3133
theta_star = np.asanyarray(self._proposal_rng(self._theta))
3234
log_p_theta_star = self._model.log_density(theta_star)

bayes_kit/smc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ def time(self, n: int) -> float:
3939
return n / self.N
4040

4141
def transition(self, n: int) -> None:
42-
def lpminus1(theta):
42+
def lpminus1(theta: Vector) -> float:
4343
return self.log_likelihood(theta) * self.time(n - 1) + self.log_prior(theta)
4444

45-
def lp(theta):
45+
def lp(theta: Vector) -> float:
4646
return self.log_likelihood(theta) * self.time(n) + self.log_prior(theta)
4747

4848
# note: try to do this in parallel

test/models/beta_binomial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ def log_density(self, theta: npt.NDArray[np.float64]) -> float:
2626
return self.log_likelihood(theta) + self.log_prior(theta)
2727

2828
def log_prior(self, theta: npt.NDArray[np.float64]) -> float:
29-
return stats.beta.logpdf(theta[0], self.alpha, self.beta)
29+
return stats.beta.logpdf(theta[0], self.alpha, self.beta) # type: ignore # scipy is not typed
3030

3131
def log_likelihood(self, theta: npt.NDArray[np.float64]) -> float:
32-
return stats.binom.logpmf(self.x, self.N, theta[0])
32+
return stats.binom.logpmf(self.x, self.N, theta[0]) # type: ignore # scipy is not typed
3333

3434
def initial_state(self, _: int) -> npt.NDArray[np.float64]:
3535
return self._rand.beta(self.alpha, self.beta, size=1)

test/models/std_normal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def log_density_gradient(
1515
) -> tuple[float, npt.NDArray[np.float64]]:
1616
return -0.5 * params_unc[0] * params_unc[0], -params_unc
1717

18-
def posterior_mean(self):
18+
def posterior_mean(self) -> float:
1919
return 0
2020

21-
def posterior_variance(self):
21+
def posterior_variance(self) -> float:
2222
return 1

test/test_hmc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_hmc_diag_std_normal() -> None:
1818
np.testing.assert_allclose(mean, model.posterior_mean(), atol=0.1)
1919
np.testing.assert_allclose(var, model.posterior_variance(), atol=0.1)
2020

21+
2122
def test_hmc_diag_repr() -> None:
2223
init = np.random.normal(loc=0, scale=1, size=[1])
2324
model = StdNormal()

test/test_rwm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from bayes_kit.rwm import RandomWalkMetropolis
33
import numpy as np
44

5+
56
def test_rwm_std_normal() -> None:
67
# init with draw from posterior
78
init = np.random.normal(loc=0, scale=1, size=[1])
@@ -15,7 +16,7 @@ def test_rwm_std_normal() -> None:
1516
np.testing.assert_allclose(mean, model.posterior_mean(), atol=0.1)
1617
np.testing.assert_allclose(var, model.posterior_variance(), atol=0.1)
1718

18-
accept = M - (draws[:M-1] == draws[1:]).sum()
19+
accept = M - (draws[: M - 1] == draws[1:]).sum()
1920
print(f"{accept=}")
2021
print(f"{draws[1:10]=}")
2122
print(f"{mean=} {var=}")

test/test_tempered_smc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44

55

6-
def test_rwm_smc_beta_binom():
6+
def test_rwm_smc_beta_binom() -> None:
77
model = BetaBinom()
88
M = 75
99
N = 10

0 commit comments

Comments
 (0)