Skip to content

Commit 3a680aa

Browse files
authored
MSD-1124: harmonize max_sample_size_embeddings + remove default max_sample_size
1 parent a59bd91 commit 3a680aa

File tree

7 files changed

+34
-39
lines changed

7 files changed

+34
-39
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def report(
6767
report_subtitle: str = "",
6868
report_credits: str = REPORT_CREDITS,
6969
report_extra_info: str = "",
70-
max_sample_size_accuracy: int = MAX_SAMPLE_SIZE_ACCURACY,
71-
max_sample_size_embeddings: int = MAX_SAMPLE_SIZE_EMBEDDINGS,
70+
max_sample_size_accuracy: int | None = None,
71+
max_sample_size_embeddings: int | None = None,
7272
statistics_path: str | Path | None = None,
7373
on_progress: ProgressCallback | None = None,
7474
) -> tuple[Path, dict | None]:

examples/benchmark.ipynb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
" syn_tgt_data=syn,\n",
5959
" trn_tgt_data=tgt,\n",
6060
" hol_tgt_data=hol,\n",
61-
" max_sample_size_embeddings=50_000,\n",
6261
" )\n",
6362
" row = pd.json_normalize(metrics, sep=\"_\")\n",
6463
" row.insert(0, \"dataset\", dataset)\n",

src/mostlyai/qa/common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
_LOG = logging.getLogger(__name__)
2424

2525

26-
MAX_SAMPLE_SIZE_ACCURACY = 100_000
27-
MAX_SAMPLE_SIZE_EMBEDDINGS = 10_000
2826
ACCURACY_MAX_COLUMNS = 300 # should be an even number and greater than 100
2927

3028
MAX_UNIVARIATE_PLOTS = 300

src/mostlyai/qa/report.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@
4141
ProgressCallback,
4242
PrerequisiteNotMetError,
4343
check_min_sample_size,
44-
MAX_SAMPLE_SIZE_ACCURACY,
45-
MAX_SAMPLE_SIZE_EMBEDDINGS,
4644
add_tqdm,
4745
NXT_COLUMN,
4846
CTX_COLUMN_PREFIX,
@@ -69,8 +67,8 @@ def report(
6967
report_subtitle: str = "",
7068
report_credits: str = REPORT_CREDITS,
7169
report_extra_info: str = "",
72-
max_sample_size_accuracy: int = MAX_SAMPLE_SIZE_ACCURACY,
73-
max_sample_size_embeddings: int = MAX_SAMPLE_SIZE_EMBEDDINGS,
70+
max_sample_size_accuracy: int | None = None,
71+
max_sample_size_embeddings: int | None = None,
7472
statistics_path: str | Path | None = None,
7573
on_progress: ProgressCallback | None = None,
7674
) -> tuple[Path, dict | None]:
@@ -225,6 +223,11 @@ def report(
225223
)
226224
on_progress(current=30, total=100)
227225

226+
# ensure that embeddings are all of equal size for a fair 3-way comparison
227+
max_sample_size_embeddings = min(syn_sample_size, trn_sample_size)
228+
if hol_sample_size != 0:
229+
max_sample_size_embeddings = min(max_sample_size_embeddings, hol_sample_size)
230+
228231
# calculate embeddings
229232
syn_embeds = calculate_embeddings(
230233
pull_data_for_embeddings(
@@ -237,21 +240,16 @@ def report(
237240
)
238241
_LOG.info(f"calculated embeddings for synthetic {syn_embeds.shape}")
239242
on_progress(current=50, total=100)
240-
# ensure that `trn` and `hol` are of equal size
241-
max_sample_size = min(max_sample_size_embeddings, trn_sample_size)
242-
if hol_tgt_data is not None:
243-
max_sample_size = min(max_sample_size_embeddings, hol_sample_size)
244243
trn_embeds = calculate_embeddings(
245244
pull_data_for_embeddings(
246245
df_tgt=trn_tgt_data,
247246
df_ctx=trn_ctx_data,
248247
ctx_primary_key=ctx_primary_key,
249248
tgt_context_key=tgt_context_key,
250-
max_sample_size=max_sample_size,
249+
max_sample_size=max_sample_size_embeddings,
251250
)
252251
)
253252
_LOG.info(f"calculated embeddings for training {trn_embeds.shape}")
254-
255253
on_progress(current=60, total=100)
256254
if hol_tgt_data is not None:
257255
hol_embeds = calculate_embeddings(
@@ -260,7 +258,7 @@ def report(
260258
df_ctx=hol_ctx_data,
261259
ctx_primary_key=ctx_primary_key,
262260
tgt_context_key=tgt_context_key,
263-
max_sample_size=max_sample_size,
261+
max_sample_size=max_sample_size_embeddings,
264262
)
265263
)
266264
_LOG.info(f"calculated embeddings for holdout {hol_embeds.shape}")

src/mostlyai/qa/report_from_statistics.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
PrerequisiteNotMetError,
2828
check_min_sample_size,
2929
add_tqdm,
30-
MAX_SAMPLE_SIZE_ACCURACY,
31-
MAX_SAMPLE_SIZE_EMBEDDINGS,
3230
check_statistics_prerequisite,
3331
determine_data_size,
3432
REPORT_CREDITS,
@@ -50,8 +48,8 @@ def report_from_statistics(
5048
report_subtitle: str = "",
5149
report_credits: str = REPORT_CREDITS,
5250
report_extra_info: str = "",
53-
max_sample_size_accuracy: int = MAX_SAMPLE_SIZE_ACCURACY,
54-
max_sample_size_embeddings: int = MAX_SAMPLE_SIZE_EMBEDDINGS,
51+
max_sample_size_accuracy: int | None = None,
52+
max_sample_size_embeddings: int | None = None,
5553
on_progress: ProgressCallback | None = None,
5654
) -> Path:
5755
with TemporaryWorkspace() as workspace:

src/mostlyai/qa/sampling.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,20 @@ def pull_data_for_accuracy(
7070

7171
if df_ctx is not None:
7272
# explicit context
73-
df_ctx = df_ctx.sample(frac=1).head(max_sample_size).reset_index(drop=True)
74-
df_ctx = df_ctx.rename(columns={ctx_primary_key: tgt_context_key})
75-
df_tgt = df_tgt.merge(df_ctx[tgt_context_key], on=tgt_context_key).reset_index(drop=True)
73+
df_ctx = df_ctx.sample(frac=1).head(max_sample_size)
74+
df_ctx = df_ctx.rename(columns={ctx_primary_key: tgt_context_key}).reset_index(drop=True)
75+
df_tgt = df_tgt.merge(df_ctx[tgt_context_key], on=tgt_context_key, how="inner").reset_index(drop=True)
7676
elif tgt_context_key is not None:
7777
# implicit context
78-
df_ctx = df_tgt[[tgt_context_key]].drop_duplicates().sample(frac=1).head(max_sample_size).reset_index(drop=True)
79-
df_tgt = df_tgt.merge(df_ctx[tgt_context_key], on=tgt_context_key).reset_index(drop=True)
78+
df_ctx = df_tgt[[tgt_context_key]].drop_duplicates()
79+
df_ctx = df_ctx.sample(frac=1).head(max_sample_size).reset_index(drop=True)
80+
df_tgt = df_tgt.merge(df_ctx[tgt_context_key], on=tgt_context_key, how="inner").reset_index(drop=True)
8081
else:
8182
# no context; flat table
82-
df_ctx = pd.DataFrame({key: range(len(df_tgt))})
83-
df_tgt = df_tgt.sample(frac=1).head(max_sample_size).reset_index(drop=True)
84-
df_tgt[key] = df_ctx[key]
8583
tgt_context_key = key
84+
df_tgt = df_tgt.sample(frac=1).head(max_sample_size).reset_index(drop=True)
85+
df_tgt[key] = range(len(df_tgt))
86+
df_ctx = df_tgt[[key]]
8687

8788
# consistently use "__KEY" as key column
8889
df_ctx = df_ctx.rename(columns={tgt_context_key: key})
@@ -188,12 +189,13 @@ def pull_data_for_embeddings(
188189

189190
if df_ctx is not None:
190191
# explicit context
191-
df_ctx = df_ctx.sample(frac=1).head(max_sample_size).reset_index(drop=True)
192-
df_ctx = df_ctx.rename(columns={ctx_primary_key: tgt_context_key})
192+
df_ctx = df_ctx.sample(frac=1).head(max_sample_size)
193+
df_ctx = df_ctx.rename(columns={ctx_primary_key: tgt_context_key}).reset_index(drop=True)
193194
df_tgt = df_tgt.merge(df_ctx[tgt_context_key], on=tgt_context_key, how="right").reset_index(drop=True)
194195
elif tgt_context_key is not None:
195196
# implicit context
196-
df_ctx = df_tgt[[tgt_context_key]].drop_duplicates().sample(frac=1).head(max_sample_size).reset_index(drop=True)
197+
df_ctx = df_tgt[[tgt_context_key]].drop_duplicates()
198+
df_ctx = df_ctx.sample(frac=1).head(max_sample_size).reset_index(drop=True)
197199
df_tgt = df_tgt.merge(df_ctx[tgt_context_key], on=tgt_context_key, how="right").reset_index(drop=True)
198200
else:
199201
# no context; flat table
@@ -214,18 +216,18 @@ def row_to_string(row: pd.Series) -> str:
214216
# JSON to keep the string length for faster speed short
215217
return " ".join(row.values.astype(str))
216218

217-
def sequence_to_json(sequence: pd.DataFrame) -> str:
219+
def sequence_to_string(sequence: pd.DataFrame) -> str:
218220
return ", ".join(sequence.apply(row_to_string, axis=1))
219221

220-
jsons = (
222+
strings = (
221223
df_tgt.groupby(tgt_context_key)
222-
.apply(sequence_to_json, include_groups=False)
224+
.apply(sequence_to_string, include_groups=False)
223225
.sample(frac=1)
224226
.reset_index(drop=True)
225227
)
226228
time_elapsed = time.time() - t0
227-
_LOG.info(f"finished pulling data for embeddings ({time_elapsed=:.2f}s, {jsons.shape=})")
228-
return jsons
229+
_LOG.info(f"finished pulling data for embeddings ({time_elapsed=:.2f}s, {strings.shape=})")
230+
return strings
229231

230232

231233
def calculate_embeddings(texts: pd.Series | pd.DataFrame) -> np.ndarray:

tests/end_to_end/test_report.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def test_report_flat(tmp_path):
4242
assert report_path.exists()
4343

4444
accuracy = metrics["accuracy"]
45-
assert 0.8 <= accuracy["overall"] <= 1.0
46-
assert 0.8 <= accuracy["univariate"] <= 1.0
47-
assert 0.8 <= accuracy["bivariate"] <= 1.0
45+
assert 0.5 <= accuracy["overall"] <= 1.0
46+
assert 0.5 <= accuracy["univariate"] <= 1.0
47+
assert 0.5 <= accuracy["bivariate"] <= 1.0
4848
assert accuracy["coherence"] is None
4949
assert 0.8 <= accuracy["overall_max"] <= 1.0
5050
assert 0.8 <= accuracy["univariate_max"] <= 1.0

0 commit comments

Comments
 (0)