1111import yaml
1212from botocore .exceptions import ClientError
1313
14+ from sdgym ._dataset_utils import _read_zipped_data
15+
1416SYNTHESIZER_BASELINE = 'GaussianCopulaSynthesizer'
1517RESULTS_FOLDER_PREFIX = 'SDGym_results_'
1618metainfo_PREFIX = 'metainfo'
2224class ResultsHandler (ABC ):
2325 """Abstract base class for handling results storage and retrieval."""
2426
27+ def __init__ (self , baseline_synthesizer = SYNTHESIZER_BASELINE ):
28+ self .baseline_synthesizer = baseline_synthesizer or SYNTHESIZER_BASELINE
29+
2530 @abstractmethod
2631 def list (self ):
2732 """List all runs in the results directory."""
@@ -59,7 +64,8 @@ def _compute_wins(self, result):
5964 result ['Win' ] = 0
6065 for dataset in datasets :
6166 score_baseline = result .loc [
62- (result ['Synthesizer' ] == SYNTHESIZER_BASELINE ) & (result ['Dataset' ] == dataset )
67+ (result ['Synthesizer' ] == self .baseline_synthesizer )
68+ & (result ['Dataset' ] == dataset )
6369 ]['Quality_Score' ].to_numpy ()
6470 if score_baseline .size == 0 :
6571 continue
@@ -84,7 +90,7 @@ def _get_summarize_table(self, folder_to_results, folder_infos):
8490 f' - # datasets: { folder_infos [folder ]["# datasets" ]} '
8591 f' - sdgym version: { folder_infos [folder ]["sdgym_version" ]} '
8692 )
87- results = results .loc [results ['Synthesizer' ] != SYNTHESIZER_BASELINE ]
93+ results = results .loc [results ['Synthesizer' ] != self . baseline_synthesizer ]
8894 column_data = results .groupby (['Synthesizer' ])['Win' ].sum ()
8995 columns .append ((date_obj , column_name , column_data ))
9096
@@ -107,9 +113,11 @@ def _get_column_name_infos(self, folder_to_results):
107113 continue
108114
109115 metainfo_info = self ._load_yaml_file (folder , yaml_files [0 ])
110- num_datasets = results .loc [
111- results ['Synthesizer' ] == SYNTHESIZER_BASELINE , 'Dataset'
112- ].nunique ()
116+ baseline_mask = results ['Synthesizer' ] == self .baseline_synthesizer
117+ if baseline_mask .any ():
118+ num_datasets = results .loc [baseline_mask , 'Dataset' ].nunique ()
119+ else :
120+ num_datasets = results ['Dataset' ].nunique ()
113121 folder_to_info [folder ] = {
114122 'date' : metainfo_info .get ('starting_date' )[:NUM_DIGITS_DATE ],
115123 'sdgym_version' : metainfo_info .get ('sdgym_version' ),
@@ -236,7 +244,8 @@ def all_runs_complete(self, folder_name):
236244class LocalResultsHandler (ResultsHandler ):
237245 """Results handler for local filesystem."""
238246
239- def __init__ (self , base_path ):
247+ def __init__ (self , base_path , baseline_synthesizer = SYNTHESIZER_BASELINE ):
248+ super ().__init__ (baseline_synthesizer = baseline_synthesizer )
240249 self .base_path = base_path
241250
242251 def list (self ):
@@ -262,8 +271,12 @@ def load_synthesizer(self, file_path):
262271 return cloudpickle .load (f )
263272
264273 def load_synthetic_data (self , file_path ):
265- """Load synthetic data from a CSV file."""
266- return pd .read_csv (os .path .join (self .base_path , file_path ))
274+ """Load synthetic data from a CSV or ZIP file."""
275+ full_path = os .path .join (self .base_path , file_path )
276+ if full_path .endswith ('.zip' ):
277+ return _read_zipped_data (full_path , modality = 'multi_table' )
278+
279+ return pd .read_csv (full_path )
267280
268281 def _get_results_files (self , folder_name , prefix , suffix ):
269282 return [
@@ -287,7 +300,8 @@ def _load_yaml_file(self, folder_name, file_name):
287300class S3ResultsHandler (ResultsHandler ):
288301 """Results handler for AWS S3 storage."""
289302
290- def __init__ (self , path , s3_client ):
303+ def __init__ (self , path , s3_client , baseline_synthesizer = SYNTHESIZER_BASELINE ):
304+ super ().__init__ (baseline_synthesizer = baseline_synthesizer )
291305 self .s3_client = s3_client
292306 self .bucket_name = path .split ('/' )[2 ]
293307 self .prefix = '/' .join (path .split ('/' )[3 :]).rstrip ('/' ) + '/'
@@ -374,10 +388,13 @@ def load_synthesizer(self, file_path):
374388
375389 def load_synthetic_data (self , file_path ):
376390 """Load synthetic data from S3."""
377- response = self .s3_client .get_object (
378- Bucket = self .bucket_name , Key = f'{ self .prefix } { file_path } '
379- )
380- return pd .read_csv (io .BytesIO (response ['Body' ].read ()))
391+ key = f'{ self .prefix } { file_path } '
392+ response = self .s3_client .get_object (Bucket = self .bucket_name , Key = key )
393+ body = response ['Body' ].read ()
394+ if file_path .endswith ('.zip' ):
395+ return _read_zipped_data (io .BytesIO (body ), modality = 'multi_table' )
396+
397+ return pd .read_csv (io .BytesIO (body ))
381398
382399 def _get_results_files (self , folder_name , prefix , suffix ):
383400 s3_prefix = f'{ self .prefix } { folder_name } /'
@@ -396,8 +413,8 @@ def _get_results(self, folder_name, file_names):
396413 for file_name in file_names :
397414 s3_key = f'{ self .prefix } { folder_name } /{ file_name } '
398415 response = self .s3_client .get_object (Bucket = self .bucket_name , Key = s3_key )
399- df = pd .read_csv (io .BytesIO (response ['Body' ].read ()))
400- results .append (df )
416+ result_df = pd .read_csv (io .BytesIO (response ['Body' ].read ()))
417+ results .append (result_df )
401418
402419 return results
403420
0 commit comments