Skip to content

Conversation

@hlahkar
Copy link

@hlahkar hlahkar commented Oct 28, 2025

This enables GPT OSS with naive attention. Features enabled:

  1. Sinks in Attention
  2. Bias in MoE

Signed-off-by: Himangshu Lahkar <[email protected]>
Signed-off-by: Himangshu Lahkar <[email protected]>
Signed-off-by: Himangshu Lahkar <[email protected]>
Signed-off-by: Himangshu Lahkar <[email protected]>
Signed-off-by: Himangshu Lahkar <[email protected]>
Signed-off-by: Himangshu Lahkar <[email protected]>
if self.bias is not None:
w1_bias_list = [self.w13_list[i].bias.squeeze() for i in experts_range]
w2_bias_list = [self.w2_list[i].bias.squeeze() for i in experts_range]
return torch.ops.hpu.mixture_of_experts.bias_fused_weights(hidden_states=hidden_states,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test fails with:
"The underlying op of 'hpu.mixture_of_experts' has no overload name 'bias_fused_weights'. Did you mean: 'fp8_fused_weights'" please fix

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CI is on 1.22.0; this needs 1.23.0 software, that's the reason it's failing; we can merge this only after CI moves to 1.23.0 release

Copilot AI review requested due to automatic review settings November 10, 2025 05:04
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enables GPT-OSS model support with two main features: attention sinks for improved context handling and bias support in Mixture of Experts (MoE) layers.

Key Changes:

  • Added sink attention mechanism to handle long-context scenarios across naive, FSDPA, and flat attention implementations
  • Implemented bias support in MoE operations for models requiring biased expert computations
  • Added model-specific routing logic for GPT-OSS in the MoE forward pass

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
vllm_gaudi/ops/hpu_fused_moe.py Added bias handling in MoE layers and GPT-OSS specific router weight processing
vllm_gaudi/extension/utils.py Extended FSDPA forward method to accept sinks parameter
vllm_gaudi/extension/ops.py Implemented sink attention logic across multiple attention implementations and added bias support to MoE operations
vllm_gaudi/attention/backends/hpu_attn.py Added sinks parameter to attention implementations with validation and dtype conversion
tests/unit_tests/sinks/test_gpt_oss.py Added integration test for GPT-OSS model with expected outputs
Comments suppressed due to low confidence (2)

vllm_gaudi/attention/backends/hpu_attn.py:1

  • Missing space after '#' in comment. Should be '# causal' for proper comment formatting.
# SPDX-License-Identifier: Apache-2.0

vllm_gaudi/attention/backends/hpu_attn.py:1

  • Inconsistent TODO format: should be 'TODO:' with a colon instead of a dash.
# SPDX-License-Identifier: Apache-2.0

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +566 to +572
w12=w1_list,
w3=w2_list,
w12_bias=w1_bias_list_slice,
w3_bias=w2_bias_list_slice,
permuted_weights=permuted_weights,
experts_min=self.experts_min,
experts_max=self.experts_max)
Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect weight lists passed to MoE operation. Should use sliced lists w1_list_slice and w2_list_slice instead of full lists w1_list and w2_list to match the expert range being processed.

Suggested change
w12=w1_list,
w3=w2_list,
w12_bias=w1_bias_list_slice,
w3_bias=w2_bias_list_slice,
permuted_weights=permuted_weights,
experts_min=self.experts_min,
experts_max=self.experts_max)
w12=w1_list_slice,
w3=w2_list_slice,
w12_bias=w1_bias_list_slice,
w3_bias=w2_bias_list_slice,
permuted_weights=permuted_weights,
experts_min=min_expert,
experts_max=max_expert)

Copilot uses AI. Check for mistakes.
Comment on lines +571 to +572
experts_min=self.experts_min,
experts_max=self.experts_max)
Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect expert range parameters. Should use min_expert and max_expert (computed for the current slice) instead of self.experts_min and self.experts_max to correctly process the expert slice.

Suggested change
experts_min=self.experts_min,
experts_max=self.experts_max)
experts_min=min_expert,
experts_max=max_expert)

Copilot uses AI. Check for mistakes.
Comment on lines 596 to 598
# TODO - change 128 to proper window size
window_size = (
128,
Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Magic number 128 used for window size. Consider defining this as a named constant or deriving it from self.sliding_window as indicated by the TODO comment.

Suggested change
# TODO - change 128 to proper window size
window_size = (
128,
# Use self.sliding_window for window size instead of hardcoded 128
window_size = (
self.sliding_window,

Copilot uses AI. Check for mistakes.
tensor_parallel_size=4,
)
generated_texts = do_sample(llm, original_output=original_output_120, rtol=1e-01, atol=1e-01, max_num_seqs=1)
assert generated_texts == expected_output
Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assertion compares single generated text with expected output incorrectly. The function returns a list but only validates the first element earlier. This assertion will fail unless generated_texts contains exactly one element matching expected_output[0]. Consider assert generated_texts[0] == expected_output[0] or assert generated_texts == expected_output after validating the list length.

Suggested change
assert generated_texts == expected_output
assert len(generated_texts) == len(expected_output)
assert generated_texts[0] == expected_output[0]

Copilot uses AI. Check for mistakes.
attn_sink = attn_sink.exp()
if attn_sink.dtype == torch.float32:
attn_sink = attn_sink.to(value.dtype)
#TODO: Removing this .sum and using attn_sink directly
Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spacing in TODO comment: should be 'TODO:' with a space after the colon for consistency.

Suggested change
#TODO: Removing this .sum and using attn_sink directly
# TODO: Removing this .sum and using attn_sink directly

Copilot uses AI. Check for mistakes.
attn_bias = None
window_size = (self.sliding_window, 0)
common_args['window_size'] = window_size
# TODO - change 128 to proper window size
Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent TODO format: should be 'TODO:' with a colon instead of a dash for consistency with project conventions.

Suggested change
# TODO - change 128 to proper window size
# TODO: change 128 to proper window size

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants