4141 ProgressCallback ,
4242 PrerequisiteNotMetError ,
4343 check_min_sample_size ,
44- MAX_SAMPLE_SIZE_ACCURACY ,
45- MAX_SAMPLE_SIZE_EMBEDDINGS ,
4644 add_tqdm ,
4745 NXT_COLUMN ,
4846 CTX_COLUMN_PREFIX ,
@@ -69,8 +67,8 @@ def report(
6967 report_subtitle : str = "" ,
7068 report_credits : str = REPORT_CREDITS ,
7169 report_extra_info : str = "" ,
72- max_sample_size_accuracy : int = MAX_SAMPLE_SIZE_ACCURACY ,
73- max_sample_size_embeddings : int = MAX_SAMPLE_SIZE_EMBEDDINGS ,
70+ max_sample_size_accuracy : int | None = None ,
71+ max_sample_size_embeddings : int | None = None ,
7472 statistics_path : str | Path | None = None ,
7573 on_progress : ProgressCallback | None = None ,
7674) -> tuple [Path , dict | None ]:
@@ -225,6 +223,11 @@ def report(
225223 )
226224 on_progress (current = 30 , total = 100 )
227225
226+ # ensure that embeddings are all of equal size for a fair 3-way comparison
227+ max_sample_size_embeddings = min (syn_sample_size , trn_sample_size )
228+ if hol_sample_size != 0 :
229+ max_sample_size_embeddings = min (max_sample_size_embeddings , hol_sample_size )
230+
228231 # calculate embeddings
229232 syn_embeds = calculate_embeddings (
230233 pull_data_for_embeddings (
@@ -237,21 +240,16 @@ def report(
237240 )
238241 _LOG .info (f"calculated embeddings for synthetic { syn_embeds .shape } " )
239242 on_progress (current = 50 , total = 100 )
240- # ensure that `trn` and `hol` are of equal size
241- max_sample_size = min (max_sample_size_embeddings , trn_sample_size )
242- if hol_tgt_data is not None :
243- max_sample_size = min (max_sample_size_embeddings , hol_sample_size )
244243 trn_embeds = calculate_embeddings (
245244 pull_data_for_embeddings (
246245 df_tgt = trn_tgt_data ,
247246 df_ctx = trn_ctx_data ,
248247 ctx_primary_key = ctx_primary_key ,
249248 tgt_context_key = tgt_context_key ,
250- max_sample_size = max_sample_size ,
249+ max_sample_size = max_sample_size_embeddings ,
251250 )
252251 )
253252 _LOG .info (f"calculated embeddings for training { trn_embeds .shape } " )
254-
255253 on_progress (current = 60 , total = 100 )
256254 if hol_tgt_data is not None :
257255 hol_embeds = calculate_embeddings (
@@ -260,7 +258,7 @@ def report(
260258 df_ctx = hol_ctx_data ,
261259 ctx_primary_key = ctx_primary_key ,
262260 tgt_context_key = tgt_context_key ,
263- max_sample_size = max_sample_size ,
261+ max_sample_size = max_sample_size_embeddings ,
264262 )
265263 )
266264 _LOG .info (f"calculated embeddings for holdout { hol_embeds .shape } " )
0 commit comments