Skip to content

Commit facc4ac

Browse files
committed
[mlir][presburger] Optimize the compilation time for calculating bounds of an Integer Relation
IntegerRelation uses Fourier-Motzkin elimination and Gaussian elimination to simplify constraints. These methods may repeatedly perform calculations and elimination on irrelevant variables. Preemptively eliminating irrelevant variables and their associated constraints can speed up up the calculation process.
1 parent 45495b5 commit facc4ac

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

mlir/include/mlir/Analysis/Presburger/IntegerRelation.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,34 @@ class IntegerRelation {
511511
void projectOut(unsigned pos, unsigned num);
512512
inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
513513

514+
/// The function removes some constraints that do not impose any bound on the
515+
/// specified variable.
516+
///
517+
/// The set of constraints (equations/inequalities) can be modeled as an
518+
/// undirected graph where:
519+
/// 1. Variables are the nodes.
520+
/// 2. Constraints are the edges connecting those nodes.
521+
///
522+
/// Variables and constraints belonging to different connected components
523+
/// are irrelevant to each other. This property allows for safe pruning of
524+
/// constraints.
525+
///
526+
/// For example, given the following constraints:
527+
/// - Inequalities: (1) d0 + d1 > 0, (2) d1 >= 2, (3) d4 > 5
528+
/// - Equalities: (4) d3 + d4 = 1, (5) d0 - d2 = 3
529+
///
530+
/// These form two connected components:
531+
/// - Component 1: {d0, d1, d2} (related by constraints 1, 2, 5)
532+
/// - Component 2: {d3, d4} (related by constraint 4)
533+
///
534+
/// If we are querying the bound of variable `d0`, constraints related to
535+
/// Component 2 (e.g., constraints 3 and 4) can be safely pruned as they
536+
/// have no impact on the solution space of Component 1.
537+
/// This function prunes irrelevant constraints by identifying all variables
538+
/// and constraints that belong to the same connected component as the
539+
/// target variable.
540+
void pruneOrthogonalConstraints(unsigned pos);
541+
514542
/// Tries to fold the specified variable to a constant using a trivial
515543
/// equality detection; if successful, the constant is substituted for the
516544
/// variable everywhere in the constraint system and then removed from the

mlir/lib/Analysis/Presburger/IntegerRelation.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Analysis/Presburger/Simplex.h"
2222
#include "mlir/Analysis/Presburger/Utils.h"
2323
#include "llvm/ADT/DenseMap.h"
24+
#include "llvm/ADT/DenseSet.h"
2425
#include "llvm/ADT/STLExtras.h"
2526
#include "llvm/ADT/Sequence.h"
2627
#include "llvm/ADT/SmallBitVector.h"
@@ -1723,12 +1724,84 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
17231724
return minDiff;
17241725
}
17251726

1727+
void IntegerRelation::pruneOrthogonalConstraints(unsigned pos) {
1728+
llvm::DenseSet<unsigned> relatedCols({pos}), relatedRows;
1729+
1730+
// Early exit if constraints is empty.
1731+
unsigned numConstraints = getNumConstraints();
1732+
if (numConstraints == 0)
1733+
return;
1734+
1735+
llvm::SmallVector<unsigned> rowStack, colStack({pos});
1736+
// The following code performs a graph traversal, starting from the target
1737+
// variable, to identify all variables(recorded in relatedCols) and
1738+
// constraints (recorded in relatedRows) belonging to the same connected
1739+
// component.
1740+
while (!rowStack.empty() || !colStack.empty()) {
1741+
if (!rowStack.empty()) {
1742+
unsigned currentRow = rowStack.pop_back_val();
1743+
// Push all variable that accociated to this constraints to relatedCols
1744+
// and colStack.
1745+
for (unsigned colIndex = 0; colIndex < getNumVars(); ++colIndex) {
1746+
if (currentRow < getNumInequalities()) {
1747+
if (atIneq(currentRow, colIndex) != 0 &&
1748+
relatedCols.insert(colIndex).second) {
1749+
colStack.push_back(colIndex);
1750+
}
1751+
} else {
1752+
if (atEq(currentRow - getNumInequalities(), colIndex) != 0 &&
1753+
relatedCols.insert(colIndex).second) {
1754+
colStack.push_back(colIndex);
1755+
}
1756+
}
1757+
}
1758+
} else {
1759+
unsigned currentCol = colStack.pop_back_val();
1760+
// Push all constraints that are associated with this variable to related
1761+
// rows and the row stack.
1762+
for (unsigned rowIndex = 0; rowIndex < numConstraints; ++rowIndex) {
1763+
if (rowIndex < getNumInequalities()) {
1764+
if (atIneq(rowIndex, currentCol) != 0 &&
1765+
relatedRows.insert(rowIndex).second) {
1766+
rowStack.push_back(rowIndex);
1767+
}
1768+
} else {
1769+
if (atEq(rowIndex - getNumInequalities(), currentCol) != 0 &&
1770+
relatedRows.insert(rowIndex).second) {
1771+
rowStack.push_back(rowIndex);
1772+
}
1773+
}
1774+
}
1775+
}
1776+
}
1777+
1778+
// Prune all constraints not related to target variable.
1779+
for (int constraintId = numConstraints - 1; constraintId >= 0;
1780+
--constraintId) {
1781+
if (relatedRows.contains(constraintId)) {
1782+
continue;
1783+
}
1784+
if (constraintId >= getNumInequalities()) {
1785+
removeEquality(constraintId - getNumInequalities());
1786+
} else {
1787+
removeInequality(constraintId);
1788+
}
1789+
}
1790+
}
1791+
17261792
template <bool isLower>
17271793
std::optional<DynamicAPInt>
17281794
IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
17291795
assert(pos < getNumVars() && "invalid position");
17301796
// Project to 'pos'.
1797+
// Prune orthogonal constraints to reduce unnecessary computations and
1798+
// accelerate the bound computation.
1799+
pruneOrthogonalConstraints(pos);
17311800
projectOut(0, pos);
1801+
1802+
// After projecting out values, more orthogonal constraints may be exposed.
1803+
// Prune these orthogonal constraints again.
1804+
pruneOrthogonalConstraints(0);
17321805
projectOut(1, getNumVars() - 1);
17331806
// Check if there's an equality equating the '0'^th variable to a constant.
17341807
int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false);

0 commit comments

Comments
 (0)