laywerrobot/lib/python3.6/site-packages/tensorboard/util.py
2020-08-27 21:55:39 +02:00

529 lines
16 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorBoard helper routine module.
This module is a trove of succinct generic helper routines that don't
pull in any heavyweight dependencies aside from TensorFlow.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import locale
import logging
import os
import re
import sys
import threading
import time
import numpy as np
import six
import tensorflow as tf
def setup_logging(streams=(sys.stderr,)):
"""Configures Python logging the way the TensorBoard team likes it.
This should be called exactly once at the beginning of main().
Args:
streams: An iterable of open files. Logs are written to each.
:type streams: tuple[file]
"""
# NOTE: Adding a level parameter to this method would be a bad idea
# because Python and ABSL disagree on the level numbers.
locale.setlocale(locale.LC_ALL, '')
tf.logging.set_verbosity(tf.logging.WARN)
# TODO(jart): Make the default TensorFlow logger behavior great again.
logging.currentframe = _hack_the_main_frame
handlers = [LogHandler(s) for s in streams]
formatter = LogFormatter()
for handler in handlers:
handler.setFormatter(formatter)
tensorflow_logger = logging.getLogger('tensorflow')
tensorflow_logger.handlers = handlers
tensorboard_logger = logging.getLogger('tensorboard')
tensorboard_logger.handlers = handlers
werkzeug_logger = logging.getLogger('werkzeug')
werkzeug_logger.setLevel(logging.WARNING)
werkzeug_logger.handlers = handlers
def closeable(class_):
"""Makes a class with a close method able to be a context manager.
This decorator is a great way to avoid having to choose between the
boilerplate of __enter__ and __exit__ methods, versus the boilerplate
of using contextlib.closing on every with statement.
Args:
class_: The class being decorated.
Raises:
ValueError: If class didn't have a close method, or already
implements __enter__ or __exit__.
"""
if 'close' not in class_.__dict__:
# coffee is for closers
raise ValueError('Class does not define a close() method: %s' % class_)
if '__enter__' in class_.__dict__ or '__exit__' in class_.__dict__:
raise ValueError('Class already defines __enter__ or __exit__: ' + class_)
class_.__enter__ = lambda self: self
class_.__exit__ = lambda self, t, v, b: self.close() and None
return class_
def close_all(resources):
"""Safely closes multiple resources.
The close method on all resources is guaranteed to be called. If
multiple close methods throw exceptions, then the first will be
raised and the rest will be logged.
Args:
resources: An iterable of object instances whose classes implement
the close method.
Raises:
Exception: To rethrow the last exception raised by a close method.
"""
exc_info = None
for resource in resources:
try:
resource.close()
except Exception as e: # pylint: disable=broad-except
if exc_info is not None:
tf.logging.error('Suppressing close(%s) failure: %s', resource, e,
exc_info=exc_info)
exc_info = sys.exc_info()
if exc_info is not None:
six.reraise(*exc_info)
def guarded_by(field):
"""Indicates method should be called from within a lock.
This decorator is purely for documentation purposes. It has the same
semantics as Java's @GuardedBy annotation.
Args:
field: The string name of the lock field, e.g. "_lock".
"""
del field
return lambda method: method
class Retrier(object):
"""Helper class for retrying things with exponential back-off."""
DELAY = 0.1
def __init__(self, is_transient, max_attempts=8, sleep=time.sleep):
"""Creates new instance.
:type is_transient: (Exception) -> bool
:type max_attempts: int
:type sleep: (float) -> None
"""
self._is_transient = is_transient
self._max_attempts = max_attempts
self._sleep = sleep
def run(self, callback):
"""Invokes callback, retrying on transient exceptions.
After the first failure, we wait 100ms, and then double with each
subsequent failed attempt. The default max attempts is 8 which
equates to about thirty seconds of sleeping total.
:type callback: () -> T
:rtype: T
"""
failures = 0
while True:
try:
return callback()
except Exception as e: # pylint: disable=broad-except
failures += 1
if failures == self._max_attempts or not self._is_transient(e):
raise
tf.logging.warn('Retrying on transient %s', e)
self._sleep(2 ** (failures - 1) * Retrier.DELAY)
class LogFormatter(logging.Formatter):
"""Google style log formatter.
The format is in essence the following:
[DIWEF]mmdd hh:mm:ss.uuuuuu thread_name file:line] msg
This class is meant to be used with LogHandler.
"""
DATE_FORMAT = '%m%d %H:%M:%S'
LOG_FORMAT = ('%(levelname)s%(asctime)s %(threadName)s '
'%(filename)s:%(lineno)d] %(message)s')
LEVEL_NAMES = {
logging.FATAL: 'F',
logging.ERROR: 'E',
logging.WARN: 'W',
logging.INFO: 'I',
logging.DEBUG: 'D',
}
def __init__(self):
"""Creates new instance."""
super(LogFormatter, self).__init__(LogFormatter.LOG_FORMAT,
LogFormatter.DATE_FORMAT)
def format(self, record):
"""Formats the log record.
:type record: logging.LogRecord
:rtype: str
"""
record.levelname = LogFormatter.LEVEL_NAMES[record.levelno]
return super(LogFormatter, self).format(record)
def formatTime(self, record, datefmt=None):
"""Return creation time of the specified LogRecord as formatted text.
This override adds microseconds.
:type record: logging.LogRecord
:rtype: str
"""
return (super(LogFormatter, self).formatTime(record, datefmt) +
'.%06d' % (record.created * 1e6 % 1e6))
class Ansi(object):
"""ANSI terminal codes container."""
ESCAPE = '\x1b['
ESCAPE_PATTERN = re.compile(re.escape(ESCAPE) + r'\??(?:\d+)(?:;\d+)*[mlh]')
RESET = ESCAPE + '0m'
BOLD = ESCAPE + '1m'
FLIP = ESCAPE + '7m'
RED = ESCAPE + '31m'
YELLOW = ESCAPE + '33m'
MAGENTA = ESCAPE + '35m'
CURSOR_HIDE = ESCAPE + '?25l'
CURSOR_SHOW = ESCAPE + '?25h'
class LogHandler(logging.StreamHandler):
"""Log handler that supports ANSI colors and ephemeral records.
Colors are applied on a line-by-line basis to non-INFO records. The
goal is to help the user visually distinguish meaningful information,
even when logging is verbose.
This handler will also strip ANSI color codes from emitted log
records automatically when the output stream is not a terminal.
Ephemeral log records are only emitted to a teletype emulator, only
display on the final row, and get overwritten as soon as another
ephemeral record is outputted. Ephemeral records are also sticky. If
a normal record is written then the previous ephemeral record is
restored right beneath it. When an ephemeral record with an empty
message is emitted, then the last ephemeral record turns into a
normal record and is allowed to spool.
This class is thread safe.
"""
EPHEMERAL = '.ephemeral' # Name suffix for ephemeral loggers.
COLORS = {
logging.FATAL: Ansi.BOLD + Ansi.RED,
logging.ERROR: Ansi.RED,
logging.WARN: Ansi.YELLOW,
logging.INFO: '',
logging.DEBUG: Ansi.MAGENTA,
}
def __init__(self, stream, type_='detect'):
"""Creates new instance.
Args:
stream: A file-like object.
type_: If "detect", will call stream.isatty() and perform system
checks to determine if it's safe to output ANSI terminal
codes. If type is "ansi" then this forces the use of ANSI
terminal codes.
Raises:
ValueError: If type is not "detect" or "ansi".
"""
if type_ not in ('detect', 'ansi'):
raise ValueError('type should be detect or ansi')
super(LogHandler, self).__init__(stream)
self._stream = stream
self._disable_flush = False
self._is_tty = (type_ == 'ansi' or
(hasattr(stream, 'isatty') and
stream.isatty() and
os.name != 'nt'))
self._ephemeral = ''
def emit(self, record):
"""Emits a log record.
:type record: logging.LogRecord
"""
self.acquire()
try:
is_ephemeral = record.name.endswith(LogHandler.EPHEMERAL)
color = LogHandler.COLORS.get(record.levelno)
if is_ephemeral:
if self._is_tty:
ephemeral = record.getMessage()
if ephemeral:
if color:
ephemeral = color + ephemeral + Ansi.RESET
self._clear_line()
self._stream.write(ephemeral)
else:
if self._ephemeral:
self._stream.write('\n')
self._ephemeral = ephemeral
else:
self._clear_line()
if self._is_tty and color:
self._stream.write(color)
self._disable_flush = True # prevent double flush
super(LogHandler, self).emit(record)
self._disable_flush = False
if self._is_tty and color:
self._stream.write(Ansi.RESET)
if self._ephemeral:
self._stream.write(self._ephemeral)
self.flush()
finally:
self._disable_flush = False
self.release()
def format(self, record):
"""Turns a log record into a string.
:type record: logging.LogRecord
:rtype: str
"""
message = super(LogHandler, self).format(record)
if not self._is_tty:
message = Ansi.ESCAPE_PATTERN.sub('', message)
return message
def flush(self):
"""Flushes output stream."""
self.acquire()
try:
if not self._disable_flush:
super(LogHandler, self).flush()
finally:
self.release()
def _clear_line(self):
if self._is_tty and self._ephemeral:
# We're counting columns in the terminal, not bytes. So we don't
# want to take UTF-8 or color codes into consideration.
text = Ansi.ESCAPE_PATTERN.sub('', tf.compat.as_text(self._ephemeral))
self._stream.write('\r' + ' ' * len(text) + '\r')
def _hack_the_main_frame():
"""Returns caller frame and skips over tf_logging.
This works around a bug in TensorFlow's open source logging module
where the Python logging module attributes log entries to the
delegate functions in tf_logging.py.
"""
if hasattr(sys, '_getframe'):
frame = sys._getframe(3)
else:
try:
raise Exception
except Exception: # pylint: disable=broad-except
frame = sys.exc_info()[2].tb_frame.f_back
if (frame is not None and
hasattr(frame.f_back, 'f_code') and
'tf_logging.py' in frame.f_back.f_code.co_filename):
return frame.f_back
return frame
class PersistentOpEvaluator(object):
"""Evaluate a fixed TensorFlow graph repeatedly, safely, efficiently.
Extend this class to create a particular kind of op evaluator, like an
image encoder. In `initialize_graph`, create an appropriate TensorFlow
graph with placeholder inputs. In `run`, evaluate this graph and
return its result. This class will manage a singleton graph and
session to preserve memory usage, and will ensure that this graph and
session do not interfere with other concurrent sessions.
A subclass of this class offers a threadsafe, highly parallel Python
entry point for evaluating a particular TensorFlow graph.
Example usage:
class FluxCapacitanceEvaluator(PersistentOpEvaluator):
\"\"\"Compute the flux capacitance required for a system.
Arguments:
x: Available power input, as a `float`, in jigawatts.
Returns:
A `float`, in nanofarads.
\"\"\"
def initialize_graph(self):
self._placeholder = tf.placeholder(some_dtype)
self._op = some_op(self._placeholder)
def run(self, x):
return self._op.eval(feed_dict: {self._placeholder: x})
evaluate_flux_capacitance = FluxCapacitanceEvaluator()
for x in xs:
evaluate_flux_capacitance(x)
"""
def __init__(self):
super(PersistentOpEvaluator, self).__init__()
self._session = None
self._initialization_lock = threading.Lock()
def _lazily_initialize(self):
"""Initialize the graph and session, if this has not yet been done."""
with self._initialization_lock:
if self._session:
return
graph = tf.Graph()
with graph.as_default():
self.initialize_graph()
# Don't reserve GPU because libpng can't run on GPU.
config = tf.ConfigProto(device_count={'GPU': 0})
self._session = tf.Session(graph=graph, config=config)
def initialize_graph(self):
"""Create the TensorFlow graph needed to compute this operation.
This should write ops to the default graph and return `None`.
"""
raise NotImplementedError('Subclasses must implement "initialize_graph".')
def run(self, *args, **kwargs):
"""Evaluate the ops with the given input.
When this function is called, the default session will have the
graph defined by a previous call to `initialize_graph`. This
function should evaluate any ops necessary to compute the result of
the query for the given *args and **kwargs, likely returning the
result of a call to `some_op.eval(...)`.
"""
raise NotImplementedError('Subclasses must implement "run".')
def __call__(self, *args, **kwargs):
self._lazily_initialize()
with self._session.as_default():
return self.run(*args, **kwargs)
class _TensorFlowPngEncoder(PersistentOpEvaluator):
"""Encode an image to PNG.
This function is thread-safe, and has high performance when run in
parallel. See `encode_png_benchmark.py` for details.
Arguments:
image: A numpy array of shape `[height, width, channels]`, where
`channels` is 1, 3, or 4, and of dtype uint8.
Returns:
A bytestring with PNG-encoded data.
"""
def __init__(self):
super(_TensorFlowPngEncoder, self).__init__()
self._image_placeholder = None
self._encode_op = None
def initialize_graph(self):
self._image_placeholder = tf.placeholder(
dtype=tf.uint8, name='image_to_encode')
self._encode_op = tf.image.encode_png(self._image_placeholder)
def run(self, image): # pylint: disable=arguments-differ
if not isinstance(image, np.ndarray):
raise ValueError("'image' must be a numpy array: %r" % image)
if image.dtype != np.uint8:
raise ValueError("'image' dtype must be uint8, but is %r" % image.dtype)
return self._encode_op.eval(feed_dict={self._image_placeholder: image})
encode_png = _TensorFlowPngEncoder()
class _TensorFlowWavEncoder(PersistentOpEvaluator):
"""Encode an audio clip to WAV.
This function is thread-safe and exhibits good parallel performance.
Arguments:
audio: A numpy array of shape `[samples, channels]`.
samples_per_second: A positive `int`, in Hz.
Returns:
A bytestring with WAV-encoded data.
"""
def __init__(self):
super(_TensorFlowWavEncoder, self).__init__()
self._audio_placeholder = None
self._samples_per_second_placeholder = None
self._encode_op = None
def initialize_graph(self):
self._audio_placeholder = tf.placeholder(
dtype=tf.float32, name='image_to_encode')
self._samples_per_second_placeholder = tf.placeholder(
dtype=tf.int32, name='samples_per_second')
self._encode_op = tf.contrib.ffmpeg.encode_audio(
self._audio_placeholder,
file_format='wav',
samples_per_second=self._samples_per_second_placeholder)
def run(self, audio, samples_per_second): # pylint: disable=arguments-differ
if not isinstance(audio, np.ndarray):
raise ValueError("'audio' must be a numpy array: %r" % audio)
if not isinstance(samples_per_second, int):
raise ValueError("'samples_per_second' must be an int: %r"
% samples_per_second)
feed_dict = {
self._audio_placeholder: audio,
self._samples_per_second_placeholder: samples_per_second,
}
return self._encode_op.eval(feed_dict=feed_dict)
encode_wav = _TensorFlowWavEncoder()