@@ -50,6 +50,199 @@ namespace mlir::tpu {
5050
5151namespace {
5252
53+ void optimizeLoadReshape (int hardware_generation,
54+ std::array<int64_t , 2 > target_shape,
55+ Operation& raw_op) {
56+ // Below, we try to look for reshapes that flatten multiple dims into the
57+ // lane dimension. If the source of the reshape originates from a load of a
58+ // ref with 128 minor dimension (effectively untiled), we can replace the
59+ // load/reshape sequence with an efficient strided load. In essence, the
60+ // strided load creates vregs with a narrow slice along the target minor
61+ // dimension, but with the 2nd minor dim after the reshape already in
62+ // sublanes. The results of strided load can be concatenated to form the
63+ // final vector result.
64+ //
65+ // A little extra care needs to be applied to packed types, which we handle by
66+ // briefly extending to 32-bit and repacking them after concatenation.
67+ TypedValue<VectorType> src;
68+ VectorType tgt_ty;
69+ if (auto op = dyn_cast<tpu::ReshapeOp>(&raw_op)) {
70+ src = op.getSource ();
71+ tgt_ty = op.getResult ().getType ();
72+ } else if (auto op = dyn_cast<vector::ShapeCastOp>(&raw_op)) {
73+ src = op.getSource ();
74+ tgt_ty = op.getResult ().getType ();
75+ } else {
76+ return ;
77+ }
78+ VectorType src_ty = src.getType ();
79+ if (src_ty.getRank () < 2 || tgt_ty.getRank () < 1 ) {
80+ return ;
81+ }
82+ const int bitwidth = src_ty.getElementTypeBitWidth ();
83+ const int packing = 32 / bitwidth;
84+ if (hardware_generation < 4 && packing > 1 ) {
85+ return ;
86+ }
87+
88+ auto load_op = dyn_cast_if_present<vector::LoadOp>(src.getDefiningOp ());
89+ // This rewrite might not be profitable if the load has other users.
90+ if (!load_op || !load_op.getBase ().hasOneUse ()) {
91+ return ;
92+ }
93+
94+ TypedValue<MemRefType> ref = load_op.getBase ();
95+ MemRefType ref_ty = getMemRefType (ref);
96+ // The reshape below might be invalid if the memref is not contiguous, but it
97+ // is an overly conservative check (we don't need all dims to be contiguous).
98+ if (!isContiguousMemref (ref)) {
99+ return ;
100+ }
101+
102+ const int64_t lane = target_shape[1 ];
103+ auto src_shape = src_ty.getShape ();
104+ auto tgt_shape = tgt_ty.getShape ();
105+ // Only handle the cases where the minor dim starts out as the number of lanes
106+ // and we fold at least the second minor dim into it, in a way that changes
107+ // its shape.
108+ if (src_shape.back () != lane ||
109+ tgt_shape.back () % (packing * lane) != 0 ||
110+ tgt_shape.back () == src_shape.back () ||
111+ tgt_shape.back () < llvm::product_of (src_shape.take_back (2 ))) {
112+ return ;
113+ }
114+
115+ // We don't handle memrefs with padding.
116+ auto tiled_layout = dyn_cast<tpu::TiledLayoutAttr>(ref_ty.getLayout ());
117+ if (!tiled_layout || tiled_layout.getTiles ().empty ()) {
118+ return ;
119+ }
120+ ArrayRef<int64_t > front_tile = tiled_layout.getTiles ().front ().dimensions ();
121+ ArrayRef<int64_t > ref_tiled_shape =
122+ ref_ty.getShape ().take_back (front_tile.size ());
123+ for (int i = 0 ; i < front_tile.size (); ++i) {
124+ if (ref_tiled_shape[i] % front_tile[i]) {
125+ return ;
126+ }
127+ }
128+
129+ // NOTE: We could generalize this to allow only flattening part of a dimension
130+ int folded_dims = 0 ;
131+ {
132+ int suffix_size = 1 ;
133+ auto sizes_it = src_shape.rbegin ();
134+ while (suffix_size < tgt_shape.back ()) {
135+ suffix_size *= *(sizes_it++);
136+ }
137+ // Make sure that the minor dim is folded only from entire major dims, not
138+ // from a part of some minor dim.
139+ if (suffix_size != tgt_shape.back ()) {
140+ return ;
141+ }
142+ folded_dims = sizes_it - src_shape.rbegin ();
143+ }
144+ DCHECK_GE (folded_dims, 2 ); // Should fold at least 2nd minor into minor.
145+
146+ // We don't handle slicing in the folded dims at the moment.
147+ if (ref_ty.getShape ().take_back (folded_dims) !=
148+ src_ty.getShape ().take_back (folded_dims)) {
149+ return ;
150+ }
151+
152+ Location loc = raw_op.getLoc ();
153+ ImplicitLocOpBuilder b (loc, &raw_op);
154+
155+ // Flatten the untiled dims into second minor and bitcast to i32.
156+ // NOTE: Source vector shape might be different from ref shape when slicing.
157+ SmallVector<int64_t > mem_shape (ref_ty.getShape ().drop_back (folded_dims));
158+ if (mem_shape.empty ()) {
159+ mem_shape.push_back (1 );
160+ }
161+ mem_shape.back () *= tgt_shape.back () / lane;
162+ mem_shape.push_back (lane);
163+ Value reshaped_ref = b.create <tpu::MemRefReshapeOp>(
164+ MemRefType::get (mem_shape, ref_ty.getElementType ()), ref);
165+ *(mem_shape.end () - 2 ) /= packing;
166+ Value i32_view = b.create <tpu::MemRefBitcastOp>(
167+ MemRefType::get (mem_shape, b.getI32Type ()), reshaped_ref);
168+
169+ // Define the shape of the small i32 chunk we will load in each iteration.
170+ // TODO(b/458291444): The loads we emit here might use suboptimal shapes and
171+ // we could do better by folding some dims (as much as slicing allows).
172+ SmallVector<int64_t > chunk_shape (src_shape.drop_back (folded_dims));
173+ if (chunk_shape.empty ()) {
174+ chunk_shape.push_back (1 );
175+ }
176+ chunk_shape.push_back (lane);
177+ VectorType chunk_ty = VectorType::get (chunk_shape, b.getI32Type ());
178+
179+ SmallVector<int32_t > strides (mem_shape.size (), 1 );
180+ const int64_t sublane_prod = tgt_shape.back () / lane;
181+ const int64_t stride = sublane_prod / packing;
182+ *(strides.end () - 2 ) = stride;
183+
184+ // Reuse indices from the original load for the prefix.
185+ auto indices = load_op.getIndices ();
186+ SmallVector<Value> idxs (indices.drop_back (folded_dims));
187+ if (idxs.empty ()) {
188+ idxs.push_back (IdxConst (0 , b, loc));
189+ }
190+ Value split_base_idx =
191+ b.create <arith::MulIOp>(idxs.back (), IdxConst (stride, b, loc));
192+ idxs.push_back (IdxConst (0 , b, loc));
193+
194+ SmallVector<Value> unpacked_chunks;
195+ unpacked_chunks.reserve (stride * packing);
196+ for (int i = 0 ; i < stride; ++i) {
197+ *(idxs.end () - 2 ) =
198+ b.create <arith::AddIOp>(split_base_idx, IdxConst (i, b, loc));
199+ Value chunk =
200+ b.create <tpu::StridedLoadOp>(chunk_ty, i32_view, idxs, strides);
201+ // Unpack elements from i32 if necessary.
202+ for (int p = 0 ; p < packing; ++p) {
203+ unpacked_chunks.push_back (b.create <arith::ShRUIOp>(
204+ chunk.getType (), chunk, I32Const (p * bitwidth, chunk_shape, b, loc)));
205+ }
206+ }
207+
208+ Value unpacked_flat;
209+ if (unpacked_chunks.size () == 1 ) {
210+ unpacked_flat = unpacked_chunks.front ();
211+ } else {
212+ SmallVector<int64_t > concat_shape (src_shape.drop_back (folded_dims));
213+ if (concat_shape.empty ()) {
214+ concat_shape.push_back (1 );
215+ }
216+ concat_shape.push_back (tgt_shape.back ());
217+ unpacked_flat = b.create <tpu::ConcatenateOp>(
218+ VectorType::get (concat_shape, b.getI32Type ()), unpacked_chunks,
219+ concat_shape.size () - 1 );
220+ }
221+
222+ Value result = unpacked_flat;
223+ if (packing > 1 ) { // Pack back, if needed.
224+ result = b.create <arith::TruncIOp>(
225+ VectorType::get (cast<VectorType>(result.getType ()).getShape (),
226+ b.getIntegerType (bitwidth)),
227+ result);
228+ }
229+ // Bitcast to the target type, if needed.
230+ if (cast<VectorType>(result.getType ()) != tgt_ty.getElementType ()) {
231+ result = b.create <arith::BitcastOp>(
232+ VectorType::get (cast<VectorType>(result.getType ()).getShape (),
233+ tgt_ty.getElementType ()),
234+ result);
235+ }
236+ // Apply the reshape to major dims, if needed.
237+ if (cast<VectorType>(result.getType ()).getShape () != tgt_ty.getShape ()) {
238+ result = b.create <tpu::ReshapeOp>(tgt_ty, result);
239+ }
240+ DCHECK_EQ (result.getType (), tgt_ty);
241+
242+ raw_op.replaceAllUsesWith (ValueRange{result});
243+ raw_op.erase ();
244+ }
245+
53246void optimizeStore (int hardware_generation, std::array<int64_t , 2 > target_shape,
54247 Operation& raw_op) {
55248 // Fuses a vector.shape_cast (that expands dimensions) into a subsequent
@@ -417,6 +610,8 @@ struct PreCanonicalizationOptimizationPass
417610 }
418611 } else if (isa<vector::StoreOp, tpu::VectorStoreOp>(op)) {
419612 optimizeStore (hardware_generation_, target_shape_, *op);
613+ } else if (isa<vector::ShapeCastOp, tpu::ReshapeOp>(op)) {
614+ optimizeLoadReshape (hardware_generation_, target_shape_, *op);
420615 }
421616 });
422617 }
0 commit comments