diff --git a/README.md b/README.md index 7f396d1..15a6aee 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ This package relies on quantum defects provided by the community. Consider citin ## Using custom quantum defects To use custom quantum defects (or quantum defects for a new species), you can simply create a subclass of `rydstate.species.species_object.SpeciesObject` (e.g. `class CustomRubidium(SpeciesObject):`) with a custom species name (e.g. `name = "Custom_Rb"`). Then, similarly to `rydstate.species.rubidium.py` you can define the quantum defects (and model potential parameters, ...) for your species. -Finally, you can use the custom species by simply calling `rydstate.RydbergStateAlkali("Custom_Rb", n=50, l=0, j=1/2, m=1/2)` (the code will look for all subclasses of `SpeciesObject` until it finds one with the species name "Custom_Rb"). +Finally, you can use the custom species by simply calling `rydstate.RydbergStateSQDTAlkali("Custom_Rb", n=50, l=0, j=1/2, m=1/2)` (the code will look for all subclasses of `SpeciesObject` until it finds one with the species name "Custom_Rb"). ## License diff --git a/docs/examples.rst b/docs/examples.rst index 5ba8c86..aff1b31 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -5,7 +5,7 @@ Examples RadialState ----------- -Some examples demonstrating the usage of the RadialState class, which uses the Numerov method for solving the radial Schrödinger equation. +Some examples demonstrating the usage of the RadialKet class, which uses the Numerov method for solving the radial Schrödinger equation. .. nbgallery:: examples/radial/hydrogen_wavefunction diff --git a/docs/examples/benchmark/benchmark_njit.ipynb b/docs/examples/benchmark/benchmark_njit.ipynb index 77bc3bc..c042006 100644 --- a/docs/examples/benchmark/benchmark_njit.ipynb +++ b/docs/examples/benchmark/benchmark_njit.ipynb @@ -17,7 +17,7 @@ "\n", "import numpy as np\n", "\n", - "from rydstate.rydberg_state import RydbergStateAlkali\n", + "from rydstate import RydbergStateSQDTAlkali\n", "\n", "test_cases: list[tuple[str, int, int, bool]] = [\n", " # species, n, l, use_njit\n", @@ -28,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -41,21 +41,21 @@ " \"\"\"\n", " # run the integration once to compile the numba function\n", " species, n, l, use_njit = test_cases[0]\n", - " state = RydbergStateAlkali(species, n, l, j=l + 0.5)\n", + " state = RydbergStateSQDTAlkali(species, n, l, j=l + 0.5)\n", " state.radial.create_wavefunction(_use_njit=True)\n", "\n", " results = []\n", " for species, n, l, use_njit in test_cases:\n", " # Setup the test function\n", " stmt = (\n", - " \"state = RydbergStateAlkali(species, n, l, j=l+0.5)\\n\"\n", + " \"state = RydbergStateSQDTAlkali(species, n, l, j=l+0.5)\\n\"\n", " \"state.radial.create_grid(dz=1e-3)\\n\"\n", " \"state.radial.create_wavefunction(_use_njit=use_njit)\"\n", " )\n", "\n", " # Time the integration multiple times and take average/std\n", " globals_dict = {\n", - " \"RydbergStateAlkali\": RydbergStateAlkali,\n", + " \"RydbergStateSQDTAlkali\": RydbergStateSQDTAlkali,\n", " \"species\": species,\n", " \"n\": n,\n", " \"l\": l,\n", diff --git a/docs/examples/comparisons/compare_dipole_matrix_element.ipynb b/docs/examples/comparisons/compare_dipole_matrix_element.ipynb index bfbfce5..2691e3b 100644 --- a/docs/examples/comparisons/compare_dipole_matrix_element.ipynb +++ b/docs/examples/comparisons/compare_dipole_matrix_element.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -18,8 +18,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "from rydstate.rydberg_state import RydbergStateAlkali\n", - "from rydstate.units import ureg" + "from rydstate import RydbergStateSQDTAlkali, ureg" ] }, { @@ -92,8 +91,8 @@ "for qn1, qn2 in zip(qn1_list, qn2_list):\n", " print(f\"n={qn1[0]}\", end=\"\\r\")\n", " q = round(qn2[-1] - qn1[-1])\n", - " state_i = RydbergStateAlkali(\"Rb\", n=qn1[0], l=qn1[1], j=qn1[2], m=qn1[3])\n", - " state_f = RydbergStateAlkali(\"Rb\", n=qn2[0], l=qn2[1], j=qn2[2], m=qn2[3])\n", + " state_i = RydbergStateSQDTAlkali(\"Rb\", n=qn1[0], l=qn1[1], j=qn1[2], m=qn1[3])\n", + " state_f = RydbergStateSQDTAlkali(\"Rb\", n=qn2[0], l=qn2[1], j=qn2[2], m=qn2[3])\n", " dipole_me = state_i.calc_matrix_element(state_f, \"electric_dipole\", q, unit=\"a.u.\")\n", " matrixelements.append(dipole_me)\n", "\n", diff --git a/docs/examples/comparisons/compare_model_potentials.ipynb b/docs/examples/comparisons/compare_model_potentials.ipynb index 5c0a6d4..3b9a7df 100644 --- a/docs/examples/comparisons/compare_model_potentials.ipynb +++ b/docs/examples/comparisons/compare_model_potentials.ipynb @@ -9,13 +9,13 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", - "from rydstate.rydberg_state import RydbergStateAlkali, RydbergStateAlkalineLS" + "from rydstate import RydbergStateSQDTAlkali, RydbergStateSQDTAlkalineLS" ] }, { @@ -41,15 +41,15 @@ } ], "source": [ - "state = RydbergStateAlkali(\"Rb\", n=40, l=0, j=0.5)\n", + "state = RydbergStateSQDTAlkali(\"Rb\", n=40, l=0, j=0.5)\n", "\n", - "states: dict[str, RydbergStateAlkali] = {}\n", + "states: dict[str, RydbergStateSQDTAlkali] = {}\n", "\n", "\n", - "states[\"model_potential_marinescu_1993\"] = RydbergStateAlkali(state.species, n=state.n, l=state.l, j=state.j)\n", + "states[\"model_potential_marinescu_1993\"] = RydbergStateSQDTAlkali(state.species, n=state.n, l=state.l, j=state.j)\n", "states[\"model_potential_marinescu_1993\"].radial.create_model(potential_type=\"model_potential_marinescu_1993\")\n", "\n", - "states[\"model_potential_fei_2009\"] = RydbergStateAlkali(state.species, n=state.n, l=state.l, j=state.j)\n", + "states[\"model_potential_fei_2009\"] = RydbergStateSQDTAlkali(state.species, n=state.n, l=state.l, j=state.j)\n", "states[\"model_potential_fei_2009\"].radial.create_model(potential_type=\"model_potential_fei_2009\")\n", "\n", "for label, state in states.items():\n", @@ -122,17 +122,17 @@ } ], "source": [ - "state = RydbergStateAlkalineLS(\"Sr88\", n=8, l=0, j_tot=0, s_tot=0)\n", + "state = RydbergStateSQDTAlkalineLS(\"Sr88\", n=8, l=0, j_tot=0, s_tot=0)\n", "\n", - "states: dict[str, RydbergStateAlkalineLS] = {}\n", + "states: dict[str, RydbergStateSQDTAlkalineLS] = {}\n", "\n", "\n", - "states[\"model_potential_marinescu_1993\"] = RydbergStateAlkalineLS(\n", + "states[\"model_potential_marinescu_1993\"] = RydbergStateSQDTAlkalineLS(\n", " state.species, n=state.n, l=state.l, j_tot=state.j_tot, s_tot=state.s_tot\n", ")\n", "states[\"model_potential_marinescu_1993\"].radial.create_model(potential_type=\"model_potential_marinescu_1993\")\n", "\n", - "states[\"model_potential_fei_2009\"] = RydbergStateAlkalineLS(\n", + "states[\"model_potential_fei_2009\"] = RydbergStateSQDTAlkalineLS(\n", " state.species, n=state.n, l=state.l, j_tot=state.j_tot, s_tot=state.s_tot\n", ")\n", "states[\"model_potential_fei_2009\"].radial.create_model(potential_type=\"model_potential_fei_2009\")\n", diff --git a/docs/examples/comparisons/compare_radial_matrix_element.ipynb b/docs/examples/comparisons/compare_radial_matrix_element.ipynb index b8f960c..095ba37 100644 --- a/docs/examples/comparisons/compare_radial_matrix_element.ipynb +++ b/docs/examples/comparisons/compare_radial_matrix_element.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -18,8 +18,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "from rydstate.rydberg_state import RydbergStateAlkali\n", - "from rydstate.units import ureg" + "from rydstate import RydbergStateSQDTAlkali, ureg" ] }, { @@ -64,8 +63,8 @@ " results[key] = []\n", " for qn1, qn2 in zip(qn1_list, qn2_list):\n", " print(f\"n={qn1[0]}\", end=\"\\r\")\n", - " state_i = RydbergStateAlkali(species, qn1[0], qn1[1], j=qn1[2])\n", - " state_f = RydbergStateAlkali(species, qn2[0], qn2[1], j=qn2[2])\n", + " state_i = RydbergStateSQDTAlkali(species, qn1[0], qn1[1], j=qn1[2])\n", + " state_f = RydbergStateSQDTAlkali(species, qn2[0], qn2[1], j=qn2[2])\n", " radial_me = state_i.radial.calc_matrix_element(state_f.radial, 1, unit=\"a.u.\")\n", " results[key].append(radial_me)\n", "\n", diff --git a/docs/examples/comparisons/compare_wavefunctions.ipynb b/docs/examples/comparisons/compare_wavefunctions.ipynb index af490fa..699739c 100644 --- a/docs/examples/comparisons/compare_wavefunctions.ipynb +++ b/docs/examples/comparisons/compare_wavefunctions.ipynb @@ -9,14 +9,14 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "from rydstate.rydberg_state import RydbergStateAlkali" + "from rydstate import RydbergStateSQDTAlkali" ] }, { @@ -44,7 +44,7 @@ "source": [ "results[\"rydstate\"] = []\n", "for qn in qns:\n", - " state = RydbergStateAlkali(\"Rb\", n=qn[0], l=qn[1], j=qn[2])\n", + " state = RydbergStateSQDTAlkali(\"Rb\", n=qn[0], l=qn[1], j=qn[2])\n", " state.radial.create_grid()\n", " state.radial.create_wavefunction()\n", " results[\"rydstate\"].append(\n", @@ -125,7 +125,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -142,8 +142,8 @@ "# the small difference of the wavefunctions explains the difference in the radial matrix element of circular states\n", "# (see also the compare_radial_matrix_element and compare_dipole_matrix_element notebooks)\n", "\n", + "from rydstate import ureg\n", "from rydstate.radial.radial_matrix_element import calc_radial_matrix_element_from_w_z\n", - "from rydstate.units import ureg\n", "\n", "to_mum = ureg.Quantity(1, \"bohr_radius\").to(\"micrometer\").magnitude\n", "\n", diff --git a/docs/examples/comparisons/compare_whittaker.ipynb b/docs/examples/comparisons/compare_whittaker.ipynb index ddd36c7..2bab6ab 100644 --- a/docs/examples/comparisons/compare_whittaker.ipynb +++ b/docs/examples/comparisons/compare_whittaker.ipynb @@ -21,7 +21,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "from rydstate import RydbergStateAlkali\n", + "from rydstate import RydbergStateSQDTAlkali\n", "\n", "logging.basicConfig(level=logging.INFO, format=\"%(levelname)s %(filename)s: %(message)s\")" ] @@ -41,7 +41,7 @@ "metadata": {}, "outputs": [], "source": [ - "states: dict[str, RydbergStateAlkali] = {}" + "states: dict[str, RydbergStateSQDTAlkali] = {}" ] }, { @@ -62,7 +62,7 @@ } ], "source": [ - "state = RydbergStateAlkali(\"Rb\", n=21, l=0, j=0.5)\n", + "state = RydbergStateSQDTAlkali(\"Rb\", n=21, l=0, j=0.5)\n", "\n", "state.radial.create_model(potential_type=\"model_potential_marinescu_1993\")\n", "state.radial.create_wavefunction(\"numerov\")\n", @@ -70,7 +70,7 @@ "\n", "# Using Numerov without model potentials will lead to some warnings,\n", "# since the resulting wavefunction does not pass all heuristic checks\n", - "state_without_mp = RydbergStateAlkali(state.species, state.n, state.l, state.j)\n", + "state_without_mp = RydbergStateSQDTAlkali(state.species, state.n, state.l, state.j)\n", "state_without_mp.radial.create_model(potential_type=\"coulomb\")\n", "state_without_mp.radial.create_wavefunction(\"numerov\")\n", "states[\"Numerov without Model Potentials\"] = state_without_mp" @@ -91,7 +91,7 @@ } ], "source": [ - "state_whittaker = RydbergStateAlkali(state.species, state.n, state.l, state.j)\n", + "state_whittaker = RydbergStateSQDTAlkali(state.species, state.n, state.l, state.j)\n", "state_whittaker.radial.create_grid(x_min=state.radial.grid.x_min, x_max=state.radial.grid.x_max)\n", "state_whittaker.radial.create_wavefunction(\"whittaker\")\n", "states[\"Whittaker\"] = state_whittaker" @@ -188,15 +188,15 @@ } ], "source": [ - "state1 = RydbergStateAlkali(\"Rb\", n=10, l=0, j=0.5)\n", - "state2 = RydbergStateAlkali(\"Rb\", n=9, l=1, j=1.5)\n", + "state1 = RydbergStateSQDTAlkali(\"Rb\", n=10, l=0, j=0.5)\n", + "state2 = RydbergStateSQDTAlkali(\"Rb\", n=9, l=1, j=1.5)\n", "\n", "dipole_me = state1.radial.calc_matrix_element(state2.radial, 1)\n", "print(f\"Numerov with model potentials: {dipole_me}\", flush=True)\n", "\n", - "_state1 = RydbergStateAlkali(state1.species, state1.n, state1.l, state1.j)\n", + "_state1 = RydbergStateSQDTAlkali(state1.species, state1.n, state1.l, state1.j)\n", "_state1.radial.create_model(potential_type=\"coulomb\")\n", - "_state2 = RydbergStateAlkali(state2.species, state2.n, state2.l, state2.j)\n", + "_state2 = RydbergStateSQDTAlkali(state2.species, state2.n, state2.l, state2.j)\n", "_state2.radial.create_model(potential_type=\"coulomb\")\n", "\n", "dipole_me = _state1.radial.calc_matrix_element(_state2.radial, 1)\n", @@ -206,10 +206,10 @@ "# to avoid integrating over the diverging peak at the origin (see plots above)\n", "xmin1, xmax1 = _state1.radial.grid.x_min, _state1.radial.grid.x_max\n", "xmin2, xmax2 = _state2.radial.grid.x_min, _state2.radial.grid.x_max\n", - "_state1 = RydbergStateAlkali(state1.species, state1.n, state1.l, state1.j)\n", + "_state1 = RydbergStateSQDTAlkali(state1.species, state1.n, state1.l, state1.j)\n", "_state1.radial.create_grid(x_min=xmin1, x_max=xmax1)\n", "_state1.radial.create_wavefunction(\"whittaker\")\n", - "_state2 = RydbergStateAlkali(state2.species, state2.n, state2.l, state2.j)\n", + "_state2 = RydbergStateSQDTAlkali(state2.species, state2.n, state2.l, state2.j)\n", "_state2.radial.create_grid(x_min=xmin2, x_max=xmax2)\n", "_state2.radial.create_wavefunction(\"whittaker\")\n", "\n", diff --git a/docs/examples/comparisons/compare_z_min_cutoff.ipynb b/docs/examples/comparisons/compare_z_min_cutoff.ipynb index 1fff6bf..cdcaa1c 100644 --- a/docs/examples/comparisons/compare_z_min_cutoff.ipynb +++ b/docs/examples/comparisons/compare_z_min_cutoff.ipynb @@ -9,14 +9,14 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "from rydstate.rydberg_state import RydbergStateAlkali" + "from rydstate import RydbergStateSQDTAlkali" ] }, { @@ -57,7 +57,7 @@ "z_i_dict = {\"hydrogen\": [], \"classical\": [], \"rydstate cutoff\": []}\n", "for qn in qn_list:\n", " print(f\"n={qn[0]}\", end=\"\\r\")\n", - " state = RydbergStateAlkali(\"Rb\", n=qn[0], l=qn[1], j=qn[2])\n", + " state = RydbergStateSQDTAlkali(\"Rb\", n=qn[0], l=qn[1], j=qn[2])\n", "\n", " hydrogen_z_i = state.radial.model.calc_hydrogen_turning_point_z(state.n, state.l)\n", " z_i_dict[\"hydrogen\"].append(hydrogen_z_i)\n", diff --git a/docs/examples/dipole_matrix_elements.ipynb b/docs/examples/dipole_matrix_elements.ipynb index c4bc544..9ce7391 100644 --- a/docs/examples/dipole_matrix_elements.ipynb +++ b/docs/examples/dipole_matrix_elements.ipynb @@ -9,13 +9,13 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", - "from rydstate.rydberg_state import RydbergStateAlkali" + "from rydstate import RydbergStateSQDTAlkali" ] }, { @@ -34,8 +34,8 @@ } ], "source": [ - "state_i = RydbergStateAlkali(\"Rb\", 60, 2, j=3 / 2, m=1 / 2)\n", - "state_f = RydbergStateAlkali(\"Rb\", 60, 3, j=5 / 2, m=1 / 2)\n", + "state_i = RydbergStateSQDTAlkali(\"Rb\", 60, 2, j=3 / 2, m=1 / 2)\n", + "state_f = RydbergStateSQDTAlkali(\"Rb\", 60, 3, j=5 / 2, m=1 / 2)\n", "\n", "kappa = 1\n", "radial = state_i.radial.calc_matrix_element(state_f.radial, k_radial=1)\n", diff --git a/docs/examples/radial/hydrogen_wavefunction.ipynb b/docs/examples/radial/hydrogen_wavefunction.ipynb index 64c99cd..4f41872 100644 --- a/docs/examples/radial/hydrogen_wavefunction.ipynb +++ b/docs/examples/radial/hydrogen_wavefunction.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -18,7 +18,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "from rydstate.radial import RadialState" + "from rydstate.radial import RadialKet" ] }, { @@ -27,7 +27,7 @@ "metadata": {}, "outputs": [], "source": [ - "state = RadialState(\"H_textbook\", nu=10, l_r=5)\n", + "state = RadialKet(\"H_textbook\", nu=10, l_r=5)\n", "state.set_n_for_sanity_check(10)\n", "state.create_model()\n", "state.create_grid(dz=1e-2)\n", diff --git a/docs/examples/radial/rubidium_wavefunction.ipynb b/docs/examples/radial/rubidium_wavefunction.ipynb index de64f17..b8cfad1 100644 --- a/docs/examples/radial/rubidium_wavefunction.ipynb +++ b/docs/examples/radial/rubidium_wavefunction.ipynb @@ -18,7 +18,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "from rydstate.rydberg_state import RydbergStateAlkali\n", + "from rydstate import RydbergStateSQDTAlkali\n", "\n", "logging.basicConfig(level=logging.INFO, format=\"%(levelname)s %(filename)s: %(message)s\")\n", "logging.getLogger(\"rydstate\").setLevel(logging.DEBUG)" @@ -30,7 +30,7 @@ "metadata": {}, "outputs": [], "source": [ - "state = RydbergStateAlkali(\"Rb\", n=130, l=129, j=129.5)\n", + "state = RydbergStateSQDTAlkali(\"Rb\", n=130, l=129, j=129.5)\n", "state.radial.create_wavefunction()\n", "\n", "turning_points = {\n", @@ -45,7 +45,7 @@ "metadata": {}, "outputs": [], "source": [ - "hydrogen = RydbergStateAlkali(\"H_textbook\", n=state.n, l=state.l, j=state.j)\n", + "hydrogen = RydbergStateSQDTAlkali(\"H_textbook\", n=state.n, l=state.l, j=state.j)\n", "hydrogen.radial.create_model()\n", "hydrogen.radial.create_wavefunction()" ] diff --git a/pyproject.toml b/pyproject.toml index 70d9914..5a24292 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ dependencies = [ ] [project.optional-dependencies] +mqdt = [ + "juliacall >= 0.9.24", +] tests = [ "pytest >= 8.0", "nbmake >= 1.3", diff --git a/src/rydstate/__init__.py b/src/rydstate/__init__.py index 00b275b..e24bdbb 100644 --- a/src/rydstate/__init__.py +++ b/src/rydstate/__init__.py @@ -1,11 +1,35 @@ from rydstate import angular, radial, species -from rydstate.rydberg_state import RydbergStateAlkali, RydbergStateAlkalineJJ, RydbergStateAlkalineLS +from rydstate.basis import ( + BasisMQDT, + BasisSQDTAlkali, + BasisSQDTAlkalineFJ, + BasisSQDTAlkalineJJ, + BasisSQDTAlkalineKS, + BasisSQDTAlkalineLS, +) +from rydstate.rydberg import ( + RydbergStateMQDT, + RydbergStateSQDT, + RydbergStateSQDTAlkali, + RydbergStateSQDTAlkalineFJ, + RydbergStateSQDTAlkalineJJ, + RydbergStateSQDTAlkalineLS, +) from rydstate.units import ureg __all__ = [ - "RydbergStateAlkali", - "RydbergStateAlkalineJJ", - "RydbergStateAlkalineLS", + "BasisMQDT", + "BasisSQDTAlkali", + "BasisSQDTAlkalineFJ", + "BasisSQDTAlkalineJJ", + "BasisSQDTAlkalineKS", + "BasisSQDTAlkalineLS", + "RydbergStateMQDT", + "RydbergStateSQDT", + "RydbergStateSQDTAlkali", + "RydbergStateSQDTAlkalineFJ", + "RydbergStateSQDTAlkalineJJ", + "RydbergStateSQDTAlkalineLS", "angular", "radial", "species", diff --git a/src/rydstate/angular/__init__.py b/src/rydstate/angular/__init__.py index de867fa..e4a0faf 100644 --- a/src/rydstate/angular/__init__.py +++ b/src/rydstate/angular/__init__.py @@ -1,9 +1,4 @@ -from rydstate.angular.angular_ket import AngularKetFJ, AngularKetJJ, AngularKetLS +from rydstate.angular.angular_ket import AngularKetFJ, AngularKetJJ, AngularKetKS, AngularKetLS from rydstate.angular.angular_state import AngularState -__all__ = [ - "AngularKetFJ", - "AngularKetJJ", - "AngularKetLS", - "AngularState", -] +__all__ = ["AngularKetFJ", "AngularKetJJ", "AngularKetKS", "AngularKetLS", "AngularState"] diff --git a/src/rydstate/angular/angular_ket.py b/src/rydstate/angular/angular_ket.py index bb01fd3..5a1d015 100644 --- a/src/rydstate/angular/angular_ket.py +++ b/src/rydstate/angular/angular_ket.py @@ -2,15 +2,15 @@ import logging from abc import ABC -from typing import TYPE_CHECKING, Any, ClassVar, Literal, get_args, overload +from typing import TYPE_CHECKING, Any, ClassVar, Literal, overload from rydstate.angular.angular_matrix_element import ( - AngularMomentumQuantumNumbers, - AngularOperatorType, calc_prefactor_of_operator_in_coupled_scheme, calc_reduced_identity_matrix_element, calc_reduced_spherical_matrix_element, calc_reduced_spin_matrix_element, + is_angular_momentum_quantum_number, + is_angular_operator_type, ) from rydstate.angular.utils import ( calc_wigner_3j, @@ -24,13 +24,15 @@ from rydstate.species import SpeciesObject if TYPE_CHECKING: + import juliacall from typing_extensions import Self + from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType from rydstate.angular.angular_state import AngularState logger = logging.getLogger(__name__) -CouplingScheme = Literal["LS", "JJ", "FJ"] +CouplingScheme = Literal["LS", "JJ", "FJ", "KS"] class InvalidQuantumNumbersError(ValueError): @@ -186,6 +188,34 @@ def get_qn(self, qn: AngularMomentumQuantumNumbers) -> float: raise ValueError(f"Quantum number {qn} not found in {self!r}.") return getattr(self, qn) # type: ignore [no-any-return] + def calc_exp_qn(self, qn: AngularMomentumQuantumNumbers) -> float: + """Calculate the expectation value of a quantum number qn. + + If the quantum number is a good quantum number simply return it, + otherwise calculate it, see also AngularState.calc_exp_qn for more details. + + Args: + qn: The quantum number to calculate the expectation value for. + + """ + if qn in self.quantum_number_names: + return self.get_qn(qn) + return self.to_state().calc_exp_qn(qn) + + def calc_std_qn(self, qn: AngularMomentumQuantumNumbers) -> float: + """Calculate the standard deviation of a quantum number qn. + + If the quantum number is a good quantum number return 0, + otherwise calculate the std, see also AngularState.calc_std_qn for more details. + + Args: + qn: The quantum number to calculate the standard deviation for. + + """ + if qn in self.quantum_number_names: + return 0 + return self.to_state().calc_std_qn(qn) + @overload def to_state(self, coupling_scheme: Literal["LS"]) -> AngularState[AngularKetLS]: ... @@ -195,6 +225,9 @@ def to_state(self, coupling_scheme: Literal["JJ"]) -> AngularState[AngularKetJJ] @overload def to_state(self, coupling_scheme: Literal["FJ"]) -> AngularState[AngularKetFJ]: ... + @overload + def to_state(self, coupling_scheme: Literal["KS"]) -> AngularState[AngularKetKS]: ... + @overload def to_state(self: Self) -> AngularState[Self]: ... @@ -219,6 +252,8 @@ def to_state(self, coupling_scheme: CouplingScheme | None = None) -> AngularStat return self._to_state_jj() if coupling_scheme == "FJ": return self._to_state_fj() + if coupling_scheme == "KS": + return self._to_state_ks() raise ValueError(f"Unknown coupling scheme {coupling_scheme!r}.") def _to_state_ls(self) -> AngularState[AngularKetLS]: @@ -247,7 +282,7 @@ def _to_state_ls(self) -> AngularState[AngularKetLS]: ) except InvalidQuantumNumbersError: continue - coeff = self.calc_reduced_overlap(ls_ket) + coeff = ls_ket.calc_reduced_overlap(self) if coeff != 0: kets.append(ls_ket) coefficients.append(coeff) @@ -282,7 +317,7 @@ def _to_state_jj(self) -> AngularState[AngularKetJJ]: ) except InvalidQuantumNumbersError: continue - coeff = self.calc_reduced_overlap(jj_ket) + coeff = jj_ket.calc_reduced_overlap(self) if coeff != 0: kets.append(jj_ket) coefficients.append(coeff) @@ -317,7 +352,7 @@ def _to_state_fj(self) -> AngularState[AngularKetFJ]: ) except InvalidQuantumNumbersError: continue - coeff = self.calc_reduced_overlap(fj_ket) + coeff = fj_ket.calc_reduced_overlap(self) if coeff != 0: kets.append(fj_ket) coefficients.append(coeff) @@ -326,7 +361,42 @@ def _to_state_fj(self) -> AngularState[AngularKetFJ]: return AngularState(coefficients, kets) - def calc_reduced_overlap(self, other: AngularKetBase) -> float: + def _to_state_ks(self) -> AngularState[AngularKetKS]: + """Convert a single ket to state in KS coupling.""" + kets: list[AngularKetKS] = [] + coefficients: list[float] = [] + + j_c_list = get_possible_quantum_number_values(self.s_c, self.l_c, getattr(self, "j_c", None)) + for j_c in j_c_list: + k_list = get_possible_quantum_number_values(j_c, self.l_r, getattr(self, "k", None)) + for k in k_list: + j_tot_list = get_possible_quantum_number_values(k, self.s_r, getattr(self, "j_tot", None)) + for j_tot in j_tot_list: + try: + ks_ket = AngularKetKS( + i_c=self.i_c, + s_c=self.s_c, + l_c=self.l_c, + s_r=self.s_r, + l_r=self.l_r, + j_c=j_c, + k=k, + j_tot=j_tot, + f_tot=self.f_tot, + m=self.m, + ) + except InvalidQuantumNumbersError: + continue + coeff = ks_ket.calc_reduced_overlap(self) + if coeff != 0: + kets.append(ks_ket) + coefficients.append(coeff) + + from rydstate.angular.angular_state import AngularState # noqa: PLC0415 + + return AngularState(coefficients, kets) + + def calc_reduced_overlap(self, other: AngularKetBase | AngularState[Any]) -> float: # noqa: PLR0911 """Calculate the reduced overlap (ignoring the magnetic quantum number m). If both kets are of the same type (=same coupling scheme), this is just a delta function @@ -334,44 +404,50 @@ def calc_reduced_overlap(self, other: AngularKetBase) -> float: If the kets are of different types, the overlap is calculated using the corresponding Clebsch-Gordan coefficients (/ Wigner-j symbols). """ + from rydstate.angular.angular_state import AngularState # noqa: PLC0415 + + if isinstance(other, AngularState): + return other.calc_reduced_overlap(self) + + if type(self) is type(other): + return 1 if self.quantum_numbers == other.quantum_numbers else 0 + for q in set(self.quantum_number_names) & set(other.quantum_number_names): if self.get_qn(q) != other.get_qn(q): return 0 - if type(self) is type(other): - return 1 - kets = [self, other] - # JJ - FJ overlaps - if any(isinstance(s, AngularKetJJ) for s in kets) and any(isinstance(s, AngularKetFJ) for s in kets): - jj = next(s for s in kets if isinstance(s, AngularKetJJ)) - fj = next(s for s in kets if isinstance(s, AngularKetFJ)) - return clebsch_gordan_6j(fj.j_r, fj.j_c, fj.i_c, jj.j_tot, fj.f_c, fj.f_tot) - - # JJ - LS overlaps - if any(isinstance(s, AngularKetJJ) for s in kets) and any(isinstance(s, AngularKetLS) for s in kets): + # JJ overlaps + if any(isinstance(s, AngularKetJJ) for s in kets): jj = next(s for s in kets if isinstance(s, AngularKetJJ)) - ls = next(s for s in kets if isinstance(s, AngularKetLS)) - # NOTE: it matters, whether you first put all 3 l's and then all 3 s's or the other way round - # (see symmetry properties of 9j symbol) - # this convention is used, such that all matrix elements work out correctly, no matter in which - # coupling scheme they are calculated - return clebsch_gordan_9j(ls.l_r, ls.l_c, ls.l_tot, ls.s_r, ls.s_c, ls.s_tot, jj.j_r, jj.j_c, jj.j_tot) - - # FJ - LS overlaps - if any(isinstance(s, AngularKetFJ) for s in kets) and any(isinstance(s, AngularKetLS) for s in kets): - fj = next(s for s in kets if isinstance(s, AngularKetFJ)) - ls = next(s for s in kets if isinstance(s, AngularKetLS)) - ov: float = 0 - for coeff, jj_ket in fj.to_state("JJ"): - ov += coeff * ls.calc_reduced_overlap(jj_ket) - return float(ov) - - raise NotImplementedError(f"This method is not yet implemented for {self!r} and {other!r}.") + # - FJ + if any(isinstance(s, AngularKetFJ) for s in kets): + fj = next(s for s in kets if isinstance(s, AngularKetFJ)) + return clebsch_gordan_6j(fj.j_r, fj.j_c, fj.i_c, jj.j_tot, fj.f_c, fj.f_tot) + + # - LS + if any(isinstance(s, AngularKetLS) for s in kets): + ls = next(s for s in kets if isinstance(s, AngularKetLS)) + # NOTE: it matters, whether you first put all 3 l's and then all 3 s's or the other way round + # (see symmetry properties of 9j symbol) + # this convention is used, such that all matrix elements work out correctly, no matter in which + # coupling scheme they are calculated + return clebsch_gordan_9j(ls.l_r, ls.l_c, ls.l_tot, ls.s_r, ls.s_c, ls.s_tot, jj.j_r, jj.j_c, jj.j_tot) + + # - KS overlaps + if any(isinstance(s, AngularKetKS) for s in kets): + ks = next(s for s in kets if isinstance(s, AngularKetKS)) + # we have some gauge degree of freedom, which one must use to get consistent matrix elements + prefactor = -1 if (jj.j_r + jj.j_c) % 2 == 0 else 1 # TODO not quite correct yet + return prefactor * clebsch_gordan_6j(ks.s_r, ks.l_r, ks.j_c, jj.j_r, ks.k, ks.j_tot) + + raise NotImplementedError(f"calc_reduced_overlap not implemented for {kets!r}.") + + return self.to_state("JJ").calc_reduced_overlap(other) def calc_reduced_matrix_element( # noqa: C901 - self: Self, other: AngularKetBase, operator: AngularOperatorType, kappa: int + self: Self, other: AngularKetBase | AngularState[Any], operator: AngularOperatorType, kappa: int ) -> float: r"""Calculate the reduced angular matrix element. @@ -382,12 +458,17 @@ def calc_reduced_matrix_element( # noqa: C901 \left\langle self || \hat{O}^{(\kappa)} || other \right\rangle """ - if operator not in get_args(AngularOperatorType): + if not is_angular_operator_type(operator): raise NotImplementedError(f"calc_reduced_matrix_element is not implemented for operator {operator}.") + from rydstate.angular.angular_state import AngularState # noqa: PLC0415 + + if isinstance(other, AngularState): + return other.calc_reduced_matrix_element(self, operator, kappa) + if type(self) is not type(other): return self.to_state().calc_reduced_matrix_element(other.to_state(), operator, kappa) - if operator in get_args(AngularMomentumQuantumNumbers) and operator not in self.quantum_number_names: + if is_angular_momentum_quantum_number(operator) and operator not in self.quantum_number_names: return self.to_state().calc_reduced_matrix_element(other.to_state(), operator, kappa) qn_name: AngularMomentumQuantumNumbers @@ -515,8 +596,12 @@ def _calc_prefactor_of_operator_in_coupled_scheme( f1, f2, f_tot = (self.get_qn(qn1), self.get_qn(qn2), self.get_qn(qn_combined)) i1, i2, i_tot = (other.get_qn(qn1), other.get_qn(qn2), other.get_qn(qn_combined)) + # this should already been taken care of by _kronecker_delta_non_involved_spins + # TODO alternatively, remove _kronecker_delta_non_involved_spins, + # and check here a descending qns from qn1 or qn2 if (operator_acts_on == "first" and f2 != i2) or (operator_acts_on == "second" and f1 != i1): return 0 + prefactor = calc_prefactor_of_operator_in_coupled_scheme(f1, f2, f_tot, i1, i2, i_tot, kappa, operator_acts_on) return prefactor * self._calc_prefactor_of_operator_in_coupled_scheme(other, qn_combined, kappa) @@ -708,3 +793,133 @@ def sanity_check(self, msgs: list[str] | None = None) -> None: msgs.append(f"{self.f_c=}, {self.j_r=}, {self.f_tot=} don't satisfy spin addition rule.") super().sanity_check(msgs) + + +class AngularKetKS(AngularKetBase): + """Spin ket in KS coupling.""" + + __slots__ = ("j_c", "k", "j_tot") + quantum_number_names: ClassVar = ("i_c", "s_c", "l_c", "s_r", "l_r", "j_c", "k", "j_tot", "f_tot") + coupled_quantum_numbers: ClassVar = { + "j_c": ("s_c", "l_c"), + "k": ("j_c", "l_r"), + "j_tot": ("k", "s_r"), + "f_tot": ("i_c", "j_tot"), + } + coupling_scheme = "KS" + + j_c: float + """Total core electron angular quantum number (s_c + l_c).""" + k: float + """Intermediate angular momentum (j_c + l_r).""" + j_tot: float + """Total electron angular momentum quantum number (k + s_r).""" + + def __init__( + self, + i_c: float | None = None, + s_c: float | None = None, + l_c: int = 0, + s_r: float = 0.5, + l_r: int | None = None, + j_c: float | None = None, + k: float | None = None, + j_tot: float | None = None, + f_tot: float | None = None, + m: float | None = None, + species: str | SpeciesObject | None = None, + ) -> None: + """Initialize the Spin ket.""" + super().__init__(i_c, s_c, l_c, s_r, l_r, f_tot, m, species) + + self.j_c = try_trivial_spin_addition(self.l_c, self.s_c, j_c, "j_c") + self.k = try_trivial_spin_addition(self.j_c, self.l_r, k, "k") + self.j_tot = try_trivial_spin_addition(self.k, self.s_r, j_tot, "j_tot") + self.f_tot = try_trivial_spin_addition(self.i_c, self.j_tot, f_tot, "f_tot") + + super()._post_init() + + def sanity_check(self, msgs: list[str] | None = None) -> None: + """Check that the quantum numbers are valid.""" + msgs = msgs if msgs is not None else [] + + if not check_spin_addition_rule(self.l_c, self.s_c, self.j_c): + msgs.append(f"{self.l_c=}, {self.s_c=}, {self.j_c=} don't satisfy spin addition rule.") + + if not check_spin_addition_rule(self.j_c, self.l_r, self.k): + msgs.append(f"{self.j_c=}, {self.l_r=}, {self.k=} don't satisfy spin addition rule.") + + if not check_spin_addition_rule(self.k, self.s_r, self.j_tot): + msgs.append(f"{self.k=}, {self.s_r=}, {self.j_tot=} don't satisfy spin addition rule.") + + if not check_spin_addition_rule(self.i_c, self.j_tot, self.f_tot): + msgs.append(f"{self.i_c=}, {self.j_tot=}, {self.f_tot=} don't satisfy spin addition rule.") + + super().sanity_check(msgs) + + +def julia_qn_to_dict(qn: juliacall.AnyValue) -> dict[str, float]: + """Convert MQDT Julia quantum numbers to dict object.""" + if "fjQuantumNumbers" in str(qn): + return dict(s_c=qn.sc, l_c=qn.lc, j_c=qn.Jc, f_c=qn.Fc, l_r=qn.lr, j_r=qn.Jr, f_tot=qn.F) # noqa: C408 + if "jjQuantumNumbers" in str(qn): + return dict(s_c=qn.sc, l_c=qn.lc, j_c=qn.Jc, l_r=qn.lr, j_r=qn.Jr, j_tot=qn.J, f_tot=qn.F) # noqa: C408 + if "lsQuantumNumbers" in str(qn): + return dict(s_c=qn.sc, s_tot=qn.S, l_c=qn.lc, l_r=qn.lr, l_tot=qn.L, j_tot=qn.J, f_tot=qn.F) # noqa: C408 + raise ValueError(f"Unknown MQDT Julia quantum numbers {qn!s}.") + + +def quantum_numbers_to_angular_ket( + species: str | SpeciesObject, + s_c: float | None = None, + l_c: int = 0, + j_c: float | None = None, + f_c: float | None = None, + s_r: float = 0.5, + l_r: int | None = None, + j_r: float | None = None, + k: float | None = None, + s_tot: float | None = None, + l_tot: int | None = None, + j_tot: float | None = None, + f_tot: float | None = None, + m: float | None = None, +) -> AngularKetBase: + r"""Return an AngularKet object in the corresponding coupling scheme from the given quantum numbers. + + Args: + species: Atomic species. + s_c: Spin quantum number of the core electron (0 for Alkali, 0.5 for divalent atoms). + l_c: Orbital angular momentum quantum number of the core electron. + j_c: Total angular momentum quantum number of the core electron. + f_c: Total angular momentum quantum number of the core (core electron + nucleus). + s_r: Spin quantum number of the rydberg electron always 0.5) + l_r: Orbital angular momentum quantum number of the rydberg electron. + j_r: Total angular momentum quantum number of the rydberg electron. + k: Intermediate angular momentum (j_c + l_r). + s_tot: Total spin quantum number of all electrons. + l_tot: Total orbital angular momentum quantum number of all electrons. + j_tot: Total angular momentum quantum number of all electrons. + f_tot: Total angular momentum quantum number of the atom (rydberg electron + core) + m: Total magnetic quantum number. + Optional, only needed for concrete angular matrix elements. + + """ + if all(qn is None for qn in [j_c, f_c, j_r, k]): + return AngularKetLS( + s_c=s_c, l_c=l_c, s_r=s_r, l_r=l_r, s_tot=s_tot, l_tot=l_tot, j_tot=j_tot, f_tot=f_tot, m=m, species=species + ) + if all(qn is None for qn in [s_tot, l_tot, f_c, k]): + return AngularKetJJ( + s_c=s_c, l_c=l_c, j_c=j_c, s_r=s_r, l_r=l_r, j_r=j_r, j_tot=j_tot, f_tot=f_tot, m=m, species=species + ) + if all(qn is None for qn in [s_tot, l_tot, j_tot, k]): + return AngularKetFJ( + s_c=s_c, l_c=l_c, j_c=j_c, f_c=f_c, s_r=s_r, l_r=l_r, j_r=j_r, f_tot=f_tot, m=m, species=species + ) + if all(qn is None for qn in [s_tot, l_tot, j_r, f_c]): + return AngularKetKS( + s_c=s_c, l_c=l_c, j_c=j_c, s_r=s_r, l_r=l_r, k=k, j_tot=j_tot, f_tot=f_tot, m=m, species=species + ) + + raise ValueError("Invalid combination of angular quantum numbers provided.") diff --git a/src/rydstate/angular/angular_matrix_element.py b/src/rydstate/angular/angular_matrix_element.py index 1b69dc3..823b35e 100644 --- a/src/rydstate/angular/angular_matrix_element.py +++ b/src/rydstate/angular/angular_matrix_element.py @@ -2,9 +2,10 @@ import math from functools import lru_cache -from typing import TYPE_CHECKING, Callable, Literal, TypeVar +from typing import TYPE_CHECKING, Callable, Literal, TypeVar, get_args import numpy as np +from typing_extensions import TypeGuard from rydstate.angular.utils import calc_wigner_3j, calc_wigner_6j, minus_one_pow @@ -18,7 +19,7 @@ def lru_cache(maxsize: int) -> Callable[[Callable[P, R]], Callable[P, R]]: ... AngularMomentumQuantumNumbers = Literal[ - "i_c", "s_c", "l_c", "s_r", "l_r", "s_tot", "l_tot", "j_c", "j_r", "j_tot", "f_c", "f_tot" + "i_c", "s_c", "l_c", "s_r", "l_r", "s_tot", "l_tot", "j_c", "j_r", "k", "j_tot", "f_c", "f_tot" ] IdentityOperators = Literal[ "identity_i_c", @@ -33,6 +34,7 @@ def lru_cache(maxsize: int) -> Callable[[Callable[P, R]], Callable[P, R]]: ... "identity_j_tot", "identity_f_c", "identity_f_tot", + "identity_k", ] AngularOperatorType = Literal[ "spherical", @@ -41,6 +43,16 @@ def lru_cache(maxsize: int) -> Callable[[Callable[P, R]], Callable[P, R]]: ... ] +def is_angular_momentum_quantum_number(qn: str) -> TypeGuard[AngularMomentumQuantumNumbers]: + """Check if the given string is an AngularMomentumQuantumNumbers.""" + return qn in get_args(AngularMomentumQuantumNumbers) + + +def is_angular_operator_type(qn: str) -> TypeGuard[AngularOperatorType]: + """Check if the given string is an AngularOperatorType.""" + return qn in get_args(AngularOperatorType) + + @lru_cache(maxsize=10_000) def calc_reduced_spherical_matrix_element(l_r_final: int, l_r_initial: int, kappa: int) -> float: r"""Calculate the reduced spherical matrix element (l_r_final || \hat{Y}_{k} || l_r_initial). diff --git a/src/rydstate/angular/angular_state.py b/src/rydstate/angular/angular_state.py index dd05493..cc58621 100644 --- a/src/rydstate/angular/angular_state.py +++ b/src/rydstate/angular/angular_state.py @@ -2,7 +2,7 @@ import logging import math -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args, overload +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload import numpy as np @@ -10,17 +10,18 @@ AngularKetBase, AngularKetFJ, AngularKetJJ, + AngularKetKS, AngularKetLS, ) -from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers +from rydstate.angular.angular_matrix_element import is_angular_momentum_quantum_number if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Iterator, Sequence from typing_extensions import Self from rydstate.angular.angular_ket import CouplingScheme - from rydstate.angular.angular_matrix_element import AngularOperatorType + from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType logger = logging.getLogger(__name__) @@ -30,7 +31,7 @@ class AngularState(Generic[_AngularKet]): def __init__( - self, coefficients: list[float], kets: list[_AngularKet], *, warn_if_not_normalized: bool = True + self, coefficients: Sequence[float], kets: Sequence[_AngularKet], *, warn_if_not_normalized: bool = False ) -> None: self.coefficients = np.array(coefficients) self.kets = kets @@ -76,6 +77,9 @@ def to(self, coupling_scheme: Literal["JJ"]) -> AngularState[AngularKetJJ]: ... @overload def to(self, coupling_scheme: Literal["FJ"]) -> AngularState[AngularKetFJ]: ... + @overload + def to(self, coupling_scheme: Literal["KS"]) -> AngularState[AngularKetKS]: ... + def to(self, coupling_scheme: CouplingScheme) -> AngularState[Any]: """Convert to specified coupling scheme. @@ -96,7 +100,7 @@ def to(self, coupling_scheme: CouplingScheme) -> AngularState[Any]: else: kets.append(scheme_ket) coefficients.append(coeff * scheme_coeff) - return AngularState(coefficients, kets, warn_if_not_normalized=abs(self.norm - 1) < 1e-10) + return AngularState(coefficients, kets, warn_if_not_normalized=False) def calc_exp_qn(self, q: AngularMomentumQuantumNumbers) -> float: """Calculate the expectation value of a quantum number q. @@ -106,7 +110,7 @@ def calc_exp_qn(self, q: AngularMomentumQuantumNumbers) -> float: """ if q not in self.kets[0].quantum_number_names: - for ket_class in [AngularKetLS, AngularKetJJ, AngularKetFJ]: + for ket_class in [AngularKetLS, AngularKetJJ, AngularKetFJ, AngularKetKS]: if q in ket_class.quantum_number_names: return self.to(ket_class.coupling_scheme).calc_exp_qn(q) @@ -124,7 +128,7 @@ def calc_std_qn(self, q: AngularMomentumQuantumNumbers) -> float: """ if q not in self.kets[0].quantum_number_names: - for ket_class in [AngularKetLS, AngularKetJJ, AngularKetFJ]: + for ket_class in [AngularKetLS, AngularKetJJ, AngularKetFJ, AngularKetKS]: if q in ket_class.quantum_number_names: return self.to(ket_class.coupling_scheme).calc_std_qn(q) @@ -164,8 +168,8 @@ def calc_reduced_matrix_element( """ if isinstance(other, AngularKetBase): other = other.to_state() - if operator in get_args(AngularMomentumQuantumNumbers) and operator not in self.kets[0].quantum_number_names: - for ket_class in [AngularKetLS, AngularKetJJ, AngularKetFJ]: + if is_angular_momentum_quantum_number(operator) and operator not in self.kets[0].quantum_number_names: + for ket_class in [AngularKetLS, AngularKetJJ, AngularKetFJ, AngularKetKS]: if operator in ket_class.quantum_number_names: return self.to(ket_class.coupling_scheme).calc_reduced_matrix_element(other, operator, kappa) diff --git a/src/rydstate/basis/__init__.py b/src/rydstate/basis/__init__.py new file mode 100644 index 0000000..982ddcd --- /dev/null +++ b/src/rydstate/basis/__init__.py @@ -0,0 +1,17 @@ +from rydstate.basis.basis_mqdt import BasisMQDT +from rydstate.basis.basis_sqdt import ( + BasisSQDTAlkali, + BasisSQDTAlkalineFJ, + BasisSQDTAlkalineJJ, + BasisSQDTAlkalineKS, + BasisSQDTAlkalineLS, +) + +__all__ = [ + "BasisMQDT", + "BasisSQDTAlkali", + "BasisSQDTAlkalineFJ", + "BasisSQDTAlkalineJJ", + "BasisSQDTAlkalineKS", + "BasisSQDTAlkalineLS", +] diff --git a/src/rydstate/basis/basis_base.py b/src/rydstate/basis/basis_base.py new file mode 100644 index 0000000..ba0e14d --- /dev/null +++ b/src/rydstate/basis/basis_base.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload + +import numpy as np +from typing_extensions import Self + +from rydstate.angular.angular_matrix_element import is_angular_momentum_quantum_number +from rydstate.rydberg.rydberg_base import RydbergStateBase +from rydstate.species.species_object import SpeciesObject +from rydstate.units import ureg + +if TYPE_CHECKING: + from rydstate.units import MatrixElementOperator, NDArray, PintArray, PintFloat + +_RydbergState = TypeVar("_RydbergState", bound=RydbergStateBase) + + +class BasisBase(ABC, Generic[_RydbergState]): + states: list[_RydbergState] + + def __init__(self, species: str | SpeciesObject) -> None: + if isinstance(species, str): + species = SpeciesObject.from_name(species) + self.species = species + + def __len__(self) -> int: + return len(self.states) + + def copy(self) -> Self: + new_basis = self.__class__.__new__(self.__class__) + new_basis.species = self.species + new_basis.states = list(self.states) + return new_basis + + @overload + def filter_states(self, qn: str, value: tuple[float, float], *, delta: float = 1e-10) -> Self: ... + + @overload + def filter_states(self, qn: str, value: float, *, delta: float = 1e-10) -> Self: ... + + def filter_states(self, qn: str, value: float | tuple[float, float], *, delta: float = 1e-10) -> Self: + if isinstance(value, tuple): + qn_min = value[0] - delta + qn_max = value[1] + delta + else: + qn_min = value - delta + qn_max = value + delta + + if is_angular_momentum_quantum_number(qn): + self.states = [state for state in self.states if qn_min <= state.angular.calc_exp_qn(qn) <= qn_max] + elif qn in ["n", "nu", "nu_energy"]: + self.states = [state for state in self.states if qn_min <= getattr(state, qn) <= qn_max] + else: + raise ValueError(f"Unknown quantum number {qn}") + + return self + + def sort_states(self, qn: str) -> Self: + values = self.calc_exp_qn(qn) + sorted_indices = np.argsort(values) + self.states = [self.states[i] for i in sorted_indices] + return self + + def calc_exp_qn(self, qn: str) -> list[float]: + if is_angular_momentum_quantum_number(qn): + return [state.angular.calc_exp_qn(qn) for state in self.states] + if qn in ["n", "nu", "nu_energy"]: + return [getattr(state, qn) for state in self.states] + raise ValueError(f"Unknown quantum number {qn}") + + def calc_std_qn(self, qn: str) -> list[float]: + if is_angular_momentum_quantum_number(qn): + return [state.angular.calc_std_qn(qn) for state in self.states] + if qn in ["n", "nu", "nu_energy"]: + return [0 for state in self.states] + raise ValueError(f"Unknown quantum number {qn}") + + def calc_reduced_overlap(self, other: RydbergStateBase) -> NDArray: + """Calculate the reduced overlap (ignoring the magnetic quantum number m).""" + return np.array([bra.calc_reduced_overlap(other) for bra in self.states]) + + def calc_reduced_overlaps(self, other: BasisBase[Any]) -> NDArray: + """Calculate the reduced overlap for all states in the bases self and other. + + Returns a numpy array overlaps, where overlaps[i,j] corresponds to the overlap of the + i-th state of self and the j-th state of other. + """ + return np.array([[bra.calc_reduced_overlap(ket) for ket in other.states] for bra in self.states]) + + @overload + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: None = None + ) -> PintArray: ... + + @overload + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str + ) -> NDArray: ... + + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str | None = None + ) -> PintArray | NDArray: + r"""Calculate the reduced matrix element.""" + values_list = [bra.calc_reduced_matrix_element(other, operator, unit=unit) for bra in self.states] + if unit is not None: + return np.array(values_list) + + values: list[PintFloat] = values_list # type: ignore[assignment] + _unit = values[0].units + _values = np.array([v.magnitude for v in values]) + return ureg.Quantity(_values, _unit) + + @overload + def calc_reduced_matrix_elements( + self, other: BasisBase[Any], operator: MatrixElementOperator, unit: None = None + ) -> PintArray: ... + + @overload + def calc_reduced_matrix_elements( + self, other: BasisBase[Any], operator: MatrixElementOperator, unit: str + ) -> NDArray: ... + + def calc_reduced_matrix_elements( + self, other: BasisBase[Any], operator: MatrixElementOperator, unit: str | None = None + ) -> PintArray | NDArray: + r"""Calculate the reduced matrix element.""" + values_list = [ + [bra.calc_reduced_matrix_element(ket, operator, unit=unit) for ket in other.states] for bra in self.states + ] + if unit is not None: + return np.array(values_list) + + values: list[PintFloat] = values_list # type: ignore[assignment] + _unit = values[0].units + _values = np.array([v.magnitude for v in values]) + return ureg.Quantity(_values, _unit) diff --git a/src/rydstate/basis/basis_mqdt.py b/src/rydstate/basis/basis_mqdt.py new file mode 100644 index 0000000..b73b8db --- /dev/null +++ b/src/rydstate/basis/basis_mqdt.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from rydstate.angular.angular_ket import julia_qn_to_dict +from rydstate.basis.basis_base import BasisBase +from rydstate.rydberg.rydberg_mqdt import RydbergStateMQDT +from rydstate.rydberg.rydberg_sqdt import RydbergStateSQDT + +if TYPE_CHECKING: + from rydstate.species import SpeciesObject + +logger = logging.getLogger(__name__) + +try: + USE_JULIACALL = True + from juliacall import ( + JuliaError, + Main as jl, # noqa: N813 + convert, + ) +except ImportError: + USE_JULIACALL = False + + +if USE_JULIACALL: + try: + jl.seval("using MQDT") + jl.seval("using CGcoefficient") + except JuliaError: + logger.exception("Failed to load Julia MQDT or CGcoefficient package") + USE_JULIACALL = False + +FMODEL_MAX_L = {"Sr87": 2, "Sr88": 2, "Yb171": 4, "Yb173": 1, "Yb174": 4} + + +class BasisMQDT(BasisBase[RydbergStateMQDT[Any]]): + def __init__( + self, + species: str | SpeciesObject, + n_min: int = 0, + n_max: int | None = None, + *, + skip_high_l: bool = True, + model_names: list[str] | None = None, + ) -> None: + super().__init__(species) + + if not USE_JULIACALL: + raise ImportError("JuliaCall or the MQDT Julia package is not available.") + + try: + self.jl_species = getattr(jl.MQDT, self.species.name) + parameters = self.jl_species.PARA + except AttributeError as e: + raise ValueError(f"Species '{species}' is not supported in the MQDT Julia package.") from e + + # TODO use n_min and n_max of the different models + + if n_max is None: + raise ValueError("n_max must be given") + + # initialize Wigner symbol calculation + if skip_high_l: + jl.CGcoefficient.wigner_init_float(5, "Jmax", 9) + else: + jl.CGcoefficient.wigner_init_float(n_max - 1, "Jmax", 9) + + logger.debug("Calculating low l MQDT states...") + + jl_species_attr_names = [str(name) for name in jl.seval(f"names(MQDT.{self.species.name}, all=true)")] + self.models = {name: getattr(self.jl_species, name) for name in jl_species_attr_names} + self.models = {k: v for k, v in self.models.items() if str(v).startswith("fModel")} + if model_names is not None: + self.models = {k: v for k, v in self.models.items() if k in model_names} + + if skip_high_l: + logger.debug("Skipping high l states.") + else: + logger.debug("Calculating high l SQDT states...") + l_start = FMODEL_MAX_L[self.species.name] + 1 + high_l_models = { + f"high_l_{l_ryd}": jl.single_channel_models(species, l_ryd, parameters) + for l_ryd in range(l_start, n_max) + } + self.models.update(high_l_models) + + model_names = list(self.models.keys()) + jl_states = {name: jl.eigenstates(n_min, n_max, model, parameters) for name, model in self.models.items()} + _models_vector = convert(jl.Vector, [self.models[name] for name in model_names]) + _jl_states_vector = convert(jl.Vector, [jl_states[name] for name in model_names]) + jl_basis = jl.basisarray(_jl_states_vector, _models_vector) + + logger.debug("Generated state table with %d states", len(jl_basis.states)) + + self.states = [] + for jl_state in jl_basis.states: + coeffs = jl_state.coeff + nus = jl_state.nu + nu_energy = jl_state.energy + qns = jl_state.channels.i + qns = [julia_qn_to_dict(qn) for qn in qns] + + sqdt_states = [RydbergStateSQDT(species, nu=nu, **qn) for nu, qn in zip(nus, qns)] + # check angular and radial are created correctly + [(s.angular, s.radial) for s in sqdt_states] + + mqdt_state = RydbergStateMQDT(coeffs, sqdt_states, nu_energy=nu_energy, warn_if_not_normalized=False) + self.states.append(mqdt_state) diff --git a/src/rydstate/basis/basis_sqdt.py b/src/rydstate/basis/basis_sqdt.py new file mode 100644 index 0000000..97bead4 --- /dev/null +++ b/src/rydstate/basis/basis_sqdt.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import numpy as np + +from rydstate.basis.basis_base import BasisBase +from rydstate.rydberg import ( + RydbergStateSQDT, + RydbergStateSQDTAlkali, + RydbergStateSQDTAlkalineFJ, + RydbergStateSQDTAlkalineJJ, + RydbergStateSQDTAlkalineLS, +) + + +class BasisSQDTAlkali(BasisBase[RydbergStateSQDTAlkali]): + def __init__(self, species: str, n_min: int = 0, n_max: int | None = None) -> None: + super().__init__(species) + + if n_max is None: + raise ValueError("n_max must be given") + + s = 1 / 2 + i_c = self.species.i_c if self.species.i_c is not None else 0 + + self.states = [] + for n in range(n_min, n_max + 1): + for l in range(n): + for j in np.arange(abs(l - s), l + s + 1): + for f in np.arange(abs(j - i_c), j + i_c + 1): + state = RydbergStateSQDTAlkali(species, n=n, l=l, j=float(j), f=float(f)) + self.states.append(state) + + +class BasisSQDTAlkalineLS(BasisBase[RydbergStateSQDTAlkalineLS]): + def __init__(self, species: str, n_min: int = 0, n_max: int | None = None) -> None: + super().__init__(species) + + if n_max is None: + raise ValueError("n_max must be given") + + i_c = self.species.i_c if self.species.i_c is not None else 0 + + self.states = [] + for n in range(n_min, n_max + 1): + for l in range(n): + for s_tot in [0, 1]: + for j_tot in range(abs(l - s_tot), l + s_tot + 1): + for f_tot in np.arange(abs(j_tot - i_c), j_tot + i_c + 1): + state = RydbergStateSQDTAlkalineLS( + species, n=n, l=l, s_tot=s_tot, j_tot=j_tot, f_tot=float(f_tot) + ) + self.states.append(state) + + +class BasisSQDTAlkalineJJ(BasisBase[RydbergStateSQDTAlkalineJJ]): + def __init__(self, species: str, n_min: int = 0, n_max: int | None = None) -> None: + super().__init__(species) + + if n_max is None: + raise ValueError("n_max must be given") + + i_c = self.species.i_c if self.species.i_c is not None else 0 + j_c = 0.5 + s_r = 0.5 + self.states = [] + for n in range(n_min, n_max + 1): + for l_r in range(n): + for j_r in np.arange(abs(l_r - s_r), l_r + s_r + 1): + for j_tot in range(int(abs(j_r - j_c)), int(j_r + j_c + 1)): + for f_tot in np.arange(abs(j_tot - i_c), j_tot + i_c + 1): + state = RydbergStateSQDTAlkalineJJ( + species, n=n, l=l_r, j_r=float(j_r), j_tot=j_tot, f_tot=float(f_tot) + ) + self.states.append(state) + + +class BasisSQDTAlkalineFJ(BasisBase[RydbergStateSQDTAlkalineFJ]): + def __init__(self, species: str, n_min: int = 0, n_max: int | None = None) -> None: + super().__init__(species) + + if n_max is None: + raise ValueError("n_max must be given") + + i_c = self.species.i_c if self.species.i_c is not None else 0 + j_c = 0.5 + s_r = 0.5 + self.states = [] + for n in range(n_min, n_max + 1): + for l_r in range(n): + for j_r in np.arange(abs(l_r - s_r), l_r + s_r + 1): + for f_c in np.arange(abs(j_c - i_c), j_c + i_c + 1): + for f_tot in np.arange(abs(f_c - j_r), f_c + j_r + 1): + state = RydbergStateSQDTAlkalineFJ( + species, n=n, l=l_r, j_r=float(j_r), f_c=float(f_c), f_tot=float(f_tot) + ) + self.states.append(state) + + +class BasisSQDTAlkalineKS(BasisBase[RydbergStateSQDT]): + def __init__(self, species: str, n_min: int = 0, n_max: int | None = None) -> None: + super().__init__(species) + + if n_max is None: + raise ValueError("n_max must be given") + + i_c = self.species.i_c if self.species.i_c is not None else 0 + j_c = 0.5 + s_r = 0.5 + self.states = [] + for n in range(n_min, n_max + 1): + for l_r in range(n): + for k in np.arange(abs(j_c - l_r), j_c + l_r + 1): + for j_tot in np.arange(abs(k - s_r), k + s_r + 1): + for f_tot in np.arange(abs(j_tot - i_c), j_tot + i_c + 1): + state = RydbergStateSQDT( + species, n=n, l_r=l_r, k=float(k), j_tot=float(j_tot), f_tot=float(f_tot) + ) + self.states.append(state) diff --git a/src/rydstate/radial/__init__.py b/src/rydstate/radial/__init__.py index 4e436b5..3b21898 100644 --- a/src/rydstate/radial/__init__.py +++ b/src/rydstate/radial/__init__.py @@ -1,15 +1,15 @@ from rydstate.radial.grid import Grid from rydstate.radial.model import Model, PotentialType from rydstate.radial.numerov import run_numerov_integration +from rydstate.radial.radial_ket import RadialKet from rydstate.radial.radial_matrix_element import calc_radial_matrix_element_from_w_z -from rydstate.radial.radial_state import RadialState from rydstate.radial.wavefunction import Wavefunction, WavefunctionNumerov, WavefunctionWhittaker __all__ = [ "Grid", "Model", "PotentialType", - "RadialState", + "RadialKet", "Wavefunction", "WavefunctionNumerov", "WavefunctionWhittaker", diff --git a/src/rydstate/radial/radial_state.py b/src/rydstate/radial/radial_ket.py similarity index 94% rename from src/rydstate/radial/radial_state.py rename to src/rydstate/radial/radial_ket.py index ccac5f3..782a137 100644 --- a/src/rydstate/radial/radial_state.py +++ b/src/rydstate/radial/radial_ket.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -class RadialState: +class RadialKet: r"""Class representing a radial Rydberg state.""" def __init__( @@ -30,7 +30,7 @@ def __init__( nu: float, l_r: int, ) -> None: - r"""Initialize the radial state. + r"""Initialize the radial ket. Args: species: Atomic species. @@ -197,11 +197,11 @@ def create_wavefunction( self._wavefunction.apply_sign_convention(sign_convention) self._grid = self._wavefunction.grid - def calc_overlap(self, other: RadialState, *, integration_method: INTEGRATION_METHODS = "sum") -> float: - r"""Calculate the overlap of two radial states. + def calc_overlap(self, other: RadialKet, *, integration_method: INTEGRATION_METHODS = "sum") -> float: + r"""Calculate the overlap of two radial kets. Args: - other: Other radial state + other: Other radial ket integration_method: Integration method to use Returns: @@ -212,17 +212,17 @@ def calc_overlap(self, other: RadialState, *, integration_method: INTEGRATION_ME @overload def calc_matrix_element( - self, other: RadialState, k_radial: int, *, integration_method: INTEGRATION_METHODS = "sum" + self, other: RadialKet, k_radial: int, *, integration_method: INTEGRATION_METHODS = "sum" ) -> PintFloat: ... @overload def calc_matrix_element( - self, other: RadialState, k_radial: int, unit: str, *, integration_method: INTEGRATION_METHODS = "sum" + self, other: RadialKet, k_radial: int, unit: str, *, integration_method: INTEGRATION_METHODS = "sum" ) -> float: ... def calc_matrix_element( self, - other: RadialState, + other: RadialKet, k_radial: int, unit: str | None = None, *, @@ -241,7 +241,7 @@ def calc_matrix_element( and w(z) = z^{-1/2} \tilde{u}(z^2) = (r/_a_0)^{1/4} \sqrt{a_0} r R(r). Args: - other: Other radial state + other: Other radial ket k_radial: Power of r in the matrix element (default=0, this corresponds to the overlap integral \int dr r^2 R_1(r) R_2(r)) unit: Unit of the returned matrix element, default None returns a Pint quantity. diff --git a/src/rydstate/radial/wavefunction.py b/src/rydstate/radial/wavefunction.py index 2ec3b7b..d7b8109 100644 --- a/src/rydstate/radial/wavefunction.py +++ b/src/rydstate/radial/wavefunction.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from rydstate.radial import Grid, Model - from rydstate.radial.radial_state import RadialState + from rydstate.radial.radial_ket import RadialKet from rydstate.units import NDArray logger = logging.getLogger(__name__) @@ -27,17 +27,17 @@ class Wavefunction(ABC): def __init__( self, - radial_state: RadialState, + radial_ket: RadialKet, grid: Grid, ) -> None: """Create a Wavefunction object. Args: - radial_state: The RadialState object. + radial_ket: The RadialKet object. grid: The grid object. """ - self.radial_state = radial_state + self.radial_ket = radial_ket self.grid = grid self._w_list: NDArray | None = None @@ -92,8 +92,8 @@ def apply_sign_convention(self, sign_convention: WavefunctionSignConvention) -> break if sign_convention == "n_l_1": - assert self.radial_state.n is not None, "n must be given to apply the n_l_1 sign convention." - if current_outer_sign != (-1) ** (self.radial_state.n - self.radial_state.l_r - 1): + assert self.radial_ket.n is not None, "n must be given to apply the n_l_1 sign convention." + if current_outer_sign != (-1) ** (self.radial_ket.n - self.radial_ket.l_r - 1): self._w_list = -self._w_list elif sign_convention == "positive_at_outer_bound": if current_outer_sign != 1: @@ -115,19 +115,19 @@ class WavefunctionNumerov(Wavefunction): def __init__( self, - radial_state: RadialState, + radial_ket: RadialKet, grid: Grid, model: Model, ) -> None: """Create a Wavefunction object. Args: - radial_state: The RadialState object. + radial_ket: The RadialKet object. grid: The grid object. model: The model object. """ - super().__init__(radial_state, grid) + super().__init__(radial_ket, grid) self.model = model def integrate(self, run_backward: bool = True, w0: float = 1e-10, *, _use_njit: bool = True) -> None: @@ -172,8 +172,8 @@ def integrate(self, run_backward: bool = True, w0: float = 1e-10, *, _use_njit: # and not like in the rest of this class, i.e. y = w(z) and x = z grid = self.grid - species = self.radial_state.species - energy_au = calc_energy_from_nu(species.reduced_mass_au, self.radial_state.nu) + species = self.radial_ket.species + energy_au = calc_energy_from_nu(species.reduced_mass_au, self.radial_ket.nu) v_eff = self.model.calc_total_effective_potential(grid.x_list) glist = 8 * species.reduced_mass_au * grid.z_list * grid.z_list * (energy_au - v_eff) @@ -201,7 +201,7 @@ def integrate(self, run_backward: bool = True, w0: float = 1e-10, *, _use_njit: y0, y1 = 0, w0 x_start, x_stop, dx = grid.z_min, grid.z_max, grid.dz g_list_directed = glist - n = self.radial_state.n if self.radial_state.n is not None else self.radial_state.nu + n = self.radial_ket.n if self.radial_ket.n is not None else self.radial_ket.nu x_min = math.sqrt(n * (n + 15)) if _use_njit: @@ -238,7 +238,7 @@ def sanity_check(self, z_stop: float, run_backward: bool) -> bool: # noqa: C901 warning_msgs: list[str] = [] grid = self.grid - state = self.radial_state + state = self.radial_ket # Check and Correct if divergence of the wavefunction w_list_abs = np.abs(self.w_list) @@ -283,7 +283,7 @@ def sanity_check(self, z_stop: float, run_backward: bool) -> bool: # noqa: C901 elif n <= 16: tol = 2e-3 - species = self.radial_state.species + species = self.radial_ket.species if species.number_valence_electrons == 2: # For divalent atoms the inner boundary is less well behaved ... tol = 2e-2 @@ -338,8 +338,8 @@ def sanity_check(self, z_stop: float, run_backward: bool) -> bool: # noqa: C901 class WavefunctionWhittaker(Wavefunction): def integrate(self) -> None: logger.warning("Using Whittaker to get the wavefunction is not recommended! Use this only for comparison.") - l = self.radial_state.l_r - nu = self.radial_state.nu + l = self.radial_ket.l_r + nu = self.radial_ket.nu whitw_vectorized = np.vectorize(whitw, otypes=[float]) whitw_list = whitw_vectorized(nu, l + 0.5, 2 * self.grid.x_list / nu) diff --git a/src/rydstate/rydberg/__init__.py b/src/rydstate/rydberg/__init__.py new file mode 100644 index 0000000..4d8de57 --- /dev/null +++ b/src/rydstate/rydberg/__init__.py @@ -0,0 +1,17 @@ +from rydstate.rydberg.rydberg_mqdt import RydbergStateMQDT +from rydstate.rydberg.rydberg_sqdt import ( + RydbergStateSQDT, + RydbergStateSQDTAlkali, + RydbergStateSQDTAlkalineFJ, + RydbergStateSQDTAlkalineJJ, + RydbergStateSQDTAlkalineLS, +) + +__all__ = [ + "RydbergStateMQDT", + "RydbergStateSQDT", + "RydbergStateSQDTAlkali", + "RydbergStateSQDTAlkalineFJ", + "RydbergStateSQDTAlkalineJJ", + "RydbergStateSQDTAlkalineLS", +] diff --git a/src/rydstate/rydberg/rydberg_base.py b/src/rydstate/rydberg/rydberg_base.py new file mode 100644 index 0000000..dd243de --- /dev/null +++ b/src/rydstate/rydberg/rydberg_base.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from rydstate.angular import AngularState + from rydstate.angular.angular_ket import AngularKetBase + from rydstate.units import MatrixElementOperator, PintFloat + + +logger = logging.getLogger(__name__) + + +class RydbergStateBase(ABC): + @property + @abstractmethod + def angular(self) -> AngularState[Any] | AngularKetBase: ... + + @abstractmethod + def calc_reduced_overlap(self, other: RydbergStateBase) -> float: ... + + @abstractmethod + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str | None = None + ) -> PintFloat | float: ... diff --git a/src/rydstate/rydberg/rydberg_mqdt.py b/src/rydstate/rydberg/rydberg_mqdt.py new file mode 100644 index 0000000..4b8202c --- /dev/null +++ b/src/rydstate/rydberg/rydberg_mqdt.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import logging +from functools import cached_property +from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload + +import numpy as np + +from rydstate.angular import AngularState +from rydstate.rydberg.rydberg_base import RydbergStateBase +from rydstate.rydberg.rydberg_sqdt import RydbergStateSQDT + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + from rydstate.units import MatrixElementOperator, PintFloat + + +logger = logging.getLogger(__name__) + + +_RydbergState = TypeVar("_RydbergState", bound=RydbergStateSQDT) + + +class RydbergStateMQDT(RydbergStateBase, Generic[_RydbergState]): + def __init__( + self, + coefficients: Sequence[float], + sqdt_states: Sequence[_RydbergState], + *, + nu_energy: float | None = None, + warn_if_not_normalized: bool = True, + ) -> None: + self.coefficients = np.array(coefficients) + self.sqdt_states = sqdt_states + self.nu_energy = nu_energy + + if len(coefficients) != len(sqdt_states): + raise ValueError("Length of coefficients and sqdt_states must be the same.") + if not all(type(sqdt_state) is type(sqdt_states[0]) for sqdt_state in sqdt_states): + raise ValueError("All sqdt_states must be of the same type.") + if len(set(sqdt_states)) != len(sqdt_states): + raise ValueError("RydbergStateMQDT initialized with duplicate sqdt_states.") + if abs(self.norm - 1) > 1e-10 and warn_if_not_normalized: + logger.warning( + "RydbergStateMQDT initialized with non-normalized coefficients " + "(norm=%s, coefficients=%s, sqdt_states=%s)", + self.norm, + coefficients, + sqdt_states, + ) + if self.norm > 1: + self.coefficients /= self.norm + + def __iter__(self) -> Iterator[tuple[float, _RydbergState]]: + return zip(self.coefficients, self.sqdt_states).__iter__() + + def __repr__(self) -> str: + terms = [f"{coeff}*{sqdt_state!r}" for coeff, sqdt_state in self] + return f"{self.__class__.__name__}({', '.join(terms)})" + + def __str__(self) -> str: + terms = [f"{coeff}*{sqdt_state!s}" for coeff, sqdt_state in self] + return f"{', '.join(terms)}" + + @property + def norm(self) -> float: + """Return the norm of the state (should be 1).""" + return np.linalg.norm(self.coefficients) # type: ignore [return-value] + + @cached_property + def angular(self) -> AngularState[Any]: + """Return the angular part of the MQDT state as an AngularState.""" + angular_kets = [ket.angular for ket in self.sqdt_states] + return AngularState(self.coefficients.tolist(), angular_kets) + + def calc_reduced_overlap(self, other: RydbergStateBase) -> float: + """Calculate the reduced overlap (ignoring the magnetic quantum number m).""" + if isinstance(other, RydbergStateSQDT): + other = other.to_mqdt() + + if isinstance(other, RydbergStateMQDT): + ov = 0 + for coeff1, sqdt1 in self: + for coeff2, sqdt2 in other: + ov += np.conjugate(coeff1) * coeff2 * sqdt1.calc_reduced_overlap(sqdt2) + return ov + + raise NotImplementedError(f"calc_reduced_overlap not implemented for {type(self)=}, {type(other)=}") + + @overload # type: ignore [override] + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: None = None + ) -> PintFloat: ... + + @overload + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str + ) -> float: ... + + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str | None = None + ) -> PintFloat | float: + r"""Calculate the reduced angular matrix element. + + This means, calculate the following matrix element: + + .. math:: + \left\langle self || \hat{O}^{(\kappa)} || other \right\rangle + + """ + if isinstance(other, RydbergStateSQDT): + other = other.to_mqdt() + + if isinstance(other, RydbergStateMQDT): + value = 0 + for coeff1, sqdt1 in self: + for coeff2, sqdt2 in other: + value += ( + np.conjugate(coeff1) * coeff2 * sqdt1.calc_reduced_matrix_element(sqdt2, operator, unit=unit) + ) + return value + + raise NotImplementedError(f"calc_reduced_overlap not implemented for {type(self)=}, {type(other)=}") diff --git a/src/rydstate/rydberg_state.py b/src/rydstate/rydberg/rydberg_sqdt.py similarity index 50% rename from src/rydstate/rydberg_state.py rename to src/rydstate/rydberg/rydberg_sqdt.py index c3eab6e..4d1e2df 100644 --- a/src/rydstate/rydberg_state.py +++ b/src/rydstate/rydberg/rydberg_sqdt.py @@ -2,46 +2,142 @@ import logging import math -from abc import ABC, abstractmethod from functools import cached_property -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Any, overload import numpy as np +from typing_extensions import deprecated -from rydstate.angular import AngularKetJJ, AngularKetLS -from rydstate.angular.utils import try_trivial_spin_addition -from rydstate.radial import RadialState -from rydstate.species.species_object import SpeciesObject +from rydstate.angular.angular_ket import quantum_numbers_to_angular_ket +from rydstate.radial import RadialKet +from rydstate.rydberg.rydberg_base import RydbergStateBase +from rydstate.species import SpeciesObject from rydstate.species.utils import calc_energy_from_nu from rydstate.units import BaseQuantities, MatrixElementOperatorRanks, ureg if TYPE_CHECKING: - from typing_extensions import Self - - from rydstate.angular.angular_ket import AngularKetBase + from rydstate import RydbergStateMQDT + from rydstate.angular.angular_ket import AngularKetBase, AngularKetFJ, AngularKetJJ, AngularKetLS from rydstate.units import MatrixElementOperator, PintFloat logger = logging.getLogger(__name__) -class RydbergStateBase(ABC): +class RydbergStateSQDT(RydbergStateBase): species: SpeciesObject + def __init__( + self, + species: str | SpeciesObject, + n: int | None = None, + nu: float | None = None, + s_c: float | None = None, + l_c: int = 0, + j_c: float | None = None, + f_c: float | None = None, + s_r: float = 0.5, + l_r: int | None = None, + j_r: float | None = None, + k: float | None = None, + s_tot: float | None = None, + l_tot: int | None = None, + j_tot: float | None = None, + f_tot: float | None = None, + m: float | None = None, + ) -> None: + r"""Initialize the Rydberg state. + + Args: + species: Atomic species. + n: Principal quantum number of the rydberg electron. + nu: Effective principal quantum number of the rydberg electron. + Optional, if not given it will be calculated from n, l, j_tot, s_tot. + s_c: Spin quantum number of the core electron (0 for Alkali, 0.5 for divalent atoms). + l_c: Orbital angular momentum quantum number of the core electron. + j_c: Total angular momentum quantum number of the core electron. + f_c: Total angular momentum quantum number of the core (core electron + nucleus). + s_r: Spin quantum number of the rydberg electron always 0.5) + l_r: Orbital angular momentum quantum number of the rydberg electron. + j_r: Total angular momentum quantum number of the rydberg electron. + k: Intermediate angular momentum (j_c + l_r). + s_tot: Total spin quantum number of all electrons. + l_tot: Total orbital angular momentum quantum number of all electrons. + j_tot: Total angular momentum quantum number of all electrons. + f_tot: Total angular momentum quantum number of the atom (rydberg electron + core) + m: Total magnetic quantum number. + Optional, only needed for concrete angular matrix elements. + + """ + if isinstance(species, str): + species = SpeciesObject.from_name(species) + self.species = species + + self._qns = dict( # noqa: C408 + s_c=s_c, + l_c=l_c, + j_c=j_c, + f_c=f_c, + s_r=s_r, + l_r=l_r, + j_r=j_r, + k=k, + s_tot=s_tot, + l_tot=l_tot, + j_tot=j_tot, + f_tot=f_tot, + m=m, + ) + + self.n = n + self._nu = nu + if nu is None and n is None: + raise ValueError("Either n or nu must be given to initialize the Rydberg state.") + + def __repr__(self) -> str: + species, n, nu = self.species.name, self.n, self.nu + n_str = f", {n=}" if n is not None else "" + return f"{self.__class__.__name__}({species=}{n_str}, {nu=}, {self.angular})" + def __str__(self) -> str: return self.__repr__() - @property - @abstractmethod - def radial(self) -> RadialState: ... + @cached_property + def radial(self) -> RadialKet: + """The radial part of the Rydberg electron.""" + radial_ket = RadialKet(self.species, nu=self.nu, l_r=self.angular.l_r) + if self.n is not None: + radial_ket.set_n_for_sanity_check(self.n) + s_tot_list = [self.angular.get_qn("s_tot")] if "s_tot" in self.angular.quantum_number_names else [0, 1] + for s_tot in s_tot_list: + if not self.species.is_allowed_shell(self.n, self.angular.l_r, s_tot=s_tot): + raise ValueError( + f"The shell (n={self.n}, l_r={self.angular.l_r}, s_tot={s_tot})" + f" is not allowed for the species {self.species}." + ) + return radial_ket + + @cached_property + def angular(self) -> AngularKetBase: + """The angular/spin part of the Rydberg electron.""" + return quantum_numbers_to_angular_ket(species=self.species, **self._qns) # type: ignore [arg-type] - @property - @abstractmethod - def angular(self) -> AngularKetBase: ... + @cached_property + def nu(self) -> float: + """The effective principal quantum number nu (for alkali atoms also known as n*) for the Rydberg state.""" + if self._nu is not None: + return self._nu + assert self.n is not None + if any(qn not in self.angular.quantum_number_names for qn in ["j_tot", "s_tot"]): + raise ValueError("j_tot and s_tot must be defined to calculate nu from n.") + return self.species.calc_nu( + self.n, self.angular.l_r, self.angular.get_qn("j_tot"), s_tot=self.angular.get_qn("s_tot") + ) - @abstractmethod + @deprecated("Use the property nu instead.") def get_nu(self) -> float: """Get the effective principal quantum number nu (for alkali atoms also known as n*) for the Rydberg state.""" + return self.nu @overload def get_energy(self, unit: None = None) -> PintFloat: ... @@ -59,8 +155,7 @@ def get_energy(self, unit: str | None = None) -> PintFloat | float: where `\mu = R_M/R_\infty` is the reduced mass and `\nu` the effective principal quantum number. """ - nu = self.get_nu() - energy_au = calc_energy_from_nu(self.species.reduced_mass_au, nu) + energy_au = calc_energy_from_nu(self.species.reduced_mass_au, self.nu) if unit == "a.u.": return energy_au energy: PintFloat = energy_au * BaseQuantities["energy"] @@ -68,22 +163,33 @@ def get_energy(self, unit: str | None = None) -> PintFloat | float: return energy return energy.to(unit, "spectroscopy").magnitude + def to_mqdt(self) -> RydbergStateMQDT[Any]: + """Convert to a trivial RydbergMQDT state with only one contribution with coefficient 1.""" + from rydstate import RydbergStateMQDT # noqa: PLC0415 + + return RydbergStateMQDT([1], [self]) + def calc_reduced_overlap(self, other: RydbergStateBase) -> float: """Calculate the reduced overlap (ignoring the magnetic quantum number m).""" + if not isinstance(other, RydbergStateSQDT): + return self.to_mqdt().calc_reduced_overlap(other) + radial_overlap = self.radial.calc_overlap(other.radial) angular_overlap = self.angular.calc_reduced_overlap(other.angular) return radial_overlap * angular_overlap - @overload + @overload # type: ignore [override] def calc_reduced_matrix_element( - self, other: Self, operator: MatrixElementOperator, unit: None = None + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: None = None ) -> PintFloat: ... @overload - def calc_reduced_matrix_element(self, other: Self, operator: MatrixElementOperator, unit: str) -> float: ... + def calc_reduced_matrix_element( + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str + ) -> float: ... def calc_reduced_matrix_element( - self, other: Self, operator: MatrixElementOperator, unit: str | None = None + self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str | None = None ) -> PintFloat | float: r"""Calculate the reduced matrix element. @@ -106,6 +212,9 @@ def calc_reduced_matrix_element( The reduced matrix element for the given operator. """ + if not isinstance(other, RydbergStateSQDT): + return self.to_mqdt().calc_reduced_matrix_element(other, operator, unit=unit) + if operator not in MatrixElementOperatorRanks: raise ValueError( f"Operator {operator} not supported, must be one of {list(MatrixElementOperatorRanks.keys())}." @@ -148,13 +257,15 @@ def calc_reduced_matrix_element( return matrix_element.to(unit).magnitude @overload - def calc_matrix_element(self, other: Self, operator: MatrixElementOperator, q: int) -> PintFloat: ... + def calc_matrix_element(self, other: RydbergStateSQDT, operator: MatrixElementOperator, q: int) -> PintFloat: ... @overload - def calc_matrix_element(self, other: Self, operator: MatrixElementOperator, q: int, unit: str) -> float: ... + def calc_matrix_element( + self, other: RydbergStateSQDT, operator: MatrixElementOperator, q: int, unit: str + ) -> float: ... def calc_matrix_element( - self, other: Self, operator: MatrixElementOperator, q: int, unit: str | None = None + self, other: RydbergStateSQDT, operator: MatrixElementOperator, q: int, unit: str | None = None ) -> PintFloat | float: r"""Calculate the matrix element. @@ -185,9 +296,11 @@ def calc_matrix_element( return prefactor * reduced_matrix_element -class RydbergStateAlkali(RydbergStateBase): +class RydbergStateSQDTAlkali(RydbergStateSQDT): """Create an Alkali Rydberg state, including the radial and angular states.""" + angular: AngularKetLS + def __init__( self, species: str | SpeciesObject, @@ -196,6 +309,7 @@ def __init__( j: float | None = None, f: float | None = None, m: float | None = None, + nu: float | None = None, ) -> None: r"""Initialize the Rydberg state. @@ -208,46 +322,27 @@ def __init__( Optional, only needed if the species supports hyperfine structure (i.e. species.i_c is not None or 0). m: Total magnetic quantum number. Optional, only needed for concrete angular matrix elements. + nu: Effective principal quantum number of the rydberg electron. + Optional, if not given it will be calculated from n, l, j. """ - if isinstance(species, str): - species = SpeciesObject.from_name(species) - self.species = species - i_c = species.i_c if species.i_c is not None else 0 - self.n = n + super().__init__(species=species, n=n, nu=nu, l_r=l, j_tot=j, f_tot=f, m=m) + self.l = l - self.j = try_trivial_spin_addition(l, 0.5, j, "j") - self.f = try_trivial_spin_addition(self.j, i_c, f, "f") + self.j = self.angular.j_tot + self.f = self.angular.f_tot self.m = m - if species.number_valence_electrons != 1: - raise ValueError(f"The species {species.name} is not an alkali atom.") - if not species.is_allowed_shell(n, l): - raise ValueError(f"The shell ({n=}, {l=}) is not allowed for the species {self.species}.") - - @cached_property - def angular(self) -> AngularKetLS: - """The angular/spin state of the Rydberg electron.""" - return AngularKetLS(l_r=self.l, j_tot=self.j, m=self.m, f_tot=self.f, species=self.species) - - @cached_property - def radial(self) -> RadialState: - """The radial state of the Rydberg electron.""" - radial_state = RadialState(self.species, nu=self.get_nu(), l_r=self.l) - radial_state.set_n_for_sanity_check(self.n) - return radial_state - def __repr__(self) -> str: species, n, l, j, f, m = self.species, self.n, self.l, self.j, self.f, self.m return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j=}, {f=}, {m=})" - def get_nu(self) -> float: - return self.species.calc_nu(self.n, self.l, self.j, s_tot=1 / 2) - -class RydbergStateAlkalineLS(RydbergStateBase): +class RydbergStateSQDTAlkalineLS(RydbergStateSQDT): """Create an Alkaline Rydberg state, including the radial and angular states.""" + angular: AngularKetLS + def __init__( self, species: str | SpeciesObject, @@ -257,6 +352,7 @@ def __init__( j_tot: int | None = None, f_tot: float | None = None, m: float | None = None, + nu: float | None = None, ) -> None: r"""Initialize the Rydberg state. @@ -270,49 +366,28 @@ def __init__( Optional, only needed if the species supports hyperfine structure (i.e. species.i_c is not None or 0). m: Total magnetic quantum number. Optional, only needed for concrete angular matrix elements. + nu: Effective principal quantum number of the rydberg electron. + Optional, if not given it will be calculated from n, l, j_tot, s_tot. """ - if isinstance(species, str): - species = SpeciesObject.from_name(species) - self.species = species - i_c = species.i_c if species.i_c is not None else 0 - self.n = n + super().__init__(species=species, n=n, nu=nu, l_r=l, s_tot=s_tot, j_tot=j_tot, f_tot=f_tot, m=m) + self.l = l - self.s_tot = s_tot - self.j_tot = try_trivial_spin_addition(l, s_tot, j_tot, "j_tot") - self.f_tot = try_trivial_spin_addition(self.j_tot, i_c, f_tot, "f_tot") + self.s_tot = self.angular.s_tot + self.j_tot = self.angular.j_tot + self.f_tot = self.angular.f_tot self.m = m - if species.number_valence_electrons != 2: - raise ValueError(f"The species {species.name} is not an alkaline atom.") - if not species.is_allowed_shell(n, l, s_tot=s_tot): - raise ValueError(f"The shell ({n=}, {l=}) is not allowed for the species {self.species}.") - - @cached_property - def angular(self) -> AngularKetLS: - """The angular/spin state of the Rydberg electron.""" - return AngularKetLS( - l_r=self.l, s_tot=self.s_tot, j_tot=self.j_tot, f_tot=self.f_tot, m=self.m, species=self.species - ) - - @cached_property - def radial(self) -> RadialState: - """The radial state of the Rydberg electron.""" - radial_state = RadialState(self.species, nu=self.get_nu(), l_r=self.l) - radial_state.set_n_for_sanity_check(self.n) - return radial_state - def __repr__(self) -> str: species, n, l, s_tot, j_tot, f_tot, m = self.species, self.n, self.l, self.s_tot, self.j_tot, self.f_tot, self.m return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {s_tot=}, {j_tot=}, {f_tot=}, {m=})" - def get_nu(self) -> float: - return self.species.calc_nu(self.n, self.l, self.j_tot, s_tot=self.s_tot) - -class RydbergStateAlkalineJJ(RydbergStateBase): +class RydbergStateSQDTAlkalineJJ(RydbergStateSQDT): """Create an Alkaline Rydberg state, including the radial and angular states.""" + angular: AngularKetJJ + def __init__( self, species: str | SpeciesObject, @@ -322,6 +397,7 @@ def __init__( j_tot: int | None = None, f_tot: float | None = None, m: float | None = None, + nu: float | None = None, ) -> None: r"""Initialize the Rydberg state. @@ -336,50 +412,95 @@ def __init__( (i.e. species.i_c is not None and species.i_c != 0). m: Total magnetic quantum number. Optional, only needed for concrete angular matrix elements. + nu: Effective principal quantum number of the rydberg electron. + Optional, if not given it will be calculated from n, l, j_tot. """ - if isinstance(species, str): - species = SpeciesObject.from_name(species) - self.species = species - s_r, s_c = 1 / 2, 1 / 2 - i_c = species.i_c if species.i_c is not None else 0 - self.n = n - self.l = l - self.j_r = try_trivial_spin_addition(l, s_r, j_r, "j_r") - self.j_tot = try_trivial_spin_addition(self.j_r, s_c, j_tot, "j_tot") - self.f_tot = try_trivial_spin_addition(self.j_tot, i_c, f_tot, "f_tot") - self.m = m + super().__init__(species=species, n=n, nu=nu, l_r=l, j_r=j_r, j_tot=j_tot, f_tot=f_tot, m=m) - if species.number_valence_electrons != 2: - raise ValueError(f"The species {species.name} is not an alkaline atom.") - for s_tot in [0, 1]: - if not species.is_allowed_shell(n, l, s_tot=s_tot): - raise ValueError(f"The shell ({n=}, {l=}) is not allowed for the species {self.species}.") + self.l = self.angular.l_r + self.j_r = self.angular.j_r + self.j_tot = self.angular.j_tot + self.f_tot = self.angular.f_tot + self.m = self.angular.m - @cached_property - def angular(self) -> AngularKetJJ: - """The angular/spin state of the Rydberg electron.""" - return AngularKetJJ( - l_r=self.l, j_r=self.j_r, j_tot=self.j_tot, f_tot=self.f_tot, m=self.m, species=self.species - ) + def __repr__(self) -> str: + species, n, l, j_r, j_tot, f_tot, m = self.species, self.n, self.l, self.j_r, self.j_tot, self.f_tot, self.m + return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j_r=}, {j_tot=}, {f_tot=}, {m=})" @cached_property - def radial(self) -> RadialState: - """The radial state of the Rydberg electron.""" - radial_state = RadialState(self.species, nu=self.get_nu(), l_r=self.l) - radial_state.set_n_for_sanity_check(self.n) - return radial_state + def nu(self) -> float: + if self._nu is not None: + return self._nu + assert self.n is not None + nus = [self.species.calc_nu(self.n, self.l, self.j_tot, s_tot=s_tot) for s_tot in [0, 1]] + + if any(abs(nu - nus[0]) > 1e-10 for nu in nus[1:]): + raise ValueError( + "RydbergStateSQDTAlkalineJJ is intended for high-l states only, " + "where the quantum defects are the same for singlet and triplet states." + ) + return nus[0] + + +class RydbergStateSQDTAlkalineFJ(RydbergStateSQDT): + """Create an Alkaline Rydberg state, including the radial and angular states.""" + + angular: AngularKetFJ + + def __init__( + self, + species: str | SpeciesObject, + n: int, + l: int, + j_r: float, + f_c: float | None = None, + f_tot: float | None = None, + m: float | None = None, + nu: float | None = None, + ) -> None: + r"""Initialize the Rydberg state. + + Args: + species: Atomic species. + n: Principal quantum number of the rydberg electron. + l: Orbital angular momentum quantum number of the rydberg electron. + j_r: Total angular momentum quantum number of the Rydberg electron. + f_c: Total angular momentum quantum number of the core (core electron + nucleus). + f_tot: Total angular momentum quantum number of the atom (rydberg electron + core) + Optional, only needed if the species supports hyperfine structure (i.e. species.i_c is not None or 0). + m: Total magnetic quantum number. + Optional, only needed for concrete angular matrix elements. + nu: Effective principal quantum number of the rydberg electron. + Optional, if not given it will be calculated from n, l. + + """ + super().__init__(species=species, n=n, nu=nu, l_r=l, j_r=j_r, f_c=f_c, f_tot=f_tot, m=m) + + self.l = self.angular.l_r + self.j_r = self.angular.j_r + self.f_c = self.angular.f_c + self.f_tot = self.angular.f_tot + self.m = self.angular.m def __repr__(self) -> str: - species, n, l, j_r, j_tot, f_tot, m = self.species, self.n, self.l, self.j_r, self.j_tot, self.f_tot, self.m - return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j_r=}, {j_tot=}, {f_tot=}, {m=})" + species, n, l, j_r, f_c, f_tot, m = self.species, self.n, self.l, self.j_r, self.f_c, self.f_tot, self.m + return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j_r=}, {f_c=}, {f_tot=}, {m=})" - def get_nu(self) -> float: - nu_singlet = self.species.calc_nu(self.n, self.l, self.j_tot, s_tot=0) - nu_triplet = self.species.calc_nu(self.n, self.l, self.j_tot, s_tot=1) - if abs(nu_singlet - nu_triplet) > 1e-10: + @cached_property + def nu(self) -> float: + if self._nu is not None: + return self._nu + assert self.n is not None + nus = [ + self.species.calc_nu(self.n, self.l, float(j_tot), s_tot=s_tot) + for s_tot in [0, 1] + for j_tot in np.arange(abs(self.j_r - 1 / 2), self.j_r + 1 / 2 + 1) + ] + + if any(abs(nu - nus[0]) > 1e-10 for nu in nus[1:]): raise ValueError( - "RydbergStateAlkalineJJ is intended for high-l states only, " + "RydbergStateSQDTAlkalineFJ is intended for high-l states only, " "where the quantum defects are the same for singlet and triplet states." ) - return nu_singlet + return nus[0] diff --git a/src/rydstate/species/species_object.py b/src/rydstate/species/species_object.py index 5ac1e01..71b9928 100644 --- a/src/rydstate/species/species_object.py +++ b/src/rydstate/species/species_object.py @@ -178,7 +178,7 @@ def from_name(cls, name: str) -> SpeciesObject: This approach allows for easy extension of the library with new species. A user can even subclass SpeciesObject in his code (without modifying the rydstate library), e.g. `class CustomRubidium(SpeciesObject): name = "Custom_Rb" ...` - and then use the new species by calling RydbergStateAlkali("Custom_Rb", ...) + and then use the new species by calling RydbergStateSQDTAlkali("Custom_Rb", ...) Args: name: The species name (e.g. "Rb"). diff --git a/src/rydstate/species/strontium.py b/src/rydstate/species/strontium.py index c20a43f..9fed907 100644 --- a/src/rydstate/species/strontium.py +++ b/src/rydstate/species/strontium.py @@ -34,31 +34,7 @@ class _StrontiumAbstract(SpeciesObject): # https://iopscience.iop.org/article/10.1088/1674-1056/18/10/025 model_potential_parameter_fei_2009 = (0.9959, 16.9567, 0.2648, 0.1439) - -class Strontium87(_StrontiumAbstract): - name = "Sr87" - i_c = 9 / 2 - - # https://physics.nist.gov/PhysRefData/Handbook/Tables/strontiumtable1.htm - _isotope_mass_u = 86.908884 # u - _corrected_rydberg_constant = ( - rydberg_constant.m / (1 + electron_mass.to("u").m / _isotope_mass_u), - None, - str(rydberg_constant.u), - ) - - -class Strontium88(_StrontiumAbstract): - name = "Sr88" - i_c = 0 - - # https://physics.nist.gov/PhysRefData/Handbook/Tables/strontiumtable1.htm - _isotope_mass = 87.905619 # u - _corrected_rydberg_constant = ( - rydberg_constant.m / (1 + electron_mass.to("u").m / _isotope_mass), - None, - str(rydberg_constant.u), - ) + # TODO add isotope specific quantum defects # -- [1] Brienza 2023, Phys. Rev. A 108, 022815 # Microwave spectroscopy of low-l singlet strontium Rydberg states at intermediate n @@ -94,3 +70,29 @@ class Strontium88(_StrontiumAbstract): (3, 3.0, 1): (0.119, -2.0, 100, 0.0, 0.0), # [3] (3, 4.0, 1): (0.120, -2.4, 120, 0.0, 0.0), # [3] } + + +class Strontium87(_StrontiumAbstract): + name = "Sr87" + i_c = 9 / 2 + + # https://physics.nist.gov/PhysRefData/Handbook/Tables/strontiumtable1.htm + _isotope_mass_u = 86.908884 # u + _corrected_rydberg_constant = ( + rydberg_constant.m / (1 + electron_mass.to("u").m / _isotope_mass_u), + None, + str(rydberg_constant.u), + ) + + +class Strontium88(_StrontiumAbstract): + name = "Sr88" + i_c = 0 + + # https://physics.nist.gov/PhysRefData/Handbook/Tables/strontiumtable1.htm + _isotope_mass = 87.905619 # u + _corrected_rydberg_constant = ( + rydberg_constant.m / (1 + electron_mass.to("u").m / _isotope_mass), + None, + str(rydberg_constant.u), + ) diff --git a/src/rydstate/species/ytterbium.py b/src/rydstate/species/ytterbium.py index 6c0a37a..0460b05 100644 --- a/src/rydstate/species/ytterbium.py +++ b/src/rydstate/species/ytterbium.py @@ -24,44 +24,7 @@ class _YtterbiumAbstract(SpeciesObject): # https://iopscience.iop.org/article/10.1088/1674-1056/18/10/025 model_potential_parameter_fei_2009 = (0.8704, 22.0040, 0.1513, 0.3306) - -class Ytterbium171(_YtterbiumAbstract): - name = "Yb171" - i_c = 1 / 2 - - # https://physics.nist.gov/PhysRefData/Handbook/Tables/ytterbiumtable1.htm - _isotope_mass = 170.936323 # u - _corrected_rydberg_constant = ( - rydberg_constant.m / (1 + electron_mass.to("u").m / _isotope_mass), - None, - str(rydberg_constant.u), - ) - - -class Ytterbium173(_YtterbiumAbstract): - name = "Yb173" - i_c = 5 / 2 - - # https://physics.nist.gov/PhysRefData/Handbook/Tables/ytterbiumtable1.htm - _isotope_mass = 172.938208 # u - _corrected_rydberg_constant = ( - rydberg_constant.m / (1 + electron_mass.to("u").m / _isotope_mass), - None, - str(rydberg_constant.u), - ) - - -class Ytterbium174(_YtterbiumAbstract): - name = "Yb174" - i_c = 0 - - # https://physics.nist.gov/PhysRefData/Handbook/Tables/ytterbiumtable1.htm - _isotope_mass = 173.938859 # u - _corrected_rydberg_constant = ( - rydberg_constant.m / (1 + electron_mass.to("u").m / _isotope_mass), - None, - str(rydberg_constant.u), - ) + # TODO add isotope specific quantum defects # -- [1] Peper 2024, http://arxiv.org/abs/2406.01482 # Spectroscopy and modeling of 171Yb Rydberg states for high-fidelity two-qubit gates @@ -100,3 +63,42 @@ class Ytterbium174(_YtterbiumAbstract): # (4, 4.0, "+"): (0.0262659964, 0.0254568575, 0.0, 0.0, 0.0), # [3] S8 # (4, 4.0, "-"): (-0.148808463, -0.134219071, 0.0, 0.0, 0.0), # [3] S8 } + + +class Ytterbium171(_YtterbiumAbstract): + name = "Yb171" + i_c = 1 / 2 + + # https://physics.nist.gov/PhysRefData/Handbook/Tables/ytterbiumtable1.htm + _isotope_mass = 170.936323 # u + _corrected_rydberg_constant = ( + rydberg_constant.m / (1 + electron_mass.to("u").m / _isotope_mass), + None, + str(rydberg_constant.u), + ) + + +class Ytterbium173(_YtterbiumAbstract): + name = "Yb173" + i_c = 5 / 2 + + # https://physics.nist.gov/PhysRefData/Handbook/Tables/ytterbiumtable1.htm + _isotope_mass = 172.938208 # u + _corrected_rydberg_constant = ( + rydberg_constant.m / (1 + electron_mass.to("u").m / _isotope_mass), + None, + str(rydberg_constant.u), + ) + + +class Ytterbium174(_YtterbiumAbstract): + name = "Yb174" + i_c = 0 + + # https://physics.nist.gov/PhysRefData/Handbook/Tables/ytterbiumtable1.htm + _isotope_mass = 173.938859 # u + _corrected_rydberg_constant = ( + rydberg_constant.m / (1 + electron_mass.to("u").m / _isotope_mass), + None, + str(rydberg_constant.u), + ) diff --git a/tests/test_all_elements.py b/tests/test_all_elements.py index 8fb56b2..6321b05 100644 --- a/tests/test_all_elements.py +++ b/tests/test_all_elements.py @@ -1,28 +1,26 @@ from typing import TYPE_CHECKING import pytest -from rydstate.rydberg_state import RydbergStateAlkali, RydbergStateAlkalineLS +from rydstate import RydbergStateSQDTAlkali, RydbergStateSQDTAlkalineLS from rydstate.species import SpeciesObject if TYPE_CHECKING: - from rydstate.rydberg_state import RydbergStateBase + from rydstate import RydbergStateSQDT @pytest.mark.parametrize("species_name", SpeciesObject.get_available_species()) def test_magnetic(species_name: str) -> None: """Test magnetic units.""" species = SpeciesObject.from_name(species_name) + i_c = species.i_c if species.i_c is not None else 0 - state: RydbergStateBase + state: RydbergStateSQDT if species.number_valence_electrons == 1: - if species.i_c is None: - state = RydbergStateAlkali(species, n=50, l=0) - else: - state = RydbergStateAlkali(species, n=50, l=0, f=species.i_c + 0.5) + state = RydbergStateSQDTAlkali(species, n=50, l=0, f=i_c + 0.5) state.radial.create_wavefunction() - with pytest.raises(ValueError, match="j must be set"): - RydbergStateAlkali(species, n=50, l=1) + with pytest.raises(ValueError, match="j_tot must be set"): + RydbergStateSQDTAlkali(species, n=50, l=1) elif species.number_valence_electrons == 2 and species._quantum_defects is not None: # noqa: SLF001 for s_tot in [0, 1]: - state = RydbergStateAlkalineLS(species, n=50, l=1, s_tot=s_tot, j_tot=1 + s_tot) + state = RydbergStateSQDTAlkalineLS(species, n=50, l=1, s_tot=s_tot, j_tot=1 + s_tot, f_tot=s_tot + 1 + i_c) state.radial.create_wavefunction() diff --git a/tests/test_angular_matrix_elements.py b/tests/test_angular_matrix_elements.py index ea6e2e8..35a7976 100644 --- a/tests/test_angular_matrix_elements.py +++ b/tests/test_angular_matrix_elements.py @@ -4,59 +4,54 @@ import numpy as np import pytest -from rydstate.angular import AngularKetFJ, AngularKetJJ, AngularKetLS +from rydstate.angular import AngularKetFJ, AngularKetJJ, AngularKetKS, AngularKetLS +from rydstate.angular.angular_ket import CouplingScheme from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers if TYPE_CHECKING: - from rydstate.angular.angular_ket import AngularKetBase, CouplingScheme + from rydstate.angular.angular_ket import AngularKetBase from rydstate.angular.angular_matrix_element import AngularOperatorType -TEST_KET_PAIRS = [ - ( - AngularKetLS(s_tot=1, l_r=1, j_tot=1, f_tot=1.5, species="Yb173"), - AngularKetLS(s_tot=1, l_r=1, j_tot=1, f_tot=1.5, species="Yb173"), - ), +TEST_KET_PAIRS: list[tuple[AngularKetBase, AngularKetBase]] = [ ( AngularKetLS(s_tot=1, l_r=1, j_tot=1, f_tot=1.5, species="Yb173"), AngularKetFJ(f_c=2, l_r=1, j_r=1.5, f_tot=2.5, species="Yb173"), ), - ( - AngularKetFJ(f_c=2, l_r=1, j_r=1.5, f_tot=2.5, species="Yb173"), - AngularKetFJ(f_c=2, l_r=1, j_r=1.5, f_tot=2.5, species="Yb173"), - ), ( AngularKetFJ(i_c=2.5, s_c=0.5, l_c=0, s_r=0.5, l_r=1, j_c=0.5, f_c=2.0, j_r=1.5, f_tot=2.5), AngularKetFJ(i_c=2.5, s_c=0.5, l_c=0, s_r=0.5, l_r=2, j_c=0.5, f_c=2.0, j_r=1.5, f_tot=2.5), ), ] -TEST_KETS = [ - AngularKetLS(s_tot=1, l_r=1, j_tot=1, f_tot=1.5, species="Yb173"), +TEST_KETS: list[AngularKetBase] = [ AngularKetLS(s_tot=1, l_r=1, j_tot=1, f_tot=1.5, species="Yb173"), AngularKetJJ(l_r=1, j_r=1.5, j_tot=2, f_tot=2.5, species="Yb173"), AngularKetFJ(f_c=2, l_r=1, j_r=1.5, f_tot=2.5, species="Yb173"), - AngularKetFJ(f_c=2, l_r=1, j_r=1.5, f_tot=2.5, species="Yb173"), + AngularKetLS(s_tot=0, l_r=1, j_tot=1, f_tot=1.5, species="Yb173"), + AngularKetJJ(l_r=1, j_r=1.5, j_tot=1, f_tot=1.5, species="Yb173"), + AngularKetKS(l_r=1, k=1.5, j_tot=1, f_tot=1.5, species="Yb173"), ] +TEST_KET_PAIRS += [(ket, ket) for ket in TEST_KETS] + @pytest.mark.parametrize("ket", TEST_KETS) def test_exp_q_different_coupling_schemes(ket: AngularKetBase) -> None: all_qns: tuple[AngularMomentumQuantumNumbers, ...] = get_args(AngularMomentumQuantumNumbers) + coupling_schemes: list[CouplingScheme] = get_args(CouplingScheme) # type: ignore [assignment] for q in all_qns: exp_q = ket.to_state("LS").calc_exp_qn(q) - assert np.isclose(exp_q, ket.to_state("JJ").calc_exp_qn(q)) - assert np.isclose(exp_q, ket.to_state("FJ").calc_exp_qn(q)) - std_q = ket.to_state("LS").calc_std_qn(q) - assert np.isclose(std_q, ket.to_state("JJ").calc_std_qn(q)) - assert np.isclose(std_q, ket.to_state("FJ").calc_std_qn(q)) + for scheme in coupling_schemes: + assert np.isclose(exp_q, ket.to_state(scheme).calc_exp_qn(q)) + assert np.isclose(std_q, ket.to_state(scheme).calc_std_qn(q)) @pytest.mark.parametrize(("ket1", "ket2"), TEST_KET_PAIRS) def test_overlap_different_coupling_schemes(ket1: AngularKetBase, ket2: AngularKetBase) -> None: ov = ket1.calc_reduced_overlap(ket2) - coupling_schemes: list[CouplingScheme] = ["LS", "JJ", "FJ"] + coupling_schemes: list[CouplingScheme] = get_args(CouplingScheme) # type: ignore [assignment] for scheme in coupling_schemes: assert np.isclose(ov, ket1.to_state().calc_reduced_overlap(ket2.to_state(scheme))) assert np.isclose(ov, ket1.to_state(scheme).calc_reduced_overlap(ket2)) @@ -69,7 +64,7 @@ def test_reduced_identity(ket: AngularKetBase) -> None: reduced_identity = np.sqrt(2 * ket.f_tot + 1) op: AngularMomentumQuantumNumbers - coupling_schemes: list[CouplingScheme] = ["LS", "JJ", "FJ"] + coupling_schemes: list[CouplingScheme] = get_args(CouplingScheme) # type: ignore [assignment] for scheme in coupling_schemes: state = ket.to_state(scheme) for op in state.kets[0].quantum_number_names: @@ -88,13 +83,14 @@ def test_matrix_elements_in_different_coupling_schemes(ket1: AngularKetBase, ket ("i_c", 1), ("f_tot", 1), ("j_tot", 1), + ("k", 1), ] - coupling_schemes: list[CouplingScheme] = ["LS", "JJ", "FJ"] + coupling_schemes: list[CouplingScheme] = get_args(CouplingScheme) # type: ignore [assignment] - for scheme in coupling_schemes: - for operator, kappa in example_list: + for operator, kappa in example_list: + val = ket1.calc_reduced_matrix_element(ket2, operator, kappa) + for scheme in coupling_schemes: msg = f"{operator=}, {kappa=}, {ket1=}, {ket2=}, {scheme=}" - val = ket1.calc_reduced_matrix_element(ket2, operator, kappa) assert np.isclose( val, ket1.to_state().calc_reduced_matrix_element(ket2.to_state(scheme), operator, kappa) diff --git a/tests/test_hydrogen.py b/tests/test_hydrogen.py index a28ee13..9ea6863 100644 --- a/tests/test_hydrogen.py +++ b/tests/test_hydrogen.py @@ -1,6 +1,6 @@ import numpy as np import pytest -from rydstate.rydberg_state import RydbergStateAlkali +from rydstate import RydbergStateSQDTAlkali from sympy.abc import r as sympy_r from sympy.physics import hydrogen as sympy_hydrogen from sympy.utilities.lambdify import lambdify @@ -28,7 +28,7 @@ def test_hydrogen_wavefunctions(species: str, n: int, l: int, run_backward: bool) -> None: """Test that numerov integration matches sympy's analytical hydrogen wavefunctions.""" # Setup atom - state = RydbergStateAlkali(species, n=n, l=l, j=l + 0.5) + state = RydbergStateSQDTAlkali(species, n=n, l=l, j=l + 0.5) # Run the numerov integration state.radial.create_wavefunction("numerov", run_backward=run_backward, sign_convention="n_l_1") diff --git a/tests/test_matrix_elements.py b/tests/test_matrix_elements.py index ed3c37d..2f98c57 100644 --- a/tests/test_matrix_elements.py +++ b/tests/test_matrix_elements.py @@ -1,6 +1,6 @@ import numpy as np import pytest -from rydstate.rydberg_state import RydbergStateAlkali +from rydstate import RydbergStateSQDTAlkali from rydstate.units import BaseUnits, ureg @@ -10,7 +10,7 @@ def test_magnetic(l: int) -> None: g_s = 2.002319304363 g_l = 1 - state = RydbergStateAlkali("Rb", n=max(l + 1, 10), l=l, j=l + 0.5, m=l + 0.5) + state = RydbergStateSQDTAlkali("Rb", n=max(l + 1, 10), l=l, j=l + 0.5, m=l + 0.5) # Check that for m = j_tot = l + s_tot the magnetic matrix element is - mu_B * (g_l * l + g_s * s_tot) mu = state.calc_matrix_element(state, "magnetic_dipole", q=0) diff --git a/tests/test_radial_matrix_element.py b/tests/test_radial_matrix_element.py index 709f0aa..74d9b70 100644 --- a/tests/test_radial_matrix_element.py +++ b/tests/test_radial_matrix_element.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from rydstate.radial import RadialState -from rydstate.rydberg_state import RydbergStateAlkali +from rydstate import RydbergStateSQDTAlkali +from rydstate.radial import RadialKet from rydstate.species import SpeciesObject @@ -28,8 +28,8 @@ def test_circular_matrix_element(species: str, n: int, dn: int, dl: int) -> None matrix_element = {} for _species in [species, "H_textbook"]: - state_i = RydbergStateAlkali(_species, n=n, l=l1, j=l1 + 0.5) - state_f = RydbergStateAlkali(_species, n=n + dn, l=l2, j=l2 + 0.5) + state_i = RydbergStateSQDTAlkali(_species, n=n, l=l1, j=l1 + 0.5) + state_f = RydbergStateSQDTAlkali(_species, n=n + dn, l=l2, j=l2 + 0.5) matrix_element[_species] = state_i.radial.calc_matrix_element(state_f.radial, 1, unit="bohr") assert np.isclose(matrix_element[species], matrix_element["H_textbook"], rtol=1e-4) @@ -61,7 +61,7 @@ def test_circular_expectation_value(species_name: str, n: int, l: int, j_tot: fl species = SpeciesObject.from_name(species_name) nu = species.calc_nu(n, l, j_tot) - state = RadialState(species, nu=nu, l_r=l) + state = RadialKet(species, nu=nu, l_r=l) state.set_n_for_sanity_check(n) state.create_wavefunction()