Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 63 additions & 62 deletions checkpoint/orbax/checkpoint/_src/arrays/subchunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,41 +72,27 @@ def choose_chunk_shape(
Returns:
List of length `len(write_shape)` specifying the chosen chunk shape.
"""
# TensorStore Zarr metadata doesn't support 0-sized dimensions.
write_shape = tuple(max(1, d) for d in write_shape)

if target_byte_size is None:
return write_shape

# TODO: b/354139177 - This check is too generous; the minimally viable chunk
# size should be set to something within the range of [4 KiB; 1 MiB] (from
# TensorStore and storage performance considerations).
if target_byte_size < dtype.itemsize:
raise ValueError(
f'target_byte_size={target_byte_size} must be >= {dtype.itemsize}'
)

if len(global_shape) != len(write_shape):
raise ValueError(
f'global_shape={global_shape} and write_shape={write_shape} must have'
' the same length.'
)
if target_byte_size < 1 * _MIB:
logging.warning(
'Setting the target_byte_size too small could reduce performance.'
)

# TensorStore Zarr metadata doesn't support 0-sized dimensions.
write_shape = tuple(max(1, d) for d in write_shape)

if target_byte_size is None and not shard_axes:
# No restrictions on chunk size or shape; return the write shape as-is.
return write_shape

sharded_dimensions = np.array(global_shape) != np.array(write_shape)
dtype_size = dtype.itemsize
target_elements = target_byte_size // dtype_size

rank = len(write_shape)

# `dim_factors[i]` is the list of divisors of `write_shape[i]`
dim_factors = [_find_divisors(size) for size in write_shape]

# The current chunk shape is:
# [dim_factors[i][-1] for i in range(rank)]
# [dim_factors[i][-1] for i in range(rank)]
dim_factors = [_find_divisors(size) for size in write_shape]

total_elements = math.prod(write_shape)

Expand All @@ -118,57 +104,72 @@ def reduce_dim(dim_to_reduce: int) -> None:
total_elements = (total_elements // current_dim) * new_dim
sharded_dimensions[dim_to_reduce] = True

# First, try to reduce the size of the chunk shape on the `shard_axes`.
# If some of these specified axes are already sharded, we will skip them on
# the first iteration which ensures that we shard at least once on each of the
# `shard_axes`. It might also be the case that the given target_byte_size is
# too big to shard on all of the requested axes, in which case we will
# maximize the number of the number of axes that are sharded.
(shard_axes := list(shard_axes)).sort()
could_shard = bool(shard_axes)
first_sharding_iteration = True
while could_shard and total_elements > target_elements:
could_shard = False
# For the first pass, exclude dimensions that are already sharded.
# We do our best to shard at least once of each of the `shard_axes`.
if first_sharding_iteration:
must_shard_dims = list(i for i in shard_axes if not sharded_dimensions[i])
if not must_shard_dims:
# In case all priority axes are already sharded, use all of them.
must_shard_dims = shard_axes
first_sharding_iteration = False
else:
must_shard_dims = shard_axes
# Exclude dimensions that can no longer be sharded.
must_shard_dims = list(
i for i in must_shard_dims if len(dim_factors[i]) > 1
if any(axis < 0 or axis >= rank for axis in shard_axes):
raise ValueError(
f'All shard_axes={shard_axes} must be non-negative and less than'
f' rank={rank}.'
)
# Shard once on each of the remaining dimensions in a round-robin fashion,
# while we can.
while must_shard_dims and total_elements > target_elements:
could_shard = True
# Find the minimum available divisor among the remaining dimensions.
dim_idx = min(
must_shard_dims,
key=lambda i: dim_factors[i][-1] // dim_factors[i][-2],
)
reduce_dim(dim_idx)
must_shard_dims.remove(dim_idx)

if shard_axes:
# Reduce all explicitly requested shard axes.
for shard_axis in shard_axes:
while len(dim_factors[shard_axis]) > 1:
reduce_dim(shard_axis)

if target_byte_size is None:
current_shape = tuple(dim_factors[i][-1] for i in range(rank))
if current_shape != write_shape:
logging.vlog(
1,
'Reduced write shape using shard_axes=%s: global_shape=%s,'
' write_shape=%s, dtype=%s, target_byte_size=%d; reduced shape: %s',
'Reduced write shape using shard_axes=%s only (no target_byte_size):'
' global_shape=%s, write_shape=%s, dtype=%s; reduced shape: %s',
shard_axes,
global_shape,
write_shape,
dtype,
target_byte_size,
current_shape,
)
return current_shape

# A target byte size is also specified. We will now try to find the smallest
# chunk shape that satisfies the target byte size.

# TODO: b/354139177 - This check is too generous; the minimally viable chunk
# size should be set to something within the range of [4 KiB; 1 MiB] (from
# TensorStore and storage performance considerations).
if target_byte_size < dtype.itemsize:
raise ValueError(
f'target_byte_size={target_byte_size} must be >= {dtype.itemsize}'
)

if target_byte_size < 1 * _MIB:
logging.warning(
'Setting the target_byte_size too small could reduce performance.'
)

target_elements = target_byte_size // dtype_size

# First, try to reduce the size of the chunk shape on axes that are already
# sharded.
could_shard = True
while could_shard and total_elements > target_elements:
could_shard = False
# Find all dimensions that are sharded and can be sharded further.
candidate_shard_dims = list(
i
for i in shard_axes
if sharded_dimensions[i] and len(dim_factors[i]) > 1
)
# Shard once on each of the remaining dimensions in a round-robin fashion,
# while we can.
while candidate_shard_dims and total_elements > target_elements:
could_shard = True
# Find the minimum available divisor among the remaining dimensions.
dim_idx = min(
candidate_shard_dims,
key=lambda i: dim_factors[i][-1] // dim_factors[i][-2],
)
reduce_dim(dim_idx)
candidate_shard_dims.remove(dim_idx)

# If we are not within target_byte_size yet, continue to reduce the current
# chunk shape until the desired number of elements is reached.
Expand Down
Loading
Loading