Skip to content

Conversation

@AndreSlavescu
Copy link
Contributor

@AndreSlavescu AndreSlavescu commented Jul 12, 2025

Summary

Optimizing Softmax and RMSNorm runtime performance on hidden_size >= 64k

Testing Done

Added large tests for 64K dim

  • Hardware Type: RTX 3090
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

W: (H,)
"""
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
num_stages = calculate_num_stages()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We can just return num_stages from rms_norm_forward() like num_warps to avoid calling it again.

}


def calculate_num_stages():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a table where we can look up these properties?

),
],
)
def test_large_64k_softmax_correctness(dtype, atol, rtol):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any considerations of not just adding a test case to original tests?

(GemmaRMSNorm, 1.0, "gemma"),
],
)
def test_large_64k_correctness(dtype, atol, rtol, reference, offset, casting_mode):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment on lines +130 to +131
device = torch.cuda.current_device()
torch_device_props = torch.cuda.get_device_properties(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


num_warps = 4
if BLOCK_SIZE >= 32768:
if BLOCK_SIZE >= 65536:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, I'm always wondering why we don't take element_size into account. Do you have any idea?

@shimizust
Copy link
Collaborator

Can you show some plots comparing the perf before/after?

@upskyy upskyy mentioned this pull request Nov 8, 2025
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants