Skip to content

Commit ba68668

Browse files
committed
test: add unit trst to catch caching behavior
1 parent 6217bf9 commit ba68668

File tree

3 files changed

+114
-10
lines changed

3 files changed

+114
-10
lines changed

marimo/_ast/app.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,15 +336,15 @@ def cell(
336336
337337
```
338338
@app.cell
339-
def __(mo):
339+
def _(mo):
340340
# ...
341341
342342
@app.cell()
343-
def __(mo):
343+
def _(mo):
344344
# ...
345345
346346
@app.cell(disabled=True)
347-
def __(mo):
347+
def _(mo):
348348
# ...
349349
```
350350
@@ -759,7 +759,6 @@ async def _function_call(
759759
async def embed(
760760
self,
761761
defs: dict[str, Any] | None = None,
762-
**kwargs: Any,
763762
) -> AppEmbedResult:
764763
"""Embed a notebook into another notebook.
765764
@@ -823,16 +822,13 @@ async def embed(
823822
arguments. marimo will use these values instead of executing
824823
the cells that would normally define them. Cells that depend
825824
on these variables will use your provided values.
826-
**kwargs (Any):
827-
For forward-compatibility with future arguments.
828825
829826
Returns:
830827
An object `result` with two attributes: `result.output` (visual
831828
output of the notebook) and `result.defs` (a dictionary mapping
832829
variable names defined by the notebook to their values).
833830
834831
"""
835-
del kwargs
836832
from marimo._plugins.stateless.flex import vstack
837833
from marimo._runtime.context.utils import running_in_notebook
838834

tests/_ast/test_app.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,108 @@ def __(x: int, y: int) -> None:
11501150
assert result.defs["y"] == 10 # y cell still ran
11511151
assert "x=100, y=10" in result.output.text
11521152

1153+
async def test_app_embed_with_defs_stale_outputs(self) -> None:
1154+
"""Test that embed() doesn't return stale cached outputs with different defs."""
1155+
app = App()
1156+
1157+
@app.cell
1158+
def __() -> tuple[int]:
1159+
x = 10
1160+
return (x,)
1161+
1162+
@app.cell
1163+
def __(x: int) -> None:
1164+
"x is small" if x == 10 else "x is large"
1165+
1166+
# First call - no override
1167+
result_initial = await app.embed()
1168+
assert result_initial.defs["x"] == 10
1169+
assert "x is small" in result_initial.output.text
1170+
1171+
# Second call - with first override
1172+
result_override = await app.embed(defs={"x": 100})
1173+
assert result_override.defs["x"] == 100
1174+
assert "x is large" in result_override.output.text
1175+
assert "x is small" not in result_override.output.text
1176+
1177+
# Third call - with second override
1178+
result_override2 = await app.embed(defs={"x": 200})
1179+
assert result_override2.defs["x"] == 200
1180+
assert "x is large" in result_override2.output.text
1181+
assert "x is small" not in result_override2.output.text
1182+
1183+
# Check that initial result wasn't mutated by subsequent calls
1184+
assert result_initial.defs["x"] == 10
1185+
assert "x is small" in result_initial.output.text
1186+
assert "x is large" not in result_initial.output.text
1187+
1188+
async def test_app_embed_with_defs_stale_outputs_kernel(
1189+
self, k: Kernel, exec_req: ExecReqProvider
1190+
) -> None:
1191+
"""Test embed() with different defs through kernel (tests caching code path)."""
1192+
await k.run(
1193+
[
1194+
exec_req.get(
1195+
"""
1196+
from marimo import App
1197+
1198+
app = App()
1199+
1200+
@app.cell
1201+
def __() -> tuple[int]:
1202+
x = 10
1203+
return (x,)
1204+
1205+
@app.cell
1206+
def __(x: int) -> None:
1207+
"x is small" if x == 10 else "x is large"
1208+
"""
1209+
),
1210+
exec_req.get(
1211+
"""
1212+
# First call - no override
1213+
result_initial = await app.embed()
1214+
"""
1215+
),
1216+
exec_req.get(
1217+
"""
1218+
# Second call - with first override
1219+
result_override = await app.embed(defs={"x": 100})
1220+
"""
1221+
),
1222+
exec_req.get(
1223+
"""
1224+
# Third call - with second override
1225+
result_override2 = await app.embed(defs={"x": 200})
1226+
"""
1227+
),
1228+
]
1229+
)
1230+
assert not k.errors
1231+
1232+
result_initial = k.globals["result_initial"]
1233+
result_override = k.globals["result_override"]
1234+
result_override2 = k.globals["result_override2"]
1235+
1236+
# Check first result - output then defs
1237+
assert "x is small" in result_initial.output.text
1238+
assert result_initial.defs["x"] == 10
1239+
1240+
# Check second result with first override - output then defs
1241+
assert "x is large" in result_override.output.text
1242+
assert "x is small" not in result_override.output.text
1243+
assert result_override.defs["x"] == 100
1244+
1245+
# Check third result with second override - output then defs
1246+
assert "x is large" in result_override2.output.text
1247+
assert "x is small" not in result_override2.output.text
1248+
assert result_override2.defs["x"] == 200
1249+
1250+
# Check that initial result wasn't mutated by subsequent calls
1251+
assert "x is small" in result_initial.output.text
1252+
assert "x is large" not in result_initial.output.text
1253+
assert result_initial.defs["x"] == 10
1254+
11531255
@pytest.mark.xfail(
11541256
True, reason="Flaky in CI, can't repro locally", strict=False
11551257
)

tests/_runtime/test_dataflow.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,7 +1445,9 @@ def test_prune_cells_for_overrides_single_cell() -> None:
14451445
execution_order = ["0", "1", "2"]
14461446

14471447
# Override x - should prune cell 0
1448-
result = dataflow.prune_cells_for_overrides(graph, execution_order, {"x": 100})
1448+
result = dataflow.prune_cells_for_overrides(
1449+
graph, execution_order, {"x": 100}
1450+
)
14491451
assert result == ["1", "2"]
14501452

14511453

@@ -1537,7 +1539,9 @@ def test_prune_cells_for_overrides_partial_override() -> None:
15371539
execution_order = ["0", "1", "2"]
15381540

15391541
# Override only x - should prune only cell 0
1540-
result = dataflow.prune_cells_for_overrides(graph, execution_order, {"x": 100})
1542+
result = dataflow.prune_cells_for_overrides(
1543+
graph, execution_order, {"x": 100}
1544+
)
15411545
assert result == ["1", "2"]
15421546

15431547

@@ -1581,5 +1585,7 @@ def test_prune_cells_for_overrides_preserves_order() -> None:
15811585
execution_order = ["0", "1", "2", "3"]
15821586

15831587
# Override b - should prune only cell 1, preserving order
1584-
result = dataflow.prune_cells_for_overrides(graph, execution_order, {"b": 100})
1588+
result = dataflow.prune_cells_for_overrides(
1589+
graph, execution_order, {"b": 100}
1590+
)
15851591
assert result == ["0", "2", "3"]

0 commit comments

Comments
 (0)