2828_LOG = logging .getLogger (__name__ )
2929
3030
31- def calculate_dcrs (data : np .ndarray | None , query : np .ndarray | None ) -> np .ndarray | None :
31+ def calculate_dcrs_nndrs (
32+ data : np .ndarray | None , query : np .ndarray | None
33+ ) -> tuple [np .ndarray | None , np .ndarray | None ]:
3234 """
33- Calculate Distance to Closest Records (DCRs).
35+ Calculate Distance to Closest Records (DCRs) and Nearest Neighbor Distance Ratios (NNDRs) .
3436
3537 Args:
3638 data: Embeddings of the training data.
@@ -39,19 +41,21 @@ def calculate_dcrs(data: np.ndarray | None, query: np.ndarray | None) -> np.ndar
3941 Returns:
4042 """
4143 if data is None or query is None :
42- return None
44+ return None , None
4345 # sort data by first dimension to enforce deterministic results
4446 data = data [data [:, 0 ].argsort ()]
4547 _LOG .info (f"calculate DCRs for { data .shape = } and { query .shape = } " )
46- index = NearestNeighbors (n_neighbors = 1 , algorithm = "auto" , metric = "cosine" , n_jobs = min (cpu_count () - 1 , 16 ))
48+ index = NearestNeighbors (n_neighbors = 2 , algorithm = "auto" , metric = "cosine" , n_jobs = min (cpu_count () - 1 , 16 ))
4749 index .fit (data )
4850 dcrs , _ = index .kneighbors (query )
49- return dcrs [:, 0 ]
51+ dcr = dcrs [:, 0 ]
52+ nndr = (dcrs [:, 0 ] + 1e-8 ) / (dcrs [:, 1 ] + 1e-8 )
53+ return dcr , nndr
5054
5155
5256def calculate_distances (
5357 * , syn_embeds : np .ndarray , trn_embeds : np .ndarray , hol_embeds : np .ndarray | None
54- ) -> tuple [ np . ndarray , np .ndarray | None , np . ndarray | None ]:
58+ ) -> dict [ str , np .ndarray ]:
5559 """
5660 Calculates distances to the closest records (DCR).
5761
@@ -61,52 +65,96 @@ def calculate_distances(
6165 hol_embeds: Embeddings of holdout data.
6266
6367 Returns:
64- Tuple containing:
68+ Dictionary containing:
6569 - dcr_syn_trn: DCR for synthetic to training.
6670 - dcr_syn_hol: DCR for synthetic to holdout.
6771 - dcr_trn_hol: DCR for training to holdout.
72+ - nndr_syn_trn: NNDR for synthetic to training.
73+ - nndr_syn_hol: NNDR for synthetic to holdout.
74+ - nndr_trn_hol: NNDR for training to holdout.
6875 """
6976 if hol_embeds is not None :
7077 assert trn_embeds .shape == hol_embeds .shape
7178
72- # calculate DCR for synthetic to training
73- dcr_syn_trn = calculate_dcrs (data = trn_embeds , query = syn_embeds )
74- # calculate DCR for synthetic to holdout
75- dcr_syn_hol = calculate_dcrs (data = hol_embeds , query = syn_embeds )
76- # calculate DCR for holdout to training
77- dcr_trn_hol = calculate_dcrs (data = trn_embeds , query = hol_embeds )
79+ # calculate DCR / NNDR for synthetic to training
80+ dcr_syn_trn , nndr_syn_trn = calculate_dcrs_nndrs (data = trn_embeds , query = syn_embeds )
81+ # calculate DCR / NNDR for synthetic to holdout
82+ dcr_syn_hol , nndr_syn_hol = calculate_dcrs_nndrs (data = hol_embeds , query = syn_embeds )
83+ # calculate DCR / NNDR for holdout to training
84+ dcr_trn_hol , nndr_trn_hol = calculate_dcrs_nndrs (data = trn_embeds , query = hol_embeds )
7885
79- dcr_syn_trn_deciles = np .round (np .quantile (dcr_syn_trn , np .linspace (0 , 1 , 11 )), 3 )
80- _LOG .info (f"DCR deciles for synthetic to training: { dcr_syn_trn_deciles } " )
86+ # log statistics
87+ def deciles (x ):
88+ return np .round (np .quantile (x , np .linspace (0 , 1 , 11 )), 3 )
89+
90+ _LOG .info (f"DCR deciles for synthetic to training: { deciles (dcr_syn_trn )} " )
91+ _LOG .info (f"NNDR deciles for synthetic to training: { deciles (nndr_syn_trn )} " )
8192 if dcr_syn_hol is not None :
82- dcr_syn_hol_deciles = np .round (np .quantile (dcr_syn_hol , np .linspace (0 , 1 , 11 )), 3 )
83- _LOG .info (f"DCR deciles for synthetic to holdout: { dcr_syn_hol_deciles } " )
84- # calculate share of dcr_syn_trn != dcr_syn_hol
93+ _LOG .info (f"DCR deciles for synthetic to holdout: { deciles (dcr_syn_hol )} " )
94+ _LOG .info (f"NNDR deciles for synthetic to holdout: { deciles (nndr_syn_hol )} " )
8595 _LOG .info (f"share of dcr_syn_trn < dcr_syn_hol: { np .mean (dcr_syn_trn < dcr_syn_hol ):.1%} " )
96+ _LOG .info (f"share of nndr_syn_trn < nndr_syn_hol: { np .mean (nndr_syn_trn < nndr_syn_hol ):.1%} " )
8697 _LOG .info (f"share of dcr_syn_trn > dcr_syn_hol: { np .mean (dcr_syn_trn > dcr_syn_hol ):.1%} " )
87-
98+ _LOG . info ( f"share of nndr_syn_trn > nndr_syn_hol: { np . mean ( nndr_syn_trn > nndr_syn_hol ):.1% } " )
8899 if dcr_trn_hol is not None :
89- dcr_trn_hol_deciles = np .round (np .quantile (dcr_trn_hol , np .linspace (0 , 1 , 11 )), 3 )
90- _LOG .info (f"DCR deciles for training to holdout: { dcr_trn_hol_deciles } " )
100+ _LOG .info (f"DCR deciles for training to holdout: { deciles (dcr_trn_hol )} " )
101+ _LOG .info (f"NNDR deciles for training to holdout: { deciles (nndr_trn_hol )} " )
102+ return {
103+ "dcr_syn_trn" : dcr_syn_trn ,
104+ "nndr_syn_trn" : nndr_syn_trn ,
105+ "dcr_syn_hol" : dcr_syn_hol ,
106+ "nndr_syn_hol" : nndr_syn_hol ,
107+ "dcr_trn_hol" : dcr_trn_hol ,
108+ "nndr_trn_hol" : nndr_trn_hol ,
109+ }
91110
92- return dcr_syn_trn , dcr_syn_hol , dcr_trn_hol
93111
112+ def plot_distances (plot_title : str , distances : dict [str , np .ndarray ]) -> go .Figure :
113+ dcr_syn_trn = distances ["dcr_syn_trn" ]
114+ dcr_syn_hol = distances ["dcr_syn_hol" ]
115+ dcr_trn_hol = distances ["dcr_trn_hol" ]
116+ nndr_syn_trn = distances ["nndr_syn_trn" ]
117+ nndr_syn_hol = distances ["nndr_syn_hol" ]
118+ nndr_trn_hol = distances ["nndr_trn_hol" ]
94119
95- def plot_distances (
96- plot_title : str , dcr_syn_trn : np .ndarray , dcr_syn_hol : np .ndarray | None , dcr_trn_hol : np .ndarray | None
97- ) -> go .Figure :
98- # calculate quantiles
120+ # calculate quantiles for DCR
99121 y = np .linspace (0 , 1 , 101 )
100- x_syn_trn = np .quantile (dcr_syn_trn , y )
122+
123+ # Calculate max values to use later
124+ max_dcr_syn_trn = np .max (dcr_syn_trn )
125+ max_dcr_syn_hol = None if dcr_syn_hol is None else np .max (dcr_syn_hol )
126+ max_dcr_trn_hol = None if dcr_trn_hol is None else np .max (dcr_trn_hol )
127+ max_nndr_syn_trn = np .max (nndr_syn_trn )
128+ max_nndr_syn_hol = None if nndr_syn_hol is None else np .max (nndr_syn_hol )
129+ max_nndr_trn_hol = None if nndr_trn_hol is None else np .max (nndr_trn_hol )
130+
131+ # Ensure first point is always at x=0 for all lines
132+ # and last point is at the maximum x value with y=1
133+ x_dcr_syn_trn = np .concatenate ([[0 ], np .quantile (dcr_syn_trn , y [1 :- 1 ]), [max_dcr_syn_trn ]])
101134 if dcr_syn_hol is not None :
102- x_syn_hol = np .quantile (dcr_syn_hol , y )
135+ x_dcr_syn_hol = np .concatenate ([[ 0 ], np . quantile (dcr_syn_hol , y [ 1 : - 1 ]), [ max_dcr_syn_hol ]] )
103136 else :
104- x_syn_hol = None
137+ x_dcr_syn_hol = None
105138
106139 if dcr_trn_hol is not None :
107- x_trn_hol = np .quantile (dcr_trn_hol , y )
140+ x_dcr_trn_hol = np .concatenate ([[ 0 ], np . quantile (dcr_trn_hol , y [ 1 : - 1 ]), [ max_dcr_trn_hol ]] )
108141 else :
109- x_trn_hol = None
142+ x_dcr_trn_hol = None
143+
144+ # calculate quantiles for NNDR
145+ x_nndr_syn_trn = np .concatenate ([[0 ], np .quantile (nndr_syn_trn , y [1 :- 1 ]), [max_nndr_syn_trn ]])
146+ if nndr_syn_hol is not None :
147+ x_nndr_syn_hol = np .concatenate ([[0 ], np .quantile (nndr_syn_hol , y [1 :- 1 ]), [max_nndr_syn_hol ]])
148+ else :
149+ x_nndr_syn_hol = None
150+
151+ if nndr_trn_hol is not None :
152+ x_nndr_trn_hol = np .concatenate ([[0 ], np .quantile (nndr_trn_hol , y [1 :- 1 ]), [max_nndr_trn_hol ]])
153+ else :
154+ x_nndr_trn_hol = None
155+
156+ # Adjust y to match the new x arrays with the added 0 and 1 points
157+ y = np .concatenate ([[0 ], y [1 :- 1 ], [1 ]])
110158
111159 # prepare layout
112160 layout = go .Layout (
@@ -120,80 +168,132 @@ def plot_distances(
120168 plot_bgcolor = CHARTS_COLORS ["background" ],
121169 autosize = True ,
122170 height = 500 ,
123- margin = dict (l = 20 , r = 20 , b = 20 , t = 40 , pad = 5 ),
171+ margin = dict (l = 20 , r = 20 , b = 20 , t = 60 , pad = 5 ),
124172 showlegend = True ,
125- yaxis = dict (
126- showticklabels = False ,
127- zeroline = True ,
128- zerolinewidth = 1 ,
129- zerolinecolor = "#999999" ,
130- rangemode = "tozero" ,
173+ )
174+
175+ # Create a figure with two subplots side by side
176+ fig = go .Figure (layout = layout ).set_subplots (
177+ rows = 1 ,
178+ cols = 2 ,
179+ horizontal_spacing = 0.05 ,
180+ subplot_titles = ("Distance to Closest Record (DCR)" , "Nearest Neighbor Distance Ratio (NNDR)" ),
181+ )
182+ fig .update_annotations (font_size = 12 )
183+
184+ # Configure axes for both subplots
185+ for i in range (1 , 3 ):
186+ fig .update_xaxes (
187+ col = i ,
131188 showline = True ,
132189 linewidth = 1 ,
133190 linecolor = "#999999" ,
134- ),
135- yaxis2 = dict (
136- overlaying = "y" ,
137- side = "right" ,
191+ hoverformat = ".3f" ,
192+ )
193+
194+ # Only show y-axis on the right side with percentage labels
195+ fig .update_yaxes (
196+ col = i ,
138197 tickformat = ".0%" ,
139198 showgrid = False ,
140- range = [0 , 1 ],
199+ range = [- 0.01 , 1.01 ],
141200 showline = True ,
142201 linewidth = 1 ,
143202 linecolor = "#999999" ,
203+ side = "right" ,
204+ tickvals = [0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ],
205+ )
206+
207+ # Add traces for DCR plot (left subplot)
208+ # training vs holdout (light gray)
209+ if x_dcr_trn_hol is not None :
210+ fig .add_trace (
211+ go .Scatter (
212+ mode = "lines" ,
213+ x = x_dcr_trn_hol ,
214+ y = y ,
215+ name = "Training vs. Holdout Data" ,
216+ line = dict (color = "#999999" , width = 5 ),
217+ showlegend = True ,
218+ ),
219+ row = 1 ,
220+ col = 1 ,
221+ )
222+
223+ # synthetic vs holdout (gray)
224+ if x_dcr_syn_hol is not None :
225+ fig .add_trace (
226+ go .Scatter (
227+ mode = "lines" ,
228+ x = x_dcr_syn_hol ,
229+ y = y ,
230+ name = "Synthetic vs. Holdout Data" ,
231+ line = dict (color = "#666666" , width = 5 ),
232+ showlegend = True ,
233+ ),
234+ row = 1 ,
235+ col = 1 ,
236+ )
237+
238+ # synthetic vs training (green)
239+ fig .add_trace (
240+ go .Scatter (
241+ mode = "lines" ,
242+ x = x_dcr_syn_trn ,
243+ y = y ,
244+ name = "Synthetic vs. Training Data" ,
245+ line = dict (color = "#24db96" , width = 5 ),
246+ showlegend = True ,
144247 ),
145- xaxis = dict (
146- showline = True ,
147- linewidth = 1 ,
148- linecolor = "#999999" ,
149- hoverformat = ".3f" ,
150- ),
248+ row = 1 ,
249+ col = 1 ,
151250 )
152- fig = go .Figure (layout = layout )
153-
154- traces = []
155251
252+ # Add traces for NNDR plot (right subplot)
156253 # training vs holdout (light gray)
157- if x_trn_hol is not None :
158- traces . append (
254+ if x_nndr_trn_hol is not None :
255+ fig . add_trace (
159256 go .Scatter (
160257 mode = "lines" ,
161- x = x_trn_hol ,
258+ x = x_nndr_trn_hol ,
162259 y = y ,
163260 name = "Training vs. Holdout Data" ,
164261 line = dict (color = "#999999" , width = 5 ),
165- yaxis = "y2" ,
166- )
262+ showlegend = False ,
263+ ),
264+ row = 1 ,
265+ col = 2 ,
167266 )
168267
169268 # synthetic vs holdout (gray)
170- if x_syn_hol is not None :
171- traces . append (
269+ if x_nndr_syn_hol is not None :
270+ fig . add_trace (
172271 go .Scatter (
173272 mode = "lines" ,
174- x = x_syn_hol ,
273+ x = x_nndr_syn_hol ,
175274 y = y ,
176275 name = "Synthetic vs. Holdout Data" ,
177276 line = dict (color = "#666666" , width = 5 ),
178- yaxis = "y2" ,
179- )
277+ showlegend = False ,
278+ ),
279+ row = 1 ,
280+ col = 2 ,
180281 )
181282
182283 # synthetic vs training (green)
183- traces . append (
284+ fig . add_trace (
184285 go .Scatter (
185286 mode = "lines" ,
186- x = x_syn_trn ,
287+ x = x_nndr_syn_trn ,
187288 y = y ,
188289 name = "Synthetic vs. Training Data" ,
189290 line = dict (color = "#24db96" , width = 5 ),
190- yaxis = "y2" ,
191- )
291+ showlegend = False ,
292+ ),
293+ row = 1 ,
294+ col = 2 ,
192295 )
193296
194- for trace in traces :
195- fig .add_trace (trace )
196-
197297 fig .update_layout (
198298 legend = dict (
199299 orientation = "h" ,
@@ -210,12 +310,11 @@ def plot_distances(
210310
211311
212312def plot_store_distances (
213- dcr_syn_trn : np .ndarray ,
214- dcr_syn_hol : np .ndarray | None ,
215- dcr_trn_hol : np .ndarray | None ,
313+ distances : dict [str , np .ndarray ],
216314 workspace : TemporaryWorkspace ,
217315) -> None :
218316 fig = plot_distances (
219- "Cumulative Distributions of Distance to Closest Records (DCR)" , dcr_syn_trn , dcr_syn_hol , dcr_trn_hol
317+ "Cumulative Distributions of Distance Metrics" ,
318+ distances ,
220319 )
221320 workspace .store_figure_html (fig , "distances_dcr" )
0 commit comments