diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 6342b43af..121703e76 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -105,6 +105,8 @@ class BaseSearch(BaseAutotuner): _baseline_post_args: Sequence[object] | None _jobs: int _precompile_result_counter: count[int] + _effective_atol: float + _effective_rtol: float def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None: """ @@ -134,6 +136,9 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None: self._kernel_mutates_args, self._baseline_post_args, ) = self._compute_baseline() + self._effective_atol, self._effective_rtol = ( + self._compute_effective_tolerances() + ) self._jobs = self._decide_num_jobs() def _next_precompile_result_path(self) -> str: @@ -222,6 +227,66 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]: baseline_post_args = self._clone_args(new_args) return baseline_output, mutated, baseline_post_args + def _compute_effective_tolerances(self) -> tuple[float, float]: + """ + Compute effective tolerances based on the dtypes in the baseline output. + + For low-precision dtypes (fp8), we need stricter tolerances to ensure + bitwise comparison works correctly. This method automatically detects + such dtypes and adjusts tolerances accordingly. + + Returns: + A tuple of (atol, rtol) to use for accuracy validation. + """ + # Default tolerance when not user-specified + DEFAULT_TOL = 1e-2 + + # Get user-specified or default tolerances + atol = self.settings.autotune_baseline_atol + rtol = self.settings.autotune_baseline_rtol + + # Collect all dtypes from baseline output and mutated args + dtypes = set() + + def collect_dtypes(obj: object) -> object: + if isinstance(obj, torch.Tensor): + dtypes.add(obj.dtype) + return obj + + tree_map_only(torch.Tensor, collect_dtypes, self._baseline_output) + if self._kernel_mutates_args and self._baseline_post_args is not None: + tree_map_only(torch.Tensor, collect_dtypes, self._baseline_post_args) + + # Check for fp8 dtypes - these require exact bitwise comparison + fp8_dtypes = { + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + torch.float8_e8m0fnu, + } + + # Only apply strict tolerances if ALL dtypes are fp8 + # Mixed dtypes (fp8 + fp32) would be too strict with atol=0.0, rtol=0.0 + all_dtypes_are_fp8 = dtypes and all(dtype in fp8_dtypes for dtype in dtypes) + + if all_dtypes_are_fp8: + # All dtypes are fp8 - use bitwise comparison + # unless the user explicitly set either tolerance value (i.e., not None) + user_set_either = atol is not None or rtol is not None + if not user_set_either: + self.log( + f"Detected fp8 dtype(s) in output: {dtypes}. " + "Using bitwise comparison (atol=0.0, rtol=0.0) for autotuning accuracy check." + ) + return 0.0, 0.0 + + # Use user-specified values or defaults + return ( + atol if atol is not None else DEFAULT_TOL, + rtol if rtol is not None else DEFAULT_TOL, + ) + def _decide_num_jobs(self) -> int: if not self.settings.autotune_precompile: return 1 @@ -278,15 +343,15 @@ def _validate_against_baseline( torch.testing.assert_close( output, self._baseline_output, - atol=self.settings.autotune_baseline_atol, - rtol=self.settings.autotune_baseline_rtol, + atol=self._effective_atol, + rtol=self._effective_rtol, ) if self._kernel_mutates_args: torch.testing.assert_close( args, self._baseline_post_args, - atol=self.settings.autotune_baseline_atol, - rtol=self.settings.autotune_baseline_rtol, + atol=self._effective_atol, + rtol=self._effective_rtol, ) except AssertionError as e: self.counters["accuracy_mismatch"] += 1 diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 60fad2ddb..ce3db4989 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -399,8 +399,8 @@ class _Settings: ) autotuner_fn: AutotunerFunction = default_autotuner_fn autotune_baseline_fn: Callable[..., object] | None = None - autotune_baseline_atol: float = 1e-2 - autotune_baseline_rtol: float = 1e-2 + autotune_baseline_atol: float | None = None + autotune_baseline_rtol: float | None = None class Settings(_Settings): @@ -478,11 +478,13 @@ class Settings(_Settings): ), "autotune_baseline_atol": ( "Absolute tolerance for baseline output comparison during autotuning accuracy checks. " - "Defaults to 1e-2. Pass as @helion.kernel(..., autotune_baseline_atol=1e-3)." + "Defaults to 1e-2, or 0.0 for fp8 dtypes (automatic bitwise comparison). " + "Pass as @helion.kernel(..., autotune_baseline_atol=1e-3)." ), "autotune_baseline_rtol": ( "Relative tolerance for baseline output comparison during autotuning accuracy checks. " - "Defaults to 1e-2. Pass as @helion.kernel(..., autotune_baseline_rtol=1e-3)." + "Defaults to 1e-2, or 0.0 for fp8 dtypes (automatic bitwise comparison). " + "Pass as @helion.kernel(..., autotune_baseline_rtol=1e-3)." ), "autotune_cache": ( "The name of the autotuner cache class to use. " @@ -495,7 +497,6 @@ def __init__(self, **settings: object) -> None: """ Initialize the Settings object with the provided dictionary of settings. """ - # pyrefly: ignore [bad-argument-type] super().__init__(**settings) diff --git a/test/test_autotuner.py b/test/test_autotuner.py index 03d554233..0f503f225 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -960,6 +960,77 @@ def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: self.assertIn(winner, (cfg1, cfg2)) self.assertEqual(search.counters["accuracy_mismatch"], 0) + @skipIfCpu("fails on Triton CPU backend") + @skipIfRocm("fp8 dtypes not supported on ROCm") + @unittest.skipIf( + not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9, + "FP8 requires GPU with compute capability >= 9.0 (e.g., H100)", + ) + def test_autotune_fp8_automatic_tolerance(self) -> None: + """Test that fp8 dtypes automatically get 0.0 tolerances.""" + cfg1 = helion.Config(block_sizes=[16], num_warps=4) + cfg2 = helion.Config(block_sizes=[32], num_warps=8) + + # Test with float8_e4m3fn as a representative fp8 dtype + @helion.kernel(configs=[cfg1, cfg2]) + def cast_to_fp8(x: torch.Tensor) -> torch.Tensor: + out = torch.empty(x.size(), dtype=torch.float8_e4m3fn, device=x.device) + for t in hl.tile(x.size()): + out[t] = x[t].to(torch.float8_e4m3fn) + return out + + x = torch.randn([64], device=DEVICE) + bound = cast_to_fp8.bind((x,)) + search = FiniteSearch(bound, (x,), configs=[cfg1, cfg2]) + + # Verify that effective tolerances were set to 0.0 automatically + self.assertEqual( + search._effective_atol, + 0.0, + f"Expected automatic atol=0.0 for fp8, got {search._effective_atol}", + ) + self.assertEqual( + search._effective_rtol, + 0.0, + f"Expected automatic rtol=0.0 for fp8, got {search._effective_rtol}", + ) + + # Should successfully autotune without error + winner = search.autotune() + self.assertIn(winner, (cfg1, cfg2)) + self.assertEqual(search.counters["accuracy_mismatch"], 0) + + @skipIfCpu("fails on Triton CPU backend") + @skipIfRocm("fp8 dtypes not supported on ROCm") + @unittest.skipIf( + not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9, + "FP8 requires GPU with compute capability >= 9.0 (e.g., H100)", + ) + def test_autotune_fp8_explicit_tolerance_override(self) -> None: + """Test that explicit tolerances override automatic fp8 detection.""" + cfg1 = helion.Config(block_sizes=[16], num_warps=4) + cfg2 = helion.Config(block_sizes=[32], num_warps=8) + + # User explicitly sets non-zero tolerances despite fp8 output + @helion.kernel( + configs=[cfg1, cfg2], + autotune_baseline_atol=1e-5, + autotune_baseline_rtol=1e-5, + ) + def cast_to_fp8(x: torch.Tensor) -> torch.Tensor: + out = torch.empty(x.size(), dtype=torch.float8_e4m3fn, device=x.device) + for t in hl.tile(x.size()): + out[t] = x[t].to(torch.float8_e4m3fn) + return out + + x = torch.randn([64], device=DEVICE) + bound = cast_to_fp8.bind((x,)) + search = FiniteSearch(bound, (x,), configs=[cfg1, cfg2]) + + # Should respect user's explicit tolerances, not override to 0.0 + self.assertEqual(search._effective_atol, 1e-5) + self.assertEqual(search._effective_rtol, 1e-5) + @skipIfCpu("fails on Triton CPU backend") def test_max_generations(self): """Autotuner max generation respects explicit kwargs then setting override."""