Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 119 additions & 4 deletions effectful/handlers/jax/_handlers.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
import functools
import itertools
import typing
from collections.abc import Callable, Mapping, Sequence
from collections.abc import Callable, Iterable, Mapping, Sequence
from types import EllipsisType
from typing import Annotated

from opt_einsum import get_symbol
from opt_einsum.parser import parse_einsum_input

try:
import jax
import jax.numpy as jnp
except ImportError:
raise ImportError("JAX is required to use effectful.handlers.jax")

from effectful.internals.runtime import interpreter
from effectful.ops.semantics import apply, evaluate, fvsof, typeof
from effectful.ops.semantics import apply, evaluate, fvsof, fwd, typeof
from effectful.ops.syntax import (
Scoped,
_CustomSingleDispatchCallable,
defdata,
deffn,
defop,
syntactic_eq,
syntactic_hash,
)
from effectful.ops.types import Expr, NotHandled, Operation, Term

Expand Down Expand Up @@ -65,9 +70,13 @@ def update_sizes(sizes, op, size):

def _getitem_sizeof(x: jax.Array, key: tuple[Expr[IndexElement], ...]):
if is_eager_array(x):
for i, k in enumerate(key):
i = 0
for k in key:
if isinstance(k, Term) and len(k.args) == 0 and len(k.kwargs) == 0:
update_sizes(sizes, k.op, x.shape[i])
if k is not None:
i += 1

return defdata(jax_getitem, x, key)

def _apply(op, *args, **kwargs):
Expand All @@ -86,6 +95,13 @@ def _partial_eval(t: Expr[jax.Array]) -> Expr[jax.Array]:
if not sized_fvs:
return t

# if any dimension is zero sized, the result is empty
if any(size == 0 for size in sized_fvs.values()):
ops = tuple(sized_fvs.keys())
key = tuple(k() for k in ops)
shape = tuple(sized_fvs[k] for k in ops)
return jax_getitem(jnp.empty(shape), key)

def _is_eager(t):
return not isinstance(t, Term) or t.op in sized_fvs or is_eager_array(t)

Expand Down Expand Up @@ -135,7 +151,11 @@ def _jax_op(*args, **kwargs) -> jax.Array:
and not isinstance(args[0], Term)
and sized_fvs
and args[1]
and all(isinstance(k, Term) and k.op in sized_fvs for k in args[1])
and all(
(isinstance(k, Term) and k.op in sized_fvs)
or (isinstance(k, slice) and k == slice(None))
for k in args[1]
)
):
raise NotHandled
elif sized_fvs and set(sized_fvs.keys()) == fvsof(tm) - {jax_getitem, _jax_op}:
Expand Down Expand Up @@ -175,6 +195,93 @@ def _jax_op(*args, **kwargs) -> jax.Array:
return _jax_op


def _named_dims(term: Expr[jax.Array]) -> tuple[Operation, ...]:
if not (isinstance(term, Term) and term.op == jax_getitem):
return ()
index = term.args[1]
assert isinstance(index, Iterable)
return tuple(i.op for i in index if isinstance(i, Term) and not i.args)


def _reduce_named(array, axis=None, **kwargs) -> jax.Array:
if axis is None:
return fwd()

named_dims = _named_dims(array)
if not named_dims:
return fwd()

bound_arr = bind_dims(array, *named_dims)

if isinstance(axis, int):
axis = (axis,)
shifted_axis = tuple(a + len(named_dims) if a >= 0 else a for a in axis)

reduced = fwd(bound_arr, axis=shifted_axis, **kwargs)
return unbind_dims(reduced, *named_dims)


def _einsum_named(subscripts, *operands, **kwargs) -> jax.Array:
# only the string-subscripts form is handled; forward the interleaved form
if not isinstance(subscripts, str):
if any(isinstance(x, Term) for x in (subscripts, *operands)):
raise ValueError("Interleaved einsum is not implemented with named tensors")
return jax.numpy.einsum(subscripts, *operands, **kwargs)

# forward if any operand has a symbolic (Term) shape
if any(isinstance(arr.shape, Term) for arr in operands):
raise NotHandled

named = [_named_dims(op) for op in operands]

# normalize: expand ellipses and make the output explicit, using the
# positional shapes (shapes=True avoids materializing the operands)
shapes = [op.shape for op in operands]
in_part, out_part, _ = parse_einsum_input([subscripts, *shapes], shapes=True)
in_specs = in_part.split(",")
assert len(in_specs) == len(operands)

# fresh symbols for named dims, avoiding every symbol already in use;
# get_symbol gives an effectively unlimited supply (spills into unicode)
used = {c for c in (in_part + out_part) if c not in ",->"}
counter = itertools.count()

def next_symbol():
while True:
s = get_symbol(next(counter))
if s not in used:
used.add(s)
return s

# assign a letter per unique named dim; shared names reuse the same letter
# so einsum aligns them as batch dims rather than contracting
letter_of, order = {}, []
for dims in named:
for d in dims:
if d not in letter_of:
letter_of[d] = next_symbol()
order.append(d)

# bind named dims to leading positional axes and prepend their letters
bound, new_in_specs = [], []
for op, dims, spec in zip(operands, named, in_specs):
bound.append(bind_dims(op, *dims) if dims else op)
new_in_specs.append("".join(letter_of[d] for d in dims) + spec)

# add every named dim to the front of the output as passthrough
out_prefix = "".join(letter_of[d] for d in order)
new_subscripts = ",".join(new_in_specs) + "->" + out_prefix + out_part

result = jax.numpy.einsum(new_subscripts, *bound, **kwargs)

# unbind: leading axes correspond to `order`, reindex them back to named
reindexed = jax_getitem(
result,
tuple(d() for d in order) + tuple(slice(None) for _ in range(len(out_part))),
)
return reindexed


@_register_jax_op
def jax_getitem(x: jax.Array, key: tuple[IndexElement, ...]) -> jax.Array:
"""Operation for indexing an array. Unlike the standard __getitem__ method,
Expand Down Expand Up @@ -208,6 +315,8 @@ def bind_dims[T, A, B](
>>> bind_dims(t, b, a).shape
(3, 2)
"""
if isinstance(value, Term) and value.op == bind_dims:
return bind_dims(value.args[0], *(names + tuple(value.args[1:])))
if jax.tree_util.treedef_is_leaf(jax.tree.structure(value)):
return __dispatch(typeof(value))(value, *names)
return jax.tree.map(lambda v: bind_dims(v, *names), value)
Expand Down Expand Up @@ -277,3 +386,9 @@ def _(x: jax.Array, other) -> bool:
and x.shape == other.shape
and bool((jnp.asarray(x) == jnp.asarray(other)).all())
)


@syntactic_hash.register(jax.Array)
def _(x: jax.Array) -> int:
# Concrete arrays aren't hashable; hash by shape, dtype, and bytes.
return hash(("jax.Array", x.shape, str(x.dtype), bytes(jax.numpy.asarray(x))))
98 changes: 63 additions & 35 deletions effectful/handlers/jax/_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@
import effectful.handlers.jax.numpy as jnp
from effectful.handlers.jax._handlers import (
IndexElement,
_partial_eval,
_register_jax_op,
bind_dims,
jax_getitem,
unbind_dims,
)
from effectful.internals.tensor_utils import _desugar_tensor_index
from effectful.internals.unification import Box, nested_type
from effectful.ops.syntax import defdata
from effectful.ops.types import Expr, NotHandled, Operation, Term


@nested_type.register(jax.Array)
@nested_type.register(jax._src.core.Tracer)
def _(value):
return Box(jax.Array)


class _IndexUpdateHelper:
"""Helper class to implement array-style .at[index].set() updates for effectful arrays."""

Expand Down Expand Up @@ -451,54 +457,76 @@ def _bind_dims_array(t: jax.Array, *args: Operation[[], jax.Array]) -> jax.Array
>>> bind_dims(t, b, a).shape
(3, 2)
"""

def _evaluate(expr):
if isinstance(expr, Term):
(args, kwargs) = jax.tree.map(_evaluate, (expr.args, expr.kwargs))
return _partial_eval(expr)
if not jax.tree_util.treedef_is_leaf(jax.tree.structure(expr)):
return jax.tree.map(_evaluate, expr)
return expr

if not isinstance(t, Term):
return t

result = _evaluate(t)
if not isinstance(result, Term) or not args:
return result

# ensure that the result is a jax_getitem with an array as the first argument
if not (result.op is jax_getitem and isinstance(result.args[0], jax.Array)):
if not (t.op is jax_getitem and isinstance(t.args[0], jax.Array)):
raise NotHandled

array = result.args[0]
dims = result.args[1]
array = t.args[0]
dims = t.args[1]
assert isinstance(dims, Sequence)
ndim = len(array.shape)

# ensure that the order is a subset of the named dimensions
order_set = set(args)
if not order_set <= set(a.op for a in dims if isinstance(a, Term)):
raise NotHandled

# permute the inner array so that the leading dimensions are in the order
# specified and the trailing dimensions are the remaining named dimensions
# (or slices)
reindex_dims = [
i
for i, o in enumerate(dims)
if not isinstance(o, Term) or o.op not in order_set
]
dim_ops = [a.op if isinstance(a, Term) else None for a in dims]
perm = (
[dim_ops.index(o) for o in args]
+ reindex_dims
+ list(range(len(dims), len(array.shape)))
)
array = jnp.transpose(array, perm)
reindexed = jax_getitem(
array, (slice(None),) * len(args) + tuple(dims[i] for i in reindex_dims)
def axis_op(ax: int) -> Operation | None:
"""The named op of a bare index term at axis ``ax``, else ``None``."""
if ax < len(dims):
d = dims[ax]
if isinstance(d, Term) and not d.args and not d.kwargs:
return d.op
return None

# Assign an einsum id to every axis of ``array``. Axes that share a named op
# get the *same* id — a repeated op (e.g. ``arr[i(), i()]``) ties its axes
# together, which einsum reads as a diagonal. Every other axis (slices, ints,
# fancy indices, compound terms, and trailing positional axes) gets a unique
# id, so einsum simply carries it through to be reindexed below.
op_ids: dict[Operation, int] = {}
in_ids: list[int] = []
next_id = 0
for ax in range(ndim):
op = axis_op(ax)
if op is not None:
if op not in op_ids:
op_ids[op] = next_id
next_id += 1
in_ids.append(op_ids[op])
else:
in_ids.append(next_id)
next_id += 1

# Output order: bound args that actually appear (in the requested order,
# deduplicated by the diagonal merge), then every remaining axis in
# first-appearance order. einsum does the permutation and the diagonals; with
# all distinct ids retained in the output it performs no reduction.
present_arg_ids = [op_ids[o] for o in args if o in op_ids]
seen = set(present_arg_ids)
rest_ids: list[int] = []
for i in in_ids:
if i not in seen:
seen.add(i)
rest_ids.append(i)

array = jnp.einsum(array, in_ids, present_arg_ids + rest_ids)

# Re-apply the original index for each carried axis and re-name unbound op
# axes. Trailing positional axes (first appearance beyond ``dims``) are left
# for jax_getitem to carry implicitly.
first_pos: dict[int, int] = {}
for ax, i in enumerate(in_ids):
first_pos.setdefault(i, ax)

index_expr = (slice(None),) * len(present_arg_ids) + tuple(
dims[first_pos[i]] if first_pos[i] < len(dims) else slice(None)
for i in rest_ids
)
return reindexed
return jax_getitem(array, index_expr)


@unbind_dims.register(jax.Array) # type: ignore
Expand Down
Loading