Skip to content
3 changes: 3 additions & 0 deletions src/rydstate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from rydstate import angular, radial, species
from rydstate.basis import BasisSQDTAlkali, BasisSQDTAlkalineLS
from rydstate.rydberg import (
RydbergStateSQDT,
RydbergStateSQDTAlkali,
Expand All @@ -8,6 +9,8 @@
from rydstate.units import ureg

__all__ = [
"BasisSQDTAlkali",
"BasisSQDTAlkalineLS",
"RydbergStateSQDT",
"RydbergStateSQDTAlkali",
"RydbergStateSQDTAlkalineJJ",
Expand Down
39 changes: 34 additions & 5 deletions src/rydstate/angular/angular_ket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import logging
from abc import ABC
from typing import TYPE_CHECKING, Any, ClassVar, Literal, get_args, overload
from typing import TYPE_CHECKING, Any, ClassVar, Literal, overload

from rydstate.angular.angular_matrix_element import (
AngularMomentumQuantumNumbers,
AngularOperatorType,
calc_prefactor_of_operator_in_coupled_scheme,
calc_reduced_identity_matrix_element,
calc_reduced_spherical_matrix_element,
calc_reduced_spin_matrix_element,
is_angular_momentum_quantum_number,
is_angular_operator_type,
)
from rydstate.angular.utils import (
calc_wigner_3j,
Expand All @@ -26,6 +26,7 @@
if TYPE_CHECKING:
from typing_extensions import Self

from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType
from rydstate.angular.angular_state import AngularState

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -186,6 +187,34 @@ def get_qn(self, qn: AngularMomentumQuantumNumbers) -> float:
raise ValueError(f"Quantum number {qn} not found in {self!r}.")
return getattr(self, qn) # type: ignore [no-any-return]

def calc_exp_qn(self, qn: AngularMomentumQuantumNumbers) -> float:
"""Calculate the expectation value of a quantum number qn.

If the quantum number is a good quantum number simply return it,
otherwise calculate it, see also AngularState.calc_exp_qn for more details.

Args:
qn: The quantum number to calculate the expectation value for.

"""
if qn in self.quantum_number_names:
return self.get_qn(qn)
return self.to_state().calc_exp_qn(qn)

def calc_std_qn(self, qn: AngularMomentumQuantumNumbers) -> float:
"""Calculate the standard deviation of a quantum number qn.

If the quantum number is a good quantum number return 0,
otherwise calculate the std, see also AngularState.calc_std_qn for more details.

Args:
qn: The quantum number to calculate the standard deviation for.

"""
if qn in self.quantum_number_names:
return 0
return self.to_state().calc_std_qn(qn)

@overload
def to_state(self, coupling_scheme: Literal["LS"]) -> AngularState[AngularKetLS]: ...

Expand Down Expand Up @@ -382,12 +411,12 @@ def calc_reduced_matrix_element( # noqa: C901
\left\langle self || \hat{O}^{(\kappa)} || other \right\rangle

"""
if operator not in get_args(AngularOperatorType):
if not is_angular_operator_type(operator):
raise NotImplementedError(f"calc_reduced_matrix_element is not implemented for operator {operator}.")

if type(self) is not type(other):
return self.to_state().calc_reduced_matrix_element(other.to_state(), operator, kappa)
if operator in get_args(AngularMomentumQuantumNumbers) and operator not in self.quantum_number_names:
if is_angular_momentum_quantum_number(operator) and operator not in self.quantum_number_names:
return self.to_state().calc_reduced_matrix_element(other.to_state(), operator, kappa)

qn_name: AngularMomentumQuantumNumbers
Expand Down
13 changes: 12 additions & 1 deletion src/rydstate/angular/angular_matrix_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import math
from functools import lru_cache
from typing import TYPE_CHECKING, Callable, Literal, TypeVar
from typing import TYPE_CHECKING, Callable, Literal, TypeVar, get_args

import numpy as np
from typing_extensions import TypeGuard

from rydstate.angular.utils import calc_wigner_3j, calc_wigner_6j, minus_one_pow

Expand Down Expand Up @@ -41,6 +42,16 @@ def lru_cache(maxsize: int) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
]


def is_angular_momentum_quantum_number(qn: str) -> TypeGuard[AngularMomentumQuantumNumbers]:
"""Check if the given string is an AngularMomentumQuantumNumbers."""
return qn in get_args(AngularMomentumQuantumNumbers)


def is_angular_operator_type(qn: str) -> TypeGuard[AngularOperatorType]:
"""Check if the given string is an AngularOperatorType."""
return qn in get_args(AngularOperatorType)


@lru_cache(maxsize=10_000)
def calc_reduced_spherical_matrix_element(l_r_final: int, l_r_initial: int, kappa: int) -> float:
r"""Calculate the reduced spherical matrix element (l_r_final || \hat{Y}_{k} || l_r_initial).
Expand Down
14 changes: 8 additions & 6 deletions src/rydstate/angular/angular_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import math
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args, overload
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload

import numpy as np

Expand All @@ -12,15 +12,15 @@
AngularKetJJ,
AngularKetLS,
)
from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers
from rydstate.angular.angular_matrix_element import is_angular_momentum_quantum_number

if TYPE_CHECKING:
from collections.abc import Iterator
from collections.abc import Iterator, Sequence

from typing_extensions import Self

from rydstate.angular.angular_ket import CouplingScheme
from rydstate.angular.angular_matrix_element import AngularOperatorType
from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType

logger = logging.getLogger(__name__)

Expand All @@ -30,13 +30,15 @@

class AngularState(Generic[_AngularKet]):
def __init__(
self, coefficients: list[float], kets: list[_AngularKet], *, warn_if_not_normalized: bool = True
self, coefficients: Sequence[float], kets: Sequence[_AngularKet], *, warn_if_not_normalized: bool = True
) -> None:
self.coefficients = np.array(coefficients)
self.kets = kets

if len(coefficients) != len(kets):
raise ValueError("Length of coefficients and kets must be the same.")
if len(kets) == 0:
raise ValueError("At least one ket must be provided.")
if not all(type(ket) is type(kets[0]) for ket in kets):
raise ValueError("All kets must be of the same type.")
if len(set(kets)) != len(kets):
Expand Down Expand Up @@ -164,7 +166,7 @@ def calc_reduced_matrix_element(
"""
if isinstance(other, AngularKetBase):
other = other.to_state()
if operator in get_args(AngularMomentumQuantumNumbers) and operator not in self.kets[0].quantum_number_names:
if is_angular_momentum_quantum_number(operator) and operator not in self.kets[0].quantum_number_names:
for ket_class in [AngularKetLS, AngularKetJJ, AngularKetFJ]:
if operator in ket_class.quantum_number_names:
return self.to(ket_class.coupling_scheme).calc_reduced_matrix_element(other, operator, kappa)
Expand Down
3 changes: 3 additions & 0 deletions src/rydstate/basis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from rydstate.basis.basis_sqdt import BasisSQDTAlkali, BasisSQDTAlkalineLS

__all__ = ["BasisSQDTAlkali", "BasisSQDTAlkalineLS"]
143 changes: 143 additions & 0 deletions src/rydstate/basis/basis_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from __future__ import annotations

from abc import ABC
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload

import numpy as np
from typing_extensions import Self

from rydstate.angular.angular_matrix_element import is_angular_momentum_quantum_number
from rydstate.rydberg.rydberg_base import RydbergStateBase
from rydstate.species.species_object import SpeciesObject
from rydstate.units import ureg

if TYPE_CHECKING:
from rydstate.units import MatrixElementOperator, NDArray, PintArray, PintFloat

_RydbergState = TypeVar("_RydbergState", bound=RydbergStateBase)


class BasisBase(ABC, Generic[_RydbergState]):
states: list[_RydbergState]

def __init__(self, species: str | SpeciesObject) -> None:
if isinstance(species, str):
species = SpeciesObject.from_name(species)
self.species = species

def __len__(self) -> int:
return len(self.states)

def copy(self) -> Self:
new_basis = self.__class__.__new__(self.__class__)
new_basis.species = self.species
new_basis.states = list(self.states)
return new_basis

@overload
def filter_states(self, qn: str, value: tuple[float, float], *, delta: float = 1e-10) -> Self: ...

@overload
def filter_states(self, qn: str, value: float, *, delta: float = 1e-10) -> Self: ...

def filter_states(self, qn: str, value: float | tuple[float, float], *, delta: float = 1e-10) -> Self:
if isinstance(value, tuple):
qn_min = value[0] - delta
qn_max = value[1] + delta
else:
qn_min = value - delta
qn_max = value + delta

if is_angular_momentum_quantum_number(qn):
self.states = [state for state in self.states if qn_min <= state.angular.calc_exp_qn(qn) <= qn_max]
elif qn in ["n", "nu", "nu_energy"]:
self.states = [state for state in self.states if qn_min <= getattr(state, qn) <= qn_max]
else:
raise ValueError(f"Unknown quantum number {qn}")

return self

def sort_states(self, *qns: str) -> Self:
"""Sort the basis states according to the given quantum numbers.

The first quantum number given is the primary sorting key, the second quantum number
is the secondary sorting key, and so on.
"""
values = np.array([self.calc_exp_qn(qn) for qn in qns])
sorted_indices = np.lexsort(values[::-1])
self.states = [self.states[i] for i in sorted_indices]
return self

def calc_exp_qn(self, qn: str) -> list[float]:
if is_angular_momentum_quantum_number(qn):
return [state.angular.calc_exp_qn(qn) for state in self.states]
if qn in ["n", "nu", "nu_energy"]:
return [getattr(state, qn) for state in self.states]
raise ValueError(f"Unknown quantum number {qn}")

def calc_std_qn(self, qn: str) -> list[float]:
if is_angular_momentum_quantum_number(qn):
return [state.angular.calc_std_qn(qn) for state in self.states]
if qn in ["n", "nu", "nu_energy"]:
return [0 for state in self.states]
raise ValueError(f"Unknown quantum number {qn}")

def calc_reduced_overlap(self, other: RydbergStateBase) -> NDArray:
"""Calculate the reduced overlap <self|other> (ignoring the magnetic quantum number m)."""
return np.array([bra.calc_reduced_overlap(other) for bra in self.states])

def calc_reduced_overlaps(self, other: BasisBase[Any]) -> NDArray:
"""Calculate the reduced overlap <bra|ket> for all states in the bases self and other.

Returns a numpy array overlaps, where overlaps[i,j] corresponds to the overlap of the
i-th state of self and the j-th state of other.
"""
return np.array([[bra.calc_reduced_overlap(ket) for ket in other.states] for bra in self.states])

@overload
def calc_reduced_matrix_element(
self, other: RydbergStateBase, operator: MatrixElementOperator, unit: None = None
) -> PintArray: ...

@overload
def calc_reduced_matrix_element(
self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str
) -> NDArray: ...

def calc_reduced_matrix_element(
self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str | None = None
) -> PintArray | NDArray:
r"""Calculate the reduced matrix element."""
values_list = [bra.calc_reduced_matrix_element(other, operator, unit=unit) for bra in self.states]
if unit is not None:
return np.array(values_list)

values: list[PintFloat] = values_list # type: ignore[assignment]
_unit = values[0].units
_values = np.array([v.magnitude for v in values])
return ureg.Quantity(_values, _unit)

@overload
def calc_reduced_matrix_elements(
self, other: BasisBase[Any], operator: MatrixElementOperator, unit: None = None
) -> PintArray: ...

@overload
def calc_reduced_matrix_elements(
self, other: BasisBase[Any], operator: MatrixElementOperator, unit: str
) -> NDArray: ...

def calc_reduced_matrix_elements(
self, other: BasisBase[Any], operator: MatrixElementOperator, unit: str | None = None
) -> PintArray | NDArray:
r"""Calculate the reduced matrix element."""
values_list = [
[bra.calc_reduced_matrix_element(ket, operator, unit=unit) for ket in other.states] for bra in self.states
]
if unit is not None:
return np.array(values_list)

values: list[list[PintFloat]] = values_list # type: ignore[assignment]
_unit = values[0][0].units
_values = np.array([[v.magnitude for v in vs] for vs in values])
return ureg.Quantity(_values, _unit)
50 changes: 50 additions & 0 deletions src/rydstate/basis/basis_sqdt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

import numpy as np

from rydstate.basis.basis_base import BasisBase
from rydstate.rydberg import RydbergStateSQDTAlkali, RydbergStateSQDTAlkalineLS


class BasisSQDTAlkali(BasisBase[RydbergStateSQDTAlkali]):
def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> None:
super().__init__(species)

if n_max is None:
raise ValueError("n_max must be given")

s = 1 / 2
i_c = self.species.i_c if self.species.i_c is not None else 0

self.states = []
for n in range(n_min, n_max + 1):
for l in range(n):
if not self.species.is_allowed_shell(n, l, s):
continue
for j in np.arange(abs(l - s), l + s + 1):
for f in np.arange(abs(j - i_c), j + i_c + 1):
state = RydbergStateSQDTAlkali(species, n=n, l=l, j=float(j), f=float(f))
self.states.append(state)


class BasisSQDTAlkalineLS(BasisBase[RydbergStateSQDTAlkalineLS]):
def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> None:
super().__init__(species)

if n_max is None:
raise ValueError("n_max must be given")

i_c = self.species.i_c if self.species.i_c is not None else 0

self.states = []
for n in range(n_min, n_max + 1):
for l in range(n):
for s_tot in [0, 1]:
if not self.species.is_allowed_shell(n, l, s_tot):
continue
for j_tot in range(abs(l - s_tot), l + s_tot + 1):
for f_tot in np.arange(abs(j_tot - i_c), j_tot + i_c + 1):
state = RydbergStateSQDTAlkalineLS(
species, n=n, l=l, s_tot=s_tot, j_tot=j_tot, f_tot=float(f_tot)
)
self.states.append(state)
3 changes: 2 additions & 1 deletion src/rydstate/rydberg/rydberg_sqdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def __init__(

def __repr__(self) -> str:
species, n, l, j, f, m = self.species, self.n, self.l, self.j, self.f, self.m
return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j=}, {f=}, {m=})"
f_string = f", {f=}" if self.species.i_c not in (None, 0) else ""
return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j=}{f_string}, {m=})"


class RydbergStateSQDTAlkalineLS(RydbergStateSQDT):
Expand Down
Loading