Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,19 @@ TVM_DLL Pass RewriteDataflowReshape();
* The pass will reuse allocated memory to its best effort, in order to
* reduce the total amount of allocated memory size.
*
* The pass "supports" dynamic shape in the way of TIR variable upper bound
* annotation. We can optionally annotate the attribute "tir_var_upper_bound"
* to Relax functions. The attribute value is a dict from strings to integers,
* denoting the name of TIR variables to the upper bound values of the TIR vars.
* Note: The annotated upper bound attribute only applies to TIR vars in the
* The pass "supports" dynamic shape in the way of TIR variable bound
* annotations. We can optionally annotate the attributes "tir_var_upper_bound"
* and "tir_var_lower_bound" to Relax functions. The attribute values are dicts
* from strings to integers, denoting the name of TIR variables to the bound
* values of the TIR vars.
* Note: The annotated bound attributes only apply to TIR vars in the
* function signature for clarity.
*
* For example, we can annotate a Relax function with
* `R.func_attr({"tir_var_upper_bound": {"n": 1024}})`.
* It means the maximum value of variable that names "n" in the function
* signature will have upper bound 1024. And we will use 1024 as its value
* during memory planning.
* `R.func_attr({"tir_var_lower_bound": {"n": 1}, "tir_var_upper_bound": {"n": 1024}})`.
* It means the variable that names "n" in the function signature will have
* range [1, 1024]. And we will use these bounds during memory planning.
* If lower bound is not specified, it defaults to 0.
*
* \return The pass.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1181,10 +1181,12 @@ def from_exported_program(
if range_constraints:
if func_attrs is None:
func_attrs = {}
tir_var_upper_bound = {
func_attrs["tir_var_lower_bound"] = {
var_name: lower for var_name, (lower, _) in range_constraints.items()
}
func_attrs["tir_var_upper_bound"] = {
var_name: upper for var_name, (_, upper) in range_constraints.items()
}
func_attrs["tir_var_upper_bound"] = tir_var_upper_bound

nodes: List[fx.Node] = exported_program.graph.nodes

Expand Down
32 changes: 25 additions & 7 deletions src/relax/transform/adjust_matmul_order.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,37 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, ffi::Map<DFPattern, Expr>)>>
pat_permuted_matmul_on_rhs;

PrimExpr symbolic_var_constraints = Bool(true);
if (auto upper_bounds = func->GetAttr<ffi::Map<ffi::String, Any>>("tir_var_upper_bound")) {
auto upper_bounds = func->GetAttr<ffi::Map<ffi::String, Any>>("tir_var_upper_bound");
auto lower_bounds = func->GetAttr<ffi::Map<ffi::String, Any>>("tir_var_lower_bound");

if (upper_bounds || lower_bounds) {
ffi::Map<ffi::String, tir::Var> name_lookup;
for (const auto& tir_var : TIRVarsInStructInfo(GetStructInfo(func))) {
name_lookup.Set(tir_var->name_hint, tir_var);
symbolic_var_constraints = symbolic_var_constraints && (0 <= tir_var);
}

for (const auto& [key, obj_bound] : upper_bounds.value()) {
auto tir_var_name = Downcast<ffi::String>(key);
if (auto opt_var = name_lookup.Get(tir_var_name)) {
auto var = opt_var.value();
auto expr_bound = Downcast<PrimExpr>(obj_bound);
symbolic_var_constraints = symbolic_var_constraints && (var < expr_bound);
// Add lower bound constraints
if (lower_bounds) {
for (const auto& [key, obj_bound] : lower_bounds.value()) {
auto tir_var_name = Downcast<ffi::String>(key);
if (auto opt_var = name_lookup.Get(tir_var_name)) {
auto var = opt_var.value();
auto expr_bound = Downcast<PrimExpr>(obj_bound);
symbolic_var_constraints = symbolic_var_constraints && (expr_bound <= var);
}
}
}

// Add upper bound constraints
if (upper_bounds) {
for (const auto& [key, obj_bound] : upper_bounds.value()) {
auto tir_var_name = Downcast<ffi::String>(key);
if (auto opt_var = name_lookup.Get(tir_var_name)) {
auto var = opt_var.value();
auto expr_bound = Downcast<PrimExpr>(obj_bound);
symbolic_var_constraints = symbolic_var_constraints && (var < expr_bound);
}
}
}
}
Expand Down
38 changes: 25 additions & 13 deletions src/relax/transform/static_plan_block_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,40 +365,52 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
};

/*!
* \brief Set the upper bound of the TIR variables that appear in
* \brief Set the range constraints of the TIR variables that appear in
* the input function signature in the analyzer.
* \param func The function to be analyzed.
* \param ana The analyzer which contains the TIR var upper bounds.
* \param dom_map The domain map of the TIR variables.
*/
void SetTIRVarUpperBound(Function func, arith::Analyzer* ana,
ffi::Map<tir::Var, arith::IntSet>* dom_map) {
// Use the attribute-annotated TIR var upper bounds as the TIR var values for
void SetTIRVarRangeConstraints(Function func, arith::Analyzer* ana,
ffi::Map<tir::Var, arith::IntSet>* dom_map) {
// Use the attribute-annotated TIR var bounds as the TIR var values for
// memory planning.
// NOTE: we only apply the annotated upper bounds to the TIR variables that
// NOTE: we only apply the annotated bounds to the TIR variables that
// appear in the **function signature**.
ffi::Map<ffi::String, IntImm> var_upper_bound_attr_raw =
func->GetAttr<ffi::Map<ffi::String, IntImm>>("tir_var_upper_bound")
.value_or(ffi::Map<ffi::String, IntImm>());
ffi::Map<ffi::String, IntImm> var_lower_bound_attr_raw =
func->GetAttr<ffi::Map<ffi::String, IntImm>>("tir_var_lower_bound")
.value_or(ffi::Map<ffi::String, IntImm>());
ffi::Array<ffi::String> non_negative_var_attr_raw =
func->GetAttr<ffi::Array<ffi::String>>("tir_non_negative_var")
.value_or(ffi::Array<ffi::String>());
std::unordered_map<ffi::String, IntImm> var_upper_bound_attr;
std::unordered_map<ffi::String, IntImm> var_lower_bound_attr;
std::unordered_set<ffi::String> non_negative_var_attr;
// We manually check the value type to ensure the values are all positive IntImm.
for (auto [key, value] : var_upper_bound_attr_raw) {
var_upper_bound_attr[key] = value;
}
for (auto [key, value] : var_lower_bound_attr_raw) {
var_lower_bound_attr[key] = value;
}
for (const ffi::String& var_name : non_negative_var_attr_raw) {
non_negative_var_attr.insert(var_name);
}
ffi::Array<tir::Var> var_in_signature = TIRVarsInStructInfo(GetStructInfo(func));
for (const tir::Var& tir_var : var_in_signature) {
auto it = var_upper_bound_attr.find(tir_var->name_hint);
if (it != var_upper_bound_attr.end()) {
tvm::Range range =
tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0),
tvm::IntImm(DataType::Int(64), (*it).second->value + 1));
auto it_upper = var_upper_bound_attr.find(tir_var->name_hint);
auto it_lower = var_lower_bound_attr.find(tir_var->name_hint);

if (it_upper != var_upper_bound_attr.end() || it_lower != var_lower_bound_attr.end()) {
int64_t lower = (it_lower != var_lower_bound_attr.end()) ? it_lower->second->value : 0;
int64_t upper = (it_upper != var_upper_bound_attr.end())
? it_upper->second->value
: std::numeric_limits<int64_t>::max();
tvm::Range range = tvm::Range::FromMinExtent(
tvm::IntImm(DataType::Int(64), lower), tvm::IntImm(DataType::Int(64), upper - lower + 1));
ana->Bind(tir_var, range);
dom_map->Set(tir_var, arith::IntSet::FromRange(range));
} else if (non_negative_var_attr.count(tir_var->name_hint)) {
Expand Down Expand Up @@ -485,8 +497,8 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
: ctx_mod_(ctx_mod), analyzer_(analyzer) {}

void VisitExpr_(const FunctionNode* func) final {
// Set the upper bound of TIR variables in the analyzer.
SetTIRVarUpperBound(ffi::GetRef<Function>(func), analyzer_, &dom_map_);
// Set the range constraints of TIR variables in the analyzer.
SetTIRVarRangeConstraints(ffi::GetRef<Function>(func), analyzer_, &dom_map_);
// Recurse into the function to get its tokens.
Tokens body_tokens = GetTokens(func->body);
// Discard the tokens used by the function return value, as they are external referenced.
Expand Down Expand Up @@ -843,7 +855,7 @@ class StorageAllocationRewriter : public ExprMutator {
plan_dynamic_output_ = static_cast<bool>(
func_->GetAttr<IntImm>(plan_dyn_attr_).value_or(IntImm(DataType::Int(32), 0))->value);
if (plan_dynamic_output_) {
SetTIRVarUpperBound(ffi::GetRef<Function>(func_), &ana_, &dom_map_);
SetTIRVarRangeConstraints(ffi::GetRef<Function>(func_), &ana_, &dom_map_);
}
token2storage_var_.clear();
Function func = Downcast<Function>(this->VisitExpr_(func_));
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6747,7 +6747,7 @@ def main(
x1: R.Tensor(("s0", 4), dtype="float32"), x2: R.Tensor(("s0", 4), dtype="float32")
) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")):
s0 = T.int64(is_size_var=True)
R.func_attr({"tir_var_upper_bound": {"s0": 64}})
R.func_attr({"tir_var_lower_bound": {"s0": 1}, "tir_var_upper_bound": {"s0": 64}})
with R.dataflow():
lv: R.Tensor((s0, 4), dtype="float32") = R.add(x1, x2)
gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,)
Expand Down
12 changes: 12 additions & 0 deletions tests/python/relax/test_transform_static_plan_block_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,18 @@ def main(x: R.Tensor((2, "n"), dtype="float32")):
relax.transform.StaticPlanBlockMemory()(Module)


def test_invalid_tir_var_lower_bound():
@tvm.script.ir_module
class Module:
@R.function
def main(x: R.Tensor((2, "n"), dtype="float32")):
R.func_attr({"tir_var_lower_bound": {"n": [4]}, "relax.force_pure": True})
return x

with pytest.raises((TVMError, TypeError)):
relax.transform.StaticPlanBlockMemory()(Module)


def test_add():
@I.ir_module
class Module:
Expand Down
Loading