diff --git a/docs/source/operation-guide/llm-track/commands.md b/docs/source/operation-guide/llm-track/commands.md index 6f63c7c..dde277a 100644 --- a/docs/source/operation-guide/llm-track/commands.md +++ b/docs/source/operation-guide/llm-track/commands.md @@ -76,6 +76,18 @@ python scripts/generate_kernel_and_verify.py \ --model-name claude-opus-4-6 ``` +### Third-Party Providers + +Use `--base-url` to connect to any OpenAI-compatible provider. + +```bash +python scripts/generate_kernel_and_verify.py \ + --server-type openai \ + --model-name \ + --base-url \ + --api-key +``` + ## Advanced Options ### Enable Reflection diff --git a/docs/source/operation-guide/llm-track/parameters.md b/docs/source/operation-guide/llm-track/parameters.md index 68f9880..6c8b97c 100644 --- a/docs/source/operation-guide/llm-track/parameters.md +++ b/docs/source/operation-guide/llm-track/parameters.md @@ -15,6 +15,8 @@ LLM Track command-line parameters. |-----------|---------|-------------| | `--op-name` | All | Test a single operator (e.g., `aten::add`) | | `--single-test` | Off | Randomly select 1 operator for quick testing | +| `--base-url` | `http://localhost:8000/v1` | API base URL for OpenAI-compatible providers (e.g., DashScope, vLLM server) | +| `--api-key` | Env var | API key (overrides `OPENAI_API_KEY` / `ANTHROPIC_API_KEY` env var) | | `--dataset` | Auto | Dataset: `KernelGenBench`, `KernelGenBench-aten`, `KernelGenBench-vllm`, `KernelGenBench-cublas` | | `--max-rounds` | 10 | Number of Pass@K rounds | | `--device-count` | 8 | Number of GPUs for verification | @@ -63,6 +65,24 @@ Number of independent kernel samples to generate: - Higher values → better Pass@K coverage - Higher cost → more API calls +### --base-url + +Specify a custom API endpoint for OpenAI-compatible providers: + +```bash +--server-type openai --model-name --base-url +``` + +### --api-key + +Override the default API key from environment variables: + +```bash +--api-key +``` + +If not set, reads from `OPENAI_API_KEY` or `ANTHROPIC_API_KEY` depending on `--server-type`. + ## Output Results saved to `output/pass_at_k//`: diff --git a/scripts/generate_kernel_and_verify.py b/scripts/generate_kernel_and_verify.py index 510ef2f..2c1ec05 100644 --- a/scripts/generate_kernel_and_verify.py +++ b/scripts/generate_kernel_and_verify.py @@ -740,6 +740,8 @@ def main(): # Generation config parser.add_argument("--server-type", type=str, default="openai") parser.add_argument("--model-name", type=str, default="gpt-4o-mini") + parser.add_argument("--base-url", type=str, default=None, help="API base URL (for OpenAI-compatible providers)") + parser.add_argument("--api-key", type=str, default=None, help="API key (overrides OPENAI_API_KEY / ANTHROPIC_API_KEY env var)") parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--max-tokens", type=int, default=16384) parser.add_argument("--num-workers", type=int, default=150) @@ -793,10 +795,18 @@ def main(): run_name = output_dir.name # Create generation config + # Set API key in env if provided + if args.api_key: + os.environ["OPENAI_API_KEY"] = args.api_key + os.environ["ANTHROPIC_API_KEY"] = args.api_key + + base_url = args.base_url if args.base_url else "http://localhost:8000/v1" + gen_config = GenerationConfig( run_name="", server_type=args.server_type, model_name=args.model_name, + base_url=base_url, temperature=args.temperature, max_tokens=args.max_tokens, num_workers=args.num_workers,