# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """The TensorBoard Histograms plugin. See `http_api.md` in this directory for specifications of the routes for this plugin. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import random import numpy as np import six import tensorflow as tf from werkzeug import wrappers from tensorboard import plugin_util from tensorboard.backend import http_util from tensorboard.plugins import base_plugin from tensorboard.plugins.histogram import metadata class HistogramsPlugin(base_plugin.TBPlugin): """Histograms Plugin for TensorBoard. This supports both old-style summaries (created with TensorFlow ops that output directly to the `histo` field of the proto) and new-style summaries (as created by the `tensorboard.plugins.histogram.summary` module). """ plugin_name = metadata.PLUGIN_NAME # Use a round number + 1 since sampling includes both start and end steps, # so N+1 samples corresponds to dividing the step sequence into N intervals. SAMPLE_SIZE = 51 def __init__(self, context): """Instantiates HistogramsPlugin via TensorBoard core. Args: context: A base_plugin.TBContext instance. """ self._db_connection_provider = context.db_connection_provider self._multiplexer = context.multiplexer def get_plugin_apps(self): return { '/histograms': self.histograms_route, '/tags': self.tags_route, } def is_active(self): """This plugin is active iff any run has at least one histograms tag.""" if self._db_connection_provider: # The plugin is active if one relevant tag can be found in the database. db = self._db_connection_provider() cursor = db.execute(''' SELECT 1 FROM Tags WHERE Tags.plugin_name = ? LIMIT 1 ''', (metadata.PLUGIN_NAME,)) return bool(list(cursor)) return bool(self._multiplexer) and any(self.index_impl().values()) def index_impl(self): """Return {runName: {tagName: {displayName: ..., description: ...}}}.""" if self._db_connection_provider: # Read tags from the database. db = self._db_connection_provider() cursor = db.execute(''' SELECT Tags.tag_name, Tags.display_name, Runs.run_name FROM Tags JOIN Runs ON Tags.run_id = Runs.run_id WHERE Tags.plugin_name = ? ''', (metadata.PLUGIN_NAME,)) result = collections.defaultdict(dict) for row in cursor: tag_name, display_name, run_name = row result[run_name][tag_name] = { 'displayName': display_name, # TODO(chihuahua): Populate the description. Currently, the tags # table does not link with the description table. 'description': '', } return result runs = self._multiplexer.Runs() result = {run: {} for run in runs} mapping = self._multiplexer.PluginRunToTagToContent(metadata.PLUGIN_NAME) for (run, tag_to_content) in six.iteritems(mapping): for (tag, content) in six.iteritems(tag_to_content): content = metadata.parse_plugin_metadata(content) summary_metadata = self._multiplexer.SummaryMetadata(run, tag) result[run][tag] = {'displayName': summary_metadata.display_name, 'description': plugin_util.markdown_to_safe_html( summary_metadata.summary_description)} return result def histograms_impl(self, tag, run, downsample_to=None): """Result of the form `(body, mime_type)`, or `ValueError`. At most `downsample_to` events will be returned. If this value is `None`, then no downsampling will be performed. """ if self._db_connection_provider: # Serve data from the database. db = self._db_connection_provider() cursor = db.cursor() # Prefetch the tag ID matching this run and tag. cursor.execute( ''' SELECT tag_id FROM Tags JOIN Runs USING (run_id) WHERE Runs.run_name = :run AND Tags.tag_name = :tag AND Tags.plugin_name = :plugin ''', {'run': run, 'tag': tag, 'plugin': metadata.PLUGIN_NAME}) row = cursor.fetchone() if not row: raise ValueError('No histogram tag %r for run %r' % (tag, run)) (tag_id,) = row # Fetch tensor values, optionally with linear-spaced sampling by step. # For steps ranging from s_min to s_max and sample size k, this query # divides the range into k - 1 equal-sized intervals and returns the # lowest step at or above each of the k interval boundaries (which always # includes s_min and s_max, and may be fewer than k results if there are # intervals where no steps are present). For contiguous steps the results # can be formally expressed as the following: # [s_min + math.ceil(i / k * (s_max - s_min)) for i in range(0, k + 1)] cursor.execute( ''' SELECT MIN(step) AS step, computed_time, data, dtype, shape FROM Tensors INNER JOIN ( SELECT MIN(step) AS min_step, MAX(step) AS max_step FROM Tensors /* Filter out NULL so we can use TensorSeriesStepIndex. */ WHERE series = :tag_id AND step IS NOT NULL ) /* Ensure we omit reserved rows, which have NULL step values. */ WHERE series = :tag_id AND step IS NOT NULL /* Bucket rows into sample_size linearly spaced buckets, or do no sampling if sample_size is NULL. */ GROUP BY IFNULL(:sample_size - 1, max_step - min_step) * (step - min_step) / (max_step - min_step) ORDER BY step ''', {'tag_id': tag_id, 'sample_size': downsample_to}) events = [(computed_time, step, self._get_values(data, dtype, shape)) for step, computed_time, data, dtype, shape in cursor] else: # Serve data from events files. try: tensor_events = self._multiplexer.Tensors(run, tag) except KeyError: raise ValueError('No histogram tag %r for run %r' % (tag, run)) events = [[e.wall_time, e.step, tf.make_ndarray(e.tensor_proto).tolist()] for e in tensor_events] if downsample_to is not None and len(events) > downsample_to: indices = sorted(random.Random(0).sample(list(range(len(events))), downsample_to)) events = [events[i] for i in indices] return (events, 'application/json') def _get_values(self, data_blob, dtype_enum, shape_string): """Obtains values for histogram data given blob and dtype enum. Args: data_blob: The blob obtained from the database. dtype_enum: The enum representing the dtype. shape_string: A comma-separated string of numbers denoting shape. Returns: The histogram values as a list served to the frontend. """ buf = np.frombuffer(data_blob, dtype=tf.DType(dtype_enum).as_numpy_dtype) return buf.reshape([int(i) for i in shape_string.split(',')]).tolist() @wrappers.Request.application def tags_route(self, request): index = self.index_impl() return http_util.Respond(request, index, 'application/json') @wrappers.Request.application def histograms_route(self, request): """Given a tag and single run, return array of histogram values.""" tag = request.args.get('tag') run = request.args.get('run') try: (body, mime_type) = self.histograms_impl( tag, run, downsample_to=self.SAMPLE_SIZE) code = 200 except ValueError as e: (body, mime_type) = (str(e), 'text/plain') code = 400 return http_util.Respond(request, body, mime_type, code=code)