Skip to content

Conversation

@copybara-service
Copy link

[Pallas:MGPU] Add support for indexing untiled dimensions under WG semantic.

Indexing is lowered to lax.slice_p + lax.squeeze_p. We add the missing rule for lax.squeeze_p under WG semantics. We lower lax.squeeze_p to vector.shape_cast.

So x[1, 1] where x is a JAX array will be lowered to:

%1 = vector.extract_strided_slice %0 {offsets = [1, 1, ...], sizes = [1, 1, ...]} : vector<NxMx...> to vector<1x1x...>
%2 = vector.shape_cast %2 : vector<1x1x...> to vector<...>

This gets simplified to:

%2 = vector.extract %0[1, 1] : vector<...> from vector<NxMx...>

So we need to add support for vector.extract to the Mosaic GPU dialect. We add the respective lowering and layout inference rules.

…mantic.

Indexing is lowered to `lax.slice_p` + `lax.squeeze_p`. We add the missing rule for `lax.squeeze_p` under WG semantics. We lower `lax.squeeze_p` to `vector.shape_cast`.

So `x[1, 1]` where `x` is a JAX array will be lowered to:
```
%1 = vector.extract_strided_slice %0 {offsets = [1, 1, ...], sizes = [1, 1, ...]} : vector<NxMx...> to vector<1x1x...>
%2 = vector.shape_cast %2 : vector<1x1x...> to vector<...>
```

This gets simplified to:
```
%2 = vector.extract %0[1, 1] : vector<...> from vector<NxMx...>
```

So we need to add support for `vector.extract` to the Mosaic GPU dialect. We add the respective lowering and layout inference rules.

PiperOrigin-RevId: 844759978
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant