Skip to content

Commit 8136ad6

Browse files
authored
SDG Pipeline (#1215)
* Add SDG pipeline Signed-off-by: Ao Tang <[email protected]> * pre-commit Signed-off-by: Ao Tang <[email protected]> * comment resolved Signed-off-by: Ao Tang <[email protected]> * check for existing loop before calling asyncio.run Signed-off-by: Ao Tang <[email protected]> * handle edge case for languagefilter Signed-off-by: Ao Tang <[email protected]> * commont resolved Signed-off-by: Ao Tang <[email protected]> * Add readme for sdg tutorial Signed-off-by: Ao Tang <[email protected]> * remove the copyright for this blank file. Signed-off-by: Ao Tang <[email protected]> * remove quickstart Signed-off-by: Ao Tang <[email protected]> * Add test file to cicd-main Signed-off-by: Ao Tang <[email protected]> * resolve pr comments Signed-off-by: Ao Tang <[email protected]> * pr comment resolved Signed-off-by: Ao Tang <[email protected]> * add RayClient Signed-off-by: Ao Tang <[email protected]> * ruff Signed-off-by: Ao Tang <[email protected]> * ruff Signed-off-by: Ao Tang <[email protected]> * adding configurable generation parameters Signed-off-by: Ao Tang <[email protected]> * change to public model Signed-off-by: Ao Tang <[email protected]> --------- Signed-off-by: Ao Tang <[email protected]>
1 parent 55ddb34 commit 8136ad6

File tree

15 files changed

+1763
-9
lines changed

15 files changed

+1763
-9
lines changed

.github/workflows/cicd-main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ jobs:
127127
matrix:
128128
os: [ubuntu-latest]
129129
python-version: ["3.10", "3.12"]
130-
folder: ["backends", "core", "models", "pipelines", "stages-audio", "stages-common", "stages-deduplication", "stages-image", "stages-text", "stages-video", "tasks", "utils"]
130+
folder: ["backends", "core", "models", "pipelines", "stages-audio", "stages-common", "stages-deduplication", "stages-image", "stages-synthetic", "stages-text", "stages-video", "tasks", "utils"]
131131
needs: [pre-flight, cicd-wait-in-queue]
132132
runs-on: ${{ matrix.os }}
133133
name: Unit_Test_${{ matrix.folder}}_CPU_python-${{ matrix.python-version }}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .llm_client import AsyncLLMClient, LLMClient
16+
from .openai_client import AsyncOpenAIClient, OpenAIClient
17+
18+
__all__ = [
19+
"AsyncLLMClient",
20+
"AsyncOpenAIClient",
21+
"LLMClient",
22+
"OpenAIClient",
23+
]
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import secrets
17+
from abc import ABC, abstractmethod
18+
from collections.abc import Iterable
19+
from dataclasses import dataclass
20+
21+
from loguru import logger
22+
23+
24+
class ConversationFormatter(ABC):
25+
"""
26+
Represents a way of formatting a conversation with an LLM
27+
such that it can response appropriately
28+
"""
29+
30+
@abstractmethod
31+
def format_conversation(self, conv: list[dict]) -> str:
32+
msg = "format_conversation must be implemented by subclasses"
33+
raise NotImplementedError(msg)
34+
35+
36+
@dataclass
37+
class GenerationConfig:
38+
"""Configuration class for LLM generation parameters."""
39+
40+
max_tokens: int | None = 2048
41+
n: int | None = 1
42+
seed: int | None = 0
43+
stop: str | None | list[str] = None
44+
stream: bool = False
45+
temperature: float | None = 0.0
46+
top_k: int | None = None
47+
top_p: float | None = 0.95
48+
49+
50+
class LLMClient(ABC):
51+
"""
52+
Interface representing a client connecting to an LLM inference server
53+
and making requests synchronously
54+
"""
55+
56+
@abstractmethod
57+
def setup(self) -> None:
58+
"""
59+
Setup the client.
60+
"""
61+
62+
@abstractmethod
63+
def query_model(
64+
self,
65+
*,
66+
messages: Iterable,
67+
model: str,
68+
conversation_formatter: ConversationFormatter | None = None,
69+
generation_config: GenerationConfig | dict | None = None,
70+
) -> list[str]:
71+
msg = "Subclass of LLMClient must implement 'query_model'"
72+
raise NotImplementedError(msg)
73+
74+
75+
class AsyncLLMClient(ABC):
76+
"""
77+
Interface representing a client connecting to an LLM inference server
78+
and making requests asynchronously
79+
"""
80+
81+
def __init__(self, max_concurrent_requests: int = 5, max_retries: int = 3, base_delay: float = 1.0):
82+
"""
83+
Initialize the async client with concurrency and retry settings.
84+
Args:
85+
max_concurrent_requests: Maximum number of concurrent requests
86+
max_retries: Maximum number of retry attempts for rate-limited requests
87+
base_delay: Base delay for exponential backoff (in seconds)
88+
"""
89+
self.max_concurrent_requests = max_concurrent_requests
90+
self.max_retries = max_retries
91+
self.base_delay = base_delay
92+
# Semaphore for controlling concurrent requests
93+
self._semaphore = None
94+
self._semaphore_loop = None
95+
96+
@abstractmethod
97+
def setup(self) -> None:
98+
"""
99+
Setup the client.
100+
"""
101+
102+
@abstractmethod
103+
async def _query_model_impl(
104+
self,
105+
*,
106+
messages: Iterable,
107+
model: str,
108+
conversation_formatter: ConversationFormatter | None = None,
109+
generation_config: GenerationConfig | dict | None = None,
110+
) -> list[str]:
111+
"""
112+
Internal implementation of query_model without retry/concurrency logic.
113+
Subclasses should implement this method instead of query_model.
114+
"""
115+
msg = "Subclass of AsyncLLMClient must implement '_query_model_impl'"
116+
raise NotImplementedError(msg)
117+
118+
async def query_model( # noqa: C901, PLR0912
119+
self,
120+
*,
121+
messages: Iterable,
122+
model: str,
123+
conversation_formatter: ConversationFormatter | None = None,
124+
generation_config: GenerationConfig | dict | None = None,
125+
) -> list[str]:
126+
"""
127+
Query the model with automatic retry and concurrency control.
128+
"""
129+
# Use default config if none provided
130+
if generation_config is None:
131+
generation_config = GenerationConfig()
132+
elif isinstance(generation_config, dict):
133+
generation_config = GenerationConfig(**generation_config)
134+
135+
# Initialize semaphore if not already done or if we're in a different event loop
136+
current_loop = asyncio.get_running_loop()
137+
if self._semaphore is None or self._semaphore_loop != current_loop:
138+
self._semaphore = asyncio.Semaphore(self.max_concurrent_requests)
139+
self._semaphore_loop = current_loop
140+
141+
async with self._semaphore: # Limit concurrent requests
142+
# Retry logic with exponential backoff
143+
last_exception = None
144+
145+
for attempt in range(self.max_retries + 1):
146+
# Check if this is a retry attempt and if we should delay
147+
if attempt > 0 and last_exception:
148+
is_rate_limit = "429" in str(last_exception) or "rate" in str(last_exception).lower()
149+
is_connection_error = (
150+
"connection" in str(last_exception).lower()
151+
or "ReadError" in str(last_exception)
152+
or "BrokenResourceError" in str(last_exception)
153+
or "APIConnectionError" in str(last_exception)
154+
or "httpx.ReadError" in str(last_exception)
155+
)
156+
157+
if is_rate_limit or is_connection_error:
158+
if is_rate_limit:
159+
logger.warning(
160+
f"Rate limit error (429) detected. Attempt {attempt + 1}/{self.max_retries + 1}. Retrying in {self.base_delay * (2 ** (attempt - 1)):.1f}s..."
161+
)
162+
else:
163+
logger.warning(
164+
f"Connection error detected. Attempt {attempt + 1}/{self.max_retries + 1}. Retrying in {self.base_delay * (2 ** (attempt - 1)):.1f}s..."
165+
)
166+
logger.warning(f"Error details: {str(last_exception)[:200]}...")
167+
if "localhost" in str(last_exception):
168+
logger.warning(
169+
"Local API server issue - consider reducing --max-concurrent-requests or checking server resources"
170+
)
171+
172+
# Exponential backoff with jitter
173+
delay = self.base_delay * (2 ** (attempt - 1)) + secrets.randbelow(100) / 100.0
174+
await asyncio.sleep(delay)
175+
else:
176+
# Re-raise if not a retryable error
177+
raise last_exception
178+
179+
# Attempt the query
180+
try:
181+
return await self._query_model_impl(
182+
messages=messages,
183+
model=model,
184+
conversation_formatter=conversation_formatter,
185+
generation_config=generation_config,
186+
)
187+
except Exception as e:
188+
last_exception = e
189+
# If this is the last attempt, provide helpful error message
190+
if attempt == self.max_retries:
191+
if "connection" in str(e).lower() or "ReadError" in str(e):
192+
logger.error(f"Connection error after {self.max_retries + 1} attempts!")
193+
logger.error(f"Final error: {str(e)[:200]}...")
194+
if "localhost" in str(e):
195+
logger.error("Suggestions for local API server:")
196+
logger.error("- Check if server is running and has sufficient resources")
197+
logger.error("- Reduce concurrent requests: --max-concurrent-requests 1")
198+
logger.error("- Increase timeout: --timeout 900")
199+
logger.error("- Check server logs for memory/GPU issues")
200+
raise
201+
# Otherwise, continue to next iteration
202+
continue
203+
204+
# This line should never be reached due to the raise in the except block
205+
# but if we get here, re-raise the last exception
206+
if last_exception:
207+
raise last_exception
208+
209+
# This should never be reached, but add explicit return for linter
210+
logger.warning(
211+
"Unexpected code path: AsyncLLMClient.query_model completed without returning a result or raising an exception"
212+
)
213+
return []
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import warnings
16+
from collections.abc import Iterable
17+
18+
from openai import AsyncOpenAI, OpenAI
19+
20+
from nemo_curator.models.client.llm_client import AsyncLLMClient, ConversationFormatter, GenerationConfig, LLMClient
21+
22+
23+
class OpenAIClient(LLMClient):
24+
"""
25+
A wrapper around OpenAI's Python client for querying models
26+
"""
27+
28+
def __init__(self, **kwargs) -> None:
29+
# Extract timeout if provided, default to 120 for backward compatibility
30+
self.timeout = kwargs.pop("timeout", 120)
31+
self.openai_kwargs = kwargs
32+
33+
def setup(self) -> None:
34+
"""
35+
Setup the client.
36+
"""
37+
self.client = OpenAI(**self.openai_kwargs)
38+
39+
def query_model(
40+
self,
41+
*,
42+
messages: Iterable,
43+
model: str,
44+
conversation_formatter: ConversationFormatter | None = None,
45+
generation_config: GenerationConfig | dict | None = None,
46+
) -> list[str]:
47+
if conversation_formatter is not None:
48+
warnings.warn("conversation_formatter is not used in an OpenAIClient", stacklevel=2)
49+
50+
# Use default config if none provided
51+
if generation_config is None:
52+
generation_config = GenerationConfig()
53+
elif isinstance(generation_config, dict):
54+
generation_config = GenerationConfig(**generation_config)
55+
56+
if generation_config.top_k is not None:
57+
warnings.warn("top_k is not used in an OpenAIClient", stacklevel=2)
58+
59+
response = self.client.chat.completions.create(
60+
messages=messages,
61+
model=model,
62+
max_tokens=generation_config.max_tokens,
63+
n=generation_config.n,
64+
seed=generation_config.seed,
65+
stop=generation_config.stop,
66+
stream=generation_config.stream,
67+
temperature=generation_config.temperature,
68+
top_p=generation_config.top_p,
69+
timeout=self.timeout,
70+
)
71+
72+
return [choice.message.content for choice in response.choices]
73+
74+
75+
class AsyncOpenAIClient(AsyncLLMClient):
76+
"""
77+
A wrapper around OpenAI's Python async client for querying models
78+
"""
79+
80+
def __init__(
81+
self, max_concurrent_requests: int = 5, max_retries: int = 3, base_delay: float = 1.0, **kwargs
82+
) -> None:
83+
"""
84+
Initialize the AsyncOpenAI client.
85+
86+
Args:
87+
max_concurrent_requests: Maximum number of concurrent requests
88+
max_retries: Maximum number of retry attempts for rate-limited requests
89+
base_delay: Base delay for exponential backoff (in seconds)
90+
**kwargs: Additional arguments passed to OpenAI client
91+
"""
92+
super().__init__(max_concurrent_requests, max_retries, base_delay)
93+
# Extract timeout if provided, default to 120 for backward compatibility
94+
self.timeout = kwargs.pop("timeout", 120)
95+
self.openai_kwargs = kwargs
96+
97+
def setup(self) -> None:
98+
"""
99+
Setup the client.
100+
"""
101+
self.client = AsyncOpenAI(**self.openai_kwargs)
102+
103+
async def _query_model_impl(
104+
self,
105+
*,
106+
messages: Iterable,
107+
model: str,
108+
conversation_formatter: ConversationFormatter | None = None,
109+
generation_config: GenerationConfig | dict | None = None,
110+
) -> list[str]:
111+
"""
112+
Internal implementation of query_model without retry/concurrency logic.
113+
"""
114+
if conversation_formatter is not None:
115+
warnings.warn("conversation_formatter is not used in an AsyncOpenAIClient", stacklevel=2)
116+
117+
# Use default config if none provided
118+
if generation_config is None:
119+
generation_config = GenerationConfig()
120+
elif isinstance(generation_config, dict):
121+
generation_config = GenerationConfig(**generation_config)
122+
123+
if generation_config.top_k is not None:
124+
warnings.warn("top_k is not used in an AsyncOpenAIClient", stacklevel=2)
125+
126+
response = await self.client.chat.completions.create(
127+
messages=messages,
128+
model=model,
129+
max_tokens=generation_config.max_tokens,
130+
n=generation_config.n,
131+
seed=generation_config.seed,
132+
stop=generation_config.stop,
133+
stream=generation_config.stream,
134+
temperature=generation_config.temperature,
135+
top_p=generation_config.top_p,
136+
timeout=self.timeout,
137+
)
138+
139+
return [choice.message.content for choice in response.choices]

nemo_curator/stages/synthetic/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)