Skip to content

Conversation

@bzamanlooy
Copy link
Collaborator

PR Type

Feauture:
TF Attack

Short Description

Adding TF attack to the midst toolkit. Refactored the code, made sure it works with midst toolkit, wrote an integration test.

#Next step:

  • fix typing errors
  • unit tests

@coderabbitai
Copy link

coderabbitai bot commented Dec 3, 2025

📝 Walkthrough

Walkthrough

This PR introduces a comprehensive membership inference attack (MIA) framework for tabular differentially private models. The implementation includes three new Python modules providing data loading utilities (data_utils.py), diffusion-based attack orchestration (tf_attack.py), and classifier training infrastructure (classification.py). Supporting infrastructure consists of integration and unit tests, along with JSON configuration and metadata files for test assets. The attack workflow combines dataset preparation, diffusion score computation, MLP classifier training, and performance evaluation across multiple model instances and timesteps.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • src/midst_toolkit/attacks/tf/tf_attack.py: Core attack orchestration with mixed_loss computation for diffusion models, dataset preparation, classifier training, and end-to-end attack pipeline; contains multiple interconnected functions with conditional logic and external dependencies.
  • src/midst_toolkit/attacks/tf/data_utils.py: Substantial data utilities module with 8+ functions spanning ROC plotting, multi-table dataset loading, deduplication logic, attack data preparation, and performance evaluation; complex validation and error handling patterns.
  • src/midst_toolkit/attacks/tf/classification.py: PyTorch-based neural network training module with custom loss functions, model checkpointing, and iterative training loops; moderate complexity but requires verification of optimization and loss computation correctness.
  • tests/integration/attacks/tf/test_tf_attack.py: Integration test with hardcoded performance assertions; verify expected ROC-AUC and TPR@FPR thresholds are achievable and representative.
  • JSON configuration and metadata files (trans_domain.json, args, updated_config.json variants): Homogeneous, low-complexity additions primarily for test infrastructure.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The PR title 'Bz/tf' is a branch name reference that does not meaningfully describe the changeset. It does not convey what the PR does or what the main change is. Revise the title to clearly describe the main change, e.g., 'Add TensorFlow membership attack module to midst toolkit' or 'Implement TF attack for membership inference attacks'.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed The PR description covers PR Type and a brief description of the changes, but lacks detail on test coverage and is incomplete per the repository template.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch bz/tf

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 17

Note

Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.

🟡 Minor comments (4)
src/midst_toolkit/attacks/tf/data_utils.py-157-165 (1)

157-165: Deduplication key validation happens after drop_duplicates is called.

The validation for missing keys (lines 162-165) occurs after drop_duplicates is already called (lines 158-159). If keys are missing, drop_duplicates will raise a KeyError before your descriptive ValueError is reached.

+    # Ensure all keys for deduplication exist in both DataFrames before deduplication
+    missing_keys_merge = [key for key in keys_for_deduplication if key not in df_merge.columns]
+    missing_keys_challenge = [key for key in keys_for_deduplication if key not in df_challenge.columns]
+    if missing_keys_merge or missing_keys_challenge:
+        raise ValueError(f"Missing columns for deduplication: {missing_keys_merge + missing_keys_challenge}")
+
     # Deduplicate the datasets once
     df_merge = df_merge.drop_duplicates(subset=keys_for_deduplication)
     df_challenge = df_challenge.drop_duplicates(subset=keys_for_deduplication)
-
-    # Ensure all keys for deduplication exist in both DataFrames
-    missing_keys_merge = [key for key in keys_for_deduplication if key not in df_merge.columns]
-    missing_keys_challenge = [key for key in keys_for_deduplication if key not in df_challenge.columns]
-    if missing_keys_merge or missing_keys_challenge:
-        raise ValueError(f"Missing columns for deduplication: {missing_keys_merge + missing_keys_challenge}")
src/midst_toolkit/attacks/tf/data_utils.py-174-179 (1)

174-179: Potential ValueError on edge case.

If all FPR values are >= max_fpr, tpr[fpr < max_fpr] will be an empty array and max() will raise a ValueError.

 def get_tpr_at_fpr(true_membership: list[int], predictions: list[float], max_fpr: float = 0.1) -> float:
     """
     Calculates the best True Positive Rate when the False Positive Rate is at most `max_fpr`.
     """
     fpr, tpr, _ = roc_curve(true_membership, predictions)
-    return max(tpr[fpr < max_fpr])
+    valid_tpr = tpr[fpr <= max_fpr]
+    if len(valid_tpr) == 0:
+        return 0.0
+    return float(max(valid_tpr))
src/midst_toolkit/attacks/tf/classifcation.py-133-134 (1)

133-134: Condition x_val is not None is always true after tensor conversion.

On line 116, x_val is unconditionally converted to a tensor via torch.tensor(x_val, ...). If the original x_val parameter was None, this would raise an error before reaching line 133. The check should happen before tensor conversion.

+    has_validation = x_val is not None
     x_train = torch.tensor(x_train, dtype=torch.float32).to(device)
     y_train = torch.tensor(x_train_label, dtype=torch.float32).to(device)
-    x_val = torch.tensor(x_val, dtype=torch.float32).to(device)
-    y_test = torch.tensor(x_val_label, dtype=torch.float32).to(device)
+    if has_validation:
+        x_val = torch.tensor(x_val, dtype=torch.float32).to(device)
+        y_val = torch.tensor(x_val_label, dtype=torch.float32).to(device)

Then use has_validation in the condition on line 133.

Committable suggestion skipped: line range outside the PR's diff.

tests/integration/attacks/tf/test_tf_attack.py-51-53 (1)

51-53: Use explicit key access instead of dict unpacking for clarity.

The dictionaries returned have keys "max_tpr" and "roc_auc". While Python 3.7+ guarantees dictionary insertion order, explicit key access makes the code clearer and more maintainable.

-    tpr_at_fpr_train, roc_auc_train = mia_performance_train.values()
-    tpr_at_fpr_val, roc_auc_val = mia_performance_val.values()
-    tpr_at_fpr_test, roc_auc_test = mia_performance_test.values()
+    tpr_at_fpr_train = mia_performance_train["max_tpr"]
+    roc_auc_train = mia_performance_train["roc_auc"]
+    tpr_at_fpr_val = mia_performance_val["max_tpr"]
+    roc_auc_val = mia_performance_val["roc_auc"]
+    tpr_at_fpr_test = mia_performance_test["max_tpr"]
+    roc_auc_test = mia_performance_test["roc_auc"]
🧹 Nitpick comments (15)
src/midst_toolkit/attacks/tf/data_utils.py (2)

1-6: Excessive linting suppressions reduce code quality.

Suppressing D102, D105, D103, D200 (docstring rules) and multiple mypy error codes across the entire file is a significant code smell. As per the PR description, "fix typing errors" is listed as a next step — consider addressing these rather than suppressing them.


59-59: Unused verbose parameter.

The verbose parameter is declared but never used in the function body. Either implement verbose logging or remove the parameter.

src/midst_toolkit/attacks/tf/classifcation.py (2)

59-62: Remove dead code: x = x.float() is unused.

The variable x is reassigned on line 60 but never used afterward. This appears to be leftover code.

 def custom_loss_fn(model, x, y):
     confidences = model(x)
-    x = x.float()
     y = y.float()
     return nn.BCELoss()(confidences, y.unsqueeze(1))

114-117: Inconsistent variable naming: y_test vs x_val.

Line 117 uses y_test for labels corresponding to x_val, mixing "test" and "val" terminology. This is confusing given the function parameters use x_val and x_val_label.

     x_train = torch.tensor(x_train, dtype=torch.float32).to(device)
     y_train = torch.tensor(x_train_label, dtype=torch.float32).to(device)
     x_val = torch.tensor(x_val, dtype=torch.float32).to(device)
-    y_test = torch.tensor(x_val_label, dtype=torch.float32).to(device)
+    y_val = torch.tensor(x_val_label, dtype=torch.float32).to(device)

Then update references on lines 134, 150 accordingly.

src/midst_toolkit/attacks/tf/tf_attack.py (11)

1-6: Clean up unused lint suppressions.

Static analysis indicates the noqa directive on line 1 is unused. Since the PR author already plans to "fix typing errors", consider removing unnecessary suppressions once proper type annotations are added rather than blanket-disabling mypy checks.


37-82: Several issues in mixed_loss function.

  1. Unused parameter: no_mean is never used (Line 44).
  2. Redundant device assignment: device is assigned on line 51 but immediately overwritten on line 61.
  3. Typo: "defeualt" → "default" (Line 63).
  4. Redundant conditional: The check if not return_random: on line 74 is always True since we already returned on line 67 when return_random=True.
 def mixed_loss(
     diffusion,
     x,
     out_dict,
     noise=None,
     t=None,
     return_random=False,
-    no_mean=False,
     parallel_batch=None,
     addt_value=None,
 ):
     x_num = x[:, : diffusion.num_numerical_features]
     x_cat = x[:, diffusion.num_numerical_features :]

-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    device = x.device
     noise_tensor = torch.tensor(noise, device=device, dtype=torch.float)
     batch_noise = noise_tensor.repeat(x_num.shape[0], 1)

     # there is actually no categorical classes, as we have examined the DM, so we just ignore x_cat here and later
     x_num = x_num.repeat_interleave(parallel_batch, dim=0)
     x_cat = x_cat.repeat_interleave(parallel_batch, dim=0)

     b = x_num.shape[0]

-    device = x.device
     if t is None:
-        # the defeualt is uniform sampling
+        # the default is uniform sampling
         t, pt = diffusion.sample_time(b, device)

     if return_random:
         return noise, t, pt

     additional_t = t * 0 + addt_value

     # forward x_num_t with (t+additional_t) timestamps
     x_num_t = diffusion.gaussian_q_sample(x_num, t + additional_t, noise=batch_noise)

-    if not return_random:
-        current_t = t
-        # predict noises with t timestamps
-        predicted_noise = diffusion._denoise_fn(x_num_t, current_t, **out_dict)
-        current_loss = diffusion._gaussian_loss(predicted_noise, batch_noise, batch_noise, current_t, batch_noise)
-        transformed_current_loss = current_loss.reshape(-1, parallel_batch)
+    current_t = t
+    # predict noises with t timestamps
+    predicted_noise = diffusion._denoise_fn(x_num_t, current_t, **out_dict)
+    current_loss = diffusion._gaussian_loss(predicted_noise, batch_noise, batch_noise, current_t, batch_noise)
+    transformed_current_loss = current_loss.reshape(-1, parallel_batch)

     return transformed_current_loss * 0, transformed_current_loss

136-142: Prefix unused variables with underscore.

label_encoders and column_orders are unpacked but never used. Per convention, prefix with underscore to indicate intentional non-use.

-        dataset, label_encoders, column_orders = Dataset.from_df(
+        dataset, _label_encoders, _column_orders = Dataset.from_df(

207-229: Misplaced docstring.

The docstring is placed in the middle of the function after executable code (lines 201-205). It should be the first statement after the function signature to be recognized by documentation tools and IDEs.

Move the docstring to immediately after line 198 (def get_score(...):) before any executable statements.


247-252: Remove debug print and simplify loop logic.

  1. Debug print: print(iter_max) should be removed or replaced with proper logging.
  2. Confusing loop: The assert iter_max == 1 followed by while iter_id < iter_max: means the loop always runs exactly once. Consider simplifying to remove the loop entirely or documenting why this structure exists for future extensibility.
-        print(iter_max)
         iter_max = iter_max // batch_size
         return_res = torch.zeros([batch_size, parallel_batch])
         assert iter_max == 1
-        iter_id = 0
-        while iter_id < iter_max:
+        # Process single batch (currently only supports iter_max == 1)
+        x, out_dict = next(train_loader)
         ...
-            iter_id += 1

236-236: Prefix unused variable with underscore.

challenge_dataset is unpacked but never used.

-        train_loader, iter_max, challenge_dataset = train_loader_list[loader_count]
+        train_loader, iter_max, _challenge_dataset = train_loader_list[loader_count]

261-269: Prefix unused variables with underscore.

noise and pt are unpacked but never used in this context.

-                noise, t_cur, pt = mixed_loss(
+                _noise, t_cur, _pt = mixed_loss(

309-316: Hardcoded deduplication keys reduce reusability.

The keys ["trans_id", "balance"] are hardcoded in both prepare_data_for_attack calls. Consider making these configurable via a parameter.


479-479: Document magic number for noise dimension.

The noise dimension size=8 is hardcoded without explanation. Consider extracting this to a named constant or parameter with documentation explaining why 8 is the appropriate value.

+    NOISE_DIMENSION = 8  # Must match the diffusion model's expected noise dimension
-    input_noise: list[list[float]] = [np.random.normal(size=8).tolist() for _ in range(num_noise_per_time_step)]
+    input_noise: list[list[float]] = [np.random.normal(size=NOISE_DIMENSION).tolist() for _ in range(num_noise_per_time_step)]

187-198: Unused phase parameter.

The phase parameter is accepted but never used in the function body. Either remove it or implement the intended behavior.

 def get_score(
     data_path,
     save_dir,
     input_noise,
     type="tabddpm",
-    phase=None,
     challenge_name=None,
     batch_size=None,
     parallel_batch=None,
     addt_value=None,
     t_value=None,
 ):

If phase is intended for future use, add a TODO comment or raise NotImplementedError when a non-None value is passed.


191-191: Avoid shadowing built-in type.

Using type as a parameter name shadows Python's built-in type() function, which can cause subtle bugs if the built-in is needed within this function.

 def get_score(
     data_path,
     save_dir,
     input_noise,
-    type="tabddpm",
+    model_type="tabddpm",
     ...
 ):
-    if type == "tabddpm":
+    if model_type == "tabddpm":
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6eb18e9 and 62a6fec.

⛔ Files ignored due to path filters (49)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/data_for_training_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/data_for_training_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/data_for_validating_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/data_for_validating_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_synthetic.csv is excluded by !**/*.csv
📒 Files selected for processing (22)
  • src/midst_toolkit/attacks/tf/classifcation.py (1 hunks)
  • src/midst_toolkit/attacks/tf/data_utils.py (1 hunks)
  • src/midst_toolkit/attacks/tf/tf_attack.py (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/updated_config.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/updated_config.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/updated_config.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/data_configs/dataset_meta.json (1 hunks)
  • tests/integration/attacks/tf/data_configs/trans.json (1 hunks)
  • tests/integration/attacks/tf/data_configs/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/test_tf_attack.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (13)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args (2)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
  • test_save_additional_tabddpm_config (19-54)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json (2)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
src/midst_toolkit/models/clavaddpm/model.py (1)
  • DiffusionParameters (19-31)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/updated_config.json (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
  • test_train_and_fine_tune_tabddpm (135-187)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
  • test_save_additional_tabddpm_config (19-54)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args (1)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/workspace/train_1/args (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
  • test_train_and_fine_tune_tabddpm (135-187)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
  • test_save_additional_tabddpm_config (19-54)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json (2)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
  • test_save_additional_tabddpm_config (19-54)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
tests/integration/attacks/tf/data_configs/trans_domain.json (1)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
  • test_train_and_fine_tune_tabddpm (135-187)
tests/integration/attacks/tf/data_configs/trans.json (2)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
  • test_train_and_fine_tune_tabddpm (135-187)
src/midst_toolkit/attacks/ensemble/blending.py (1)
  • __init__ (27-66)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/updated_config.json (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
  • test_train_and_fine_tune_tabddpm (135-187)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
  • test_save_additional_tabddpm_config (19-54)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
  • test_train_and_fine_tune_tabddpm (135-187)
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/updated_config.json (3)
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
  • test_train_and_fine_tune_tabddpm (135-187)
tests/unit/attacks/ensemble/test_shadow_model_utils.py (1)
  • test_save_additional_tabddpm_config (19-54)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (1)
  • save_additional_tabddpm_config (36-76)
tests/integration/attacks/tf/test_tf_attack.py (2)
src/midst_toolkit/attacks/tf/tf_attack.py (1)
  • tf_attack (459-559)
src/midst_toolkit/common/random.py (2)
  • set_all_random_seeds (11-55)
  • unset_all_random_seeds (58-67)
src/midst_toolkit/attacks/tf/data_utils.py (3)
src/midst_toolkit/models/clavaddpm/dataset.py (1)
  • Dataset (77-397)
src/midst_toolkit/models/clavaddpm/enumerations.py (1)
  • Normalization (58-63)
src/midst_toolkit/models/clavaddpm/data_loaders.py (1)
  • FastTensorDataLoader (473-537)
🪛 Ruff (0.14.7)
src/midst_toolkit/attacks/tf/tf_attack.py

1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200, PLR0915)

Remove unused noqa directive

(RUF100)


44-44: Unused function argument: no_mean

(ARG001)


85-85: Avoid specifying long messages outside the exception class

(TRY003)


136-136: Unpacked variable label_encoders is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


136-136: Unpacked variable column_orders is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


192-192: Unused function argument: phase

(ARG001)


203-203: Avoid specifying long messages outside the exception class

(TRY003)


236-236: Unpacked variable challenge_dataset is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


261-261: Unpacked variable noise is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


261-261: Unpacked variable pt is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

src/midst_toolkit/attacks/tf/data_utils.py

1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200)

Remove unused noqa directive

(RUF100)


59-59: Unused function argument: verbose

(ARG001)


75-75: Avoid specifying long messages outside the exception class

(TRY003)


95-95: Avoid specifying long messages outside the exception class

(TRY003)


101-101: Avoid specifying long messages outside the exception class

(TRY003)


143-143: Avoid specifying long messages outside the exception class

(TRY003)


165-165: Avoid specifying long messages outside the exception class

(TRY003)


190-190: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: run-code-check
  • GitHub Check: unit-tests
  • GitHub Check: integration-tests
🔇 Additional comments (16)
src/midst_toolkit/attacks/tf/data_utils.py (1)

218-261: LGTM!

The FastTensorDataLoader implementation correctly handles batching and shuffling. The pattern matches the reference implementation from src/midst_toolkit/models/clavaddpm/data_loaders.py.

tests/integration/attacks/tf/data_configs/trans.json (1)

1-50: LGTM!

This configuration uses relative paths consistently, making it portable across different environments unlike some of the other tabddpm config files.

tests/integration/attacks/tf/data_configs/dataset_meta.json (1)

1-1: ✓ Test metadata structure is appropriate.

The JSON correctly describes a single-table dataset with no relationships, which aligns with the trans domain schema and test configuration patterns.

tests/integration/attacks/tf/data_configs/trans_domain.json (1)

1-1: ✓ Domain schema is consistent and well-formed.

The feature definitions (8 fields with appropriate continuous/discrete classifications) align with the dataset_meta.json and are consistent across all tabddpm model variants used in tests.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1)

1-1: ✓ Consistent test asset.

Domain schema matches across all tabddpm variants, which is appropriate for uniform test configuration.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json (1)

1-1: ✓ Consistent test asset.

Matches schema across all tabddpm variants for uniform testing.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json (1)

1-1: ✓ Consistent test asset.

Matches schema across all tabddpm variants for uniform testing.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json (1)

1-1: ✓ Consistent test asset.

Matches schema across all tabddpm variants for uniform testing.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/updated_config.json (2)

1-8: Verify absolute workspace_dir path will not break tests.

Line 5 contains an absolute path /projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_2/workspace that does not exist in typical CI/test environments. This may cause test failures when the configuration is used to create directories or write outputs.

Check whether:

  • The integration test overrides this path at runtime (e.g., using save_additional_tabddpm_config)
  • The code gracefully handles missing parent directories
  • Tests are expected to run in an environment where /projects/ exists

The relative paths for data_dir and test_data_dir (lines 3, 7) are appropriate for test portability.


14-42: ✓ Hyperparameter choices are appropriate for fast integration testing.

Minimal values (iterations=2–3, batch_size=1, num_timesteps=3) prioritize test speed while maintaining realistic configuration structure. This aligns with typical test fixtures for ML frameworks.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json (2)

1-8: Verify absolute workspace_dir path will not break tests.

Line 5 contains an absolute path /projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_5/workspace that does not exist in typical CI/test environments. This matches the same issue in tabddpm_2/updated_config.json and suggests a systematic configuration pattern.

Verify that:

  • The integration test overrides workspace_dir paths at runtime (expected based on save_additional_tabddpm_config utility)
  • Absolute paths are intentionally used as templates and not written to during tests

The relative test data paths (lines 3, 7) are appropriate for portability.


14-42: ✓ Hyperparameter choices are appropriate for fast integration testing.

Configuration values match the appropriate minimal test settings (iterations=2–3, batch_size=1, num_timesteps=3), balancing test speed with realistic model structure.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json (1)

1-1: LGTM!

The domain schema defines the expected field structure with appropriate size and type metadata for the tabddpm_4 model configuration.

tests/integration/attacks/tf/test_tf_attack.py (1)

20-28: Good use of deterministic seeding for reproducible tests.

The combination of set_all_random_seeds with use_deterministic_torch_algos=True and CUBLAS_WORKSPACE_CONFIG environment variable ensures reproducible results, which is essential for testing ML code with fixed expected values.

src/midst_toolkit/attacks/tf/classifcation.py (1)

8-40: LGTM: MLP architecture is straightforward and well-documented.

The 3-layer MLP with tanh activations and sigmoid output is appropriate for binary classification in the attack context.

src/midst_toolkit/attacks/tf/tf_attack.py (1)

81-81: Clarify intent of returning zeroed tensor.

The return statement returns transformed_current_loss * 0 as the first element, which is always a zero tensor. If this is intentional (e.g., for API compatibility or specific algorithm requirements), consider adding a comment explaining why. Otherwise, this may be leftover debugging code.

@bzamanlooy
Copy link
Collaborator Author

@coderabbitai full review

@coderabbitai
Copy link

coderabbitai bot commented Dec 4, 2025

✅ Actions performed

Full review triggered.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (7)
src/midst_toolkit/attacks/tf/data_utils.py (1)

138-171: Still only returning the last model’s challenge_label.csv

Inside prepare_data_for_attack, df_challenge_labels is reassigned on every iteration (Line 152) and never accumulated. With multiple indices, you return labels only from the last model, while df_merge and df_challenge are concatenated across all models. This is the same issue previously flagged.

If you intend to aggregate across models, collect labels in a list and concatenate:

 def prepare_data_for_attack(indices, model_type, models_base_dir, keys_for_deduplication):
@@
-    df_challenge_list = []
+    df_challenge_list = []
+    df_challenge_labels_list = []
@@
-        df_challenge_list.append(pd.read_csv(os.path.join(base_path, "challenge_with_id.csv")))
-        df_challenge_labels = pd.read_csv(os.path.join(base_path, "challenge_label.csv"))
+        df_challenge_list.append(pd.read_csv(os.path.join(base_path, "challenge_with_id.csv")))
+        df_challenge_labels_list.append(pd.read_csv(os.path.join(base_path, "challenge_label.csv")))
@@
-    df_challenge = pd.concat(df_challenge_list, ignore_index=True)
+    df_challenge = pd.concat(df_challenge_list, ignore_index=True)
+    df_challenge_labels = pd.concat(df_challenge_labels_list, ignore_index=True)
@@
-    return df_merge_without_challenge, df_challenge, df_challenge_labels
+    return df_merge_without_challenge, df_challenge, df_challenge_labels

If only the last model’s labels are truly desired, add an explicit comment and enforce that indices has length 1.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args (1)

2-8: Replace hardcoded absolute workspace_dir with a project-relative path

workspace_dir is an absolute path (Line 5), which will fail on other machines and CI. For a test asset, it should be under the repo with a relative path, similar to data_dir and other configs.

-        "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_4/workspace",
+        "workspace_dir": "tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace",
tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json (1)

5-5: Hardcoded absolute path breaks CI/portability.

This issue was previously flagged. The workspace_dir contains an absolute path specific to a particular machine, which will fail in CI and for other developers.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args (1)

5-5: Hardcoded absolute path breaks CI/portability.

This issue was previously flagged. The workspace_dir contains an absolute path specific to a particular machine, which will fail in CI and for other developers.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args (1)

3-7: Avoid hardcoded absolute workspace_dir in test config.

"workspace_dir" is an absolute, developer/cluster-specific path and will not exist on other machines/CI. Use a project-relative path (like the other fields) or a placeholder that your test harness fills in.

-        "workspace_dir": "/projects/midst-experiments/tabddpm_midst_toolkit/train/tabddpm_5/workspace",
+        "workspace_dir": "tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace",
src/midst_toolkit/attacks/tf/tf_attack.py (2)

383-435: Use addt_value_list when building training/validation features; current loop ignores it.

Here you still hardcode for addt_value in [0]:, so addt_value_list is effectively unused. At the same time, the feature matrix width is computed as len(input_noise) * len(timesteps_list) * len(addt_value_list), meaning any addt_value_list with length > 1 leaves trailing zero columns, and the classifier input_dim is larger than the actually populated features.

Refactor the nested loop to iterate over addt_value_list and index feature blocks accordingly:

-        t_value_count = 0
-        for t_value in timesteps_list:
-            for addt_value in [0]:
+        feature_block_index = 0
+        for t_value in timesteps_list:
+            for addt_value in addt_value_list:
                 if model_number in train_indices:
@@
-                    x_train[
-                        samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1),
-                        t_value_count * num_noise_per_time_step : (t_value_count + 1) * num_noise_per_time_step,
-                    ] = predictions.detach().squeeze().cpu().numpy()
+                    start = feature_block_index * num_noise_per_time_step
+                    end = (feature_block_index + 1) * num_noise_per_time_step
+                    x_train[
+                        samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1),
+                        start:end,
+                    ] = predictions.detach().squeeze().cpu().numpy()
@@
-                    x_train_label[
-                        samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1)
-                    ] = np.concatenate([np.zeros(samples_per_train_model), np.ones(samples_per_train_model)])
-                    t_value_count += 1
+                    x_train_label[
+                        samples_per_train_model * 2 * train_count : samples_per_train_model * 2 * (train_count + 1)
+                    ] = np.concatenate([np.zeros(samples_per_train_model), np.ones(samples_per_train_model)])
+                    feature_block_index += 1
@@
-                    x_val[
-                        sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1),
-                        t_value_count * num_noise_per_time_step : (t_value_count + 1) * num_noise_per_time_step,
-                    ] = predictions.detach().squeeze().cpu().numpy()
+                    start = feature_block_index * num_noise_per_time_step
+                    end = (feature_block_index + 1) * num_noise_per_time_step
+                    x_val[
+                        sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1),
+                        start:end,
+                    ] = predictions.detach().squeeze().cpu().numpy()
@@
-                    x_val_label[sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1)] = (
-                        np.concatenate([np.zeros(sample_per_val_model), np.ones(sample_per_val_model)])
-                    )
-                    t_value_count += 1
+                    x_val_label[sample_per_val_model * 2 * val_count : sample_per_val_model * 2 * (val_count + 1)] = (
+                        np.concatenate([np.zeros(sample_per_val_model), np.ones(sample_per_val_model)])
+                    )
+                    feature_block_index += 1

This way, x_train/x_val and the MLP’s input_dim stay consistent for arbitrary addt_value_list.


503-528: Respect addt_value_list, avoid hardcoded batch_size, and guard min–max normalization against zero range.

Three issues remain in tf_attack’s scoring loop:

  1. Hardcoded batch_size = 200:

    • This magic number is baked into the function, making it harder to reuse in other scenarios or tests.
  2. addt_value_list is ignored:

    • The loop still uses for addt_value in [0]:, so callers can’t vary addt_value despite it being a parameter.
  3. Potential division by zero in min–max normalization:

    • If all model outputs are identical (min_output == max_output), (max_output - min_output) is zero, leading to nans and failing the subsequent assertion.

A possible fix:

-def tf_attack(
+def tf_attack(
@@
-    classifier_hidden_dim: int,
-    addt_value_list: list[int],
-    meta_dir: Path,
-) -> tuple[Any, Any, Any]:
+    classifier_hidden_dim: int,
+    addt_value_list: list[int],
+    meta_dir: Path,
+    batch_size: int = 200,
+) -> tuple[Any, Any, Any]:
@@
-        model_dir = tabddpm_data_dir / model_folder
-        model_path = model_dir / target_model_subdir
-        batch_size = 200
-        t_value_count = 0
+        model_dir = tabddpm_data_dir / model_folder
+        model_path = model_dir / target_model_subdir
         current_input = []
         for t_value in timesteps_list:
-            for addt_value in [0]:
+            for addt_value in addt_value_list:
                 predictions: torch.Tensor = get_score(
@@
-                )
-                t_value_count += 1
-                current_input.append(predictions)
+                )
+                current_input.append(predictions)
@@
-        predictions = regression_model(predictions).detach().cpu().numpy()
-        # clip to [0, 1]
-        min_output, max_output = np.min(predictions), np.max(predictions)
-        predictions = (predictions - min_output) / (max_output - min_output)
+        predictions = regression_model(predictions).detach().cpu().numpy()
+        # Rescale to [0, 1] safely
+        min_output, max_output = np.min(predictions), np.max(predictions)
+        range_output = max_output - min_output
+        if range_output > 0:
+            predictions = (predictions - min_output) / range_output
+        else:
+            # All outputs identical; fall back to a neutral constant
+            predictions = np.full_like(predictions, 0.5)
         predictions = torch.tensor(predictions)

This keeps the existing integration test behavior (batch_size defaults to 200, addt_value_list=[0]) while making the function robust and reusable.

🧹 Nitpick comments (8)
src/midst_toolkit/attacks/tf/data_utils.py (3)

1-18: Tighten lint/typing pragmas and clean leftover comment

  • Line 1: # ruff: noqa: D102, D105, D103, D200 is flagged as unused; if these rules aren’t enabled, drop this directive instead of carrying dead config.
  • Lines 2–6: multiple broad mypy disables (no-untyped-def, has-type, index, attr-defined, assignment) effectively opt this module out of type checking. Given the PR TODO “fix typing errors”, it would be better long‑term to narrow these or remove them once annotations are in place.
  • Line 16: Comment # at very top of file (optional but helpful) looks like a review note rather than code documentation and can be removed.
-# ruff: noqa: D102, D105, D103, D200
-# mypy: disable-error-code=no-untyped-def
-# mypy: disable-error-code=has-type
-# mypy: disable-error-code=index
-# mypy: disable-error-code=attr-defined
-# mypy: disable-error-code=assignment
-from __future__ import annotations
-
-# at very top of file (optional but helpful)
+from __future__ import annotations
@@
-from typing import Any, Literal
+from typing import Any, Literal

(And later, once typing is addressed, consider re‑enabling mypy checks instead of disabling them globally for this file.)


59-76: verbose parameter in load_multi_table_customized is currently unused

The verbose argument (Line 59) is never read, which is flagged by static analysis and can confuse callers expecting logging or extra checks.

Either remove the parameter if not needed:

-def load_multi_table_customized(data_dir, meta_dir=None, train_name="train.csv", verbose=True):
+def load_multi_table_customized(data_dir, meta_dir=None, train_name="train.csv"):

or actually use it (e.g., to print/log dataset/table info when verbose is True).


267-304: prepare_fast_dataloader behavior (infinite batches, feature concatenation) is intentional but worth documenting clearly

  • Correctly handles three cases: both numerical+categorical, categorical‑only, and numerical‑only.
  • Returns an endless stream of (x, y) batches via while True: yield from dataloader, reshuffling each epoch when split == "train".

Consider making the “infinite generator” behavior even more explicit in the docstring (e.g., “This is an infinite generator; callers must bound iteration externally”) to avoid misuse in code that expects a single finite epoch.

tests/unit/evaluation/quality/test_alpha_precision_naive.py (1)

53-58: Redundant conditional branches with identical assertions.

Both the if is_apple_silicon() and else branches contain identical assertions. Either the branching is unnecessary and can be simplified, or the expected values for the non-Apple Silicon path should differ.

If the values are truly architecture-independent for naive metrics, simplify:

-    if is_apple_silicon():
-        assert pytest.approx(0.05994074074074074, abs=1e-8) == quality_results["delta_precision_alpha_naive"]
-        assert pytest.approx(0.005229629629629584, abs=1e-8) == quality_results["delta_coverage_beta_naive"]
-    else:
-        assert pytest.approx(0.05994074074074074, abs=1e-8) == quality_results["delta_precision_alpha_naive"]
-        assert pytest.approx(0.005229629629629584, abs=1e-8) == quality_results["delta_coverage_beta_naive"]
+    assert pytest.approx(0.05994074074074074, abs=1e-8) == quality_results["delta_precision_alpha_naive"]
+    assert pytest.approx(0.005229629629629584, abs=1e-8) == quality_results["delta_coverage_beta_naive"]
tests/integration/attacks/tf/test_tf_attack.py (2)

45-54: Remove redundant .values() unpacking before explicit key-based access.

You first unpack mia_performance_* .values() into variables, then immediately overwrite them using explicit key lookups. The unpacking is unnecessary and relies on dict insertion order; just keep the key-based access:

-    mia_performance_train, mia_performance_val, mia_performance_test = tf_attack(**config)
-    tpr_at_fpr_train, roc_auc_train = mia_performance_train.values()
-    tpr_at_fpr_val, roc_auc_val = mia_performance_val.values()
-    tpr_at_fpr_test, roc_auc_test = mia_performance_test.values()
-    tpr_at_fpr_train = mia_performance_train["max_tpr"]
-    roc_auc_train = mia_performance_train["roc_auc"]
-    tpr_at_fpr_val = mia_performance_val["max_tpr"]
-    roc_auc_val = mia_performance_val["roc_auc"]
-    tpr_at_fpr_test = mia_performance_test["max_tpr"]
-    roc_auc_test = mia_performance_test["roc_auc"]
+    mia_performance_train, mia_performance_val, mia_performance_test = tf_attack(**config)
+    tpr_at_fpr_train = mia_performance_train["max_tpr"]
+    roc_auc_train = mia_performance_train["roc_auc"]
+    tpr_at_fpr_val = mia_performance_val["max_tpr"]
+    roc_auc_val = mia_performance_val["roc_auc"]
+    tpr_at_fpr_test = mia_performance_test["max_tpr"]
+    roc_auc_test = mia_performance_test["roc_auc"]

69-119: Consider deleting the large commented-out alternative test with absolute paths.

This block is dead code and contains hardcoded, developer-specific absolute paths. Keeping it commented out adds noise and may confuse future readers about which setup is authoritative. Prefer removing it or moving the scenario into a separate, properly parameterized test/config.

src/midst_toolkit/attacks/tf/tf_attack.py (2)

1-7: Remove or narrow unused ruff: noqa directive.

Ruff reports this noqa as unused for D102, D105, D103, D200, PLR0915. If those checks aren’t actually enabled, the directive is unnecessary noise. Either remove it or restrict it to the specific rules you need to silence.

-# ruff: noqa: D102, D105, D103, D200, PLR0915
+# ruff: noqa

or simply delete the line if you don’t need it at all.


113-181: Guard get_dataset against missing target_model_dir/batch_size instead of relying on fragile defaults.

get_dataset defaults target_model_dir=None and batch_size=None, but the body assumes both are set:

  • os.path.join(target_model_dir, ...) will fail if target_model_dir is None.
  • prepare_fast_dataloader(..., batch_size=batch_size, ...) likely expects an integer.

From this module it’s always called with non-None arguments, but the signature suggests they’re optional. Either make them required or fail fast with clear errors:

-def get_dataset(data_path, target_model_dir=None, train_name="train_with_id.csv", batch_size=None, meta_dir=""):
+def get_dataset(data_path, target_model_dir=None, train_name="train_with_id.csv", batch_size=None, meta_dir=""):
@@
-    tables, relation_order, _ = load_multi_table_customized(
+    if target_model_dir is None:
+        raise ValueError("target_model_dir must be provided.")
+    if batch_size is None:
+        raise ValueError("batch_size must be provided.")
+
+    tables, relation_order, _ = load_multi_table_customized(
         data_path,
         meta_dir=meta_dir,
         train_name=train_name,

This makes incorrect external use of get_dataset fail deterministically instead of with a less obvious TypeError.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6eb18e9 and b1d9f8a.

⛔ Files ignored due to path filters (52)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/data_for_training_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/test.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/data_for_training_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/test.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/data_for_validating_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/test.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/data_for_validating_MIA.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/test.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_synthetic.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/None_trans_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_label.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/challenge_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/cluster_ckpt.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_2.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/predictions_test_222.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/test.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/train_with_id.csv is excluded by !**/*.csv
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_label_encoders.pkl is excluded by !**/*.pkl
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_synthetic.csv is excluded by !**/*.csv
📒 Files selected for processing (18)
  • src/midst_toolkit/attacks/tf/classification.py (1 hunks)
  • src/midst_toolkit/attacks/tf/data_utils.py (1 hunks)
  • src/midst_toolkit/attacks/tf/tf_attack.py (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/updated_config.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/updated_config.json (1 hunks)
  • tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/workspace/train_1/args (1 hunks)
  • tests/integration/attacks/tf/data_configs/dataset_meta.json (1 hunks)
  • tests/integration/attacks/tf/data_configs/trans.json (1 hunks)
  • tests/integration/attacks/tf/data_configs/trans_domain.json (1 hunks)
  • tests/integration/attacks/tf/test_tf_attack.py (1 hunks)
  • tests/unit/evaluation/quality/test_alpha_precision_naive.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/unit/evaluation/quality/test_alpha_precision_naive.py (4)
src/midst_toolkit/common/random.py (2)
  • set_all_random_seeds (11-55)
  • unset_all_random_seeds (58-67)
src/midst_toolkit/data_processing/midst_data_processing.py (2)
  • load_midst_data (94-121)
  • process_midst_data_for_alpha_precision_evaluation (17-91)
src/midst_toolkit/evaluation/utils.py (2)
  • extract_columns_based_on_meta_info (45-87)
  • one_hot_encode_categoricals_and_merge_with_numerical (90-128)
tests/utils/architecture.py (1)
  • is_apple_silicon (4-6)
src/midst_toolkit/attacks/tf/data_utils.py (2)
src/midst_toolkit/models/clavaddpm/dataset.py (1)
  • Dataset (77-397)
src/midst_toolkit/models/clavaddpm/enumerations.py (1)
  • Normalization (58-63)
src/midst_toolkit/attacks/tf/tf_attack.py (4)
src/midst_toolkit/attacks/tf/classification.py (2)
  • MLP (10-42)
  • fitmodel (67-161)
src/midst_toolkit/attacks/tf/data_utils.py (6)
  • CustomUnpickler (51-56)
  • TaskType (42-48)
  • evaluate_attack_performance (185-218)
  • load_multi_table_customized (59-103)
  • prepare_data_for_attack (138-171)
  • prepare_fast_dataloader (267-304)
src/midst_toolkit/models/clavaddpm/dataset.py (3)
  • Dataset (77-397)
  • from_df (276-397)
  • size (199-211)
src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py (3)
  • sample_time (945-982)
  • gaussian_q_sample (302-321)
  • _gaussian_loss (505-543)
🪛 Ruff (0.14.7)
src/midst_toolkit/attacks/tf/data_utils.py

1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200)

Remove unused noqa directive

(RUF100)


59-59: Unused function argument: verbose

(ARG001)


75-75: Avoid specifying long messages outside the exception class

(TRY003)


95-95: Avoid specifying long messages outside the exception class

(TRY003)


101-101: Avoid specifying long messages outside the exception class

(TRY003)


143-143: Avoid specifying long messages outside the exception class

(TRY003)


165-165: Avoid specifying long messages outside the exception class

(TRY003)


193-193: Avoid specifying long messages outside the exception class

(TRY003)

src/midst_toolkit/attacks/tf/tf_attack.py

1-1: Unused noqa directive (non-enabled: D102, D105, D103, D200, PLR0915)

Remove unused noqa directive

(RUF100)


82-82: Avoid specifying long messages outside the exception class

(TRY003)


222-222: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (11)
src/midst_toolkit/attacks/tf/data_utils.py (3)

70-103: Table loading and validation logic looks sound for single-table use

The per‑table loading/validation flow (train CSV + {table}_domain.json, ID column removal, '?' value check, and numeric‑column string guard) is consistent and defensive. For the current test dataset (single "trans" table), this is straightforward and appropriate.


185-219: Attack evaluation aggregation and metric computation look coherent

The aggregation in evaluate_attack_performance—looping over indices, skipping missing files, concatenating predictions and labels, then computing max TPR at a fixed FPR and ROC AUC—is consistent and matches the intended evaluation logic. Guard clauses for empty indices and no predictions give safe fallbacks.

Please double‑check that the shape of challenge_label.csv (loaded via np.loadtxt(..., skiprows=1)) matches the expected 1D label array so roc_curve receives the correct y_true.


221-265: FastTensorDataLoader implementation is minimal and correct

The loader enforces equal leading dimensions, supports optional shuffling via a permuted index, and computes the correct number of batches including a final partial batch. __iter__/__next__ semantics are standard and should integrate cleanly with for loops or yield from.

tests/integration/attacks/tf/data_configs/dataset_meta.json (1)

1-1: Dataset meta format is minimal but consistent with loader expectations

Single‑table relation_order and tables.trans structure align with how load_multi_table_customized reads dataset_meta.json. No issues spotted.

tests/integration/attacks/tf/data_configs/trans.json (1)

1-50: Config parameters and relative paths look appropriate for tests

The general/clustering/diffusion/classifier/sampling/matching sections are consistent and use repo‑relative paths (including workspace_dir), which is good for portability in CI and local runs.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_3/trans_domain.json (1)

1-1: Domain descriptor matches the expected schema

Field names, sizes, and types mirror the other trans_domain.json assets, so this should integrate smoothly with the domain‑aware tooling.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_1/trans_domain.json (1)

1-1: Tabddpm_1 domain metadata is consistent with other variants

Schema and typing for all fields match the other trans_domain.json assets, which keeps the attack tests’ domain assumptions uniform.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_5/trans_domain.json (1)

1-1: Tabddpm_5 domain descriptor aligns with the rest of the suite

Field set and type declarations are consistent with the other tabddpm model domain files, so downstream utilities can treat all models uniformly.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_2/trans_domain.json (1)

1-1: LGTM!

Valid JSON domain metadata defining feature sizes and types for test fixtures. Structure is consistent with other trans_domain.json files in the PR.

tests/integration/attacks/tf/data_configs/trans_domain.json (1)

1-1: LGTM!

Valid JSON domain metadata for test data configuration.

tests/integration/attacks/tf/assets/tabddpm_models/tabddpm_4/trans_domain.json (1)

1-1: LGTM!

Valid JSON domain metadata consistent with other tabddpm model test assets.

Comment on lines +15 to +21
set_all_random_seeds(
seed=133742,
use_deterministic_torch_algos=True,
disable_torch_benchmarking=True,
)

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Ensure RNG and env cleanup even when the test fails.

If any assertion fails before the end of the test, unset_all_random_seeds() and the CUBLAS_WORKSPACE_CONFIG cleanup won’t run, leaving global state polluted for subsequent tests. Wrap the body in a try/finally:

 def test_tf_attack_whitebox_small_config():
-    # Set deterministic behavior
-    set_all_random_seeds(
-        seed=133742,
-        use_deterministic_torch_algos=True,
-        disable_torch_benchmarking=True,
-    )
-
-    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-
-    base_path = ...
-    ...
-    assert tpr_at_fpr_test == pytest.approx(0.12, abs=1e-8)
-
-    unset_all_random_seeds()
-    os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
+    # Set deterministic behavior
+    set_all_random_seeds(
+        seed=133742,
+        use_deterministic_torch_algos=True,
+        disable_torch_benchmarking=True,
+    )
+    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+    try:
+        base_path = ...
+        ...
+        assert tpr_at_fpr_test == pytest.approx(0.12, abs=1e-8)
+    finally:
+        unset_all_random_seeds()
+        os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)

Also applies to: 65-66

🤖 Prompt for AI Agents
In tests/integration/attacks/tf/test_tf_attack.py around lines 15-21 (and
similarly at lines 65-66), the test sets global RNG seeds and the
CUBLAS_WORKSPACE_CONFIG env var but does not guarantee cleanup if the test
fails; wrap the test body that calls set_all_random_seeds and sets
os.environ["CUBLAS_WORKSPACE_CONFIG"] in a try/finally so that
unset_all_random_seeds() and deletion (or restoration) of
CUBLAS_WORKSPACE_CONFIG always run in the finally block; apply the same
try/finally pattern to the other occurrence at lines 65-66 to ensure global
state is restored even on assertion errors.

Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some preliminary comments for you to consider. I know this is a first pass at getting this code into the toolkit. So we don't need to address everything. However, I think most of my comments are at least worth thinking about addressing to true to improve the clarity of the code.

We'll certainly have to work on making it a bit easier to use.

@bzamanlooy bzamanlooy self-assigned this Dec 8, 2025
Copy link
Collaborator

@lotif lotif left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a partial review because there seems to be some bigger issues with this (e.g. tartan_federer and tf folders seem to have duplicated code).

Also, both data_configs and test_tf_attack_results folders should live inside the assets folder. It's annoying we can't add comments on the folders themselves in github.



class CustomUnpickler(pickle.Unpickler):
def find_class(self, module: str, name: str) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docstrings here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this return type can be Callable instead of Any.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is overriding a specific function within the Unpickler class, I think we need to keep the signature the same?

Returns:
_description_
"""
df_train_merge, _, _ = prepare_data_for_attack(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see where these are being used at all. I left them, because they were part of your original implementation, but it seems like they aren't used in the training, we just use the dataloader?

Copy link
Collaborator Author

@bzamanlooy bzamanlooy Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So these were used to create population dataset within the structure of the competition for training and validating the classifier and later passed onto prepare_data. A more straight forward approach is to just load all of the transaction table, take out the training points for the indices corresponding to the challenge dataset.

for timestep in timesteps:
for additional_timestep in additional_timesteps:
if model_number in train_indices:
batch_size = samples_per_train_model * 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand why we're multiplying samples_per_train_model by 2 here to get the batch size (same with the validation set up). The batch size is very heavily tied to the overall dataset size and will throw an error if not set correctly. The way it's setup now, if you don't set samples_per_train_model just right, it will break.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's say the number of samples is 100. Then we take 100 samples from the training points (members) used to train the model and 100 samples from the rest of the data available to the adversary. This is why we multiply it by 2.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. That makes sense in theory. I don't think that's what actually happening in practice for the code though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. I think I'm starting too see this now.

Copy link
Collaborator

@lotif lotif left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks a lot better, thanks for addressing the earlier comments.

Approving with mostly minor comments. If you can't address them right now, you can merge and address them in a follow up PR.

Returns:
Filtered dataframe reading for classifier training (or testing)
"""
raw_data = pd.read_csv(model_dir / "train_with_id.csv")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the lines above, we should make this into a parameter with the default value being a constant with value "train_with_id.csv".

Args:
model_dir: Model directory from which to load data. This directory must contain a file named
"train_with_id.csv" and "data_for_training_MIA.csv"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of the MIA dataset is defined by the mia_dataset_name parameter, though.

Maybe we could make a constant on this file with the name of these datasets so we don't need to be hardcoding them in multiple places? The default value of that parameter can also be set to this constant.

input_noise,
model_type,
meta_dir=meta_dir,
challenge_name="challenge_with_id.csv",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be put into a constant and replaced where it's being hardcoded.

@bzamanlooy bzamanlooy merged commit d8619b1 into main Dec 12, 2025
6 of 7 checks passed
@bzamanlooy bzamanlooy deleted the bz/tf branch December 12, 2025 03:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants