diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index ac0a10199943..795a3c3062c4 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -616,10 +616,14 @@ class WGSplatFragLayout: def can_broadcast_to(self, shape) -> bool: """Check that the shape can be broadcast. - Only dimensions of size 1 can be broadcast. All other dimensions - must be the same as the argument shape. + All source dimensions must match the target's trailing dimensions by + equality or being set to 1 (i.e. we can broadcast 1-sized dimensions or + create new leading dimensions). """ - return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1])) + return len(self.shape) <= len(shape) and all( + dim1 == dim2 or dim1 == 1 + for dim1, dim2 in zip(self.shape[::-1], shape[::-1]) + ) def registers_element_type(self, t: ir.Type) -> ir.Type: return t