Skip to content

Commit 6d1e87e

Browse files
neilSchroederneilSchroeder
andauthored
Visualize plans that exceed memory (#846)
* Refactor memory check method to find operations exceeding memory limits * Refactor memory check to identify operations exceeding allowed memory limits * Add operation memory check to finalize plan creation * Add optional parameter for operations exceeding memory in FinalizedPlan constructor * Add handling for operations exceeding memory in FinalizedPlan * Add memory validation and reporting in FinalizedPlan * Add validation call in FinalizedPlan constructor * Add warning for operations exceeding memory in FinalizedPlan visualization * Add warnings import for enhanced memory management in FinalizedPlan * Refactor imports in plan.py for improved organization * Add HTML warning for memory exceeded in FinalizedPlan visualization * Refactor FinalizedPlan graph label to use a predefined variable * Add missing line break for improved readability in FinalizedPlan class * Highlight operations exceeding memory in red within FinalizedPlan visualization * Add test for plan exceeding memory in default spec * Add visualization test for memory exceeded warning in default spec * lint * Update memory exceeded warning text format in FinalizedPlan class --------- Co-authored-by: neilSchroeder <[email protected]>
1 parent dd0ada4 commit 6d1e87e

File tree

2 files changed

+110
-29
lines changed

2 files changed

+110
-29
lines changed

cubed/core/plan.py

Lines changed: 91 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import shutil
55
import tempfile
66
import uuid
7+
import warnings
78
from datetime import datetime
89
from enum import Enum
910
from functools import lru_cache
10-
from typing import Any, Callable, Dict, Optional
11+
from typing import Any, Callable, Dict, List, Optional, Tuple
1112

1213
import networkx as nx
1314

@@ -271,25 +272,21 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap
271272

272273
return dag
273274

274-
def _check_projected_mem(self, dag) -> None:
275-
op_name = None
276-
max_projected_mem_op = None
275+
def _find_ops_exceeding_memory(self, dag) -> List[Tuple[str, "PrimitiveOperation"]]:
276+
"""Find all operations where projected memory exceeds allowed memory.
277+
278+
Returns a list of (op_name, primitive_op) tuples for operations that
279+
exceed memory limits, sorted by projected memory (highest first).
280+
"""
281+
ops_exceeding = []
277282
for n, d in dag.nodes(data=True):
278283
if "primitive_op" in d:
279284
op = d["primitive_op"]
280-
if (
281-
max_projected_mem_op is None
282-
or op.projected_mem > max_projected_mem_op.projected_mem
283-
):
284-
op_name = n
285-
max_projected_mem_op = op
286-
if max_projected_mem_op is not None:
287-
op = max_projected_mem_op
288-
if op.projected_mem > op.allowed_mem:
289-
raise ValueError(
290-
f"Projected blockwise memory ({memory_repr(op.projected_mem)}) exceeds allowed_mem ({memory_repr(op.allowed_mem)}), "
291-
f"including reserved_mem ({memory_repr(op.reserved_mem)}) for {op_name}"
292-
)
285+
if op.projected_mem > op.allowed_mem:
286+
ops_exceeding.append((n, op))
287+
# Sort by projected_mem descending so worst offenders are first
288+
ops_exceeding.sort(key=lambda x: x[1].projected_mem, reverse=True)
289+
return ops_exceeding
293290

294291
@lru_cache # noqa: B019
295292
def _finalize(
@@ -304,8 +301,10 @@ def _finalize(
304301
if callable(compile_function):
305302
dag = self._compile_blockwise(dag, compile_function)
306303
dag = self._create_lazy_zarr_arrays(dag)
307-
self._check_projected_mem(dag)
308-
return FinalizedPlan(nx.freeze(dag), self.array_names, optimize_graph)
304+
ops_exceeding_memory = self._find_ops_exceeding_memory(dag)
305+
return FinalizedPlan(
306+
nx.freeze(dag), self.array_names, optimize_graph, ops_exceeding_memory
307+
)
309308

310309

311310
class ArrayRole(Enum):
@@ -324,10 +323,11 @@ class FinalizedPlan:
324323
4. freezing the final DAG so it can't be changed
325324
"""
326325

327-
def __init__(self, dag, array_names, optimized):
326+
def __init__(self, dag, array_names, optimized, ops_exceeding_memory=None):
328327
self.dag = dag
329328
self.array_names = array_names
330329
self.optimized = optimized
330+
self._ops_exceeding_memory = ops_exceeding_memory or []
331331
self._calculate_stats()
332332

333333
self.input_array_names = []
@@ -540,6 +540,34 @@ def total_nchunks(self) -> int:
540540
"""The total number of chunks for all materialized arrays in this plan."""
541541
return self._total_nchunks
542542

543+
@property
544+
def exceeds_memory(self) -> bool:
545+
"""True if any operation in this plan exceeds the allowed memory."""
546+
return len(self._ops_exceeding_memory) > 0
547+
548+
@property
549+
def ops_exceeding_memory(self) -> List[Tuple[str, "PrimitiveOperation"]]:
550+
"""List of (op_name, primitive_op) tuples for operations exceeding memory.
551+
552+
Sorted by projected memory (highest first).
553+
"""
554+
return self._ops_exceeding_memory
555+
556+
def validate(self) -> None:
557+
"""Validate that this plan can be executed.
558+
559+
Raises
560+
------
561+
ValueError
562+
If any operation's projected memory exceeds the allowed memory.
563+
"""
564+
if self._ops_exceeding_memory:
565+
op_name, op = self._ops_exceeding_memory[0] # Report worst offender
566+
raise ValueError(
567+
f"Projected blockwise memory ({memory_repr(op.projected_mem)}) exceeds allowed_mem ({memory_repr(op.allowed_mem)}), "
568+
f"including reserved_mem ({memory_repr(op.reserved_mem)}) for {op_name}"
569+
)
570+
543571
def execute(
544572
self,
545573
executor=None,
@@ -548,6 +576,8 @@ def execute(
548576
spec=None,
549577
**kwargs,
550578
):
579+
self.validate()
580+
551581
dag = self.dag
552582

553583
if resume:
@@ -580,6 +610,15 @@ def visualize(
580610
rankdir="TB",
581611
show_hidden=False,
582612
):
613+
if self._ops_exceeding_memory:
614+
op_names = [name for name, _ in self._ops_exceeding_memory]
615+
warnings.warn(
616+
f"Plan has {len(self._ops_exceeding_memory)} operation(s) that exceed allowed memory: {op_names}. "
617+
"These are shown in red in the visualization.",
618+
stacklevel=2,
619+
)
620+
ops_exceeding_names = {name for name, _ in self._ops_exceeding_memory}
621+
583622
dag = self.dag.copy() # make a copy since we mutate the DAG below
584623

585624
# remove edges from create-arrays output node to avoid cluttering the diagram
@@ -590,19 +629,39 @@ def visualize(
590629
list(n for n, d in dag.nodes(data=True) if d.get("hidden", False))
591630
)
592631

632+
# Build the graph label - use HTML-like label for mixed colors if memory exceeded
633+
stats_text = (
634+
f"num tasks: {self.num_tasks}<BR ALIGN='LEFT'/>"
635+
f"max projected memory: {memory_repr(self.max_projected_mem)}<BR ALIGN='LEFT'/>"
636+
f"total nbytes written: {memory_repr(self.total_nbytes_written)}<BR ALIGN='LEFT'/>"
637+
f"optimized: {self.optimized}<BR ALIGN='LEFT'/>"
638+
)
639+
640+
if self._ops_exceeding_memory:
641+
# Build warning text in red
642+
warning_lines = [
643+
"<BR ALIGN='LEFT'/>!!! MEMORY EXCEEDED !!!<BR ALIGN='LEFT'/>"
644+
]
645+
for op_name, op in self._ops_exceeding_memory:
646+
warning_lines.append(
647+
f"{op_name}: requires {memory_repr(op.projected_mem)}, "
648+
f"allowed {memory_repr(op.allowed_mem)}<BR ALIGN='LEFT'/>"
649+
)
650+
warning_text = "".join(warning_lines)
651+
# HTML-like label with mixed colors
652+
label = f"<<FONT>{stats_text}</FONT><FONT COLOR='#cc0000'>{warning_text}</FONT>>"
653+
else:
654+
# Simple HTML label (no warning)
655+
label = f"<{stats_text}>"
656+
593657
dag.graph["graph"] = {
594658
"rankdir": rankdir,
595-
"label": (
596-
# note that \l is used to left-justify each line (see https://www.graphviz.org/docs/attrs/nojustify/)
597-
rf"num tasks: {self.num_tasks}\l"
598-
rf"max projected memory: {memory_repr(self.max_projected_mem)}\l"
599-
rf"total nbytes written: {memory_repr(self.total_nbytes_written)}\l"
600-
rf"optimized: {self.optimized}\l"
601-
),
659+
"label": label,
602660
"labelloc": "bottom",
603661
"labeljust": "left",
604662
"fontsize": "10",
605663
}
664+
606665
dag.graph["node"] = {"fontname": "helvetica", "shape": "box", "fontsize": "10"}
607666

608667
# do an initial pass to extract array variable names from stack summaries
@@ -627,7 +686,11 @@ def visualize(
627686
func_name = d["func_name"]
628687
label = f"{n}\n{func_name}".strip()
629688
op_name = d["op_name"]
630-
if op_name == "blockwise":
689+
if n in ops_exceeding_names:
690+
# operation exceeds memory - show in red
691+
d["style"] = '"rounded,filled"'
692+
d["fillcolor"] = "#ff6b6b"
693+
elif op_name == "blockwise":
631694
d["style"] = '"rounded,filled"'
632695
d["fillcolor"] = "#dcbeff"
633696
elif op_name == "rechunk":

cubed/tests/test_core.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,11 +493,29 @@ def test_default_spec_allowed_mem_exceeded():
493493
# default spec fails for large computations
494494
a = xp.ones((20000, 10000), chunks=(10000, 10000))
495495
b = xp.negative(a)
496+
# plan() succeeds but marks plan as exceeding memory
497+
plan = b.plan()
498+
assert plan.exceeds_memory
499+
assert len(plan.ops_exceeding_memory) == 1
500+
# compute() raises the error
496501
with pytest.raises(
497502
ValueError,
498503
match=r"Projected blockwise memory \(.+\) exceeds allowed_mem \(.+\), including reserved_mem \(.+\) for op-\d+",
499504
):
500-
b.plan()
505+
b.compute()
506+
507+
508+
def test_default_spec_allowed_mem_exceeded_visualize(tmp_path):
509+
# visualize works but warns when memory is exceeded
510+
import warnings
511+
512+
a = xp.ones((20000, 10000), chunks=(10000, 10000))
513+
b = xp.negative(a)
514+
with warnings.catch_warnings(record=True) as w:
515+
warnings.simplefilter("always")
516+
b.visualize(filename=str(tmp_path / "cubed"))
517+
assert len(w) == 1
518+
assert "exceed allowed memory" in str(w[0].message)
501519

502520

503521
def test_default_spec_config_override():

0 commit comments

Comments
 (0)