339 lines
12 KiB
Python
339 lines
12 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
|
||
|
"""The TensorBoard plugin for performance profiling."""
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import logging
|
||
|
import os
|
||
|
from werkzeug import wrappers
|
||
|
|
||
|
import tensorflow as tf
|
||
|
|
||
|
from tensorboard.backend import http_util
|
||
|
from tensorboard.backend.event_processing import plugin_asset_util
|
||
|
from tensorboard.plugins import base_plugin
|
||
|
from tensorboard.plugins.profile import trace_events_json
|
||
|
from tensorboard.plugins.profile import trace_events_pb2
|
||
|
|
||
|
# The prefix of routes provided by this plugin.
|
||
|
PLUGIN_NAME = 'profile'
|
||
|
|
||
|
# HTTP routes
|
||
|
LOGDIR_ROUTE = '/logdir'
|
||
|
DATA_ROUTE = '/data'
|
||
|
TOOLS_ROUTE = '/tools'
|
||
|
HOSTS_ROUTE = '/hosts'
|
||
|
|
||
|
# Available profiling tools -> file name of the tool data.
|
||
|
_FILE_NAME = 'TOOL_FILE_NAME'
|
||
|
TOOLS = {
|
||
|
'trace_viewer': 'trace',
|
||
|
'trace_viewer@': 'tracetable', #streaming traceviewer
|
||
|
'op_profile': 'op_profile.json',
|
||
|
'input_pipeline_analyzer': 'input_pipeline.json',
|
||
|
'overview_page': 'overview_page.json',
|
||
|
'memory_viewer': 'memory_viewer.json',
|
||
|
'google_chart_demo': 'google_chart_demo.json',
|
||
|
}
|
||
|
|
||
|
# Tools that consume raw data.
|
||
|
_RAW_DATA_TOOLS = frozenset(['input_pipeline_analyzer',
|
||
|
'op_profile',
|
||
|
'overview_page',
|
||
|
'memory_viewer',
|
||
|
'google_chart_demo',])
|
||
|
|
||
|
def process_raw_trace(raw_trace):
|
||
|
"""Processes raw trace data and returns the UI data."""
|
||
|
trace = trace_events_pb2.Trace()
|
||
|
trace.ParseFromString(raw_trace)
|
||
|
return ''.join(trace_events_json.TraceEventsJsonStream(trace))
|
||
|
|
||
|
|
||
|
class ProfilePluginLoader(base_plugin.TBLoader):
|
||
|
"""Loader for Profile Plugin."""
|
||
|
|
||
|
def define_flags(self, parser):
|
||
|
group = parser.add_argument_group('profile plugin')
|
||
|
group.add_argument(
|
||
|
'--master_tpu_unsecure_channel',
|
||
|
metavar='ADDR',
|
||
|
type=str,
|
||
|
default='',
|
||
|
help='''\
|
||
|
IP address of "master tpu", used for getting streaming trace data
|
||
|
through tpu profiler analysis grpc. The grpc channel is not secured.\
|
||
|
''')
|
||
|
|
||
|
def load(self, context):
|
||
|
return ProfilePlugin(context)
|
||
|
|
||
|
|
||
|
class ProfilePlugin(base_plugin.TBPlugin):
|
||
|
"""Profile Plugin for TensorBoard."""
|
||
|
|
||
|
plugin_name = PLUGIN_NAME
|
||
|
|
||
|
def __init__(self, context):
|
||
|
"""Constructs a profiler plugin for TensorBoard.
|
||
|
|
||
|
This plugin adds handlers for performance-related frontends.
|
||
|
|
||
|
Args:
|
||
|
context: A base_plugin.TBContext instance.
|
||
|
"""
|
||
|
self.logdir = context.logdir
|
||
|
self.plugin_logdir = plugin_asset_util.PluginDirectory(
|
||
|
self.logdir, ProfilePlugin.plugin_name)
|
||
|
self.stub = None
|
||
|
self.master_tpu_unsecure_channel = context.flags.master_tpu_unsecure_channel
|
||
|
|
||
|
@wrappers.Request.application
|
||
|
def logdir_route(self, request):
|
||
|
return http_util.Respond(request, {'logdir': self.plugin_logdir},
|
||
|
'application/json')
|
||
|
|
||
|
def _run_dir(self, run):
|
||
|
run_dir = os.path.join(self.plugin_logdir, run)
|
||
|
return run_dir if tf.gfile.IsDirectory(run_dir) else None
|
||
|
|
||
|
def start_grpc_stub_if_necessary(self):
|
||
|
# We will enable streaming trace viewer on two conditions:
|
||
|
# 1. user specify the flags master_tpu_unsecure_channel to the ip address of
|
||
|
# as "master" TPU. grpc will be used to fetch streaming trace data.
|
||
|
# 2. the logdir is on google cloud storage.
|
||
|
if self.master_tpu_unsecure_channel and self.logdir.startswith('gs://'):
|
||
|
if self.stub is None:
|
||
|
import grpc
|
||
|
from tensorflow.contrib.tpu.profiler import tpu_profiler_analysis_pb2_grpc # pylint: disable=line-too-long
|
||
|
# Workaround the grpc's 4MB message limitation.
|
||
|
gigabyte = 1024 * 1024 * 1024
|
||
|
options = [('grpc.max_message_length', gigabyte),
|
||
|
('grpc.max_send_message_length', gigabyte),
|
||
|
('grpc.max_receive_message_length', gigabyte)]
|
||
|
tpu_profiler_port = self.master_tpu_unsecure_channel + ':8466'
|
||
|
channel = grpc.insecure_channel(tpu_profiler_port, options)
|
||
|
self.stub = tpu_profiler_analysis_pb2_grpc.TPUProfileAnalysisStub(
|
||
|
channel)
|
||
|
|
||
|
def index_impl(self):
|
||
|
"""Returns available runs and available tool data in the log directory.
|
||
|
|
||
|
In the plugin log directory, each directory contains profile data for a
|
||
|
single run (identified by the directory name), and files in the run
|
||
|
directory contains data for different tools. The file that contains profile
|
||
|
for a specific tool "x" will have a suffix name TOOLS["x"].
|
||
|
Example:
|
||
|
log/
|
||
|
run1/
|
||
|
plugins/
|
||
|
profile/
|
||
|
host1.trace
|
||
|
host2.trace
|
||
|
run2/
|
||
|
plugins/
|
||
|
profile/
|
||
|
host1.trace
|
||
|
host2.trace
|
||
|
|
||
|
Returns:
|
||
|
A map from runs to tool names e.g.
|
||
|
{"run1": ["trace_viewer"], "run2": ["trace_viewer"]} for the example.
|
||
|
"""
|
||
|
# TODO(ioeric): use the following structure and use EventMultiplexer so that
|
||
|
# the plugin still works when logdir is set to root_logdir/run1/
|
||
|
# root_logdir/
|
||
|
# run1/
|
||
|
# plugins/
|
||
|
# profile/
|
||
|
# host1.trace
|
||
|
# run2/
|
||
|
# plugins/
|
||
|
# profile/
|
||
|
# host2.trace
|
||
|
run_to_tools = {}
|
||
|
if not tf.gfile.IsDirectory(self.plugin_logdir):
|
||
|
return run_to_tools
|
||
|
|
||
|
self.start_grpc_stub_if_necessary()
|
||
|
for run in tf.gfile.ListDirectory(self.plugin_logdir):
|
||
|
run_dir = self._run_dir(run)
|
||
|
if not run_dir:
|
||
|
continue
|
||
|
run_to_tools[run] = []
|
||
|
for tool in TOOLS:
|
||
|
tool_pattern = '*' + TOOLS[tool]
|
||
|
path = os.path.join(run_dir, tool_pattern)
|
||
|
try:
|
||
|
files = tf.gfile.Glob(path)
|
||
|
if len(files) >= 1:
|
||
|
run_to_tools[run].append(tool)
|
||
|
except tf.errors.OpError as e:
|
||
|
logging.warning("Cannot read asset directory: %s, OpError %s",
|
||
|
run_dir, e)
|
||
|
if 'trace_viewer@' in run_to_tools[run]:
|
||
|
# streaming trace viewer always override normal trace viewer.
|
||
|
# the trailing '@' is to inform tf-profile-dashboard.html and
|
||
|
# tf-trace-viewer.html that stream trace viewer should be used.
|
||
|
removed_tool = 'trace_viewer@' if self.stub is None else 'trace_viewer'
|
||
|
if removed_tool in run_to_tools[run]:
|
||
|
run_to_tools[run].remove(removed_tool)
|
||
|
run_to_tools[run].sort()
|
||
|
op = 'overview_page'
|
||
|
if op in run_to_tools[run]:
|
||
|
# keep overview page at the top of the list
|
||
|
run_to_tools[run].remove(op)
|
||
|
run_to_tools[run].insert(0, op)
|
||
|
return run_to_tools
|
||
|
|
||
|
@wrappers.Request.application
|
||
|
def tools_route(self, request):
|
||
|
run_to_tools = self.index_impl()
|
||
|
|
||
|
return http_util.Respond(request, run_to_tools, 'application/json')
|
||
|
|
||
|
def host_impl(self, run, tool):
|
||
|
"""Returns available hosts for the run and tool in the log directory.
|
||
|
|
||
|
In the plugin log directory, each directory contains profile data for a
|
||
|
single run (identified by the directory name), and files in the run
|
||
|
directory contains data for different tools and hosts. The file that
|
||
|
contains profile for a specific tool "x" will have a prefix name TOOLS["x"].
|
||
|
|
||
|
Example:
|
||
|
log/
|
||
|
run1/
|
||
|
plugins/
|
||
|
profile/
|
||
|
host1.trace
|
||
|
host2.trace
|
||
|
run2/
|
||
|
plugins/
|
||
|
profile/
|
||
|
host1.trace
|
||
|
host2.trace
|
||
|
|
||
|
Returns:
|
||
|
A list of host names e.g.
|
||
|
{"host1", "host2", "host3"} for the example.
|
||
|
"""
|
||
|
hosts = {}
|
||
|
if not tf.gfile.IsDirectory(self.plugin_logdir):
|
||
|
return hosts
|
||
|
run_dir = self._run_dir(run)
|
||
|
if not run_dir:
|
||
|
logging.warning("Cannot find asset directory: %s", run_dir)
|
||
|
return hosts
|
||
|
tool_pattern = '*' + TOOLS[tool]
|
||
|
try:
|
||
|
files = tf.gfile.Glob(os.path.join(run_dir, tool_pattern))
|
||
|
hosts = [os.path.basename(f).replace(TOOLS[tool], '') for f in files]
|
||
|
except tf.errors.OpError as e:
|
||
|
logging.warning("Cannot read asset directory: %s, OpError %s",
|
||
|
run_dir, e)
|
||
|
return hosts
|
||
|
|
||
|
|
||
|
@wrappers.Request.application
|
||
|
def hosts_route(self, request):
|
||
|
run = request.args.get('run')
|
||
|
tool = request.args.get('tag')
|
||
|
hosts = self.host_impl(run, tool)
|
||
|
return http_util.Respond(request, hosts, 'application/json')
|
||
|
|
||
|
def data_impl(self, request):
|
||
|
"""Retrieves and processes the tool data for a run and a host.
|
||
|
|
||
|
Args:
|
||
|
request: XMLHttpRequest
|
||
|
|
||
|
Returns:
|
||
|
A string that can be served to the frontend tool or None if tool,
|
||
|
run or host is invalid.
|
||
|
"""
|
||
|
run = request.args.get('run')
|
||
|
tool = request.args.get('tag')
|
||
|
host = request.args.get('host')
|
||
|
|
||
|
if tool not in TOOLS:
|
||
|
return None
|
||
|
|
||
|
self.start_grpc_stub_if_necessary()
|
||
|
if tool == 'trace_viewer@' and self.stub is not None:
|
||
|
from tensorflow.contrib.tpu.profiler import tpu_profiler_analysis_pb2
|
||
|
grpc_request = tpu_profiler_analysis_pb2.ProfileSessionDataRequest()
|
||
|
grpc_request.repository_root = self.plugin_logdir
|
||
|
grpc_request.session_id = run[:-1]
|
||
|
grpc_request.tool_name = 'trace_viewer'
|
||
|
# Remove the trailing dot if present
|
||
|
grpc_request.host_name = host.rstrip('.')
|
||
|
|
||
|
grpc_request.parameters['resolution'] = request.args.get('resolution')
|
||
|
if request.args.get('start_time_ms') is not None:
|
||
|
grpc_request.parameters['start_time_ms'] = request.args.get(
|
||
|
'start_time_ms')
|
||
|
if request.args.get('end_time_ms') is not None:
|
||
|
grpc_request.parameters['end_time_ms'] = request.args.get('end_time_ms')
|
||
|
grpc_response = self.stub.GetSessionToolData(grpc_request)
|
||
|
return grpc_response.output
|
||
|
|
||
|
if tool not in TOOLS:
|
||
|
return None
|
||
|
tool_name = str(host) + TOOLS[tool]
|
||
|
rel_data_path = os.path.join(run, tool_name)
|
||
|
asset_path = os.path.join(self.plugin_logdir, rel_data_path)
|
||
|
raw_data = None
|
||
|
try:
|
||
|
with tf.gfile.Open(asset_path, 'rb') as f:
|
||
|
raw_data = f.read()
|
||
|
except tf.errors.NotFoundError:
|
||
|
logging.warning('Asset path %s not found', asset_path)
|
||
|
except tf.errors.OpError as e:
|
||
|
logging.warning("Couldn't read asset path: %s, OpError %s", asset_path, e)
|
||
|
|
||
|
if raw_data is None:
|
||
|
return None
|
||
|
if tool == 'trace_viewer':
|
||
|
return process_raw_trace(raw_data)
|
||
|
if tool in _RAW_DATA_TOOLS:
|
||
|
return raw_data
|
||
|
return None
|
||
|
|
||
|
@wrappers.Request.application
|
||
|
def data_route(self, request):
|
||
|
# params
|
||
|
# request: XMLHTTPRequest.
|
||
|
data = self.data_impl(request)
|
||
|
if data is None:
|
||
|
return http_util.Respond(request, '404 Not Found', 'text/plain', code=404)
|
||
|
return http_util.Respond(request, data, 'application/json')
|
||
|
|
||
|
def get_plugin_apps(self):
|
||
|
return {
|
||
|
LOGDIR_ROUTE: self.logdir_route,
|
||
|
TOOLS_ROUTE: self.tools_route,
|
||
|
HOSTS_ROUTE: self.hosts_route,
|
||
|
DATA_ROUTE: self.data_route,
|
||
|
}
|
||
|
|
||
|
def is_active(self):
|
||
|
"""The plugin is active iff any run has at least one active tool/tag."""
|
||
|
return any(self.index_impl().values())
|