Skip to content

Commit 3f6d0ac

Browse files
Add ability to use sample weights to the membership attack models, where they are supported by the underlying Scikit-Learn estimators. Only the Logistic Regression and Random Forest estimators support sample weights.
PiperOrigin-RevId: 478542133
1 parent feddd28 commit 3f6d0ac

15 files changed

+552
-78
lines changed

tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Functions for advanced membership inference attacks."""
1515

1616
import functools
17-
from typing import Sequence, Union
17+
from typing import Optional, Sequence, Union
1818
import numpy as np
1919
import scipy.stats
2020
from tensorflow_privacy.privacy.privacy_tests.utils import log_loss
@@ -197,6 +197,7 @@ def convert_logit_to_prob(logit: np.ndarray) -> np.ndarray:
197197

198198
def calculate_statistic(pred: np.ndarray,
199199
labels: np.ndarray,
200+
sample_weight: Optional[np.ndarray] = None,
200201
is_logits: bool = True,
201202
option: str = 'logit',
202203
small_value: float = 1e-45):
@@ -215,6 +216,10 @@ def calculate_statistic(pred: np.ndarray,
215216
An array of size n by c where n is the number of samples and c is the
216217
number of classes
217218
labels: true labels of samples (integer valued)
219+
sample_weight: a vector of weights of shape (num_samples, ) that are
220+
assigned to individual samples. If not provided, then each sample is
221+
given unit weight. Only the LogisticRegressionAttacker and the
222+
RandomForestAttacker support sample weights.
218223
is_logits: whether pred is logits or probability vectors
219224
option: confidence using probability, xe loss, logit of confidence,
220225
confidence using logits, hinge loss
@@ -241,7 +246,7 @@ def calculate_statistic(pred: np.ndarray,
241246
if option in ['conf with prob', 'conf with logit']:
242247
return pred[range(n), labels]
243248
if option == 'xe':
244-
return log_loss(labels, pred)
249+
return log_loss(labels, pred, sample_weight=sample_weight)
245250
if option == 'logit':
246251
p_true = pred[range(n), labels]
247252
pred[range(n), labels] = 0

tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_example.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import functools
1717
import gc
1818
import os
19+
from typing import Optional
20+
1921
from absl import app
2022
from absl import flags
2123
import matplotlib.pyplot as plt
@@ -69,7 +71,11 @@ def plot_curve_with_area(x, y, xlabel, ylabel, ax, label, title=None):
6971
ax.title.set_text(title)
7072

7173

72-
def get_stat_and_loss_aug(model, x, y, batch_size=4096):
74+
def get_stat_and_loss_aug(model,
75+
x,
76+
y,
77+
sample_weight: Optional[np.ndarray] = None,
78+
batch_size=4096):
7379
"""A helper function to get the statistics and losses.
7480
7581
Here we get the statistics and losses for the original and
@@ -80,6 +86,10 @@ def get_stat_and_loss_aug(model, x, y, batch_size=4096):
8086
model: model to make prediction
8187
x: samples
8288
y: true labels of samples (integer valued)
89+
sample_weight: a vector of weights of shape (n_samples, ) that are
90+
assigned to individual samples. If not provided, then each sample is
91+
given unit weight. Only the LogisticRegressionAttacker and the
92+
RandomForestAttacker support sample weights.
8393
batch_size: the batch size for model.predict
8494
8595
Returns:
@@ -89,8 +99,10 @@ def get_stat_and_loss_aug(model, x, y, batch_size=4096):
8999
for data in [x, x[:, :, ::-1, :]]:
90100
prob = amia.convert_logit_to_prob(
91101
model.predict(data, batch_size=batch_size))
92-
losses.append(utils.log_loss(y, prob))
93-
stat.append(amia.calculate_statistic(prob, y, convert_to_prob=False))
102+
losses.append(utils.log_loss(y, prob, sample_weight=sample_weight))
103+
stat.append(
104+
amia.calculate_statistic(
105+
prob, y, sample_weight=sample_weight, convert_to_prob=False))
94106
return np.vstack(stat).transpose(1, 0), np.vstack(losses).transpose(1, 0)
95107

96108

@@ -103,6 +115,8 @@ def main(unused_argv):
103115

104116
# Load data.
105117
x, y = load_cifar10()
118+
# Sample weights are set to `None` by default, but can be changed here.
119+
sample_weight = None
106120
n = x.shape[0]
107121

108122
# Train the target and shadow models. We will use one of the model in `models`
@@ -144,7 +158,7 @@ def main(unused_argv):
144158
print(f'Trained model #{i} with {in_indices[-1].sum()} examples.')
145159

146160
# Get the statistics of the current model.
147-
s, l = get_stat_and_loss_aug(model, x, y)
161+
s, l = get_stat_and_loss_aug(model, x, y, sample_weight)
148162
stat.append(s)
149163
losses.append(l)
150164

@@ -175,7 +189,9 @@ def main(unused_argv):
175189
stat_target, stat_in, stat_out, fix_variance=True)
176190
attack_input = AttackInputData(
177191
loss_train=scores[in_indices_target],
178-
loss_test=scores[~in_indices_target])
192+
loss_test=scores[~in_indices_target],
193+
sample_weight_train=sample_weight,
194+
sample_weight_test=sample_weight)
179195
result_lira = mia.run_attacks(attack_input).single_attack_results[0]
180196
print('Advanced MIA attack with Gaussian:',
181197
f'auc = {result_lira.get_auc():.4f}',
@@ -187,7 +203,9 @@ def main(unused_argv):
187203
scores = -amia.compute_score_offset(stat_target, stat_in, stat_out)
188204
attack_input = AttackInputData(
189205
loss_train=scores[in_indices_target],
190-
loss_test=scores[~in_indices_target])
206+
loss_test=scores[~in_indices_target],
207+
sample_weight_train=sample_weight,
208+
sample_weight_test=sample_weight)
191209
result_offset = mia.run_attacks(attack_input).single_attack_results[0]
192210
print('Advanced MIA attack with offset:',
193211
f'auc = {result_offset.get_auc():.4f}',
@@ -197,7 +215,9 @@ def main(unused_argv):
197215
loss_target = losses[idx][:, 0]
198216
attack_input = AttackInputData(
199217
loss_train=loss_target[in_indices_target],
200-
loss_test=loss_target[~in_indices_target])
218+
loss_test=loss_target[~in_indices_target],
219+
sample_weight_train=sample_weight,
220+
sample_weight_test=sample_weight)
201221
result_baseline = mia.run_attacks(attack_input).single_attack_results[0]
202222
print('Baseline MIA attack:', f'auc = {result_baseline.get_auc():.4f}',
203223
f'adv = {result_baseline.get_attacker_advantage():.4f}')

tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/advanced_mia_test.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,19 +158,21 @@ def test_calculate_statistic_logit(self):
158158
# [0.09003057, 0.66524096, 0.24472847]])
159159
labels = np.array([1, 2])
160160

161-
stat = amia.calculate_statistic(logit, labels, is_logits, 'conf with prob')
161+
stat = amia.calculate_statistic(logit, labels, None, is_logits,
162+
'conf with prob')
162163
np.testing.assert_allclose(stat, np.array([0.72747516, 0.24472847]))
163164

164-
stat = amia.calculate_statistic(logit, labels, is_logits, 'xe')
165+
stat = amia.calculate_statistic(logit, labels, None, is_logits, 'xe')
165166
np.testing.assert_allclose(stat, np.array([0.31817543, 1.40760596]))
166167

167-
stat = amia.calculate_statistic(logit, labels, is_logits, 'logit')
168+
stat = amia.calculate_statistic(logit, labels, None, is_logits, 'logit')
168169
np.testing.assert_allclose(stat, np.array([0.98185009, -1.12692802]))
169170

170-
stat = amia.calculate_statistic(logit, labels, is_logits, 'conf with logit')
171+
stat = amia.calculate_statistic(logit, labels, None, is_logits,
172+
'conf with logit')
171173
np.testing.assert_allclose(stat, np.array([2, 0.]))
172174

173-
stat = amia.calculate_statistic(logit, labels, is_logits, 'hinge')
175+
stat = amia.calculate_statistic(logit, labels, None, is_logits, 'hinge')
174176
np.testing.assert_allclose(stat, np.array([1, -1.]))
175177

176178
def test_calculate_statistic_prob(self):
@@ -179,19 +181,74 @@ def test_calculate_statistic_prob(self):
179181
prob = np.array([[0.1, 0.85, 0.05], [0.1, 0.5, 0.4]])
180182
labels = np.array([1, 2])
181183

182-
stat = amia.calculate_statistic(prob, labels, is_logits, 'conf with prob')
184+
stat = amia.calculate_statistic(prob, labels, None, is_logits,
185+
'conf with prob')
183186
np.testing.assert_allclose(stat, np.array([0.85, 0.4]))
184187

185-
stat = amia.calculate_statistic(prob, labels, is_logits, 'xe')
188+
stat = amia.calculate_statistic(prob, labels, None, is_logits, 'xe')
186189
np.testing.assert_allclose(stat, np.array([0.16251893, 0.91629073]))
187190

188-
stat = amia.calculate_statistic(prob, labels, is_logits, 'logit')
191+
stat = amia.calculate_statistic(prob, labels, None, is_logits, 'logit')
189192
np.testing.assert_allclose(stat, np.array([1.73460106, -0.40546511]))
190193

191194
np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels,
192-
is_logits, 'conf with logit')
195+
None, is_logits, 'conf with logit')
193196
np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels,
194-
is_logits, 'hinge')
197+
None, is_logits, 'hinge')
198+
199+
def test_calculate_statistic_logit_with_sample_weights(self):
200+
"""Test calculate_statistic with input as logit."""
201+
is_logits = True
202+
logit = np.array([[1, 2, -3.], [-1, 1, 0]])
203+
# expected probability vector
204+
# array([[0.26762315, 0.72747516, 0.00490169],
205+
# [0.09003057, 0.66524096, 0.24472847]])
206+
labels = np.array([1, 2])
207+
sample_weight = np.array([1.0, 0.5])
208+
209+
stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits,
210+
'conf with prob')
211+
np.testing.assert_allclose(stat, np.array([0.72747516, 0.24472847]))
212+
213+
stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits,
214+
'xe')
215+
np.testing.assert_allclose(stat, np.array([0.31817543, 0.70380298]))
216+
217+
stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits,
218+
'logit')
219+
np.testing.assert_allclose(stat, np.array([0.98185009, -1.12692802]))
220+
221+
stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits,
222+
'conf with logit')
223+
np.testing.assert_allclose(stat, np.array([2, 0.]))
224+
225+
stat = amia.calculate_statistic(logit, labels, sample_weight, is_logits,
226+
'hinge')
227+
np.testing.assert_allclose(stat, np.array([1, -1.]))
228+
229+
def test_calculate_statistic_prob_with_sample_weights(self):
230+
"""Test calculate_statistic with input as probability vector."""
231+
is_logits = False
232+
prob = np.array([[0.1, 0.85, 0.05], [0.1, 0.5, 0.4]])
233+
labels = np.array([1, 2])
234+
sample_weight = np.array([1.0, 0.5])
235+
236+
stat = amia.calculate_statistic(prob, labels, sample_weight, is_logits,
237+
'conf with prob')
238+
np.testing.assert_allclose(stat, np.array([0.85, 0.4]))
239+
240+
stat = amia.calculate_statistic(prob, labels, sample_weight, is_logits,
241+
'xe')
242+
np.testing.assert_allclose(stat, np.array([0.16251893, 0.458145365]))
243+
244+
stat = amia.calculate_statistic(prob, labels, sample_weight, is_logits,
245+
'logit')
246+
np.testing.assert_allclose(stat, np.array([1.73460106, -0.40546511]))
247+
248+
np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels,
249+
None, is_logits, 'conf with logit')
250+
np.testing.assert_raises(ValueError, amia.calculate_statistic, prob, labels,
251+
None, is_logits, 'hinge')
195252

196253

197254
if __name__ == '__main__':

tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/data_structures.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import logging
2121
import os
2222
import pickle
23-
from typing import Any, Callable, Iterable, MutableSequence, Optional, Union
23+
from typing import Any, Iterable, MutableSequence, Optional, Union
2424

2525
import numpy as np
2626
import pandas as pd
@@ -203,6 +203,10 @@ class AttackInputData:
203203
labels_train: Optional[np.ndarray] = None
204204
labels_test: Optional[np.ndarray] = None
205205

206+
# Sample weights, if provided.
207+
sample_weight_train: Optional[np.ndarray] = None
208+
sample_weight_test: Optional[np.ndarray] = None
209+
206210
# Explicitly specified loss. If provided, this is used instead of deriving
207211
# loss from logits and labels
208212
loss_train: Optional[np.ndarray] = None
@@ -219,8 +223,7 @@ class AttackInputData:
219223
# string representation, or a callable.
220224
# If a callable is provided, it should take in two argument, the 1st is
221225
# labels, the 2nd is logits or probs.
222-
loss_function: Union[Callable[[np.ndarray, np.ndarray], np.ndarray], str,
223-
utils.LossFunction] = utils.LossFunction.CROSS_ENTROPY
226+
loss_function: utils.LossFunctionCallable = utils.LossFunction.CROSS_ENTROPY
224227
# Whether `loss_function` will be called with logits or probs. If not set
225228
# (None), will decide by availablity of logits and probs and logits is
226229
# preferred when both are available.
@@ -309,7 +312,8 @@ def get_loss_train(self):
309312
self.loss_function_using_logits = (self.logits_train is not None)
310313
return utils.get_loss(self.loss_train, self.labels_train, self.logits_train,
311314
self.probs_train, self.loss_function,
312-
self.loss_function_using_logits, self.multilabel_data)
315+
self.loss_function_using_logits, self.multilabel_data,
316+
self.sample_weight_train)
313317

314318
def get_loss_test(self):
315319
"""Calculates (if needed) cross-entropy losses for the test set.
@@ -321,7 +325,8 @@ def get_loss_test(self):
321325
self.loss_function_using_logits = bool(self.logits_test)
322326
return utils.get_loss(self.loss_test, self.labels_test, self.logits_test,
323327
self.probs_test, self.loss_function,
324-
self.loss_function_using_logits, self.multilabel_data)
328+
self.loss_function_using_logits, self.multilabel_data,
329+
self.sample_weight_test)
325330

326331
def get_entropy_train(self):
327332
"""Calculates prediction entropy for the training set."""
@@ -367,6 +372,11 @@ def get_test_size(self):
367372
"""Returns the number of examples of the test set."""
368373
return self.get_test_shape()[0]
369374

375+
def has_nonnull_sample_weights(self):
376+
"""Whether both the train and test input data have sample weights."""
377+
return (self.sample_weight_train is not None and
378+
self.sample_weight_test is not None)
379+
370380
def is_multihot_labels(self, arr, arr_name) -> bool:
371381
"""Check if the 2D array is multihot, with values in [0, 1].
372382
@@ -556,6 +566,8 @@ def __str__(self):
556566
_append_array_shape(self.probs_test, 'probs_test', result)
557567
_append_array_shape(self.labels_train, 'labels_train', result)
558568
_append_array_shape(self.labels_test, 'labels_test', result)
569+
_append_array_shape(self.sample_weight_train, 'sample_weight_train', result)
570+
_append_array_shape(self.sample_weight_test, 'sample_weight_test', result)
559571
result.append(')')
560572
return '\n'.join(result)
561573

0 commit comments

Comments
 (0)