File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed
jax/experimental/mosaic/gpu Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -616,10 +616,14 @@ class WGSplatFragLayout:
616616 def can_broadcast_to (self , shape ) -> bool :
617617 """Check that the shape can be broadcast.
618618
619- Only dimensions of size 1 can be broadcast. All other dimensions
620- must be the same as the argument shape.
619+ All source dimensions must match the target's trailing dimensions by
620+ equality or being set to 1 (i.e. we can broadcast 1-sized dimensions or
621+ create new leading dimensions).
621622 """
622- return all (dim1 == dim2 or dim1 == 1 for dim1 , dim2 in zip (self .shape [::- 1 ], shape [::- 1 ]))
623+ return len (self .shape ) <= len (shape ) and all (
624+ dim1 == dim2 or dim1 == 1
625+ for dim1 , dim2 in zip (self .shape [::- 1 ], shape [::- 1 ])
626+ )
623627
624628 def registers_element_type (self , t : ir .Type ) -> ir .Type :
625629 return t
You can’t perform that action at this time.
0 commit comments