diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 88b45698..8a2c501f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -56,101 +56,130 @@ jobs: - name: Build and install DLCompiler run: | set -ex - source /home/dlc_ci/.bashrc - which conda - echo "which conda? $(which conda)" - conda activate dlcompiler - source /usr/local/Ascend/cann-8.5.0/set_env.sh - cd ${{ env.CI_PATH }} - export JSON_PATH34=${{ vars.CI_BASE_PATH }}/data/v34/include.zip - export GOOGLETEST_DIR34=${{ vars.CI_BASE_PATH }}/data/v34/googletest - export LLVM_TGZ_PATH34=${{ vars.CI_BASE_PATH }}/data/v34/llvm-064f02da-ubuntu-arm64.tar.gz - - rm -rf ./third_party/* - git clone --no-hardlinks ${{ env.THIRD_PARTY_PATH }}/ascendnpu-ir ./third_party/ascendnpu-ir - git clone --no-hardlinks ${{ env.THIRD_PARTY_PATH }}/json ./third_party/json - git clone --no-hardlinks ${{ env.THIRD_PARTY_PATH }}/triton_shared ./third_party/triton_shared - git clone --no-hardlinks ${{ env.THIRD_PARTY_PATH }}/triton ./third_party/triton - echo "whoami? $(whoami)" - echo "which python? $(which python)" - git submodule update --init - echo "git submodule update done." - pip install nanobind -i https://mirrors.huaweicloud.com/repository/pypi/simple - bash compile_shared.sh apply_patch=true + echo "=== 切换前用户信息 ===" + echo "whoami: $(whoami)" + echo "user id: $(id)" + echo "PATH: $PATH" + echo "======================" + + # 切换到 root 用户执行,确保有权限访问 NPU 设备 + sudo -E bash -c ' + source /home/dlc_ci/.bashrc + which conda + echo "which conda? $(which conda)" + conda activate dlcompiler + source /usr/local/Ascend/cann-8.5.0/set_env.sh + cd ${{ env.CI_PATH }} + export JSON_PATH34=${{ vars.CI_BASE_PATH }}/data/v34/include.zip + export GOOGLETEST_DIR34=${{ vars.CI_BASE_PATH }}/data/v34/googletest + export LLVM_TGZ_PATH34=${{ vars.CI_BASE_PATH }}/data/v34/llvm-064f02da-ubuntu-arm64.tar.gz + + rm -rf ./third_party/* + git clone --no-hardlinks ${{ env.THIRD_PARTY_PATH }}/ascendnpu-ir ./third_party/ascendnpu-ir + git clone --no-hardlinks ${{ env.THIRD_PARTY_PATH }}/json ./third_party/json + git clone --no-hardlinks ${{ env.THIRD_PARTY_PATH }}/triton_shared ./third_party/triton_shared + git clone --no-hardlinks ${{ env.THIRD_PARTY_PATH }}/triton ./third_party/triton + echo "whoami before compile: $(whoami)" + echo "which python? $(which python)" + echo "npu-smi path: $(which npu-smi 2>/dev/null || echo "not found")" + git submodule update --init + echo "git submodule update done." + pip install nanobind -i https://mirrors.huaweicloud.com/repository/pypi/simple + + bash compile_shared.sh apply_patch=true + ' - name: Build and install tilelang-dlc run: | set -ex - source /home/dlc_ci/.bashrc - conda activate dlcompiler - source /usr/local/Ascend/cann-8.5.0/set_env.sh - cd ${{ env.CI_PATH }} - export TILELANG_DLC_PATH=${{ vars.CI_BASE_PATH }}/data/tilelang-dlc - export DLCOMPILER_SOURCE=${{ env.CI_PATH }} - export TILELANG_USE_DLCOMPILER=1 - echo "whoami? $(whoami)" - echo "which python? $(which python)" - bash scripts/install_tilelang-dlc.sh + # 切换到 root 用户执行 + sudo -E bash -c ' + source /home/dlc_ci/.bashrc + conda activate dlcompiler + source /usr/local/Ascend/cann-8.5.0/set_env.sh + cd ${{ env.CI_PATH }} + export TILELANG_DLC_PATH=${{ vars.CI_BASE_PATH }}/data/tilelang-dlc + export DLCOMPILER_SOURCE=${{ env.CI_PATH }} + export TILELANG_USE_DLCOMPILER=1 + echo "whoami? $(whoami)" + echo "which python? $(which python)" + bash scripts/install_tilelang-dlc.sh + ' - name: Run tilelang-dlc tests on ascend run: | set -ex - source /home/dlc_ci/.bashrc - conda activate dlcompiler - source /usr/local/Ascend/cann-8.5.0/set_env.sh - cd ${{env.CI_PATH }} - export PATH=${{ vars.CI_BASE_PATH }}/data/bishengir_latest/:$PATH - export ASCEND_RT_VISIBLE_DEVICES=7 - export TILELANG_USE_DLCOMPILER=1 - bash test/commonir/run_tests.sh + # 切换到 root 用户执行 + sudo -E bash -c ' + source /home/dlc_ci/.bashrc + conda activate dlcompiler + source /usr/local/Ascend/cann-8.5.0/set_env.sh + cd ${{env.CI_PATH }} + export PATH=${{ vars.CI_BASE_PATH }}/data/bishengir_latest/:$PATH + export ASCEND_RT_VISIBLE_DEVICES=7 + export TILELANG_USE_DLCOMPILER=1 + bash test/commonir/run_tests.sh + ' - name: Run triton tests on ascend run: | set -ex - source /home/dlc_ci/.bashrc - conda activate dlcompiler - source /usr/local/Ascend/cann-8.5.0/set_env.sh - cd ${{env.CI_PATH }} - echo "whoami? $(whoami)" - echo "which python? $(which python)" - export PATH=${{ vars.CI_BASE_PATH }}/data/bishengir_latest/:$PATH - export ASCEND_RT_VISIBLE_DEVICES=7 - bash test/ascend/run_tests.sh + # 切换到 root 用户执行 + sudo -E bash -c ' + source /home/dlc_ci/.bashrc + conda activate dlcompiler + source /usr/local/Ascend/cann-8.5.0/set_env.sh + cd ${{env.CI_PATH }} + echo "whoami? $(whoami)" + echo "which python? $(which python)" + export PATH=${{ vars.CI_BASE_PATH }}/data/bishengir_latest/:$PATH + export ASCEND_RT_VISIBLE_DEVICES=7 + bash test/ascend/run_tests.sh + ' - name: Run MLIR tests run: | set -ex - source /home/dlc_ci/.bashrc - conda activate dlcompiler - source /usr/local/Ascend/cann-8.5.0/set_env.sh - cd ${{env.CI_PATH }} - echo "whoami? $(whoami)" - echo "which python? $(which python)" - export PATH=${{ vars.CI_BASE_PATH }}/data/bishengir_latest/:$PATH - export ASCEND_RT_VISIBLE_DEVICES=7 - bash test/ascend/test_mlir.sh + # 切换到 root 用户执行 + sudo -E bash -c ' + source /home/dlc_ci/.bashrc + conda activate dlcompiler + source /usr/local/Ascend/cann-8.5.0/set_env.sh + cd ${{env.CI_PATH }} + echo "whoami? $(whoami)" + echo "which python? $(which python)" + export PATH=${{ vars.CI_BASE_PATH }}/data/bishengir_latest/:$PATH + export ASCEND_RT_VISIBLE_DEVICES=7 + bash test/ascend/test_mlir.sh + ' - name: Run DSL tests run: | set -ex - source /home/dlc_ci/.bashrc - conda activate dlcompiler - source /usr/local/Ascend/cann-8.5.0/set_env.sh - cd ${{env.CI_PATH }} - echo "whoami? $(whoami)" - echo "which python? $(which python)" - export PATH=${{ vars.CI_BASE_PATH }}/data/bishengir_latest/:$PATH - export ASCEND_RT_VISIBLE_DEVICES=7 - bash test/dsl/run_tests.sh + # 切换到 root 用户执行 + sudo -E bash -c ' + source /home/dlc_ci/.bashrc + conda activate dlcompiler + source /usr/local/Ascend/cann-8.5.0/set_env.sh + cd ${{env.CI_PATH }} + echo "whoami? $(whoami)" + echo "which python? $(which python)" + export PATH=${{ vars.CI_BASE_PATH }}/data/bishengir_latest/:$PATH + export ASCEND_RT_VISIBLE_DEVICES=7 + bash test/dsl/run_tests.sh + ' - name: Clear workfile if: always() run: | - export workdir=$(pwd) - cd .. - rm -rf $workdir - mkdir $workdir - chmod -R 777 $workdir - if [ -d "${{ env.CI_PATH }}" ]; then - rm -rf ${{ env.CI_PATH }} - fi + # 使用 root 权限清理,确保能删除之前以 root 创建的文件 + sudo bash -c ' + export workdir=$(pwd) + cd .. + rm -rf $workdir + mkdir $workdir + chmod -R 777 $workdir + if [ -d "${{ env.CI_PATH }}" ]; then + rm -rf ${{ env.CI_PATH }} + fi + ' diff --git a/backend/npu.py b/backend/npu.py index 38de5735..610cefae 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -521,6 +521,15 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False, cpu_verify=Fa pattern = r"(memref\<.*?\>)\s+to\s+(tensor\<.*?\>)" # 使用正则替换,保留memref和tensor类型,中间插入注释 content = re.sub(pattern, r"\1 // to \2", content) + # 处理customop的attr + if len(re.findall("hivm\.hir\.custom", content)) > 0: + content = re.sub(r'"#hivm\.pipe<([A-Za-z0-9_]*)>"', r"#hivm.pipe<\1>", content) + content = re.sub( + r'"#hivm\.tcore_type<([A-Za-z0-9_]*)>"', r"#hivm.tcore_type<\1>", content + ) + content = re.sub( + r'"#hivm\.vf_mode<([A-Za-z0-9_]*)>"', r"#hivm.vf_mode<\1>", content + ) if opt.debug or dump_ir: cmd_list = [ @@ -681,6 +690,10 @@ def _parse_linalg_metadata(linalg: str, metadata: dict): TENSOR_KIND_REGEX = ( r"%arg(\d+):[^,)]*?\{[^}]*?tt\.tensor_kind\s*=\s*([^:\s}]+)\s*:[^}]*?\}" ) + + # Example: bitcode = "a.bc" + BITCODES_REGEX = r'bitcode\s*=\s*(?:"([^"]+)"|\'([^\']+)\'|(\w+))' + # Example removal: ', mix_mode = "aiv"' → '' REMOVE_MIX_MODE_REGEX = r', mix_mode\s*=\s*"[^"]*"' # Note: Compiled Kernel requires to estimate size of shared memory to occupy @@ -696,6 +709,11 @@ def _parse_linalg_metadata(linalg: str, metadata: dict): metadata["tensor_kinds"] = [ int(kind) for _, kind in re.findall(TENSOR_KIND_REGEX, linalg) ] + + # Parse all bitcode paths + bitcodes = re.findall(BITCODES_REGEX, linalg) + metadata["bitcodes"] = [val for group in bitcodes for val in group if val] + # remove the mix_mode attribute linalg = re.sub(REMOVE_MIX_MODE_REGEX, "", linalg) return linalg, metadata @@ -729,6 +747,12 @@ def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt): _compile_option_list += ["--enable-sanitizer=true"] if _is_auto_map_parallel_blocks_enabled(): _compile_option_list += ["--enable-auto-blockify-loop"] + + bitcodes = metadata["bitcodes"] + if bitcodes is not None: + for bitcodes in bitcodes: + _compile_option_list += [f"--link-aicore-bitcode={bitcodes}"] + npu_compiler_path = _get_npucompiler_path() # support bishengir-compile more version diff --git a/compiler/lib/Conversion/LinkedToHIVM/LinkedToHIVM.cpp b/compiler/lib/Conversion/LinkedToHIVM/LinkedToHIVM.cpp index bf8bf01b..dbbb7da4 100644 --- a/compiler/lib/Conversion/LinkedToHIVM/LinkedToHIVM.cpp +++ b/compiler/lib/Conversion/LinkedToHIVM/LinkedToHIVM.cpp @@ -123,6 +123,24 @@ struct TritonCustomSyncOpToHIVMSyncOpConversion } }; +// Convert CustomOp after operand type changed, +// for example tt.ptr changed to memref. +class TritonCustomOpToHIVMCustomOpConversion + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CustomOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto res_types = adaptor.getOutputs().getTypes(); + auto new_op = rewriter.create( + op->getLoc(), res_types, adaptor.getOperands(), op->getAttrs()); + rewriter.replaceOp(op, new_op); + return success(); + } +}; + void LinkedToHIVMPass::runOnOperation() { auto module = getOperation(); ConversionTarget target(getContext()); @@ -130,6 +148,7 @@ void LinkedToHIVMPass::runOnOperation() { RewritePatternSet patterns(&getContext()); patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); } diff --git a/language/deeplink/__init__.py b/language/deeplink/__init__.py index 044f8557..b4f5ddca 100644 --- a/language/deeplink/__init__.py +++ b/language/deeplink/__init__.py @@ -21,6 +21,14 @@ L0C, SyncFlag, ) +from .custom_op import ( + custom, + custom_semantic, + register_custom_op, + CORE, + PIPE, + MODE, +) __all__ = [ "libdevice", @@ -43,6 +51,12 @@ "L0C", "SyncFlag", "async_task", + "custom", + "custom_semantic", + "register_custom_op", + "CORE", + "PIPE", + "MODE", ] init_dicp_driver() diff --git a/language/deeplink/custom_op.py b/language/deeplink/custom_op.py new file mode 100644 index 00000000..09073a1e --- /dev/null +++ b/language/deeplink/custom_op.py @@ -0,0 +1,302 @@ +import inspect +import types +import typing +import itertools +import enum +from triton.language import core, semantic +import triton.language as tl + +__all__ = ["custom", "custom_semantic", "register_custom_op", "CORE", "PIPE", "MODE"] + +_custom_op_registry = {} + + +class CORE(enum.Enum): + CUBE = "CUBE" + VECTOR = "VECTOR" + CUBE_OR_VECTOR = "CUBE_OR_VECTOR" + CUBE_AND_VECTOR = "CUBE_AND_VECTOR" + + +class PIPE(enum.Enum): + PIPE_S = "PIPE_S" + PIPE_V = "PIPE_V" + PIPE_M = "PIPE_M" + PIPE_MTE1 = "PIPE_MTE1" + PIPE_MTE2 = "PIPE_MTE2" + PIPE_MTE3 = "PIPE_MTE3" + PIPE_ALL = "PIPE_ALL" + PIPE_FIX = "PIPE_FIX" + + +class MODE(enum.Enum): + SIMD = "SIMD" + SIMT = "SIMT" + MIX = "MIX" + + +def _get_op_class(name): + # Try to get op class in _custom_op_registry. + op_class = _custom_op_registry.get(name) + if op_class is None: + # Allow bulitin custom ops used without registry. + assert name.startswith("__builtin_"), f"Custom Op '{name}' not registered." + # Return a dummy op class for builtin custom op. + op_class = type( + "_builtin_custom_op", + (object,), + { + "name": name, + "core": core.CORE.VECTOR, + "pipe": core.PIPE.PIPE_V, + "mode": core.MODE.SIMT, + "signature": inspect.signature(object), + }, + ) + return op_class + + +def _unwrap_constexpr(arg): + if isinstance(arg, tl.constexpr): + return arg.value + if isinstance(arg, tuple): + return tuple(_unwrap_constexpr(x) for x in arg) + if isinstance(arg, list): + return [_unwrap_constexpr(x) for x in arg] + if isinstance(arg, dict): + return {k: _unwrap_constexpr(v) for k, v in arg.items()} + return arg + + +def _to_value(value, builder, ty=None): + # Try to use 'type' attribute if ty not set. + ty = getattr(value, "type", ty) if ty is None else ty + if isinstance(value, tl.tensor): + if not value.type.is_block() and isinstance(ty, tl.dtype) and value.type != ty: + # For a scalar variable, if its type is not the expected one + # that specified by type hint 'ty', insert a cast for it. + return tl.semantic.cast(value, ty, builder).handle + return value.handle + if isinstance(value, bool): + return builder.get_int1(value) + if isinstance(value, int): + if isinstance(ty, tl.dtype): + if ty.is_int64(): + return builder.get_int64(value) + if ty.is_uint64(): + return builder.get_uint64(value) + if ty.is_int32(): + return builder.get_int32(value) + if ty.is_uint32(): + return builder.get_uint32(value) + if ty.is_int16(): + return builder.get_int16(value) + if ty.is_uint16(): + return builder.get_uint16(value) + if ty.is_int8(): + return builder.get_int8(value) + if ty.is_uint8(): + return builder.get_uint8(value) + # default int32 + return builder.get_int32(value) + if isinstance(value, float): + if isinstance(ty, tl.dtype): + if ty.is_fp64(): + return builder.get_fp64(value) + if ty.is_fp32(): + return builder.get_fp32(value) + if ty.is_fp16(): + return builder.get_fp16(value) + if ty.is_bf16(): + return builder.get_bf16(value) + # default float32 + return builder.get_fp32(value) + if isinstance(value, tl.constexpr): + return _to_value(value.value, builder) + raise TypeError(f"Unsupported argument type {value} : {type(value)}") + + +def _to_operands(args, builder): + operands = [] + for value in args: + if value is None: + continue + if isinstance(value, (list, tuple)): + for item in value: + operands.append(_to_value(item, builder)) + else: + operands.append(_to_value(value, builder)) + return operands + + +def _get_element_type(ty): + if isinstance(ty, types.GenericAlias): + return typing.get_args(ty)[0] + return ty + + +def _args_to_operands(op, builder, args, kwargs): + if not op.signature.parameters: + # Without parameters in signature, use the actual parameter order. + return _to_operands(itertools.chain(args, kwargs.values()), builder) + + # Convert arguments to operands according the signature. + operands = [] + bind = op.signature.bind(*args, **kwargs) + for param in op.signature.parameters.values(): + value = bind.arguments.get(param.name, None) + if value is None: + continue + ty = op.arg_type.get(param.name, param.annotation) + if isinstance(value, (list, tuple)): + ty = _get_element_type(ty) + for item in value: + operands.append(_to_value(item, builder, ty)) + else: + operands.append(_to_value(value, builder, ty)) + return operands + + +def _add_optional_attr(op, name, builder, attrs): + if hasattr(op, name): + attrs[name] = getattr(op, name) + + +def _add_bitcode_attr(op, builder, attrs): + if not hasattr(op, "bitcode"): + return + from pathlib import Path + + bitcode = Path(getattr(op, "bitcode")) + assert bitcode.exists(), f"Provided bitcode ({bitcode}) not exist" + attrs["bitcode"] = str(bitcode.absolute()) + + +def _make_attrs(op, builder): + attrs = { + "hivm.tcore_type": f"#hivm.tcore_type<{op.core.value}>", + "hivm.pipe": f"#hivm.pipe<{op.pipe.value}>", + "hivm.vf_mode": f"#hivm.vf_mode<{op.mode.value}>", + } + + if not op.name.startswith("__builtin_"): + assert hasattr(op, "symbol"), f"Non builtin custom op, symbol is required." + assert hasattr( + op, "bitcode" + ), f"Non builtin custom op, bitcode path is required." + + # Add bit code path attribute, formalize to abosulte path. + _add_bitcode_attr(op, builder, attrs) + + _add_optional_attr(op, "symbol", builder, attrs) + _add_optional_attr(op, "source", builder, attrs) + _add_optional_attr(op, "compile", builder, attrs) + # Extra attributes can be added here, such as op.extra_attr="attr_a=xx" + _add_optional_attr(op, "extra_attr", builder, attrs) + _add_optional_attr(op, "iterator_types", builder, attrs) + return attrs + + +def _to_result(res, res_types): + assert len(res) == len(res_types) + n_res = len(res) + if n_res == 0: + return None + if n_res == 1: + return tl.tensor(res[0], res_types[0]) + return tuple(tl.tensor(res[i], res_types[i]) for i in range(n_res)) + + +def _init_op(op_class, *args, **kwargs): + op = op_class.__new__(op_class) + # Add arg_type dict to support dynamic argument type specifying. + setattr(op, "arg_type", {}) + if op_class.signature.parameters: + # Init with arguments validate. + op_class.__init__(op, *args, **kwargs) + return op + + +def custom_semantic(name: str, *args, _semantic=None, **kwargs): + _builder = _semantic.builder + name = _unwrap_constexpr(name) + assert name in _custom_op_registry, f"Custom op '{name}' not found." + # Get op class according the name. + op_class = _get_op_class(name) + # Convert constexpr to value in arguments. + args = _unwrap_constexpr(args) + kwargs = _unwrap_constexpr(kwargs) + # Create op instance from op class with the arguments. + op = _init_op(op_class, *args, **kwargs) + # Prepare inputs and outputs operands. + out = kwargs.pop("out", []) + outs = out if isinstance(out, (list, tuple)) else [out] + outputs = _to_operands(outs, _builder) + inputs = _args_to_operands(op, _builder, args, kwargs) + # Setup attributes. + attrs = _make_attrs(op, _builder) + # Build IR for the custom op. + res = _builder.create_custom_op(name, attrs, inputs, outputs) + # Results with same types as outputs. + res_types = [out.type for out in outs] + return _to_result(res, res_types) + + +@core.builtin +def custom(name: str, *args, _semantic=None, **kwargs): + """Invoke a custom operation with the given name and arguments.""" + return custom_semantic(name, *args, _semantic=_semantic, **kwargs) + + +def register_custom_op(op): + """Register a custom operation so that we can invoke it using al.custom().""" + assert inspect.isclass(op), "@register_custom_op should decorate on a class." + # Use class name if name not set. + if not hasattr(op, "name"): + setattr(op, "name", op.__name__) + # The op name should not be used. + assert ( + op.name not in _custom_op_registry + ), f"Custom op name '{op.name}' already used." + # Check required core, pipe, mode fields. + assert hasattr(op, "core"), "'core' field is required." + assert hasattr(op, "pipe"), "'pipe' field is required." + assert hasattr(op, "mode"), "'mode' field is required." + assert isinstance(op.core, CORE), "Invalid 'core' field, CORE type is required." + assert isinstance(op.pipe, PIPE), "Invalid 'pipe' field, PIPE type is required." + assert isinstance(op.mode, MODE), "Invalid 'mode' field, MODE type is required." + # Retrieve arguments signature from __init__ method and save it. + signature = inspect.signature(op) + setattr(op, "signature", signature) + # Register the custom op configuration. + _custom_op_registry[op.name] = op + return op + + +# ----------------------- +# SPMD Programming Model +# ----------------------- +def _constexpr_to_value(v): + if isinstance(v, tl.constexpr): + return v.value + return v + + +def _is_int_like_elem(x) -> bool: + """Accept int / tl.constexpr(int) / tl.tensor(int*).""" + if isinstance(x, int): + return True + if isinstance(x, tl.constexpr): + # constexpr value should be python int + return isinstance(x.value, int) + if isinstance(x, tl.tensor): + # Offsets/strides must be integer typed (i32/i64 etc.) + return x.dtype.is_int() + return False + + +def _assert_int_like_tuple(name: str, xs): + assert isinstance( + xs, (tuple, list) + ), f"{name} should be a tuple/list, but got {type(xs)}" + assert all(_is_int_like_elem(x) for x in xs), f"{name} should be integer" diff --git a/patch/ascendnpu-ir.patch b/patch/ascendnpu-ir.patch index e94b98b6..2b428979 100644 --- a/patch/ascendnpu-ir.patch +++ b/patch/ascendnpu-ir.patch @@ -67,6 +67,34 @@ index ce55e7f..089f75e 100644 getRegionBuilder(); // Used for AggregateOpInterface to decompose into legal operations +diff --git a/bishengir/include/bishengir/Dialect/HIVM/IR/HIVMAttrs.td b/bishengir/include/bishengir/Dialect/HIVM/IR/HIVMAttrs.td +index f245841..47db011 100644 +--- a/bishengir/include/bishengir/Dialect/HIVM/IR/HIVMAttrs.td ++++ b/bishengir/include/bishengir/Dialect/HIVM/IR/HIVMAttrs.td +@@ -783,6 +783,23 @@ def HIVM_StorageAligned : HIVM_Attr<"StorageAligned", "storage_aligned"> { + // Misc. + //===----------------------------------------------------------------------===// + ++//SH ++def HIVM_VF_SIMD : I32EnumAttrCase<"SIMD", 0>; ++def HIVM_VF_SIMT : I32EnumAttrCase<"SIMT", 1>; ++def HIVM_VF_MIX : I32EnumAttrCase<"MIX", 2>; ++ ++def HIVM_VFModeEnum ++ : HIVM_I32Enum<"VFMode", ++ "HIVM VF Mode", [HIVM_VF_SIMD, HIVM_VF_SIMT, HIVM_VF_MIX]>; ++ ++def HIVM_VFModeAttr : HIVM_I32EnumAttr<"vf_mode", HIVM_VFModeEnum> { ++ let description = [{ ++ HIVM VF mode attribute. ++ }]; ++} ++//SH ++ ++ + def HIVM_MultiBufferAttr : HIVM_Attr<"MultiBuffer", "multi_buffer"> { + let description = [{ + HIVM multi-buffer attribute. diff --git a/bishengir/include/bishengir/Dialect/HIVM/IR/HIVMImpl.h b/bishengir/include/bishengir/Dialect/HIVM/IR/HIVMImpl.h index cfd9f28..945c6b8 100644 --- a/bishengir/include/bishengir/Dialect/HIVM/IR/HIVMImpl.h @@ -98,6 +126,112 @@ index f039870..5271c3c 100644 }] >, InterfaceMethod< +diff --git a/bishengir/include/bishengir/Dialect/HIVM/IR/HIVMOps.td b/bishengir/include/bishengir/Dialect/HIVM/IR/HIVMOps.td +index 9c16c24..aea8554 100644 +--- a/bishengir/include/bishengir/Dialect/HIVM/IR/HIVMOps.td ++++ b/bishengir/include/bishengir/Dialect/HIVM/IR/HIVMOps.td +@@ -219,4 +219,101 @@ def FinishDebugOp : HIVM_Op<"finish_debug", [CubeVectorCoreTypeTrait]> { + }]; + } + ++//===----------------------------------------------------------------------===// ++// CustomOp ++//===----------------------------------------------------------------------===// ++ ++def CustomOp ++ : HIVM_StructuredOp< ++ "custom", [AttrSizedOperandSegments, ++ MemoryEffects<[MemRead, MemWrite]>, SinglePipeOpTrait, ++ DeclareOpInterfaceMethods< ++ HIVMInferCoreTypeInterface, ["inferCoreType"]>]> { ++ let summary = [{ ++ Custom operation is a generic op interface for users to write their own custom implementation. ++ ++ Scenarios: ++ 1. Existing operations could not fulfill the desired functionality. ++ 2. Existing operations could fulfill the functionality, but overall performance is not optimal. ++ 3. Desire for private operation. ++ }]; ++ ++ let description = [{ ++ General interface for custom op, where: ++ - name : unique op name. ++ ++ Note : there are names reserved for builtins, usually starts with "__builtin". ++ Compiler will link these builtins to self-contained template library, ++ which comes together within bishengir-compile. ++ ++ For normal names/cases, user needs to specify implementation location/compilation commands (TODO), ++ and all ther necessary informations. ++ ++ Available builtin names: ++ "__builtin_gather_load" ++ ++ - inputs : input parameters. ++ - outputs : output results, designated "init" operands, which act as initial values for the results ++ of the operation or the init locations to which the results of the op will be written. ++ ++ In order to adapt to future enhancements quickly and dynamically, custom op relies on attributes ++ to retreive necessary information, required informations are: ++ - CoreType : which core type to execute on, refer to TCoreTypeAttr. ++ - Pipe : which pipe to execute on, refer to PipeAttr. ++ - VFMode : which mode to run on vector units, refer to VFModeAttr. ++ this attribute is ignored when core type is cube. ++ ++ Note : for builtins, user could specify these informations or not, ++ compiler will help to check the correctness and canonicalize. ++ ++ TODO: ++ - Impl : user provided implementation. ++ - Multi Pipe : custom op wants to use multiple pipes, which is a MacroOp in HIVM's context. ++ }]; ++ ++ let arguments = (ins StrAttr:$name, Variadic:$inputs, ++ Variadic:$outputs); ++ ++ let results = (outs Variadic:$results); ++ ++ let extraClassDeclaration = [{ ++ // TODO: Customize ++ static int getOpLibraryMaxRankImpl() { return 5; } ++ ++ PIPE getPipe(); ++ ++ ::mlir::MutableOperandRange getDpsInitsMutable() { ++ return getOutputsMutable(); ++ } ++ ++ // Helper functions ++ void setPipe(PIPE); ++ ++ std::optional getCoreType(); ++ void setCoreType(TCoreType); ++ ++ std::optional getVFMode(); ++ void setVFMode(VFMode); ++ ++ bool isBuiltin(); ++ ++ // Builtins helpers ++ struct BuiltinInfo { ++ TCoreType coreType; ++ PIPE pipe; ++ VFMode vfMode; ++ }; ++ ++ // Map <-> {CORE_TYPE, PIPE, VF_MODE} ++ static const DenseMap kBuiltins; ++ }]; ++ ++ let hasVerifier = 1; ++ let hasCustomAssemblyFormat = 1; ++ let hasCanonicalizer = 1; ++} ++ ++// TODO: Add CustomMacroOp ++//SH ++ + #endif // BISHENGIR_DIALECT_HIVM_IR_HIVMOPS_TD diff --git a/bishengir/lib/Dialect/HFusion/IR/HFusionOps.cpp b/bishengir/lib/Dialect/HFusion/IR/HFusionOps.cpp index 9309262..cdd81b9 100644 --- a/bishengir/lib/Dialect/HFusion/IR/HFusionOps.cpp @@ -210,6 +344,54 @@ index d3ad5f2..f190405 100644 [](auto) -> OpPattern { return OpPattern::kReduceScatter; }) .Case( [](auto) -> OpPattern { return OpPattern::kInterleave; }) +diff --git a/bishengir/lib/Dialect/HIVM/IR/HIVMCanonicalizations.cpp b/bishengir/lib/Dialect/HIVM/IR/HIVMCanonicalizations.cpp +index 968b75d..9dce8dd 100644 +--- a/bishengir/lib/Dialect/HIVM/IR/HIVMCanonicalizations.cpp ++++ b/bishengir/lib/Dialect/HIVM/IR/HIVMCanonicalizations.cpp +@@ -558,4 +558,41 @@ LogicalResult StoreOp::fold(hivm::StoreOp::FoldAdaptor adaptor, + void mlir::hivm::HIVMDialect::getCanonicalizationPatterns( + ::mlir::RewritePatternSet &results) const { + results.add(getContext()); +-} +\ No newline at end of file ++} ++ ++//SH ++struct CustomOpCanonicalizer : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ ++ LogicalResult matchAndRewrite(CustomOp customOp, ++ PatternRewriter &rewriter) const final { ++ if (!customOp.isBuiltin()) ++ return failure(); ++ ++ const auto &builtinInfo = CustomOp::kBuiltins.at(customOp.getName()); ++ const auto &coreType = customOp.getCoreType(); ++ if (!coreType || *coreType != builtinInfo.coreType) { ++ customOp.setCoreType(builtinInfo.coreType); ++ return success(); ++ } ++ ++ if (customOp.getPipe() != builtinInfo.pipe) { ++ customOp.setPipe(builtinInfo.pipe); ++ return success(); ++ } ++ ++ const auto &vfMode = customOp.getVFMode(); ++ if (!vfMode || *vfMode != builtinInfo.vfMode) { ++ customOp.setVFMode(builtinInfo.vfMode); ++ return success(); ++ } ++ ++ return failure(); ++ } ++}; ++ ++void CustomOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ++ ::mlir::MLIRContext *context) { ++ results.add(context); ++} ++//SH diff --git a/bishengir/lib/Dialect/HIVM/IR/HIVMImpl.cpp b/bishengir/lib/Dialect/HIVM/IR/HIVMImpl.cpp index b3d7807..2e8b56c 100644 --- a/bishengir/lib/Dialect/HIVM/IR/HIVMImpl.cpp @@ -223,6 +405,215 @@ index b3d7807..2e8b56c 100644 llvm::SmallVector alignedStrides(rank, 1); for (int64_t i = 0; i < rank; i++) { if (strideAlignElems[i] == 1) { +diff --git a/bishengir/lib/Dialect/HIVM/IR/HIVMOps.cpp b/bishengir/lib/Dialect/HIVM/IR/HIVMOps.cpp +index e663a2c..827ac4f 100644 +--- a/bishengir/lib/Dialect/HIVM/IR/HIVMOps.cpp ++++ b/bishengir/lib/Dialect/HIVM/IR/HIVMOps.cpp +@@ -412,3 +412,204 @@ std::string hivm::detail::getTypeName(Location loc, Type type) { + return unknown; + } + ++ ++//SH===----------------------------------------------------------------------===// ++// CustomOp ++//===----------------------------------------------------------------------===// ++ ++// Helper functions ++void CustomOp::setPipe(PIPE pipe) { ++ getOperation()->setAttr(PipeAttr::name, PipeAttr::get(getContext(), pipe)); ++} ++ ++std::optional CustomOp::getCoreType() { ++ if (const auto coreTypeAttr = ++ getOperation()->template getAttrOfType( ++ TCoreTypeAttr::name)) { ++ return coreTypeAttr.getTcoretype(); ++ } ++ ++ return {}; ++} ++ ++void CustomOp::setCoreType(TCoreType coreType) { ++ getOperation()->setAttr(TCoreTypeAttr::name, ++ TCoreTypeAttr::get(getContext(), coreType)); ++} ++ ++std::optional CustomOp::getVFMode() { ++ if (const auto vfModeAttr = ++ getOperation()->template getAttrOfType( ++ VFModeAttr::name)) { ++ return vfModeAttr.getValue(); ++ } ++ ++ return {}; ++} ++ ++void CustomOp::setVFMode(VFMode vfMode) { ++ getOperation()->setAttr(VFModeAttr::name, ++ VFModeAttr::get(getContext(), vfMode)); ++} ++ ++bool CustomOp::isBuiltin() { return kBuiltins.contains(getName()); } ++ ++ParseResult CustomOp::parse(OpAsmParser &parser, OperationState &result) { ++ if (succeeded(parser.parseOptionalLess())) { ++ if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater()) ++ return failure(); ++ } ++ ++ // Parse attributes ++ SMLoc attrsLoc = parser.getCurrentLocation(); ++ if (parser.parseOptionalAttrDict(result.attributes)) ++ return failure(); ++ ++ { // Parse name ++ std::string name{}; ++ if (parser.parseString(&name)) ++ return failure(); ++ ++ result.addAttribute("name", parser.getBuilder().getStringAttr(name)); ++ } ++ ++ { // Parse variadic args ++ SmallVector variadicArgsSizes; ++ auto parseVariadicArgs = [&](const std::string &nameHint) { ++ SMLoc loc; ++ SmallVector types; ++ SmallVector operands; ++ ++ if (succeeded(parser.parseOptionalKeyword(nameHint))) { ++ loc = parser.getCurrentLocation(); ++ if (parser.parseLParen() || parser.parseOperandList(operands) || ++ parser.parseColonTypeList(types) || parser.parseRParen()) ++ return failure(); ++ } ++ ++ if (parser.resolveOperands(operands, types, loc, result.operands)) { ++ return failure(); ++ } ++ ++ variadicArgsSizes.push_back(static_cast(operands.size())); ++ return success(); ++ }; ++ ++ if (failed(parseVariadicArgs("ins")) || failed(parseVariadicArgs("outs"))) { ++ return failure(); ++ } ++ ++ // Update operandSegmentSizes attribute ++ const auto operandSegmentSizesAttr = ++ parser.getBuilder().getDenseI32ArrayAttr(variadicArgsSizes); ++ // This is a bit complex because we're trying to be backward compatible with ++ // operation syntax that mix the inherent attributes and the discardable ++ // ones in the same dictionary. If the properties are used, we append the ++ // operandSegmentSizes there directly. Otherwise we append it to the ++ // discardable attributes dictionary where it is handled by the generic ++ // Operation::create(...) method. ++ if (result.propertiesAttr) { ++ NamedAttrList attrs = llvm::cast(result.propertiesAttr); ++ attrs.append("operandSegmentSizes", operandSegmentSizesAttr); ++ result.propertiesAttr = attrs.getDictionary(parser.getContext()); ++ } else { ++ result.addAttribute("operandSegmentSizes", operandSegmentSizesAttr); ++ std::optional info = ++ result.name.getRegisteredInfo(); ++ if (info) { ++ if (failed(info->verifyInherentAttrs(result.attributes, [&]() { ++ return parser.emitError(attrsLoc) ++ << "'" << result.name.getStringRef() << "' op "; ++ }))) ++ return failure(); ++ } ++ } ++ } ++ ++ { // Parse result types ++ SmallVector resultTypes; ++ if (parser.parseOptionalArrowTypeList(resultTypes)) { ++ return failure(); ++ } ++ result.addTypes(resultTypes); ++ } ++ ++ return success(); ++} ++ ++void CustomOp::print(OpAsmPrinter &p) { ++ p.printOptionalAttrDict(getOperation()->getAttrs(), ++ /*elidedAttrs=*/{"operandSegmentSizes", "name"}); ++ ++ p << " "; ++ p.printString(getName()); ++ ++ auto printVariadicArgs = [&](const auto &args, const std::string &nameHint) { ++ if (!args.empty()) ++ p << " " << nameHint << "(" << args << " : " << args.getTypes() << ")"; ++ }; ++ ++ printVariadicArgs(getInputs(), "ins"); ++ printVariadicArgs(getOutputs(), "outs"); ++ ++ if (!getResults().empty()) ++ p.printOptionalArrowTypeList(getResultTypes()); ++} ++ ++static LogicalResult verifyBuiltins(CustomOp op) { ++ const auto &builtinInfo = CustomOp::kBuiltins.at(op.getName()); ++ ++ const auto &coreType = op.getCoreType(); ++ if (coreType && *coreType != builtinInfo.coreType) ++ return op.emitOpError() << "Specified core type conflict with " ++ << op.getName() << "'s core type."; ++ ++ const auto &pipe = op.getPipe(); ++ if (pipe != PIPE::PIPE_UNASSIGNED && pipe != builtinInfo.pipe) ++ return op.emitOpError() ++ << "Specified pipe conflict with " << op.getName() << "'s pipe."; ++ ++ const auto &vfMode = op.getVFMode(); ++ if (vfMode && *vfMode != builtinInfo.vfMode) ++ return op.emitOpError() << "Specified vf mode conflict with " ++ << op.getName() << "'s vf mode."; ++ ++ return success(); ++} ++ ++LogicalResult CustomOp::verify() { ++ // Check builtins ++ // if (isBuiltin()) ++ // return verifyBuiltins(*this); ++ ++ // // Check core type attribute ++ // const auto coreType = getCoreType(); ++ // if (!coreType) ++ // return emitOpError() << "Missing core type information"; ++ ++ // // Check pipe attribute ++ // if (getPipe() == PIPE::PIPE_UNASSIGNED) ++ // return emitOpError() << "Missing pipe information"; ++ ++ // // Check VF mode attribute ++ // if (*coreType != TCoreType::CUBE) { ++ // if (!getVFMode()) ++ // return emitOpError() << "Missing vf mode information"; ++ // } else { // Pure cube ++ // // Cube function ignores vf mode information ++ // } ++ ++ return success(); ++} ++ ++PIPE CustomOp::getPipe() { ++ if (auto pipAttr = ++ getOperation()->template getAttrOfType(PipeAttr::name)) ++ return pipAttr.getPipe(); ++ ++ return PIPE::PIPE_UNASSIGNED; ++} ++ ++const DenseMap CustomOp::kBuiltins{ ++ {"__builtin_gather_load", {TCoreType::VECTOR, PIPE::PIPE_V, VFMode::SIMT}}}; ++//SH diff --git a/bishengir/lib/Dialect/HIVM/IR/HIVMSynchronizationOps.cpp b/bishengir/lib/Dialect/HIVM/IR/HIVMSynchronizationOps.cpp index 8c61377..0d2d7d8 100644 --- a/bishengir/lib/Dialect/HIVM/IR/HIVMSynchronizationOps.cpp @@ -281,6 +672,29 @@ index 7f9c39f..9775dca 100644 return strides; } +diff --git a/bishengir/lib/Dialect/HIVM/IR/InferCoreTypeInterface/InferCoreType.cpp b/bishengir/lib/Dialect/HIVM/IR/InferCoreTypeInterface/InferCoreType.cpp +index 39a1231..36a95e6 100644 +--- a/bishengir/lib/Dialect/HIVM/IR/InferCoreTypeInterface/InferCoreType.cpp ++++ b/bishengir/lib/Dialect/HIVM/IR/InferCoreTypeInterface/InferCoreType.cpp +@@ -111,6 +111,18 @@ inferCoreTypeForGlobalMixMatmulOps(GlobalMixMatmulTy *mixMatmulOp) { + // HIVM Ops + //===----------------------------------------------------------------------===// + ++//SH ++std::optional CustomOp::inferCoreType() { ++ if (auto coreTypeAttr = getOperation()->template getAttrOfType( ++ TCoreTypeAttr::name)) { ++ return coreTypeAttr.getTcoretype(); ++ } ++ ++ return {}; ++ } ++//SH ++ ++ + std::optional ConvertLayoutOp::inferCoreType() { + BaseMemRefType srcMemRefTy = getSource().getType(); + hivm::AddressSpace addrSpace = diff --git a/bishengir/lib/Dialect/Utils/Util.cpp b/bishengir/lib/Dialect/Utils/Util.cpp index cd4ccbe..2a2f5ac 100644 --- a/bishengir/lib/Dialect/Utils/Util.cpp diff --git a/patch/triton/include_triton_Dialect_Triton_IR_TritonOps_td.patch b/patch/triton/include_triton_Dialect_Triton_IR_TritonOps_td.patch index e2729d0c..050e5818 100644 --- a/patch/triton/include_triton_Dialect_Triton_IR_TritonOps_td.patch +++ b/patch/triton/include_triton_Dialect_Triton_IR_TritonOps_td.patch @@ -2,7 +2,7 @@ diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dial index e9f892da0..07b83751f 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td -@@ -1410,5 +1410,30 @@ def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter", [TT_DescriptorStoreLike +@@ -1410,5 +1410,83 @@ def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter", [TT_DescriptorStoreLike let hasVerifier = 1; } @@ -31,5 +31,58 @@ index e9f892da0..07b83751f 100644 + let arguments = (ins StrAttr:$op_name, StrAttr:$mode_or_sender, I32Attr:$id); + let assemblyFormat = "$op_name attr-dict"; +} - ++ ++//===----------------------------------------------------------------------===// ++// CustomOp ++//===----------------------------------------------------------------------===// ++def TT_CustomOp : TT_Op<"custom_op", [AttrSizedOperandSegments, MemoryEffects<[MemRead, MemWrite]>]> { ++ let summary = [{ ++ Custom operation is a generic op interface for users to write their own custom implementation. ++ ++ Scenarios: ++ 1. Existing operations could not fulfill the desired functionality. ++ 2. Existing operations could fulfill the functionality, but overall performance is not optimal. ++ 3. Desire for private operation. ++ }]; ++ ++ let description = [{ ++ General interface for custom op, where: ++ - name : unique op name. ++ ++ Note : there are names reserved for builtins, usually starts with "__builtin". ++ Compiler will link these builtins to self-contained template library, ++ which comes together within bishengir-compile. ++ ++ For normal names/cases, user needs to specify implementation location/compilation commands (TODO), ++ and all ther necessary informations. ++ ++ Available builtin names: ++ "__builtin_gather_load" ++ ++ - inputs : input parameters. ++ - outputs : output results, designated "init" operands, which act as initial values for the results ++ of the operation or the init locations to which the results of the op will be written. ++ ++ In order to adapt to future enhancements quickly and dynamically, custom op relies on attributes ++ to retreive necessary information, required informations are: ++ - CoreType : which core type to execute on, refer to TCoreTypeAttr. ++ - Pipe : which pipe to execute on, refer to PipeAttr. ++ - VFMode : which mode to run on vector units, refer to VFModeAttr. ++ this attribute is ignored when core type is cube. ++ ++ Note : for builtins, user could specify these informations or not, ++ compiler will help to check the correctness and canonicalize. ++ ++ TODO: ++ - Impl : user provided implementation. ++ - Multi Pipe : custom op wants to use multiple pipes, which is a MacroOp in HIVM's context. ++ }]; ++ ++ let arguments = (ins StrAttr:$name, Variadic:$inputs, ++ Variadic:$outputs); ++ ++ let results = (outs Variadic:$results); ++ ++} + #endif // Triton_OPS diff --git a/patch/triton/python_src_ir_cc.patch b/patch/triton/python_src_ir_cc.patch index becaaf98..557435a0 100644 --- a/patch/triton/python_src_ir_cc.patch +++ b/patch/triton/python_src_ir_cc.patch @@ -1,5 +1,5 @@ diff --git a/python/src/ir.cc b/python/src/ir.cc -index 4c8a4233b..2932bad97 100644 +index 4c8a4233b..b88942c30 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -339,6 +339,10 @@ void init_triton_ir(py::module &&m) { @@ -21,7 +21,7 @@ index 4c8a4233b..2932bad97 100644 // Ops py::class_(m, "OpState", py::module_local()) -@@ -1740,6 +1745,90 @@ void init_triton_ir(py::module &&m) { +@@ -1740,6 +1745,110 @@ void init_triton_ir(py::module &&m) { .def("create_gather", [](TritonOpBuilder &self, Value src, Value indices, int axis) -> Value { return self.create(src, indices, axis); }) @@ -109,6 +109,26 @@ index 4c8a4233b..2932bad97 100644 + self.getBuilder().getStringAttr(mode_or_sender), + self.getBuilder().getI32IntegerAttr(id)); + }) ++ .def("create_custom_op", ++ [](TritonOpBuilder &self, ++ const std::string &name, ++ const py::dict &attrs, ++ const std::vector &ins, ++ const std::vector &outs) -> std::vector { ++ ValueRange inputs{ins}; ++ ValueRange outputs{outs}; ++ TypeRange res_types{outputs}; ++ auto op = self.create(res_types, name, inputs, outputs); ++ for (auto &attr : attrs) { ++ std::string attr_name = py::cast(attr.first); ++ std::string attr_value = py::cast(attr.second); ++ op->setAttr(attr_name, self.getBuilder().getStringAttr(attr_value)); ++ // Attribute attr_value = py::cast(attr.second); ++ // op->setAttr(attr_name, attr_value); ++ } ++ auto results = op->getResults(); ++ return std::vector(results.begin(), results.end()); ++ }) // Force GPU barrier .def("create_barrier", [](TritonOpBuilder &self) { self.create(); }) diff --git a/test/ascend/test_custom_op.py b/test/ascend/test_custom_op.py new file mode 100644 index 00000000..3c3ee22b --- /dev/null +++ b/test/ascend/test_custom_op.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +import subprocess +import os +import triton +import triton.language as tl +import triton.language.extra.deeplink as dl + +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +import hashlib +from triton.backends.dicp_triton.npu import NPUUtils +from triton.backends.compiler import GPUTarget +from triton.compiler.compiler import make_backend, IRSource, filter_traceback +from triton import __version__, knobs +from triton.runtime.cache import ( + get_cache_manager, + get_dump_manager, + get_override_manager, + get_cache_key, +) + +from triton.backends.dicp_triton.npu import ( + make_ttir, + ttir_to_linalg, + ttir_to_ttsharedir_ascend, + ttsharedir_to_linkedir, + linalg_to_bin_enable_npu_compile, + NPUOptions, +) + + +@dl.register_custom_op +class add: + core = dl.CORE.VECTOR + pipe = dl.PIPE.PIPE_V + mode = dl.MODE.SIMD + + def __init__(self, a, b, out=None): + assert out, "out is required" + self.symbol = "custom_add_" + str(a.dtype) + # self.bitcode defaults to the Ascend installation directory + # Typically it would be a specific bitcode file like /path/to/kernel.aiv.bc + self.bitcode = "/usr/local/Ascend/" + + +@triton.jit +def triton_custom_add(output_ptr, a_ptr, b_ptr, L: tl.constexpr): + idx = tl.arange(0, L) + + a = tl.load(a_ptr + idx) + b = tl.load(b_ptr + idx) + + buf = tl.full([L], 0, a.dtype) + res = dl.custom("add", a, b, out=buf) + + tl.store(output_ptr + idx, res) + + +def compile(src, target=None, options=None, _env_vars=None): + compilation_listener = knobs.compilation.listener + if compilation_listener: + timer = CompileTimer() + + if target is None: + target = driver.active.get_current_target() + assert isinstance(target, GPUTarget), "target must be of GPUTarget type" + backend = make_backend(target) + ir_source = not isinstance(src, ASTSource) + # create backend + if ir_source: + assert isinstance(src, str), "source must be either AST or a filepath" + context = ir.context() + src = IRSource(src, context, backend) + + extra_options = src.parse_options() + options = backend.parse_options(dict(options or dict(), **extra_options)) + # create cache manager + env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars + key = get_cache_key(src, backend, options, env_vars=env_vars) + hash = hashlib.sha256(key.encode("utf-8")).hexdigest() + fn_cache_manager = get_cache_manager(hash) + # For dumping/overriding only hash the source as we want it to be independent of triton + # core changes to make it easier to track kernels by hash. + enable_override = knobs.compilation.override + enable_ir_dump = knobs.compilation.dump_ir + store_only_binary = knobs.compilation.store_binary_only + fn_override_manager = get_override_manager(src.hash()) if enable_override else None + fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None + # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms. + # The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}". + # A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate + # the file name to 150 characters to be safe. + file_name = src.name[:150] + metadata_filename = f"{file_name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = metadata_group.get(metadata_filename) + always_compile = knobs.compilation.always_compile + if not always_compile and metadata_path is not None: + # cache hit! + res = CompiledKernel(src, metadata_group, hash) + if compilation_listener: + compilation_listener( + src=src, + metadata=res.metadata._asdict(), + metadata_group=metadata_group, + times=timer.end(), + cache_hit=True, + ) + return res + + # initialize metadata + metadata = { + "hash": hash, + "target": target, + **options.__dict__, + **env_vars, + } + metadata["triton_version"] = __version__ + # run compilation pipeline and populate metadata + stages = dict() + backend.add_stages(stages, options, src.language) + first_stage = list(stages.keys()).index(src.ext) + # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. + if ir_source: + first_stage += 1 + + # For IRSource, we have already grabbed the context + called both + # ir.load_dialects and backend.load_dialects. + if not isinstance(src, IRSource): + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + + codegen_fns = backend.get_codegen_implementation(options) + module_map = backend.get_module_map() + try: + module = src.make_ir(target, options, codegen_fns, module_map, context) + except Exception as e: + filter_traceback(e) + raise + + if ir_source: + ir_filename = f"{file_name}.{src.ext}" + metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename) + else: + ir_filename = f"{file_name}.source" + metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename) + + use_ir_loc = knobs.compilation.use_ir_loc + if ir_source and use_ir_loc: + module.create_location_snapshot(src.path) + print(f"Creating new locations for {src.path}") + + if compilation_listener: + timer.finished_ir_initialization() + + if "npubin" in stages.keys(): + del stages["npubin"] + + for ext, compile_ir in list(stages.items())[first_stage:]: + next_module = compile_ir(module, metadata) + ir_filename = f"{file_name}.{ext}" + if fn_override_manager is None: + # Users can override kernels at scale by setting `ir_override` in autotune config + # without TRITON_KERNEL_OVERRIDE + if ( + ir_override := metadata.get("ir_override", None) + ) and ir_override.endswith(f".{ext}"): + next_module = parse(ir_override, ext, context) + elif full_name := fn_override_manager.get_file(ir_filename): + print(f"\nOverriding kernel with file {full_name}") + next_module = parse(full_name, ext, context) + # If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json + if (not store_only_binary) or (ext in ("cubin", "hsaco", "json")): + metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + if fn_dump_manager is not None: + fn_dump_manager.put(next_module, ir_filename) + if ext == "cubin": + sass = get_sass(next_module) + fn_dump_manager.put(sass, file_name + ".sass") + # use an env variable to parse ir from file + if use_ir_loc == ext: + ir_full_name = fn_cache_manager.get_file(ir_filename) + next_module.create_location_snapshot(ir_full_name) + print(f"Creating new locations for {ir_full_name}") + module = next_module + if compilation_listener: + timer.stage_finished(ext) + return module + + +if __name__ == "__main__": + npuutiles = NPUUtils() + src = ASTSource( + triton_custom_add, + {"output_ptr": "*i32", "a_ptr": "*i32", "b_ptr": "*i32"}, + {"L": 32}, + ) + target = GPUTarget(backend="ascend", arch=npuutiles.get_arch(), warp_size=0) + options = { + "debug": False, + "sanitize_overflow": False, + "llvm_version": 15, + "kernel_name": "triton_", + "cluster_dims": (1, 1, 1), + "num_warps": -1, + "num_ctas": -1, + "num_stages": 2, + "num_buffers_warp_spec": 0, + "num_consumer_groups": 0, + "reg_dec_producer": 0, + "reg_inc_consumer": 0, + "enable_warp_specialization": False, + "enable_nd2nz_on_vector": False, + "enable_persistent": False, + "optimize_epilogue": False, + "enable_fp_fusion": True, + "allow_fp8e4nv": False, + "allowed_dot_input_precisions": ("ieee", "hf32"), + "enable_npu_compile": True, + "max_num_imprecise_acc_default": None, + "extern_libs": None, + "multibuffer": True, + "inject_barrier_all": False, + "disable_auto_inject_block_sync": False, + "unit_flag": False, + "disable_auto_cv_work_space_manage": False, + "enable_auto_bind_sub_block": True, + "tile_mix_vector_loop": None, + "tile_mix_cube_loop": None, + "limit_auto_multi_buffer_only_for_local_buffer": None, + "set_workspace_multibuffer": None, + "stream": None, + } + linkedir = compile(src, target, options, {}) + + print("=== MLIR (linkedir) ===") + print(linkedir)