diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 4ee4d1c5ba..4fe298aba1 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -413,7 +413,9 @@ def create_function_from_signature(sig, kparams): } func_namespace['mangle_type'] = mangle_type - func_namespace['compute_spec_key'] = compute_spec_key + + from triton.runtime.driver import spec_func + func_namespace['compute_spec_key'] = spec_func("compute_spec_key") or compute_spec_key # Execute the function string in func_namespace to create the function exec(func_body, func_namespace) diff --git a/third_party/iluvatar/backend/spec/__init__.py b/third_party/iluvatar/backend/spec/__init__.py index f942b6f5b8..90830e316f 100644 --- a/third_party/iluvatar/backend/spec/__init__.py +++ b/third_party/iluvatar/backend/spec/__init__.py @@ -61,4 +61,5 @@ "bmm", "language_modify_all", "corex_sme", + "compute_spec_key", ] diff --git a/third_party/iluvatar/backend/spec/triton/runtime/jit.py b/third_party/iluvatar/backend/spec/triton/runtime/jit.py index cbca622bfb..da8194212f 100644 --- a/third_party/iluvatar/backend/spec/triton/runtime/jit.py +++ b/third_party/iluvatar/backend/spec/triton/runtime/jit.py @@ -56,6 +56,17 @@ def ext_JITFunction_spec_of(arg): return (arg % 4 == 0, arg % JITFunction.divisibility_8 == 0, arg == 1) +def compute_spec_key(v): + if hasattr(v, "data_ptr") and (v.data_ptr() % 4 == 0): + return "D" + elif isinstance(v, int): + if v % 4 == 0: + return "D" + elif v == 1: + return "1" + return "N" + + def is_corex_param(x, enable_sme): if enable_sme: if hasattr(x, "data_ptr"):