Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,11 @@ paths:
cleaned: "data/cleaned"
final: "data/final"

# LLM Provider configuration
llm:
# Provider selection: "vllm" or "api-endpoint"
provider: "api-endpoint"

# vllm: Configure VLLM server settings
vllm:
api_base: "http://localhost:8000/v1"
Expand All @@ -547,6 +552,15 @@ vllm:
max_retries: 3
retry_delay: 1.0

# API endpoint configuration
api-endpoint:
api_base: "https://api.llama.com/v1" # Optional base URL for API endpoint (null for default API)
api_key: "llama-api-key" # API key for API endpoint or compatible service (can use env var instead)
model: "Llama-4-Maverick-17B-128E-Instruct-FP8" # Default model to use
azure_api_version: "2024-06-01" # API version needed for Azure OpenAI endpoints. Make it Null for other providers
max_retries: 3 # Number of retries for API calls
retry_delay: 1.0 # Initial delay between retries (seconds)

# generation: Content generation parameters
generation:
temperature: 0.7
Expand Down
26 changes: 23 additions & 3 deletions synthetic_data_kit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,23 @@ def system_check(
console.print(f"API key source: {'Environment variable' if api_endpoint_key else 'Config file'}")

model = api_endpoint_config.get("model")

# Check for azure api version in environment variables
azure_api_version_key = os.environ.get('AZURE_API_VERSION')
console.print(f"AZURE_API_VERSION_KEY environment variable: {'Found' if azure_api_version_key else 'Not found'}")

# Set API VERSION key with priority: env var > config
azure_api_version = azure_api_version_key or api_endpoint_config.get("azure_api_version")
if azure_api_version:
console.print(f"API version source: {'Environment variable' if azure_api_version_key else 'Config file'}")

# Check API endpoint access
with console.status(f"Checking API endpoint access..."):
try:
# Try to import OpenAI
try:
from openai import OpenAI
from openai import AzureOpenAI
except ImportError:
console.print("L API endpoint package not installed", style="red")
console.print("Install with: pip install openai>=1.0.0", style="yellow")
Expand All @@ -105,10 +115,16 @@ def system_check(
client_kwargs['api_key'] = api_key
if api_base:
client_kwargs['base_url'] = api_base
if azure_api_version:
client_kwargs['api_version'] = azure_api_version

# Check API access
try:
client = OpenAI(**client_kwargs)
if azure_api_version:
client = AzureOpenAI(**client_kwargs)
else:
client = OpenAI(**client_kwargs)

# Try a simple models request to check connectivity
messages = [
{"role": "user", "content": "Hello"}
Expand All @@ -118,7 +134,9 @@ def system_check(
messages=messages,
temperature=0.1
)
console.print(f" API endpoint access confirmed", style="green")
console.print("API endpoint access confirmed.. API Responded with: ",
resp.choices[0].message.content, style="green")

if api_base:
console.print(f"Using custom API base: {api_base}", style="green")
console.print(f"Default model: {model}", style="green")
Expand All @@ -128,8 +146,10 @@ def system_check(
console.print(f"L Error connecting to API endpoint: {str(e)}", style="red")
if api_base:
console.print(f"Using custom API base: {api_base}", style="yellow")
if not api_key and not api_base:
if not api_key and api_base:
console.print("API key is required. Set in config.yaml or as API_ENDPOINT_KEY env var", style="yellow")
if not azure_api_version and api_base:
console.print("Azure API version is required. Set in config.yaml or as AZURE_API_VERSION env var", style="yellow")
return 1
except Exception as e:
console.print(f"L Error: {str(e)}", style="red")
Expand Down
1 change: 1 addition & 0 deletions synthetic_data_kit/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ api-endpoint:
api_base: "https://api.llama.com/v1" # Optional base URL for API endpoint (null for default API)
api_key: "llama-api-key" # API key for API endpoint or compatible service (can use env var instead)
model: "Llama-4-Maverick-17B-128E-Instruct-FP8" # Default model to use
azure_api_version: "2024-06-01" # API version needed for Azure OpenAI endpoints
max_retries: 3 # Number of retries for API calls
retry_delay: 1.0 # Initial delay between retries (seconds)

Expand Down
37 changes: 32 additions & 5 deletions synthetic_data_kit/models/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# Try to import OpenAI, but handle case where it's not installed
try:
from openai import OpenAI
from openai import AzureOpenAI
from openai.types.chat import ChatCompletion
OPENAI_AVAILABLE = True
except ImportError:
Expand All @@ -35,6 +36,7 @@ def __init__(self,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
model_name: Optional[str] = None,
azure_api_version: Optional[str] = None,
max_retries: Optional[int] = None,
retry_delay: Optional[float] = None):
"""Initialize an LLM client that supports multiple providers
Expand All @@ -45,6 +47,7 @@ def __init__(self,
api_base: Override API base URL from config
api_key: Override API key for API endpoint (only needed for 'api-endpoint' provider)
model_name: Override model name from config
azure_api_version: Override azure api version from config. Needed for Azure OpenAI endpoints
max_retries: Override max retries from config
retry_delay: Override retry delay from config
"""
Expand Down Expand Up @@ -74,7 +77,16 @@ def __init__(self,

if not self.api_key and not self.api_base: # Only require API key for official API
raise ValueError("API key is required for API endpoint provider. Set in config or API_ENDPOINT_KEY env var.")


# Check for azure api version in environment variables
azure_api_version_key = os.environ.get('AZURE_API_VERSION')
print(f"AZURE_API_VERSION_KEY environment variable: {'Found' if azure_api_version_key else 'Not found'}")

# Set API VERSION key with priority: CLI arg > env var > config
self.azure_api_version = azure_api_version or azure_api_version_key or api_endpoint_config.get("azure_api_version")
if self.azure_api_version:
print(f"API version source: {'Environment variable' if azure_api_version_key else 'Config file'}")

self.model = model_name or api_endpoint_config.get('model')
self.max_retries = max_retries or api_endpoint_config.get('max_retries')
self.retry_delay = retry_delay or api_endpoint_config.get('retry_delay')
Expand Down Expand Up @@ -115,8 +127,16 @@ def _init_openai_client(self):
if self.api_base:
print(f"Using API base URL: {self.api_base}")
client_kwargs['base_url'] = self.api_base

self.openai_client = OpenAI(**client_kwargs)

# Add Azure api version if provided (Needed for Azure OpenAI APIs)
print(f"Using API VERSION: {self.azure_api_version}")
if self.azure_api_version:
print(f"Using API base URL: {self.api_base}")
client_kwargs['api_version'] = self.azure_api_version
# OpenAI library differs for AzureOpenAI support
self.openai_client = AzureOpenAI(**client_kwargs)
else:
self.openai_client = OpenAI(**client_kwargs)

def _check_vllm_server(self) -> tuple:
"""Check if the VLLM server is running and accessible"""
Expand Down Expand Up @@ -353,6 +373,7 @@ async def _process_message_async(self,
"""Process a single message set asynchronously using the OpenAI API"""
try:
from openai import AsyncOpenAI
from openai import AsyncAzureOpenAI
except ImportError:
raise ImportError("The 'openai' package is required for this functionality. Please install it using 'pip install openai>=1.0.0'.")

Expand All @@ -362,8 +383,14 @@ async def _process_message_async(self,
client_kwargs['api_key'] = self.api_key
if self.api_base:
client_kwargs['base_url'] = self.api_base

async_client = AsyncOpenAI(**client_kwargs)

# Add Azure api version if provided (Needed for Azure OpenAI APIs)
if self.azure_api_version:
client_kwargs['api_version'] = self.azure_api_version
# OpenAI library differs for Azure OpenAI support
async_client = AsyncAzureOpenAI(**client_kwargs)
else:
async_client = AsyncOpenAI(**client_kwargs)

for attempt in range(self.max_retries):
try:
Expand Down
1 change: 1 addition & 0 deletions synthetic_data_kit/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def get_openai_config(config: Dict[str, Any]) -> Dict[str, Any]:
'api_base': None, # None means use default API base URL
'api_key': None, # None means use environment variables
'model': 'gpt-4o',
'azure_api_version': None,
'max_retries': 3,
'retry_delay': 1.0
})
Expand Down
23 changes: 22 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def create_api_config(provider="api-endpoint", api_key="mock-key", model="mock-m
"api_base": "https://api.together.xyz/v1",
"api_key": api_key,
"model": model,
"azure_api_version": "2024-06-01",
"max_retries": 3,
"retry_delay": 1,
},
Expand Down Expand Up @@ -232,8 +233,28 @@ def test_env():
@pytest.fixture
def patch_config(config_factory):
"""Patch the config loader to return a mock configuration."""
mock_config = config_factory.create_api_config()

with patch("synthetic_data_kit.utils.config.load_config") as mock_load_config:
mock_load_config.return_value = config_factory.create_api_config()
mock_load_config.return_value = mock_config
yield mock_load_config

@pytest.fixture
def patch_llm_client_config(config_factory):
"""Patch the config loader to return a mock configuration."""
mock_config = config_factory.create_api_config()

with patch("synthetic_data_kit.models.llm_client.load_config") as mock_load_config:
mock_load_config.return_value = mock_config
yield mock_load_config

@pytest.fixture
def patch_cli_config(config_factory):
"""Patch the config loader to return a mock configuration."""
mock_config = config_factory.create_api_config()

with patch("synthetic_data_kit.cli.load_config") as mock_load_config:
mock_load_config.return_value = mock_config
yield mock_load_config


Expand Down
24 changes: 21 additions & 3 deletions tests/functional/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_system_check_command_vllm(patch_config):


@pytest.mark.functional
def test_system_check_command_api_endpoint(patch_config, test_env):
def test_system_check_command_api_endpoint(patch_cli_config, test_env):
"""Test the system-check command with API endpoint provider."""
runner = CliRunner()

Expand All @@ -42,12 +42,30 @@ def test_system_check_command_api_endpoint(patch_config, test_env):
mock_client.models.list.return_value = ["mock-model"]
mock_openai.return_value = mock_client

# Get the mock config from the fixture
config = patch_cli_config.return_value
config["api-endpoint"]["azure_api_version"] = None

result = runner.invoke(app, ["system-check", "--provider", "api-endpoint"])

# Just check exit code, not specific message since it varies
assert result.exit_code == 0
mock_openai.assert_called_once()

@pytest.mark.functional
def test_system_check_command_azure_api_endpoint(patch_cli_config, test_env):
"""Test the system-check command with API endpoint provider."""
runner = CliRunner()

# Mock Azure OpenAI client
with patch("openai.AzureOpenAI") as mock_azure_openai:
mock_client = MagicMock()
mock_client.chat.completions.create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="Hello World!"))]
)
mock_azure_openai.return_value = mock_client

result = runner.invoke(app, ["system-check", "--provider", "api-endpoint"])
assert result.exit_code == 0
mock_azure_openai.assert_called_once()

@pytest.mark.functional
def test_ingest_command(patch_config):
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,16 @@ def test_parse_qa_pairs_invalid_json():


@pytest.mark.unit
def test_llm_client_error_handling(patch_config, test_env):
def test_llm_client_error_handling(patch_llm_client_config, test_env):
"""Test error handling in LLM client."""
with patch("synthetic_data_kit.models.llm_client.OpenAI") as mock_openai:
# Setup mock to raise an exception
mock_openai.side_effect = Exception("API Error")

# Get the mock config from the fixture
config = patch_llm_client_config.return_value
config["api-endpoint"]["azure_api_version"] = None

# Should handle the exception gracefully
with pytest.raises(Exception) as excinfo:
LLMClient(provider="api-endpoint")
Expand Down
Loading