@@ -2593,9 +2593,17 @@ def projected_optimization(loss_fn,
25932593 :param project_perturbation: A function, which will be used to enforce
25942594 some constraint. It should have the same
25952595 signature as `_project_perturbation`.
2596- :param early_stop_loss_threshold: A float or None. If specified, the
2597- attack will end if the loss is below
2598- `early_stop_loss_threshold`.
2596+ :param early_stop_loss_threshold: A float or None. If specified, the attack will end if the loss is below
2597+ `early_stop_loss_threshold`.
2598+ Enabling this option can have several different effects:
2599+ - Setting the threshold to 0. guarantees that if a successful attack is found, it is returned.
2600+ This increases the attack success rate, because without early stopping the optimizer can accidentally
2601+ bounce back to a point where the attack fails.
2602+ - Early stopping can make the attack run faster because it may run for fewer steps.
2603+ - Early stopping can make the attack run slower because the loss must be calculated at each step.
2604+ The loss is not calculated as part of the normal SPSA optimization procedure.
2605+ For most reasonable choices of hyperparameters, early stopping makes the attack much faster because
2606+ it decreases the number of steps dramatically.
25992607 :param is_debug: A bool. If True, print debug info for attack progress.
26002608
26012609 Returns:
@@ -2635,20 +2643,28 @@ def wrapped_loss_fn(x):
26352643
26362644 new_perturbation_list , new_optim_state = optimizer .minimize (
26372645 wrapped_loss_fn , [perturbation ], optim_state )
2638- loss = reduce_mean (wrapped_loss_fn (perturbation ), axis = 0 )
2639- if is_debug :
2640- with tf .device ("/cpu:0" ):
2641- loss = tf .Print (loss , [loss ], "Total batch loss" )
26422646 projected_perturbation = project_perturbation (new_perturbation_list [0 ],
26432647 epsilon , input_image ,
26442648 clip_min = clip_min ,
26452649 clip_max = clip_max )
2646- with tf .control_dependencies ([loss ]):
2647- i = tf .identity (i )
2648- if early_stop_loss_threshold :
2649- i = tf .cond (
2650- tf .less (loss , early_stop_loss_threshold ),
2651- lambda : float (num_steps ), lambda : i )
2650+
2651+ # Be careful with this bool. A value of 0. is a valid threshold but evaluates to False, so we must explicitly
2652+ # check whether the value is None.
2653+ early_stop = early_stop_loss_threshold is not None
2654+ compute_loss = is_debug or early_stop
2655+ # Don't waste time building the loss graph if we're not going to use it
2656+ if compute_loss :
2657+ # NOTE: this step is not actually redundant with the optimizer step.
2658+ # SPSA calculates the loss at randomly perturbed points but doesn't calculate the loss at the current point.
2659+ loss = reduce_mean (wrapped_loss_fn (projected_perturbation ), axis = 0 )
2660+
2661+ if is_debug :
2662+ with tf .device ("/cpu:0" ):
2663+ loss = tf .Print (loss , [loss ], "Total batch loss" )
2664+
2665+ if early_stop :
2666+ i = tf .cond (tf .less (loss , early_stop_loss_threshold ), lambda : float (num_steps ), lambda : i )
2667+
26522668 return i + 1 , projected_perturbation , nest .flatten (new_optim_state )
26532669
26542670 def cond (i , * _ ):
0 commit comments