diff --git a/.runpod/tests.json b/.runpod/tests.json index 02b8a4f..114e67f 100644 --- a/.runpod/tests.json +++ b/.runpod/tests.json @@ -1,12 +1,111 @@ { "tests": [ { - "name": "basic_test", + "name": "text_embedding_explicit_modality", "input": { - "model": "BAAI/bge-small-en-v1.5", - "input": "Hello, world!" + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "patrickjohncyh/fashion-clip", + "input": "A beautiful red dress", + "modality": "text" + } + }, + "expected_output": { + "status": "COMPLETED" }, "timeout": 10000 + }, + { + "name": "text_embedding_default_modality", + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "patrickjohncyh/fashion-clip", + "input": "A beautiful red dress" + } + }, + "expected_output": { + "status": "COMPLETED" + }, + "timeout": 10000 + }, + { + "name": "image_url_embedding", + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "patrickjohncyh/fashion-clip", + "input": "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg", + "modality": "image" + } + }, + "expected_output": { + "status": "COMPLETED" + }, + "timeout": 15000 + }, + { + "name": "multiple_images", + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "patrickjohncyh/fashion-clip", + "input": [ + "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg", + "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg" + ], + "modality": "image" + } + }, + "expected_output": { + "status": "COMPLETED" + }, + "timeout": 20000 + }, + { + "name": "multiple_texts", + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "patrickjohncyh/fashion-clip", + "input": [ + "A red dress", + "A blue shirt", + "Black shoes" + ], + "modality": "text" + } + }, + "expected_output": { + "status": "COMPLETED" + }, + "timeout": 15000 + }, + { + "name": "audio_not_implemented", + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "patrickjohncyh/fashion-clip", + "input": "audio data", + "modality": "audio" + } + }, + "expected_output": { + "status": "FAILED" + }, + "timeout": 5000 + }, + { + "name": "get_models_list", + "input": { + "openai_route": "/v1/models", + "openai_input": {} + }, + "expected_output": { + "status": "COMPLETED" + }, + "timeout": 5000 } ], "config": { @@ -16,7 +115,7 @@ "env": [ { "key": "MODEL_NAMES", - "value": "BAAI/bge-small-en-v1.5" + "value": "patrickjohncyh/fashion-clip" } ] } diff --git a/Dockerfile b/Dockerfile index 0155c94..f99387f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,4 +29,4 @@ ADD src . COPY test_input.json /test_input.json # start the handler -CMD python -u /handler.py +CMD ["python", "-u", "/handler.py"] diff --git a/README.md b/README.md index b22672b..d8c55fe 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,9 @@ --- -High-throughput, OpenAI-compatible text embedding & reranker powered by [Infinity](https://github.com/michaelfeil/infinity) +High-throughput, OpenAI-compatible **text & image embedding** & reranker powered by [Infinity](https://github.com/michaelfeil/infinity) + +**✨ New: Multimodal Support!** Now supports text and image embeddings (URLs & base64) with an explicit `modality` switch per request. --- @@ -11,14 +13,19 @@ High-throughput, OpenAI-compatible text embedding & reranker powered by [Infinit --- 1. [Quickstart](#quickstart) -2. [Endpoint Configuration](#endpoint-configuration) -3. [API Specification](#api-specification) +2. [Multimodal Features](#multimodal-features) +3. [Endpoint Configuration](#endpoint-configuration) +4. [API Specification](#api-specification) 1. [List Models](#list-models) 2. [Create Embeddings](#create-embeddings) 3. [Rerank Documents](#rerank-documents) -4. [Usage](#usage) -5. [Further Documentation](#further-documentation) -6. [Acknowledgements](#acknowledgements) +5. [Usage](#usage) + 1. [List Models](#list-models-1) + 2. [Text Embeddings](#text-embeddings) + 3. [Image Embeddings](#image-embeddings) + 4. [Reranking](#reranking) +6. [Further Documentation](#further-documentation) +7. [Acknowledgements](#acknowledgements) --- @@ -31,18 +38,56 @@ High-throughput, OpenAI-compatible text embedding & reranker powered by [Infinit --- +## Multimodal Features + +### Supported Modalities + +- ✅ **Text** – traditional text embeddings +- ✅ **Image URLs** – `http://` or `https://` links to images (`.jpg`, `.png`, `.gif`, etc.) +- ✅ **Base64 Images** – data URI format (`data:image/png;base64,...`) + +Each request targets a single modality: + +| Modality | How to request | Notes | +| -------- | ------------------------------------------------ | ------------------------------------------------- | +| `text` | Default; or set `modality="text"` | Works with any deployed embedding model | +| `image` | Set `modality="image"` | Requires a multimodal model (see below) | +| `audio` | Planned | Returns a clear `NotImplementedError` for now | + +> **Tip:** For OpenAI-compatible requests, include `"modality": "…"` alongside `model` and `input`. For native `/runsync` requests, pass `modality` inside the `input` object. If omitted, the worker assumes `text`. + +### Validation & Image Fetching Defaults + +- All inputs are validated eagerly for the chosen modality with detailed, index-aware error messages. +- Image downloads run through a shared `httpx.AsyncClient` with tuned keep-alive limits, timeouts, and a desktop browser User-Agent—improving compatibility with CDNs that block generic clients. All of these knobs can be overridden using the `HTTP_CLIENT_*` environment variables listed below. + +### Multimodal Models + +To use image embeddings, deploy a multimodal model such as: +- `patrickjohncyh/fashion-clip` – Fashion-focused CLIP model +- `jinaai/jina-clip-v1` – General-purpose multimodal embeddings +- Any other CLIP-based model with `image_embed` support + +> **Note:** Text-only models (like `BAAI/bge-small-en-v1.5`) will reject image inputs with a clear error message. + +--- + ## Endpoint Configuration All behaviour is controlled through environment variables: -| Variable | Required | Default | Description | -| ------------------------ | -------- | ------- | ---------------------------------------------------------------------------------------------------------------- | -| `MODEL_NAMES` | **Yes** | — | One or more Hugging-Face model IDs. Separate multiple IDs with a semicolon.
Example: `BAAI/bge-small-en-v1.5` | -| `BATCH_SIZES` | No | `32` | Per-model batch size; semicolon-separated list matching `MODEL_NAMES`. | -| `BACKEND` | No | `torch` | Inference engine for _all_ models: `torch`, `optimum`, or `ctranslate2`. | -| `DTYPES` | No | `auto` | Precision per model (`auto`, `fp16`, `fp8`). Semicolon-separated, must match `MODEL_NAMES`. | -| `INFINITY_QUEUE_SIZE` | No | `48000` | Max items queueable inside the Infinity engine. | -| `RUNPOD_MAX_CONCURRENCY` | No | `300` | Max concurrent requests the RunPod wrapper will accept. | +| Variable | Required | Default | Description | +| ------------------------ | -------- | ------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | +| `MODEL_NAMES` | **Yes** | — | One or more Hugging-Face model IDs. Separate multiple IDs with a semicolon.
Example: `BAAI/bge-small-en-v1.5;patrickjohncyh/fashion-clip` | +| `BATCH_SIZES` | No | `32` | Per-model batch size; semicolon-separated list matching `MODEL_NAMES`.
Example: `32;16` | +| `BACKEND` | No | `torch` | Inference engine for _all_ models: `torch`, `optimum`, or `ctranslate2`. | +| `DTYPES` | No | `auto` | Precision per model (`auto`, `fp16`, `fp8`). Semicolon-separated, must match `MODEL_NAMES`.
Example: `auto;auto` | +| `INFINITY_QUEUE_SIZE` | No | `48000` | Max items queueable inside the Infinity engine. | +| `RUNPOD_MAX_CONCURRENCY` | No | `300` | Max concurrent requests the RunPod wrapper will accept. | +| `HTTP_CLIENT_USER_AGENT` | No | `Mozilla/5.0 ... Chrome/120.0.0.0 Safari/537.36` | Override the browser-style User-Agent used for outbound image downloads. | +| `HTTP_CLIENT_TIMEOUT` | No | `10.0` | Request timeout (seconds) for outbound image fetches. | +| `HTTP_CLIENT_MAX_CONNECTIONS` | No | `50` | Concurrent connection pool size for the shared `httpx` client. | +| `HTTP_CLIENT_MAX_KEEPALIVE_CONNECTIONS` | No | `20` | Max keep-alive sockets retained by the shared `httpx` client. | --- @@ -80,17 +125,18 @@ Except for transport (path + wrapper object) the JSON you send/receive is identi #### Request Fields (shared) -| Field | Type | Required | Description | -| ------- | ------------------- | -------- | ------------------------------------------------- | -| `model` | string | **Yes** | One of the IDs supplied via `MODEL_NAMES`. | -| `input` | string | array | **Yes** | A single text string _or_ list of texts to embed. | +| Field | Type | Required | Description | +| ---------- | ------------------- | -------- | -------------------------------------------------------------------------------------------------------------------------- | +| `model` | string | **Yes** | One of the IDs supplied via `MODEL_NAMES`. | +| `input` | string | array | **Yes** | Text string(s) or image URL/base64 list matching the selected modality. Order is preserved. | +| `modality` | string | No | Required for images. Accepts `text` (default) or `image`. For OpenAI requests supply via `extra_body.modality`. | OpenAI route vs. Standard: -| Flavour | Method | Path | Body | -| -------- | ------ | ---------------- | --------------------------------------------- | -| OpenAI | `POST` | `/v1/embeddings` | `{ "model": "…", "input": "…" }` | -| Standard | `POST` | `/runsync` | `{ "input": { "model": "…", "input": "…" } }` | +| Flavour | Method | Path | Body | +| -------- | ------ | ---------------- | ---------------------------------------------------------------------- | +| OpenAI | `POST` | `/v1/embeddings` | `{ "model": "…", "input": "…", "modality": "text" }` (modality optional for text) | +| Standard | `POST` | `/runsync` | `{ "input": { "model": "…", "input": "…", "modality": "text" } }` | #### Response (both flavours) @@ -146,34 +192,90 @@ Below are minimal `curl` snippets so you can copy-paste from any machine. > Replace `` with your endpoint ID and `` with a [RunPod API key](https://docs.runpod.io/get-started/api-keys). -### OpenAI-Compatible Calls +### List Models ```bash -# List models +# OpenAI-compatible format curl -H "Authorization: Bearer " \ https://api.runpod.ai/v2//openai/v1/models -# Create embeddings +# Standard RunPod format +curl -X POST \ + -H "Content-Type: application/json" \ + -d '{"input":{"openai_route":"/v1/models"}}' \ + https://api.runpod.ai/v2//runsync +``` + +### Text Embeddings + +```bash +# OpenAI-compatible format curl -X POST \ -H "Authorization: Bearer " \ -H "Content-Type: application/json" \ - -d '{"model":"BAAI/bge-small-en-v1.5","input":"Hello world"}' \ + -d '{"model":"BAAI/bge-small-en-v1.5","input":"Hello world","modality":"text"}' \ https://api.runpod.ai/v2//openai/v1/embeddings + +# Standard RunPod format +curl -X POST \ + -H "Content-Type: application/json" \ + -d '{"input":{"model":"BAAI/bge-small-en-v1.5","input":"Hello world","modality":"text"}}' \ + https://api.runpod.ai/v2//runsync ``` -### Standard RunPod Calls +### Image Embeddings ```bash -# Create embeddings (wait for result) +# OpenAI-compatible format (image URL) +curl -X POST \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{"model":"patrickjohncyh/fashion-clip","input":"https://example.com/image.jpg","modality":"image"}' \ + https://api.runpod.ai/v2//openai/v1/embeddings + +# Standard RunPod format (base64 image) curl -X POST \ -H "Content-Type: application/json" \ - -d '{"input":{"model":"BAAI/bge-small-en-v1.5","input":"Hello world"}}' \ + -d '{"input":{"model":"patrickjohncyh/fashion-clip","input":"data:image/png;base64,iVBORw0KG...","modality":"image"}}' \ https://api.runpod.ai/v2//runsync +``` + +> **Note:** Send one request per modality. If you need both text and image embeddings, issue two calls so each payload is validated consistently. + +### Reranking + +```bash +# OpenAI-compatible format +curl -X POST \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "model": "BAAI/bge-reranker-large", + "query": "Which product has warranty coverage?", + "docs": [ + "Product A comes with a 2-year warranty", + "Product B is available in red and blue colors", + "All electronics include a standard 1-year warranty" + ], + "return_docs": true + }' \ + https://api.runpod.ai/v2//openai/v1/rerank -# Rerank +# Standard RunPod format curl -X POST \ -H "Content-Type: application/json" \ - -d '{"input":{"model":"BAAI/bge-reranker-large","query":"Which product has warranty coverage?","docs":["Product A comes with a 2-year warranty","Product B is available in red and blue colors","All electronics include a standard 1-year warranty"],"return_docs":true}}' \ + -d '{ + "input": { + "model": "BAAI/bge-reranker-large", + "query": "Which product has warranty coverage?", + "docs": [ + "Product A comes with a 2-year warranty", + "Product B is available in red and blue colors", + "All electronics include a standard 1-year warranty" + ], + "return_docs": true + } + }' \ https://api.runpod.ai/v2//runsync ``` diff --git a/docker-compose.yml b/docker-compose.yml index 2c81997..0bd5734 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,7 +13,7 @@ services: count: all capabilities: [gpu] environment: - MODEL_NAMES: "BAAI/bge-small-en-v1.5" + MODEL_NAMES: "BAAI/bge-small-en-v1.5;patrickjohncyh/fashion-clip" NVIDIA_VISIBLE_DEVICES: "all" volumes: - ./data/runpod-volume:/runpod-volume diff --git a/requirements.txt b/requirements.txt index 9812a81..3bb060a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ runpod~=1.7.0 -infinity-emb[all]==0.0.76 +infinity-emb[all]==0.0.77 +optimum==1.24.0 einops # deployment of custom code with nomic +httpx>=0.27.0 git+https://github.com/pytorch-labs/float8_experimental.git diff --git a/src/config.py b/src/config.py index 6266ed5..bfa28cd 100644 --- a/src/config.py +++ b/src/config.py @@ -1,10 +1,20 @@ import os from dotenv import load_dotenv from functools import cached_property +from typing import Optional DEFAULT_BATCH_SIZE = 32 DEFAULT_BACKEND = "torch" +DEFAULT_HTTP_CLIENT_USER_AGENT = ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/120.0.0.0 Safari/537.36" +) +DEFAULT_HTTP_CLIENT_TIMEOUT_SECONDS = 10.0 +DEFAULT_HTTP_CLIENT_MAX_CONNECTIONS = 50 +DEFAULT_HTTP_CLIENT_MAX_KEEPALIVE_CONNECTIONS = 20 + if not os.environ.get("INFINITY_QUEUE_SIZE"): # how many items can be in the queue os.environ["INFINITY_QUEUE_SIZE"] = "48000" @@ -56,3 +66,60 @@ def dtypes(self) -> list[str]: @cached_property def runpod_max_concurrency(self) -> int: return int(os.environ.get("RUNPOD_MAX_CONCURRENCY", 300)) + + +class HttpClientConfig: + ENV_USER_AGENT = "HTTP_CLIENT_USER_AGENT" + ENV_TIMEOUT = "HTTP_CLIENT_TIMEOUT" + ENV_MAX_CONNECTIONS = "HTTP_CLIENT_MAX_CONNECTIONS" + ENV_MAX_KEEPALIVE = "HTTP_CLIENT_MAX_KEEPALIVE_CONNECTIONS" + + def __init__(self): + load_dotenv() + + def _get_env_float(self, key: str, default: float) -> float: + value = os.environ.get(key) + if value is None: + return default + try: + return float(value) + except ValueError as exc: + raise ValueError(f"Environment variable {key} must be a float, got {value!r}") from exc + + def _get_env_int(self, key: str, default: int) -> int: + value = os.environ.get(key) + if value is None: + return default + try: + return int(value) + except ValueError as exc: + raise ValueError(f"Environment variable {key} must be an integer, got {value!r}") from exc + + def _get_env_str(self, key: str, default: str) -> str: + value: Optional[str] = os.environ.get(key) + if value is None: + return default + value = value.strip() + return value or default + + @cached_property + def user_agent(self) -> str: + return self._get_env_str(self.ENV_USER_AGENT, DEFAULT_HTTP_CLIENT_USER_AGENT) + + @cached_property + def timeout_seconds(self) -> float: + return self._get_env_float(self.ENV_TIMEOUT, DEFAULT_HTTP_CLIENT_TIMEOUT_SECONDS) + + @cached_property + def max_connections(self) -> int: + return self._get_env_int( + self.ENV_MAX_CONNECTIONS, + DEFAULT_HTTP_CLIENT_MAX_CONNECTIONS, + ) + + @cached_property + def max_keepalive_connections(self) -> int: + return self._get_env_int( + self.ENV_MAX_KEEPALIVE, + DEFAULT_HTTP_CLIENT_MAX_KEEPALIVE_CONNECTIONS, + ) diff --git a/src/embedding_service.py b/src/embedding_service.py index a5afc44..0dc4e30 100644 --- a/src/embedding_service.py +++ b/src/embedding_service.py @@ -1,13 +1,21 @@ -from config import EmbeddingServiceConfig +import asyncio +import logging + from infinity_emb.engine import AsyncEngineArray, EngineArgs +from infinity_emb.primitives import ModelNotDeployedError +from PIL import Image + +from config import EmbeddingServiceConfig +from http_client import create_http_client +from multimodal_utils import validate_item_for_modality from utils import ( - OpenAIModelInfo, ModelInfo, + OpenAIModelInfo, list_embeddings_to_response, to_rerank_response, ) -import asyncio +logger = logging.getLogger(__name__) class EmbeddingService: @@ -31,18 +39,28 @@ def __init__(self): self.engine_array = AsyncEngineArray.from_args(engine_args) self.is_running = False self.sepamore = asyncio.Semaphore(1) + self.http_client = None async def start(self): """starts the engine background loop""" async with self.sepamore: if not self.is_running: await self.engine_array.astart() + if self.http_client is None: + self.http_client = create_http_client() + logger.info("Created persistent HTTP client for image downloads") + self.is_running = True async def stop(self): """stops the engine background loop""" async with self.sepamore: if self.is_running: + if self.http_client is not None: + await self.http_client.aclose() + self.http_client = None + logger.info("Closed HTTP client") + await self.engine_array.astop() self.is_running = False @@ -56,25 +74,141 @@ def list_models(self) -> list[str]: async def route_openai_get_embeddings( self, - embedding_input: str | list[str], + embedding_input: str | list[str] | list[str | bytes | Image.Image], model_name: str, + modality: str = "text", return_as_list: bool = False, ): - """returns embeddings for the input text""" - if not self.is_running: - await self.start() - if not isinstance(embedding_input, list): - embedding_input = [embedding_input] - - embeddings, usage = await self.engine_array[model_name].embed(embedding_input) - if return_as_list: - return [ - list_embeddings_to_response(embeddings, model=model_name, usage=usage) - ] - else: - return list_embeddings_to_response( - embeddings, model=model_name, usage=usage + """ + Returns embeddings for the input based on specified modality. + + Args: + embedding_input: Input text(s) or image(s) to embed + model_name: Name of the model to use + modality: Type of input - "text", "image", or "audio" (not yet implemented) + return_as_list: Whether to return results as a list + + Raises: + ValueError: If model not available, modality invalid, or validation fails + NotImplementedError: If modality is "audio" + """ + try: + if not self.is_running: + await self.start() + + available_models = self.list_models() + if model_name not in available_models: + available_models_msg = ( + ", ".join(available_models) if available_models else "none" + ) + error_msg = ( + f"Model '{model_name}' is not deployed. " + f"Available deployments: {available_models_msg}. " + f"Deploy the requested model or choose one of the available models." + ) + logger.error(error_msg) + raise ValueError(error_msg) + + if not isinstance(embedding_input, list): + embedding_input = [embedding_input] + + # Validate all items for the specified modality in parallel + try: + validated_items = await asyncio.gather( + *[ + validate_item_for_modality( + item, modality, idx, client=self.http_client + ) + for idx, item in enumerate(embedding_input) + ] + ) + except (ValueError, NotImplementedError) as e: + logger.error(f"Validation failed for modality '{modality}': {e}") + raise + + logger.info( + f"Processing {len(validated_items)} {modality} items for model '{model_name}'" + ) + + # Route to appropriate embedding method based on modality + try: + if modality == "text": + logger.debug( + f"Calling .embed() with {len(validated_items)} text items" + ) + embeddings, usage = await self.engine_array[model_name].embed( + validated_items + ) + logger.debug( + f"Successfully got {len(embeddings)} text embeddings" + ) + + elif modality == "image": + logger.debug( + f"Calling .image_embed() with {len(validated_items)} image items" + ) + embeddings, usage = await self.engine_array[model_name].image_embed( + images=validated_items + ) + logger.debug( + f"Successfully got {len(embeddings)} image embeddings" + ) + + elif modality == "audio": + raise NotImplementedError( + "Audio modality is not yet implemented. " + "Currently supported modalities: 'text', 'image'" + ) + + else: + raise ValueError( + f"Invalid modality: '{modality}'. " + f"Supported modalities: 'text', 'image', 'audio' (not yet implemented)" + ) + except ModelNotDeployedError as e: + available_capabilities = sorted( + getattr(self.engine_array[model_name], "capabilities", set()) + ) + capabilities_hint = ( + ", ".join(available_capabilities) if available_capabilities else "none" + ) + error_msg = ( + f"Model '{model_name}' does not expose the '{modality}' capability. " + f"Detected capabilities: {capabilities_hint}. Deploy a model that supports " + f"this modality or adjust the request." + ) + logger.error(f"{error_msg} Original error: {e}") + raise ValueError(error_msg) from e + + if return_as_list: + return [ + list_embeddings_to_response( + embeddings, + model=model_name, + usage=usage, # type: ignore[arg-type] + ) + ] + else: + return list_embeddings_to_response( + embeddings, + model=model_name, + usage=usage, # type: ignore[arg-type] + ) + + except (ValueError, NotImplementedError) as e: + logger.warning( + f"Expected error in route_openai_get_embeddings: {type(e).__name__}: {e}" + ) + raise + except Exception as e: + logger.exception( + f"Unexpected error in route_openai_get_embeddings: " + f"model='{model_name}', modality='{modality}', " + f"input_length={len(embedding_input) if isinstance(embedding_input, list) else 1}" ) + raise RuntimeError( + f"Internal error while processing embeddings: {type(e).__name__}: {str(e)}" + ) from e async def infinity_rerank( self, query: str, docs: str, return_docs: str, model_name: str diff --git a/src/handler.py b/src/handler.py index 8f38e78..0d43789 100644 --- a/src/handler.py +++ b/src/handler.py @@ -25,16 +25,20 @@ async def async_generator_handler(job: dict[str, Any]): if openai_route and openai_route == "/v1/models": call_fn, kwargs = embedding_service.route_openai_models, {} elif openai_route and openai_route == "/v1/embeddings": - model_name = openai_input.get("model") if not openai_input: return create_error_response("Missing input").model_dump() + model_name = openai_input.get("model") if not model_name: return create_error_response( "Did not specify model in openai_input" ).model_dump() + + modality = openai_input.get("modality", "text") + call_fn, kwargs = embedding_service.route_openai_get_embeddings, { "embedding_input": openai_input.get("input"), "model_name": model_name, + "modality": modality, "return_as_list": True, } else: @@ -51,9 +55,11 @@ async def async_generator_handler(job: dict[str, Any]): "model_name": job_input.get("model"), } elif job_input.get("input"): + modality = job_input.get("modality", "text") call_fn, kwargs = embedding_service.route_openai_get_embeddings, { "embedding_input": job_input.get("input"), "model_name": job_input.get("model"), + "modality": modality, } else: return create_error_response(f"Invalid input: {job}").model_dump() diff --git a/src/http_client.py b/src/http_client.py new file mode 100644 index 0000000..dc01a8b --- /dev/null +++ b/src/http_client.py @@ -0,0 +1,32 @@ +import httpx + +from config import HttpClientConfig + +_CLIENT_CONFIG = HttpClientConfig() + + +def create_http_client() -> httpx.AsyncClient: + """ + Create a configured httpx.AsyncClient. + + Returns: + httpx.AsyncClient with proper timeout, limits, and headers configured. + """ + limits = httpx.Limits( + max_connections=_CLIENT_CONFIG.max_connections, + max_keepalive_connections=_CLIENT_CONFIG.max_keepalive_connections, + ) + + timeout = httpx.Timeout(_CLIENT_CONFIG.timeout_seconds) + + headers = { + "User-Agent": _CLIENT_CONFIG.user_agent, + } + + return httpx.AsyncClient( + limits=limits, + timeout=timeout, + headers=headers, + follow_redirects=True, + trust_env=True, + ) diff --git a/src/multimodal_utils.py b/src/multimodal_utils.py new file mode 100644 index 0000000..53a8b28 --- /dev/null +++ b/src/multimodal_utils.py @@ -0,0 +1,204 @@ +import base64 +import io +import logging +import re +from typing import Any +from PIL import Image + +logger = logging.getLogger(__name__) + + +def _ensure_rgb(img: Image.Image) -> Image.Image: + """ + Ensure image is in RGB format (required for CLIP models). + + Args: + img: PIL Image in any mode + + Returns: + PIL Image in RGB mode + """ + if img.mode != 'RGB': + logger.debug(f"Converting image from {img.mode} to RGB") + return img.convert('RGB') + return img + + +def _is_url(text: str) -> bool: + """Check if string is a URL""" + return text.startswith('http://') or text.startswith('https://') + + +async def _download_image_from_url(url: str, client) -> Image.Image: + """ + Download image from URL using httpx with proper User-Agent and timeout. + + Args: + url: HTTP(S) URL to download image from + client: httpx.AsyncClient instance with configured timeout and limits + + Returns: + PIL.Image in RGB format + + Raises: + ValueError: If download fails or content is not a valid image + """ + try: + logger.debug(f"Downloading image from URL: {url}") + response = await client.get(url) + response.raise_for_status() + + img_bytes = response.content + logger.debug(f"Downloaded {len(img_bytes)} bytes from {url} (status: {response.status_code})") + except Exception as e: + raise ValueError(f"Failed to download image from URL: {type(e).__name__}: {e}") from e + + try: + img = Image.open(io.BytesIO(img_bytes)) + img.load() # Force load to validate it's a real image + logger.debug(f"Successfully loaded image from URL: {img.size} {img.mode}") + + return _ensure_rgb(img) + except Exception as e: + raise ValueError(f"Failed to decode image from URL: {type(e).__name__}: {e}") from e + + +def _is_base64_image(data: str) -> Image.Image | None: + """ + Try to decode string as base64-encoded image. + Supports both data URI format (data:image/...) and raw base64. + + Returns: + PIL.Image in RGB format, or None if not a valid base64 image + + Note: Converts all images to RGB format for compatibility with multimodal models. + """ + try: + # Handle data URI format: data:image/png;base64,iVBORw0KG... + if data.startswith('data:'): + match = re.match(r'data:image/[^;]+;base64,(.+)', data) + if match: + base64_data = match.group(1) + logger.debug(f"Matched data URI, extracted base64 data (length: {len(base64_data)})") + else: + logger.debug("data: URI does not match expected format") + return None + else: + # Try raw base64 + base64_data = data + logger.debug("Treating as raw base64") + + img_bytes = base64.b64decode(base64_data) + logger.debug(f"Decoded base64 to {len(img_bytes)} bytes") + + img = Image.open(io.BytesIO(img_bytes)) + img.load() # Force load to validate it's a real image + logger.debug(f"Successfully loaded image: {img.size} {img.mode}") + + return _ensure_rgb(img) + except Exception as e: + logger.warning(f"Failed to decode base64 image: {type(e).__name__}: {e}") + return None + + +def validate_text_item(item: Any) -> str: + """ + Validate and convert item to text string. + Raises ValueError if item cannot be converted to text. + """ + if isinstance(item, str): + return item + + # Try to convert to string + try: + return str(item) + except Exception as e: + raise ValueError(f"Cannot convert item to text: {type(item).__name__}") from e + + +async def validate_image_item(item: Any, client=None) -> Image.Image: + """ + Validate and process image item. + Accepts: PIL.Image, bytes, URL string, or base64 string. + Returns PIL.Image in RGB format. + Raises ValueError if item is not a valid image format. + + Args: + item: The image item to validate (PIL.Image, bytes, URL string, or base64 string) + client: Optional httpx.AsyncClient for downloading URL images with timeout and User-Agent + + Note: All images are converted to RGB format for CLIP compatibility. + """ + # PIL Image - convert to RGB if needed + if isinstance(item, Image.Image): + return _ensure_rgb(item) + + # Bytes - decode to PIL Image + if isinstance(item, bytes): + try: + img = Image.open(io.BytesIO(item)) + img.load() + logger.debug(f"Loaded image from bytes: {img.size} {img.mode}") + return _ensure_rgb(img) + except Exception as e: + raise ValueError(f"Failed to decode image from bytes: {type(e).__name__}: {e}") from e + + if isinstance(item, str): + # URL images - download with proper User-Agent and timeout + if _is_url(item): + if client is None: + raise ValueError("HTTP client required for downloading images from URLs") + return await _download_image_from_url(item, client) + + # Base64 images - decode and validate (already converted to RGB) + img = _is_base64_image(item) + if img is not None: + return img + + # Not a valid image format + raise ValueError( + "String is not a valid image format (must be URL starting with http:// or https://, " + "or base64-encoded image with 'data:image/...' prefix)" + ) + + raise ValueError( + f"Invalid image type: {type(item).__name__}. " + f"Expected PIL.Image, bytes, URL string, or base64 string." + ) + + +async def validate_item_for_modality(item: Any, modality: str, index: int, client=None) -> Any: + """ + Validate a single item for the specified modality. + + Args: + item: The input item to validate + modality: One of "text", "image", "audio" + index: The index of the item in the batch (for error messages) + client: Optional httpx.AsyncClient for downloading URL images with timeout + + Returns: + Validated and processed item suitable for infinity_emb + + Raises: + ValueError: If item is not valid for the specified modality + NotImplementedError: If modality is "audio" + """ + try: + if modality == "text": + return validate_text_item(item) + elif modality == "image": + return await validate_image_item(item, client=client) + elif modality == "audio": + raise NotImplementedError( + "Audio modality is not yet implemented. " + "Currently supported modalities: 'text', 'image'" + ) + else: + raise ValueError( + f"Invalid modality: '{modality}'. " + f"Supported modalities: 'text', 'image', 'audio' (not yet implemented)" + ) + except (ValueError, NotImplementedError) as e: + # Re-raise with index information for better error messages + raise type(e)(f"Item at index {index}: {str(e)}") from e diff --git a/src/utils.py b/src/utils.py index 1a59304..4d8b42e 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,8 +1,7 @@ -from http import HTTPStatus -from typing import Annotated, Any, Dict, List, Literal, Optional, Union -from typing import Any, Dict, Iterable, List, Optional, Union -from uuid import uuid4 import time +from http import HTTPStatus +from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Union + import numpy as np import numpy.typing as npt from pydantic import BaseModel, Field, conlist diff --git a/tests/test_modality.py b/tests/test_modality.py new file mode 100644 index 0000000..870188d --- /dev/null +++ b/tests/test_modality.py @@ -0,0 +1,447 @@ +"""Integration tests for the explicit modality API using pytest.""" + +import base64 +import io +from collections.abc import Iterable, Mapping +from typing import Any + +import pytest +import requests +from PIL import Image + +BASE_URL = "http://localhost:8000" +RUNSYNC_URL = f"{BASE_URL}/runsync" +DEFAULT_TIMEOUT_SECONDS = 30 + + +def _post_runsync( + payload: Mapping[str, Any], timeout: int | float = DEFAULT_TIMEOUT_SECONDS +) -> requests.Response: + """Send a request to the worker and print useful diagnostics.""" + response = requests.post(RUNSYNC_URL, json=payload, timeout=timeout) + print(f"POST {RUNSYNC_URL} -> {response.status_code}") + return response + + +def _extract_output(result: Mapping[str, Any]) -> Mapping[str, Any] | None: + """Normalise RunPod /runsync output into a single mapping.""" + output = result.get("output") + if isinstance(output, list) and output: + return output[0] + if isinstance(output, Mapping): + return output + return None + + +def _extract_error_message(result: Mapping[str, Any]) -> str | None: + output = _extract_output(result) + if output and output.get("object") == "error": + message = output.get("message") + return str(message) if message is not None else None + return None + + +def generate_red_square_base64() -> str: + """Generate a 5x5 red square PNG encoded as a data URI.""" + img = Image.new("RGB", (5, 5), color=(255, 0, 0)) + buffer = io.BytesIO() + img.save(buffer, format="PNG") + buffer.seek(0) + img_base64 = base64.b64encode(buffer.read()).decode("utf-8") + return f"data:image/png;base64,{img_base64}" + + +@pytest.mark.parametrize( + "name,payload,expected_count", + [ + ( + "text modality", + { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "BAAI/bge-small-en-v1.5", + "input": ["Hello world", "How are you?"], + "modality": "text", + }, + } + }, + 2, + ), + ( + "image modality (url)", + { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "patrickjohncyh/fashion-clip", + "input": [ + "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg" + ], + "modality": "image", + }, + } + }, + 1, + ), + ( + "image modality (base64)", + { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "patrickjohncyh/fashion-clip", + "input": [generate_red_square_base64()], + "modality": "image", + }, + } + }, + 1, + ), + ( + "default text modality", + { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "BAAI/bge-small-en-v1.5", + "input": ["Default modality test"], + }, + } + }, + 1, + ), + ], +) +def test_modality_success( + name: str, payload: Mapping[str, Any], expected_count: int +) -> None: + response = _post_runsync(payload) + assert response.status_code == 200, f"HTTP error for {name}: {response.text}" + + result = response.json() + assert result.get("status") == "COMPLETED", ( + f"Unexpected status for {name}: {result}" + ) + + output = _extract_output(result) or {} + data = output.get("data", []) + assert isinstance(data, Iterable), f"Missing data for {name}: {output}" + + if isinstance(data, list): + assert len(data) == expected_count, ( + f"Expected {expected_count} embeddings for {name}, got {len(data)}" + ) + else: + pytest.fail(f"Unexpected data format for {name}: {type(data)}") + + +def test_wrong_modality_error() -> None: + payload = { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "BAAI/bge-small-en-v1.5", + "input": [ + "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg" + ], + "modality": "image", + }, + } + } + + response = _post_runsync(payload) + if response.status_code != 200: + assert response.status_code >= 400 + return + + result = response.json() + error_msg = _extract_error_message(result) + assert error_msg is not None, f"Expected error object, got: {result}" + lowered = error_msg.lower() + assert "does not expose the 'image' capability" in lowered + assert "detected capabilities" in lowered + + +def test_audio_not_implemented() -> None: + payload = { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "BAAI/bge-small-en-v1.5", + "input": ["audio data"], + "modality": "audio", + }, + } + } + + response = _post_runsync(payload) + if response.status_code != 200: + assert response.status_code >= 400 + return + + result = response.json() + error_msg = _extract_error_message(result) + assert error_msg is not None, f"Expected NotImplementedError output, got: {result}" + assert "not yet implemented" in error_msg.lower() + + +def test_validation_flexibility() -> None: + payload = { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "BAAI/bge-small-en-v1.5", + "input": ["https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg"], + "modality": "text", + }, + } + } + + response = _post_runsync(payload) + assert response.status_code == 200, f"HTTP error: {response.text}" + + result = response.json() + assert result.get("status") == "COMPLETED", ( + f"Expected success treating URL as text: {result}" + ) + + +@pytest.mark.parametrize( + "payload,expected_count", + [ + ( + { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "BAAI/bge-small-en-v1.5", + "input": [], + "modality": "text", + }, + } + }, + 0, + ), + ( + { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "BAAI/bge-small-en-v1.5", + "input": "Single text string", + "modality": "text", + }, + } + }, + 1, + ), + ], +) +def test_text_edge_cases(payload: Mapping[str, Any], expected_count: int) -> None: + response = _post_runsync(payload) + if response.status_code != 200: + assert response.status_code >= 400 + return + + result = response.json() + if result.get("status") != "COMPLETED": + assert _extract_error_message(result) is not None + return + + output = _extract_output(result) or {} + data = output.get("data", []) + assert isinstance(data, list) + assert len(data) == expected_count + + +def test_edge_very_long_text() -> None: + long_text = "This is a test sentence. " * 200 + payload = { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "BAAI/bge-small-en-v1.5", + "input": long_text, + "modality": "text", + }, + } + } + + response = _post_runsync(payload) + assert response.status_code == 200, f"HTTP error: {response.text}" + + result = response.json() + assert result.get("status") == "COMPLETED", f"Unexpected status: {result}" + + +def test_edge_empty_string() -> None: + payload = { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "BAAI/bge-small-en-v1.5", + "input": "", + "modality": "text", + }, + } + } + + response = _post_runsync(payload) + if response.status_code != 200: + assert response.status_code >= 400 + return + + result = response.json() + if result.get("status") == "COMPLETED": + output = _extract_output(result) or {} + data = output.get("data", []) + assert isinstance(data, list) + else: + assert _extract_error_message(result) is not None + + +def test_edge_invalid_modality() -> None: + payload = { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "BAAI/bge-small-en-v1.5", + "input": "test", + "modality": "video", + }, + } + } + + response = _post_runsync(payload) + if response.status_code != 200: + assert response.status_code >= 400 + return + + result = response.json() + error_msg = _extract_error_message(result) + assert error_msg is not None, f"Expected invalid modality error: {result}" + assert "invalid modality" in error_msg.lower() + + +def test_edge_missing_model() -> None: + payload = { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "input": "test", + "modality": "text", + }, + } + } + + response = _post_runsync(payload) + if response.status_code != 200: + assert response.status_code >= 400 + return + + result = response.json() + error_msg = _extract_error_message(result) + assert error_msg is not None, f"Expected missing model error: {result}" + + +def test_edge_nonexistent_model() -> None: + payload = { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "nonexistent/model-12345", + "input": "test", + "modality": "text", + }, + } + } + + response = _post_runsync(payload) + if response.status_code != 200: + assert response.status_code >= 400 + return + + result = response.json() + error_msg = _extract_error_message(result) + assert error_msg is not None, f"Expected model missing error: {result}" + lowered = error_msg.lower() + assert "is not deployed" in lowered + assert "available deployments" in lowered + + +def test_edge_invalid_image_url() -> None: + payload = { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "patrickjohncyh/fashion-clip", + "input": "https://example.com/nonexistent-image-12345.jpg", + "modality": "image", + }, + } + } + + response = _post_runsync(payload) + if response.status_code != 200: + assert response.status_code >= 400 + return + + result = response.json() + assert result.get("status") in {"COMPLETED", "FAILED"}, ( + f"Unexpected status: {result}" + ) + + +def test_edge_special_characters() -> None: + payload = { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "BAAI/bge-small-en-v1.5", + "input": "Hello 世界! 🌍 Special chars: @#$%^&*()_+-=[]{}|;:',.<>?/~`", + "modality": "text", + }, + } + } + + response = _post_runsync(payload) + assert response.status_code == 200, f"HTTP error: {response.text}" + + result = response.json() + assert result.get("status") == "COMPLETED", f"Unexpected status: {result}" + + +def test_edge_corrupted_base64() -> None: + invalid_payloads = [ + "data:image/png;base64,NotValidBase64Data!!!", + "data:image/png;base64,iVBORw0KGgoAAAA", + "data:image/jpeg;base64,/9j/4AAQSkZJRg", + "data:image/png;base64,SGVsbG8gV29ybGQh", + ] + + for idx, invalid_base64 in enumerate(invalid_payloads, start=1): + print( + f"\n Test variant {idx}/{len(invalid_payloads)}: {invalid_base64[:50]}..." + ) + payload = { + "input": { + "openai_route": "/v1/embeddings", + "openai_input": { + "model": "patrickjohncyh/fashion-clip", + "input": invalid_base64, + "modality": "image", + }, + } + } + + response = _post_runsync(payload) + if response.status_code != 200: + assert response.status_code >= 400 + continue + + result = response.json() + error_msg = _extract_error_message(result) + assert error_msg is not None, f"Expected error for corrupted base64: {result}" + + print("\n✓ All corrupted base64 variants handled without crash")