360 lines
14 KiB
Python
360 lines
14 KiB
Python
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""tfdbg CLI as SessionRunHook."""
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
from tensorflow.core.protobuf import config_pb2
|
||
|
from tensorflow.python.debug.lib import debug_utils
|
||
|
from tensorflow.python.debug.lib import stepper
|
||
|
from tensorflow.python.debug.wrappers import dumping_wrapper
|
||
|
from tensorflow.python.debug.wrappers import framework
|
||
|
from tensorflow.python.debug.wrappers import grpc_wrapper
|
||
|
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
||
|
from tensorflow.python.training import session_run_hook
|
||
|
|
||
|
|
||
|
class LocalCLIDebugHook(session_run_hook.SessionRunHook):
|
||
|
"""Command-line-interface debugger hook.
|
||
|
|
||
|
Can be used as a hook for `tf.train.MonitoredSession`s and
|
||
|
`tf.estimator.Estimator`s. Provides a substitute for
|
||
|
`tfdbg.LocalCLIDebugWrapperSession` in cases where the session is not directly
|
||
|
available.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, ui_type="curses", dump_root=None, thread_name_filter=None):
|
||
|
"""Create a local debugger command-line interface (CLI) hook.
|
||
|
|
||
|
Args:
|
||
|
ui_type: (`str`) requested user-interface type. Currently supported:
|
||
|
(curses | readline).
|
||
|
dump_root: (`str`) optional path to the dump root directory. Must be a
|
||
|
directory that does not exist or an empty directory. If the directory
|
||
|
does not exist, it will be created by the debugger core during debug
|
||
|
`run()` calls and removed afterwards.
|
||
|
thread_name_filter: Regular-expression white list for threads on which the
|
||
|
wrapper session will be active. See doc of `BaseDebugWrapperSession` for
|
||
|
more details.
|
||
|
"""
|
||
|
|
||
|
self._ui_type = ui_type
|
||
|
self._dump_root = dump_root
|
||
|
self._thread_name_filter = thread_name_filter
|
||
|
self._session_wrapper = None
|
||
|
self._pending_tensor_filters = {}
|
||
|
|
||
|
def add_tensor_filter(self, filter_name, tensor_filter):
|
||
|
"""Add a tensor filter.
|
||
|
|
||
|
See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()` for details.
|
||
|
Override default behavior to accommodate the possibility of this method
|
||
|
being
|
||
|
called prior to the initialization of the underlying
|
||
|
`LocalCLIDebugWrapperSession` object.
|
||
|
|
||
|
Args:
|
||
|
filter_name: See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()`
|
||
|
for details.
|
||
|
tensor_filter: See doc of
|
||
|
`LocalCLIDebugWrapperSession.add_tensor_filter()` for details.
|
||
|
"""
|
||
|
|
||
|
if self._session_wrapper:
|
||
|
self._session_wrapper.add_tensor_filter(filter_name, tensor_filter)
|
||
|
else:
|
||
|
self._pending_tensor_filters[filter_name] = tensor_filter
|
||
|
|
||
|
def begin(self):
|
||
|
pass
|
||
|
|
||
|
def before_run(self, run_context):
|
||
|
if not self._session_wrapper:
|
||
|
self._session_wrapper = local_cli_wrapper.LocalCLIDebugWrapperSession(
|
||
|
run_context.session,
|
||
|
ui_type=self._ui_type,
|
||
|
dump_root=self._dump_root,
|
||
|
thread_name_filter=self._thread_name_filter)
|
||
|
|
||
|
# Actually register tensor filters registered prior to the construction
|
||
|
# of the underlying LocalCLIDebugWrapperSession object.
|
||
|
for filter_name in self._pending_tensor_filters:
|
||
|
self._session_wrapper.add_tensor_filter(
|
||
|
filter_name, self._pending_tensor_filters[filter_name])
|
||
|
|
||
|
# Increment run call counter.
|
||
|
self._session_wrapper.increment_run_call_count()
|
||
|
|
||
|
# Adapt run_context to an instance of OnRunStartRequest for invoking
|
||
|
# superclass on_run_start().
|
||
|
on_run_start_request = framework.OnRunStartRequest(
|
||
|
run_context.original_args.fetches, run_context.original_args.feed_dict,
|
||
|
None, None, self._session_wrapper.run_call_count)
|
||
|
|
||
|
on_run_start_response = self._session_wrapper.on_run_start(
|
||
|
on_run_start_request)
|
||
|
self._performed_action = on_run_start_response.action
|
||
|
|
||
|
run_args = session_run_hook.SessionRunArgs(
|
||
|
None, feed_dict=None, options=config_pb2.RunOptions())
|
||
|
if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
|
||
|
# pylint: disable=protected-access
|
||
|
self._session_wrapper._decorate_run_options_for_debug(
|
||
|
run_args.options,
|
||
|
on_run_start_response.debug_urls,
|
||
|
debug_ops=on_run_start_response.debug_ops,
|
||
|
node_name_regex_whitelist=(
|
||
|
on_run_start_response.node_name_regex_whitelist),
|
||
|
op_type_regex_whitelist=(
|
||
|
on_run_start_response.op_type_regex_whitelist),
|
||
|
tensor_dtype_regex_whitelist=(
|
||
|
on_run_start_response.tensor_dtype_regex_whitelist),
|
||
|
tolerate_debug_op_creation_failures=(
|
||
|
on_run_start_response.tolerate_debug_op_creation_failures))
|
||
|
# pylint: enable=protected-access
|
||
|
elif self._performed_action == framework.OnRunStartAction.PROFILE_RUN:
|
||
|
# pylint: disable=protected-access
|
||
|
self._session_wrapper._decorate_run_options_for_profile(run_args.options)
|
||
|
# pylint: enable=protected-access
|
||
|
elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
|
||
|
# The _finalized property must be set to False so that the NodeStepper
|
||
|
# can insert ops for retrieving TensorHandles.
|
||
|
# pylint: disable=protected-access
|
||
|
run_context.session.graph._finalized = False
|
||
|
# pylint: enable=protected-access
|
||
|
|
||
|
with stepper.NodeStepper(
|
||
|
run_context.session, run_context.original_args.fetches,
|
||
|
run_context.original_args.feed_dict) as node_stepper:
|
||
|
self._session_wrapper.invoke_node_stepper(
|
||
|
node_stepper, restore_variable_values_on_exit=True)
|
||
|
|
||
|
return run_args
|
||
|
|
||
|
def after_run(self, run_context, run_values):
|
||
|
# Adapt run_context and run_values to OnRunEndRequest and invoke superclass
|
||
|
# on_run_end()
|
||
|
on_run_end_request = framework.OnRunEndRequest(self._performed_action,
|
||
|
run_values.run_metadata)
|
||
|
self._session_wrapper.on_run_end(on_run_end_request)
|
||
|
|
||
|
|
||
|
class DumpingDebugHook(session_run_hook.SessionRunHook):
|
||
|
"""A debugger hook that dumps debug data to filesystem.
|
||
|
|
||
|
Can be used as a hook for `tf.train.MonitoredSession`s and
|
||
|
`tf.estimator.Estimator`s.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
session_root,
|
||
|
watch_fn=None,
|
||
|
thread_name_filter=None,
|
||
|
log_usage=True):
|
||
|
"""Create a local debugger command-line interface (CLI) hook.
|
||
|
|
||
|
Args:
|
||
|
session_root: See doc of
|
||
|
`dumping_wrapper.DumpingDebugWrapperSession.__init__`.
|
||
|
watch_fn: See doc of
|
||
|
`dumping_wrapper.DumpingDebugWrapperSession.__init__`.
|
||
|
thread_name_filter: Regular-expression white list for threads on which the
|
||
|
wrapper session will be active. See doc of `BaseDebugWrapperSession` for
|
||
|
more details.
|
||
|
log_usage: (bool) Whether usage is to be logged.
|
||
|
"""
|
||
|
|
||
|
self._session_root = session_root
|
||
|
self._watch_fn = watch_fn
|
||
|
self._thread_name_filter = thread_name_filter
|
||
|
self._log_usage = log_usage
|
||
|
self._session_wrapper = None
|
||
|
|
||
|
def begin(self):
|
||
|
pass
|
||
|
|
||
|
def before_run(self, run_context):
|
||
|
if not self._session_wrapper:
|
||
|
self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession(
|
||
|
run_context.session,
|
||
|
self._session_root,
|
||
|
watch_fn=self._watch_fn,
|
||
|
thread_name_filter=self._thread_name_filter,
|
||
|
log_usage=self._log_usage)
|
||
|
|
||
|
self._session_wrapper.increment_run_call_count()
|
||
|
|
||
|
# pylint: disable=protected-access
|
||
|
debug_urls, watch_options = self._session_wrapper._prepare_run_watch_config(
|
||
|
run_context.original_args.fetches, run_context.original_args.feed_dict)
|
||
|
# pylint: enable=protected-access
|
||
|
run_options = config_pb2.RunOptions()
|
||
|
debug_utils.watch_graph(
|
||
|
run_options,
|
||
|
run_context.session.graph,
|
||
|
debug_urls=debug_urls,
|
||
|
debug_ops=watch_options.debug_ops,
|
||
|
node_name_regex_whitelist=watch_options.node_name_regex_whitelist,
|
||
|
op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
|
||
|
tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
|
||
|
tolerate_debug_op_creation_failures=(
|
||
|
watch_options.tolerate_debug_op_creation_failures))
|
||
|
|
||
|
run_args = session_run_hook.SessionRunArgs(
|
||
|
None, feed_dict=None, options=run_options)
|
||
|
return run_args
|
||
|
|
||
|
def after_run(self, run_context, run_values):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class GrpcDebugHook(session_run_hook.SessionRunHook):
|
||
|
"""A hook that streams debugger-related events to any grpc_debug_server.
|
||
|
|
||
|
For example, the debugger data server is a grpc_debug_server. The debugger
|
||
|
data server writes debugger-related events it receives via GRPC to logdir.
|
||
|
This enables debugging features in Tensorboard such as health pills.
|
||
|
|
||
|
When the arguments of debug_utils.watch_graph changes, strongly consider
|
||
|
changing arguments here too so that features are available to tflearn users.
|
||
|
|
||
|
Can be used as a hook for `tf.train.MonitoredSession`s and
|
||
|
`tf.estimator.Estimator`s.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
grpc_debug_server_addresses,
|
||
|
watch_fn=None,
|
||
|
thread_name_filter=None,
|
||
|
log_usage=True):
|
||
|
"""Constructs a GrpcDebugHook.
|
||
|
|
||
|
Args:
|
||
|
grpc_debug_server_addresses: (`list` of `str`) A list of the gRPC debug
|
||
|
server addresses, in the format of <host:port>, with or without the
|
||
|
"grpc://" prefix. For example: ["localhost:7000", "192.168.0.2:8000"]
|
||
|
watch_fn: A function that allows for customizing which ops to watch at
|
||
|
which specific steps. See doc of
|
||
|
`dumping_wrapper.DumpingDebugWrapperSession.__init__` for details.
|
||
|
thread_name_filter: Regular-expression white list for threads on which the
|
||
|
wrapper session will be active. See doc of `BaseDebugWrapperSession` for
|
||
|
more details.
|
||
|
log_usage: (bool) Whether usage is to be logged.
|
||
|
"""
|
||
|
self._grpc_debug_wrapper_session = None
|
||
|
self._thread_name_filter = thread_name_filter
|
||
|
self._grpc_debug_server_addresses = (
|
||
|
grpc_debug_server_addresses
|
||
|
if isinstance(grpc_debug_server_addresses, list) else
|
||
|
[grpc_debug_server_addresses])
|
||
|
|
||
|
self._watch_fn = watch_fn
|
||
|
self._log_usage = log_usage
|
||
|
|
||
|
def before_run(self, run_context):
|
||
|
"""Called right before a session is run.
|
||
|
|
||
|
Args:
|
||
|
run_context: A session_run_hook.SessionRunContext. Encapsulates
|
||
|
information on the run.
|
||
|
|
||
|
Returns:
|
||
|
A session_run_hook.SessionRunArgs object.
|
||
|
"""
|
||
|
|
||
|
if not self._grpc_debug_wrapper_session:
|
||
|
self._grpc_debug_wrapper_session = grpc_wrapper.GrpcDebugWrapperSession(
|
||
|
run_context.session,
|
||
|
self._grpc_debug_server_addresses,
|
||
|
watch_fn=self._watch_fn,
|
||
|
thread_name_filter=self._thread_name_filter,
|
||
|
log_usage=self._log_usage)
|
||
|
|
||
|
fetches = run_context.original_args.fetches
|
||
|
feed_dict = run_context.original_args.feed_dict
|
||
|
watch_options = self._watch_fn(fetches, feed_dict)
|
||
|
run_options = config_pb2.RunOptions()
|
||
|
debug_utils.watch_graph(
|
||
|
run_options,
|
||
|
run_context.session.graph,
|
||
|
debug_urls=self._grpc_debug_wrapper_session.prepare_run_debug_urls(
|
||
|
fetches, feed_dict),
|
||
|
debug_ops=watch_options.debug_ops,
|
||
|
node_name_regex_whitelist=watch_options.node_name_regex_whitelist,
|
||
|
op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
|
||
|
tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
|
||
|
tolerate_debug_op_creation_failures=(
|
||
|
watch_options.tolerate_debug_op_creation_failures))
|
||
|
|
||
|
return session_run_hook.SessionRunArgs(
|
||
|
None, feed_dict=None, options=run_options)
|
||
|
|
||
|
|
||
|
class TensorBoardDebugHook(GrpcDebugHook):
|
||
|
"""A tfdbg hook that can be used with TensorBoard Debugger Plugin.
|
||
|
|
||
|
This hook is the same as `GrpcDebugHook`, except that it uses a predefined
|
||
|
`watch_fn` that
|
||
|
1) uses `DebugIdentity` debug ops with the `gated_grpc` attribute set to
|
||
|
`True`, to allow the interactive enabling and disabling of tensor
|
||
|
breakpoints.
|
||
|
2) watches all tensors in the graph.
|
||
|
This saves the need for the user to define a `watch_fn`.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
grpc_debug_server_addresses,
|
||
|
thread_name_filter=None,
|
||
|
send_traceback_and_source_code=True,
|
||
|
log_usage=True):
|
||
|
"""Constructor of TensorBoardDebugHook.
|
||
|
|
||
|
Args:
|
||
|
grpc_debug_server_addresses: gRPC address(es) of debug server(s), as a
|
||
|
`str` or a `list` of `str`s. E.g., "localhost:2333",
|
||
|
"grpc://localhost:2333", ["192.168.0.7:2333", "192.168.0.8:2333"].
|
||
|
thread_name_filter: Optional filter for thread names.
|
||
|
send_traceback_and_source_code: Whether traceback of graph elements and
|
||
|
the source code are to be sent to the debug server(s).
|
||
|
log_usage: Whether the usage of this class is to be logged (if
|
||
|
applicable).
|
||
|
"""
|
||
|
|
||
|
def _gated_grpc_watch_fn(fetches, feeds):
|
||
|
del fetches, feeds # Unused.
|
||
|
return framework.WatchOptions(
|
||
|
debug_ops=["DebugIdentity(gated_grpc=true)"])
|
||
|
|
||
|
super(TensorBoardDebugHook, self).__init__(
|
||
|
grpc_debug_server_addresses,
|
||
|
watch_fn=_gated_grpc_watch_fn,
|
||
|
thread_name_filter=thread_name_filter,
|
||
|
log_usage=log_usage)
|
||
|
|
||
|
self._grpc_debug_server_addresses = grpc_debug_server_addresses
|
||
|
self._send_traceback_and_source_code = send_traceback_and_source_code
|
||
|
self._sent_graph_version = -1
|
||
|
grpc_wrapper.register_signal_handler()
|
||
|
|
||
|
def before_run(self, run_context):
|
||
|
if self._send_traceback_and_source_code:
|
||
|
self._sent_graph_version = grpc_wrapper.publish_traceback(
|
||
|
self._grpc_debug_server_addresses, run_context.session.graph,
|
||
|
run_context.original_args.feed_dict,
|
||
|
run_context.original_args.fetches, self._sent_graph_version)
|
||
|
return super(TensorBoardDebugHook, self).before_run(run_context)
|