Skip to content

Commit 2e75851

Browse files
MSD-1170: filter out empty buckets while calculating embeddings (#23)
1 parent da671a1 commit 2e75851

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

src/mostlyai/qa/report.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def report(
160160
check_min_sample_size(syn_sample_size, 100, "synthetic")
161161
check_min_sample_size(trn_sample_size, 90, "training")
162162
if hol_tgt_data is not None:
163-
check_min_sample_size(trn_sample_size, 10, "holdout")
163+
check_min_sample_size(hol_sample_size, 10, "holdout")
164164
except PrerequisiteNotMetError as err:
165165
_LOG.info(err)
166166
statistics.mark_early_exit()
@@ -242,10 +242,12 @@ def _calc_pull_embeds(df_tgt: pd.DataFrame, df_ctx: pd.DataFrame, start: int, st
242242
)
243243
# split into buckets for calculating embeddings to avoid memory issues and report continuous progress
244244
buckets = np.array_split(strings, stop - start)
245+
buckets = [b for b in buckets if len(b) > 0]
245246
embeds = []
246247
for i, bucket in enumerate(buckets, 1):
247248
embeds += [calculate_embeddings(bucket.tolist())]
248249
on_progress(current=start + i, total=100)
250+
on_progress(current=stop, total=100)
249251
embeds = np.concatenate(embeds, axis=0)
250252
_LOG.info(f"calculated embeddings {embeds.shape}")
251253
return embeds

tests/end_to_end/test_report.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,22 +193,32 @@ def make_dfs(
193193

194194
for test_idx, df_dict in enumerate(test_dfs):
195195
ctx_df, tgt_df = df_dict.pop("dfs")
196-
syn_ctx_data = trn_ctx_data = val_ctx_data = ctx_df
197-
syn_tgt_data = trn_tgt_data = val_tgt_data = tgt_df
196+
syn_ctx_data = trn_ctx_data = hol_ctx_data = ctx_df
197+
syn_tgt_data = trn_tgt_data = hol_tgt_data = tgt_df
198198
early_term = df_dict.pop("early_term")
199199
_, metrics = report(
200200
syn_tgt_data=syn_tgt_data,
201201
trn_tgt_data=trn_tgt_data,
202-
hol_tgt_data=val_tgt_data,
202+
hol_tgt_data=hol_tgt_data,
203203
syn_ctx_data=syn_ctx_data,
204204
trn_ctx_data=trn_ctx_data,
205-
hol_ctx_data=val_ctx_data,
205+
hol_ctx_data=hol_ctx_data,
206206
tgt_context_key="ck",
207207
ctx_primary_key="pk",
208208
)
209209
assert metrics is None if early_term else metrics is not None, f"Test {test_idx} failed"
210210

211211

212+
def test_report_few_holdout_records(tmp_path):
213+
tgt = pd.DataFrame({"id": list(range(100)), "col": ["a"] * 100})
214+
_, metrics = report(
215+
syn_tgt_data=tgt,
216+
trn_tgt_data=tgt,
217+
hol_tgt_data=tgt[:10],
218+
)
219+
assert metrics is not None
220+
221+
212222
def test_report_sequential_few_records(tmp_path):
213223
# ensure that we don't crash in case of dominant zero-seq-length
214224
ctx = pd.DataFrame({"id": list(range(1000))})

0 commit comments

Comments
 (0)