diff --git a/effectful/handlers/jax/_handlers.py b/effectful/handlers/jax/_handlers.py index 308cdb76e..7516adbff 100644 --- a/effectful/handlers/jax/_handlers.py +++ b/effectful/handlers/jax/_handlers.py @@ -1,9 +1,13 @@ import functools +import itertools import typing -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from types import EllipsisType from typing import Annotated +from opt_einsum import get_symbol +from opt_einsum.parser import parse_einsum_input + try: import jax import jax.numpy as jnp @@ -11,7 +15,7 @@ raise ImportError("JAX is required to use effectful.handlers.jax") from effectful.internals.runtime import interpreter -from effectful.ops.semantics import apply, evaluate, fvsof, typeof +from effectful.ops.semantics import apply, evaluate, fvsof, fwd, typeof from effectful.ops.syntax import ( Scoped, _CustomSingleDispatchCallable, @@ -19,6 +23,7 @@ deffn, defop, syntactic_eq, + syntactic_hash, ) from effectful.ops.types import Expr, NotHandled, Operation, Term @@ -65,9 +70,13 @@ def update_sizes(sizes, op, size): def _getitem_sizeof(x: jax.Array, key: tuple[Expr[IndexElement], ...]): if is_eager_array(x): - for i, k in enumerate(key): + i = 0 + for k in key: if isinstance(k, Term) and len(k.args) == 0 and len(k.kwargs) == 0: update_sizes(sizes, k.op, x.shape[i]) + if k is not None: + i += 1 + return defdata(jax_getitem, x, key) def _apply(op, *args, **kwargs): @@ -86,6 +95,13 @@ def _partial_eval(t: Expr[jax.Array]) -> Expr[jax.Array]: if not sized_fvs: return t + # if any dimension is zero sized, the result is empty + if any(size == 0 for size in sized_fvs.values()): + ops = tuple(sized_fvs.keys()) + key = tuple(k() for k in ops) + shape = tuple(sized_fvs[k] for k in ops) + return jax_getitem(jnp.empty(shape), key) + def _is_eager(t): return not isinstance(t, Term) or t.op in sized_fvs or is_eager_array(t) @@ -135,7 +151,11 @@ def _jax_op(*args, **kwargs) -> jax.Array: and not isinstance(args[0], Term) and sized_fvs and args[1] - and all(isinstance(k, Term) and k.op in sized_fvs for k in args[1]) + and all( + (isinstance(k, Term) and k.op in sized_fvs) + or (isinstance(k, slice) and k == slice(None)) + for k in args[1] + ) ): raise NotHandled elif sized_fvs and set(sized_fvs.keys()) == fvsof(tm) - {jax_getitem, _jax_op}: @@ -175,6 +195,93 @@ def _jax_op(*args, **kwargs) -> jax.Array: return _jax_op +def _named_dims(term: Expr[jax.Array]) -> tuple[Operation, ...]: + if not (isinstance(term, Term) and term.op == jax_getitem): + return () + index = term.args[1] + assert isinstance(index, Iterable) + return tuple(i.op for i in index if isinstance(i, Term) and not i.args) + + +def _reduce_named(array, axis=None, **kwargs) -> jax.Array: + if axis is None: + return fwd() + + named_dims = _named_dims(array) + if not named_dims: + return fwd() + + bound_arr = bind_dims(array, *named_dims) + + if isinstance(axis, int): + axis = (axis,) + shifted_axis = tuple(a + len(named_dims) if a >= 0 else a for a in axis) + + reduced = fwd(bound_arr, axis=shifted_axis, **kwargs) + return unbind_dims(reduced, *named_dims) + + +def _einsum_named(subscripts, *operands, **kwargs) -> jax.Array: + # only the string-subscripts form is handled; forward the interleaved form + if not isinstance(subscripts, str): + if any(isinstance(x, Term) for x in (subscripts, *operands)): + raise ValueError("Interleaved einsum is not implemented with named tensors") + return jax.numpy.einsum(subscripts, *operands, **kwargs) + + # forward if any operand has a symbolic (Term) shape + if any(isinstance(arr.shape, Term) for arr in operands): + raise NotHandled + + named = [_named_dims(op) for op in operands] + + # normalize: expand ellipses and make the output explicit, using the + # positional shapes (shapes=True avoids materializing the operands) + shapes = [op.shape for op in operands] + in_part, out_part, _ = parse_einsum_input([subscripts, *shapes], shapes=True) + in_specs = in_part.split(",") + assert len(in_specs) == len(operands) + + # fresh symbols for named dims, avoiding every symbol already in use; + # get_symbol gives an effectively unlimited supply (spills into unicode) + used = {c for c in (in_part + out_part) if c not in ",->"} + counter = itertools.count() + + def next_symbol(): + while True: + s = get_symbol(next(counter)) + if s not in used: + used.add(s) + return s + + # assign a letter per unique named dim; shared names reuse the same letter + # so einsum aligns them as batch dims rather than contracting + letter_of, order = {}, [] + for dims in named: + for d in dims: + if d not in letter_of: + letter_of[d] = next_symbol() + order.append(d) + + # bind named dims to leading positional axes and prepend their letters + bound, new_in_specs = [], [] + for op, dims, spec in zip(operands, named, in_specs): + bound.append(bind_dims(op, *dims) if dims else op) + new_in_specs.append("".join(letter_of[d] for d in dims) + spec) + + # add every named dim to the front of the output as passthrough + out_prefix = "".join(letter_of[d] for d in order) + new_subscripts = ",".join(new_in_specs) + "->" + out_prefix + out_part + + result = jax.numpy.einsum(new_subscripts, *bound, **kwargs) + + # unbind: leading axes correspond to `order`, reindex them back to named + reindexed = jax_getitem( + result, + tuple(d() for d in order) + tuple(slice(None) for _ in range(len(out_part))), + ) + return reindexed + + @_register_jax_op def jax_getitem(x: jax.Array, key: tuple[IndexElement, ...]) -> jax.Array: """Operation for indexing an array. Unlike the standard __getitem__ method, @@ -208,6 +315,8 @@ def bind_dims[T, A, B]( >>> bind_dims(t, b, a).shape (3, 2) """ + if isinstance(value, Term) and value.op == bind_dims: + return bind_dims(value.args[0], *(names + tuple(value.args[1:]))) if jax.tree_util.treedef_is_leaf(jax.tree.structure(value)): return __dispatch(typeof(value))(value, *names) return jax.tree.map(lambda v: bind_dims(v, *names), value) @@ -277,3 +386,9 @@ def _(x: jax.Array, other) -> bool: and x.shape == other.shape and bool((jnp.asarray(x) == jnp.asarray(other)).all()) ) + + +@syntactic_hash.register(jax.Array) +def _(x: jax.Array) -> int: + # Concrete arrays aren't hashable; hash by shape, dtype, and bytes. + return hash(("jax.Array", x.shape, str(x.dtype), bytes(jax.numpy.asarray(x)))) diff --git a/effectful/handlers/jax/_terms.py b/effectful/handlers/jax/_terms.py index 812062931..53c4c094f 100644 --- a/effectful/handlers/jax/_terms.py +++ b/effectful/handlers/jax/_terms.py @@ -8,17 +8,23 @@ import effectful.handlers.jax.numpy as jnp from effectful.handlers.jax._handlers import ( IndexElement, - _partial_eval, _register_jax_op, bind_dims, jax_getitem, unbind_dims, ) from effectful.internals.tensor_utils import _desugar_tensor_index +from effectful.internals.unification import Box, nested_type from effectful.ops.syntax import defdata from effectful.ops.types import Expr, NotHandled, Operation, Term +@nested_type.register(jax.Array) +@nested_type.register(jax._src.core.Tracer) +def _(value): + return Box(jax.Array) + + class _IndexUpdateHelper: """Helper class to implement array-style .at[index].set() updates for effectful arrays.""" @@ -451,54 +457,76 @@ def _bind_dims_array(t: jax.Array, *args: Operation[[], jax.Array]) -> jax.Array >>> bind_dims(t, b, a).shape (3, 2) """ - - def _evaluate(expr): - if isinstance(expr, Term): - (args, kwargs) = jax.tree.map(_evaluate, (expr.args, expr.kwargs)) - return _partial_eval(expr) - if not jax.tree_util.treedef_is_leaf(jax.tree.structure(expr)): - return jax.tree.map(_evaluate, expr) - return expr - if not isinstance(t, Term): return t - result = _evaluate(t) - if not isinstance(result, Term) or not args: - return result - # ensure that the result is a jax_getitem with an array as the first argument - if not (result.op is jax_getitem and isinstance(result.args[0], jax.Array)): + if not (t.op is jax_getitem and isinstance(t.args[0], jax.Array)): raise NotHandled - array = result.args[0] - dims = result.args[1] + array = t.args[0] + dims = t.args[1] assert isinstance(dims, Sequence) + ndim = len(array.shape) # ensure that the order is a subset of the named dimensions order_set = set(args) if not order_set <= set(a.op for a in dims if isinstance(a, Term)): raise NotHandled - # permute the inner array so that the leading dimensions are in the order - # specified and the trailing dimensions are the remaining named dimensions - # (or slices) - reindex_dims = [ - i - for i, o in enumerate(dims) - if not isinstance(o, Term) or o.op not in order_set - ] - dim_ops = [a.op if isinstance(a, Term) else None for a in dims] - perm = ( - [dim_ops.index(o) for o in args] - + reindex_dims - + list(range(len(dims), len(array.shape))) - ) - array = jnp.transpose(array, perm) - reindexed = jax_getitem( - array, (slice(None),) * len(args) + tuple(dims[i] for i in reindex_dims) + def axis_op(ax: int) -> Operation | None: + """The named op of a bare index term at axis ``ax``, else ``None``.""" + if ax < len(dims): + d = dims[ax] + if isinstance(d, Term) and not d.args and not d.kwargs: + return d.op + return None + + # Assign an einsum id to every axis of ``array``. Axes that share a named op + # get the *same* id — a repeated op (e.g. ``arr[i(), i()]``) ties its axes + # together, which einsum reads as a diagonal. Every other axis (slices, ints, + # fancy indices, compound terms, and trailing positional axes) gets a unique + # id, so einsum simply carries it through to be reindexed below. + op_ids: dict[Operation, int] = {} + in_ids: list[int] = [] + next_id = 0 + for ax in range(ndim): + op = axis_op(ax) + if op is not None: + if op not in op_ids: + op_ids[op] = next_id + next_id += 1 + in_ids.append(op_ids[op]) + else: + in_ids.append(next_id) + next_id += 1 + + # Output order: bound args that actually appear (in the requested order, + # deduplicated by the diagonal merge), then every remaining axis in + # first-appearance order. einsum does the permutation and the diagonals; with + # all distinct ids retained in the output it performs no reduction. + present_arg_ids = [op_ids[o] for o in args if o in op_ids] + seen = set(present_arg_ids) + rest_ids: list[int] = [] + for i in in_ids: + if i not in seen: + seen.add(i) + rest_ids.append(i) + + array = jnp.einsum(array, in_ids, present_arg_ids + rest_ids) + + # Re-apply the original index for each carried axis and re-name unbound op + # axes. Trailing positional axes (first appearance beyond ``dims``) are left + # for jax_getitem to carry implicitly. + first_pos: dict[int, int] = {} + for ax, i in enumerate(in_ids): + first_pos.setdefault(i, ax) + + index_expr = (slice(None),) * len(present_arg_ids) + tuple( + dims[first_pos[i]] if first_pos[i] < len(dims) else slice(None) + for i in rest_ids ) - return reindexed + return jax_getitem(array, index_expr) @unbind_dims.register(jax.Array) # type: ignore diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py new file mode 100644 index 000000000..96b708e89 --- /dev/null +++ b/effectful/handlers/jax/monoid.py @@ -0,0 +1,589 @@ +import functools +import logging +import typing +from typing import Protocol + +import jax +import jax.core +import opt_einsum +from opt_einsum import get_symbol + +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import bind_dims, jax_getitem, unbind_dims +from effectful.handlers.jax._handlers import is_eager_array +from effectful.handlers.jax.scipy.special import logsumexp +from effectful.ops.monoid import ( + CartesianProduct, + Max, + Min, + Monoid, + NormalizeIntp, + Product, + Streams, + Sum, + _is_monoid_plus, + choose_contraction, + distributes_over, +) +from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof +from effectful.ops.syntax import ObjectInterpretation, deffn, implements +from effectful.ops.types import Interpretation, NotHandled, Operation, Term + +logger = logging.getLogger(__name__) + + +def cartesian_prod(x, y): + if x.ndim == 1: + x = x[:, None] + if y.ndim == 1: + y = y[:, None] + nx, dx = x.shape + ny, dy = y.shape + # Broadcast into (nx, ny, dx+dy), then flatten the first two axes + x_b = jnp.broadcast_to(x[:, None, :], (nx, ny, dx)) + y_b = jnp.broadcast_to(y[None, :, :], (nx, ny, dy)) + return jnp.concatenate([x_b, y_b], axis=-1).reshape(nx * ny, dx + dy) + + +LogSumExp = Monoid(name="LogSumExp", identity=jnp.asarray(float("-inf"))) + +# ``Sum`` in log space is multiplication, which distributes over ``LogSumExp``: +# a + logsumexp(b, c) = logsumexp(a + b, a + c) +distributes_over.register(Sum, LogSumExp) + + +def _jax_args(args): + """True iff ``args`` is non-empty and every arg is a concrete + :class:`jax.typing.ArrayLike` or named tensor. At least one argument must be + a jax-related type. + + """ + return ( + bool(args) + and all(is_eager_array(a) or isinstance(a, jax.typing.ArrayLike) for a in args) + and any(is_eager_array(a) or isinstance(a, jax.Array) for a in args) + ) + + +class PlusJaxUpcast(ObjectInterpretation): + @implements(Monoid.plus) + def plus(self, monoid, *args): + arg_types = [typeof(a) for a in args] + + def _is_jax(t): + return issubclass(t, jax.Array | jax.core.Tracer) + + # exists array valued and non-array-valued args + if any(_is_jax(t) for t in arg_types) and any( + not _is_jax(t) for t in arg_types + ): + return monoid.plus( + *( + a if _is_jax(t) else jnp.asarray(a) + for (a, t) in zip(args, arg_types, strict=True) + ) + ) + + return fwd() + + +class SumPlusJax(ObjectInterpretation): + @implements(Sum.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.add, args) + + +class ProductPlusJax(ObjectInterpretation): + @implements(Product.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.multiply, args) + + +class MinPlusJax(ObjectInterpretation): + @implements(Min.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.minimum, args) + + +class MaxPlusJax(ObjectInterpretation): + @implements(Max.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.maximum, args) + + +class LogSumExpPlusJax(ObjectInterpretation): + @implements(LogSumExp.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.logaddexp, args) + + +class CartesianProductPlusJax(ObjectInterpretation): + @implements(CartesianProduct.plus) + def plus(self, *args): + # Skip identity ``[()]`` args; short-circuit on zero ``[]``. Both + # sentinels arrive as Python lists alongside jax-array factors, so + # check for them explicitly before composing. + if not any(isinstance(a, jax.Array) for a in args): + return fwd() + result = None + for a in args: + if a is CartesianProduct.zero: + return CartesianProduct.zero + if a is CartesianProduct.identity: + continue + if not isinstance(a, jax.Array): + return fwd() + result = a if result is None else cartesian_prod(result, a) + if result is None: + return CartesianProduct.identity + # CartesianProduct values are streams of rows. ``cartesian_prod`` + # already lifts 1D inputs to 2D, but a single-array call seeds + # ``result = a`` unchanged — promote so the rank invariant holds for + # every array-path return. + if result.ndim == 1: + result = result[:, None] + return result + + +class ReduceArrayGather(ObjectInterpretation): + """Split an array-valued stream into an index range and a length-1 stream: + + M.reduce(body, {k: a} ∪ S) ≡ M.reduce(body, {i: range(a.shape[0]), k: (a[i()],)} ∪ S) + + where ``i`` is fresh and ``a[i()] = unbind_dims(a, i)``. The length-1 stream + ``{k: (a[i()],)}`` is then eliminated by + :class:`~effectful.ops.monoid.EliminateSingletonStreams`, which substitutes + ``k := a[i()]`` into the body and the remaining streams. Together the two + steps perform the gather + ``M.reduce(body[k := a[i()]], {i: range(a.shape[0])} ∪ S)``. + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if typeof(body) is not jax.Array: + return fwd() + + if isinstance(body, Term) and body.op is delta: + return fwd() + + body_fvs = fvsof(body) + stream_keys = set(streams) + + new_streams: dict = {} + progress = False + for k, v in streams.items(): + if is_eager_array(v) and k in body_fvs and not (fvsof(v) & stream_keys): + index = Operation.define(k) + new_streams[index] = range(v.shape[0]) + new_streams[k] = (unbind_dims(v, index),) + progress = True + else: + new_streams[k] = v + + if not progress: + return fwd() + + return monoid.reduce(body, new_streams) + + +class Reductor(Protocol): + def __call__( + self, arr: jax.Array, axis: int | tuple[int, ...] | None = None + ) -> jax.Array: ... + + +ARRAY_REDUCTORS: dict[Monoid, Reductor] = {} +for monoid, func in [ + (Sum, jnp.sum), + (Product, jnp.prod), + (Min, jnp.min), + (Max, jnp.max), +]: + assert isinstance(monoid, Monoid) + assert callable(func) + ARRAY_REDUCTORS[monoid] = functools.partial(func, initial=monoid.identity) + +ARRAY_REDUCTORS[LogSumExp] = logsumexp + + +class ReduceArray(ObjectInterpretation): + """Reduce an array body over range streams.""" + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + reductor = ARRAY_REDUCTORS.get(monoid, None) + if reductor is None: + return fwd() + + if typeof(body) is not jax.Array: + return fwd() + + pos_dims = {} + if isinstance(body, Term): + if body.op == delta: + pos_dims = { + d.op + for d in body.args[0] + if isinstance(d, Term) and d.op in streams + } + elif _is_monoid_plus(body.op) and distributes_over( + body.op.__self__, monoid + ): + # delegate to factorization + return fwd() + + body_fvs = fvsof(body) + used = { + k + for k, v in streams.items() + if k in body_fvs and k not in pos_dims and isinstance(v, range) + } + if not used: + return fwd() + + delta_key = tuple(k() for k in streams if k in used) + arr = monoid.reduce(delta(delta_key, body), streams) + reduced_body = reductor(arr, axis=tuple(range(len(used)))) + return reduced_body + + +@Operation.define +def delta(_index: tuple[int, ...], _weight: jax.Array) -> jax.Array: + raise NotHandled + + +def _range_stop(term: Term): + assert term.op == jnp.arange + if "stop" in term.kwargs: + return term.kwargs["stop"] + if len(term.args) < 2: + return term.args[0] + return term.args[1] + + +class DeltaEmpty(ObjectInterpretation): + """delta((), weight) ≡ weight""" + + @implements(delta) + def _(self, index, weight): + if not index: + return weight + return fwd() + + +class DeltaFusion(ObjectInterpretation): + """delta(i1, delta(i2, weight)) ≡ delta(i1 ++ i2, weight)""" + + @implements(delta) + def _(self, index, weight): + if isinstance(weight, Term) and weight.op == delta: + return delta(index + weight.args[0], weight.args[1]) + return fwd() + + +class ReduceDeltaSimpleRange(ObjectInterpretation): + """Eliminate a Delta that has independent, dense index arguments. + + + reduce(M, streams ∪ {v: range(N)}, delta((v(),) ++ idx', body)) + ═══════════════════════════════════════════════════════════════════════════ + bind_dims(reduce(M, streams, delta(idx', body[v() := unbind_dims(streams[v], fv)])), fv) + """ + + @implements(Monoid.reduce) + def _(self, monoid: Monoid, body, streams: Streams): + if not (isinstance(body, Term) and body.op == delta): + return fwd() + + index, weight = body.args + assert isinstance(index, tuple) + + if not index: + return fwd() + + head_index, tail_index = index[0], index[1:] + if not (isinstance(head_index, Term) and head_index.op in streams): + return fwd() + + head_op: Operation = head_index.op + head_stream = streams[head_op] + if not ( + isinstance(head_stream, range) + and head_stream.start == 0 + and head_stream.step == 1 + ): + return fwd() + + tail_streams = {k: v for (k, v) in streams.items() if k != head_op} + + # peel the head index: substitute it into the weight (slicing direct + # uses, materializing the rest) along a fresh named dim, but bind that + # dim only *after* the surrounding reduce -- see the class docstring. + + fresh_op = Operation.define(head_op) + + def _jax_getitem(arr, index): + inner_index, outer_index = [], [] + progress = False + for i in index: + if isinstance(i, Term) and i.op == head_op: + inner_index.append( + slice(head_stream.start, head_stream.stop, head_stream.step) + ) + outer_index.append(fresh_op()) + progress = True + else: + inner_index.append(slice(None)) + outer_index.append(i) + if progress: + return jax_getitem(jax_getitem(arr, inner_index), outer_index) + return fwd(arr, index) + + slice_subst = typing.cast(Interpretation, {jax_getitem: _jax_getitem}) + sliced_weight = handler(slice_subst)(evaluate)(weight) + sliced_streams = handler(slice_subst)(evaluate)(tail_streams) + + gather_subst = typing.cast( + Interpretation, + { + head_op: deffn( + unbind_dims( + jnp.arange( + head_stream.start, head_stream.stop, head_stream.step + ), + fresh_op, + ) + ) + }, + ) + gathered_weight = handler(gather_subst)(evaluate)(sliced_weight) + gathered_streams = handler(gather_subst)(evaluate)(sliced_streams) + + inner = ( + monoid.reduce(delta(tail_index, gathered_weight), gathered_streams) + if gathered_streams + else gathered_weight + ) + return bind_dims(inner, fresh_op) + + +class ReduceDependentRangeMask(ObjectInterpretation): + """Eliminate a dependent range by masking. + + reduce(M, streams ∪ {u: range(N), v: range(u())}, body) + ═══════════════════════════════════════════════════════════════════════════ + reduce(M, streams ∪ {u: range(N), v: range(N)}, where(v() < u(), body, M.identity)) + + Currently recognises only the lower-triangular form ``v: range(u())``: + constant start of 0, dependent stop equal to a bare call of another + stream var. + + Not yet supported: + + - **Upper-triangular** (``v: range(u(), N)`` — constant stop, dependent + start): bbox becomes ``range(0, N)`` (or ``range(0, bbox_N)``), guard + becomes ``v() >= u()``. Same shape of rewrite as lower-tri; differs + only in which side of the range carries the stream-var reference and + in the predicate direction. + - **Banded** (``v: range(u() - k, u() + k + 1)`` — two-sided dependent + bounds with constant width): bbox is ``range(0, N + k)`` (or similar + bounded by both endpoints' extents), guard is + ``(v() >= u() - k) & (v() < u() + k + 1)``. Needs both-sides + affine-bound recognition. + - **Strided dependent** (``v: range(0, u(), k)`` for ``k != 1``): bbox + stays ``range(0, N)`` and guard becomes + ``(v() < u()) & (v() % k == 0)`` (or equivalent), or alternatively + embed in a smaller bbox ``range(0, ceil(N/k))`` and remap the index. + - **Affine bounds** (``v: range(a*u() + b, c*u() + d)`` for affine + coefficients): bbox computed from ``ub(c*u() + d)`` over ``u``'s + range; guard is the conjunction of the two affine constraints. This + subsumes the upper/banded/strided cases under one affine recogniser. + - **Multi-stream-var dependent** (``v: range(u() + w())`` referencing + more than one outer stream var): bbox is the affine combination over + both referents' ranges; guard threads through all dependencies. + - **Reverse-order dependent ranges**: e.g. ``v: range(u(), 0, -1)``; + needs to handle negative step and the corresponding reverse + enumeration. + """ + + @implements(Monoid.reduce) + def _(self, monoid: Monoid, body, streams: Streams): + stream_vars = set(streams.keys()) + + # streams of the form k: range(X) + simple_ranges = { + k: v + for (k, v) in streams.items() + if isinstance(v, range) and v.start == 0 and v.step == 1 + } + for u, u_stream in simple_ranges.items(): + if fvsof(u_stream) & stream_vars: + continue + + for v, v_stream in streams.items(): + if ( + isinstance(v_stream, Term) + and v_stream.op == jnp.arange + and isinstance(_range_stop(v_stream), Term) + and _range_stop(v_stream).op == u + ): + fresh_streams = { + a: (u_stream if a == v else b) for (a, b) in streams.items() + } + + # there are other commuting rules for delta that we do not + # currently include + if isinstance(body, Term) and body.op == delta: + fresh_body = delta( + body.args[0], + jnp.where(v() < u(), body.args[1], monoid.identity), # type: ignore[arg-type] + ) + else: + fresh_body = jnp.where(v() < u(), body, monoid.identity) + + return monoid.reduce(fresh_body, fresh_streams) + + return fwd() + + +# Cross-cutting delta rules not yet implemented: +# +# - **Delta-commuting** (DC-hoist): for any pure op ``f`` (no Scoped binders +# that intersect a delta's index ops), push delta outward: +# f(args..., delta(idx, body), args...) +# ≡ delta(idx, f(args..., body, args...)) +# This normalizes delta to the outermost position so the reduce rules can +# pattern-match ``isinstance(body, Term) and body.op == delta`` cleanly. +# The soundness condition is mechanical via ``op.__fvs_rule__``: refuse to +# commute when a non-delta arg's scope binds any op in the delta's idx. +# +# - **Delta-merging** (DC-merge): under a pure binary op ``f`` (or +# generalized n-ary), merge multiple deltas when their index tuples are +# subsequence-compatible: +# f(delta(idx_a, v), delta(idx_b, w)) ≡ delta(idx_max, f(v, w)) +# where ``idx_max`` is the longer of ``idx_a``, ``idx_b`` and ``idx_a`` is +# a subsequence of ``idx_b`` (or vice versa). Refuse to fire when neither +# is a subsequence of the other, since that would silently insert an +# outer-product broadcast. + + +class ContractLongestArrayStream(ObjectInterpretation): + @implements(choose_contraction) + def _(self, factors, streams): + lengths = { + k: v.shape[0] if isinstance(v, jax.Array) and v.shape else 0 + for (k, v) in streams.items() + } + longest = max(lengths.values()) + return fwd( + factors, {k: v for (k, v) in streams.items() if lengths[k] == longest} + ) + + +class ReduceSumProductContraction(ObjectInterpretation): + """Fast-path a sum-of-products contraction.""" + + @implements(Sum.reduce) + def _(self, body, streams: Streams): + if not ( + isinstance(body, Term) + and _is_monoid_plus(body.op) + and body.op.__self__ is Product + ): + return fwd() + + factors = body.args + if len(factors) != 2 or not all( + issubclass(typeof(f), jax.Array) for f in factors + ): + return fwd() + + (lhs, rhs) = factors + stream_vars = set(streams.keys()) + + # a fully factored reduce only has streams that are used by all factors + shared = fvsof(lhs) & fvsof(rhs) & stream_vars + if shared != stream_vars: + return fwd() + + if not all(isinstance(v, range) for v in streams.values()): + return fwd() + + # create leading reduction dimensions + delta_key = tuple(k() for k in streams) + pos_lhs = Sum.reduce(delta(delta_key, lhs), streams) + pos_rhs = Sum.reduce(delta(delta_key, rhs), streams) + + dims = "".join(get_symbol(i) for i in range(len(streams))) + contraction = jnp.einsum(f"{dims}...,{dims}...->...", pos_lhs, pos_rhs) + return contraction + + +@jax.jit(static_argnums=(0,)) +def einsum(subscripts: str, /, *operands: jax.Array) -> jax.Array: + """Evaluate an einsum expression using monoid reductions.""" + if not operands: + raise ValueError("einsum requires at least one operand") + + in_spec, out_spec, _ = opt_einsum.parser.parse_einsum_input( + [subscripts, *(op.shape for op in operands)], shapes=True + ) + in_specs = in_spec.split(",") + + all_letters = set(out_spec) | {c for s in in_specs for c in s} + ops = {c: Operation.define(jax.Array, name=c) for c in all_letters} + + sizes: dict[str, int] = {} + for spec, op in zip(in_specs, operands, strict=True): + for l, s in zip(spec, op.shape, strict=True): + if l in sizes and sizes[l] != s: + raise ValueError(f"Dimension {l} given sizes {s} and {sizes[l]}") + else: + sizes[l] = s + for c in out_spec: + if c not in sizes: + raise ValueError(f"einsum: output index {c!r} not present in any input") + + arrays = [Operation.define(jax.Array) for _ in operands] + factors = [ + unbind_dims(arr(), *(ops[c] for c in spec)) + for arr, spec in zip(arrays, in_specs, strict=True) + ] + body = Product.plus(*factors) + + out_tuple = tuple(ops[c]() for c in out_spec) + streams = {op: range(sizes[c]) for c, op in ops.items()} + with handler(NormalizeIntp): + norm = deffn(Sum.reduce(delta(out_tuple, body), streams), *arrays) + result = norm(*operands) + assert isinstance(result, jax.Array) + return result + + +NormalizeIntp.extend( + ReduceArray(), + ReduceSumProductContraction(), + ReduceArrayGather(), + ReduceDeltaSimpleRange(), + ReduceDependentRangeMask(), + DeltaEmpty(), + DeltaFusion(), + SumPlusJax(), + ProductPlusJax(), + MinPlusJax(), + MaxPlusJax(), + LogSumExpPlusJax(), + CartesianProductPlusJax(), + ContractLongestArrayStream(), + PlusJaxUpcast(), +) diff --git a/effectful/handlers/jax/numpy/__init__.py b/effectful/handlers/jax/numpy/__init__.py index 990830d27..f2d2affa4 100644 --- a/effectful/handlers/jax/numpy/__init__.py +++ b/effectful/handlers/jax/numpy/__init__.py @@ -1,24 +1,43 @@ +from types import NoneType from typing import TYPE_CHECKING import jax.numpy -from .._handlers import _register_jax_op, _register_jax_op_no_partial_eval +from effectful.handlers.jax._handlers import ( + _einsum_named, + _reduce_named, + _register_jax_op, + _register_jax_op_no_partial_eval, +) +from effectful.ops.semantics import handler +from effectful.ops.types import Operation -_no_overload = ["array", "asarray"] +_NO_OVERLOAD = ["array", "asarray"] +_REDUCTION = ["sum", "prod", "min", "max", "any", "all", "mean", "argmax"] for name, op in jax.numpy.__dict__.items(): - if not callable(op): + wrapped_value = None + if type(op) in (float, NoneType): + wrapped_value = op + elif name in _NO_OVERLOAD: + wrapped_value = _register_jax_op_no_partial_eval(op) + elif callable(op): + wrapped_value = _register_jax_op(op) + else: continue - jax_op = ( - _register_jax_op_no_partial_eval(op) - if name in _no_overload - else _register_jax_op(op) - ) - globals()[name] = jax_op + globals()[name] = wrapped_value -pi = jax.numpy.pi +for name in _REDUCTION: + op = globals()[name] + globals()[name] = handler({op: _reduce_named})(op) + +# einsum = effectful.handlers.jax._handlers.einsum +# tensordot = handler({tensordot: _tensordot_named})(tensordot) + + +einsum = Operation.define(_einsum_named) # Tell mypy about our wrapped functions. if TYPE_CHECKING: - from jax.numpy import * # noqa: F403 + from jax.numpy import * # type: ignore[assignment] # noqa: F403 diff --git a/effectful/handlers/jax/scipy/special.py b/effectful/handlers/jax/scipy/special.py index afe1334b8..67b99621a 100644 --- a/effectful/handlers/jax/scipy/special.py +++ b/effectful/handlers/jax/scipy/special.py @@ -2,9 +2,11 @@ import jax.scipy.special -from effectful.handlers.jax._handlers import _register_jax_op +from effectful.handlers.jax._handlers import _reduce_named, _register_jax_op +from effectful.ops.semantics import handler logsumexp = _register_jax_op(jax.scipy.special.logsumexp) +logsumexp = handler({logsumexp: _reduce_named})(logsumexp) # Tell mypy about our wrapped functions. if TYPE_CHECKING: diff --git a/effectful/internals/product_n.py b/effectful/internals/product_n.py index 4b8bd2a81..87a9c6a42 100644 --- a/effectful/internals/product_n.py +++ b/effectful/internals/product_n.py @@ -69,7 +69,7 @@ def map_structure(func, expr): else: return type(expr)(map_structure(func, tuple(expr.items()))) elif isinstance(expr, collections.abc.Sequence): - if isinstance(expr, str | bytes): + if isinstance(expr, str | bytes | range): return expr elif ( isinstance(expr, tuple) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index e425bba6c..2eadaeab6 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -556,6 +556,17 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions: and issubclass(subtyp, typing.get_origin(typ)) ): return subs # implicit expansion to subtyp[Any] + elif isinstance(typ, GenericAlias): + # Special case for treating arrays as iterables of arrays + try: + import jax + + if typing.get_origin(typ) is collections.abc.Iterable and issubclass( + subtyp, jax.Array + ): + return unify(typing.get_args(typ)[0], jax.Array, subs) + except ImportError: + pass raise TypeError(f"Cannot unify generic type {typ} with {subtyp} given {subs}.") diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py new file mode 100644 index 000000000..067f29e55 --- /dev/null +++ b/effectful/ops/monoid.py @@ -0,0 +1,963 @@ +import collections.abc +import functools +import itertools +import operator +import typing +from collections import Counter, UserDict, defaultdict +from collections.abc import Callable, Generator, Iterable, Mapping, Sequence +from dataclasses import dataclass +from graphlib import TopologicalSorter +from typing import Annotated, Any + +from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler, typeof +from effectful.ops.syntax import ( + ObjectInterpretation, + Scoped, + defdata, + deffn, + implements, + syntactic_eq, + syntactic_hash, +) +from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term + +type Stream[T] = Iterable[T] + +type Streams = Mapping[Operation[[], Any], Stream[Any]] + +type Body[T] = ( + Iterable[T] + | Callable[..., Body[T]] + | T + | Mapping[Any, Body[T]] + | Interpretation[T, Body[T]] +) + + +def outer_stream(streams: Streams) -> Iterable[tuple[Operation, Stream, Streams]]: + """Returns the streams that can be ordered outermost in the loop nest as + well as the remaining streams in the nest. + + """ + stream_vars = set(streams.keys()) + pred = {k: fvsof(v) & stream_vars for k, v in streams.items()} + topo = TopologicalSorter(pred) + topo.prepare() + return ( + (op, streams[op], {k: v for (k, v) in streams.items() if k != op}) + for op in topo.get_ready() + ) + + +def inner_stream( + streams: dict[Operation, Expr], +) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]: + """Returns the streams that can be ordered innermost in the loop nest as + well as the remaining streams in the nest. + + """ + stream_vars = set(streams.keys()) + + no_dependents = set() + succ = defaultdict(set) + for k, v in streams.items(): + preds = fvsof(v) & stream_vars + if preds: + for pred in preds: + succ[pred].add(k) + else: + no_dependents.add(k) + + topo = TopologicalSorter(succ) + topo.prepare() + return ( + ({k: v for (k, v) in streams.items() if k != op}, op, streams[op]) + for op in set(topo.get_ready()) | no_dependents + ) + + +def inner_streams_first(streams: dict[Operation, Expr]) -> Iterable[Operation]: + """Iterable over streams where dependent streams precede their dependencies.""" + stream_vars = set(streams.keys()) + + no_dependents = set() + succ = defaultdict(set) + for k, v in streams.items(): + preds = fvsof(v) & stream_vars + if preds: + for pred in preds: + succ[pred].add(k) + else: + no_dependents.add(k) + + topo = TopologicalSorter(succ) + return topo.static_order() + + +class Monoid[W]: + """A monoid with ``plus`` and ``reduce`` :class:`Operation` s.""" + + _name: str + identity: W + + def __init__(self, identity: W, name: str): + self._name = name + self.identity = identity + + def __repr__(self): + return f"Monoid({self._name!r})" + + def __eq__(self, other): + return id(self) == id(other) + + def __hash__(self): + return hash(id(self)) + + @Operation.define + def plus(self, *args: W) -> W: + """Monoid addition. Handlers supply per-monoid and broadcasting + behavior; the default rule only handles identity and zero cases (for + monoids that have a zero). + + """ + if hasattr(self, "zero") and any(a is self.zero for a in args): + return self.zero + + nonident_args = [a for a in args if a is not self.identity] + if len(nonident_args) != len(args): + return self.plus(*nonident_args) + + return defdata(self.plus, *nonident_args) # type: ignore[return-value] + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + """Reduce ``body`` over ``streams``. Handlers supply per-monoid and + broadcasting behavior; the default rule only handles the empty-stream + case. + """ + raise NotHandled + + @Operation.define + def weighted[T]( + self, stream: Stream[T], weight: Callable[[T], W] | Operation[[T], W] + ) -> Stream[T]: + """A stream paired with a per-element weight. ``var`` is an + :class:`Operation` standing for "an element of ``stream``"; ``weight`` + is an expression that uses ``var`` and evaluates to the weight of that + element. + + """ + raise NotHandled + + +class MonoidWithZero[T](Monoid[T]): + zero: T + + def __init__(self, name: str, identity: T, zero: T): + super().__init__(name=name, identity=identity) + self.zero = zero + + +Min = Monoid(name="Min", identity=float("inf")) +Max = Monoid(name="Max", identity=-float("inf")) +ArgMin = Monoid(name="ArgMin", identity=(Min.identity, None)) +ArgMax = Monoid(name="ArgMax", identity=(Max.identity, None)) +Sum = Monoid(name="Sum", identity=0) +Product = MonoidWithZero(name="Product", identity=1, zero=0) +# CartesianProduct values are "two-level indexable" (rows × positions). The +# identity ``[()]`` is one row of zero positions (composing with it preserves +# shape); the zero ``[]`` is no rows (absorbs under product). +CartesianProduct = MonoidWithZero(name="CartesianProduct", identity=[()], zero=[]) + + +@dataclass +class _ExtensiblePredicate[T]: + elems: set[T] + + def register(self, t: T) -> None: + self.elems.add(t) + + def __call__(self, t: T) -> bool: + return t in self.elems + + +is_commutative = _ExtensiblePredicate({Max, Min, Sum, Product}) +is_idempotent = _ExtensiblePredicate({Max, Min}) + + +@dataclass +class _ExtensibleBinaryRelation[S, T]: + tuples: set[tuple[S, T]] + + def register(self, s: S, t: T) -> None: + self.tuples.add((s, t)) + + def __call__(self, s: S, t: T) -> bool: + return (s, t) in self.tuples + + +distributes_over = _ExtensibleBinaryRelation( + {(Max, Min), (Min, Max), (Sum, Min), (Sum, Max), (Product, Sum)} +) + + +def _is_monoid_plus(op: Operation) -> bool: + """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.plus + + +def _is_monoid_reduce(op: Operation) -> bool: + """True if ``op`` is the ``reduce`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.reduce + + +def _is_monoid_weighted(op: Operation) -> bool: + """True if ``op`` is the ``weighted`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.weighted + + +class PlusEmpty(ObjectInterpretation): + """plus() = 0""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args: + return monoid.identity + return fwd() + + +class PlusSingle(ObjectInterpretation): + """plus(x) = x""" + + @implements(Monoid.plus) + def plus(self, _, *args): + if len(args) == 1: + return args[0] + return fwd() + + +class PlusAssoc(ObjectInterpretation): + """x + (y + z) = (x + y) + z = x + y + z""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + def is_nested_plus(x): + return isinstance(x, Term) and x.op is monoid.plus + + if any(is_nested_plus(x) for x in args): + flat_args = itertools.chain.from_iterable( + t.args if is_nested_plus(t) else (t,) for t in args + ) + assert len(args) > 0 + return monoid.plus(*flat_args) + return fwd() + + +class PlusDistr(ObjectInterpretation): + """x + (y * z) = x * y + x * z""" + + @implements(Monoid.plus) + def plus(self, monoid: Monoid, *args): + if any( + isinstance(x, Term) + and _is_monoid_plus(x.op) + and distributes_over(monoid, x.op.__self__) + for x in args + ): + non_terms = [] + + # group terms by their monoid + by_monoid: dict[Monoid, list[Term]] = defaultdict(list) + for t in args: + if isinstance(t, Term) and _is_monoid_plus(t.op): + by_monoid[t.op.__self__].append(t) + else: + non_terms.append(t) + + # distribute over each group + progress = False + final_sum = [] + for m, terms in by_monoid.items(): + if ( + len(terms) > 1 + and distributes_over(monoid, m) + and not distributes_over(m, monoid) + ): + progress = True + term_args = (t.args for t in terms) + dist_terms = ( + monoid.plus(*args) for args in itertools.product(*term_args) + ) + final_sum.append(m.plus(*dist_terms)) + else: + final_sum += terms + if progress: + return monoid.plus(*non_terms, *final_sum) + return fwd() + + +class PlusConsecutiveDups(ObjectInterpretation): + """x ⊕ x ⊕ y = x ⊕ y""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not is_idempotent(monoid): + return fwd() + + dedup_args = ( + args[i] + for i in range(len(args)) + if i == 0 or not syntactic_eq(args[i - 1], args[i]) + ) + return fwd(monoid, *dedup_args) + + +class PlusDups(ObjectInterpretation): + """x ⊕ y ⊕ x = x ⊕ y""" + + @dataclass + class _HashableTerm: + term: Term + + def __eq__(self, other): + return syntactic_eq(self, other) + + def __hash__(self): + return syntactic_hash(self) + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not (is_idempotent(monoid) and is_commutative(monoid)): + return fwd() + + # elim dups + args_count = Counter(self._HashableTerm(t) for t in args) + if len(args_count) < len(args): + dedup_args = [] + for t in args: + ht = self._HashableTerm(t) + if ht in args_count: + dedup_args.append(t) + del args_count[ht] + return fwd(monoid, *dedup_args) + return fwd() + + +class ReducePartial(ObjectInterpretation): + @implements(Monoid.reduce) + def _(self, monoid, body, streams): + if not streams: + return monoid.identity + + for stream_key, stream_body, streams_tail in outer_stream(streams): + if isinstance(stream_body, Term): + continue + stream_values_iter = iter(stream_body) + + # if we iterate and get a term instead of a real iterator, skip + if isinstance(stream_values_iter, Term): + continue + + new_reduces = [] + for stream_val in stream_values_iter: + with handler({stream_key: deffn(stream_val)}): + eval_args = evaluate((body, streams_tail)) + assert isinstance(eval_args, tuple) + new_reduces.append( + monoid.reduce(*eval_args) if streams_tail else eval_args[0] + ) + return monoid.plus(*new_reduces) + return fwd() + + +class ReduceFusion(ObjectInterpretation): + """Implements the identity + reduce(R, S1, reduce(R, S2, body)) = reduce(R, S1 ∪ S2, body) + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) and body.op is monoid.reduce: + return monoid.reduce(body.args[0], streams | body.args[1]) + return fwd() + + +class ReduceSplit(ObjectInterpretation): + """Implements the identity + reduce(R, S, b1 + ... + bn) = reduce(R, S, b1) + ... + reduce(R, S, bn) + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if not is_commutative(monoid): + return fwd() + if isinstance(body, Term) and body.op is monoid.plus: + return monoid.plus(*(monoid.reduce(x, streams) for x in body.args)) + return fwd() + + +@Operation.define +def choose_contraction(factors: Sequence[Any], streams: Streams) -> Operation: + """Used by `ReduceFactorization` to choose a contraction when there is + ambiguity. Takes the factors and streams that are eligible for contraction + (innermost and non-universal). + + The default behavior is to return the first support-minimal stream in the + streams dictionary. + + """ + assert len(streams) > 0 + + factors = [(a, fvsof(a)) for a in factors] + support: dict = { + k: frozenset(i for i, (_, fvs) in enumerate(factors) if k in fvs) + for k in streams + } + for v, f_v in support.items(): + if any(u_sup < f_v for u, u_sup in support.items() if u is not v): + continue + return v + assert False, "expected at least one subset-minimal stream" + + +class ReduceFactorization(ObjectInterpretation): + """reduce(⊗(F_v ∪ F_rest), {v} ∪ S) = reduce(⊗F_rest ⊗ reduce(⊗F_v, {v}), S) + + where F_v = factors mentioning v, F_rest = the others. Fires only when + v has no dependents among the remaining streams (so it can be innermost) + and F_rest is nonempty (universal variables stay in the outer core). + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if not ( + is_commutative(monoid) + and isinstance(body, Term) + and _is_monoid_plus(body.op) + and distributes_over(body.op.__self__, monoid) + ): + return fwd() + + inner = body.op.__self__ + stream_keys = set(streams) + factors = [(a, fvsof(a)) for a in body.args] + + # candidates: innermost-eligible (no remaining stream depends on v), + # non-universal (some factor doesn't mention v) + eligible = {} + for k, v in streams.items(): + if any(k in fvsof(vv) for kk, vv in streams.items() if k is not kk): + continue + if len({i for i, (_, fvs) in enumerate(factors) if k in fvs}) == len( + factors + ): + continue # v is universal: leave it in the outer core + eligible[k] = v + + if not eligible: + return fwd() + if len(eligible) == 1: + inner_stream = next(iter(eligible)) + else: + inner_stream = choose_contraction(body.args, eligible) + + inner_factor_ids = frozenset( + i for i, (_, fvs) in enumerate(factors) if inner_stream in fvs + ) + + inner_factors = [factors[i][0] for i in sorted(inner_factor_ids)] + inner_stream_keys = {inner_stream} + inner_deps = set().union( + *(factors[i][1] for i in inner_factor_ids), + fvsof(streams[inner_stream]) & stream_keys, + ) + + outer_factors = [ + a for i, (a, _) in enumerate(factors) if i not in inner_factor_ids + ] + outer_stream_keys = stream_keys - inner_stream_keys + outer_factor_deps = set().union( + *(vars for i, (_, vars) in enumerate(factors) if i not in inner_factor_ids) + ) + + # find all streams that are used in the inner factors/streams and are + # not used by the outer factors/streams + # this has to be done iteratively, because moving a stream inward + # reduces the outer dependency set + # ensures that no future factorization application creates a reduce that + # fuses with with the inner reduce + for s in inner_streams_first(streams): + outer_stream_deps = ( + set().union(*(fvsof(streams[k]) for k in outer_stream_keys)) + & stream_keys + ) + outer_deps = outer_factor_deps | outer_stream_deps + if s in inner_deps and s not in outer_deps: + inner_stream_keys |= {s} + inner_deps |= stream_keys & fvsof(streams[s]) + outer_stream_keys -= {s} + + inner_streams = {k: v for (k, v) in streams.items() if k in inner_stream_keys} + inner_red = monoid.reduce(inner.plus(*inner_factors), inner_streams) + + rest_streams = {k: s for k, s in streams.items() if k in outer_stream_keys} + new_body = inner.plus(*outer_factors, inner_red) + return monoid.reduce(new_body, rest_streams) if rest_streams else new_body + + +class ReduceDistributeCartesianProduct(ObjectInterpretation): + """Eliminates a reduce over a cartesian product. + ∑_x₁ ∑_x₂ ... ∑_xₙ ∏_i f(xᵢ) = ∏_i ∑_xᵢ f(xᵢ) + This transform is also called inversion in the lifting + literature (e.g. [1]). + + More specifically, this transform implements the identity + reduce(⨁, reduce(⨂, body2, {vv: v()}), {v: reduce(×, body1, S1)} ∪ S2) + = reduce(⨁, reduce(⨂, reduce(⨁, body2, {vv: body1}), S1), S2) + where × is the cartesian product and ⨂ distributes over ⨁. + + Note: This could be generalized to grouped inversion [2]. + + [1] Braz, Rd, Eyal Amir, and Dan Roth. "Lifted first-order + probabilistic inference." IJCAI. 2005. + [2] Taghipour, Nima, et al. "Completeness results for lifted + variable elimination." AISTATS. 2013. + """ + + @implements(Monoid.reduce) + def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): + if not (is_commutative(sum_monoid) and isinstance(sum_body, Term)): + return fwd() + + # body is a product or multiplication of products + if _is_monoid_plus(sum_body.op) and distributes_over( + sum_body.op.__self__, sum_monoid + ): + prod_reduces = sum_body.args + else: + prod_reduces = [sum_body] + + products: list[tuple[Monoid, Callable, Operation, Term]] = [] + for prod_reduce in prod_reduces: + if not ( + isinstance(prod_reduce, Term) and _is_monoid_reduce(prod_reduce.op) + ): + return fwd() + prod_monoid: Monoid = prod_reduce.op.__self__ + prod_body = prod_reduce.args[0] + prod_streams = typing.cast(Mapping, prod_reduce.args[1]) + if not ( + distributes_over(prod_monoid, sum_monoid) + and (len(products) == 0 or products[-1][0] == prod_monoid) + ): + return fwd() + + if len(prod_streams) > 1 or len(prod_streams) == 0: + return fwd() + (prod_op, prod_stream) = next(iter(prod_streams.items())) + products.append( + (prod_monoid, deffn(prod_body, prod_op), prod_op, prod_stream) + ) + + assert len(products) > 0 + + for outer_sum_streams, cprod_op, cprod_term in inner_stream(sum_streams): + if not ( + isinstance(cprod_term, Term) + and cprod_term.op is CartesianProduct.reduce + ): + continue + (cprod_body, cprod_streams) = cprod_term.args + + if not all( + prod_stream.op == cprod_op for (_, _, _, prod_stream) in products + ): + continue + + prod_op = Operation.define(products[0][2]) + prod_monoid = products[0][0] + inner_sum = sum_monoid.reduce( + prod_monoid.plus( + *(prod_body(prod_op()) for (_, prod_body, _, _) in products) + ), + {prod_op: cprod_body}, + ) + prod = prod_monoid.reduce(inner_sum, cprod_streams) + outer_sum = ( + sum_monoid.reduce(prod, outer_sum_streams) + if outer_sum_streams + else prod + ) + return outer_sum + + return fwd() + + +class ReduceWeightedStream(ObjectInterpretation): + """reduce(M, body, {x: WM.weighted(s, v, w), ...}) = reduce(M, WM.plus(w[v:=x()], body), {x: s, ...}) + + requires distributes_over(WM, M). + + The substitution ``v -> x`` is done by beta-reducing ``deffn(w, v)`` on + ``x()`` — symbolic, no Python dispatch on the weight expression. + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + for k, v in streams.items(): + if isinstance(v, Term) and _is_monoid_weighted(v.op): + v_stream, v_weight = v.args + v_monoid = v.op.__self__ + if not distributes_over(v_monoid, monoid): + continue + w_at_k = v_weight(k()) + weighted_body = v_monoid.plus(w_at_k, body) + new_streams = {**streams, k: v_stream} + return monoid.reduce(weighted_body, new_streams) + return fwd() + + +class ReduceCartesianWeightedStream(ObjectInterpretation): + """``CartesianProduct.reduce`` over a :func:`weighted` body whose + ``weight`` is independent of the plate (product-index) streams:: + + CartesianProduct.reduce(M.weighted(s, w), plates) + = M.weighted( + CartesianProduct.reduce(s, plates), + deffn(M.reduce(w, {e: row()}), row), + ) + + Reuses ``body``'s element binder ``e`` (already typed by construction); + introduces a fresh ``row`` binder typed as ``Iterable[elem_type]``. + + Only fires when ``w`` is independent of the plate vars. + """ + + @Operation.define + @staticmethod + def _iterable_elem[T](iter: Iterable[T]) -> T: + raise NotHandled + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if monoid is not CartesianProduct: + return fwd() + if not (isinstance(body, Term) and _is_monoid_weighted(body.op)): + return fwd() + + s, w = body.args + if not isinstance(s, Term) and len(s) == 0: + return CartesianProduct.reduce([], streams) + + if set(streams.keys()) & fvsof(w): + return fwd() + + elem_typ = typeof(self._iterable_elem(s)) + elem_op = Operation.define(elem_typ, name="elem") + row_op = Operation.define(Iterable[elem_typ], name="row") + + weight_monoid = body.op.__self__ + joint_weight = deffn( + weight_monoid.reduce(w(elem_op()), {elem_op: row_op()}), row_op + ) + joint_stream = CartesianProduct.reduce(s, streams) + + return weight_monoid.weighted(joint_stream, joint_weight) + + +class MonoidOverCallable(ObjectInterpretation): + """``monoid.reduce(f, streams) = lambda *a: monoid.reduce(f(*a), streams)``.""" + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) or not isinstance(body, Callable): + return fwd() + return lambda *a, **k: monoid.reduce(body(*a, **k), streams) + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args or any( + isinstance(arg, Term) or not isinstance(arg, Callable) for arg in args + ): + return fwd() + return lambda *a, **k: monoid.plus(*(arg(*a, **k) for arg in args)) + + +class MonoidOverMapping(ObjectInterpretation): + """``monoid.reduce({k: v_k}, streams) = {k: monoid.reduce(v_k, streams)}``.""" + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) or not isinstance(body, Mapping): + return fwd() + return {k: monoid.reduce(v, streams) for (k, v) in body.items()} + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args or not isinstance(args[0], Mapping): + return fwd() + + if isinstance(args[0], Interpretation): + keys = args[0].keys() + for b in args[1:]: + if not isinstance(b, Interpretation): + raise TypeError(f"Expected interpretation but got {b}") + if not keys == b.keys(): + raise ValueError( + f"Expected interpretation of {keys} but got {b.keys()}" + ) + return {k: monoid.plus(*(handler(b)(b[k]) for b in args)) for k in keys} + + for b in args[1:]: + if not isinstance(b, Mapping): + raise TypeError(f"Expected mapping but got {b}") + all_values = collections.defaultdict(list) + for d in args: + for k, v in d.items(): + all_values[k].append(v) + return {k: monoid.plus(*vs) for (k, vs) in all_values.items()} + + +def _scalar_args(args): + """True iff ``args`` is non-empty and every arg is a concrete int/float.""" + return ( + bool(args) + and not any(isinstance(x, Term) for x in args) + and all(isinstance(x, int | float) for x in args) + ) + + +class SumPlus(ObjectInterpretation): + """Scalar implementation of :data:`Sum`.""" + + @implements(Sum.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return sum(args) + + +class MinPlus(ObjectInterpretation): + """Scalar implementation of :data:`Min`.""" + + @implements(Min.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return min(args) + + +class MaxPlus(ObjectInterpretation): + """Scalar implementation of :data:`Max`.""" + + @implements(Max.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return max(args) + + +class ProductPlus(ObjectInterpretation): + """Scalar implementation of :data:`Product`.""" + + @implements(Product.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return functools.reduce(operator.mul, args) + + +class ArgMinPlus(ObjectInterpretation): + """Scalar score implementation of :data:`ArgMin`.""" + + @implements(ArgMin.plus) + def plus(self, *args): + if not args or not all(isinstance(a, tuple) for a in args): + return fwd() + if any(isinstance(a[0], Term) for a in args): + return fwd() + if not all(isinstance(a[0], int | float) for a in args): + return fwd() + return min(args, key=lambda a: a[0]) + + +class ArgMaxPlus(ObjectInterpretation): + """Scalar score implementation of :data:`ArgMax`.""" + + @implements(ArgMax.plus) + def plus(self, *args): + if not args or not all(isinstance(a, tuple) for a in args): + return fwd() + if any(isinstance(a[0], Term) for a in args): + return fwd() + if not all(isinstance(a[0], int | float) for a in args): + return fwd() + return max(args, key=lambda a: a[0]) + + +class CartesianProductPlus(ObjectInterpretation): + """Pure-Python implementation of :data:`CartesianProduct`.""" + + @implements(CartesianProduct.plus) + def plus(self, *args): + if not args: + return fwd() + if any(isinstance(x, Term) for x in args): + return fwd() + if not all(isinstance(x, Iterable) for x in args): + return fwd() + + def to_tuple(x): + return x if isinstance(x, tuple) else (x,) + + return [ + sum((to_tuple(v) for v in vals), ()) for vals in itertools.product(*args) + ] + + +is_scalar = _ExtensiblePredicate({Min, Max, Sum, Product}) + + +class MonoidOverSequence(ObjectInterpretation): + @implements(Monoid.plus) + def plus(self, monoid, *args): + if ( + not is_scalar(monoid) + or not args + or not isinstance(args[0], tuple | list | Generator) + ): + return fwd() + zipped = zip(*args, strict=True) + result = (monoid.plus(*vs) for vs in zipped) + if isinstance(args[0], tuple | list): + return type(args[0])(result) + return result + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if not is_scalar(monoid) or not isinstance(body, tuple | list | Generator): + return fwd() + result = (monoid.reduce(x, streams) for x in body) + if isinstance(body, tuple | list): + return type(body)(result) + return result + + +@Operation.define +def as_float(x: int) -> float: + if isinstance(x, Term): + raise NotHandled + return float(x) + + +class PlusCastFloat(ObjectInterpretation): + @implements(Monoid.plus) + def plus(self, monoid, *args): + typs = [typeof(a) for a in args] + if any(issubclass(t, float) for t in typs) and any( + issubclass(t, int) for t in typs + ): + args = [ + as_float(a) if issubclass(t, int) else a + for (a, t) in zip(args, typs, strict=True) + ] + return monoid.plus(*args) + return fwd() + + +class EliminateSingletonStreams(ObjectInterpretation): + """Eliminate a length-1 stream by substituting its sole element. + + reduce(M, body, {k: (v,)} ∪ S) = reduce(M, body[k := v], S[k := v]) + + Fires only when the sole element ``v`` is a :class:`Term`, i.e. a *symbolic* + singleton. This is exactly the form ``ReduceArrayGather`` produces (a gather + ``(a[i()],)``) and, more generally, every dependent singleton that + :class:`ReducePartial` cannot peel -- a non-outermost stream whose element + references another stream var. Concrete enumerated streams (``[0]``, + ``range(1)``) and monoid sentinels (``CartesianProduct.identity == [()]``) + have non-``Term`` elements and are left to ``ReducePartial`` / the + per-monoid rules. + + Unlike ``ReducePartial``, this peels the stream wherever it sits in the loop + nest and substitutes symbolically rather than unrolling, leaving a + vectorized index range (e.g. the gather's range) intact instead of + materializing it. + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + # Eliminate *all* symbolic length-1 streams in one pass via a + # simultaneous substitution. Doing them together (rather than one per + # invocation) keeps an interleaving reduction rule -- e.g. + # ``ReduceArray`` consuming a now-live index range -- from firing + # between eliminations, so sibling index ranges stay together and fuse + # into a single reduction. + singletons = { + k: vs[0] + for k, vs in streams.items() + if not isinstance(vs, Term) + and isinstance(vs, collections.abc.Sequence) + and len(vs) == 1 + and isinstance(vs[0], Term) + } + if not singletons: + return fwd() + + subs = {k: deffn(v) for k, v in singletons.items()} + new_body = handler(subs)(evaluate)(body) + new_streams = { + kk: handler(subs)(evaluate)(vv) + for kk, vv in streams.items() + if kk not in singletons + } + # reduce over no streams is a single (empty) assignment, i.e. the body + # itself -- not the monoid identity. + return monoid.reduce(new_body, new_streams) if new_streams else new_body + + +class _ExtensibleInterpretation(UserDict, Interpretation): + def extend(self, *intps: Interpretation) -> typing.Self: + for intp in intps: + self.data = coproduct(self.data, intp) # type: ignore[assignment] + return self + + +NormalizeIntp = _ExtensibleInterpretation().extend( + ReducePartial(), + EliminateSingletonStreams(), + MonoidOverSequence(), + MonoidOverMapping(), + MonoidOverCallable(), + ReduceFusion(), + ReduceSplit(), + ReduceFactorization(), + ReduceDistributeCartesianProduct(), + ReduceWeightedStream(), + ReduceCartesianWeightedStream(), + PlusEmpty(), + PlusSingle(), + PlusAssoc(), + PlusDistr(), + PlusConsecutiveDups(), + PlusDups(), + SumPlus(), + MinPlus(), + MaxPlus(), + ProductPlus(), + ArgMinPlus(), + ArgMaxPlus(), + CartesianProductPlus(), + PlusCastFloat(), +) +"""``NormalizeIntp``applies pure-Term rewrites (associativity, distributivity, +identity elimination, fusion, factorization, etc.). + +""" diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index f7678fd24..acfdf9fdb 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -209,6 +209,7 @@ def evaluate[T]( @evaluate.register(object) @evaluate.register(str) @evaluate.register(bytes) +@evaluate.register(range) def _evaluate_object[T](expr: T, **kwargs) -> T: if dataclasses.is_dataclass(expr) and not isinstance(expr, type): return typing.cast( @@ -286,6 +287,13 @@ def _evaluate_list_view(expr, **kwargs): def _simple_type(tp: type) -> type: """Convert a type object into a type that can be dispatched on.""" + + def _resolve_aliases(tp: type) -> type: + tp = typing.get_origin(tp) or tp + if isinstance(tp, typing.TypeAliasType): + return _resolve_aliases(tp.__value__) + return tp + if isinstance(tp, typing.TypeVar): tp = ( tp.__bound__ @@ -303,7 +311,7 @@ def _simple_type(tp: type) -> type: tp = functools.reduce(operator.or_, (type(arg) for arg in args)) if isinstance(tp, types.UnionType): raise TypeError(f"Union types are not supported: {tp}") - return typing.get_origin(tp) or tp + return _resolve_aliases(tp) def typeof[T](term: Expr[T]) -> type[T]: diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 764016752..2e198bf36 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -849,9 +849,97 @@ def _(x: collections.abc.Sequence, other) -> bool: @syntactic_eq.register(object) @syntactic_eq.register(str | bytes) def _(x: object, other) -> bool: + if isinstance(other, Term): # Terms often override __eq__ + return False + return x == other + + +@syntactic_eq.register(int | float) +def _(x: int | float, other) -> bool: + # Terms often override __eq__ + if isinstance(other, Term) or not isinstance(other, int | float): + return False return x == other +@_CustomSingleDispatchCallable +def syntactic_hash(__dispatch: Callable[[type], Callable[[Any], int]], x) -> int: + """Structural hash compatible with :func:`syntactic_eq`. + + Guarantees that ``syntactic_eq(x, y)`` implies + ``syntactic_hash(x) == syntactic_hash(y)``. + + :param x: A term. + :returns: An integer hash. + """ + if dataclasses.is_dataclass(x) and not isinstance(x, type): + return hash( + ( + "dataclass", + type(x), + syntactic_hash( + { + field.name: getattr(x, field.name) + for field in dataclasses.fields(x) + } + ), + ) + ) + else: + return __dispatch(type(x))(x) + + +@syntactic_hash.register +def _(x: Term) -> int: + return hash( + ( + "term", + x.op, + len(x.args), + tuple(syntactic_hash(a) for a in x.args), + # sort kwargs so order doesn't affect the hash + tuple((k, syntactic_hash(x.kwargs[k])) for k in sorted(x.kwargs)), + ) + ) + + +@syntactic_hash.register +def _(x: collections.abc.Mapping) -> int: + # XOR over (key_hash, value_hash) pairs — order-independent, + # matching the set-based comparison in syntactic_eq's Mapping branch. + acc = 0 + for k in x: + acc ^= hash((hash(k), syntactic_hash(x[k]))) + return hash(("mapping", acc)) + + +@syntactic_hash.register +def _(x: collections.abc.Sequence) -> int: + if ( + isinstance(x, tuple) + and hasattr(x, "_fields") + and all(hasattr(x, f) for f in x._fields) + ): + return hash( + ( + "namedtuple", + type(x), + tuple(syntactic_hash(getattr(x, f)) for f in x._fields), + ) + ) + else: + # Use the abstract Sequence tag (not type(x)) because syntactic_eq + # treats any two Sequences of equal length and elementwise-equal + # contents as equal — e.g. [1,2] and (1,2) compare equal. + return hash(("sequence", len(x), tuple(syntactic_hash(a) for a in x))) + + +@syntactic_hash.register(object) +@syntactic_hash.register(str | bytes) +def _(x: object) -> int: + return hash(x) + + class ObjectInterpretation[T, V](collections.abc.Mapping): """A helper superclass for defining an ``Interpretation`` of many :class:`~effectful.ops.types.Operation` instances with shared state or behavior. diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 40c1f4af5..e855469bc 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -311,6 +311,15 @@ def func(*args, **kwargs): return typing.cast(Operation[P, T], cls.define(func, **kwargs)) + @define.register(types.MethodType) + @classmethod + def _define_methodtype[**P, T]( + cls, t: Callable[P, T], *, name: str | None = None + ) -> "Operation[P, T]": + op = cls._define_callable(t, name=name) + op.__self__ = t.__self__ # type: ignore[attr-defined] + return typing.cast("Operation[P, T]", op) + @define.register(staticmethod) @classmethod def _define_staticmethod[**P, T](cls, t: "staticmethod[P, T]", **kwargs): @@ -488,7 +497,10 @@ def _instance_op(instance, *args, **kwargs): else: return default_result - instance_op = self.define(types.MethodType(_instance_op, instance)) + name = ("" if owner is None else f"{owner.__name__}_") + self.__name__ + instance_op = self.define( + types.MethodType(_instance_op, instance), name=name + ) instance.__dict__[self._name_on_instance] = instance_op return instance_op elif instance is not None: diff --git a/pyproject.toml b/pyproject.toml index d565403f2..e3ede7856 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,10 @@ Source = "https://github.com/BasisResearch/effectful" [project.optional-dependencies] torch = ["torch"] pyro = ["pyro-ppl>=1.9.1"] -jax = ["jax"] +jax = [ + "jax", + "opt_einsum" +] numpyro = [ "numpyro>=0.19", "jax<0.10" @@ -71,6 +74,7 @@ test = [ "pytest-cov", "pytest-xdist", "pytest-benchmark", + "hypothesis", "mypy", "ruff", "nbval", diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py new file mode 100644 index 000000000..72787558a --- /dev/null +++ b/tests/_monoid_helpers.py @@ -0,0 +1,423 @@ +import builtins +import itertools +import typing +from abc import ABC, abstractmethod +from collections.abc import Callable, Mapping +from typing import Any, Literal, overload + +import jax +from hypothesis import given, settings +from hypothesis import strategies as st +from hypothesis.strategies import SearchStrategy + +import effectful.handlers.jax.numpy as _jnp +from effectful.internals.runtime import interpreter +from effectful.ops.monoid import NormalizeIntp, Stream, _is_monoid_weighted +from effectful.ops.semantics import apply, evaluate, fvsof, handler +from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq +from effectful.ops.types import NotHandled, Operation, Term + + +def syntactic_eq_alpha(x, y) -> bool: + """Alpha-equivalence-respecting variant of ``syntactic_eq``. + + Walks each expression bottom-up with :func:`evaluate` and renames + every bound variable to a deterministic canonical Operation. The + canonical names are assigned by a counter that increments in + ``evaluate``'s natural traversal order, so two alpha-equivalent + expressions canonicalize to syntactically identical results. + """ + + _op_cache: dict[int, Operation] = {} + + def _canonical_op(idx: int, op: Operation) -> Operation: + """Cached canonical Operation, keyed by encounter index. + + Cached so that two independent canonicalize runs return the same + Operation object for the same index — letting ``syntactic_eq`` + compare canonical forms by Operation identity. + """ + if idx in _op_cache: + return _op_cache[idx] + + op = Operation.define(op, name=f"__cv_{idx}") + _op_cache[idx] = op + return op + + cx = _canonicalize(x, _canonical_op) + cy = _canonicalize(y, _canonical_op) + return syntactic_eq(cx, cy) + + +def _canonicalize(expr, _canonical_op): + counter = itertools.count() + + def _substitute(arg, renaming): + """Apply a bound-variable renaming using ``evaluate`` for traversal.""" + if not renaming: + return arg + with interpreter({apply: _BaseTerm, **renaming}): + return evaluate(arg) + + def _bound_var_order(args, kwargs, bound_set: set[Operation]) -> list[Operation]: + """Return bound variables in deterministic encounter order.""" + seen: list[Operation] = [] + seen_set: set[Operation] = set() + + def _capture(op, *a, **kw): + if op in bound_set and op not in seen_set: + seen.append(op) + seen_set.add(op) + return defdata(op, *a, **kw) + + # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, + # etc. for free; the apply handler captures bound vars used as + # ``x()`` anywhere in the body. + with interpreter({apply: _capture}): + evaluate((args, kwargs)) + + # Binders bypass the apply handler. Pick them up with a small structural + # walk that visits dict keys too. + def _walk_bare(obj): + if isinstance(obj, Operation): + if obj in bound_set and obj not in seen_set: + seen.append(obj) + seen_set.add(obj) + elif isinstance(obj, dict): + for k, v in obj.items(): + _walk_bare(k) + _walk_bare(v) + elif isinstance(obj, list | set | frozenset | tuple): + for v in obj: + _walk_bare(v) + + _walk_bare((args, kwargs)) + return seen + + def _apply_canonical(op, *args, **kwargs) -> Term: + bindings = op.__fvs_rule__(*args, **kwargs) + all_bound: set[Operation] = set().union( + *bindings.args, *bindings.kwargs.values() + ) + if not all_bound: + return _BaseTerm(op, *args, **kwargs) + + order = _bound_var_order(args, kwargs, all_bound) + canonical = {var: _canonical_op(next(counter), var) for var in order} + assert all_bound <= set(order) + + new_args = tuple( + _substitute( + arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} + ) + for i, arg in enumerate(args) + ) + new_kwargs = { + k: _substitute( + v, + {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, + ) + for k, v in kwargs.items() + } + + # avoid the renaming from defdata + return _BaseTerm(op, *new_args, **new_kwargs) + + with interpreter({apply: _apply_canonical}): + return evaluate(expr) + + +class Backend(ABC): + """A value-domain spec used to share monoid tests across int and jax.Array + backends. Provides the concrete value type, the hypothesis strategy for + drawing scalars in property tests, and an equality predicate that works + for that domain. + """ + + name: str + scalar_typ: Any + stream_typ: Any + strategy_for_op: dict[Operation, st.SearchStrategy[Callable[..., Any]]] + + def __init__(self): + self.strategy_for_op = {} + + @abstractmethod + def eq(self, a: Any, b: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> SearchStrategy: + raise NotImplementedError + + def _fresh_op( + self, + name: str, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> Operation: + """Build a fresh, unhandled Operation whose parameter and return + annotations are derived from this backend. + + ``ret`` is ``"scalar"`` for a scalar return or ``"stream"`` for a + stream-of-scalar return. The operation has ``n_args`` parameters, + each of type ``scalar_typ``. + """ + scalar = self.scalar_typ + out = self.stream_typ if ret == "stream" else scalar + params = ", ".join(f"_a{i}" for i in range(len(arg_types))) + ns: dict[str, Any] = {"NotHandled": NotHandled} + exec(f"def _fn({params}):\n raise NotHandled\n", ns) + fn = ns["_fn"] + fn.__annotations__ = { + **{f"_a{i}": t for i, t in enumerate(arg_types)}, + "return": out, + } + op = Operation.define(fn, name=name) + self.strategy_for_op[op] = self.strategy(arg_types, ret) + return op + + @overload + def define_vars(self, name: str, /, **kwargs) -> Operation: ... + + @overload + def define_vars( + self, n1: str, n2: str, /, *names: str, **kwargs + ) -> tuple[Operation, ...]: ... + + def define_vars(self, *names: str, **kwargs) -> Operation | tuple[Operation, ...]: # type: ignore[misc] + if len(names) == 1: + return self._fresh_op(names[0], **kwargs) + return tuple(self._fresh_op(n, **kwargs) for n in names) + + def check_rewrite( + self, + lhs, + rhs, + rule, + *, + max_examples: int = 25, + deadline=None, + normalize=NormalizeIntp, + ) -> None: + with handler(rule): + norm = evaluate(lhs) + assert syntactic_eq_alpha(norm, rhs) + + fvs = fvsof(lhs) | fvsof(rhs) + + @st.composite + def random_interpretation( + draw: st.DrawFn, + ) -> Mapping[Operation, Callable[..., Any]]: + """Draw an Interpretation binding every Operation in `free_vars` to + a randomly chosen value/callable. Keys are Operation identities. + """ + intp: dict[Operation, Callable[..., Any]] = {} + for op, strategy in self.strategy_for_op.items(): + if op in fvs: + intp[op] = draw(strategy) + return intp + + @given(intp=random_interpretation()) + @settings( + max_examples=max_examples, deadline=deadline, report_multiple_bugs=False + ) + def _check_semantics(intp): + with handler(normalize), handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert self.eq(lhs_val, rhs_val) + + _check_semantics() + + +def _is_weighted(x: Any) -> bool: + return isinstance(x, Term) and _is_monoid_weighted(x.op) + + +def _weight_pairs(x: Any, monoid: Any) -> list[tuple[Any, Any]] | None: + """Return ``(element, weight)`` pairs for a stream. + + A weighted-monoid Term yields each element paired with its weight. A plain + (unweighted) stream yields each element paired with ``monoid.identity`` -- + the no-op weight -- so an unweighted stream compares equal to a weighted one + exactly when every weight reduces to the identity (e.g. ``[()]`` vs a + weighted ``[()]`` whose single empty row reduces to the identity, and, more + generally, whenever both streams are empty). Returns ``None`` for a + non-stream Term, which never compares equal to a weighted stream. + """ + if isinstance(x, Term): + if not _is_monoid_weighted(x.op): + return None + stream, weight = x.args + assert not isinstance(stream, Term) + return [(e, typing.cast(Callable, weight)(e)) for e in stream] + return [(e, monoid.identity) for e in x] + + +def _weighted_stream_eq(a, b, leaf_eq: Callable[[Any, Any], bool]) -> bool: + monoids = {x.op.__self__ for x in (a, b) if _is_weighted(x)} + # distinct weight monoids can never be equal + if len(monoids) != 1: + return False + monoid = next(iter(monoids)) + + a_pairs = _weight_pairs(a, monoid) + b_pairs = _weight_pairs(b, monoid) + if a_pairs is None or b_pairs is None or len(a_pairs) != len(b_pairs): + return False + for (ea, wa), (eb, wb) in zip(a_pairs, b_pairs): + if not leaf_eq(ea, eb) or not leaf_eq(wa, wb): + return False + return True + + +class IntBackend(Backend): + name = "int" + scalar_typ = int + stream_typ = Stream[int] + + _unary_num_fns: list[Callable[[int], int]] = [ + lambda x: x, + lambda x: x + 1, + lambda x: x - 1, + lambda x: -x, + lambda x: 2 * x, + lambda x: 3 * x + 1, + ] + + _binary_num_fns: list[Callable[[int, int], int]] = [ + lambda x, y: x + y, + lambda x, y: x - y, + lambda x, y: x * y, + lambda x, y: x + 2 * y, + lambda x, y: 2 * x - y, + ] + + _unary_list_fns: list[Callable[[int], list[int]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1], + lambda x: [x, -x], + lambda x: [0, x, x + 1], + ] + + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> SearchStrategy: + match arg_types, ret: + case (), "scalar": + return st.integers(min_value=-100, max_value=100).map(deffn) + case (), "stream": + scalars = st.integers(min_value=-100, max_value=100) + return st.lists(scalars, max_size=2).map(deffn) + case (builtins.int,), "scalar": + return st.sampled_from(self._unary_num_fns) + case (builtins.int, builtins.int), "scalar": + return st.sampled_from(self._binary_num_fns) + case (builtins.int, builtins.int, builtins.int), "scalar": + return st.tuples( + st.sampled_from(self._binary_num_fns), + st.sampled_from(self._binary_num_fns), + ).map(lambda fg: lambda a, b, c: fg[0](a, fg[1](b, c))) + case (builtins.int,), "stream": + return st.sampled_from(self._unary_list_fns) + raise NotImplementedError( + f"No int strategy for op with return {ret!r} and {arg_types} args" + ) + + def eq(self, a: Any, b: Any) -> bool: + if _is_weighted(a) or _is_weighted(b): + return _weighted_stream_eq(a, b, self.eq) + return not isinstance(a, Term) and not isinstance(b, Term) and a == b + + +class JaxBackend(Backend): + name = "jax" + scalar_typ = jax.Array + stream_typ = jax.Array + + _unary_jax_scalar_fns: list[Callable[[jax.Array], jax.Array]] = [ + lambda a: a, + lambda a: a + 1, + lambda a: a - 1, + lambda a: -a, + lambda a: 2 * a, + ] + + _unary_jax_stream_fns: list[Callable[[jax.Array], Stream[jax.Array]]] = [ + lambda a: _jnp.stack([a, a + 1]), + lambda a: _jnp.stack([a, -a]), + lambda a: _jnp.stack([a, a + 1, 2 * a]), + ] + + _binary_jax_scalar_fns: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, + ] + + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> st.SearchStrategy[Callable]: + match arg_types, ret: + case (), "scalar": + return ( + st.lists( + st.integers(min_value=-5, max_value=5), + min_size=2, + max_size=2, + ) + .map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + .map(deffn) + ) + case (), "stream": + return ( + st.lists( + st.integers(min_value=-5, max_value=5), + min_size=1, + max_size=2, + ) + .map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + .map(deffn) + ) + case (jax.Array,), "scalar": + return st.sampled_from(self._unary_jax_scalar_fns) + case (jax.Array, jax.Array), "scalar": + return st.sampled_from(self._binary_jax_scalar_fns) + case (jax.Array, jax.Array, jax.Array), "scalar": + return st.tuples( + st.sampled_from(self._binary_jax_scalar_fns), + st.sampled_from(self._binary_jax_scalar_fns), + ).map(lambda fg: lambda a, b, c: fg[0](a, fg[1](b, c))) + case (jax.Array,), "stream": + return st.sampled_from(self._unary_jax_stream_fns) + + raise NotImplementedError( + f"No jax strategy for op with return {ret!r} and {arg_types} args" + ) + + def eq(self, a: Any, b: Any) -> bool: + if _is_weighted(a) or _is_weighted(b): + return _weighted_stream_eq(a, b, self.eq) + + def _leaf_eq(x: Any, y: Any) -> bool: + return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) + + try: + leaves = jax.tree.leaves(jax.tree.map(_leaf_eq, a, b)) + except (ValueError, TypeError): + return False + return all(leaves) + + +__all__ = ["Backend", "IntBackend", "JaxBackend", "syntactic_eq_alpha"] diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py new file mode 100644 index 000000000..e410a3e69 --- /dev/null +++ b/tests/test_handlers_jax_monoid.py @@ -0,0 +1,558 @@ +import functools + +import jax +import pytest +from jax import random as random + +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import bind_dims, jax_getitem, unbind_dims +from effectful.handlers.jax.monoid import ( + ARRAY_REDUCTORS, + DeltaEmpty, + ReduceArray, + ReduceArrayGather, + ReduceDeltaSimpleRange, + ReduceDependentRangeMask, + ReduceSumProductContraction, + delta, + einsum, +) +from effectful.ops.monoid import ( + EliminateSingletonStreams, + NormalizeIntp, + Product, + Sum, +) +from effectful.ops.semantics import coproduct, handler +from tests._monoid_helpers import JaxBackend + +MONOIDS = [ + pytest.param(monoid, reductor, id=monoid._name) + for (monoid, reductor) in ARRAY_REDUCTORS.items() +] + + +@pytest.fixture(scope="module") +def rng_key(): + return random.PRNGKey(0) + + +@pytest.fixture +def backend() -> JaxBackend: + return JaxBackend() + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_gather(monoid, reductor, backend: JaxBackend): + (x, k) = backend.define_vars("x", "k", ret="scalar") + X = jnp.arange(3) + + lhs = monoid.reduce(x(), {x: X}) + rhs = monoid.reduce(unbind_dims(X, k), {k: range(X.shape[0])}) + backend.check_rewrite( + lhs=lhs, + rhs=rhs, + rule=coproduct(ReduceArrayGather(), EliminateSingletonStreams()), + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_gather_step1(monoid, reductor, backend: JaxBackend): + """Step 1 alone: an array stream becomes an index range plus a length-1 + stream holding the gathered element. ``ReduceArrayGather`` does not perform + the gather substitution itself -- that is ``EliminateSingletonStreams``.""" + (x, k) = backend.define_vars("x", "k", ret="scalar") + X = jnp.arange(3) + + lhs = monoid.reduce(x(), {x: X}) + rhs = monoid.reduce(x(), {k: range(X.shape[0]), x: (unbind_dims(X, k),)}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceArrayGather()) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_gather_dep(monoid, reductor, backend: JaxBackend): + (x, y) = backend.define_vars("x", "y", ret="scalar") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="stream") + g = backend.define_vars( + "g", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + X = jnp.arange(3) + + # The dependent stream ``y: f(x())`` gets the *gathered element* X[x] + # substituted for x -- i.e. ``f(X[x])`` -- not the bare index. + lhs = monoid.reduce(g(x(), y()), {y: f(x()), x: X}) + rhs = monoid.reduce( + g(unbind_dims(X, x), y()), + {y: f(unbind_dims(X, x)), x: range(X.shape[0])}, + ) + backend.check_rewrite( + lhs=lhs, + rhs=rhs, + rule=coproduct(ReduceArrayGather(), EliminateSingletonStreams()), + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_1(monoid, reductor, backend: JaxBackend): + (x, k) = backend.define_vars("x", "k", ret="scalar") + X = jnp.arange(5) + + lhs = monoid.reduce(x(), {x: X}) + rhs = reductor(bind_dims(unbind_dims(X, k), k), axis=(0,)) + backend.check_rewrite( + lhs=lhs, + rhs=rhs, + rule=functools.reduce( + coproduct, # type: ignore[arg-type] + [ + ReduceArrayGather(), + EliminateSingletonStreams(), + ReduceArray(), + ReduceDeltaSimpleRange(), + ], + ), + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_2(monoid, reductor, backend: JaxBackend): + (x, y, k1, k2) = backend.define_vars("x", "y", "k1", "k2", ret="scalar") + X = jnp.arange(5) + Y = jnp.arange(7) + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + + lhs = monoid.reduce(f(x(), y()), {x: X, y: Y}) + rhs = reductor( + bind_dims(f(unbind_dims(X, k1), unbind_dims(Y, k2)), k1, k2), axis=(0, 1) + ) + backend.check_rewrite( + lhs=lhs, + rhs=rhs, + rule=functools.reduce( + coproduct, # type: ignore[arg-type] + [ + ReduceArrayGather(), + EliminateSingletonStreams(), + ReduceArray(), + ReduceDeltaSimpleRange(), + ], + ), + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_3(monoid, reductor, backend: JaxBackend): + """Stream `y` is `g(x())` — depends on the bound element of X. The reducer + must inline ``g`` along the same named dim used to unbind `x`.""" + (x, y, k1, k2) = backend.define_vars("x", "y", "k1", "k2", ret="scalar") + X = jnp.arange(5) + + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) + g = backend.define_vars("g", arg_types=[backend.scalar_typ], ret="stream") + + lhs = monoid.reduce(f(x(), y()), {x: X, y: g(x())}) + rhs = reductor( + bind_dims( + monoid.reduce(f(unbind_dims(X, x), y()), {y: g(unbind_dims(X, x))}), x + ), + axis=(0,), + ) + backend.check_rewrite( + lhs=lhs, + rhs=rhs, + rule=functools.reduce( + coproduct, # type: ignore[arg-type] + [ + ReduceArrayGather(), + EliminateSingletonStreams(), + ReduceArray(), + ReduceDeltaSimpleRange(), + DeltaEmpty(), + ], + ), + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_arange_reduce_direct_full(monoid, reductor, backend: JaxBackend): + """A full-range direct index ``A[v()]`` over ``v: arange(N)`` slices the + whole axis (``A[0:N:1]``) and reduces it -- no materialized-arange gather. + """ + (v, k) = backend.define_vars("v", "k", ret="scalar") + A = backend.define_vars("A", ret="stream") + + lhs = monoid.reduce(jax_getitem(A(), [v()]), {v: range(7)}) + rhs = reductor( + bind_dims(jax_getitem(jax_getitem(A(), [slice(0, 7, 1)]), [k()]), k), + axis=(0,), + ) + backend.check_rewrite( + lhs=lhs, rhs=rhs, rule=coproduct(ReduceArray(), ReduceDeltaSimpleRange()) + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_arange_reduce_indirect(monoid, reductor, backend: JaxBackend): + """When the range var is used both as a direct index and as a value + (``A[v()] + v()``), the direct use slices and the indirect use materializes + the range, both aligned on the same fresh dim.""" + (v, k) = backend.define_vars("v", "k", ret="scalar") + A = jnp.arange(10) + + lhs = monoid.reduce(jax_getitem(A, [v()]) + v(), {v: range(5)}) + rhs = reductor( + bind_dims( + jax_getitem(jax_getitem(A, [slice(0, 5, 1)]), [k()]) + + unbind_dims(jnp.arange(5), k), + k, + ), + axis=(0,), + ) + backend.check_rewrite( + lhs=lhs, rhs=rhs, rule=coproduct(ReduceArray(), ReduceDeltaSimpleRange()) + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_arange_reduce_two_streams(monoid, reductor, backend: JaxBackend): + """Two arange streams indexing a 2-D array slice both axes and reduce over + both at once.""" + (u, w, k1, k2) = backend.define_vars("u", "w", "k1", "k2", ret="scalar") + A = jnp.arange(8 * 9).reshape((8, 9)) + + lhs = monoid.reduce(jax_getitem(A, [u(), w()]), {u: range(4), w: range(5)}) + rhs = reductor( + bind_dims( + jax_getitem(jax_getitem(A, [slice(0, 4, 1), slice(0, 5, 1)]), [k1(), k2()]), + k1, + k2, + ), + axis=(0, 1), + ) + backend.check_rewrite( + lhs=lhs, rhs=rhs, rule=coproduct(ReduceArray(), ReduceDeltaSimpleRange()) + ) + + +# --------------------------------------------------------------------------- +# Delta rules. All tests use the operation form ``delta(idx, body)`` rather +# than the ``Delta`` dataclass; the delta op is the user-facing surface. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_delta_empty(monoid, reductor, backend: JaxBackend): + """An empty-index delta unwraps to its body. + + reduce(M, streams, delta((), body)) ≡ reduce(M, streams, body) + """ + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") + + lhs = monoid.reduce(delta((), x()), {x: X()}) + rhs = monoid.reduce(x(), {x: X()}) + backend.check_rewrite( + lhs=lhs, rhs=rhs, rule=coproduct(ReduceDeltaSimpleRange(), DeltaEmpty()) + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_delta_empty_arange(monoid, reductor, backend: JaxBackend): + x = backend.define_vars("x", ret="scalar") + f = backend.define_vars("f", arg_types=[backend.scalar_typ], ret="scalar") + + lhs = monoid.reduce(delta((x(),), f(x())), {x: range(0)}) + rhs = bind_dims(f(unbind_dims(jnp.array([]), x)), x) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaSimpleRange()) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_delta_independent_one(monoid, reductor, backend: JaxBackend): + """One R1 step: peel the final preserved index off a delta. + + reduce(M, {y: Y()}, delta((y(),), f(y()))) ≡ bind_dims(f(unbind_dims(Y(), k)), k) + """ + (y, k) = backend.define_vars("y", "k", ret="scalar") + f = backend.define_vars("f", arg_types=[backend.scalar_typ], ret="scalar") + + # We use a concrete range here instead of an abstract one, because + # unbind_dims is undefined on empty arrays (and the rewrite produces a + # different rhs in this case) + lhs = monoid.reduce(delta((y(),), f(y())), {y: range(3)}) + rhs = bind_dims(f(unbind_dims(jnp.arange(3), k)), k) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaSimpleRange()) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_delta_independent_preserves_others( + monoid, reductor, backend: JaxBackend +): + """R1 peels only the final index. Streams not matching the peeled index op + stay untouched, as do earlier entries in the index tuple. + + reduce(M, {x: X(), y: Y()}, delta((x(), y()), f(x(), y()))) + ≡ reduce(M, {x: X()}, delta((x(),), bind_dims(f(x(), unbind_dims(Y(), k)), k))) + """ + (x, y, k) = backend.define_vars("x", "y", "k", ret="scalar") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) + + lhs = monoid.reduce(delta((x(), y()), f(x(), y())), {x: range(2), y: range(3)}) + rhs = bind_dims( + bind_dims(f(unbind_dims(jnp.arange(2), x), unbind_dims(jnp.arange(3), k)), k), x + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaSimpleRange()) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_delta_simple_dep(monoid, reductor, backend: JaxBackend): + (x, y) = backend.define_vars("x", "y", ret="scalar") + X = jnp.arange(3) + + lhs = monoid.reduce( + delta((x(),), unbind_dims(X, x) + y()), + {x: range(3), y: jnp.stack([x(), x() + 1])}, + ) + rhs = bind_dims( + monoid.reduce( + delta((), unbind_dims(X, x) + y()), + { + y: jnp.stack( + [unbind_dims(jnp.arange(3), x), unbind_dims(jnp.arange(3), x) + 1] + ) + }, + ), + x, + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaSimpleRange()) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_dependent_range_mask(monoid, reductor, backend: JaxBackend): + """A dependent range stream gets rewritten to the referent's bbox stream, + with the original constraint folded into the body as a where-guard. + + reduce(M, {u: range(0, N, 1), v: range(0, u(), 1)}, body) + ≡ reduce(M, {u: range(0, N, 1), v: range(0, N, 1)}, where(v() < u(), body, M.identity)) + """ + (u, v) = backend.define_vars("u", "v", ret="scalar") + N = 5 + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) + + body = f(u(), v()) + + lhs = monoid.reduce(body, {u: range(N), v: jnp.arange(u())}) + rhs = monoid.reduce( + jnp.where(v() < u(), body, monoid.identity), {u: range(N), v: range(N)} + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDependentRangeMask()) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: JaxBackend): + """When the body is a delta term, R4 folds the constraint into the delta's + weight while leaving its index tuple untouched. + + reduce(M, {u: range(N), v: range(u())}, delta((u(), v()), w)) + ≡ reduce(M, {u: range(N), v: range(N)}, + delta((u(), v()), where(v() < u(), w, M.identity))) + """ + (u, v) = backend.define_vars("u", "v", ret="scalar") + N = 5 + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) + + weight = f(u(), v()) + idx = (u(), v()) + + lhs = monoid.reduce(delta(idx, weight), {u: range(N), v: jnp.arange(u())}) + rhs = monoid.reduce( + delta(idx, jnp.where(v() < u(), weight, monoid.identity)), + {u: range(N), v: range(N)}, + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDependentRangeMask()) + + +def test_reduce_contraction_single(backend: JaxBackend): + i = backend.define_vars("i", ret="scalar") + (A, B) = backend.define_vars( + "A", "B", arg_types=(backend.scalar_typ,), ret="scalar" + ) + + lhs = Sum.reduce(Product.plus(A(i()), B(i())), {i: range(5)}) + rhs = jnp.einsum( + "a...,a...->...", + Sum.reduce(delta((i(),), A(i())), {i: range(5)}), + Sum.reduce(delta((i(),), B(i())), {i: range(5)}), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceSumProductContraction()) + + +def test_reduce_contraction_double(backend: JaxBackend): + i, j = backend.define_vars("i", "j", ret="scalar") + (A, B) = backend.define_vars( + "A", "B", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + + lhs = Sum.reduce(Product.plus(A(i(), j()), B(i(), j())), {i: range(5), j: range(7)}) + rhs = jnp.einsum( + "ab...,ab...->...", + Sum.reduce(delta((i(), j()), A(i(), j())), {i: range(5), j: range(7)}), + Sum.reduce(delta((i(), j()), B(i(), j())), {i: range(5), j: range(7)}), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceSumProductContraction()) + + +def test_reduce_matmul(backend: JaxBackend): + key = jax.random.PRNGKey(0) + # Define dimensions + B, I, J, K = 2, 3, 4, 5 + + # Create sample matrices + X = random.normal(key, (B, I, J)) + Y = random.normal(key, (B, J, K)) + (b, i, j, k) = backend.define_vars("b", "i", "j", "k", ret="scalar") + + with handler(NormalizeIntp): + actual = Sum.reduce( + delta((b(), i(), k()), unbind_dims(X, b, i, j) * unbind_dims(Y, b, j, k)), + {b: range(B), i: range(I), j: range(J), k: range(K)}, + ) + + expected = jnp.einsum("bij,bjk->bik", X, Y) + assert jnp.allclose(actual, expected) + + +EINSUM_CASES = [ + pytest.param("ij,jk->ik", {"i": 64, "j": 64, "k": 64}, id="matmul"), + pytest.param( + "bij,bjk->bik", + {"b": 16, "i": 32, "j": 32, "k": 32}, + id="batched_matmul", + ), + pytest.param( + "a,abi,bcij,cdij->ij", + {"a": 4, "b": 4, "c": 4, "d": 4, "i": 8, "j": 8}, + id="mixed_rank", + ), + # ───────────────────────── single-operand reshuffles ───────────────────── + # No contraction across operands — these stress the diagonal/transpose/sum + # rewrites rather than any pairwise product ordering. + pytest.param("ij->ji", {"i": 256, "j": 256}, id="transpose"), + pytest.param("ijk->", {"i": 96, "j": 96, "k": 96}, id="full_reduce"), + pytest.param("ijk->k", {"i": 96, "j": 96, "k": 96}, id="partial_reduce"), + # Repeated index *within* one operand — exercises the implicit-diagonal path + # in ReduceDeltaSimpleRange (no explicit jnp.diagonal step). + pytest.param("ii->", {"i": 1024}, id="trace"), + pytest.param("ii->i", {"i": 1024}, id="diagonal"), + pytest.param("bii->b", {"b": 256, "i": 128}, id="batched_trace"), + pytest.param("iij->ij", {"i": 128, "j": 128}, id="diagonal_keep"), + # ───────────────────────── no-shared-index blowups ─────────────────────── + # Output is the full outer product — nothing contracts, so the result tensor + # is as large as the dense intermediate. Pure broadcast cost. + pytest.param("i,j->ij", {"i": 1024, "j": 1024}, id="outer_product"), + pytest.param("ij,kl->ijkl", {"i": 32, "j": 32, "k": 32, "l": 32}, id="outer_4d"), + # Element-wise: every index shared, none contracted. + pytest.param("ij,ij->ij", {"i": 512, "j": 512}, id="hadamard"), + # ───────────────────────── ordering-sensitive products ─────────────────── + # Skewed matrix chain: contracting middle-first (b,d small) is orders of + # magnitude cheaper than the left-to-right order, which materializes a big + # a×c intermediate. The classic "matrix chain order matters" case. + pytest.param( + "ab,bc,cd->ad", {"a": 256, "b": 2, "c": 256, "d": 2}, id="skewed_chain" + ), + pytest.param( + "ab,bc,cd,de->ae", + {"a": 50, "b": 40, "c": 30, "d": 20, "e": 10}, + id="chain_4", + ), + pytest.param( + "ab,bc,cd,de,ef->af", + {"a": 12, "b": 11, "c": 10, "d": 9, "e": 8, "f": 7}, + id="chain_5", + ), + # ───────────────────────── tensor-network shapes ───────────────────────── + # Cyclic / hyperedge contractions with no tree decomposition into matmuls; + # every operand shares indices with two others. + pytest.param("ij,jk,ki->", {"i": 64, "j": 64, "k": 64}, id="trace_of_product"), + pytest.param("ij,jk,ik->", {"i": 48, "j": 48, "k": 48}, id="triangle"), + pytest.param("ijk,jl,kl->il", {"i": 24, "j": 24, "k": 24, "l": 24}, id="hyperedge"), + # Star: many operands share one contracted index, fanning into a large + # outer-product output. + pytest.param( + "ai,bi,ci,di->abcd", + {"a": 8, "b": 8, "c": 8, "d": 8, "i": 32}, + id="star_contraction", + ), + # Bilinear / quadratic form over a batch (attention-score flavored). + pytest.param("bi,ij,bj->b", {"b": 128, "i": 64, "j": 64}, id="bilinear"), + # Batched matrix chain — batch axis rides through three contractions. + pytest.param( + "bij,bjk,bkl->bil", + {"b": 16, "i": 24, "j": 24, "k": 24, "l": 24}, + id="batched_chain", + ), + # Multi-index contraction surface: a whole axis-group (c) contracts at once. + pytest.param( + "abc,cde->abde", + {"a": 12, "b": 12, "c": 12, "d": 12, "e": 12}, + id="tensor_contraction", + ), + # Leading scalar factor plus an element-wise reduce — checks that the + # rank-0 operand threads through without spawning a degenerate axis. + pytest.param(",ij,ij->", {"i": 256, "j": 256}, id="scalar_scaled_reduce"), +] + + +def _make_operands(spec: str, sizes: dict[str, int], key: jax.Array) -> list[jax.Array]: + in_part = spec.split("->")[0] + in_specs = in_part.split(",") + keys = random.split(key, len(in_specs)) + return [ + random.normal(k, tuple(sizes[c] for c in s) if s else ()) + for k, s in zip(keys, in_specs, strict=True) + ] + + +@pytest.mark.parametrize( + "impl", [pytest.param(jnp.einsum, id="jax"), pytest.param(einsum, id="effectful")] +) +@pytest.mark.parametrize("spec,sizes", EINSUM_CASES) +@pytest.mark.benchmark(warmup=True, warmup_iterations=1) +def test_einsum_bench(benchmark, impl, spec, sizes, rng_key): + """Time one ``(spec, impl)`` pair. Group by ``spec`` to compare ``jnp`` + against ``effectful`` for the same subscript pattern (see module docstring). + """ + operands = _make_operands(spec, sizes, rng_key) + + @jax.jit + def f(*operands): + return impl(spec, *operands) + + @benchmark + def _run(): + return f(*operands).block_until_ready() + + +@pytest.mark.parametrize("spec,sizes", EINSUM_CASES) +def test_einsum_matches_jnp(spec: str, sizes, rng_key): + """``einsum`` returns the same result as ``jnp.einsum`` for every spec + in ``EINSUM_EXAMPLES``. + """ + operands = _make_operands(spec, sizes, rng_key) + actual = einsum(spec, *operands) + expected = jnp.einsum(spec, *operands) + assert actual.shape == expected.shape, ( + f"shape mismatch for {spec!r}: got {actual.shape}, expected {expected.shape}" + ) + assert jnp.allclose(actual, expected, atol=1e-4, rtol=1e-4), ( + f"value mismatch for {spec!r}" + ) diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index b56fd7bbd..9a2983901 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -240,7 +240,7 @@ def test_agent_tool_names_are_valid_integration(): agent = _ToolNameAgent() template = agent.ask tools = template.tools - expected_helper_tool_name = f"self__{agent.helper.__name__}" + expected_helper_tool_name = "self__helper" assert tools assert expected_helper_tool_name in tools assert all(re.fullmatch(r"[a-zA-Z0-9_-]+", name) for name in tools) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 8b93976fc..8abfdc5ac 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -1900,3 +1900,10 @@ class Info(typing.TypedDict): subs = unify(collections.abc.Mapping, Info) assert subs == {} + + +def test_unify_jax_array_iterable(): + import jax + + subs = unify(collections.abc.Iterable[T], jax.Array) + assert subs == {T: jax.Array} diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py new file mode 100644 index 000000000..9976fd6dc --- /dev/null +++ b/tests/test_ops_monoid.py @@ -0,0 +1,838 @@ +import math +import typing +from collections.abc import Iterable + +import pytest +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st + +import effectful.handlers.jax.monoid # noqa: F401 +import effectful.handlers.jax.numpy as jnp +from effectful.ops.monoid import ( + CartesianProduct, + EliminateSingletonStreams, + Max, + Min, + Monoid, + MonoidOverMapping, + MonoidOverSequence, + NormalizeIntp, + PlusAssoc, + PlusConsecutiveDups, + PlusDistr, + PlusDups, + PlusEmpty, + PlusSingle, + Product, + ReduceCartesianWeightedStream, + ReduceDistributeCartesianProduct, + ReduceFactorization, + ReduceFusion, + ReducePartial, + ReduceSplit, + ReduceWeightedStream, + Sum, + distributes_over, +) +from effectful.ops.semantics import coproduct, evaluate, fvsof, handler +from effectful.ops.syntax import deffn +from effectful.ops.types import NotHandled, Operation, Term +from tests._monoid_helpers import Backend, IntBackend, JaxBackend, syntactic_eq_alpha + + +@pytest.fixture(params=[IntBackend, JaxBackend], ids=["int", "jax"]) +def backend(request) -> Backend: + return request.param() + + +ALL_MONOIDS = [ + pytest.param(Sum, id="Sum"), + pytest.param(Product, id="Product"), + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +COMMUTATIVE = [ + pytest.param(Sum, id="Sum"), + pytest.param(Product, id="Product"), + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +IDEMPOTENT = [ + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +WITH_ZERO = [ + pytest.param(Product, id="Product"), +] + +# Pairs (outer, inner) such that inner distributes over outer — i.e. the lifting +# identity ``outer(inner(body, A), CartesianProduct...) == inner(outer(body, D), ...)`` +# is valid for that semiring pair. +MONOID_PAIRS = [ + pytest.param(o.values[0], i.values[0], id=f"{o.id}-{i.id}") + for o in ALL_MONOIDS + for i in ALL_MONOIDS + if distributes_over( + typing.cast(Monoid, i.values[0]), typing.cast(Monoid, o.values[0]) + ) +] + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_associativity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() + b = data.draw(backend.strategy(ret="scalar"))() + c = data.draw(backend.strategy(ret="scalar"))() + with handler(NormalizeIntp): + left = monoid.plus(monoid.plus(a, b), c) + right = monoid.plus(a, monoid.plus(b, c)) + assert backend.eq(left, right) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_identity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(monoid.identity, a), a) + assert backend.eq(monoid.plus(a, monoid.identity), a) + + +@pytest.mark.parametrize("monoid", COMMUTATIVE) +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_commutativity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() + b = data.draw(backend.strategy(ret="scalar"))() + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(a, b), monoid.plus(b, a)) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_idempotence(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(a, a), a) + + +@pytest.mark.parametrize("monoid", WITH_ZERO) +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_zero_absorbs(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(monoid.zero, a), monoid.zero) + assert backend.eq(monoid.plus(a, monoid.zero), monoid.zero) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_empty(monoid, backend: Backend): + backend.check_rewrite(lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty()) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_single(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + backend.check_rewrite(lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle()) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_identity_right(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + + lhs = monoid.plus(x(), monoid.identity) + rhs = monoid.plus(x()) + + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_identity_left(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + + lhs = monoid.plus(monoid.identity, x()) + rhs = monoid.plus(x()) + + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_assoc_right(monoid, backend: Backend): + x, y, z = backend.define_vars("x", "y", "z", ret="scalar") + backend.check_rewrite( + lhs=monoid.plus(x(), monoid.plus(y(), z())), + rhs=monoid.plus(x(), y(), z()), + rule=PlusAssoc(), + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_assoc_left(monoid, backend: Backend): + x, y, z = backend.define_vars("x", "y", "z", ret="scalar") + backend.check_rewrite( + lhs=monoid.plus(monoid.plus(x(), y()), z()), + rhs=monoid.plus(x(), y(), z()), + rule=PlusAssoc(), + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_sequence(monoid, backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") + backend.check_rewrite( + lhs=monoid.plus((a(), b()), (c(), d())), + rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), + rule=MonoidOverSequence(), + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_mapping(monoid, backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") + + lhs = monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}) + rhs = {0: monoid.plus(a(), c()), 1: monoid.plus(b()), 2: monoid.plus(d())} + + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverMapping()) + + +def test_plus_distributes(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") + lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) + rhs = Product.plus( + Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ) + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) + + +def test_plus_distributes_constant(backend: Backend): + a, b, c, d, e = backend.define_vars("a", "b", "c", "d", "e", ret="scalar") + lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), e()) + rhs = Product.plus( + e(), + Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) + + +def test_plus_distributes_multiple(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") + lhs = Sum.plus( + Min.plus(a(), b()), + Min.plus(c(), d()), + Max.plus(a(), b()), + Max.plus(c(), d()), + ) + rhs = Sum.plus( + Min.plus( + Sum.plus(a(), c()), + Sum.plus(a(), d()), + Sum.plus(b(), c()), + Sum.plus(b(), d()), + ), + Max.plus( + Sum.plus(a(), c()), + Sum.plus(a(), d()), + Sum.plus(b(), c()), + Sum.plus(b(), d()), + ), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +def test_plus_idempotent_consecutive(monoid, backend: Backend): + """``a, a, b → a, b`` — only consecutive duplicates collapse.""" + a, b = backend.define_vars("a", "b", ret="scalar") + lhs = monoid.plus(a(), a(), b()) + return backend.check_rewrite( + lhs=lhs, rhs=monoid.plus(a(), b()), rule=PlusConsecutiveDups() + ) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +def test_plus_idempotent_non_consecutive(monoid, backend: Backend): + """``a, b, a`` — Semilattice (Min/Max) collapses via commutative + PlusDups.""" + a, b = backend.define_vars("a", "b", ret="scalar") + lhs = monoid.plus(a(), b(), a()) + rhs = monoid.plus(a(), b()) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups()) + + +@pytest.mark.parametrize("monoid", [Min, Max]) +def test_plus_commutative_idempotent_long(monoid, backend: Backend): + """Long alternation collapses via commutative dedup (Min/Max only).""" + a, b = backend.define_vars("a", "b", ret="scalar") + lhs = monoid.plus(a(), b(), a(), b(), b(), a(), a()) + rhs = monoid.plus(a(), b()) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups()) + + +@pytest.mark.parametrize("monoid", WITH_ZERO) +def test_plus_zero(monoid, backend: Backend): + a = backend.define_vars("a", ret="scalar") + lhs_right = monoid.plus(a(), monoid.zero) + lhs_left = monoid.plus(monoid.zero, a()) + rhs = monoid.zero + backend.check_rewrite(lhs=lhs_right, rhs=rhs, rule={}) + backend.check_rewrite(lhs=lhs_left, rhs=rhs, rule={}) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_1(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + lhs = monoid.reduce(x(), {x: []}) + rhs = monoid.plus() + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReducePartial()) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_2(monoid, backend: Backend): + x, y = backend.define_vars("x", "y", ret="scalar") + Y = backend.define_vars("Y", ret="stream") + + lhs = monoid.reduce(x(), {y: Y(), x: []}) + rhs = monoid.plus() + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReducePartial()) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_3(monoid, backend: Backend): + x, y, a, b = backend.define_vars("x", "y", "a", "b", ret="scalar") + Y = backend.define_vars("Y", ret="stream") + + lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) + rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReducePartial()) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_4(monoid, backend: Backend): + x, y, a, b = backend.define_vars("x", "y", "a", "b", ret="scalar") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="stream") + + lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) + rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReducePartial()) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_eliminate_singleton_into_sibling(monoid, backend: Backend): + """A length-1 stream substitutes its element into the body *and* into a + sibling stream's definition, then drops out of the nest.""" + x, y, a = backend.define_vars("x", "y", "a", ret="scalar") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="stream") + g = backend.define_vars( + "g", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + + lhs = monoid.reduce(g(x(), y()), {x: (a(),), y: f(x())}) + rhs = monoid.reduce(g(a(), y()), {y: f(a())}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=EliminateSingletonStreams()) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_eliminate_singleton_only_stream(monoid, backend: Backend): + """When the length-1 stream is the only stream, reducing over the now-empty + nest yields the substituted body itself (not the monoid identity).""" + x, a = backend.define_vars("x", "a", ret="scalar") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="scalar") + + lhs = monoid.reduce(f(x()), {x: (a(),)}) + rhs = f(a()) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=EliminateSingletonStreams()) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") + + lhs = monoid.reduce((f(x()), g(x())), {x: X()}) + rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverSequence()) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence_2(monoid, backend: Backend): + x, y = backend.define_vars("x", "y", ret="scalar") + X, Y = backend.define_vars("X", "Y", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") + + lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) + rhs = ( + monoid.reduce(f(x()), {x: X(), y: Y()}), + monoid.reduce(g(y()), {x: X(), y: Y()}), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverSequence()) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_mapping(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") + + lhs = monoid.reduce({0: f(x()), 1: g(x())}, {x: X()}) + rhs = { + 0: monoid.reduce(f(x()), {x: X()}), + 1: monoid.reduce(g(x()), {x: X()}), + } + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverMapping()) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_no_streams(monoid, backend: Backend): + a = backend.define_vars("a", ret="scalar") + + lhs = monoid.reduce(a(), {}) + rhs = monoid.identity + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReducePartial()) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_reduce(monoid, backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + + lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) + rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFusion()) + + +@pytest.mark.parametrize("monoid", COMMUTATIVE) +def test_reduce_plus(monoid, backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + + lhs = monoid.reduce(monoid.plus(a(), b()), {a: A(), b: B()}) + rhs = monoid.plus( + monoid.reduce(a(), {a: A(), b: B()}), + monoid.reduce(b(), {a: A(), b: B()}), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceSplit()) + + +def test_reduce_independent_1(backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + + lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) + rhs = Product.plus( + Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b()), {b: B()}) + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) + + +def test_reduce_independent_2(backend: Backend): + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, B, C = backend.define_vars("A", "B", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + + lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) + rhs = Product.plus( + Sum.reduce(Product.plus(a()), {a: A()}), + Sum.reduce( + Product.plus(b(), Sum.reduce(Product.plus(f(b(), c())), {c: C()})), + {b: B()}, + ), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) + + +def test_reduce_independent_3_negative(backend: Backend): + """Stream `b` depends on `a` (b: g(a())), so the proposed factorization + is unsound — the normalizer must NOT apply it.""" + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, C = backend.define_vars("A", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + g = backend.define_vars("g", arg_types=(backend.scalar_typ,), ret="stream") + + with handler(ReduceFactorization()): # ty:ignore[invalid-argument-type] + lhs = Sum.reduce( + Product.plus(a(), b(), f(b(), c())), {a: A(), b: g(a()), c: C()} + ) + bogus_rhs = Product.plus( + Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: g(a()), c: C()}), + ) + assert fvsof(bogus_rhs) != fvsof(lhs) + assert not syntactic_eq_alpha(lhs, bogus_rhs) + + +def test_reduce_independent_4(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") + A, B, C = backend.define_vars("A", "B", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + + lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), d()), {a: A(), b: B(), c: C()}) + rhs = Product.plus( + d(), + Sum.reduce(Product.plus(a()), {a: A()}), + Sum.reduce( + Product.plus(b(), Sum.reduce(Product.plus(f(b(), c())), {c: C()})), + {b: B()}, + ), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) + + +def test_reduce_chain(backend: Backend): + x, y = backend.define_vars("x", "y", ret="scalar") + X, Y = backend.define_vars("X", "Y", ret="stream") + f, h = backend.define_vars("f", "h", arg_types=(backend.scalar_typ,), ret="scalar") + g = backend.define_vars( + "g", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + + lhs = Sum.reduce(Product.plus(f(x()), g(x(), y()), h(y())), {x: X(), y: Y()}) + rhs = Sum.reduce( + Product.plus(h(y()), Sum.reduce(Product.plus(f(x()), g(x(), y())), {x: X()})), + {y: Y()}, + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lift_shared(outer, inner, backend: Backend): + """A stream free in every factor is hoisted into an outer reduce: + Sum.reduce(f(a, c) * g(b, c), {a: A, b: B, c: C}) + = Sum.reduce(Sum.reduce(f(a, c), {a: A}) * Sum.reduce(g(b, c), {b: B}), {c: C}) + """ + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, B, C = backend.define_vars("A", "B", "C", ret="stream") + f, g = backend.define_vars( + "f", "g", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + + lhs = outer.reduce(inner.plus(f(a(), c()), g(b(), c())), {a: A(), b: B(), c: C()}) + rhs = outer.reduce( + inner.plus( + outer.reduce(inner.plus(f(a(), c())), {a: A()}), + outer.reduce(inner.plus(g(b(), c())), {b: B()}), + ), + {c: C()}, + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lift_shared_deps(outer, inner, backend: Backend): + """A shared stream is lifted together with its dependencies: both ``c`` + and ``d = h(c)`` appear in every factor, so both are hoisted.""" + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") + A, B, C = backend.define_vars("A", "B", "C", ret="stream") + h = backend.define_vars("h", arg_types=(backend.scalar_typ,), ret="stream") + f, g = backend.define_vars( + "f", + "g", + arg_types=(backend.scalar_typ, backend.scalar_typ, backend.scalar_typ), + ret="scalar", + ) + + lhs = outer.reduce( + inner.plus(f(a(), c(), d()), g(b(), c(), d())), + {a: A(), b: B(), c: C(), d: h(c())}, + ) + rhs = outer.reduce( + inner.plus( + outer.reduce(inner.plus(f(a(), c(), d())), {a: A()}), + outer.reduce(inner.plus(g(b(), c(), d())), {b: B()}), + ), + {c: C(), d: h(c())}, + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) + + +def test_reduce_cartesian_3(): + backend = JaxBackend() + i = backend.define_vars("i", ret="scalar") + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(2), {i: jnp.arange(3)}) + assert value.shape == (2**3, 3) + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(2), {i: jnp.arange(1)}) + assert value.shape == (2**1, 1) + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(1), {i: jnp.arange(3)}) + assert value.shape == (1**3, 3) + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_1(outer, inner, backend: Backend): + a, i = backend.define_vars("a", "i", ret="scalar") + A, N, A_domain = backend.define_vars("A", "N", "A_domain", ret="stream") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="scalar") + + lhs = outer.reduce( + inner.reduce(f(a()), {a: A()}), + {A: CartesianProduct.reduce(A_domain(), {i: N()})}, + ) + rhs = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) + + +def test_reduce_cartesian_1(): + backend = IntBackend() + a, i = backend.define_vars("a", "i", ret="scalar") + A = backend.define_vars("A", ret="stream") + + with handler(NormalizeIntp): + term1 = Sum.reduce( + Product.reduce(a(), {a: []}), + {A: CartesianProduct.reduce([], {i: []})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) + assert term1 == term2 + + +def test_reduce_cartesian_2(): + backend = IntBackend() + a, i = backend.define_vars("a", "i", ret="scalar") + A = backend.define_vars("A", ret="stream") + + with handler(NormalizeIntp): + term1 = Sum.reduce( + Product.reduce(a(), {a: A()}), + {A: CartesianProduct.reduce([(0,)], {i: [0]})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) + assert term1 == term2 + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_multi_index(outer, inner, backend: Backend): + a, i, j = backend.define_vars("a", "i", "j", ret="scalar") + A, N, M, A_domain = backend.define_vars("A", "N", "M", "A_domain", ret="stream") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="scalar") + + lhs = outer.reduce( + inner.reduce(f(a()), {a: A()}), + {A: CartesianProduct.reduce(A_domain(), {i: N(), j: M()})}, + ) + rhs = inner.reduce( + outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N(), j: M()} + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_2(outer, inner, backend: Backend): + """The worked example on page 396 of 'Lifted Variable Elimination: + Decoupling the Operators from the Constraint Language'. + + """ + a, i, s, t = backend.define_vars("a", "i", "s", "t", ret="scalar") + A, N, T = backend.define_vars("A", "N", "T", ret="stream") + A_domain = backend.define_vars( + "A_domain", arg_types=(backend.scalar_typ,), ret="stream" + ) + f1, f2 = backend.define_vars( + "f1", "f2", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + + lhs = outer.reduce( + inner.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A()}), + {A: CartesianProduct.reduce(A_domain(i()), {i: N()}), t: T()}, + ) + rhs = outer.reduce( + inner.reduce( + outer.reduce( + inner.plus(inner.plus(f1(a(), s()), f2(t(), a()))), {a: A_domain(i())} + ), + {i: N()}, + ), + {t: T()}, + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) + + +# --------------------------------------------------------------------------- +# Weighted streams +# --------------------------------------------------------------------------- + + +def test_reduce_single_weighted_stream(backend: Backend): + """Single weighted stream desugars: + Sum.reduce(body, {a: WS(A, w, Product)}) + = Sum.reduce(Product.plus(w(a), body), {a: A}) + """ + a = backend.define_vars("a", ret="scalar") + A = backend.define_vars("A", ret="stream") + body, w = backend.define_vars( + "body", "w", arg_types=(backend.scalar_typ,), ret="scalar" + ) + + lhs = Sum.reduce(body(a()), {a: Product.weighted(A(), w)}) + rhs = Sum.reduce(Product.plus(w(a()), body(a())), {a: A()}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceWeightedStream()) + + +def test_reduce_weighted_factorization(backend: Backend): + """Two independent weighted streams under Sum with Product weights factor: + Sum.reduce(f(a)*g(b), {a: Product.weighted(A, a, w_a), b: Product.weighted(B, b, w_b)}) + = (Sum.reduce(w_a(a)*f(a), {a: A})) * (Sum.reduce(w_b(b)*g(b), {b: B})) + + Exercises chaining of ``ReduceWeightedStream`` with ``ReduceFactorization`` + inside ``NormalizeIntp``. + """ + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + f, g, w_a, w_b = backend.define_vars( + "f", "g", "w_a", "w_b", arg_types=(backend.scalar_typ,), ret="scalar" + ) + + lhs = Sum.reduce( + Product.plus(f(a()), g(b())), + {a: Product.weighted(A(), w_a), b: Product.weighted(B(), w_b)}, + ) + rhs = Product.plus( + Sum.reduce(Product.plus(w_a(a()), Product.plus(f(a()))), {a: A()}), + Sum.reduce(Product.plus(w_b(b()), Product.plus(g(b()))), {b: B()}), + ) + backend.check_rewrite( + lhs=lhs, rhs=rhs, rule=coproduct(ReduceWeightedStream(), ReduceFactorization()) + ) + + +def test_reduce_cartesian_weighted_stream(backend: Backend): + """``CartesianProduct.reduce`` over a ``WeightedStream`` body whose weight + is independent of the plate var rewrites to a single joint + ``WeightedStream``: + + CartesianProduct.reduce(M.weighted(s, e, w(e)), {p: P}) + = M.weighted(CartesianProduct.reduce(s, {p: P}), row, M.reduce(w(e), {e: row()})) + """ + p, e_var = backend.define_vars("p", "e_var", ret="scalar") + S, P = backend.define_vars("S", "P", ret="stream") + w = backend.define_vars("w", arg_types=(backend.scalar_typ,), ret="scalar") + + lhs = CartesianProduct.reduce(Product.weighted(S(), w), {p: P()}) + row_var = Operation.define(Iterable[backend.scalar_typ], name="row") # type: ignore[name-defined] + rhs = Product.weighted( + CartesianProduct.reduce(S(), {p: P()}), + deffn(Product.reduce(w(e_var()), {e_var: row_var()}), row_var), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceCartesianWeightedStream()) + + +def test_lift_weighted_cartesian(backend: Backend): + """Compose ``ReduceCartesianWeightedStream`` + ``ReduceWeightedStream`` + + ``ReduceDistributeCartesianProduct`` on a Sum-of-Product-of-weighted shape: + + Sum.reduce( + Product.reduce(body(a()), {a: A()}), + {A: CartesianProduct.reduce(Product.weighted(S, e, w(e)), {p: P})}, + ) + + The inner ``weighted`` becomes a joint ``weighted`` (rule 1), lifts its + per-element weight into the outer Sum body (rule 2), and the lifted form + matches the inversion pattern (rule 3), yielding:: + + Product.reduce( + Sum.reduce(Product.plus(w(a()), body(a())), {a: S}), + {p: P}, + ) + """ + a, p = backend.define_vars("a", "p", ret="scalar") + A, S, P = backend.define_vars("A", "S", "P", ret="stream") + body, w = backend.define_vars( + "body", "w", arg_types=(backend.scalar_typ,), ret="scalar" + ) + + lhs = Sum.reduce( + Product.reduce(body(a()), {a: A()}), + {A: CartesianProduct.reduce(Product.weighted(S(), w), {p: P()})}, + ) + rhs = Product.reduce( + Sum.reduce(Product.plus(w(a()), body(a())), {a: S()}), {p: P()} + ) + backend.check_rewrite( + lhs=lhs, + rhs=rhs, + rule=coproduct( + coproduct(ReduceWeightedStream(), ReduceCartesianWeightedStream()), + ReduceDistributeCartesianProduct(), + ), + ) + + +def test_weighted_expectation_demo(): + """Demo: compute E[f(X)] = Σ_x w(x)·f(x) via a weighted reduce. + + X ranges over [1, 2, 3, 4] with weights w(x) = x/10 (a valid distribution + since the weights sum to 1) and f(x) = x*x. Expected value: + 0.1·1 + 0.2·4 + 0.3·9 + 0.4·16 = 10.0 + """ + weights = {1: 0.1, 2: 0.2, 3: 0.3, 4: 0.4} + + def _w(v: int) -> float: + if isinstance(v, Term): + raise NotHandled + return weights[v] + + def _f(v: int) -> float: + if isinstance(v, Term): + raise NotHandled + return float(v * v) + + a = Operation.define(int, name="a") + w = Operation.define(_w, name="w") + f = Operation.define(_f, name="f") + + with handler(NormalizeIntp): + result = evaluate(Sum.reduce(f(a()), {a: Product.weighted([1, 2, 3, 4], w)})) + + assert math.isclose(result, 10.0) diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 185b6132e..1f5c47763 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -489,7 +489,6 @@ def _(self, x: bool) -> bool: ) assert isinstance(term_float, Term) - assert term_float.op.__name__ == "my_singledispatch" assert term_float.args == (1.5,) assert term_float.kwargs == {}