Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions effectful/handlers/llm/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
compile_restricted,
safe_globals,
)
from RestrictedPython.PrintCollector import PrintCollector

from effectful.internals.unification import nested_type
from effectful.ops.syntax import ObjectInterpretation, defop, implements
Expand Down Expand Up @@ -724,6 +725,16 @@ def exec(
builtins.exec(bytecode, env, env)


class _StdoutPrintCollector(PrintCollector):
"""`_print_` factory whose `print(...)` writes to the real `sys.stdout`
(so output-capturing callers see it) rather than accumulating into the
collector's discarded `printed` buffer."""

def _call_print(self, *objects, **kwargs):
kwargs.setdefault("file", sys.stdout)
builtins.print(*objects, **kwargs)


class RestrictedEvalProvider(ObjectInterpretation):
"""
Safer provider using RestrictedPython.
Expand Down Expand Up @@ -800,12 +811,21 @@ def exec(
rglobals["setattr"] = Guards.guarded_setattr
rglobals["_write_"] = lambda x: x

# Track keys before execution to identify new definitions
keys_before = set(rglobals.keys())
# RestrictedPython rewrites `print(...)` into its `_print_` collector
# protocol; route it to the real stdout so output-capturing callers
# (e.g. redirect_stdout) see it instead of a discarded collector.
rglobals["_print_"] = _StdoutPrintCollector

# Snapshot value identities before execution so we can copy back every
# *binding effect* — both new names and rebindings of seeded names.
before = dict(rglobals)
builtins.exec(bytecode, rglobals, rglobals)

# Copy newly defined items back to env so caller can access them
for key in rglobals:
if key not in keys_before:
env[key] = rglobals[key]
sentinel = object()
env.update(
{
key: value
for key, value in rglobals.items()
if key != "__builtins__" and before.get(key, sentinel) is not value
}
)
57 changes: 57 additions & 0 deletions tests/test_handlers_llm_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import ast
import builtins
import contextlib
import inspect
import io
import sys
import textwrap
import types
Expand All @@ -25,6 +27,9 @@
mypy_type_check,
type_to_ast,
)
from effectful.handlers.llm.evaluation import compile as compile_op
from effectful.handlers.llm.evaluation import exec as exec_op
from effectful.handlers.llm.evaluation import parse as parse_op
from effectful.internals.unification import nested_type
from effectful.ops.semantics import handler
from effectful.ops.syntax import defop
Expand Down Expand Up @@ -1537,3 +1542,55 @@ def test_builtins_in_env_does_not_bypass_security():
source_private, context=dangerous_ctx
)
fn("test")


# ============================================================================
# RestrictedEvalProvider state-retention and print (#685)
# ============================================================================


def _restricted_run(source: str, ns: dict, capture: bool = False) -> str | None:
"""Run one snippet through the parse/compile/exec ops under
RestrictedEvalProvider, optionally capturing stdout."""
with handler(RestrictedEvalProvider()):
code = compile_op(parse_op(source, "<f>"), "<f>")
if capture:
buf = io.StringIO()
with contextlib.redirect_stdout(buf):
exec_op(code, ns)
return buf.getvalue()
exec_op(code, ns)
return None


def test_restricted_exec_copies_back_rebound_seed():
"""#685: rebinding a name already present in the namespace writes the new
value back, not just never-before-seen names."""
ns = {"x": 1}
_restricted_run("x = 99", ns)
assert ns["x"] == 99


def test_restricted_exec_copies_back_rebound_new_key():
"""#685: a key that becomes a 'seed' after its first definition is still
rebindable on subsequent calls."""
ns: dict = {}
_restricted_run("y = 1", ns)
_restricted_run("y = 2", ns)
assert ns["y"] == 2


def test_restricted_exec_persists_and_rebinds_across_calls():
"""#685: the namespace is a real REPL session — a binding from one call is
usable in the next, and rebinding it using its prior value works."""
ns: dict = {}
_restricted_run("x = 10", ns)
_restricted_run("x = x + 1", ns)
assert ns["x"] == 11


def test_restricted_exec_print_captured_to_stdout():
"""#685: RestrictedPython's `print` is routed to the real stdout so
output-capturing callers see it (rather than NameError on `_print_`)."""
out = _restricted_run("print('hi')", {}, capture=True)
assert out == "hi\n"
Loading