2828_LOG = logging .getLogger (__name__ )
2929
3030
31+ def calculate_dcrs (data : np .ndarray | None , query : np .ndarray | None ) -> np .ndarray | None :
32+ """
33+ Calculate Distance to Closest Records (DCRs).
34+
35+ Args:
36+ data: Embeddings of the training data.
37+ query: Embeddings of the query set.
38+
39+ Returns:
40+ """
41+ if data is None or query is None :
42+ return None
43+ # sort data by first dimension to enforce deterministic results
44+ data = data [data [:, 0 ].argsort ()]
45+ _LOG .info (f"calculate DCRs for { data .shape = } and { query .shape = } " )
46+ index = NearestNeighbors (n_neighbors = 1 , algorithm = "auto" , metric = "cosine" , n_jobs = min (cpu_count () - 1 , 16 ))
47+ index .fit (data )
48+ dcrs , _ = index .kneighbors (query )
49+ return dcrs [:, 0 ]
50+
51+
3152def calculate_distances (
3253 * , syn_embeds : np .ndarray , trn_embeds : np .ndarray , hol_embeds : np .ndarray | None
3354) -> tuple [np .ndarray , np .ndarray | None , np .ndarray | None ]:
@@ -47,28 +68,13 @@ def calculate_distances(
4768 """
4869 if hol_embeds is not None :
4970 assert trn_embeds .shape == hol_embeds .shape
50- # calculate DCR for synthetic to training
51- index_syn = NearestNeighbors (n_neighbors = 1 , algorithm = "brute" , metric = "l2" , n_jobs = min (cpu_count () - 1 , 16 ))
52- index_syn .fit (syn_embeds )
53- _LOG .info (f"calculate DCRs for { len (syn_embeds ):,} synthetic to { len (trn_embeds ):,} training" )
54- dcrs_syn_trn , _ = index_syn .kneighbors (trn_embeds )
55- dcr_syn_trn = dcrs_syn_trn [:, 0 ]
5671
57- dcr_syn_hol = None
58- dcr_trn_hol = None
59-
60- if hol_embeds is not None :
61- # calculate DCR for synthetic to holdout
62- _LOG .info (f"calculate DCRs for { len (syn_embeds ):,} synthetic to { len (hol_embeds ):,} holdout" )
63- dcrs_syn_hol , _ = index_syn .kneighbors (hol_embeds )
64- dcr_syn_hol = dcrs_syn_hol [:, 0 ]
65-
66- # calculate DCR for training to holdout
67- _LOG .info (f"calculate DCRs for { len (trn_embeds ):,} training to { len (hol_embeds ):,} holdout" )
68- index_trn = NearestNeighbors (n_neighbors = 1 , algorithm = "brute" , metric = "l2" , n_jobs = min (cpu_count () - 1 , 16 ))
69- index_trn .fit (trn_embeds )
70- dcrs_trn_hol , _ = index_trn .kneighbors (hol_embeds )
71- dcr_trn_hol = dcrs_trn_hol [:, 0 ]
72+ # calculate DCR for synthetic to training
73+ dcr_syn_trn = calculate_dcrs (data = trn_embeds , query = syn_embeds )
74+ # calculate DCR for synthetic to holdout
75+ dcr_syn_hol = calculate_dcrs (data = hol_embeds , query = syn_embeds )
76+ # calculate DCR for holdout to training
77+ dcr_trn_hol = calculate_dcrs (data = trn_embeds , query = hol_embeds )
7278
7379 dcr_syn_trn_deciles = np .round (np .quantile (dcr_syn_trn , np .linspace (0 , 1 , 11 )), 3 )
7480 _LOG .info (f"DCR deciles for synthetic to training: { dcr_syn_trn_deciles } " )
0 commit comments