Skip to content

Commit 36a2b2f

Browse files
authored
[Refactor] Simplify index sign state handling in LegalizeNegativeIndex (tile-ai#1354)
This commit refines the logic for determining the sign state of indices in the LegalizeNegativeIndex transformation. It prioritizes vector patterns, specifically Ramp and Broadcast nodes, to avoid compile-time lane queries. The handling of scalar indices is also streamlined, ensuring clearer diagnostics when non-negativity cannot be proven. These changes enhance the robustness and clarity of index handling in the transformation pass.
1 parent 1e92d11 commit 36a2b2f

File tree

1 file changed

+30
-40
lines changed

1 file changed

+30
-40
lines changed

src/transform/legalize_negative_index.cc

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -44,52 +44,28 @@ class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
4444
PrimExpr simplified = analyzer_.Simplify(indices[i]);
4545
IndexSignState state = IndexSignState::kUnknown;
4646

47-
// Handle scalar indices with the standard analyzer
48-
if (simplified.dtype().lanes() == 1) {
49-
if (analyzer_.CanProve(simplified >= 0))
47+
// Handle vector patterns first to avoid querying lanes() on
48+
// scalable vectors (which is not allowed at compile-time).
49+
if (const auto *ramp = simplified.as<RampNode>()) {
50+
// For scalable vectors, we cannot rely on a constant lane count.
51+
// Use sufficient (but not necessary) conditions:
52+
// - If base >= 0 and stride >= 0, all lanes are non-negative.
53+
// - If base < 0 and stride <= 0, all lanes are negative.
54+
bool base_nonneg = analyzer_.CanProve(ramp->base >= 0);
55+
bool base_neg = analyzer_.CanProve(ramp->base < 0);
56+
bool stride_nonneg = analyzer_.CanProve(ramp->stride >= 0);
57+
bool stride_nonpos = analyzer_.CanProve(ramp->stride <= 0);
58+
59+
if (base_nonneg && stride_nonneg) {
5060
state = IndexSignState::kNonNegative;
51-
else if (analyzer_.CanProve(simplified < 0))
61+
} else if (base_neg && stride_nonpos) {
5262
state = IndexSignState::kNegative;
53-
else
54-
DLOG(WARNING)
55-
<< "LegalizeNegativeIndex: cannot prove non-negative index "
56-
<< simplified << " for buffer " << buffer_name << " (axis " << i
57-
<< ", index " + indices[i]->Script() + ").";
58-
}
59-
// Vector indices: try to reason about non-negativity/negativity
60-
// Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
61-
// lanes).
62-
else if (const auto *ramp = simplified.as<RampNode>()) {
63-
// Compute a safe lower/upper bound for the vector lanes
64-
// lower_bound = base_min + min(0, stride_min) * (lanes - 1)
65-
// upper_bound = base_max + max(0, stride_max) * (lanes - 1)
66-
auto base_bound = analyzer_.const_int_bound(ramp->base);
67-
auto stride_bound = analyzer_.const_int_bound(ramp->stride);
68-
int lanes = *as_const_int(ramp->lanes);
69-
70-
int64_t base_min = base_bound->min_value;
71-
int64_t base_max = base_bound->max_value;
72-
int64_t s_min = stride_bound->min_value;
73-
int64_t s_max = stride_bound->max_value;
74-
75-
// Guard against overflow is not strictly necessary here because
76-
// bounds may be +/-inf represented by sentinel values.
77-
int64_t lower = base_min;
78-
if (s_min < 0)
79-
lower += s_min * (lanes - 1);
80-
int64_t upper = base_max;
81-
if (s_max > 0)
82-
upper += s_max * (lanes - 1);
83-
84-
if (lower >= 0)
85-
state = IndexSignState::kNonNegative;
86-
else if (upper < 0)
87-
state = IndexSignState::kNegative;
88-
else
63+
} else {
8964
DLOG(WARNING)
9065
<< "LegalizeNegativeIndex: cannot prove non-negative index "
9166
<< simplified << " for buffer " << buffer_name << " (axis " << i
9267
<< ", index " + indices[i]->Script() + ").";
68+
}
9369
} else if (const auto *broadcast = simplified.as<BroadcastNode>()) {
9470
auto v = analyzer_.Simplify(broadcast->value);
9571
if (analyzer_.CanProve(v >= 0))
@@ -109,6 +85,20 @@ class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
10985
<< simplified << " for buffer " << buffer_name << " (axis " << i
11086
<< ", index " + indices[i]->Script() + ").";
11187
}
88+
} else {
89+
// Assume scalar (or non-Ramp/Broadcast) index; avoid querying lanes().
90+
// Fall back to scalar reasoning. If this expression is actually a
91+
// vector-but-not-Ramp/Broadcast, treat as unknown to be safe.
92+
// Try to prove scalar first; if proof fails, leave as unknown.
93+
if (analyzer_.CanProve(simplified >= 0))
94+
state = IndexSignState::kNonNegative;
95+
else if (analyzer_.CanProve(simplified < 0))
96+
state = IndexSignState::kNegative;
97+
else
98+
DLOG(WARNING)
99+
<< "LegalizeNegativeIndex: cannot prove non-negative index "
100+
<< simplified << " for buffer " << buffer_name << " (axis " << i
101+
<< ", index " + indices[i]->Script() + ").";
112102
}
113103
states.push_back(state);
114104
}

0 commit comments

Comments
 (0)