Skip to content

Commit 307ea66

Browse files
authored
feat: replay an execution (#1211)
Signed-off-by: Louis Mandel <[email protected]>
1 parent b703453 commit 307ea66

File tree

4 files changed

+98
-18
lines changed

4 files changed

+98
-18
lines changed

src/pdl/pdl.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class InterpreterConfig(TypedDict, total=False):
4747
"""
4848
cwd: Path
4949
"""Path considered as the current working directory for file reading."""
50+
replay: dict[str, Any]
51+
"""Execute the program reusing some already computed values.
52+
"""
5053

5154

5255
def exec_program(
@@ -66,9 +69,10 @@ def exec_program(
6669
output: Configure the output of the returned value of this function. Defaults to `"result"`
6770
6871
Returns:
69-
Return the final result if `output` is set to `"result"`. If set of `all`, it returns a dictionary containing, `result`, `scope`, and `trace`.
72+
Return the final result if `output` is set to `"result"`. If set of `all`, it returns a dictionary containing, `result`, `scope`, `trace`, and `replay`.
7073
"""
71-
config = config or {}
74+
config = config or InterpreterConfig()
75+
config["replay"] = dict(config.get("replay", {}))
7276
state = InterpreterState(**config)
7377
if not isinstance(scope, PdlDict):
7478
scope = PdlDict(scope or {})
@@ -83,7 +87,12 @@ def exec_program(
8387
return result
8488
case "all":
8589
scope = future_scope.result()
86-
return {"result": result, "scope": scope, "trace": trace}
90+
return {
91+
"result": result,
92+
"scope": scope,
93+
"trace": trace,
94+
"replay": state.replay,
95+
}
8796
case _:
8897
assert False, 'The `output` variable should be "result" or "all"'
8998

src/pdl/pdl_interpreter.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ class InterpreterState(BaseModel):
188188
"""Event loop to schedule LLM calls."""
189189
current_pdl_context: Ref[LazyMessages] = Ref(DependentContext([]))
190190
"""Current value of the context set at the beginning of the execution of the block."""
191+
replay: dict[str, Any] = {}
191192

192193
def with_yield_result(self: "InterpreterState", b: bool) -> "InterpreterState":
193194
return self.model_copy(update={"yield_result": b})
@@ -305,7 +306,7 @@ def process_prog(
305306
stdlib_file = Path(__file__).parent / "pdl_stdlib.pdl"
306307
stdlib, _ = parse_file(stdlib_file)
307308
_, _, stdlib_dict, _ = process_block(
308-
state.with_yield_background(False).with_yield_result(False),
309+
state.with_yield_background(False).with_yield_result(False).with_id("stdlib"),
309310
empty_scope,
310311
stdlib.root,
311312
loc,
@@ -505,7 +506,7 @@ def process_advance_block_retry( # noqa: C901
505506
trial_total = max_retry + 1
506507
for trial_idx in range(trial_total): # pylint: disable=too-many-nested-blocks
507508
try:
508-
result, background, new_scope, trace = process_block_body(
509+
result, background, new_scope, trace = process_block_body_with_replay(
509510
state, scope, block, loc
510511
)
511512

@@ -640,6 +641,42 @@ def result_with_type_checking(
640641
return result
641642

642643

644+
def process_block_body_with_replay(
645+
state: InterpreterState,
646+
scope: ScopeType,
647+
block: AdvancedBlockType,
648+
loc: PdlLocationType,
649+
) -> tuple[PdlLazy[Any], LazyMessages, ScopeType, AdvancedBlockType]:
650+
if isinstance(block, LeafBlock):
651+
block_id = block.pdl__id
652+
assert isinstance(block_id, str)
653+
try:
654+
result = state.replay[block_id]
655+
background: LazyMessages = SingletonContext(
656+
PdlDict({"role": state.role, "content": result})
657+
)
658+
if state.yield_result:
659+
yield_result(result.result(), block.kind)
660+
if state.yield_background:
661+
yield_background(background)
662+
trace = block
663+
# Special case
664+
match block:
665+
case ModelBlock():
666+
if block.modelResponse is not None:
667+
assert block.pdl__id is not None
668+
raw_result = state.replay[block.pdl__id + ".modelResponse"]
669+
scope = scope | {block.modelResponse: raw_result}
670+
except KeyError:
671+
result, background, scope, trace = process_block_body(
672+
state, scope, block, loc
673+
)
674+
state.replay[block_id] = result
675+
else:
676+
result, background, scope, trace = process_block_body(state, scope, block, loc)
677+
return result, background, scope, trace
678+
679+
643680
def process_block_body(
644681
state: InterpreterState,
645682
scope: ScopeType,
@@ -1815,6 +1852,8 @@ def get_transformed_inputs(kwargs):
18151852
)
18161853
if block.modelResponse is not None:
18171854
scope = scope | {block.modelResponse: raw_result}
1855+
assert block.pdl__id is not None
1856+
state.replay[block.pdl__id + ".modelResponse"] = raw_result
18181857
trace: BlockTypeTVarProcessCallModel = concrete_block.model_copy(
18191858
update={"pdl__result": result}
18201859
) # pyright: ignore

tests/data/function.pdl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,3 @@ defs:
1313
${ notes }
1414

1515
### Answer:
16-
17-
18-

tests/test_examples_run.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import random
55
from dataclasses import dataclass, field
66
from enum import Enum
7-
from typing import Dict, List, Optional
7+
from typing import Dict, List, Optional, Tuple
88

99
import yaml
1010
from pytest import CaptureFixture, MonkeyPatch
@@ -104,6 +104,7 @@ class FailedResults:
104104
wrong_results: Dict[str, str] = field(default_factory=lambda: {})
105105
unexpected_parse_error: Dict[str, str] = field(default_factory=lambda: {})
106106
unexpected_runtime_error: Dict[str, str] = field(default_factory=lambda: {})
107+
wrong_replay_results: Dict[str, str] = field(default_factory=lambda: {})
107108

108109

109110
# pylint: disable=too-many-instance-attributes
@@ -161,7 +162,9 @@ def __init__(self, monkeypatch: MonkeyPatch) -> None:
161162
self.__collect_expected_results()
162163

163164
# Inits execution results for each PDL file
164-
self.execution_results: Dict[str, ExecutionResult] = {}
165+
self.execution_results: Dict[
166+
str, Tuple[ExecutionResult, ExecutionResult | None]
167+
] = {}
165168

166169
# Init failed results
167170
self.failed_results = FailedResults()
@@ -199,13 +202,11 @@ def __collect_expected_results(self) -> None:
199202

200203
self.expected_results[file] = expected_result
201204

202-
def __execute_file(self, pdl_file_name: str) -> None:
205+
def __execute_and_replay_file(self, pdl_file_name: str) -> None:
203206
"""
204207
Tests the result of a single file and returns the result output and the error code
205208
"""
206209

207-
exec_result = ExecutionResult()
208-
209210
pdl_file_path = pathlib.Path(pdl_file_name)
210211
scope: ScopeType = PdlDict({})
211212

@@ -217,13 +218,27 @@ def __execute_file(self, pdl_file_name: str) -> None:
217218
if inputs.scope is not None:
218219
scope = inputs.scope
219220

221+
exec_result, output = self.__execute_file(pdl_file_path, scope, replay={})
222+
223+
if output is not None:
224+
replay_result, _ = self.__execute_file(
225+
pdl_file_path, scope, replay=output["replay"]
226+
)
227+
else:
228+
replay_result = None
229+
230+
self.execution_results[pdl_file_name] = exec_result, replay_result
231+
232+
def __execute_file(self, pdl_file_path, scope, replay):
233+
exec_result = ExecutionResult()
234+
output = None
220235
try:
221236
# Execute file
222237
output = pdl.exec_file(
223238
pdl_file_path,
224239
scope=scope,
225240
output="all",
226-
config=pdl.InterpreterConfig(batch=1),
241+
config=pdl.InterpreterConfig(batch=1, replay=replay),
227242
)
228243

229244
exec_result.result = str(output["result"])
@@ -235,8 +250,7 @@ def __execute_file(self, pdl_file_name: str) -> None:
235250
except Exception as exc:
236251
exec_result.result = str(exc)
237252
exec_result.error_code = ExecutionErrorCode.RUNTIME_ERROR
238-
239-
self.execution_results[pdl_file_name] = exec_result
253+
return exec_result, output
240254

241255
def populate_exec_result_for_checks(self) -> None:
242256
"""
@@ -245,7 +259,7 @@ def populate_exec_result_for_checks(self) -> None:
245259

246260
for file in self.check:
247261
if file not in self.skip:
248-
self.__execute_file(file)
262+
self.__execute_and_replay_file(file)
249263

250264
def validate_expected_and_actual(self) -> None:
251265
"""
@@ -256,11 +270,12 @@ def validate_expected_and_actual(self) -> None:
256270
wrong_result: Dict[str, str] = {}
257271
unexpected_parse_error: Dict[str, str] = {}
258272
unexpected_runtime_error: Dict[str, str] = {}
273+
wrong_replay_result: Dict[str, str] = {}
259274

260275
for file in self.check:
261276
if file not in self.skip:
262277
expected_result = self.expected_results[file]
263-
actual_result = self.execution_results[file]
278+
actual_result, replay_result = self.execution_results[file]
264279
match = expected_result.compare_to_execution(actual_result)
265280

266281
if not match:
@@ -274,7 +289,14 @@ def validate_expected_and_actual(self) -> None:
274289
if actual_result.result is not None:
275290
wrong_result[file] = actual_result.result
276291

292+
if replay_result is not None:
293+
match_replay = expected_result.compare_to_execution(replay_result)
294+
if not match_replay:
295+
if replay_result.result is not None:
296+
wrong_replay_result[file] = replay_result.result
297+
277298
self.failed_results.wrong_results = wrong_result
299+
self.failed_results.wrong_replay_results = wrong_replay_result
278300
self.failed_results.unexpected_parse_error = unexpected_parse_error
279301
self.failed_results.unexpected_runtime_error = unexpected_runtime_error
280302

@@ -347,6 +369,16 @@ def test_example_runs(capsys: CaptureFixture[str], monkeypatch: MonkeyPatch) ->
347369
f"Actual result (copy everything below this line):\n✂️ ------------------------------------------------------------\n{actual}\n-------------------------------------------------------------"
348370
)
349371

372+
# Print the actual results for wrong replay results
373+
for file, actual in background.failed_results.wrong_replay_results.items():
374+
print(
375+
"\n============================================================================"
376+
)
377+
print(f"File that produced wrong REPLAY result: {file}")
378+
print(
379+
f"Replay result:\n ------------------------------------------------------------\n{actual}\n-------------------------------------------------------------"
380+
)
381+
350382
assert (
351383
len(background.failed_results.unexpected_parse_error) == 0
352384
), f"Unexpected parse error: {background.failed_results.unexpected_parse_error}"
@@ -356,3 +388,6 @@ def test_example_runs(capsys: CaptureFixture[str], monkeypatch: MonkeyPatch) ->
356388
assert (
357389
len(background.failed_results.wrong_results) == 0
358390
), f"Wrong results: {background.failed_results.wrong_results}"
391+
assert (
392+
len(background.failed_results.wrong_replay_results) == 0
393+
), f"Wrong replay results: {background.failed_results.wrong_results}"

0 commit comments

Comments
 (0)