-
Notifications
You must be signed in to change notification settings - Fork 77
Initial Commit GPT-OSS #485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
c52dbfd
bc3d704
f3e2553
1d35ae9
a350ae9
1928416
596433a
2a2968a
2ecb728
051a0c0
d4eee4d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| import vllm | ||
| import os | ||
| from vllm.entrypoints.llm import LLM | ||
|
|
||
| RUN_20B_MODEL = True # Set to False to run the 120B model instead | ||
| MODEL_PATH = "lmsys/gpt-oss-20b-BF16" | ||
| MODEL_PATH_120 = "lmsys/gpt-oss-120b-BF16" | ||
| # reference https://github.com/huggingface/transformers/blob/68eb1a9a6353911f491b1c8139eb73d052a8e9b9/tests/models/gpt_oss/test_modeling_gpt_oss.py#L397 | ||
| original_output = "Roses are red, violets are blue, I love you, and I love you too.\n\nRoses are red, vio" | ||
| # reference https://github.com/huggingface/transformers/blob/68eb1a9a6353911f491b1c8139eb73d052a8e9b9/tests/models/gpt_oss/test_modeling_gpt_oss.py#L462 | ||
| original_output_120 = "Roses are red, violets are blue,\nI am a language model, not a human being" | ||
|
|
||
|
|
||
| def do_sample(llm: LLM, original_output: str, rtol: float, atol: float, max_num_seqs: int) -> list[str]: | ||
| prompts = [ | ||
| "Roses are red, violets", | ||
| ] * max_num_seqs | ||
|
|
||
| sampling_params = vllm.SamplingParams( | ||
| temperature=0, | ||
| max_tokens=20, | ||
| ) | ||
| outputs = llm.generate(prompts, sampling_params) | ||
|
|
||
| # Print the outputs. | ||
| generated_texts: list[str] = [] | ||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
| generated_texts.append(generated_text) | ||
| print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
|
|
||
| assert prompts[0] + generated_texts[0] == original_output, "Generated text does not match the expected output." | ||
| return generated_texts | ||
|
|
||
|
|
||
| expected_output = [ | ||
| "are blue, I love you, and I love you too.\n\nRoses are red, vio" # noqa: E501 | ||
| ] | ||
|
|
||
|
|
||
| def _test_gpt_oss(): | ||
| """Main function that sets up and runs the prompt processing.""" | ||
| if RUN_20B_MODEL: | ||
| llm = LLM( | ||
| MODEL_PATH, | ||
| max_num_seqs=8, | ||
| dtype='bfloat16', | ||
| enforce_eager=True, | ||
| max_model_len=512, | ||
| max_num_batched_tokens=2048, | ||
| tensor_parallel_size=1, | ||
| ) | ||
| generated_texts = do_sample(llm, original_output=original_output, rtol=1e-01, atol=1e-01, max_num_seqs=1) | ||
| else: | ||
| llm = LLM( | ||
| MODEL_PATH_120, | ||
| max_num_seqs=8, | ||
| dtype='bfloat16', | ||
| enforce_eager=False, | ||
| max_model_len=512, | ||
| max_num_batched_tokens=2048, | ||
| tensor_parallel_size=4, | ||
| ) | ||
| generated_texts = do_sample(llm, original_output=original_output_120, rtol=1e-01, atol=1e-01, max_num_seqs=1) | ||
| assert generated_texts == expected_output | ||
|
|
||
|
|
||
| def test_gpt_oss_1x(): | ||
| os.environ['VLLM_PROMPT_USE_FUSEDSDPA'] = '0' | ||
| _test_gpt_oss() | ||
| os.environ['VLLM_PROMPT_USE_FUSEDSDPA'] = '1' | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -167,6 +167,7 @@ def __init__( | |||||||||||||
| qk_head_dim: int, | ||||||||||||||
| v_head_dim: int, | ||||||||||||||
| kv_b_proj: ColumnParallelLinear, | ||||||||||||||
| sinks: Optional[torch.Tensor] = None, | ||||||||||||||
| **kwargs, | ||||||||||||||
| ) -> None: | ||||||||||||||
| torch.nn.Module.__init__(self) | ||||||||||||||
|
|
@@ -218,6 +219,11 @@ def __init__( | |||||||||||||
| "encoder/decoder cross-attention " | ||||||||||||||
| "are not implemented for " | ||||||||||||||
| "TritonMLAImpl") | ||||||||||||||
| self.sinks = sinks | ||||||||||||||
| if sinks is not None: | ||||||||||||||
| assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of " | ||||||||||||||
| f"heads in the layer. Sinks shape: {sinks.shape}, " | ||||||||||||||
| f"num_heads: {num_heads}.") | ||||||||||||||
|
|
||||||||||||||
| def forward( | ||||||||||||||
| self, | ||||||||||||||
|
|
@@ -389,6 +395,7 @@ def __init__( | |||||||||||||
| attn_type: str = AttentionType.DECODER, | ||||||||||||||
| kv_sharing_target_layer_name: Optional[str] = None, | ||||||||||||||
| use_irope: bool = False, | ||||||||||||||
| sinks: Optional[torch.Tensor] = None, | ||||||||||||||
| ) -> None: | ||||||||||||||
| super(AttentionImpl, self).__init__() | ||||||||||||||
| if kv_sharing_target_layer_name is not None: | ||||||||||||||
|
|
@@ -453,6 +460,11 @@ def __init__( | |||||||||||||
| raise NotImplementedError("Encoder self-attention " | ||||||||||||||
| "is not implemented for " | ||||||||||||||
| "HPUAttentionImpl") | ||||||||||||||
| self.sinks = sinks | ||||||||||||||
| if sinks is not None: | ||||||||||||||
| assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of " | ||||||||||||||
| f"heads in the layer. Sinks shape: {sinks.shape}, " | ||||||||||||||
| f"num_heads: {num_heads}.") | ||||||||||||||
|
|
||||||||||||||
| def _maybe_init_alibi_biases( | ||||||||||||||
| self, | ||||||||||||||
|
|
@@ -534,6 +546,12 @@ def forward( | |||||||||||||
| # Reshape the input keys and values and store them in the cache. | ||||||||||||||
| # If kv_cache is not provided, the new key and value tensors are | ||||||||||||||
| # not cached. This happens during the initial memory profiling run. | ||||||||||||||
| if key.dtype != key_cache.dtype: | ||||||||||||||
| key = key.to(key_cache.dtype) | ||||||||||||||
| if value.dtype != value_cache.dtype: | ||||||||||||||
| value = value.to(value_cache.dtype) | ||||||||||||||
| if query.dtype != key.dtype: | ||||||||||||||
| query = query.to(key.dtype) | ||||||||||||||
| key_cache = self.k_cache(key, key_cache, slot_mapping) | ||||||||||||||
| value_cache = self.v_cache(value, value_cache, slot_mapping) | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -570,13 +588,17 @@ def forward( | |||||||||||||
|
|
||||||||||||||
| common_args = self.common_attention_args(block_list, key_cache, value_cache, attn_metadata.block_size) | ||||||||||||||
|
|
||||||||||||||
| if self.sliding_window and hasattr(attn_metadata, | ||||||||||||||
| 'window_attn_bias') and attn_metadata.window_attn_bias is not None \ | ||||||||||||||
| and self.prefill_impl == 'naive_impl': | ||||||||||||||
| attn_bias = attn_metadata.window_attn_bias | ||||||||||||||
| if self.sliding_window: | ||||||||||||||
| if hasattr(attn_metadata, 'window_attn_bias') and attn_metadata.window_attn_bias is not None: | ||||||||||||||
| attn_bias = attn_metadata.window_attn_bias | ||||||||||||||
| else: | ||||||||||||||
| attn_bias = None | ||||||||||||||
| window_size = (self.sliding_window, 0) | ||||||||||||||
| common_args['window_size'] = window_size | ||||||||||||||
| # TODO - change 128 to proper window size | ||||||||||||||
|
||||||||||||||
| # TODO - change 128 to proper window size | |
| # TODO: change 128 to proper window size |
Outdated
Copilot
AI
Nov 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Magic number 128 used for window size. Consider defining this as a named constant or deriving it from self.sliding_window as indicated by the TODO comment.
| # TODO - change 128 to proper window size | |
| window_size = ( | |
| 128, | |
| # Use self.sliding_window for window size instead of hardcoded 128 | |
| window_size = ( | |
| self.sliding_window, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assertion compares single generated text with expected output incorrectly. The function returns a list but only validates the first element earlier. This assertion will fail unless
generated_textscontains exactly one element matchingexpected_output[0]. Considerassert generated_texts[0] == expected_output[0]orassert generated_texts == expected_outputafter validating the list length.