22import jax .numpy as jnp
33import numpy as np
44import pytest
5- from jax import tree_map , vmap
65from numpy .testing import assert_allclose
76
87from pyscf_ipu .experimental .basis import basisset
9- from pyscf_ipu .experimental .device import has_ipu , ipu_func
108from pyscf_ipu .experimental .integrals import (
119 eri_basis ,
1210 eri_basis_sparse ,
1917 overlap_primitives ,
2018)
2119from pyscf_ipu .experimental .interop import to_pyscf
22- from pyscf_ipu .experimental .nuclear_gradients import (
23- grad_kinetic_basis ,
24- grad_nuclear_basis ,
25- grad_overlap_basis ,
26- )
2720from pyscf_ipu .experimental .primitive import Primitive
2821from pyscf_ipu .experimental .structure import molecule
2922
3023
31- @pytest .mark .parametrize ("basis_name" , ["sto-3g" , "6-31g**" ])
32- def test_to_pyscf (basis_name ):
33- mol = molecule ("water" )
34- basis = basisset (mol , basis_name )
35- pyscf_mol = to_pyscf (mol , basis_name )
36- assert basis .num_orbitals == pyscf_mol .nao
37-
38-
39- @pytest .mark .parametrize ("basis_name" , ["sto-3g" , "6-31+g" ])
40- def test_gto (basis_name ):
41- from pyscf .dft .numint import eval_rho
42-
43- # Atomic orbitals
44- structure = molecule ("water" )
45- basis = basisset (structure , basis_name )
46- mesh , _ = uniform_mesh ()
47- actual = basis (mesh )
48-
49- mol = to_pyscf (structure , basis_name )
50- expect_ao = mol .eval_gto ("GTOval_cart" , np .asarray (mesh ))
51- assert_allclose (actual , expect_ao , atol = 1e-6 )
52-
53- # Molecular orbitals
54- mf = mol .KS ()
55- mf .kernel ()
56- C = jnp .array (mf .mo_coeff , dtype = jnp .float32 )
57- actual = basis .occupancy * C @ C .T
58- expect = jnp .array (mf .make_rdm1 (), dtype = jnp .float32 )
59- assert_allclose (actual , expect , atol = 1e-6 )
60-
61- # Electron density
62- actual = electron_density (basis , mesh , C )
63- expect = eval_rho (mol , expect_ao , mf .make_rdm1 (), "lda" )
64- assert_allclose (actual , expect , atol = 1e-6 )
65-
66-
6724def test_overlap ():
6825 # Exercise 3.21 of "Modern quantum chemistry: introduction to advanced
6926 # electronic structure theory."" by Szabo and Ostlund
@@ -151,19 +108,6 @@ def test_water_nuclear():
151108 assert_allclose (actual , expect , atol = 1e-4 )
152109
153110
154- def eri_orbitals (orbitals ):
155- def take (orbital , index ):
156- p = tree_map (lambda * xs : jnp .stack (xs ), * orbital .primitives )
157- p = tree_map (lambda x : jnp .take (x , index , axis = 0 ), p )
158- c = jnp .take (orbital .coefficients , index )
159- return p , c
160-
161- indices = [jnp .arange (o .num_primitives ) for o in orbitals ]
162- indices = [i .reshape (- 1 ) for i in jnp .meshgrid (* indices )]
163- prim , coef = zip (* [take (o , i ) for o , i in zip (orbitals , indices )])
164- return jnp .sum (jnp .prod (jnp .stack (coef ), axis = 0 ) * vmap (eri_primitives )(* prim ))
165-
166-
167111def test_eri ():
168112 # PyQuante test cases for ERI
169113 a , b , c , d = [Primitive ()] * 4
@@ -172,18 +116,6 @@ def test_eri():
172116 c , d = [Primitive (lmn = jnp .array ([1 , 0 , 0 ]))] * 2
173117 assert_allclose (eri_primitives (a , b , c , d ), 0.940316 , atol = 1e-5 )
174118
175- # H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund
176- h2 = molecule ("h2" )
177- basis = basisset (h2 , "sto-3g" )
178- indices = [(0 , 0 , 0 , 0 ), (0 , 0 , 1 , 1 ), (1 , 0 , 0 , 0 ), (1 , 0 , 1 , 0 )]
179- expected = [0.7746 , 0.5697 , 0.4441 , 0.2970 ]
180-
181- for ijkl , expect in zip (indices , expected ):
182- actual = eri_orbitals ([basis .orbitals [aoid ] for aoid in ijkl ])
183- assert_allclose (actual , expect , atol = 1e-4 )
184-
185-
186- def test_eri_basis ():
187119 # H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund
188120 h2 = molecule ("h2" )
189121 basis = basisset (h2 , "sto-3g" )
@@ -219,63 +151,3 @@ def test_water_eri(sparse):
219151 aosym = "s8" if sparse else "s1"
220152 expect = to_pyscf (h2o , basis_name = basis_name ).intor ("int2e_cart" , aosym = aosym )
221153 assert_allclose (actual , expect , atol = 1e-4 )
222-
223-
224- @pytest .mark .skipif (not has_ipu (), reason = "Skipping ipu test!" )
225- def test_ipu_overlap ():
226- from pyscf_ipu .experimental .integrals import _overlap_primitives
227-
228- a , b = [Primitive ()] * 2
229- actual = ipu_func (_overlap_primitives )(a , b )
230- assert_allclose (actual , overlap_primitives (a , b ))
231-
232-
233- @pytest .mark .skipif (not has_ipu (), reason = "Skipping ipu test!" )
234- def test_ipu_kinetic ():
235- from pyscf_ipu .experimental .integrals import _kinetic_primitives
236-
237- a , b = [Primitive ()] * 2
238- actual = ipu_func (_kinetic_primitives )(a , b )
239- assert_allclose (actual , kinetic_primitives (a , b ))
240-
241-
242- @pytest .mark .skipif (not has_ipu (), reason = "Skipping ipu test!" )
243- def test_ipu_nuclear ():
244- from pyscf_ipu .experimental .integrals import _nuclear_primitives
245-
246- # PyQuante test case for nuclear attraction integral
247- a , b = [Primitive ()] * 2
248- c = jnp .zeros (3 )
249- actual = ipu_func (_nuclear_primitives )(a , b , c )
250- assert_allclose (actual , - 1.595769 , atol = 1e-5 )
251-
252-
253- @pytest .mark .skipif (not has_ipu (), reason = "Skipping ipu test!" )
254- def test_ipu_eri ():
255- from pyscf_ipu .experimental .integrals import _eri_primitives
256-
257- # PyQuante test cases for ERI
258- a , b , c , d = [Primitive ()] * 4
259- actual = ipu_func (_eri_primitives )(a , b , c , d )
260- assert_allclose (actual , 1.128379 , atol = 1e-5 )
261-
262-
263- @pytest .mark .parametrize ("basis_name" , ["sto-3g" , "6-31+g" ])
264- def test_nuclear_gradients (basis_name ):
265- h2 = molecule ("h2" )
266- scfmol = to_pyscf (h2 , basis_name )
267- basis = basisset (h2 , basis_name )
268-
269- actual = grad_overlap_basis (basis )
270- expect = scfmol .intor ("int1e_ipovlp_cart" , comp = 3 )
271- assert_allclose (actual , expect , atol = 1e-6 )
272-
273- actual = grad_kinetic_basis (basis )
274- expect = scfmol .intor ("int1e_ipkin_cart" , comp = 3 )
275- assert_allclose (actual , expect , atol = 1e-6 )
276-
277- # TODO: investigate possible inconsistency in libcint outputs?
278- actual = grad_nuclear_basis (basis )
279- expect = scfmol .intor ("int1e_ipnuc_cart" , comp = 3 )
280- expect = - np .moveaxis (expect , 1 , 2 )
281- assert_allclose (actual , expect , atol = 1e-6 )
0 commit comments