Skip to content

Commit d6df159

Browse files
tlongeriGoogle-ML-Automation
authored andcommitted
[Mosaic:TPU] Clean up tpu.memref_slice verifier
PiperOrigin-RevId: 843809401
1 parent 94ae97f commit d6df159

File tree

1 file changed

+10
-24
lines changed

1 file changed

+10
-24
lines changed

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ OpFoldResult BitcastVregOp::fold(FoldAdaptor adaptor) {
165165
}
166166

167167
LogicalResult MemRefSliceOp::verify() {
168-
auto source_type = getMemRefType(getMemRef());
168+
auto source_type = getMemRef().getType();
169169
auto target_type = getType();
170170
auto source_layout = source_type.getLayout();
171171
auto target_layout = target_type.getLayout();
@@ -176,6 +176,11 @@ LogicalResult MemRefSliceOp::verify() {
176176
return emitOpError(
177177
"Only slicing of memrefs with static shapes is supported.");
178178
}
179+
if (getDynamicSizes().size() != target_type.getNumDynamicDims()) {
180+
return emitOpError(
181+
"Number of provided dynamic dimensions sizes must match the number of "
182+
"dynamic dimensions in the target type.");
183+
}
179184
auto source_shape = source_type.getShape();
180185
bool is_semaphore =
181186
HasMemorySpace(source_type, tpu::MemorySpace::kSemaphoreMem);
@@ -191,21 +196,11 @@ LogicalResult MemRefSliceOp::verify() {
191196
}
192197
// TODO(apaszke): Check that the result has a smaller shape.
193198
// TODO(apaszke): Check that strides are equivalent.
194-
// Source and target attributes may be different before propagation is done by
195-
// the canonicalizer, so we allow this when attributes are "unset" in the
196-
// target type. Note that MemRefType does not allow a null layout so we treat
197-
// the default identity affine map as an "unset" value instead.
198-
bool is_target_memory_space_provided = target_memory_space != nullptr;
199-
if (is_target_memory_space_provided &&
200-
target_memory_space != source_type.getMemorySpace()) {
199+
if (target_memory_space != source_type.getMemorySpace()) {
201200
return emitOpError(
202201
"Memory spaces must match if the target memory space is provided.");
203202
}
204-
if (isa<TiledLayoutAttr>(source_layout) &&
205-
!isa<TiledLayoutAttr>(target_layout)) {
206-
// TODO(slebedev): Remove this special-case once we move layout propagation
207-
// to the infer-memref-layout pass.
208-
} else if (isa<StridedLayoutAttr>(target_layout)) {
203+
if (isa<StridedLayoutAttr>(target_layout)) {
209204
SmallVector<int64_t> source_strides;
210205
int64_t source_offset;
211206
if (failed(
@@ -230,18 +225,9 @@ LogicalResult MemRefSliceOp::verify() {
230225
return emitOpError("Layout mismatch: got ")
231226
<< target_layout << ", expected " << expected_layout << ".";
232227
}
233-
} else {
234-
bool is_target_layout_identity_map =
235-
isa<AffineMapAttr>(target_layout) && target_layout.isIdentity();
236-
if (!is_target_layout_identity_map && target_layout != source_layout) {
237-
return emitOpError(
238-
"Layouts must match if the target layout is not an identity map.");
239-
}
240-
}
241-
if (getDynamicSizes().size() != target_type.getNumDynamicDims()) {
228+
} else if (target_layout != source_layout) {
242229
return emitOpError(
243-
"Number of provided dynamic dimensions sizes must match the number of "
244-
"dynamic dimensions in the target type.");
230+
"Layouts must match if the target layout is not strided.");
245231
}
246232
return success();
247233
}

0 commit comments

Comments
 (0)