@@ -40,8 +40,8 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False,
4040 :param X_train: numpy array with training inputs
4141 :param Y_train: numpy array with training outputs
4242 :param save: Boolean controling the save operation
43- :param predictions_adv: if set with the adversarial example tensor,
44- will run adversarial training
43+ :param predictions_adv: if set with the adversarial example tensor,
44+ will run adversarial training
4545 :return: True if model trained
4646 """
4747 print "Starting model training using TensorFlow."
@@ -63,7 +63,8 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False,
6363 print ("Epoch " + str (epoch ))
6464
6565 # Compute number of batches
66- nb_batches = int (math .ceil (len (X_train ) / FLAGS .batch_size ))
66+ nb_batches = int (math .ceil (float (len (X_train )) / FLAGS .batch_size ))
67+ assert nb_batches * FLAGS .batch_size >= len (X_train )
6768
6869 prev = time .time ()
6970 for batch in range (nb_batches ):
@@ -80,6 +81,7 @@ def tf_model_train(sess, x, y, predictions, X_train, Y_train, save=False,
8081 train_step .run (feed_dict = {x : X_train [start :end ],
8182 y : Y_train [start :end ],
8283 keras .backend .learning_phase (): 1 })
84+ assert end >= len (X_train ) # Check that all examples were used
8385
8486
8587 if save :
@@ -112,21 +114,29 @@ def tf_model_eval(sess, x, y, model, X_test, Y_test):
112114
113115 with sess .as_default ():
114116 # Compute number of batches
115- nb_batches = int (math .ceil (len (X_test ) / FLAGS .batch_size ))
117+ nb_batches = int (math .ceil (float (len (X_test )) / FLAGS .batch_size ))
118+ assert nb_batches * FLAGS .batch_size >= len (X_test )
116119
117120 for batch in range (nb_batches ):
118121 if batch % 100 == 0 and batch > 0 :
119122 print ("Batch " + str (batch ))
120123
121- # Compute batch start and end indices
122- start , end = batch_indices (batch , len (X_test ), FLAGS .batch_size )
124+ # Must not use the `batch_indices` function here, because it
125+ # repeats some examples.
126+ # It's acceptable to repeat during training, but not eval.
127+ start = batch * FLAGS .batch_size
128+ end = min (len (X_test ), start + FLAGS .batch_size )
129+ cur_batch_size = end - start + 1
123130
124- accuracy += acc_value .eval (feed_dict = {x : X_test [start :end ],
131+ # The last batch may be smaller than all others, so we need to
132+ # account for variable batch size here
133+ accuracy += cur_batch_size * acc_value .eval (feed_dict = {x : X_test [start :end ],
125134 y : Y_test [start :end ],
126135 keras .backend .learning_phase (): 0 })
136+ assert end >= len (X_test )
127137
128- # Divide by number of batches to get final value
129- accuracy /= nb_batches
138+ # Divide by number of examples to get final value
139+ accuracy /= len ( X_test )
130140
131141 return accuracy
132142
0 commit comments