# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """tfdbg CLI as SessionRunHook.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import config_pb2 from tensorflow.python.debug.lib import debug_utils from tensorflow.python.debug.lib import stepper from tensorflow.python.debug.wrappers import dumping_wrapper from tensorflow.python.debug.wrappers import framework from tensorflow.python.debug.wrappers import grpc_wrapper from tensorflow.python.debug.wrappers import local_cli_wrapper from tensorflow.python.training import session_run_hook class LocalCLIDebugHook(session_run_hook.SessionRunHook): """Command-line-interface debugger hook. Can be used as a hook for `tf.train.MonitoredSession`s and `tf.estimator.Estimator`s. Provides a substitute for `tfdbg.LocalCLIDebugWrapperSession` in cases where the session is not directly available. """ def __init__(self, ui_type="curses", dump_root=None, thread_name_filter=None): """Create a local debugger command-line interface (CLI) hook. Args: ui_type: (`str`) requested user-interface type. Currently supported: (curses | readline). dump_root: (`str`) optional path to the dump root directory. Must be a directory that does not exist or an empty directory. If the directory does not exist, it will be created by the debugger core during debug `run()` calls and removed afterwards. thread_name_filter: Regular-expression white list for threads on which the wrapper session will be active. See doc of `BaseDebugWrapperSession` for more details. """ self._ui_type = ui_type self._dump_root = dump_root self._thread_name_filter = thread_name_filter self._session_wrapper = None self._pending_tensor_filters = {} def add_tensor_filter(self, filter_name, tensor_filter): """Add a tensor filter. See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()` for details. Override default behavior to accommodate the possibility of this method being called prior to the initialization of the underlying `LocalCLIDebugWrapperSession` object. Args: filter_name: See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()` for details. tensor_filter: See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()` for details. """ if self._session_wrapper: self._session_wrapper.add_tensor_filter(filter_name, tensor_filter) else: self._pending_tensor_filters[filter_name] = tensor_filter def begin(self): pass def before_run(self, run_context): if not self._session_wrapper: self._session_wrapper = local_cli_wrapper.LocalCLIDebugWrapperSession( run_context.session, ui_type=self._ui_type, dump_root=self._dump_root, thread_name_filter=self._thread_name_filter) # Actually register tensor filters registered prior to the construction # of the underlying LocalCLIDebugWrapperSession object. for filter_name in self._pending_tensor_filters: self._session_wrapper.add_tensor_filter( filter_name, self._pending_tensor_filters[filter_name]) # Increment run call counter. self._session_wrapper.increment_run_call_count() # Adapt run_context to an instance of OnRunStartRequest for invoking # superclass on_run_start(). on_run_start_request = framework.OnRunStartRequest( run_context.original_args.fetches, run_context.original_args.feed_dict, None, None, self._session_wrapper.run_call_count) on_run_start_response = self._session_wrapper.on_run_start( on_run_start_request) self._performed_action = on_run_start_response.action run_args = session_run_hook.SessionRunArgs( None, feed_dict=None, options=config_pb2.RunOptions()) if self._performed_action == framework.OnRunStartAction.DEBUG_RUN: # pylint: disable=protected-access self._session_wrapper._decorate_run_options_for_debug( run_args.options, on_run_start_response.debug_urls, debug_ops=on_run_start_response.debug_ops, node_name_regex_whitelist=( on_run_start_response.node_name_regex_whitelist), op_type_regex_whitelist=( on_run_start_response.op_type_regex_whitelist), tensor_dtype_regex_whitelist=( on_run_start_response.tensor_dtype_regex_whitelist), tolerate_debug_op_creation_failures=( on_run_start_response.tolerate_debug_op_creation_failures)) # pylint: enable=protected-access elif self._performed_action == framework.OnRunStartAction.PROFILE_RUN: # pylint: disable=protected-access self._session_wrapper._decorate_run_options_for_profile(run_args.options) # pylint: enable=protected-access elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER: # The _finalized property must be set to False so that the NodeStepper # can insert ops for retrieving TensorHandles. # pylint: disable=protected-access run_context.session.graph._finalized = False # pylint: enable=protected-access with stepper.NodeStepper( run_context.session, run_context.original_args.fetches, run_context.original_args.feed_dict) as node_stepper: self._session_wrapper.invoke_node_stepper( node_stepper, restore_variable_values_on_exit=True) return run_args def after_run(self, run_context, run_values): # Adapt run_context and run_values to OnRunEndRequest and invoke superclass # on_run_end() on_run_end_request = framework.OnRunEndRequest(self._performed_action, run_values.run_metadata) self._session_wrapper.on_run_end(on_run_end_request) class DumpingDebugHook(session_run_hook.SessionRunHook): """A debugger hook that dumps debug data to filesystem. Can be used as a hook for `tf.train.MonitoredSession`s and `tf.estimator.Estimator`s. """ def __init__(self, session_root, watch_fn=None, thread_name_filter=None, log_usage=True): """Create a local debugger command-line interface (CLI) hook. Args: session_root: See doc of `dumping_wrapper.DumpingDebugWrapperSession.__init__`. watch_fn: See doc of `dumping_wrapper.DumpingDebugWrapperSession.__init__`. thread_name_filter: Regular-expression white list for threads on which the wrapper session will be active. See doc of `BaseDebugWrapperSession` for more details. log_usage: (bool) Whether usage is to be logged. """ self._session_root = session_root self._watch_fn = watch_fn self._thread_name_filter = thread_name_filter self._log_usage = log_usage self._session_wrapper = None def begin(self): pass def before_run(self, run_context): if not self._session_wrapper: self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession( run_context.session, self._session_root, watch_fn=self._watch_fn, thread_name_filter=self._thread_name_filter, log_usage=self._log_usage) self._session_wrapper.increment_run_call_count() # pylint: disable=protected-access debug_urls, watch_options = self._session_wrapper._prepare_run_watch_config( run_context.original_args.fetches, run_context.original_args.feed_dict) # pylint: enable=protected-access run_options = config_pb2.RunOptions() debug_utils.watch_graph( run_options, run_context.session.graph, debug_urls=debug_urls, debug_ops=watch_options.debug_ops, node_name_regex_whitelist=watch_options.node_name_regex_whitelist, op_type_regex_whitelist=watch_options.op_type_regex_whitelist, tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist, tolerate_debug_op_creation_failures=( watch_options.tolerate_debug_op_creation_failures)) run_args = session_run_hook.SessionRunArgs( None, feed_dict=None, options=run_options) return run_args def after_run(self, run_context, run_values): pass class GrpcDebugHook(session_run_hook.SessionRunHook): """A hook that streams debugger-related events to any grpc_debug_server. For example, the debugger data server is a grpc_debug_server. The debugger data server writes debugger-related events it receives via GRPC to logdir. This enables debugging features in Tensorboard such as health pills. When the arguments of debug_utils.watch_graph changes, strongly consider changing arguments here too so that features are available to tflearn users. Can be used as a hook for `tf.train.MonitoredSession`s and `tf.estimator.Estimator`s. """ def __init__(self, grpc_debug_server_addresses, watch_fn=None, thread_name_filter=None, log_usage=True): """Constructs a GrpcDebugHook. Args: grpc_debug_server_addresses: (`list` of `str`) A list of the gRPC debug server addresses, in the format of , with or without the "grpc://" prefix. For example: ["localhost:7000", "192.168.0.2:8000"] watch_fn: A function that allows for customizing which ops to watch at which specific steps. See doc of `dumping_wrapper.DumpingDebugWrapperSession.__init__` for details. thread_name_filter: Regular-expression white list for threads on which the wrapper session will be active. See doc of `BaseDebugWrapperSession` for more details. log_usage: (bool) Whether usage is to be logged. """ self._grpc_debug_wrapper_session = None self._thread_name_filter = thread_name_filter self._grpc_debug_server_addresses = ( grpc_debug_server_addresses if isinstance(grpc_debug_server_addresses, list) else [grpc_debug_server_addresses]) self._watch_fn = watch_fn self._log_usage = log_usage def before_run(self, run_context): """Called right before a session is run. Args: run_context: A session_run_hook.SessionRunContext. Encapsulates information on the run. Returns: A session_run_hook.SessionRunArgs object. """ if not self._grpc_debug_wrapper_session: self._grpc_debug_wrapper_session = grpc_wrapper.GrpcDebugWrapperSession( run_context.session, self._grpc_debug_server_addresses, watch_fn=self._watch_fn, thread_name_filter=self._thread_name_filter, log_usage=self._log_usage) fetches = run_context.original_args.fetches feed_dict = run_context.original_args.feed_dict watch_options = self._watch_fn(fetches, feed_dict) run_options = config_pb2.RunOptions() debug_utils.watch_graph( run_options, run_context.session.graph, debug_urls=self._grpc_debug_wrapper_session.prepare_run_debug_urls( fetches, feed_dict), debug_ops=watch_options.debug_ops, node_name_regex_whitelist=watch_options.node_name_regex_whitelist, op_type_regex_whitelist=watch_options.op_type_regex_whitelist, tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist, tolerate_debug_op_creation_failures=( watch_options.tolerate_debug_op_creation_failures)) return session_run_hook.SessionRunArgs( None, feed_dict=None, options=run_options) class TensorBoardDebugHook(GrpcDebugHook): """A tfdbg hook that can be used with TensorBoard Debugger Plugin. This hook is the same as `GrpcDebugHook`, except that it uses a predefined `watch_fn` that 1) uses `DebugIdentity` debug ops with the `gated_grpc` attribute set to `True`, to allow the interactive enabling and disabling of tensor breakpoints. 2) watches all tensors in the graph. This saves the need for the user to define a `watch_fn`. """ def __init__(self, grpc_debug_server_addresses, thread_name_filter=None, send_traceback_and_source_code=True, log_usage=True): """Constructor of TensorBoardDebugHook. Args: grpc_debug_server_addresses: gRPC address(es) of debug server(s), as a `str` or a `list` of `str`s. E.g., "localhost:2333", "grpc://localhost:2333", ["192.168.0.7:2333", "192.168.0.8:2333"]. thread_name_filter: Optional filter for thread names. send_traceback_and_source_code: Whether traceback of graph elements and the source code are to be sent to the debug server(s). log_usage: Whether the usage of this class is to be logged (if applicable). """ def _gated_grpc_watch_fn(fetches, feeds): del fetches, feeds # Unused. return framework.WatchOptions( debug_ops=["DebugIdentity(gated_grpc=true)"]) super(TensorBoardDebugHook, self).__init__( grpc_debug_server_addresses, watch_fn=_gated_grpc_watch_fn, thread_name_filter=thread_name_filter, log_usage=log_usage) self._grpc_debug_server_addresses = grpc_debug_server_addresses self._send_traceback_and_source_code = send_traceback_and_source_code self._sent_graph_version = -1 grpc_wrapper.register_signal_handler() def before_run(self, run_context): if self._send_traceback_and_source_code: self._sent_graph_version = grpc_wrapper.publish_traceback( self._grpc_debug_server_addresses, run_context.session.graph, run_context.original_args.feed_dict, run_context.original_args.fetches, self._sent_graph_version) return super(TensorBoardDebugHook, self).before_run(run_context)