Skip to content

Commit cff4768

Browse files
Changes DPOptimizerClass to generically accept and use any dp_sum_query.
This enables creation of generic DPOptimizers by user's passing queries. The most common Gaussian query is automatically performed for convenience and backwards compatibility. Byproducts of this update: -ensures consistent implementations between the internal (and legacy) `get_gradients` and newer `_compute_gradients` for all queries. -refactors for python readability. PiperOrigin-RevId: 470883774
1 parent ed16033 commit cff4768

File tree

4 files changed

+296
-201
lines changed

4 files changed

+296
-201
lines changed

tensorflow_privacy/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,14 @@
6161
from tensorflow_privacy.privacy.keras_models.dp_keras_model import make_dp_model_class
6262

6363
# Optimizers
64+
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPAdagradOptimizer
65+
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPAdamOptimizer
66+
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPSGDOptimizer
6467
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdagradOptimizer
6568
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdamOptimizer
6669
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
70+
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_gaussian_query_optimizer_class
71+
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_generic_optimizer_class
6772
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_optimizer_class
6873

6974
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import VectorizedDPKerasAdagradOptimizer

tensorflow_privacy/privacy/optimizers/BUILD

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,18 @@ py_library(
1818
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"],
1919
)
2020

21+
py_library(
22+
name = "dp_optimizer_factory",
23+
srcs = [
24+
"dp_optimizer_keras.py",
25+
],
26+
srcs_version = "PY3",
27+
deps = [
28+
"//tensorflow_privacy/privacy/dp_query",
29+
"//tensorflow_privacy/privacy/dp_query:gaussian_query",
30+
],
31+
)
32+
2133
py_library(
2234
name = "dp_optimizer_vectorized",
2335
srcs = [
@@ -32,7 +44,10 @@ py_library(
3244
"dp_optimizer_keras.py",
3345
],
3446
srcs_version = "PY3",
35-
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"],
47+
deps = [
48+
"//tensorflow_privacy/privacy/dp_query",
49+
"//tensorflow_privacy/privacy/dp_query:gaussian_query",
50+
],
3651
)
3752

3853
py_library(
@@ -84,7 +99,7 @@ py_test(
8499
python_version = "PY3",
85100
srcs_version = "PY3",
86101
deps = [
87-
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
88-
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras_vectorized",
102+
":dp_optimizer_keras",
103+
":dp_optimizer_keras_vectorized",
89104
],
90105
)

tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py

Lines changed: 151 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,28 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""Differentially private version of Keras optimizer v2."""
16+
from typing import Optional, Type
17+
import warnings
1618

1719
import tensorflow as tf
18-
20+
from tensorflow_privacy.privacy.dp_query import dp_query
1921
from tensorflow_privacy.privacy.dp_query import gaussian_query
2022

2123

22-
def make_keras_optimizer_class(cls):
23-
"""Given a subclass of `tf.keras.optimizers.Optimizer`, returns a DP-SGD subclass of it.
24+
def _normalize(microbatch_gradient: tf.Tensor,
25+
num_microbatches: float) -> tf.Tensor:
26+
"""Normalizes `microbatch_gradient` by `num_microbatches`."""
27+
return tf.truediv(microbatch_gradient,
28+
tf.cast(num_microbatches, microbatch_gradient.dtype))
29+
30+
31+
def make_keras_generic_optimizer_class(
32+
cls: Type[tf.keras.optimizers.Optimizer]):
33+
"""Returns a differentially private (DP) subclass of `cls`.
2434
2535
Args:
2636
cls: Class from which to derive a DP subclass. Should be a subclass of
2737
`tf.keras.optimizers.Optimizer`.
28-
29-
Returns:
30-
A DP-SGD subclass of `cls`.
3138
"""
3239

3340
class DPOptimizerClass(cls): # pylint: disable=empty-docstring
@@ -138,24 +145,23 @@ class DPOptimizerClass(cls): # pylint: disable=empty-docstring
138145

139146
def __init__(
140147
self,
141-
l2_norm_clip,
142-
noise_multiplier,
143-
num_microbatches=None,
144-
gradient_accumulation_steps=1,
148+
dp_sum_query: dp_query.DPQuery,
149+
num_microbatches: Optional[int] = None,
150+
gradient_accumulation_steps: int = 1,
145151
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
146152
**kwargs):
147-
"""Initialize the DPOptimizerClass.
153+
"""Initializes the DPOptimizerClass.
148154
149155
Args:
150-
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
151-
noise_multiplier: Ratio of the standard deviation to the clipping norm.
156+
dp_sum_query: `DPQuery` object, specifying differential privacy
157+
mechanism to use.
152158
num_microbatches: Number of microbatches into which each minibatch is
153-
split. Default is `None` which means that number of microbatches
154-
is equal to batch size (i.e. each microbatch contains exactly one
159+
split. Default is `None` which means that number of microbatches is
160+
equal to batch size (i.e. each microbatch contains exactly one
155161
example). If `gradient_accumulation_steps` is greater than 1 and
156162
`num_microbatches` is not `None` then the effective number of
157-
microbatches is equal to
158-
`num_microbatches * gradient_accumulation_steps`.
163+
microbatches is equal to `num_microbatches *
164+
gradient_accumulation_steps`.
159165
gradient_accumulation_steps: If greater than 1 then optimizer will be
160166
accumulating gradients for this number of optimizer steps before
161167
applying them to update model weights. If this argument is set to 1
@@ -165,13 +171,13 @@ def __init__(
165171
"""
166172
super().__init__(*args, **kwargs)
167173
self.gradient_accumulation_steps = gradient_accumulation_steps
168-
self._l2_norm_clip = l2_norm_clip
169-
self._noise_multiplier = noise_multiplier
170174
self._num_microbatches = num_microbatches
171-
self._dp_sum_query = gaussian_query.GaussianSumQuery(
172-
l2_norm_clip, l2_norm_clip * noise_multiplier)
173-
self._global_state = None
175+
self._dp_sum_query = dp_sum_query
174176
self._was_dp_gradients_called = False
177+
# We initialize the self.`_global_state` within the gradient functions
178+
# (and not here) because tensors must be initialized within the graph.
179+
180+
self._global_state = None
175181

176182
def _create_slots(self, var_list):
177183
super()._create_slots(var_list) # pytype: disable=attribute-error
@@ -235,66 +241,62 @@ def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
235241
"""DP-SGD version of base class method."""
236242

237243
self._was_dp_gradients_called = True
244+
if self._global_state is None:
245+
self._global_state = self._dp_sum_query.initial_global_state()
246+
238247
# Compute loss.
239248
if not callable(loss) and tape is None:
240249
raise ValueError('`tape` is required when a `Tensor` loss is passed.')
250+
241251
tape = tape if tape is not None else tf.GradientTape()
242252

243-
if callable(loss):
244-
with tape:
253+
with tape:
254+
if callable(loss):
245255
if not callable(var_list):
246256
tape.watch(var_list)
247257

248258
loss = loss()
249-
if self._num_microbatches is None:
250-
num_microbatches = tf.shape(input=loss)[0]
251-
else:
252-
num_microbatches = self._num_microbatches
253-
microbatch_losses = tf.reduce_mean(
254-
tf.reshape(loss, [num_microbatches, -1]), axis=1)
255-
256-
if callable(var_list):
257-
var_list = var_list()
258-
else:
259-
with tape:
260-
if self._num_microbatches is None:
261-
num_microbatches = tf.shape(input=loss)[0]
262-
else:
263-
num_microbatches = self._num_microbatches
264-
microbatch_losses = tf.reduce_mean(
265-
tf.reshape(loss, [num_microbatches, -1]), axis=1)
259+
if self._num_microbatches is None:
260+
num_microbatches = tf.shape(input=loss)[0]
261+
else:
262+
num_microbatches = self._num_microbatches
263+
microbatch_losses = tf.reduce_mean(
264+
tf.reshape(loss, [num_microbatches, -1]), axis=1)
265+
266+
if callable(var_list):
267+
var_list = var_list()
266268

267269
var_list = tf.nest.flatten(var_list)
268270

271+
sample_params = (
272+
self._dp_sum_query.derive_sample_params(self._global_state))
273+
269274
# Compute the per-microbatch losses using helpful jacobian method.
270275
with tf.keras.backend.name_scope(self._name + '/gradients'):
271-
jacobian = tape.jacobian(
276+
jacobian_per_var = tape.jacobian(
272277
microbatch_losses, var_list, unconnected_gradients='zero')
273278

274-
# Clip gradients to given l2_norm_clip.
275-
def clip_gradients(g):
276-
return tf.clip_by_global_norm(g, self._l2_norm_clip)[0]
279+
def process_microbatch(sample_state, microbatch_jacobians):
280+
"""Process one microbatch (record) with privacy helper."""
281+
sample_state = self._dp_sum_query.accumulate_record(
282+
sample_params, sample_state, microbatch_jacobians)
283+
return sample_state
277284

278-
clipped_gradients = tf.map_fn(clip_gradients, jacobian)
285+
sample_state = self._dp_sum_query.initial_sample_state(var_list)
286+
for idx in range(num_microbatches):
287+
microbatch_jacobians_per_var = [
288+
jacobian[idx] for jacobian in jacobian_per_var
289+
]
290+
sample_state = process_microbatch(sample_state,
291+
microbatch_jacobians_per_var)
279292

280-
def reduce_noise_normalize_batch(g):
281-
# Sum gradients over all microbatches.
282-
summed_gradient = tf.reduce_sum(g, axis=0)
293+
grad_sums, self._global_state, _ = (
294+
self._dp_sum_query.get_noised_result(sample_state,
295+
self._global_state))
296+
final_grads = tf.nest.map_structure(_normalize, grad_sums,
297+
[num_microbatches] * len(grad_sums))
283298

284-
# Add noise to summed gradients.
285-
noise_stddev = self._l2_norm_clip * self._noise_multiplier
286-
noise = tf.random.normal(
287-
tf.shape(input=summed_gradient), stddev=noise_stddev)
288-
noised_gradient = tf.add(summed_gradient, noise)
289-
290-
# Normalize by number of microbatches and return.
291-
return tf.truediv(noised_gradient,
292-
tf.cast(num_microbatches, tf.float32))
293-
294-
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
295-
clipped_gradients)
296-
297-
return list(zip(final_gradients, var_list))
299+
return list(zip(final_grads, var_list))
298300

299301
def get_gradients(self, loss, params):
300302
"""DP-SGD version of base class method."""
@@ -322,17 +324,13 @@ def process_microbatch(i, sample_state):
322324
sample_state = self._dp_sum_query.initial_sample_state(params)
323325
for idx in range(self._num_microbatches):
324326
sample_state = process_microbatch(idx, sample_state)
327+
325328
grad_sums, self._global_state, _ = (
326329
self._dp_sum_query.get_noised_result(sample_state,
327330
self._global_state))
328331

329-
def normalize(v):
330-
try:
331-
return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
332-
except TypeError:
333-
return None
334-
335-
final_grads = tf.nest.map_structure(normalize, grad_sums)
332+
final_grads = tf.nest.map_structure(
333+
_normalize, grad_sums, [self._num_microbatches] * len(grad_sums))
336334

337335
return final_grads
338336

@@ -368,7 +366,87 @@ def apply_gradients(self, *args, **kwargs):
368366
return DPOptimizerClass
369367

370368

371-
DPKerasAdagradOptimizer = make_keras_optimizer_class(
369+
def make_gaussian_query_optimizer_class(cls):
370+
"""Returns a differentially private optimizer using the `GaussianSumQuery`.
371+
372+
Args:
373+
cls: `DPOptimizerClass`, the output of `make_keras_optimizer_class`.
374+
"""
375+
376+
def return_gaussian_query_optimizer(
377+
l2_norm_clip: float,
378+
noise_multiplier: float,
379+
num_microbatches: Optional[int] = None,
380+
gradient_accumulation_steps: int = 1,
381+
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
382+
**kwargs):
383+
"""Returns a `DPOptimizerClass` `cls` using the `GaussianSumQuery`.
384+
385+
This function is a thin wrapper around
386+
`make_keras_optimizer_class.<locals>.DPOptimizerClass` which can be used to
387+
apply a `GaussianSumQuery` to any `DPOptimizerClass`.
388+
389+
When combined with stochastic gradient descent, this creates the canonical
390+
DP-SGD algorithm of "Deep Learning with Differential Privacy"
391+
(see https://arxiv.org/abs/1607.00133).
392+
393+
Args:
394+
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
395+
noise_multiplier: Ratio of the standard deviation to the clipping norm.
396+
num_microbatches: Number of microbatches into which each minibatch is
397+
split. Default is `None` which means that number of microbatches is
398+
equal to batch size (i.e. each microbatch contains exactly one example).
399+
If `gradient_accumulation_steps` is greater than 1 and
400+
`num_microbatches` is not `None` then the effective number of
401+
microbatches is equal to `num_microbatches *
402+
gradient_accumulation_steps`.
403+
gradient_accumulation_steps: If greater than 1 then optimizer will be
404+
accumulating gradients for this number of optimizer steps before
405+
applying them to update model weights. If this argument is set to 1 then
406+
updates will be applied on each optimizer step.
407+
*args: These will be passed on to the base class `__init__` method.
408+
**kwargs: These will be passed on to the base class `__init__` method.
409+
"""
410+
dp_sum_query = gaussian_query.GaussianSumQuery(
411+
l2_norm_clip, l2_norm_clip * noise_multiplier)
412+
return cls(
413+
dp_sum_query=dp_sum_query,
414+
num_microbatches=num_microbatches,
415+
gradient_accumulation_steps=gradient_accumulation_steps,
416+
*args,
417+
**kwargs)
418+
419+
return return_gaussian_query_optimizer
420+
421+
422+
def make_keras_optimizer_class(cls: Type[tf.keras.optimizers.Optimizer]):
423+
"""Returns a differentially private optimizer using the `GaussianSumQuery`.
424+
425+
For backwards compatibility, we create this symbol to match the previous
426+
output of `make_keras_optimizer_class` but using the new logic.
427+
428+
Args:
429+
cls: Class from which to derive a DP subclass. Should be a subclass of
430+
`tf.keras.optimizers.Optimizer`.
431+
"""
432+
warnings.warn(
433+
'`make_keras_optimizer_class` will be depracated on 2023-02-23. '
434+
'Please switch to `make_gaussian_query_optimizer_class` and the '
435+
'generic optimizers (`make_keras_generic_optimizer_class`).')
436+
return make_gaussian_query_optimizer_class(
437+
make_keras_generic_optimizer_class(cls))
438+
439+
440+
GenericDPAdagradOptimizer = make_keras_generic_optimizer_class(
372441
tf.keras.optimizers.Adagrad)
373-
DPKerasAdamOptimizer = make_keras_optimizer_class(tf.keras.optimizers.Adam)
374-
DPKerasSGDOptimizer = make_keras_optimizer_class(tf.keras.optimizers.SGD)
442+
GenericDPAdamOptimizer = make_keras_generic_optimizer_class(
443+
tf.keras.optimizers.Adam)
444+
GenericDPSGDOptimizer = make_keras_generic_optimizer_class(
445+
tf.keras.optimizers.SGD)
446+
447+
# We keep the same names for backwards compatibility.
448+
DPKerasAdagradOptimizer = make_gaussian_query_optimizer_class(
449+
GenericDPAdagradOptimizer)
450+
DPKerasAdamOptimizer = make_gaussian_query_optimizer_class(
451+
GenericDPAdamOptimizer)
452+
DPKerasSGDOptimizer = make_gaussian_query_optimizer_class(GenericDPSGDOptimizer)

0 commit comments

Comments
 (0)