Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2441,6 +2441,19 @@ def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions):
return _ensure_fa(x, x_aval.dtype).reshape(y_aval.shape)


@register_lowering_rule(lax.squeeze_p, mgpu.LoweringSemantics.Warpgroup)
def _squeeze_lowering_rule_wg(ctx: LoweringRuleContext, x, dimensions):
[x_aval] = ctx.avals_in
[y_aval] = ctx.avals_out
x = _ensure_ir_value(x, x_aval.dtype)
if y_aval.ndim == 0: # scalar
# TODO(allanrenucci): Lower to `vector.extract` once we support scalar
# results in MGPU dialect lowering.
raise NotImplementedError("Squeeze to scalar is not supported.")
res_ty = ir.VectorType.get(y_aval.shape, ir.VectorType(x.type).element_type)
return vector_dialect.shape_cast(res_ty, x)


def _reduce_lowering_rule(op, ctx: LoweringRuleContext, x, *, axes, **kwargs):
[x_aval] = ctx.avals_in
match x.layout:
Expand Down
21 changes: 21 additions & 0 deletions jax/experimental/mosaic/gpu/dialect_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,27 @@ def _vector_extract_strided_slice_op_lowering_rule(
return [fragmented_array_to_ir(result, out_vec_ty)]


@_register_lowering(vector.ExtractOp)
def _vector_extract_op_lowering_rule(
ctx: LoweringContext, op: vector.ExtractOp
) -> Sequence[ir.Value]:
del ctx
if not ir.VectorType.isinstance(op.result.type):
raise NotImplementedError("Scalar element extraction is not supported.")
if op.dynamic_position:
raise NotImplementedError("Only slicing with static indices allowed.")
[in_layout] = inference_utils.in_layouts(op)
[out_layout] = inference_utils.out_layouts(op)
assert in_layout == out_layout
a = _fragmented_array_from_ir(op.source, in_layout)
result_type = ir.VectorType(op.result.type)
slices = tuple(slice(i, i + 1) for i in op.static_position)
# TODO(allanrenucci): Add direct support for indexing to FragmentedArray.
result = a[slices].reshape(tuple(result_type.shape))
assert result.layout == layouts.from_layout_attr(out_layout)
return [fragmented_array_to_ir(result, result_type)]


@_register_lowering(vector.ReductionOp)
def _vector_reduction_op_lowering_rule(
ctx: LoweringContext, op: vector.ReductionOp
Expand Down
25 changes: 25 additions & 0 deletions jax/experimental/mosaic/gpu/layout_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,31 @@ def _extract_strided_slice_constraint_system(
)


@_add_constraint_system_derivation_rule(vector.ExtractOp)
def _vector_extract_constraint_system(
ctx: DerivationContext, op: vector.ExtractOp
) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]:
del ctx
if not ir.VectorType.isinstance(op.result.type):
raise NotImplementedError("Scalar element extraction is not supported.")
if op.dynamic_position:
raise NotImplementedError("Only slicing with static indices allowed.")
operand = ValueSite(op, VariableType.OPERAND, 0)
result = ValueSite(op, VariableType.RESULT, 0)
variable = cs.Variable(operand)
constraints = [
cs.Divides(variable, tuple(op.result.type.shape)),
# TODO(allanrenucci): Remove once vectors with splat and strided layouts
# can be sliced.
cs.NotOfType(variable, fa.WGSplatFragLayout),
cs.NotOfType(variable, fa.WGStridedFragLayout),
]
return (
cs.ConstraintSystem(constraints=constraints),
{variable: [operand, result]},
)


@_add_constraint_system_derivation_rule(mgpu.CustomPrimitiveOp)
def _custom_primitive_constraint_system(
ctx: DerivationContext,
Expand Down
25 changes: 25 additions & 0 deletions tests/mosaic/gpu_layout_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1979,6 +1979,31 @@ def test_infer_layout_for_vector_extract_strided_slice_fails(
):
mgpu.infer_layout(self.module)

def test_infer_layout_for_vector_extract(self):
layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT)
with ir.InsertionPoint(self.module.body):
i16 = ir.IntegerType.get_signless(16)
src_ty = ir.VectorType.get([2, 3, 64, 8], i16)
[src] = undefs(src_ty)
src = mgpu.dialect.layout_cast(src, layout)
op = vector.ExtractOp(src, dynamic_position=[], static_position=[1, 1])
mgpu.infer_layout(self.module)
self.checkInLayouts(op, [layout])
self.checkOutLayouts(op, [layout])

def test_infer_layout_for_vector_extract_fails_if_not_dividing_result_shape(self):
layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT)
with ir.InsertionPoint(self.module.body):
i16 = ir.IntegerType.get_signless(16)
src_ty = ir.VectorType.get([64, 64], i16)
[src] = undefs(src_ty)
src = mgpu.dialect.layout_cast(src, layout)
vector.extract(src, dynamic_position=[], static_position=[0])
with self.assertRaisesRegex(
ValueError, "Failed to infer a possible set of layouts."
):
mgpu.infer_layout(self.module)

def test_infer_tmem_layout_for_slice_tmem_op(self):
# in and out layouts can be different.
in_layout = layouts.to_layout_attr(tcgen05.tmem_default_layout(packing=1))
Expand Down
34 changes: 30 additions & 4 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,22 @@ def fn(_):
return fn()


def _array_splat(value, shape: tuple[int, ...]):
"""Same as `jnp.full(shape, value, jnp.float32)` but implemented using `inline_mgpu`.

This is useful to prevent the result from being optimized away.
"""
@plgpu.inline_mgpu(
return_type=plgpu.ShapeDtypeStruct(
shape, jnp.float32, layout=plgpu.Layout.WG_SPLAT(shape)
),
)
def fn(_):
ir_value = mgpu.c(value, ir.F32Type.get())
return mgpu.FragmentedArray.splat(ir_value, shape)
return fn()


class PallasTestMetaclass(parameterized.TestGeneratorMetaclass):

def __new__(mcs, *args, lowering_semantics=plgpu.LoweringSemantics.Lane):
Expand Down Expand Up @@ -352,20 +368,31 @@ def kernel(out_ref):
)

def test_slice_untiled_dim(self):
self.skip_if_wg_semantics()
shape = (2, 3, 64, 8)

@functools.partial(
self.kernel,
out_shape=jax.ShapeDtypeStruct(shape[2:], jnp.float32),
)
def kernel(x_ref, out_ref):
y = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA, optimized=False)[1, 1]
out_ref[...] = y
x = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA, optimized=False)
out_ref[...] = x[1, 1]

x = jnp.arange(math.prod(shape)).reshape(shape).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x[1, 1])

def test_squeeze_to_scalar(self):
self.skip_if_wg_semantics() # Scalar element extraction is not supported for `vector.extract`.
@functools.partial(
self.kernel,
out_shape=jax.ShapeDtypeStruct((), jnp.float32),
)
def kernel(out_ref):
x = _array_splat(42, (1, 1, 1))
out_ref[...] = lax.squeeze(x, dimensions=(0, 1, 2))

np.testing.assert_array_equal(kernel(), jnp.array(42, dtype=jnp.float32))

def test_add_xy_indexed(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32)
Expand Down Expand Up @@ -2863,7 +2890,6 @@ def test_missing_primitive_lowerings_are_tracked(self):
pallas_primitives.semaphore_read_p,
pallas_primitives.delay_p,
checkify.check_p,
lax.squeeze_p,
}

self.assertSetEqual(actual_missing_primitives, expected_missing_primitives)
Expand Down
Loading