diff --git a/src/rydstate/__init__.py b/src/rydstate/__init__.py index 4088e9f..4db6a45 100644 --- a/src/rydstate/__init__.py +++ b/src/rydstate/__init__.py @@ -1,4 +1,5 @@ from rydstate import angular, radial, species +from rydstate.basis import BasisSQDTAlkali, BasisSQDTAlkalineLS from rydstate.rydberg import ( RydbergStateSQDT, RydbergStateSQDTAlkali, @@ -8,6 +9,8 @@ from rydstate.units import ureg __all__ = [ + "BasisSQDTAlkali", + "BasisSQDTAlkalineLS", "RydbergStateSQDT", "RydbergStateSQDTAlkali", "RydbergStateSQDTAlkalineJJ", diff --git a/src/rydstate/angular/angular_ket.py b/src/rydstate/angular/angular_ket.py index 62ceeba..58b92c1 100644 --- a/src/rydstate/angular/angular_ket.py +++ b/src/rydstate/angular/angular_ket.py @@ -2,15 +2,15 @@ import logging from abc import ABC -from typing import TYPE_CHECKING, Any, ClassVar, Literal, get_args, overload +from typing import TYPE_CHECKING, Any, ClassVar, Literal, overload from rydstate.angular.angular_matrix_element import ( - AngularMomentumQuantumNumbers, - AngularOperatorType, calc_prefactor_of_operator_in_coupled_scheme, calc_reduced_identity_matrix_element, calc_reduced_spherical_matrix_element, calc_reduced_spin_matrix_element, + is_angular_momentum_quantum_number, + is_angular_operator_type, ) from rydstate.angular.utils import ( calc_wigner_3j, @@ -26,6 +26,7 @@ if TYPE_CHECKING: from typing_extensions import Self + from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType from rydstate.angular.angular_state import AngularState logger = logging.getLogger(__name__) @@ -186,6 +187,34 @@ def get_qn(self, qn: AngularMomentumQuantumNumbers) -> float: raise ValueError(f"Quantum number {qn} not found in {self!r}.") return getattr(self, qn) # type: ignore [no-any-return] + def calc_exp_qn(self, qn: AngularMomentumQuantumNumbers) -> float: + """Calculate the expectation value of a quantum number qn. + + If the quantum number is a good quantum number simply return it, + otherwise calculate it, see also AngularState.calc_exp_qn for more details. + + Args: + qn: The quantum number to calculate the expectation value for. + + """ + if qn in self.quantum_number_names: + return self.get_qn(qn) + return self.to_state().calc_exp_qn(qn) + + def calc_std_qn(self, qn: AngularMomentumQuantumNumbers) -> float: + """Calculate the standard deviation of a quantum number qn. + + If the quantum number is a good quantum number return 0, + otherwise calculate the std, see also AngularState.calc_std_qn for more details. + + Args: + qn: The quantum number to calculate the standard deviation for. + + """ + if qn in self.quantum_number_names: + return 0 + return self.to_state().calc_std_qn(qn) + @overload def to_state(self, coupling_scheme: Literal["LS"]) -> AngularState[AngularKetLS]: ... @@ -382,12 +411,12 @@ def calc_reduced_matrix_element( # noqa: C901 \left\langle self || \hat{O}^{(\kappa)} || other \right\rangle """ - if operator not in get_args(AngularOperatorType): + if not is_angular_operator_type(operator): raise NotImplementedError(f"calc_reduced_matrix_element is not implemented for operator {operator}.") if type(self) is not type(other): return self.to_state().calc_reduced_matrix_element(other.to_state(), operator, kappa) - if operator in get_args(AngularMomentumQuantumNumbers) and operator not in self.quantum_number_names: + if is_angular_momentum_quantum_number(operator) and operator not in self.quantum_number_names: return self.to_state().calc_reduced_matrix_element(other.to_state(), operator, kappa) qn_name: AngularMomentumQuantumNumbers diff --git a/src/rydstate/angular/angular_matrix_element.py b/src/rydstate/angular/angular_matrix_element.py index 1b69dc3..e07ac23 100644 --- a/src/rydstate/angular/angular_matrix_element.py +++ b/src/rydstate/angular/angular_matrix_element.py @@ -2,9 +2,10 @@ import math from functools import lru_cache -from typing import TYPE_CHECKING, Callable, Literal, TypeVar +from typing import TYPE_CHECKING, Callable, Literal, TypeVar, get_args import numpy as np +from typing_extensions import TypeGuard from rydstate.angular.utils import calc_wigner_3j, calc_wigner_6j, minus_one_pow @@ -41,6 +42,16 @@ def lru_cache(maxsize: int) -> Callable[[Callable[P, R]], Callable[P, R]]: ... ] +def is_angular_momentum_quantum_number(qn: str) -> TypeGuard[AngularMomentumQuantumNumbers]: + """Check if the given string is an AngularMomentumQuantumNumbers.""" + return qn in get_args(AngularMomentumQuantumNumbers) + + +def is_angular_operator_type(qn: str) -> TypeGuard[AngularOperatorType]: + """Check if the given string is an AngularOperatorType.""" + return qn in get_args(AngularOperatorType) + + @lru_cache(maxsize=10_000) def calc_reduced_spherical_matrix_element(l_r_final: int, l_r_initial: int, kappa: int) -> float: r"""Calculate the reduced spherical matrix element (l_r_final || \hat{Y}_{k} || l_r_initial). diff --git a/src/rydstate/angular/angular_state.py b/src/rydstate/angular/angular_state.py index dd05493..e3978fb 100644 --- a/src/rydstate/angular/angular_state.py +++ b/src/rydstate/angular/angular_state.py @@ -2,7 +2,7 @@ import logging import math -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args, overload +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload import numpy as np @@ -12,15 +12,15 @@ AngularKetJJ, AngularKetLS, ) -from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers +from rydstate.angular.angular_matrix_element import is_angular_momentum_quantum_number if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Iterator, Sequence from typing_extensions import Self from rydstate.angular.angular_ket import CouplingScheme - from rydstate.angular.angular_matrix_element import AngularOperatorType + from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType logger = logging.getLogger(__name__) @@ -30,13 +30,15 @@ class AngularState(Generic[_AngularKet]): def __init__( - self, coefficients: list[float], kets: list[_AngularKet], *, warn_if_not_normalized: bool = True + self, coefficients: Sequence[float], kets: Sequence[_AngularKet], *, warn_if_not_normalized: bool = True ) -> None: self.coefficients = np.array(coefficients) self.kets = kets if len(coefficients) != len(kets): raise ValueError("Length of coefficients and kets must be the same.") + if len(kets) == 0: + raise ValueError("At least one ket must be provided.") if not all(type(ket) is type(kets[0]) for ket in kets): raise ValueError("All kets must be of the same type.") if len(set(kets)) != len(kets): @@ -164,7 +166,7 @@ def calc_reduced_matrix_element( """ if isinstance(other, AngularKetBase): other = other.to_state() - if operator in get_args(AngularMomentumQuantumNumbers) and operator not in self.kets[0].quantum_number_names: + if is_angular_momentum_quantum_number(operator) and operator not in self.kets[0].quantum_number_names: for ket_class in [AngularKetLS, AngularKetJJ, AngularKetFJ]: if operator in ket_class.quantum_number_names: return self.to(ket_class.coupling_scheme).calc_reduced_matrix_element(other, operator, kappa) diff --git a/src/rydstate/basis/__init__.py b/src/rydstate/basis/__init__.py new file mode 100644 index 0000000..e2c87e3 --- /dev/null +++ b/src/rydstate/basis/__init__.py @@ -0,0 +1,3 @@ +from rydstate.basis.basis_sqdt import BasisSQDTAlkali, BasisSQDTAlkalineLS + +__all__ = ["BasisSQDTAlkali", "BasisSQDTAlkalineLS"] diff --git a/src/rydstate/basis/basis_base.py b/src/rydstate/basis/basis_base.py new file mode 100644 index 0000000..e6fe5d2 --- /dev/null +++ b/src/rydstate/basis/basis_base.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload + +import numpy as np +from typing_extensions import Self + +from rydstate.angular.angular_matrix_element import is_angular_momentum_quantum_number +from rydstate.rydberg.rydberg_base import RydbergStateBase +from rydstate.species.species_object import SpeciesObject +from rydstate.units import ureg + +if TYPE_CHECKING: + from rydstate.units import MatrixElementOperator, NDArray, PintArray, PintFloat + +_RydbergState = TypeVar("_RydbergState", bound=RydbergStateBase) + + +class BasisBase(ABC, Generic[_RydbergState]): + states: list[_RydbergState] + + def __init__(self, species: str | SpeciesObject) -> None: + if isinstance(species, str): + species = SpeciesObject.from_name(species) + self.species = species + + def __len__(self) -> int: + return len(self.states) + + def copy(self) -> Self: + new_basis = self.__class__.__new__(self.__class__) + new_basis.species = self.species + new_basis.states = list(self.states) + return new_basis + + @overload + def filter_states(self, qn: str, value: tuple[float, float], *, delta: float = 1e-10) -> Self: ... + + @overload + def filter_states(self, qn: str, value: float, *, delta: float = 1e-10) -> Self: ... + + def filter_states(self, qn: str, value: float | tuple[float, float], *, delta: float = 1e-10) -> Self: + if isinstance(value, tuple): + qn_min = value[0] - delta + qn_max = value[1] + delta + else: + qn_min = value - delta + qn_max = value + delta + + if is_angular_momentum_quantum_number(qn): + self.states = [state for state in self.states if qn_min <= state.angular.calc_exp_qn(qn) <= qn_max] + elif qn in ["n", "nu", "nu_energy"]: + self.states = [state for state in self.states if qn_min <= getattr(state, qn) <= qn_max] + else: + raise ValueError(f"Unknown quantum number {qn}") + + return self + + def sort_states(self, *qns: str) -> Self: + """Sort the basis states according to the given quantum numbers. + + The first quantum number given is the primary sorting key, the second quantum number + is the secondary sorting key, and so on. + """ + values = np.array([self.calc_exp_qn(qn) for qn in qns]) + sorted_indices = np.lexsort(values[::-1]) + self.states = [self.states[i] for i in sorted_indices] + return self + + def calc_exp_qn(self, qn: str) -> list[float]: + if is_angular_momentum_quantum_number(qn): + return [state.angular.calc_exp_qn(qn) for state in self.states] + if qn in ["n", "nu", "nu_energy"]: + return [getattr(state, qn) for state in self.states] + raise ValueError(f"Unknown quantum number {qn}") + + def calc_std_qn(self, qn: str) -> list[float]: + if is_angular_momentum_quantum_number(qn): + return [state.angular.calc_std_qn(qn) for state in self.states] + if qn in ["n", "nu", "nu_energy"]: + return [0 for state in self.states] + raise ValueError(f"Unknown quantum number {qn}") + + def calc_reduced_overlap(self, other: RydbergStateBase) -> NDArray: + """Calculate the reduced overlap (ignoring the magnetic quantum number m).""" + return np.array([bra.calc_reduced_overlap(other) for bra in self.states]) + + def calc_reduced_overlaps(self, other: BasisBase[Any]) -> NDArray: + """Calculate the reduced overlap for all states in the bases self and other. + + Returns a numpy array overlaps, where overlaps[i,j] corresponds to the overlap of the + i-th state of self and the j-th state of other. + """ + return np.array([[bra.calc_reduced_overlap(ket) for ket in other.states] for bra in self.states]) + + @overload + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: None = None + ) -> PintArray: ... + + @overload + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str + ) -> NDArray: ... + + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str | None = None + ) -> PintArray | NDArray: + r"""Calculate the reduced matrix element.""" + values_list = [bra.calc_reduced_matrix_element(other, operator, unit=unit) for bra in self.states] + if unit is not None: + return np.array(values_list) + + values: list[PintFloat] = values_list # type: ignore[assignment] + _unit = values[0].units + _values = np.array([v.magnitude for v in values]) + return ureg.Quantity(_values, _unit) + + @overload + def calc_reduced_matrix_elements( + self, other: BasisBase[Any], operator: MatrixElementOperator, unit: None = None + ) -> PintArray: ... + + @overload + def calc_reduced_matrix_elements( + self, other: BasisBase[Any], operator: MatrixElementOperator, unit: str + ) -> NDArray: ... + + def calc_reduced_matrix_elements( + self, other: BasisBase[Any], operator: MatrixElementOperator, unit: str | None = None + ) -> PintArray | NDArray: + r"""Calculate the reduced matrix element.""" + values_list = [ + [bra.calc_reduced_matrix_element(ket, operator, unit=unit) for ket in other.states] for bra in self.states + ] + if unit is not None: + return np.array(values_list) + + values: list[list[PintFloat]] = values_list # type: ignore[assignment] + _unit = values[0][0].units + _values = np.array([[v.magnitude for v in vs] for vs in values]) + return ureg.Quantity(_values, _unit) diff --git a/src/rydstate/basis/basis_sqdt.py b/src/rydstate/basis/basis_sqdt.py new file mode 100644 index 0000000..d936b58 --- /dev/null +++ b/src/rydstate/basis/basis_sqdt.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import numpy as np + +from rydstate.basis.basis_base import BasisBase +from rydstate.rydberg import RydbergStateSQDTAlkali, RydbergStateSQDTAlkalineLS + + +class BasisSQDTAlkali(BasisBase[RydbergStateSQDTAlkali]): + def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> None: + super().__init__(species) + + if n_max is None: + raise ValueError("n_max must be given") + + s = 1 / 2 + i_c = self.species.i_c if self.species.i_c is not None else 0 + + self.states = [] + for n in range(n_min, n_max + 1): + for l in range(n): + if not self.species.is_allowed_shell(n, l, s): + continue + for j in np.arange(abs(l - s), l + s + 1): + for f in np.arange(abs(j - i_c), j + i_c + 1): + state = RydbergStateSQDTAlkali(species, n=n, l=l, j=float(j), f=float(f)) + self.states.append(state) + + +class BasisSQDTAlkalineLS(BasisBase[RydbergStateSQDTAlkalineLS]): + def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> None: + super().__init__(species) + + if n_max is None: + raise ValueError("n_max must be given") + + i_c = self.species.i_c if self.species.i_c is not None else 0 + + self.states = [] + for n in range(n_min, n_max + 1): + for l in range(n): + for s_tot in [0, 1]: + if not self.species.is_allowed_shell(n, l, s_tot): + continue + for j_tot in range(abs(l - s_tot), l + s_tot + 1): + for f_tot in np.arange(abs(j_tot - i_c), j_tot + i_c + 1): + state = RydbergStateSQDTAlkalineLS( + species, n=n, l=l, s_tot=s_tot, j_tot=j_tot, f_tot=float(f_tot) + ) + self.states.append(state) diff --git a/src/rydstate/rydberg/rydberg_sqdt.py b/src/rydstate/rydberg/rydberg_sqdt.py index 0ca1a9e..538db0e 100644 --- a/src/rydstate/rydberg/rydberg_sqdt.py +++ b/src/rydstate/rydberg/rydberg_sqdt.py @@ -319,7 +319,8 @@ def __init__( def __repr__(self) -> str: species, n, l, j, f, m = self.species, self.n, self.l, self.j, self.f, self.m - return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j=}, {f=}, {m=})" + f_string = f", {f=}" if self.species.i_c not in (None, 0) else "" + return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j=}{f_string}, {m=})" class RydbergStateSQDTAlkalineLS(RydbergStateSQDT): diff --git a/tests/test_basis.py b/tests/test_basis.py new file mode 100644 index 0000000..d8ae608 --- /dev/null +++ b/tests/test_basis.py @@ -0,0 +1,64 @@ +import numpy as np +import pytest +from rydstate import BasisSQDTAlkali +from rydstate.basis.basis_sqdt import BasisSQDTAlkalineLS + + +@pytest.mark.parametrize("species_name", ["Rb", "Na", "H"]) +def test_alkali_basis(species_name: str) -> None: + """Test alkali basis creation.""" + basis = BasisSQDTAlkali(species_name, n_min=1, n_max=20) + basis.sort_states("n", "l_r") + lowest_n_state = {"Rb": (4, 2), "Na": (3, 0), "H": (1, 0)}[species_name] + assert (basis.states[0].n, basis.states[0].l) == lowest_n_state + assert (basis.states[-1].n, basis.states[-1].l) == (20, 19) + assert len(basis.states) == {"Rb": 388, "Na": 396, "H": 400}[species_name] + + state0 = basis.states[0] + ov = basis.calc_reduced_overlap(state0) + compare_ov = np.zeros(len(basis.states)) + compare_ov[0] = 1.0 + assert np.allclose(ov, compare_ov, atol=1e-3) + + me = basis.calc_reduced_matrix_element(state0, "electric_dipole", unit="e a0") + assert np.shape(me) == (len(basis.states),) + assert np.count_nonzero(me) > 0 + + basis.filter_states("n", (1, 7)) + ov_matrix = basis.calc_reduced_overlaps(basis) + assert np.allclose(ov_matrix, np.eye(len(basis.states)), atol=1e-3) + + me_matrix = basis.calc_reduced_matrix_elements(basis, "electric_dipole", unit="e a0") + assert np.shape(me_matrix) == (len(basis.states), len(basis.states)) + assert np.count_nonzero(me_matrix) > 0 + + +@pytest.mark.parametrize("species_name", ["Sr88", "Sr87", "Yb174", "Yb171"]) +def test_alkaline_basis(species_name: str) -> None: + """Test alkaline basis creation.""" + basis = BasisSQDTAlkalineLS(species_name, n_min=30, n_max=35) + basis.sort_states("n", "l_r") + assert (basis.states[0].n, basis.states[0].l) == (30, 0) + assert (basis.states[-1].n, basis.states[-1].l) == (35, 34) + assert len(basis.states) == {"Sr88": 768, "Sr87": 7188, "Yb174": 768, "Yb171": 1524}[species_name] + + if species_name in ["Sr87", "Yb171"]: + pytest.skip("Quantum defects for Sr87 and Yb171 not implemented yet.") + + state0 = basis.states[0] + ov = basis.calc_reduced_overlap(state0) + compare_ov = np.zeros(len(basis.states)) + compare_ov[0] = 1.0 + assert np.allclose(ov, compare_ov, atol=1e-3) + + me = basis.calc_reduced_matrix_element(state0, "electric_dipole", unit="e a0") + assert np.shape(me) == (len(basis.states),) + assert np.count_nonzero(me) > 0 + + basis.filter_states("l_r", (0, 2)) + ov_matrix = basis.calc_reduced_overlaps(basis) + assert np.allclose(ov_matrix, np.eye(len(basis.states)), atol=1e-2) + + me_matrix = basis.calc_reduced_matrix_elements(basis, "electric_dipole", unit="e a0") + assert np.shape(me_matrix) == (len(basis.states), len(basis.states)) + assert np.count_nonzero(me_matrix) > 0