44import shutil
55import tempfile
66import uuid
7+ import warnings
78from datetime import datetime
89from enum import Enum
910from functools import lru_cache
10- from typing import Any , Callable , Dict , Optional
11+ from typing import Any , Callable , Dict , List , Optional , Tuple
1112
1213import networkx as nx
1314
@@ -271,25 +272,21 @@ def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGrap
271272
272273 return dag
273274
274- def _check_projected_mem (self , dag ) -> None :
275- op_name = None
276- max_projected_mem_op = None
275+ def _find_ops_exceeding_memory (self , dag ) -> List [Tuple [str , "PrimitiveOperation" ]]:
276+ """Find all operations where projected memory exceeds allowed memory.
277+
278+ Returns a list of (op_name, primitive_op) tuples for operations that
279+ exceed memory limits, sorted by projected memory (highest first).
280+ """
281+ ops_exceeding = []
277282 for n , d in dag .nodes (data = True ):
278283 if "primitive_op" in d :
279284 op = d ["primitive_op" ]
280- if (
281- max_projected_mem_op is None
282- or op .projected_mem > max_projected_mem_op .projected_mem
283- ):
284- op_name = n
285- max_projected_mem_op = op
286- if max_projected_mem_op is not None :
287- op = max_projected_mem_op
288- if op .projected_mem > op .allowed_mem :
289- raise ValueError (
290- f"Projected blockwise memory ({ memory_repr (op .projected_mem )} ) exceeds allowed_mem ({ memory_repr (op .allowed_mem )} ), "
291- f"including reserved_mem ({ memory_repr (op .reserved_mem )} ) for { op_name } "
292- )
285+ if op .projected_mem > op .allowed_mem :
286+ ops_exceeding .append ((n , op ))
287+ # Sort by projected_mem descending so worst offenders are first
288+ ops_exceeding .sort (key = lambda x : x [1 ].projected_mem , reverse = True )
289+ return ops_exceeding
293290
294291 @lru_cache # noqa: B019
295292 def _finalize (
@@ -304,8 +301,10 @@ def _finalize(
304301 if callable (compile_function ):
305302 dag = self ._compile_blockwise (dag , compile_function )
306303 dag = self ._create_lazy_zarr_arrays (dag )
307- self ._check_projected_mem (dag )
308- return FinalizedPlan (nx .freeze (dag ), self .array_names , optimize_graph )
304+ ops_exceeding_memory = self ._find_ops_exceeding_memory (dag )
305+ return FinalizedPlan (
306+ nx .freeze (dag ), self .array_names , optimize_graph , ops_exceeding_memory
307+ )
309308
310309
311310class ArrayRole (Enum ):
@@ -324,10 +323,11 @@ class FinalizedPlan:
324323 4. freezing the final DAG so it can't be changed
325324 """
326325
327- def __init__ (self , dag , array_names , optimized ):
326+ def __init__ (self , dag , array_names , optimized , ops_exceeding_memory = None ):
328327 self .dag = dag
329328 self .array_names = array_names
330329 self .optimized = optimized
330+ self ._ops_exceeding_memory = ops_exceeding_memory or []
331331 self ._calculate_stats ()
332332
333333 self .input_array_names = []
@@ -540,6 +540,34 @@ def total_nchunks(self) -> int:
540540 """The total number of chunks for all materialized arrays in this plan."""
541541 return self ._total_nchunks
542542
543+ @property
544+ def exceeds_memory (self ) -> bool :
545+ """True if any operation in this plan exceeds the allowed memory."""
546+ return len (self ._ops_exceeding_memory ) > 0
547+
548+ @property
549+ def ops_exceeding_memory (self ) -> List [Tuple [str , "PrimitiveOperation" ]]:
550+ """List of (op_name, primitive_op) tuples for operations exceeding memory.
551+
552+ Sorted by projected memory (highest first).
553+ """
554+ return self ._ops_exceeding_memory
555+
556+ def validate (self ) -> None :
557+ """Validate that this plan can be executed.
558+
559+ Raises
560+ ------
561+ ValueError
562+ If any operation's projected memory exceeds the allowed memory.
563+ """
564+ if self ._ops_exceeding_memory :
565+ op_name , op = self ._ops_exceeding_memory [0 ] # Report worst offender
566+ raise ValueError (
567+ f"Projected blockwise memory ({ memory_repr (op .projected_mem )} ) exceeds allowed_mem ({ memory_repr (op .allowed_mem )} ), "
568+ f"including reserved_mem ({ memory_repr (op .reserved_mem )} ) for { op_name } "
569+ )
570+
543571 def execute (
544572 self ,
545573 executor = None ,
@@ -548,6 +576,8 @@ def execute(
548576 spec = None ,
549577 ** kwargs ,
550578 ):
579+ self .validate ()
580+
551581 dag = self .dag
552582
553583 if resume :
@@ -580,6 +610,15 @@ def visualize(
580610 rankdir = "TB" ,
581611 show_hidden = False ,
582612 ):
613+ if self ._ops_exceeding_memory :
614+ op_names = [name for name , _ in self ._ops_exceeding_memory ]
615+ warnings .warn (
616+ f"Plan has { len (self ._ops_exceeding_memory )} operation(s) that exceed allowed memory: { op_names } . "
617+ "These are shown in red in the visualization." ,
618+ stacklevel = 2 ,
619+ )
620+ ops_exceeding_names = {name for name , _ in self ._ops_exceeding_memory }
621+
583622 dag = self .dag .copy () # make a copy since we mutate the DAG below
584623
585624 # remove edges from create-arrays output node to avoid cluttering the diagram
@@ -590,19 +629,39 @@ def visualize(
590629 list (n for n , d in dag .nodes (data = True ) if d .get ("hidden" , False ))
591630 )
592631
632+ # Build the graph label - use HTML-like label for mixed colors if memory exceeded
633+ stats_text = (
634+ f"num tasks: { self .num_tasks } <BR ALIGN='LEFT'/>"
635+ f"max projected memory: { memory_repr (self .max_projected_mem )} <BR ALIGN='LEFT'/>"
636+ f"total nbytes written: { memory_repr (self .total_nbytes_written )} <BR ALIGN='LEFT'/>"
637+ f"optimized: { self .optimized } <BR ALIGN='LEFT'/>"
638+ )
639+
640+ if self ._ops_exceeding_memory :
641+ # Build warning text in red
642+ warning_lines = [
643+ "<BR ALIGN='LEFT'/>!!! MEMORY EXCEEDED !!!<BR ALIGN='LEFT'/>"
644+ ]
645+ for op_name , op in self ._ops_exceeding_memory :
646+ warning_lines .append (
647+ f"{ op_name } : requires { memory_repr (op .projected_mem )} , "
648+ f"allowed { memory_repr (op .allowed_mem )} <BR ALIGN='LEFT'/>"
649+ )
650+ warning_text = "" .join (warning_lines )
651+ # HTML-like label with mixed colors
652+ label = f"<<FONT>{ stats_text } </FONT><FONT COLOR='#cc0000'>{ warning_text } </FONT>>"
653+ else :
654+ # Simple HTML label (no warning)
655+ label = f"<{ stats_text } >"
656+
593657 dag .graph ["graph" ] = {
594658 "rankdir" : rankdir ,
595- "label" : (
596- # note that \l is used to left-justify each line (see https://www.graphviz.org/docs/attrs/nojustify/)
597- rf"num tasks: { self .num_tasks } \l"
598- rf"max projected memory: { memory_repr (self .max_projected_mem )} \l"
599- rf"total nbytes written: { memory_repr (self .total_nbytes_written )} \l"
600- rf"optimized: { self .optimized } \l"
601- ),
659+ "label" : label ,
602660 "labelloc" : "bottom" ,
603661 "labeljust" : "left" ,
604662 "fontsize" : "10" ,
605663 }
664+
606665 dag .graph ["node" ] = {"fontname" : "helvetica" , "shape" : "box" , "fontsize" : "10" }
607666
608667 # do an initial pass to extract array variable names from stack summaries
@@ -627,7 +686,11 @@ def visualize(
627686 func_name = d ["func_name" ]
628687 label = f"{ n } \n { func_name } " .strip ()
629688 op_name = d ["op_name" ]
630- if op_name == "blockwise" :
689+ if n in ops_exceeding_names :
690+ # operation exceeds memory - show in red
691+ d ["style" ] = '"rounded,filled"'
692+ d ["fillcolor" ] = "#ff6b6b"
693+ elif op_name == "blockwise" :
631694 d ["style" ] = '"rounded,filled"'
632695 d ["fillcolor" ] = "#dcbeff"
633696 elif op_name == "rechunk" :
0 commit comments