diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index beba35c..2986e6f 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -10,7 +10,7 @@ jobs: pytest-with-nbmake: strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.10", "3.11", "3.12", "3.13"] runs-on: ubuntu-latest timeout-minutes: 30 steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 709d417..ea5120d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,20 +9,20 @@ repos: - id: check-merge-conflict - repo: https://github.com/Lucas-C/pre-commit-hooks - rev: v1.5.5 + rev: v1.5.6 hooks: - id: remove-tabs exclude: 'Makefile|nist_energy_levels/.*\.txt$' - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.0 + rev: v0.15.5 hooks: - id: ruff args: ["--fix", "--exit-non-zero-on-fix"] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.18.2 + rev: v1.19.1 hooks: - id: mypy additional_dependencies: ["numpy >= 2.0", "pint >= 0.25.1"] diff --git a/README.md b/README.md index 15a6aee..309c807 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ The *RydState* software calculates properties of Rydberg states. We especially focus on the calculation of the radial wavefunction of Rydberg states via the Numerov method. -The software can be installed via pip (requires Python >= 3.9): +The software can be installed via pip (requires Python >= 3.10): ```bash pip install rydstate diff --git a/pyproject.toml b/pyproject.toml index c567b4b..53eb18e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -33,7 +32,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Physics", "Typing :: Typed", ] -requires-python = ">= 3.9" +requires-python = ">= 3.10" dependencies = [ "numpy >= 2.0", "numba >= 0.60", @@ -47,7 +46,6 @@ dependencies = [ tests = [ "pytest >= 8.0", "nbmake >= 1.3", - "rydstate[comparison]", ] docs = [ "sphinx >= 7", @@ -72,7 +70,7 @@ mypy = [ [dependency-groups] dev = [ - "rydstate[docs,tests,comparison,jupyter,mypy]", + "rydstate[docs,tests,jupyter,mypy]", "check-wheel-contents >= 0.6", ] @@ -108,7 +106,7 @@ addopts = [ [tool.ruff] line-length = 120 -target-version = "py39" +target-version = "py310" extend-include = ["*.ipynb"] [tool.ruff.lint] diff --git a/src/rydstate/angular/angular_ket.py b/src/rydstate/angular/angular_ket.py index a21b2ce..45a1c62 100644 --- a/src/rydstate/angular/angular_ket.py +++ b/src/rydstate/angular/angular_ket.py @@ -1,6 +1,5 @@ from __future__ import annotations -import contextlib import logging from abc import ABC from typing import TYPE_CHECKING, Any, ClassVar, Literal, overload @@ -14,32 +13,26 @@ is_angular_operator_type, ) from rydstate.angular.utils import ( + InvalidQuantumNumbersError, check_spin_addition_rule, get_possible_quantum_number_values, minus_one_pow, try_trivial_spin_addition, ) from rydstate.angular.wigner_symbols import calc_wigner_3j, clebsch_gordan_6j, clebsch_gordan_9j -from rydstate.species import SpeciesObject if TYPE_CHECKING: + from collections.abc import Sequence + from typing_extensions import Self from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType from rydstate.angular.angular_state import AngularState + from rydstate.angular.utils import CouplingScheme + from rydstate.species import SpeciesObject logger = logging.getLogger(__name__) -CouplingScheme = Literal["LS", "JJ", "FJ"] - - -class InvalidQuantumNumbersError(ValueError): - def __init__(self, ket: AngularKetBase, msg: str = "") -> None: - _msg = f"Invalid quantum numbers for {ket!r}" - if len(msg) > 0: - _msg += f"\n {msg}" - super().__init__(_msg) - class AngularKetBase(ABC): """Base class for a angular ket (i.e. a simple canonical spin ketstate).""" @@ -92,12 +85,13 @@ def __init__( ) -> None: """Initialize the Spin ket. - species: - Atomic species, e.g. 'Rb87'. - Not used for calculation, only for convenience to infer the core electron spin and nuclear spin quantum numbers. + Atomic species, e.g. 'Rb87', will not be used for calculation, + only for convenience to infer the core electron spin and nuclear spin quantum numbers. """ if species is not None: if isinstance(species, str): + from rydstate.species import SpeciesObject # noqa: PLC0415 + species = SpeciesObject.from_name(species) # use i_c = 0 for species without defined nuclear spin (-> ignore hyperfine) species_i_c = species.i_c if species.i_c is not None else 0 @@ -154,7 +148,7 @@ def __setattr__(self, key: str, value: object) -> None: super().__setattr__(key, value) def __repr__(self) -> str: - args = ", ".join(f"{qn}={val}" for qn, val in zip(self.quantum_number_names, self.quantum_numbers)) + args = ", ".join(f"{qn}={val}" for qn, val in zip(self.quantum_number_names, self.quantum_numbers, strict=True)) if self.m is not None: args += f", m={self.m}" return f"{self.__class__.__name__}({args})" @@ -237,10 +231,8 @@ def to_state(self, coupling_scheme: CouplingScheme | None = None) -> AngularStat The angular state in the specified coupling scheme. """ - from rydstate.angular.angular_state import AngularState # noqa: PLC0415 - if coupling_scheme is None or coupling_scheme == self.coupling_scheme: - return AngularState([1], [self]) + return self._create_angular_state([1], [self]) if coupling_scheme == "LS": return self._to_state_ls() if coupling_scheme == "JJ": @@ -280,9 +272,7 @@ def _to_state_ls(self) -> AngularState[AngularKetLS]: kets.append(ls_ket) coefficients.append(coeff) - from rydstate.angular.angular_state import AngularState # noqa: PLC0415 - - return AngularState(coefficients, kets) + return self._create_angular_state(coefficients, kets) def _to_state_jj(self) -> AngularState[AngularKetJJ]: """Convert a single ket to state in JJ coupling.""" @@ -315,9 +305,7 @@ def _to_state_jj(self) -> AngularState[AngularKetJJ]: kets.append(jj_ket) coefficients.append(coeff) - from rydstate.angular.angular_state import AngularState # noqa: PLC0415 - - return AngularState(coefficients, kets) + return self._create_angular_state(coefficients, kets) def _to_state_fj(self) -> AngularState[AngularKetFJ]: """Convert a single ket to state in FJ coupling.""" @@ -350,6 +338,10 @@ def _to_state_fj(self) -> AngularState[AngularKetFJ]: kets.append(fj_ket) coefficients.append(coeff) + return self._create_angular_state(coefficients, kets) + + def _create_angular_state(self, coefficients: Sequence[float], kets: Sequence[AngularKetBase]) -> AngularState[Any]: + """Create an AngularState from coefficients and kets.""" from rydstate.angular.angular_state import AngularState # noqa: PLC0415 return AngularState(coefficients, kets) @@ -736,55 +728,3 @@ def sanity_check(self, msgs: list[str] | None = None) -> None: msgs.append(f"{self.f_c=}, {self.j_r=}, {self.f_tot=} don't satisfy spin addition rule.") super().sanity_check(msgs) - - -def quantum_numbers_to_angular_ket( - species: str | SpeciesObject, - s_c: float | None = None, - l_c: int = 0, - j_c: float | None = None, - f_c: float | None = None, - s_r: float = 0.5, - l_r: int | None = None, - j_r: float | None = None, - s_tot: float | None = None, - l_tot: int | None = None, - j_tot: float | None = None, - f_tot: float | None = None, - m: float | None = None, -) -> AngularKetBase: - r"""Return an AngularKet object in the corresponding coupling scheme from the given quantum numbers. - - Args: - species: Atomic species. - s_c: Spin quantum number of the core electron (0 for Alkali, 0.5 for divalent atoms). - l_c: Orbital angular momentum quantum number of the core electron. - j_c: Total angular momentum quantum number of the core electron. - f_c: Total angular momentum quantum number of the core (core electron + nucleus). - s_r: Spin quantum number of the rydberg electron (always 0.5). - l_r: Orbital angular momentum quantum number of the rydberg electron. - j_r: Total angular momentum quantum number of the rydberg electron. - s_tot: Total spin quantum number of all electrons. - l_tot: Total orbital angular momentum quantum number of all electrons. - j_tot: Total angular momentum quantum number of all electrons. - f_tot: Total angular momentum quantum number of the atom (rydberg electron + core). - m: Total magnetic quantum number. - Optional, only needed for concrete angular matrix elements. - - """ - with contextlib.suppress(InvalidQuantumNumbersError, ValueError): - return AngularKetLS( - s_c=s_c, l_c=l_c, s_r=s_r, l_r=l_r, s_tot=s_tot, l_tot=l_tot, j_tot=j_tot, f_tot=f_tot, m=m, species=species - ) - - with contextlib.suppress(InvalidQuantumNumbersError, ValueError): - return AngularKetJJ( - s_c=s_c, l_c=l_c, j_c=j_c, s_r=s_r, l_r=l_r, j_r=j_r, j_tot=j_tot, f_tot=f_tot, m=m, species=species - ) - - with contextlib.suppress(InvalidQuantumNumbersError, ValueError): - return AngularKetFJ( - s_c=s_c, l_c=l_c, j_c=j_c, f_c=f_c, s_r=s_r, l_r=l_r, j_r=j_r, f_tot=f_tot, m=m, species=species - ) - - raise ValueError("Invalid combination of angular quantum numbers provided.") diff --git a/src/rydstate/angular/angular_matrix_element.py b/src/rydstate/angular/angular_matrix_element.py index e50839b..36f1fc0 100644 --- a/src/rydstate/angular/angular_matrix_element.py +++ b/src/rydstate/angular/angular_matrix_element.py @@ -2,15 +2,16 @@ import math from functools import lru_cache -from typing import TYPE_CHECKING, Callable, Literal, TypeVar, get_args +from typing import TYPE_CHECKING, Literal, TypeGuard, TypeVar, get_args import numpy as np -from typing_extensions import TypeGuard from rydstate.angular.utils import minus_one_pow from rydstate.angular.wigner_symbols import calc_wigner_3j, calc_wigner_6j if TYPE_CHECKING: + from collections.abc import Callable + from typing_extensions import ParamSpec P = ParamSpec("P") diff --git a/src/rydstate/angular/angular_state.py b/src/rydstate/angular/angular_state.py index e3978fb..5ec8cfc 100644 --- a/src/rydstate/angular/angular_state.py +++ b/src/rydstate/angular/angular_state.py @@ -19,8 +19,9 @@ from typing_extensions import Self - from rydstate.angular.angular_ket import CouplingScheme from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType + from rydstate.angular.utils import CouplingScheme + from rydstate.units import NDArray logger = logging.getLogger(__name__) @@ -30,7 +31,11 @@ class AngularState(Generic[_AngularKet]): def __init__( - self, coefficients: Sequence[float], kets: Sequence[_AngularKet], *, warn_if_not_normalized: bool = True + self, + coefficients: Sequence[float] | NDArray, + kets: Sequence[_AngularKet], + *, + warn_if_not_normalized: bool = True, ) -> None: self.coefficients = np.array(coefficients) self.kets = kets @@ -49,7 +54,7 @@ def __init__( self.coefficients /= self.norm def __iter__(self) -> Iterator[tuple[float, _AngularKet]]: - return zip(self.coefficients, self.kets).__iter__() + return zip(self.coefficients, self.kets, strict=True).__iter__() def __repr__(self) -> str: terms = [f"{coeff}*{ket!r}" for coeff, ket in self] diff --git a/src/rydstate/angular/utils.py b/src/rydstate/angular/utils.py index ecea008..fa41889 100644 --- a/src/rydstate/angular/utils.py +++ b/src/rydstate/angular/utils.py @@ -1,7 +1,24 @@ from __future__ import annotations +import contextlib +from typing import TYPE_CHECKING, Literal + import numpy as np +if TYPE_CHECKING: + from rydstate.angular.angular_ket import AngularKetBase + from rydstate.species.species_object import SpeciesObject + +CouplingScheme = Literal["LS", "JJ", "FJ"] + + +class InvalidQuantumNumbersError(ValueError): + def __init__(self, ket: AngularKetBase, msg: str = "") -> None: + _msg = f"Invalid quantum numbers for {ket!r}" + if len(msg) > 0: + _msg += f"\n {msg}" + super().__init__(_msg) + def minus_one_pow(n: float) -> int: """Calculate (-1)^n for an integer n and raise an error if n is not an integer.""" @@ -42,3 +59,57 @@ def get_possible_quantum_number_values(s_1: float, s_2: float, s_tot: float | No if s_tot is not None: return [s_tot] return [float(s) for s in np.arange(abs(s_1 - s_2), s_1 + s_2 + 1, 1)] + + +def quantum_numbers_to_angular_ket( + species: str | SpeciesObject, + s_c: float | None = None, + l_c: int = 0, + j_c: float | None = None, + f_c: float | None = None, + s_r: float = 0.5, + l_r: int | None = None, + j_r: float | None = None, + s_tot: float | None = None, + l_tot: int | None = None, + j_tot: float | None = None, + f_tot: float | None = None, + m: float | None = None, +) -> AngularKetBase: + r"""Return an AngularKet object in the corresponding coupling scheme from the given quantum numbers. + + Args: + species: Atomic species. + s_c: Spin quantum number of the core electron (0 for Alkali, 0.5 for divalent atoms). + l_c: Orbital angular momentum quantum number of the core electron. + j_c: Total angular momentum quantum number of the core electron. + f_c: Total angular momentum quantum number of the core (core electron + nucleus). + s_r: Spin quantum number of the rydberg electron (always 0.5). + l_r: Orbital angular momentum quantum number of the rydberg electron. + j_r: Total angular momentum quantum number of the rydberg electron. + s_tot: Total spin quantum number of all electrons. + l_tot: Total orbital angular momentum quantum number of all electrons. + j_tot: Total angular momentum quantum number of all electrons. + f_tot: Total angular momentum quantum number of the atom (rydberg electron + core). + m: Total magnetic quantum number. + Optional, only needed for concrete angular matrix elements. + + """ + from rydstate.angular.angular_ket import AngularKetFJ, AngularKetJJ, AngularKetLS # noqa: PLC0415 + + with contextlib.suppress(InvalidQuantumNumbersError, ValueError): + return AngularKetLS( + s_c=s_c, l_c=l_c, s_r=s_r, l_r=l_r, s_tot=s_tot, l_tot=l_tot, j_tot=j_tot, f_tot=f_tot, m=m, species=species + ) + + with contextlib.suppress(InvalidQuantumNumbersError, ValueError): + return AngularKetJJ( + s_c=s_c, l_c=l_c, j_c=j_c, s_r=s_r, l_r=l_r, j_r=j_r, j_tot=j_tot, f_tot=f_tot, m=m, species=species + ) + + with contextlib.suppress(InvalidQuantumNumbersError, ValueError): + return AngularKetFJ( + s_c=s_c, l_c=l_c, j_c=j_c, f_c=f_c, s_r=s_r, l_r=l_r, j_r=j_r, f_tot=f_tot, m=m, species=species + ) + + raise ValueError("Invalid combination of angular quantum numbers provided.") diff --git a/src/rydstate/angular/wigner_symbols.py b/src/rydstate/angular/wigner_symbols.py index c1ef37c..11bf67c 100644 --- a/src/rydstate/angular/wigner_symbols.py +++ b/src/rydstate/angular/wigner_symbols.py @@ -2,7 +2,7 @@ import math from functools import lru_cache, wraps -from typing import TYPE_CHECKING, Callable, TypeVar +from typing import TYPE_CHECKING, TypeVar from sympy import Integer from sympy.physics.wigner import ( @@ -14,6 +14,8 @@ from rydstate.angular.utils import minus_one_pow if TYPE_CHECKING: + from collections.abc import Callable + from typing_extensions import ParamSpec P = ParamSpec("P") diff --git a/src/rydstate/basis/basis_base.py b/src/rydstate/basis/basis_base.py index e6fe5d2..3618571 100644 --- a/src/rydstate/basis/basis_base.py +++ b/src/rydstate/basis/basis_base.py @@ -50,7 +50,7 @@ def filter_states(self, qn: str, value: float | tuple[float, float], *, delta: f if is_angular_momentum_quantum_number(qn): self.states = [state for state in self.states if qn_min <= state.angular.calc_exp_qn(qn) <= qn_max] - elif qn in ["n", "nu", "nu_energy"]: + elif qn in ["n", "nu", "nu_ref"]: self.states = [state for state in self.states if qn_min <= getattr(state, qn) <= qn_max] else: raise ValueError(f"Unknown quantum number {qn}") @@ -68,18 +68,18 @@ def sort_states(self, *qns: str) -> Self: self.states = [self.states[i] for i in sorted_indices] return self - def calc_exp_qn(self, qn: str) -> list[float]: + def calc_exp_qn(self, qn: str) -> NDArray: if is_angular_momentum_quantum_number(qn): - return [state.angular.calc_exp_qn(qn) for state in self.states] - if qn in ["n", "nu", "nu_energy"]: - return [getattr(state, qn) for state in self.states] + return np.array([state.angular.calc_exp_qn(qn) for state in self.states]) + if qn in ["n", "nu", "nu_ref"]: + return np.array([getattr(state, qn) for state in self.states]) raise ValueError(f"Unknown quantum number {qn}") - def calc_std_qn(self, qn: str) -> list[float]: + def calc_std_qn(self, qn: str) -> NDArray: if is_angular_momentum_quantum_number(qn): - return [state.angular.calc_std_qn(qn) for state in self.states] - if qn in ["n", "nu", "nu_energy"]: - return [0 for state in self.states] + return np.array([state.angular.calc_std_qn(qn) for state in self.states]) + if qn in ["n", "nu", "nu_ref"]: + return np.zeros(len(self.states)) raise ValueError(f"Unknown quantum number {qn}") def calc_reduced_overlap(self, other: RydbergStateBase) -> NDArray: diff --git a/src/rydstate/basis/basis_sqdt.py b/src/rydstate/basis/basis_sqdt.py index ced0b2d..fac839f 100644 --- a/src/rydstate/basis/basis_sqdt.py +++ b/src/rydstate/basis/basis_sqdt.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from typing import TYPE_CHECKING import numpy as np @@ -12,11 +13,14 @@ RydbergStateSQDTAlkalineLS, ) +if TYPE_CHECKING: + from rydstate.species.species_object import SpeciesObject + logger = logging.getLogger(__name__) class BasisSQDTAlkali(BasisBase[RydbergStateSQDTAlkali]): - def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> None: + def __init__(self, species: str | SpeciesObject, n_min: int = 1, n_max: int | None = None) -> None: super().__init__(species) if n_max is None: @@ -37,7 +41,7 @@ def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> No class BasisSQDTAlkalineLS(BasisBase[RydbergStateSQDTAlkalineLS]): - def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> None: + def __init__(self, species: str | SpeciesObject, n_min: int = 1, n_max: int | None = None) -> None: super().__init__(species) if n_max is None: @@ -60,7 +64,7 @@ def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> No class BasisSQDTAlkalineJJ(BasisBase[RydbergStateSQDTAlkalineJJ]): - def __init__(self, species: str, n_min: int = 0, n_max: int | None = None) -> None: + def __init__(self, species: str | SpeciesObject, n_min: int = 0, n_max: int | None = None) -> None: super().__init__(species) if n_max is None: @@ -90,7 +94,7 @@ def __init__(self, species: str, n_min: int = 0, n_max: int | None = None) -> No class BasisSQDTAlkalineFJ(BasisBase[RydbergStateSQDTAlkalineFJ]): - def __init__(self, species: str, n_min: int = 0, n_max: int | None = None) -> None: + def __init__(self, species: str | SpeciesObject, n_min: int = 0, n_max: int | None = None) -> None: super().__init__(species) if n_max is None: diff --git a/src/rydstate/radial/numerov.py b/src/rydstate/radial/numerov.py index 5cfb8a6..f5ff2f3 100644 --- a/src/rydstate/radial/numerov.py +++ b/src/rydstate/radial/numerov.py @@ -1,11 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from numba import njit if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence from rydstate.units import NDArray diff --git a/src/rydstate/radial/wavefunction.py b/src/rydstate/radial/wavefunction.py index 36db1ca..43e3ee1 100644 --- a/src/rydstate/radial/wavefunction.py +++ b/src/rydstate/radial/wavefunction.py @@ -3,7 +3,7 @@ import logging import math from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Literal import numpy as np from mpmath import whitw @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -WavefunctionSignConvention = Optional[Literal["positive_at_outer_bound", "n_l_1"]] +WavefunctionSignConvention = Literal["positive_at_outer_bound", "n_l_1"] | None class Wavefunction(ABC): diff --git a/src/rydstate/rydberg/rydberg_base.py b/src/rydstate/rydberg/rydberg_base.py index 0eeec28..6aeab66 100644 --- a/src/rydstate/rydberg/rydberg_base.py +++ b/src/rydstate/rydberg/rydberg_base.py @@ -22,3 +22,8 @@ def calc_reduced_overlap(self, other: RydbergStateBase) -> float: ... def calc_reduced_matrix_element( self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str | None = None ) -> PintFloat | float: ... + + @property + @abstractmethod + def nu_ref(self) -> float: + """The reference effective principal quantum number nu_ref.""" diff --git a/src/rydstate/rydberg/rydberg_sqdt.py b/src/rydstate/rydberg/rydberg_sqdt.py index 744b582..5fe65ec 100644 --- a/src/rydstate/rydberg/rydberg_sqdt.py +++ b/src/rydstate/rydberg/rydberg_sqdt.py @@ -7,7 +7,7 @@ import numpy as np -from rydstate.angular.angular_ket import quantum_numbers_to_angular_ket +from rydstate.angular.utils import quantum_numbers_to_angular_ket from rydstate.radial import RadialKet from rydstate.rydberg.rydberg_base import RydbergStateBase from rydstate.species import SpeciesObject @@ -15,6 +15,8 @@ from rydstate.units import BaseQuantities, MatrixElementOperatorRanks, ureg if TYPE_CHECKING: + from typing_extensions import Self + from rydstate.angular.angular_ket import AngularKetBase, AngularKetFJ, AngularKetJJ, AngularKetLS from rydstate.units import MatrixElementOperator, PintFloat @@ -94,14 +96,19 @@ def __init__( if nu is None and n is None: raise ValueError("Either n or nu must be given to initialize the Rydberg state.") + self._set_qn_as_attributes() + + def _set_qn_as_attributes(self) -> None: + pass + @classmethod def from_angular_ket( - cls, + cls: type[Self], species: str | SpeciesObject, angular_ket: AngularKetBase, n: int | None = None, nu: float | None = None, - ) -> RydbergStateSQDT: + ) -> Self: """Initialize the Rydberg state from an angular ket.""" obj = cls.__new__(cls) @@ -115,13 +122,14 @@ def from_angular_ket( raise ValueError("Either n or nu must be given to initialize the Rydberg state.") obj.angular = angular_ket + obj._set_qn_as_attributes() # noqa: SLF001 return obj def __repr__(self) -> str: species, n, nu = self.species.name, self.n, self.nu n_str = f", {n=}" if n is not None else "" - return f"{self.__class__.__name__}({species=}{n_str}, {nu=}, {self.angular})" + return f"{self.__class__.__name__}({species}{n_str}, {nu=}, {self.angular})" def __str__(self) -> str: return self.__repr__() @@ -151,13 +159,19 @@ def nu(self) -> float: """The effective principal quantum number nu (for alkali atoms also known as n*) for the Rydberg state.""" if self._nu is not None: return self._nu - assert self.n is not None - if any(qn not in self.angular.quantum_number_names for qn in ["j_tot", "s_tot"]): - raise ValueError("j_tot and s_tot must be defined to calculate nu from n.") + assert isinstance(self.species, SpeciesObject), "nu must be given if not sqdt" + assert self.n is not None, "either nu or n must be given" + + if "j_tot" not in self.angular.quantum_number_names or "s_tot" not in self.angular.quantum_number_names: + raise RuntimeError("j_tot and s_tot must be defined in the angular ket to calculate nu from n.") return self.species.calc_nu( self.n, self.angular.l_r, self.angular.get_qn("j_tot"), s_tot=self.angular.get_qn("s_tot") ) + @property + def nu_ref(self) -> float: + return self.nu + @overload def get_energy(self, unit: None = None) -> PintFloat: ... @@ -342,15 +356,18 @@ def __init__( """ super().__init__(species=species, n=n, nu=nu, l_r=l, j_tot=j, f_tot=f, m=m) - self.l = l + def _set_qn_as_attributes(self) -> None: + self.l = self.angular.l_r self.j = self.angular.j_tot self.f = self.angular.f_tot - self.m = m + self.m = self.angular.m def __repr__(self) -> str: - species, n, l, j, f, m = self.species, self.n, self.l, self.j, self.f, self.m + species, n, nu = self.species.name, self.n, self.nu + l, j, f, m = self.l, self.j, self.f, self.m + n_str = f", {n=}" if n is not None else "" f_string = f", {f=}" if self.species.i_c not in (None, 0) else "" - return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j=}{f_string}, {m=})" + return f"{self.__class__.__name__}({species}{n_str}, {nu=}, {l=}, {j=}{f_string}, {m=})" class RydbergStateSQDTAlkalineLS(RydbergStateSQDT): @@ -388,15 +405,18 @@ def __init__( """ super().__init__(species=species, n=n, nu=nu, l_r=l, s_tot=s_tot, j_tot=j_tot, f_tot=f_tot, m=m) - self.l = l + def _set_qn_as_attributes(self) -> None: + self.l = self.angular.l_r self.s_tot = self.angular.s_tot self.j_tot = self.angular.j_tot self.f_tot = self.angular.f_tot - self.m = m + self.m = self.angular.m def __repr__(self) -> str: - species, n, l, s_tot, j_tot, f_tot, m = self.species, self.n, self.l, self.s_tot, self.j_tot, self.f_tot, self.m - return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {s_tot=}, {j_tot=}, {f_tot=}, {m=})" + species, n, nu = self.species.name, self.n, self.nu + l, s_tot, j_tot, f_tot, m = self.l, self.s_tot, self.j_tot, self.f_tot, self.m + n_str = f", {n=}" if n is not None else "" + return f"{self.__class__.__name__}({species}{n_str}, {nu=}, {l=}, {s_tot=}, {j_tot=}, {f_tot=}, {m=})" class RydbergStateSQDTAlkalineJJ(RydbergStateSQDT): @@ -434,6 +454,7 @@ def __init__( """ super().__init__(species=species, n=n, nu=nu, l_r=l, j_r=j_r, j_tot=j_tot, f_tot=f_tot, m=m) + def _set_qn_as_attributes(self) -> None: self.l = self.angular.l_r self.j_r = self.angular.j_r self.j_tot = self.angular.j_tot @@ -441,22 +462,10 @@ def __init__( self.m = self.angular.m def __repr__(self) -> str: - species, n, l, j_r, j_tot, f_tot, m = self.species, self.n, self.l, self.j_r, self.j_tot, self.f_tot, self.m - return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j_r=}, {j_tot=}, {f_tot=}, {m=})" - - @cached_property - def nu(self) -> float: - if self._nu is not None: - return self._nu - assert self.n is not None - nus = [self.species.calc_nu(self.n, self.l, self.j_tot, s_tot=s_tot) for s_tot in [0, 1]] - - if any(abs(nu - nus[0]) > 1e-10 for nu in nus[1:]): - raise ValueError( - "RydbergStateSQDTAlkalineJJ is intended for high-l states only, " - "where the quantum defects are the same for singlet and triplet states." - ) - return nus[0] + species, n, nu = self.species.name, self.n, self.nu + l, j_r, j_tot, f_tot, m = self.l, self.j_r, self.j_tot, self.f_tot, self.m + n_str = f", {n=}" if n is not None else "" + return f"{self.__class__.__name__}({species}{n_str}, {nu=}, {l=}, {j_r=}, {j_tot=}, {f_tot=}, {m=})" class RydbergStateSQDTAlkalineFJ(RydbergStateSQDT): @@ -494,6 +503,7 @@ def __init__( """ super().__init__(species=species, n=n, nu=nu, l_r=l, j_r=j_r, f_c=f_c, f_tot=f_tot, m=m) + def _set_qn_as_attributes(self) -> None: self.l = self.angular.l_r self.j_r = self.angular.j_r self.f_c = self.angular.f_c @@ -501,23 +511,9 @@ def __init__( self.m = self.angular.m def __repr__(self) -> str: - species, n, l, j_r, f_c, f_tot, m = self.species, self.n, self.l, self.j_r, self.f_c, self.f_tot, self.m - return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j_r=}, {f_c=}, {f_tot=}, {m=})" - - @cached_property - def nu(self) -> float: - if self._nu is not None: - return self._nu - assert self.n is not None - nus = [ - self.species.calc_nu(self.n, self.l, float(j_tot), s_tot=s_tot) - for s_tot in [0, 1] - for j_tot in np.arange(abs(self.j_r - 1 / 2), self.j_r + 1 / 2 + 1) - ] - - if any(abs(nu - nus[0]) > 1e-10 for nu in nus[1:]): - raise ValueError( - "RydbergStateSQDTAlkalineFJ is intended for high-l states only, " - "where the quantum defects are the same for singlet and triplet states." - ) - return nus[0] + species, n, nu = self.species.name, self.n, self.nu + l, j_r, f_c, f_tot, m = self.l, self.j_r, self.f_c, self.f_tot, self.m + l_c, j_c = self.angular.l_c, self.angular.j_c + core_string = f", {l_c=}, {j_c=}" if l_c != 0 else "" + n_str = f", {n=}" if n is not None else "" + return f"{self.__class__.__name__}({species}{n_str}, {nu=}{core_string}, {l=}, {j_r=}, {f_c=}, {f_tot=}, {m=})" diff --git a/src/rydstate/units.py b/src/rydstate/units.py index 7748865..1477922 100644 --- a/src/rydstate/units.py +++ b/src/rydstate/units.py @@ -1,13 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, Union +from typing import TYPE_CHECKING, Any, Literal from pint import UnitRegistry if TYPE_CHECKING: + from typing import TypeAlias + import numpy.typing as npt from pint.facets.plain import PlainQuantity, PlainUnit - from typing_extensions import TypeAlias NDArray: TypeAlias = npt.NDArray[Any] PintFloat: TypeAlias = PlainQuantity[float] @@ -48,7 +49,7 @@ "arbitrary", "zero", ] -DimensionLike = Union[Dimension, tuple[Dimension, Dimension]] +DimensionLike = Dimension | tuple[Dimension, Dimension] # some abbreviations: au_time: atomic_unit_of_time; au_current: atomic_unit_of_current; m_e: electron_mass _CommonUnits: dict[Dimension, str] = { diff --git a/tests/test_angular_matrix_elements.py b/tests/test_angular_matrix_elements.py index ea6e2e8..07d31ff 100644 --- a/tests/test_angular_matrix_elements.py +++ b/tests/test_angular_matrix_elements.py @@ -8,8 +8,9 @@ from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers if TYPE_CHECKING: - from rydstate.angular.angular_ket import AngularKetBase, CouplingScheme + from rydstate.angular.angular_ket import AngularKetBase from rydstate.angular.angular_matrix_element import AngularOperatorType + from rydstate.angular.utils import CouplingScheme TEST_KET_PAIRS = [ ( diff --git a/tests/test_basis.py b/tests/test_basis.py index bbef57a..dd49e48 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -1,11 +1,6 @@ -from typing import TYPE_CHECKING, Any - import numpy as np import pytest -from rydstate import BasisSQDTAlkali, BasisSQDTAlkalineFJ, BasisSQDTAlkalineJJ, BasisSQDTAlkalineLS - -if TYPE_CHECKING: - from rydstate.basis.basis_base import BasisBase +from rydstate import BasisSQDTAlkali, BasisSQDTAlkalineLS @pytest.mark.parametrize("species_name", ["Rb", "Na", "H"]) @@ -69,10 +64,3 @@ def test_alkaline_basis(species_name: str) -> None: basis = BasisSQDTAlkalineLS(species_name, n_min=30, n_max=35) basis.filter_states("l_r", (6, 10)) - for basis_class in [BasisSQDTAlkalineJJ, BasisSQDTAlkalineFJ]: - basis2: BasisBase[Any] = basis_class(species_name, n_min=30, n_max=35) # type: ignore [assignment] - basis2.filter_states("l_r", (6, 10)) - assert len(basis2.states) == len(basis.states) - trafo = basis.calc_reduced_overlaps(basis2) - trafo_inv = basis2.calc_reduced_overlaps(basis) - assert np.allclose(trafo @ trafo_inv, np.eye(len(basis.states)), atol=1e-3)