Skip to content

Commit 198c6ee

Browse files
dicentra13Orbax Authors
authored andcommitted
Fully reduce explicitly requested axes in subchunking.
PiperOrigin-RevId: 829377963
1 parent 1f97396 commit 198c6ee

File tree

3 files changed

+162
-170
lines changed

3 files changed

+162
-170
lines changed

checkpoint/orbax/checkpoint/_src/arrays/subchunking.py

Lines changed: 63 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -72,41 +72,27 @@ def choose_chunk_shape(
7272
Returns:
7373
List of length `len(write_shape)` specifying the chosen chunk shape.
7474
"""
75-
# TensorStore Zarr metadata doesn't support 0-sized dimensions.
76-
write_shape = tuple(max(1, d) for d in write_shape)
77-
78-
if target_byte_size is None:
79-
return write_shape
80-
81-
# TODO: b/354139177 - This check is too generous; the minimally viable chunk
82-
# size should be set to something within the range of [4 KiB; 1 MiB] (from
83-
# TensorStore and storage performance considerations).
84-
if target_byte_size < dtype.itemsize:
85-
raise ValueError(
86-
f'target_byte_size={target_byte_size} must be >= {dtype.itemsize}'
87-
)
88-
8975
if len(global_shape) != len(write_shape):
9076
raise ValueError(
9177
f'global_shape={global_shape} and write_shape={write_shape} must have'
9278
' the same length.'
9379
)
94-
if target_byte_size < 1 * _MIB:
95-
logging.warning(
96-
'Setting the target_byte_size too small could reduce performance.'
97-
)
80+
81+
# TensorStore Zarr metadata doesn't support 0-sized dimensions.
82+
write_shape = tuple(max(1, d) for d in write_shape)
83+
84+
if target_byte_size is None and not shard_axes:
85+
# No restrictions on chunk size or shape; return the write shape as-is.
86+
return write_shape
9887

9988
sharded_dimensions = np.array(global_shape) != np.array(write_shape)
10089
dtype_size = dtype.itemsize
101-
target_elements = target_byte_size // dtype_size
102-
10390
rank = len(write_shape)
10491

10592
# `dim_factors[i]` is the list of divisors of `write_shape[i]`
106-
dim_factors = [_find_divisors(size) for size in write_shape]
107-
10893
# The current chunk shape is:
109-
# [dim_factors[i][-1] for i in range(rank)]
94+
# [dim_factors[i][-1] for i in range(rank)]
95+
dim_factors = [_find_divisors(size) for size in write_shape]
11096

11197
total_elements = math.prod(write_shape)
11298

@@ -118,57 +104,72 @@ def reduce_dim(dim_to_reduce: int) -> None:
118104
total_elements = (total_elements // current_dim) * new_dim
119105
sharded_dimensions[dim_to_reduce] = True
120106

121-
# First, try to reduce the size of the chunk shape on the `shard_axes`.
122-
# If some of these specified axes are already sharded, we will skip them on
123-
# the first iteration which ensures that we shard at least once on each of the
124-
# `shard_axes`. It might also be the case that the given target_byte_size is
125-
# too big to shard on all of the requested axes, in which case we will
126-
# maximize the number of the number of axes that are sharded.
127-
(shard_axes := list(shard_axes)).sort()
128-
could_shard = bool(shard_axes)
129-
first_sharding_iteration = True
130-
while could_shard and total_elements > target_elements:
131-
could_shard = False
132-
# For the first pass, exclude dimensions that are already sharded.
133-
# We do our best to shard at least once of each of the `shard_axes`.
134-
if first_sharding_iteration:
135-
must_shard_dims = list(i for i in shard_axes if not sharded_dimensions[i])
136-
if not must_shard_dims:
137-
# In case all priority axes are already sharded, use all of them.
138-
must_shard_dims = shard_axes
139-
first_sharding_iteration = False
140-
else:
141-
must_shard_dims = shard_axes
142-
# Exclude dimensions that can no longer be sharded.
143-
must_shard_dims = list(
144-
i for i in must_shard_dims if len(dim_factors[i]) > 1
107+
if any(axis < 0 or axis >= rank for axis in shard_axes):
108+
raise ValueError(
109+
f'All shard_axes={shard_axes} must be non-negative and less than'
110+
f' rank={rank}.'
145111
)
146-
# Shard once on each of the remaining dimensions in a round-robin fashion,
147-
# while we can.
148-
while must_shard_dims and total_elements > target_elements:
149-
could_shard = True
150-
# Find the minimum available divisor among the remaining dimensions.
151-
dim_idx = min(
152-
must_shard_dims,
153-
key=lambda i: dim_factors[i][-1] // dim_factors[i][-2],
154-
)
155-
reduce_dim(dim_idx)
156-
must_shard_dims.remove(dim_idx)
157112

158-
if shard_axes:
113+
# Reduce all explicitly requested shard axes.
114+
for shard_axis in shard_axes:
115+
while len(dim_factors[shard_axis]) > 1:
116+
reduce_dim(shard_axis)
117+
118+
if target_byte_size is None:
159119
current_shape = tuple(dim_factors[i][-1] for i in range(rank))
160120
if current_shape != write_shape:
161121
logging.vlog(
162122
1,
163-
'Reduced write shape using shard_axes=%s: global_shape=%s,'
164-
' write_shape=%s, dtype=%s, target_byte_size=%d; reduced shape: %s',
123+
'Reduced write shape using shard_axes=%s only (no target_byte_size):'
124+
' global_shape=%s, write_shape=%s, dtype=%s; reduced shape: %s',
165125
shard_axes,
166126
global_shape,
167127
write_shape,
168128
dtype,
169-
target_byte_size,
170129
current_shape,
171130
)
131+
return current_shape
132+
133+
# A target byte size is also specified. We will now try to find the smallest
134+
# chunk shape that satisfies the target byte size.
135+
136+
# TODO: b/354139177 - This check is too generous; the minimally viable chunk
137+
# size should be set to something within the range of [4 KiB; 1 MiB] (from
138+
# TensorStore and storage performance considerations).
139+
if target_byte_size < dtype.itemsize:
140+
raise ValueError(
141+
f'target_byte_size={target_byte_size} must be >= {dtype.itemsize}'
142+
)
143+
144+
if target_byte_size < 1 * _MIB:
145+
logging.warning(
146+
'Setting the target_byte_size too small could reduce performance.'
147+
)
148+
149+
target_elements = target_byte_size // dtype_size
150+
151+
# First, try to reduce the size of the chunk shape on axes that are already
152+
# sharded.
153+
could_shard = True
154+
while could_shard and total_elements > target_elements:
155+
could_shard = False
156+
# Find all dimensions that are sharded and can be sharded further.
157+
candidate_shard_dims = list(
158+
i
159+
for i in shard_axes
160+
if sharded_dimensions[i] and len(dim_factors[i]) > 1
161+
)
162+
# Shard once on each of the remaining dimensions in a round-robin fashion,
163+
# while we can.
164+
while candidate_shard_dims and total_elements > target_elements:
165+
could_shard = True
166+
# Find the minimum available divisor among the remaining dimensions.
167+
dim_idx = min(
168+
candidate_shard_dims,
169+
key=lambda i: dim_factors[i][-1] // dim_factors[i][-2],
170+
)
171+
reduce_dim(dim_idx)
172+
candidate_shard_dims.remove(dim_idx)
172173

173174
# If we are not within target_byte_size yet, continue to reduce the current
174175
# chunk shape until the desired number of elements is reached.

0 commit comments

Comments
 (0)