1313# limitations under the License.
1414# ==============================================================================
1515"""Differentially private version of Keras optimizer v2."""
16- from typing import Optional , Type
17- import warnings
1816
1917import tensorflow as tf
20- from tensorflow_privacy .privacy .dp_query import dp_query
21- from tensorflow_privacy .privacy .dp_query import gaussian_query
22-
2318
24- def _normalize (microbatch_gradient : tf .Tensor ,
25- num_microbatches : float ) -> tf .Tensor :
26- """Normalizes `microbatch_gradient` by `num_microbatches`."""
27- return tf .truediv (microbatch_gradient ,
28- tf .cast (num_microbatches , microbatch_gradient .dtype ))
19+ from tensorflow_privacy .privacy .dp_query import gaussian_query
2920
3021
31- def make_keras_generic_optimizer_class (
32- cls : Type [tf .keras .optimizers .Optimizer ]):
33- """Returns a differentially private (DP) subclass of `cls`.
22+ def make_keras_optimizer_class (cls ):
23+ """Given a subclass of `tf.keras.optimizers.Optimizer`, returns a DP-SGD subclass of it.
3424
3525 Args:
3626 cls: Class from which to derive a DP subclass. Should be a subclass of
3727 `tf.keras.optimizers.Optimizer`.
28+
29+ Returns:
30+ A DP-SGD subclass of `cls`.
3831 """
3932
4033 class DPOptimizerClass (cls ): # pylint: disable=empty-docstring
@@ -145,23 +138,24 @@ class DPOptimizerClass(cls): # pylint: disable=empty-docstring
145138
146139 def __init__ (
147140 self ,
148- dp_sum_query : dp_query .DPQuery ,
149- num_microbatches : Optional [int ] = None ,
150- gradient_accumulation_steps : int = 1 ,
141+ l2_norm_clip ,
142+ noise_multiplier ,
143+ num_microbatches = None ,
144+ gradient_accumulation_steps = 1 ,
151145 * args , # pylint: disable=keyword-arg-before-vararg, g-doc-args
152146 ** kwargs ):
153- """Initializes the DPOptimizerClass.
147+ """Initialize the DPOptimizerClass.
154148
155149 Args:
156- dp_sum_query: `DPQuery` object, specifying differential privacy
157- mechanism to use .
150+ l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
151+ noise_multiplier: Ratio of the standard deviation to the clipping norm .
158152 num_microbatches: Number of microbatches into which each minibatch is
159- split. Default is `None` which means that number of microbatches is
160- equal to batch size (i.e. each microbatch contains exactly one
153+ split. Default is `None` which means that number of microbatches
154+ is equal to batch size (i.e. each microbatch contains exactly one
161155 example). If `gradient_accumulation_steps` is greater than 1 and
162156 `num_microbatches` is not `None` then the effective number of
163- microbatches is equal to `num_microbatches *
164- gradient_accumulation_steps`.
157+ microbatches is equal to
158+ `num_microbatches * gradient_accumulation_steps`.
165159 gradient_accumulation_steps: If greater than 1 then optimizer will be
166160 accumulating gradients for this number of optimizer steps before
167161 applying them to update model weights. If this argument is set to 1
@@ -171,13 +165,13 @@ def __init__(
171165 """
172166 super ().__init__ (* args , ** kwargs )
173167 self .gradient_accumulation_steps = gradient_accumulation_steps
168+ self ._l2_norm_clip = l2_norm_clip
169+ self ._noise_multiplier = noise_multiplier
174170 self ._num_microbatches = num_microbatches
175- self ._dp_sum_query = dp_sum_query
176- self ._was_dp_gradients_called = False
177- # We initialize the self.`_global_state` within the gradient functions
178- # (and not here) because tensors must be initialized within the graph.
179-
171+ self ._dp_sum_query = gaussian_query .GaussianSumQuery (
172+ l2_norm_clip , l2_norm_clip * noise_multiplier )
180173 self ._global_state = None
174+ self ._was_dp_gradients_called = False
181175
182176 def _create_slots (self , var_list ):
183177 super ()._create_slots (var_list ) # pytype: disable=attribute-error
@@ -241,62 +235,66 @@ def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
241235 """DP-SGD version of base class method."""
242236
243237 self ._was_dp_gradients_called = True
244- if self ._global_state is None :
245- self ._global_state = self ._dp_sum_query .initial_global_state ()
246-
247238 # Compute loss.
248239 if not callable (loss ) and tape is None :
249240 raise ValueError ('`tape` is required when a `Tensor` loss is passed.' )
250-
251241 tape = tape if tape is not None else tf .GradientTape ()
252242
253- with tape :
254- if callable ( loss ) :
243+ if callable ( loss ) :
244+ with tape :
255245 if not callable (var_list ):
256246 tape .watch (var_list )
257247
258248 loss = loss ()
259- if self ._num_microbatches is None :
260- num_microbatches = tf .shape (input = loss )[0 ]
261- else :
262- num_microbatches = self ._num_microbatches
263- microbatch_losses = tf .reduce_mean (
264- tf .reshape (loss , [num_microbatches , - 1 ]), axis = 1 )
265-
266- if callable (var_list ):
267- var_list = var_list ()
249+ if self ._num_microbatches is None :
250+ num_microbatches = tf .shape (input = loss )[0 ]
251+ else :
252+ num_microbatches = self ._num_microbatches
253+ microbatch_losses = tf .reduce_mean (
254+ tf .reshape (loss , [num_microbatches , - 1 ]), axis = 1 )
255+
256+ if callable (var_list ):
257+ var_list = var_list ()
258+ else :
259+ with tape :
260+ if self ._num_microbatches is None :
261+ num_microbatches = tf .shape (input = loss )[0 ]
262+ else :
263+ num_microbatches = self ._num_microbatches
264+ microbatch_losses = tf .reduce_mean (
265+ tf .reshape (loss , [num_microbatches , - 1 ]), axis = 1 )
268266
269267 var_list = tf .nest .flatten (var_list )
270268
271- sample_params = (
272- self ._dp_sum_query .derive_sample_params (self ._global_state ))
273-
274269 # Compute the per-microbatch losses using helpful jacobian method.
275270 with tf .keras .backend .name_scope (self ._name + '/gradients' ):
276- jacobian_per_var = tape .jacobian (
271+ jacobian = tape .jacobian (
277272 microbatch_losses , var_list , unconnected_gradients = 'zero' )
278273
279- def process_microbatch (sample_state , microbatch_jacobians ):
280- """Process one microbatch (record) with privacy helper."""
281- sample_state = self ._dp_sum_query .accumulate_record (
282- sample_params , sample_state , microbatch_jacobians )
283- return sample_state
274+ # Clip gradients to given l2_norm_clip.
275+ def clip_gradients (g ):
276+ return tf .clip_by_global_norm (g , self ._l2_norm_clip )[0 ]
284277
285- sample_state = self ._dp_sum_query .initial_sample_state (var_list )
286- for idx in range (num_microbatches ):
287- microbatch_jacobians_per_var = [
288- jacobian [idx ] for jacobian in jacobian_per_var
289- ]
290- sample_state = process_microbatch (sample_state ,
291- microbatch_jacobians_per_var )
278+ clipped_gradients = tf .map_fn (clip_gradients , jacobian )
292279
293- grad_sums , self ._global_state , _ = (
294- self ._dp_sum_query .get_noised_result (sample_state ,
295- self ._global_state ))
296- final_grads = tf .nest .map_structure (_normalize , grad_sums ,
297- [num_microbatches ] * len (grad_sums ))
280+ def reduce_noise_normalize_batch (g ):
281+ # Sum gradients over all microbatches.
282+ summed_gradient = tf .reduce_sum (g , axis = 0 )
298283
299- return list (zip (final_grads , var_list ))
284+ # Add noise to summed gradients.
285+ noise_stddev = self ._l2_norm_clip * self ._noise_multiplier
286+ noise = tf .random .normal (
287+ tf .shape (input = summed_gradient ), stddev = noise_stddev )
288+ noised_gradient = tf .add (summed_gradient , noise )
289+
290+ # Normalize by number of microbatches and return.
291+ return tf .truediv (noised_gradient ,
292+ tf .cast (num_microbatches , tf .float32 ))
293+
294+ final_gradients = tf .nest .map_structure (reduce_noise_normalize_batch ,
295+ clipped_gradients )
296+
297+ return list (zip (final_gradients , var_list ))
300298
301299 def get_gradients (self , loss , params ):
302300 """DP-SGD version of base class method."""
@@ -324,13 +322,17 @@ def process_microbatch(i, sample_state):
324322 sample_state = self ._dp_sum_query .initial_sample_state (params )
325323 for idx in range (self ._num_microbatches ):
326324 sample_state = process_microbatch (idx , sample_state )
327-
328325 grad_sums , self ._global_state , _ = (
329326 self ._dp_sum_query .get_noised_result (sample_state ,
330327 self ._global_state ))
331328
332- final_grads = tf .nest .map_structure (
333- _normalize , grad_sums , [self ._num_microbatches ] * len (grad_sums ))
329+ def normalize (v ):
330+ try :
331+ return tf .truediv (v , tf .cast (self ._num_microbatches , tf .float32 ))
332+ except TypeError :
333+ return None
334+
335+ final_grads = tf .nest .map_structure (normalize , grad_sums )
334336
335337 return final_grads
336338
@@ -366,87 +368,7 @@ def apply_gradients(self, *args, **kwargs):
366368 return DPOptimizerClass
367369
368370
369- def make_gaussian_query_optimizer_class (cls ):
370- """Returns a differentially private optimizer using the `GaussianSumQuery`.
371-
372- Args:
373- cls: `DPOptimizerClass`, the output of `make_keras_optimizer_class`.
374- """
375-
376- def return_gaussian_query_optimizer (
377- l2_norm_clip : float ,
378- noise_multiplier : float ,
379- num_microbatches : Optional [int ] = None ,
380- gradient_accumulation_steps : int = 1 ,
381- * args , # pylint: disable=keyword-arg-before-vararg, g-doc-args
382- ** kwargs ):
383- """Returns a `DPOptimizerClass` `cls` using the `GaussianSumQuery`.
384-
385- This function is a thin wrapper around
386- `make_keras_optimizer_class.<locals>.DPOptimizerClass` which can be used to
387- apply a `GaussianSumQuery` to any `DPOptimizerClass`.
388-
389- When combined with stochastic gradient descent, this creates the canonical
390- DP-SGD algorithm of "Deep Learning with Differential Privacy"
391- (see https://arxiv.org/abs/1607.00133).
392-
393- Args:
394- l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
395- noise_multiplier: Ratio of the standard deviation to the clipping norm.
396- num_microbatches: Number of microbatches into which each minibatch is
397- split. Default is `None` which means that number of microbatches is
398- equal to batch size (i.e. each microbatch contains exactly one example).
399- If `gradient_accumulation_steps` is greater than 1 and
400- `num_microbatches` is not `None` then the effective number of
401- microbatches is equal to `num_microbatches *
402- gradient_accumulation_steps`.
403- gradient_accumulation_steps: If greater than 1 then optimizer will be
404- accumulating gradients for this number of optimizer steps before
405- applying them to update model weights. If this argument is set to 1 then
406- updates will be applied on each optimizer step.
407- *args: These will be passed on to the base class `__init__` method.
408- **kwargs: These will be passed on to the base class `__init__` method.
409- """
410- dp_sum_query = gaussian_query .GaussianSumQuery (
411- l2_norm_clip , l2_norm_clip * noise_multiplier )
412- return cls (
413- dp_sum_query = dp_sum_query ,
414- num_microbatches = num_microbatches ,
415- gradient_accumulation_steps = gradient_accumulation_steps ,
416- * args ,
417- ** kwargs )
418-
419- return return_gaussian_query_optimizer
420-
421-
422- def make_keras_optimizer_class (cls : Type [tf .keras .optimizers .Optimizer ]):
423- """Returns a differentially private optimizer using the `GaussianSumQuery`.
424-
425- For backwards compatibility, we create this symbol to match the previous
426- output of `make_keras_optimizer_class` but using the new logic.
427-
428- Args:
429- cls: Class from which to derive a DP subclass. Should be a subclass of
430- `tf.keras.optimizers.Optimizer`.
431- """
432- warnings .warn (
433- '`make_keras_optimizer_class` will be depracated on 2023-02-23. '
434- 'Please switch to `make_gaussian_query_optimizer_class` and the '
435- 'generic optimizers (`make_keras_generic_optimizer_class`).' )
436- return make_gaussian_query_optimizer_class (
437- make_keras_generic_optimizer_class (cls ))
438-
439-
440- GenericDPAdagradOptimizer = make_keras_generic_optimizer_class (
371+ DPKerasAdagradOptimizer = make_keras_optimizer_class (
441372 tf .keras .optimizers .Adagrad )
442- GenericDPAdamOptimizer = make_keras_generic_optimizer_class (
443- tf .keras .optimizers .Adam )
444- GenericDPSGDOptimizer = make_keras_generic_optimizer_class (
445- tf .keras .optimizers .SGD )
446-
447- # We keep the same names for backwards compatibility.
448- DPKerasAdagradOptimizer = make_gaussian_query_optimizer_class (
449- GenericDPAdagradOptimizer )
450- DPKerasAdamOptimizer = make_gaussian_query_optimizer_class (
451- GenericDPAdamOptimizer )
452- DPKerasSGDOptimizer = make_gaussian_query_optimizer_class (GenericDPSGDOptimizer )
373+ DPKerasAdamOptimizer = make_keras_optimizer_class (tf .keras .optimizers .Adam )
374+ DPKerasSGDOptimizer = make_keras_optimizer_class (tf .keras .optimizers .SGD )
0 commit comments