1313# limitations under the License.
1414# ==============================================================================
1515"""Differentially private version of Keras optimizer v2."""
16+ from typing import Optional , Type
17+ import warnings
1618
1719import tensorflow as tf
18-
20+ from tensorflow_privacy . privacy . dp_query import dp_query
1921from tensorflow_privacy .privacy .dp_query import gaussian_query
2022
2123
22- def make_keras_optimizer_class (cls ):
23- """Given a subclass of `tf.keras.optimizers.Optimizer`, returns a DP-SGD subclass of it.
24+ def _normalize (microbatch_gradient : tf .Tensor ,
25+ num_microbatches : float ) -> tf .Tensor :
26+ """Normalizes `microbatch_gradient` by `num_microbatches`."""
27+ return tf .truediv (microbatch_gradient ,
28+ tf .cast (num_microbatches , microbatch_gradient .dtype ))
29+
30+
31+ def make_keras_generic_optimizer_class (
32+ cls : Type [tf .keras .optimizers .Optimizer ]):
33+ """Returns a differentially private (DP) subclass of `cls`.
2434
2535 Args:
2636 cls: Class from which to derive a DP subclass. Should be a subclass of
2737 `tf.keras.optimizers.Optimizer`.
28-
29- Returns:
30- A DP-SGD subclass of `cls`.
3138 """
3239
3340 class DPOptimizerClass (cls ): # pylint: disable=empty-docstring
@@ -138,24 +145,23 @@ class DPOptimizerClass(cls): # pylint: disable=empty-docstring
138145
139146 def __init__ (
140147 self ,
141- l2_norm_clip ,
142- noise_multiplier ,
143- num_microbatches = None ,
144- gradient_accumulation_steps = 1 ,
148+ dp_sum_query : dp_query .DPQuery ,
149+ num_microbatches : Optional [int ] = None ,
150+ gradient_accumulation_steps : int = 1 ,
145151 * args , # pylint: disable=keyword-arg-before-vararg, g-doc-args
146152 ** kwargs ):
147- """Initialize the DPOptimizerClass.
153+ """Initializes the DPOptimizerClass.
148154
149155 Args:
150- l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
151- noise_multiplier: Ratio of the standard deviation to the clipping norm .
156+ dp_sum_query: `DPQuery` object, specifying differential privacy
157+ mechanism to use .
152158 num_microbatches: Number of microbatches into which each minibatch is
153- split. Default is `None` which means that number of microbatches
154- is equal to batch size (i.e. each microbatch contains exactly one
159+ split. Default is `None` which means that number of microbatches is
160+ equal to batch size (i.e. each microbatch contains exactly one
155161 example). If `gradient_accumulation_steps` is greater than 1 and
156162 `num_microbatches` is not `None` then the effective number of
157- microbatches is equal to
158- `num_microbatches * gradient_accumulation_steps`.
163+ microbatches is equal to `num_microbatches *
164+ gradient_accumulation_steps`.
159165 gradient_accumulation_steps: If greater than 1 then optimizer will be
160166 accumulating gradients for this number of optimizer steps before
161167 applying them to update model weights. If this argument is set to 1
@@ -165,13 +171,13 @@ def __init__(
165171 """
166172 super ().__init__ (* args , ** kwargs )
167173 self .gradient_accumulation_steps = gradient_accumulation_steps
168- self ._l2_norm_clip = l2_norm_clip
169- self ._noise_multiplier = noise_multiplier
170174 self ._num_microbatches = num_microbatches
171- self ._dp_sum_query = gaussian_query .GaussianSumQuery (
172- l2_norm_clip , l2_norm_clip * noise_multiplier )
173- self ._global_state = None
175+ self ._dp_sum_query = dp_sum_query
174176 self ._was_dp_gradients_called = False
177+ # We initialize the self.`_global_state` within the gradient functions
178+ # (and not here) because tensors must be initialized within the graph.
179+
180+ self ._global_state = None
175181
176182 def _create_slots (self , var_list ):
177183 super ()._create_slots (var_list ) # pytype: disable=attribute-error
@@ -235,66 +241,62 @@ def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
235241 """DP-SGD version of base class method."""
236242
237243 self ._was_dp_gradients_called = True
244+ if self ._global_state is None :
245+ self ._global_state = self ._dp_sum_query .initial_global_state ()
246+
238247 # Compute loss.
239248 if not callable (loss ) and tape is None :
240249 raise ValueError ('`tape` is required when a `Tensor` loss is passed.' )
250+
241251 tape = tape if tape is not None else tf .GradientTape ()
242252
243- if callable ( loss ) :
244- with tape :
253+ with tape :
254+ if callable ( loss ) :
245255 if not callable (var_list ):
246256 tape .watch (var_list )
247257
248258 loss = loss ()
249- if self ._num_microbatches is None :
250- num_microbatches = tf .shape (input = loss )[0 ]
251- else :
252- num_microbatches = self ._num_microbatches
253- microbatch_losses = tf .reduce_mean (
254- tf .reshape (loss , [num_microbatches , - 1 ]), axis = 1 )
255-
256- if callable (var_list ):
257- var_list = var_list ()
258- else :
259- with tape :
260- if self ._num_microbatches is None :
261- num_microbatches = tf .shape (input = loss )[0 ]
262- else :
263- num_microbatches = self ._num_microbatches
264- microbatch_losses = tf .reduce_mean (
265- tf .reshape (loss , [num_microbatches , - 1 ]), axis = 1 )
259+ if self ._num_microbatches is None :
260+ num_microbatches = tf .shape (input = loss )[0 ]
261+ else :
262+ num_microbatches = self ._num_microbatches
263+ microbatch_losses = tf .reduce_mean (
264+ tf .reshape (loss , [num_microbatches , - 1 ]), axis = 1 )
265+
266+ if callable (var_list ):
267+ var_list = var_list ()
266268
267269 var_list = tf .nest .flatten (var_list )
268270
271+ sample_params = (
272+ self ._dp_sum_query .derive_sample_params (self ._global_state ))
273+
269274 # Compute the per-microbatch losses using helpful jacobian method.
270275 with tf .keras .backend .name_scope (self ._name + '/gradients' ):
271- jacobian = tape .jacobian (
276+ jacobian_per_var = tape .jacobian (
272277 microbatch_losses , var_list , unconnected_gradients = 'zero' )
273278
274- # Clip gradients to given l2_norm_clip.
275- def clip_gradients (g ):
276- return tf .clip_by_global_norm (g , self ._l2_norm_clip )[0 ]
279+ def process_microbatch (sample_state , microbatch_jacobians ):
280+ """Process one microbatch (record) with privacy helper."""
281+ sample_state = self ._dp_sum_query .accumulate_record (
282+ sample_params , sample_state , microbatch_jacobians )
283+ return sample_state
277284
278- clipped_gradients = tf .map_fn (clip_gradients , jacobian )
285+ sample_state = self ._dp_sum_query .initial_sample_state (var_list )
286+ for idx in range (num_microbatches ):
287+ microbatch_jacobians_per_var = [
288+ jacobian [idx ] for jacobian in jacobian_per_var
289+ ]
290+ sample_state = process_microbatch (sample_state ,
291+ microbatch_jacobians_per_var )
279292
280- def reduce_noise_normalize_batch (g ):
281- # Sum gradients over all microbatches.
282- summed_gradient = tf .reduce_sum (g , axis = 0 )
293+ grad_sums , self ._global_state , _ = (
294+ self ._dp_sum_query .get_noised_result (sample_state ,
295+ self ._global_state ))
296+ final_grads = tf .nest .map_structure (_normalize , grad_sums ,
297+ [num_microbatches ] * len (grad_sums ))
283298
284- # Add noise to summed gradients.
285- noise_stddev = self ._l2_norm_clip * self ._noise_multiplier
286- noise = tf .random .normal (
287- tf .shape (input = summed_gradient ), stddev = noise_stddev )
288- noised_gradient = tf .add (summed_gradient , noise )
289-
290- # Normalize by number of microbatches and return.
291- return tf .truediv (noised_gradient ,
292- tf .cast (num_microbatches , tf .float32 ))
293-
294- final_gradients = tf .nest .map_structure (reduce_noise_normalize_batch ,
295- clipped_gradients )
296-
297- return list (zip (final_gradients , var_list ))
299+ return list (zip (final_grads , var_list ))
298300
299301 def get_gradients (self , loss , params ):
300302 """DP-SGD version of base class method."""
@@ -322,17 +324,13 @@ def process_microbatch(i, sample_state):
322324 sample_state = self ._dp_sum_query .initial_sample_state (params )
323325 for idx in range (self ._num_microbatches ):
324326 sample_state = process_microbatch (idx , sample_state )
327+
325328 grad_sums , self ._global_state , _ = (
326329 self ._dp_sum_query .get_noised_result (sample_state ,
327330 self ._global_state ))
328331
329- def normalize (v ):
330- try :
331- return tf .truediv (v , tf .cast (self ._num_microbatches , tf .float32 ))
332- except TypeError :
333- return None
334-
335- final_grads = tf .nest .map_structure (normalize , grad_sums )
332+ final_grads = tf .nest .map_structure (
333+ _normalize , grad_sums , [self ._num_microbatches ] * len (grad_sums ))
336334
337335 return final_grads
338336
@@ -368,7 +366,87 @@ def apply_gradients(self, *args, **kwargs):
368366 return DPOptimizerClass
369367
370368
371- DPKerasAdagradOptimizer = make_keras_optimizer_class (
369+ def make_gaussian_query_optimizer_class (cls ):
370+ """Returns a differentially private optimizer using the `GaussianSumQuery`.
371+
372+ Args:
373+ cls: `DPOptimizerClass`, the output of `make_keras_optimizer_class`.
374+ """
375+
376+ def return_gaussian_query_optimizer (
377+ l2_norm_clip : float ,
378+ noise_multiplier : float ,
379+ num_microbatches : Optional [int ] = None ,
380+ gradient_accumulation_steps : int = 1 ,
381+ * args , # pylint: disable=keyword-arg-before-vararg, g-doc-args
382+ ** kwargs ):
383+ """Returns a `DPOptimizerClass` `cls` using the `GaussianSumQuery`.
384+
385+ This function is a thin wrapper around
386+ `make_keras_optimizer_class.<locals>.DPOptimizerClass` which can be used to
387+ apply a `GaussianSumQuery` to any `DPOptimizerClass`.
388+
389+ When combined with stochastic gradient descent, this creates the canonical
390+ DP-SGD algorithm of "Deep Learning with Differential Privacy"
391+ (see https://arxiv.org/abs/1607.00133).
392+
393+ Args:
394+ l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
395+ noise_multiplier: Ratio of the standard deviation to the clipping norm.
396+ num_microbatches: Number of microbatches into which each minibatch is
397+ split. Default is `None` which means that number of microbatches is
398+ equal to batch size (i.e. each microbatch contains exactly one example).
399+ If `gradient_accumulation_steps` is greater than 1 and
400+ `num_microbatches` is not `None` then the effective number of
401+ microbatches is equal to `num_microbatches *
402+ gradient_accumulation_steps`.
403+ gradient_accumulation_steps: If greater than 1 then optimizer will be
404+ accumulating gradients for this number of optimizer steps before
405+ applying them to update model weights. If this argument is set to 1 then
406+ updates will be applied on each optimizer step.
407+ *args: These will be passed on to the base class `__init__` method.
408+ **kwargs: These will be passed on to the base class `__init__` method.
409+ """
410+ dp_sum_query = gaussian_query .GaussianSumQuery (
411+ l2_norm_clip , l2_norm_clip * noise_multiplier )
412+ return cls (
413+ dp_sum_query = dp_sum_query ,
414+ num_microbatches = num_microbatches ,
415+ gradient_accumulation_steps = gradient_accumulation_steps ,
416+ * args ,
417+ ** kwargs )
418+
419+ return return_gaussian_query_optimizer
420+
421+
422+ def make_keras_optimizer_class (cls : Type [tf .keras .optimizers .Optimizer ]):
423+ """Returns a differentially private optimizer using the `GaussianSumQuery`.
424+
425+ For backwards compatibility, we create this symbol to match the previous
426+ output of `make_keras_optimizer_class` but using the new logic.
427+
428+ Args:
429+ cls: Class from which to derive a DP subclass. Should be a subclass of
430+ `tf.keras.optimizers.Optimizer`.
431+ """
432+ warnings .warn (
433+ '`make_keras_optimizer_class` will be depracated on 2023-02-23. '
434+ 'Please switch to `make_gaussian_query_optimizer_class` and the '
435+ 'generic optimizers (`make_keras_generic_optimizer_class`).' )
436+ return make_gaussian_query_optimizer_class (
437+ make_keras_generic_optimizer_class (cls ))
438+
439+
440+ GenericDPAdagradOptimizer = make_keras_generic_optimizer_class (
372441 tf .keras .optimizers .Adagrad )
373- DPKerasAdamOptimizer = make_keras_optimizer_class (tf .keras .optimizers .Adam )
374- DPKerasSGDOptimizer = make_keras_optimizer_class (tf .keras .optimizers .SGD )
442+ GenericDPAdamOptimizer = make_keras_generic_optimizer_class (
443+ tf .keras .optimizers .Adam )
444+ GenericDPSGDOptimizer = make_keras_generic_optimizer_class (
445+ tf .keras .optimizers .SGD )
446+
447+ # We keep the same names for backwards compatibility.
448+ DPKerasAdagradOptimizer = make_gaussian_query_optimizer_class (
449+ GenericDPAdagradOptimizer )
450+ DPKerasAdamOptimizer = make_gaussian_query_optimizer_class (
451+ GenericDPAdamOptimizer )
452+ DPKerasSGDOptimizer = make_gaussian_query_optimizer_class (GenericDPSGDOptimizer )
0 commit comments