Skip to content

Commit b1169ac

Browse files
authored
Merge pull request #46 from openai/other_classes
created helper function other_classes
2 parents c26f10a + a1331a0 commit b1169ac

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

cleverhans/attacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import multiprocessing as mp
1212

1313
from . import utils_tf
14+
from . import utils
1415

1516
from tensorflow.python.platform import flags
1617
FLAGS = flags.FLAGS
@@ -200,8 +201,7 @@ def jacobian(sess, x, grads, target, X):
200201

201202
# Sum over all classes different from the target class to prepare for
202203
# saliency map computation in the next step of the attack
203-
other_classes = list(xrange(FLAGS.nb_classes))
204-
other_classes.remove(target)
204+
other_classes = utils.other_classes(FLAGS.nb_classes, target)
205205
grad_others = np.sum(jacobian_val[other_classes, :, :], axis=0)
206206

207207
return jacobian_val[target], grad_others

cleverhans/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,17 @@ def batch_indices(batch_nb, data_length, batch_size):
7979
end -= shift
8080

8181
return start, end
82+
83+
84+
def other_classes(nb_classes, class_ind):
85+
"""
86+
Heper function that returns a list of class indices without one class
87+
:param nb_classes: number of classes in total
88+
:param class_ind: the class index to be omitted
89+
:return: list of class indices without one class
90+
"""
91+
92+
other_classes_list = list(xrange(nb_classes))
93+
other_classes_list.remove(class_ind)
94+
95+
return other_classes_list

tutorials/mnist_tutorial_jsma.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from cleverhans.utils_mnist import data_mnist, model_mnist
1515
from cleverhans.utils_tf import tf_model_train, tf_model_eval
1616
from cleverhans.attacks import jsma, jacobian_graph
17+
from cleverhans.utils import other_classes
1718

1819
FLAGS = flags.FLAGS
1920

@@ -101,8 +102,7 @@ def main(argv=None):
101102
for sample_ind in xrange(FLAGS.source_samples):
102103
# We want to find an adversarial example for each possible target class
103104
# (i.e. all classes that differ from the label given in the dataset)
104-
target_classes = list(xrange(FLAGS.nb_classes))
105-
target_classes.remove(int(np.argmax(Y_test[sample_ind])))
105+
target_classes = other_classes(FLAGS.nb_classes, int(np.argmax(Y_test[sample_ind])))
106106

107107
# Loop over all target classes
108108
for target in target_classes:
@@ -128,5 +128,8 @@ def main(argv=None):
128128
percentage_perturbed = np.mean(perturbations)
129129
print('Avg. rate of perterbed features {0}'.format(percentage_perturbed))
130130

131+
# Close TF session
132+
sess.close()
133+
131134
if __name__ == '__main__':
132135
app.run()

0 commit comments

Comments
 (0)