Skip to content

Commit f2562d4

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic TPU][NFC] Move the load->reshape optimization to the pre canonicalization optimization pass
PiperOrigin-RevId: 829408887
1 parent df5d0a7 commit f2562d4

File tree

2 files changed

+197
-182
lines changed

2 files changed

+197
-182
lines changed

jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc

Lines changed: 2 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,188 +1485,8 @@ FailureOr<Value> canonicalize_shape_cast(const CanonicalizeContext& ctx,
14851485

14861486
FailureOr<Value> canonicalize_reshape(const CanonicalizeContext &ctx,
14871487
Operation &raw_op) {
1488-
// Below, we try to look for reshapes that flatten multiple dims into the
1489-
// lane dimension. If the source of the reshape originates from a load of a
1490-
// ref with 128 minor dimension (effectively untiled), we can replace the
1491-
// load/reshape sequence with an efficient strided load. In essence, the
1492-
// strided load creates vregs with a narrow slice along the target minor
1493-
// dimension, but with the 2nd minor dim after the reshape already in
1494-
// sublanes. The results of strided load can be concatenated to form the
1495-
// final vector result.
1496-
//
1497-
// A little extra care needs to be applied to packed types, which we handle by
1498-
// briefly extending to 32-bit and repacking them after concatenation.
1499-
auto op = cast<tpu::ReshapeOp>(raw_op);
1500-
TypedValue<VectorType> src = op.getSource();
1501-
VectorType src_ty = src.getType();
1502-
VectorType tgt_ty = op.getResult().getType();
1503-
if (src_ty.getRank() < 2 || tgt_ty.getRank() < 1) {
1504-
return raw_op.getResult(0);
1505-
}
1506-
const int bitwidth = src_ty.getElementTypeBitWidth();
1507-
const int packing = 32 / bitwidth;
1508-
if (ctx.hardware_generation < 4 && packing > 1) {
1509-
return raw_op.getResult(0);
1510-
}
1511-
1512-
auto load_op = dyn_cast_if_present<vector::LoadOp>(src.getDefiningOp());
1513-
// This rewrite might not be profitable if the load has other users.
1514-
if (!load_op || !load_op.getBase().hasOneUse()) {
1515-
return raw_op.getResult(0);
1516-
}
1517-
1518-
TypedValue<MemRefType> ref = load_op.getBase();
1519-
MemRefType ref_ty = getMemRefType(ref);
1520-
// The reshape below might be invalid if the memref is not contiguous, but it
1521-
// is an overly conservative check (we don't need all dims to be contiguous).
1522-
if (!isContiguousMemref(ref)) {
1523-
return raw_op.getResult(0);
1524-
}
1525-
1526-
const int64_t lane = ctx.target_shape[1];
1527-
auto src_shape = src_ty.getShape();
1528-
auto tgt_shape = tgt_ty.getShape();
1529-
// Only handle the cases where the minor dim starts out as the number of lanes
1530-
// and we fold at least the second minor dim into it, in a way that changes
1531-
// its shape.
1532-
if (src_shape.back() != lane ||
1533-
tgt_shape.back() % (packing * lane) != 0 ||
1534-
tgt_shape.back() == src_shape.back() ||
1535-
tgt_shape.back() < llvm::product_of(src_shape.take_back(2))) {
1536-
return raw_op.getResult(0);
1537-
}
1538-
1539-
// We don't handle memrefs with padding.
1540-
auto tiled_layout = dyn_cast<tpu::TiledLayoutAttr>(ref_ty.getLayout());
1541-
if (!tiled_layout || tiled_layout.getTiles().empty()) {
1542-
return raw_op.getResult(0);
1543-
}
1544-
ArrayRef<int64_t> front_tile = tiled_layout.getTiles().front().dimensions();
1545-
ArrayRef<int64_t> ref_tiled_shape =
1546-
ref_ty.getShape().take_back(front_tile.size());
1547-
for (int i = 0; i < front_tile.size(); ++i) {
1548-
if (ref_tiled_shape[i] % front_tile[i]) {
1549-
return raw_op.getResult(0);
1550-
}
1551-
}
1552-
1553-
// NOTE: We could generalize this to allow only flattening part of a dimension
1554-
int folded_dims = 0;
1555-
{
1556-
int suffix_size = 1;
1557-
auto sizes_it = src_shape.rbegin();
1558-
while (suffix_size < tgt_shape.back()) {
1559-
suffix_size *= *(sizes_it++);
1560-
}
1561-
// Make sure that the minor dim is folded only from entire major dims, not
1562-
// from a part of some minor dim.
1563-
if (suffix_size != tgt_shape.back()) {
1564-
return raw_op.getResult(0);
1565-
}
1566-
folded_dims = sizes_it - src_shape.rbegin();
1567-
}
1568-
DCHECK_GE(folded_dims, 2); // Should fold at least 2nd minor into minor.
1569-
1570-
// We don't handle slicing in the folded dims at the moment.
1571-
if (ref_ty.getShape().take_back(folded_dims) !=
1572-
src_ty.getShape().take_back(folded_dims)) {
1573-
return raw_op.getResult(0);
1574-
}
1575-
1576-
// NOTE: Source vector shape might be different from ref shape when slicing.
1577-
SmallVector<int64_t> mem_shape(ref_ty.getShape().drop_back(folded_dims));
1578-
if (mem_shape.empty()) {
1579-
mem_shape.push_back(1);
1580-
}
1581-
1582-
CanonicalBuilder b(ctx, op->getLoc(), op.getOperation());
1583-
Location loc = op.getLoc();
1584-
1585-
// Flatten the untiled dims into second minor and bitcast to i32.
1586-
mem_shape.back() *= tgt_shape.back() / lane;
1587-
mem_shape.push_back(lane);
1588-
Value reshaped_ref = b.create<tpu::MemRefReshapeOp>(
1589-
MemRefType::get(mem_shape, ref_ty.getElementType()), ref);
1590-
*(mem_shape.end() - 2) /= packing;
1591-
Value i32_view = b.create<tpu::MemRefBitcastOp>(
1592-
MemRefType::get(mem_shape, b.getI32Type()), reshaped_ref);
1593-
1594-
// Define the shape of the small i32 chunk we will load in each iteration.
1595-
// TODO(b/458291444): The loads we emit here might use suboptimal shapes and
1596-
// we could do better by folding some dims (as much as slicing allows).
1597-
SmallVector<int64_t> chunk_shape(src_shape.drop_back(folded_dims));
1598-
if (chunk_shape.empty()) {
1599-
chunk_shape.push_back(1);
1600-
}
1601-
chunk_shape.push_back(lane);
1602-
VectorType chunk_ty = VectorType::get(chunk_shape, b.getI32Type());
1603-
1604-
SmallVector<int32_t> strides(mem_shape.size(), 1);
1605-
const int64_t sublane_prod = tgt_shape.back() / lane;
1606-
const int64_t stride = sublane_prod / packing;
1607-
*(strides.end() - 2) = stride;
1608-
1609-
// Reuse indices from the original load for the prefix.
1610-
auto indices = load_op.getIndices();
1611-
SmallVector<Value> idxs(indices.drop_back(folded_dims));
1612-
if (idxs.empty()) {
1613-
idxs.push_back(IdxConst(0, b, loc));
1614-
}
1615-
Value split_base_idx =
1616-
b.create<arith::MulIOp>(idxs.back(), IdxConst(stride, b, loc));
1617-
idxs.push_back(IdxConst(0, b, loc));
1618-
1619-
SmallVector<Value> unpacked_chunks;
1620-
unpacked_chunks.reserve(stride * packing);
1621-
for (int i = 0; i < stride; ++i) {
1622-
*(idxs.end() - 2) =
1623-
b.create<arith::AddIOp>(split_base_idx, IdxConst(i, b, loc));
1624-
Value chunk =
1625-
b.create<tpu::StridedLoadOp>(chunk_ty, i32_view, idxs, strides);
1626-
// Unpack elements from i32 if necessary.
1627-
for (int p = 0; p < packing; ++p) {
1628-
unpacked_chunks.push_back(b.create<arith::ShRUIOp>(
1629-
chunk.getType(), chunk, I32Const(p * bitwidth, chunk_shape, b, loc)));
1630-
}
1631-
}
1632-
1633-
Value unpacked_flat;
1634-
if (unpacked_chunks.size() == 1) {
1635-
unpacked_flat = unpacked_chunks.front();
1636-
} else {
1637-
SmallVector<int64_t> concat_shape(src_shape.drop_back(folded_dims));
1638-
if (concat_shape.empty()) {
1639-
concat_shape.push_back(1);
1640-
}
1641-
concat_shape.push_back(tgt_shape.back());
1642-
unpacked_flat = b.create<tpu::ConcatenateOp>(
1643-
VectorType::get(concat_shape, b.getI32Type()), unpacked_chunks,
1644-
concat_shape.size() - 1);
1645-
}
1646-
1647-
Value result = unpacked_flat;
1648-
if (packing > 1) { // Pack back, if needed.
1649-
result = b.create<arith::TruncIOp>(
1650-
VectorType::get(cast<VectorType>(result.getType()).getShape(),
1651-
b.getIntegerType(bitwidth)),
1652-
result);
1653-
}
1654-
// Bitcast to the target type, if needed.
1655-
if (cast<VectorType>(result.getType()) != tgt_ty.getElementType()) {
1656-
result = b.create<arith::BitcastOp>(
1657-
VectorType::get(cast<VectorType>(result.getType()).getShape(),
1658-
tgt_ty.getElementType()),
1659-
result);
1660-
}
1661-
// Apply the reshape to major dims, if needed.
1662-
if (cast<VectorType>(result.getType()).getShape() != tgt_ty.getShape()) {
1663-
result = b.create<tpu::ReshapeOp>(tgt_ty, result);
1664-
}
1665-
DCHECK_EQ(result.getType(), tgt_ty);
1666-
1667-
op.replaceAllUsesWith(result);
1668-
op.erase();
1669-
return result;
1488+
// TODO(b/456092935): Better implementation for reshapes that (un)fold minor.
1489+
return raw_op.getResult(0);
16701490
}
16711491

16721492
FailureOr<Value> canonicalize_transpose(const CanonicalizeContext &ctx,

jaxlib/mosaic/dialect/tpu/transforms/pre_canonicalization_optimization.cc

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,199 @@ namespace mlir::tpu {
5050

5151
namespace {
5252

53+
void optimizeLoadReshape(int hardware_generation,
54+
std::array<int64_t, 2> target_shape,
55+
Operation& raw_op) {
56+
// Below, we try to look for reshapes that flatten multiple dims into the
57+
// lane dimension. If the source of the reshape originates from a load of a
58+
// ref with 128 minor dimension (effectively untiled), we can replace the
59+
// load/reshape sequence with an efficient strided load. In essence, the
60+
// strided load creates vregs with a narrow slice along the target minor
61+
// dimension, but with the 2nd minor dim after the reshape already in
62+
// sublanes. The results of strided load can be concatenated to form the
63+
// final vector result.
64+
//
65+
// A little extra care needs to be applied to packed types, which we handle by
66+
// briefly extending to 32-bit and repacking them after concatenation.
67+
TypedValue<VectorType> src;
68+
VectorType tgt_ty;
69+
if (auto op = dyn_cast<tpu::ReshapeOp>(&raw_op)) {
70+
src = op.getSource();
71+
tgt_ty = op.getResult().getType();
72+
} else if (auto op = dyn_cast<vector::ShapeCastOp>(&raw_op)) {
73+
src = op.getSource();
74+
tgt_ty = op.getResult().getType();
75+
} else {
76+
return;
77+
}
78+
VectorType src_ty = src.getType();
79+
if (src_ty.getRank() < 2 || tgt_ty.getRank() < 1) {
80+
return;
81+
}
82+
const int bitwidth = src_ty.getElementTypeBitWidth();
83+
const int packing = 32 / bitwidth;
84+
if (hardware_generation < 4 && packing > 1) {
85+
return;
86+
}
87+
88+
auto load_op = dyn_cast_if_present<vector::LoadOp>(src.getDefiningOp());
89+
// This rewrite might not be profitable if the load has other users.
90+
if (!load_op || !load_op.getBase().hasOneUse()) {
91+
return;
92+
}
93+
94+
TypedValue<MemRefType> ref = load_op.getBase();
95+
MemRefType ref_ty = getMemRefType(ref);
96+
// The reshape below might be invalid if the memref is not contiguous, but it
97+
// is an overly conservative check (we don't need all dims to be contiguous).
98+
if (!isContiguousMemref(ref)) {
99+
return;
100+
}
101+
102+
const int64_t lane = target_shape[1];
103+
auto src_shape = src_ty.getShape();
104+
auto tgt_shape = tgt_ty.getShape();
105+
// Only handle the cases where the minor dim starts out as the number of lanes
106+
// and we fold at least the second minor dim into it, in a way that changes
107+
// its shape.
108+
if (src_shape.back() != lane ||
109+
tgt_shape.back() % (packing * lane) != 0 ||
110+
tgt_shape.back() == src_shape.back() ||
111+
tgt_shape.back() < llvm::product_of(src_shape.take_back(2))) {
112+
return;
113+
}
114+
115+
// We don't handle memrefs with padding.
116+
auto tiled_layout = dyn_cast<tpu::TiledLayoutAttr>(ref_ty.getLayout());
117+
if (!tiled_layout || tiled_layout.getTiles().empty()) {
118+
return;
119+
}
120+
ArrayRef<int64_t> front_tile = tiled_layout.getTiles().front().dimensions();
121+
ArrayRef<int64_t> ref_tiled_shape =
122+
ref_ty.getShape().take_back(front_tile.size());
123+
for (int i = 0; i < front_tile.size(); ++i) {
124+
if (ref_tiled_shape[i] % front_tile[i]) {
125+
return;
126+
}
127+
}
128+
129+
// NOTE: We could generalize this to allow only flattening part of a dimension
130+
int folded_dims = 0;
131+
{
132+
int suffix_size = 1;
133+
auto sizes_it = src_shape.rbegin();
134+
while (suffix_size < tgt_shape.back()) {
135+
suffix_size *= *(sizes_it++);
136+
}
137+
// Make sure that the minor dim is folded only from entire major dims, not
138+
// from a part of some minor dim.
139+
if (suffix_size != tgt_shape.back()) {
140+
return;
141+
}
142+
folded_dims = sizes_it - src_shape.rbegin();
143+
}
144+
DCHECK_GE(folded_dims, 2); // Should fold at least 2nd minor into minor.
145+
146+
// We don't handle slicing in the folded dims at the moment.
147+
if (ref_ty.getShape().take_back(folded_dims) !=
148+
src_ty.getShape().take_back(folded_dims)) {
149+
return;
150+
}
151+
152+
Location loc = raw_op.getLoc();
153+
ImplicitLocOpBuilder b(loc, &raw_op);
154+
155+
// Flatten the untiled dims into second minor and bitcast to i32.
156+
// NOTE: Source vector shape might be different from ref shape when slicing.
157+
SmallVector<int64_t> mem_shape(ref_ty.getShape().drop_back(folded_dims));
158+
if (mem_shape.empty()) {
159+
mem_shape.push_back(1);
160+
}
161+
mem_shape.back() *= tgt_shape.back() / lane;
162+
mem_shape.push_back(lane);
163+
Value reshaped_ref = b.create<tpu::MemRefReshapeOp>(
164+
MemRefType::get(mem_shape, ref_ty.getElementType()), ref);
165+
*(mem_shape.end() - 2) /= packing;
166+
Value i32_view = b.create<tpu::MemRefBitcastOp>(
167+
MemRefType::get(mem_shape, b.getI32Type()), reshaped_ref);
168+
169+
// Define the shape of the small i32 chunk we will load in each iteration.
170+
// TODO(b/458291444): The loads we emit here might use suboptimal shapes and
171+
// we could do better by folding some dims (as much as slicing allows).
172+
SmallVector<int64_t> chunk_shape(src_shape.drop_back(folded_dims));
173+
if (chunk_shape.empty()) {
174+
chunk_shape.push_back(1);
175+
}
176+
chunk_shape.push_back(lane);
177+
VectorType chunk_ty = VectorType::get(chunk_shape, b.getI32Type());
178+
179+
SmallVector<int32_t> strides(mem_shape.size(), 1);
180+
const int64_t sublane_prod = tgt_shape.back() / lane;
181+
const int64_t stride = sublane_prod / packing;
182+
*(strides.end() - 2) = stride;
183+
184+
// Reuse indices from the original load for the prefix.
185+
auto indices = load_op.getIndices();
186+
SmallVector<Value> idxs(indices.drop_back(folded_dims));
187+
if (idxs.empty()) {
188+
idxs.push_back(IdxConst(0, b, loc));
189+
}
190+
Value split_base_idx =
191+
b.create<arith::MulIOp>(idxs.back(), IdxConst(stride, b, loc));
192+
idxs.push_back(IdxConst(0, b, loc));
193+
194+
SmallVector<Value> unpacked_chunks;
195+
unpacked_chunks.reserve(stride * packing);
196+
for (int i = 0; i < stride; ++i) {
197+
*(idxs.end() - 2) =
198+
b.create<arith::AddIOp>(split_base_idx, IdxConst(i, b, loc));
199+
Value chunk =
200+
b.create<tpu::StridedLoadOp>(chunk_ty, i32_view, idxs, strides);
201+
// Unpack elements from i32 if necessary.
202+
for (int p = 0; p < packing; ++p) {
203+
unpacked_chunks.push_back(b.create<arith::ShRUIOp>(
204+
chunk.getType(), chunk, I32Const(p * bitwidth, chunk_shape, b, loc)));
205+
}
206+
}
207+
208+
Value unpacked_flat;
209+
if (unpacked_chunks.size() == 1) {
210+
unpacked_flat = unpacked_chunks.front();
211+
} else {
212+
SmallVector<int64_t> concat_shape(src_shape.drop_back(folded_dims));
213+
if (concat_shape.empty()) {
214+
concat_shape.push_back(1);
215+
}
216+
concat_shape.push_back(tgt_shape.back());
217+
unpacked_flat = b.create<tpu::ConcatenateOp>(
218+
VectorType::get(concat_shape, b.getI32Type()), unpacked_chunks,
219+
concat_shape.size() - 1);
220+
}
221+
222+
Value result = unpacked_flat;
223+
if (packing > 1) { // Pack back, if needed.
224+
result = b.create<arith::TruncIOp>(
225+
VectorType::get(cast<VectorType>(result.getType()).getShape(),
226+
b.getIntegerType(bitwidth)),
227+
result);
228+
}
229+
// Bitcast to the target type, if needed.
230+
if (cast<VectorType>(result.getType()) != tgt_ty.getElementType()) {
231+
result = b.create<arith::BitcastOp>(
232+
VectorType::get(cast<VectorType>(result.getType()).getShape(),
233+
tgt_ty.getElementType()),
234+
result);
235+
}
236+
// Apply the reshape to major dims, if needed.
237+
if (cast<VectorType>(result.getType()).getShape() != tgt_ty.getShape()) {
238+
result = b.create<tpu::ReshapeOp>(tgt_ty, result);
239+
}
240+
DCHECK_EQ(result.getType(), tgt_ty);
241+
242+
raw_op.replaceAllUsesWith(ValueRange{result});
243+
raw_op.erase();
244+
}
245+
53246
void optimizeStore(int hardware_generation, std::array<int64_t, 2> target_shape,
54247
Operation& raw_op) {
55248
// Fuses a vector.shape_cast (that expands dimensions) into a subsequent
@@ -417,6 +610,8 @@ struct PreCanonicalizationOptimizationPass
417610
}
418611
} else if (isa<vector::StoreOp, tpu::VectorStoreOp>(op)) {
419612
optimizeStore(hardware_generation_, target_shape_, *op);
613+
} else if (isa<vector::ShapeCastOp, tpu::ReshapeOp>(op)) {
614+
optimizeLoadReshape(hardware_generation_, target_shape_, *op);
420615
}
421616
});
422617
}

0 commit comments

Comments
 (0)