Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 69 additions & 4 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class BaseSearch(BaseAutotuner):
_baseline_post_args: Sequence[object] | None
_jobs: int
_precompile_result_counter: count[int]
_effective_atol: float
_effective_rtol: float

def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
"""
Expand Down Expand Up @@ -134,6 +136,9 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
self._kernel_mutates_args,
self._baseline_post_args,
) = self._compute_baseline()
self._effective_atol, self._effective_rtol = (
self._compute_effective_tolerances()
)
self._jobs = self._decide_num_jobs()

def _next_precompile_result_path(self) -> str:
Expand Down Expand Up @@ -222,6 +227,66 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]:
baseline_post_args = self._clone_args(new_args)
return baseline_output, mutated, baseline_post_args

def _compute_effective_tolerances(self) -> tuple[float, float]:
"""
Compute effective tolerances based on the dtypes in the baseline output.

For low-precision dtypes (fp8), we need stricter tolerances to ensure
bitwise comparison works correctly. This method automatically detects
such dtypes and adjusts tolerances accordingly.

Returns:
A tuple of (atol, rtol) to use for accuracy validation.
"""
# Default tolerance when not user-specified
DEFAULT_TOL = 1e-2

# Get user-specified or default tolerances
atol = self.settings.autotune_baseline_atol
rtol = self.settings.autotune_baseline_rtol

# Collect all dtypes from baseline output and mutated args
dtypes = set()

def collect_dtypes(obj: object) -> object:
if isinstance(obj, torch.Tensor):
dtypes.add(obj.dtype)
return obj

tree_map_only(torch.Tensor, collect_dtypes, self._baseline_output)
if self._kernel_mutates_args and self._baseline_post_args is not None:
tree_map_only(torch.Tensor, collect_dtypes, self._baseline_post_args)

# Check for fp8 dtypes - these require exact bitwise comparison
fp8_dtypes = {
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
}

# Only apply strict tolerances if ALL dtypes are fp8
# Mixed dtypes (fp8 + fp32) would be too strict with atol=0.0, rtol=0.0
all_dtypes_are_fp8 = dtypes and all(dtype in fp8_dtypes for dtype in dtypes)

if all_dtypes_are_fp8:
# All dtypes are fp8 - use bitwise comparison
# unless the user explicitly set either tolerance value (i.e., not None)
user_set_either = atol is not None or rtol is not None
if not user_set_either:
self.log(
f"Detected fp8 dtype(s) in output: {dtypes}. "
"Using bitwise comparison (atol=0.0, rtol=0.0) for autotuning accuracy check."
)
return 0.0, 0.0

# Use user-specified values or defaults
return (
atol if atol is not None else DEFAULT_TOL,
rtol if rtol is not None else DEFAULT_TOL,
)

def _decide_num_jobs(self) -> int:
if not self.settings.autotune_precompile:
return 1
Expand Down Expand Up @@ -278,15 +343,15 @@ def _validate_against_baseline(
torch.testing.assert_close(
output,
self._baseline_output,
atol=self.settings.autotune_baseline_atol,
rtol=self.settings.autotune_baseline_rtol,
atol=self._effective_atol,
rtol=self._effective_rtol,
)
if self._kernel_mutates_args:
torch.testing.assert_close(
args,
self._baseline_post_args,
atol=self.settings.autotune_baseline_atol,
rtol=self.settings.autotune_baseline_rtol,
atol=self._effective_atol,
rtol=self._effective_rtol,
)
except AssertionError as e:
self.counters["accuracy_mismatch"] += 1
Expand Down
11 changes: 6 additions & 5 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,8 @@ class _Settings:
)
autotuner_fn: AutotunerFunction = default_autotuner_fn
autotune_baseline_fn: Callable[..., object] | None = None
autotune_baseline_atol: float = 1e-2
autotune_baseline_rtol: float = 1e-2
autotune_baseline_atol: float | None = None
autotune_baseline_rtol: float | None = None


class Settings(_Settings):
Expand Down Expand Up @@ -478,11 +478,13 @@ class Settings(_Settings):
),
"autotune_baseline_atol": (
"Absolute tolerance for baseline output comparison during autotuning accuracy checks. "
"Defaults to 1e-2. Pass as @helion.kernel(..., autotune_baseline_atol=1e-3)."
"Defaults to 1e-2, or 0.0 for fp8 dtypes (automatic bitwise comparison). "
"Pass as @helion.kernel(..., autotune_baseline_atol=1e-3)."
),
"autotune_baseline_rtol": (
"Relative tolerance for baseline output comparison during autotuning accuracy checks. "
"Defaults to 1e-2. Pass as @helion.kernel(..., autotune_baseline_rtol=1e-3)."
"Defaults to 1e-2, or 0.0 for fp8 dtypes (automatic bitwise comparison). "
"Pass as @helion.kernel(..., autotune_baseline_rtol=1e-3)."
),
"autotune_cache": (
"The name of the autotuner cache class to use. "
Expand All @@ -495,7 +497,6 @@ def __init__(self, **settings: object) -> None:
"""
Initialize the Settings object with the provided dictionary of settings.
"""

# pyrefly: ignore [bad-argument-type]
super().__init__(**settings)

Expand Down
71 changes: 71 additions & 0 deletions test/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,77 @@ def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
self.assertIn(winner, (cfg1, cfg2))
self.assertEqual(search.counters["accuracy_mismatch"], 0)

@skipIfCpu("fails on Triton CPU backend")
@skipIfRocm("fp8 dtypes not supported on ROCm")
@unittest.skipIf(
not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9,
"FP8 requires GPU with compute capability >= 9.0 (e.g., H100)",
)
def test_autotune_fp8_automatic_tolerance(self) -> None:
"""Test that fp8 dtypes automatically get 0.0 tolerances."""
cfg1 = helion.Config(block_sizes=[16], num_warps=4)
cfg2 = helion.Config(block_sizes=[32], num_warps=8)

# Test with float8_e4m3fn as a representative fp8 dtype
@helion.kernel(configs=[cfg1, cfg2])
def cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
out = torch.empty(x.size(), dtype=torch.float8_e4m3fn, device=x.device)
for t in hl.tile(x.size()):
out[t] = x[t].to(torch.float8_e4m3fn)
return out

x = torch.randn([64], device=DEVICE)
bound = cast_to_fp8.bind((x,))
search = FiniteSearch(bound, (x,), configs=[cfg1, cfg2])

# Verify that effective tolerances were set to 0.0 automatically
self.assertEqual(
search._effective_atol,
0.0,
f"Expected automatic atol=0.0 for fp8, got {search._effective_atol}",
)
self.assertEqual(
search._effective_rtol,
0.0,
f"Expected automatic rtol=0.0 for fp8, got {search._effective_rtol}",
)

# Should successfully autotune without error
winner = search.autotune()
self.assertIn(winner, (cfg1, cfg2))
self.assertEqual(search.counters["accuracy_mismatch"], 0)

@skipIfCpu("fails on Triton CPU backend")
@skipIfRocm("fp8 dtypes not supported on ROCm")
@unittest.skipIf(
not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9,
"FP8 requires GPU with compute capability >= 9.0 (e.g., H100)",
)
def test_autotune_fp8_explicit_tolerance_override(self) -> None:
"""Test that explicit tolerances override automatic fp8 detection."""
cfg1 = helion.Config(block_sizes=[16], num_warps=4)
cfg2 = helion.Config(block_sizes=[32], num_warps=8)

# User explicitly sets non-zero tolerances despite fp8 output
@helion.kernel(
configs=[cfg1, cfg2],
autotune_baseline_atol=1e-5,
autotune_baseline_rtol=1e-5,
)
def cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
out = torch.empty(x.size(), dtype=torch.float8_e4m3fn, device=x.device)
for t in hl.tile(x.size()):
out[t] = x[t].to(torch.float8_e4m3fn)
return out

x = torch.randn([64], device=DEVICE)
bound = cast_to_fp8.bind((x,))
search = FiniteSearch(bound, (x,), configs=[cfg1, cfg2])

# Should respect user's explicit tolerances, not override to 0.0
self.assertEqual(search._effective_atol, 1e-5)
self.assertEqual(search._effective_rtol, 1e-5)

@skipIfCpu("fails on Triton CPU backend")
def test_max_generations(self):
"""Autotuner max generation respects explicit kwargs then setting override."""
Expand Down
Loading