1- from typing import Optional , Tuple , Union
1+ from typing import Iterator , Optional , Union
22from numpy .typing import NDArray
33import numpy as np
44
55from .model_types import GradModel
66
7+ Sample = tuple [NDArray [np .float64 ], float ]
8+
79
810class HMCDiag :
911 def __init__ (
@@ -13,8 +15,7 @@ def __init__(
1315 steps : int ,
1416 metric_diag : Optional [NDArray [np .float64 ]] = None ,
1517 init : Optional [NDArray [np .float64 ]] = None ,
16- seed : Union [None , int , np .random .BitGenerator , np .random .Generator ] = None
17-
18+ seed : Union [None , int , np .random .BitGenerator , np .random .Generator ] = None ,
1819 ):
1920 self ._model = model
2021 self ._dim = self ._model .dims ()
@@ -24,20 +25,19 @@ def __init__(
2425 self ._rand = np .random .default_rng (seed )
2526 self ._theta = init or self ._rand .normal (size = self ._dim )
2627
27- def __iter__ (self ):
28+ def __iter__ (self ) -> Iterator [ Sample ] :
2829 return self
2930
30- def __next__ (self ):
31+ def __next__ (self ) -> Sample :
3132 return self .sample ()
3233
3334 def joint_logp (self , theta : NDArray [np .float64 ], rho : NDArray [np .float64 ]) -> float :
34- return self ._model .log_density (theta ) - 0.5 * np .dot (
35- rho , np .multiply (self ._metric , rho )
36- )
35+ adj : float = 0.5 * np .dot (rho , self ._metric * rho )
36+ return self ._model .log_density (theta ) - adj
3737
3838 def leapfrog (
3939 self , theta : NDArray [np .float64 ], rho : NDArray [np .float64 ]
40- ) -> Tuple [NDArray [np .float64 ], NDArray [np .float64 ]]:
40+ ) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ]]:
4141 # TODO(bob-carpenter): refactor to share non-initial and non-final updates
4242 for n in range (self ._steps ):
4343 lp , grad = self ._model .log_density_gradient (theta )
@@ -47,7 +47,7 @@ def leapfrog(
4747 rho = rho_mid + 0.5 * self ._stepsize * np .multiply (self ._metric , grad )
4848 return (theta , rho )
4949
50- def sample (self ) -> Tuple [ NDArray [ np . float64 ], float ] :
50+ def sample (self ) -> Sample :
5151 rho = self ._rand .normal (size = self ._dim )
5252 logp = self .joint_logp (self ._theta , rho )
5353 theta_prop , rho_prop = self .leapfrog (self ._theta , rho )
0 commit comments