Skip to content
1 change: 1 addition & 0 deletions pyscf_ipu/experimental/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def basisset(structure: Structure, basis_name: str = "sto-3g"):
alphas=jnp.array(s["exponents"], dtype=float),
lmn=jnp.array(lmn, dtype=jnp.int32),
coefficients=jnp.array(s["coefficients"], dtype=float),
atom_index=a,
)
orbitals.append(ao)

Expand Down
9 changes: 2 additions & 7 deletions pyscf_ipu/experimental/integrals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from dataclasses import asdict
from functools import partial
from itertools import product as cartesian_product
from typing import Callable
Expand Down Expand Up @@ -84,15 +83,11 @@ def _overlap_primitives(a: Primitive, b: Primitive) -> float:
def _kinetic_primitives(a: Primitive, b: Primitive) -> float:
t0 = b.alpha * (2 * jnp.sum(b.lmn) + 3) * _overlap_primitives(a, b)

def offset_qn(ax: int, offset: int):
lmn = b.lmn.at[ax].add(offset)
return Primitive(**{**asdict(b), "lmn": lmn})

axes = jnp.arange(3)
b1 = vmap(offset_qn, (0, None))(axes, 2)
b1 = vmap(b.offset_lmn, (0, None))(axes, 2)
t1 = jnp.sum(vmap(_overlap_primitives, (None, 0))(a, b1))

b2 = vmap(offset_qn, (0, None))(axes, -2)
b2 = vmap(b.offset_lmn, (0, None))(axes, -2)
t2 = jnp.sum(b.lmn * (b.lmn - 1) * vmap(_overlap_primitives, (None, 0))(a, b2))
return t0 - 2.0 * b.alpha**2 * t1 - 0.5 * t2

Expand Down
6 changes: 2 additions & 4 deletions pyscf_ipu/experimental/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
def to_pyscf(
structure: Structure, basis_name: str = "sto-3g", unit: str = "Bohr"
) -> "gto.Mole":
position = np.asarray(structure.position)
mol = gto.Mole(unit=unit, spin=structure.num_electrons % 2, cart=True)
mol.atom = [
(symbol, pos)
for symbol, pos in zip(structure.atomic_symbol, structure.position)
]
mol.atom = [(symbol, pos) for symbol, pos in zip(structure.atomic_symbol, position)]
mol.basis = basis_name
mol.build(unit=unit)
return mol
Expand Down
162 changes: 162 additions & 0 deletions pyscf_ipu/experimental/nuclear_gradients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from functools import partial
from typing import Callable

import jax.numpy as jnp
from jax import tree_map, vmap
from jax.ops import segment_sum

from .basis import Basis
from .integrals import _kinetic_primitives, _nuclear_primitives, _overlap_primitives
from .orbital import batch_orbitals
from .primitive import Primitive
from .types import Float3, Float3xNxN


def grad_primitive_integral(
primitive_op: Callable, a: Primitive, b: Primitive
) -> Float3:
"""gradient of a one-electron integral over primitive functions defined as:

< grad_a a | p | b >

where the cartesian gradient is evaluated with respect to a.center. For Gaussian
primitives this gradient simplifies to:

2 * alpha < a(l+1) | p | b > - l * < a(l-1) | p | b >

where a(l+/-1) -> offset the lmn component for the corresponding gradient component.

Args:
primitive_op (Callable): integral operation over two primitives (a, b) -> float
a (Primitive): left hand side of the integral
b (Primitive): right hand side of the integral

Returns:
Float3: Gradient of the integral with respect to cartesian axes.
"""
t1 = [primitive_op(a.offset_lmn(ax, 1), b) for ax in range(3)]
t2 = [primitive_op(a.offset_lmn(ax, -1), b) for ax in range(3)]
grad_out = 2 * a.alpha * jnp.stack(t1) - a.lmn * jnp.stack(t2)
return grad_out


def grad_overlap_primitives(a: Primitive, b: Primitive) -> Float3:
"""Evaluate the gradient of the overlap integral between primitives a and b. The
gradient is with respect to a.center.

Args:
a (Primitive): left hand side of the overlap integral.
b (Primitive): right hand side of the overlap integral.

Returns:
Float3: Gradient of the overlap integral with respect to cartesian axes.
"""
return grad_primitive_integral(_overlap_primitives, a, b)


def grad_kinetic_primitives(a: Primitive, b: Primitive) -> Float3:
"""Evaluate the gradient of the kinetic energy integral between primitives a and b.
The gradient is with respect to a.center.

Args:
a (Primitive): left hand side of the kinetic energy integral.
b (Primitive): right hand side of the kinetic energy integral.

Returns:
Float3: Gradient of the kinetic energy integral with respect to cartesian axes.
"""
return grad_primitive_integral(_kinetic_primitives, a, b)


def grad_nuclear_primitives(a: Primitive, b: Primitive, c: Float3) -> Float3:
"""Evaluate the gradient of the nuclear attraction integral between primitives a and
b, and the nuclear potential centered on c. Gradient is with respect to a.center.

Args:
a (Primitive): left hand side of the nuclear attraction integral.
b (Primitive): right hand side of the nuclear attraction integral.
c (Float3): center for the nuclear attraction potential 1/(r - c)

Returns:
Float3: Gradient of the nuclear attraction integral with respect to cartesian
axes
"""
return grad_primitive_integral(partial(_nuclear_primitives, c=c), a, b)


def grad_integrate(basis: Basis, primitive_op: Callable) -> Float3xNxN:
"""gradient of a one-electron integral over the basis set of atomic orbitals.

Args:
basis (Basis): basis set of N atomic orbitals
primitive_op (Callable): integral operation over two primitives (a, b) -> float

Returns:
Float3xNxN: Gradient of the integral with respect to cartesian axes evaluated
over the NxN combinations of atomic orbitals.
"""

def take_primitives(indices):
p = tree_map(lambda x: jnp.take(x, indices, axis=0), primitives)
c = jnp.take(coefficients, indices)
return p, c
Comment on lines +100 to +103
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like to define these local lambdas after the variables to which they refer have been defined. It puts the computation closer to the point of use.

(i.e. move it after line 37)

OTOH, batch_orbitals is already doing a lot of Python list comprehensions, so it might be clearer and the same complexity to treat b as a list of lists of primitives, and just inline the list comprehensions here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a think, this lambda is actually used in a few places now -> maybe it should be promoted from a function-local to a definition in a module?


primitives, coefficients, orbital_index = batch_orbitals(basis.orbitals)
ii, jj = jnp.meshgrid(*[jnp.arange(basis.num_primitives)] * 2, indexing="ij")
lhs, cl = take_primitives(ii.reshape(-1))
rhs, cr = take_primitives(jj.reshape(-1))
out = vmap(primitive_op)(lhs, rhs)

out = cl * cr * out.T
out = out.reshape(3, basis.num_primitives, basis.num_primitives)
out = jnp.rollaxis(out, 1)
out = segment_sum(out, orbital_index, num_segments=basis.num_orbitals)
out = jnp.rollaxis(out, -1)
out = segment_sum(out, orbital_index, num_segments=basis.num_orbitals)
return jnp.rollaxis(out, -1)


def grad_overlap_basis(basis: Basis) -> Float3xNxN:
"""gradient of the overlap integral over the basis set of atomic orbitals

Args:
basis (Basis): basis set of N atomic orbitals

Returns:
Float3xNxN: Gradient of the overlap integral with respect to cartesian axes
evaluated over the NxN combinations of atomic orbitals.
"""
return grad_integrate(basis, grad_overlap_primitives)


def grad_kinetic_basis(basis: Basis) -> Float3xNxN:
"""gradient of the kinetic energy integral over the basis set of atomic orbitals

Args:
basis (Basis): basis set of N atomic orbitals

Returns:
Float3xNxN: Gradient of the kinetic energy integral with respect to cartesian
axes evaluated over the NxN combinations of atomic orbitals.
"""
return grad_integrate(basis, grad_kinetic_primitives)


def grad_nuclear_basis(basis: Basis) -> Float3xNxN:
"""gradient of the nuclear attraction integral over the basis set of atomic orbitals

Args:
basis (Basis): basis set of N atomic orbitals

Returns:
Float3xNxN: Gradient of the nuclear attraction integral with respect to
cartesian axes evaluated over the NxN combinations of atomic orbitals.
"""

def nuclear(c, z):
op = partial(grad_nuclear_primitives, c=c)
return z * grad_integrate(basis, op)

out = vmap(nuclear)(basis.structure.position, basis.structure.atomic_number)
return jnp.sum(out, axis=0)
7 changes: 5 additions & 2 deletions pyscf_ipu/experimental/orbital.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ def eval_orbital(p: Primitive, coef: float, pos: FloatNx3):
return out

@staticmethod
def from_bse(center, alphas, lmn, coefficients):
def from_bse(center, alphas, lmn, coefficients, atom_index):
coefficients = coefficients.reshape(-1)
assert len(coefficients) == len(alphas), "Expecting same size vectors!"
p = [Primitive(center=center, alpha=a, lmn=lmn) for a in alphas]
p = [
Primitive(center=center, alpha=a, lmn=lmn, atom_index=atom_index)
for a in alphas
]
return Orbital(primitives=p, coefficients=coefficients)


Expand Down
7 changes: 7 additions & 0 deletions pyscf_ipu/experimental/primitive.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from copy import deepcopy
from typing import Optional

import chex
Expand All @@ -14,6 +15,7 @@ class Primitive:
alpha: float = 1.0
lmn: Int3 = jnp.zeros(3, dtype=jnp.int32)
norm: Optional[float] = None
atom_index: Optional[int] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a back-pointer into the list of which this is a member?
That's also a hint that the where(a.atom_index == atom_index) could be lifted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, the problem I haven't solved is not redundantly storing the center on both the primitives and in the Structure


def __post_init__(self):
if self.norm is None:
Expand All @@ -26,6 +28,11 @@ def angular_momentum(self) -> int:
def __call__(self, pos: FloatNx3) -> FloatN:
return eval_primitive(self, pos)

def offset_lmn(self, axis: int, offset: int) -> "Primitive":
out = deepcopy(self)
out.lmn = self.lmn.at[axis].add(offset)
return out


def normalize(lmn: Int3, alpha: float) -> float:
L = jnp.sum(lmn)
Expand Down
2 changes: 2 additions & 0 deletions pyscf_ipu/experimental/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from jaxtyping import Array, Float, Int

Float3 = Float[Array, "3"]
Float3xNxN = Float[Array, "3 N N"]
Float3xNxNxNxN = Float[Array, "3 N N N N"]
FloatNx3 = Float[Array, "N 3"]
FloatN = Float[Array, "N"]
FloatNxN = Float[Array, "N N"]
Expand Down
33 changes: 33 additions & 0 deletions test/test_nuclear_gradients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import numpy as np
from numpy.testing import assert_allclose

from pyscf_ipu.experimental.basis import basisset
from pyscf_ipu.experimental.interop import to_pyscf
from pyscf_ipu.experimental.nuclear_gradients import (
grad_kinetic_basis,
grad_nuclear_basis,
grad_overlap_basis,
)
from pyscf_ipu.experimental.structure import molecule


def test_nuclear_gradients():
basis_name = "sto-3g"
h2 = molecule("h2")
scfmol = to_pyscf(h2, basis_name)
basis = basisset(h2, basis_name)

actual = grad_overlap_basis(basis)
expect = scfmol.intor("int1e_ipovlp_cart", comp=3)
assert_allclose(actual, expect, atol=1e-6)

actual = grad_kinetic_basis(basis)
expect = scfmol.intor("int1e_ipkin_cart", comp=3)
assert_allclose(actual, expect, atol=1e-6)

# TODO: investigate possible inconsistency in libcint outputs?
actual = grad_nuclear_basis(basis)
expect = scfmol.intor("int1e_ipnuc_cart", comp=3)
expect = -np.moveaxis(expect, 1, 2)
assert_allclose(actual, expect, atol=1e-6)