-
Notifications
You must be signed in to change notification settings - Fork 1
Bz/tf #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bz/tf #104
Conversation
📝 WalkthroughWalkthroughThis PR introduces a comprehensive membership inference attack (MIA) framework for tabular differentially private models. The implementation includes three new Python modules providing data loading utilities ( Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 17
Note
Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.
🟡 Minor comments (4)
src/midst_toolkit/attacks/tf/data_utils.py-157-165 (1)
157-165: Deduplication key validation happens afterdrop_duplicatesis called.The validation for missing keys (lines 162-165) occurs after
drop_duplicatesis already called (lines 158-159). If keys are missing,drop_duplicateswill raise aKeyErrorbefore your descriptiveValueErroris reached.+ # Ensure all keys for deduplication exist in both DataFrames before deduplication + missing_keys_merge = [key for key in keys_for_deduplication if key not in df_merge.columns] + missing_keys_challenge = [key for key in keys_for_deduplication if key not in df_challenge.columns] + if missing_keys_merge or missing_keys_challenge: + raise ValueError(f"Missing columns for deduplication: {missing_keys_merge + missing_keys_challenge}") + # Deduplicate the datasets once df_merge = df_merge.drop_duplicates(subset=keys_for_deduplication) df_challenge = df_challenge.drop_duplicates(subset=keys_for_deduplication) - - # Ensure all keys for deduplication exist in both DataFrames - missing_keys_merge = [key for key in keys_for_deduplication if key not in df_merge.columns] - missing_keys_challenge = [key for key in keys_for_deduplication if key not in df_challenge.columns] - if missing_keys_merge or missing_keys_challenge: - raise ValueError(f"Missing columns for deduplication: {missing_keys_merge + missing_keys_challenge}")src/midst_toolkit/attacks/tf/data_utils.py-174-179 (1)
174-179: PotentialValueErroron edge case.If all FPR values are >=
max_fpr,tpr[fpr < max_fpr]will be an empty array andmax()will raise aValueError.def get_tpr_at_fpr(true_membership: list[int], predictions: list[float], max_fpr: float = 0.1) -> float: """ Calculates the best True Positive Rate when the False Positive Rate is at most `max_fpr`. """ fpr, tpr, _ = roc_curve(true_membership, predictions) - return max(tpr[fpr < max_fpr]) + valid_tpr = tpr[fpr <= max_fpr] + if len(valid_tpr) == 0: + return 0.0 + return float(max(valid_tpr))src/midst_toolkit/attacks/tf/classifcation.py-133-134 (1)
133-134: Conditionx_val is not Noneis always true after tensor conversion.On line 116,
x_valis unconditionally converted to a tensor viatorch.tensor(x_val, ...). If the originalx_valparameter wasNone, this would raise an error before reaching line 133. The check should happen before tensor conversion.+ has_validation = x_val is not None x_train = torch.tensor(x_train, dtype=torch.float32).to(device) y_train = torch.tensor(x_train_label, dtype=torch.float32).to(device) - x_val = torch.tensor(x_val, dtype=torch.float32).to(device) - y_test = torch.tensor(x_val_label, dtype=torch.float32).to(device) + if has_validation: + x_val = torch.tensor(x_val, dtype=torch.float32).to(device) + y_val = torch.tensor(x_val_label, dtype=torch.float32).to(device)Then use
has_validationin the condition on line 133.Committable suggestion skipped: line range outside the PR's diff.
tests/integration/attacks/tf/test_tf_attack.py-51-53 (1)
51-53: Use explicit key access instead of dict unpacking for clarity.The dictionaries returned have keys
"max_tpr"and"roc_auc". While Python 3.7+ guarantees dictionary insertion order, explicit key access makes the code clearer and more maintainable.- tpr_at_fpr_train, roc_auc_train = mia_performance_train.values() - tpr_at_fpr_val, roc_auc_val = mia_performance_val.values() - tpr_at_fpr_test, roc_auc_test = mia_performance_test.values() + tpr_at_fpr_train = mia_performance_train["max_tpr"] + roc_auc_train = mia_performance_train["roc_auc"] + tpr_at_fpr_val = mia_performance_val["max_tpr"] + roc_auc_val = mia_performance_val["roc_auc"] + tpr_at_fpr_test = mia_performance_test["max_tpr"] + roc_auc_test = mia_performance_test["roc_auc"]
🧹 Nitpick comments (15)
src/midst_toolkit/attacks/tf/data_utils.py (2)
1-6: Excessive linting suppressions reduce code quality.Suppressing D102, D105, D103, D200 (docstring rules) and multiple mypy error codes across the entire file is a significant code smell. As per the PR description, "fix typing errors" is listed as a next step — consider addressing these rather than suppressing them.
59-59: Unusedverboseparameter.The
verboseparameter is declared but never used in the function body. Either implement verbose logging or remove the parameter.src/midst_toolkit/attacks/tf/classifcation.py (2)
59-62: Remove dead code:x = x.float()is unused.The variable
xis reassigned on line 60 but never used afterward. This appears to be leftover code.def custom_loss_fn(model, x, y): confidences = model(x) - x = x.float() y = y.float() return nn.BCELoss()(confidences, y.unsqueeze(1))
114-117: Inconsistent variable naming:y_testvsx_val.Line 117 uses
y_testfor labels corresponding tox_val, mixing "test" and "val" terminology. This is confusing given the function parameters usex_valandx_val_label.x_train = torch.tensor(x_train, dtype=torch.float32).to(device) y_train = torch.tensor(x_train_label, dtype=torch.float32).to(device) x_val = torch.tensor(x_val, dtype=torch.float32).to(device) - y_test = torch.tensor(x_val_label, dtype=torch.float32).to(device) + y_val = torch.tensor(x_val_label, dtype=torch.float32).to(device)Then update references on lines 134, 150 accordingly.
src/midst_toolkit/attacks/tf/tf_attack.py (11)
1-6: Clean up unused lint suppressions.Static analysis indicates the
noqadirective on line 1 is unused. Since the PR author already plans to "fix typing errors", consider removing unnecessary suppressions once proper type annotations are added rather than blanket-disabling mypy checks.
37-82: Several issues inmixed_lossfunction.
- Unused parameter:
no_meanis never used (Line 44).- Redundant device assignment:
deviceis assigned on line 51 but immediately overwritten on line 61.- Typo: "defeualt" → "default" (Line 63).
- Redundant conditional: The check
if not return_random:on line 74 is alwaysTruesince we already returned on line 67 whenreturn_random=True.def mixed_loss( diffusion, x, out_dict, noise=None, t=None, return_random=False, - no_mean=False, parallel_batch=None, addt_value=None, ): x_num = x[:, : diffusion.num_numerical_features] x_cat = x[:, diffusion.num_numerical_features :] - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = x.device noise_tensor = torch.tensor(noise, device=device, dtype=torch.float) batch_noise = noise_tensor.repeat(x_num.shape[0], 1) # there is actually no categorical classes, as we have examined the DM, so we just ignore x_cat here and later x_num = x_num.repeat_interleave(parallel_batch, dim=0) x_cat = x_cat.repeat_interleave(parallel_batch, dim=0) b = x_num.shape[0] - device = x.device if t is None: - # the defeualt is uniform sampling + # the default is uniform sampling t, pt = diffusion.sample_time(b, device) if return_random: return noise, t, pt additional_t = t * 0 + addt_value # forward x_num_t with (t+additional_t) timestamps x_num_t = diffusion.gaussian_q_sample(x_num, t + additional_t, noise=batch_noise) - if not return_random: - current_t = t - # predict noises with t timestamps - predicted_noise = diffusion._denoise_fn(x_num_t, current_t, **out_dict) - current_loss = diffusion._gaussian_loss(predicted_noise, batch_noise, batch_noise, current_t, batch_noise) - transformed_current_loss = current_loss.reshape(-1, parallel_batch) + current_t = t + # predict noises with t timestamps + predicted_noise = diffusion._denoise_fn(x_num_t, current_t, **out_dict) + current_loss = diffusion._gaussian_loss(predicted_noise, batch_noise, batch_noise, current_t, batch_noise) + transformed_current_loss = current_loss.reshape(-1, parallel_batch) return transformed_current_loss * 0, transformed_current_loss
136-142: Prefix unused variables with underscore.
label_encodersandcolumn_ordersare unpacked but never used. Per convention, prefix with underscore to indicate intentional non-use.- dataset, label_encoders, column_orders = Dataset.from_df( + dataset, _label_encoders, _column_orders = Dataset.from_df(
207-229: Misplaced docstring.The docstring is placed in the middle of the function after executable code (lines 201-205). It should be the first statement after the function signature to be recognized by documentation tools and IDEs.
Move the docstring to immediately after line 198 (
def get_score(...):) before any executable statements.
247-252: Remove debug print and simplify loop logic.
- Debug print:
print(iter_max)should be removed or replaced with proper logging.- Confusing loop: The
assert iter_max == 1followed bywhile iter_id < iter_max:means the loop always runs exactly once. Consider simplifying to remove the loop entirely or documenting why this structure exists for future extensibility.- print(iter_max) iter_max = iter_max // batch_size return_res = torch.zeros([batch_size, parallel_batch]) assert iter_max == 1 - iter_id = 0 - while iter_id < iter_max: + # Process single batch (currently only supports iter_max == 1) + x, out_dict = next(train_loader) ... - iter_id += 1
236-236: Prefix unused variable with underscore.
challenge_datasetis unpacked but never used.- train_loader, iter_max, challenge_dataset = train_loader_list[loader_count] + train_loader, iter_max, _challenge_dataset = train_loader_list[loader_count]
261-269: Prefix unused variables with underscore.
noiseandptare unpacked but never used in this context.- noise, t_cur, pt = mixed_loss( + _noise, t_cur, _pt = mixed_loss(
309-316: Hardcoded deduplication keys reduce reusability.The keys
["trans_id", "balance"]are hardcoded in bothprepare_data_for_attackcalls. Consider making these configurable via a parameter.
479-479: Document magic number for noise dimension.The noise dimension
size=8is hardcoded without explanation. Consider extracting this to a named constant or parameter with documentation explaining why 8 is the appropriate value.+ NOISE_DIMENSION = 8 # Must match the diffusion model's expected noise dimension - input_noise: list[list[float]] = [np.random.normal(size=8).tolist() for _ in range(num_noise_per_time_step)] + input_noise: list[list[float]] = [np.random.normal(size=NOISE_DIMENSION).tolist() for _ in range(num_noise_per_time_step)]
187-198: Unusedphaseparameter.The
phaseparameter is accepted but never used in the function body. Either remove it or implement the intended behavior.def get_score( data_path, save_dir, input_noise, type="tabddpm", - phase=None, challenge_name=None, batch_size=None, parallel_batch=None, addt_value=None, t_value=None, ):If
phaseis intended for future use, add a TODO comment or raiseNotImplementedErrorwhen a non-None value is passed.
191-191: Avoid shadowing built-intype.Using
typeas a parameter name shadows Python's built-intype()function, which can cause subtle bugs if the built-in is needed within this function.def get_score( data_path, save_dir, input_noise, - type="tabddpm", + model_type="tabddpm", ... ): - if type == "tabddpm": + if model_type == "tabddpm":
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (49)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/data_for_training_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/data_for_training_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/data_for_validating_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/data_for_validating_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_synthetic.csvis excluded by!**/*.csv
📒 Files selected for processing (22)
src/midst_toolkit/attacks/tf/classifcation.py(1 hunks)src/midst_toolkit/attacks/tf/data_utils.py(1 hunks)src/midst_toolkit/attacks/tf/tf_attack.py(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/updated_config.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/updated_config.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/updated_config.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/data_configs/dataset_meta.json(1 hunks)tests/integration/attacks/tf/data_configs/trans.json(1 hunks)tests/integration/attacks/tf/data_configs/trans_domain.json(1 hunks)tests/integration/attacks/tf/test_tf_attack.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (13)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args (2)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
test_save_additional_tabddpm_config(19-54)src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json (2)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)src/midst_toolkit/models/clavaddpm/model.py (1)
DiffusionParameters(19-31)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/updated_config.json (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
test_save_additional_tabddpm_config(19-54)src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args (1)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/workspace/train_1/args (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
test_save_additional_tabddpm_config(19-54)src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json (2)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
test_save_additional_tabddpm_config(19-54)src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
tests/integration/attacks/tf/data_configs/trans_domain.json (1)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)
tests/integration/attacks/tf/data_configs/trans.json (2)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)src/midst_toolkit/attacks/ensemble/blending.py (1)
__init__(27-66)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/updated_config.json (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
test_save_additional_tabddpm_config(19-54)src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/updated_config.json (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
test_train_and_fine_tune_tabddpm(135-187)tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
test_save_additional_tabddpm_config(19-54)src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
save_additional_tabddpm_config(36-76)
tests/integration/attacks/tf/test_tf_attack.py (2)
src/midst_toolkit/attacks/tf/tf_attack.py (1)
tf_attack(459-559)src/midst_toolkit/common/random.py (2)
set_all_random_seeds(11-55)unset_all_random_seeds(58-67)
src/midst_toolkit/attacks/tf/data_utils.py (3)
src/midst_toolkit/models/clavaddpm/dataset.py (1)
Dataset(77-397)src/midst_toolkit/models/clavaddpm/enumerations.py (1)
Normalization(58-63)src/midst_toolkit/models/clavaddpm/data_loaders.py (1)
FastTensorDataLoader(473-537)
🪛 Ruff (0.14.7)
src/midst_toolkit/attacks/tf/tf_attack.py
1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200, PLR0915)
Remove unused noqa directive
(RUF100)
44-44: Unused function argument: no_mean
(ARG001)
85-85: Avoid specifying long messages outside the exception class
(TRY003)
136-136: Unpacked variable label_encoders is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
136-136: Unpacked variable column_orders is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
192-192: Unused function argument: phase
(ARG001)
203-203: Avoid specifying long messages outside the exception class
(TRY003)
236-236: Unpacked variable challenge_dataset is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
261-261: Unpacked variable noise is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
261-261: Unpacked variable pt is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
src/midst_toolkit/attacks/tf/data_utils.py
1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200)
Remove unused noqa directive
(RUF100)
59-59: Unused function argument: verbose
(ARG001)
75-75: Avoid specifying long messages outside the exception class
(TRY003)
95-95: Avoid specifying long messages outside the exception class
(TRY003)
101-101: Avoid specifying long messages outside the exception class
(TRY003)
143-143: Avoid specifying long messages outside the exception class
(TRY003)
165-165: Avoid specifying long messages outside the exception class
(TRY003)
190-190: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: run-code-check
- GitHub Check: unit-tests
- GitHub Check: integration-tests
🔇 Additional comments (16)
src/midst_toolkit/attacks/tf/data_utils.py (1)
218-261: LGTM!The
FastTensorDataLoaderimplementation correctly handles batching and shuffling. The pattern matches the reference implementation fromsrc/midst_toolkit/models/clavaddpm/data_loaders.py.tests/integration/attacks/tf/data_configs/trans.json (1)
1-50: LGTM!This configuration uses relative paths consistently, making it portable across different environments unlike some of the other tabddpm config files.
tests/integration/attacks/tf/data_configs/dataset_meta.json (1)
1-1: ✓ Test metadata structure is appropriate.The JSON correctly describes a single-table dataset with no relationships, which aligns with the trans domain schema and test configuration patterns.
tests/integration/attacks/tf/data_configs/trans_domain.json (1)
1-1: ✓ Domain schema is consistent and well-formed.The feature definitions (8 fields with appropriate continuous/discrete classifications) align with the dataset_meta.json and are consistent across all tabddpm model variants used in tests.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1)
1-1: ✓ Consistent test asset.Domain schema matches across all tabddpm variants, which is appropriate for uniform test configuration.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json (1)
1-1: ✓ Consistent test asset.Matches schema across all tabddpm variants for uniform testing.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json (1)
1-1: ✓ Consistent test asset.Matches schema across all tabddpm variants for uniform testing.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json (1)
1-1: ✓ Consistent test asset.Matches schema across all tabddpm variants for uniform testing.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/updated_config.json (2)
1-8: Verify absolute workspace_dir path will not break tests.Line 5 contains an absolute path
/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_2/workspacethat does not exist in typical CI/test environments. This may cause test failures when the configuration is used to create directories or write outputs.Check whether:
- The integration test overrides this path at runtime (e.g., using
save_additional_tabddpm_config)- The code gracefully handles missing parent directories
- Tests are expected to run in an environment where
/projects/existsThe relative paths for
data_dirandtest_data_dir(lines 3, 7) are appropriate for test portability.
14-42: ✓ Hyperparameter choices are appropriate for fast integration testing.Minimal values (iterations=2–3, batch_size=1, num_timesteps=3) prioritize test speed while maintaining realistic configuration structure. This aligns with typical test fixtures for ML frameworks.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json (2)
1-8: Verify absolute workspace_dir path will not break tests.Line 5 contains an absolute path
/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_5/workspacethat does not exist in typical CI/test environments. This matches the same issue intabddpm_2/updated_config.jsonand suggests a systematic configuration pattern.Verify that:
- The integration test overrides workspace_dir paths at runtime (expected based on
save_additional_tabddpm_configutility)- Absolute paths are intentionally used as templates and not written to during tests
The relative test data paths (lines 3, 7) are appropriate for portability.
14-42: ✓ Hyperparameter choices are appropriate for fast integration testing.Configuration values match the appropriate minimal test settings (iterations=2–3, batch_size=1, num_timesteps=3), balancing test speed with realistic model structure.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json (1)
1-1: LGTM!The domain schema defines the expected field structure with appropriate size and type metadata for the tabddpm_4 model configuration.
tests/integration/attacks/tf/test_tf_attack.py (1)
20-28: Good use of deterministic seeding for reproducible tests.The combination of
set_all_random_seedswithuse_deterministic_torch_algos=TrueandCUBLAS_WORKSPACE_CONFIGenvironment variable ensures reproducible results, which is essential for testing ML code with fixed expected values.src/midst_toolkit/attacks/tf/classifcation.py (1)
8-40: LGTM: MLP architecture is straightforward and well-documented.The 3-layer MLP with tanh activations and sigmoid output is appropriate for binary classification in the attack context.
src/midst_toolkit/attacks/tf/tf_attack.py (1)
81-81: Clarify intent of returning zeroed tensor.The return statement returns
transformed_current_loss * 0as the first element, which is always a zero tensor. If this is intentional (e.g., for API compatibility or specific algorithm requirements), consider adding a comment explaining why. Otherwise, this may be leftover debugging code.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args
Outdated
Show resolved
Hide resolved
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json
Outdated
Show resolved
Hide resolved
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args
Outdated
Show resolved
Hide resolved
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args
Outdated
Show resolved
Hide resolved
|
@coderabbitai full review |
✅ Actions performedFull review triggered. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
♻️ Duplicate comments (7)
src/midst_toolkit/attacks/tf/data_utils.py (1)
138-171: Still only returning the last model’schallenge_label.csvInside
prepare_data_for_attack,df_challenge_labelsis reassigned on every iteration (Line 152) and never accumulated. With multipleindices, you return labels only from the last model, whiledf_mergeanddf_challengeare concatenated across all models. This is the same issue previously flagged.If you intend to aggregate across models, collect labels in a list and concatenate:
def prepare_data_for_attack(indices, model_type, models_base_dir, keys_for_deduplication): @@ - df_challenge_list = [] + df_challenge_list = [] + df_challenge_labels_list = [] @@ - df_challenge_list.append(pd.read_csv(os.path.join(base_path, "challenge_with_id.csv"))) - df_challenge_labels = pd.read_csv(os.path.join(base_path, "challenge_label.csv")) + df_challenge_list.append(pd.read_csv(os.path.join(base_path, "challenge_with_id.csv"))) + df_challenge_labels_list.append(pd.read_csv(os.path.join(base_path, "challenge_label.csv"))) @@ - df_challenge = pd.concat(df_challenge_list, ignore_index=True) + df_challenge = pd.concat(df_challenge_list, ignore_index=True) + df_challenge_labels = pd.concat(df_challenge_labels_list, ignore_index=True) @@ - return df_merge_without_challenge, df_challenge, df_challenge_labels + return df_merge_without_challenge, df_challenge, df_challenge_labelsIf only the last model’s labels are truly desired, add an explicit comment and enforce that
indiceshas length 1.tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args (1)
2-8: Replace hardcoded absoluteworkspace_dirwith a project-relative path
workspace_diris an absolute path (Line 5), which will fail on other machines and CI. For a test asset, it should be under the repo with a relative path, similar todata_dirand other configs.- "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_4/workspace", + "workspace_dir": "tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace",tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json (1)
5-5: Hardcoded absolute path breaks CI/portability.This issue was previously flagged. The
workspace_dircontains an absolute path specific to a particular machine, which will fail in CI and for other developers.tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args (1)
5-5: Hardcoded absolute path breaks CI/portability.This issue was previously flagged. The
workspace_dircontains an absolute path specific to a particular machine, which will fail in CI and for other developers.tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args (1)
3-7: Avoid hardcoded absoluteworkspace_dirin test config.
"workspace_dir"is an absolute, developer/cluster-specific path and will not exist on other machines/CI. Use a project-relative path (like the other fields) or a placeholder that your test harness fills in.- "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_5/workspace", + "workspace_dir": "tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace",src/midst_toolkit/attacks/tf/tf_attack.py (2)
383-435: Useaddt_value_listwhen building training/validation features; current loop ignores it.Here you still hardcode
for addt_value in [0]:, soaddt_value_listis effectively unused. At the same time, the feature matrix width is computed aslen(input_noise) * len(timesteps_list) * len(addt_value_list), meaning anyaddt_value_listwith length > 1 leaves trailing zero columns, and the classifier input_dim is larger than the actually populated features.Refactor the nested loop to iterate over
addt_value_listand index feature blocks accordingly:- t_value_count = 0 - for t_value in timesteps_list: - for addt_value in [0]: + feature_block_index = 0 + for t_value in timesteps_list: + for addt_value in addt_value_list: if model_number in train_indices: @@ - x_train[ - samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1), - t_value_count * num_noise_per_time_step : (t_value_count + 1) * num_noise_per_time_step, - ] = predictions.detach().squeeze().cpu().numpy() + start = feature_block_index * num_noise_per_time_step + end = (feature_block_index + 1) * num_noise_per_time_step + x_train[ + samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1), + start:end, + ] = predictions.detach().squeeze().cpu().numpy() @@ - x_train_label[ - samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1) - ] = np.concatenate([np.zeros(samples_per_train_model), np.ones(samples_per_train_model)]) - t_value_count += 1 + x_train_label[ + samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1) + ] = np.concatenate([np.zeros(samples_per_train_model), np.ones(samples_per_train_model)]) + feature_block_index += 1 @@ - x_val[ - sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1), - t_value_count * num_noise_per_time_step : (t_value_count + 1) * num_noise_per_time_step, - ] = predictions.detach().squeeze().cpu().numpy() + start = feature_block_index * num_noise_per_time_step + end = (feature_block_index + 1) * num_noise_per_time_step + x_val[ + sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1), + start:end, + ] = predictions.detach().squeeze().cpu().numpy() @@ - x_val_label[sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1)] = ( - np.concatenate([np.zeros(sample_per_val_model), np.ones(sample_per_val_model)]) - ) - t_value_count += 1 + x_val_label[sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1)] = ( + np.concatenate([np.zeros(sample_per_val_model), np.ones(sample_per_val_model)]) + ) + feature_block_index += 1This way,
x_train/x_valand the MLP’sinput_dimstay consistent for arbitraryaddt_value_list.
503-528: Respectaddt_value_list, avoid hardcodedbatch_size, and guard min–max normalization against zero range.Three issues remain in
tf_attack’s scoring loop:
Hardcoded
batch_size = 200:
- This magic number is baked into the function, making it harder to reuse in other scenarios or tests.
addt_value_listis ignored:
- The loop still uses
for addt_value in [0]:, so callers can’t varyaddt_valuedespite it being a parameter.Potential division by zero in min–max normalization:
- If all model outputs are identical (
min_output == max_output),(max_output - min_output)is zero, leading tonans and failing the subsequent assertion.A possible fix:
-def tf_attack( +def tf_attack( @@ - classifier_hidden_dim: int, - addt_value_list: list[int], - meta_dir: Path, -) -> tuple[Any, Any, Any]: + classifier_hidden_dim: int, + addt_value_list: list[int], + meta_dir: Path, + batch_size: int = 200, +) -> tuple[Any, Any, Any]: @@ - model_dir = tabddpm_data_dir / model_folder - model_path = model_dir / target_model_subdir - batch_size = 200 - t_value_count = 0 + model_dir = tabddpm_data_dir / model_folder + model_path = model_dir / target_model_subdir current_input = [] for t_value in timesteps_list: - for addt_value in [0]: + for addt_value in addt_value_list: predictions: torch.Tensor = get_score( @@ - ) - t_value_count += 1 - current_input.append(predictions) + ) + current_input.append(predictions) @@ - predictions = regression_model(predictions).detach().cpu().numpy() - # clip to [0, 1] - min_output, max_output = np.min(predictions), np.max(predictions) - predictions = (predictions - min_output) / (max_output - min_output) + predictions = regression_model(predictions).detach().cpu().numpy() + # Rescale to [0, 1] safely + min_output, max_output = np.min(predictions), np.max(predictions) + range_output = max_output - min_output + if range_output > 0: + predictions = (predictions - min_output) / range_output + else: + # All outputs identical; fall back to a neutral constant + predictions = np.full_like(predictions, 0.5) predictions = torch.tensor(predictions)This keeps the existing integration test behavior (
batch_sizedefaults to 200,addt_value_list=[0]) while making the function robust and reusable.
🧹 Nitpick comments (8)
src/midst_toolkit/attacks/tf/data_utils.py (3)
1-18: Tighten lint/typing pragmas and clean leftover comment
- Line 1:
# ruff: noqa: D102, D105, D103, D200is flagged as unused; if these rules aren’t enabled, drop this directive instead of carrying dead config.- Lines 2–6: multiple broad mypy disables (
no-untyped-def,has-type,index,attr-defined,assignment) effectively opt this module out of type checking. Given the PR TODO “fix typing errors”, it would be better long‑term to narrow these or remove them once annotations are in place.- Line 16: Comment
# at very top of file (optional but helpful)looks like a review note rather than code documentation and can be removed.-# ruff: noqa: D102, D105, D103, D200 -# mypy: disable-error-code=no-untyped-def -# mypy: disable-error-code=has-type -# mypy: disable-error-code=index -# mypy: disable-error-code=attr-defined -# mypy: disable-error-code=assignment -from __future__ import annotations - -# at very top of file (optional but helpful) +from __future__ import annotations @@ -from typing import Any, Literal +from typing import Any, Literal(And later, once typing is addressed, consider re‑enabling mypy checks instead of disabling them globally for this file.)
59-76:verboseparameter inload_multi_table_customizedis currently unusedThe
verboseargument (Line 59) is never read, which is flagged by static analysis and can confuse callers expecting logging or extra checks.Either remove the parameter if not needed:
-def load_multi_table_customized(data_dir, meta_dir=None, train_name="train.csv", verbose=True): +def load_multi_table_customized(data_dir, meta_dir=None, train_name="train.csv"):or actually use it (e.g., to print/log dataset/table info when
verboseis True).
267-304:prepare_fast_dataloaderbehavior (infinite batches, feature concatenation) is intentional but worth documenting clearly
- Correctly handles three cases: both numerical+categorical, categorical‑only, and numerical‑only.
- Returns an endless stream of
(x, y)batches viawhile True: yield from dataloader, reshuffling each epoch whensplit == "train".Consider making the “infinite generator” behavior even more explicit in the docstring (e.g., “This is an infinite generator; callers must bound iteration externally”) to avoid misuse in code that expects a single finite epoch.
tests/unit/evaluation/quality/test_alpha_precision_naive.py (1)
53-58: Redundant conditional branches with identical assertions.Both the
if is_apple_silicon()andelsebranches contain identical assertions. Either the branching is unnecessary and can be simplified, or the expected values for the non-Apple Silicon path should differ.If the values are truly architecture-independent for naive metrics, simplify:
- if is_apple_silicon(): - assert pytest.approx(0.05994074074074074, abs=1e-8) == quality_results["delta_precision_alpha_naive"] - assert pytest.approx(0.005229629629629584, abs=1e-8) == quality_results["delta_coverage_beta_naive"] - else: - assert pytest.approx(0.05994074074074074, abs=1e-8) == quality_results["delta_precision_alpha_naive"] - assert pytest.approx(0.005229629629629584, abs=1e-8) == quality_results["delta_coverage_beta_naive"] + assert pytest.approx(0.05994074074074074, abs=1e-8) == quality_results["delta_precision_alpha_naive"] + assert pytest.approx(0.005229629629629584, abs=1e-8) == quality_results["delta_coverage_beta_naive"]tests/integration/attacks/tf/test_tf_attack.py (2)
45-54: Remove redundant.values()unpacking before explicit key-based access.You first unpack
mia_performance_* .values()into variables, then immediately overwrite them using explicit key lookups. The unpacking is unnecessary and relies on dict insertion order; just keep the key-based access:- mia_performance_train, mia_performance_val, mia_performance_test = tf_attack(**config) - tpr_at_fpr_train, roc_auc_train = mia_performance_train.values() - tpr_at_fpr_val, roc_auc_val = mia_performance_val.values() - tpr_at_fpr_test, roc_auc_test = mia_performance_test.values() - tpr_at_fpr_train = mia_performance_train["max_tpr"] - roc_auc_train = mia_performance_train["roc_auc"] - tpr_at_fpr_val = mia_performance_val["max_tpr"] - roc_auc_val = mia_performance_val["roc_auc"] - tpr_at_fpr_test = mia_performance_test["max_tpr"] - roc_auc_test = mia_performance_test["roc_auc"] + mia_performance_train, mia_performance_val, mia_performance_test = tf_attack(**config) + tpr_at_fpr_train = mia_performance_train["max_tpr"] + roc_auc_train = mia_performance_train["roc_auc"] + tpr_at_fpr_val = mia_performance_val["max_tpr"] + roc_auc_val = mia_performance_val["roc_auc"] + tpr_at_fpr_test = mia_performance_test["max_tpr"] + roc_auc_test = mia_performance_test["roc_auc"]
69-119: Consider deleting the large commented-out alternative test with absolute paths.This block is dead code and contains hardcoded, developer-specific absolute paths. Keeping it commented out adds noise and may confuse future readers about which setup is authoritative. Prefer removing it or moving the scenario into a separate, properly parameterized test/config.
src/midst_toolkit/attacks/tf/tf_attack.py (2)
1-7: Remove or narrow unusedruff: noqadirective.Ruff reports this
noqaas unused forD102, D105, D103, D200, PLR0915. If those checks aren’t actually enabled, the directive is unnecessary noise. Either remove it or restrict it to the specific rules you need to silence.-# ruff: noqa: D102, D105, D103, D200, PLR0915 +# ruff: noqaor simply delete the line if you don’t need it at all.
113-181: Guardget_datasetagainst missingtarget_model_dir/batch_sizeinstead of relying on fragile defaults.
get_datasetdefaultstarget_model_dir=Noneandbatch_size=None, but the body assumes both are set:
os.path.join(target_model_dir, ...)will fail iftarget_model_dirisNone.prepare_fast_dataloader(..., batch_size=batch_size, ...)likely expects an integer.From this module it’s always called with non-
Nonearguments, but the signature suggests they’re optional. Either make them required or fail fast with clear errors:-def get_dataset(data_path, target_model_dir=None, train_name="train_with_id.csv", batch_size=None, meta_dir=""): +def get_dataset(data_path, target_model_dir=None, train_name="train_with_id.csv", batch_size=None, meta_dir=""): @@ - tables, relation_order, _ = load_multi_table_customized( + if target_model_dir is None: + raise ValueError("target_model_dir must be provided.") + if batch_size is None: + raise ValueError("batch_size must be provided.") + + tables, relation_order, _ = load_multi_table_customized( data_path, meta_dir=meta_dir, train_name=train_name,This makes incorrect external use of
get_datasetfail deterministically instead of with a less obviousTypeError.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (52)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/data_for_training_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/test.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/data_for_training_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/test.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/data_for_validating_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/test.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/data_for_validating_MIA.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/test.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_synthetic.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/None_trans_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_label.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/cluster_ckpt.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_2.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_222.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/test.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/train_with_id.csvis excluded by!**/*.csvtests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_label_encoders.pklis excluded by!**/*.pkltests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_synthetic.csvis excluded by!**/*.csv
📒 Files selected for processing (18)
src/midst_toolkit/attacks/tf/classification.py(1 hunks)src/midst_toolkit/attacks/tf/data_utils.py(1 hunks)src/midst_toolkit/attacks/tf/tf_attack.py(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json(1 hunks)tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args(1 hunks)tests/integration/attacks/tf/data_configs/dataset_meta.json(1 hunks)tests/integration/attacks/tf/data_configs/trans.json(1 hunks)tests/integration/attacks/tf/data_configs/trans_domain.json(1 hunks)tests/integration/attacks/tf/test_tf_attack.py(1 hunks)tests/unit/evaluation/quality/test_alpha_precision_naive.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/unit/evaluation/quality/test_alpha_precision_naive.py (4)
src/midst_toolkit/common/random.py (2)
set_all_random_seeds(11-55)unset_all_random_seeds(58-67)src/midst_toolkit/data_processing/midst_data_processing.py (2)
load_midst_data(94-121)process_midst_data_for_alpha_precision_evaluation(17-91)src/midst_toolkit/evaluation/utils.py (2)
extract_columns_based_on_meta_info(45-87)one_hot_encode_categoricals_and_merge_with_numerical(90-128)tests/utils/architecture.py (1)
is_apple_silicon(4-6)
src/midst_toolkit/attacks/tf/data_utils.py (2)
src/midst_toolkit/models/clavaddpm/dataset.py (1)
Dataset(77-397)src/midst_toolkit/models/clavaddpm/enumerations.py (1)
Normalization(58-63)
src/midst_toolkit/attacks/tf/tf_attack.py (4)
src/midst_toolkit/attacks/tf/classification.py (2)
MLP(10-42)fitmodel(67-161)src/midst_toolkit/attacks/tf/data_utils.py (6)
CustomUnpickler(51-56)TaskType(42-48)evaluate_attack_performance(185-218)load_multi_table_customized(59-103)prepare_data_for_attack(138-171)prepare_fast_dataloader(267-304)src/midst_toolkit/models/clavaddpm/dataset.py (3)
Dataset(77-397)from_df(276-397)size(199-211)src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py (3)
sample_time(945-982)gaussian_q_sample(302-321)_gaussian_loss(505-543)
🪛 Ruff (0.14.7)
src/midst_toolkit/attacks/tf/data_utils.py
1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200)
Remove unused noqa directive
(RUF100)
59-59: Unused function argument: verbose
(ARG001)
75-75: Avoid specifying long messages outside the exception class
(TRY003)
95-95: Avoid specifying long messages outside the exception class
(TRY003)
101-101: Avoid specifying long messages outside the exception class
(TRY003)
143-143: Avoid specifying long messages outside the exception class
(TRY003)
165-165: Avoid specifying long messages outside the exception class
(TRY003)
193-193: Avoid specifying long messages outside the exception class
(TRY003)
src/midst_toolkit/attacks/tf/tf_attack.py
1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200, PLR0915)
Remove unused noqa directive
(RUF100)
82-82: Avoid specifying long messages outside the exception class
(TRY003)
222-222: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (11)
src/midst_toolkit/attacks/tf/data_utils.py (3)
70-103: Table loading and validation logic looks sound for single-table useThe per‑table loading/validation flow (train CSV +
{table}_domain.json, ID column removal,'?'value check, and numeric‑column string guard) is consistent and defensive. For the current test dataset (single"trans"table), this is straightforward and appropriate.
185-219: Attack evaluation aggregation and metric computation look coherentThe aggregation in
evaluate_attack_performance—looping overindices, skipping missing files, concatenating predictions and labels, then computing max TPR at a fixed FPR and ROC AUC—is consistent and matches the intended evaluation logic. Guard clauses for emptyindicesand no predictions give safe fallbacks.Please double‑check that the shape of
challenge_label.csv(loaded vianp.loadtxt(..., skiprows=1)) matches the expected 1D label array soroc_curvereceives the correcty_true.
221-265:FastTensorDataLoaderimplementation is minimal and correctThe loader enforces equal leading dimensions, supports optional shuffling via a permuted index, and computes the correct number of batches including a final partial batch.
__iter__/__next__semantics are standard and should integrate cleanly withforloops oryield from.tests/integration/attacks/tf/data_configs/dataset_meta.json (1)
1-1: Dataset meta format is minimal but consistent with loader expectationsSingle‑table
relation_orderandtables.transstructure align with howload_multi_table_customizedreadsdataset_meta.json. No issues spotted.tests/integration/attacks/tf/data_configs/trans.json (1)
1-50: Config parameters and relative paths look appropriate for testsThe general/clustering/diffusion/classifier/sampling/matching sections are consistent and use repo‑relative paths (including
workspace_dir), which is good for portability in CI and local runs.tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json (1)
1-1: Domain descriptor matches the expected schemaField names, sizes, and types mirror the other
trans_domain.jsonassets, so this should integrate smoothly with the domain‑aware tooling.tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1)
1-1: Tabddpm_1 domain metadata is consistent with other variantsSchema and typing for all fields match the other
trans_domain.jsonassets, which keeps the attack tests’ domain assumptions uniform.tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json (1)
1-1: Tabddpm_5 domain descriptor aligns with the rest of the suiteField set and type declarations are consistent with the other tabddpm model domain files, so downstream utilities can treat all models uniformly.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json (1)
1-1: LGTM!Valid JSON domain metadata defining feature sizes and types for test fixtures. Structure is consistent with other
trans_domain.jsonfiles in the PR.tests/integration/attacks/tf/data_configs/trans_domain.json (1)
1-1: LGTM!Valid JSON domain metadata for test data configuration.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json (1)
1-1: LGTM!Valid JSON domain metadata consistent with other tabddpm model test assets.
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json
Outdated
Show resolved
Hide resolved
| set_all_random_seeds( | ||
| seed=133742, | ||
| use_deterministic_torch_algos=True, | ||
| disable_torch_benchmarking=True, | ||
| ) | ||
|
|
||
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure RNG and env cleanup even when the test fails.
If any assertion fails before the end of the test, unset_all_random_seeds() and the CUBLAS_WORKSPACE_CONFIG cleanup won’t run, leaving global state polluted for subsequent tests. Wrap the body in a try/finally:
def test_tf_attack_whitebox_small_config():
- # Set deterministic behavior
- set_all_random_seeds(
- seed=133742,
- use_deterministic_torch_algos=True,
- disable_torch_benchmarking=True,
- )
-
- os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-
- base_path = ...
- ...
- assert tpr_at_fpr_test == pytest.approx(0.12, abs=1e-8)
-
- unset_all_random_seeds()
- os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
+ # Set deterministic behavior
+ set_all_random_seeds(
+ seed=133742,
+ use_deterministic_torch_algos=True,
+ disable_torch_benchmarking=True,
+ )
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+ try:
+ base_path = ...
+ ...
+ assert tpr_at_fpr_test == pytest.approx(0.12, abs=1e-8)
+ finally:
+ unset_all_random_seeds()
+ os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)Also applies to: 65-66
🤖 Prompt for AI Agents
In tests/integration/attacks/tf/test_tf_attack.py around lines 15-21 (and
similarly at lines 65-66), the test sets global RNG seeds and the
CUBLAS_WORKSPACE_CONFIG env var but does not guarantee cleanup if the test
fails; wrap the test body that calls set_all_random_seeds and sets
os.environ["CUBLAS_WORKSPACE_CONFIG"] in a try/finally so that
unset_all_random_seeds() and deletion (or restoration) of
CUBLAS_WORKSPACE_CONFIG always run in the finally block; apply the same
try/finally pattern to the other occurrence at lines 65-66 to ensure global
state is restored even on assertion errors.
emersodb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some preliminary comments for you to consider. I know this is a first pass at getting this code into the toolkit. So we don't need to address everything. However, I think most of my comments are at least worth thinking about addressing to true to improve the clarity of the code.
We'll certainly have to work on making it a bit easier to use.
lotif
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a partial review because there seems to be some bigger issues with this (e.g. tartan_federer and tf folders seem to have duplicated code).
Also, both data_configs and test_tf_attack_results folders should live inside the assets folder. It's annoying we can't add comments on the folders themselves in github.
|
|
||
|
|
||
| class CustomUnpickler(pickle.Unpickler): | ||
| def find_class(self, module: str, name: str) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing docstrings here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, this return type can be Callable instead of Any.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is overriding a specific function within the Unpickler class, I think we need to keep the signature the same?
src/midst_toolkit/attacks/tartan_federer/tartan_federer_attack.py
Outdated
Show resolved
Hide resolved
| Returns: | ||
| _description_ | ||
| """ | ||
| df_train_merge, _, _ = prepare_data_for_attack( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see where these are being used at all. I left them, because they were part of your original implementation, but it seems like they aren't used in the training, we just use the dataloader?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So these were used to create population dataset within the structure of the competition for training and validating the classifier and later passed onto prepare_data. A more straight forward approach is to just load all of the transaction table, take out the training points for the indices corresponding to the challenge dataset.
| for timestep in timesteps: | ||
| for additional_timestep in additional_timesteps: | ||
| if model_number in train_indices: | ||
| batch_size = samples_per_train_model * 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really understand why we're multiplying samples_per_train_model by 2 here to get the batch size (same with the validation set up). The batch size is very heavily tied to the overall dataset size and will throw an error if not set correctly. The way it's setup now, if you don't set samples_per_train_model just right, it will break.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's say the number of samples is 100. Then we take 100 samples from the training points (members) used to train the model and 100 samples from the rest of the data available to the adversary. This is why we multiply it by 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see. That makes sense in theory. I don't think that's what actually happening in practice for the code though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay. I think I'm starting too see this now.
src/midst_toolkit/attacks/tartan_federer/tartan_federer_attack.py
Outdated
Show resolved
Hide resolved
.../integration/attacks/tartan_federer/assets/tartan_federer_attack_results/mia_performance.txt
Outdated
Show resolved
Hide resolved
lotif
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks a lot better, thanks for addressing the earlier comments.
Approving with mostly minor comments. If you can't address them right now, you can merge and address them in a follow up PR.
src/midst_toolkit/attacks/tartan_federer/tartan_federer_attack.py
Outdated
Show resolved
Hide resolved
src/midst_toolkit/attacks/tartan_federer/tartan_federer_attack.py
Outdated
Show resolved
Hide resolved
src/midst_toolkit/attacks/tartan_federer/tartan_federer_attack.py
Outdated
Show resolved
Hide resolved
src/midst_toolkit/attacks/tartan_federer/tartan_federer_attack.py
Outdated
Show resolved
Hide resolved
| Returns: | ||
| Filtered dataframe reading for classifier training (or testing) | ||
| """ | ||
| raw_data = pd.read_csv(model_dir / "train_with_id.csv") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the lines above, we should make this into a parameter with the default value being a constant with value "train_with_id.csv".
| Args: | ||
| model_dir: Model directory from which to load data. This directory must contain a file named | ||
| "train_with_id.csv" and "data_for_training_MIA.csv" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name of the MIA dataset is defined by the mia_dataset_name parameter, though.
Maybe we could make a constant on this file with the name of these datasets so we don't need to be hardcoding them in multiple places? The default value of that parameter can also be set to this constant.
src/midst_toolkit/attacks/tartan_federer/tartan_federer_attack.py
Outdated
Show resolved
Hide resolved
| input_noise, | ||
| model_type, | ||
| meta_dir=meta_dir, | ||
| challenge_name="challenge_with_id.csv", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be put into a constant and replaced where it's being hardcoded.
PR Type
Feauture:
TF Attack
Short Description
Adding TF attack to the midst toolkit. Refactored the code, made sure it works with midst toolkit, wrote an integration test.
#Next step: