-
Notifications
You must be signed in to change notification settings - Fork 70
Automatically use zero tolerance for bitwise comparison for fp8 dtypes during autotuning #1158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
b5bbd30 to
fa64d26
Compare
This change automatically sets atol=0.0 and rtol=0.0 for autotuning accuracy checks when fp8 dtypes are detected in kernel outputs, while respecting user-specified tolerance values. Changes: - Added _user_set_atol and _user_set_rtol flags to Settings to track whether users explicitly set tolerance values - Added _compute_effective_tolerances() method to BaseSearch that detects fp8 dtypes and automatically sets tolerances to 0.0 when user hasn't explicitly specified them - Updated _validate_against_baseline() to use effective tolerances instead of settings values directly Behavior: - If user sets neither tolerance: automatic 0.0 for fp8 dtypes - If user sets either tolerance: both values are respected (no override) - Supports all fp8 variants: float8_e4m3fn, float8_e5m2, float8_e4m3fnuz, float8_e5m2fnuz, float8_e8m0fnu Tests: - test_autotune_fp8_automatic_tolerance: verifies automatic detection - test_autotune_fp8_explicit_tolerance_override: verifies user values respected - test_autotune_fp8_explicit_default_tolerance: verifies explicit 1e-2 respected - test_autotune_fp8_partial_tolerance_override: verifies partial specification respected Fixes issue where fp8 kernels would fail autotuning with: "Rtol=0.01 and atol=0.01 are not supported for bitwise comparison of low dimensional floats"
fa64d26 to
de76da9
Compare
helion/runtime/settings.py
Outdated
| autotune_baseline_rtol: float = dataclasses.field(default=1e-2) | ||
| # Internal fields to track if user explicitly set tolerance values | ||
| _user_set_atol: bool = dataclasses.field(default=False, init=False, repr=False) | ||
| _user_set_rtol: bool = dataclasses.field(default=False, init=False, repr=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if it's possible, but would really love to avoid needing to have these private fields on the settings class, to keep things more simple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I changed the implementation a bit. Now the autotune_baseline_atol field is an optional[float]. None value represents unset tolerance level.
gmagogsfm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PTAL
helion/runtime/settings.py
Outdated
| autotune_baseline_rtol: float = dataclasses.field(default=1e-2) | ||
| # Internal fields to track if user explicitly set tolerance values | ||
| _user_set_atol: bool = dataclasses.field(default=False, init=False, repr=False) | ||
| _user_set_rtol: bool = dataclasses.field(default=False, init=False, repr=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I changed the implementation a bit. Now the autotune_baseline_atol field is an optional[float]. None value represents unset tolerance level.
This change automatically sets atol=0.0 and rtol=0.0 for autotuning accuracy checks when fp8 dtypes are detected in kernel outputs, while respecting user-specified tolerance values.
Behavior:
Fixes issue where kernels returning fp8 types would fail autotuning with: "Rtol=0.01 and atol=0.01 are not supported for bitwise comparison of low dimensional floats", due to pytorch's bitwise comparison