diff --git a/.runpod/tests.json b/.runpod/tests.json
index 02b8a4f..114e67f 100644
--- a/.runpod/tests.json
+++ b/.runpod/tests.json
@@ -1,12 +1,111 @@
{
"tests": [
{
- "name": "basic_test",
+ "name": "text_embedding_explicit_modality",
"input": {
- "model": "BAAI/bge-small-en-v1.5",
- "input": "Hello, world!"
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "patrickjohncyh/fashion-clip",
+ "input": "A beautiful red dress",
+ "modality": "text"
+ }
+ },
+ "expected_output": {
+ "status": "COMPLETED"
},
"timeout": 10000
+ },
+ {
+ "name": "text_embedding_default_modality",
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "patrickjohncyh/fashion-clip",
+ "input": "A beautiful red dress"
+ }
+ },
+ "expected_output": {
+ "status": "COMPLETED"
+ },
+ "timeout": 10000
+ },
+ {
+ "name": "image_url_embedding",
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "patrickjohncyh/fashion-clip",
+ "input": "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg",
+ "modality": "image"
+ }
+ },
+ "expected_output": {
+ "status": "COMPLETED"
+ },
+ "timeout": 15000
+ },
+ {
+ "name": "multiple_images",
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "patrickjohncyh/fashion-clip",
+ "input": [
+ "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg",
+ "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg"
+ ],
+ "modality": "image"
+ }
+ },
+ "expected_output": {
+ "status": "COMPLETED"
+ },
+ "timeout": 20000
+ },
+ {
+ "name": "multiple_texts",
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "patrickjohncyh/fashion-clip",
+ "input": [
+ "A red dress",
+ "A blue shirt",
+ "Black shoes"
+ ],
+ "modality": "text"
+ }
+ },
+ "expected_output": {
+ "status": "COMPLETED"
+ },
+ "timeout": 15000
+ },
+ {
+ "name": "audio_not_implemented",
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "patrickjohncyh/fashion-clip",
+ "input": "audio data",
+ "modality": "audio"
+ }
+ },
+ "expected_output": {
+ "status": "FAILED"
+ },
+ "timeout": 5000
+ },
+ {
+ "name": "get_models_list",
+ "input": {
+ "openai_route": "/v1/models",
+ "openai_input": {}
+ },
+ "expected_output": {
+ "status": "COMPLETED"
+ },
+ "timeout": 5000
}
],
"config": {
@@ -16,7 +115,7 @@
"env": [
{
"key": "MODEL_NAMES",
- "value": "BAAI/bge-small-en-v1.5"
+ "value": "patrickjohncyh/fashion-clip"
}
]
}
diff --git a/Dockerfile b/Dockerfile
index 0155c94..f99387f 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -29,4 +29,4 @@ ADD src .
COPY test_input.json /test_input.json
# start the handler
-CMD python -u /handler.py
+CMD ["python", "-u", "/handler.py"]
diff --git a/README.md b/README.md
index b22672b..d8c55fe 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,9 @@
---
-High-throughput, OpenAI-compatible text embedding & reranker powered by [Infinity](https://github.com/michaelfeil/infinity)
+High-throughput, OpenAI-compatible **text & image embedding** & reranker powered by [Infinity](https://github.com/michaelfeil/infinity)
+
+**✨ New: Multimodal Support!** Now supports text and image embeddings (URLs & base64) with an explicit `modality` switch per request.
---
@@ -11,14 +13,19 @@ High-throughput, OpenAI-compatible text embedding & reranker powered by [Infinit
---
1. [Quickstart](#quickstart)
-2. [Endpoint Configuration](#endpoint-configuration)
-3. [API Specification](#api-specification)
+2. [Multimodal Features](#multimodal-features)
+3. [Endpoint Configuration](#endpoint-configuration)
+4. [API Specification](#api-specification)
1. [List Models](#list-models)
2. [Create Embeddings](#create-embeddings)
3. [Rerank Documents](#rerank-documents)
-4. [Usage](#usage)
-5. [Further Documentation](#further-documentation)
-6. [Acknowledgements](#acknowledgements)
+5. [Usage](#usage)
+ 1. [List Models](#list-models-1)
+ 2. [Text Embeddings](#text-embeddings)
+ 3. [Image Embeddings](#image-embeddings)
+ 4. [Reranking](#reranking)
+6. [Further Documentation](#further-documentation)
+7. [Acknowledgements](#acknowledgements)
---
@@ -31,18 +38,56 @@ High-throughput, OpenAI-compatible text embedding & reranker powered by [Infinit
---
+## Multimodal Features
+
+### Supported Modalities
+
+- ✅ **Text** – traditional text embeddings
+- ✅ **Image URLs** – `http://` or `https://` links to images (`.jpg`, `.png`, `.gif`, etc.)
+- ✅ **Base64 Images** – data URI format (`data:image/png;base64,...`)
+
+Each request targets a single modality:
+
+| Modality | How to request | Notes |
+| -------- | ------------------------------------------------ | ------------------------------------------------- |
+| `text` | Default; or set `modality="text"` | Works with any deployed embedding model |
+| `image` | Set `modality="image"` | Requires a multimodal model (see below) |
+| `audio` | Planned | Returns a clear `NotImplementedError` for now |
+
+> **Tip:** For OpenAI-compatible requests, include `"modality": "…"` alongside `model` and `input`. For native `/runsync` requests, pass `modality` inside the `input` object. If omitted, the worker assumes `text`.
+
+### Validation & Image Fetching Defaults
+
+- All inputs are validated eagerly for the chosen modality with detailed, index-aware error messages.
+- Image downloads run through a shared `httpx.AsyncClient` with tuned keep-alive limits, timeouts, and a desktop browser User-Agent—improving compatibility with CDNs that block generic clients. All of these knobs can be overridden using the `HTTP_CLIENT_*` environment variables listed below.
+
+### Multimodal Models
+
+To use image embeddings, deploy a multimodal model such as:
+- `patrickjohncyh/fashion-clip` – Fashion-focused CLIP model
+- `jinaai/jina-clip-v1` – General-purpose multimodal embeddings
+- Any other CLIP-based model with `image_embed` support
+
+> **Note:** Text-only models (like `BAAI/bge-small-en-v1.5`) will reject image inputs with a clear error message.
+
+---
+
## Endpoint Configuration
All behaviour is controlled through environment variables:
-| Variable | Required | Default | Description |
-| ------------------------ | -------- | ------- | ---------------------------------------------------------------------------------------------------------------- |
-| `MODEL_NAMES` | **Yes** | — | One or more Hugging-Face model IDs. Separate multiple IDs with a semicolon.
Example: `BAAI/bge-small-en-v1.5` |
-| `BATCH_SIZES` | No | `32` | Per-model batch size; semicolon-separated list matching `MODEL_NAMES`. |
-| `BACKEND` | No | `torch` | Inference engine for _all_ models: `torch`, `optimum`, or `ctranslate2`. |
-| `DTYPES` | No | `auto` | Precision per model (`auto`, `fp16`, `fp8`). Semicolon-separated, must match `MODEL_NAMES`. |
-| `INFINITY_QUEUE_SIZE` | No | `48000` | Max items queueable inside the Infinity engine. |
-| `RUNPOD_MAX_CONCURRENCY` | No | `300` | Max concurrent requests the RunPod wrapper will accept. |
+| Variable | Required | Default | Description |
+| ------------------------ | -------- | ------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `MODEL_NAMES` | **Yes** | — | One or more Hugging-Face model IDs. Separate multiple IDs with a semicolon.
Example: `BAAI/bge-small-en-v1.5;patrickjohncyh/fashion-clip` |
+| `BATCH_SIZES` | No | `32` | Per-model batch size; semicolon-separated list matching `MODEL_NAMES`.
Example: `32;16` |
+| `BACKEND` | No | `torch` | Inference engine for _all_ models: `torch`, `optimum`, or `ctranslate2`. |
+| `DTYPES` | No | `auto` | Precision per model (`auto`, `fp16`, `fp8`). Semicolon-separated, must match `MODEL_NAMES`.
Example: `auto;auto` |
+| `INFINITY_QUEUE_SIZE` | No | `48000` | Max items queueable inside the Infinity engine. |
+| `RUNPOD_MAX_CONCURRENCY` | No | `300` | Max concurrent requests the RunPod wrapper will accept. |
+| `HTTP_CLIENT_USER_AGENT` | No | `Mozilla/5.0 ... Chrome/120.0.0.0 Safari/537.36` | Override the browser-style User-Agent used for outbound image downloads. |
+| `HTTP_CLIENT_TIMEOUT` | No | `10.0` | Request timeout (seconds) for outbound image fetches. |
+| `HTTP_CLIENT_MAX_CONNECTIONS` | No | `50` | Concurrent connection pool size for the shared `httpx` client. |
+| `HTTP_CLIENT_MAX_KEEPALIVE_CONNECTIONS` | No | `20` | Max keep-alive sockets retained by the shared `httpx` client. |
---
@@ -80,17 +125,18 @@ Except for transport (path + wrapper object) the JSON you send/receive is identi
#### Request Fields (shared)
-| Field | Type | Required | Description |
-| ------- | ------------------- | -------- | ------------------------------------------------- |
-| `model` | string | **Yes** | One of the IDs supplied via `MODEL_NAMES`. |
-| `input` | string | array | **Yes** | A single text string _or_ list of texts to embed. |
+| Field | Type | Required | Description |
+| ---------- | ------------------- | -------- | -------------------------------------------------------------------------------------------------------------------------- |
+| `model` | string | **Yes** | One of the IDs supplied via `MODEL_NAMES`. |
+| `input` | string | array | **Yes** | Text string(s) or image URL/base64 list matching the selected modality. Order is preserved. |
+| `modality` | string | No | Required for images. Accepts `text` (default) or `image`. For OpenAI requests supply via `extra_body.modality`. |
OpenAI route vs. Standard:
-| Flavour | Method | Path | Body |
-| -------- | ------ | ---------------- | --------------------------------------------- |
-| OpenAI | `POST` | `/v1/embeddings` | `{ "model": "…", "input": "…" }` |
-| Standard | `POST` | `/runsync` | `{ "input": { "model": "…", "input": "…" } }` |
+| Flavour | Method | Path | Body |
+| -------- | ------ | ---------------- | ---------------------------------------------------------------------- |
+| OpenAI | `POST` | `/v1/embeddings` | `{ "model": "…", "input": "…", "modality": "text" }` (modality optional for text) |
+| Standard | `POST` | `/runsync` | `{ "input": { "model": "…", "input": "…", "modality": "text" } }` |
#### Response (both flavours)
@@ -146,34 +192,90 @@ Below are minimal `curl` snippets so you can copy-paste from any machine.
> Replace `` with your endpoint ID and `` with a [RunPod API key](https://docs.runpod.io/get-started/api-keys).
-### OpenAI-Compatible Calls
+### List Models
```bash
-# List models
+# OpenAI-compatible format
curl -H "Authorization: Bearer " \
https://api.runpod.ai/v2//openai/v1/models
-# Create embeddings
+# Standard RunPod format
+curl -X POST \
+ -H "Content-Type: application/json" \
+ -d '{"input":{"openai_route":"/v1/models"}}' \
+ https://api.runpod.ai/v2//runsync
+```
+
+### Text Embeddings
+
+```bash
+# OpenAI-compatible format
curl -X POST \
-H "Authorization: Bearer " \
-H "Content-Type: application/json" \
- -d '{"model":"BAAI/bge-small-en-v1.5","input":"Hello world"}' \
+ -d '{"model":"BAAI/bge-small-en-v1.5","input":"Hello world","modality":"text"}' \
https://api.runpod.ai/v2//openai/v1/embeddings
+
+# Standard RunPod format
+curl -X POST \
+ -H "Content-Type: application/json" \
+ -d '{"input":{"model":"BAAI/bge-small-en-v1.5","input":"Hello world","modality":"text"}}' \
+ https://api.runpod.ai/v2//runsync
```
-### Standard RunPod Calls
+### Image Embeddings
```bash
-# Create embeddings (wait for result)
+# OpenAI-compatible format (image URL)
+curl -X POST \
+ -H "Authorization: Bearer " \
+ -H "Content-Type: application/json" \
+ -d '{"model":"patrickjohncyh/fashion-clip","input":"https://example.com/image.jpg","modality":"image"}' \
+ https://api.runpod.ai/v2//openai/v1/embeddings
+
+# Standard RunPod format (base64 image)
curl -X POST \
-H "Content-Type: application/json" \
- -d '{"input":{"model":"BAAI/bge-small-en-v1.5","input":"Hello world"}}' \
+ -d '{"input":{"model":"patrickjohncyh/fashion-clip","input":"data:image/png;base64,iVBORw0KG...","modality":"image"}}' \
https://api.runpod.ai/v2//runsync
+```
+
+> **Note:** Send one request per modality. If you need both text and image embeddings, issue two calls so each payload is validated consistently.
+
+### Reranking
+
+```bash
+# OpenAI-compatible format
+curl -X POST \
+ -H "Authorization: Bearer " \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "BAAI/bge-reranker-large",
+ "query": "Which product has warranty coverage?",
+ "docs": [
+ "Product A comes with a 2-year warranty",
+ "Product B is available in red and blue colors",
+ "All electronics include a standard 1-year warranty"
+ ],
+ "return_docs": true
+ }' \
+ https://api.runpod.ai/v2//openai/v1/rerank
-# Rerank
+# Standard RunPod format
curl -X POST \
-H "Content-Type: application/json" \
- -d '{"input":{"model":"BAAI/bge-reranker-large","query":"Which product has warranty coverage?","docs":["Product A comes with a 2-year warranty","Product B is available in red and blue colors","All electronics include a standard 1-year warranty"],"return_docs":true}}' \
+ -d '{
+ "input": {
+ "model": "BAAI/bge-reranker-large",
+ "query": "Which product has warranty coverage?",
+ "docs": [
+ "Product A comes with a 2-year warranty",
+ "Product B is available in red and blue colors",
+ "All electronics include a standard 1-year warranty"
+ ],
+ "return_docs": true
+ }
+ }' \
https://api.runpod.ai/v2//runsync
```
diff --git a/docker-compose.yml b/docker-compose.yml
index 2c81997..0bd5734 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -13,7 +13,7 @@ services:
count: all
capabilities: [gpu]
environment:
- MODEL_NAMES: "BAAI/bge-small-en-v1.5"
+ MODEL_NAMES: "BAAI/bge-small-en-v1.5;patrickjohncyh/fashion-clip"
NVIDIA_VISIBLE_DEVICES: "all"
volumes:
- ./data/runpod-volume:/runpod-volume
diff --git a/requirements.txt b/requirements.txt
index 9812a81..3bb060a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,6 @@
runpod~=1.7.0
-infinity-emb[all]==0.0.76
+infinity-emb[all]==0.0.77
+optimum==1.24.0
einops # deployment of custom code with nomic
+httpx>=0.27.0
git+https://github.com/pytorch-labs/float8_experimental.git
diff --git a/src/config.py b/src/config.py
index 6266ed5..bfa28cd 100644
--- a/src/config.py
+++ b/src/config.py
@@ -1,10 +1,20 @@
import os
from dotenv import load_dotenv
from functools import cached_property
+from typing import Optional
DEFAULT_BATCH_SIZE = 32
DEFAULT_BACKEND = "torch"
+DEFAULT_HTTP_CLIENT_USER_AGENT = (
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
+ "Chrome/120.0.0.0 Safari/537.36"
+)
+DEFAULT_HTTP_CLIENT_TIMEOUT_SECONDS = 10.0
+DEFAULT_HTTP_CLIENT_MAX_CONNECTIONS = 50
+DEFAULT_HTTP_CLIENT_MAX_KEEPALIVE_CONNECTIONS = 20
+
if not os.environ.get("INFINITY_QUEUE_SIZE"):
# how many items can be in the queue
os.environ["INFINITY_QUEUE_SIZE"] = "48000"
@@ -56,3 +66,60 @@ def dtypes(self) -> list[str]:
@cached_property
def runpod_max_concurrency(self) -> int:
return int(os.environ.get("RUNPOD_MAX_CONCURRENCY", 300))
+
+
+class HttpClientConfig:
+ ENV_USER_AGENT = "HTTP_CLIENT_USER_AGENT"
+ ENV_TIMEOUT = "HTTP_CLIENT_TIMEOUT"
+ ENV_MAX_CONNECTIONS = "HTTP_CLIENT_MAX_CONNECTIONS"
+ ENV_MAX_KEEPALIVE = "HTTP_CLIENT_MAX_KEEPALIVE_CONNECTIONS"
+
+ def __init__(self):
+ load_dotenv()
+
+ def _get_env_float(self, key: str, default: float) -> float:
+ value = os.environ.get(key)
+ if value is None:
+ return default
+ try:
+ return float(value)
+ except ValueError as exc:
+ raise ValueError(f"Environment variable {key} must be a float, got {value!r}") from exc
+
+ def _get_env_int(self, key: str, default: int) -> int:
+ value = os.environ.get(key)
+ if value is None:
+ return default
+ try:
+ return int(value)
+ except ValueError as exc:
+ raise ValueError(f"Environment variable {key} must be an integer, got {value!r}") from exc
+
+ def _get_env_str(self, key: str, default: str) -> str:
+ value: Optional[str] = os.environ.get(key)
+ if value is None:
+ return default
+ value = value.strip()
+ return value or default
+
+ @cached_property
+ def user_agent(self) -> str:
+ return self._get_env_str(self.ENV_USER_AGENT, DEFAULT_HTTP_CLIENT_USER_AGENT)
+
+ @cached_property
+ def timeout_seconds(self) -> float:
+ return self._get_env_float(self.ENV_TIMEOUT, DEFAULT_HTTP_CLIENT_TIMEOUT_SECONDS)
+
+ @cached_property
+ def max_connections(self) -> int:
+ return self._get_env_int(
+ self.ENV_MAX_CONNECTIONS,
+ DEFAULT_HTTP_CLIENT_MAX_CONNECTIONS,
+ )
+
+ @cached_property
+ def max_keepalive_connections(self) -> int:
+ return self._get_env_int(
+ self.ENV_MAX_KEEPALIVE,
+ DEFAULT_HTTP_CLIENT_MAX_KEEPALIVE_CONNECTIONS,
+ )
diff --git a/src/embedding_service.py b/src/embedding_service.py
index a5afc44..0dc4e30 100644
--- a/src/embedding_service.py
+++ b/src/embedding_service.py
@@ -1,13 +1,21 @@
-from config import EmbeddingServiceConfig
+import asyncio
+import logging
+
from infinity_emb.engine import AsyncEngineArray, EngineArgs
+from infinity_emb.primitives import ModelNotDeployedError
+from PIL import Image
+
+from config import EmbeddingServiceConfig
+from http_client import create_http_client
+from multimodal_utils import validate_item_for_modality
from utils import (
- OpenAIModelInfo,
ModelInfo,
+ OpenAIModelInfo,
list_embeddings_to_response,
to_rerank_response,
)
-import asyncio
+logger = logging.getLogger(__name__)
class EmbeddingService:
@@ -31,18 +39,28 @@ def __init__(self):
self.engine_array = AsyncEngineArray.from_args(engine_args)
self.is_running = False
self.sepamore = asyncio.Semaphore(1)
+ self.http_client = None
async def start(self):
"""starts the engine background loop"""
async with self.sepamore:
if not self.is_running:
await self.engine_array.astart()
+ if self.http_client is None:
+ self.http_client = create_http_client()
+ logger.info("Created persistent HTTP client for image downloads")
+
self.is_running = True
async def stop(self):
"""stops the engine background loop"""
async with self.sepamore:
if self.is_running:
+ if self.http_client is not None:
+ await self.http_client.aclose()
+ self.http_client = None
+ logger.info("Closed HTTP client")
+
await self.engine_array.astop()
self.is_running = False
@@ -56,25 +74,141 @@ def list_models(self) -> list[str]:
async def route_openai_get_embeddings(
self,
- embedding_input: str | list[str],
+ embedding_input: str | list[str] | list[str | bytes | Image.Image],
model_name: str,
+ modality: str = "text",
return_as_list: bool = False,
):
- """returns embeddings for the input text"""
- if not self.is_running:
- await self.start()
- if not isinstance(embedding_input, list):
- embedding_input = [embedding_input]
-
- embeddings, usage = await self.engine_array[model_name].embed(embedding_input)
- if return_as_list:
- return [
- list_embeddings_to_response(embeddings, model=model_name, usage=usage)
- ]
- else:
- return list_embeddings_to_response(
- embeddings, model=model_name, usage=usage
+ """
+ Returns embeddings for the input based on specified modality.
+
+ Args:
+ embedding_input: Input text(s) or image(s) to embed
+ model_name: Name of the model to use
+ modality: Type of input - "text", "image", or "audio" (not yet implemented)
+ return_as_list: Whether to return results as a list
+
+ Raises:
+ ValueError: If model not available, modality invalid, or validation fails
+ NotImplementedError: If modality is "audio"
+ """
+ try:
+ if not self.is_running:
+ await self.start()
+
+ available_models = self.list_models()
+ if model_name not in available_models:
+ available_models_msg = (
+ ", ".join(available_models) if available_models else "none"
+ )
+ error_msg = (
+ f"Model '{model_name}' is not deployed. "
+ f"Available deployments: {available_models_msg}. "
+ f"Deploy the requested model or choose one of the available models."
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ if not isinstance(embedding_input, list):
+ embedding_input = [embedding_input]
+
+ # Validate all items for the specified modality in parallel
+ try:
+ validated_items = await asyncio.gather(
+ *[
+ validate_item_for_modality(
+ item, modality, idx, client=self.http_client
+ )
+ for idx, item in enumerate(embedding_input)
+ ]
+ )
+ except (ValueError, NotImplementedError) as e:
+ logger.error(f"Validation failed for modality '{modality}': {e}")
+ raise
+
+ logger.info(
+ f"Processing {len(validated_items)} {modality} items for model '{model_name}'"
+ )
+
+ # Route to appropriate embedding method based on modality
+ try:
+ if modality == "text":
+ logger.debug(
+ f"Calling .embed() with {len(validated_items)} text items"
+ )
+ embeddings, usage = await self.engine_array[model_name].embed(
+ validated_items
+ )
+ logger.debug(
+ f"Successfully got {len(embeddings)} text embeddings"
+ )
+
+ elif modality == "image":
+ logger.debug(
+ f"Calling .image_embed() with {len(validated_items)} image items"
+ )
+ embeddings, usage = await self.engine_array[model_name].image_embed(
+ images=validated_items
+ )
+ logger.debug(
+ f"Successfully got {len(embeddings)} image embeddings"
+ )
+
+ elif modality == "audio":
+ raise NotImplementedError(
+ "Audio modality is not yet implemented. "
+ "Currently supported modalities: 'text', 'image'"
+ )
+
+ else:
+ raise ValueError(
+ f"Invalid modality: '{modality}'. "
+ f"Supported modalities: 'text', 'image', 'audio' (not yet implemented)"
+ )
+ except ModelNotDeployedError as e:
+ available_capabilities = sorted(
+ getattr(self.engine_array[model_name], "capabilities", set())
+ )
+ capabilities_hint = (
+ ", ".join(available_capabilities) if available_capabilities else "none"
+ )
+ error_msg = (
+ f"Model '{model_name}' does not expose the '{modality}' capability. "
+ f"Detected capabilities: {capabilities_hint}. Deploy a model that supports "
+ f"this modality or adjust the request."
+ )
+ logger.error(f"{error_msg} Original error: {e}")
+ raise ValueError(error_msg) from e
+
+ if return_as_list:
+ return [
+ list_embeddings_to_response(
+ embeddings,
+ model=model_name,
+ usage=usage, # type: ignore[arg-type]
+ )
+ ]
+ else:
+ return list_embeddings_to_response(
+ embeddings,
+ model=model_name,
+ usage=usage, # type: ignore[arg-type]
+ )
+
+ except (ValueError, NotImplementedError) as e:
+ logger.warning(
+ f"Expected error in route_openai_get_embeddings: {type(e).__name__}: {e}"
+ )
+ raise
+ except Exception as e:
+ logger.exception(
+ f"Unexpected error in route_openai_get_embeddings: "
+ f"model='{model_name}', modality='{modality}', "
+ f"input_length={len(embedding_input) if isinstance(embedding_input, list) else 1}"
)
+ raise RuntimeError(
+ f"Internal error while processing embeddings: {type(e).__name__}: {str(e)}"
+ ) from e
async def infinity_rerank(
self, query: str, docs: str, return_docs: str, model_name: str
diff --git a/src/handler.py b/src/handler.py
index 8f38e78..0d43789 100644
--- a/src/handler.py
+++ b/src/handler.py
@@ -25,16 +25,20 @@ async def async_generator_handler(job: dict[str, Any]):
if openai_route and openai_route == "/v1/models":
call_fn, kwargs = embedding_service.route_openai_models, {}
elif openai_route and openai_route == "/v1/embeddings":
- model_name = openai_input.get("model")
if not openai_input:
return create_error_response("Missing input").model_dump()
+ model_name = openai_input.get("model")
if not model_name:
return create_error_response(
"Did not specify model in openai_input"
).model_dump()
+
+ modality = openai_input.get("modality", "text")
+
call_fn, kwargs = embedding_service.route_openai_get_embeddings, {
"embedding_input": openai_input.get("input"),
"model_name": model_name,
+ "modality": modality,
"return_as_list": True,
}
else:
@@ -51,9 +55,11 @@ async def async_generator_handler(job: dict[str, Any]):
"model_name": job_input.get("model"),
}
elif job_input.get("input"):
+ modality = job_input.get("modality", "text")
call_fn, kwargs = embedding_service.route_openai_get_embeddings, {
"embedding_input": job_input.get("input"),
"model_name": job_input.get("model"),
+ "modality": modality,
}
else:
return create_error_response(f"Invalid input: {job}").model_dump()
diff --git a/src/http_client.py b/src/http_client.py
new file mode 100644
index 0000000..dc01a8b
--- /dev/null
+++ b/src/http_client.py
@@ -0,0 +1,32 @@
+import httpx
+
+from config import HttpClientConfig
+
+_CLIENT_CONFIG = HttpClientConfig()
+
+
+def create_http_client() -> httpx.AsyncClient:
+ """
+ Create a configured httpx.AsyncClient.
+
+ Returns:
+ httpx.AsyncClient with proper timeout, limits, and headers configured.
+ """
+ limits = httpx.Limits(
+ max_connections=_CLIENT_CONFIG.max_connections,
+ max_keepalive_connections=_CLIENT_CONFIG.max_keepalive_connections,
+ )
+
+ timeout = httpx.Timeout(_CLIENT_CONFIG.timeout_seconds)
+
+ headers = {
+ "User-Agent": _CLIENT_CONFIG.user_agent,
+ }
+
+ return httpx.AsyncClient(
+ limits=limits,
+ timeout=timeout,
+ headers=headers,
+ follow_redirects=True,
+ trust_env=True,
+ )
diff --git a/src/multimodal_utils.py b/src/multimodal_utils.py
new file mode 100644
index 0000000..53a8b28
--- /dev/null
+++ b/src/multimodal_utils.py
@@ -0,0 +1,204 @@
+import base64
+import io
+import logging
+import re
+from typing import Any
+from PIL import Image
+
+logger = logging.getLogger(__name__)
+
+
+def _ensure_rgb(img: Image.Image) -> Image.Image:
+ """
+ Ensure image is in RGB format (required for CLIP models).
+
+ Args:
+ img: PIL Image in any mode
+
+ Returns:
+ PIL Image in RGB mode
+ """
+ if img.mode != 'RGB':
+ logger.debug(f"Converting image from {img.mode} to RGB")
+ return img.convert('RGB')
+ return img
+
+
+def _is_url(text: str) -> bool:
+ """Check if string is a URL"""
+ return text.startswith('http://') or text.startswith('https://')
+
+
+async def _download_image_from_url(url: str, client) -> Image.Image:
+ """
+ Download image from URL using httpx with proper User-Agent and timeout.
+
+ Args:
+ url: HTTP(S) URL to download image from
+ client: httpx.AsyncClient instance with configured timeout and limits
+
+ Returns:
+ PIL.Image in RGB format
+
+ Raises:
+ ValueError: If download fails or content is not a valid image
+ """
+ try:
+ logger.debug(f"Downloading image from URL: {url}")
+ response = await client.get(url)
+ response.raise_for_status()
+
+ img_bytes = response.content
+ logger.debug(f"Downloaded {len(img_bytes)} bytes from {url} (status: {response.status_code})")
+ except Exception as e:
+ raise ValueError(f"Failed to download image from URL: {type(e).__name__}: {e}") from e
+
+ try:
+ img = Image.open(io.BytesIO(img_bytes))
+ img.load() # Force load to validate it's a real image
+ logger.debug(f"Successfully loaded image from URL: {img.size} {img.mode}")
+
+ return _ensure_rgb(img)
+ except Exception as e:
+ raise ValueError(f"Failed to decode image from URL: {type(e).__name__}: {e}") from e
+
+
+def _is_base64_image(data: str) -> Image.Image | None:
+ """
+ Try to decode string as base64-encoded image.
+ Supports both data URI format (data:image/...) and raw base64.
+
+ Returns:
+ PIL.Image in RGB format, or None if not a valid base64 image
+
+ Note: Converts all images to RGB format for compatibility with multimodal models.
+ """
+ try:
+ # Handle data URI format: data:image/png;base64,iVBORw0KG...
+ if data.startswith('data:'):
+ match = re.match(r'data:image/[^;]+;base64,(.+)', data)
+ if match:
+ base64_data = match.group(1)
+ logger.debug(f"Matched data URI, extracted base64 data (length: {len(base64_data)})")
+ else:
+ logger.debug("data: URI does not match expected format")
+ return None
+ else:
+ # Try raw base64
+ base64_data = data
+ logger.debug("Treating as raw base64")
+
+ img_bytes = base64.b64decode(base64_data)
+ logger.debug(f"Decoded base64 to {len(img_bytes)} bytes")
+
+ img = Image.open(io.BytesIO(img_bytes))
+ img.load() # Force load to validate it's a real image
+ logger.debug(f"Successfully loaded image: {img.size} {img.mode}")
+
+ return _ensure_rgb(img)
+ except Exception as e:
+ logger.warning(f"Failed to decode base64 image: {type(e).__name__}: {e}")
+ return None
+
+
+def validate_text_item(item: Any) -> str:
+ """
+ Validate and convert item to text string.
+ Raises ValueError if item cannot be converted to text.
+ """
+ if isinstance(item, str):
+ return item
+
+ # Try to convert to string
+ try:
+ return str(item)
+ except Exception as e:
+ raise ValueError(f"Cannot convert item to text: {type(item).__name__}") from e
+
+
+async def validate_image_item(item: Any, client=None) -> Image.Image:
+ """
+ Validate and process image item.
+ Accepts: PIL.Image, bytes, URL string, or base64 string.
+ Returns PIL.Image in RGB format.
+ Raises ValueError if item is not a valid image format.
+
+ Args:
+ item: The image item to validate (PIL.Image, bytes, URL string, or base64 string)
+ client: Optional httpx.AsyncClient for downloading URL images with timeout and User-Agent
+
+ Note: All images are converted to RGB format for CLIP compatibility.
+ """
+ # PIL Image - convert to RGB if needed
+ if isinstance(item, Image.Image):
+ return _ensure_rgb(item)
+
+ # Bytes - decode to PIL Image
+ if isinstance(item, bytes):
+ try:
+ img = Image.open(io.BytesIO(item))
+ img.load()
+ logger.debug(f"Loaded image from bytes: {img.size} {img.mode}")
+ return _ensure_rgb(img)
+ except Exception as e:
+ raise ValueError(f"Failed to decode image from bytes: {type(e).__name__}: {e}") from e
+
+ if isinstance(item, str):
+ # URL images - download with proper User-Agent and timeout
+ if _is_url(item):
+ if client is None:
+ raise ValueError("HTTP client required for downloading images from URLs")
+ return await _download_image_from_url(item, client)
+
+ # Base64 images - decode and validate (already converted to RGB)
+ img = _is_base64_image(item)
+ if img is not None:
+ return img
+
+ # Not a valid image format
+ raise ValueError(
+ "String is not a valid image format (must be URL starting with http:// or https://, "
+ "or base64-encoded image with 'data:image/...' prefix)"
+ )
+
+ raise ValueError(
+ f"Invalid image type: {type(item).__name__}. "
+ f"Expected PIL.Image, bytes, URL string, or base64 string."
+ )
+
+
+async def validate_item_for_modality(item: Any, modality: str, index: int, client=None) -> Any:
+ """
+ Validate a single item for the specified modality.
+
+ Args:
+ item: The input item to validate
+ modality: One of "text", "image", "audio"
+ index: The index of the item in the batch (for error messages)
+ client: Optional httpx.AsyncClient for downloading URL images with timeout
+
+ Returns:
+ Validated and processed item suitable for infinity_emb
+
+ Raises:
+ ValueError: If item is not valid for the specified modality
+ NotImplementedError: If modality is "audio"
+ """
+ try:
+ if modality == "text":
+ return validate_text_item(item)
+ elif modality == "image":
+ return await validate_image_item(item, client=client)
+ elif modality == "audio":
+ raise NotImplementedError(
+ "Audio modality is not yet implemented. "
+ "Currently supported modalities: 'text', 'image'"
+ )
+ else:
+ raise ValueError(
+ f"Invalid modality: '{modality}'. "
+ f"Supported modalities: 'text', 'image', 'audio' (not yet implemented)"
+ )
+ except (ValueError, NotImplementedError) as e:
+ # Re-raise with index information for better error messages
+ raise type(e)(f"Item at index {index}: {str(e)}") from e
diff --git a/src/utils.py b/src/utils.py
index 1a59304..4d8b42e 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -1,8 +1,7 @@
-from http import HTTPStatus
-from typing import Annotated, Any, Dict, List, Literal, Optional, Union
-from typing import Any, Dict, Iterable, List, Optional, Union
-from uuid import uuid4
import time
+from http import HTTPStatus
+from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Union
+
import numpy as np
import numpy.typing as npt
from pydantic import BaseModel, Field, conlist
diff --git a/tests/test_modality.py b/tests/test_modality.py
new file mode 100644
index 0000000..870188d
--- /dev/null
+++ b/tests/test_modality.py
@@ -0,0 +1,447 @@
+"""Integration tests for the explicit modality API using pytest."""
+
+import base64
+import io
+from collections.abc import Iterable, Mapping
+from typing import Any
+
+import pytest
+import requests
+from PIL import Image
+
+BASE_URL = "http://localhost:8000"
+RUNSYNC_URL = f"{BASE_URL}/runsync"
+DEFAULT_TIMEOUT_SECONDS = 30
+
+
+def _post_runsync(
+ payload: Mapping[str, Any], timeout: int | float = DEFAULT_TIMEOUT_SECONDS
+) -> requests.Response:
+ """Send a request to the worker and print useful diagnostics."""
+ response = requests.post(RUNSYNC_URL, json=payload, timeout=timeout)
+ print(f"POST {RUNSYNC_URL} -> {response.status_code}")
+ return response
+
+
+def _extract_output(result: Mapping[str, Any]) -> Mapping[str, Any] | None:
+ """Normalise RunPod /runsync output into a single mapping."""
+ output = result.get("output")
+ if isinstance(output, list) and output:
+ return output[0]
+ if isinstance(output, Mapping):
+ return output
+ return None
+
+
+def _extract_error_message(result: Mapping[str, Any]) -> str | None:
+ output = _extract_output(result)
+ if output and output.get("object") == "error":
+ message = output.get("message")
+ return str(message) if message is not None else None
+ return None
+
+
+def generate_red_square_base64() -> str:
+ """Generate a 5x5 red square PNG encoded as a data URI."""
+ img = Image.new("RGB", (5, 5), color=(255, 0, 0))
+ buffer = io.BytesIO()
+ img.save(buffer, format="PNG")
+ buffer.seek(0)
+ img_base64 = base64.b64encode(buffer.read()).decode("utf-8")
+ return f"data:image/png;base64,{img_base64}"
+
+
+@pytest.mark.parametrize(
+ "name,payload,expected_count",
+ [
+ (
+ "text modality",
+ {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "BAAI/bge-small-en-v1.5",
+ "input": ["Hello world", "How are you?"],
+ "modality": "text",
+ },
+ }
+ },
+ 2,
+ ),
+ (
+ "image modality (url)",
+ {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "patrickjohncyh/fashion-clip",
+ "input": [
+ "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg"
+ ],
+ "modality": "image",
+ },
+ }
+ },
+ 1,
+ ),
+ (
+ "image modality (base64)",
+ {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "patrickjohncyh/fashion-clip",
+ "input": [generate_red_square_base64()],
+ "modality": "image",
+ },
+ }
+ },
+ 1,
+ ),
+ (
+ "default text modality",
+ {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "BAAI/bge-small-en-v1.5",
+ "input": ["Default modality test"],
+ },
+ }
+ },
+ 1,
+ ),
+ ],
+)
+def test_modality_success(
+ name: str, payload: Mapping[str, Any], expected_count: int
+) -> None:
+ response = _post_runsync(payload)
+ assert response.status_code == 200, f"HTTP error for {name}: {response.text}"
+
+ result = response.json()
+ assert result.get("status") == "COMPLETED", (
+ f"Unexpected status for {name}: {result}"
+ )
+
+ output = _extract_output(result) or {}
+ data = output.get("data", [])
+ assert isinstance(data, Iterable), f"Missing data for {name}: {output}"
+
+ if isinstance(data, list):
+ assert len(data) == expected_count, (
+ f"Expected {expected_count} embeddings for {name}, got {len(data)}"
+ )
+ else:
+ pytest.fail(f"Unexpected data format for {name}: {type(data)}")
+
+
+def test_wrong_modality_error() -> None:
+ payload = {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "BAAI/bge-small-en-v1.5",
+ "input": [
+ "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg"
+ ],
+ "modality": "image",
+ },
+ }
+ }
+
+ response = _post_runsync(payload)
+ if response.status_code != 200:
+ assert response.status_code >= 400
+ return
+
+ result = response.json()
+ error_msg = _extract_error_message(result)
+ assert error_msg is not None, f"Expected error object, got: {result}"
+ lowered = error_msg.lower()
+ assert "does not expose the 'image' capability" in lowered
+ assert "detected capabilities" in lowered
+
+
+def test_audio_not_implemented() -> None:
+ payload = {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "BAAI/bge-small-en-v1.5",
+ "input": ["audio data"],
+ "modality": "audio",
+ },
+ }
+ }
+
+ response = _post_runsync(payload)
+ if response.status_code != 200:
+ assert response.status_code >= 400
+ return
+
+ result = response.json()
+ error_msg = _extract_error_message(result)
+ assert error_msg is not None, f"Expected NotImplementedError output, got: {result}"
+ assert "not yet implemented" in error_msg.lower()
+
+
+def test_validation_flexibility() -> None:
+ payload = {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "BAAI/bge-small-en-v1.5",
+ "input": ["https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg"],
+ "modality": "text",
+ },
+ }
+ }
+
+ response = _post_runsync(payload)
+ assert response.status_code == 200, f"HTTP error: {response.text}"
+
+ result = response.json()
+ assert result.get("status") == "COMPLETED", (
+ f"Expected success treating URL as text: {result}"
+ )
+
+
+@pytest.mark.parametrize(
+ "payload,expected_count",
+ [
+ (
+ {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "BAAI/bge-small-en-v1.5",
+ "input": [],
+ "modality": "text",
+ },
+ }
+ },
+ 0,
+ ),
+ (
+ {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "BAAI/bge-small-en-v1.5",
+ "input": "Single text string",
+ "modality": "text",
+ },
+ }
+ },
+ 1,
+ ),
+ ],
+)
+def test_text_edge_cases(payload: Mapping[str, Any], expected_count: int) -> None:
+ response = _post_runsync(payload)
+ if response.status_code != 200:
+ assert response.status_code >= 400
+ return
+
+ result = response.json()
+ if result.get("status") != "COMPLETED":
+ assert _extract_error_message(result) is not None
+ return
+
+ output = _extract_output(result) or {}
+ data = output.get("data", [])
+ assert isinstance(data, list)
+ assert len(data) == expected_count
+
+
+def test_edge_very_long_text() -> None:
+ long_text = "This is a test sentence. " * 200
+ payload = {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "BAAI/bge-small-en-v1.5",
+ "input": long_text,
+ "modality": "text",
+ },
+ }
+ }
+
+ response = _post_runsync(payload)
+ assert response.status_code == 200, f"HTTP error: {response.text}"
+
+ result = response.json()
+ assert result.get("status") == "COMPLETED", f"Unexpected status: {result}"
+
+
+def test_edge_empty_string() -> None:
+ payload = {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "BAAI/bge-small-en-v1.5",
+ "input": "",
+ "modality": "text",
+ },
+ }
+ }
+
+ response = _post_runsync(payload)
+ if response.status_code != 200:
+ assert response.status_code >= 400
+ return
+
+ result = response.json()
+ if result.get("status") == "COMPLETED":
+ output = _extract_output(result) or {}
+ data = output.get("data", [])
+ assert isinstance(data, list)
+ else:
+ assert _extract_error_message(result) is not None
+
+
+def test_edge_invalid_modality() -> None:
+ payload = {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "BAAI/bge-small-en-v1.5",
+ "input": "test",
+ "modality": "video",
+ },
+ }
+ }
+
+ response = _post_runsync(payload)
+ if response.status_code != 200:
+ assert response.status_code >= 400
+ return
+
+ result = response.json()
+ error_msg = _extract_error_message(result)
+ assert error_msg is not None, f"Expected invalid modality error: {result}"
+ assert "invalid modality" in error_msg.lower()
+
+
+def test_edge_missing_model() -> None:
+ payload = {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "input": "test",
+ "modality": "text",
+ },
+ }
+ }
+
+ response = _post_runsync(payload)
+ if response.status_code != 200:
+ assert response.status_code >= 400
+ return
+
+ result = response.json()
+ error_msg = _extract_error_message(result)
+ assert error_msg is not None, f"Expected missing model error: {result}"
+
+
+def test_edge_nonexistent_model() -> None:
+ payload = {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "nonexistent/model-12345",
+ "input": "test",
+ "modality": "text",
+ },
+ }
+ }
+
+ response = _post_runsync(payload)
+ if response.status_code != 200:
+ assert response.status_code >= 400
+ return
+
+ result = response.json()
+ error_msg = _extract_error_message(result)
+ assert error_msg is not None, f"Expected model missing error: {result}"
+ lowered = error_msg.lower()
+ assert "is not deployed" in lowered
+ assert "available deployments" in lowered
+
+
+def test_edge_invalid_image_url() -> None:
+ payload = {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "patrickjohncyh/fashion-clip",
+ "input": "https://example.com/nonexistent-image-12345.jpg",
+ "modality": "image",
+ },
+ }
+ }
+
+ response = _post_runsync(payload)
+ if response.status_code != 200:
+ assert response.status_code >= 400
+ return
+
+ result = response.json()
+ assert result.get("status") in {"COMPLETED", "FAILED"}, (
+ f"Unexpected status: {result}"
+ )
+
+
+def test_edge_special_characters() -> None:
+ payload = {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "BAAI/bge-small-en-v1.5",
+ "input": "Hello 世界! 🌍 Special chars: @#$%^&*()_+-=[]{}|;:',.<>?/~`",
+ "modality": "text",
+ },
+ }
+ }
+
+ response = _post_runsync(payload)
+ assert response.status_code == 200, f"HTTP error: {response.text}"
+
+ result = response.json()
+ assert result.get("status") == "COMPLETED", f"Unexpected status: {result}"
+
+
+def test_edge_corrupted_base64() -> None:
+ invalid_payloads = [
+ "data:image/png;base64,NotValidBase64Data!!!",
+ "data:image/png;base64,iVBORw0KGgoAAAA",
+ "data:image/jpeg;base64,/9j/4AAQSkZJRg",
+ "data:image/png;base64,SGVsbG8gV29ybGQh",
+ ]
+
+ for idx, invalid_base64 in enumerate(invalid_payloads, start=1):
+ print(
+ f"\n Test variant {idx}/{len(invalid_payloads)}: {invalid_base64[:50]}..."
+ )
+ payload = {
+ "input": {
+ "openai_route": "/v1/embeddings",
+ "openai_input": {
+ "model": "patrickjohncyh/fashion-clip",
+ "input": invalid_base64,
+ "modality": "image",
+ },
+ }
+ }
+
+ response = _post_runsync(payload)
+ if response.status_code != 200:
+ assert response.status_code >= 400
+ continue
+
+ result = response.json()
+ error_msg = _extract_error_message(result)
+ assert error_msg is not None, f"Expected error for corrupted base64: {result}"
+
+ print("\n✓ All corrupted base64 variants handled without crash")