laywerrobot/lib/python3.6/site-packages/tensorflow/contrib/eager/python/saver.py

182 lines
6.5 KiB
Python
Raw Normal View History

2020-08-27 21:55:39 +02:00
"""Saver for eager mode TensorFlow."""
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import saver as _saver
def _init_from_checkpoint(self, *args, **kwargs):
"""Overrides default init by loading value from checkpoint."""
# pylint: disable=protected-access
self._old_init(*args, **kwargs)
ckpt_name = self._map_func(self._shared_name)
if ckpt_name not in self._ckpt_var_cache:
raise errors.NotFoundError(None, None,
"%s not found in checkpoint" % ckpt_name)
val = self._ckpt_var_cache.get(ckpt_name, None)
if val is not None:
self.assign(val)
# Avoid assigning for the second time.
self._ckpt_var_cache[ckpt_name] = None
# pylint: enable=protected-access
@contextlib.contextmanager
def restore_variables_on_create(save_path, map_func=None):
"""ContextManager that restores variables on creation.
When save_path is None (e.g. No checkpoint), does nothing.
Otherwise, it preloads all values from checkpoint. When the
corresponding variable is first created, it assigns the checkpoint
value to the variable.
```python
with restore_variables_on_create(
tf.train.latest_checkpoint(checkpoint_dir)):
```
Args:
save_path: The checkpoint file prefix.
map_func: A function that given the variable name as argument
and returns a variable name in checkpoint for restore. If
None, use the variable with the same name in checkpoint to restore.
It's an error that the mapped variable name doesn't exist in
checkpoint.
Yields:
Nothing.
Raises:
NotFoundError: If the variable is not found in checkpoint.
ValueError: If not used in eager mode or map_func is not callable.
"""
if not context.executing_eagerly():
raise ValueError(
"Currently, restore_variables_on_create can only be used with "
"eager execution enabled.")
if save_path:
if map_func is None:
map_func_wrapper = lambda self, x: x
else:
if not callable(map_func):
raise ValueError("map_func must be callable.")
map_func_wrapper = lambda self, x: map_func(x)
ckpt_var_cache = dict()
reader = checkpoint_utils.load_checkpoint(save_path)
for k, _ in checkpoint_utils.list_variables(save_path):
ckpt_var_cache[k] = reader.get_tensor(k)
old_init = getattr(resource_variable_ops.ResourceVariable,
"_init_from_args", None)
assert old_init, "ResourceVariable misses _init_from_args method."
setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
_init_from_checkpoint)
setattr(resource_variable_ops.ResourceVariable, "_old_init", old_init)
setattr(resource_variable_ops.ResourceVariable, "_map_func",
map_func_wrapper)
setattr(resource_variable_ops.ResourceVariable, "_ckpt_var_cache",
ckpt_var_cache)
try:
yield
except Exception as e:
raise e
finally:
if save_path:
setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
old_init)
setattr(resource_variable_ops.ResourceVariable, "_old_init", None)
setattr(resource_variable_ops.ResourceVariable, "_map_func", None)
setattr(resource_variable_ops.ResourceVariable, "_ckpt_var_cache", None)
class Saver(object):
"""A tf.train.Saver adapter for use when eager execution is enabled.
"""
def __init__(self, var_list):
"""A tf.train.Saver adapter for use when eager execution is enabled.
The API, and on-disk format, mimic tf.train.Saver except that no
Session is needed.
Args:
var_list: The list of variables that will be saved and restored. Either a
list of `tfe.Variable` objects, or a dictionary mapping names to
`tfe.Variable` objects.
Raises:
RuntimeError: if invoked when eager execution has not been enabled.
"""
if not context.executing_eagerly():
raise RuntimeError("tfe.Saver can only be used when eager "
"execution is enabled. Use tf.train.Saver when "
"building graphs.")
self._saver = _saver.Saver(var_list=var_list)
def save(self, file_prefix, global_step=None):
"""Saves variables.
Args:
file_prefix: Path prefix of files created for the checkpoint.
global_step: If provided the global step number is appended to file_prefix
to create the checkpoint filename. The optional argument can be a
Tensor, a Variable, or an integer.
Returns:
A string: prefix of filenames created for the checkpoint. This may be
an extension of file_prefix that is suitable to pass as an argument
to a subsequent call to `restore()`.
"""
with ops.device("/device:CPU:0"):
return self._saver.save(
None, file_prefix, write_meta_graph=False, global_step=global_step)
def restore(self, file_prefix):
"""Restores previously saved variables.
Args:
file_prefix: Path prefix where parameters were previously saved.
Typically obtained from a previous `save()` call, or from
@{tf.train.latest_checkpoint}.
"""
with ops.device("/device:CPU:0"):
self._saver.restore(None, file_prefix)
def get_optimizer_variables(optimizer):
"""Returns a list of variables for the given `tf.train.Optimizer`.
Equivalent to `optimizer.variables()`.
Args:
optimizer: An instance of `tf.train.Optimizer` which has created variables
(typically after a call to `Optimizer.minimize`).
Returns:
A list of variables which have been created by the `Optimizer`.
"""
return optimizer.variables()