Skip to content

Commit d73fbb1

Browse files
committed
Add support for configurable chrF metric parameters in task YAML, fix #2256
1 parent e916aa4 commit d73fbb1

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

lm_eval/api/metrics.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66
import string
77
from collections.abc import Iterable
8+
from functools import partial
89
from typing import Callable, List, Optional, Sequence, TypeVar
910

1011
import numpy as np
@@ -99,7 +100,7 @@ def bleu(items):
99100

100101

101102
@register_aggregation("chrf")
102-
def chrf(items):
103+
def chrf(items, char_order=6, word_order=0, **kwargs):
103104
"""chrF++ is a tool for automatic evaluation of machine translation output
104105
based on character n-gram precision and recall enhanced with word n-grams.
105106
Source: https://github.com/m-popovic/chrF
@@ -110,7 +111,9 @@ def chrf(items):
110111
refs = list(zip(*items))[0]
111112
preds = list(zip(*items))[1]
112113
refs, preds = _sacreformat(refs, preds)
113-
return sacrebleu.corpus_chrf(preds, refs).score
114+
return sacrebleu.corpus_chrf(
115+
preds, refs, char_order=char_order, word_order=word_order, **kwargs
116+
).score
114117

115118

116119
@register_aggregation("ter")
@@ -482,7 +485,7 @@ def _bootstrap_internal_no_mp(
482485
chunk_size = min(1000, iters)
483486
from tqdm import tqdm
484487

485-
print(f"bootstrapping for stddev: {f.__name__}")
488+
print(f"bootstrapping for stddev: {getattr(f, '__name__', repr(f))}")
486489

487490
# A single loop replaces the multiprocessing pool.
488491
for i in tqdm(range(iters // chunk_size)):
@@ -515,7 +518,7 @@ def bootstrap_stderr(
515518
chunk_size = min(1000, iters)
516519
from tqdm import tqdm
517520

518-
print("bootstrapping for stddev:", f.__name__)
521+
print("bootstrapping for stddev:", getattr(f, "__name__", repr(f)))
519522
with mp.Pool(mp.cpu_count()) as pool:
520523
for bootstrap in tqdm(
521524
pool.imap(
@@ -533,7 +536,9 @@ def bootstrap_stderr(
533536

534537

535538
def stderr_for_metric(
536-
metric: Callable[[Sequence[T]], float], bootstrap_iters: int
539+
metric: Callable[[Sequence[T]], float],
540+
bootstrap_iters: int,
541+
metric_kwargs: dict = None,
537542
) -> Optional[Callable[[Sequence[T]], float]]:
538543
"""
539544
Return a function that estimates the standard error of `metric(xs)`.
@@ -548,6 +553,9 @@ def stderr_for_metric(
548553
# return no function (don't compute stderr) if bootstrap iters = 0
549554
return None
550555

556+
if metric_kwargs is None:
557+
metric_kwargs = {}
558+
551559
bootstrappable = [
552560
median,
553561
matthews_corrcoef,
@@ -560,7 +568,10 @@ def stderr_for_metric(
560568
]
561569

562570
if metric in bootstrappable:
563-
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
571+
metric_with_kwargs = (
572+
partial(metric, **metric_kwargs) if metric_kwargs else metric
573+
)
574+
return lambda x: bootstrap_stderr(metric_with_kwargs, x, iters=bootstrap_iters)
564575

565576
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
566577

lm_eval/evaluator_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,21 @@ def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
111111
# TODO: Handle this better and allow other aggregate functions other than mean.
112112
agg_fn = mean
113113
metric_key = f"{metric},{filter_key}"
114-
self.agg_metrics[metric_key] = agg_fn(items)
114+
metric_kwargs = {}
115+
if (
116+
hasattr(self.task, "_metric_fn_kwargs")
117+
and metric in self.task._metric_fn_kwargs
118+
):
119+
metric_kwargs = self.task._metric_fn_kwargs[metric]
120+
self.agg_metrics[metric_key] = agg_fn(items, **metric_kwargs)
115121
self.sample_len = len(items) # TODO: same sample size for each metric?
116122
if isinstance(bootstrap_iters, int):
117123
stderr_fn = stderr_for_metric(
118124
metric=agg_fn,
119125
bootstrap_iters=min(bootstrap_iters, 100)
120126
if metric in ["bleu", "chrf", "ter"]
121127
else bootstrap_iters,
128+
metric_kwargs=metric_kwargs,
122129
)
123130
self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
124131
stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"

0 commit comments

Comments
 (0)