216 lines
8.4 KiB
Python
216 lines
8.4 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Contains functions for evaluation and summarization of metrics."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import time
|
|
import math
|
|
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import init_ops
|
|
from tensorflow.python.ops import state_ops
|
|
from tensorflow.python.ops import variable_scope
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.training import basic_session_run_hooks
|
|
from tensorflow.python.training import monitored_session
|
|
from tensorflow.python.training import session_run_hook
|
|
|
|
|
|
def _get_or_create_eval_step():
|
|
"""Gets or creates the eval step `Tensor`.
|
|
|
|
Returns:
|
|
A `Tensor` representing a counter for the evaluation step.
|
|
|
|
Raises:
|
|
ValueError: If multiple `Tensors` have been added to the
|
|
`tf.GraphKeys.EVAL_STEP` collection.
|
|
"""
|
|
graph = ops.get_default_graph()
|
|
eval_steps = graph.get_collection(ops.GraphKeys.EVAL_STEP)
|
|
if len(eval_steps) == 1:
|
|
return eval_steps[0]
|
|
elif len(eval_steps) > 1:
|
|
raise ValueError('Multiple tensors added to tf.GraphKeys.EVAL_STEP')
|
|
else:
|
|
counter = variable_scope.get_variable(
|
|
'eval_step',
|
|
shape=[],
|
|
dtype=dtypes.int64,
|
|
initializer=init_ops.zeros_initializer(),
|
|
trainable=False,
|
|
collections=[ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.EVAL_STEP])
|
|
return counter
|
|
|
|
|
|
def _get_latest_eval_step_value(update_ops):
|
|
"""Gets the eval step `Tensor` value after running `update_ops`.
|
|
|
|
Args:
|
|
update_ops: A list of `Tensors` or a dictionary of names to `Tensors`,
|
|
which are run before reading the eval step value.
|
|
|
|
Returns:
|
|
A `Tensor` representing the value for the evaluation step.
|
|
"""
|
|
if isinstance(update_ops, dict):
|
|
update_ops = list(update_ops.values())
|
|
|
|
with ops.control_dependencies(update_ops):
|
|
return array_ops.identity(_get_or_create_eval_step().read_value())
|
|
|
|
|
|
class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
|
|
"""Run hook used by the evaluation routines to run the `eval_ops` N times."""
|
|
|
|
def __init__(self, num_evals, log_progress=True):
|
|
"""Constructs the run hook.
|
|
|
|
Args:
|
|
num_evals: The number of evaluations to run for. if set to None, will
|
|
iterate the dataset until all inputs are exhausted.
|
|
log_progress: Whether to log evaluation progress, defaults to True.
|
|
"""
|
|
# The number of evals to run for.
|
|
self._num_evals = num_evals
|
|
self._evals_completed = None
|
|
self._log_progress = log_progress
|
|
# Reduce logging frequency if there are 20 or more evaluations.
|
|
self._log_frequency = (1 if (num_evals is None or num_evals < 20)
|
|
else math.floor(num_evals / 10.))
|
|
|
|
def _set_evals_completed_tensor(self, updated_eval_step):
|
|
self._evals_completed = updated_eval_step
|
|
|
|
def before_run(self, run_context):
|
|
return session_run_hook.SessionRunArgs({
|
|
'evals_completed': self._evals_completed
|
|
})
|
|
|
|
def after_run(self, run_context, run_values):
|
|
evals_completed = run_values.results['evals_completed']
|
|
if self._log_progress:
|
|
if self._num_evals is None:
|
|
logging.info('Evaluation [%d]', evals_completed)
|
|
else:
|
|
if ((evals_completed % self._log_frequency) == 0 or
|
|
(self._num_evals == evals_completed)):
|
|
logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
|
|
if self._num_evals is not None and evals_completed >= self._num_evals:
|
|
run_context.request_stop()
|
|
|
|
|
|
def _evaluate_once(checkpoint_path,
|
|
master='',
|
|
scaffold=None,
|
|
eval_ops=None,
|
|
feed_dict=None,
|
|
final_ops=None,
|
|
final_ops_feed_dict=None,
|
|
hooks=None,
|
|
config=None):
|
|
"""Evaluates the model at the given checkpoint path.
|
|
|
|
During a single evaluation, the `eval_ops` is run until the session is
|
|
interrupted or requested to finish. This is typically requested via a
|
|
`tf.contrib.training.StopAfterNEvalsHook` which results in `eval_ops` running
|
|
the requested number of times.
|
|
|
|
Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of
|
|
`Tensors` or a dictionary from names to `Tensors`. The `final_ops` is
|
|
evaluated a single time after `eval_ops` has finished running and the fetched
|
|
values of `final_ops` are returned. If `final_ops` is left as `None`, then
|
|
`None` is returned.
|
|
|
|
One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record
|
|
summaries after the `eval_ops` have run. If `eval_ops` is `None`, the
|
|
summaries run immediately after the model checkpoint has been restored.
|
|
|
|
Note that `evaluate_once` creates a local variable used to track the number of
|
|
evaluations run via `tf.contrib.training.get_or_create_eval_step`.
|
|
Consequently, if a custom local init op is provided via a `scaffold`, the
|
|
caller should ensure that the local init op also initializes the eval step.
|
|
|
|
Args:
|
|
checkpoint_path: The path to a checkpoint to use for evaluation.
|
|
master: The BNS address of the TensorFlow master.
|
|
scaffold: An tf.train.Scaffold instance for initializing variables and
|
|
restoring variables. Note that `scaffold.init_fn` is used by the function
|
|
to restore the checkpoint. If you supply a custom init_fn, then it must
|
|
also take care of restoring the model from its checkpoint.
|
|
eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
|
|
to `Tensors`, which is run until the session is requested to stop,
|
|
commonly done by a `tf.contrib.training.StopAfterNEvalsHook`.
|
|
feed_dict: The feed dictionary to use when executing the `eval_ops`.
|
|
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
|
|
to `Tensors`.
|
|
final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
|
|
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
|
|
evaluation loop.
|
|
config: An instance of `tf.ConfigProto` that will be used to
|
|
configure the `Session`. If left as `None`, the default will be used.
|
|
|
|
Returns:
|
|
The fetched values of `final_ops` or `None` if `final_ops` is `None`.
|
|
"""
|
|
eval_step = _get_or_create_eval_step()
|
|
|
|
# Prepare the run hooks.
|
|
hooks = list(hooks or [])
|
|
|
|
if eval_ops is not None:
|
|
update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True)
|
|
|
|
if isinstance(eval_ops, dict):
|
|
eval_ops['update_eval_step'] = update_eval_step
|
|
elif isinstance(eval_ops, (tuple, list)):
|
|
eval_ops = list(eval_ops) + [update_eval_step]
|
|
else:
|
|
eval_ops = [eval_ops, update_eval_step]
|
|
|
|
eval_step_value = _get_latest_eval_step_value(eval_ops)
|
|
|
|
for h in hooks:
|
|
if isinstance(h, _StopAfterNEvalsHook):
|
|
h._set_evals_completed_tensor(eval_step_value) # pylint: disable=protected-access
|
|
|
|
logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
|
|
time.gmtime()))
|
|
|
|
# Prepare the session creator.
|
|
session_creator = monitored_session.ChiefSessionCreator(
|
|
scaffold=scaffold,
|
|
checkpoint_filename_with_path=checkpoint_path,
|
|
master=master,
|
|
config=config)
|
|
|
|
final_ops_hook = basic_session_run_hooks.FinalOpsHook(
|
|
final_ops, final_ops_feed_dict)
|
|
hooks.append(final_ops_hook)
|
|
|
|
with monitored_session.MonitoredSession(
|
|
session_creator=session_creator, hooks=hooks) as session:
|
|
if eval_ops is not None:
|
|
while not session.should_stop():
|
|
session.run(eval_ops, feed_dict)
|
|
|
|
logging.info('Finished evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
|
|
time.gmtime()))
|
|
return final_ops_hook.final_ops_values
|