Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion fastembed/late_interaction_multimodal/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ class ColPali(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[NumpyA
BOS_TOKEN = "<s>"
PAD_TOKEN = "<pad>"
QUERY_MARKER_TOKEN_ID = [2, 5098]
IMAGE_TOKEN_ID = 257152 # The '<image>' special token
IMAGE_PLACEHOLDER_SIZE = (3, 448, 448)
EMPTY_TEXT_PLACEHOLDER = np.array(
[257152] * 1024 + [2, 50721, 573, 2416, 235265, 108]
[IMAGE_TOKEN_ID] * 1024 + [2, 50721, 573, 2416, 235265, 108]
) # This is a tokenization of '<image>' * 1024 + '<bos>Describe the image.\n' line which is used as placeholder
# while processing an image
EVEN_ATTENTION_MASK = np.array([1] * 1030)
Expand Down Expand Up @@ -298,6 +299,39 @@ def embed_image(
**kwargs,
)

def get_image_mask(
self,
images: ImageInput | Iterable[ImageInput],
**kwargs: Any,
) -> list[NumpyArray]:
"""
Generate image token masks for ColPali embeddings.

For ColPali, image embeddings use 1030 tokens:
- Tokens 0-1023: Image tokens (token ID 257152)
- Tokens 1024-1029: Text tokens from prompt "Describe the image.\\n"

Args:
images: Single image or iterable of images
**kwargs: Additional processing arguments (reserved for future use)

Returns:
List of binary masks (dtype=bool) where True = image token (ID 257152), False = other tokens.
"""
from pathlib import Path

# Ensure images is iterable
is_single = isinstance(images, (str, bytes, Path)) or hasattr(images, "read")
images_to_process: Iterable[ImageInput] = [images] if is_single else images # type: ignore[assignment, list-item]

# Generate masks - all images get the same mask based on fixed tokenization pattern
masks: list[NumpyArray] = []
for _ in images_to_process:
mask: NumpyArray = self.EMPTY_TEXT_PLACEHOLDER == self.IMAGE_TOKEN_ID
masks.append(mask)

return masks

@classmethod
def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]:
return ColPaliTextEmbeddingWorker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,36 @@ def token_count(
return self.model.token_count(
texts, batch_size=batch_size, include_extension=include_extension, **kwargs
)

def get_image_mask(
self,
images: ImageInput | Iterable[ImageInput],
**kwargs: Any,
) -> list[NumpyArray]:
"""
Generate binary masks identifying image tokens in processed image sequences.

This method processes images and returns masks indicating which tokens in the
resulting sequence correspond to image content (value=1) vs text/special tokens (value=0).

Args:
images: Single image or iterable of images (file paths, bytes, or PIL Image objects)
**kwargs: Additional keyword arguments (reserved for future use)

Returns:
List of binary masks (numpy arrays with dtype=bool), one per image. Each mask has shape (sequence_length,)
where sequence_length is the number of tokens in the processed image representation.
Values are True for image tokens, False for non-image tokens (text, special tokens, etc.).

Raises:
NotImplementedError: If the underlying model doesn't support image mask generation.

Example:
```python
model = LateInteractionMultimodalEmbedding("Qdrant/colpali-v1.3-fp16")
masks = model.get_image_mask(["image1.jpg", "image2.jpg"])
# masks[0] is a numpy array of shape (1030,) with dtype=bool for ColPali
# First 1024 values are True (image tokens), last 6 are False (text tokens)
```
"""
return self.model.get_image_mask(images, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,39 @@ def token_count(
) -> int:
"""Returns the number of tokens in the texts."""
raise NotImplementedError("Subclasses must implement this method")

def get_image_mask(
self,
images: ImageInput | Iterable[ImageInput],
**kwargs: Any,
) -> list[NumpyArray]:
"""
Generate binary masks identifying image tokens in processed image sequences.

This method processes images and returns masks indicating which tokens in the
resulting sequence correspond to image content (value=1) vs text/special tokens (value=0).

Args:
images: Single image or iterable of images (file paths, bytes, or PIL Image objects)
**kwargs: Additional keyword arguments (reserved for future use)

Returns:
List of binary masks (numpy arrays with dtype=bool), one per image. Each mask has shape (sequence_length,)
where sequence_length is the number of tokens in the processed image representation.
Values are True for image tokens, False for non-image tokens (text, special tokens, etc.).

Raises:
NotImplementedError: If the model doesn't support image mask generation.

Example:
```python
model = ColPali(model_name="Qdrant/colpali-v1.3-fp16")
masks = model.get_image_mask(["image1.jpg", "image2.jpg"])
# masks[0] is a numpy array of shape (1030,) with dtype=bool for ColPali
# First 1024 values are True (image tokens), last 6 are False (text tokens)
```
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support image mask generation. "
"Override this method in subclasses to provide model-specific implementation."
)
61 changes: 61 additions & 0 deletions tests/test_late_interaction_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,64 @@ def test_token_count() -> None:
assert short_doc_token_count + long_doc_token_count < model.token_count(
documents, include_extension=True
)


def test_colpali_image_mask():
"""Test that get_image_mask returns correct masks for image tokens."""
if os.getenv("CI"):
pytest.skip("Colpali is too large to test in CI")

model = LateInteractionMultimodalEmbedding(model_name="Qdrant/colpali-v1.3-fp16")

# Get mask for single image
masks = model.get_image_mask([images[0]])

assert len(masks) == 1, "Should return one mask per image"
mask = masks[0]

# ColPali uses 1030 tokens total: 1024 image + 6 text
assert mask.shape == (1030,), f"Expected shape (1030,), got {mask.shape}"
assert mask.dtype == np.bool_, f"Expected bool dtype, got {mask.dtype}"

# First 1024 tokens should be image tokens (value=True)
assert np.all(mask[:1024]), "First 1024 tokens should be image tokens (True)"

# Last 6 tokens should be text tokens (value=False)
assert np.all(~mask[1024:]), "Last 6 tokens should be text tokens (False)"

# Test with multiple images
masks = model.get_image_mask([images[0], images[1]])
assert len(masks) == 2, "Should return two masks for two images"
assert all(m.shape == (1030,) for m in masks), "All masks should have same shape"


def test_colpali_image_mask_single_image():
"""Test get_image_mask with a single image (not in a list)."""
if os.getenv("CI"):
pytest.skip("Colpali is too large to test in CI")

model = LateInteractionMultimodalEmbedding(model_name="Qdrant/colpali-v1.3-fp16")

# Pass single image without list
masks = model.get_image_mask(images[0])

assert len(masks) == 1, "Should return one mask for single image"
assert masks[0].shape == (1030,), "Mask should have correct shape"


def test_base_class_raises_not_implemented():
"""Test that base class raises NotImplementedError."""
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
LateInteractionMultimodalEmbeddingBase,
)

# Create a minimal subclass that doesn't implement get_image_mask
class MinimalModel(LateInteractionMultimodalEmbeddingBase):
pass

model = MinimalModel(model_name="test", cache_dir="/tmp")

with pytest.raises(NotImplementedError) as exc_info:
model.get_image_mask(["dummy.jpg"])

assert "does not support image mask generation" in str(exc_info.value)