Skip to content

Commit b5a7b24

Browse files
authored
perf: minimize calls to load_embedder (#133)
1 parent 27e0beb commit b5a7b24

File tree

7 files changed

+37
-18
lines changed

7 files changed

+37
-18
lines changed

mostlyai/qa/_sampling.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
ACCURACY_MAX_COLUMNS,
4545
ProgressCallbackWrapper,
4646
)
47-
from mostlyai.qa.assets import load_embedder, load_tokenizer
47+
from mostlyai.qa.assets import load_tokenizer
4848

4949

5050
_LOG = logging.getLogger(__name__)
@@ -290,11 +290,8 @@ def calculate_embeddings(
290290
progress: ProgressCallbackWrapper | None = None,
291291
progress_from: int | None = None,
292292
progress_to: int | None = None,
293+
embedder: Any | None = None,
293294
) -> np.ndarray:
294-
# load embedder
295-
t0 = time.time()
296-
embedder = load_embedder()
297-
_LOG.info(f"loaded load_embedder in {time.time() - t0:.2f}s")
298295
# split into buckets for calculating embeddings to avoid memory issues and report continuous progress
299296
steps = progress_to - progress_from if progress_to is not None and progress_from is not None else 1
300297
buckets = np.array_split(strings, steps)

mostlyai/qa/assets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def load_embedder():
4343
Load the embedder model.
4444
Can deal with read-only cache folder by attempting to download the model if it is not locally available.
4545
Users can set MOSTLY_HF_HOME environment variable to override the default cache folder.
46+
47+
Note that this method can take significant time to load the model. Thus, it is recommended to call this method once and reuse the returned object.
4648
"""
4749
from sentence_transformers import SentenceTransformer
4850

mostlyai/qa/reporting.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
plot_store_distinct_categories_per_sequence,
4646
plot_store_sequences_per_distinct_category,
4747
)
48+
from mostlyai.qa.assets import load_embedder
4849
from mostlyai.qa.metrics import ModelMetrics, Accuracy, Similarity, Distances
4950
from mostlyai.qa._sampling import (
5051
calculate_embeddings,
@@ -288,6 +289,9 @@ def report(
288289
acc_cats_per_seq = acc_seqs_per_cat = pd.DataFrame({"column": [], "accuracy": [], "accuracy_max": []})
289290
progress.update(completed=25, total=100)
290291

292+
_LOG.info("load embedder")
293+
embedder = load_embedder()
294+
291295
_LOG.info("calculate embeddings for synthetic")
292296
syn_embeds = calculate_embeddings(
293297
strings=pull_data_for_embeddings(
@@ -300,6 +304,7 @@ def report(
300304
progress=progress,
301305
progress_from=25,
302306
progress_to=45,
307+
embedder=embedder,
303308
)
304309
_LOG.info("calculate embeddings for training")
305310
trn_embeds = calculate_embeddings(
@@ -313,6 +318,7 @@ def report(
313318
progress=progress,
314319
progress_from=45,
315320
progress_to=65,
321+
embedder=embedder,
316322
)
317323
if hol_tgt_data is not None:
318324
_LOG.info("calculate embeddings for holdout")
@@ -327,6 +333,7 @@ def report(
327333
progress=progress,
328334
progress_from=65,
329335
progress_to=85,
336+
embedder=embedder,
330337
)
331338
else:
332339
hol_embeds = None

mostlyai/qa/reporting_from_statistics.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
ProgressCallbackWrapper,
3737
)
3838
from mostlyai.qa._filesystem import Statistics, TemporaryWorkspace
39+
from mostlyai.qa.assets import load_embedder
3940

4041
_LOG = logging.getLogger(__name__)
4142

@@ -162,6 +163,9 @@ def report_from_statistics(
162163
acc_cats_per_seq = acc_seqs_per_cat = pd.DataFrame({"column": [], "accuracy": [], "accuracy_max": []})
163164
progress.update(completed=40, total=100)
164165

166+
_LOG.info("load embedder")
167+
embedder = load_embedder()
168+
165169
_LOG.info("calculate embeddings for synthetic")
166170
syn_embeds = calculate_embeddings(
167171
strings=pull_data_for_embeddings(
@@ -174,6 +178,7 @@ def report_from_statistics(
174178
progress=progress,
175179
progress_from=40,
176180
progress_to=60,
181+
embedder=embedder,
177182
)
178183

179184
_LOG.info("report similarity")

tests/unit/test_distances.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
plot_store_distances,
2222
)
2323
from mostlyai.qa._sampling import calculate_embeddings
24+
from mostlyai.qa.assets import load_embedder
2425

2526

2627
@pytest.fixture()
@@ -42,9 +43,10 @@ def cat_with_rare_and_none():
4243

4344
def test_calculate_distances():
4445
n = 10
45-
syn_embeds = calculate_embeddings(["a 0 1.0"] * n)
46-
trn_embeds = calculate_embeddings(["a 0 0.0"] * n)
47-
hol_embeds = calculate_embeddings(["a 0 1.0"] * n)
46+
embedder = load_embedder()
47+
syn_embeds = calculate_embeddings(["a 0 1.0"] * n, embedder=embedder)
48+
trn_embeds = calculate_embeddings(["a 0 0.0"] * n, embedder=embedder)
49+
hol_embeds = calculate_embeddings(["a 0 1.0"] * n, embedder=embedder)
4850
dcr_syn_trn, dcr_syn_hol, dcr_trn_hol = calculate_distances(
4951
syn_embeds=syn_embeds, trn_embeds=trn_embeds, hol_embeds=hol_embeds
5052
)
@@ -55,9 +57,9 @@ def test_calculate_distances():
5557
assert dcr_syn_hol.max() == 0
5658

5759
# test specifically that near matches do not report a distance of 0 due to rounding
58-
syn_embeds = calculate_embeddings(["a 0.0002"] * n)
59-
trn_embeds = calculate_embeddings(["a 0.0001"] * n)
60-
hol_embeds = calculate_embeddings(["a 0.0001"] * n)
60+
syn_embeds = calculate_embeddings(["a 0.0002"] * n, embedder=embedder)
61+
trn_embeds = calculate_embeddings(["a 0.0001"] * n, embedder=embedder)
62+
hol_embeds = calculate_embeddings(["a 0.0001"] * n, embedder=embedder)
6163
dcr_syn_trn, dcr_syn_hol, dcr_trn_hol = calculate_distances(
6264
syn_embeds=syn_embeds, trn_embeds=trn_embeds, hol_embeds=hol_embeds
6365
)
@@ -66,7 +68,8 @@ def test_calculate_distances():
6668

6769

6870
def test_plot_store_dcr(workspace):
69-
embeds = calculate_embeddings(["a 0.0002"] * 100)
71+
embedder = load_embedder()
72+
embeds = calculate_embeddings(["a 0.0002"] * 100, embedder=embedder)
7073
dcr_syn_trn, dcr_syn_hol, dcr_trn_hol = calculate_distances(syn_embeds=embeds, trn_embeds=embeds, hol_embeds=embeds)
7174
plot_store_distances(dcr_syn_trn, dcr_syn_hol, dcr_trn_hol, workspace)
7275
output_dir = workspace.workspace_dir / "figures"

tests/unit/test_html_report.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from mostlyai.qa import _accuracy, _html_report, _distances, _similarity
1616
from mostlyai.qa._common import CTX_COLUMN_PREFIX, TGT_COLUMN_PREFIX
17+
from mostlyai.qa.assets import load_embedder
1718
from mostlyai.qa.reporting import _calculate_metrics
1819
from mostlyai.qa._sampling import calculate_embeddings, pull_data_for_embeddings
1920
import pandas as pd
@@ -33,9 +34,11 @@ def test_generate_store_report(tmp_path, cols, workspace):
3334
acc_cats_per_seq = pd.DataFrame({"column": acc_uni["column"], "accuracy": 0.5, "accuracy_max": 0.5})
3435
acc_seqs_per_cat = pd.DataFrame({"column": acc_uni["column"], "accuracy": 0.5, "accuracy_max": 0.5})
3536
corr_trn = _accuracy.calculate_correlations(acc_trn)
36-
syn_embeds = calculate_embeddings(pull_data_for_embeddings(df_tgt=syn))
37-
trn_embeds = calculate_embeddings(pull_data_for_embeddings(df_tgt=trn))
38-
hol_embeds = calculate_embeddings(pull_data_for_embeddings(df_tgt=hol))
37+
embedder = load_embedder()
38+
syn_embeds = calculate_embeddings(pull_data_for_embeddings(df_tgt=syn), embedder=embedder)
39+
trn_embeds = calculate_embeddings(pull_data_for_embeddings(df_tgt=trn), embedder=embedder)
40+
hol_embeds = calculate_embeddings(pull_data_for_embeddings(df_tgt=hol), embedder=embedder)
41+
3942
sim_cosine_trn_hol, sim_cosine_trn_syn = _similarity.calculate_cosine_similarities(
4043
syn_embeds=syn_embeds,
4144
trn_embeds=trn_embeds,

tests/unit/test_similarity.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from mostlyai.qa._similarity import calculate_cosine_similarities, calculate_discriminator_auc
1818
from mostlyai.qa._sampling import calculate_embeddings
19+
from mostlyai.qa.assets import load_embedder
1920

2021

2122
def test_calculate_embeddings():
@@ -25,9 +26,10 @@ def test_calculate_embeddings():
2526
# semantically distant synthetic data
2627
syn_distant = ["quantum physics theory", "deep space exploration"]
2728

28-
trn_embeds = calculate_embeddings(trn)
29-
syn_close_embeds = calculate_embeddings(syn_close)
30-
syn_distant_embeds = calculate_embeddings(syn_distant)
29+
embedder = load_embedder()
30+
trn_embeds = calculate_embeddings(trn, embedder=embedder)
31+
syn_close_embeds = calculate_embeddings(syn_close, embedder=embedder)
32+
syn_distant_embeds = calculate_embeddings(syn_distant, embedder=embedder)
3133
assert np.all(trn_embeds[0] == trn_embeds[2]) # check that we retain row order
3234

3335
# check that syn_close is closer to trn than syn_distant

0 commit comments

Comments
 (0)