laywerrobot/lib/python3.6/site-packages/tensorboard/plugins/profile/profile_plugin.py

339 lines
12 KiB
Python
Raw Normal View History

2020-08-27 21:55:39 +02:00
# -*- coding: utf-8 -*-
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard plugin for performance profiling."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
from werkzeug import wrappers
import tensorflow as tf
from tensorboard.backend import http_util
from tensorboard.backend.event_processing import plugin_asset_util
from tensorboard.plugins import base_plugin
from tensorboard.plugins.profile import trace_events_json
from tensorboard.plugins.profile import trace_events_pb2
# The prefix of routes provided by this plugin.
PLUGIN_NAME = 'profile'
# HTTP routes
LOGDIR_ROUTE = '/logdir'
DATA_ROUTE = '/data'
TOOLS_ROUTE = '/tools'
HOSTS_ROUTE = '/hosts'
# Available profiling tools -> file name of the tool data.
_FILE_NAME = 'TOOL_FILE_NAME'
TOOLS = {
'trace_viewer': 'trace',
'trace_viewer@': 'tracetable', #streaming traceviewer
'op_profile': 'op_profile.json',
'input_pipeline_analyzer': 'input_pipeline.json',
'overview_page': 'overview_page.json',
'memory_viewer': 'memory_viewer.json',
'google_chart_demo': 'google_chart_demo.json',
}
# Tools that consume raw data.
_RAW_DATA_TOOLS = frozenset(['input_pipeline_analyzer',
'op_profile',
'overview_page',
'memory_viewer',
'google_chart_demo',])
def process_raw_trace(raw_trace):
"""Processes raw trace data and returns the UI data."""
trace = trace_events_pb2.Trace()
trace.ParseFromString(raw_trace)
return ''.join(trace_events_json.TraceEventsJsonStream(trace))
class ProfilePluginLoader(base_plugin.TBLoader):
"""Loader for Profile Plugin."""
def define_flags(self, parser):
group = parser.add_argument_group('profile plugin')
group.add_argument(
'--master_tpu_unsecure_channel',
metavar='ADDR',
type=str,
default='',
help='''\
IP address of "master tpu", used for getting streaming trace data
through tpu profiler analysis grpc. The grpc channel is not secured.\
''')
def load(self, context):
return ProfilePlugin(context)
class ProfilePlugin(base_plugin.TBPlugin):
"""Profile Plugin for TensorBoard."""
plugin_name = PLUGIN_NAME
def __init__(self, context):
"""Constructs a profiler plugin for TensorBoard.
This plugin adds handlers for performance-related frontends.
Args:
context: A base_plugin.TBContext instance.
"""
self.logdir = context.logdir
self.plugin_logdir = plugin_asset_util.PluginDirectory(
self.logdir, ProfilePlugin.plugin_name)
self.stub = None
self.master_tpu_unsecure_channel = context.flags.master_tpu_unsecure_channel
@wrappers.Request.application
def logdir_route(self, request):
return http_util.Respond(request, {'logdir': self.plugin_logdir},
'application/json')
def _run_dir(self, run):
run_dir = os.path.join(self.plugin_logdir, run)
return run_dir if tf.gfile.IsDirectory(run_dir) else None
def start_grpc_stub_if_necessary(self):
# We will enable streaming trace viewer on two conditions:
# 1. user specify the flags master_tpu_unsecure_channel to the ip address of
# as "master" TPU. grpc will be used to fetch streaming trace data.
# 2. the logdir is on google cloud storage.
if self.master_tpu_unsecure_channel and self.logdir.startswith('gs://'):
if self.stub is None:
import grpc
from tensorflow.contrib.tpu.profiler import tpu_profiler_analysis_pb2_grpc # pylint: disable=line-too-long
# Workaround the grpc's 4MB message limitation.
gigabyte = 1024 * 1024 * 1024
options = [('grpc.max_message_length', gigabyte),
('grpc.max_send_message_length', gigabyte),
('grpc.max_receive_message_length', gigabyte)]
tpu_profiler_port = self.master_tpu_unsecure_channel + ':8466'
channel = grpc.insecure_channel(tpu_profiler_port, options)
self.stub = tpu_profiler_analysis_pb2_grpc.TPUProfileAnalysisStub(
channel)
def index_impl(self):
"""Returns available runs and available tool data in the log directory.
In the plugin log directory, each directory contains profile data for a
single run (identified by the directory name), and files in the run
directory contains data for different tools. The file that contains profile
for a specific tool "x" will have a suffix name TOOLS["x"].
Example:
log/
run1/
plugins/
profile/
host1.trace
host2.trace
run2/
plugins/
profile/
host1.trace
host2.trace
Returns:
A map from runs to tool names e.g.
{"run1": ["trace_viewer"], "run2": ["trace_viewer"]} for the example.
"""
# TODO(ioeric): use the following structure and use EventMultiplexer so that
# the plugin still works when logdir is set to root_logdir/run1/
# root_logdir/
# run1/
# plugins/
# profile/
# host1.trace
# run2/
# plugins/
# profile/
# host2.trace
run_to_tools = {}
if not tf.gfile.IsDirectory(self.plugin_logdir):
return run_to_tools
self.start_grpc_stub_if_necessary()
for run in tf.gfile.ListDirectory(self.plugin_logdir):
run_dir = self._run_dir(run)
if not run_dir:
continue
run_to_tools[run] = []
for tool in TOOLS:
tool_pattern = '*' + TOOLS[tool]
path = os.path.join(run_dir, tool_pattern)
try:
files = tf.gfile.Glob(path)
if len(files) >= 1:
run_to_tools[run].append(tool)
except tf.errors.OpError as e:
logging.warning("Cannot read asset directory: %s, OpError %s",
run_dir, e)
if 'trace_viewer@' in run_to_tools[run]:
# streaming trace viewer always override normal trace viewer.
# the trailing '@' is to inform tf-profile-dashboard.html and
# tf-trace-viewer.html that stream trace viewer should be used.
removed_tool = 'trace_viewer@' if self.stub is None else 'trace_viewer'
if removed_tool in run_to_tools[run]:
run_to_tools[run].remove(removed_tool)
run_to_tools[run].sort()
op = 'overview_page'
if op in run_to_tools[run]:
# keep overview page at the top of the list
run_to_tools[run].remove(op)
run_to_tools[run].insert(0, op)
return run_to_tools
@wrappers.Request.application
def tools_route(self, request):
run_to_tools = self.index_impl()
return http_util.Respond(request, run_to_tools, 'application/json')
def host_impl(self, run, tool):
"""Returns available hosts for the run and tool in the log directory.
In the plugin log directory, each directory contains profile data for a
single run (identified by the directory name), and files in the run
directory contains data for different tools and hosts. The file that
contains profile for a specific tool "x" will have a prefix name TOOLS["x"].
Example:
log/
run1/
plugins/
profile/
host1.trace
host2.trace
run2/
plugins/
profile/
host1.trace
host2.trace
Returns:
A list of host names e.g.
{"host1", "host2", "host3"} for the example.
"""
hosts = {}
if not tf.gfile.IsDirectory(self.plugin_logdir):
return hosts
run_dir = self._run_dir(run)
if not run_dir:
logging.warning("Cannot find asset directory: %s", run_dir)
return hosts
tool_pattern = '*' + TOOLS[tool]
try:
files = tf.gfile.Glob(os.path.join(run_dir, tool_pattern))
hosts = [os.path.basename(f).replace(TOOLS[tool], '') for f in files]
except tf.errors.OpError as e:
logging.warning("Cannot read asset directory: %s, OpError %s",
run_dir, e)
return hosts
@wrappers.Request.application
def hosts_route(self, request):
run = request.args.get('run')
tool = request.args.get('tag')
hosts = self.host_impl(run, tool)
return http_util.Respond(request, hosts, 'application/json')
def data_impl(self, request):
"""Retrieves and processes the tool data for a run and a host.
Args:
request: XMLHttpRequest
Returns:
A string that can be served to the frontend tool or None if tool,
run or host is invalid.
"""
run = request.args.get('run')
tool = request.args.get('tag')
host = request.args.get('host')
if tool not in TOOLS:
return None
self.start_grpc_stub_if_necessary()
if tool == 'trace_viewer@' and self.stub is not None:
from tensorflow.contrib.tpu.profiler import tpu_profiler_analysis_pb2
grpc_request = tpu_profiler_analysis_pb2.ProfileSessionDataRequest()
grpc_request.repository_root = self.plugin_logdir
grpc_request.session_id = run[:-1]
grpc_request.tool_name = 'trace_viewer'
# Remove the trailing dot if present
grpc_request.host_name = host.rstrip('.')
grpc_request.parameters['resolution'] = request.args.get('resolution')
if request.args.get('start_time_ms') is not None:
grpc_request.parameters['start_time_ms'] = request.args.get(
'start_time_ms')
if request.args.get('end_time_ms') is not None:
grpc_request.parameters['end_time_ms'] = request.args.get('end_time_ms')
grpc_response = self.stub.GetSessionToolData(grpc_request)
return grpc_response.output
if tool not in TOOLS:
return None
tool_name = str(host) + TOOLS[tool]
rel_data_path = os.path.join(run, tool_name)
asset_path = os.path.join(self.plugin_logdir, rel_data_path)
raw_data = None
try:
with tf.gfile.Open(asset_path, 'rb') as f:
raw_data = f.read()
except tf.errors.NotFoundError:
logging.warning('Asset path %s not found', asset_path)
except tf.errors.OpError as e:
logging.warning("Couldn't read asset path: %s, OpError %s", asset_path, e)
if raw_data is None:
return None
if tool == 'trace_viewer':
return process_raw_trace(raw_data)
if tool in _RAW_DATA_TOOLS:
return raw_data
return None
@wrappers.Request.application
def data_route(self, request):
# params
# request: XMLHTTPRequest.
data = self.data_impl(request)
if data is None:
return http_util.Respond(request, '404 Not Found', 'text/plain', code=404)
return http_util.Respond(request, data, 'application/json')
def get_plugin_apps(self):
return {
LOGDIR_ROUTE: self.logdir_route,
TOOLS_ROUTE: self.tools_route,
HOSTS_ROUTE: self.hosts_route,
DATA_ROUTE: self.data_route,
}
def is_active(self):
"""The plugin is active iff any run has at least one active tool/tag."""
return any(self.index_impl().values())