Skip to content

Commit 4429d37

Browse files
authored
Fix bugs in SPSA early stopping (#906)
* tweak doc * add to documentation of early_stop_loss_threshold * more detail to doc * fix bugs in early stopping * Update cleverhans/attacks_tf.py Co-Authored-By: goodfeli <[email protected]>
1 parent bfae77b commit 4429d37

File tree

2 files changed

+31
-15
lines changed

2 files changed

+31
-15
lines changed

cleverhans/attacks/__init__.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2593,9 +2593,17 @@ def projected_optimization(loss_fn,
25932593
:param project_perturbation: A function, which will be used to enforce
25942594
some constraint. It should have the same
25952595
signature as `_project_perturbation`.
2596-
:param early_stop_loss_threshold: A float or None. If specified, the
2597-
attack will end if the loss is below
2598-
`early_stop_loss_threshold`.
2596+
:param early_stop_loss_threshold: A float or None. If specified, the attack will end if the loss is below
2597+
`early_stop_loss_threshold`.
2598+
Enabling this option can have several different effects:
2599+
- Setting the threshold to 0. guarantees that if a successful attack is found, it is returned.
2600+
This increases the attack success rate, because without early stopping the optimizer can accidentally
2601+
bounce back to a point where the attack fails.
2602+
- Early stopping can make the attack run faster because it may run for fewer steps.
2603+
- Early stopping can make the attack run slower because the loss must be calculated at each step.
2604+
The loss is not calculated as part of the normal SPSA optimization procedure.
2605+
For most reasonable choices of hyperparameters, early stopping makes the attack much faster because
2606+
it decreases the number of steps dramatically.
25992607
:param is_debug: A bool. If True, print debug info for attack progress.
26002608
26012609
Returns:
@@ -2635,20 +2643,28 @@ def wrapped_loss_fn(x):
26352643

26362644
new_perturbation_list, new_optim_state = optimizer.minimize(
26372645
wrapped_loss_fn, [perturbation], optim_state)
2638-
loss = reduce_mean(wrapped_loss_fn(perturbation), axis=0)
2639-
if is_debug:
2640-
with tf.device("/cpu:0"):
2641-
loss = tf.Print(loss, [loss], "Total batch loss")
26422646
projected_perturbation = project_perturbation(new_perturbation_list[0],
26432647
epsilon, input_image,
26442648
clip_min=clip_min,
26452649
clip_max=clip_max)
2646-
with tf.control_dependencies([loss]):
2647-
i = tf.identity(i)
2648-
if early_stop_loss_threshold:
2649-
i = tf.cond(
2650-
tf.less(loss, early_stop_loss_threshold),
2651-
lambda: float(num_steps), lambda: i)
2650+
2651+
# Be careful with this bool. A value of 0. is a valid threshold but evaluates to False, so we must explicitly
2652+
# check whether the value is None.
2653+
early_stop = early_stop_loss_threshold is not None
2654+
compute_loss = is_debug or early_stop
2655+
# Don't waste time building the loss graph if we're not going to use it
2656+
if compute_loss:
2657+
# NOTE: this step is not actually redundant with the optimizer step.
2658+
# SPSA calculates the loss at randomly perturbed points but doesn't calculate the loss at the current point.
2659+
loss = reduce_mean(wrapped_loss_fn(projected_perturbation), axis=0)
2660+
2661+
if is_debug:
2662+
with tf.device("/cpu:0"):
2663+
loss = tf.Print(loss, [loss], "Total batch loss")
2664+
2665+
if early_stop:
2666+
i = tf.cond(tf.less(loss, early_stop_loss_threshold), lambda: float(num_steps), lambda: i)
2667+
26522668
return i + 1, projected_perturbation, nest.flatten(new_optim_state)
26532669

26542670
def cond(i, *_):

cleverhans/attacks_tf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,8 +1485,8 @@ class TensorOptimizer(object):
14851485
behaviors when being assigned multiple times within a single sess.run()
14861486
call, particularly in Distributed TF, so this avoids thinking about those
14871487
issues. These are helper classes for the `projected_optimization`
1488-
method. Apart from not using Variables, they follow the same interface as
1489-
tf.Optimizer.
1488+
method. Apart from not using Variables, they follow an interface very
1489+
similar to tf.Optimizer.
14901490
"""
14911491

14921492
def _compute_gradients(self, loss_fn, x, unused_optim_state):

0 commit comments

Comments
 (0)