1222 lines
48 KiB
Python
1222 lines
48 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
|
||
|
"""Base class for optimizers."""
|
||
|
# pylint: disable=g-bad-name
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import abc
|
||
|
|
||
|
from tensorflow.python.eager import backprop
|
||
|
from tensorflow.python.eager import context
|
||
|
from tensorflow.python.framework import dtypes
|
||
|
from tensorflow.python.framework import ops
|
||
|
from tensorflow.python.ops import array_ops
|
||
|
from tensorflow.python.ops import control_flow_ops
|
||
|
from tensorflow.python.ops import gradients
|
||
|
from tensorflow.python.ops import math_ops
|
||
|
from tensorflow.python.ops import resource_variable_ops
|
||
|
from tensorflow.python.ops import state_ops
|
||
|
from tensorflow.python.ops import variable_scope
|
||
|
from tensorflow.python.ops import variables
|
||
|
from tensorflow.python.training import distribute as distribute_lib
|
||
|
from tensorflow.python.training import slot_creator
|
||
|
from tensorflow.python.training.checkpointable import base as checkpointable
|
||
|
from tensorflow.python.util import nest
|
||
|
from tensorflow.python.util.tf_export import tf_export
|
||
|
|
||
|
|
||
|
def get_filtered_grad_fn(grad_fn):
|
||
|
# `distributed_context.join()` requires that its arguments are parallel
|
||
|
# across threads, and in particular that `grads_and_vars` has the same
|
||
|
# variables in the same order.
|
||
|
|
||
|
# When computing gradients in eager mode with multiple threads, you
|
||
|
# can get extra variables with a gradient of `None`. This happens when
|
||
|
# those variables are accessed in another thread during the gradient
|
||
|
# computation. To get a consistent set of variables, we filter out
|
||
|
# those with `None` gradients.
|
||
|
def filtered_grad_fn(x=None):
|
||
|
return [(g, v) for g, v in grad_fn(x) if g is not None]
|
||
|
|
||
|
return filtered_grad_fn
|
||
|
|
||
|
|
||
|
def _deduplicate_indexed_slices(values, indices):
|
||
|
"""Sums `values` associated with any non-unique `indices`.
|
||
|
|
||
|
Args:
|
||
|
values: A `Tensor` with rank >= 1.
|
||
|
indices: A one-dimensional integer `Tensor`, indexing into the first
|
||
|
dimension of `values` (as in an IndexedSlices object).
|
||
|
Returns:
|
||
|
A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
|
||
|
de-duplicated version of `indices` and `summed_values` contains the sum of
|
||
|
`values` slices associated with each unique index.
|
||
|
"""
|
||
|
unique_indices, new_index_positions = array_ops.unique(indices)
|
||
|
summed_values = math_ops.unsorted_segment_sum(
|
||
|
values, new_index_positions,
|
||
|
array_ops.shape(unique_indices)[0])
|
||
|
return (summed_values, unique_indices)
|
||
|
|
||
|
|
||
|
def _var_key(var):
|
||
|
# TODO(ashankar): Consolidate handling for eager and graph
|
||
|
if hasattr(var, "op"):
|
||
|
return (var.op.graph, var.op.name)
|
||
|
return var._unique_id # pylint: disable=protected-access
|
||
|
|
||
|
|
||
|
class _OptimizableVariable(object):
|
||
|
"""Interface for abstracting over variables in the optimizers."""
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def target(self):
|
||
|
"""Returns the optimization target for this variable."""
|
||
|
raise NotImplementedError("Calling an abstract method.")
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def update_op(self, optimizer, g):
|
||
|
"""Returns the update ops for updating the variable."""
|
||
|
raise NotImplementedError("Calling an abstract method.")
|
||
|
|
||
|
|
||
|
class _RefVariableProcessor(_OptimizableVariable):
|
||
|
"""Processor for Variable."""
|
||
|
|
||
|
def __init__(self, v):
|
||
|
self._v = v
|
||
|
|
||
|
def __str__(self):
|
||
|
return "<_RefVariableProcessor(%s)>" % self._v
|
||
|
|
||
|
def target(self):
|
||
|
return self._v._ref() # pylint: disable=protected-access
|
||
|
|
||
|
def update_op(self, optimizer, g):
|
||
|
if isinstance(g, ops.Tensor):
|
||
|
update_op = optimizer._apply_dense(g, self._v) # pylint: disable=protected-access
|
||
|
if self._v.constraint is not None:
|
||
|
with ops.control_dependencies([update_op]):
|
||
|
return self._v.assign(self._v.constraint(self._v))
|
||
|
else:
|
||
|
return update_op
|
||
|
else:
|
||
|
assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
|
||
|
"tensor nor IndexedSlices.")
|
||
|
if self._v.constraint is not None:
|
||
|
raise RuntimeError(
|
||
|
"Cannot use a constraint function on a sparse variable.")
|
||
|
# pylint: disable=protected-access
|
||
|
return optimizer._apply_sparse_duplicate_indices(g, self._v)
|
||
|
|
||
|
|
||
|
class _DenseReadResourceVariableProcessor(_OptimizableVariable):
|
||
|
"""Processor for dense ResourceVariables."""
|
||
|
|
||
|
def __init__(self, v):
|
||
|
self._v = v
|
||
|
|
||
|
def target(self):
|
||
|
return self._v
|
||
|
|
||
|
def update_op(self, optimizer, g):
|
||
|
# pylint: disable=protected-access
|
||
|
update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0])
|
||
|
if self._v.constraint is not None:
|
||
|
with ops.control_dependencies([update_op]):
|
||
|
return self._v.assign(self._v.constraint(self._v))
|
||
|
else:
|
||
|
return update_op
|
||
|
|
||
|
|
||
|
class _DenseResourceVariableProcessor(_OptimizableVariable):
|
||
|
"""Processor for dense ResourceVariables."""
|
||
|
|
||
|
def __init__(self, v):
|
||
|
self._v = v
|
||
|
|
||
|
def target(self):
|
||
|
return self._v
|
||
|
|
||
|
def update_op(self, optimizer, g):
|
||
|
# pylint: disable=protected-access
|
||
|
if isinstance(g, ops.IndexedSlices):
|
||
|
if self._v.constraint is not None:
|
||
|
raise RuntimeError(
|
||
|
"Cannot use a constraint function on a sparse variable.")
|
||
|
return optimizer._resource_apply_sparse_duplicate_indices(
|
||
|
g.values, self._v, g.indices)
|
||
|
update_op = optimizer._resource_apply_dense(g, self._v)
|
||
|
if self._v.constraint is not None:
|
||
|
with ops.control_dependencies([update_op]):
|
||
|
return self._v.assign(self._v.constraint(self._v))
|
||
|
else:
|
||
|
return update_op
|
||
|
|
||
|
|
||
|
class _TensorProcessor(_OptimizableVariable):
|
||
|
"""Processor for ordinary Tensors.
|
||
|
|
||
|
Even though a Tensor can't really be updated, sometimes it is useful to
|
||
|
compute the gradients with respect to a Tensor using the optimizer. Updating
|
||
|
the Tensor is, of course, unsupported.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, v):
|
||
|
self._v = v
|
||
|
|
||
|
def target(self):
|
||
|
return self._v
|
||
|
|
||
|
def update_op(self, optimizer, g):
|
||
|
raise NotImplementedError("Trying to update a Tensor ", self._v)
|
||
|
|
||
|
|
||
|
def _get_processor(v):
|
||
|
"""The processor of v."""
|
||
|
if context.executing_eagerly():
|
||
|
if isinstance(v, ops.Tensor):
|
||
|
return _TensorProcessor(v)
|
||
|
else:
|
||
|
return _DenseResourceVariableProcessor(v)
|
||
|
if isinstance(
|
||
|
v, resource_variable_ops.ResourceVariable) and not v._in_graph_mode: # pylint: disable=protected-access
|
||
|
# True if and only if `v` was initialized eagerly.
|
||
|
return _DenseResourceVariableProcessor(v)
|
||
|
if v.op.type == "VarHandleOp":
|
||
|
return _DenseResourceVariableProcessor(v)
|
||
|
if isinstance(v, variables.Variable):
|
||
|
return _RefVariableProcessor(v)
|
||
|
if isinstance(v, ops.Tensor):
|
||
|
return _TensorProcessor(v)
|
||
|
raise NotImplementedError("Trying to optimize unsupported type ", v)
|
||
|
|
||
|
|
||
|
@tf_export("train.Optimizer")
|
||
|
class Optimizer(
|
||
|
# Optimizers inherit from CheckpointableBase rather than Checkpointable
|
||
|
# since they do most of their dependency management themselves (slot
|
||
|
# variables are special-cased, and non-slot variables are keyed to graphs).
|
||
|
checkpointable.CheckpointableBase):
|
||
|
"""Base class for optimizers.
|
||
|
|
||
|
This class defines the API to add Ops to train a model. You never use this
|
||
|
class directly, but instead instantiate one of its subclasses such as
|
||
|
`GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
|
||
|
|
||
|
### Usage
|
||
|
|
||
|
```python
|
||
|
# Create an optimizer with the desired parameters.
|
||
|
opt = GradientDescentOptimizer(learning_rate=0.1)
|
||
|
# Add Ops to the graph to minimize a cost by updating a list of variables.
|
||
|
# "cost" is a Tensor, and the list of variables contains tf.Variable
|
||
|
# objects.
|
||
|
opt_op = opt.minimize(cost, var_list=<list of variables>)
|
||
|
```
|
||
|
|
||
|
In the training program you will just have to run the returned Op.
|
||
|
|
||
|
```python
|
||
|
# Execute opt_op to do one step of training:
|
||
|
opt_op.run()
|
||
|
```
|
||
|
|
||
|
### Processing gradients before applying them.
|
||
|
|
||
|
Calling `minimize()` takes care of both computing the gradients and
|
||
|
applying them to the variables. If you want to process the gradients
|
||
|
before applying them you can instead use the optimizer in three steps:
|
||
|
|
||
|
1. Compute the gradients with `compute_gradients()`.
|
||
|
2. Process the gradients as you wish.
|
||
|
3. Apply the processed gradients with `apply_gradients()`.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
# Create an optimizer.
|
||
|
opt = GradientDescentOptimizer(learning_rate=0.1)
|
||
|
|
||
|
# Compute the gradients for a list of variables.
|
||
|
grads_and_vars = opt.compute_gradients(loss, <list of variables>)
|
||
|
|
||
|
# grads_and_vars is a list of tuples (gradient, variable). Do whatever you
|
||
|
# need to the 'gradient' part, for example cap them, etc.
|
||
|
capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
|
||
|
|
||
|
# Ask the optimizer to apply the capped gradients.
|
||
|
opt.apply_gradients(capped_grads_and_vars)
|
||
|
```
|
||
|
|
||
|
### Gating Gradients
|
||
|
|
||
|
Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
|
||
|
argument that controls the degree of parallelism during the application of
|
||
|
the gradients.
|
||
|
|
||
|
The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
|
||
|
|
||
|
<b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides
|
||
|
the maximum parallelism in execution, at the cost of some non-reproducibility
|
||
|
in the results. For example the two gradients of `matmul` depend on the input
|
||
|
values: With `GATE_NONE` one of the gradients could be applied to one of the
|
||
|
inputs _before_ the other gradient is computed resulting in non-reproducible
|
||
|
results.
|
||
|
|
||
|
<b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
|
||
|
they are used. This prevents race conditions for Ops that generate gradients
|
||
|
for multiple inputs where the gradients depend on the inputs.
|
||
|
|
||
|
<b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
|
||
|
before any one of them is used. This provides the least parallelism but can
|
||
|
be useful if you want to process all gradients before applying any of them.
|
||
|
|
||
|
### Slots
|
||
|
|
||
|
Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
|
||
|
allocate and manage additional variables associated with the variables to
|
||
|
train. These are called <i>Slots</i>. Slots have names and you can ask the
|
||
|
optimizer for the names of the slots that it uses. Once you have a slot name
|
||
|
you can ask the optimizer for the variable it created to hold the slot value.
|
||
|
|
||
|
This can be useful if you want to log debug a training algorithm, report stats
|
||
|
about the slots, etc.
|
||
|
"""
|
||
|
|
||
|
# Values for gate_gradients.
|
||
|
GATE_NONE = 0
|
||
|
GATE_OP = 1
|
||
|
GATE_GRAPH = 2
|
||
|
|
||
|
def __init__(self, use_locking, name):
|
||
|
"""Create a new Optimizer.
|
||
|
|
||
|
This must be called by the constructors of subclasses.
|
||
|
|
||
|
Args:
|
||
|
use_locking: Bool. If True apply use locks to prevent concurrent updates
|
||
|
to variables.
|
||
|
name: A non-empty string. The name to use for accumulators created
|
||
|
for the optimizer.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If name is malformed.
|
||
|
"""
|
||
|
if not name:
|
||
|
raise ValueError("Must specify the optimizer name")
|
||
|
self._use_locking = use_locking
|
||
|
self._name = name
|
||
|
# Dictionary of slots.
|
||
|
# {slot_name :
|
||
|
# {_var_key(variable_to_train): slot_for_the_variable, ... },
|
||
|
# ... }
|
||
|
self._slots = {}
|
||
|
self._non_slot_dict = {}
|
||
|
# For implementing Checkpointable. Stores information about how to restore
|
||
|
# slot variables which have not yet been created
|
||
|
# (checkpointable._CheckpointPosition objects).
|
||
|
# {slot_name :
|
||
|
# {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
|
||
|
# ... }
|
||
|
self._deferred_slot_restorations = {}
|
||
|
|
||
|
# TODO(isaprykin): When using a DistributionStrategy, and when an
|
||
|
# optimizer is created in each tower, it might be dangerous to
|
||
|
# rely on some Optimer methods. When such methods are called on a
|
||
|
# per-tower optimizer, an exception needs to be thrown. We do
|
||
|
# allow creation per-tower optimizers however, because the
|
||
|
# compute_gradients()->apply_gradients() sequence is safe.
|
||
|
|
||
|
def get_name(self):
|
||
|
return self._name
|
||
|
|
||
|
def minimize(self, loss, global_step=None, var_list=None,
|
||
|
gate_gradients=GATE_OP, aggregation_method=None,
|
||
|
colocate_gradients_with_ops=False, name=None,
|
||
|
grad_loss=None):
|
||
|
"""Add operations to minimize `loss` by updating `var_list`.
|
||
|
|
||
|
This method simply combines calls `compute_gradients()` and
|
||
|
`apply_gradients()`. If you want to process the gradient before applying
|
||
|
them call `compute_gradients()` and `apply_gradients()` explicitly instead
|
||
|
of using this function.
|
||
|
|
||
|
Args:
|
||
|
loss: A `Tensor` containing the value to minimize.
|
||
|
global_step: Optional `Variable` to increment by one after the
|
||
|
variables have been updated.
|
||
|
var_list: Optional list or tuple of `Variable` objects to update to
|
||
|
minimize `loss`. Defaults to the list of variables collected in
|
||
|
the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
|
||
|
gate_gradients: How to gate the computation of gradients. Can be
|
||
|
`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
|
||
|
aggregation_method: Specifies the method used to combine gradient terms.
|
||
|
Valid values are defined in the class `AggregationMethod`.
|
||
|
colocate_gradients_with_ops: If True, try colocating gradients with
|
||
|
the corresponding op.
|
||
|
name: Optional name for the returned operation.
|
||
|
grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
|
||
|
|
||
|
Returns:
|
||
|
An Operation that updates the variables in `var_list`. If `global_step`
|
||
|
was not `None`, that operation also increments `global_step`.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If some of the variables are not `Variable` objects.
|
||
|
|
||
|
@compatibility(eager)
|
||
|
When eager execution is enabled, `loss` should be a Python function that
|
||
|
takes elements of `var_list` as arguments and computes the value to be
|
||
|
minimized. If `var_list` is None, `loss` should take no arguments.
|
||
|
Minimization (and gradient computation) is done with respect to the
|
||
|
elements of `var_list` if not None, else with respect to any trainable
|
||
|
variables created during the execution of the `loss` function.
|
||
|
`gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
|
||
|
`grad_loss` are ignored when eager execution is enabled.
|
||
|
@end_compatibility
|
||
|
"""
|
||
|
grads_and_vars = self.compute_gradients(
|
||
|
loss, var_list=var_list, gate_gradients=gate_gradients,
|
||
|
aggregation_method=aggregation_method,
|
||
|
colocate_gradients_with_ops=colocate_gradients_with_ops,
|
||
|
grad_loss=grad_loss)
|
||
|
|
||
|
vars_with_grad = [v for g, v in grads_and_vars if g is not None]
|
||
|
if not vars_with_grad:
|
||
|
raise ValueError(
|
||
|
"No gradients provided for any variable, check your graph for ops"
|
||
|
" that do not support gradients, between variables %s and loss %s." %
|
||
|
([str(v) for _, v in grads_and_vars], loss))
|
||
|
|
||
|
return self.apply_gradients(grads_and_vars, global_step=global_step,
|
||
|
name=name)
|
||
|
|
||
|
def compute_gradients(self, loss, var_list=None,
|
||
|
gate_gradients=GATE_OP,
|
||
|
aggregation_method=None,
|
||
|
colocate_gradients_with_ops=False,
|
||
|
grad_loss=None):
|
||
|
"""Compute gradients of `loss` for the variables in `var_list`.
|
||
|
|
||
|
This is the first part of `minimize()`. It returns a list
|
||
|
of (gradient, variable) pairs where "gradient" is the gradient
|
||
|
for "variable". Note that "gradient" can be a `Tensor`, an
|
||
|
`IndexedSlices`, or `None` if there is no gradient for the
|
||
|
given variable.
|
||
|
|
||
|
Args:
|
||
|
loss: A Tensor containing the value to minimize or a callable taking
|
||
|
no arguments which returns the value to minimize. When eager execution
|
||
|
is enabled it must be a callable.
|
||
|
var_list: Optional list or tuple of `tf.Variable` to update to minimize
|
||
|
`loss`. Defaults to the list of variables collected in the graph
|
||
|
under the key `GraphKeys.TRAINABLE_VARIABLES`.
|
||
|
gate_gradients: How to gate the computation of gradients. Can be
|
||
|
`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
|
||
|
aggregation_method: Specifies the method used to combine gradient terms.
|
||
|
Valid values are defined in the class `AggregationMethod`.
|
||
|
colocate_gradients_with_ops: If True, try colocating gradients with
|
||
|
the corresponding op.
|
||
|
grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
|
||
|
|
||
|
Returns:
|
||
|
A list of (gradient, variable) pairs. Variable is always present, but
|
||
|
gradient can be `None`.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: If `var_list` contains anything else than `Variable` objects.
|
||
|
ValueError: If some arguments are invalid.
|
||
|
RuntimeError: If called with eager execution enabled and `loss` is
|
||
|
not callable.
|
||
|
|
||
|
@compatibility(eager)
|
||
|
When eager execution is enabled, `gate_gradients`, `aggregation_method`,
|
||
|
and `colocate_gradients_with_ops` are ignored.
|
||
|
@end_compatibility
|
||
|
"""
|
||
|
if callable(loss):
|
||
|
with backprop.GradientTape() as tape:
|
||
|
if var_list is not None:
|
||
|
tape.watch(var_list)
|
||
|
loss_value = loss()
|
||
|
|
||
|
# Scale loss if using a "mean" loss reduction and multiple towers.
|
||
|
# Have to be careful to call distribute_lib.get_loss_reduction()
|
||
|
# *after* loss() is evaluated, so we know what loss reduction it uses.
|
||
|
# TODO(josh11b): Test that we handle weight decay in a reasonable way.
|
||
|
if (distribute_lib.get_loss_reduction() ==
|
||
|
variable_scope.VariableAggregation.MEAN):
|
||
|
num_towers = distribute_lib.get_distribution_strategy().num_towers
|
||
|
if num_towers > 1:
|
||
|
loss_value *= (1. / num_towers)
|
||
|
|
||
|
if var_list is None:
|
||
|
var_list = tape.watched_variables()
|
||
|
grads = tape.gradient(loss_value, var_list, grad_loss)
|
||
|
return list(zip(grads, var_list))
|
||
|
|
||
|
# Non-callable/Tensor loss case
|
||
|
if context.executing_eagerly():
|
||
|
raise RuntimeError(
|
||
|
"`loss` passed to Optimizer.compute_gradients should "
|
||
|
"be a function when eager execution is enabled.")
|
||
|
|
||
|
# Scale loss if using a "mean" loss reduction and multiple towers.
|
||
|
if (distribute_lib.get_loss_reduction() ==
|
||
|
variable_scope.VariableAggregation.MEAN):
|
||
|
num_towers = distribute_lib.get_distribution_strategy().num_towers
|
||
|
if num_towers > 1:
|
||
|
loss *= (1. / num_towers)
|
||
|
|
||
|
if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
|
||
|
Optimizer.GATE_GRAPH]:
|
||
|
raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
|
||
|
"Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" %
|
||
|
gate_gradients)
|
||
|
self._assert_valid_dtypes([loss])
|
||
|
if grad_loss is not None:
|
||
|
self._assert_valid_dtypes([grad_loss])
|
||
|
if var_list is None:
|
||
|
var_list = (
|
||
|
variables.trainable_variables() +
|
||
|
ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
|
||
|
else:
|
||
|
var_list = nest.flatten(var_list)
|
||
|
# pylint: disable=protected-access
|
||
|
var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
|
||
|
# pylint: enable=protected-access
|
||
|
processors = [_get_processor(v) for v in var_list]
|
||
|
if not var_list:
|
||
|
raise ValueError("No variables to optimize.")
|
||
|
var_refs = [p.target() for p in processors]
|
||
|
grads = gradients.gradients(
|
||
|
loss, var_refs, grad_ys=grad_loss,
|
||
|
gate_gradients=(gate_gradients == Optimizer.GATE_OP),
|
||
|
aggregation_method=aggregation_method,
|
||
|
colocate_gradients_with_ops=colocate_gradients_with_ops)
|
||
|
if gate_gradients == Optimizer.GATE_GRAPH:
|
||
|
grads = control_flow_ops.tuple(grads)
|
||
|
grads_and_vars = list(zip(grads, var_list))
|
||
|
self._assert_valid_dtypes(
|
||
|
[v for g, v in grads_and_vars
|
||
|
if g is not None and v.dtype != dtypes.resource])
|
||
|
return grads_and_vars
|
||
|
|
||
|
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
||
|
"""Apply gradients to variables.
|
||
|
|
||
|
This is the second part of `minimize()`. It returns an `Operation` that
|
||
|
applies gradients.
|
||
|
|
||
|
Args:
|
||
|
grads_and_vars: List of (gradient, variable) pairs as returned by
|
||
|
`compute_gradients()`.
|
||
|
global_step: Optional `Variable` to increment by one after the
|
||
|
variables have been updated.
|
||
|
name: Optional name for the returned operation. Default to the
|
||
|
name passed to the `Optimizer` constructor.
|
||
|
|
||
|
Returns:
|
||
|
An `Operation` that applies the specified gradients. If `global_step`
|
||
|
was not None, that operation also increments `global_step`.
|
||
|
|
||
|
Raises:
|
||
|
TypeError: If `grads_and_vars` is malformed.
|
||
|
ValueError: If none of the variables have gradients.
|
||
|
RuntimeError: If you should use `_distributed_apply()` instead.
|
||
|
"""
|
||
|
# This is a default implementation of apply_gradients() that can be shared
|
||
|
# by most optimizers. It relies on the subclass implementing the following
|
||
|
# methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
|
||
|
|
||
|
# Handle DistributionStrategy case.
|
||
|
if distribute_lib.get_cross_tower_context():
|
||
|
raise RuntimeError("Use `_distributed_apply()` instead of "
|
||
|
"`apply_gradients()` in a cross-tower context.")
|
||
|
# TODO(isaprykin): Get rid of `has_distribution_strategy()` check by
|
||
|
# always calling _distributed_apply(), using the default distribution
|
||
|
# as needed.
|
||
|
if distribute_lib.has_distribution_strategy():
|
||
|
grads_and_vars = get_filtered_grad_fn(lambda _: grads_and_vars)()
|
||
|
return distribute_lib.get_tower_context().merge_call(
|
||
|
self._distributed_apply, grads_and_vars, global_step, name)
|
||
|
|
||
|
# No DistributionStrategy case.
|
||
|
grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
|
||
|
if not grads_and_vars:
|
||
|
raise ValueError("No variables provided.")
|
||
|
converted_grads_and_vars = []
|
||
|
for g, v in grads_and_vars:
|
||
|
if g is not None:
|
||
|
try:
|
||
|
# Convert the grad to Tensor or IndexedSlices if necessary.
|
||
|
g = ops.convert_to_tensor_or_indexed_slices(g)
|
||
|
except TypeError:
|
||
|
raise TypeError(
|
||
|
"Gradient must be convertible to a Tensor"
|
||
|
" or IndexedSlices, or None: %s" % g)
|
||
|
if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
|
||
|
raise TypeError(
|
||
|
"Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
|
||
|
p = _get_processor(v)
|
||
|
converted_grads_and_vars.append((g, v, p))
|
||
|
|
||
|
converted_grads_and_vars = tuple(converted_grads_and_vars)
|
||
|
var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
|
||
|
if not var_list:
|
||
|
raise ValueError("No gradients provided for any variable: %s." %
|
||
|
([str(v) for _, _, v in converted_grads_and_vars],))
|
||
|
with ops.init_scope():
|
||
|
self._create_slots(var_list)
|
||
|
update_ops = []
|
||
|
with ops.name_scope(name, self._name) as name:
|
||
|
self._prepare()
|
||
|
for grad, var, processor in converted_grads_and_vars:
|
||
|
if grad is None:
|
||
|
continue
|
||
|
# We colocate all ops created in _apply_dense or _apply_sparse
|
||
|
# on the same device as the variable.
|
||
|
# TODO(apassos): figure out how to get the variable name here.
|
||
|
if context.executing_eagerly() or isinstance(
|
||
|
var,
|
||
|
resource_variable_ops.ResourceVariable) and not var._in_graph_mode: # pylint: disable=protected-access
|
||
|
scope_name = ""
|
||
|
else:
|
||
|
scope_name = var.op.name
|
||
|
with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
|
||
|
update_ops.append(processor.update_op(self, grad))
|
||
|
if global_step is None:
|
||
|
apply_updates = self._finish(update_ops, name)
|
||
|
else:
|
||
|
with ops.control_dependencies([self._finish(update_ops, "update")]):
|
||
|
with ops.colocate_with(global_step):
|
||
|
if isinstance(global_step, resource_variable_ops.ResourceVariable):
|
||
|
# TODO(apassos): the implicit read in assign_add is slow; consider
|
||
|
# making it less so.
|
||
|
apply_updates = resource_variable_ops.assign_add_variable_op(
|
||
|
global_step.handle,
|
||
|
ops.convert_to_tensor(1, dtype=global_step.dtype),
|
||
|
name=name)
|
||
|
else:
|
||
|
apply_updates = state_ops.assign_add(global_step, 1, name=name)
|
||
|
|
||
|
if not context.executing_eagerly():
|
||
|
if isinstance(apply_updates, ops.Tensor):
|
||
|
apply_updates = apply_updates.op
|
||
|
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
||
|
if apply_updates not in train_op:
|
||
|
train_op.append(apply_updates)
|
||
|
|
||
|
return apply_updates
|
||
|
|
||
|
def _distributed_apply(self,
|
||
|
distribution,
|
||
|
grads_and_vars,
|
||
|
global_step=None,
|
||
|
name=None):
|
||
|
"""A version of `apply_gradients` for cross-tower context.
|
||
|
|
||
|
This is a version of `apply_gradients()` for when you are using a
|
||
|
`DistributionStrategy` and are in a cross-tower context. If in a
|
||
|
tower context, use `apply_gradients()` as normal.
|
||
|
|
||
|
Args:
|
||
|
distribution: A `DistributionStrategy` object.
|
||
|
grads_and_vars: List of (gradient, variable) pairs as returned by
|
||
|
`compute_gradients()`, and then aggregated across towers.
|
||
|
global_step: Optional (mirrored) `Variable` to increment by one
|
||
|
after the variables have been updated.
|
||
|
name: Optional name for the returned operation. Default to the
|
||
|
name passed to the `Optimizer` constructor.
|
||
|
|
||
|
Returns:
|
||
|
An `Operation` that applies the specified gradients across all
|
||
|
towers. If `global_step` was not None, that operation also
|
||
|
increments `global_step`.
|
||
|
"""
|
||
|
reduced_grads = distribution.batch_reduce(
|
||
|
variable_scope.VariableAggregation.SUM, grads_and_vars)
|
||
|
var_list = [v for _, v in grads_and_vars]
|
||
|
grads_and_vars = zip(reduced_grads, var_list)
|
||
|
# Note that this is called in a cross-tower context.
|
||
|
self._create_slots(var_list)
|
||
|
|
||
|
def update(v, g):
|
||
|
"""Apply gradients to a replica variable."""
|
||
|
assert v is not None
|
||
|
|
||
|
try:
|
||
|
# Convert the grad to Tensor or IndexedSlices if necessary.
|
||
|
g = ops.convert_to_tensor_or_indexed_slices(g)
|
||
|
except TypeError:
|
||
|
raise TypeError("Gradient must be convertible to a Tensor"
|
||
|
" or IndexedSlices, or None: %s" % g)
|
||
|
if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
|
||
|
raise TypeError(
|
||
|
"Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
|
||
|
p = _get_processor(v)
|
||
|
|
||
|
scope_name = "" if context.executing_eagerly() else v.op.name
|
||
|
# device_policy is set because non-mirrored tensors will be read in
|
||
|
# `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t`
|
||
|
# is an example.
|
||
|
with ops.name_scope("update_" + scope_name):
|
||
|
return p.update_op(self, g)
|
||
|
|
||
|
with ops.name_scope(name, self._name) as name:
|
||
|
self._prepare()
|
||
|
|
||
|
update_ops = [
|
||
|
op
|
||
|
for grad, var in grads_and_vars
|
||
|
for op in distribution.unwrap(distribution.update(var, update, grad))
|
||
|
]
|
||
|
|
||
|
def finish(self, update_ops):
|
||
|
return self._finish(update_ops, "update")
|
||
|
|
||
|
non_slot_devices = distribution.non_slot_devices(var_list)
|
||
|
finish_updates = distribution.update_non_slot(
|
||
|
non_slot_devices, finish, self, update_ops)
|
||
|
if global_step is None:
|
||
|
apply_updates = distribution.group(finish_updates, name=name)
|
||
|
else:
|
||
|
with ops.control_dependencies(distribution.unwrap(finish_updates)):
|
||
|
apply_updates = distribution.group(distribution.update(
|
||
|
global_step, state_ops.assign_add, 1, name=name))
|
||
|
|
||
|
if not context.executing_eagerly():
|
||
|
if isinstance(apply_updates, ops.Tensor):
|
||
|
apply_updates = apply_updates.op
|
||
|
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
||
|
if apply_updates not in train_op:
|
||
|
train_op.append(apply_updates)
|
||
|
|
||
|
return apply_updates
|
||
|
|
||
|
def get_slot(self, var, name):
|
||
|
"""Return a slot named `name` created for `var` by the Optimizer.
|
||
|
|
||
|
Some `Optimizer` subclasses use additional variables. For example
|
||
|
`Momentum` and `Adagrad` use variables to accumulate updates. This method
|
||
|
gives access to these `Variable` objects if for some reason you need them.
|
||
|
|
||
|
Use `get_slot_names()` to get the list of slot names created by the
|
||
|
`Optimizer`.
|
||
|
|
||
|
Args:
|
||
|
var: A variable passed to `minimize()` or `apply_gradients()`.
|
||
|
name: A string.
|
||
|
|
||
|
Returns:
|
||
|
The `Variable` for the slot if it was created, `None` otherwise.
|
||
|
"""
|
||
|
# pylint: disable=protected-access
|
||
|
named_slots = self._slots.get(name, None)
|
||
|
if not named_slots:
|
||
|
return None
|
||
|
|
||
|
if hasattr(var, "_distributed_container"):
|
||
|
# NOTE: If this isn't patched, then there is no `handle` in
|
||
|
# `_resource_apply_dense`.
|
||
|
distributed_container = var._distributed_container()
|
||
|
assert distributed_container is not None
|
||
|
if context.executing_eagerly():
|
||
|
key = distributed_container._unique_id
|
||
|
else:
|
||
|
key = (distributed_container.graph, distributed_container._shared_name)
|
||
|
# pylint: enable=protected-access
|
||
|
mirrored_slot = named_slots.get(key, None)
|
||
|
if mirrored_slot is None: return None
|
||
|
return mirrored_slot.get(device=var.device)
|
||
|
|
||
|
return named_slots.get(_var_key(var), None)
|
||
|
|
||
|
def get_slot_names(self):
|
||
|
"""Return a list of the names of slots created by the `Optimizer`.
|
||
|
|
||
|
See `get_slot()`.
|
||
|
|
||
|
Returns:
|
||
|
A list of strings.
|
||
|
"""
|
||
|
return sorted(self._slots.keys())
|
||
|
|
||
|
def variables(self):
|
||
|
"""A list of variables which encode the current state of `Optimizer`.
|
||
|
|
||
|
Includes slot variables and additional global variables created by the
|
||
|
optimizer in the current default graph.
|
||
|
|
||
|
Returns:
|
||
|
A list of variables.
|
||
|
"""
|
||
|
executing_eagerly = context.executing_eagerly()
|
||
|
current_graph = ops.get_default_graph()
|
||
|
|
||
|
def _from_current_graph(variable):
|
||
|
if executing_eagerly:
|
||
|
# No variable.op in eager mode. We don't expect lots of eager graphs,
|
||
|
# but behavior should be consistent with graph mode.
|
||
|
return variable._graph_key == current_graph._graph_key # pylint: disable=protected-access
|
||
|
else:
|
||
|
return variable.op.graph is current_graph
|
||
|
|
||
|
optimizer_variables = [v for v in self._non_slot_variables()
|
||
|
if _from_current_graph(v)]
|
||
|
for _, variable_dict in self._slots.items():
|
||
|
for _, slot_for_variable in variable_dict.items():
|
||
|
if _from_current_graph(slot_for_variable):
|
||
|
optimizer_variables.append(slot_for_variable)
|
||
|
# Sort variables by name so that the return is deterministic.
|
||
|
return sorted(optimizer_variables, key=lambda v: v.name)
|
||
|
|
||
|
def _create_non_slot_variable(self, initial_value, name, colocate_with):
|
||
|
"""Add an extra variable, not associated with a slot."""
|
||
|
# Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
|
||
|
eager = context.executing_eagerly()
|
||
|
graph = None if eager else colocate_with.graph
|
||
|
|
||
|
key = (name, graph)
|
||
|
v = self._non_slot_dict.get(key, None)
|
||
|
if v is None:
|
||
|
self._maybe_initialize_checkpointable()
|
||
|
distribution_strategy = distribute_lib.get_distribution_strategy()
|
||
|
with distribution_strategy.colocate_vars_with(colocate_with):
|
||
|
if eager:
|
||
|
restored_initial_value = self._preload_simple_restoration(
|
||
|
name=name, shape=None)
|
||
|
if restored_initial_value is not None:
|
||
|
initial_value = restored_initial_value
|
||
|
v = variable_scope.variable(initial_value, name=name, trainable=False)
|
||
|
# Restore this variable by name if necessary, but don't add a
|
||
|
# Checkpointable dependency. Optimizers return the current graph's
|
||
|
# non-slot variables from _checkpoint_dependencies explicitly rather
|
||
|
# than unconditionally adding dependencies (since there may be multiple
|
||
|
# non-slot variables with the same name in different graphs, trying to
|
||
|
# save all of them would result in errors).
|
||
|
self._handle_deferred_dependencies(name=name, checkpointable=v)
|
||
|
self._non_slot_dict[key] = v
|
||
|
|
||
|
return v
|
||
|
|
||
|
@property
|
||
|
def _checkpoint_dependencies(self):
|
||
|
"""From Checkpointable. Gather graph-specific non-slot variables to save."""
|
||
|
current_graph_non_slot_variables = []
|
||
|
current_graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
|
||
|
for (name, _), variable_object in sorted(self._non_slot_dict.items(),
|
||
|
# Avoid comparing graphs
|
||
|
key=lambda item: item[0][0]):
|
||
|
if variable_object._graph_key == current_graph_key: # pylint: disable=protected-access
|
||
|
current_graph_non_slot_variables.append(
|
||
|
checkpointable.CheckpointableReference(
|
||
|
name=name, ref=variable_object))
|
||
|
return (super(Optimizer, self)._checkpoint_dependencies
|
||
|
+ current_graph_non_slot_variables)
|
||
|
|
||
|
def _lookup_dependency(self, name):
|
||
|
"""From Checkpointable. Find a non-slot variable in the current graph."""
|
||
|
unconditional = super(Optimizer, self)._lookup_dependency(name)
|
||
|
if unconditional is not None:
|
||
|
return unconditional
|
||
|
graph = None if context.executing_eagerly() else ops.get_default_graph()
|
||
|
return self._get_non_slot_variable(name, graph=graph)
|
||
|
|
||
|
def _get_non_slot_variable(self, name, graph=None):
|
||
|
non_slot = self._non_slot_dict.get((name, graph), None)
|
||
|
if hasattr(non_slot, "_distributed_container"):
|
||
|
# This is a mirrored non-slot. In order to enable code like `_finish`
|
||
|
# to assign to a non-slot, return the current context replica.
|
||
|
return non_slot.get()
|
||
|
else:
|
||
|
return non_slot
|
||
|
|
||
|
def _non_slot_variables(self):
|
||
|
"""Additional variables created by the `Optimizer`.
|
||
|
|
||
|
Returns:
|
||
|
A list or tuple of variables.
|
||
|
"""
|
||
|
return self._non_slot_dict.values()
|
||
|
|
||
|
def _assert_valid_dtypes(self, tensors):
|
||
|
"""Asserts tensors are all valid types (see `_valid_dtypes`).
|
||
|
|
||
|
Args:
|
||
|
tensors: Tensors to check.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If any tensor is not a valid type.
|
||
|
"""
|
||
|
valid_dtypes = self._valid_dtypes()
|
||
|
for t in tensors:
|
||
|
dtype = t.dtype.base_dtype
|
||
|
if dtype not in valid_dtypes:
|
||
|
raise ValueError(
|
||
|
"Invalid type %r for %s, expected: %s." % (
|
||
|
dtype, t.name, [v for v in valid_dtypes]))
|
||
|
|
||
|
# --------------
|
||
|
# Methods to be implemented by subclasses if they want to use the
|
||
|
# inherited implementation of apply_gradients() or compute_gradients().
|
||
|
# --------------
|
||
|
def _valid_dtypes(self):
|
||
|
"""Valid types for loss, variables and gradients.
|
||
|
|
||
|
Subclasses should override to allow other float types.
|
||
|
|
||
|
Returns:
|
||
|
Valid types for loss, variables and gradients.
|
||
|
"""
|
||
|
return set(
|
||
|
[dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64])
|
||
|
|
||
|
def _create_slots(self, var_list):
|
||
|
"""Create all slots needed by the variables.
|
||
|
|
||
|
Args:
|
||
|
var_list: A list of `Variable` objects.
|
||
|
"""
|
||
|
# No slots needed by default
|
||
|
pass
|
||
|
|
||
|
def _prepare(self):
|
||
|
"""Create all needed tensors before applying gradients.
|
||
|
|
||
|
This is called with the name_scope using the "name" that
|
||
|
users have chosen for the application of gradients.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
def _apply_dense(self, grad, var):
|
||
|
"""Add ops to apply dense gradients to `var`.
|
||
|
|
||
|
Args:
|
||
|
grad: A `Tensor`.
|
||
|
var: A `Variable` object.
|
||
|
|
||
|
Returns:
|
||
|
An `Operation`.
|
||
|
"""
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _resource_apply_dense(self, grad, handle):
|
||
|
"""Add ops to apply dense gradients to the variable `handle`.
|
||
|
|
||
|
Args:
|
||
|
grad: a `Tensor` representing the gradient.
|
||
|
handle: a `Tensor` of dtype `resource` which points to the variable
|
||
|
to be updated.
|
||
|
|
||
|
Returns:
|
||
|
An `Operation` which updates the value of the variable.
|
||
|
"""
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices):
|
||
|
"""Add ops to apply sparse gradients to `handle`, with repeated indices.
|
||
|
|
||
|
Optimizers which override this method must deal with repeated indices. See
|
||
|
the docstring of `_apply_sparse_duplicate_indices` for details. By default
|
||
|
the correct behavior, to sum non-unique indices and their associated
|
||
|
gradients, is enforced by first pre-processing `grad` and `indices` and
|
||
|
passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
|
||
|
with duplicate indices may instead override this method to avoid the
|
||
|
overhead of summing.
|
||
|
|
||
|
Args:
|
||
|
grad: a `Tensor` representing the gradient for the affected indices.
|
||
|
handle: a `Tensor` of dtype `resource` which points to the variable
|
||
|
to be updated.
|
||
|
indices: a `Tensor` of integral type representing the indices for
|
||
|
which the gradient is nonzero. Indices may be repeated.
|
||
|
|
||
|
Returns:
|
||
|
An `Operation` which updates the value of the variable.
|
||
|
"""
|
||
|
summed_grad, unique_indices = _deduplicate_indexed_slices(
|
||
|
values=grad, indices=indices)
|
||
|
return self._resource_apply_sparse(summed_grad, handle, unique_indices)
|
||
|
|
||
|
def _resource_apply_sparse(self, grad, handle, indices):
|
||
|
"""Add ops to apply sparse gradients to the variable `handle`.
|
||
|
|
||
|
Similar to `_apply_sparse`, the `indices` argument to this method has been
|
||
|
de-duplicated. Optimizers which deal correctly with non-unique indices may
|
||
|
instead override `_resource_apply_sparse_duplicate_indices` to avoid this
|
||
|
overhead.
|
||
|
|
||
|
Args:
|
||
|
grad: a `Tensor` representing the gradient for the affected indices.
|
||
|
handle: a `Tensor` of dtype `resource` which points to the variable
|
||
|
to be updated.
|
||
|
indices: a `Tensor` of integral type representing the indices for
|
||
|
which the gradient is nonzero. Indices are unique.
|
||
|
|
||
|
Returns:
|
||
|
An `Operation` which updates the value of the variable.
|
||
|
"""
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _apply_sparse_duplicate_indices(self, grad, var):
|
||
|
"""Add ops to apply sparse gradients to `var`, with repeated sparse indices.
|
||
|
|
||
|
Optimizers which override this method must deal with IndexedSlices objects
|
||
|
such as the following:
|
||
|
|
||
|
IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
|
||
|
|
||
|
The correct interpretation is:
|
||
|
|
||
|
IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
|
||
|
|
||
|
Many optimizers deal incorrectly with repeated indices when updating based
|
||
|
on sparse gradients (e.g. summing squares rather than squaring the sum, or
|
||
|
applying momentum terms multiple times). Adding first is always the correct
|
||
|
behavior, so this is enforced here by reconstructing the IndexedSlices to
|
||
|
have only unique indices, then calling _apply_sparse.
|
||
|
|
||
|
Optimizers which deal correctly with repeated indices may instead override
|
||
|
this method to avoid the overhead of summing indices.
|
||
|
|
||
|
Args:
|
||
|
grad: `IndexedSlices`.
|
||
|
var: A `Variable` object.
|
||
|
|
||
|
Returns:
|
||
|
An `Operation`.
|
||
|
"""
|
||
|
summed_values, unique_indices = _deduplicate_indexed_slices(
|
||
|
values=grad.values, indices=grad.indices)
|
||
|
gradient_no_duplicate_indices = ops.IndexedSlices(
|
||
|
indices=unique_indices,
|
||
|
values=summed_values,
|
||
|
dense_shape=grad.dense_shape)
|
||
|
return self._apply_sparse(gradient_no_duplicate_indices, var)
|
||
|
|
||
|
def _apply_sparse(self, grad, var):
|
||
|
"""Add ops to apply sparse gradients to `var`.
|
||
|
|
||
|
The IndexedSlices object passed to `grad` in this function is by default
|
||
|
pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
|
||
|
indices (see its docstring for details). Optimizers which can tolerate or
|
||
|
have correct special cases for duplicate sparse indices may override
|
||
|
`_apply_sparse_duplicate_indices` instead of this function, avoiding that
|
||
|
overhead.
|
||
|
|
||
|
Args:
|
||
|
grad: `IndexedSlices`, with no repeated indices.
|
||
|
var: A `Variable` object.
|
||
|
|
||
|
Returns:
|
||
|
An `Operation`.
|
||
|
"""
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _finish(self, update_ops, name_scope):
|
||
|
"""Do what is needed to finish the update.
|
||
|
|
||
|
This is called with the `name_scope` using the "name" that
|
||
|
users have chosen for the application of gradients.
|
||
|
|
||
|
Args:
|
||
|
update_ops: List of `Operation` objects to update variables. This list
|
||
|
contains the values returned by the `_apply_dense()` and
|
||
|
`_apply_sparse()` calls.
|
||
|
name_scope: String. Name to use for the returned operation.
|
||
|
|
||
|
Returns:
|
||
|
The operation to apply updates.
|
||
|
"""
|
||
|
return control_flow_ops.group(*update_ops, name=name_scope)
|
||
|
|
||
|
# --------------
|
||
|
# Utility methods for subclasses.
|
||
|
# --------------
|
||
|
|
||
|
def _slot_dict(self, slot_name):
|
||
|
"""Returns a dict for caching slots created under the given name.
|
||
|
|
||
|
Args:
|
||
|
slot_name: Name for the slot.
|
||
|
|
||
|
Returns:
|
||
|
A dict that maps primary `Variable` objects to the slot created
|
||
|
for that variable, under the given slot name.
|
||
|
"""
|
||
|
named_slots = self._slots.get(slot_name, None)
|
||
|
if named_slots is None:
|
||
|
named_slots = {}
|
||
|
self._slots[slot_name] = named_slots
|
||
|
return named_slots
|
||
|
|
||
|
def _get_or_make_slot(self, var, val, slot_name, op_name):
|
||
|
"""Find or create a slot for a variable.
|
||
|
|
||
|
Args:
|
||
|
var: A `Variable` object.
|
||
|
val: A `Tensor`. The initial value of the slot.
|
||
|
slot_name: Name for the slot.
|
||
|
op_name: Name to use when scoping the Variable that
|
||
|
needs to be created for the slot.
|
||
|
|
||
|
Returns:
|
||
|
A `Variable` object.
|
||
|
"""
|
||
|
named_slots = self._slot_dict(slot_name)
|
||
|
if _var_key(var) not in named_slots:
|
||
|
new_slot_variable = slot_creator.create_slot(var, val, op_name)
|
||
|
self._restore_slot_variable(
|
||
|
slot_name=slot_name, variable=var,
|
||
|
slot_variable=new_slot_variable)
|
||
|
named_slots[_var_key(var)] = new_slot_variable
|
||
|
return named_slots[_var_key(var)]
|
||
|
|
||
|
def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
|
||
|
slot_name, op_name):
|
||
|
"""Find or create a slot for a variable, using an Initializer.
|
||
|
|
||
|
Args:
|
||
|
var: A `Variable` object.
|
||
|
initializer: An `Initializer`. The initial value of the slot.
|
||
|
shape: Shape of the initial value of the slot.
|
||
|
dtype: Type of the value of the slot.
|
||
|
slot_name: Name for the slot.
|
||
|
op_name: Name to use when scoping the Variable that
|
||
|
needs to be created for the slot.
|
||
|
|
||
|
Returns:
|
||
|
A `Variable` object.
|
||
|
"""
|
||
|
named_slots = self._slot_dict(slot_name)
|
||
|
if _var_key(var) not in named_slots:
|
||
|
new_slot_variable = slot_creator.create_slot_with_initializer(
|
||
|
var, initializer, shape, dtype, op_name)
|
||
|
self._restore_slot_variable(
|
||
|
slot_name=slot_name, variable=var,
|
||
|
slot_variable=new_slot_variable)
|
||
|
named_slots[_var_key(var)] = new_slot_variable
|
||
|
return named_slots[_var_key(var)]
|
||
|
|
||
|
def _zeros_slot(self, var, slot_name, op_name):
|
||
|
"""Find or create a slot initialized with 0.0.
|
||
|
|
||
|
Args:
|
||
|
var: A `Variable` object.
|
||
|
slot_name: Name for the slot.
|
||
|
op_name: Name to use when scoping the Variable that
|
||
|
needs to be created for the slot.
|
||
|
|
||
|
Returns:
|
||
|
A `Variable` object.
|
||
|
"""
|
||
|
named_slots = self._slot_dict(slot_name)
|
||
|
if _var_key(var) not in named_slots:
|
||
|
new_slot_variable = slot_creator.create_zeros_slot(var, op_name)
|
||
|
self._restore_slot_variable(
|
||
|
slot_name=slot_name, variable=var,
|
||
|
slot_variable=new_slot_variable)
|
||
|
named_slots[_var_key(var)] = new_slot_variable
|
||
|
return named_slots[_var_key(var)]
|
||
|
|
||
|
# --------------
|
||
|
# For implementing the Checkpointable interface.
|
||
|
# --------------
|
||
|
|
||
|
def _restore_slot_variable(self, slot_name, variable, slot_variable):
|
||
|
"""Restore a newly created slot variable's value."""
|
||
|
variable_key = _var_key(variable)
|
||
|
deferred_restorations = self._deferred_slot_restorations.get(
|
||
|
slot_name, {}).pop(variable_key, [])
|
||
|
# Iterate over restores, highest restore UID first to minimize the number
|
||
|
# of assignments.
|
||
|
deferred_restorations.sort(key=lambda position: position.restore_uid,
|
||
|
reverse=True)
|
||
|
for checkpoint_position in deferred_restorations:
|
||
|
checkpoint_position.restore(slot_variable)
|
||
|
|
||
|
def _create_or_restore_slot_variable(
|
||
|
self, slot_variable_position, slot_name, variable):
|
||
|
"""Restore a slot variable's value, possibly creating it.
|
||
|
|
||
|
Called when a variable which has an associated slot variable is created or
|
||
|
restored. When executing eagerly, we create the slot variable with a
|
||
|
restoring initializer.
|
||
|
|
||
|
No new variables are created when graph building. Instead,
|
||
|
_restore_slot_variable catches these after normal creation and adds restore
|
||
|
ops to the graph. This method is nonetheless important when graph building
|
||
|
for the case when a slot variable has already been created but `variable`
|
||
|
has just been added to a dependency graph (causing us to realize that the
|
||
|
slot variable needs to be restored).
|
||
|
|
||
|
Args:
|
||
|
slot_variable_position: A `checkpointable._CheckpointPosition` object
|
||
|
indicating the slot variable `Checkpointable` object to be restored.
|
||
|
slot_name: The name of this `Optimizer`'s slot to restore into.
|
||
|
variable: The variable object this slot is being created for.
|
||
|
"""
|
||
|
named_slots = self._slot_dict(slot_name)
|
||
|
variable_key = _var_key(variable)
|
||
|
slot_variable = named_slots.get(variable_key, None)
|
||
|
if (slot_variable is None and context.executing_eagerly() and
|
||
|
slot_variable_position.is_simple_variable()
|
||
|
# Defer slot variable creation if there is an active variable creator
|
||
|
# scope. Generally we'd like to eagerly create/restore slot variables
|
||
|
# when possible, but this may mean that scopes intended to catch
|
||
|
# `variable` also catch its eagerly created slot variable
|
||
|
# unintentionally (specifically make_template would add a dependency on
|
||
|
# a slot variable if not for this case). Deferring is mostly harmless
|
||
|
# (aside from double initialization), and makes variable creator scopes
|
||
|
# behave the same way they do when graph building.
|
||
|
and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
|
||
|
initializer = checkpointable.CheckpointInitialValue(
|
||
|
checkpoint_position=slot_variable_position)
|
||
|
slot_variable = self._get_or_make_slot(
|
||
|
var=variable,
|
||
|
val=initializer,
|
||
|
slot_name=slot_name,
|
||
|
op_name=self._name)
|
||
|
# Slot variables are not owned by any one object (because we don't want to
|
||
|
# save the slot variable if the optimizer is saved without the non-slot
|
||
|
# variable, or if the non-slot variable is saved without the optimizer;
|
||
|
# it's a dependency hypergraph with edges of the form (optimizer, non-slot
|
||
|
# variable, variable)). So we don't _track_ slot variables anywhere, and
|
||
|
# instead special-case this dependency and otherwise pretend it's a normal
|
||
|
# graph.
|
||
|
if slot_variable is not None:
|
||
|
# If we've either made this slot variable, or if we've pulled out an
|
||
|
# existing slot variable, we should restore it.
|
||
|
slot_variable_position.restore(slot_variable)
|
||
|
else:
|
||
|
# We didn't make the slot variable. Defer restoring until it gets created
|
||
|
# normally. We keep a list rather than the one with the highest restore
|
||
|
# UID in case slot variables have their own dependencies, in which case
|
||
|
# those could differ between restores.
|
||
|
self._deferred_slot_restorations.setdefault(
|
||
|
slot_name, {}).setdefault(variable_key, []).append(
|
||
|
slot_variable_position)
|
||
|
|
||
|
def _call_if_callable(self, param):
|
||
|
"""Call the function if param is callable."""
|
||
|
return param() if callable(param) else param
|