55import re
66import string
77from collections .abc import Iterable
8+ from functools import partial
89from typing import Callable , List , Optional , Sequence , TypeVar
910
1011import numpy as np
@@ -99,7 +100,7 @@ def bleu(items):
99100
100101
101102@register_aggregation ("chrf" )
102- def chrf (items ):
103+ def chrf (items , char_order = 6 , word_order = 0 , ** kwargs ):
103104 """chrF++ is a tool for automatic evaluation of machine translation output
104105 based on character n-gram precision and recall enhanced with word n-grams.
105106 Source: https://github.com/m-popovic/chrF
@@ -110,7 +111,9 @@ def chrf(items):
110111 refs = list (zip (* items ))[0 ]
111112 preds = list (zip (* items ))[1 ]
112113 refs , preds = _sacreformat (refs , preds )
113- return sacrebleu .corpus_chrf (preds , refs ).score
114+ return sacrebleu .corpus_chrf (
115+ preds , refs , char_order = char_order , word_order = word_order , ** kwargs
116+ ).score
114117
115118
116119@register_aggregation ("ter" )
@@ -482,7 +485,7 @@ def _bootstrap_internal_no_mp(
482485 chunk_size = min (1000 , iters )
483486 from tqdm import tqdm
484487
485- print (f"bootstrapping for stddev: { f . __name__ } " )
488+ print (f"bootstrapping for stddev: { getattr ( f , ' __name__' , repr ( f )) } " )
486489
487490 # A single loop replaces the multiprocessing pool.
488491 for i in tqdm (range (iters // chunk_size )):
@@ -515,7 +518,7 @@ def bootstrap_stderr(
515518 chunk_size = min (1000 , iters )
516519 from tqdm import tqdm
517520
518- print ("bootstrapping for stddev:" , f . __name__ )
521+ print ("bootstrapping for stddev:" , getattr ( f , " __name__" , repr ( f )) )
519522 with mp .Pool (mp .cpu_count ()) as pool :
520523 for bootstrap in tqdm (
521524 pool .imap (
@@ -533,7 +536,9 @@ def bootstrap_stderr(
533536
534537
535538def stderr_for_metric (
536- metric : Callable [[Sequence [T ]], float ], bootstrap_iters : int
539+ metric : Callable [[Sequence [T ]], float ],
540+ bootstrap_iters : int ,
541+ metric_kwargs : dict = None ,
537542) -> Optional [Callable [[Sequence [T ]], float ]]:
538543 """
539544 Return a function that estimates the standard error of `metric(xs)`.
@@ -548,6 +553,9 @@ def stderr_for_metric(
548553 # return no function (don't compute stderr) if bootstrap iters = 0
549554 return None
550555
556+ if metric_kwargs is None :
557+ metric_kwargs = {}
558+
551559 bootstrappable = [
552560 median ,
553561 matthews_corrcoef ,
@@ -560,7 +568,10 @@ def stderr_for_metric(
560568 ]
561569
562570 if metric in bootstrappable :
563- return lambda x : bootstrap_stderr (metric , x , iters = bootstrap_iters )
571+ metric_with_kwargs = (
572+ partial (metric , ** metric_kwargs ) if metric_kwargs else metric
573+ )
574+ return lambda x : bootstrap_stderr (metric_with_kwargs , x , iters = bootstrap_iters )
564575
565576 stderr = {mean : mean_stderr , acc_all : acc_all_stderr }
566577
0 commit comments