Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mteb/_create_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _create_text_dataloader_for_queries(


def _convert_conv_history_to_query(
row: dict[str, list[str] | Conversation],
row: dict[str, str | list[str] | Conversation],
) -> dict[str, str | Conversation]:
"""Convert a conversation history to a single query string.

Expand All @@ -130,10 +130,10 @@ def _convert_conv_history_to_query(
conversation = row["text"]
# if it's a list of strings, just join them
if isinstance(conversation, list) and isinstance(conversation[0], str):
conversation = cast(list[str], conversation)
conv_str = "; ".join(conversation)
conversation_ = cast(list[str], conversation)
conv_str = "; ".join(conversation_)
current_conversation = [
ConversationTurn(role="user", content=message) for message in conversation
ConversationTurn(role="user", content=message) for message in conversation_
]
if not _warned_about_user_role:
logger.warning(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections.abc import Sequence
from typing import Any

import torch
Expand Down Expand Up @@ -56,8 +57,8 @@ class ImageTextPairClassificationEvaluator(Evaluator):
def __init__(
self,
dataset,
images_column_names: str | list[str],
texts_column_names: str | list[str],
images_column_names: str | Sequence[str],
texts_column_names: str | Sequence[str],
num_images_per_sample: int,
num_texts_per_sample: int,
task_metadata: TaskMetadata,
Expand Down
6 changes: 3 additions & 3 deletions mteb/abstasks/_stratification.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def _prepare_stratification(
rows = sp.lil_matrix(y).rows
rows_used = dict.fromkeys(range(self.n_samples), False)
all_combinations = []
per_row_combinations = [[] for i in range(self.n_samples)]
samples_with_combination = {}
per_row_combinations: list[list[int]] = [[] for i in range(self.n_samples)]
samples_with_combination: dict[str, int] = {}
folds = [[] for _ in range(self.n_splits)] # type: ignore

# for every row
Expand All @@ -229,7 +229,7 @@ def _prepare_stratification(
all_combinations.append(combination)
per_row_combinations[sample_index].append(combination)

all_combinations = [list(x) for x in set(all_combinations)]
all_combinations: list[list[int]] = [list(x) for x in set(all_combinations)]

self.desired_samples_per_combination_per_fold = {
combination: np.array(
Expand Down
18 changes: 12 additions & 6 deletions mteb/abstasks/abstask.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from copy import copy
from pathlib import Path
from typing import Any, cast
from typing import Any, TypedDict, cast

import numpy as np
from datasets import ClassLabel, Dataset, DatasetDict, load_dataset
Expand Down Expand Up @@ -60,6 +60,12 @@ def _multilabel_subsampling(
return dataset_dict


class AbsMetrics(TypedDict):
"""The abstract class for the metrics returned by the tasks"""

...

Comment on lines +63 to +67
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure why this is added?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to standartize, because dict is not compatible with mappting, but I will remove this I think and change evaluate subset to Mapping, which can handle both


class AbsTask(ABC):
"""The abstract class for the tasks

Expand Down Expand Up @@ -123,7 +129,7 @@ def evaluate(
encode_kwargs: dict[str, Any],
prediction_folder: Path | None = None,
**kwargs: Any,
) -> dict[HFSubset, ScoresDict]:
) -> Mapping[HFSubset, ScoresDict]:
"""Evaluates an MTEB compatible model on the task.

Args:
Expand Down Expand Up @@ -198,12 +204,12 @@ def _evaluate_subset(
model: EncoderProtocol,
data_split: Dataset,
*,
encode_kwargs: dict[str, Any],
hf_split: str,
hf_subset: str,
encode_kwargs: dict[str, Any],
prediction_folder: Path | None = None,
**kwargs: Any,
) -> ScoresDict:
) -> AbsMetrics:
raise NotImplementedError(
"If you are using the default evaluate method, you must implement _evaluate_subset method."
)
Expand Down Expand Up @@ -499,7 +505,7 @@ def filter_languages(
self.hf_subsets = subsets_to_keep
return self

def _add_main_score(self, scores: dict[HFSubset, ScoresDict]) -> None:
def _add_main_score(self, scores: dict[HFSubset, ScoresDict | AbsMetrics]) -> None:
scores["main_score"] = scores[self.metadata.main_score]

def _upload_dataset_to_hub(
Expand Down
6 changes: 3 additions & 3 deletions mteb/abstasks/image/image_text_pair_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from collections.abc import Sequence
from pathlib import Path
from typing import Any, TypedDict
from typing import Any

import torch
from datasets import Dataset, concatenate_datasets
Expand All @@ -11,7 +11,7 @@
calculate_image_statistics,
calculate_text_statistics,
)
from mteb.abstasks.abstask import AbsTask
from mteb.abstasks.abstask import AbsMetrics, AbsTask
from mteb.models.models_protocols import EncoderProtocol
from mteb.types.statistics import (
ImageStatistics,
Expand All @@ -36,7 +36,7 @@ class ImageTextPairClassificationDescriptiveStatistics(SplitDescriptiveStatistic
image_statistics: ImageStatistics


class ImageTextPairClassificationMetrics(TypedDict):
class ImageTextPairClassificationMetrics(AbsMetrics):
"""ImageTextPairClassification metrics.

Attributes:
Expand Down
10 changes: 6 additions & 4 deletions mteb/abstasks/task_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
"InstructionReranking",
) + MIEB_TASK_TYPE

TaskType = Literal[_TASK_TYPE]
TaskType = Literal[_TASK_TYPE] # type: ignore[valid-type]
"""The type of the task. E.g. includes "Classification", "Retrieval" and "Clustering"."""


Expand Down Expand Up @@ -193,7 +193,9 @@


PromptDict = TypedDict(
"PromptDict", {prompt_type.value: str for prompt_type in PromptType}, total=False
"PromptDict",
{prompt_type.value: str for prompt_type in PromptType},
total=False, # type: ignore[misc]
)
"""A dictionary containing the prompt used for the task.

Expand Down Expand Up @@ -447,7 +449,7 @@ def get_modalities(self, prompt_type: PromptType | None = None) -> list[Modaliti
Raises:
ValueError: If the prompt type is not recognized.
"""
if prompt_type is None:
if prompt_type is None or self.category is None:
return self.modalities
query_modalities, doc_modalities = self.category.split("2")
category_to_modality: dict[str, Modalities] = {
Expand Down Expand Up @@ -711,7 +713,7 @@ def _hf_languages(self) -> list[str]:
readme_langs.append(lang_name)
return sorted(set(readme_langs))

def _hf_license(self) -> str:
def _hf_license(self) -> str | None:
dataset_license = self.license
if dataset_license:
license_mapping = {
Expand Down
62 changes: 44 additions & 18 deletions mteb/abstasks/text/bitext_mining.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import logging
from collections import defaultdict
from pathlib import Path
from typing import Any, ClassVar, TypedDict
from typing import Any, ClassVar, cast

from datasets import Dataset, DatasetDict
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

from mteb._evaluators import BitextMiningEvaluator
from mteb.abstasks._statistics_calculation import calculate_text_statistics
from mteb.abstasks.abstask import AbsTask
from mteb.abstasks.abstask import AbsMetrics, AbsTask
from mteb.models import EncoderProtocol, MTEBModels
from mteb.models.models_protocols import CrossEncoderProtocol, SearchProtocol
from mteb.types import HFSubset, ScoresDict
from mteb.types.statistics import SplitDescriptiveStatistics, TextStatistics

Expand All @@ -36,7 +37,7 @@ class BitextDescriptiveStatistics(SplitDescriptiveStatistics):
sentence2_statistics: TextStatistics


class BitextMiningMetrics(TypedDict):
class BitextMiningMetrics(AbsMetrics):
"""Metrics for BitextMining tasks

Attributes:
Expand Down Expand Up @@ -78,6 +79,23 @@ def evaluate(
**kwargs: Any,
) -> dict[HFSubset, ScoresDict]:
"""Added load for "parallel" datasets"""
if isinstance(model, CrossEncoderProtocol) and not self._support_cross_encoder:
raise TypeError(
f"Model {model} is a CrossEncoder, but this task {self.metadata.name} does not support CrossEncoders. "
"Please use a Encoder model instead."
)

# encoders might implement search protocols
if (
isinstance(model, SearchProtocol)
and not isinstance(model, EncoderProtocol)
and not self._support_search
):
raise TypeError(
f"Model {model} is a SearchProtocol, but this task {self.metadata.name} does not support Search. "
"Please use a Encoder model instead."
)

if not self.data_loaded:
self.load_data()

Expand All @@ -87,11 +105,16 @@ def evaluate(
if subsets_to_run is not None:
hf_subsets = [s for s in hf_subsets if s in subsets_to_run]

scores = {}
encoder_model = cast(EncoderProtocol, model)

if self.dataset is None:
raise ValueError("Dataset is not loaded.")

scores: dict[str, BitextMiningMetrics] = {}
if self.parallel_subsets:
scores = self._evaluate_subset(
model,
self.dataset[split], # type: ignore
scores = self._evaluate_subset( # type: ignore[assignment]
encoder_model,
self.dataset[split],
parallel=True,
hf_split=split,
hf_subset="parallel",
Expand All @@ -109,8 +132,8 @@ def evaluate(
data_split = self.dataset[split]
else:
data_split = self.dataset[hf_subset][split]
scores[hf_subset] = self._evaluate_subset(
model,
scores[hf_subset] = self._evaluate_subset( # type: ignore[assignment]
encoder_model,
data_split,
hf_split=split,
hf_subset=hf_subset,
Expand All @@ -124,21 +147,21 @@ def evaluate(
def _get_pairs(self, parallel: bool) -> list[tuple[str, str]]:
pairs = self._DEFAULT_PAIR
if parallel:
pairs = [langpair.split("-") for langpair in self.hf_subsets]
pairs = [langpair.split("-") for langpair in self.hf_subsets] # type: ignore[misc]
return pairs

def _evaluate_subset(
def _evaluate_subset( # type: ignore[override]
self,
model: EncoderProtocol,
data_split: Dataset,
*,
hf_split: str,
hf_subset: str,
parallel: bool = False,
encode_kwargs: dict[str, Any],
prediction_folder: Path | None = None,
parallel: bool = False,
**kwargs,
) -> ScoresDict:
) -> BitextMiningMetrics | dict[str, BitextMiningMetrics]:
pairs = self._get_pairs(parallel)

evaluator = BitextMiningEvaluator(
Expand Down Expand Up @@ -250,8 +273,11 @@ def _calculate_descriptive_statistics_from_split(
)

def _push_dataset_to_hub(self, repo_name: str) -> None:
if self.dataset is None:
raise ValueError("Dataset is not loaded.")

if self.metadata.is_multilingual:
dataset = defaultdict(dict)
dataset: dict[str, dict[str, list[str]]] = defaultdict(dict)
for config in self.metadata.eval_langs:
logger.info(f"Converting {config} of {self.metadata.name}")

Expand All @@ -266,10 +292,10 @@ def _push_dataset_to_hub(self, repo_name: str) -> None:
for split in self.dataset[config]:
dataset[split][lang_1] = self.dataset[config][split][sent_1]
dataset[split][lang_2] = self.dataset[config][split][sent_2]
for split in dataset:
dataset[split] = Dataset.from_dict(dataset[split])
dataset = DatasetDict(dataset)
dataset.push_to_hub(repo_name)
dataset_dict = DatasetDict(
{split: Dataset.from_dict(dataset[split]) for split in dataset}
)
dataset_dict.push_to_hub(repo_name)
else:
sentences = {}
for split in self.dataset:
Expand Down
8 changes: 5 additions & 3 deletions mteb/abstasks/text/reranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

logger = logging.getLogger(__name__)

OLD_FORMAT_RERANKING_TASKS = []
OLD_FORMAT_RERANKING_TASKS: list[str] = []


@deprecated(
Expand Down Expand Up @@ -105,7 +105,9 @@ def transform_old_dataset_format(self, given_dataset: Dataset | None = None):
)

given_dataset = copy(given_dataset)
self.dataset = defaultdict(lambda: defaultdict(dict))
self.dataset: dict[str, dict[str, RetrievalSplitData]] = defaultdict(
lambda: defaultdict(dict) # type: ignore[arg-type]
)

hf_subsets = self.hf_subsets

Expand All @@ -127,7 +129,7 @@ def transform_old_dataset_format(self, given_dataset: Dataset | None = None):
for split in cur_dataset:
corpus = []
queries = []
relevant_docs = defaultdict(dict)
relevant_docs: dict[str, dict[str, int]] = defaultdict(dict)
top_ranked = defaultdict(list)

# Create an enumerated dataset to pass indices
Expand Down
Loading
Loading