-
Notifications
You must be signed in to change notification settings - Fork 433
Optimize Runtime Perf #806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| W: (H,) | ||
| """ | ||
| Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode) | ||
| num_stages = calculate_num_stages() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: We can just return num_stages from rms_norm_forward() like num_warps to avoid calling it again.
| } | ||
|
|
||
|
|
||
| def calculate_num_stages(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a table where we can look up these properties?
| ), | ||
| ], | ||
| ) | ||
| def test_large_64k_softmax_correctness(dtype, atol, rtol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any considerations of not just adding a test case to original tests?
| (GemmaRMSNorm, 1.0, "gemma"), | ||
| ], | ||
| ) | ||
| def test_large_64k_correctness(dtype, atol, rtol, reference, offset, casting_mode): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| device = torch.cuda.current_device() | ||
| torch_device_props = torch.cuda.get_device_properties(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make it xpu compatible
https://docs.pytorch.org/docs/stable/generated/torch.xpu.get_device_properties.html#torch.xpu.get_device_properties
|
|
||
| num_warps = 4 | ||
| if BLOCK_SIZE >= 32768: | ||
| if BLOCK_SIZE >= 65536: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, I'm always wondering why we don't take element_size into account. Do you have any idea?
|
Can you show some plots comparing the perf before/after? |
Summary
Optimizing Softmax and RMSNorm runtime performance on hidden_size >= 64k
Testing Done
Added large tests for 64K dim
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence