-
Notifications
You must be signed in to change notification settings - Fork 2
WIP: nuclear gradients #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1d6a620
345d115
d0e0121
5ee86e3
f41070e
38bcc4c
1b2f2d5
4845ab1
dec9cd2
cfddd28
c231e2a
c049a0a
288d894
64deee4
58e6383
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| # Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
| from functools import partial | ||
| from typing import Callable | ||
|
|
||
| import jax.numpy as jnp | ||
| from jax import tree_map, vmap | ||
| from jax.ops import segment_sum | ||
|
|
||
| from .basis import Basis | ||
| from .integrals import _kinetic_primitives, _nuclear_primitives, _overlap_primitives | ||
| from .orbital import batch_orbitals | ||
| from .primitive import Primitive | ||
| from .types import Float3, Float3xNxN | ||
|
|
||
|
|
||
| def grad_primitive_integral( | ||
| primitive_op: Callable, a: Primitive, b: Primitive | ||
| ) -> Float3: | ||
| """gradient of a one-electron integral over primitive functions defined as: | ||
|
|
||
| < grad_a a | p | b > | ||
|
|
||
| where the cartesian gradient is evaluated with respect to a.center. For Gaussian | ||
| primitives this gradient simplifies to: | ||
|
|
||
| 2 * alpha < a(l+1) | p | b > - l * < a(l-1) | p | b > | ||
|
|
||
| where a(l+/-1) -> offset the lmn component for the corresponding gradient component. | ||
|
|
||
| Args: | ||
| primitive_op (Callable): integral operation over two primitives (a, b) -> float | ||
| a (Primitive): left hand side of the integral | ||
| b (Primitive): right hand side of the integral | ||
|
|
||
| Returns: | ||
| Float3: Gradient of the integral with respect to cartesian axes. | ||
| """ | ||
| t1 = [primitive_op(a.offset_lmn(ax, 1), b) for ax in range(3)] | ||
| t2 = [primitive_op(a.offset_lmn(ax, -1), b) for ax in range(3)] | ||
| grad_out = 2 * a.alpha * jnp.stack(t1) - a.lmn * jnp.stack(t2) | ||
| return grad_out | ||
|
|
||
|
|
||
| def grad_overlap_primitives(a: Primitive, b: Primitive) -> Float3: | ||
| """Evaluate the gradient of the overlap integral between primitives a and b. The | ||
| gradient is with respect to a.center. | ||
|
|
||
| Args: | ||
| a (Primitive): left hand side of the overlap integral. | ||
| b (Primitive): right hand side of the overlap integral. | ||
|
|
||
| Returns: | ||
| Float3: Gradient of the overlap integral with respect to cartesian axes. | ||
| """ | ||
| return grad_primitive_integral(_overlap_primitives, a, b) | ||
|
|
||
|
|
||
| def grad_kinetic_primitives(a: Primitive, b: Primitive) -> Float3: | ||
| """Evaluate the gradient of the kinetic energy integral between primitives a and b. | ||
| The gradient is with respect to a.center. | ||
|
|
||
| Args: | ||
| a (Primitive): left hand side of the kinetic energy integral. | ||
| b (Primitive): right hand side of the kinetic energy integral. | ||
|
|
||
| Returns: | ||
| Float3: Gradient of the kinetic energy integral with respect to cartesian axes. | ||
| """ | ||
| return grad_primitive_integral(_kinetic_primitives, a, b) | ||
|
|
||
|
|
||
| def grad_nuclear_primitives(a: Primitive, b: Primitive, c: Float3) -> Float3: | ||
| """Evaluate the gradient of the nuclear attraction integral between primitives a and | ||
| b, and the nuclear potential centered on c. Gradient is with respect to a.center. | ||
|
|
||
| Args: | ||
| a (Primitive): left hand side of the nuclear attraction integral. | ||
| b (Primitive): right hand side of the nuclear attraction integral. | ||
| c (Float3): center for the nuclear attraction potential 1/(r - c) | ||
|
|
||
| Returns: | ||
| Float3: Gradient of the nuclear attraction integral with respect to cartesian | ||
| axes | ||
| """ | ||
| return grad_primitive_integral(partial(_nuclear_primitives, c=c), a, b) | ||
|
|
||
|
|
||
| def grad_integrate(basis: Basis, primitive_op: Callable) -> Float3xNxN: | ||
| """gradient of a one-electron integral over the basis set of atomic orbitals. | ||
|
|
||
| Args: | ||
| basis (Basis): basis set of N atomic orbitals | ||
| primitive_op (Callable): integral operation over two primitives (a, b) -> float | ||
|
|
||
| Returns: | ||
| Float3xNxN: Gradient of the integral with respect to cartesian axes evaluated | ||
| over the NxN combinations of atomic orbitals. | ||
| """ | ||
|
|
||
| def take_primitives(indices): | ||
| p = tree_map(lambda x: jnp.take(x, indices, axis=0), primitives) | ||
| c = jnp.take(coefficients, indices) | ||
| return p, c | ||
|
|
||
| primitives, coefficients, orbital_index = batch_orbitals(basis.orbitals) | ||
| ii, jj = jnp.meshgrid(*[jnp.arange(basis.num_primitives)] * 2, indexing="ij") | ||
| lhs, cl = take_primitives(ii.reshape(-1)) | ||
| rhs, cr = take_primitives(jj.reshape(-1)) | ||
| out = vmap(primitive_op)(lhs, rhs) | ||
|
|
||
| out = cl * cr * out.T | ||
| out = out.reshape(3, basis.num_primitives, basis.num_primitives) | ||
| out = jnp.rollaxis(out, 1) | ||
| out = segment_sum(out, orbital_index, num_segments=basis.num_orbitals) | ||
| out = jnp.rollaxis(out, -1) | ||
| out = segment_sum(out, orbital_index, num_segments=basis.num_orbitals) | ||
| return jnp.rollaxis(out, -1) | ||
|
|
||
|
|
||
| def grad_overlap_basis(basis: Basis) -> Float3xNxN: | ||
| """gradient of the overlap integral over the basis set of atomic orbitals | ||
|
|
||
| Args: | ||
| basis (Basis): basis set of N atomic orbitals | ||
|
|
||
| Returns: | ||
| Float3xNxN: Gradient of the overlap integral with respect to cartesian axes | ||
| evaluated over the NxN combinations of atomic orbitals. | ||
| """ | ||
| return grad_integrate(basis, grad_overlap_primitives) | ||
|
|
||
|
|
||
| def grad_kinetic_basis(basis: Basis) -> Float3xNxN: | ||
| """gradient of the kinetic energy integral over the basis set of atomic orbitals | ||
|
|
||
| Args: | ||
| basis (Basis): basis set of N atomic orbitals | ||
|
|
||
| Returns: | ||
| Float3xNxN: Gradient of the kinetic energy integral with respect to cartesian | ||
| axes evaluated over the NxN combinations of atomic orbitals. | ||
| """ | ||
| return grad_integrate(basis, grad_kinetic_primitives) | ||
|
|
||
|
|
||
| def grad_nuclear_basis(basis: Basis) -> Float3xNxN: | ||
| """gradient of the nuclear attraction integral over the basis set of atomic orbitals | ||
|
|
||
| Args: | ||
| basis (Basis): basis set of N atomic orbitals | ||
|
|
||
| Returns: | ||
| Float3xNxN: Gradient of the nuclear attraction integral with respect to | ||
| cartesian axes evaluated over the NxN combinations of atomic orbitals. | ||
| """ | ||
|
|
||
| def nuclear(c, z): | ||
| op = partial(grad_nuclear_primitives, c=c) | ||
| return z * grad_integrate(basis, op) | ||
|
|
||
| out = vmap(nuclear)(basis.structure.position, basis.structure.atomic_number) | ||
| return jnp.sum(out, axis=0) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| # Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
| from copy import deepcopy | ||
| from typing import Optional | ||
|
|
||
| import chex | ||
|
|
@@ -14,6 +15,7 @@ class Primitive: | |
| alpha: float = 1.0 | ||
| lmn: Int3 = jnp.zeros(3, dtype=jnp.int32) | ||
| norm: Optional[float] = None | ||
| atom_index: Optional[int] = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a back-pointer into the list of which this is a member?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indeed, the problem I haven't solved is not redundantly storing the center on both the primitives and in the |
||
|
|
||
| def __post_init__(self): | ||
| if self.norm is None: | ||
|
|
@@ -26,6 +28,11 @@ def angular_momentum(self) -> int: | |
| def __call__(self, pos: FloatNx3) -> FloatN: | ||
| return eval_primitive(self, pos) | ||
|
|
||
| def offset_lmn(self, axis: int, offset: int) -> "Primitive": | ||
| out = deepcopy(self) | ||
| out.lmn = self.lmn.at[axis].add(offset) | ||
| return out | ||
|
|
||
|
|
||
| def normalize(lmn: Int3, alpha: float) -> float: | ||
| L = jnp.sum(lmn) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| # Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
| import numpy as np | ||
| from numpy.testing import assert_allclose | ||
|
|
||
| from pyscf_ipu.experimental.basis import basisset | ||
| from pyscf_ipu.experimental.interop import to_pyscf | ||
| from pyscf_ipu.experimental.nuclear_gradients import ( | ||
| grad_kinetic_basis, | ||
| grad_nuclear_basis, | ||
| grad_overlap_basis, | ||
| ) | ||
| from pyscf_ipu.experimental.structure import molecule | ||
|
|
||
|
|
||
| def test_nuclear_gradients(): | ||
| basis_name = "sto-3g" | ||
| h2 = molecule("h2") | ||
| scfmol = to_pyscf(h2, basis_name) | ||
| basis = basisset(h2, basis_name) | ||
|
|
||
| actual = grad_overlap_basis(basis) | ||
| expect = scfmol.intor("int1e_ipovlp_cart", comp=3) | ||
| assert_allclose(actual, expect, atol=1e-6) | ||
|
|
||
| actual = grad_kinetic_basis(basis) | ||
| expect = scfmol.intor("int1e_ipkin_cart", comp=3) | ||
| assert_allclose(actual, expect, atol=1e-6) | ||
|
|
||
| # TODO: investigate possible inconsistency in libcint outputs? | ||
| actual = grad_nuclear_basis(basis) | ||
| expect = scfmol.intor("int1e_ipnuc_cart", comp=3) | ||
| expect = -np.moveaxis(expect, 1, 2) | ||
| assert_allclose(actual, expect, atol=1e-6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like to define these local lambdas after the variables to which they refer have been defined. It puts the computation closer to the point of use.
(i.e. move it after line 37)
OTOH,
batch_orbitalsis already doing a lot of Python list comprehensions, so it might be clearer and the same complexity to treatbas a list of lists of primitives, and just inline the list comprehensions here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have a think, this lambda is actually used in a few places now -> maybe it should be promoted from a function-local to a definition in a module?