diff --git a/tfutil.py b/tfutil.py index cf7ad0ad..f4322976 100755 --- a/tfutil.py +++ b/tfutil.py @@ -452,7 +452,7 @@ def _init_fields(self): self._build_func_name = None # Name of the build function. self._build_module_src = None # Full source code of the module containing the build function. self._run_cache = dict() # Cached graph data for Network.run(). - + def _init_graph(self): # Collect inputs. self.input_names = [] @@ -466,7 +466,7 @@ def _init_graph(self): if self.name is None: self.name = self._build_func_name self.scope = tf.get_default_graph().unique_name(self.name.replace('/', '_'), mark_as_used=False) - + # Build template graph. with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE): assert tf.get_variable_scope().name == self.scope @@ -474,14 +474,14 @@ def _init_graph(self): with tf.control_dependencies(None): # ignore surrounding control_dependencies self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names] out_expr = self._build_func(*self.input_templates, is_template_graph=True, **self.static_kwargs) - + # Collect outputs. assert is_tf_expression(out_expr) or isinstance(out_expr, tuple) self.output_templates = [out_expr] if is_tf_expression(out_expr) else list(out_expr) self.output_names = [t.name.split('/')[-1].split(':')[0] for t in self.output_templates] self.num_outputs = len(self.output_templates) assert self.num_outputs >= 1 - + # Populate remaining fields. self.input_shapes = [shape_to_list(t.shape) for t in self.input_templates] self.output_shapes = [shape_to_list(t.shape) for t in self.output_templates] @@ -530,7 +530,7 @@ def find_var(self, var_or_localname): # Note: This method is very inefficient -- prefer to use tfutil.run(list_of_vars) whenever possible. def get_var(self, var_or_localname): return self.find_var(var_or_localname).eval() - + # Set the value of a given variable based on the given NumPy array. # Note: This method is very inefficient -- prefer to use tfutil.set_vars() whenever possible. def set_var(self, var_or_localname, new_value): @@ -560,13 +560,13 @@ def __setstate__(self, state): self.static_kwargs = state['static_kwargs'] self._build_module_src = state['build_module_src'] self._build_func_name = state['build_func_name'] - + # Parse imported module. module = imp.new_module('_tfutil_network_import_module_%d' % len(_network_import_modules)) exec(self._build_module_src, module.__dict__) self._build_func = find_obj_in_module(module, self._build_func_name) _network_import_modules.append(module) # avoid gc - + # Init graph. self._init_graph() self.reset_vars() @@ -746,4 +746,135 @@ def setup_weight_histograms(self, title=None): name = title + '_toplevel/' + localname tf.summary.histogram(name, var) + # This function takes all that run takes + etalons and finds the latents + # that approximate the etalon images. + # to use call: Gs.reverse_gan_for_etalons(latents, labels, etalons) + # where etalons.shape is for eg. (?, 1024, 1024, 3) ~ [-1:1] + # Returns the history of latents with the last solution being the best. + def reverse_gan_for_etalons(self, + *in_arrays, # Expects start values of latents, any labels and etalon images. + itterations = 100, # How many optimisations itterations to take. Emperical good value is 2000. + learning_rate = 0.000001, # Initial learning rate + stohastic_clipping = True, + return_as_list = False, # True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs. + print_progress = False, # Print progress to the console? Useful for very large input arrays. + minibatch_size = None, # Maximum minibatch size to use, None = disable batching. + num_gpus = 1, # Number of GPUs to use. + out_mul = 1.0, # Multiplicative constant to apply to the output(s). + out_add = 0.0, # Additive constant to apply to the output(s). + out_shrink = 1, # Shrink the spatial dimensions of the output(s) by the given factor. + out_dtype = None, # Convert the output to the specified data type. + **dynamic_kwargs): # Additional keyword arguments to pass into the network construction function. + + assert len(in_arrays) == 3 + num_items = in_arrays[0].shape[0] + if minibatch_size is None: + minibatch_size = num_items + key = str([list(sorted(dynamic_kwargs.items())), + num_gpus, + out_mul, + out_add, + out_shrink, + out_dtype]) + # Build graph. Same is in Run fuction. + if key not in self._run_cache: + with absolute_name_scope(self.scope + '/Run'), tf.control_dependencies(None): + in_split = list(zip(*[tf.split(x, num_gpus) for x in self.input_templates])) + out_split = [] + for gpu in range(num_gpus): + with tf.device('/gpu:%d' % gpu): + out_expr = self.get_output_for(*in_split[gpu], + return_as_list=True, + **dynamic_kwargs) + if out_mul != 1.0: + out_expr = [x * out_mul for x in out_expr] + if out_add != 0.0: + out_expr = [x + out_add for x in out_expr] + if out_shrink > 1: + ksize = [1, 1, out_shrink, out_shrink] + out_expr = [tf.nn.avg_pool(x, + ksize=ksize, + strides=ksize, + padding='VALID', + data_format='NCHW') for x in out_expr] + if out_dtype is not None: + if tf.as_dtype(out_dtype).is_integer: + out_expr = [tf.round(x) for x in out_expr] + out_expr = [tf.saturate_cast(x, out_dtype) for x in out_expr] + out_split.append(out_expr) + self._run_cache[key] = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)] + + # UP until now we were making a Tensor for model + out_expr = self._run_cache[key] + # Let's make a loss function: + psy_name = str(self.scope + '/etalon') + psy = tf.placeholder(tf.float32, out_expr[0].shape, name=psy_name) + # MSE loss for all etalons. + loss = tf.reduce_sum(tf.pow(out_expr[0] - psy, 2)) + latents_name = self.input_templates[0].name + input_latents = tf.get_default_graph().get_tensor_by_name(latents_name) + # Let's compute the gradient of loss function with regard to input: + gradient = tf.gradients(loss, input_latents) + # We modify existing template to feed etalons + # into the loss and gradient tensors: + templ = self.input_templates + templ.append(psy) + # Create a new feed dictionary: + feed_dict = dict(zip(templ, in_arrays)) + # Return loss and the gradient with it's feed dictionary + l_rate = learning_rate + latents = in_arrays[0] + samples_num = latents.shape[0] + print("Batch has {} etalons to reverese.".format(samples_num)) + # for recording the story of itterations + history = [] + c_min = 1e+9 + x_min = None + # Here is main optimisation logic. Stohastic clipping is + # from 'Precise Recovery of Latent Vectors from Generative + # Adversarial Networks', ICLR 2017 workshop track + # [arxiv]. https://arxiv.org/abs/1702.04782 + for i in range(itterations): + g = tf.get_default_session().run( + [loss, gradient], + feed_dict=feed_dict) + latents = latents - l_rate * g[1][0] + # Standard clipping + if stohastic_clipping: + # Stohastic clipping + for j in range(samples_num): + edge1 = np.where(latents[j] >= 1.)[0] + edge2 = np.where(latents[j] <= -1)[0] + if edge1.shape[0] > 0: + rand_el1 = np.random.uniform(-1, + 1, + size=(1, edge1.shape[0])) + latents[j, edge1] = rand_el1 + if edge2.shape[0] > 0: + rand_el2 = np.random.uniform(-1, + 1, + size=(1, edge2.shape[0])) + latents[j, edge2] = rand_el2 + else: + latents = np.clip(latents, -1, 1) + + # Udating the dictionary for next itteration. + feed_dict[input_latents] = latents + + if i % 50 == 49: + # We reduce the learning rate every 50 itterations + learning_rate /= 2 + # And record the history + history.append((g[0], latents)) + print(i, g[0]/samples_num) + + if g[0] < c_min: + # Saving the best latents + c_min = g[0] + x_min = latents + + # We return back the optimisation history of latents + history.append((c_min, x_min)) + return history + #----------------------------------------------------------------------------