From ae3eaa44b014645d68e42ae4e7eaf8298c0282d4 Mon Sep 17 00:00:00 2001 From: Alessandro Sangiorgi Date: Fri, 21 Nov 2025 12:26:05 -0600 Subject: [PATCH 1/3] Add feat to add kernel mapping dynamically --- benchmarks/run.py | 330 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 305 insertions(+), 25 deletions(-) diff --git a/benchmarks/run.py b/benchmarks/run.py index 5ccb5cdd4..a0a423e74 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -38,10 +38,12 @@ from typing import Any from typing import Callable from typing import cast +import warnings import torch from torch.utils._pytree import tree_leaves from torch.utils._pytree import tree_map +import yaml from helion._testing import get_nvidia_gpu_model from helion._utils import counters @@ -654,6 +656,180 @@ class RunResult: } +def load_kernel_config( + config_path: str, +) -> tuple[dict[str, tuple[str, ...]], dict[str, dict[str, str]]]: + """Load kernel configuration from YAML or JSON file. + + Args: + config_path: Path to configuration file (YAML or JSON) + + Returns: + Tuple of (kernel_mappings, kernel_metric_mappings) + + Raises: + ValueError: If configuration is invalid + FileNotFoundError: If config file doesn't exist + """ + config_file = Path(config_path) + if not config_file.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + # Load configuration + with open(config_file, "r") as f: + if config_file.suffix in [".yaml", ".yml"]: + config = yaml.safe_load(f) + elif config_file.suffix == ".json": + config = json.load(f) + else: + raise ValueError( + f"Unsupported config file format: {config_file.suffix}. Use .yaml, .yml, or .json" + ) + + if not isinstance(config, dict): + raise ValueError("Configuration must be a dictionary") + + kernel_mappings = {} + kernel_metric_mappings = {} + + # Process kernel mappings - now reusing process_single_kernel_mapping + if "kernel_mappings" in config: + raw_mappings = config["kernel_mappings"] + if not isinstance(raw_mappings, dict): + raise ValueError("kernel_mappings must be a dictionary") + + for kernel_name, mapping in raw_mappings.items(): + # Reuse the single source of truth for processing + kernel_mappings[kernel_name] = process_single_kernel_mapping( + kernel_name, mapping + ) + + # Process kernel metric mappings + if "kernel_metric_mappings" in config: + raw_metrics = config["kernel_metric_mappings"] + if not isinstance(raw_metrics, dict): + raise ValueError("kernel_metric_mappings must be a dictionary") + + for kernel_name, metrics in raw_metrics.items(): + if not isinstance(metrics, dict): + raise ValueError( + f"Invalid metrics for kernel '{kernel_name}': must be a dictionary" + ) + kernel_metric_mappings[kernel_name] = metrics + + # Process hardware-specific overrides if present + if "hardware_overrides" in config and is_cuda(): + gpu_model = get_nvidia_gpu_model() + if gpu_model in config["hardware_overrides"]: + hw_config = config["hardware_overrides"][gpu_model] + + # Merge hardware-specific kernel mappings - now reusing the helper + if "kernel_mappings" in hw_config: + for kernel_name, mapping in hw_config["kernel_mappings"].items(): + # Reuse the same helper for consistency + kernel_mappings[kernel_name] = process_single_kernel_mapping( + kernel_name, mapping + ) + + # Merge hardware-specific metric mappings + if "kernel_metric_mappings" in hw_config: + for kernel_name, metrics in hw_config["kernel_metric_mappings"].items(): + if kernel_name not in kernel_metric_mappings: + kernel_metric_mappings[kernel_name] = {} + kernel_metric_mappings[kernel_name].update(metrics) + + return kernel_mappings, kernel_metric_mappings + + +def process_single_kernel_mapping( + kernel_name: str, mapping: dict[str, Any] +) -> tuple[str, ...]: + """Process a single kernel mapping configuration.""" + if not isinstance(mapping, dict): + raise ValueError( + f"Invalid mapping for kernel '{kernel_name}': must be a dictionary" + ) + + if "tritonbench_module" not in mapping: + raise ValueError(f"Missing 'tritonbench_module' for kernel '{kernel_name}'") + + tritonbench_module = mapping["tritonbench_module"] + + # Handle variants + if "variants" in mapping: + variants = [] + for variant in mapping["variants"]: + if "helion_module" not in variant or "helion_func" not in variant: + raise ValueError( + f"Variant in kernel '{kernel_name}' must have 'helion_module' and 'helion_func'" + ) + variants.append((variant["helion_module"], variant["helion_func"])) + + if "args" in mapping: + return (tritonbench_module, variants, mapping["args"]) + else: + return (tritonbench_module, variants) + else: + # Single implementation format + if "helion_module" not in mapping or "helion_func" not in mapping: + raise ValueError( + f"Kernel '{kernel_name}' must have 'helion_module' and 'helion_func' or 'variants'" + ) + + if "args" in mapping: + return ( + tritonbench_module, + mapping["helion_module"], + mapping["helion_func"], + mapping["args"], + ) + else: + return ( + tritonbench_module, + mapping["helion_module"], + mapping["helion_func"], + ) + + +def merge_kernel_configs( + base_mappings: dict[str, tuple[str, ...]], + base_metrics: dict[str, dict[str, str]], + custom_mappings: dict[str, tuple[str, ...]], + custom_metrics: dict[str, dict[str, str]], +) -> tuple[dict[str, tuple[str, ...]], dict[str, dict[str, str]]]: + """Merge custom kernel configurations with base configurations. + + Custom configs extend and can override base configs. + This allows users to: + - Add new kernels not in the base config + - Override existing kernel definitions + - Add or override metric mappings + + Args: + base_mappings: Base kernel mappings (hardcoded) + base_metrics: Base metric mappings (hardcoded) + custom_mappings: Custom kernel mappings from config file + custom_metrics: Custom metric mappings from config file + + Returns: + Tuple of merged (kernel_mappings, kernel_metric_mappings) + """ + # Start with base, then overlay custom (custom takes precedence) + merged_mappings = {**base_mappings, **custom_mappings} + + # For metrics, merge at the kernel level + merged_metrics = dict(base_metrics) + for kernel, metrics in custom_metrics.items(): + if kernel in merged_metrics: + # Merge metrics for this kernel (custom overrides base for same keys) + merged_metrics[kernel] = {**merged_metrics[kernel], **metrics} + else: + # New kernel, add all its metrics + merged_metrics[kernel] = metrics + + return merged_mappings, merged_metrics + + def check_and_setup_tritonbench() -> None: """Ensure a usable tritonbench installation is available.""" @@ -812,17 +988,29 @@ def run_kernel( tritonbench_args: list[str], input_shard_info: tuple[int, int] | None, results: list[RunResult], + kernel_mappings: dict[str, tuple[str, ...]] | None = None, + kernel_metric_mappings: dict[str, dict[str, str]] | None = None, ) -> None: """Run a kernel benchmark, handling both single and multiple variants.""" + # Use provided mappings or default to global mappings + active_mappings = ( + kernel_mappings if kernel_mappings is not None else KERNEL_MAPPINGS + ) + active_metrics = ( + kernel_metric_mappings + if kernel_metric_mappings is not None + else KERNEL_METRIC_MAPPINGS + ) + # Check if kernel is in the mapping table - if kernel_name not in KERNEL_MAPPINGS: + if kernel_name not in active_mappings: print(f"Error: Unknown kernel '{kernel_name}'", file=sys.stderr) print( - f"Available kernels: {', '.join(KERNEL_MAPPINGS.keys())}", file=sys.stderr + f"Available kernels: {', '.join(active_mappings.keys())}", file=sys.stderr ) sys.exit(1) - mapping = KERNEL_MAPPINGS[kernel_name] + mapping = active_mappings[kernel_name] # Extract operator args if present operator_args = {} @@ -859,6 +1047,7 @@ def run_kernel( input_shard_info, operator_args, results, + active_metrics, ) @@ -870,6 +1059,7 @@ def run_kernel_variants( input_shard_info: tuple[int, int] | None, operator_args: dict[str, Any] | None, results: list[RunResult], + kernel_metric_mappings: dict[str, dict[str, str]] | None = None, ) -> None: """Run kernel variants in the same benchmark run.""" @@ -1006,7 +1196,9 @@ def helion_method( if isinstance(kfunc, Kernel): # Helion kernel - we call it in a lambda to delay execution until measurement - measured_func_callable = lambda: kfunc(*args, **kwargs) # noqa: E731 + measured_func_callable = lambda: kfunc( + *args, **kwargs + ) # noqa: E731 else: # tritonbench integration wrapper - pass tritonbench operator instance as first argument # The wrapper must return a callable that does the actual computation, for delayed execution @@ -1112,7 +1304,9 @@ def accuracy_fail_hook( tritonbench_run(tritonbench_args) tmp.seek(0) try: - process_result(kernel_name, tmp.readlines(), results) + process_result( + kernel_name, tmp.readlines(), results, kernel_metric_mappings + ) except Exception: logger.exception("failed to process results") @@ -1133,9 +1327,24 @@ def get_device_name() -> str: def process_result( - kernel_name: str, lines: list[str], results: list[RunResult] + kernel_name: str, + lines: list[str], + results: list[RunResult], + kernel_metric_mappings: dict[str, dict[str, str]] | None = None, ) -> None: - assert kernel_name in KERNEL_METRIC_MAPPINGS + # Use provided mappings or default to global KERNEL_METRIC_MAPPINGS + active_metrics = ( + kernel_metric_mappings + if kernel_metric_mappings is not None + else KERNEL_METRIC_MAPPINGS + ) + + if kernel_name not in active_metrics: + logger.warning( + f"No metric mappings found for kernel '{kernel_name}', skipping result processing" + ) + return + names = lines[0].strip().split(";") shape = [] @@ -1148,15 +1357,13 @@ def process_result( if idx == 0: shape.append(item) else: - if name not in KERNEL_METRIC_MAPPINGS[kernel_name]: + if name not in active_metrics[kernel_name]: logger.info(f"ignoring {name}") else: if item == "": # if benchmark failed, tritonbench emits empty string item = 0.0 - metrics[KERNEL_METRIC_MAPPINGS[kernel_name][name]].append( - float(item) - ) + metrics[active_metrics[kernel_name][name]].append(float(item)) results.append( RunResult( @@ -1246,6 +1453,12 @@ def main() -> None: action="store_true", help="List implementations to be run on Benchmark CI for specified kernel(s).", ) + parser.add_argument( + "--kernel-config", + type=str, + help="Path to YAML or JSON configuration file for additional kernel mappings. " + "Custom mappings extend and can override base mappings.", + ) # Parse known args to get the kernel name, pass rest to tritonbench args, tritonbench_args = parser.parse_known_args() @@ -1265,23 +1478,69 @@ def main() -> None: ) sys.exit(1) + # Load custom kernel configurations if provided + active_kernel_mappings = KERNEL_MAPPINGS + active_metric_mappings = KERNEL_METRIC_MAPPINGS + + if args.kernel_config: + try: + print( + f"Loading custom kernel configuration from: {args.kernel_config}", + file=sys.stderr, + ) + custom_mappings, custom_metrics = load_kernel_config(args.kernel_config) + + # Report what was loaded + if custom_mappings: + print( + f"Loaded {len(custom_mappings)} kernel mapping(s): {', '.join(custom_mappings.keys())}", + file=sys.stderr, + ) + if custom_metrics: + print( + f"Loaded metric mappings for {len(custom_metrics)} kernel(s): {', '.join(custom_metrics.keys())}", + file=sys.stderr, + ) + + # Merge with base configurations + active_kernel_mappings, active_metric_mappings = merge_kernel_configs( + KERNEL_MAPPINGS, KERNEL_METRIC_MAPPINGS, custom_mappings, custom_metrics + ) + + # Report if any kernels were overridden + overridden = set(custom_mappings.keys()) & set(KERNEL_MAPPINGS.keys()) + if overridden: + print( + f"Overriding base mappings for: {', '.join(overridden)}", + file=sys.stderr, + ) + + except ( + FileNotFoundError, + ValueError, + yaml.YAMLError, + json.JSONDecodeError, + ) as e: + print(f"Error loading kernel configuration: {e}", file=sys.stderr) + sys.exit(1) + # Handle --list-impls-for-benchmark-ci flag if args.list_impls_for_benchmark_ci: - assert args.kernel, ( - "--op or --kernel must be specified with --list-impls-for-benchmark-ci" - ) + assert ( + args.kernel + ), "--op or --kernel must be specified with --list-impls-for-benchmark-ci" # List implementations for specified kernels to be run on Benchmark CI kernel_names = [k.strip() for k in args.kernel.split(",")] for kernel in kernel_names: - assert kernel in KERNEL_METRIC_MAPPINGS, ( - f"Unable to find kernel in KERNEL_METRIC_MAPPINGS: {kernel}" - ) + assert ( + kernel in active_metric_mappings + ), f"Unable to find kernel in metric mappings: {kernel}" # Extract implementation names that have speedup metrics implementations = [] baseline_impl = "" - for metric_key, metric_value in KERNEL_METRIC_MAPPINGS[kernel].items(): + for metric_key, metric_value in active_metric_mappings[kernel].items(): # Find the baseline implementation if metric_value == "baseline": baseline_impl = metric_key @@ -1327,21 +1586,28 @@ def main() -> None: kernel_names = [k.strip() for k in args.kernel.split(",")] # Validate all kernel names first - invalid_kernels = [k for k in kernel_names if k not in KERNEL_MAPPINGS] + invalid_kernels = [k for k in kernel_names if k not in active_kernel_mappings] if invalid_kernels: print( f"Error: Unknown kernel(s): {', '.join(invalid_kernels)}", file=sys.stderr, ) print( - f"Available kernels: {', '.join(KERNEL_MAPPINGS.keys())}", + f"Available kernels: {', '.join(active_kernel_mappings.keys())}", file=sys.stderr, ) sys.exit(1) # Run specified kernels if len(kernel_names) == 1: - run_kernel(kernel_names[0], tritonbench_args, input_shard_info, results) + run_kernel( + kernel_names[0], + tritonbench_args, + input_shard_info, + results, + active_kernel_mappings, + active_metric_mappings, + ) else: print( f"Running {len(kernel_names)} kernels: {', '.join(kernel_names)}...\n", @@ -1352,16 +1618,30 @@ def main() -> None: print(f"Kernel: {kernel_name}", file=sys.stderr) print(f"{'=' * 60}\n", file=sys.stderr) run_kernel( - kernel_name, tritonbench_args.copy(), input_shard_info, results + kernel_name, + tritonbench_args.copy(), + input_shard_info, + results, + active_kernel_mappings, + active_metric_mappings, ) else: # Run all kernels - print(f"Running all {len(KERNEL_MAPPINGS)} kernels...\n", file=sys.stderr) - for kernel_name in KERNEL_MAPPINGS: + print( + f"Running all {len(active_kernel_mappings)} kernels...\n", file=sys.stderr + ) + for kernel_name in active_kernel_mappings: print(f"\n{'=' * 60}", file=sys.stderr) print(f"Kernel: {kernel_name}", file=sys.stderr) print(f"{'=' * 60}\n", file=sys.stderr) - run_kernel(kernel_name, tritonbench_args.copy(), input_shard_info, results) + run_kernel( + kernel_name, + tritonbench_args.copy(), + input_shard_info, + results, + active_kernel_mappings, + active_metric_mappings, + ) if args.output: write_results_to_json( From c66baebe9964c3e28f4af3a8fe2ce02e48026229 Mon Sep 17 00:00:00 2001 From: Alessandro Sangiorgi Date: Fri, 21 Nov 2025 12:50:23 -0600 Subject: [PATCH 2/3] Some fixes and keep yaml import only if load is used --- benchmarks/run.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/benchmarks/run.py b/benchmarks/run.py index a0a423e74..02a325872 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -38,12 +38,10 @@ from typing import Any from typing import Callable from typing import cast -import warnings import torch from torch.utils._pytree import tree_leaves from torch.utils._pytree import tree_map -import yaml from helion._testing import get_nvidia_gpu_model from helion._utils import counters @@ -678,6 +676,12 @@ def load_kernel_config( # Load configuration with open(config_file, "r") as f: if config_file.suffix in [".yaml", ".yml"]: + try: + import yaml + except ImportError as e: + raise RuntimeError( + "YAML configuration requested but PyYAML is not installed." + ) from e config = yaml.safe_load(f) elif config_file.suffix == ".json": config = json.load(f) @@ -1518,7 +1522,7 @@ def main() -> None: except ( FileNotFoundError, ValueError, - yaml.YAMLError, + RuntimeError, json.JSONDecodeError, ) as e: print(f"Error loading kernel configuration: {e}", file=sys.stderr) From aca75cf47563f781f270cb9db245b434f0148877 Mon Sep 17 00:00:00 2001 From: Alessandro Sangiorgi Date: Fri, 21 Nov 2025 12:56:13 -0600 Subject: [PATCH 3/3] Remove some comments --- benchmarks/run.py | 66 ++++++++++++++++++----------------------------- 1 file changed, 25 insertions(+), 41 deletions(-) diff --git a/benchmarks/run.py b/benchmarks/run.py index 02a325872..8575703d5 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -674,7 +674,7 @@ def load_kernel_config( raise FileNotFoundError(f"Configuration file not found: {config_path}") # Load configuration - with open(config_file, "r") as f: + with open(config_file) as f: if config_file.suffix in [".yaml", ".yml"]: try: import yaml @@ -696,19 +696,16 @@ def load_kernel_config( kernel_mappings = {} kernel_metric_mappings = {} - # Process kernel mappings - now reusing process_single_kernel_mapping if "kernel_mappings" in config: raw_mappings = config["kernel_mappings"] if not isinstance(raw_mappings, dict): raise ValueError("kernel_mappings must be a dictionary") for kernel_name, mapping in raw_mappings.items(): - # Reuse the single source of truth for processing kernel_mappings[kernel_name] = process_single_kernel_mapping( kernel_name, mapping ) - # Process kernel metric mappings if "kernel_metric_mappings" in config: raw_metrics = config["kernel_metric_mappings"] if not isinstance(raw_metrics, dict): @@ -727,15 +724,12 @@ def load_kernel_config( if gpu_model in config["hardware_overrides"]: hw_config = config["hardware_overrides"][gpu_model] - # Merge hardware-specific kernel mappings - now reusing the helper if "kernel_mappings" in hw_config: for kernel_name, mapping in hw_config["kernel_mappings"].items(): - # Reuse the same helper for consistency kernel_mappings[kernel_name] = process_single_kernel_mapping( kernel_name, mapping ) - # Merge hardware-specific metric mappings if "kernel_metric_mappings" in hw_config: for kernel_name, metrics in hw_config["kernel_metric_mappings"].items(): if kernel_name not in kernel_metric_mappings: @@ -759,7 +753,6 @@ def process_single_kernel_mapping( tritonbench_module = mapping["tritonbench_module"] - # Handle variants if "variants" in mapping: variants = [] for variant in mapping["variants"]: @@ -771,28 +764,24 @@ def process_single_kernel_mapping( if "args" in mapping: return (tritonbench_module, variants, mapping["args"]) - else: - return (tritonbench_module, variants) - else: - # Single implementation format - if "helion_module" not in mapping or "helion_func" not in mapping: - raise ValueError( - f"Kernel '{kernel_name}' must have 'helion_module' and 'helion_func' or 'variants'" - ) + return (tritonbench_module, variants) + if "helion_module" not in mapping or "helion_func" not in mapping: + raise ValueError( + f"Kernel '{kernel_name}' must have 'helion_module' and 'helion_func' or 'variants'" + ) - if "args" in mapping: - return ( - tritonbench_module, - mapping["helion_module"], - mapping["helion_func"], - mapping["args"], - ) - else: - return ( - tritonbench_module, - mapping["helion_module"], - mapping["helion_func"], - ) + if "args" in mapping: + return ( + tritonbench_module, + mapping["helion_module"], + mapping["helion_func"], + mapping["args"], + ) + return ( + tritonbench_module, + mapping["helion_module"], + mapping["helion_func"], + ) def merge_kernel_configs( @@ -818,17 +807,14 @@ def merge_kernel_configs( Returns: Tuple of merged (kernel_mappings, kernel_metric_mappings) """ - # Start with base, then overlay custom (custom takes precedence) merged_mappings = {**base_mappings, **custom_mappings} # For metrics, merge at the kernel level merged_metrics = dict(base_metrics) for kernel, metrics in custom_metrics.items(): if kernel in merged_metrics: - # Merge metrics for this kernel (custom overrides base for same keys) merged_metrics[kernel] = {**merged_metrics[kernel], **metrics} else: - # New kernel, add all its metrics merged_metrics[kernel] = metrics return merged_mappings, merged_metrics @@ -1200,9 +1186,7 @@ def helion_method( if isinstance(kfunc, Kernel): # Helion kernel - we call it in a lambda to delay execution until measurement - measured_func_callable = lambda: kfunc( - *args, **kwargs - ) # noqa: E731 + measured_func_callable = lambda: kfunc(*args, **kwargs) # noqa: E731 else: # tritonbench integration wrapper - pass tritonbench operator instance as first argument # The wrapper must return a callable that does the actual computation, for delayed execution @@ -1530,15 +1514,15 @@ def main() -> None: # Handle --list-impls-for-benchmark-ci flag if args.list_impls_for_benchmark_ci: - assert ( - args.kernel - ), "--op or --kernel must be specified with --list-impls-for-benchmark-ci" + assert args.kernel, ( + "--op or --kernel must be specified with --list-impls-for-benchmark-ci" + ) # List implementations for specified kernels to be run on Benchmark CI kernel_names = [k.strip() for k in args.kernel.split(",")] for kernel in kernel_names: - assert ( - kernel in active_metric_mappings - ), f"Unable to find kernel in metric mappings: {kernel}" + assert kernel in active_metric_mappings, ( + f"Unable to find kernel in metric mappings: {kernel}" + ) # Extract implementation names that have speedup metrics implementations = []