From 723a1baeb7230337477a2c1c5485b5a9f7a65618 Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Wed, 3 Dec 2025 08:52:06 +0100 Subject: [PATCH 01/11] AngularKet add calc_exp_qn --- src/rydstate/angular/angular_ket.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/rydstate/angular/angular_ket.py b/src/rydstate/angular/angular_ket.py index 62ceeba..2269010 100644 --- a/src/rydstate/angular/angular_ket.py +++ b/src/rydstate/angular/angular_ket.py @@ -186,6 +186,34 @@ def get_qn(self, qn: AngularMomentumQuantumNumbers) -> float: raise ValueError(f"Quantum number {qn} not found in {self!r}.") return getattr(self, qn) # type: ignore [no-any-return] + def calc_exp_qn(self, qn: AngularMomentumQuantumNumbers) -> float: + """Calculate the expectation value of a quantum number qn. + + If the quantum number is a good quantum number simply return it, + otherwise calculate it, see also AngularState.calc_exp_qn for more details. + + Args: + qn: The quantum number to calculate the expectation value for. + + """ + if qn in self.quantum_number_names: + return self.get_qn(qn) + return self.to_state().calc_exp_qn(qn) + + def calc_std_qn(self, qn: AngularMomentumQuantumNumbers) -> float: + """Calculate the standard deviation of a quantum number qn. + + If the quantum number is a good quantum number return 0, + otherwise calculate the std, see also AngularState.calc_std_qn for more details. + + Args: + qn: The quantum number to calculate the standard deviation for. + + """ + if qn in self.quantum_number_names: + return 0 + return self.to_state().calc_std_qn(qn) + @overload def to_state(self, coupling_scheme: Literal["LS"]) -> AngularState[AngularKetLS]: ... From 111075f0a0bc2f0511824d00dc588d152878618e Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Thu, 4 Dec 2025 16:06:50 +0100 Subject: [PATCH 02/11] add is_angular_momentum_quantum_number and is_angular_operator_type --- src/rydstate/angular/angular_ket.py | 11 ++++++----- src/rydstate/angular/angular_matrix_element.py | 13 ++++++++++++- src/rydstate/angular/angular_state.py | 8 ++++---- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/rydstate/angular/angular_ket.py b/src/rydstate/angular/angular_ket.py index 2269010..58b92c1 100644 --- a/src/rydstate/angular/angular_ket.py +++ b/src/rydstate/angular/angular_ket.py @@ -2,15 +2,15 @@ import logging from abc import ABC -from typing import TYPE_CHECKING, Any, ClassVar, Literal, get_args, overload +from typing import TYPE_CHECKING, Any, ClassVar, Literal, overload from rydstate.angular.angular_matrix_element import ( - AngularMomentumQuantumNumbers, - AngularOperatorType, calc_prefactor_of_operator_in_coupled_scheme, calc_reduced_identity_matrix_element, calc_reduced_spherical_matrix_element, calc_reduced_spin_matrix_element, + is_angular_momentum_quantum_number, + is_angular_operator_type, ) from rydstate.angular.utils import ( calc_wigner_3j, @@ -26,6 +26,7 @@ if TYPE_CHECKING: from typing_extensions import Self + from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType from rydstate.angular.angular_state import AngularState logger = logging.getLogger(__name__) @@ -410,12 +411,12 @@ def calc_reduced_matrix_element( # noqa: C901 \left\langle self || \hat{O}^{(\kappa)} || other \right\rangle """ - if operator not in get_args(AngularOperatorType): + if not is_angular_operator_type(operator): raise NotImplementedError(f"calc_reduced_matrix_element is not implemented for operator {operator}.") if type(self) is not type(other): return self.to_state().calc_reduced_matrix_element(other.to_state(), operator, kappa) - if operator in get_args(AngularMomentumQuantumNumbers) and operator not in self.quantum_number_names: + if is_angular_momentum_quantum_number(operator) and operator not in self.quantum_number_names: return self.to_state().calc_reduced_matrix_element(other.to_state(), operator, kappa) qn_name: AngularMomentumQuantumNumbers diff --git a/src/rydstate/angular/angular_matrix_element.py b/src/rydstate/angular/angular_matrix_element.py index 1b69dc3..e07ac23 100644 --- a/src/rydstate/angular/angular_matrix_element.py +++ b/src/rydstate/angular/angular_matrix_element.py @@ -2,9 +2,10 @@ import math from functools import lru_cache -from typing import TYPE_CHECKING, Callable, Literal, TypeVar +from typing import TYPE_CHECKING, Callable, Literal, TypeVar, get_args import numpy as np +from typing_extensions import TypeGuard from rydstate.angular.utils import calc_wigner_3j, calc_wigner_6j, minus_one_pow @@ -41,6 +42,16 @@ def lru_cache(maxsize: int) -> Callable[[Callable[P, R]], Callable[P, R]]: ... ] +def is_angular_momentum_quantum_number(qn: str) -> TypeGuard[AngularMomentumQuantumNumbers]: + """Check if the given string is an AngularMomentumQuantumNumbers.""" + return qn in get_args(AngularMomentumQuantumNumbers) + + +def is_angular_operator_type(qn: str) -> TypeGuard[AngularOperatorType]: + """Check if the given string is an AngularOperatorType.""" + return qn in get_args(AngularOperatorType) + + @lru_cache(maxsize=10_000) def calc_reduced_spherical_matrix_element(l_r_final: int, l_r_initial: int, kappa: int) -> float: r"""Calculate the reduced spherical matrix element (l_r_final || \hat{Y}_{k} || l_r_initial). diff --git a/src/rydstate/angular/angular_state.py b/src/rydstate/angular/angular_state.py index dd05493..a6f85da 100644 --- a/src/rydstate/angular/angular_state.py +++ b/src/rydstate/angular/angular_state.py @@ -2,7 +2,7 @@ import logging import math -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args, overload +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload import numpy as np @@ -12,7 +12,7 @@ AngularKetJJ, AngularKetLS, ) -from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers +from rydstate.angular.angular_matrix_element import is_angular_momentum_quantum_number if TYPE_CHECKING: from collections.abc import Iterator @@ -20,7 +20,7 @@ from typing_extensions import Self from rydstate.angular.angular_ket import CouplingScheme - from rydstate.angular.angular_matrix_element import AngularOperatorType + from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType logger = logging.getLogger(__name__) @@ -164,7 +164,7 @@ def calc_reduced_matrix_element( """ if isinstance(other, AngularKetBase): other = other.to_state() - if operator in get_args(AngularMomentumQuantumNumbers) and operator not in self.kets[0].quantum_number_names: + if is_angular_momentum_quantum_number(operator) and operator not in self.kets[0].quantum_number_names: for ket_class in [AngularKetLS, AngularKetJJ, AngularKetFJ]: if operator in ket_class.quantum_number_names: return self.to(ket_class.coupling_scheme).calc_reduced_matrix_element(other, operator, kappa) From df9aea6676e17c05e73a4dc097b467a36c763a0b Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Fri, 19 Dec 2025 16:31:09 +0100 Subject: [PATCH 03/11] small fix angular state --- src/rydstate/angular/angular_state.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/rydstate/angular/angular_state.py b/src/rydstate/angular/angular_state.py index a6f85da..e3978fb 100644 --- a/src/rydstate/angular/angular_state.py +++ b/src/rydstate/angular/angular_state.py @@ -15,7 +15,7 @@ from rydstate.angular.angular_matrix_element import is_angular_momentum_quantum_number if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Iterator, Sequence from typing_extensions import Self @@ -30,13 +30,15 @@ class AngularState(Generic[_AngularKet]): def __init__( - self, coefficients: list[float], kets: list[_AngularKet], *, warn_if_not_normalized: bool = True + self, coefficients: Sequence[float], kets: Sequence[_AngularKet], *, warn_if_not_normalized: bool = True ) -> None: self.coefficients = np.array(coefficients) self.kets = kets if len(coefficients) != len(kets): raise ValueError("Length of coefficients and kets must be the same.") + if len(kets) == 0: + raise ValueError("At least one ket must be provided.") if not all(type(ket) is type(kets[0]) for ket in kets): raise ValueError("All kets must be of the same type.") if len(set(kets)) != len(kets): From cf19032404a76574278f6cd6e748f5c14f8c6120 Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Wed, 3 Dec 2025 08:52:30 +0100 Subject: [PATCH 04/11] start adding sqdt basis --- src/rydstate/__init__.py | 3 ++ src/rydstate/basis/__init__.py | 3 ++ src/rydstate/basis/basis_sqdt.py | 73 ++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+) create mode 100644 src/rydstate/basis/__init__.py create mode 100644 src/rydstate/basis/basis_sqdt.py diff --git a/src/rydstate/__init__.py b/src/rydstate/__init__.py index 4088e9f..4db6a45 100644 --- a/src/rydstate/__init__.py +++ b/src/rydstate/__init__.py @@ -1,4 +1,5 @@ from rydstate import angular, radial, species +from rydstate.basis import BasisSQDTAlkali, BasisSQDTAlkalineLS from rydstate.rydberg import ( RydbergStateSQDT, RydbergStateSQDTAlkali, @@ -8,6 +9,8 @@ from rydstate.units import ureg __all__ = [ + "BasisSQDTAlkali", + "BasisSQDTAlkalineLS", "RydbergStateSQDT", "RydbergStateSQDTAlkali", "RydbergStateSQDTAlkalineJJ", diff --git a/src/rydstate/basis/__init__.py b/src/rydstate/basis/__init__.py new file mode 100644 index 0000000..e2c87e3 --- /dev/null +++ b/src/rydstate/basis/__init__.py @@ -0,0 +1,3 @@ +from rydstate.basis.basis_sqdt import BasisSQDTAlkali, BasisSQDTAlkalineLS + +__all__ = ["BasisSQDTAlkali", "BasisSQDTAlkalineLS"] diff --git a/src/rydstate/basis/basis_sqdt.py b/src/rydstate/basis/basis_sqdt.py new file mode 100644 index 0000000..00ee8db --- /dev/null +++ b/src/rydstate/basis/basis_sqdt.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from abc import ABC +from typing import get_args + +import numpy as np +from typing_extensions import Self + +from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers +from rydstate.rydberg import RydbergStateMQDT, RydbergStateSQDT, RydbergStateSQDTAlkali, RydbergStateSQDTAlkalineLS +from rydstate.species.species_object import SpeciesObject + + +class BasisBase(ABC): + states: list[RydbergStateSQDT | RydbergStateMQDT] + + def __init__(self, species: str | SpeciesObject) -> None: + if isinstance(species, str): + species = SpeciesObject.from_name(species) + self.species = species + + def __len__(self) -> int: + return len(self.states) + + def filter_states(self, qn: str, qn_min: float, qn_max: float) -> Self: + if qn in get_args(AngularMomentumQuantumNumbers): + self.states = [state for state in self.states if qn_min <= state.angular.calc_exp_qn(qn) <= qn_max] + elif qn in ["n", "nu"]: + self.states = [state for state in self.states if qn_min <= getattr(state, qn) <= qn_max] + else: + raise ValueError(f"Unknown quantum number {qn}") + + return self + + +class BasisSQDTAlkali(BasisBase): + def __init__(self, species: str, n_min: int = 1, n_max: int | None = None): + super().__init__(species) + + if n_max is None: + raise ValueError("n_max must be given") + + s = 1 / 2 + i_c = self.species.i_c if self.species.i_c is not None else 0 + states: list[RydbergStateSQDTAlkali] = [] + for n in range(n_min, n_max + 1): + for l in range(n): + for j in np.arange(abs(l - s), l + s + 1): + for f in np.arange(abs(j - i_c), j + i_c + 1): + state = RydbergStateSQDTAlkali(species, n=n, l=l, j=j, f=f) + states.append(state) + + self.states = states + + +class BasisSQDTAlkalineLS(BasisBase): + def __init__(self, species: str, n_min: int = 1, n_max: int | None = None): + super().__init__(species) + + if n_max is None: + raise ValueError("n_max must be given") + + i_c = self.species.i_c if self.species.i_c is not None else 0 + states: list[RydbergStateSQDTAlkalineLS] = [] + for s_tot in [0, 1]: + for n in range(n_min, n_max + 1): + for l in range(n): + for j_tot in range(abs(l - s_tot), l + s_tot + 1): + for f_tot in np.arange(abs(j_tot - i_c), j_tot + i_c + 1): + state = RydbergStateSQDTAlkalineLS(species, n=n, l=l, s_tot=s_tot, j_tot=j_tot, f_tot=f_tot) + states.append(state) + + self.states = states From c24a9234d5b3ad4f7fd9957c917414c764d14ef3 Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Mon, 22 Dec 2025 09:39:00 +0100 Subject: [PATCH 05/11] split up basis base --- src/rydstate/basis/basis_base.py | 48 ++++++++++++++++++++++++++++++++ src/rydstate/basis/basis_sqdt.py | 31 ++------------------- 2 files changed, 50 insertions(+), 29 deletions(-) create mode 100644 src/rydstate/basis/basis_base.py diff --git a/src/rydstate/basis/basis_base.py b/src/rydstate/basis/basis_base.py new file mode 100644 index 0000000..c9b7a4f --- /dev/null +++ b/src/rydstate/basis/basis_base.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, get_args + +from typing_extensions import Self + +from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers +from rydstate.species.species_object import SpeciesObject + +if TYPE_CHECKING: + from rydstate.rydberg import RydbergStateSQDT + + +class BasisBase(ABC): + states: list[RydbergStateSQDT] + + def __init__(self, species: str | SpeciesObject) -> None: + if isinstance(species, str): + species = SpeciesObject.from_name(species) + self.species = species + + def __len__(self) -> int: + return len(self.states) + + def filter_states(self, qn: str, qn_min: float, qn_max: float) -> Self: + if qn in get_args(AngularMomentumQuantumNumbers): + self.states = [state for state in self.states if qn_min <= state.angular.calc_exp_qn(qn) <= qn_max] + elif qn in ["n", "nu", "nu_energy"]: + self.states = [state for state in self.states if qn_min <= getattr(state, qn) <= qn_max] + else: + raise ValueError(f"Unknown quantum number {qn}") + + return self + + def calc_exp_qn(self, qn: str) -> list[float]: + if qn in get_args(AngularMomentumQuantumNumbers): + return [state.angular.calc_exp_qn(qn) for state in self.states] + if qn in ["n", "nu", "nu_energy"]: + return [getattr(state, qn) for state in self.states] + raise ValueError(f"Unknown quantum number {qn}") + + def calc_std_qn(self, qn: str) -> list[float]: + if qn in get_args(AngularMomentumQuantumNumbers): + return [state.angular.calc_std_qn(qn) for state in self.states] + if qn in ["n", "nu", "nu_energy"]: + return [0 for state in self.states] + raise ValueError(f"Unknown quantum number {qn}") diff --git a/src/rydstate/basis/basis_sqdt.py b/src/rydstate/basis/basis_sqdt.py index 00ee8db..5237740 100644 --- a/src/rydstate/basis/basis_sqdt.py +++ b/src/rydstate/basis/basis_sqdt.py @@ -1,36 +1,9 @@ from __future__ import annotations -from abc import ABC -from typing import get_args - import numpy as np -from typing_extensions import Self - -from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers -from rydstate.rydberg import RydbergStateMQDT, RydbergStateSQDT, RydbergStateSQDTAlkali, RydbergStateSQDTAlkalineLS -from rydstate.species.species_object import SpeciesObject - - -class BasisBase(ABC): - states: list[RydbergStateSQDT | RydbergStateMQDT] - - def __init__(self, species: str | SpeciesObject) -> None: - if isinstance(species, str): - species = SpeciesObject.from_name(species) - self.species = species - - def __len__(self) -> int: - return len(self.states) - - def filter_states(self, qn: str, qn_min: float, qn_max: float) -> Self: - if qn in get_args(AngularMomentumQuantumNumbers): - self.states = [state for state in self.states if qn_min <= state.angular.calc_exp_qn(qn) <= qn_max] - elif qn in ["n", "nu"]: - self.states = [state for state in self.states if qn_min <= getattr(state, qn) <= qn_max] - else: - raise ValueError(f"Unknown quantum number {qn}") - return self +from rydstate.basis.basis_base import BasisBase +from rydstate.rydberg import RydbergStateSQDTAlkali, RydbergStateSQDTAlkalineLS class BasisSQDTAlkali(BasisBase): From 434c4a4ecd0f52b558eed4ab8b191f5dd305fadf Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Mon, 22 Dec 2025 09:39:38 +0100 Subject: [PATCH 06/11] BasisBase start adding calc_reduced_... --- src/rydstate/basis/basis_base.py | 81 ++++++++++++++++++++++++++++---- src/rydstate/basis/basis_sqdt.py | 28 +++++------ 2 files changed, 87 insertions(+), 22 deletions(-) diff --git a/src/rydstate/basis/basis_base.py b/src/rydstate/basis/basis_base.py index c9b7a4f..ba9c8ce 100644 --- a/src/rydstate/basis/basis_base.py +++ b/src/rydstate/basis/basis_base.py @@ -1,19 +1,24 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING, get_args +from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload +import numpy as np from typing_extensions import Self -from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers +from rydstate.angular.angular_matrix_element import is_angular_momentum_quantum_number +from rydstate.rydberg.rydberg_base import RydbergStateBase from rydstate.species.species_object import SpeciesObject +from rydstate.units import ureg if TYPE_CHECKING: - from rydstate.rydberg import RydbergStateSQDT + from rydstate.units import MatrixElementOperator, NDArray, PintArray, PintFloat +_RydbergState = TypeVar("_RydbergState", bound=RydbergStateBase) -class BasisBase(ABC): - states: list[RydbergStateSQDT] + +class BasisBase(ABC, Generic[_RydbergState]): + states: list[_RydbergState] def __init__(self, species: str | SpeciesObject) -> None: if isinstance(species, str): @@ -24,7 +29,7 @@ def __len__(self) -> int: return len(self.states) def filter_states(self, qn: str, qn_min: float, qn_max: float) -> Self: - if qn in get_args(AngularMomentumQuantumNumbers): + if is_angular_momentum_quantum_number(qn): self.states = [state for state in self.states if qn_min <= state.angular.calc_exp_qn(qn) <= qn_max] elif qn in ["n", "nu", "nu_energy"]: self.states = [state for state in self.states if qn_min <= getattr(state, qn) <= qn_max] @@ -34,15 +39,75 @@ def filter_states(self, qn: str, qn_min: float, qn_max: float) -> Self: return self def calc_exp_qn(self, qn: str) -> list[float]: - if qn in get_args(AngularMomentumQuantumNumbers): + if is_angular_momentum_quantum_number(qn): return [state.angular.calc_exp_qn(qn) for state in self.states] if qn in ["n", "nu", "nu_energy"]: return [getattr(state, qn) for state in self.states] raise ValueError(f"Unknown quantum number {qn}") def calc_std_qn(self, qn: str) -> list[float]: - if qn in get_args(AngularMomentumQuantumNumbers): + if is_angular_momentum_quantum_number(qn): return [state.angular.calc_std_qn(qn) for state in self.states] if qn in ["n", "nu", "nu_energy"]: return [0 for state in self.states] raise ValueError(f"Unknown quantum number {qn}") + + def calc_reduced_overlap(self, other: RydbergStateBase) -> NDArray: + """Calculate the reduced overlap (ignoring the magnetic quantum number m).""" + return np.array([bra.calc_reduced_overlap(other) for bra in self.states]) + + def calc_reduced_overlaps(self, other: BasisBase[Any]) -> NDArray: + """Calculate the reduced overlap for all states in the bases self and other. + + Returns a numpy array overlaps, where overlaps[i,j] corresponds to the overlap of the + i-th state of self and the j-th state of other. + """ + return np.array([[bra.calc_reduced_overlap(ket) for ket in other.states] for bra in self.states]) + + @overload + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: None = None + ) -> PintArray: ... + + @overload + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str + ) -> NDArray: ... + + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str | None = None + ) -> PintArray | NDArray: + r"""Calculate the reduced matrix element.""" + values_list = [bra.calc_reduced_matrix_element(other, operator, unit=unit) for bra in self.states] + if unit is not None: + return np.array(values_list) + + values: list[PintFloat] = values_list # type: ignore[assignment] + _unit = values[0].units + _values = np.array([v.magnitude for v in values]) + return ureg.Quantity(_values, _unit) + + @overload + def calc_reduced_matrix_elements( + self, other: BasisBase[Any], operator: MatrixElementOperator, unit: None = None + ) -> PintArray: ... + + @overload + def calc_reduced_matrix_elements( + self, other: BasisBase[Any], operator: MatrixElementOperator, unit: str + ) -> NDArray: ... + + def calc_reduced_matrix_elements( + self, other: BasisBase[Any], operator: MatrixElementOperator, unit: str | None = None + ) -> PintArray | NDArray: + r"""Calculate the reduced matrix element.""" + values_list = [ + [bra.calc_reduced_matrix_element(ket, operator, unit=unit) for ket in other.states] for bra in self.states + ] + if unit is not None: + return np.array(values_list) + + values: list[list[PintFloat]] = values_list # type: ignore[assignment] + _unit = values[0][0].units + _values = np.array([[v.magnitude for v in vs] for vs in values]) + return ureg.Quantity(_values, _unit) diff --git a/src/rydstate/basis/basis_sqdt.py b/src/rydstate/basis/basis_sqdt.py index 5237740..5d7e050 100644 --- a/src/rydstate/basis/basis_sqdt.py +++ b/src/rydstate/basis/basis_sqdt.py @@ -6,8 +6,8 @@ from rydstate.rydberg import RydbergStateSQDTAlkali, RydbergStateSQDTAlkalineLS -class BasisSQDTAlkali(BasisBase): - def __init__(self, species: str, n_min: int = 1, n_max: int | None = None): +class BasisSQDTAlkali(BasisBase[RydbergStateSQDTAlkali]): + def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> None: super().__init__(species) if n_max is None: @@ -15,32 +15,32 @@ def __init__(self, species: str, n_min: int = 1, n_max: int | None = None): s = 1 / 2 i_c = self.species.i_c if self.species.i_c is not None else 0 - states: list[RydbergStateSQDTAlkali] = [] + + self.states = [] for n in range(n_min, n_max + 1): for l in range(n): for j in np.arange(abs(l - s), l + s + 1): for f in np.arange(abs(j - i_c), j + i_c + 1): - state = RydbergStateSQDTAlkali(species, n=n, l=l, j=j, f=f) - states.append(state) - - self.states = states + state = RydbergStateSQDTAlkali(species, n=n, l=l, j=float(j), f=float(f)) + self.states.append(state) -class BasisSQDTAlkalineLS(BasisBase): - def __init__(self, species: str, n_min: int = 1, n_max: int | None = None): +class BasisSQDTAlkalineLS(BasisBase[RydbergStateSQDTAlkalineLS]): + def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> None: super().__init__(species) if n_max is None: raise ValueError("n_max must be given") i_c = self.species.i_c if self.species.i_c is not None else 0 - states: list[RydbergStateSQDTAlkalineLS] = [] + + self.states = [] for s_tot in [0, 1]: for n in range(n_min, n_max + 1): for l in range(n): for j_tot in range(abs(l - s_tot), l + s_tot + 1): for f_tot in np.arange(abs(j_tot - i_c), j_tot + i_c + 1): - state = RydbergStateSQDTAlkalineLS(species, n=n, l=l, s_tot=s_tot, j_tot=j_tot, f_tot=f_tot) - states.append(state) - - self.states = states + state = RydbergStateSQDTAlkalineLS( + species, n=n, l=l, s_tot=s_tot, j_tot=j_tot, f_tot=float(f_tot) + ) + self.states.append(state) From 2994ac04da93faea92a09e135a051953b88bce90 Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Mon, 22 Dec 2025 09:40:06 +0100 Subject: [PATCH 07/11] Basis update filter_states and add copy --- src/rydstate/basis/basis_base.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/rydstate/basis/basis_base.py b/src/rydstate/basis/basis_base.py index ba9c8ce..789da0e 100644 --- a/src/rydstate/basis/basis_base.py +++ b/src/rydstate/basis/basis_base.py @@ -28,7 +28,26 @@ def __init__(self, species: str | SpeciesObject) -> None: def __len__(self) -> int: return len(self.states) - def filter_states(self, qn: str, qn_min: float, qn_max: float) -> Self: + def copy(self) -> Self: + new_basis = self.__class__.__new__(self.__class__) + new_basis.species = self.species + new_basis.states = list(self.states) + return new_basis + + @overload + def filter_states(self, qn: str, value: tuple[float, float], *, delta: float = 1e-10) -> Self: ... + + @overload + def filter_states(self, qn: str, value: float, *, delta: float = 1e-10) -> Self: ... + + def filter_states(self, qn: str, value: float | tuple[float, float], *, delta: float = 1e-10) -> Self: + if isinstance(value, tuple): + qn_min = value[0] - delta + qn_max = value[1] + delta + else: + qn_min = value - delta + qn_max = value + delta + if is_angular_momentum_quantum_number(qn): self.states = [state for state in self.states if qn_min <= state.angular.calc_exp_qn(qn) <= qn_max] elif qn in ["n", "nu", "nu_energy"]: From da4923af07fcc03e6cf9fa9bd560653d9674685b Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Tue, 9 Dec 2025 15:36:13 +0100 Subject: [PATCH 08/11] basis add sort_states --- src/rydstate/basis/basis_base.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/rydstate/basis/basis_base.py b/src/rydstate/basis/basis_base.py index 789da0e..e6fe5d2 100644 --- a/src/rydstate/basis/basis_base.py +++ b/src/rydstate/basis/basis_base.py @@ -57,6 +57,17 @@ def filter_states(self, qn: str, value: float | tuple[float, float], *, delta: f return self + def sort_states(self, *qns: str) -> Self: + """Sort the basis states according to the given quantum numbers. + + The first quantum number given is the primary sorting key, the second quantum number + is the secondary sorting key, and so on. + """ + values = np.array([self.calc_exp_qn(qn) for qn in qns]) + sorted_indices = np.lexsort(values[::-1]) + self.states = [self.states[i] for i in sorted_indices] + return self + def calc_exp_qn(self, qn: str) -> list[float]: if is_angular_momentum_quantum_number(qn): return [state.angular.calc_exp_qn(qn) for state in self.states] From 08204637cb140b3e7ebef08ac6d85d7157391f69 Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Mon, 22 Dec 2025 13:56:54 +0100 Subject: [PATCH 09/11] fixup basis sqdt --- src/rydstate/basis/basis_sqdt.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/rydstate/basis/basis_sqdt.py b/src/rydstate/basis/basis_sqdt.py index 5d7e050..d936b58 100644 --- a/src/rydstate/basis/basis_sqdt.py +++ b/src/rydstate/basis/basis_sqdt.py @@ -19,6 +19,8 @@ def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> No self.states = [] for n in range(n_min, n_max + 1): for l in range(n): + if not self.species.is_allowed_shell(n, l, s): + continue for j in np.arange(abs(l - s), l + s + 1): for f in np.arange(abs(j - i_c), j + i_c + 1): state = RydbergStateSQDTAlkali(species, n=n, l=l, j=float(j), f=float(f)) @@ -35,9 +37,11 @@ def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> No i_c = self.species.i_c if self.species.i_c is not None else 0 self.states = [] - for s_tot in [0, 1]: - for n in range(n_min, n_max + 1): - for l in range(n): + for n in range(n_min, n_max + 1): + for l in range(n): + for s_tot in [0, 1]: + if not self.species.is_allowed_shell(n, l, s_tot): + continue for j_tot in range(abs(l - s_tot), l + s_tot + 1): for f_tot in np.arange(abs(j_tot - i_c), j_tot + i_c + 1): state = RydbergStateSQDTAlkalineLS( From 8cdcebfb59213cdf73e38463ded8eaa04daaa704 Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Mon, 22 Dec 2025 13:54:37 +0100 Subject: [PATCH 10/11] improve rydberg sqdt alkali repr --- src/rydstate/rydberg/rydberg_sqdt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/rydstate/rydberg/rydberg_sqdt.py b/src/rydstate/rydberg/rydberg_sqdt.py index 0ca1a9e..538db0e 100644 --- a/src/rydstate/rydberg/rydberg_sqdt.py +++ b/src/rydstate/rydberg/rydberg_sqdt.py @@ -319,7 +319,8 @@ def __init__( def __repr__(self) -> str: species, n, l, j, f, m = self.species, self.n, self.l, self.j, self.f, self.m - return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j=}, {f=}, {m=})" + f_string = f", {f=}" if self.species.i_c not in (None, 0) else "" + return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j=}{f_string}, {m=})" class RydbergStateSQDTAlkalineLS(RydbergStateSQDT): From 7753ed79aefadf79ceef3a371b3a68aca7ae829c Mon Sep 17 00:00:00 2001 From: johannes-moegerle Date: Mon, 22 Dec 2025 13:54:49 +0100 Subject: [PATCH 11/11] [tests] start adding tests for basis --- tests/test_basis.py | 64 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 tests/test_basis.py diff --git a/tests/test_basis.py b/tests/test_basis.py new file mode 100644 index 0000000..d8ae608 --- /dev/null +++ b/tests/test_basis.py @@ -0,0 +1,64 @@ +import numpy as np +import pytest +from rydstate import BasisSQDTAlkali +from rydstate.basis.basis_sqdt import BasisSQDTAlkalineLS + + +@pytest.mark.parametrize("species_name", ["Rb", "Na", "H"]) +def test_alkali_basis(species_name: str) -> None: + """Test alkali basis creation.""" + basis = BasisSQDTAlkali(species_name, n_min=1, n_max=20) + basis.sort_states("n", "l_r") + lowest_n_state = {"Rb": (4, 2), "Na": (3, 0), "H": (1, 0)}[species_name] + assert (basis.states[0].n, basis.states[0].l) == lowest_n_state + assert (basis.states[-1].n, basis.states[-1].l) == (20, 19) + assert len(basis.states) == {"Rb": 388, "Na": 396, "H": 400}[species_name] + + state0 = basis.states[0] + ov = basis.calc_reduced_overlap(state0) + compare_ov = np.zeros(len(basis.states)) + compare_ov[0] = 1.0 + assert np.allclose(ov, compare_ov, atol=1e-3) + + me = basis.calc_reduced_matrix_element(state0, "electric_dipole", unit="e a0") + assert np.shape(me) == (len(basis.states),) + assert np.count_nonzero(me) > 0 + + basis.filter_states("n", (1, 7)) + ov_matrix = basis.calc_reduced_overlaps(basis) + assert np.allclose(ov_matrix, np.eye(len(basis.states)), atol=1e-3) + + me_matrix = basis.calc_reduced_matrix_elements(basis, "electric_dipole", unit="e a0") + assert np.shape(me_matrix) == (len(basis.states), len(basis.states)) + assert np.count_nonzero(me_matrix) > 0 + + +@pytest.mark.parametrize("species_name", ["Sr88", "Sr87", "Yb174", "Yb171"]) +def test_alkaline_basis(species_name: str) -> None: + """Test alkaline basis creation.""" + basis = BasisSQDTAlkalineLS(species_name, n_min=30, n_max=35) + basis.sort_states("n", "l_r") + assert (basis.states[0].n, basis.states[0].l) == (30, 0) + assert (basis.states[-1].n, basis.states[-1].l) == (35, 34) + assert len(basis.states) == {"Sr88": 768, "Sr87": 7188, "Yb174": 768, "Yb171": 1524}[species_name] + + if species_name in ["Sr87", "Yb171"]: + pytest.skip("Quantum defects for Sr87 and Yb171 not implemented yet.") + + state0 = basis.states[0] + ov = basis.calc_reduced_overlap(state0) + compare_ov = np.zeros(len(basis.states)) + compare_ov[0] = 1.0 + assert np.allclose(ov, compare_ov, atol=1e-3) + + me = basis.calc_reduced_matrix_element(state0, "electric_dipole", unit="e a0") + assert np.shape(me) == (len(basis.states),) + assert np.count_nonzero(me) > 0 + + basis.filter_states("l_r", (0, 2)) + ov_matrix = basis.calc_reduced_overlaps(basis) + assert np.allclose(ov_matrix, np.eye(len(basis.states)), atol=1e-2) + + me_matrix = basis.calc_reduced_matrix_elements(basis, "electric_dipole", unit="e a0") + assert np.shape(me_matrix) == (len(basis.states), len(basis.states)) + assert np.count_nonzero(me_matrix) > 0