Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,10 +616,14 @@ class WGSplatFragLayout:
def can_broadcast_to(self, shape) -> bool:
"""Check that the shape can be broadcast.

Only dimensions of size 1 can be broadcast. All other dimensions
must be the same as the argument shape.
All source dimensions must match the target's trailing dimensions by
equality or being set to 1 (i.e. we can broadcast 1-sized dimensions or
create new leading dimensions).
"""
return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1]))
return len(self.shape) <= len(shape) and all(
dim1 == dim2 or dim1 == 1
for dim1, dim2 in zip(self.shape[::-1], shape[::-1])
)

def registers_element_type(self, t: ir.Type) -> ir.Type:
return t
Expand Down
Loading