Skip to content

Commit 23289a8

Browse files
cxy-1993kcloudy0717
authored andcommitted
[MLIR][Presburger] optimize bound computation by pruning orthogonal constraints (llvm#164199)
IntegerRelation uses Fourier-Motzkin elimination and Gaussian elimination to simplify constraints. These methods may repeatedly perform calculations and elimination on irrelevant variables. Preemptively eliminating irrelevant variables and their associated constraints can speed up up the calculation process.
1 parent e6662ff commit 23289a8

File tree

2 files changed

+94
-4
lines changed

2 files changed

+94
-4
lines changed

mlir/include/mlir/Analysis/Presburger/IntegerRelation.h

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,12 @@ class IntegerRelation {
210210
return getNumInequalities() + getNumEqualities();
211211
}
212212

213-
// Unified indexing into the constraints. Index into the inequalities
214-
// if i < getNumInequalities() and into the equalities otherwise.
215-
inline DynamicAPInt atConstraint(unsigned i, unsigned j) const {
213+
/// Unified indexing into the constraints. Index into the inequalities
214+
/// if i < getNumInequalities() and into the equalities otherwise.
215+
inline int64_t atConstraint64(unsigned i, unsigned j) const {
216216
assert(i < getNumConstraints());
217217
unsigned numIneqs = getNumInequalities();
218-
return i < numIneqs ? atIneq(i, j) : atEq(i - numIneqs, j);
218+
return i < numIneqs ? atIneq64(i, j) : atEq64(i - numIneqs, j);
219219
}
220220
inline DynamicAPInt &atConstraint(unsigned i, unsigned j) {
221221
assert(i < getNumConstraints());
@@ -365,6 +365,7 @@ class IntegerRelation {
365365

366366
void removeEquality(unsigned pos);
367367
void removeInequality(unsigned pos);
368+
void removeConstraint(unsigned pos);
368369

369370
/// Remove the (in)equalities at positions [start, end).
370371
void removeEqualityRange(unsigned start, unsigned end);
@@ -525,6 +526,34 @@ class IntegerRelation {
525526
void projectOut(unsigned pos, unsigned num);
526527
inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
527528

529+
/// The function removes some constraints that do not impose any bound on the
530+
/// specified variable.
531+
///
532+
/// The set of constraints (equations/inequalities) can be modeled as an
533+
/// undirected graph where:
534+
/// 1. Variables are the nodes.
535+
/// 2. Constraints are the edges connecting those nodes.
536+
///
537+
/// Variables and constraints belonging to different connected components
538+
/// are irrelevant to each other. This property allows for safe pruning of
539+
/// constraints.
540+
///
541+
/// For example, given the following constraints:
542+
/// - Inequalities: (1) d0 + d1 > 0, (2) d1 >= 2, (3) d4 > 5
543+
/// - Equalities: (4) d3 + d4 = 1, (5) d0 - d2 = 3
544+
///
545+
/// These form two connected components:
546+
/// - Component 1: {d0, d1, d2} (related by constraints 1, 2, 5)
547+
/// - Component 2: {d3, d4} (related by constraint 4)
548+
///
549+
/// If we are querying the bound of variable `d0`, constraints related to
550+
/// Component 2 (e.g., constraints 3 and 4) can be safely pruned as they
551+
/// have no impact on the solution space of Component 1.
552+
/// This function prunes irrelevant constraints by identifying all variables
553+
/// and constraints that belong to the same connected component as the
554+
/// target variable.
555+
void pruneOrthogonalConstraints(unsigned pos);
556+
528557
/// Tries to fold the specified variable to a constant using a trivial
529558
/// equality detection; if successful, the constant is substituted for the
530559
/// variable everywhere in the constraint system and then removed from the

mlir/lib/Analysis/Presburger/IntegerRelation.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Analysis/Presburger/Simplex.h"
2222
#include "mlir/Analysis/Presburger/Utils.h"
2323
#include "llvm/ADT/DenseMap.h"
24+
#include "llvm/ADT/DenseSet.h"
2425
#include "llvm/ADT/STLExtras.h"
2526
#include "llvm/ADT/Sequence.h"
2627
#include "llvm/ADT/SmallBitVector.h"
@@ -442,6 +443,14 @@ void IntegerRelation::removeInequality(unsigned pos) {
442443
inequalities.removeRow(pos);
443444
}
444445

446+
void IntegerRelation::removeConstraint(unsigned pos) {
447+
if (pos >= getNumInequalities()) {
448+
removeEquality(pos - getNumInequalities());
449+
} else {
450+
removeInequality(pos);
451+
}
452+
}
453+
445454
void IntegerRelation::removeEqualityRange(unsigned start, unsigned end) {
446455
if (start >= end)
447456
return;
@@ -1742,12 +1751,64 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
17421751
return minDiff;
17431752
}
17441753

1754+
void IntegerRelation::pruneOrthogonalConstraints(unsigned pos) {
1755+
llvm::DenseSet<unsigned> relatedCols({pos}), relatedRows;
1756+
1757+
// Early exit if constraints is empty.
1758+
unsigned numConstraints = getNumConstraints();
1759+
if (numConstraints == 0)
1760+
return;
1761+
1762+
llvm::SmallVector<unsigned> rowStack, colStack({pos});
1763+
// The following code performs a graph traversal, starting from the target
1764+
// variable, to identify all variables(recorded in relatedCols) and
1765+
// constraints (recorded in relatedRows) belonging to the same connected
1766+
// component.
1767+
while (!rowStack.empty() || !colStack.empty()) {
1768+
if (!rowStack.empty()) {
1769+
unsigned currentRow = rowStack.pop_back_val();
1770+
// Push all variable that accociated to this constraints to relatedCols
1771+
// and colStack.
1772+
for (unsigned colIndex = 0; colIndex < getNumVars(); ++colIndex) {
1773+
if (atConstraint(currentRow, colIndex) != 0 &&
1774+
relatedCols.insert(colIndex).second) {
1775+
colStack.push_back(colIndex);
1776+
}
1777+
}
1778+
} else {
1779+
unsigned currentCol = colStack.pop_back_val();
1780+
// Push all constraints that are associated with this variable to related
1781+
// rows and the row stack.
1782+
for (unsigned rowIndex = 0; rowIndex < numConstraints; ++rowIndex) {
1783+
if (atConstraint(rowIndex, currentCol) != 0 &&
1784+
relatedRows.insert(rowIndex).second) {
1785+
rowStack.push_back(rowIndex);
1786+
}
1787+
}
1788+
}
1789+
}
1790+
1791+
// Prune all constraints not related to target variable.
1792+
for (int constraintId = numConstraints - 1; constraintId >= 0;
1793+
--constraintId) {
1794+
if (!relatedRows.contains(constraintId))
1795+
removeConstraint((unsigned)constraintId);
1796+
}
1797+
}
1798+
17451799
template <bool isLower>
17461800
std::optional<DynamicAPInt>
17471801
IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
17481802
assert(pos < getNumVars() && "invalid position");
17491803
// Project to 'pos'.
1804+
// Prune orthogonal constraints to reduce unnecessary computations and
1805+
// accelerate the bound computation.
1806+
pruneOrthogonalConstraints(pos);
17501807
projectOut(0, pos);
1808+
1809+
// After projecting out values, more orthogonal constraints may be exposed.
1810+
// Prune these orthogonal constraints again.
1811+
pruneOrthogonalConstraints(0);
17511812
projectOut(1, getNumVars() - 1);
17521813
// Check if there's an equality equating the '0'^th variable to a constant.
17531814
int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false);

0 commit comments

Comments
 (0)