Skip to content

Commit 11d231e

Browse files
authored
Add multi table support to ResultsExplorer (#505)
1 parent c13ed60 commit 11d231e

File tree

26 files changed

+412
-78
lines changed

26 files changed

+412
-78
lines changed

sdgym/benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,7 +1832,7 @@ def benchmark_multi_table(
18321832
output_destination=None,
18331833
show_progress=False,
18341834
):
1835-
"""Run the SDGym benchmark on single-table datasets.
1835+
"""Run the SDGym benchmark on multi-table datasets.
18361836
18371837
Args:
18381838
synthesizers (list[string]):
@@ -1844,8 +1844,8 @@ def benchmark_multi_table(
18441844
or ``create_synthesizer_variant``). Defaults to ``None``.
18451845
sdv_datasets (list[str] or ``None``):
18461846
Names of the SDV demo datasets to use for the benchmark. Defaults to
1847-
``[adult, alarm, census, child, expedia_hotel_logs, insurance, intrusion, news,
1848-
covtype]``. Use ``None`` to disable using any sdv datasets.
1847+
``[NBA, financial, Student_loan, Biodegradability, fake_hotels, restbase,
1848+
airbnb-simplified]``. Use ``None`` to disable using any sdv datasets.
18491849
additional_datasets_folder (str or ``None``):
18501850
The path to a folder (local or an S3 bucket). Datasets found in this folder are
18511851
run in addition to the SDV datasets. If ``None``, no additional datasets are used.

sdgym/datasets.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,40 @@ def _get_bucket_name(bucket):
3434
return bucket[len(S3_PREFIX) :] if bucket.startswith(S3_PREFIX) else bucket
3535

3636

37+
def _raise_dataset_not_found_error(
38+
s3_client,
39+
bucket_name,
40+
dataset_name,
41+
current_modality,
42+
bucket,
43+
modality,
44+
):
45+
display_name = dataset_name
46+
if isinstance(dataset_name, Path):
47+
display_name = dataset_name.name
48+
49+
available_modalities = []
50+
for other_modality in MODALITIES:
51+
if other_modality == current_modality:
52+
continue
53+
54+
other_prefix = f'{other_modality.lower()}/{dataset_name}/'
55+
other_contents = _list_s3_bucket_contents(s3_client, bucket_name, other_prefix)
56+
if other_contents:
57+
available_modalities.append(other_modality)
58+
59+
if available_modalities:
60+
available_list = ', '.join(sorted(available_modalities))
61+
raise ValueError(
62+
f"Dataset '{display_name}' not found in bucket '{bucket}' "
63+
f"for modality '{modality}'. It is available under modality: '{available_list}'."
64+
)
65+
else:
66+
raise ValueError(
67+
f"Dataset '{display_name}' not found in bucket '{bucket}' for modality '{modality}'."
68+
)
69+
70+
3771
def _download_dataset(
3872
modality,
3973
dataset_name,
@@ -53,12 +87,8 @@ def _download_dataset(
5387

5488
contents = _list_s3_bucket_contents(s3_client, bucket_name, prefix)
5589
if not contents:
56-
display_name = dataset_name
57-
if isinstance(dataset_name, Path):
58-
display_name = dataset_name.name
59-
60-
raise ValueError(
61-
f"Dataset '{display_name}' not found in bucket '{bucket}' for modality '{modality}'."
90+
_raise_dataset_not_found_error(
91+
s3_client, bucket_name, dataset_name, modality, bucket, modality
6292
)
6393

6494
for obj in contents:

sdgym/result_explorer/result_explorer.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
import os
44

5-
from sdgym.benchmark import DEFAULT_SINGLE_TABLE_DATASETS
65
from sdgym.datasets import load_dataset
7-
from sdgym.result_explorer.result_handler import LocalResultsHandler, S3ResultsHandler
6+
from sdgym.result_explorer.result_handler import (
7+
SYNTHESIZER_BASELINE,
8+
LocalResultsHandler,
9+
S3ResultsHandler,
10+
)
811
from sdgym.s3 import _get_s3_client, is_s3_path
12+
from sdgym.synthesizers.base import _validate_modality
913

1014

1115
def _validate_local_path(path):
@@ -14,20 +18,51 @@ def _validate_local_path(path):
1418
raise ValueError(f"The provided path '{path}' is not a valid local directory.")
1519

1620

21+
_BASELINE_BY_MODALITY = {
22+
'single_table': SYNTHESIZER_BASELINE,
23+
'multi_table': 'MultiTableUniformSynthesizer',
24+
}
25+
26+
27+
def _resolve_effective_path(path, modality):
28+
"""Append the modality folder to the given base path if provided."""
29+
# Avoid double-appending if already included
30+
if str(path).rstrip('/').endswith(('/' + modality, modality)):
31+
return path
32+
33+
if is_s3_path(path):
34+
return path.rstrip('/') + '/' + modality
35+
36+
return os.path.join(path, modality)
37+
38+
1739
class ResultsExplorer:
1840
"""Explorer for SDGym benchmark results, supporting both local and S3 storage."""
1941

20-
def __init__(self, path, aws_access_key_id=None, aws_secret_access_key=None):
42+
def _create_results_handler(self, original_path, effective_path):
43+
"""Create the appropriate results handler for local or S3 storage."""
44+
baseline_synthesizer = _BASELINE_BY_MODALITY.get(self.modality, SYNTHESIZER_BASELINE)
45+
if is_s3_path(original_path):
46+
s3_client = _get_s3_client(
47+
original_path, self.aws_access_key_id, self.aws_secret_access_key
48+
)
49+
return S3ResultsHandler(
50+
effective_path, s3_client, baseline_synthesizer=baseline_synthesizer
51+
)
52+
53+
_validate_local_path(effective_path)
54+
return LocalResultsHandler(effective_path, baseline_synthesizer=baseline_synthesizer)
55+
56+
def __init__(
57+
self, path, modality='single_table', aws_access_key_id=None, aws_secret_access_key=None
58+
):
2159
self.path = path
60+
_validate_modality(modality)
61+
self.modality = modality.lower()
2262
self.aws_access_key_id = aws_access_key_id
2363
self.aws_secret_access_key = aws_secret_access_key
24-
25-
if is_s3_path(path):
26-
s3_client = _get_s3_client(path, aws_access_key_id, aws_secret_access_key)
27-
self._handler = S3ResultsHandler(path, s3_client)
28-
else:
29-
_validate_local_path(path)
30-
self._handler = LocalResultsHandler(path)
64+
effective_path = _resolve_effective_path(path, self.modality)
65+
self._handler = self._create_results_handler(path, effective_path)
3166

3267
def list(self):
3368
"""List all runs available in the results directory."""
@@ -37,7 +72,11 @@ def _get_file_path(self, results_folder_name, dataset_name, synthesizer_name, fi
3772
"""Validate access to the synthesizer or synthetic data file."""
3873
end_filename = f'{synthesizer_name}'
3974
if file_type == 'synthetic_data':
40-
end_filename += '_synthetic_data.csv'
75+
# Multi-table synthetic data is zipped (multiple CSVs), single table is CSV
76+
if self.modality == 'multi_table':
77+
end_filename += '_synthetic_data.zip'
78+
else:
79+
end_filename += '_synthetic_data.csv'
4180
elif file_type == 'synthesizer':
4281
end_filename += '.pkl'
4382

@@ -62,14 +101,8 @@ def load_synthetic_data(self, results_folder_name, dataset_name, synthesizer_nam
62101

63102
def load_real_data(self, dataset_name):
64103
"""Load the real data for a given dataset."""
65-
if dataset_name not in DEFAULT_SINGLE_TABLE_DATASETS:
66-
raise ValueError(
67-
f"Dataset '{dataset_name}' is not a SDGym dataset. "
68-
'Please provide a valid dataset name.'
69-
)
70-
71104
data, _ = load_dataset(
72-
modality='single_table',
105+
modality=self.modality,
73106
dataset=dataset_name,
74107
aws_access_key_id=self.aws_access_key_id,
75108
aws_secret_access_key=self.aws_secret_access_key,

sdgym/result_explorer/result_handler.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import yaml
1212
from botocore.exceptions import ClientError
1313

14+
from sdgym._dataset_utils import _read_zipped_data
15+
1416
SYNTHESIZER_BASELINE = 'GaussianCopulaSynthesizer'
1517
RESULTS_FOLDER_PREFIX = 'SDGym_results_'
1618
metainfo_PREFIX = 'metainfo'
@@ -22,6 +24,9 @@
2224
class ResultsHandler(ABC):
2325
"""Abstract base class for handling results storage and retrieval."""
2426

27+
def __init__(self, baseline_synthesizer=SYNTHESIZER_BASELINE):
28+
self.baseline_synthesizer = baseline_synthesizer or SYNTHESIZER_BASELINE
29+
2530
@abstractmethod
2631
def list(self):
2732
"""List all runs in the results directory."""
@@ -59,7 +64,8 @@ def _compute_wins(self, result):
5964
result['Win'] = 0
6065
for dataset in datasets:
6166
score_baseline = result.loc[
62-
(result['Synthesizer'] == SYNTHESIZER_BASELINE) & (result['Dataset'] == dataset)
67+
(result['Synthesizer'] == self.baseline_synthesizer)
68+
& (result['Dataset'] == dataset)
6369
]['Quality_Score'].to_numpy()
6470
if score_baseline.size == 0:
6571
continue
@@ -84,7 +90,7 @@ def _get_summarize_table(self, folder_to_results, folder_infos):
8490
f' - # datasets: {folder_infos[folder]["# datasets"]}'
8591
f' - sdgym version: {folder_infos[folder]["sdgym_version"]}'
8692
)
87-
results = results.loc[results['Synthesizer'] != SYNTHESIZER_BASELINE]
93+
results = results.loc[results['Synthesizer'] != self.baseline_synthesizer]
8894
column_data = results.groupby(['Synthesizer'])['Win'].sum()
8995
columns.append((date_obj, column_name, column_data))
9096

@@ -107,9 +113,11 @@ def _get_column_name_infos(self, folder_to_results):
107113
continue
108114

109115
metainfo_info = self._load_yaml_file(folder, yaml_files[0])
110-
num_datasets = results.loc[
111-
results['Synthesizer'] == SYNTHESIZER_BASELINE, 'Dataset'
112-
].nunique()
116+
baseline_mask = results['Synthesizer'] == self.baseline_synthesizer
117+
if baseline_mask.any():
118+
num_datasets = results.loc[baseline_mask, 'Dataset'].nunique()
119+
else:
120+
num_datasets = results['Dataset'].nunique()
113121
folder_to_info[folder] = {
114122
'date': metainfo_info.get('starting_date')[:NUM_DIGITS_DATE],
115123
'sdgym_version': metainfo_info.get('sdgym_version'),
@@ -236,7 +244,8 @@ def all_runs_complete(self, folder_name):
236244
class LocalResultsHandler(ResultsHandler):
237245
"""Results handler for local filesystem."""
238246

239-
def __init__(self, base_path):
247+
def __init__(self, base_path, baseline_synthesizer=SYNTHESIZER_BASELINE):
248+
super().__init__(baseline_synthesizer=baseline_synthesizer)
240249
self.base_path = base_path
241250

242251
def list(self):
@@ -262,8 +271,12 @@ def load_synthesizer(self, file_path):
262271
return cloudpickle.load(f)
263272

264273
def load_synthetic_data(self, file_path):
265-
"""Load synthetic data from a CSV file."""
266-
return pd.read_csv(os.path.join(self.base_path, file_path))
274+
"""Load synthetic data from a CSV or ZIP file."""
275+
full_path = os.path.join(self.base_path, file_path)
276+
if full_path.endswith('.zip'):
277+
return _read_zipped_data(full_path, modality='multi_table')
278+
279+
return pd.read_csv(full_path)
267280

268281
def _get_results_files(self, folder_name, prefix, suffix):
269282
return [
@@ -287,7 +300,8 @@ def _load_yaml_file(self, folder_name, file_name):
287300
class S3ResultsHandler(ResultsHandler):
288301
"""Results handler for AWS S3 storage."""
289302

290-
def __init__(self, path, s3_client):
303+
def __init__(self, path, s3_client, baseline_synthesizer=SYNTHESIZER_BASELINE):
304+
super().__init__(baseline_synthesizer=baseline_synthesizer)
291305
self.s3_client = s3_client
292306
self.bucket_name = path.split('/')[2]
293307
self.prefix = '/'.join(path.split('/')[3:]).rstrip('/') + '/'
@@ -374,10 +388,13 @@ def load_synthesizer(self, file_path):
374388

375389
def load_synthetic_data(self, file_path):
376390
"""Load synthetic data from S3."""
377-
response = self.s3_client.get_object(
378-
Bucket=self.bucket_name, Key=f'{self.prefix}{file_path}'
379-
)
380-
return pd.read_csv(io.BytesIO(response['Body'].read()))
391+
key = f'{self.prefix}{file_path}'
392+
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
393+
body = response['Body'].read()
394+
if file_path.endswith('.zip'):
395+
return _read_zipped_data(io.BytesIO(body), modality='multi_table')
396+
397+
return pd.read_csv(io.BytesIO(body))
381398

382399
def _get_results_files(self, folder_name, prefix, suffix):
383400
s3_prefix = f'{self.prefix}{folder_name}/'
@@ -396,8 +413,8 @@ def _get_results(self, folder_name, file_names):
396413
for file_name in file_names:
397414
s3_key = f'{self.prefix}{folder_name}/{file_name}'
398415
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=s3_key)
399-
df = pd.read_csv(io.BytesIO(response['Body'].read()))
400-
results.append(df)
416+
result_df = pd.read_csv(io.BytesIO(response['Body'].read()))
417+
results.append(result_df)
401418

402419
return results
403420

sdgym/run_benchmark/upload_benchmark_results.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def upload_results(
116116
run_date = folder_infos['date']
117117
result_explorer = ResultsExplorer(
118118
OUTPUT_DESTINATION_AWS,
119+
modality='single_table',
119120
aws_access_key_id=aws_access_key_id,
120121
aws_secret_access_key=aws_secret_access_key,
121122
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Synthesizer,Dataset,Dataset_Size_MB,Train_Time,Peak_Memory_MB,Synthesizer_Size_MB,Sample_Time,Evaluate_Time,Diagnostic_Score,Quality_Score
2+
HMASynthesizer,fake_hotels,0.048698,22.852492,33.315142,0.988611,2.723049,0.082362,1.0,0.7353482911012336
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Synthesizer,Dataset,Dataset_Size_MB,Train_Time,Peak_Memory_MB,Synthesizer_Size_MB,Sample_Time,Evaluate_Time,Diagnostic_Score,Quality_Score
2+
MultiTableUniformSynthesizer,fake_hotels,0.048698,0.201284,0.851853,0.109464,0.02749,0.081629,0.9122678149273894,0.5962941240006595

0 commit comments

Comments
 (0)