Skip to content

Commit 957da30

Browse files
committed
[mlir][presburger] optimize bound computation by pruning orthogonal constraints
IntegerRelation uses Fourier-Motzkin elimination and Gaussian elimination to simplify constraints. These methods may repeatedly perform calculations and elimination on irrelevant variables. Preemptively eliminating irrelevant variables and their associated constraints can speed up up the calculation process.
1 parent 45495b5 commit 957da30

File tree

2 files changed

+104
-0
lines changed

2 files changed

+104
-0
lines changed

mlir/include/mlir/Analysis/Presburger/IntegerRelation.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,19 @@ class IntegerRelation {
205205
return inequalities(i, j);
206206
}
207207

208+
/// Unified indexing into the constraints. Index into the inequalities
209+
/// if i < getNumInequalities() and into the equalities otherwise.
210+
inline int64_t atConstraint64(unsigned i, unsigned j) const {
211+
assert(i < getNumConstraints());
212+
unsigned numIneqs = getNumInequalities();
213+
return i < numIneqs ? atIneq64(i, j) : atEq64(i - numIneqs, j);
214+
}
215+
inline DynamicAPInt &atConstraint(unsigned i, unsigned j) {
216+
assert(i < getNumConstraints());
217+
unsigned numIneqs = getNumInequalities();
218+
return i < numIneqs ? atIneq(i, j) : atEq(i - numIneqs, j);
219+
}
220+
208221
unsigned getNumConstraints() const {
209222
return getNumInequalities() + getNumEqualities();
210223
}
@@ -351,6 +364,7 @@ class IntegerRelation {
351364

352365
void removeEquality(unsigned pos);
353366
void removeInequality(unsigned pos);
367+
void removeConstraint(unsigned pos);
354368

355369
/// Remove the (in)equalities at positions [start, end).
356370
void removeEqualityRange(unsigned start, unsigned end);
@@ -511,6 +525,34 @@ class IntegerRelation {
511525
void projectOut(unsigned pos, unsigned num);
512526
inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
513527

528+
/// The function removes some constraints that do not impose any bound on the
529+
/// specified variable.
530+
///
531+
/// The set of constraints (equations/inequalities) can be modeled as an
532+
/// undirected graph where:
533+
/// 1. Variables are the nodes.
534+
/// 2. Constraints are the edges connecting those nodes.
535+
///
536+
/// Variables and constraints belonging to different connected components
537+
/// are irrelevant to each other. This property allows for safe pruning of
538+
/// constraints.
539+
///
540+
/// For example, given the following constraints:
541+
/// - Inequalities: (1) d0 + d1 > 0, (2) d1 >= 2, (3) d4 > 5
542+
/// - Equalities: (4) d3 + d4 = 1, (5) d0 - d2 = 3
543+
///
544+
/// These form two connected components:
545+
/// - Component 1: {d0, d1, d2} (related by constraints 1, 2, 5)
546+
/// - Component 2: {d3, d4} (related by constraint 4)
547+
///
548+
/// If we are querying the bound of variable `d0`, constraints related to
549+
/// Component 2 (e.g., constraints 3 and 4) can be safely pruned as they
550+
/// have no impact on the solution space of Component 1.
551+
/// This function prunes irrelevant constraints by identifying all variables
552+
/// and constraints that belong to the same connected component as the
553+
/// target variable.
554+
void pruneOrthogonalConstraints(unsigned pos);
555+
514556
/// Tries to fold the specified variable to a constant using a trivial
515557
/// equality detection; if successful, the constant is substituted for the
516558
/// variable everywhere in the constraint system and then removed from the

mlir/lib/Analysis/Presburger/IntegerRelation.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Analysis/Presburger/Simplex.h"
2222
#include "mlir/Analysis/Presburger/Utils.h"
2323
#include "llvm/ADT/DenseMap.h"
24+
#include "llvm/ADT/DenseSet.h"
2425
#include "llvm/ADT/STLExtras.h"
2526
#include "llvm/ADT/Sequence.h"
2627
#include "llvm/ADT/SmallBitVector.h"
@@ -441,6 +442,14 @@ void IntegerRelation::removeInequality(unsigned pos) {
441442
inequalities.removeRow(pos);
442443
}
443444

445+
void IntegerRelation::removeConstraint(unsigned pos) {
446+
if (pos >= (int)getNumInequalities()) {
447+
removeEquality(pos - getNumInequalities());
448+
} else {
449+
removeInequality(pos);
450+
}
451+
}
452+
444453
void IntegerRelation::removeEqualityRange(unsigned start, unsigned end) {
445454
if (start >= end)
446455
return;
@@ -1723,12 +1732,65 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
17231732
return minDiff;
17241733
}
17251734

1735+
void IntegerRelation::pruneOrthogonalConstraints(unsigned pos) {
1736+
llvm::DenseSet<unsigned> relatedCols({pos}), relatedRows;
1737+
1738+
// Early exit if constraints is empty.
1739+
unsigned numConstraints = getNumConstraints();
1740+
if (numConstraints == 0)
1741+
return;
1742+
1743+
llvm::SmallVector<unsigned> rowStack, colStack({pos});
1744+
// The following code performs a graph traversal, starting from the target
1745+
// variable, to identify all variables(recorded in relatedCols) and
1746+
// constraints (recorded in relatedRows) belonging to the same connected
1747+
// component.
1748+
while (!rowStack.empty() || !colStack.empty()) {
1749+
if (!rowStack.empty()) {
1750+
unsigned currentRow = rowStack.pop_back_val();
1751+
// Push all variable that accociated to this constraints to relatedCols
1752+
// and colStack.
1753+
for (unsigned colIndex = 0; colIndex < getNumVars(); ++colIndex) {
1754+
if (atConstraint(currentRow, colIndex) != 0 &&
1755+
relatedCols.insert(colIndex).second) {
1756+
colStack.push_back(colIndex);
1757+
}
1758+
}
1759+
} else {
1760+
unsigned currentCol = colStack.pop_back_val();
1761+
// Push all constraints that are associated with this variable to related
1762+
// rows and the row stack.
1763+
for (unsigned rowIndex = 0; rowIndex < numConstraints; ++rowIndex) {
1764+
if (atConstraint(rowIndex, currentCol) != 0 &&
1765+
relatedRows.insert(rowIndex).second) {
1766+
rowStack.push_back(rowIndex);
1767+
}
1768+
}
1769+
}
1770+
}
1771+
1772+
// Prune all constraints not related to target variable.
1773+
for (int constraintId = numConstraints - 1; constraintId >= 0;
1774+
--constraintId) {
1775+
if (!relatedRows.contains(constraintId)) {
1776+
removeConstraint(constraintId);
1777+
}
1778+
}
1779+
}
1780+
17261781
template <bool isLower>
17271782
std::optional<DynamicAPInt>
17281783
IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
17291784
assert(pos < getNumVars() && "invalid position");
17301785
// Project to 'pos'.
1786+
// Prune orthogonal constraints to reduce unnecessary computations and
1787+
// accelerate the bound computation.
1788+
pruneOrthogonalConstraints(pos);
17311789
projectOut(0, pos);
1790+
1791+
// After projecting out values, more orthogonal constraints may be exposed.
1792+
// Prune these orthogonal constraints again.
1793+
pruneOrthogonalConstraints(0);
17321794
projectOut(1, getNumVars() - 1);
17331795
// Check if there's an equality equating the '0'^th variable to a constant.
17341796
int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false);

0 commit comments

Comments
 (0)