Skip to content

Commit cf0ce35

Browse files
committed
Add lower bound support for range constraints
1 parent 506a0bb commit cf0ce35

File tree

6 files changed

+70
-25
lines changed

6 files changed

+70
-25
lines changed

include/tvm/relax/transform.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,19 @@ TVM_DLL Pass RewriteDataflowReshape();
125125
* The pass will reuse allocated memory to its best effort, in order to
126126
* reduce the total amount of allocated memory size.
127127
*
128-
* The pass "supports" dynamic shape in the way of TIR variable upper bound
129-
* annotation. We can optionally annotate the attribute "tir_var_upper_bound"
130-
* to Relax functions. The attribute value is a dict from strings to integers,
131-
* denoting the name of TIR variables to the upper bound values of the TIR vars.
132-
* Note: The annotated upper bound attribute only applies to TIR vars in the
128+
* The pass "supports" dynamic shape in the way of TIR variable bound
129+
* annotations. We can optionally annotate the attributes "tir_var_upper_bound"
130+
* and "tir_var_lower_bound" to Relax functions. The attribute values are dicts
131+
* from strings to integers, denoting the name of TIR variables to the bound
132+
* values of the TIR vars.
133+
* Note: The annotated bound attributes only apply to TIR vars in the
133134
* function signature for clarity.
134135
*
135136
* For example, we can annotate a Relax function with
136-
* `R.func_attr({"tir_var_upper_bound": {"n": 1024}})`.
137-
* It means the maximum value of variable that names "n" in the function
138-
* signature will have upper bound 1024. And we will use 1024 as its value
139-
* during memory planning.
137+
* `R.func_attr({"tir_var_lower_bound": {"n": 1}, "tir_var_upper_bound": {"n": 1024}})`.
138+
* It means the variable that names "n" in the function signature will have
139+
* range [1, 1024]. And we will use these bounds during memory planning.
140+
* If lower bound is not specified, it defaults to 0.
140141
*
141142
* \return The pass.
142143
*/

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,10 +1179,12 @@ def from_exported_program(
11791179
if range_constraints:
11801180
if func_attrs is None:
11811181
func_attrs = {}
1182-
tir_var_upper_bound = {
1182+
func_attrs["tir_var_lower_bound"] = {
1183+
var_name: lower for var_name, (lower, _) in range_constraints.items()
1184+
}
1185+
func_attrs["tir_var_upper_bound"] = {
11831186
var_name: upper for var_name, (_, upper) in range_constraints.items()
11841187
}
1185-
func_attrs["tir_var_upper_bound"] = tir_var_upper_bound
11861188

11871189
nodes: List[fx.Node] = exported_program.graph.nodes
11881190

src/relax/transform/adjust_matmul_order.cc

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,19 +73,37 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, ffi::Map<DFPattern, Expr>)>>
7373
pat_permuted_matmul_on_rhs;
7474

7575
PrimExpr symbolic_var_constraints = Bool(true);
76-
if (auto upper_bounds = func->GetAttr<ffi::Map<ffi::String, Any>>("tir_var_upper_bound")) {
76+
auto upper_bounds = func->GetAttr<ffi::Map<ffi::String, Any>>("tir_var_upper_bound");
77+
auto lower_bounds = func->GetAttr<ffi::Map<ffi::String, Any>>("tir_var_lower_bound");
78+
79+
if (upper_bounds || lower_bounds) {
7780
ffi::Map<ffi::String, tir::Var> name_lookup;
7881
for (const auto& tir_var : TIRVarsInStructInfo(GetStructInfo(func))) {
7982
name_lookup.Set(tir_var->name_hint, tir_var);
8083
symbolic_var_constraints = symbolic_var_constraints && (0 <= tir_var);
8184
}
8285

83-
for (const auto& [key, obj_bound] : upper_bounds.value()) {
84-
auto tir_var_name = Downcast<ffi::String>(key);
85-
if (auto opt_var = name_lookup.Get(tir_var_name)) {
86-
auto var = opt_var.value();
87-
auto expr_bound = Downcast<PrimExpr>(obj_bound);
88-
symbolic_var_constraints = symbolic_var_constraints && (var < expr_bound);
86+
// Add lower bound constraints
87+
if (lower_bounds) {
88+
for (const auto& [key, obj_bound] : lower_bounds.value()) {
89+
auto tir_var_name = Downcast<ffi::String>(key);
90+
if (auto opt_var = name_lookup.Get(tir_var_name)) {
91+
auto var = opt_var.value();
92+
auto expr_bound = Downcast<PrimExpr>(obj_bound);
93+
symbolic_var_constraints = symbolic_var_constraints && (expr_bound <= var);
94+
}
95+
}
96+
}
97+
98+
// Add upper bound constraints
99+
if (upper_bounds) {
100+
for (const auto& [key, obj_bound] : upper_bounds.value()) {
101+
auto tir_var_name = Downcast<ffi::String>(key);
102+
if (auto opt_var = name_lookup.Get(tir_var_name)) {
103+
auto var = opt_var.value();
104+
auto expr_bound = Downcast<PrimExpr>(obj_bound);
105+
symbolic_var_constraints = symbolic_var_constraints && (var < expr_bound);
106+
}
89107
}
90108
}
91109
}

src/relax/transform/static_plan_block_memory.cc

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -373,32 +373,44 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
373373
*/
374374
void SetTIRVarUpperBound(Function func, arith::Analyzer* ana,
375375
ffi::Map<tir::Var, arith::IntSet>* dom_map) {
376-
// Use the attribute-annotated TIR var upper bounds as the TIR var values for
376+
// Use the attribute-annotated TIR var bounds as the TIR var values for
377377
// memory planning.
378-
// NOTE: we only apply the annotated upper bounds to the TIR variables that
378+
// NOTE: we only apply the annotated bounds to the TIR variables that
379379
// appear in the **function signature**.
380380
ffi::Map<ffi::String, IntImm> var_upper_bound_attr_raw =
381381
func->GetAttr<ffi::Map<ffi::String, IntImm>>("tir_var_upper_bound")
382382
.value_or(ffi::Map<ffi::String, IntImm>());
383+
ffi::Map<ffi::String, IntImm> var_lower_bound_attr_raw =
384+
func->GetAttr<ffi::Map<ffi::String, IntImm>>("tir_var_lower_bound")
385+
.value_or(ffi::Map<ffi::String, IntImm>());
383386
ffi::Array<ffi::String> non_negative_var_attr_raw =
384387
func->GetAttr<ffi::Array<ffi::String>>("tir_non_negative_var")
385388
.value_or(ffi::Array<ffi::String>());
386389
std::unordered_map<ffi::String, IntImm> var_upper_bound_attr;
390+
std::unordered_map<ffi::String, IntImm> var_lower_bound_attr;
387391
std::unordered_set<ffi::String> non_negative_var_attr;
388392
// We manually check the value type to ensure the values are all positive IntImm.
389393
for (auto [key, value] : var_upper_bound_attr_raw) {
390394
var_upper_bound_attr[key] = value;
391395
}
396+
for (auto [key, value] : var_lower_bound_attr_raw) {
397+
var_lower_bound_attr[key] = value;
398+
}
392399
for (const ffi::String& var_name : non_negative_var_attr_raw) {
393400
non_negative_var_attr.insert(var_name);
394401
}
395402
ffi::Array<tir::Var> var_in_signature = TIRVarsInStructInfo(GetStructInfo(func));
396403
for (const tir::Var& tir_var : var_in_signature) {
397-
auto it = var_upper_bound_attr.find(tir_var->name_hint);
398-
if (it != var_upper_bound_attr.end()) {
404+
auto it_upper = var_upper_bound_attr.find(tir_var->name_hint);
405+
auto it_lower = var_lower_bound_attr.find(tir_var->name_hint);
406+
407+
if (it_upper != var_upper_bound_attr.end() || it_lower != var_lower_bound_attr.end()) {
408+
int64_t lower = (it_lower != var_lower_bound_attr.end()) ? it_lower->second->value : 0;
409+
int64_t upper = (it_upper != var_upper_bound_attr.end()) ? it_upper->second->value
410+
: std::numeric_limits<int64_t>::max();
399411
tvm::Range range =
400-
tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0),
401-
tvm::IntImm(DataType::Int(64), (*it).second->value + 1));
412+
tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), lower),
413+
tvm::IntImm(DataType::Int(64), upper - lower + 1));
402414
ana->Bind(tir_var, range);
403415
dom_map->Set(tir_var, arith::IntSet::FromRange(range));
404416
} else if (non_negative_var_attr.count(tir_var->name_hint)) {

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6689,7 +6689,7 @@ def main(
66896689
x1: R.Tensor(("s0", 4), dtype="float32"), x2: R.Tensor(("s0", 4), dtype="float32")
66906690
) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")):
66916691
s0 = T.int64(is_size_var=True)
6692-
R.func_attr({"tir_var_upper_bound": {"s0": 64}})
6692+
R.func_attr({"tir_var_lower_bound": {"s0": 1}, "tir_var_upper_bound": {"s0": 64}})
66936693
with R.dataflow():
66946694
lv: R.Tensor((s0, 4), dtype="float32") = R.add(x1, x2)
66956695
gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,)

tests/python/relax/test_transform_static_plan_block_memory.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,6 +1347,18 @@ def main(x: R.Tensor((2, "n"), dtype="float32")):
13471347
relax.transform.StaticPlanBlockMemory()(Module)
13481348

13491349

1350+
def test_invalid_tir_var_lower_bound():
1351+
@tvm.script.ir_module
1352+
class Module:
1353+
@R.function
1354+
def main(x: R.Tensor((2, "n"), dtype="float32")):
1355+
R.func_attr({"tir_var_lower_bound": {"n": [4]}, "relax.force_pure": True})
1356+
return x
1357+
1358+
with pytest.raises((TVMError, TypeError)):
1359+
relax.transform.StaticPlanBlockMemory()(Module)
1360+
1361+
13501362
def test_add():
13511363
@I.ir_module
13521364
class Module:

0 commit comments

Comments
 (0)