diff --git a/examples/07_gemm_all_scatter/matmul_wrapper.py b/examples/07_gemm_all_scatter/matmul_wrapper.py index 5d8adb58..0e710435 100644 --- a/examples/07_gemm_all_scatter/matmul_wrapper.py +++ b/examples/07_gemm_all_scatter/matmul_wrapper.py @@ -6,37 +6,15 @@ # from streamk_kernel import streamk_gemm from gemm_all_scatter import persistent_gemm_all_scatter -from examples.common.utils import is_triton_interpret_set +from examples.common.matmul_helpers import MatmulDebugMixin import iris gemm_kernel = persistent_gemm_all_scatter -class matmul(torch.autograd.Function): - _debug = False - _registers = None - _spills = None - +class matmul(MatmulDebugMixin, torch.autograd.Function): _num_xcds = iris.hip.get_num_xcc() - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def get_matmul_registers(): - if matmul._debug: - return matmul._registers - else: - raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") - - @staticmethod - def get_matmul_spills(): - if matmul._debug: - return matmul._spills - else: - raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") - @staticmethod def _call( a: torch.Tensor, @@ -119,9 +97,7 @@ def _call( mm_end_timestamp_ptr=mm_end_timestamp, ) - if matmul._debug and not is_triton_interpret_set(): - matmul._registers = kk.n_regs - matmul._spills = kk.n_spills + matmul._track_debug_info(kk) return c diff --git a/examples/08_gemm_all_reduce_atomics/matmul_wrapper.py b/examples/08_gemm_all_reduce_atomics/matmul_wrapper.py index 8b64759e..c02a9dd4 100644 --- a/examples/08_gemm_all_reduce_atomics/matmul_wrapper.py +++ b/examples/08_gemm_all_reduce_atomics/matmul_wrapper.py @@ -11,20 +11,12 @@ # from streamk_kernel_atomic import streamk_gemm from gemm_all_reduce_atomics import persistent_gemm_all_reduce -from examples.common.utils import is_triton_interpret_set +from examples.common.matmul_helpers import MatmulDebugMixin import iris gemm_kernel = persistent_gemm_all_reduce - -class matmul(torch.autograd.Function): - _debug = True - - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - matmul.streamk_registers = 0 - matmul.streamk_spills = 0 +class matmul(MatmulDebugMixin, torch.autograd.Function): @staticmethod def _call( @@ -109,9 +101,7 @@ def _call( mm_end_timestamp_ptr=mm_end_timestamp, ) - if matmul._debug and not is_triton_interpret_set(): - matmul.streamk_registers = kk.n_regs - matmul.streamk_spills = kk.n_spills + matmul._track_debug_info(kk) return c diff --git a/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py b/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py index 49e53c0d..b46388b4 100644 --- a/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py +++ b/examples/09_gemm_one_shot_all_reduce/matmul_wrapper.py @@ -11,23 +11,15 @@ # from streamk_kernel_atomic import streamk_gemm from gemm_one_shot_all_reduce import persistent_gemm_all_reduce -from examples.common.utils import is_triton_interpret_set +from examples.common.matmul_helpers import MatmulDebugMixin import iris gemm_kernel = persistent_gemm_all_reduce - -class matmul(torch.autograd.Function): - _debug = True +class matmul(MatmulDebugMixin, torch.autograd.Function): _num_xcds = iris.hip.get_num_xcc() - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - matmul.streamk_registers = 0 - matmul.streamk_spills = 0 - @staticmethod def _call( a: torch.Tensor, @@ -150,12 +142,7 @@ def _call( mm_end_timestamp_ptr=mm_end_timestamp, ) - if matmul._debug and not is_triton_interpret_set(): - matmul.streamk_registers = kk.n_regs - matmul.streamk_spills = kk.n_spills - print(f"{kk.n_regs} registers used, {kk.n_spills} spills") - # print(kk.asm['ttgir']) - # print(kk.asm['amdgcn']) + matmul._track_debug_info(kk) return c diff --git a/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py b/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py index 1d46297a..56c2df86 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py +++ b/examples/10_gemm_all_scatter_wg_specialization/matmul_wrapper.py @@ -8,37 +8,15 @@ from gemm_all_scatter_wg_specialization import ( persistent_gemm_all_scatter_wg_specialization, ) -from examples.common.utils import is_triton_interpret_set +from examples.common.matmul_helpers import MatmulDebugMixin import iris gemm_kernel = persistent_gemm_all_scatter_wg_specialization - -class matmul(torch.autograd.Function): - _debug = False - _registers = None - _spills = None +class matmul(MatmulDebugMixin, torch.autograd.Function): _num_xcds = iris.hip.get_num_xcc() - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def get_matmul_registers(): - if matmul._debug: - return matmul._registers - else: - raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") - - @staticmethod - def get_matmul_spills(): - if matmul._debug: - return matmul._spills - else: - raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") - @staticmethod def _call( a: torch.Tensor, @@ -125,9 +103,7 @@ def _call( mm_end_timestamp_ptr=mm_end_timestamp, ) - if matmul._debug and not is_triton_interpret_set(): - matmul._registers = kk.n_regs - matmul._spills = kk.n_spills + matmul._track_debug_info(kk) return c diff --git a/examples/11_gemm_all_scatter_producer_consumer/matmul_wrapper.py b/examples/11_gemm_all_scatter_producer_consumer/matmul_wrapper.py index 02dd22e1..326bb7f6 100644 --- a/examples/11_gemm_all_scatter_producer_consumer/matmul_wrapper.py +++ b/examples/11_gemm_all_scatter_producer_consumer/matmul_wrapper.py @@ -6,37 +6,15 @@ # from streamk_kernel import streamk_gemm from gemm_all_scatter_producer_consumer import persistent_gemm -from examples.common.utils import is_triton_interpret_set +from examples.common.matmul_helpers import MatmulDebugMixin import iris gemm_kernel = persistent_gemm - -class matmul(torch.autograd.Function): - _debug = False - _registers = None - _spills = None +class matmul(MatmulDebugMixin, torch.autograd.Function): _num_xcds = iris.hip.get_num_xcc() - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def get_matmul_registers(): - if matmul._debug: - return matmul._registers - else: - raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") - - @staticmethod - def get_matmul_spills(): - if matmul._debug: - return matmul._spills - else: - raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") - @staticmethod def _call( a: torch.Tensor, @@ -118,9 +96,7 @@ def _call( mm_end_timestamp_ptr=mm_end_timestamp, ) - if matmul._debug and not is_triton_interpret_set(): - matmul._registers = kk.n_regs - matmul._spills = kk.n_spills + matmul._track_debug_info(kk) return c diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/matmul_wrapper.py b/examples/12_gemm_all_scatter_bulk_synchronous/matmul_wrapper.py index d8b1ab7b..bd7f55bb 100644 --- a/examples/12_gemm_all_scatter_bulk_synchronous/matmul_wrapper.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/matmul_wrapper.py @@ -6,36 +6,14 @@ # from streamk_kernel import streamk_gemm from gemm_all_scatter_bulk_synchronous import persistent_gemm -from examples.common.utils import is_triton_interpret_set +from examples.common.matmul_helpers import MatmulDebugMixin import iris gemm_kernel = persistent_gemm - -class matmul(torch.autograd.Function): - _debug = False - _registers = None - _spills = None +class matmul(MatmulDebugMixin, torch.autograd.Function): _num_xcds = iris.hip.get_num_xcc() - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def get_matmul_registers(): - if matmul._debug: - return matmul._registers - else: - raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") - - @staticmethod - def get_matmul_spills(): - if matmul._debug: - return matmul._spills - else: - raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") - @staticmethod def _call( a: torch.Tensor, @@ -114,9 +92,7 @@ def _call( mm_end_timestamp_ptr=mm_end_timestamp, ) - if matmul._debug and not is_triton_interpret_set(): - matmul._registers = kk.n_regs - matmul._spills = kk.n_spills + matmul._track_debug_info(kk) return c diff --git a/examples/15_gemm_all_reduce_ring_based/matmul_wrapper.py b/examples/15_gemm_all_reduce_ring_based/matmul_wrapper.py index b539d070..7e5c557f 100644 --- a/examples/15_gemm_all_reduce_ring_based/matmul_wrapper.py +++ b/examples/15_gemm_all_reduce_ring_based/matmul_wrapper.py @@ -11,37 +11,15 @@ # from streamk_kernel_atomic import streamk_gemm from gemm_all_reduce_ring_based import persistent_gemm -from examples.common.utils import is_triton_interpret_set +from examples.common.matmul_helpers import MatmulDebugMixin import iris gemm_kernel = persistent_gemm - -class matmul(torch.autograd.Function): - _debug = True - _registers = None - _spills = None +class matmul(MatmulDebugMixin, torch.autograd.Function): _num_xcds = iris.hip.get_num_xcc() - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def get_matmul_registers(): - if matmul._debug: - return matmul._registers - else: - raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") - - @staticmethod - def get_matmul_spills(): - if matmul._debug: - return matmul._spills - else: - raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") - @staticmethod def _call( a: torch.Tensor, @@ -123,9 +101,7 @@ def _call( mm_end_timestamp_ptr=mm_end_timestamp, ) - # if matmul._debug and not is_triton_interpret_set(): - matmul._registers = kk.n_regs - matmul._spills = kk.n_spills + matmul._track_debug_info(kk) return ring_buffer diff --git a/examples/20_gemm_all_scatter_independent/matmul_wrapper.py b/examples/20_gemm_all_scatter_independent/matmul_wrapper.py index d8b1ab7b..bd7f55bb 100644 --- a/examples/20_gemm_all_scatter_independent/matmul_wrapper.py +++ b/examples/20_gemm_all_scatter_independent/matmul_wrapper.py @@ -6,36 +6,14 @@ # from streamk_kernel import streamk_gemm from gemm_all_scatter_bulk_synchronous import persistent_gemm -from examples.common.utils import is_triton_interpret_set +from examples.common.matmul_helpers import MatmulDebugMixin import iris gemm_kernel = persistent_gemm - -class matmul(torch.autograd.Function): - _debug = False - _registers = None - _spills = None +class matmul(MatmulDebugMixin, torch.autograd.Function): _num_xcds = iris.hip.get_num_xcc() - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def get_matmul_registers(): - if matmul._debug: - return matmul._registers - else: - raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") - - @staticmethod - def get_matmul_spills(): - if matmul._debug: - return matmul._spills - else: - raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") - @staticmethod def _call( a: torch.Tensor, @@ -114,9 +92,7 @@ def _call( mm_end_timestamp_ptr=mm_end_timestamp, ) - if matmul._debug and not is_triton_interpret_set(): - matmul._registers = kk.n_regs - matmul._spills = kk.n_spills + matmul._track_debug_info(kk) return c diff --git a/examples/21_gemm_one_shot_all_reduce_independent/matmul_wrapper.py b/examples/21_gemm_one_shot_all_reduce_independent/matmul_wrapper.py index b2184a01..10452cc7 100644 --- a/examples/21_gemm_one_shot_all_reduce_independent/matmul_wrapper.py +++ b/examples/21_gemm_one_shot_all_reduce_independent/matmul_wrapper.py @@ -9,37 +9,15 @@ from gemm_one_shot_all_reduce_independent import persistent_gemm -from examples.common.utils import is_triton_interpret_set +from examples.common.matmul_helpers import MatmulDebugMixin import iris gemm_kernel = persistent_gemm - -class matmul(torch.autograd.Function): - _debug = True - _registers = None - _spills = None +class matmul(MatmulDebugMixin, torch.autograd.Function): _num_xcds = iris.hip.get_num_xcc() - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def get_matmul_registers(): - if matmul._debug: - return matmul._registers - else: - raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") - - @staticmethod - def get_matmul_spills(): - if matmul._debug: - return matmul._spills - else: - raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") - @staticmethod def _call( a: torch.Tensor, @@ -117,8 +95,7 @@ def _call( mm_end_timestamp_ptr=mm_end_timestamp, ) - matmul._registers = kk.n_regs - matmul._spills = kk.n_spills + matmul._track_debug_info(kk) return C diff --git a/examples/common/matmul_helpers.py b/examples/common/matmul_helpers.py new file mode 100644 index 00000000..2efd085e --- /dev/null +++ b/examples/common/matmul_helpers.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Common utilities for matmul wrappers in Iris GEMM examples. + +This module provides shared helper functions and a mixin class that can be used +to reduce code duplication across matmul wrapper implementations. +""" + +import torch +from examples.common.utils import is_triton_interpret_set + + +class MatmulDebugMixin: + """ + Mixin class providing debug functionality for matmul wrappers. + + This can be mixed into torch.autograd.Function subclasses to add + standardized debug flag management and register/spill tracking. + + Usage: + class matmul(MatmulDebugMixin, torch.autograd.Function): + # ...your implementation... + pass + """ + + _debug = False + _registers = None + _spills = None + + @classmethod + def set_debug(cls, debug: bool): + """Enable or disable debug mode for register/spill tracking.""" + cls._debug = debug + # Initialize streamk attributes for backward compatibility with some examples + if not hasattr(cls, 'streamk_registers'): + cls.streamk_registers = 0 + cls.streamk_spills = 0 + + @classmethod + def get_matmul_registers(cls): + """Get the number of registers used by the kernel (debug mode only).""" + if cls._debug: + # Support both naming conventions + if cls._registers is not None: + return cls._registers + elif hasattr(cls, 'streamk_registers'): + return cls.streamk_registers + return 0 + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @classmethod + def get_matmul_spills(cls): + """Get the number of register spills in the kernel (debug mode only).""" + if cls._debug: + # Support both naming conventions + if cls._spills is not None: + return cls._spills + elif hasattr(cls, 'streamk_spills'): + return cls.streamk_spills + return 0 + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @classmethod + def _track_debug_info(cls, kernel_result): + """ + Track register and spill information from kernel execution. + + Call this after kernel invocation to store debug info if debug mode is enabled. + + Args: + kernel_result: The kernel object returned from kernel invocation + """ + if cls._debug and not is_triton_interpret_set(): + cls._registers = kernel_result.n_regs + cls._spills = kernel_result.n_spills + # Also update streamk_ attributes if they exist + if hasattr(cls, 'streamk_registers'): + cls.streamk_registers = kernel_result.n_regs + cls.streamk_spills = kernel_result.n_spills