File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed
jax/experimental/mosaic/gpu Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -616,10 +616,14 @@ class WGSplatFragLayout:
616616 def can_broadcast_to (self , shape ) -> bool :
617617 """Check that the shape can be broadcast.
618618
619- Only dimensions of size 1 can be broadcast. All other dimensions
620- must be the same as the argument shape.
619+ Source rank must be not larger than the target rank and all the trailing
620+ dimensions of the target shape must equal the dimensions of the source shape
621+ (or source's one must be equal to 1).
621622 """
622- return all (dim1 == dim2 or dim1 == 1 for dim1 , dim2 in zip (self .shape [::- 1 ], shape [::- 1 ]))
623+ return len (self .shape ) <= len (shape ) and all (
624+ dim1 == dim2 or dim1 == 1
625+ for dim1 , dim2 in zip (self .shape [::- 1 ], shape [::- 1 ])
626+ )
623627
624628 def registers_element_type (self , t : ir .Type ) -> ir .Type :
625629 return t
You can’t perform that action at this time.
0 commit comments