"""Saver for eager mode TensorFlow.""" # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from __future__ import division from __future__ import print_function import contextlib from tensorflow.python.eager import context from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import saver as _saver def _init_from_checkpoint(self, *args, **kwargs): """Overrides default init by loading value from checkpoint.""" # pylint: disable=protected-access self._old_init(*args, **kwargs) ckpt_name = self._map_func(self._shared_name) if ckpt_name not in self._ckpt_var_cache: raise errors.NotFoundError(None, None, "%s not found in checkpoint" % ckpt_name) val = self._ckpt_var_cache.get(ckpt_name, None) if val is not None: self.assign(val) # Avoid assigning for the second time. self._ckpt_var_cache[ckpt_name] = None # pylint: enable=protected-access @contextlib.contextmanager def restore_variables_on_create(save_path, map_func=None): """ContextManager that restores variables on creation. When save_path is None (e.g. No checkpoint), does nothing. Otherwise, it preloads all values from checkpoint. When the corresponding variable is first created, it assigns the checkpoint value to the variable. ```python with restore_variables_on_create( tf.train.latest_checkpoint(checkpoint_dir)): ``` Args: save_path: The checkpoint file prefix. map_func: A function that given the variable name as argument and returns a variable name in checkpoint for restore. If None, use the variable with the same name in checkpoint to restore. It's an error that the mapped variable name doesn't exist in checkpoint. Yields: Nothing. Raises: NotFoundError: If the variable is not found in checkpoint. ValueError: If not used in eager mode or map_func is not callable. """ if not context.executing_eagerly(): raise ValueError( "Currently, restore_variables_on_create can only be used with " "eager execution enabled.") if save_path: if map_func is None: map_func_wrapper = lambda self, x: x else: if not callable(map_func): raise ValueError("map_func must be callable.") map_func_wrapper = lambda self, x: map_func(x) ckpt_var_cache = dict() reader = checkpoint_utils.load_checkpoint(save_path) for k, _ in checkpoint_utils.list_variables(save_path): ckpt_var_cache[k] = reader.get_tensor(k) old_init = getattr(resource_variable_ops.ResourceVariable, "_init_from_args", None) assert old_init, "ResourceVariable misses _init_from_args method." setattr(resource_variable_ops.ResourceVariable, "_init_from_args", _init_from_checkpoint) setattr(resource_variable_ops.ResourceVariable, "_old_init", old_init) setattr(resource_variable_ops.ResourceVariable, "_map_func", map_func_wrapper) setattr(resource_variable_ops.ResourceVariable, "_ckpt_var_cache", ckpt_var_cache) try: yield except Exception as e: raise e finally: if save_path: setattr(resource_variable_ops.ResourceVariable, "_init_from_args", old_init) setattr(resource_variable_ops.ResourceVariable, "_old_init", None) setattr(resource_variable_ops.ResourceVariable, "_map_func", None) setattr(resource_variable_ops.ResourceVariable, "_ckpt_var_cache", None) class Saver(object): """A tf.train.Saver adapter for use when eager execution is enabled. """ def __init__(self, var_list): """A tf.train.Saver adapter for use when eager execution is enabled. The API, and on-disk format, mimic tf.train.Saver except that no Session is needed. Args: var_list: The list of variables that will be saved and restored. Either a list of `tfe.Variable` objects, or a dictionary mapping names to `tfe.Variable` objects. Raises: RuntimeError: if invoked when eager execution has not been enabled. """ if not context.executing_eagerly(): raise RuntimeError("tfe.Saver can only be used when eager " "execution is enabled. Use tf.train.Saver when " "building graphs.") self._saver = _saver.Saver(var_list=var_list) def save(self, file_prefix, global_step=None): """Saves variables. Args: file_prefix: Path prefix of files created for the checkpoint. global_step: If provided the global step number is appended to file_prefix to create the checkpoint filename. The optional argument can be a Tensor, a Variable, or an integer. Returns: A string: prefix of filenames created for the checkpoint. This may be an extension of file_prefix that is suitable to pass as an argument to a subsequent call to `restore()`. """ with ops.device("/device:CPU:0"): return self._saver.save( None, file_prefix, write_meta_graph=False, global_step=global_step) def restore(self, file_prefix): """Restores previously saved variables. Args: file_prefix: Path prefix where parameters were previously saved. Typically obtained from a previous `save()` call, or from @{tf.train.latest_checkpoint}. """ with ops.device("/device:CPU:0"): self._saver.restore(None, file_prefix) def get_optimizer_variables(optimizer): """Returns a list of variables for the given `tf.train.Optimizer`. Equivalent to `optimizer.variables()`. Args: optimizer: An instance of `tf.train.Optimizer` which has created variables (typically after a call to `Optimizer.minimize`). Returns: A list of variables which have been created by the `Optimizer`. """ return optimizer.variables()