laywerrobot/lib/python3.6/site-packages/tensorflow/contrib/gan/python/train.py
2020-08-27 21:55:39 +02:00

1044 lines
41 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TFGAN project provides a lightweight GAN training/testing framework.
This file contains the core helper functions to create and train a GAN model.
See the README or examples in `tensorflow_models` for details on how to use.
TFGAN training occurs in four steps:
1) Create a model
2) Add a loss
3) Create train ops
4) Run the train ops
The functions in this file are organized around these four steps. Each function
corresponds to one of the steps.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import variables as variables_lib
from tensorflow.contrib.gan.python import losses as tfgan_losses
from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.slim.python.slim import learning as slim_learning
from tensorflow.contrib.training.python.training import training
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.distributions import distribution as ds
from tensorflow.python.ops.losses import losses
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
__all__ = [
'gan_model',
'infogan_model',
'acgan_model',
'cyclegan_model',
'gan_loss',
'cyclegan_loss',
'gan_train_ops',
'gan_train',
'get_sequential_train_hooks',
'get_joint_train_hooks',
'get_sequential_train_steps',
'RunTrainOpsHook',
]
def gan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# Real data and conditioning.
real_data,
generator_inputs,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator',
# Options.
check_shapes=True):
"""Returns GAN model outputs and variables.
Args:
generator_fn: A python lambda that takes `generator_inputs` as inputs and
returns the outputs of the GAN generator.
discriminator_fn: A python lambda that takes `real_data`/`generated data`
and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
real_data: A Tensor representing the real data.
generator_inputs: A Tensor or list of Tensors to the generator. In the
vanilla GAN case, this might be a single noise Tensor. In the conditional
GAN case, this might be the generator's conditioning.
generator_scope: Optional generator variable scope. Useful if you want to
reuse a subgraph that has already been created.
discriminator_scope: Optional discriminator variable scope. Useful if you
want to reuse a subgraph that has already been created.
check_shapes: If `True`, check that generator produces Tensors that are the
same shape as real data. Otherwise, skip this check.
Returns:
A GANModel namedtuple.
Raises:
ValueError: If the generator outputs a Tensor that isn't the same shape as
`real_data`.
"""
# Create models
with variable_scope.variable_scope(generator_scope) as gen_scope:
generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as dis_scope:
discriminator_gen_outputs = discriminator_fn(generated_data,
generator_inputs)
with variable_scope.variable_scope(dis_scope, reuse=True):
real_data = ops.convert_to_tensor(real_data)
discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)
if check_shapes:
if not generated_data.shape.is_compatible_with(real_data.shape):
raise ValueError(
'Generator output shape (%s) must be the same shape as real data '
'(%s).' % (generated_data.shape, real_data.shape))
# Get model-specific variables.
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
return namedtuples.GANModel(
generator_inputs,
generated_data,
generator_variables,
gen_scope,
generator_fn,
real_data,
discriminator_real_outputs,
discriminator_gen_outputs,
discriminator_variables,
dis_scope,
discriminator_fn)
def infogan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# Real data and conditioning.
real_data,
unstructured_generator_inputs,
structured_generator_inputs,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator'):
"""Returns an InfoGAN model outputs and variables.
See https://arxiv.org/abs/1606.03657 for more details.
Args:
generator_fn: A python lambda that takes a list of Tensors as inputs and
returns the outputs of the GAN generator.
discriminator_fn: A python lambda that takes `real_data`/`generated data`
and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list).
`logits` are in the range [-inf, inf], and `distribution_list` is a list
of Tensorflow distributions representing the predicted noise distribution
of the ith structure noise.
real_data: A Tensor representing the real data.
unstructured_generator_inputs: A list of Tensors to the generator.
These tensors represent the unstructured noise or conditioning.
structured_generator_inputs: A list of Tensors to the generator.
These tensors must have high mutual information with the recognizer.
generator_scope: Optional generator variable scope. Useful if you want to
reuse a subgraph that has already been created.
discriminator_scope: Optional discriminator variable scope. Useful if you
want to reuse a subgraph that has already been created.
Returns:
An InfoGANModel namedtuple.
Raises:
ValueError: If the generator outputs a Tensor that isn't the same shape as
`real_data`.
ValueError: If the discriminator output is malformed.
"""
# Create models
with variable_scope.variable_scope(generator_scope) as gen_scope:
unstructured_generator_inputs = _convert_tensor_or_l_or_d(
unstructured_generator_inputs)
structured_generator_inputs = _convert_tensor_or_l_or_d(
structured_generator_inputs)
generator_inputs = (
unstructured_generator_inputs + structured_generator_inputs)
generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as disc_scope:
dis_gen_outputs, predicted_distributions = discriminator_fn(
generated_data, generator_inputs)
_validate_distributions(predicted_distributions, structured_generator_inputs)
with variable_scope.variable_scope(disc_scope, reuse=True):
real_data = ops.convert_to_tensor(real_data)
dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)
if not generated_data.get_shape().is_compatible_with(real_data.get_shape()):
raise ValueError(
'Generator output shape (%s) must be the same shape as real data '
'(%s).' % (generated_data.get_shape(), real_data.get_shape()))
# Get model-specific variables.
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(
disc_scope)
return namedtuples.InfoGANModel(
generator_inputs,
generated_data,
generator_variables,
gen_scope,
generator_fn,
real_data,
dis_real_outputs,
dis_gen_outputs,
discriminator_variables,
disc_scope,
lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API
structured_generator_inputs,
predicted_distributions,
discriminator_fn)
def acgan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# Real data and conditioning.
real_data,
generator_inputs,
one_hot_labels,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator',
# Options.
check_shapes=True):
"""Returns an ACGANModel contains all the pieces needed for ACGAN training.
The `acgan_model` is the same as the `gan_model` with the only difference
being that the discriminator additionally outputs logits to classify the input
(real or generated).
Therefore, an explicit field holding one_hot_labels is necessary, as well as a
discriminator_fn that outputs a 2-tuple holding the logits for real/fake and
classification.
See https://arxiv.org/abs/1610.09585 for more details.
Args:
generator_fn: A python lambda that takes `generator_inputs` as inputs and
returns the outputs of the GAN generator.
discriminator_fn: A python lambda that takes `real_data`/`generated data`
and `generator_inputs`. Outputs a tuple consisting of two Tensors:
(1) real/fake logits in the range [-inf, inf]
(2) classification logits in the range [-inf, inf]
real_data: A Tensor representing the real data.
generator_inputs: A Tensor or list of Tensors to the generator. In the
vanilla GAN case, this might be a single noise Tensor. In the conditional
GAN case, this might be the generator's conditioning.
one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by
acgan_loss.
generator_scope: Optional generator variable scope. Useful if you want to
reuse a subgraph that has already been created.
discriminator_scope: Optional discriminator variable scope. Useful if you
want to reuse a subgraph that has already been created.
check_shapes: If `True`, check that generator produces Tensors that are the
same shape as real data. Otherwise, skip this check.
Returns:
A ACGANModel namedtuple.
Raises:
ValueError: If the generator outputs a Tensor that isn't the same shape as
`real_data`.
TypeError: If the discriminator does not output a tuple consisting of
(discrimination logits, classification logits).
"""
# Create models
with variable_scope.variable_scope(generator_scope) as gen_scope:
generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as dis_scope:
with ops.name_scope(dis_scope.name+'/generated/'):
(discriminator_gen_outputs, discriminator_gen_classification_logits
) = _validate_acgan_discriminator_outputs(
discriminator_fn(generated_data, generator_inputs))
with variable_scope.variable_scope(dis_scope, reuse=True):
with ops.name_scope(dis_scope.name+'/real/'):
real_data = ops.convert_to_tensor(real_data)
(discriminator_real_outputs, discriminator_real_classification_logits
) = _validate_acgan_discriminator_outputs(
discriminator_fn(real_data, generator_inputs))
if check_shapes:
if not generated_data.shape.is_compatible_with(real_data.shape):
raise ValueError(
'Generator output shape (%s) must be the same shape as real data '
'(%s).' % (generated_data.shape, real_data.shape))
# Get model-specific variables.
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(
dis_scope)
return namedtuples.ACGANModel(
generator_inputs, generated_data, generator_variables, gen_scope,
generator_fn, real_data, discriminator_real_outputs,
discriminator_gen_outputs, discriminator_variables, dis_scope,
discriminator_fn, one_hot_labels,
discriminator_real_classification_logits,
discriminator_gen_classification_logits)
def cyclegan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# data X and Y.
data_x,
data_y,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator',
model_x2y_scope='ModelX2Y',
model_y2x_scope='ModelY2X',
# Options.
check_shapes=True):
"""Returns a CycleGAN model outputs and variables.
See https://arxiv.org/abs/1703.10593 for more details.
Args:
generator_fn: A python lambda that takes `data_x` or `data_y` as inputs and
returns the outputs of the GAN generator.
discriminator_fn: A python lambda that takes `real_data`/`generated data`
and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
data_x: A `Tensor` of dataset X. Must be the same shape as `data_y`.
data_y: A `Tensor` of dataset Y. Must be the same shape as `data_x`.
generator_scope: Optional generator variable scope. Useful if you want to
reuse a subgraph that has already been created. Defaults to 'Generator'.
discriminator_scope: Optional discriminator variable scope. Useful if you
want to reuse a subgraph that has already been created. Defaults to
'Discriminator'.
model_x2y_scope: Optional variable scope for model x2y variables. Defaults
to 'ModelX2Y'.
model_y2x_scope: Optional variable scope for model y2x variables. Defaults
to 'ModelY2X'.
check_shapes: If `True`, check that generator produces Tensors that are the
same shape as `data_x` (`data_y`). Otherwise, skip this check.
Returns:
A `CycleGANModel` namedtuple.
Raises:
ValueError: If `check_shapes` is True and `data_x` or the generator output
does not have the same shape as `data_y`.
"""
# Create models.
def _define_partial_model(input_data, output_data):
return gan_model(
generator_fn=generator_fn,
discriminator_fn=discriminator_fn,
real_data=output_data,
generator_inputs=input_data,
generator_scope=generator_scope,
discriminator_scope=discriminator_scope,
check_shapes=check_shapes)
with variable_scope.variable_scope(model_x2y_scope):
model_x2y = _define_partial_model(data_x, data_y)
with variable_scope.variable_scope(model_y2x_scope):
model_y2x = _define_partial_model(data_y, data_x)
with variable_scope.variable_scope(model_y2x.generator_scope, reuse=True):
reconstructed_x = model_y2x.generator_fn(model_x2y.generated_data)
with variable_scope.variable_scope(model_x2y.generator_scope, reuse=True):
reconstructed_y = model_x2y.generator_fn(model_y2x.generated_data)
return namedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x,
reconstructed_y)
def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'):
if isinstance(aux_loss_weight, ops.Tensor):
aux_loss_weight.shape.assert_is_compatible_with([])
with ops.control_dependencies(
[check_ops.assert_greater_equal(aux_loss_weight, 0.0)]):
aux_loss_weight = array_ops.identity(aux_loss_weight)
elif aux_loss_weight is not None and aux_loss_weight < 0:
raise ValueError('`%s` must be greater than 0. Instead, was %s' %
(name, aux_loss_weight))
return aux_loss_weight
def _use_aux_loss(aux_loss_weight):
if aux_loss_weight is not None:
if not isinstance(aux_loss_weight, ops.Tensor):
return aux_loss_weight > 0
else:
return True
else:
return False
def _tensor_pool_adjusted_model(model, tensor_pool_fn):
"""Adjusts model using `tensor_pool_fn`.
Args:
model: A GANModel tuple.
tensor_pool_fn: A function that takes (generated_data, generator_inputs),
stores them in an internal pool and returns a previously stored
(generated_data, generator_inputs) with some probability. For example
tfgan.features.tensor_pool.
Returns:
A new GANModel tuple where discriminator outputs are adjusted by taking
pooled generator outputs as inputs. Returns the original model if
`tensor_pool_fn` is None.
Raises:
ValueError: If tensor pool does not support the `model`.
"""
if tensor_pool_fn is None:
return model
pooled_generated_data, pooled_generator_inputs = tensor_pool_fn(
(model.generated_data, model.generator_inputs))
if isinstance(model, namedtuples.GANModel):
with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
dis_gen_outputs = model.discriminator_fn(pooled_generated_data,
pooled_generator_inputs)
return model._replace(discriminator_gen_outputs=dis_gen_outputs)
elif isinstance(model, namedtuples.ACGANModel):
with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
(dis_pooled_gen_outputs,
dis_pooled_gen_classification_logits) = model.discriminator_fn(
pooled_generated_data, pooled_generator_inputs)
return model._replace(
discriminator_gen_outputs=dis_pooled_gen_outputs,
discriminator_gen_classification_logits=
dis_pooled_gen_classification_logits)
elif isinstance(model, namedtuples.InfoGANModel):
with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
(dis_pooled_gen_outputs,
pooled_predicted_distributions) = model.discriminator_and_aux_fn(
pooled_generated_data, pooled_generator_inputs)
return model._replace(
discriminator_gen_outputs=dis_pooled_gen_outputs,
predicted_distributions=pooled_predicted_distributions)
else:
raise ValueError('Tensor pool does not support `model`: %s.' % type(model))
def gan_loss(
# GANModel.
model,
# Loss functions.
generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,
# Auxiliary losses.
gradient_penalty_weight=None,
gradient_penalty_epsilon=1e-10,
gradient_penalty_target=1.0,
gradient_penalty_one_sided=False,
mutual_information_penalty_weight=None,
aux_cond_generator_weight=None,
aux_cond_discriminator_weight=None,
tensor_pool_fn=None,
# Options.
add_summaries=True):
"""Returns losses necessary to train generator and discriminator.
Args:
model: A GANModel tuple.
generator_loss_fn: The loss function on the generator. Takes a GANModel
tuple.
discriminator_loss_fn: The loss function on the discriminator. Takes a
GANModel tuple.
gradient_penalty_weight: If not `None`, must be a non-negative Python number
or Tensor indicating how much to weight the gradient penalty. See
https://arxiv.org/pdf/1704.00028.pdf for more details.
gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the
small positive value used by the gradient penalty function for numerical
stability. Note some applications will need to increase this value to
avoid NaNs.
gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python
number or `Tensor` indicating the target value of gradient norm. See the
CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0.
gradient_penalty_one_sided: If `True`, penalty proposed in
https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
mutual_information_penalty_weight: If not `None`, must be a non-negative
Python number or Tensor indicating how much to weight the mutual
information penalty. See https://arxiv.org/abs/1606.03657 for more
details.
aux_cond_generator_weight: If not None: add a classification loss as in
https://arxiv.org/abs/1610.09585
aux_cond_discriminator_weight: If not None: add a classification loss as in
https://arxiv.org/abs/1610.09585
tensor_pool_fn: A function that takes (generated_data, generator_inputs),
stores them in an internal pool and returns previous stored
(generated_data, generator_inputs). For example
`tf.gan.features.tensor_pool`. Defaults to None (not using tensor pool).
add_summaries: Whether or not to add summaries for the losses.
Returns:
A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes
regularization losses.
Raises:
ValueError: If any of the auxiliary loss weights is provided and negative.
ValueError: If `mutual_information_penalty_weight` is provided, but the
`model` isn't an `InfoGANModel`.
"""
# Validate arguments.
gradient_penalty_weight = _validate_aux_loss_weight(gradient_penalty_weight,
'gradient_penalty_weight')
mutual_information_penalty_weight = _validate_aux_loss_weight(
mutual_information_penalty_weight, 'infogan_weight')
aux_cond_generator_weight = _validate_aux_loss_weight(
aux_cond_generator_weight, 'aux_cond_generator_weight')
aux_cond_discriminator_weight = _validate_aux_loss_weight(
aux_cond_discriminator_weight, 'aux_cond_discriminator_weight')
# Verify configuration for mutual information penalty
if (_use_aux_loss(mutual_information_penalty_weight) and
not isinstance(model, namedtuples.InfoGANModel)):
raise ValueError(
'When `mutual_information_penalty_weight` is provided, `model` must be '
'an `InfoGANModel`. Instead, was %s.' % type(model))
# Verify configuration for mutual auxiliary condition loss (ACGAN).
if ((_use_aux_loss(aux_cond_generator_weight) or
_use_aux_loss(aux_cond_discriminator_weight)) and
not isinstance(model, namedtuples.ACGANModel)):
raise ValueError(
'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` '
'is provided, `model` must be an `ACGANModel`. Instead, was %s.' %
type(model))
# Create standard losses.
gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
dis_loss = discriminator_loss_fn(
_tensor_pool_adjusted_model(model, tensor_pool_fn),
add_summaries=add_summaries)
# Add optional extra losses.
if _use_aux_loss(gradient_penalty_weight):
gp_loss = tfgan_losses.wasserstein_gradient_penalty(
model,
epsilon=gradient_penalty_epsilon,
target=gradient_penalty_target,
one_sided=gradient_penalty_one_sided,
add_summaries=add_summaries)
dis_loss += gradient_penalty_weight * gp_loss
if _use_aux_loss(mutual_information_penalty_weight):
info_loss = tfgan_losses.mutual_information_penalty(
model, add_summaries=add_summaries)
dis_loss += mutual_information_penalty_weight * info_loss
gen_loss += mutual_information_penalty_weight * info_loss
if _use_aux_loss(aux_cond_generator_weight):
ac_gen_loss = tfgan_losses.acgan_generator_loss(
model, add_summaries=add_summaries)
gen_loss += aux_cond_generator_weight * ac_gen_loss
if _use_aux_loss(aux_cond_discriminator_weight):
ac_disc_loss = tfgan_losses.acgan_discriminator_loss(
model, add_summaries=add_summaries)
dis_loss += aux_cond_discriminator_weight * ac_disc_loss
# Gathers auxiliary losses.
if model.generator_scope:
gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name)
else:
gen_reg_loss = 0
if model.discriminator_scope:
dis_reg_loss = losses.get_regularization_loss(
model.discriminator_scope.name)
else:
dis_reg_loss = 0
return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss)
def cyclegan_loss(
model,
# Loss functions.
generator_loss_fn=tfgan_losses.least_squares_generator_loss,
discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss,
# Auxiliary losses.
cycle_consistency_loss_fn=tfgan_losses.cycle_consistency_loss,
cycle_consistency_loss_weight=10.0,
# Options
**kwargs):
"""Returns the losses for a `CycleGANModel`.
See https://arxiv.org/abs/1703.10593 for more details.
Args:
model: A `CycleGANModel` namedtuple.
generator_loss_fn: The loss function on the generator. Takes a `GANModel`
named tuple.
discriminator_loss_fn: The loss function on the discriminator. Takes a
`GANModel` namedtuple.
cycle_consistency_loss_fn: The cycle consistency loss function. Takes a
`CycleGANModel` namedtuple.
cycle_consistency_loss_weight: A non-negative Python number or a scalar
`Tensor` indicating how much to weigh the cycle consistency loss.
**kwargs: Keyword args to pass directly to `gan_loss` to construct the loss
for each partial model of `model`.
Returns:
A `CycleGANLoss` namedtuple.
Raises:
ValueError: If `model` is not a `CycleGANModel` namedtuple.
"""
# Sanity checks.
if not isinstance(model, namedtuples.CycleGANModel):
raise ValueError(
'`model` must be a `CycleGANModel`. Instead, was %s.' % type(model))
# Defines cycle consistency loss.
cycle_consistency_loss = cycle_consistency_loss_fn(
model, add_summaries=kwargs.get('add_summaries', True))
cycle_consistency_loss_weight = _validate_aux_loss_weight(
cycle_consistency_loss_weight, 'cycle_consistency_loss_weight')
aux_loss = cycle_consistency_loss_weight * cycle_consistency_loss
# Defines losses for each partial model.
def _partial_loss(partial_model):
partial_loss = gan_loss(
partial_model,
generator_loss_fn=generator_loss_fn,
discriminator_loss_fn=discriminator_loss_fn,
**kwargs)
return partial_loss._replace(
generator_loss=partial_loss.generator_loss + aux_loss)
with ops.name_scope('cyclegan_loss_x2y'):
loss_x2y = _partial_loss(model.model_x2y)
with ops.name_scope('cyclegan_loss_y2x'):
loss_y2x = _partial_loss(model.model_y2x)
return namedtuples.CycleGANLoss(loss_x2y, loss_y2x)
def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
"""Gets generator and discriminator update ops.
Args:
kwargs: A dictionary of kwargs to be passed to `create_train_op`.
`update_ops` is removed, if present.
gen_scope: A scope for the generator.
dis_scope: A scope for the discriminator.
check_for_unused_ops: A Python bool. If `True`, throw Exception if there are
unused update ops.
Returns:
A 2-tuple of (generator update ops, discriminator train ops).
Raises:
ValueError: If there are update ops outside of the generator or
discriminator scopes.
"""
if 'update_ops' in kwargs:
update_ops = set(kwargs['update_ops'])
del kwargs['update_ops']
else:
update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
all_gen_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, gen_scope))
all_dis_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, dis_scope))
if check_for_unused_ops:
unused_ops = update_ops - all_gen_ops - all_dis_ops
if unused_ops:
raise ValueError('There are unused update ops: %s' % unused_ops)
gen_update_ops = list(all_gen_ops & update_ops)
dis_update_ops = list(all_dis_ops & update_ops)
return gen_update_ops, dis_update_ops
def gan_train_ops(
model,
loss,
generator_optimizer,
discriminator_optimizer,
check_for_unused_update_ops=True,
# Optional args to pass directly to the `create_train_op`.
**kwargs):
"""Returns GAN train ops.
The highest-level call in TFGAN. It is composed of functions that can also
be called, should a user require more control over some part of the GAN
training process.
Args:
model: A GANModel.
loss: A GANLoss.
generator_optimizer: The optimizer for generator updates.
discriminator_optimizer: The optimizer for the discriminator updates.
check_for_unused_update_ops: If `True`, throws an exception if there are
update ops outside of the generator or discriminator scopes.
**kwargs: Keyword args to pass directly to
`training.create_train_op` for both the generator and
discriminator train op.
Returns:
A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
be used to train a generator/discriminator pair.
"""
if isinstance(model, namedtuples.CycleGANModel):
# Get and store all arguments other than model and loss from locals.
# Contents of locals should not be modified, may not affect values. So make
# a copy. https://docs.python.org/2/library/functions.html#locals.
saved_params = dict(locals())
saved_params.pop('model', None)
saved_params.pop('loss', None)
kwargs = saved_params.pop('kwargs', {})
saved_params.update(kwargs)
with ops.name_scope('cyclegan_x2y_train'):
train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y,
**saved_params)
with ops.name_scope('cyclegan_y2x_train'):
train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x,
**saved_params)
return namedtuples.GANTrainOps(
(train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op),
(train_ops_x2y.discriminator_train_op,
train_ops_y2x.discriminator_train_op),
training_util.get_or_create_global_step().assign_add(1))
# Create global step increment op.
global_step = training_util.get_or_create_global_step()
global_step_inc = global_step.assign_add(1)
# Get generator and discriminator update ops. We split them so that update
# ops aren't accidentally run multiple times. For now, throw an error if
# there are update ops that aren't associated with either the generator or
# the discriminator. Might modify the `kwargs` dictionary.
gen_update_ops, dis_update_ops = _get_update_ops(
kwargs, model.generator_scope.name, model.discriminator_scope.name,
check_for_unused_update_ops)
generator_global_step = None
if isinstance(generator_optimizer,
sync_replicas_optimizer.SyncReplicasOptimizer):
# TODO(joelshor): Figure out a way to get this work without including the
# dummy global step in the checkpoint.
# WARNING: Making this variable a local variable causes sync replicas to
# hang forever.
generator_global_step = variable_scope.get_variable(
'dummy_global_step_generator',
shape=[],
dtype=global_step.dtype.base_dtype,
initializer=init_ops.zeros_initializer(),
trainable=False,
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
gen_update_ops += [generator_global_step.assign(global_step)]
with ops.name_scope('generator_train'):
gen_train_op = training.create_train_op(
total_loss=loss.generator_loss,
optimizer=generator_optimizer,
variables_to_train=model.generator_variables,
global_step=generator_global_step,
update_ops=gen_update_ops,
**kwargs)
discriminator_global_step = None
if isinstance(discriminator_optimizer,
sync_replicas_optimizer.SyncReplicasOptimizer):
# See comment above `generator_global_step`.
discriminator_global_step = variable_scope.get_variable(
'dummy_global_step_discriminator',
shape=[],
dtype=global_step.dtype.base_dtype,
initializer=init_ops.zeros_initializer(),
trainable=False,
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
dis_update_ops += [discriminator_global_step.assign(global_step)]
with ops.name_scope('discriminator_train'):
disc_train_op = training.create_train_op(
total_loss=loss.discriminator_loss,
optimizer=discriminator_optimizer,
variables_to_train=model.discriminator_variables,
global_step=discriminator_global_step,
update_ops=dis_update_ops,
**kwargs)
return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc)
# TODO(joelshor): Implement a dynamic GAN train loop, as in `Real-Time Adaptive
# Image Compression` (https://arxiv.org/abs/1705.05823)
class RunTrainOpsHook(session_run_hook.SessionRunHook):
"""A hook to run train ops a fixed number of times."""
def __init__(self, train_ops, train_steps):
"""Run train ops a certain number of times.
Args:
train_ops: A train op or iterable of train ops to run.
train_steps: The number of times to run the op(s).
"""
if not isinstance(train_ops, (list, tuple)):
train_ops = [train_ops]
self._train_ops = train_ops
self._train_steps = train_steps
def before_run(self, run_context):
for _ in range(self._train_steps):
run_context.session.run(self._train_ops)
def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
"""Returns a hooks function for sequential GAN training.
Args:
train_steps: A `GANTrainSteps` tuple that determines how many generator
and discriminator training steps to take.
Returns:
A function that takes a GANTrainOps tuple and returns a list of hooks.
"""
def get_hooks(train_ops):
generator_hook = RunTrainOpsHook(train_ops.generator_train_op,
train_steps.generator_train_steps)
discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op,
train_steps.discriminator_train_steps)
return [generator_hook, discriminator_hook]
return get_hooks
def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
"""Returns a hooks function for sequential GAN training.
When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON
ALL OPTIMIZERS TO AVOID RACE CONDITIONS.
The order of steps taken is:
1) Combined generator and discriminator steps
2) Generator only steps, if any remain
3) Discriminator only steps, if any remain
**NOTE**: Unlike `get_sequential_train_hooks`, this method performs updates
for the generator and discriminator simultaneously whenever possible. This
reduces the number of `tf.Session` calls, and can also change the training
semantics.
To illustrate the difference look at the following example:
`train_steps=namedtuples.GANTrainSteps(3, 5)` will cause
`get_sequential_train_hooks` to make 8 session calls:
1) 3 generator steps
2) 5 discriminator steps
In contrast, `get_joint_train_steps` will make 5 session calls:
1) 3 generator + discriminator steps
2) 2 discriminator steps
Args:
train_steps: A `GANTrainSteps` tuple that determines how many generator
and discriminator training steps to take.
Returns:
A function that takes a GANTrainOps tuple and returns a list of hooks.
"""
g_steps = train_steps.generator_train_steps
d_steps = train_steps.discriminator_train_steps
# Get the number of each type of step that should be run.
num_d_and_g_steps = min(g_steps, d_steps)
num_g_steps = g_steps - num_d_and_g_steps
num_d_steps = d_steps - num_d_and_g_steps
def get_hooks(train_ops):
g_op = train_ops.generator_train_op
d_op = train_ops.discriminator_train_op
joint_hook = RunTrainOpsHook([g_op, d_op], num_d_and_g_steps)
g_hook = RunTrainOpsHook(g_op, num_g_steps)
d_hook = RunTrainOpsHook(d_op, num_d_steps)
return [joint_hook, g_hook, d_hook]
return get_hooks
# TODO(joelshor): This function currently returns the global step. Find a
# good way for it to return the generator, discriminator, and final losses.
def gan_train(
train_ops,
logdir,
get_hooks_fn=get_sequential_train_hooks(),
master='',
is_chief=True,
scaffold=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=600,
save_summaries_steps=100,
config=None):
"""A wrapper around `contrib.training.train` that uses GAN hooks.
Args:
train_ops: A GANTrainOps named tuple.
logdir: The directory where the graph and checkpoints are saved.
get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
of hooks.
master: The URL of the master.
is_chief: Specifies whether or not the training is being run by the primary
replica during replica training.
scaffold: An tf.train.Scaffold instance.
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
training loop.
chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
inside the training loop for the chief trainer only.
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
using a default checkpoint saver. If `save_checkpoint_secs` is set to
`None`, then the default checkpoint saver isn't used.
save_summaries_steps: The frequency, in number of global steps, that the
summaries are written to disk using a default summary saver. If
`save_summaries_steps` is set to `None`, then the default summary saver
isn't used.
config: An instance of `tf.ConfigProto`.
Returns:
Output of the call to `training.train`.
"""
new_hooks = get_hooks_fn(train_ops)
if hooks is not None:
hooks = list(hooks) + list(new_hooks)
else:
hooks = new_hooks
return training.train(
train_ops.global_step_inc_op,
logdir,
master=master,
is_chief=is_chief,
scaffold=scaffold,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=save_checkpoint_secs,
save_summaries_steps=save_summaries_steps,
config=config)
def get_sequential_train_steps(
train_steps=namedtuples.GANTrainSteps(1, 1)):
"""Returns a thin wrapper around slim.learning.train_step, for GANs.
This function is to provide support for the Supervisor. For new code, please
use `MonitoredSession` and `get_sequential_train_hooks`.
Args:
train_steps: A `GANTrainSteps` tuple that determines how many generator
and discriminator training steps to take.
Returns:
A function that can be used for `train_step_fn` for GANs.
"""
def sequential_train_steps(sess, train_ops, global_step, train_step_kwargs):
"""A thin wrapper around slim.learning.train_step, for GANs.
Args:
sess: A Tensorflow session.
train_ops: A GANTrainOps tuple of train ops to run.
global_step: The global step.
train_step_kwargs: Dictionary controlling `train_step` behavior.
Returns:
A scalar final loss and a bool whether or not the train loop should stop.
"""
# Only run `should_stop` at the end, if required. Make a local copy of
# `train_step_kwargs`, if necessary, so as not to modify the caller's
# dictionary.
should_stop_op, train_kwargs = None, train_step_kwargs
if 'should_stop' in train_step_kwargs:
should_stop_op = train_step_kwargs['should_stop']
train_kwargs = train_step_kwargs.copy()
del train_kwargs['should_stop']
# Run generator training steps.
gen_loss = 0
for _ in range(train_steps.generator_train_steps):
cur_gen_loss, _ = slim_learning.train_step(
sess, train_ops.generator_train_op, global_step, train_kwargs)
gen_loss += cur_gen_loss
# Run discriminator training steps.
dis_loss = 0
for _ in range(train_steps.discriminator_train_steps):
cur_dis_loss, _ = slim_learning.train_step(
sess, train_ops.discriminator_train_op, global_step, train_kwargs)
dis_loss += cur_dis_loss
sess.run(train_ops.global_step_inc_op)
# Run the `should_stop` op after the global step has been incremented, so
# that the `should_stop` aligns with the proper `global_step` count.
if should_stop_op is not None:
should_stop = sess.run(should_stop_op)
else:
should_stop = False
return gen_loss + dis_loss, should_stop
return sequential_train_steps
# Helpers
def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
"""Convert input, list of inputs, or dictionary of inputs to Tensors."""
if isinstance(tensor_or_l_or_d, (list, tuple)):
return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
elif isinstance(tensor_or_l_or_d, dict):
return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
else:
return ops.convert_to_tensor(tensor_or_l_or_d)
def _validate_distributions(distributions_l, noise_l):
if not isinstance(distributions_l, (tuple, list)):
raise ValueError('`predicted_distributions` must be a list. Instead, found '
'%s.' % type(distributions_l))
for dist in distributions_l:
if not isinstance(dist, ds.Distribution):
raise ValueError('Every element in `predicted_distributions` must be a '
'`tf.Distribution`. Instead, found %s.' % type(dist))
if len(distributions_l) != len(noise_l):
raise ValueError('Length of `predicted_distributions` %i must be the same '
'as the length of structured noise %i.' %
(len(distributions_l), len(noise_l)))
def _validate_acgan_discriminator_outputs(discriminator_output):
try:
a, b = discriminator_output
except (TypeError, ValueError):
raise TypeError(
'A discriminator function for ACGAN must output a tuple '
'consisting of (discrimination logits, classification logits).')
return a, b