diff --git a/language/gpt-oss-120b/.gitignore b/language/gpt-oss-120b/.gitignore new file mode 100644 index 0000000000..78317dc552 --- /dev/null +++ b/language/gpt-oss-120b/.gitignore @@ -0,0 +1,3 @@ +*venv* +*.pkl +*.csv \ No newline at end of file diff --git a/language/gpt-oss-120b/README.md b/language/gpt-oss-120b/README.md new file mode 100644 index 0000000000..60e611344b --- /dev/null +++ b/language/gpt-oss-120b/README.md @@ -0,0 +1,154 @@ +# MLPerf Inference reference implementation for GPT-OSS-120B +This is the reference implementation for GPT-OSS-120B. This is a proposal and is a WIP. + +## Model and Dataset download + +#### TODO: Replace this with mlc download link when available + +* Model: `openai/gpt-oss-120b`, commit id: [`b5c939d`](https://huggingface.co/openai/gpt-oss-120b/tree/b5c939de8f754692c1647ca79fbf85e8c1e70f8a) +* Dataset: Please request access at [this link](https://drive.google.com/drive/folders/1DCfEXHqe69okrqKbSyV-8VUw413JqpPY?usp=drive_link) - **this is a tentative dataset** + +Datasets are now provided in **Parquet format** (recommended) for better performance and smaller file size (50% smaller than pickle). Pickle format is still supported for backward compatibility. + +## Environment setup +Work on reference implementation is done using the sglang containers at [https://hub.docker.com/r/lmsysorg/sglang/tags](https://hub.docker.com/r/lmsysorg/sglang/tags). For enroot setup, a script is provided under [`setup_enroot.sh`](./setup_enroot.sh). For all sections below, we shall assume this environment is instantiated. + +Once in the environment, install additional requirements using [`setup.sh`](./setup.sh): +```bash +./setup.sh +``` + +## Running the reference implementation: SGLang +Use [`./sglang/run_server.sh`](./sglang/run_server.sh) to launch an SGLang server hosting `gpt-oss-120b`. + +### Run the server +```bash +./run_server.sh \ + --model_path path/to/gpt-oss-120b/model \ + --dp N \ + --stream_interval 100 \ + --eagle_path optional/path/to/eagle/head +``` +The script uses `python3 -m sglang.launch_server` tp instantiate the model, with `tp=pp=ep=1`, and `dp` as specified. + +You may also use docker: +```bash +docker run --runtime nvidia --gpus all --net host \ + -v ${HF_HOME}:/root/.cache/huggingface \ + --env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \ + --ipc=host lmsysorg/sglang:latest \ + python3 -m sglang.launch_server --model-path ${MODEL_NAME} \ + --host 0.0.0.0 --port 3000 --data-parallel-size=1 --max-running-requests 512 \ + --mem-fraction-static 0.85 --chunked-prefill-size 16384 --ep-size=1 \ + --enable-metrics --stream-interval 500 +``` + +Then, run a benchmark script that uses the client to send/recv requests. +### Run the inference + +**Note:** All scripts now support both Parquet (`.parquet`) and Pickle (`.pkl`) formats for dataset files. Parquet is recommended as it offers: +- 50% smaller file size +- Faster loading times +- Cross-language compatibility +- Type-safe schema preservation + +Example usage: +```bash +# first, install loadgen +pip install $(git rev-parse --show-toplevel)/loadgen + +# Using Parquet format (recommended) +python3 run_mlperf.py \ + --scenario offline \ + --input-file /path/to/dataset.parquet \ + --accuracy + +# Using Pickle format (backward compatible) +python3 run_mlperf.py \ + --scenario offline \ + --input-file /path/to/dataset.pkl \ + --accuracy +``` + +Full command-line options: +```bash +python3 run_mlperf.py --help +usage: run_mlperf.py [-h] [--scenario {offline,server}] --input-file INPUT_FILE [--max-samples MAX_SAMPLES] [--mlperf-conf MLPERF_CONF] + [--user-conf USER_CONF] [--accuracy] [--output-dir OUTPUT_DIR] [--backend {sglang}] [--server-url SERVER_URL] + [--generation-config GENERATION_CONFIG] [--max-new-tokens MAX_NEW_TOKENS] [--num-workers NUM_WORKERS] + [--max-concurrency MAX_CONCURRENCY] + +Run MLPerf inference benchmarks for gpt-oss + +options: + -h, --help show this help message and exit + --scenario {offline,server} + MLPerf scenario mode + --input-file INPUT_FILE + Path to tokenized dataset (parquet or pickle file) + --max-samples MAX_SAMPLES + Maximum number of samples to use (None for all) + --mlperf-conf MLPERF_CONF + Path to MLPerf configuration file + --user-conf USER_CONF + Path to user configuration file + --accuracy Run accuracy mode instead of performance + --output-dir OUTPUT_DIR + Directory for MLPerf output logs + --backend {sglang} Backend to use for inference + --server-url SERVER_URL + Server URL for backend (SGLang) + --generation-config GENERATION_CONFIG + Path to generation configuration JSON file + --max-new-tokens MAX_NEW_TOKENS + Override max_new_tokens from generation config (default: use value from config) + --num-workers NUM_WORKERS + Number of worker threads (for server scenario) + --max-concurrency MAX_CONCURRENCY + Maximum concurrent requests to backend (SGLang handles batching internally) + +``` + +### Evaluate the accuracy +Run `run_mlperf.py` with `--accuracy`, and then use the generated `mlperf_log_accuracy.json` to evaluate the accuracy of the run. + +Example usage: +```bash +# Using Parquet format (recommended) +python3 eval_mlperf_accuracy.py \ + --mlperf-log mlperf_results/offline/accuracy/mlperf_log_accuracy.json \ + --reference-data /path/to/acc_eval_inputs.parquet \ + --tokenizer openai/gpt-oss-120b + +# Using Pickle format (backward compatible) +python3 eval_mlperf_accuracy.py \ + --mlperf-log mlperf_results/offline/accuracy/mlperf_log_accuracy.json \ + --reference-data /path/to/acc_eval_inputs.pkl \ + --tokenizer openai/gpt-oss-120b +``` + +Full command-line options: +```bash +python3 eval_mlperf_accuracy.py --help +usage: eval_mlperf_accuracy.py [-h] --mlperf-log MLPERF_LOG --reference-data REFERENCE_DATA [--tokenizer TOKENIZER] [--output-file OUTPUT_FILE] + [--save-outputs SAVE_OUTPUTS] [--num-lcb-workers NUM_LCB_WORKERS] [--verbose] + +Evaluate MLPerf accuracy logs for gpt-oss-120b + +options: + -h, --help show this help message and exit + --mlperf-log MLPERF_LOG + Path to mlperf_log_accuracy.json + --reference-data REFERENCE_DATA + Path to reference parquet or pickle file (DataFrame with dataset, ground_truth, etc.) + --tokenizer TOKENIZER + HuggingFace tokenizer name or path + --output-file OUTPUT_FILE + Output JSON file for results (optional) + --save-outputs SAVE_OUTPUTS + Save detokenized outputs to pickle file (ordered by qsl_idx) for debugging + --num-lcb-workers NUM_LCB_WORKERS + Number of parallel workers for LiveCodeBench evaluation (default: 64) + --verbose Verbose logging + +``` diff --git a/language/gpt-oss-120b/backends/__init__.py b/language/gpt-oss-120b/backends/__init__.py new file mode 100644 index 0000000000..3f68dc171c --- /dev/null +++ b/language/gpt-oss-120b/backends/__init__.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 +"""Backend implementations for gpt-oss inference.""" + +from .base_backend import BaseBackend +from .sglang_backend import SGLangBackend + +__all__ = [ + "BaseBackend", + "SGLangBackend", +] diff --git a/language/gpt-oss-120b/backends/base_backend.py b/language/gpt-oss-120b/backends/base_backend.py new file mode 100644 index 0000000000..228de1ced8 --- /dev/null +++ b/language/gpt-oss-120b/backends/base_backend.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +"""Base backend class for gpt-oss inference.""" + +import abc +import logging +from typing import List, Dict, Any, Optional + +logger = logging.getLogger(__name__) + + +class BaseBackend(abc.ABC): + """Abstract base class for inference backends. + + All backends must implement this interface to work with the MLPerf SUT. + """ + + def __init__(self, config: Optional[Dict[str, Any]] = None): + """Initialize the backend. + + Args: + config: Optional configuration dictionary + """ + self.config = config or {} + self.initialized = False + logger.info(f"Initializing {self.__class__.__name__}") + + @abc.abstractmethod + def initialize(self) -> None: + """Initialize the backend (load model, connect to server, etc.).""" + raise NotImplementedError("Subclasses must implement initialize()") + + @abc.abstractmethod + def generate( + self, + prompts: List[List[int]], + max_tokens: int = 100, + temperature: float = 0.001, + top_k: int = 1, + top_p: float = 1.0, + **kwargs + ) -> List[Dict[str, Any]]: + """Generate responses for a batch of prompts. + + Args: + prompts: List of token ID sequences + max_tokens: Maximum tokens to generate per prompt + temperature: Sampling temperature + top_k: Top-k sampling parameter + top_p: Top-p (nucleus) sampling parameter + **kwargs: Additional backend-specific parameters + + Returns: + List of response dictionaries with keys: + - output_ids: List of generated token IDs + - output_text: Generated text (optional) + - metadata: Additional metadata (latencies, etc.) + """ + raise NotImplementedError("Subclasses must implement generate()") + + @abc.abstractmethod + def cleanup(self) -> None: + """Clean up backend resources.""" + raise NotImplementedError("Subclasses must implement cleanup()") + + def __enter__(self): + """Context manager entry.""" + self.initialize() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.cleanup() + + @property + def is_initialized(self) -> bool: + """Check if backend is initialized.""" + return self.initialized diff --git a/language/gpt-oss-120b/backends/sglang_backend.py b/language/gpt-oss-120b/backends/sglang_backend.py new file mode 100644 index 0000000000..83fd12f3ad --- /dev/null +++ b/language/gpt-oss-120b/backends/sglang_backend.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 +"""SGLang backend implementation for gpt-oss.""" + +import asyncio +import json +import logging +import requests +import time +from typing import List, Dict, Any, Optional, AsyncIterator +import aiohttp +from .base_backend import BaseBackend + +logger = logging.getLogger(__name__) + + +class SGLangBackend(BaseBackend): + """SGLang inference backend using HTTP API. + + Connects to an SGLang server running the gpt-oss model. + """ + + def __init__( + self, + server_url: str = "http://localhost:30000", + timeout: int = 1200, + max_pool_size: int = 2000, # Default pool size for high concurrency + **kwargs + ): + """Initialize SGLang backend. + + Args: + server_url: URL of the SGLang server + timeout: Request timeout in seconds + max_pool_size: Maximum connection pool size (should be >= max_concurrency) + **kwargs: Additional configuration + """ + config = { + "server_url": server_url, + "timeout": timeout, + "max_pool_size": max_pool_size, + **kwargs + } + super().__init__(config) + self.server_url = server_url + self.timeout = timeout + self.max_pool_size = max_pool_size + self.session = None + + def initialize(self) -> None: + """Initialize connection to SGLang server.""" + if self.initialized: + logger.warning("Backend already initialized") + return + + logger.info(f"Connecting to SGLang server at {self.server_url}") + logger.info( + f"Configuring connection pool with max_pool_size={self.max_pool_size}") + # Create session with larger connection pool for high concurrency + # Default pool size is 10, but we may have 100s-1000s of concurrent + # requests + self.session = requests.Session() + + # Increase connection pool size to support high concurrency + # pool_maxsize should be >= max_concurrency to avoid "pool is full" + # warnings + adapter = requests.adapters.HTTPAdapter( + # Number of connection pools to cache + pool_connections=min(100, self.max_pool_size // 10), + pool_maxsize=self.max_pool_size, # Maximum number of connections in the pool + max_retries=3, # Retry failed requests + # Don't block when pool is full, create new connections + pool_block=False + ) + self.session.mount('http://', adapter) + self.session.mount('https://', adapter) + + # Test connection with a simple request + try: + test_response = self._send_request( + input_ids=[1, 2, 3], + max_tokens=5, + temperature=0.001, + top_k=1, + top_p=1.0 + ) + if "error" in test_response: + raise ConnectionError( + f"Failed to connect to SGLang server: {test_response['error']}" + ) + logger.info("Successfully connected to SGLang server") + self.initialized = True + except Exception as e: + logger.error(f"Failed to initialize SGLang backend: {e}") + raise + + def _send_request( + self, + input_ids: List[int], + max_tokens: int, + temperature: float, + top_k: int, + top_p: float + ) -> Dict[str, Any]: + """Send a single request to the SGLang server. + + Args: + input_ids: Token IDs for the prompt + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_k: Top-k parameter + top_p: Top-p parameter + + Returns: + Response dictionary from the server + """ + payload = { + "input_ids": input_ids, + "sampling_params": { + "max_new_tokens": max_tokens, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + } + } + + try: + response = self.session.post( + f"{self.server_url}/generate", + json=payload, + timeout=self.timeout, + ) + if response.status_code == 200: + return response.json() + else: + logger.error( + f"Request failed with status {response.status_code}: {response.text}" + ) + return {"error": f"HTTP {response.status_code}: {response.text}"} + except requests.exceptions.RequestException as e: + logger.error(f"Request failed: {e}") + return {"error": str(e)} + + def generate( + self, + prompts: List[List[int]], + max_tokens: int = 100, + temperature: float = 0.001, + top_k: int = 1, + top_p: float = 1.0, + **kwargs + ) -> List[Dict[str, Any]]: + """Generate responses for a batch of prompts. + + Args: + prompts: List of token ID sequences + max_tokens: Maximum tokens to generate per prompt + temperature: Sampling temperature + top_k: Top-k sampling parameter + top_p: Top-p (nucleus) sampling parameter + **kwargs: Additional parameters (ignored) + + Returns: + List of response dictionaries with keys: + - output_ids: List of generated token IDs + - output_text: Generated text (if available) + - metadata: Additional metadata (latencies, etc.) + """ + if not self.initialized: + raise RuntimeError( + "Backend not initialized. Call initialize() first.") + + results = [] + for prompt_ids in prompts: + start_time = time.time() + response = self._send_request( + input_ids=prompt_ids, + max_tokens=max_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p + ) + end_time = time.time() + latency = end_time - start_time + + # Extract output_ids from response + output_ids = [] + output_text = "" + if "error" not in response: + output_ids = response.get("output_ids", []) + output_text = response.get("text", "") + + result = { + "output_ids": output_ids, + "output_text": output_text, + "metadata": { + "latency": latency, + "completion_tokens": response.get("meta_info", {}).get( + "completion_tokens", len(output_ids) + ), + "error": response.get("error"), + } + } + results.append(result) + + return results + + async def generate_stream( + self, + input_ids: List[int], + max_tokens: int = 100, + temperature: float = 0.001, + top_k: int = 1, + top_p: float = 1.0, + **kwargs + ) -> AsyncIterator[Dict[str, Any]]: + """Generate response with streaming support. + + Yields incremental responses as tokens are generated. + + Args: + input_ids: Token IDs for the prompt + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_k: Top-k parameter + top_p: Top-p parameter + + Yields: + Dict with: + - delta_token_ids: List of new token IDs in this chunk + - delta_text: New text in this chunk + - is_first_token: True if this is the first token + - is_finished: True if generation is complete + - accumulated_token_ids: All tokens generated so far + - metadata: Additional info (TTFT, completion_tokens, etc.) + + Note: + SGLang's streaming API behavior: + - Returns 'output_ids', 'text', and 'meta_info' in each chunk + - 'output_ids' can have retractions (length can decrease between chunks) + - 'meta_info.completion_tokens' is the RELIABLE cumulative token count + - 'finish_reason' in meta_info indicates completion (not a 'finished' flag) + - We use completion_tokens for accurate LoadGen token/sec metrics + """ + if not self.initialized: + raise RuntimeError( + "Backend not initialized. Call initialize() first.") + + payload = { + "input_ids": input_ids, + "sampling_params": { + "max_new_tokens": max_tokens, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + }, + "stream": True # Enable streaming + } + + start_time = time.time() + first_token_time = None + accumulated_token_ids = [] + accumulated_text = "" + is_first = True + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.server_url}/generate", + json=payload, + timeout=aiohttp.ClientTimeout(total=self.timeout) + ) as response: + if response.status != 200: + error_text = await response.text() + logger.error( + f"Streaming request failed: {response.status} - {error_text}") + yield { + "delta_token_ids": [], + "delta_text": "", + "is_first_token": False, + "is_finished": True, + "accumulated_token_ids": [], + "error": f"HTTP {response.status}: {error_text}", + "metadata": {} + } + return + + # Read streaming response + async for line in response.content: + if not line: + continue + + # SGLang sends data as "data: {...}\n\n" + line_str = line.decode('utf-8').strip() + if not line_str.startswith('data:'): + continue + + try: + # Remove "data:" prefix + json_str = line_str[5:].strip() + if json_str == '[DONE]': + break + + chunk = json.loads(json_str) + + # Extract text delta + delta_text = chunk.get("text", "") + + # Check if this is the final chunk + # SGLang uses 'finish_reason' in meta_info, not + # 'finished' flag + meta_info = chunk.get("meta_info", {}) + finish_reason = meta_info.get("finish_reason") + is_finished = ( + finish_reason is not None and finish_reason != "null") or chunk.get( + "finished", False) + + # Extract token information from chunk + # SGLang's output_ids can have retractions, so use meta_info.completion_tokens + # which is the reliable cumulative count + chunk_output_ids = chunk.get("output_ids", []) + completion_tokens = meta_info.get( + "completion_tokens", 0) + + if completion_tokens > 0: + # Use completion_tokens as the authoritative + # count + previous_count = len(accumulated_token_ids) + + if completion_tokens > previous_count: + # New tokens generated + num_new_tokens = completion_tokens - previous_count + + if chunk_output_ids and len( + chunk_output_ids) >= num_new_tokens: + # Use actual token IDs from chunk + delta_token_ids = chunk_output_ids[-num_new_tokens:] if num_new_tokens > 0 else [ + ] + else: + # Fallback: create placeholder tokens + # for counting + delta_token_ids = list( + range(previous_count, completion_tokens)) + + accumulated_token_ids.extend( + delta_token_ids) + else: + delta_token_ids = [] + + else: + # No completion_tokens - fallback to output_ids + # or text estimation + if chunk_output_ids: + delta_token_ids = chunk_output_ids + accumulated_token_ids.extend( + delta_token_ids) + elif delta_text: + # Estimate from text length + estimated_tokens = max( + 1, len(delta_text) // 4) + delta_token_ids = [0] * estimated_tokens + accumulated_token_ids.extend( + delta_token_ids) + else: + delta_token_ids = [] + + # Accumulate text + if delta_text: + accumulated_text += delta_text + + # Mark first token timing + if is_first and (delta_token_ids or delta_text): + first_token_time = time.time() + is_first = False + + yield { + "delta_token_ids": delta_token_ids, + "delta_text": delta_text, + "is_first_token": (first_token_time is not None and is_first is False and len(accumulated_token_ids) <= len(delta_token_ids)), + "is_finished": is_finished, + "accumulated_token_ids": accumulated_token_ids.copy(), + "accumulated_text": accumulated_text, + "metadata": { + "ttft_ms": (first_token_time - start_time) * 1000 if first_token_time else None, + "latency_ms": (time.time() - start_time) * 1000, + **chunk.get("meta_info", {}) + } + } + + if is_finished: + break + + except json.JSONDecodeError as e: + logger.warning( + f"Failed to parse streaming chunk: {e}") + continue + + except asyncio.TimeoutError: + logger.error(f"Streaming request timed out after {self.timeout}s") + yield { + "delta_token_ids": [], + "delta_text": "", + "is_first_token": False, + "is_finished": True, + "accumulated_token_ids": accumulated_token_ids, + "error": "Timeout", + "metadata": {} + } + except Exception as e: + logger.error(f"Streaming request failed: {e}", exc_info=True) + yield { + "delta_token_ids": [], + "delta_text": "", + "is_first_token": False, + "is_finished": True, + "accumulated_token_ids": accumulated_token_ids, + "error": str(e), + "metadata": {} + } + + def cleanup(self) -> None: + """Clean up backend resources.""" + if self.session: + self.session.close() + self.session = None + self.initialized = False + logger.info("SGLang backend cleaned up") diff --git a/language/gpt-oss-120b/eval_accuracy.py b/language/gpt-oss-120b/eval_accuracy.py new file mode 100644 index 0000000000..0169f43c36 --- /dev/null +++ b/language/gpt-oss-120b/eval_accuracy.py @@ -0,0 +1,1096 @@ +#!/usr/bin/env python3 +""" +Standalone evaluation script for mlperf-inference deepseek-r1 dataset. + +Expected input format (pickle file with DataFrame): +- model_output: The model's response text +- tok_model_output_len: The length of the model's response tokens +- ground_truth: The expected answer +- dataset: Dataset name (e.g., 'gpqa', 'mmlu_pro', 'math500', 'livecodebench', 'aime') +- question: The question text + +Output adds columns: +- extracted_answer: Parsed answer from model output +- prompt_accuracy: 100.0 if correct, 0.0 if incorrect +- evaluation_details: Detailed evaluation explanation +""" + +import sys +import os +import argparse +import logging +import pickle +import re +import shutil +import time +from functools import lru_cache +from typing import Dict, Any, Optional, Tuple, Union +import pandas as pd +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError +import multiprocessing +from pathlib import Path +from contextlib import redirect_stdout, redirect_stderr + +# MLPerf log processing imports +import numpy as np +from transformers import AutoTokenizer + +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ============================================================================= +# Input Validation +# ============================================================================= + + +def detect_pass_k(df: pd.DataFrame) -> int: + """Detect if DataFrame has pass@k format and return k. + + Returns: + Number of passes (k) if pass@k format detected, otherwise 1 + """ + # Check for model_output_0, model_output_1, etc. + pass_k = 0 + while f'model_output_{pass_k}' in df.columns: + pass_k += 1 + + # If no _0 suffix found, check for single model_output column + if pass_k == 0 and 'model_output' in df.columns: + return 1 + + return pass_k + + +def validate_dataframe(df: pd.DataFrame) -> None: + """Validate input DataFrame has required columns.""" + if not isinstance(df, pd.DataFrame): + raise ValueError("Input must be a pandas DataFrame") + + # Detect pass@k format + pass_k = detect_pass_k(df) + + if pass_k == 0: + raise ValueError( + "No model_output columns found (expected 'model_output' or 'model_output_0', 'model_output_1', etc.)") + + # Check for dataset column + if 'dataset' not in df.columns: + raise ValueError("Missing required column: 'dataset'") + + # Check for tok_model_output_len (either single or with suffixes) + has_tok_len = all( + f'tok_model_output_len_{i}' in df.columns for i in range(pass_k)) + + if not has_tok_len: + raise ValueError("Missing required tok_model_output_len column(s)") + + # Check for ground_truth or rubrics depending on dataset + has_ground_truth = 'ground_truth' in df.columns + has_rubrics = 'rubrics' in df.columns + + if not has_ground_truth and not has_rubrics: + raise ValueError( + "DataFrame must have either 'ground_truth' or 'rubrics' column") + + +def validate_text_input(text: Any) -> str: + """Validate and convert text input to string.""" + if pd.isna(text) or text is None: + return "" + return str(text).strip() + + +def validate_dataset_name(dataset: Any) -> str: + """Validate dataset name.""" + if pd.isna(dataset) or not dataset: + raise ValueError("Dataset name cannot be empty") + return str(dataset).lower() + + +# ============================================================================= +# Harmony Format Extraction +# ============================================================================= + + +def extract_final_section(text: str) -> str: + """Extract content from the <|channel|>final<|message|>...<|return|> section. + + The model outputs have two sections: + - <|channel|>analysis<|message|>... (reasoning, may have draft answers) + - <|channel|>final<|message|>... (actual final answer) + + This function extracts only the final section to avoid extracting + wrong answers from the analysis section. + + Uses a flexible regex to handle corrupted markers like: + - <|channel|>final 明<|message|> + - <|channel|>final537<|message|> + + Args: + text: Full model output text + + Returns: + Content of final section if found, otherwise returns original text + """ + text = validate_text_input(text) + if not text: + return "" + + # Flexible pattern to handle corrupted markers (allows chars between final + # and <|message|>) + match = re.search( + r'<\|channel\|>final[^<]*<\|message\|>(.*?)(?:<\|return\|>|$)', + text, re.DOTALL + ) + if match: + return match.group(1).strip() + + # Fallback: return original text if no final section found + return text + + +def strip_markdown_bold(text: str) -> str: + """Remove markdown bold formatting (**text**) from text. + + Args: + text: Text that may contain **bold** formatting + + Returns: + Text with bold markers removed + """ + return re.sub(r'\*\*([^*]+)\*\*', r'\1', text) + + +# ============================================================================= +# Answer Parsing Functions +# ============================================================================= + +def parse_multiple_choice(text: str, max_option: str = 'D') -> Optional[str]: + """Parse multiple choice answer (A-D or A-J). + + First extracts the final section from harmony-formatted outputs, + then parses the answer from that section only. + """ + text = validate_text_input(text) + if not text: + return None + + # Extract final section first (for harmony format) + final_section = extract_final_section(text) + + # Strip markdown bold formatting (**A** -> A) + final_section = strip_markdown_bold(final_section) + + # Clean artifacts + if final_section.startswith( + ("['", '["')) and final_section.endswith(("']", '"]')): + final_section = final_section[2:-2].strip() + + final_section = final_section.replace(r'\n', '\n').replace(r'\'', "'") + + # Try to extract from final section first + # Priority 1: Single letter answer at start of final section (common in + # harmony format) + single_letter_match = re.match( + rf'^[^a-zA-Z]*([A-{max_option}])(?:[^a-zA-Z]|$)', + final_section.strip(), re.IGNORECASE + ) + if single_letter_match: + return single_letter_match.group(1).upper() + + # Priority 2: "Answer: X" pattern in final section + answer_pattern = rf'\b(?:Answer|ANSWER)\s*[:.]?\s*([A-{max_option}])\b' + answer_match = re.search(answer_pattern, final_section, re.IGNORECASE) + if answer_match: + return answer_match.group(1).upper() + + # Priority 3: Fall back to ANSWER/FINAL ANSWER pattern in full text + # (for backwards compatibility with non-harmony outputs) + full_text = text.replace(r'\n', '\n').replace(r'\'', "'") + pattern = rf"\b(?:ANSWER|FINAL\s*ANSWER)\b\s*[:=]?\s*(?:\(?\s*([A-{max_option}])\s*\)?)(?:\s*$|[^A-Za-z])" + matches = list(re.finditer(pattern, full_text, re.IGNORECASE)) + + if matches: + return matches[-1].group(1).upper() + + # MMLU-Pro fallback: standalone letter in final section + if max_option == 'J': + fallback_matches = list(re.finditer( + rf"\b([A-{max_option}])\b", final_section, re.IGNORECASE)) + if fallback_matches: + return fallback_matches[-1].group(1).upper() + + return None + + +def parse_boxed_math(text: str) -> Optional[str]: + """Parse \\boxed{answer} format.""" + text = validate_text_input(text) + if not text: + return None + + idx = text.rfind(r"\boxed{") + if idx == -1: + return None + + # Find matching brace + depth, i = 0, idx + 7 + content_start = i + while i < len(text): + if text[i] == '{': + depth += 1 + elif text[i] == '}': + if depth == 0: + return text[content_start:i].strip() + depth -= 1 + i += 1 + return None + + +def parse_aime_answer(text: str) -> Optional[int]: + """Parse AIME integer answer (0-999).""" + text = validate_text_input(text) + if not text: + return None + + # Priority 1: \boxed{digits} + boxed_matches = list(re.finditer(r"\\boxed{\s*(\d+)\s*}", text)) + if boxed_matches: + extracted_str = boxed_matches[-1].group(1) + else: + # Priority 2: Answer: + answer_matches = list(re.finditer( + r"Answer:\s*(\d+)(?!\.)\b", text, re.IGNORECASE | re.MULTILINE)) + if not answer_matches: + return None + extracted_str = answer_matches[-1].group(1) + + try: + val = int(extracted_str) + if 0 <= val <= 999: + return val + except ValueError: + pass + + return None + + +def parse_code(text: str) -> Optional[str]: + """Parse code from ```python or plain ``` code block. + + First extracts the final section from harmony-formatted outputs, + then parses code from that section only. This avoids extracting + malformed code blocks from the analysis section. + + Priority: + 1. Code from final section (if harmony format detected) + 2. Last ```python block from full text (fallback) + 3. Last plain ``` block from full text (fallback) + """ + text = validate_text_input(text) + if not text: + return None + + # First try to extract from final section (for harmony format) + final_section = extract_final_section(text) + + # Check if we got a different final section (harmony format detected) + if final_section != text: + # Parse code from final section only + python_matches = list( + re.finditer( + r"```python(.*?)```", + final_section, + re.DOTALL)) + if python_matches: + return python_matches[-1].group(1).strip() + + plain_matches = list( + re.finditer( + r"```(.*?)```", + final_section, + re.DOTALL)) + if plain_matches: + code = plain_matches[-1].group(1).strip() + code = re.sub( + r'^(?:python|py)\s*\n', + '', + code, + flags=re.IGNORECASE) + return code + + # Fallback: search full text (for non-harmony outputs or if final section has no code) + # Try ```python blocks first (most specific) + python_matches = list(re.finditer(r"```python(.*?)```", text, re.DOTALL)) + if python_matches: + return python_matches[-1].group(1).strip() + + # Fall back to plain ``` blocks + plain_matches = list(re.finditer(r"```(.*?)```", text, re.DOTALL)) + if plain_matches: + # Get the last match + code = plain_matches[-1].group(1).strip() + # Remove language tag if present (e.g., ```python\n or ```py\n) + code = re.sub(r'^(?:python|py)\s*\n', '', code, flags=re.IGNORECASE) + return code + + return None + + +# ============================================================================= +# Answer Evaluation Functions +# ============================================================================= + +def evaluate_multiple_choice( + parsed: Optional[str], ground_truth: str, valid_options: str) -> bool: + """Evaluate multiple choice answer.""" + if not parsed or not ground_truth: + return False + + parsed = parsed.upper() + ground_truth = ground_truth.upper() + + return parsed in valid_options and parsed == ground_truth + + +def evaluate_math500(parsed: Optional[str], ground_truth: str) -> bool: + """Evaluate MATH-500 using PRM800K grader.""" + if not parsed or not ground_truth: + return False + + parsed = str(parsed).strip() + ground_truth = str(ground_truth) + + if not parsed: + return False + + # Use sys.path approach for proper module importing + workspace_path = os.path.dirname(os.path.abspath(__file__)) + prm800k_module_path = os.path.join( + workspace_path, "submodules", "prm800k", "prm800k") + + if not os.path.exists(prm800k_module_path): + raise FileNotFoundError( + f"PRM800K module not found at: {prm800k_module_path}") + + # Save current directory and sys.path + original_cwd = os.getcwd() + original_syspath = sys.path.copy() + + try: + # Add prm800k module path to sys.path + if prm800k_module_path not in sys.path: + sys.path.insert(0, prm800k_module_path) + + # Change directory as some imports might use relative paths + os.chdir(prm800k_module_path) + + # Now import should work + from grading.grader import grade_answer + result = grade_answer(given_answer=parsed, ground_truth=ground_truth) + except ImportError as e: + raise ImportError(f"Failed to import PRM800K grader: {e}") + finally: + # Always restore original directory and sys.path + os.chdir(original_cwd) + sys.path[:] = original_syspath + + return result + + +def evaluate_aime(parsed: Optional[int], ground_truth: Any) -> bool: + """Evaluate AIME integer answer.""" + if parsed is None: + return False + + try: + gt_int = int(ground_truth) + return int(parsed) == gt_int + except (ValueError, TypeError): + return False + + +@lru_cache(maxsize=1) +def load_lcb_benchmark() -> Dict[str, Any]: + """Load LiveCodeBench benchmark with caching.""" + lcb_dir = os.path.abspath(os.path.join( + os.path.dirname(__file__), "submodules", "LiveCodeBench")) + + if not os.path.isdir(lcb_dir): + raise FileNotFoundError( + f"LiveCodeBench submodule required at: {lcb_dir}") + + original_cwd = os.getcwd() + os.chdir(lcb_dir) + + if lcb_dir not in sys.path: + sys.path.insert(0, lcb_dir) + + try: + os.environ['TQDM_DISABLE'] = '1' + + from lcb_runner.utils.scenarios import Scenario + from lcb_runner.runner.scenario_router import build_prompt_benchmark + + mock_args = argparse.Namespace( + scenario=Scenario.codegeneration, release_version="release_v6", + subset="code_generation", language="python", not_fast=False, + start_date=None, end_date=None, k=[1], num_samples=1, + timeout=60, num_workers=1, num_process_evaluate=1, + model_name="standalone_eval", output_dir="/tmp", + prompt_type="custom", continue_existing=False, evaluate=True + ) + + full_benchmark, _ = build_prompt_benchmark(mock_args) + return {inst.question_id: inst for inst in full_benchmark} + + finally: + os.chdir(original_cwd) + os.environ.pop('TQDM_DISABLE', None) + + +def evaluate_livecodebench(code: Optional[str], question_id: str) -> bool: + """Evaluate LiveCodeBench code generation. + + Returns: + bool: True if all tests passed, False otherwise + """ + result, _ = evaluate_livecodebench_detailed(code, question_id) + return result + + +def evaluate_livecodebench_detailed( + code: Optional[str], question_id: str) -> Tuple[bool, str]: + """Evaluate LiveCodeBench code generation with detailed results. + + Returns: + Tuple[bool, str]: (passed, detailed_reason) + - passed: True if all tests passed, False otherwise + - detailed_reason: Description of test results or error + """ + if not code or not question_id: + return False, "No code or question_id provided" + + lcb_dir = os.path.abspath(os.path.join( + os.path.dirname(__file__), "submodules", "LiveCodeBench")) + + try: + benchmark_map = load_lcb_benchmark() + except Exception as e: + return False, f"Failed to load benchmark: {type(e).__name__}: {e}" + + instance = benchmark_map.get(question_id) + if not instance: + return False, f"Question ID '{question_id}' not found in benchmark" + + original_cwd = os.getcwd() + temp_dir = f"/tmp/temp_lcb_eval_{question_id}_{int(time.time())}" + os.makedirs(temp_dir, exist_ok=True) + + try: + os.chdir(lcb_dir) + os.environ['TQDM_DISABLE'] = '1' + + from lcb_runner.utils.scenarios import Scenario + from lcb_runner.evaluation import extract_instance_results + from lcb_runner.runner.scenario_router import sort_and_extract_save_results, get_metrics + + mock_args = argparse.Namespace( + scenario=Scenario.codegeneration, release_version="release_v6", + subset="code_generation", language="python", not_fast=False, + start_date=None, end_date=None, k=[1], num_samples=1, + timeout=60, num_workers=1, num_process_evaluate=1, + model_name="inline_handler_eval", output_dir=temp_dir, + prompt_type="custom", continue_existing=False, evaluate=True, + ) + + batch_benchmark = [instance] + batch_custom_outputs = [[code]] + + save_results = [inst.insert_output(output, output) + for inst, output in zip(batch_benchmark, batch_custom_outputs)] + + _, combined_results = sort_and_extract_save_results( + mock_args.scenario, save_results) + _, instance_results, _ = get_metrics( + mock_args.scenario, mock_args, batch_benchmark, combined_results + ) + + graded = extract_instance_results(instance_results) + passed = graded and graded[0] and graded[0][0] + + # Try to extract detailed results + detailed_reason = "" + try: + if combined_results and len(combined_results) > 0: + result_info = combined_results[0] + if hasattr(result_info, 'result') and result_info.result: + # Extract test results + test_results = result_info.result + if isinstance(test_results, dict): + detailed_reason = f"Test results: {test_results}" + elif isinstance(test_results, list): + num_passed = sum(1 for r in test_results if r) + num_total = len(test_results) + detailed_reason = f"Passed {num_passed}/{num_total} test cases" + else: + detailed_reason = f"Result: {test_results}" + elif hasattr(result_info, 'status'): + detailed_reason = f"Status: {result_info.status}" + except Exception: + pass + + if not detailed_reason: + if passed: + detailed_reason = "All tests passed" + else: + detailed_reason = "Failed one or more test cases" + + return passed, detailed_reason + + except Exception as e: + return False, f"Evaluation error: {type(e).__name__}: {str(e)[:200]}" + finally: + os.chdir(original_cwd) + shutil.rmtree(temp_dir, ignore_errors=True) + os.environ.pop('TQDM_DISABLE', None) + + +def evaluate_livecodebench_worker( + args: Tuple[str, str]) -> Tuple[str, bool, str]: + """Worker function for parallel LiveCodeBench evaluation. + + Returns: + Tuple[str, bool, str]: (question_id, passed, detailed_reason) + """ + code, question_id = args + + # Suppress all stdout/stderr from worker processes to prevent pollution + try: + with open(os.devnull, 'w') as devnull: + with redirect_stdout(devnull), redirect_stderr(devnull): + # Also set environment variable to disable tqdm + os.environ['TQDM_DISABLE'] = '1' + passed, reason = evaluate_livecodebench_detailed( + code, question_id) + return question_id, passed, reason + except Exception as e: + error_msg = f"Error evaluating {question_id}: {type(e).__name__}: {e}" + # Don't use logger here as it might output to stdout in worker process + return question_id, False, error_msg + + +# ============================================================================= +# Dataset Configuration +# ============================================================================= + +DATASET_EVALUATORS = { + 'gpqa': { + 'parse': lambda text: parse_multiple_choice(text, 'D'), + 'evaluate': lambda parsed, gt: evaluate_multiple_choice(parsed, gt, 'ABCD') + }, + 'mmlu_pro': { + 'parse': lambda text: parse_multiple_choice(text, 'J'), + 'evaluate': lambda parsed, gt: evaluate_multiple_choice(parsed, gt, 'ABCDEFGHIJ') + }, + 'math500': { + 'parse': parse_boxed_math, + 'evaluate': evaluate_math500 + }, + 'aime': { + 'parse': parse_aime_answer, + 'evaluate': evaluate_aime + }, + 'livecodebench': { + 'parse': parse_code, + 'evaluate': evaluate_livecodebench + }, + 'mmlu': { + 'parse': lambda text: parse_multiple_choice(text, 'J'), + 'evaluate': lambda parsed, gt: evaluate_multiple_choice(parsed, gt, 'ABCDEFGHIJ') + }, + +} + + +def get_evaluator(dataset_name: str) -> Dict[str, Any]: + """Get evaluator functions for dataset.""" + dataset_lower = validate_dataset_name(dataset_name) + + for key, evaluator in DATASET_EVALUATORS.items(): + if key in dataset_lower: + return evaluator + + raise ValueError(f"No evaluator found for dataset: {dataset_name}") + + +# ============================================================================= +# Main Processing Functions +# ============================================================================= + +def process_row(row: pd.Series) -> Dict[str, Any]: + """Process a single row and return extracted answer and accuracy.""" + dataset_name = validate_dataset_name(row['dataset']) + raw_output = validate_text_input(row['model_output_0']) + ground_truth = row['ground_truth'] + + evaluator = get_evaluator(dataset_name) + extracted = evaluator['parse'](raw_output) + + is_correct = False + if extracted is not None and not pd.isna(ground_truth): + is_correct = evaluator['evaluate'](extracted, ground_truth) + + return { + 'extracted_answer': extracted, + 'prompt_accuracy': 100.0 if is_correct else 0.0 + } + + +def process_dataframe(df: pd.DataFrame, + num_lcb_workers: int = 64) -> pd.DataFrame: + """Process entire dataframe with optimized batch processing. + + Args: + df: Input DataFrame to evaluate + num_lcb_workers: Maximum number of parallel workers for LiveCodeBench evaluation + + Supports both single-pass and pass@k formats: + - Single-pass: model_output -> extracted_answer, prompt_accuracy + - Pass@k: model_output_0, model_output_1, ... -> extracted_answer_0, prompt_accuracy_0, ... + and aggregated prompt_accuracy = max(prompt_accuracy_0, prompt_accuracy_1, ...) + """ + validate_dataframe(df) + + df_output = df.copy() + + # Detect pass@k + pass_k = detect_pass_k(df) + logger.info(f"Detected pass@k format with k={pass_k}") + + # Initialize columns for each pass + for pass_num in range(pass_k): + suffix = f'_{pass_num}' + df_output[f'extracted_answer{suffix}'] = None + df_output[f'prompt_accuracy{suffix}'] = 0.0 + df_output[f'evaluation_details{suffix}'] = None + + # Add aggregated columns (max across all passes) + df_output['prompt_accuracy'] = 0.0 + df_output['evaluation_details'] = None + + # Check if we have LiveCodeBench datasets to evaluate + has_livecodebench = any('livecodebench' in str(ds).lower() + for ds in df_output['dataset'].unique()) + + # Pre-load LiveCodeBench benchmark and create shared process pool for all + # LCB evaluations + lcb_executor = None + if has_livecodebench: + try: + logger.info( + "Pre-loading LiveCodeBench benchmark for worker processes...") + # Load benchmark in main process before forking - workers will + # inherit via copy-on-write + _ = load_lcb_benchmark() + logger.info("LiveCodeBench benchmark loaded successfully") + + # Create a single process pool for all LCB evaluations + max_workers = min(multiprocessing.cpu_count(), num_lcb_workers) + lcb_executor = ProcessPoolExecutor(max_workers=max_workers) + logger.info( + f"Created shared ProcessPoolExecutor with {max_workers} workers for LiveCodeBench") + except Exception as e: + logger.warning(f"Failed to pre-load LiveCodeBench benchmark: {e}") + logger.warning("Will fall back to per-evaluation loading") + + try: + # Process by dataset + for dataset_name, group_indices in tqdm(df_output.groupby('dataset').groups.items(), + desc="Processing datasets"): + evaluator = get_evaluator(dataset_name) + + # For LiveCodeBench, always use batched evaluation across all + # passes + is_livecodebench = 'livecodebench' in dataset_name.lower() + if is_livecodebench: + # Validate prerequisites for batched LCB evaluation + if lcb_executor is None: + raise RuntimeError( + "LiveCodeBench evaluation requires a shared executor, but it was not initialized. " + "This may indicate the LiveCodeBench benchmark failed to load.") + + # Parse all passes first + logger.info( + f"Parsing {len(group_indices)} rows for dataset '{dataset_name}' across {pass_k} passes") + for pass_num in range(pass_k): + suffix = f'_{pass_num}' + model_output_col = f'model_output{suffix}' + extracted_answer_col = f'extracted_answer{suffix}' + evaluation_details_col = f'evaluation_details{suffix}' + + for idx in group_indices: + row = df_output.loc[idx] + raw_output = validate_text_input(row[model_output_col]) + extracted = evaluator['parse'](raw_output) + df_output.at[idx, extracted_answer_col] = extracted + + if extracted is None or pd.isna(extracted): + df_output.at[idx, + evaluation_details_col] = "No answer extracted from model output" + + # Collect all work items from all passes + all_work_items = [] + work_item_metadata = [] # (idx, pass_num) + for pass_num in range(pass_k): + suffix = f'_{pass_num}' + extracted_answer_col = f'extracted_answer{suffix}' + for idx in group_indices: + row = df_output.loc[idx] + extracted = row.get(extracted_answer_col) + ground_truth = row.get('ground_truth') + + if extracted is not None and not pd.isna(ground_truth): + all_work_items.append((extracted, ground_truth)) + work_item_metadata.append((idx, pass_num)) + + if all_work_items: + # Submit all work at once for maximum parallelism + max_workers = min( + multiprocessing.cpu_count(), len(all_work_items), num_lcb_workers) + logger.info( + f"Evaluating {len(all_work_items)} LiveCodeBench items across {pass_k} passes with {max_workers} workers") + + future_to_metadata = { + lcb_executor.submit(evaluate_livecodebench_worker, work_item): metadata + for work_item, metadata in zip(all_work_items, work_item_metadata) + } + + # Collect results and assign to appropriate pass columns + pass_results = {i: {'correct': 0, 'total': 0} + for i in range(pass_k)} + + for future in tqdm(as_completed(future_to_metadata, timeout=1200), + total=len(future_to_metadata), + desc=f"Evaluating LiveCodeBench (all passes)"): + idx, pass_num = future_to_metadata[future] + suffix = f'_{pass_num}' + prompt_accuracy_col = f'prompt_accuracy{suffix}' + evaluation_details_col = f'evaluation_details{suffix}' + + try: + question_id, is_correct, detailed_reason = future.result( + timeout=80) + df_output.at[idx, + prompt_accuracy_col] = 100.0 if is_correct else 0.0 + df_output.at[idx, + evaluation_details_col] = detailed_reason + pass_results[pass_num]['total'] += 1 + if is_correct: + pass_results[pass_num]['correct'] += 1 + except TimeoutError: + logger.warning( + f"Timeout evaluating row {idx} pass {pass_num}: Test execution exceeded 80s timeout") + df_output.at[idx, prompt_accuracy_col] = 0.0 + df_output.at[idx, + evaluation_details_col] = "Timeout: Test execution exceeded time limit" + pass_results[pass_num]['total'] += 1 + except Exception as e: + logger.error( + f"Error evaluating row {idx} pass {pass_num}: {e}") + df_output.at[idx, prompt_accuracy_col] = 0.0 + df_output.at[idx, + evaluation_details_col] = f"Error: {e}" + pass_results[pass_num]['total'] += 1 + + # Log results for each pass + for pass_num in range(pass_k): + if pass_results[pass_num]['total'] > 0: + correct = pass_results[pass_num]['correct'] + total = pass_results[pass_num]['total'] + accuracy = correct / total * 100 + logger.info( + f"{dataset_name} pass {pass_num} results: {correct}/{total} correct ({accuracy:.1f}% accuracy)") + + else: + # Sequential pass processing for non-LCB datasets + for pass_num in range(pass_k): + suffix = f'_{pass_num}' + model_output_col = f'model_output{suffix}' + extracted_answer_col = f'extracted_answer{suffix}' + prompt_accuracy_col = f'prompt_accuracy{suffix}' + evaluation_details_col = f'evaluation_details{suffix}' + + logger.info( + f"Processing {len(group_indices)} rows for dataset '{dataset_name}', pass {pass_num}") + + # Parse answers for all rows in this dataset for this pass + for idx in group_indices: + row = df_output.loc[idx] + raw_output = validate_text_input(row[model_output_col]) + extracted = evaluator['parse'](raw_output) + df_output.at[idx, extracted_answer_col] = extracted + + # Set initial evaluation details for rows without extracted + # answers + if extracted is None or pd.isna(extracted): + df_output.at[idx, + evaluation_details_col] = "No answer extracted from model output" + + # Evaluate answers for this pass + # Sequential evaluation for all non-LCB datasets + correct_count = 0 + total_evaluated = 0 + + for idx in group_indices: + row = df_output.loc[idx] + extracted = row[extracted_answer_col] + ground_truth = row.get('ground_truth') + + if extracted is not None and not pd.isna(ground_truth): + is_correct = evaluator['evaluate']( + extracted, ground_truth) + df_output.at[idx, + prompt_accuracy_col] = 100.0 if is_correct else 0.0 + total_evaluated += 1 + if is_correct: + correct_count += 1 + + # Log results for this pass + if total_evaluated > 0: + accuracy = correct_count / total_evaluated * 100 + logger.info( + f"{dataset_name} pass {pass_num} results: {correct_count}/{total_evaluated} correct ({accuracy:.1f}% accuracy)") + + # Aggregate results across all passes (take max) + logger.info( + f"Aggregating results across {pass_k} passes for dataset '{dataset_name}'") + for idx in group_indices: + # Get all accuracy values for this row + accuracies = [] + for pass_num in range(pass_k): + acc = df_output.at[idx, f'prompt_accuracy_{pass_num}'] + accuracies.append(acc if not pd.isna(acc) else 0.0) + + # Set aggregated accuracy as max + max_accuracy = max(accuracies) + df_output.at[idx, 'prompt_accuracy'] = max_accuracy + + # Find which pass achieved max accuracy + max_pass = accuracies.index(max_accuracy) + df_output.at[idx, + 'evaluation_details'] = f"Best pass: {max_pass} (accuracy: {max_accuracy:.1f}%)" + + return df_output + finally: + # Clean up shared LiveCodeBench executor + if lcb_executor is not None: + logger.info( + "Shutting down shared LiveCodeBench ProcessPoolExecutor") + lcb_executor.shutdown(wait=True) + + +# ============================================================================= +# Unified Evaluation Utilities +# ============================================================================= + +def print_evaluation_results(df_evaluated: pd.DataFrame, + logger: Optional[logging.Logger] = None) -> Dict[str, Any]: + """Print evaluation results in a unified format. + + Args: + df_evaluated: DataFrame with evaluated results + logger: Optional logger instance (uses module logger if not provided) + + Returns: + Dictionary with evaluation statistics + """ + if logger is None: + logger = logging.getLogger(__name__) + + # Detect pass@k + pass_k = detect_pass_k(df_evaluated) + + # Calculate statistics - always use aggregated prompt_accuracy (max across + # passes) + evaluated = df_evaluated['extracted_answer_0'].notna().sum() + correct = (df_evaluated['prompt_accuracy'] > 0).sum() + accuracy = df_evaluated['prompt_accuracy'].mean() + + # Calculate average token length across all passes + all_output_lens = [] + for i in range(pass_k): + all_output_lens.extend( + df_evaluated[f'tok_model_output_len_{i}'].tolist()) + mean_output_len = float( + sum(all_output_lens) / + len(all_output_lens)) if all_output_lens else 0.0 + + # Use exact_match as the metric key + metric_key = 'exact_match' + + results = { + # 'evaluated': int(evaluated), + # 'correct': int(correct), + metric_key: float(accuracy), + 'tokens_per_sample': mean_output_len, + 'num-samples': len(df_evaluated), + 'pass_k': pass_k, + } + + # Report individual pass accuracies + for i in range(pass_k): + pass_acc = df_evaluated[f'prompt_accuracy_{i}'].mean() + results[f'{metric_key}_pass_{i}'] = float(pass_acc) + + print("\nResults\n") + print(results) + + +def process_and_save_dataframe(df: pd.DataFrame, + output_dir: Optional[Union[str, Path]] = None, + base_filename: Optional[str] = None, + num_lcb_workers: int = 64) -> Tuple[pd.DataFrame, str]: + """Process dataframe for evaluation and save the results. + + Args: + df: Input DataFrame to evaluate + output_dir: Directory to save the evaluated pickle file (defaults to same dir as source) + base_filename: Base filename for output (defaults to auto-generated) + num_lcb_workers: Maximum number of parallel workers for LiveCodeBench evaluation + + Returns: + Tuple of (evaluated_dataframe, saved_file_path) + """ + # Process the dataframe + df_evaluated = process_dataframe(df, num_lcb_workers=num_lcb_workers) + + # Determine output path + if output_dir is None: + # Try to infer from existing path info in the dataframe or use current + # directory + output_dir = Path.cwd() + else: + output_dir = Path(output_dir) + + output_dir.mkdir(parents=True, exist_ok=True) + + # Generate filename if not provided + if base_filename is None: + timestamp = time.strftime("%Y%m%d_%H%M%S") + base_filename = f"results_evaluated_{timestamp}.pkl" + elif not base_filename.endswith('_evaluated.pkl'): + # Ensure it ends with _evaluated.pkl + if base_filename.endswith('.pkl'): + base_filename = base_filename[:-4] + '_evaluated.pkl' + else: + base_filename = base_filename + '_evaluated.pkl' + + output_path = output_dir / base_filename + + # Save the evaluated dataframe + with open(output_path, 'wb') as f: + pickle.dump(df_evaluated, f) + + logger.info(f"Evaluated results saved to: {output_path}") + + return df_evaluated, str(output_path) + + +# ============================================================================= +# Main Function +# ============================================================================= + +def detect_file_type(file_path: Union[str, Path]) -> str: + """Detect whether file is MLPerf JSON or pickle format. + + Returns: + "mlperf_json" or "pickle" + """ + file_path = Path(file_path) + + # Check by extension first + if file_path.suffix.lower() == '.json': + return "mlperf_json" + elif file_path.suffix.lower() in ['.pkl', '.pickle']: + return "pickle" + + # Try to detect by content + try: + # Try reading as JSON first + with open(file_path, 'r') as f: + first_char = f.read(1) + if first_char in ['[', '{']: + # Likely JSON + return "mlperf_json" + except BaseException: + pass + + # Default to pickle + return "pickle" + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate model outputs - supports both pickle DataFrames and MLPerf JSON logs") + parser.add_argument("--input-file", required=True, + help="Input file (pickle DataFrame or MLPerf JSON log)") + parser.add_argument( + "--output-file", help="Output pickle file (defaults to _evaluated.pkl)") + parser.add_argument("--num-lcb-workers", type=int, default=64, + help="Maximum number of parallel workers for LiveCodeBench evaluation (default: 64)") + parser.add_argument("--verbose", action="store_true", + help="Verbose logging") + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + if not os.path.exists(args.input_file): + raise FileNotFoundError(f"Input file not found: {args.input_file}") + + input_path = Path(args.input_file) + + # Detect file type + file_type = detect_file_type(input_path) + logger.info(f"Detected input file type: {file_type}") + + # Determine output file path + if args.output_file: + output_path = Path(args.output_file) + output_dir = output_path.parent + output_filename = output_path.name + else: + output_dir = input_path.parent + output_filename = input_path.stem + "_evaluated.pkl" + + logger.info(f"Processing: {args.input_file}") + + # Handle pickle DataFrame format + logger.info("Processing pickle DataFrame file") + + # Load and process data + with open(args.input_file, 'rb') as f: + df = pickle.load(f) + + logger.info(f"Loaded {len(df)} rows") + + # Process and save with unified function + df_evaluated, saved_file_path = process_and_save_dataframe( + df, + output_dir=output_dir, + base_filename=output_filename, + num_lcb_workers=args.num_lcb_workers + ) + + # Print evaluation results with unified function + print_evaluation_results(df_evaluated, logger) + + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/language/gpt-oss-120b/eval_mlperf_accuracy.py b/language/gpt-oss-120b/eval_mlperf_accuracy.py new file mode 100644 index 0000000000..1267c49da4 --- /dev/null +++ b/language/gpt-oss-120b/eval_mlperf_accuracy.py @@ -0,0 +1,713 @@ +#!/usr/bin/env python3 +""" +Evaluate MLPerf accuracy logs for gpt-oss-120b. + +This script takes MLPerf accuracy JSON logs and a reference pickle file, +evaluates the outputs, and generates accuracy scores by dataset and overall. + +Usage: + python eval_mlperf_accuracy.py \ + --mlperf-log mlperf_logs_offline_x8_acc/offline/accuracy/mlperf_log_accuracy.json \ + --reference-data data/accuracy_eval_tokenized_filtered.pkl \ + --output-file accuracy_results.json +""" + +from eval_accuracy import ( + get_evaluator, validate_dataset_name, validate_text_input, DATASET_EVALUATORS, + evaluate_livecodebench_worker, load_lcb_benchmark +) +import argparse +import json +import logging +import pickle +import struct +import multiprocessing +import os +from pathlib import Path +from typing import Dict, Any, List, Tuple +from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError + +import numpy as np +import pandas as pd +from transformers import AutoTokenizer +from tqdm import tqdm + +# Import evaluation functions from the existing script +import sys +sys.path.insert(0, str(Path(__file__).parent)) + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Hardcoded repeats per dataset for final score calculation +# Final score = sum(dataset_correct / dataset_repeats) +DATASET_REPEATS = { + 'aime25': 8, + 'gpqa_diamond': 5, + 'livecodebench_v6': 3, +} + + +def load_mlperf_log(log_path: str) -> List[Dict[str, Any]]: + """Load MLPerf accuracy JSON log. + + Args: + log_path: Path to mlperf_log_accuracy.json + + Returns: + List of log entries with seq_id, qsl_idx, data (hex), token_count + """ + logger.info(f"Loading MLPerf log from {log_path}") + with open(log_path, 'r') as f: + log_data = json.load(f) + + logger.info(f"Loaded {len(log_data)} log entries") + + return log_data + + +def decode_hex_to_tokens(hex_data: str) -> List[int]: + """Decode hex string to list of token IDs (int32). + + MLPerf stores token IDs as hex-encoded int32 array. + + Args: + hex_data: Hex string like "450D0300..." + + Returns: + List of token IDs + """ + # Convert hex string to bytes + data_bytes = bytes.fromhex(hex_data) + + # Unpack as int32 array (little-endian) + num_tokens = len(data_bytes) // 4 + token_ids = struct.unpack(f'<{num_tokens}i', data_bytes) + + return list(token_ids) + + +def detokenize(token_ids: List[int], tokenizer) -> str: + """Convert token IDs to text. + + Args: + token_ids: List of integer token IDs + tokenizer: HuggingFace tokenizer + + Returns: + Decoded text string + """ + return tokenizer.decode(token_ids, skip_special_tokens=False) + + +def process_livecodebench_batch( + entries: List[Dict[str, Any]], + reference_df: pd.DataFrame, + tokenizer, + evaluator: Dict[str, Any], + lcb_executor: ProcessPoolExecutor, + dataset_name: str, + args +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """Process a batch of LiveCodeBench entries in parallel. + + Args: + entries: List of MLPerf log entries for this dataset + reference_df: Reference DataFrame + tokenizer: HuggingFace tokenizer + evaluator: Evaluator functions dict + lcb_executor: ProcessPoolExecutor for parallel evaluation + dataset_name: Dataset name + args: Command line arguments + + Returns: + Tuple of (results_list, outputs_list) + """ + # First pass: decode and parse all entries + work_items = [] + # Store (entry, qsl_idx, ref_row, token_ids, model_output) + entry_metadata = [] + + logger.info(f"Parsing {len(entries)} {dataset_name} entries...") + for entry in tqdm(entries, desc=f"Parsing {dataset_name}", unit="entry"): + seq_id = entry['seq_id'] + qsl_idx = entry['qsl_idx'] + hex_data = entry['data'] + + ref_row = reference_df.iloc[qsl_idx] + ground_truth = ref_row.get('ground_truth', None) + + # Decode tokens to text + token_ids = decode_hex_to_tokens(hex_data) + model_output = detokenize(token_ids, tokenizer) + + # Parse code from model output + extracted_code = evaluator['parse'](model_output) + + entry_metadata.append({ + 'entry': entry, + 'qsl_idx': qsl_idx, + 'ref_row': ref_row, + 'token_ids': token_ids, + 'model_output': model_output, + 'extracted_code': extracted_code, + 'ground_truth': ground_truth + }) + + # Add to work queue if code was extracted + if extracted_code is not None and not pd.isna(ground_truth): + work_items.append((extracted_code, ground_truth)) + else: + work_items.append(None) # Placeholder for skipped items + + # Second pass: batch evaluate code in parallel + logger.info( + f"Evaluating {len([w for w in work_items if w is not None])} {dataset_name} code samples with parallel workers...") + + results_list = [] + outputs_list = [] + + # Submit all work items + future_to_idx = {} + for idx, work_item in enumerate(work_items): + if work_item is not None: + future = lcb_executor.submit( + evaluate_livecodebench_worker, work_item) + future_to_idx[future] = idx + + # Collect results with progress bar + eval_results = [None] * len(work_items) + + for future in tqdm(as_completed(future_to_idx.keys(), timeout=1200), + total=len(future_to_idx), + desc=f"Evaluating {dataset_name}", + unit="sample"): + idx = future_to_idx[future] + try: + question_id, is_correct, detailed_reason = future.result( + timeout=80) + eval_results[idx] = (is_correct, detailed_reason) + except TimeoutError: + logger.warning( + f"Timeout evaluating sample {idx}: Test execution exceeded 80s timeout") + eval_results[idx] = ( + False, "Timeout: Test execution exceeded time limit") + except Exception as e: + logger.error(f"Error evaluating sample {idx}: {e}") + eval_results[idx] = (False, f"Error: {e}") + + # Third pass: compile final results + for idx, metadata in enumerate(entry_metadata): + entry = metadata['entry'] + qsl_idx = metadata['qsl_idx'] + token_ids = metadata['token_ids'] + model_output = metadata['model_output'] + extracted_code = metadata['extracted_code'] + ground_truth = metadata['ground_truth'] + + # Get evaluation result + if extracted_code is None or pd.isna(ground_truth): + is_correct = False + eval_details = "No code extracted from model output" if extracted_code is None else "No ground truth available" + else: + is_correct, eval_details = eval_results[idx] + + # Record result + result = { + 'seq_id': entry['seq_id'], + 'qsl_idx': qsl_idx, + 'dataset': dataset_name, + 'is_correct': is_correct, + 'extracted_answer': str(extracted_code)[:200] if extracted_code is not None else None, + 'ground_truth': str(ground_truth) if not pd.isna(ground_truth) else None, + 'evaluation_details': eval_details, + 'token_count': len(token_ids), + 'model_output_preview': model_output[:200] if args.verbose else None + } + results_list.append(result) + + # Store output data if requested + if args.save_outputs: + output_record = { + 'qsl_idx': qsl_idx, + 'seq_id': entry['seq_id'], + 'dataset': dataset_name, + 'ground_truth': ground_truth, + 'model_output': model_output, + 'output_token_ids': token_ids, + 'extracted_answer': extracted_code, + 'is_correct': is_correct, + 'evaluation_details': eval_details + } + outputs_list.append(output_record) + + return results_list, outputs_list + + +def evaluate_single_entry( + model_output: str, + ground_truth: str, + dataset_name: str +) -> Tuple[bool, Any, str]: + """Evaluate a single model output. + + Args: + model_output: Generated text from model + ground_truth: Expected answer + dataset_name: Dataset name (e.g., 'gpqa', 'math500') + + Returns: + Tuple of (is_correct, extracted_answer, evaluation_details) + """ + evaluator = get_evaluator(dataset_name) + + # Parse answer from model output + extracted = evaluator['parse'](model_output) + + # Evaluate correctness + is_correct = False + evaluation_details = "" + + if extracted is None or pd.isna(extracted): + evaluation_details = "No answer extracted from model output" + else: + if not pd.isna(ground_truth): + try: + is_correct = evaluator['evaluate'](extracted, ground_truth) + if is_correct: + evaluation_details = "Correct" + else: + evaluation_details = f"Incorrect (extracted: {extracted}, ground_truth: {ground_truth})" + except Exception as e: + evaluation_details = f"Evaluation error: {e}" + logger.warning(f"Error evaluating: {e}") + else: + evaluation_details = "No ground truth available" + + return is_correct, extracted, evaluation_details + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate MLPerf accuracy logs for gpt-oss-120b" + ) + parser.add_argument( + "--mlperf-log", + type=str, + required=True, + help="Path to mlperf_log_accuracy.json" + ) + parser.add_argument( + "--reference-data", + type=str, + required=True, + help="Path to reference parquet or pickle file (DataFrame with dataset, ground_truth, etc.)" + ) + parser.add_argument( + "--tokenizer", + type=str, + default="openai/gpt-oss-120b", + help="HuggingFace tokenizer name or path" + ) + parser.add_argument( + "--output-file", + type=str, + default=None, + help="Output JSON file for results (optional)" + ) + parser.add_argument( + "--save-outputs", + type=str, + default=None, + help="Save detokenized outputs to pickle file (ordered by qsl_idx) for debugging" + ) + parser.add_argument( + "--num-lcb-workers", + type=int, + default=64, + help="Number of parallel workers for LiveCodeBench evaluation (default: 64)" + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Verbose logging" + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + # Load MLPerf log + mlperf_log = load_mlperf_log(args.mlperf_log) + + # Load reference data + logger.info(f"Loading reference data from {args.reference_data}") + if args.reference_data.endswith('.parquet'): + reference_df = pd.read_parquet(args.reference_data) + logger.info("Loaded reference data from Parquet file") + elif args.reference_data.endswith('.pkl') or args.reference_data.endswith('.pickle'): + with open(args.reference_data, 'rb') as f: + reference_df = pickle.load(f) + logger.info("Loaded reference data from Pickle file") + else: + # Try parquet first, then pickle + try: + reference_df = pd.read_parquet(args.reference_data) + logger.info("Auto-detected Parquet format") + except Exception: + with open(args.reference_data, 'rb') as f: + reference_df = pickle.load(f) + logger.info("Auto-detected Pickle format") + + # Convert numpy arrays to native Python types for JSON serialization + for col in reference_df.columns: + # Check if column contains numpy arrays + if reference_df[col].dtype == object: + reference_df[col] = reference_df[col].apply( + lambda x: x.tolist() if isinstance(x, np.ndarray) else x + ) + + logger.info(f"Reference data shape: {reference_df.shape}") + logger.info(f"Reference columns: {list(reference_df.columns)}") + + # Validate required columns exist + required_columns = ['dataset', 'ground_truth'] + missing_columns = [ + col for col in required_columns if col not in reference_df.columns] + if missing_columns: + raise ValueError( + f"Reference data missing required columns: {missing_columns}") + + # Log unique datasets in reference data + if 'dataset' in reference_df.columns: + unique_datasets = reference_df['dataset'].unique() + dataset_counts = reference_df['dataset'].value_counts() + logger.info( + f"Unique datasets in reference data ({len(unique_datasets)} total):") + for ds in sorted(unique_datasets): + logger.info(f" '{ds}' ({dataset_counts[ds]} samples)") + + logger.info("\nSample rows from reference data:") + for idx in [0, 1, 2]: + if idx < len(reference_df): + logger.info( + f" Row {idx}: dataset='{reference_df.iloc[idx]['dataset']}'") + + # Show how each will be mapped to evaluators + logger.info("\nExpected Dataset → Evaluator mapping:") + for ds in sorted(unique_datasets): + try: + ds_lower = validate_dataset_name(ds) + # Find which evaluator key matches + matched_key = None + for key in DATASET_EVALUATORS.keys(): + if key in ds_lower: + matched_key = key + break + logger.info( + f" '{ds}' (normalized: '{ds_lower}') → '{matched_key}'") + except Exception as e: + logger.warning(f" '{ds}' → ERROR: {e}") + + # Load tokenizer + logger.info(f"Loading tokenizer: {args.tokenizer}") + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + + # Group MLPerf log entries by dataset + logger.info("Grouping MLPerf log entries by dataset...") + dataset_entries = defaultdict(list) + + for entry in mlperf_log: + qsl_idx = entry['qsl_idx'] + + if qsl_idx >= len(reference_df): + logger.warning( + f"qsl_idx {qsl_idx} out of range (max: {len(reference_df)-1})") + continue + + ref_row = reference_df.iloc[qsl_idx] + dataset_name = validate_dataset_name(ref_row['dataset']) + dataset_entries[dataset_name].append(entry) + + logger.info(f"Grouped entries by dataset:") + total_entries = 0 + for ds_name, entries in sorted(dataset_entries.items()): + logger.info(f" {ds_name}: {len(entries)} entries") + total_entries += len(entries) + logger.info(f"Total entries: {total_entries}") + + # Pre-load LiveCodeBench benchmark if needed + lcb_executor = None + if any('livecodebench' in ds for ds in dataset_entries.keys()): + try: + logger.info( + "Pre-loading LiveCodeBench benchmark for parallel evaluation...") + os.environ['TQDM_DISABLE'] = '1' # Disable tqdm in workers + _ = load_lcb_benchmark() + logger.info("LiveCodeBench benchmark loaded successfully") + + # Create shared ProcessPoolExecutor for all LCB evaluations + max_workers = min( + multiprocessing.cpu_count(), + args.num_lcb_workers) + lcb_executor = ProcessPoolExecutor(max_workers=max_workers) + logger.info( + f"Created ProcessPoolExecutor with {max_workers} workers for LiveCodeBench") + except Exception as e: + logger.warning(f"Failed to pre-load LiveCodeBench benchmark: {e}") + logger.warning("LiveCodeBench evaluation may be slower") + + # Process each dataset separately with its own progress bar + logger.info("\nProcessing MLPerf log entries by dataset...") + + results = [] + # Track stats per dataset (simple correct/total) + dataset_stats = defaultdict(lambda: {"correct": 0, "total": 0}) + outputs_data = [] # For saving detokenized outputs + + try: + for dataset_name in sorted(dataset_entries.keys()): + entries = dataset_entries[dataset_name] + logger.info(f"\n{'=' * 80}") + logger.info(f"Processing {dataset_name}: {len(entries)} samples") + logger.info(f"{'=' * 80}") + + evaluator = get_evaluator(dataset_name) + is_livecodebench = 'livecodebench' in dataset_name.lower() + + if is_livecodebench and lcb_executor is not None: + # Batched LiveCodeBench evaluation + results_batch, outputs_batch = process_livecodebench_batch( + entries, reference_df, tokenizer, evaluator, + lcb_executor, dataset_name, args + ) + results.extend(results_batch) + if args.save_outputs: + outputs_data.extend(outputs_batch) + + # Update stats + for res in results_batch: + dataset_stats[dataset_name]["total"] += 1 + if res['is_correct']: + dataset_stats[dataset_name]["correct"] += 1 + else: + # Sequential evaluation for non-LCB datasets + for entry in tqdm( + entries, desc=f"Evaluating {dataset_name}", unit="entry"): + seq_id = entry['seq_id'] + qsl_idx = entry['qsl_idx'] + hex_data = entry['data'] + + ref_row = reference_df.iloc[qsl_idx] + ground_truth = ref_row.get('ground_truth', None) + + # Decode tokens to text + token_ids = decode_hex_to_tokens(hex_data) + model_output = detokenize(token_ids, tokenizer) + + # Evaluate + try: + is_correct, extracted, eval_details = evaluate_single_entry( + model_output, ground_truth, dataset_name + ) + except Exception as e: + logger.warning( + f"Evaluation error for qsl_idx={qsl_idx}, dataset={dataset_name}: {e}") + is_correct = False + extracted = None + eval_details = f"Evaluation error: {e}" + + # Record result + result = { + 'seq_id': seq_id, + 'qsl_idx': qsl_idx, + 'dataset': dataset_name, + 'is_correct': is_correct, + 'extracted_answer': str(extracted) if extracted is not None else None, + 'ground_truth': str(ground_truth) if not pd.isna(ground_truth) else None, + 'evaluation_details': eval_details, + 'token_count': len(token_ids), + 'model_output_preview': model_output[:200] if args.verbose else None + } + results.append(result) + + # Store output data for pickle export + if args.save_outputs: + output_record = { + 'qsl_idx': qsl_idx, + 'seq_id': seq_id, + 'dataset': dataset_name, + 'ground_truth': ground_truth, + 'model_output': model_output, + 'output_token_ids': token_ids, + 'extracted_answer': extracted, + 'is_correct': is_correct, + 'evaluation_details': eval_details + } + outputs_data.append(output_record) + + # Update stats + dataset_stats[dataset_name]["total"] += 1 + if is_correct: + dataset_stats[dataset_name]["correct"] += 1 + + finally: + # Clean up LiveCodeBench executor + if lcb_executor is not None: + logger.info("Shutting down LiveCodeBench ProcessPoolExecutor") + lcb_executor.shutdown(wait=True) + os.environ.pop('TQDM_DISABLE', None) + + # Calculate per-dataset scores and final score + # Final score = sum(dataset_correct / dataset_repeats) + logger.info("\nCalculating final scores...") + + total_correct = sum(stats["correct"] for stats in dataset_stats.values()) + total_samples = sum(stats["total"] for stats in dataset_stats.values()) + overall_accuracy = ( + total_correct / + total_samples * + 100) if total_samples > 0 else 0.0 + + # Calculate weighted final score + final_score = 0.0 + max_score = 0.0 + final_score_components = {} + for dataset_name, stats in dataset_stats.items(): + repeats = DATASET_REPEATS.get(dataset_name, 1) + component_score = stats["correct"] / repeats + max_component_score = stats["total"] / repeats + final_score += component_score + max_score += max_component_score + final_score_components[dataset_name] = { + "correct": stats["correct"], + "total": stats["total"], + "repeats": repeats, + "component_score": component_score, + "max_component_score": max_component_score + } + + final_score_percentage = ( + final_score / + max_score * + 100) if max_score > 0 else 0.0 + + # Print results + print("\n" + "=" * 80) + print("MLPerf Accuracy Evaluation Results") + print("=" * 80) + print(f"Total samples evaluated: {total_samples}") + print( + f"Overall raw accuracy: {overall_accuracy:.2f}% ({total_correct}/{total_samples})") + print("=" * 80) + + print("\nPer-Dataset Breakdown:") + print("-" * 80) + print(f"{'Dataset':25s} {'Correct':>8s} {'Total':>8s} {'Repeats':>8s} {'Score':>10s} {'Accuracy':>10s}") + print("-" * 80) + for dataset_name in sorted(dataset_stats.keys()): + stats = dataset_stats[dataset_name] + if stats["total"] > 0: + accuracy = (stats["correct"] / stats["total"] * 100) + repeats = DATASET_REPEATS.get(dataset_name, 1) + component_score = stats["correct"] / repeats + print( + f"{dataset_name:25s} {stats['correct']:8d} {stats['total']:8d} {repeats:8d} {component_score:10.2f} {accuracy:9.2f}%") + + print("=" * 80) + print(f"\nFinal Score Calculation:") + print("-" * 80) + score_parts = [] + value_parts = [] + result_parts = [] + max_parts = [] + for dataset_name in sorted(final_score_components.keys()): + comp = final_score_components[dataset_name] + score_parts.append(f"{dataset_name}/{comp['repeats']}") + value_parts.append(f"{comp['correct']}/{comp['repeats']}") + result_parts.append(f"{comp['component_score']:.2f}") + max_parts.append(f"{comp['total']}/{comp['repeats']}") + print(f"Formula: {' + '.join(score_parts)}") + print(f"Score: = {' + '.join(value_parts)}") + print(f" = {' + '.join(result_parts)}") + print(f" = {final_score:.2f}") + print(f"Max: = {' + '.join(max_parts)}") + print(f" = {max_score:.2f}") + print( + f"\nFINAL SCORE: {final_score_percentage:.2f}% ({final_score:.2f}/{max_score:.2f})") + print("=" * 80) + + print("\n\nPrinting for submission_checker:") + print(f"\n'exact_match': {final_score}") + + # Save detokenized outputs to pickle if requested + if args.save_outputs: + logger.info(f"Saving detokenized outputs to {args.save_outputs}...") + + # Sort by qsl_idx for ordered output + outputs_data_sorted = sorted(outputs_data, key=lambda x: x['qsl_idx']) + + # Convert to DataFrame for easier inspection + outputs_df = pd.DataFrame(outputs_data_sorted) + + output_path = Path(args.save_outputs) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'wb') as f: + pickle.dump(outputs_df, f) + + logger.info( + f"Saved {len(outputs_df)} detokenized outputs (ordered by qsl_idx) to: {output_path}") + logger.info(f"Columns: {list(outputs_df.columns)}") + + # Save detailed results if requested + if args.output_file: + # Build per-dataset stats + per_dataset_stats = {} + for dataset_name, stats in dataset_stats.items(): + repeats = DATASET_REPEATS.get(dataset_name, 1) + component_score = stats["correct"] / repeats + max_component_score = stats["total"] / repeats + per_dataset_stats[dataset_name] = { + "correct": stats["correct"], + "total": stats["total"], + "accuracy": (stats["correct"] / stats["total"] * 100) if stats["total"] > 0 else 0.0, + "repeats": repeats, + "component_score": component_score, + "max_component_score": max_component_score + } + + summary = { + "total_samples": total_samples, + "total_correct": total_correct, + "overall_accuracy": overall_accuracy, + "final_score": final_score, + "max_score": max_score, + "final_score_percentage": final_score_percentage, + "dataset_repeats": DATASET_REPEATS, + "per_dataset": per_dataset_stats + } + + output_data = { + "summary": summary, + "detailed_results": results if args.verbose else None + } + + output_path = Path(args.output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w') as f: + json.dump(output_data, f, indent=2) + + logger.info(f"Results saved to: {output_path}") + + logger.info("Evaluation complete!") + + +if __name__ == "__main__": + main() diff --git a/language/gpt-oss-120b/eval_mlperf_performance.py b/language/gpt-oss-120b/eval_mlperf_performance.py new file mode 100755 index 0000000000..aa3c275f37 --- /dev/null +++ b/language/gpt-oss-120b/eval_mlperf_performance.py @@ -0,0 +1,645 @@ +#!/usr/bin/env python3 +"""Evaluate MLPerf performance logs and analyze output token lengths. + +This script reads MLPerf accuracy logs (mlperf_log_accuracy.json) and +detokenizes the hex-encoded token IDs to produce human-readable text output. +Optionally includes input prompts and reference data from a pickle file, +and generates histogram plots for token length analysis. + +Usage: + # Basic usage (outputs only) + python eval_mlperf_performance.py \ + --mlperf-log mlperf_logs/offline/accuracy/mlperf_log_accuracy.json \ + --output-file detokenized_outputs.json \ + --tokenizer openai/gpt-oss-120b + + # With reference data (includes inputs and metadata) + python eval_mlperf_performance.py \ + --mlperf-log mlperf_logs/offline/accuracy/mlperf_log_accuracy.json \ + --output-file detokenized_outputs.json \ + --reference-data data/accuracy_eval_tokenized_filtered.pkl \ + --tokenizer openai/gpt-oss-120b + + # With histogram plots (enables plotting when --plot-dir is specified) + python eval_mlperf_performance.py \ + --mlperf-log mlperf_logs/offline/accuracy/mlperf_log_accuracy.json \ + --output-file detokenized_outputs.json \ + --reference-data data/accuracy_eval_tokenized_filtered.pkl \ + --plot-dir plots + +The output JSON format (with reference data): + [ + { + "qsl_idx": 0, + "token_ids": [1, 2, 3, ...], + "text": "detokenized response text", + "num_tokens": 150, + "dataset": "gpqa", + "input_prompt": "Question: ...", + "input_token_ids": [...], + "num_input_tokens": 1024, + "ground_truth": "Answer" + }, + ... + ] +""" + +from tqdm import tqdm +from transformers import AutoTokenizer +import matplotlib.pyplot as plt +import argparse +import json +import logging +import pickle +import sys +from pathlib import Path +from typing import List, Dict, Any, Optional + +import pandas as pd +import matplotlib +matplotlib.use('Agg') # Non-interactive backend for server environments + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Parse MLPerf accuracy JSON and detokenize responses" + ) + + parser.add_argument( + "--mlperf-log", + type=str, + required=True, + help="Path to mlperf_log_accuracy.json file" + ) + + parser.add_argument( + "--output-file", + type=str, + required=True, + help="Path to output JSON file with detokenized responses" + ) + + parser.add_argument( + "--reference-data", + type=str, + default=None, + help="Path to reference parquet or pickle file (DataFrame with prompts, dataset, etc.) - optional" + ) + + parser.add_argument( + "--tokenizer", + type=str, + default="openai/gpt-oss-120b", + help="Tokenizer to use for detokenization (default: openai/gpt-oss-120b)" + ) + + parser.add_argument( + "--pretty", + action="store_true", + help="Pretty-print the output JSON with indentation" + ) + + parser.add_argument( + "--plot-dir", + type=str, + default=None, + help="Directory to save histogram plots (enables plotting if specified)" + ) + + return parser.parse_args() + + +def decode_hex_to_tokens(hex_string: str) -> List[int]: + """Decode hex-encoded byte array to list of token IDs. + + MLPerf stores token IDs as hex-encoded bytes where each token is a 4-byte + little-endian integer. + + Args: + hex_string: Hex-encoded string from MLPerf log + + Returns: + List of token IDs + """ + # Remove any whitespace + hex_string = hex_string.strip() + + # Convert hex string to bytes + byte_data = bytes.fromhex(hex_string) + + # Each token is stored as 4 bytes (int32) in little-endian format + token_ids = [] + for i in range(0, len(byte_data), 4): + if i + 4 <= len(byte_data): + # Unpack 4 bytes as little-endian int32 + token_id = int.from_bytes( + byte_data[i:i + 4], byteorder='little', signed=True) + token_ids.append(token_id) + + return token_ids + + +def parse_mlperf_log(log_path: str) -> List[Dict[str, Any]]: + """Parse MLPerf accuracy log file. + + Handles multiple formats: + - JSON array: [{"qsl_idx": 0, ...}, ...] + - JSONL: one JSON object per line + - Concatenated JSON: multiple JSON objects on same line + + Args: + log_path: Path to mlperf_log_accuracy.json + + Returns: + List of entries with qsl_idx and hex-encoded data + """ + logger.info(f"Reading MLPerf log: {log_path}") + + entries = [] + + # First try to load as a single JSON array + try: + with open(log_path, 'r') as f: + log_data = json.load(f) + if isinstance(log_data, list): + logger.info(f"Loaded {len(log_data)} entries as JSON array") + return log_data + except json.JSONDecodeError: + pass # Not a valid JSON array, try line-by-line parsing + + # Parse line by line (JSONL or concatenated JSON) + decoder = json.JSONDecoder() + with open(log_path, 'r') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + + # Try to parse as single JSON object first + try: + entry = json.loads(line) + entries.append(entry) + except json.JSONDecodeError: + # Line might have multiple concatenated JSON objects + # Extract them one by one using raw_decode + remaining = line + parsed_count = 0 + while remaining: + remaining = remaining.lstrip() + if not remaining: + break + try: + obj, end_idx = decoder.raw_decode(remaining) + entries.append(obj) + remaining = remaining[end_idx:] + parsed_count += 1 + except json.JSONDecodeError as e: + if parsed_count == 0: + logger.warning( + f"Line {line_num}: Could not parse JSON: {e}") + break + + logger.info(f"Loaded {len(entries)} entries from MLPerf log") + return entries + + +def plot_histograms( + results: List[Dict[str, Any]], + output_dir: str, + has_reference: bool = False +) -> None: + """Generate histogram plots for output token lengths and differences. + + Args: + results: List of parsed results with token lengths + output_dir: Directory to save plots + has_reference: Whether reference data is available for difference plots + """ + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + logger.info(f"Generating histogram plots in {output_dir}...") + + # Extract output token lengths + output_lengths = [r['num_tokens'] for r in results] + + # Plot 1: Output Sequence Length (OSL) Histogram + plt.figure(figsize=(12, 6)) + plt.hist( + output_lengths, + bins=50, + edgecolor='black', + alpha=0.7, + color='steelblue') + plt.xlabel('Output Token Length (OSL)', fontsize=12) + plt.ylabel('Frequency', fontsize=12) + plt.title( + f'Distribution of Output Token Lengths\n(n={len(output_lengths)}, mean={sum(output_lengths)/len(output_lengths):.1f}, median={sorted(output_lengths)[len(output_lengths)//2]})', + fontsize=14) + plt.grid(axis='y', alpha=0.3) + + # Add statistics box + stats_text = f'Min: {min(output_lengths)}\nMax: {max(output_lengths)}\nStd: {(sum((x - sum(output_lengths)/len(output_lengths))**2 for x in output_lengths) / len(output_lengths))**0.5:.1f}' + plt.text(0.98, 0.97, stats_text, transform=plt.gca().transAxes, + fontsize=10, verticalalignment='top', horizontalalignment='right', + bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) + + osl_plot_path = output_path / 'output_token_length_histogram.png' + plt.tight_layout() + plt.savefig(osl_plot_path, dpi=300, bbox_inches='tight') + plt.close() + logger.info(f"✓ Saved OSL histogram: {osl_plot_path}") + + # Plot 2: Token Length Difference Histogram (if reference data available) + if has_reference: + results_with_diff = [ + r for r in results if 'output_token_len_diff' in r] + if results_with_diff: + differences = [r['output_token_len_diff'] + for r in results_with_diff] + + plt.figure(figsize=(12, 6)) + plt.hist( + differences, + bins=50, + edgecolor='black', + alpha=0.7, + color='coral') + plt.xlabel( + 'Token Length Difference (Actual - Reference)', + fontsize=12) + plt.ylabel('Frequency', fontsize=12) + + mean_diff = sum(differences) / len(differences) + median_diff = sorted(differences)[len(differences) // 2] + plt.title( + f'Distribution of Output Token Length Differences\n(n={len(differences)}, mean={mean_diff:.1f}, median={median_diff})', + fontsize=14) + plt.grid(axis='y', alpha=0.3) + plt.axvline( + x=0, + color='red', + linestyle='--', + linewidth=2, + label='Zero difference') + + # Add statistics box + longer = sum(1 for d in differences if d > 0) + shorter = sum(1 for d in differences if d < 0) + exact = sum(1 for d in differences if d == 0) + stats_text = f'Min: {min(differences)}\nMax: {max(differences)}\nStd: {(sum((x - mean_diff)**2 for x in differences) / len(differences))**0.5:.1f}\n\nLonger: {longer} ({longer/len(differences)*100:.1f}%)\nShorter: {shorter} ({shorter/len(differences)*100:.1f}%)\nExact: {exact} ({exact/len(differences)*100:.1f}%)' + plt.text(0.98, 0.97, stats_text, transform=plt.gca().transAxes, + fontsize=9, verticalalignment='top', horizontalalignment='right', + bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5)) + + plt.legend() + + diff_plot_path = output_path / 'token_length_difference_histogram.png' + plt.tight_layout() + plt.savefig(diff_plot_path, dpi=300, bbox_inches='tight') + plt.close() + logger.info(f"✓ Saved difference histogram: {diff_plot_path}") + + # Plot 3: Combined comparison (side by side) + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) + + # Reference vs Actual + ref_lengths = [r['ref_num_output_tokens'] + for r in results_with_diff] + actual_lengths = [r['actual_num_output_tokens'] + for r in results_with_diff] + + ax1.hist([ref_lengths, actual_lengths], bins=50, label=['Reference', 'Actual'], + alpha=0.6, edgecolor='black', color=['steelblue', 'coral']) + ax1.set_xlabel('Output Token Length', fontsize=12) + ax1.set_ylabel('Frequency', fontsize=12) + ax1.set_title( + f'Reference vs Actual Output Token Lengths\n(n={len(results_with_diff)})', + fontsize=13) + ax1.legend() + ax1.grid(axis='y', alpha=0.3) + + # Scatter plot: Reference vs Actual + ax2.scatter( + ref_lengths, + actual_lengths, + alpha=0.4, + s=10, + color='purple') + ax2.plot([min(ref_lengths), max(ref_lengths)], [min(ref_lengths), max(ref_lengths)], + 'r--', linewidth=2, label='y=x (perfect match)') + ax2.set_xlabel('Reference Token Length', fontsize=12) + ax2.set_ylabel('Actual Token Length', fontsize=12) + ax2.set_title( + 'Reference vs Actual Token Lengths (Scatter)', + fontsize=13) + ax2.legend() + ax2.grid(alpha=0.3) + + comparison_plot_path = output_path / 'token_length_comparison.png' + plt.tight_layout() + plt.savefig(comparison_plot_path, dpi=300, bbox_inches='tight') + plt.close() + logger.info(f"✓ Saved comparison plot: {comparison_plot_path}") + else: + logger.warning("No samples with token length differences found") + + logger.info(f"✓ All plots saved to {output_dir}/") + + +def detokenize_responses( + entries: List[Dict[str, Any]], + tokenizer: Any, + reference_df: Optional[pd.DataFrame] = None +) -> List[Dict[str, Any]]: + """Detokenize responses from MLPerf log entries. + + When reference data is provided, input_prompt is generated by detokenizing + input token IDs from the reference data (checks: tok_input, input_token_ids, + input_tokens, tokenized_input). This shows exactly what was sent to the model + (after tokenization), not the original text prompt. + + Args: + entries: List of MLPerf log entries with hex-encoded token IDs + tokenizer: HuggingFace tokenizer instance + reference_df: Optional reference DataFrame with input prompts and metadata + + Returns: + List of dictionaries with qsl_idx, token_ids, and detokenized text + """ + logger.info("Detokenizing responses...") + + results = [] + for entry in tqdm(entries, desc="Detokenizing", unit="response"): + qsl_idx = entry.get("qsl_idx") + hex_data = entry.get("data", "") + + # Decode hex to token IDs + try: + token_ids = decode_hex_to_tokens(hex_data) + except Exception as e: + logger.error(f"Error decoding tokens for qsl_idx={qsl_idx}: {e}") + token_ids = [] + + # Detokenize to text + try: + text = tokenizer.decode(token_ids, skip_special_tokens=True) + except Exception as e: + logger.error(f"Error detokenizing qsl_idx={qsl_idx}: {e}") + text = "" + + # Build result record + result = { + "qsl_idx": qsl_idx, + "token_ids": token_ids, + "text": text, + "num_tokens": len(token_ids) + } + + # Add reference data if available + if reference_df is not None and qsl_idx < len(reference_df): + ref_row = reference_df.iloc[qsl_idx] + + # Add common fields from reference data + if 'dataset' in ref_row: + result['dataset'] = ref_row['dataset'] + + # Get input token IDs and detokenize to see what was actually sent to the model + # Check multiple possible field names for input tokens + input_token_ids = None + for field in ['tok_input', 'input_token_ids', + 'input_tokens', 'tokenized_input']: + if field in ref_row: + input_token_ids = ref_row[field] + break + + if input_token_ids is not None: + result['input_token_ids'] = input_token_ids + if isinstance(input_token_ids, list): + result['num_input_tokens'] = len(input_token_ids) + # Detokenize input tokens to show what was actually sent to + # the model + try: + result['input_prompt'] = tokenizer.decode( + input_token_ids, skip_special_tokens=False) + except Exception as e: + logger.warning( + f"Error detokenizing input tokens for qsl_idx={qsl_idx}: {e}") + result['input_prompt'] = None + else: + result['num_input_tokens'] = None + result['input_prompt'] = None + else: + # Fallback to raw prompt field if input token IDs not available + if 'prompt' in ref_row: + result['input_prompt'] = ref_row['prompt'] + elif 'input_text' in ref_row: + result['input_prompt'] = ref_row['input_text'] + elif 'text' in ref_row: + result['input_prompt'] = ref_row['text'] + + if 'ground_truth' in ref_row: + result['ground_truth'] = ref_row['ground_truth'] + + # Compute output token length difference + # Check for reference output token length in various possible field + # names + ref_output_len = None + for field in ['output_token_ids', 'target_token_ids', + 'output_tokens', 'expected_output_token_ids']: + if field in ref_row: + ref_tokens = ref_row[field] + if isinstance(ref_tokens, list): + ref_output_len = len(ref_tokens) + result['ref_output_token_ids'] = ref_tokens + break + elif isinstance(ref_tokens, (int, float)) and not pd.isna(ref_tokens): + ref_output_len = int(ref_tokens) + break + + # Also check for direct length field + if ref_output_len is None: + for field in ['output_len', 'output_length', + 'num_output_tokens', 'target_len']: + if field in ref_row and not pd.isna(ref_row[field]): + ref_output_len = int(ref_row[field]) + break + + if ref_output_len is not None: + actual_output_len = len(token_ids) + result['ref_num_output_tokens'] = ref_output_len + result['actual_num_output_tokens'] = actual_output_len + result['output_token_len_diff'] = actual_output_len - \ + ref_output_len + result['output_token_len_ratio'] = actual_output_len / \ + ref_output_len if ref_output_len > 0 else None + + # Add any other columns that might be useful + for col in ['question_id', 'difficulty', 'subject', 'category']: + if col in ref_row: + result[col] = ref_row[col] + + results.append(result) + + return results + + +def main(): + """Main function.""" + args = parse_args() + + # Validate input file exists + log_path = Path(args.mlperf_log) + if not log_path.exists(): + logger.error(f"MLPerf log file not found: {args.mlperf_log}") + sys.exit(1) + + logger.info("=" * 80) + logger.info("MLPerf Accuracy Log Parser") + logger.info("=" * 80) + logger.info(f"Input log: {args.mlperf_log}") + logger.info(f"Output file: {args.output_file}") + logger.info( + f"Reference data: {args.reference_data if args.reference_data else 'None (outputs only)'}") + logger.info(f"Tokenizer: {args.tokenizer}") + logger.info("=" * 80) + + # Load reference data if provided + reference_df = None + if args.reference_data: + logger.info(f"Loading reference data from {args.reference_data}") + try: + if args.reference_data.endswith('.parquet'): + reference_df = pd.read_parquet(args.reference_data) + logger.info("Loaded reference data from Parquet file") + elif args.reference_data.endswith('.pkl') or args.reference_data.endswith('.pickle'): + with open(args.reference_data, 'rb') as f: + reference_df = pickle.load(f) + logger.info("Loaded reference data from Pickle file") + else: + # Try parquet first, then pickle + try: + reference_df = pd.read_parquet(args.reference_data) + logger.info("Auto-detected Parquet format") + except Exception: + with open(args.reference_data, 'rb') as f: + reference_df = pickle.load(f) + logger.info("Auto-detected Pickle format") + + logger.info(f"✓ Reference data loaded: {reference_df.shape}") + logger.info(f" Columns: {list(reference_df.columns)}") + except Exception as e: + logger.error(f"Failed to load reference data: {e}") + sys.exit(1) + + # Load tokenizer + logger.info(f"Loading tokenizer: {args.tokenizer}") + try: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + logger.info("✓ Tokenizer loaded successfully") + except Exception as e: + logger.error(f"Failed to load tokenizer: {e}") + sys.exit(1) + + # Parse MLPerf log + try: + entries = parse_mlperf_log(args.mlperf_log) + except Exception as e: + logger.error(f"Failed to parse MLPerf log: {e}") + sys.exit(1) + + if not entries: + logger.error("No entries found in MLPerf log") + sys.exit(1) + + # Detokenize responses + try: + results = detokenize_responses(entries, tokenizer, reference_df) + except Exception as e: + logger.error(f"Failed to detokenize responses: {e}") + sys.exit(1) + + # Write output JSON + logger.info(f"Writing detokenized outputs to: {args.output_file}") + output_path = Path(args.output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w') as f: + if args.pretty: + json.dump(results, f, indent=2, ensure_ascii=False) + else: + json.dump(results, f, ensure_ascii=False) + + logger.info("=" * 80) + logger.info("✓ Parsing completed successfully") + logger.info("=" * 80) + logger.info(f"Total responses parsed: {len(results)}") + + # Print statistics + total_tokens = sum(r["num_tokens"] for r in results) + avg_tokens = total_tokens / len(results) if results else 0 + logger.info(f"Total output tokens: {total_tokens:,}") + logger.info(f"Average tokens per response: {avg_tokens:.1f}") + + # Print token length difference statistics if reference data was provided + if reference_df is not None: + results_with_diff = [ + r for r in results if 'output_token_len_diff' in r] + if results_with_diff: + diffs = [r['output_token_len_diff'] for r in results_with_diff] + ratios = [r['output_token_len_ratio'] + for r in results_with_diff if r['output_token_len_ratio'] is not None] + + logger.info( + f"\nOutput Token Length Analysis ({len(results_with_diff)} samples with reference):") + logger.info( + f" Mean difference (actual - ref): {sum(diffs) / len(diffs):.2f} tokens") + logger.info(f" Min difference: {min(diffs)} tokens") + logger.info(f" Max difference: {max(diffs)} tokens") + if ratios: + logger.info( + f" Mean ratio (actual / ref): {sum(ratios) / len(ratios):.3f}x") + + # Count samples that are longer/shorter + longer = sum(1 for d in diffs if d > 0) + shorter = sum(1 for d in diffs if d < 0) + exact = sum(1 for d in diffs if d == 0) + logger.info( + f" Longer than reference: {longer} ({longer/len(diffs)*100:.1f}%)") + logger.info( + f" Shorter than reference: {shorter} ({shorter/len(diffs)*100:.1f}%)") + logger.info( + f" Exact match: {exact} ({exact/len(diffs)*100:.1f}%)") + + logger.info("=" * 80) + + # Show sample output + if results: + logger.info("Sample output (first entry):") + sample = results[0] + logger.info(f" qsl_idx: {sample['qsl_idx']}") + logger.info(f" num_tokens: {sample['num_tokens']}") + logger.info(f" text preview: {sample['text'][:200]}...") + logger.info("=" * 80) + + # Generate histogram plots if plot directory is specified + if args.plot_dir: + logger.info("\n" + "=" * 80) + logger.info("Generating Histogram Plots") + logger.info("=" * 80) + plot_histograms( + results, args.plot_dir, has_reference=( + reference_df is not None)) + logger.info("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/language/gpt-oss-120b/generation_config.json b/language/gpt-oss-120b/generation_config.json new file mode 100644 index 0000000000..d7b586f927 --- /dev/null +++ b/language/gpt-oss-120b/generation_config.json @@ -0,0 +1,46 @@ +{ + "_comment": "Generation configuration for gpt-oss-120b model", + "_description": "These parameters control the text generation behavior", + + "max_new_tokens": 32768, + "temperature": 1.0, + "top_k": -1, + "top_p": 1.0, + + "_parameter_descriptions": { + "max_new_tokens": "Maximum number of tokens to generate per request (1-32768)", + "temperature": "Sampling temperature (0.0 = deterministic, higher = more random). Typical: 0.001-2.0", + "top_k": "Top-k sampling (number of highest probability tokens to consider). -1 = disabled", + "top_p": "Top-p/nucleus sampling (cumulative probability threshold). 0.0-1.0, typically 1.0 for no filtering", + + "_additional_params_note": "SGLang supports additional parameters like:", + "repetition_penalty": "Penalty for repeating tokens (typically 1.0-1.2)", + "frequency_penalty": "Penalty based on token frequency (0.0-2.0)", + "presence_penalty": "Penalty for tokens already present (0.0-2.0)", + "min_tokens": "Minimum tokens to generate before stopping", + "stop": "Stop sequences (list of strings)", + "ignore_eos": "Whether to ignore EOS token (boolean)" + }, + + "_presets": { + "deterministic": { + "max_new_tokens": 10240, + "temperature": 0.001, + "top_k": 1, + "top_p": 1.0 + }, + "creative": { + "max_new_tokens": 10240, + "temperature": 1.5, + "top_k": 50, + "top_p": 0.95 + }, + "balanced": { + "max_new_tokens": 10240, + "temperature": 0.7, + "top_k": 40, + "top_p": 0.9 + } + } +} + diff --git a/language/gpt-oss-120b/mlperf/__init__.py b/language/gpt-oss-120b/mlperf/__init__.py new file mode 100644 index 0000000000..c5aaa0d243 --- /dev/null +++ b/language/gpt-oss-120b/mlperf/__init__.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +"""MLPerf inference integration for gpt-oss.""" + +from .base_sut import BaseSUT +from .offline_sut import OfflineSUT +from .server_sut import ServerSUT +from .qsl import QuerySampleLibrary + +__all__ = [ + "BaseSUT", + "OfflineSUT", + "ServerSUT", + "QuerySampleLibrary", +] diff --git a/language/gpt-oss-120b/mlperf/base_sut.py b/language/gpt-oss-120b/mlperf/base_sut.py new file mode 100644 index 0000000000..1919d56a2a --- /dev/null +++ b/language/gpt-oss-120b/mlperf/base_sut.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +"""Base System Under Test (SUT) class for MLPerf inference benchmarks.""" + +import abc +import logging +import threading +from typing import List, Dict, Any, Optional +import mlperf_loadgen as lg + +logger = logging.getLogger(__name__) + + +class BaseSUT(abc.ABC): + """Base class for MLPerf inference System Under Test (SUT). + + This class defines the interface that all SUTs must implement for MLPerf + inference benchmarks. It provides two main methods: + - issue_queries: to enqueue prompt tokens + - flush_queries: to await completion of all issued queries + """ + + def __init__( + self, backend, dataset: List[List[int]], name: str = "BaseSUT", progress_bar=None): + """Initialize the base SUT. + + Args: + backend: Backend instance for inference + dataset: List of tokenized prompts + name: Name of the SUT for logging purposes + progress_bar: Optional tqdm progress bar for real-time updates + """ + self.backend = backend + self.dataset = dataset + self.name = name + self.sut = None + self.results = {} + self.progress_bar = progress_bar + + # Graceful shutdown support (set on KeyboardInterrupt) + self.should_stop = threading.Event() + + logger.info(f"Initializing {self.name}") + + @abc.abstractmethod + def issue_queries(self, query_samples: List[lg.QuerySample]) -> None: + """Issue queries to the SUT. + + This method should enqueue the provided query samples for processing. + It should return immediately without waiting for completion. + + Args: + query_samples: List of MLPerf LoadGen query samples to process + """ + raise NotImplementedError("Subclasses must implement issue_queries") + + @abc.abstractmethod + def flush_queries(self) -> None: + """Flush all pending queries. + + This method should wait for all previously issued queries to complete + before returning. It's called by LoadGen to ensure all work is done. + """ + raise NotImplementedError("Subclasses must implement flush_queries") + + def start(self) -> lg.ConstructSUT: + """Start the SUT and return the LoadGen SUT handle. + + Returns: + LoadGen SUT handle for use with LoadGen + """ + self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries) + logger.info(f"{self.name} started") + return self.sut + + def stop(self) -> None: + """Stop the SUT and clean up resources. + + Signals graceful shutdown and allows subclasses to cancel pending work. + """ + logger.info(f"Stopping {self.name}...") + + # Signal all workers/tasks to stop + self.should_stop.set() + + # Subclasses should override to add their own cleanup + # (e.g., cancel tasks, clear queues) + + if self.sut: + lg.DestroySUT(self.sut) + self.sut = None + logger.info(f"{self.name} stopped") + + def get_results(self) -> Dict[int, Any]: + """Get all results from completed queries. + + Returns: + Dictionary mapping query IDs to results + """ + return self.results + + def __enter__(self): + """Context manager entry.""" + return self.start() + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.stop() diff --git a/language/gpt-oss-120b/mlperf/offline_sut.py b/language/gpt-oss-120b/mlperf/offline_sut.py new file mode 100644 index 0000000000..7c4e8e9ee4 --- /dev/null +++ b/language/gpt-oss-120b/mlperf/offline_sut.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +"""Offline scenario SUT implementation for gpt-oss.""" + +import logging +import numpy as np +import time +from typing import List, Dict, Any +from concurrent.futures import ThreadPoolExecutor, as_completed +import mlperf_loadgen as lg +from tqdm import tqdm +from .base_sut import BaseSUT + +logger = logging.getLogger(__name__) + + +class OfflineSUT(BaseSUT): + """Offline scenario System Under Test. + + In the Offline scenario, all queries are issued at once and can be + processed in any order. This allows for maximum batching and throughput. + """ + + def __init__( + self, + backend, + dataset: List[List[int]], + max_tokens: int = 32768, + temperature: float = 0.001, + top_k: int = 1, + top_p: float = 1.0, + name: str = "OfflineSUT", + progress_bar=None, + max_concurrency: int = 128 + ): + """Initialize the Offline SUT. + + Args: + backend: Backend instance for inference + dataset: List of tokenized prompts + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_k: Top-k sampling parameter + top_p: Top-p sampling parameter + name: Name of the SUT + progress_bar: Optional tqdm progress bar for real-time updates + max_concurrency: Maximum concurrent requests to backend (SGLang does in-flight batching) + """ + super().__init__(backend, dataset, name, progress_bar) + self.max_tokens = max_tokens + self.temperature = temperature + self.top_k = top_k + self.top_p = top_p + self.pending_queries = [] + self.max_concurrency = max_concurrency + + logger.info( + f"OfflineSUT configured with max_concurrency={max_concurrency} (backend handles batching)") + + def issue_queries(self, query_samples: List[lg.QuerySample]) -> None: + """Issue queries to the SUT. + + In Offline mode, we accumulate all queries and process them in batch. + + Args: + query_samples: List of MLPerf LoadGen query samples + """ + logger.info(f"Received {len(query_samples)} queries") + + # Update progress bar total by accumulating (for repeats_per_sample > 1) + # LoadGen may call issue_queries multiple times for repeated sampling + if self.progress_bar is not None: + self.progress_bar.total = ( + self.progress_bar.total or 0) + len(query_samples) + self.progress_bar.refresh() + + # Store queries for batch processing + for qs in query_samples: + self.pending_queries.append(qs) + + def flush_queries(self) -> None: + """Process all accumulated queries with concurrent requests. + + Sends individual requests concurrently up to max_concurrency limit. + SGLang handles batching internally via continuous batching. + """ + if not self.pending_queries: + logger.info("No pending queries to flush") + return + + logger.info( + f"Flushing {len(self.pending_queries)} queries with max_concurrency={self.max_concurrency}") + start_time = time.time() + + def process_single_query(query_sample): + """Process a single query (backend batches automatically via continuous batching).""" + # Check if we should stop (e.g., KeyboardInterrupt) + if self.should_stop.is_set(): + logger.info( + f"Skipping query {query_sample.id} due to shutdown") + return None, None, None + + query_id = query_sample.id + input_ids = self.dataset[query_sample.index] + + # Call backend with single query + # SGLang will batch this with other concurrent requests + # automatically + responses = self.backend.generate( + prompts=[input_ids], # Single query as list + max_tokens=self.max_tokens, + temperature=self.temperature, + top_k=self.top_k, + top_p=self.top_p + ) + + return query_id, query_sample, responses[0] + + try: + # Process queries in parallel with max_concurrency + logger.info( + f"Submitting {len(self.pending_queries)} queries to {self.max_concurrency} concurrent workers...") + with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor: + # Submit all queries at once + futures = [ + executor.submit( + process_single_query, + qs) for qs in self.pending_queries] + + # Process results as they complete + completed_count = 0 + cancelled_count = 0 + + for future in as_completed(futures): + # Check if shutdown was requested + if self.should_stop.is_set(): + logger.info( + "Shutdown requested, cancelling remaining futures...") + for f in futures: + f.cancel() + cancelled_count = sum( + 1 for f in futures if f.cancelled()) + logger.info( + f"Cancelled {cancelled_count} pending futures") + break + try: + query_id, query_sample, response = future.result() + + # Skip if query was cancelled/skipped + if query_id is None: + continue + + output_ids = response.get("output_ids", []) + + # Store results + self.results[query_id] = { + "output_ids": output_ids, + "output_text": response.get("output_text", ""), + "metadata": response.get("metadata", {}) + } + + # Convert output_ids to numpy array for LoadGen + # LoadGen expects int32 token IDs as a contiguous array + if output_ids: + token_array = np.ascontiguousarray( + output_ids, dtype=np.int32) + output_data_ptr = token_array.ctypes.data + output_data_size = token_array.nbytes + n_tokens = len(output_ids) + else: + # Empty response + token_array = np.array([], dtype=np.int32) + output_data_ptr = 0 + output_data_size = 0 + n_tokens = 0 + + # Create response for LoadGen with token count + response_array = [ + lg.QuerySampleResponse( + query_id, + output_data_ptr, + output_data_size, + n_tokens # Number of output tokens for tokens/sec metric + ) + ] + + # Report completion to LoadGen + lg.QuerySamplesComplete(response_array) + + # Update progress bar + if self.progress_bar is not None: + self.progress_bar.update(1) + self.progress_bar.refresh() + + completed_count += 1 + # Log progress at debug level only (tqdm shows + # progress) + if completed_count % 100 == 0: + logger.debug( + f"Completed {completed_count}/{len(self.pending_queries)} queries") + + except Exception as e: + logger.error( + f"Error processing query: {e}", exc_info=True) + + elapsed = time.time() - start_time + if cancelled_count > 0: + logger.info( + f"Completed {completed_count} queries, cancelled {cancelled_count} queries " + f"in {elapsed:.2f}s" + ) + else: + logger.info( + f"Completed {len(self.pending_queries)} queries in {elapsed:.2f}s " + f"({len(self.pending_queries)/elapsed:.2f} QPS)" + ) + + except Exception as e: + logger.error(f"Error during concurrent flush: {e}", exc_info=True) + raise + finally: + # Clear pending queries + self.pending_queries = [] diff --git a/language/gpt-oss-120b/mlperf/qsl.py b/language/gpt-oss-120b/mlperf/qsl.py new file mode 100644 index 0000000000..e7b06a1bb8 --- /dev/null +++ b/language/gpt-oss-120b/mlperf/qsl.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +"""Query Sample Library for gpt-oss MLPerf integration.""" + +import logging +from typing import List +import mlperf_loadgen as lg + +logger = logging.getLogger(__name__) + + +class QuerySampleLibrary: + """Query Sample Library implementation. + + This class manages the dataset of samples that LoadGen will query. + """ + + def __init__(self, dataset: List[List[int]]): + """Initialize the Query Sample Library. + + Args: + dataset: List of tokenized prompts (list of token ID lists) + """ + self.dataset = dataset + self.qsl = None + logger.info(f"Initializing QSL with {len(dataset)} samples") + + def load_query_samples(self, sample_indices: List[int]) -> None: + """Load specified query samples into memory. + + Args: + sample_indices: List of sample indices to load + """ + # For this implementation, all samples are already in memory + logger.info(f"Loading {len(sample_indices)} query samples") + + def unload_query_samples(self, sample_indices: List[int]) -> None: + """Unload specified query samples from memory. + + Args: + sample_indices: List of sample indices to unload + """ + # For this implementation, we keep all samples in memory + logger.info(f"Unloading {len(sample_indices)} query samples") + + def __len__(self) -> int: + """Return the number of samples in the dataset.""" + return len(self.dataset) + + def __enter__(self): + """Context manager entry.""" + self.qsl = lg.ConstructQSL( + len(self.dataset), + len(self.dataset), # performance sample count + self.load_query_samples, + self.unload_query_samples + ) + logger.info("QSL constructed") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + if self.qsl: + lg.DestroyQSL(self.qsl) + self.qsl = None + logger.info("QSL destroyed") diff --git a/language/gpt-oss-120b/mlperf/server_sut.py b/language/gpt-oss-120b/mlperf/server_sut.py new file mode 100644 index 0000000000..2f9c83532f --- /dev/null +++ b/language/gpt-oss-120b/mlperf/server_sut.py @@ -0,0 +1,443 @@ +#!/usr/bin/env python3 +"""Server scenario SUT implementation with streaming support for gpt-oss.""" + +import asyncio +import logging +import numpy as np +import queue +import sys +import threading +import time +from dataclasses import dataclass +from typing import List, Dict, Any, Optional +import mlperf_loadgen as lg +from tqdm import tqdm + +from .base_sut import BaseSUT + +logger = logging.getLogger(__name__) + + +@dataclass +class StreamingQueryState: + """State for a streaming query.""" + query_sample: lg.QuerySample + query_id: int + input_ids: List[int] + accumulated_tokens: List[int] + accumulated_text: str + first_token_received: bool + first_token_time: Optional[float] + start_time: float + finished: bool + + +class ServerSUT(BaseSUT): + """Server scenario SUT with streaming support. + + Properly reports FirstTokenComplete and QuerySamplesComplete to LoadGen. + """ + + def __init__( + self, + backend, + dataset: List[List[int]], + max_tokens: int = 32768, + temperature: float = 0.001, + top_k: int = 1, + top_p: float = 1.0, + num_workers: int = 1, + name: str = "ServerSUT", + progress_bar=None + ): + """Initialize the Server SUT. + + Args: + backend: Backend instance for inference (must support streaming) + dataset: List of tokenized prompts + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_k: Top-k sampling parameter + top_p: Top-p sampling parameter + num_workers: Number of worker threads + name: Name of the SUT + progress_bar: Optional tqdm progress bar for real-time updates + """ + super().__init__(backend, dataset, name, progress_bar) + self.max_tokens = max_tokens + self.temperature = temperature + self.top_k = top_k + self.top_p = top_p + self.num_workers = num_workers + + # Query queue and streaming state + self.query_queue = queue.Queue() + self.active_streams: Dict[int, StreamingQueryState] = {} + self.active_streams_lock = threading.Lock() + + # Track active async tasks for cancellation on KeyboardInterrupt + self.active_tasks = set() + self.active_tasks_lock = threading.Lock() + + # Worker threads + self.workers = [] + + # Progress tracking + self.queries_completed = 0 + self.progress_lock = threading.Lock() + + # Event loop for async streaming + self.loop = None + self.loop_thread = None + + logger.info( + f"ServerSUT configured with num_workers={num_workers} (streaming enabled)") + + def start(self) -> lg.ConstructSUT: + """Start the SUT and worker threads.""" + # Start event loop thread for async streaming + self._start_event_loop() + + # Start worker threads + self._start_workers() + + # Create LoadGen SUT + self.sut = lg.ConstructSUT( + self.issue_queries, + self.flush_queries) + logger.info(f"{self.name} started with streaming support") + return self.sut + + def _start_event_loop(self): + """Start the asyncio event loop in a separate thread.""" + def run_loop(): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + self.loop_thread = threading.Thread(target=run_loop, daemon=True) + self.loop_thread.start() + + # Wait for loop to be ready + while self.loop is None: + time.sleep(0.001) + + logger.info("Async event loop started") + + def _start_workers(self): + """Start worker threads for processing queries.""" + for i in range(self.num_workers): + worker = threading.Thread( + target=self._worker_thread, + name=f"ServerWorker-{i}", + daemon=True + ) + self.workers.append(worker) + worker.start() + logger.info(f"Started {self.num_workers} worker threads") + + def _worker_thread(self): + """Worker thread that processes queries from the queue.""" + try: + while not self.should_stop.is_set(): + try: + query_sample = self.query_queue.get(timeout=0.1) + except queue.Empty: + continue + except KeyboardInterrupt: + logger.info( + "Worker thread interrupted, exiting gracefully...") + break + + # Schedule async streaming processing and track task + if self.loop and not self.should_stop.is_set(): + # Create the coroutine + coro = self._process_streaming_query_tracked(query_sample) + # Schedule it on the event loop + future = asyncio.run_coroutine_threadsafe(coro, self.loop) + # Don't wait for completion - it happens asynchronously + + except Exception as e: + logger.error(f"Worker thread error: {e}", exc_info=True) + + async def _process_streaming_query_tracked( + self, query_sample: lg.QuerySample): + """Wrapper that tracks the async task for cancellation.""" + task = asyncio.current_task() + + # Add to active tasks + with self.active_tasks_lock: + self.active_tasks.add(task) + + try: + await self._process_streaming_query(query_sample) + finally: + # Remove from active tasks + with self.active_tasks_lock: + self.active_tasks.discard(task) + + async def _process_streaming_query(self, query_sample: lg.QuerySample): + """Process a single query with streaming support. + + Token reporting to LoadGen: + 1. When first token arrives → lg.FirstTokenComplete([token_0]) + 2. When generation finishes → lg.QuerySamplesComplete([token_1, token_2, ..., token_n]) + Args: + query_sample: MLPerf LoadGen query sample + """ + query_id = query_sample.id + sample_idx = query_sample.index + input_ids = self.dataset[sample_idx] + + # Initialize streaming state + state = StreamingQueryState( + query_sample=query_sample, + query_id=query_id, + input_ids=input_ids, + accumulated_tokens=[], + accumulated_text="", + first_token_received=False, + first_token_time=None, + start_time=time.time(), + finished=False + ) + + with self.active_streams_lock: + self.active_streams[query_id] = state + + try: + # Stream tokens from backend + async for chunk in self.backend.generate_stream( + input_ids=input_ids, + max_tokens=self.max_tokens, + temperature=self.temperature, + top_k=self.top_k, + top_p=self.top_p + ): + # Update state + if chunk.get("delta_token_ids"): + state.accumulated_tokens.extend(chunk["delta_token_ids"]) + if chunk.get("delta_text"): + state.accumulated_text += chunk["delta_text"] + + # Send FirstTokenComplete on first token + if chunk.get( + "is_first_token") and not state.first_token_received: + state.first_token_received = True + state.first_token_time = time.time() + await self._send_first_token_complete(state) + + # Check if finished + if chunk.get("is_finished"): + state.finished = True + await self._send_final_response(state) + break + + # If no explicit finish signal, send final response + if not state.finished: + state.finished = True + await self._send_final_response(state) + + except asyncio.CancelledError: + # Task was cancelled (e.g., KeyboardInterrupt during graceful + # shutdown) + logger.info( + f"Streaming query {query_id} cancelled during shutdown") + # Don't send response to LoadGen - we're shutting down + raise # Re-raise to mark task as cancelled + except Exception as e: + logger.error( + f"Error processing streaming query {query_id}: {e}", + exc_info=True) + # Send empty response to unblock LoadGen + try: + await self._send_final_response(state) + except BaseException: + pass + finally: + # Clean up + with self.active_streams_lock: + self.active_streams.pop(query_id, None) + + async def _send_first_token_complete(self, state: StreamingQueryState): + """Send FirstTokenComplete to LoadGen for TTFT measurement. + + Only sends the first token for TTFT measurement. + """ + try: + logger.debug( + f"First token for query {state.query_id} at {state.first_token_time - state.start_time:.3f}s") + + # LoadGen uses this to measure Time To First Token (TTFT) + if state.accumulated_tokens and len(state.accumulated_tokens) > 0: + # Extract only the first token + first_token_only = [state.accumulated_tokens[0]] + token_array = np.ascontiguousarray( + first_token_only, dtype=np.int32) + else: + # No tokens yet - this shouldn't happen but handle gracefully + token_array = np.array([], dtype=np.int32) + logger.warning( + f"FirstTokenComplete called but no tokens accumulated for query {state.query_id}") + + # Create response + response = lg.QuerySampleResponse( + state.query_id, + token_array.ctypes.data if token_array.size > 0 else 0, + token_array.nbytes, + len(token_array) + ) + + # Report to LoadGen + lg.FirstTokenComplete([response]) + logger.debug( + f"Sent FirstTokenComplete for query {state.query_id}: 1 token") + + except Exception as e: + logger.error( + f"Error sending FirstTokenComplete for query {state.query_id}: {e}", + exc_info=True) + + async def _send_final_response(self, state: StreamingQueryState): + """Send final QuerySamplesComplete to LoadGen. (send all tokens except the first one) + """ + try: + num_total_tokens = len(state.accumulated_tokens) + logger.debug( + f"Final response for query {state.query_id}: {num_total_tokens} total tokens") + + # Store results (all tokens for internal tracking) + self.results[state.query_id] = { + "output_ids": state.accumulated_tokens, + "output_text": state.accumulated_text, + "metadata": { + "latency": time.time() - state.start_time, + "ttft": state.first_token_time - state.start_time if state.first_token_time else None, + } + } + + if state.accumulated_tokens and len(state.accumulated_tokens) > 1: + remaining_tokens = state.accumulated_tokens[1:] + token_array = np.ascontiguousarray( + remaining_tokens, dtype=np.int32) + else: + token_array = np.array([], dtype=np.int32) + + # Create response + response = lg.QuerySampleResponse( + state.query_id, + token_array.ctypes.data if token_array.size > 0 else 0, + token_array.nbytes, + len(token_array) + ) + + # Report to LoadGen + lg.QuerySamplesComplete([response]) + logger.debug( + f"Sent QuerySamplesComplete for query {state.query_id}: " + f"{len(token_array)} remaining tokens (total: {num_total_tokens})" + ) + + # Update progress bar (force refresh for async updates) + if self.progress_bar is not None: + with self.progress_lock: + self.queries_completed += 1 + self.progress_bar.update(1) + self.progress_bar.refresh() # Force redraw from async context + sys.stdout.flush() # Force flush for immediate display in async/threaded context + + except Exception as e: + logger.error( + f"Error sending final response for query {state.query_id}: {e}", + exc_info=True) + + def issue_queries(self, query_samples: List[lg.QuerySample]) -> None: + """Issue queries to the SUT. + + In Server mode, queries are added to a queue for worker threads. + + Args: + query_samples: List of MLPerf LoadGen query samples + """ + # Update progress bar total dynamically as queries arrive + if self.progress_bar is not None: + with self.progress_lock: + self.progress_bar.total = ( + self.progress_bar.total or 0) + len(query_samples) + self.progress_bar.refresh() + + for qs in query_samples: + self.query_queue.put(qs) + + def flush_queries(self) -> None: + """Flush all pending queries. + + Wait for all issued queries to complete. + """ + logger.info("Flushing server queries...") + + # Wait for queue to empty and all streams to complete + while True: + queue_empty = self.query_queue.empty() + + with self.active_streams_lock: + no_active_streams = len(self.active_streams) == 0 + + if queue_empty and no_active_streams: + break + + time.sleep(0.01) + + logger.info("Server queries flushed") + + def stop(self) -> None: + """Stop the SUT and clean up resources.""" + if self.should_stop.is_set(): + logger.info(f"{self.name} already stopping or stopped.") + return + + super().stop() + + # Cancel all active streaming tasks + logger.info("Cancelling active streaming tasks...") + tasks_to_cancel = [] + with self.active_tasks_lock: + tasks_to_cancel = list(self.active_tasks) + + if tasks_to_cancel: + logger.info(f"Cancelling {len(tasks_to_cancel)} active tasks") + for task in tasks_to_cancel: + if not task.done(): + task.cancel() + + # Clear pending queries from queue + pending_count = 0 + try: + while True: + self.query_queue.get_nowait() + pending_count += 1 + except queue.Empty: + pass + + if pending_count > 0: + logger.info(f"Cleared {pending_count} pending queries from queue") + + # Wait for workers with progress bar + with tqdm(total=len(self.workers), desc="Stopping workers", unit="worker") as pbar: + for i, worker in enumerate(self.workers): + worker.join(timeout=5) + if worker.is_alive(): + logger.warning( + f"Worker {i+1} did not terminate gracefully") + pbar.update(1) + + # Stop event loop + if self.loop: + self.loop.call_soon_threadsafe(self.loop.stop) + if self.loop_thread: + self.loop_thread.join(timeout=2) + + logger.info("All workers stopped") + + # Destroy LoadGen SUT + super().stop() diff --git a/language/gpt-oss-120b/mlperf/user.conf b/language/gpt-oss-120b/mlperf/user.conf new file mode 100644 index 0000000000..27c2fe59b4 --- /dev/null +++ b/language/gpt-oss-120b/mlperf/user.conf @@ -0,0 +1,7 @@ +gpt-oss-120b.*.performance_sample_count = 4395 +gpt-oss-120b.Server.target_qps = 1 +gpt-oss-120b.Server.min_duration = 60000 +gpt-oss-120b.Server.performance_sample_count = 4395 +gpt-oss-120b.Server.target_latency = 0 +gpt-oss-120b.Server.ttft_latency = 2000 +gpt-oss-120b.Server.tpot_latency = 20 diff --git a/language/gpt-oss-120b/requirements.txt b/language/gpt-oss-120b/requirements.txt new file mode 100644 index 0000000000..9d8f33995a --- /dev/null +++ b/language/gpt-oss-120b/requirements.txt @@ -0,0 +1,11 @@ +absl-py==2.3.1 +anthropic==0.72.0 +audioread==3.1.0 +datasets==2.21.0 +joblib==1.5.3 +lazy_loader==0.4 +msgpack==1.1.2 +numba==0.63.1 +pooch==1.8.2 +scikit-learn==1.8.0 +soxr==1.0.0 diff --git a/language/gpt-oss-120b/run_mlperf.py b/language/gpt-oss-120b/run_mlperf.py new file mode 100755 index 0000000000..2edfaa37f5 --- /dev/null +++ b/language/gpt-oss-120b/run_mlperf.py @@ -0,0 +1,512 @@ +#!/usr/bin/env python3 +"""MLPerf inference benchmark runner for gpt-oss. + +This script integrates the gpt-oss model with MLPerf LoadGen for +performance and accuracy benchmarking. + +Usage: + # Offline scenario (performance) + python run_mlperf.py --scenario offline --input-file data/accuracy_eval_tokenized.pkl + + # Server scenario (performance) + python run_mlperf.py --scenario server --input-file data/accuracy_eval_tokenized.pkl + + # Accuracy mode + python run_mlperf.py --scenario offline --accuracy --input-file data/accuracy_eval_tokenized.pkl +""" + +import argparse +import json +import logging +import os +import sys +import threading +from pathlib import Path +from typing import Optional, Dict, Any + +import mlperf_loadgen as lg +import pandas as pd +from tqdm import tqdm + +from backends import SGLangBackend +from mlperf import OfflineSUT, ServerSUT, QuerySampleLibrary +from utils import load_tokenized_dataset, StandardTokenizer + +# Disable tokenizers parallelism to avoid forking issues +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +def load_generation_config(config_path: str) -> Dict[str, Any]: + """Load generation configuration from JSON file. + + Args: + config_path: Path to generation_config.json + + Returns: + Dictionary with generation parameters + """ + logger.info(f"Loading generation config from {config_path}") + + with open(config_path, 'r') as f: + config = json.load(f) + + # Filter out comment fields (starting with _) + gen_params = {k: v for k, v in config.items() if not k.startswith('_')} + + return gen_params + + +def create_argument_parser() -> argparse.ArgumentParser: + """Create argument parser for MLPerf runner.""" + parser = argparse.ArgumentParser( + description="Run MLPerf inference benchmarks for gpt-oss" + ) + + # Scenario selection + parser.add_argument( + "--scenario", + type=str, + default="offline", + choices=["offline", "server"], + help="MLPerf scenario (offline or server)" + ) + + # Dataset + parser.add_argument( + "--input-file", + type=str, + required=True, + help="Path to tokenized dataset (parquet or pickle file)" + ) + + parser.add_argument( + "--max-samples", + type=int, + default=None, + help="Maximum number of samples to use (None for all)" + ) + + # MLPerf configuration + parser.add_argument( + "--mlperf-conf", + type=str, + default="/home/scratch.shobhitv_coreai/mlcinf-repos/gpt-oss-perf/loadgen/mlperf.conf", + help="Path to MLPerf configuration file" + ) + + parser.add_argument( + "--user-conf", + type=str, + default="mlperf/user.conf", + help="Path to user configuration file" + ) + + parser.add_argument( + "--accuracy", + action="store_true", + help="Run accuracy mode instead of performance" + ) + + # Output configuration + parser.add_argument( + "--output-dir", + type=str, + default="mlperf_results", + help="Directory for MLPerf output logs" + ) + + # Backend configuration + parser.add_argument( + "--backend", + type=str, + default="sglang", + choices=["sglang"], + help="Backend to use for inference" + ) + + parser.add_argument( + "--server-url", + type=str, + default="http://localhost:30000", + help="Server URL for backend (SGLang)" + ) + + # Generation configuration + parser.add_argument( + "--generation-config", + type=str, + default="generation_config.json", + help="Path to generation configuration JSON file" + ) + + parser.add_argument( + "--max-new-tokens", + type=int, + default=None, + help="Override max_new_tokens from generation config (default: use value from config)" + ) + + # Server scenario specific + parser.add_argument( + "--num-workers", + type=int, + default=1, + help="Number of worker threads (for server scenario)" + ) + + # Concurrency control + parser.add_argument( + "--max-concurrency", + type=int, + default=128, + help="Maximum concurrent requests to backend (SGLang handles batching internally)" + ) + + parser.add_argument( + "--timeout", + type=int, + default=1200, + help="Timeout for HTTP requests in seconds (default: 1200)" + ) + + return parser + + +def configure_loadgen( + scenario: str, + accuracy_mode: bool, + mlperf_conf: Optional[str] = None, + user_conf: Optional[str] = None, + log_dir: Optional[str] = None, + model_name: str = "gpt-oss-120b" +) -> lg.TestSettings: + """Configure LoadGen test settings. + + Args: + scenario: MLPerf scenario ("offline" or "server") + accuracy_mode: Whether to run in accuracy mode + mlperf_conf: Path to MLPerf config file + user_conf: Path to user config file + log_dir: Directory for logs + model_name: Model name for configuration + + Returns: + LoadGen TestSettings + """ + settings = lg.TestSettings() + + # Set scenario + if scenario.lower() == "offline": + settings.scenario = lg.TestScenario.Offline + elif scenario.lower() == "server": + settings.scenario = lg.TestScenario.Server + else: + raise ValueError(f"Unknown scenario: {scenario}") + + # Set mode + if accuracy_mode: + settings.mode = lg.TestMode.AccuracyOnly + else: + settings.mode = lg.TestMode.PerformanceOnly + + # Load configurations if files exist + # conf_type: 2 = mlperf.conf, 1 = user.conf + # LoadGen tracks config calls and only allows one user.conf for official + # submissions + if mlperf_conf and Path(mlperf_conf).exists(): + logger.debug(f"Loading MLPerf config from {mlperf_conf}") + settings.FromConfig(mlperf_conf, model_name, scenario.capitalize(), 2) + else: + logger.warning(f"MLPerf config not found: {mlperf_conf}") + + if user_conf and Path(user_conf).exists(): + logger.debug(f"Loading user config from {user_conf}") + settings.FromConfig(user_conf, model_name, scenario.capitalize(), 1) + else: + logger.warning(f"User config not found: {user_conf}") + + return settings + + +def main(): + """Main function.""" + parser = create_argument_parser() + args = parser.parse_args() + + # Track resources for cleanup + sut = None + qsl = None + backend = None + pbar = None + cleanup_done = False + + def do_cleanup(): + """Perform cleanup once and only once.""" + nonlocal cleanup_done, pbar, sut, qsl, backend + + if cleanup_done: + return + cleanup_done = True + + logger.info("Performing cleanup...") + + # 1. Close progress bar first (before any LoadGen cleanup) + try: + if pbar is not None: + pbar.close() + pbar = None + logger.debug(" ✓ Progress bar closed") + except Exception as e: + logger.debug(f" ! Error closing progress bar: {e}") + + # Small delay to let LoadGen internal threads finish + import time + time.sleep(0.5) + + # 2. Stop SUT (this will stop worker threads and flush) + try: + if sut is not None: + logger.info(" - Stopping SUT and worker threads...") + sut.stop() + sut = None + logger.info(" ✓ SUT stopped") + except Exception as e: + logger.warning(f" ! Error stopping SUT: {e}") + + # 3. Destroy QSL + try: + if qsl is not None and qsl.qsl is not None: + logger.info(" - Destroying Query Sample Library...") + lg.DestroyQSL(qsl.qsl) + qsl.qsl = None + logger.info(" ✓ QSL destroyed") + except Exception as e: + logger.warning(f" ! Error destroying QSL: {e}") + + # 4. Cleanup backend last + try: + if backend is not None and backend.initialized: + logger.info(" - Cleaning up backend connection...") + backend.cleanup() + backend = None + logger.info(" ✓ Backend cleaned up") + except Exception as e: + logger.warning(f" ! Error cleaning up backend: {e}") + + try: + # Create output directories + output_dir = Path(args.output_dir) + log_dir = output_dir / args.scenario / \ + ("accuracy" if args.accuracy else "performance") + log_dir.mkdir(parents=True, exist_ok=True) + + logger.info("=" * 80) + logger.info("MLPerf Inference Benchmark Runner for GPT-OSS") + logger.info("=" * 80) + logger.info(f"Backend: {args.backend}") + logger.info(f"Scenario: {args.scenario}") + logger.info(f"Accuracy: {args.accuracy}") + logger.info(f"Input file: {args.input_file}") + logger.info(f"Output directory: {log_dir}") + logger.info("=" * 80) + + # Load dataset + logger.debug("Loading tokenized dataset...") + with tqdm(total=1, desc="Loading dataset", unit="file") as pbar: + dataset_info = load_tokenized_dataset( + args.input_file, + max_samples=args.max_samples + ) + prompts = dataset_info["prompts"] + df = dataset_info["dataframe"] + pbar.update(1) + + logger.info(f"Loaded {len(prompts)} prompts from dataset") + + # Load generation configuration + logger.info("Loading generation configuration...") + gen_config = load_generation_config(args.generation_config) + + # Extract generation parameters with defaults + # CLI override takes precedence over config file + if args.max_new_tokens is not None: + max_tokens = args.max_new_tokens + logger.info( + f"Using max_new_tokens from CLI override: {max_tokens}") + else: + max_tokens = gen_config.get('max_new_tokens', 10240) + logger.info(f"Using max_new_tokens from config: {max_tokens}") + + temperature = gen_config.get('temperature', 1.0) + top_k = gen_config.get('top_k', -1) + top_p = gen_config.get('top_p', 1.0) + + logger.info("Generation parameters:") + logger.info(f" max_new_tokens: {max_tokens}") + logger.info(f" temperature: {temperature}") + logger.info(f" top_k: {top_k}") + logger.info(f" top_p: {top_p}") + + # Initialize backend + logger.debug(f"Initializing {args.backend} backend...") + if args.backend == "sglang": + # Set pool size to match max_concurrency with small safety margin + # This prevents "connection pool is full" warnings + pool_size = int(args.max_concurrency * 1.1) # 10% safety margin + backend = SGLangBackend( + server_url=args.server_url, + timeout=args.timeout, + max_pool_size=pool_size + ) + else: + raise ValueError(f"Unknown backend: {args.backend}") + + # Initialize backend + backend.initialize() + + # Create progress bar early so subsequent logs print below it + # Total will be dynamically updated by SUT based on actual queries from LoadGen: + # - Offline: Set once when all queries arrive + # - Server: Incremented as queries arrive + pbar = tqdm( + total=0, # Will be updated dynamically by SUT + desc=f"MLPerf {args.scenario}", + unit="query", + leave=True, + position=0, + mininterval=0.1, + smoothing=0.1, + dynamic_ncols=True, + file=sys.stdout # Force unbuffered output for async updates + ) + + # Create SUT with progress bar + logger.debug(f"Creating {args.scenario} SUT...") + if args.scenario == "offline": + sut = OfflineSUT( + backend=backend, + dataset=prompts, + max_tokens=max_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + name=f"gpt-oss-120b_offline_sut", + progress_bar=pbar, + max_concurrency=args.max_concurrency + ) + else: # server + sut = ServerSUT( + backend=backend, + dataset=prompts, + max_tokens=max_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + num_workers=args.num_workers, + name=f"gpt-oss-120b_server_sut", + progress_bar=pbar + ) + + # Create QSL + logger.info("Creating Query Sample Library...") + qsl = QuerySampleLibrary(prompts) + qsl.qsl = lg.ConstructQSL( + len(prompts), + len(prompts), + qsl.load_query_samples, + qsl.unload_query_samples + ) + + # Configure LoadGen + settings = configure_loadgen( + scenario=args.scenario, + accuracy_mode=args.accuracy, + mlperf_conf=args.mlperf_conf, + user_conf=args.user_conf, + log_dir=str(log_dir) + ) + + # Configure logging + log_settings = lg.LogSettings() + log_settings.log_output.outdir = str(log_dir) + log_settings.log_output.copy_summary_to_stdout = True + log_settings.enable_trace = False + + # Start the SUT and run test + logger.info("Running LoadGen test...") + sut.start() + lg.StartTestWithLogSettings( + sut.sut, + qsl.qsl, + settings, + log_settings + ) + logger.info("LoadGen test completed successfully") + + # Give LoadGen a moment to finish internal cleanup + import time + time.sleep(0.2) + + # Flush queries + logger.info("Flushing queries...") + with tqdm(total=1, desc="Flushing queries", unit="batch") as pbar: + sut.flush_queries() + pbar.update(1) + + # Get results + logger.info("Retrieving results...") + with tqdm(total=1, desc="Getting results", unit="batch") as pbar: + results = sut.get_results() + pbar.update(1) + logger.info(f"Retrieved {len(results)} results from SUT") + + logger.info(f"MLPerf results saved to: {log_dir}") + + # If in accuracy mode, prompt user to run evaluation + if args.accuracy: + logger.info("=" * 80) + logger.info("Accuracy mode completed!") + logger.info("To evaluate accuracy, run:") + logger.info( + f" python eval_accuracy.py --input-file {log_dir}/mlperf_log_accuracy.json") + logger.info("=" * 80) + + except KeyboardInterrupt: + logger.info("\n" + "=" * 80) + logger.info("⚠️ Test interrupted by user (Ctrl+C)") + logger.info("=" * 80) + do_cleanup() + logger.info("=" * 80) + logger.info("✓ Cleanup completed successfully") + logger.info("=" * 80) + # Exit immediately to prevent finally block from running + os._exit(130) # Use os._exit to skip finally block + + except Exception as e: + logger.error("\n" + "=" * 80) + logger.error(f"❌ Error during test: {e}") + logger.error("=" * 80) + logger.error("Stack trace:", exc_info=True) + do_cleanup() + logger.error("=" * 80) + # Exit immediately to prevent finally block from running + os._exit(1) + + finally: + # Only run cleanup if not already done (normal exit path) + if not cleanup_done: + do_cleanup() + logger.info("=" * 80) + logger.info("✓ Cleanup completed successfully") + logger.info("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/language/gpt-oss-120b/setup.sh b/language/gpt-oss-120b/setup.sh new file mode 100755 index 0000000000..23188a0cbd --- /dev/null +++ b/language/gpt-oss-120b/setup.sh @@ -0,0 +1,3 @@ +pip install -r requirements.txt +git_dir=$(git rev-parse --show-toplevel) +pip install $git_dir/loadgen \ No newline at end of file diff --git a/language/gpt-oss-120b/setup_enroot.sh b/language/gpt-oss-120b/setup_enroot.sh new file mode 100755 index 0000000000..c534ded13e --- /dev/null +++ b/language/gpt-oss-120b/setup_enroot.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +sqsh_location=$(readlink -f $(dirname $0))/sqsh_files +sandbox_name=sglang_v0.5.4.post2 +docker_image=lmsysorg/sglang:v0.5.4.post2 + +while [[ $# -gt 0 ]]; do + case $1 in + --docker_image) + docker_image=$2 + shift 2 + ;; + --sandbox_name) + sandbox_name=$2 + shift 2 + ;; + *) + echo "Unknown argument: $1" + echo "Usage: $0 --docker_image --sandbox_name " + exit 1 + ;; + esac +done + +mkdir -p $sqsh_location +enroot import -o $sqsh_location/$sandbox_name.sqsh docker://$docker_image +enroot create --name $sandbox_name $sqsh_location/$sandbox_name.sqsh +# enroot start --mount $(pwd):$(pwd) --root --rw $sandbox_name diff --git a/language/gpt-oss-120b/sglang/run_infer.py b/language/gpt-oss-120b/sglang/run_infer.py new file mode 100644 index 0000000000..049ef4112f --- /dev/null +++ b/language/gpt-oss-120b/sglang/run_infer.py @@ -0,0 +1,505 @@ +#!/usr/bin/env python3 +""" +Script to send pre-tokenized requests to SGLang server. + +Usage: + python run_infer.py --input-tokens tokenized_data.pkl [options] + +Arguments: + --input-tokens Path to pickle file containing pre-tokenized data from harmony-tokens.py + --server-url SGLang server URL (default: http://localhost:30000) + --max-samples Maximum number of samples to process (default: all) + --max-tokens Maximum tokens to generate per request (default: 100) + --max-concurrency Maximum number of concurrent requests (default: 256) + --output Output pickle file for responses (optional) + --pass-k Number of inference passes per sample for pass@k strategy (default: 1) +""" + +import requests +import json +import time +import argparse +from typing import List, Dict, Any +import logging +from multiprocessing import Pool +import pandas as pd +from tqdm import tqdm +from transformers import AutoTokenizer + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Initialize tokenizer +MODEL_NAME = "openai/gpt-oss-120b" +tokenizer = None + + +def get_tokenizer(): + """Get or initialize the tokenizer.""" + global tokenizer + if tokenizer is None: + logger.info(f"Loading tokenizer for {MODEL_NAME}...") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + logger.info("Tokenizer loaded successfully") + return tokenizer + + +class SGLangClient: + def __init__(self, + server_url: str = "http://localhost:30000", + temperature: float = 0.001, + top_k: int = 1, + top_p: float = 1.0, + timeout: int = 1200 + ): + self.base_url = server_url + self.session = requests.Session() + self.temperature = temperature + self.top_k = top_k + self.top_p = top_p + self.timeout = timeout + + def send_request( + self, input_ids: List[int], max_tokens: int = 100) -> Dict[str, Any]: + """Send a single request to the SGLang server.""" + # SGLang format with input_ids + payload = { + "input_ids": input_ids, + "sampling_params": { + "max_new_tokens": max_tokens, + "temperature": self.temperature, + "top_k": self.top_k, + "top_p": self.top_p, + } + } + + try: + response = self.session.post( + f"{self.base_url}/generate", + json=payload, + timeout=self.timeout, + ) + if response.status_code == 200: + return response.json() + else: + logger.error( + f"Request failed with status {response.status_code}: {response.text}") + return {"error": f"HTTP {response.status_code}: {response.text}"} + except requests.exceptions.RequestException as e: + logger.error(f"Request failed: {e}") + return {"error": str(e)} + + +def load_tokenized_data(data_file: str) -> pd.DataFrame: + """Load pre-tokenized data from pickle file produced by harmony-tokens.py.""" + logger.info(f"Loading tokenized data from {data_file}") + + # Load DataFrame from pickle + df = pd.read_pickle(data_file) + logger.info(f"Loaded DataFrame with shape: {df.shape}") + + # Check if tok_input column exists and has valid data + if 'tok_input' in df.columns: + # Check for any None values in tok_input (indicating failed + # tokenization) + failed_mask = df['tok_input'].isna() + failed_count = failed_mask.sum() + + if failed_count > 0: + failed_indices = df[failed_mask].index.unique() + error_msg = f"Found {failed_count} failed tokenized samples at indices: {failed_indices.tolist()}" + logger.error(error_msg) + raise AssertionError(error_msg) + + # Check first sample + first_tokens = df.iloc[0]['tok_input'] + if isinstance(first_tokens, list): + logger.info(f"First sample token length: {len(first_tokens)}") + else: + logger.warning( + "tok_input column exists but first sample is not a list") + + logger.info(f"All {len(df)} samples were successfully tokenized") + else: + logger.warning("No 'tok_input' column found in DataFrame") + + return df + + +def send_single_request(args_tuple): + """Send a single request - used by multiprocessing pool.""" + input_ids, max_tokens, server_url, sample_id, pass_num, temperature, top_k, top_p, timeout = args_tuple + + # Create a new client for this process + client = SGLangClient( + server_url=server_url, + temperature=temperature, + top_k=top_k, + top_p=top_p, + timeout=timeout) + + try: + # Track latency: time from request sent to response received + start_time = time.time() + response = client.send_request(input_ids, max_tokens=max_tokens) + end_time = time.time() + latency = end_time - start_time + return sample_id, pass_num, response, latency + except Exception as e: + logger.error(f"Request {sample_id} (pass {pass_num}) failed: {e}") + # Return None for latency on error + return sample_id, pass_num, {"error": str(e)}, None + + +def send_requests_parallel(tokenized_df: pd.DataFrame, server_url: str, + max_tokens: int = 100, max_concurrency: int = 128, temperature: float = 0.001, top_k: int = 1, top_p: float = 1.0, timeout: int = 1200, + pass_k: int = 1): + """Send all requests to SGLang server in parallel using multiprocessing. + + Args: + pass_k: Number of inference passes per sample for pass@k strategy + + Returns: + tuple: (responses_by_pass, latencies_by_pass) - Dict mapping (sample_id, pass_num) to response/latency + """ + num_samples = len(tokenized_df) + total_requests = num_samples * pass_k + logger.info( + f"Sending {total_requests} requests ({num_samples} samples × {pass_k} passes) to server with {max_concurrency} concurrent workers...") + + # Prepare arguments for multiprocessing - create pass_k requests per sample + args_list = [] + for idx, row in tokenized_df.iterrows(): + for pass_num in range(pass_k): + args_list.append(( + row['tok_input'], max_tokens, server_url, + idx, pass_num, temperature, top_k, top_p, timeout + )) + + start_time = time.time() + + with Pool(processes=min(max_concurrency, total_requests)) as pool: + results = list(tqdm( + pool.imap_unordered(send_single_request, args_list), + total=len(args_list), + desc="Sending requests", + unit="request" + )) + + # Group results by sample_id and pass_num + responses_by_pass = {} + latencies_by_pass = {} + for sample_id, pass_num, response, latency in results: + responses_by_pass[(sample_id, pass_num)] = response + latencies_by_pass[(sample_id, pass_num)] = latency + + total_time = time.time() - start_time + logger.info( + f"Completed {total_requests} requests in {total_time:.2f} seconds") + logger.info(f"Average rate: {total_requests/total_time:.2f} requests/sec") + + # Log latency statistics + valid_latencies = [ + lat for lat in latencies_by_pass.values() if lat is not None] + if valid_latencies: + avg_latency = sum(valid_latencies) / len(valid_latencies) + min_latency = min(valid_latencies) + max_latency = max(valid_latencies) + logger.info( + f"Latency stats - Avg: {avg_latency:.3f}s, Min: {min_latency:.3f}s, Max: {max_latency:.3f}s") + + return responses_by_pass, latencies_by_pass + + +def extract_response_ids( + responses_by_pass: Dict[tuple, Dict[str, Any]], tokenized_df: pd.DataFrame, pass_k: int) -> Dict[tuple, List[int]]: + """Extract response output_ids from SGLang responses for all passes. + + Args: + responses_by_pass: Dict mapping (sample_id, pass_num) to response + tokenized_df: DataFrame with samples + pass_k: Number of passes per sample + + Returns: + Dict mapping (sample_id, pass_num) to output_ids list + """ + logger.info("Extracting response output_ids...") + + response_ids_by_pass = {} + total_responses = len(tokenized_df) * pass_k + + with tqdm(total=total_responses, desc="Extracting responses", unit="response") as pbar: + for idx, row in tokenized_df.iterrows(): + for pass_num in range(pass_k): + response = responses_by_pass.get((idx, pass_num), {}) + response_id = [] + if "error" not in response and "output_ids" in response: + try: + # SGLang returns the generated token IDs in the + # 'output_ids' field + response_id = response["output_ids"] + except Exception as e: + logger.warning( + f"Failed to extract response for sample {idx}, pass {pass_num}: {e}") + response_ids_by_pass[(idx, pass_num)] = response_id + pbar.update(1) + + logger.info("Response output_ids extraction complete") + return response_ids_by_pass + + +def detokenize_output_ids( + response_ids_by_pass: Dict[tuple, List[int]], pass_k: int) -> Dict[tuple, str]: + """Detokenize output_ids back to text using AutoTokenizer for all passes. + + Args: + response_ids_by_pass: Dict mapping (sample_id, pass_num) to output_ids + pass_k: Number of passes per sample + + Returns: + Dict mapping (sample_id, pass_num) to detokenized text + """ + logger.info("Detokenizing output_ids to text...") + + tokenizer = get_tokenizer() + detokenized_texts_by_pass = {} + + for (sample_id, pass_num), token_ids in tqdm( + response_ids_by_pass.items(), desc="Detokenizing outputs", unit="output"): + try: + # Detokenize the token IDs back to text + text = tokenizer.decode(token_ids, skip_special_tokens=True) + detokenized_texts_by_pass[(sample_id, pass_num)] = text + except Exception as e: + logger.warning( + f"Failed to detokenize output for sample {sample_id}, pass {pass_num}: {e}") + detokenized_texts_by_pass[(sample_id, pass_num)] = "" + + logger.info("Output detokenization complete") + return detokenized_texts_by_pass + + +def save_responses(responses_by_pass: Dict[tuple, Dict[str, Any]], + response_ids_by_pass: Dict[tuple, List[int]], + detokenized_texts_by_pass: Dict[tuple, str], + latencies_by_pass: Dict[tuple, float], + tokenized_df: pd.DataFrame, pass_k: int, output_file: str = None) -> pd.DataFrame: + """Save all responses to DataFrame and optionally to pickle file. + + Args: + responses_by_pass: Dict mapping (sample_id, pass_num) to response + response_ids_by_pass: Dict mapping (sample_id, pass_num) to output_ids + detokenized_texts_by_pass: Dict mapping (sample_id, pass_num) to text + latencies_by_pass: Dict mapping (sample_id, pass_num) to latency + tokenized_df: Original DataFrame with samples + pass_k: Number of passes per sample + output_file: Optional output pickle file + + Returns: + DataFrame with columns for each pass (e.g., model_output_0, model_output_1, ...) + """ + logger.info("Processing responses and updating DataFrame...") + + # Work with the original DataFrame + result_df = tokenized_df.copy() + + # Create columns for each pass with _0, _1, _2, ... suffixes + for pass_num in range(pass_k): + # Lists to store data for this pass + model_outputs = [] + tok_model_outputs = [] + tok_model_output_lens = [] + infer_times = [] + + for idx in tokenized_df.index: + key = (idx, pass_num) + detokenized_text = detokenized_texts_by_pass.get(key, "") + response_ids = response_ids_by_pass.get(key, []) + latency = latencies_by_pass.get(key, None) + + model_outputs.append(detokenized_text) + tok_model_outputs.append(response_ids) + tok_model_output_lens.append(len(response_ids)) + infer_times.append(latency) + + # Add columns with suffixes + result_df[f'model_output_{pass_num}'] = model_outputs + result_df[f'tok_model_output_{pass_num}'] = tok_model_outputs + result_df[f'tok_model_output_len_{pass_num}'] = tok_model_output_lens + result_df[f'infer_time_{pass_num}'] = infer_times + + # Calculate output token lengths for logging + all_output_token_lengths = [] + for idx in tokenized_df.index: + for pass_num in range(pass_k): + key = (idx, pass_num) + response = responses_by_pass.get(key, {}) + response_ids = response_ids_by_pass.get(key, []) + try: + output_token_length = response.get( + "meta_info", {}).get( + "completion_tokens", len(response_ids)) + all_output_token_lengths.append(output_token_length) + except Exception as e: + logger.warning( + f"Failed to calculate output tokens for sample {idx}, pass {pass_num}: {e}") + all_output_token_lengths.append(len(response_ids)) + + logger.info(f"Updated DataFrame with shape: {result_df.shape}") + new_columns = [ + f'model_output_{i}, tok_model_output_{i}, tok_model_output_len_{i}, infer_time_{i}' for i in range(pass_k)] + logger.info(f"Added columns for {pass_k} passes: {', '.join(new_columns)}") + if all_output_token_lengths: + logger.info( + f"Average output token length: {sum(all_output_token_lengths)/len(all_output_token_lengths):.1f}") + + # Save to pickle file if output_file is provided + if output_file: + logger.info(f"Saving responses to {output_file}...") + result_df.to_pickle(output_file) + logger.info(f"Responses saved to {output_file}") + + return result_df + + +def process_requests(tokenized_df: pd.DataFrame, server_url: str, + max_samples: int = None, max_tokens: int = 100, + max_concurrency: int = 128, output_file: str = None, temperature: float = 0.001, top_k: int = 1, top_p: float = 1.0, + timeout: int = 1200, pass_k: int = 1) -> pd.DataFrame: + """Main processing function that handles requests and response extraction. + + Args: + pass_k: Number of inference passes per sample for pass@k strategy + """ + + # Step 1: Limit samples if specified + if max_samples is not None: + tokenized_df = tokenized_df.head(max_samples) + logger.info(f"Limited to first {max_samples} samples") + + # Step 2: Send all requests in parallel (k passes per sample) + responses_by_pass, latencies_by_pass = send_requests_parallel( + tokenized_df, + server_url, + max_tokens, + max_concurrency, + temperature, + top_k, + top_p, + timeout, + pass_k) + + # Step 3: Extract response output_ids for all passes + response_ids_by_pass = extract_response_ids( + responses_by_pass, tokenized_df, pass_k) + + # Step 4: Detokenize output_ids to text for model_output for all passes + detokenized_texts_by_pass = detokenize_output_ids( + response_ids_by_pass, pass_k) + + # Step 5: Save all results and return DataFrame + result_df = save_responses( + responses_by_pass, + response_ids_by_pass, + detokenized_texts_by_pass, + latencies_by_pass, + tokenized_df, + pass_k, + output_file) + + return result_df + + +def main(): + parser = argparse.ArgumentParser( + description="Send pre-tokenized requests to SGLang server") + parser.add_argument("--input-tokens", required=True, + help="Path to pickle file containing pre-tokenized data from harmony-tokens.py") + parser.add_argument("--server-url", default="http://localhost:30000", + help="SGLang server URL (default: http://localhost:30000)") + parser.add_argument("--max-samples", type=int, default=None, + help="Maximum number of samples to process (default: all)") + parser.add_argument("--max-tokens", type=int, default=100, + help="Maximum tokens to generate per request") + parser.add_argument("--max-concurrency", type=int, default=256, + help="Maximum number of concurrent requests (default: 256)") + parser.add_argument("--output", default=None, + help="Output pickle file for responses (optional)") + parser.add_argument("--pass-k", type=int, default=1, + help="Number of inference passes per sample for pass@k strategy (default: 1)") + parser.add_argument("--temperature", type=float, default=0.001, + help="Temperature for sampling (default: 0.001)") + parser.add_argument("--top-k", type=int, default=1, + help="Top-k for sampling (default: 1)") + parser.add_argument("--top-p", type=float, default=1.0, + help="Top-p for sampling (default: 1.0)") + parser.add_argument("--timeout", type=int, default=1200, + help="Timeout for requests (default: 1200)") + + args = parser.parse_args() + + # Test connection + logger.info(f"Testing server connection to {args.server_url}...") + test_client = SGLangClient( + server_url=args.server_url, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + timeout=args.timeout) + + test_response = test_client.send_request(input_ids=[1, 2, 3], max_tokens=5) + if "error" in test_response: + logger.error(f"Server connection failed: {test_response['error']}") + logger.error("Make sure your SGLang server is running. Try:") + logger.error( + " python -m sglang.launch_server --model-path openai/gpt-oss-120b --mem-fraction-static 0.98 --tp 8") + return + logger.info("Server connection successful") + + # Load pre-tokenized data + tokenized_df = load_tokenized_data(args.input_tokens) + + # Process requests and get result DataFrame + result_df = process_requests(tokenized_df, args.server_url, + max_samples=args.max_samples, + max_tokens=args.max_tokens, + max_concurrency=args.max_concurrency, + output_file=args.output, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + timeout=args.timeout, + pass_k=args.pass_k) + + # Print summary + logger.info(f"\nProcessing completed:") + logger.info(f" - Total samples processed: {len(result_df)}") + logger.info(f" - Number of passes per sample: {args.pass_k}") + logger.info( + f" - Average input token length: {result_df['tok_input_len'].mean():.1f}") + + # Calculate average output length across all passes + if args.pass_k == 1: + avg_output_len = result_df['tok_model_output_len_0'].mean() + logger.info(f" - Average output token length: {avg_output_len:.1f}") + else: + all_output_lens = [] + for i in range(args.pass_k): + all_output_lens.extend( + result_df[f'tok_model_output_len_{i}'].tolist()) + avg_output_len = sum(all_output_lens) / \ + len(all_output_lens) if all_output_lens else 0 + logger.info( + f" - Average output token length (across all passes): {avg_output_len:.1f}") + + if args.output: + logger.info(f" - Results saved to: {args.output}") + else: + logger.info(" - Results returned as DataFrame (not saved to file)") + + +if __name__ == "__main__": + main() diff --git a/language/gpt-oss-120b/sglang/run_server.sh b/language/gpt-oss-120b/sglang/run_server.sh new file mode 100755 index 0000000000..f988ea6b5f --- /dev/null +++ b/language/gpt-oss-120b/sglang/run_server.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +pip install -r requirements.txt + +dp=1 +model_path=openai/gpt-oss-120b +eagle_path="" +stream_interval=500 +extra_args="" + +while [[ $# -gt 0 ]]; do + case $1 in + --dp) + dp=$2 + shift 2 + ;; + --model_path) + model_path=$2 + shift 2 + ;; + --eagle_path) + eagle_path=$2 + shift 2 + ;; + --stream_interval) + stream_interval=$2 + shift 2 + ;; + *) + extra_args="$extra_args $1" + shift 1 + ;; + esac +done + +args=" --model-path $model_path \ + --host 0.0.0.0 \ + --data-parallel-size=$dp \ + --max-running-requests $((dp * 512)) \ + --mem-fraction-static 0.85 \ + --chunked-prefill-size 16384 \ + --ep-size=1 \ + --enable-metrics \ + --stream-interval $stream_interval " + +if [ -n "$eagle_path" ]; then + args="$args --speculative-draft-model-path $eagle_path \ + --speculative-algorithm EAGLE3" +fi + +# --speculative-num-steps 1 \ +# --speculative-eagle-topk 1 \ +# --speculative-num-draft-tokens 3 \ + + +set -x; +python3 -m sglang.launch_server $args $extra_args diff --git a/language/gpt-oss-120b/submodules/LiveCodeBench b/language/gpt-oss-120b/submodules/LiveCodeBench new file mode 120000 index 0000000000..d1e5c66592 --- /dev/null +++ b/language/gpt-oss-120b/submodules/LiveCodeBench @@ -0,0 +1 @@ +../../deepseek-r1/submodules/LiveCodeBench \ No newline at end of file diff --git a/language/gpt-oss-120b/submodules/prm800k b/language/gpt-oss-120b/submodules/prm800k new file mode 120000 index 0000000000..1b078c3842 --- /dev/null +++ b/language/gpt-oss-120b/submodules/prm800k @@ -0,0 +1 @@ +../../deepseek-r1/submodules/prm800k \ No newline at end of file diff --git a/language/gpt-oss-120b/utils/__init__.py b/language/gpt-oss-120b/utils/__init__.py new file mode 100644 index 0000000000..9b3b53963d --- /dev/null +++ b/language/gpt-oss-120b/utils/__init__.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 +"""Utilities for gpt-oss MLPerf integration.""" + +from .tokenization import StandardTokenizer, load_tokenized_dataset + +__all__ = [ + "StandardTokenizer", + "load_tokenized_dataset", +] diff --git a/language/gpt-oss-120b/utils/tokenization.py b/language/gpt-oss-120b/utils/tokenization.py new file mode 100644 index 0000000000..533f5c6691 --- /dev/null +++ b/language/gpt-oss-120b/utils/tokenization.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +"""Tokenization utilities for gpt-oss.""" + +import logging +from typing import List, Dict, Any, Optional +import numpy as np +import pandas as pd +from transformers import AutoTokenizer + +logger = logging.getLogger(__name__) + +MODEL_NAME = "openai/gpt-oss-120b" + + +class StandardTokenizer: + """Standard tokenizer wrapper for gpt-oss model.""" + + def __init__(self, model_name: str = MODEL_NAME): + """Initialize the tokenizer. + + Args: + model_name: HuggingFace model name or path + """ + self.model_name = model_name + self.tokenizer = None + logger.info(f"Initializing tokenizer for {model_name}") + + def load(self) -> None: + """Load the tokenizer.""" + if self.tokenizer is None: + logger.info(f"Loading tokenizer from {self.model_name}") + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + logger.info("Tokenizer loaded successfully") + + def encode(self, text: str) -> List[int]: + """Encode text to token IDs. + + Args: + text: Input text + + Returns: + List of token IDs + """ + if self.tokenizer is None: + self.load() + return self.tokenizer.encode(text) + + def decode(self, token_ids: List[int], + skip_special_tokens: bool = True) -> str: + """Decode token IDs to text. + + Args: + token_ids: List of token IDs + skip_special_tokens: Whether to skip special tokens + + Returns: + Decoded text + """ + if self.tokenizer is None: + self.load() + return self.tokenizer.decode( + token_ids, skip_special_tokens=skip_special_tokens) + + def __call__(self, text: str) -> List[int]: + """Encode text to token IDs (callable interface). + + Args: + text: Input text + + Returns: + List of token IDs + """ + return self.encode(text) + + +def load_tokenized_dataset( + dataset_path: str, + max_samples: Optional[int] = None +) -> Dict[str, Any]: + """Load a tokenized dataset from parquet or pickle file. + + Args: + dataset_path: Path to the parquet or pickle file containing tokenized data + max_samples: Maximum number of samples to load (None for all) + + Returns: + Dictionary containing: + - prompts: List of tokenized prompts + - dataframe: Original DataFrame + - metadata: Additional metadata + """ + logger.info(f"Loading tokenized dataset from {dataset_path}") + + # Load DataFrame based on file extension + if dataset_path.endswith('.parquet'): + df = pd.read_parquet(dataset_path) + logger.info(f"Loaded Parquet DataFrame with shape: {df.shape}") + elif dataset_path.endswith('.pkl') or dataset_path.endswith('.pickle'): + df = pd.read_pickle(dataset_path) + logger.info(f"Loaded Pickle DataFrame with shape: {df.shape}") + else: + # Try to auto-detect based on file content + try: + df = pd.read_parquet(dataset_path) + logger.info( + f"Auto-detected Parquet format, loaded DataFrame with shape: {df.shape}") + except Exception: + df = pd.read_pickle(dataset_path) + logger.info( + f"Auto-detected Pickle format, loaded DataFrame with shape: {df.shape}") + + # Convert numpy arrays to native Python types for JSON serialization + for col in df.columns: + # Check if column contains numpy arrays + if df[col].dtype == object: + df[col] = df[col].apply( + lambda x: x.tolist() if isinstance(x, np.ndarray) else x + ) + + # Limit samples if specified + if max_samples is not None: + df = df.head(max_samples) + logger.info(f"Limited to {max_samples} samples") + + # Extract tokenized prompts - support both column names + if 'tok_input' in df.columns: # pre-v4.0 + token_col = 'tok_input' + elif 'input_tokens' in df.columns: # v4.0+ + token_col = 'input_tokens' + else: + raise ValueError( + "Dataset must have 'tok_input' or 'input_tokens' column with tokenized prompts") + + # Verify tokenization + failed_mask = df[token_col].isna() + if failed_mask.any(): + failed_count = failed_mask.sum() + logger.error(f"Found {failed_count} samples with failed tokenization") + raise ValueError(f"{failed_count} samples have invalid tokenization") + + prompts = df[token_col].tolist() + logger.info(f"Loaded {len(prompts)} tokenized prompts") + + # Log statistics + prompt_lengths = [len(p) for p in prompts] + logger.info( + f"Prompt length stats - " + f"min: {min(prompt_lengths)}, " + f"max: {max(prompt_lengths)}, " + f"mean: {sum(prompt_lengths)/len(prompt_lengths):.1f}" + ) + + return { + "prompts": prompts, + "dataframe": df, + "metadata": { + "num_samples": len(prompts), + "min_length": min(prompt_lengths), + "max_length": max(prompt_lengths), + "mean_length": sum(prompt_lengths) / len(prompt_lengths) + } + }