@@ -72,41 +72,27 @@ def choose_chunk_shape(
7272 Returns:
7373 List of length `len(write_shape)` specifying the chosen chunk shape.
7474 """
75- # TensorStore Zarr metadata doesn't support 0-sized dimensions.
76- write_shape = tuple (max (1 , d ) for d in write_shape )
77-
78- if target_byte_size is None :
79- return write_shape
80-
81- # TODO: b/354139177 - This check is too generous; the minimally viable chunk
82- # size should be set to something within the range of [4 KiB; 1 MiB] (from
83- # TensorStore and storage performance considerations).
84- if target_byte_size < dtype .itemsize :
85- raise ValueError (
86- f'target_byte_size={ target_byte_size } must be >= { dtype .itemsize } '
87- )
88-
8975 if len (global_shape ) != len (write_shape ):
9076 raise ValueError (
9177 f'global_shape={ global_shape } and write_shape={ write_shape } must have'
9278 ' the same length.'
9379 )
94- if target_byte_size < 1 * _MIB :
95- logging .warning (
96- 'Setting the target_byte_size too small could reduce performance.'
97- )
80+
81+ # TensorStore Zarr metadata doesn't support 0-sized dimensions.
82+ write_shape = tuple (max (1 , d ) for d in write_shape )
83+
84+ if target_byte_size is None and not shard_axes :
85+ # No restrictions on chunk size or shape; return the write shape as-is.
86+ return write_shape
9887
9988 sharded_dimensions = np .array (global_shape ) != np .array (write_shape )
10089 dtype_size = dtype .itemsize
101- target_elements = target_byte_size // dtype_size
102-
10390 rank = len (write_shape )
10491
10592 # `dim_factors[i]` is the list of divisors of `write_shape[i]`
106- dim_factors = [_find_divisors (size ) for size in write_shape ]
107-
10893 # The current chunk shape is:
109- # [dim_factors[i][-1] for i in range(rank)]
94+ # [dim_factors[i][-1] for i in range(rank)]
95+ dim_factors = [_find_divisors (size ) for size in write_shape ]
11096
11197 total_elements = math .prod (write_shape )
11298
@@ -118,57 +104,72 @@ def reduce_dim(dim_to_reduce: int) -> None:
118104 total_elements = (total_elements // current_dim ) * new_dim
119105 sharded_dimensions [dim_to_reduce ] = True
120106
121- # First, try to reduce the size of the chunk shape on the `shard_axes`.
122- # If some of these specified axes are already sharded, we will skip them on
123- # the first iteration which ensures that we shard at least once on each of the
124- # `shard_axes`. It might also be the case that the given target_byte_size is
125- # too big to shard on all of the requested axes, in which case we will
126- # maximize the number of the number of axes that are sharded.
127- (shard_axes := list (shard_axes )).sort ()
128- could_shard = bool (shard_axes )
129- first_sharding_iteration = True
130- while could_shard and total_elements > target_elements :
131- could_shard = False
132- # For the first pass, exclude dimensions that are already sharded.
133- # We do our best to shard at least once of each of the `shard_axes`.
134- if first_sharding_iteration :
135- must_shard_dims = list (i for i in shard_axes if not sharded_dimensions [i ])
136- if not must_shard_dims :
137- # In case all priority axes are already sharded, use all of them.
138- must_shard_dims = shard_axes
139- first_sharding_iteration = False
140- else :
141- must_shard_dims = shard_axes
142- # Exclude dimensions that can no longer be sharded.
143- must_shard_dims = list (
144- i for i in must_shard_dims if len (dim_factors [i ]) > 1
107+ if any (axis < 0 or axis >= rank for axis in shard_axes ):
108+ raise ValueError (
109+ f'All shard_axes={ shard_axes } must be non-negative and less than'
110+ f' rank={ rank } .'
145111 )
146- # Shard once on each of the remaining dimensions in a round-robin fashion,
147- # while we can.
148- while must_shard_dims and total_elements > target_elements :
149- could_shard = True
150- # Find the minimum available divisor among the remaining dimensions.
151- dim_idx = min (
152- must_shard_dims ,
153- key = lambda i : dim_factors [i ][- 1 ] // dim_factors [i ][- 2 ],
154- )
155- reduce_dim (dim_idx )
156- must_shard_dims .remove (dim_idx )
157112
158- if shard_axes :
113+ # Reduce all explicitly requested shard axes.
114+ for shard_axis in shard_axes :
115+ while len (dim_factors [shard_axis ]) > 1 :
116+ reduce_dim (shard_axis )
117+
118+ if target_byte_size is None :
159119 current_shape = tuple (dim_factors [i ][- 1 ] for i in range (rank ))
160120 if current_shape != write_shape :
161121 logging .vlog (
162122 1 ,
163- 'Reduced write shape using shard_axes=%s: global_shape=%s, '
164- ' write_shape =%s, dtype =%s, target_byte_size=%d ; reduced shape: %s' ,
123+ 'Reduced write shape using shard_axes=%s only (no target_byte_size): '
124+ ' global_shape =%s, write_shape =%s, dtype=%s ; reduced shape: %s' ,
165125 shard_axes ,
166126 global_shape ,
167127 write_shape ,
168128 dtype ,
169- target_byte_size ,
170129 current_shape ,
171130 )
131+ return current_shape
132+
133+ # A target byte size is also specified. We will now try to find the smallest
134+ # chunk shape that satisfies the target byte size.
135+
136+ # TODO: b/354139177 - This check is too generous; the minimally viable chunk
137+ # size should be set to something within the range of [4 KiB; 1 MiB] (from
138+ # TensorStore and storage performance considerations).
139+ if target_byte_size < dtype .itemsize :
140+ raise ValueError (
141+ f'target_byte_size={ target_byte_size } must be >= { dtype .itemsize } '
142+ )
143+
144+ if target_byte_size < 1 * _MIB :
145+ logging .warning (
146+ 'Setting the target_byte_size too small could reduce performance.'
147+ )
148+
149+ target_elements = target_byte_size // dtype_size
150+
151+ # First, try to reduce the size of the chunk shape on axes that are already
152+ # sharded.
153+ could_shard = True
154+ while could_shard and total_elements > target_elements :
155+ could_shard = False
156+ # Find all dimensions that are sharded and can be sharded further.
157+ candidate_shard_dims = list (
158+ i
159+ for i in shard_axes
160+ if sharded_dimensions [i ] and len (dim_factors [i ]) > 1
161+ )
162+ # Shard once on each of the remaining dimensions in a round-robin fashion,
163+ # while we can.
164+ while candidate_shard_dims and total_elements > target_elements :
165+ could_shard = True
166+ # Find the minimum available divisor among the remaining dimensions.
167+ dim_idx = min (
168+ candidate_shard_dims ,
169+ key = lambda i : dim_factors [i ][- 1 ] // dim_factors [i ][- 2 ],
170+ )
171+ reduce_dim (dim_idx )
172+ candidate_shard_dims .remove (dim_idx )
172173
173174 # If we are not within target_byte_size yet, continue to reduce the current
174175 # chunk shape until the desired number of elements is reached.
0 commit comments