# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """The TensorBoard Scalars plugin. See `http_api.md` in this directory for specifications of the routes for this plugin. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import csv import six from six import StringIO from werkzeug import wrappers import numpy as np import tensorflow as tf from tensorboard import plugin_util from tensorboard.backend import http_util from tensorboard.plugins import base_plugin from tensorboard.plugins.scalar import metadata class OutputFormat(object): """An enum used to list the valid output formats for API calls.""" JSON = 'json' CSV = 'csv' class ScalarsPlugin(base_plugin.TBPlugin): """Scalars Plugin for TensorBoard.""" plugin_name = metadata.PLUGIN_NAME def __init__(self, context): """Instantiates ScalarsPlugin via TensorBoard core. Args: context: A base_plugin.TBContext instance. """ self._multiplexer = context.multiplexer self._db_connection_provider = context.db_connection_provider def get_plugin_apps(self): return { '/scalars': self.scalars_route, '/tags': self.tags_route, } def is_active(self): """The scalars plugin is active iff any run has at least one scalar tag.""" if self._db_connection_provider: # The plugin is active if one relevant tag can be found in the database. db = self._db_connection_provider() cursor = db.execute(''' SELECT 1 FROM Tags WHERE Tags.plugin_name = ? LIMIT 1 ''', (metadata.PLUGIN_NAME,)) return bool(list(cursor)) if not self._multiplexer: return False return bool(self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME)) def index_impl(self): """Return {runName: {tagName: {displayName: ..., description: ...}}}.""" if self._db_connection_provider: # Read tags from the database. db = self._db_connection_provider() cursor = db.execute(''' SELECT Tags.tag_name, Tags.display_name, Runs.run_name FROM Tags JOIN Runs ON Tags.run_id = Runs.run_id WHERE Tags.plugin_name = ? ''', (metadata.PLUGIN_NAME,)) result = collections.defaultdict(dict) for row in cursor: tag_name, display_name, run_name = row result[run_name][tag_name] = { 'displayName': display_name, # TODO(chihuahua): Populate the description. Currently, the tags # table does not link with the description table. 'description': '', } return result runs = self._multiplexer.Runs() result = {run: {} for run in runs} mapping = self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) for (run, tag_to_content) in six.iteritems(mapping): for (tag, content) in six.iteritems(tag_to_content): content = metadata.parse_plugin_metadata(content) summary_metadata = self._multiplexer.SummaryMetadata(run, tag) result[run][tag] = {'displayName': summary_metadata.display_name, 'description': plugin_util.markdown_to_safe_html( summary_metadata.summary_description)} return result def scalars_impl(self, tag, run, output_format): """Result of the form `(body, mime_type)`.""" if self._db_connection_provider: db = self._db_connection_provider() # We select for steps greater than -1 because the writer inserts # placeholder rows en masse. The check for step filters out those rows. cursor = db.execute(''' SELECT Tensors.step, Tensors.computed_time, Tensors.data, Tensors.dtype FROM Tensors JOIN Tags ON Tensors.series = Tags.tag_id JOIN Runs ON Tags.run_id = Runs.run_id WHERE Runs.run_name = ? AND Tags.tag_name = ? AND Tags.plugin_name = ? AND Tensors.shape = '' AND Tensors.step > -1 ORDER BY Tensors.step ''', (run, tag, metadata.PLUGIN_NAME)) values = [(wall_time, step, self._get_value(data, dtype_enum)) for (step, wall_time, data, dtype_enum) in cursor] else: tensor_events = self._multiplexer.Tensors(run, tag) values = [(tensor_event.wall_time, tensor_event.step, tf.make_ndarray(tensor_event.tensor_proto).item()) for tensor_event in tensor_events] if output_format == OutputFormat.CSV: string_io = StringIO() writer = csv.writer(string_io) writer.writerow(['Wall time', 'Step', 'Value']) writer.writerows(values) return (string_io.getvalue(), 'text/csv') else: return (values, 'application/json') def _get_value(self, scalar_data_blob, dtype_enum): """Obtains value for scalar event given blob and dtype enum. Args: scalar_data_blob: The blob obtained from the database. dtype_enum: The enum representing the dtype. Returns: The scalar value. """ tensorflow_dtype = tf.DType(dtype_enum) buf = np.frombuffer(scalar_data_blob, dtype=tensorflow_dtype.as_numpy_dtype) return np.asscalar(buf) @wrappers.Request.application def tags_route(self, request): index = self.index_impl() return http_util.Respond(request, index, 'application/json') @wrappers.Request.application def scalars_route(self, request): """Given a tag and single run, return array of ScalarEvents.""" # TODO: return HTTP status code for malformed requests tag = request.args.get('tag') run = request.args.get('run') output_format = request.args.get('format') (body, mime_type) = self.scalars_impl(tag, run, output_format) return http_util.Respond(request, body, mime_type)