Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions effectful/internals/unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,17 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions:
typing.get_origin(typ), collections.abc.Generator
):
return unify(typing.get_args(typ)[0], typing.get_args(subtyp)[0], subs)
elif typing.get_origin(subtyp) is effectful.ops.types.Operation and not (
isinstance(typing.get_origin(typ), type)
and issubclass(typing.get_origin(typ), effectful.ops.types.Operation)
):
# An Operation[P, R] is a Callable[P, R] (gh #669): unify the pattern
# against the operation's parameter/return signature. ``Operation``'s
# args are (params, return) just like ``Callable``'s, except params is
# a tuple (or ``...``) rather than a list.
op_params, op_ret = typing.get_args(subtyp)
callable_params = op_params if op_params is ... else list(op_params)
return unify(typ, collections.abc.Callable[callable_params, op_ret], subs) # type: ignore
elif typing.get_origin(typ) == typing.get_origin(subtyp):
return unify(typing.get_args(typ), typing.get_args(subtyp), subs)
elif types.get_original_bases(typing.get_origin(subtyp)):
Expand All @@ -556,6 +567,17 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions:
and issubclass(subtyp, typing.get_origin(typ))
):
return subs # implicit expansion to subtyp[Any]
elif isinstance(typ, GenericAlias):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is from @jfeser in #656

# Special case for treating arrays as iterables of arrays
try:
import jax

if typing.get_origin(typ) is collections.abc.Iterable and issubclass(
subtyp, jax.Array
):
return unify(typing.get_args(typ)[0], jax.Array, subs)
except ImportError:
pass
raise TypeError(f"Cannot unify generic type {typ} with {subtyp} given {subs}.")


Expand Down
66 changes: 66 additions & 0 deletions tests/test_internals_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1967,3 +1967,69 @@ class Info(typing.TypedDict):

subs = unify(collections.abc.Mapping, Info)
assert subs == {}


def test_unify_jax_array_iterable():
import jax

subs = unify(collections.abc.Iterable[T], jax.Array)
assert subs == {T: jax.Array}


def test_unify_operation_callable():
"""An ``Operation[P, R]`` unifies as a ``Callable[P, R]`` (gh #669)."""
from effectful.ops.types import Operation

# TypeVar params bind to the operation's parameter/return types
assert unify(collections.abc.Callable[[T], V], Operation[[int], int]) == {
T: int,
V: int,
}
# a repeated TypeVar binds consistently
assert unify(collections.abc.Callable[[T], T], Operation[[int], int]) == {T: int}
# multiple parameters
assert unify(collections.abc.Callable[[T, U], V], Operation[[int, str], bool]) == {
T: int,
U: str,
V: bool,
}
# ``...`` parameters in the pattern ignore the operation's parameter types
assert unify(collections.abc.Callable[..., V], Operation[[int], int]) == {V: int}
# fully concrete: nothing to bind
assert unify(collections.abc.Callable[[int], int], Operation[[int], int]) == {}
# nested: an operation-valued argument
assert unify(
collections.abc.Callable[[T], list[V]], Operation[[int], list[str]]
) == {T: int, V: str}


def test_unify_operation_callable_failure():
"""An arity mismatch between the Callable pattern and the Operation fails."""
from effectful.ops.types import Operation

with pytest.raises(TypeError):
unify(collections.abc.Callable[[T, U], V], Operation[[int], int])
with pytest.raises(TypeError):
unify(collections.abc.Callable[[T], V], Operation[[int, str], bool])


def test_operation_unifies_with_callable_param_gh669():
"""An Operation passed where a ``Callable`` is expected infers correctly.

Regression test for gh #669: calling an operation whose parameter is typed
``Callable[[S], T]`` with another operation should unify and infer the return
type, rather than raising ``Cannot unify generic type ...``.
"""
from effectful.ops.semantics import typeof
from effectful.ops.types import NotHandled, Operation

@Operation.define
def f(x: int) -> int:
raise NotHandled

@Operation.define
def g[S, R](x: collections.abc.Callable[[S], R]) -> R:
raise NotHandled

term = g(f)
assert typeof(term) is int
Loading