laywerrobot/lib/python3.6/site-packages/tensorboard/plugins/scalar/scalars_plugin.py
2020-08-27 21:55:39 +02:00

196 lines
6.4 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard Scalars plugin.
See `http_api.md` in this directory for specifications of the routes for
this plugin.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import csv
import six
from six import StringIO
from werkzeug import wrappers
import numpy as np
import tensorflow as tf
from tensorboard import plugin_util
from tensorboard.backend import http_util
from tensorboard.plugins import base_plugin
from tensorboard.plugins.scalar import metadata
class OutputFormat(object):
"""An enum used to list the valid output formats for API calls."""
JSON = 'json'
CSV = 'csv'
class ScalarsPlugin(base_plugin.TBPlugin):
"""Scalars Plugin for TensorBoard."""
plugin_name = metadata.PLUGIN_NAME
def __init__(self, context):
"""Instantiates ScalarsPlugin via TensorBoard core.
Args:
context: A base_plugin.TBContext instance.
"""
self._multiplexer = context.multiplexer
self._db_connection_provider = context.db_connection_provider
def get_plugin_apps(self):
return {
'/scalars': self.scalars_route,
'/tags': self.tags_route,
}
def is_active(self):
"""The scalars plugin is active iff any run has at least one scalar tag."""
if self._db_connection_provider:
# The plugin is active if one relevant tag can be found in the database.
db = self._db_connection_provider()
cursor = db.execute('''
SELECT
1
FROM Tags
WHERE Tags.plugin_name = ?
LIMIT 1
''', (metadata.PLUGIN_NAME,))
return bool(list(cursor))
if not self._multiplexer:
return False
return bool(self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME))
def index_impl(self):
"""Return {runName: {tagName: {displayName: ..., description: ...}}}."""
if self._db_connection_provider:
# Read tags from the database.
db = self._db_connection_provider()
cursor = db.execute('''
SELECT
Tags.tag_name,
Tags.display_name,
Runs.run_name
FROM Tags
JOIN Runs
ON Tags.run_id = Runs.run_id
WHERE
Tags.plugin_name = ?
''', (metadata.PLUGIN_NAME,))
result = collections.defaultdict(dict)
for row in cursor:
tag_name, display_name, run_name = row
result[run_name][tag_name] = {
'displayName': display_name,
# TODO(chihuahua): Populate the description. Currently, the tags
# table does not link with the description table.
'description': '',
}
return result
runs = self._multiplexer.Runs()
result = {run: {} for run in runs}
mapping = self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME)
for (run, tag_to_content) in six.iteritems(mapping):
for (tag, content) in six.iteritems(tag_to_content):
content = metadata.parse_plugin_metadata(content)
summary_metadata = self._multiplexer.SummaryMetadata(run, tag)
result[run][tag] = {'displayName': summary_metadata.display_name,
'description': plugin_util.markdown_to_safe_html(
summary_metadata.summary_description)}
return result
def scalars_impl(self, tag, run, output_format):
"""Result of the form `(body, mime_type)`."""
if self._db_connection_provider:
db = self._db_connection_provider()
# We select for steps greater than -1 because the writer inserts
# placeholder rows en masse. The check for step filters out those rows.
cursor = db.execute('''
SELECT
Tensors.step,
Tensors.computed_time,
Tensors.data,
Tensors.dtype
FROM Tensors
JOIN Tags
ON Tensors.series = Tags.tag_id
JOIN Runs
ON Tags.run_id = Runs.run_id
WHERE
Runs.run_name = ?
AND Tags.tag_name = ?
AND Tags.plugin_name = ?
AND Tensors.shape = ''
AND Tensors.step > -1
ORDER BY Tensors.step
''', (run, tag, metadata.PLUGIN_NAME))
values = [(wall_time, step, self._get_value(data, dtype_enum))
for (step, wall_time, data, dtype_enum) in cursor]
else:
tensor_events = self._multiplexer.Tensors(run, tag)
values = [(tensor_event.wall_time,
tensor_event.step,
tf.make_ndarray(tensor_event.tensor_proto).item())
for tensor_event in tensor_events]
if output_format == OutputFormat.CSV:
string_io = StringIO()
writer = csv.writer(string_io)
writer.writerow(['Wall time', 'Step', 'Value'])
writer.writerows(values)
return (string_io.getvalue(), 'text/csv')
else:
return (values, 'application/json')
def _get_value(self, scalar_data_blob, dtype_enum):
"""Obtains value for scalar event given blob and dtype enum.
Args:
scalar_data_blob: The blob obtained from the database.
dtype_enum: The enum representing the dtype.
Returns:
The scalar value.
"""
tensorflow_dtype = tf.DType(dtype_enum)
buf = np.frombuffer(scalar_data_blob, dtype=tensorflow_dtype.as_numpy_dtype)
return np.asscalar(buf)
@wrappers.Request.application
def tags_route(self, request):
index = self.index_impl()
return http_util.Respond(request, index, 'application/json')
@wrappers.Request.application
def scalars_route(self, request):
"""Given a tag and single run, return array of ScalarEvents."""
# TODO: return HTTP status code for malformed requests
tag = request.args.get('tag')
run = request.args.get('run')
output_format = request.args.get('format')
(body, mime_type) = self.scalars_impl(tag, run, output_format)
return http_util.Respond(request, body, mime_type)