Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 138 additions & 7 deletions tfutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def _init_fields(self):
self._build_func_name = None # Name of the build function.
self._build_module_src = None # Full source code of the module containing the build function.
self._run_cache = dict() # Cached graph data for Network.run().

def _init_graph(self):
# Collect inputs.
self.input_names = []
Expand All @@ -466,22 +466,22 @@ def _init_graph(self):
if self.name is None:
self.name = self._build_func_name
self.scope = tf.get_default_graph().unique_name(self.name.replace('/', '_'), mark_as_used=False)

# Build template graph.
with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE):
assert tf.get_variable_scope().name == self.scope
with absolute_name_scope(self.scope): # ignore surrounding name_scope
with tf.control_dependencies(None): # ignore surrounding control_dependencies
self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
out_expr = self._build_func(*self.input_templates, is_template_graph=True, **self.static_kwargs)

# Collect outputs.
assert is_tf_expression(out_expr) or isinstance(out_expr, tuple)
self.output_templates = [out_expr] if is_tf_expression(out_expr) else list(out_expr)
self.output_names = [t.name.split('/')[-1].split(':')[0] for t in self.output_templates]
self.num_outputs = len(self.output_templates)
assert self.num_outputs >= 1

# Populate remaining fields.
self.input_shapes = [shape_to_list(t.shape) for t in self.input_templates]
self.output_shapes = [shape_to_list(t.shape) for t in self.output_templates]
Expand Down Expand Up @@ -530,7 +530,7 @@ def find_var(self, var_or_localname):
# Note: This method is very inefficient -- prefer to use tfutil.run(list_of_vars) whenever possible.
def get_var(self, var_or_localname):
return self.find_var(var_or_localname).eval()

# Set the value of a given variable based on the given NumPy array.
# Note: This method is very inefficient -- prefer to use tfutil.set_vars() whenever possible.
def set_var(self, var_or_localname, new_value):
Expand Down Expand Up @@ -560,13 +560,13 @@ def __setstate__(self, state):
self.static_kwargs = state['static_kwargs']
self._build_module_src = state['build_module_src']
self._build_func_name = state['build_func_name']

# Parse imported module.
module = imp.new_module('_tfutil_network_import_module_%d' % len(_network_import_modules))
exec(self._build_module_src, module.__dict__)
self._build_func = find_obj_in_module(module, self._build_func_name)
_network_import_modules.append(module) # avoid gc

# Init graph.
self._init_graph()
self.reset_vars()
Expand Down Expand Up @@ -746,4 +746,135 @@ def setup_weight_histograms(self, title=None):
name = title + '_toplevel/' + localname
tf.summary.histogram(name, var)

# This function takes all that run takes + etalons and finds the latents
# that approximate the etalon images.
# to use call: Gs.reverse_gan_for_etalons(latents, labels, etalons)
# where etalons.shape is for eg. (?, 1024, 1024, 3) ~ [-1:1]
# Returns the history of latents with the last solution being the best.
def reverse_gan_for_etalons(self,
*in_arrays, # Expects start values of latents, any labels and etalon images.
itterations = 100, # How many optimisations itterations to take. Emperical good value is 2000.
learning_rate = 0.000001, # Initial learning rate
stohastic_clipping = True,
return_as_list = False, # True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
print_progress = False, # Print progress to the console? Useful for very large input arrays.
minibatch_size = None, # Maximum minibatch size to use, None = disable batching.
num_gpus = 1, # Number of GPUs to use.
out_mul = 1.0, # Multiplicative constant to apply to the output(s).
out_add = 0.0, # Additive constant to apply to the output(s).
out_shrink = 1, # Shrink the spatial dimensions of the output(s) by the given factor.
out_dtype = None, # Convert the output to the specified data type.
**dynamic_kwargs): # Additional keyword arguments to pass into the network construction function.

assert len(in_arrays) == 3
num_items = in_arrays[0].shape[0]
if minibatch_size is None:
minibatch_size = num_items
key = str([list(sorted(dynamic_kwargs.items())),
num_gpus,
out_mul,
out_add,
out_shrink,
out_dtype])
# Build graph. Same is in Run fuction.
if key not in self._run_cache:
with absolute_name_scope(self.scope + '/Run'), tf.control_dependencies(None):
in_split = list(zip(*[tf.split(x, num_gpus) for x in self.input_templates]))
out_split = []
for gpu in range(num_gpus):
with tf.device('/gpu:%d' % gpu):
out_expr = self.get_output_for(*in_split[gpu],
return_as_list=True,
**dynamic_kwargs)
if out_mul != 1.0:
out_expr = [x * out_mul for x in out_expr]
if out_add != 0.0:
out_expr = [x + out_add for x in out_expr]
if out_shrink > 1:
ksize = [1, 1, out_shrink, out_shrink]
out_expr = [tf.nn.avg_pool(x,
ksize=ksize,
strides=ksize,
padding='VALID',
data_format='NCHW') for x in out_expr]
if out_dtype is not None:
if tf.as_dtype(out_dtype).is_integer:
out_expr = [tf.round(x) for x in out_expr]
out_expr = [tf.saturate_cast(x, out_dtype) for x in out_expr]
out_split.append(out_expr)
self._run_cache[key] = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]

# UP until now we were making a Tensor for model
out_expr = self._run_cache[key]
# Let's make a loss function:
psy_name = str(self.scope + '/etalon')
psy = tf.placeholder(tf.float32, out_expr[0].shape, name=psy_name)
# MSE loss for all etalons.
loss = tf.reduce_sum(tf.pow(out_expr[0] - psy, 2))
latents_name = self.input_templates[0].name
input_latents = tf.get_default_graph().get_tensor_by_name(latents_name)
# Let's compute the gradient of loss function with regard to input:
gradient = tf.gradients(loss, input_latents)
# We modify existing template to feed etalons
# into the loss and gradient tensors:
templ = self.input_templates
templ.append(psy)
# Create a new feed dictionary:
feed_dict = dict(zip(templ, in_arrays))
# Return loss and the gradient with it's feed dictionary
l_rate = learning_rate
latents = in_arrays[0]
samples_num = latents.shape[0]
print("Batch has {} etalons to reverese.".format(samples_num))
# for recording the story of itterations
history = []
c_min = 1e+9
x_min = None
# Here is main optimisation logic. Stohastic clipping is
# from 'Precise Recovery of Latent Vectors from Generative
# Adversarial Networks', ICLR 2017 workshop track
# [arxiv]. https://arxiv.org/abs/1702.04782
for i in range(itterations):
g = tf.get_default_session().run(
[loss, gradient],
feed_dict=feed_dict)
latents = latents - l_rate * g[1][0]
# Standard clipping
if stohastic_clipping:
# Stohastic clipping
for j in range(samples_num):
edge1 = np.where(latents[j] >= 1.)[0]
edge2 = np.where(latents[j] <= -1)[0]
if edge1.shape[0] > 0:
rand_el1 = np.random.uniform(-1,
1,
size=(1, edge1.shape[0]))
latents[j, edge1] = rand_el1
if edge2.shape[0] > 0:
rand_el2 = np.random.uniform(-1,
1,
size=(1, edge2.shape[0]))
latents[j, edge2] = rand_el2
else:
latents = np.clip(latents, -1, 1)

# Udating the dictionary for next itteration.
feed_dict[input_latents] = latents

if i % 50 == 49:
# We reduce the learning rate every 50 itterations
learning_rate /= 2
# And record the history
history.append((g[0], latents))
print(i, g[0]/samples_num)

if g[0] < c_min:
# Saving the best latents
c_min = g[0]
x_min = latents

# We return back the optimisation history of latents
history.append((c_min, x_min))
return history

#----------------------------------------------------------------------------