@@ -165,7 +165,7 @@ OpFoldResult BitcastVregOp::fold(FoldAdaptor adaptor) {
165165}
166166
167167LogicalResult MemRefSliceOp::verify () {
168- auto source_type = getMemRefType ( getMemRef ());
168+ auto source_type = getMemRef (). getType ( );
169169 auto target_type = getType ();
170170 auto source_layout = source_type.getLayout ();
171171 auto target_layout = target_type.getLayout ();
@@ -176,6 +176,11 @@ LogicalResult MemRefSliceOp::verify() {
176176 return emitOpError (
177177 " Only slicing of memrefs with static shapes is supported." );
178178 }
179+ if (getDynamicSizes ().size () != target_type.getNumDynamicDims ()) {
180+ return emitOpError (
181+ " Number of provided dynamic dimensions sizes must match the number of "
182+ " dynamic dimensions in the target type." );
183+ }
179184 auto source_shape = source_type.getShape ();
180185 bool is_semaphore =
181186 HasMemorySpace (source_type, tpu::MemorySpace::kSemaphoreMem );
@@ -191,21 +196,11 @@ LogicalResult MemRefSliceOp::verify() {
191196 }
192197 // TODO(apaszke): Check that the result has a smaller shape.
193198 // TODO(apaszke): Check that strides are equivalent.
194- // Source and target attributes may be different before propagation is done by
195- // the canonicalizer, so we allow this when attributes are "unset" in the
196- // target type. Note that MemRefType does not allow a null layout so we treat
197- // the default identity affine map as an "unset" value instead.
198- bool is_target_memory_space_provided = target_memory_space != nullptr ;
199- if (is_target_memory_space_provided &&
200- target_memory_space != source_type.getMemorySpace ()) {
199+ if (target_memory_space != source_type.getMemorySpace ()) {
201200 return emitOpError (
202201 " Memory spaces must match if the target memory space is provided." );
203202 }
204- if (isa<TiledLayoutAttr>(source_layout) &&
205- !isa<TiledLayoutAttr>(target_layout)) {
206- // TODO(slebedev): Remove this special-case once we move layout propagation
207- // to the infer-memref-layout pass.
208- } else if (isa<StridedLayoutAttr>(target_layout)) {
203+ if (isa<StridedLayoutAttr>(target_layout)) {
209204 SmallVector<int64_t > source_strides;
210205 int64_t source_offset;
211206 if (failed (
@@ -230,18 +225,9 @@ LogicalResult MemRefSliceOp::verify() {
230225 return emitOpError (" Layout mismatch: got " )
231226 << target_layout << " , expected " << expected_layout << " ." ;
232227 }
233- } else {
234- bool is_target_layout_identity_map =
235- isa<AffineMapAttr>(target_layout) && target_layout.isIdentity ();
236- if (!is_target_layout_identity_map && target_layout != source_layout) {
237- return emitOpError (
238- " Layouts must match if the target layout is not an identity map." );
239- }
240- }
241- if (getDynamicSizes ().size () != target_type.getNumDynamicDims ()) {
228+ } else if (target_layout != source_layout) {
242229 return emitOpError (
243- " Number of provided dynamic dimensions sizes must match the number of "
244- " dynamic dimensions in the target type." );
230+ " Layouts must match if the target layout is not strided." );
245231 }
246232 return success ();
247233}
0 commit comments