Skip to content

Commit 1aba234

Browse files
[MGPU] Doc fix and shape size fix for broadcast in WGSplatFragLayout.
We enforce len(source_shape) <= len(target_shape) PiperOrigin-RevId: 845281808
1 parent b8d3757 commit 1aba234

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -616,10 +616,14 @@ class WGSplatFragLayout:
616616
def can_broadcast_to(self, shape) -> bool:
617617
"""Check that the shape can be broadcast.
618618
619-
Only dimensions of size 1 can be broadcast. All other dimensions
620-
must be the same as the argument shape.
619+
Source rank must be not larger than the target rank and all the trailing
620+
dimensions of the target shape must equal the dimensions of the source shape
621+
(or source's one must be equal to 1).
621622
"""
622-
return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1]))
623+
return len(self.shape) <= len(shape) and all(
624+
dim1 == dim2 or dim1 == 1
625+
for dim1, dim2 in zip(self.shape[::-1], shape[::-1])
626+
)
623627

624628
def registers_element_type(self, t: ir.Type) -> ir.Type:
625629
return t

0 commit comments

Comments
 (0)