@@ -44,52 +44,28 @@ class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
4444 PrimExpr simplified = analyzer_.Simplify (indices[i]);
4545 IndexSignState state = IndexSignState::kUnknown ;
4646
47- // Handle scalar indices with the standard analyzer
48- if (simplified.dtype ().lanes () == 1 ) {
49- if (analyzer_.CanProve (simplified >= 0 ))
47+ // Handle vector patterns first to avoid querying lanes() on
48+ // scalable vectors (which is not allowed at compile-time).
49+ if (const auto *ramp = simplified.as <RampNode>()) {
50+ // For scalable vectors, we cannot rely on a constant lane count.
51+ // Use sufficient (but not necessary) conditions:
52+ // - If base >= 0 and stride >= 0, all lanes are non-negative.
53+ // - If base < 0 and stride <= 0, all lanes are negative.
54+ bool base_nonneg = analyzer_.CanProve (ramp->base >= 0 );
55+ bool base_neg = analyzer_.CanProve (ramp->base < 0 );
56+ bool stride_nonneg = analyzer_.CanProve (ramp->stride >= 0 );
57+ bool stride_nonpos = analyzer_.CanProve (ramp->stride <= 0 );
58+
59+ if (base_nonneg && stride_nonneg) {
5060 state = IndexSignState::kNonNegative ;
51- else if (analyzer_. CanProve (simplified < 0 ))
61+ } else if (base_neg && stride_nonpos) {
5262 state = IndexSignState::kNegative ;
53- else
54- DLOG (WARNING)
55- << " LegalizeNegativeIndex: cannot prove non-negative index "
56- << simplified << " for buffer " << buffer_name << " (axis " << i
57- << " , index " + indices[i]->Script () + " )." ;
58- }
59- // Vector indices: try to reason about non-negativity/negativity
60- // Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
61- // lanes).
62- else if (const auto *ramp = simplified.as <RampNode>()) {
63- // Compute a safe lower/upper bound for the vector lanes
64- // lower_bound = base_min + min(0, stride_min) * (lanes - 1)
65- // upper_bound = base_max + max(0, stride_max) * (lanes - 1)
66- auto base_bound = analyzer_.const_int_bound (ramp->base );
67- auto stride_bound = analyzer_.const_int_bound (ramp->stride );
68- int lanes = *as_const_int (ramp->lanes );
69-
70- int64_t base_min = base_bound->min_value ;
71- int64_t base_max = base_bound->max_value ;
72- int64_t s_min = stride_bound->min_value ;
73- int64_t s_max = stride_bound->max_value ;
74-
75- // Guard against overflow is not strictly necessary here because
76- // bounds may be +/-inf represented by sentinel values.
77- int64_t lower = base_min;
78- if (s_min < 0 )
79- lower += s_min * (lanes - 1 );
80- int64_t upper = base_max;
81- if (s_max > 0 )
82- upper += s_max * (lanes - 1 );
83-
84- if (lower >= 0 )
85- state = IndexSignState::kNonNegative ;
86- else if (upper < 0 )
87- state = IndexSignState::kNegative ;
88- else
63+ } else {
8964 DLOG (WARNING)
9065 << " LegalizeNegativeIndex: cannot prove non-negative index "
9166 << simplified << " for buffer " << buffer_name << " (axis " << i
9267 << " , index " + indices[i]->Script () + " )." ;
68+ }
9369 } else if (const auto *broadcast = simplified.as <BroadcastNode>()) {
9470 auto v = analyzer_.Simplify (broadcast->value );
9571 if (analyzer_.CanProve (v >= 0 ))
@@ -109,6 +85,20 @@ class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
10985 << simplified << " for buffer " << buffer_name << " (axis " << i
11086 << " , index " + indices[i]->Script () + " )." ;
11187 }
88+ } else {
89+ // Assume scalar (or non-Ramp/Broadcast) index; avoid querying lanes().
90+ // Fall back to scalar reasoning. If this expression is actually a
91+ // vector-but-not-Ramp/Broadcast, treat as unknown to be safe.
92+ // Try to prove scalar first; if proof fails, leave as unknown.
93+ if (analyzer_.CanProve (simplified >= 0 ))
94+ state = IndexSignState::kNonNegative ;
95+ else if (analyzer_.CanProve (simplified < 0 ))
96+ state = IndexSignState::kNegative ;
97+ else
98+ DLOG (WARNING)
99+ << " LegalizeNegativeIndex: cannot prove non-negative index "
100+ << simplified << " for buffer " << buffer_name << " (axis " << i
101+ << " , index " + indices[i]->Script () + " )." ;
112102 }
113103 states.push_back (state);
114104 }
0 commit comments