2121 plot_store_distances ,
2222)
2323from mostlyai .qa ._sampling import calculate_embeddings
24+ from mostlyai .qa .assets import load_embedder
2425
2526
2627@pytest .fixture ()
@@ -42,9 +43,10 @@ def cat_with_rare_and_none():
4243
4344def test_calculate_distances ():
4445 n = 10
45- syn_embeds = calculate_embeddings (["a 0 1.0" ] * n )
46- trn_embeds = calculate_embeddings (["a 0 0.0" ] * n )
47- hol_embeds = calculate_embeddings (["a 0 1.0" ] * n )
46+ embedder = load_embedder ()
47+ syn_embeds = calculate_embeddings (["a 0 1.0" ] * n , embedder = embedder )
48+ trn_embeds = calculate_embeddings (["a 0 0.0" ] * n , embedder = embedder )
49+ hol_embeds = calculate_embeddings (["a 0 1.0" ] * n , embedder = embedder )
4850 dcr_syn_trn , dcr_syn_hol , dcr_trn_hol = calculate_distances (
4951 syn_embeds = syn_embeds , trn_embeds = trn_embeds , hol_embeds = hol_embeds
5052 )
@@ -55,9 +57,9 @@ def test_calculate_distances():
5557 assert dcr_syn_hol .max () == 0
5658
5759 # test specifically that near matches do not report a distance of 0 due to rounding
58- syn_embeds = calculate_embeddings (["a 0.0002" ] * n )
59- trn_embeds = calculate_embeddings (["a 0.0001" ] * n )
60- hol_embeds = calculate_embeddings (["a 0.0001" ] * n )
60+ syn_embeds = calculate_embeddings (["a 0.0002" ] * n , embedder = embedder )
61+ trn_embeds = calculate_embeddings (["a 0.0001" ] * n , embedder = embedder )
62+ hol_embeds = calculate_embeddings (["a 0.0001" ] * n , embedder = embedder )
6163 dcr_syn_trn , dcr_syn_hol , dcr_trn_hol = calculate_distances (
6264 syn_embeds = syn_embeds , trn_embeds = trn_embeds , hol_embeds = hol_embeds
6365 )
@@ -66,7 +68,8 @@ def test_calculate_distances():
6668
6769
6870def test_plot_store_dcr (workspace ):
69- embeds = calculate_embeddings (["a 0.0002" ] * 100 )
71+ embedder = load_embedder ()
72+ embeds = calculate_embeddings (["a 0.0002" ] * 100 , embedder = embedder )
7073 dcr_syn_trn , dcr_syn_hol , dcr_trn_hol = calculate_distances (syn_embeds = embeds , trn_embeds = embeds , hol_embeds = embeds )
7174 plot_store_distances (dcr_syn_trn , dcr_syn_hol , dcr_trn_hol , workspace )
7275 output_dir = workspace .workspace_dir / "figures"
0 commit comments