Skip to content

Conversation

@gmagogsfm
Copy link
Contributor

This change automatically sets atol=0.0 and rtol=0.0 for autotuning accuracy checks when fp8 dtypes are detected in kernel outputs, while respecting user-specified tolerance values.

Behavior:

  • If user sets neither tolerance: automatic 0.0 for fp8 dtypes
  • If user sets either tolerance: both values are respected (no override)
  • Supports all fp8 variants: float8_e4m3fn, float8_e5m2, float8_e4m3fnuz, float8_e5m2fnuz, float8_e8m0fnu

Fixes issue where kernels returning fp8 types would fail autotuning with: "Rtol=0.01 and atol=0.01 are not supported for bitwise comparison of low dimensional floats", due to pytorch's bitwise comparison

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 21, 2025
@gmagogsfm gmagogsfm force-pushed the auto-fp8-tolerance branch 2 times, most recently from b5bbd30 to fa64d26 Compare November 21, 2025 00:17
@gmagogsfm
Copy link
Contributor Author

@jansel @yf225 When testing with vLLM kernel, I discovered that quantized outputs (fp8 type) would trigger an assert in PyTorch when tolerance level is set to anything non-zero. This PR would automatically set the tolerance level correctly for these kernels.

This change automatically sets atol=0.0 and rtol=0.0 for autotuning
accuracy checks when fp8 dtypes are detected in kernel outputs, while
respecting user-specified tolerance values.

Changes:
- Added _user_set_atol and _user_set_rtol flags to Settings to track
  whether users explicitly set tolerance values
- Added _compute_effective_tolerances() method to BaseSearch that
  detects fp8 dtypes and automatically sets tolerances to 0.0 when
  user hasn't explicitly specified them
- Updated _validate_against_baseline() to use effective tolerances
  instead of settings values directly

Behavior:
- If user sets neither tolerance: automatic 0.0 for fp8 dtypes
- If user sets either tolerance: both values are respected (no override)
- Supports all fp8 variants: float8_e4m3fn, float8_e5m2, float8_e4m3fnuz,
  float8_e5m2fnuz, float8_e8m0fnu

Tests:
- test_autotune_fp8_automatic_tolerance: verifies automatic detection
- test_autotune_fp8_explicit_tolerance_override: verifies user values respected
- test_autotune_fp8_explicit_default_tolerance: verifies explicit 1e-2 respected
- test_autotune_fp8_partial_tolerance_override: verifies partial specification respected

Fixes issue where fp8 kernels would fail autotuning with:
"Rtol=0.01 and atol=0.01 are not supported for bitwise comparison of low dimensional floats"
autotune_baseline_rtol: float = dataclasses.field(default=1e-2)
# Internal fields to track if user explicitly set tolerance values
_user_set_atol: bool = dataclasses.field(default=False, init=False, repr=False)
_user_set_rtol: bool = dataclasses.field(default=False, init=False, repr=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it's possible, but would really love to avoid needing to have these private fields on the settings class, to keep things more simple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I changed the implementation a bit. Now the autotune_baseline_atol field is an optional[float]. None value represents unset tolerance level.

Copy link
Contributor Author

@gmagogsfm gmagogsfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PTAL

autotune_baseline_rtol: float = dataclasses.field(default=1e-2)
# Internal fields to track if user explicitly set tolerance values
_user_set_atol: bool = dataclasses.field(default=False, init=False, repr=False)
_user_set_rtol: bool = dataclasses.field(default=False, init=False, repr=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I changed the implementation a bit. Now the autotune_baseline_atol field is an optional[float]. None value represents unset tolerance level.

@gmagogsfm gmagogsfm requested a review from yf225 November 21, 2025 17:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants