|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import asyncio |
| 16 | +import secrets |
| 17 | +from abc import ABC, abstractmethod |
| 18 | +from collections.abc import Iterable |
| 19 | +from dataclasses import dataclass |
| 20 | + |
| 21 | +from loguru import logger |
| 22 | + |
| 23 | + |
| 24 | +class ConversationFormatter(ABC): |
| 25 | + """ |
| 26 | + Represents a way of formatting a conversation with an LLM |
| 27 | + such that it can response appropriately |
| 28 | + """ |
| 29 | + |
| 30 | + @abstractmethod |
| 31 | + def format_conversation(self, conv: list[dict]) -> str: |
| 32 | + msg = "format_conversation must be implemented by subclasses" |
| 33 | + raise NotImplementedError(msg) |
| 34 | + |
| 35 | + |
| 36 | +@dataclass |
| 37 | +class GenerationConfig: |
| 38 | + """Configuration class for LLM generation parameters.""" |
| 39 | + |
| 40 | + max_tokens: int | None = 2048 |
| 41 | + n: int | None = 1 |
| 42 | + seed: int | None = 0 |
| 43 | + stop: str | None | list[str] = None |
| 44 | + stream: bool = False |
| 45 | + temperature: float | None = 0.0 |
| 46 | + top_k: int | None = None |
| 47 | + top_p: float | None = 0.95 |
| 48 | + |
| 49 | + |
| 50 | +class LLMClient(ABC): |
| 51 | + """ |
| 52 | + Interface representing a client connecting to an LLM inference server |
| 53 | + and making requests synchronously |
| 54 | + """ |
| 55 | + |
| 56 | + @abstractmethod |
| 57 | + def setup(self) -> None: |
| 58 | + """ |
| 59 | + Setup the client. |
| 60 | + """ |
| 61 | + |
| 62 | + @abstractmethod |
| 63 | + def query_model( |
| 64 | + self, |
| 65 | + *, |
| 66 | + messages: Iterable, |
| 67 | + model: str, |
| 68 | + conversation_formatter: ConversationFormatter | None = None, |
| 69 | + generation_config: GenerationConfig | dict | None = None, |
| 70 | + ) -> list[str]: |
| 71 | + msg = "Subclass of LLMClient must implement 'query_model'" |
| 72 | + raise NotImplementedError(msg) |
| 73 | + |
| 74 | + |
| 75 | +class AsyncLLMClient(ABC): |
| 76 | + """ |
| 77 | + Interface representing a client connecting to an LLM inference server |
| 78 | + and making requests asynchronously |
| 79 | + """ |
| 80 | + |
| 81 | + def __init__(self, max_concurrent_requests: int = 5, max_retries: int = 3, base_delay: float = 1.0): |
| 82 | + """ |
| 83 | + Initialize the async client with concurrency and retry settings. |
| 84 | + Args: |
| 85 | + max_concurrent_requests: Maximum number of concurrent requests |
| 86 | + max_retries: Maximum number of retry attempts for rate-limited requests |
| 87 | + base_delay: Base delay for exponential backoff (in seconds) |
| 88 | + """ |
| 89 | + self.max_concurrent_requests = max_concurrent_requests |
| 90 | + self.max_retries = max_retries |
| 91 | + self.base_delay = base_delay |
| 92 | + # Semaphore for controlling concurrent requests |
| 93 | + self._semaphore = None |
| 94 | + self._semaphore_loop = None |
| 95 | + |
| 96 | + @abstractmethod |
| 97 | + def setup(self) -> None: |
| 98 | + """ |
| 99 | + Setup the client. |
| 100 | + """ |
| 101 | + |
| 102 | + @abstractmethod |
| 103 | + async def _query_model_impl( |
| 104 | + self, |
| 105 | + *, |
| 106 | + messages: Iterable, |
| 107 | + model: str, |
| 108 | + conversation_formatter: ConversationFormatter | None = None, |
| 109 | + generation_config: GenerationConfig | dict | None = None, |
| 110 | + ) -> list[str]: |
| 111 | + """ |
| 112 | + Internal implementation of query_model without retry/concurrency logic. |
| 113 | + Subclasses should implement this method instead of query_model. |
| 114 | + """ |
| 115 | + msg = "Subclass of AsyncLLMClient must implement '_query_model_impl'" |
| 116 | + raise NotImplementedError(msg) |
| 117 | + |
| 118 | + async def query_model( # noqa: C901, PLR0912 |
| 119 | + self, |
| 120 | + *, |
| 121 | + messages: Iterable, |
| 122 | + model: str, |
| 123 | + conversation_formatter: ConversationFormatter | None = None, |
| 124 | + generation_config: GenerationConfig | dict | None = None, |
| 125 | + ) -> list[str]: |
| 126 | + """ |
| 127 | + Query the model with automatic retry and concurrency control. |
| 128 | + """ |
| 129 | + # Use default config if none provided |
| 130 | + if generation_config is None: |
| 131 | + generation_config = GenerationConfig() |
| 132 | + elif isinstance(generation_config, dict): |
| 133 | + generation_config = GenerationConfig(**generation_config) |
| 134 | + |
| 135 | + # Initialize semaphore if not already done or if we're in a different event loop |
| 136 | + current_loop = asyncio.get_running_loop() |
| 137 | + if self._semaphore is None or self._semaphore_loop != current_loop: |
| 138 | + self._semaphore = asyncio.Semaphore(self.max_concurrent_requests) |
| 139 | + self._semaphore_loop = current_loop |
| 140 | + |
| 141 | + async with self._semaphore: # Limit concurrent requests |
| 142 | + # Retry logic with exponential backoff |
| 143 | + last_exception = None |
| 144 | + |
| 145 | + for attempt in range(self.max_retries + 1): |
| 146 | + # Check if this is a retry attempt and if we should delay |
| 147 | + if attempt > 0 and last_exception: |
| 148 | + is_rate_limit = "429" in str(last_exception) or "rate" in str(last_exception).lower() |
| 149 | + is_connection_error = ( |
| 150 | + "connection" in str(last_exception).lower() |
| 151 | + or "ReadError" in str(last_exception) |
| 152 | + or "BrokenResourceError" in str(last_exception) |
| 153 | + or "APIConnectionError" in str(last_exception) |
| 154 | + or "httpx.ReadError" in str(last_exception) |
| 155 | + ) |
| 156 | + |
| 157 | + if is_rate_limit or is_connection_error: |
| 158 | + if is_rate_limit: |
| 159 | + logger.warning( |
| 160 | + f"Rate limit error (429) detected. Attempt {attempt + 1}/{self.max_retries + 1}. Retrying in {self.base_delay * (2 ** (attempt - 1)):.1f}s..." |
| 161 | + ) |
| 162 | + else: |
| 163 | + logger.warning( |
| 164 | + f"Connection error detected. Attempt {attempt + 1}/{self.max_retries + 1}. Retrying in {self.base_delay * (2 ** (attempt - 1)):.1f}s..." |
| 165 | + ) |
| 166 | + logger.warning(f"Error details: {str(last_exception)[:200]}...") |
| 167 | + if "localhost" in str(last_exception): |
| 168 | + logger.warning( |
| 169 | + "Local API server issue - consider reducing --max-concurrent-requests or checking server resources" |
| 170 | + ) |
| 171 | + |
| 172 | + # Exponential backoff with jitter |
| 173 | + delay = self.base_delay * (2 ** (attempt - 1)) + secrets.randbelow(100) / 100.0 |
| 174 | + await asyncio.sleep(delay) |
| 175 | + else: |
| 176 | + # Re-raise if not a retryable error |
| 177 | + raise last_exception |
| 178 | + |
| 179 | + # Attempt the query |
| 180 | + try: |
| 181 | + return await self._query_model_impl( |
| 182 | + messages=messages, |
| 183 | + model=model, |
| 184 | + conversation_formatter=conversation_formatter, |
| 185 | + generation_config=generation_config, |
| 186 | + ) |
| 187 | + except Exception as e: |
| 188 | + last_exception = e |
| 189 | + # If this is the last attempt, provide helpful error message |
| 190 | + if attempt == self.max_retries: |
| 191 | + if "connection" in str(e).lower() or "ReadError" in str(e): |
| 192 | + logger.error(f"Connection error after {self.max_retries + 1} attempts!") |
| 193 | + logger.error(f"Final error: {str(e)[:200]}...") |
| 194 | + if "localhost" in str(e): |
| 195 | + logger.error("Suggestions for local API server:") |
| 196 | + logger.error("- Check if server is running and has sufficient resources") |
| 197 | + logger.error("- Reduce concurrent requests: --max-concurrent-requests 1") |
| 198 | + logger.error("- Increase timeout: --timeout 900") |
| 199 | + logger.error("- Check server logs for memory/GPU issues") |
| 200 | + raise |
| 201 | + # Otherwise, continue to next iteration |
| 202 | + continue |
| 203 | + |
| 204 | + # This line should never be reached due to the raise in the except block |
| 205 | + # but if we get here, re-raise the last exception |
| 206 | + if last_exception: |
| 207 | + raise last_exception |
| 208 | + |
| 209 | + # This should never be reached, but add explicit return for linter |
| 210 | + logger.warning( |
| 211 | + "Unexpected code path: AsyncLLMClient.query_model completed without returning a result or raising an exception" |
| 212 | + ) |
| 213 | + return [] |
0 commit comments