From 1d6a62012ce6cd3c44aacdaad7fc88c2c63f5146 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Wed, 20 Sep 2023 11:30:56 +0000 Subject: [PATCH 01/15] WIP: nuclear gradients --- pyscf_ipu/experimental/basis.py | 1 + pyscf_ipu/experimental/integrals.py | 9 ++------- pyscf_ipu/experimental/orbital.py | 7 +++++-- pyscf_ipu/experimental/primitive.py | 7 +++++++ 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/pyscf_ipu/experimental/basis.py b/pyscf_ipu/experimental/basis.py index eb7ebad..070b2f0 100644 --- a/pyscf_ipu/experimental/basis.py +++ b/pyscf_ipu/experimental/basis.py @@ -64,6 +64,7 @@ def basisset(structure: Structure, basis_name: str = "sto-3g"): alphas=jnp.array(s["exponents"], dtype=float), lmn=jnp.array(lmn, dtype=jnp.int32), coefficients=jnp.array(s["coefficients"], dtype=float), + atom_index=a, ) orbitals.append(ao) diff --git a/pyscf_ipu/experimental/integrals.py b/pyscf_ipu/experimental/integrals.py index 32ddac8..3e346f5 100644 --- a/pyscf_ipu/experimental/integrals.py +++ b/pyscf_ipu/experimental/integrals.py @@ -1,5 +1,4 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from dataclasses import asdict from functools import partial from itertools import product as cartesian_product from typing import Callable @@ -84,15 +83,11 @@ def _overlap_primitives(a: Primitive, b: Primitive) -> float: def _kinetic_primitives(a: Primitive, b: Primitive) -> float: t0 = b.alpha * (2 * jnp.sum(b.lmn) + 3) * _overlap_primitives(a, b) - def offset_qn(ax: int, offset: int): - lmn = b.lmn.at[ax].add(offset) - return Primitive(**{**asdict(b), "lmn": lmn}) - axes = jnp.arange(3) - b1 = vmap(offset_qn, (0, None))(axes, 2) + b1 = vmap(b.offset_lmn, (0, None))(axes, 2) t1 = jnp.sum(vmap(_overlap_primitives, (None, 0))(a, b1)) - b2 = vmap(offset_qn, (0, None))(axes, -2) + b2 = vmap(b.offset_lmn, (0, None))(axes, -2) t2 = jnp.sum(b.lmn * (b.lmn - 1) * vmap(_overlap_primitives, (None, 0))(a, b2)) return t0 - 2.0 * b.alpha**2 * t1 - 0.5 * t2 diff --git a/pyscf_ipu/experimental/orbital.py b/pyscf_ipu/experimental/orbital.py index 285fea9..c85c245 100644 --- a/pyscf_ipu/experimental/orbital.py +++ b/pyscf_ipu/experimental/orbital.py @@ -32,10 +32,13 @@ def eval_orbital(p: Primitive, coef: float, pos: FloatNx3): return out @staticmethod - def from_bse(center, alphas, lmn, coefficients): + def from_bse(center, alphas, lmn, coefficients, atom_index): coefficients = coefficients.reshape(-1) assert len(coefficients) == len(alphas), "Expecting same size vectors!" - p = [Primitive(center=center, alpha=a, lmn=lmn) for a in alphas] + p = [ + Primitive(center=center, alpha=a, lmn=lmn, atom_index=atom_index) + for a in alphas + ] return Orbital(primitives=p, coefficients=coefficients) diff --git a/pyscf_ipu/experimental/primitive.py b/pyscf_ipu/experimental/primitive.py index 283ac9e..41f98cc 100644 --- a/pyscf_ipu/experimental/primitive.py +++ b/pyscf_ipu/experimental/primitive.py @@ -1,4 +1,5 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from copy import deepcopy from typing import Optional import chex @@ -14,6 +15,7 @@ class Primitive: alpha: float = 1.0 lmn: Int3 = jnp.zeros(3, dtype=jnp.int32) norm: Optional[float] = None + atom_index: Optional[int] = None def __post_init__(self): if self.norm is None: @@ -26,6 +28,11 @@ def angular_momentum(self) -> int: def __call__(self, pos: FloatNx3) -> FloatN: return eval_primitive(self, pos) + def offset_lmn(self, axis: int, offset: int) -> "Primitive": + out = deepcopy(self) + out.lmn = self.lmn.at[axis].add(offset) + return out + def normalize(lmn: Int3, alpha: float) -> float: L = jnp.sum(lmn) From 345d1151b550a3e0fb3960bee4fe3d630b2ffcc2 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Wed, 20 Sep 2023 20:56:03 +0000 Subject: [PATCH 02/15] add grad overlap evaluation --- pyscf_ipu/experimental/nuclear_gradients.py | 52 ++++++++++ test/test_integrals.py | 103 ++++++++++++++++++++ 2 files changed, 155 insertions(+) create mode 100644 pyscf_ipu/experimental/nuclear_gradients.py diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py new file mode 100644 index 0000000..957c79b --- /dev/null +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. + +import jax.numpy as jnp +from jax import jit, tree_map, vmap +from jax.ops import segment_sum + +from .basis import Basis +from .integrals import _overlap_primitives +from .orbital import batch_orbitals +from .primitive import Primitive +from .types import Float3 + + +def grad_overlap_primitives(i: int, a: Primitive, b: Primitive) -> Float3: + """Analytic gradient of overlap integral with respect to atom i center""" + axes = jnp.arange(3) + lhs_p1 = vmap(a.offset_lmn, (0, None))(axes, 1) + t1 = 2 * a.alpha * vmap(_overlap_primitives, (0, None))(lhs_p1, b) + + lhs_m1 = vmap(a.offset_lmn, (0, None))(axes, -1) + t2 = jnp.where(a.lmn > 0, a.lmn, jnp.zeros_like(a.lmn)) + t2 *= vmap(_overlap_primitives, (0, None))(lhs_m1, b) + grad_out = t1 - t2 + return jnp.where(a.atom_index == i, grad_out, jnp.zeros_like(grad_out)) + + +# output is [3, N, N] +def grad_overlap_basis(b: Basis): + def take_primitives(indices): + p = tree_map(lambda x: jnp.take(x, indices, axis=0), primitives) + c = jnp.take(coefficients, indices) + return p, c + + # atom_indices = jnp.arange(b.structure.num_atoms) + primitives, coefficients, orbital_index = batch_orbitals(b.orbitals) + ii, jj = jnp.meshgrid(*[jnp.arange(b.num_primitives)] * 2, indexing="ij") + lhs, cl = take_primitives(ii.reshape(-1)) + rhs, cr = take_primitives(jj.reshape(-1)) + + op = jit(vmap(grad_overlap_primitives, (None, 0, 0))) + + out = op(0, lhs, rhs) + + for i in range(1, b.structure.num_atoms): + out += op(i, lhs, rhs) + + out = cl * cr * out.T + out = out.reshape(3, b.num_primitives, b.num_primitives) + out = segment_sum(jnp.rollaxis(out, 1), orbital_index) + out = segment_sum(jnp.rollaxis(out, -1), orbital_index) + + return jnp.rollaxis(out, -1) diff --git a/test/test_integrals.py b/test/test_integrals.py index 00b359a..5571f6d 100644 --- a/test/test_integrals.py +++ b/test/test_integrals.py @@ -2,9 +2,11 @@ import jax.numpy as jnp import numpy as np import pytest +from jax import tree_map, vmap from numpy.testing import assert_allclose from pyscf_ipu.experimental.basis import basisset +from pyscf_ipu.experimental.device import has_ipu, ipu_func from pyscf_ipu.experimental.integrals import ( eri_basis, eri_basis_sparse, @@ -17,10 +19,47 @@ overlap_primitives, ) from pyscf_ipu.experimental.interop import to_pyscf +from pyscf_ipu.experimental.mesh import electron_density, uniform_mesh from pyscf_ipu.experimental.primitive import Primitive from pyscf_ipu.experimental.structure import molecule +@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31g**"]) +def test_to_pyscf(basis_name): + mol = molecule("water") + basis = basisset(mol, basis_name) + pyscf_mol = to_pyscf(mol, basis_name) + assert basis.num_orbitals == pyscf_mol.nao + + +@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31+g"]) +def test_gto(basis_name): + from pyscf.dft.numint import eval_rho + + # Atomic orbitals + structure = molecule("water") + basis = basisset(structure, basis_name) + mesh, _ = uniform_mesh() + actual = basis(mesh) + + mol = to_pyscf(structure, basis_name) + expect_ao = mol.eval_gto("GTOval_cart", np.asarray(mesh)) + assert_allclose(actual, expect_ao, atol=1e-6) + + # Molecular orbitals + mf = mol.KS() + mf.kernel() + C = jnp.array(mf.mo_coeff, dtype=jnp.float32) + actual = basis.occupancy * C @ C.T + expect = jnp.array(mf.make_rdm1(), dtype=jnp.float32) + assert_allclose(actual, expect, atol=1e-6) + + # Electron density + actual = electron_density(basis, mesh, C) + expect = eval_rho(mol, expect_ao, mf.make_rdm1(), "lda") + assert_allclose(actual, expect, atol=1e-6) + + def test_overlap(): # Exercise 3.21 of "Modern quantum chemistry: introduction to advanced # electronic structure theory."" by Szabo and Ostlund @@ -108,6 +147,19 @@ def test_water_nuclear(): assert_allclose(actual, expect, atol=1e-4) +def eri_orbitals(orbitals): + def take(orbital, index): + p = tree_map(lambda *xs: jnp.stack(xs), *orbital.primitives) + p = tree_map(lambda x: jnp.take(x, index, axis=0), p) + c = jnp.take(orbital.coefficients, index) + return p, c + + indices = [jnp.arange(o.num_primitives) for o in orbitals] + indices = [i.reshape(-1) for i in jnp.meshgrid(*indices)] + prim, coef = zip(*[take(o, i) for o, i in zip(orbitals, indices)]) + return jnp.sum(jnp.prod(jnp.stack(coef), axis=0) * vmap(eri_primitives)(*prim)) + + def test_eri(): # PyQuante test cases for ERI a, b, c, d = [Primitive()] * 4 @@ -116,6 +168,18 @@ def test_eri(): c, d = [Primitive(lmn=jnp.array([1, 0, 0]))] * 2 assert_allclose(eri_primitives(a, b, c, d), 0.940316, atol=1e-5) + # H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund + h2 = molecule("h2") + basis = basisset(h2, "sto-3g") + indices = [(0, 0, 0, 0), (0, 0, 1, 1), (1, 0, 0, 0), (1, 0, 1, 0)] + expected = [0.7746, 0.5697, 0.4441, 0.2970] + + for ijkl, expect in zip(indices, expected): + actual = eri_orbitals([basis.orbitals[aoid] for aoid in ijkl]) + assert_allclose(actual, expect, atol=1e-4) + + +def test_eri_basis(): # H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund h2 = molecule("h2") basis = basisset(h2, "sto-3g") @@ -151,3 +215,42 @@ def test_water_eri(sparse): aosym = "s8" if sparse else "s1" expect = to_pyscf(h2o, basis_name=basis_name).intor("int2e_cart", aosym=aosym) assert_allclose(actual, expect, atol=1e-4) + + +@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") +def test_ipu_overlap(): + from pyscf_ipu.experimental.integrals import _overlap_primitives + + a, b = [Primitive()] * 2 + actual = ipu_func(_overlap_primitives)(a, b) + assert_allclose(actual, overlap_primitives(a, b)) + + +@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") +def test_ipu_kinetic(): + from pyscf_ipu.experimental.integrals import _kinetic_primitives + + a, b = [Primitive()] * 2 + actual = ipu_func(_kinetic_primitives)(a, b) + assert_allclose(actual, kinetic_primitives(a, b)) + + +@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") +def test_ipu_nuclear(): + from pyscf_ipu.experimental.integrals import _nuclear_primitives + + # PyQuante test case for nuclear attraction integral + a, b = [Primitive()] * 2 + c = jnp.zeros(3) + actual = ipu_func(_nuclear_primitives)(a, b, c) + assert_allclose(actual, -1.595769, atol=1e-5) + + +@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") +def test_ipu_eri(): + from pyscf_ipu.experimental.integrals import _eri_primitives + + # PyQuante test cases for ERI + a, b, c, d = [Primitive()] * 4 + actual = ipu_func(_eri_primitives)(a, b, c, d) + assert_allclose(actual, 1.128379, atol=1e-5) From d0e012103190a0ef735b7e788f53377e842f9944 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Thu, 21 Sep 2023 08:06:29 +0000 Subject: [PATCH 03/15] vmap over atoms --- pyscf_ipu/experimental/nuclear_gradients.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py index 957c79b..d9ca332 100644 --- a/pyscf_ipu/experimental/nuclear_gradients.py +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -31,18 +31,16 @@ def take_primitives(indices): c = jnp.take(coefficients, indices) return p, c - # atom_indices = jnp.arange(b.structure.num_atoms) primitives, coefficients, orbital_index = batch_orbitals(b.orbitals) ii, jj = jnp.meshgrid(*[jnp.arange(b.num_primitives)] * 2, indexing="ij") lhs, cl = take_primitives(ii.reshape(-1)) rhs, cr = take_primitives(jj.reshape(-1)) - op = jit(vmap(grad_overlap_primitives, (None, 0, 0))) - - out = op(0, lhs, rhs) - - for i in range(1, b.structure.num_atoms): - out += op(i, lhs, rhs) + op = vmap(grad_overlap_primitives, (None, 0, 0)) + op = jit(vmap(op, (0, None, None))) + atom_indices = jnp.arange(b.structure.num_atoms) + out = op(atom_indices, lhs, rhs) + out = jnp.sum(out, axis=0) out = cl * cr * out.T out = out.reshape(3, b.num_primitives, b.num_primitives) From 5ee86e3906071df5edcc10b2f082160f0e1de5b5 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Thu, 21 Sep 2023 08:11:59 +0000 Subject: [PATCH 04/15] loop over basis --- test/test_integrals.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/test_integrals.py b/test/test_integrals.py index 5571f6d..a88d769 100644 --- a/test/test_integrals.py +++ b/test/test_integrals.py @@ -254,3 +254,13 @@ def test_ipu_eri(): a, b, c, d = [Primitive()] * 4 actual = ipu_func(_eri_primitives)(a, b, c, d) assert_allclose(actual, 1.128379, atol=1e-5) + + +@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31+g"]) +def test_nuclear_gradients(basis_name): + h2 = molecule("h2") + expect = to_pyscf(h2, basis_name).intor("int1e_ipovlp_cart", comp=3) + basis = basisset(h2, basis_name) + actual = grad_overlap_basis(basis) + + assert_allclose(actual, expect, atol=1e-6) From f41070eacf5331dbe7c8a050da6bb8da19ead55d Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Thu, 21 Sep 2023 08:12:17 +0000 Subject: [PATCH 05/15] update type annotations --- pyscf_ipu/experimental/nuclear_gradients.py | 5 ++--- pyscf_ipu/experimental/types.py | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py index d9ca332..3fad88b 100644 --- a/pyscf_ipu/experimental/nuclear_gradients.py +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -8,7 +8,7 @@ from .integrals import _overlap_primitives from .orbital import batch_orbitals from .primitive import Primitive -from .types import Float3 +from .types import Float3, Float3xNxN def grad_overlap_primitives(i: int, a: Primitive, b: Primitive) -> Float3: @@ -24,8 +24,7 @@ def grad_overlap_primitives(i: int, a: Primitive, b: Primitive) -> Float3: return jnp.where(a.atom_index == i, grad_out, jnp.zeros_like(grad_out)) -# output is [3, N, N] -def grad_overlap_basis(b: Basis): +def grad_overlap_basis(b: Basis) -> Float3xNxN: def take_primitives(indices): p = tree_map(lambda x: jnp.take(x, indices, axis=0), primitives) c = jnp.take(coefficients, indices) diff --git a/pyscf_ipu/experimental/types.py b/pyscf_ipu/experimental/types.py index 63077ca..642d3e8 100644 --- a/pyscf_ipu/experimental/types.py +++ b/pyscf_ipu/experimental/types.py @@ -4,6 +4,7 @@ from jaxtyping import Array, Float, Int Float3 = Float[Array, "3"] +Float3xNxN = Float[Array, "3 N N"] FloatNx3 = Float[Array, "N 3"] FloatN = Float[Array, "N"] FloatNxN = Float[Array, "N N"] From 38bcc4c4918650dc6dfaa10ba12840dd119ef7e3 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Thu, 21 Sep 2023 09:39:23 +0000 Subject: [PATCH 06/15] add gradient of kinetic energy integrals --- pyscf_ipu/experimental/nuclear_gradients.py | 59 ++++++++++++++++++--- test/test_integrals.py | 12 ++++- 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py index 3fad88b..a5bcd46 100644 --- a/pyscf_ipu/experimental/nuclear_gradients.py +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -1,30 +1,73 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. - +from typing import Callable import jax.numpy as jnp from jax import jit, tree_map, vmap from jax.ops import segment_sum from .basis import Basis -from .integrals import _overlap_primitives +from .integrals import _overlap_primitives, _kinetic_primitives from .orbital import batch_orbitals from .primitive import Primitive from .types import Float3, Float3xNxN -def grad_overlap_primitives(i: int, a: Primitive, b: Primitive) -> Float3: - """Analytic gradient of overlap integral with respect to atom i center""" +def grad_primitive_integral( + primitive_op: Callable, atom_index: int, a: Primitive, b: Primitive +) -> Float3: + """Generic gradient of a one-electron integral with respect the atom_index center""" axes = jnp.arange(3) lhs_p1 = vmap(a.offset_lmn, (0, None))(axes, 1) - t1 = 2 * a.alpha * vmap(_overlap_primitives, (0, None))(lhs_p1, b) + t1 = 2 * a.alpha * vmap(primitive_op, (0, None))(lhs_p1, b) lhs_m1 = vmap(a.offset_lmn, (0, None))(axes, -1) t2 = jnp.where(a.lmn > 0, a.lmn, jnp.zeros_like(a.lmn)) - t2 *= vmap(_overlap_primitives, (0, None))(lhs_m1, b) + t2 *= vmap(primitive_op, (0, None))(lhs_m1, b) grad_out = t1 - t2 - return jnp.where(a.atom_index == i, grad_out, jnp.zeros_like(grad_out)) + return jnp.where(a.atom_index == atom_index, grad_out, jnp.zeros_like(grad_out)) + + +def grad_integrate(b: Basis, primitive_op: Callable) -> Float3xNxN: + """Generic gradient of one-electron integrals over the basis set""" + + def take_primitives(indices): + p = tree_map(lambda x: jnp.take(x, indices, axis=0), primitives) + c = jnp.take(coefficients, indices) + return p, c + + primitives, coefficients, orbital_index = batch_orbitals(b.orbitals) + ii, jj = jnp.meshgrid(*[jnp.arange(b.num_primitives)] * 2, indexing="ij") + lhs, cl = take_primitives(ii.reshape(-1)) + rhs, cr = take_primitives(jj.reshape(-1)) + + op = vmap(primitive_op, (None, 0, 0)) + op = jit(vmap(op, (0, None, None))) + atom_indices = jnp.arange(b.structure.num_atoms) + out = op(atom_indices, lhs, rhs) + out = jnp.sum(out, axis=0) + + out = cl * cr * out.T + out = out.reshape(3, b.num_primitives, b.num_primitives) + out = segment_sum(jnp.rollaxis(out, 1), orbital_index) + out = segment_sum(jnp.rollaxis(out, -1), orbital_index) + + return jnp.rollaxis(out, -1) + + +def grad_overlap_primitives(i: int, a: Primitive, b: Primitive) -> Float3: + return grad_primitive_integral(_overlap_primitives, i, a, b) + + +def grad_kinetic_primitives(i: int, a: Primitive, b: Primitive) -> Float3: + return grad_primitive_integral(_kinetic_primitives, i, a, b) def grad_overlap_basis(b: Basis) -> Float3xNxN: + return grad_integrate(b, grad_overlap_primitives) + + +def grad_kinetic_basis(b: Basis) -> Float3xNxN: + return grad_integrate(b, grad_kinetic_primitives) + def take_primitives(indices): p = tree_map(lambda x: jnp.take(x, indices, axis=0), primitives) c = jnp.take(coefficients, indices) @@ -35,7 +78,7 @@ def take_primitives(indices): lhs, cl = take_primitives(ii.reshape(-1)) rhs, cr = take_primitives(jj.reshape(-1)) - op = vmap(grad_overlap_primitives, (None, 0, 0)) + op = vmap(grad_kinetic_primitives, (None, 0, 0)) op = jit(vmap(op, (0, None, None))) atom_indices = jnp.arange(b.structure.num_atoms) out = op(atom_indices, lhs, rhs) diff --git a/test/test_integrals.py b/test/test_integrals.py index a88d769..b89d94f 100644 --- a/test/test_integrals.py +++ b/test/test_integrals.py @@ -19,7 +19,10 @@ overlap_primitives, ) from pyscf_ipu.experimental.interop import to_pyscf -from pyscf_ipu.experimental.mesh import electron_density, uniform_mesh +from pyscf_ipu.experimental.nuclear_gradients import ( + grad_kinetic_basis, + grad_overlap_basis, +) from pyscf_ipu.experimental.primitive import Primitive from pyscf_ipu.experimental.structure import molecule @@ -259,8 +262,13 @@ def test_ipu_eri(): @pytest.mark.parametrize("basis_name", ["sto-3g", "6-31+g"]) def test_nuclear_gradients(basis_name): h2 = molecule("h2") - expect = to_pyscf(h2, basis_name).intor("int1e_ipovlp_cart", comp=3) + scfmol = to_pyscf(h2, basis_name) basis = basisset(h2, basis_name) + actual = grad_overlap_basis(basis) + expect = scfmol.intor("int1e_ipovlp_cart", comp=3) + assert_allclose(actual, expect, atol=1e-6) + actual = grad_kinetic_basis(basis) + expect = scfmol.intor("int1e_ipkin_cart", comp=3) assert_allclose(actual, expect, atol=1e-6) From 1b2f2d53c05b0da0087c772dc37246c4302ddbf3 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Thu, 21 Sep 2023 09:48:32 +0000 Subject: [PATCH 07/15] cleanup --- pyscf_ipu/experimental/nuclear_gradients.py | 40 +++------------------ 1 file changed, 5 insertions(+), 35 deletions(-) diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py index a5bcd46..30a01c2 100644 --- a/pyscf_ipu/experimental/nuclear_gradients.py +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -1,5 +1,6 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from typing import Callable +from functools import partial import jax.numpy as jnp from jax import jit, tree_map, vmap from jax.ops import segment_sum @@ -53,40 +54,9 @@ def take_primitives(indices): return jnp.rollaxis(out, -1) -def grad_overlap_primitives(i: int, a: Primitive, b: Primitive) -> Float3: - return grad_primitive_integral(_overlap_primitives, i, a, b) +grad_overlap_primitives = partial(grad_primitive_integral, _overlap_primitives) +grad_kinetic_primitives = partial(grad_primitive_integral, _kinetic_primitives) -def grad_kinetic_primitives(i: int, a: Primitive, b: Primitive) -> Float3: - return grad_primitive_integral(_kinetic_primitives, i, a, b) - - -def grad_overlap_basis(b: Basis) -> Float3xNxN: - return grad_integrate(b, grad_overlap_primitives) - - -def grad_kinetic_basis(b: Basis) -> Float3xNxN: - return grad_integrate(b, grad_kinetic_primitives) - - def take_primitives(indices): - p = tree_map(lambda x: jnp.take(x, indices, axis=0), primitives) - c = jnp.take(coefficients, indices) - return p, c - - primitives, coefficients, orbital_index = batch_orbitals(b.orbitals) - ii, jj = jnp.meshgrid(*[jnp.arange(b.num_primitives)] * 2, indexing="ij") - lhs, cl = take_primitives(ii.reshape(-1)) - rhs, cr = take_primitives(jj.reshape(-1)) - - op = vmap(grad_kinetic_primitives, (None, 0, 0)) - op = jit(vmap(op, (0, None, None))) - atom_indices = jnp.arange(b.structure.num_atoms) - out = op(atom_indices, lhs, rhs) - out = jnp.sum(out, axis=0) - - out = cl * cr * out.T - out = out.reshape(3, b.num_primitives, b.num_primitives) - out = segment_sum(jnp.rollaxis(out, 1), orbital_index) - out = segment_sum(jnp.rollaxis(out, -1), orbital_index) - - return jnp.rollaxis(out, -1) +grad_overlap_basis = partial(grad_integrate, primitive_op=grad_overlap_primitives) +grad_kinetic_basis = partial(grad_integrate, primitive_op=grad_kinetic_primitives) From 4845ab153f55b6423de2d2e55c3a6b33e71a3258 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Fri, 22 Sep 2023 09:33:15 +0000 Subject: [PATCH 08/15] fix pc --- pyscf_ipu/experimental/nuclear_gradients.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py index 30a01c2..7b708a2 100644 --- a/pyscf_ipu/experimental/nuclear_gradients.py +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -1,12 +1,13 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from typing import Callable from functools import partial +from typing import Callable + import jax.numpy as jnp from jax import jit, tree_map, vmap from jax.ops import segment_sum from .basis import Basis -from .integrals import _overlap_primitives, _kinetic_primitives +from .integrals import _kinetic_primitives, _overlap_primitives from .orbital import batch_orbitals from .primitive import Primitive from .types import Float3, Float3xNxN From dec9cd27453533cd37178900fcfd8dc061af0f75 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Fri, 22 Sep 2023 12:28:01 +0000 Subject: [PATCH 09/15] adding doc strings --- pyscf_ipu/experimental/nuclear_gradients.py | 72 +++++++++++++++------ test/test_integrals.py | 6 ++ 2 files changed, 59 insertions(+), 19 deletions(-) diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py index 7b708a2..913713a 100644 --- a/pyscf_ipu/experimental/nuclear_gradients.py +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -3,52 +3,74 @@ from typing import Callable import jax.numpy as jnp -from jax import jit, tree_map, vmap +from jax import tree_map, vmap from jax.ops import segment_sum from .basis import Basis -from .integrals import _kinetic_primitives, _overlap_primitives +from .integrals import _kinetic_primitives, _nuclear_primitives, _overlap_primitives from .orbital import batch_orbitals from .primitive import Primitive from .types import Float3, Float3xNxN def grad_primitive_integral( - primitive_op: Callable, atom_index: int, a: Primitive, b: Primitive + primitive_op: Callable, a: Primitive, b: Primitive ) -> Float3: - """Generic gradient of a one-electron integral with respect the atom_index center""" + """gradient of a one-electron integral over primitive functions defined as: + + < grad_a a | p | b > + + where the cartesian gradient is evaluated with respect to a.center. For Gaussian + primitives this gradient simplifies to: + + 2 * alpha < a(l+1) | p | b > - l * < a(l-1) | p | b > + + where a(l+/-1) -> offset the lmn component for the corresponding gradient component. + + Args: + primitive_op (Callable): integral operation over two primitives (a, b) -> float + a (Primitive): left hand side of the integral + b (Primitive): right hand side of the integral + + Returns: + Float3: Gradient of the integral with respect to cartesian axes. + """ + axes = jnp.arange(3) lhs_p1 = vmap(a.offset_lmn, (0, None))(axes, 1) t1 = 2 * a.alpha * vmap(primitive_op, (0, None))(lhs_p1, b) lhs_m1 = vmap(a.offset_lmn, (0, None))(axes, -1) - t2 = jnp.where(a.lmn > 0, a.lmn, jnp.zeros_like(a.lmn)) - t2 *= vmap(primitive_op, (0, None))(lhs_m1, b) + t2 = jnp.maximum(a.lmn, 0) * vmap(primitive_op, (0, None))(lhs_m1, b) grad_out = t1 - t2 - return jnp.where(a.atom_index == atom_index, grad_out, jnp.zeros_like(grad_out)) + return grad_out + +def grad_integrate(basis: Basis, primitive_op: Callable) -> Float3xNxN: + """gradient of a one-electron integral over the basis set of atomic orbitals. -def grad_integrate(b: Basis, primitive_op: Callable) -> Float3xNxN: - """Generic gradient of one-electron integrals over the basis set""" + Args: + basis (Basis): basis set of N atomic orbitals + primitive_op (Callable): integral operation over two primitives (a, b) -> float + + Returns: + Float3xNxN: Gradient of the integral with respect to cartesian axes evaluated + over the NxN combinations of atomic orbitals. + """ def take_primitives(indices): p = tree_map(lambda x: jnp.take(x, indices, axis=0), primitives) c = jnp.take(coefficients, indices) return p, c - primitives, coefficients, orbital_index = batch_orbitals(b.orbitals) - ii, jj = jnp.meshgrid(*[jnp.arange(b.num_primitives)] * 2, indexing="ij") + primitives, coefficients, orbital_index = batch_orbitals(basis.orbitals) + ii, jj = jnp.meshgrid(*[jnp.arange(basis.num_primitives)] * 2, indexing="ij") lhs, cl = take_primitives(ii.reshape(-1)) rhs, cr = take_primitives(jj.reshape(-1)) - - op = vmap(primitive_op, (None, 0, 0)) - op = jit(vmap(op, (0, None, None))) - atom_indices = jnp.arange(b.structure.num_atoms) - out = op(atom_indices, lhs, rhs) - out = jnp.sum(out, axis=0) + out = vmap(primitive_op)(lhs, rhs) out = cl * cr * out.T - out = out.reshape(3, b.num_primitives, b.num_primitives) + out = out.reshape(3, basis.num_primitives, basis.num_primitives) out = segment_sum(jnp.rollaxis(out, 1), orbital_index) out = segment_sum(jnp.rollaxis(out, -1), orbital_index) @@ -58,6 +80,18 @@ def take_primitives(indices): grad_overlap_primitives = partial(grad_primitive_integral, _overlap_primitives) grad_kinetic_primitives = partial(grad_primitive_integral, _kinetic_primitives) - grad_overlap_basis = partial(grad_integrate, primitive_op=grad_overlap_primitives) grad_kinetic_basis = partial(grad_integrate, primitive_op=grad_kinetic_primitives) + + +def grad_nuclear_primitives(a: Primitive, b: Primitive, c: Float3) -> Float3: + return grad_primitive_integral(partial(_nuclear_primitives, c=c), a, b) + + +def grad_nuclear_basis(basis: Basis): + def nuclear(c, z): + op = partial(grad_nuclear_primitives, c=c) + return z * grad_integrate(basis, op) + + out = vmap(nuclear)(basis.structure.position, basis.structure.atomic_number) + return jnp.sum(out, axis=0) diff --git a/test/test_integrals.py b/test/test_integrals.py index b89d94f..bb09d8a 100644 --- a/test/test_integrals.py +++ b/test/test_integrals.py @@ -21,6 +21,7 @@ from pyscf_ipu.experimental.interop import to_pyscf from pyscf_ipu.experimental.nuclear_gradients import ( grad_kinetic_basis, + grad_nuclear_basis, grad_overlap_basis, ) from pyscf_ipu.experimental.primitive import Primitive @@ -217,6 +218,7 @@ def test_water_eri(sparse): actual = eri_basis_sparse(basis) if sparse else eri_basis(basis) aosym = "s8" if sparse else "s1" expect = to_pyscf(h2o, basis_name=basis_name).intor("int2e_cart", aosym=aosym) + print("max |actual - expect| ={}", np.max(np.abs(actual - expect))) assert_allclose(actual, expect, atol=1e-4) @@ -272,3 +274,7 @@ def test_nuclear_gradients(basis_name): actual = grad_kinetic_basis(basis) expect = scfmol.intor("int1e_ipkin_cart", comp=3) assert_allclose(actual, expect, atol=1e-6) + + actual = grad_nuclear_basis(basis) + expect = scfmol.intor("int1e_ipnuc_cart", comp=3) + assert_allclose(actual, expect, atol=1e-6) From cfddd2876b172ff857c2dd44650c396b238578f1 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Sun, 24 Sep 2023 12:31:23 +0000 Subject: [PATCH 10/15] passing nuclear test case --- pyscf_ipu/experimental/nuclear_gradients.py | 3 +-- test/test_integrals.py | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py index 913713a..0cab770 100644 --- a/pyscf_ipu/experimental/nuclear_gradients.py +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -41,7 +41,7 @@ def grad_primitive_integral( t1 = 2 * a.alpha * vmap(primitive_op, (0, None))(lhs_p1, b) lhs_m1 = vmap(a.offset_lmn, (0, None))(axes, -1) - t2 = jnp.maximum(a.lmn, 0) * vmap(primitive_op, (0, None))(lhs_m1, b) + t2 = a.lmn * vmap(primitive_op, (0, None))(lhs_m1, b) grad_out = t1 - t2 return grad_out @@ -73,7 +73,6 @@ def take_primitives(indices): out = out.reshape(3, basis.num_primitives, basis.num_primitives) out = segment_sum(jnp.rollaxis(out, 1), orbital_index) out = segment_sum(jnp.rollaxis(out, -1), orbital_index) - return jnp.rollaxis(out, -1) diff --git a/test/test_integrals.py b/test/test_integrals.py index bb09d8a..22346be 100644 --- a/test/test_integrals.py +++ b/test/test_integrals.py @@ -275,6 +275,8 @@ def test_nuclear_gradients(basis_name): expect = scfmol.intor("int1e_ipkin_cart", comp=3) assert_allclose(actual, expect, atol=1e-6) + # TODO: investigate possible inconsistency in libcint outputs? actual = grad_nuclear_basis(basis) expect = scfmol.intor("int1e_ipnuc_cart", comp=3) + expect = -np.moveaxis(expect, 1, 2) assert_allclose(actual, expect, atol=1e-6) From c231e2a6453e5db93a63b55fe624979e2e12178c Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Sun, 24 Sep 2023 14:23:36 +0000 Subject: [PATCH 11/15] adding doc strings --- pyscf_ipu/experimental/nuclear_gradients.py | 85 +++++++++++++++++++-- 1 file changed, 77 insertions(+), 8 deletions(-) diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py index 0cab770..e01976a 100644 --- a/pyscf_ipu/experimental/nuclear_gradients.py +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -46,6 +46,50 @@ def grad_primitive_integral( return grad_out +def grad_overlap_primitives(a: Primitive, b: Primitive) -> Float3: + """Evaluate the gradient of the overlap integral between primitives a and b. The + gradient is with respect to a.center. + + Args: + a (Primitive): left hand side of the overlap integral. + b (Primitive): right hand side of the overlap integral. + + Returns: + Float3: Gradient of the overlap integral with respect to cartesian axes. + """ + return grad_primitive_integral(_overlap_primitives, a, b) + + +def grad_kinetic_primitives(a: Primitive, b: Primitive) -> Float3: + """Evaluate the gradient of the kinetic energy integral between primitives a and b. + The gradient is with respect to a.center. + + Args: + a (Primitive): left hand side of the kinetic energy integral. + b (Primitive): right hand side of the kinetic energy integral. + + Returns: + Float3: Gradient of the kinetic energy integral with respect to cartesian axes. + """ + return grad_primitive_integral(_kinetic_primitives, a, b) + + +def grad_nuclear_primitives(a: Primitive, b: Primitive, c: Float3) -> Float3: + """Evaluate the gradient of the nuclear attraction integral between primitives a and + b, and the nuclear potential centered on c. Gradient is with respect to a.center. + + Args: + a (Primitive): left hand side of the nuclear attraction integral. + b (Primitive): right hand side of the nuclear attraction integral. + c (Float3): center for the nuclear attraction potential 1/(r - c) + + Returns: + Float3: Gradient of the nuclear attraction integral with respect to cartesian + axes + """ + return grad_primitive_integral(partial(_nuclear_primitives, c=c), a, b) + + def grad_integrate(basis: Basis, primitive_op: Callable) -> Float3xNxN: """gradient of a one-electron integral over the basis set of atomic orbitals. @@ -55,7 +99,7 @@ def grad_integrate(basis: Basis, primitive_op: Callable) -> Float3xNxN: Returns: Float3xNxN: Gradient of the integral with respect to cartesian axes evaluated - over the NxN combinations of atomic orbitals. + over the NxN combinations of atomic orbitals. """ def take_primitives(indices): @@ -76,18 +120,43 @@ def take_primitives(indices): return jnp.rollaxis(out, -1) -grad_overlap_primitives = partial(grad_primitive_integral, _overlap_primitives) -grad_kinetic_primitives = partial(grad_primitive_integral, _kinetic_primitives) +def grad_overlap_basis(basis: Basis) -> Float3xNxN: + """gradient of the overlap integral over the basis set of atomic orbitals -grad_overlap_basis = partial(grad_integrate, primitive_op=grad_overlap_primitives) -grad_kinetic_basis = partial(grad_integrate, primitive_op=grad_kinetic_primitives) + Args: + basis (Basis): basis set of N atomic orbitals + Returns: + Float3xNxN: Gradient of the overlap integral with respect to cartesian axes + evaluated over the NxN combinations of atomic orbitals. + """ + return grad_integrate(basis, grad_overlap_primitives) -def grad_nuclear_primitives(a: Primitive, b: Primitive, c: Float3) -> Float3: - return grad_primitive_integral(partial(_nuclear_primitives, c=c), a, b) +def grad_kinetic_basis(basis: Basis) -> Float3xNxN: + """gradient of the kinetic energy integral over the basis set of atomic orbitals + + Args: + basis (Basis): basis set of N atomic orbitals + + Returns: + Float3xNxN: Gradient of the kinetic energy integral with respect to cartesian + axes evaluated over the NxN combinations of atomic orbitals. + """ + return grad_integrate(basis, grad_kinetic_primitives) + + +def grad_nuclear_basis(basis: Basis) -> Float3xNxN: + """gradient of the nuclear attraction integral over the basis set of atomic orbitals + + Args: + basis (Basis): basis set of N atomic orbitals + + Returns: + Float3xNxN: Gradient of the nuclear attraction integral with respect to + cartesian axes evaluated over the NxN combinations of atomic orbitals. + """ -def grad_nuclear_basis(basis: Basis): def nuclear(c, z): op = partial(grad_nuclear_primitives, c=c) return z * grad_integrate(basis, op) From c049a0af5876b98db3de2ea912be35c4eca66a34 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Sun, 24 Sep 2023 14:34:22 +0000 Subject: [PATCH 12/15] rm debug --- test/test_integrals.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_integrals.py b/test/test_integrals.py index 22346be..8066b5f 100644 --- a/test/test_integrals.py +++ b/test/test_integrals.py @@ -218,7 +218,6 @@ def test_water_eri(sparse): actual = eri_basis_sparse(basis) if sparse else eri_basis(basis) aosym = "s8" if sparse else "s1" expect = to_pyscf(h2o, basis_name=basis_name).intor("int2e_cart", aosym=aosym) - print("max |actual - expect| ={}", np.max(np.abs(actual - expect))) assert_allclose(actual, expect, atol=1e-4) From 288d89448da4016477e0eb212c75b2a87cac2b8a Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Mon, 25 Sep 2023 09:00:57 +0000 Subject: [PATCH 13/15] replace vmap with list comprehension --- pyscf_ipu/experimental/nuclear_gradients.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py index e01976a..e72b881 100644 --- a/pyscf_ipu/experimental/nuclear_gradients.py +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -36,13 +36,9 @@ def grad_primitive_integral( Float3: Gradient of the integral with respect to cartesian axes. """ - axes = jnp.arange(3) - lhs_p1 = vmap(a.offset_lmn, (0, None))(axes, 1) - t1 = 2 * a.alpha * vmap(primitive_op, (0, None))(lhs_p1, b) - - lhs_m1 = vmap(a.offset_lmn, (0, None))(axes, -1) - t2 = a.lmn * vmap(primitive_op, (0, None))(lhs_m1, b) - grad_out = t1 - t2 + t1 = [primitive_op(a.offset_lmn(ax, 1), b) for ax in range(3)] + t2 = [primitive_op(a.offset_lmn(ax, -1), b) for ax in range(3)] + grad_out = 2 * a.alpha * jnp.stack(t1) - a.lmn * jnp.stack(t2) return grad_out From 64deee4cbd6e9f81800301b86bfbabe47e076961 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Tue, 26 Sep 2023 08:08:25 +0000 Subject: [PATCH 14/15] add num_segments to support jit compilation --- pyscf_ipu/experimental/nuclear_gradients.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py index e72b881..f008e1c 100644 --- a/pyscf_ipu/experimental/nuclear_gradients.py +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -111,8 +111,10 @@ def take_primitives(indices): out = cl * cr * out.T out = out.reshape(3, basis.num_primitives, basis.num_primitives) - out = segment_sum(jnp.rollaxis(out, 1), orbital_index) - out = segment_sum(jnp.rollaxis(out, -1), orbital_index) + out = jnp.rollaxis(out, 1) + out = segment_sum(out, orbital_index, num_segments=basis.num_orbitals) + out = jnp.rollaxis(out, -1) + out = segment_sum(out, orbital_index, num_segments=basis.num_orbitals) return jnp.rollaxis(out, -1) From 58e6383a853fb4680f7131cad31733b4e0d931eb Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Tue, 26 Sep 2023 18:59:05 +0000 Subject: [PATCH 15/15] separate test file --- pyscf_ipu/experimental/interop.py | 6 +- pyscf_ipu/experimental/nuclear_gradients.py | 1 - pyscf_ipu/experimental/types.py | 1 + test/test_integrals.py | 128 -------------------- test/test_nuclear_gradients.py | 33 +++++ 5 files changed, 36 insertions(+), 133 deletions(-) create mode 100644 test/test_nuclear_gradients.py diff --git a/pyscf_ipu/experimental/interop.py b/pyscf_ipu/experimental/interop.py index 0b57574..dc724e9 100644 --- a/pyscf_ipu/experimental/interop.py +++ b/pyscf_ipu/experimental/interop.py @@ -12,11 +12,9 @@ def to_pyscf( structure: Structure, basis_name: str = "sto-3g", unit: str = "Bohr" ) -> "gto.Mole": + position = np.asarray(structure.position) mol = gto.Mole(unit=unit, spin=structure.num_electrons % 2, cart=True) - mol.atom = [ - (symbol, pos) - for symbol, pos in zip(structure.atomic_symbol, structure.position) - ] + mol.atom = [(symbol, pos) for symbol, pos in zip(structure.atomic_symbol, position)] mol.basis = basis_name mol.build(unit=unit) return mol diff --git a/pyscf_ipu/experimental/nuclear_gradients.py b/pyscf_ipu/experimental/nuclear_gradients.py index f008e1c..3189251 100644 --- a/pyscf_ipu/experimental/nuclear_gradients.py +++ b/pyscf_ipu/experimental/nuclear_gradients.py @@ -35,7 +35,6 @@ def grad_primitive_integral( Returns: Float3: Gradient of the integral with respect to cartesian axes. """ - t1 = [primitive_op(a.offset_lmn(ax, 1), b) for ax in range(3)] t2 = [primitive_op(a.offset_lmn(ax, -1), b) for ax in range(3)] grad_out = 2 * a.alpha * jnp.stack(t1) - a.lmn * jnp.stack(t2) diff --git a/pyscf_ipu/experimental/types.py b/pyscf_ipu/experimental/types.py index 642d3e8..07a8848 100644 --- a/pyscf_ipu/experimental/types.py +++ b/pyscf_ipu/experimental/types.py @@ -5,6 +5,7 @@ Float3 = Float[Array, "3"] Float3xNxN = Float[Array, "3 N N"] +Float3xNxNxNxN = Float[Array, "3 N N N N"] FloatNx3 = Float[Array, "N 3"] FloatN = Float[Array, "N"] FloatNxN = Float[Array, "N N"] diff --git a/test/test_integrals.py b/test/test_integrals.py index 8066b5f..00b359a 100644 --- a/test/test_integrals.py +++ b/test/test_integrals.py @@ -2,11 +2,9 @@ import jax.numpy as jnp import numpy as np import pytest -from jax import tree_map, vmap from numpy.testing import assert_allclose from pyscf_ipu.experimental.basis import basisset -from pyscf_ipu.experimental.device import has_ipu, ipu_func from pyscf_ipu.experimental.integrals import ( eri_basis, eri_basis_sparse, @@ -19,51 +17,10 @@ overlap_primitives, ) from pyscf_ipu.experimental.interop import to_pyscf -from pyscf_ipu.experimental.nuclear_gradients import ( - grad_kinetic_basis, - grad_nuclear_basis, - grad_overlap_basis, -) from pyscf_ipu.experimental.primitive import Primitive from pyscf_ipu.experimental.structure import molecule -@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31g**"]) -def test_to_pyscf(basis_name): - mol = molecule("water") - basis = basisset(mol, basis_name) - pyscf_mol = to_pyscf(mol, basis_name) - assert basis.num_orbitals == pyscf_mol.nao - - -@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31+g"]) -def test_gto(basis_name): - from pyscf.dft.numint import eval_rho - - # Atomic orbitals - structure = molecule("water") - basis = basisset(structure, basis_name) - mesh, _ = uniform_mesh() - actual = basis(mesh) - - mol = to_pyscf(structure, basis_name) - expect_ao = mol.eval_gto("GTOval_cart", np.asarray(mesh)) - assert_allclose(actual, expect_ao, atol=1e-6) - - # Molecular orbitals - mf = mol.KS() - mf.kernel() - C = jnp.array(mf.mo_coeff, dtype=jnp.float32) - actual = basis.occupancy * C @ C.T - expect = jnp.array(mf.make_rdm1(), dtype=jnp.float32) - assert_allclose(actual, expect, atol=1e-6) - - # Electron density - actual = electron_density(basis, mesh, C) - expect = eval_rho(mol, expect_ao, mf.make_rdm1(), "lda") - assert_allclose(actual, expect, atol=1e-6) - - def test_overlap(): # Exercise 3.21 of "Modern quantum chemistry: introduction to advanced # electronic structure theory."" by Szabo and Ostlund @@ -151,19 +108,6 @@ def test_water_nuclear(): assert_allclose(actual, expect, atol=1e-4) -def eri_orbitals(orbitals): - def take(orbital, index): - p = tree_map(lambda *xs: jnp.stack(xs), *orbital.primitives) - p = tree_map(lambda x: jnp.take(x, index, axis=0), p) - c = jnp.take(orbital.coefficients, index) - return p, c - - indices = [jnp.arange(o.num_primitives) for o in orbitals] - indices = [i.reshape(-1) for i in jnp.meshgrid(*indices)] - prim, coef = zip(*[take(o, i) for o, i in zip(orbitals, indices)]) - return jnp.sum(jnp.prod(jnp.stack(coef), axis=0) * vmap(eri_primitives)(*prim)) - - def test_eri(): # PyQuante test cases for ERI a, b, c, d = [Primitive()] * 4 @@ -172,18 +116,6 @@ def test_eri(): c, d = [Primitive(lmn=jnp.array([1, 0, 0]))] * 2 assert_allclose(eri_primitives(a, b, c, d), 0.940316, atol=1e-5) - # H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund - h2 = molecule("h2") - basis = basisset(h2, "sto-3g") - indices = [(0, 0, 0, 0), (0, 0, 1, 1), (1, 0, 0, 0), (1, 0, 1, 0)] - expected = [0.7746, 0.5697, 0.4441, 0.2970] - - for ijkl, expect in zip(indices, expected): - actual = eri_orbitals([basis.orbitals[aoid] for aoid in ijkl]) - assert_allclose(actual, expect, atol=1e-4) - - -def test_eri_basis(): # H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund h2 = molecule("h2") basis = basisset(h2, "sto-3g") @@ -219,63 +151,3 @@ def test_water_eri(sparse): aosym = "s8" if sparse else "s1" expect = to_pyscf(h2o, basis_name=basis_name).intor("int2e_cart", aosym=aosym) assert_allclose(actual, expect, atol=1e-4) - - -@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") -def test_ipu_overlap(): - from pyscf_ipu.experimental.integrals import _overlap_primitives - - a, b = [Primitive()] * 2 - actual = ipu_func(_overlap_primitives)(a, b) - assert_allclose(actual, overlap_primitives(a, b)) - - -@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") -def test_ipu_kinetic(): - from pyscf_ipu.experimental.integrals import _kinetic_primitives - - a, b = [Primitive()] * 2 - actual = ipu_func(_kinetic_primitives)(a, b) - assert_allclose(actual, kinetic_primitives(a, b)) - - -@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") -def test_ipu_nuclear(): - from pyscf_ipu.experimental.integrals import _nuclear_primitives - - # PyQuante test case for nuclear attraction integral - a, b = [Primitive()] * 2 - c = jnp.zeros(3) - actual = ipu_func(_nuclear_primitives)(a, b, c) - assert_allclose(actual, -1.595769, atol=1e-5) - - -@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") -def test_ipu_eri(): - from pyscf_ipu.experimental.integrals import _eri_primitives - - # PyQuante test cases for ERI - a, b, c, d = [Primitive()] * 4 - actual = ipu_func(_eri_primitives)(a, b, c, d) - assert_allclose(actual, 1.128379, atol=1e-5) - - -@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31+g"]) -def test_nuclear_gradients(basis_name): - h2 = molecule("h2") - scfmol = to_pyscf(h2, basis_name) - basis = basisset(h2, basis_name) - - actual = grad_overlap_basis(basis) - expect = scfmol.intor("int1e_ipovlp_cart", comp=3) - assert_allclose(actual, expect, atol=1e-6) - - actual = grad_kinetic_basis(basis) - expect = scfmol.intor("int1e_ipkin_cart", comp=3) - assert_allclose(actual, expect, atol=1e-6) - - # TODO: investigate possible inconsistency in libcint outputs? - actual = grad_nuclear_basis(basis) - expect = scfmol.intor("int1e_ipnuc_cart", comp=3) - expect = -np.moveaxis(expect, 1, 2) - assert_allclose(actual, expect, atol=1e-6) diff --git a/test/test_nuclear_gradients.py b/test/test_nuclear_gradients.py new file mode 100644 index 0000000..2688acd --- /dev/null +++ b/test/test_nuclear_gradients.py @@ -0,0 +1,33 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import numpy as np +from numpy.testing import assert_allclose + +from pyscf_ipu.experimental.basis import basisset +from pyscf_ipu.experimental.interop import to_pyscf +from pyscf_ipu.experimental.nuclear_gradients import ( + grad_kinetic_basis, + grad_nuclear_basis, + grad_overlap_basis, +) +from pyscf_ipu.experimental.structure import molecule + + +def test_nuclear_gradients(): + basis_name = "sto-3g" + h2 = molecule("h2") + scfmol = to_pyscf(h2, basis_name) + basis = basisset(h2, basis_name) + + actual = grad_overlap_basis(basis) + expect = scfmol.intor("int1e_ipovlp_cart", comp=3) + assert_allclose(actual, expect, atol=1e-6) + + actual = grad_kinetic_basis(basis) + expect = scfmol.intor("int1e_ipkin_cart", comp=3) + assert_allclose(actual, expect, atol=1e-6) + + # TODO: investigate possible inconsistency in libcint outputs? + actual = grad_nuclear_basis(basis) + expect = scfmol.intor("int1e_ipnuc_cart", comp=3) + expect = -np.moveaxis(expect, 1, 2) + assert_allclose(actual, expect, atol=1e-6)