530 lines
16 KiB
Python
530 lines
16 KiB
Python
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
"""TensorBoard helper routine module.
|
||
|
|
||
|
This module is a trove of succinct generic helper routines that don't
|
||
|
pull in any heavyweight dependencies aside from TensorFlow.
|
||
|
"""
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import locale
|
||
|
import logging
|
||
|
import os
|
||
|
import re
|
||
|
import sys
|
||
|
import threading
|
||
|
import time
|
||
|
|
||
|
import numpy as np
|
||
|
import six
|
||
|
import tensorflow as tf
|
||
|
|
||
|
|
||
|
def setup_logging(streams=(sys.stderr,)):
|
||
|
"""Configures Python logging the way the TensorBoard team likes it.
|
||
|
|
||
|
This should be called exactly once at the beginning of main().
|
||
|
|
||
|
Args:
|
||
|
streams: An iterable of open files. Logs are written to each.
|
||
|
|
||
|
:type streams: tuple[file]
|
||
|
"""
|
||
|
# NOTE: Adding a level parameter to this method would be a bad idea
|
||
|
# because Python and ABSL disagree on the level numbers.
|
||
|
locale.setlocale(locale.LC_ALL, '')
|
||
|
tf.logging.set_verbosity(tf.logging.WARN)
|
||
|
# TODO(jart): Make the default TensorFlow logger behavior great again.
|
||
|
logging.currentframe = _hack_the_main_frame
|
||
|
handlers = [LogHandler(s) for s in streams]
|
||
|
formatter = LogFormatter()
|
||
|
for handler in handlers:
|
||
|
handler.setFormatter(formatter)
|
||
|
tensorflow_logger = logging.getLogger('tensorflow')
|
||
|
tensorflow_logger.handlers = handlers
|
||
|
tensorboard_logger = logging.getLogger('tensorboard')
|
||
|
tensorboard_logger.handlers = handlers
|
||
|
werkzeug_logger = logging.getLogger('werkzeug')
|
||
|
werkzeug_logger.setLevel(logging.WARNING)
|
||
|
werkzeug_logger.handlers = handlers
|
||
|
|
||
|
|
||
|
def closeable(class_):
|
||
|
"""Makes a class with a close method able to be a context manager.
|
||
|
|
||
|
This decorator is a great way to avoid having to choose between the
|
||
|
boilerplate of __enter__ and __exit__ methods, versus the boilerplate
|
||
|
of using contextlib.closing on every with statement.
|
||
|
|
||
|
Args:
|
||
|
class_: The class being decorated.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If class didn't have a close method, or already
|
||
|
implements __enter__ or __exit__.
|
||
|
"""
|
||
|
if 'close' not in class_.__dict__:
|
||
|
# coffee is for closers
|
||
|
raise ValueError('Class does not define a close() method: %s' % class_)
|
||
|
if '__enter__' in class_.__dict__ or '__exit__' in class_.__dict__:
|
||
|
raise ValueError('Class already defines __enter__ or __exit__: ' + class_)
|
||
|
class_.__enter__ = lambda self: self
|
||
|
class_.__exit__ = lambda self, t, v, b: self.close() and None
|
||
|
return class_
|
||
|
|
||
|
|
||
|
def close_all(resources):
|
||
|
"""Safely closes multiple resources.
|
||
|
|
||
|
The close method on all resources is guaranteed to be called. If
|
||
|
multiple close methods throw exceptions, then the first will be
|
||
|
raised and the rest will be logged.
|
||
|
|
||
|
Args:
|
||
|
resources: An iterable of object instances whose classes implement
|
||
|
the close method.
|
||
|
|
||
|
Raises:
|
||
|
Exception: To rethrow the last exception raised by a close method.
|
||
|
"""
|
||
|
exc_info = None
|
||
|
for resource in resources:
|
||
|
try:
|
||
|
resource.close()
|
||
|
except Exception as e: # pylint: disable=broad-except
|
||
|
if exc_info is not None:
|
||
|
tf.logging.error('Suppressing close(%s) failure: %s', resource, e,
|
||
|
exc_info=exc_info)
|
||
|
exc_info = sys.exc_info()
|
||
|
if exc_info is not None:
|
||
|
six.reraise(*exc_info)
|
||
|
|
||
|
|
||
|
def guarded_by(field):
|
||
|
"""Indicates method should be called from within a lock.
|
||
|
|
||
|
This decorator is purely for documentation purposes. It has the same
|
||
|
semantics as Java's @GuardedBy annotation.
|
||
|
|
||
|
Args:
|
||
|
field: The string name of the lock field, e.g. "_lock".
|
||
|
"""
|
||
|
del field
|
||
|
return lambda method: method
|
||
|
|
||
|
|
||
|
class Retrier(object):
|
||
|
"""Helper class for retrying things with exponential back-off."""
|
||
|
|
||
|
DELAY = 0.1
|
||
|
|
||
|
def __init__(self, is_transient, max_attempts=8, sleep=time.sleep):
|
||
|
"""Creates new instance.
|
||
|
|
||
|
:type is_transient: (Exception) -> bool
|
||
|
:type max_attempts: int
|
||
|
:type sleep: (float) -> None
|
||
|
"""
|
||
|
self._is_transient = is_transient
|
||
|
self._max_attempts = max_attempts
|
||
|
self._sleep = sleep
|
||
|
|
||
|
def run(self, callback):
|
||
|
"""Invokes callback, retrying on transient exceptions.
|
||
|
|
||
|
After the first failure, we wait 100ms, and then double with each
|
||
|
subsequent failed attempt. The default max attempts is 8 which
|
||
|
equates to about thirty seconds of sleeping total.
|
||
|
|
||
|
:type callback: () -> T
|
||
|
:rtype: T
|
||
|
"""
|
||
|
failures = 0
|
||
|
while True:
|
||
|
try:
|
||
|
return callback()
|
||
|
except Exception as e: # pylint: disable=broad-except
|
||
|
failures += 1
|
||
|
if failures == self._max_attempts or not self._is_transient(e):
|
||
|
raise
|
||
|
tf.logging.warn('Retrying on transient %s', e)
|
||
|
self._sleep(2 ** (failures - 1) * Retrier.DELAY)
|
||
|
|
||
|
|
||
|
class LogFormatter(logging.Formatter):
|
||
|
"""Google style log formatter.
|
||
|
|
||
|
The format is in essence the following:
|
||
|
|
||
|
[DIWEF]mmdd hh:mm:ss.uuuuuu thread_name file:line] msg
|
||
|
|
||
|
This class is meant to be used with LogHandler.
|
||
|
"""
|
||
|
|
||
|
DATE_FORMAT = '%m%d %H:%M:%S'
|
||
|
LOG_FORMAT = ('%(levelname)s%(asctime)s %(threadName)s '
|
||
|
'%(filename)s:%(lineno)d] %(message)s')
|
||
|
|
||
|
LEVEL_NAMES = {
|
||
|
logging.FATAL: 'F',
|
||
|
logging.ERROR: 'E',
|
||
|
logging.WARN: 'W',
|
||
|
logging.INFO: 'I',
|
||
|
logging.DEBUG: 'D',
|
||
|
}
|
||
|
|
||
|
def __init__(self):
|
||
|
"""Creates new instance."""
|
||
|
super(LogFormatter, self).__init__(LogFormatter.LOG_FORMAT,
|
||
|
LogFormatter.DATE_FORMAT)
|
||
|
|
||
|
def format(self, record):
|
||
|
"""Formats the log record.
|
||
|
|
||
|
:type record: logging.LogRecord
|
||
|
:rtype: str
|
||
|
"""
|
||
|
record.levelname = LogFormatter.LEVEL_NAMES[record.levelno]
|
||
|
return super(LogFormatter, self).format(record)
|
||
|
|
||
|
def formatTime(self, record, datefmt=None):
|
||
|
"""Return creation time of the specified LogRecord as formatted text.
|
||
|
|
||
|
This override adds microseconds.
|
||
|
|
||
|
:type record: logging.LogRecord
|
||
|
:rtype: str
|
||
|
"""
|
||
|
return (super(LogFormatter, self).formatTime(record, datefmt) +
|
||
|
'.%06d' % (record.created * 1e6 % 1e6))
|
||
|
|
||
|
|
||
|
class Ansi(object):
|
||
|
"""ANSI terminal codes container."""
|
||
|
|
||
|
ESCAPE = '\x1b['
|
||
|
ESCAPE_PATTERN = re.compile(re.escape(ESCAPE) + r'\??(?:\d+)(?:;\d+)*[mlh]')
|
||
|
RESET = ESCAPE + '0m'
|
||
|
BOLD = ESCAPE + '1m'
|
||
|
FLIP = ESCAPE + '7m'
|
||
|
RED = ESCAPE + '31m'
|
||
|
YELLOW = ESCAPE + '33m'
|
||
|
MAGENTA = ESCAPE + '35m'
|
||
|
CURSOR_HIDE = ESCAPE + '?25l'
|
||
|
CURSOR_SHOW = ESCAPE + '?25h'
|
||
|
|
||
|
|
||
|
class LogHandler(logging.StreamHandler):
|
||
|
"""Log handler that supports ANSI colors and ephemeral records.
|
||
|
|
||
|
Colors are applied on a line-by-line basis to non-INFO records. The
|
||
|
goal is to help the user visually distinguish meaningful information,
|
||
|
even when logging is verbose.
|
||
|
|
||
|
This handler will also strip ANSI color codes from emitted log
|
||
|
records automatically when the output stream is not a terminal.
|
||
|
|
||
|
Ephemeral log records are only emitted to a teletype emulator, only
|
||
|
display on the final row, and get overwritten as soon as another
|
||
|
ephemeral record is outputted. Ephemeral records are also sticky. If
|
||
|
a normal record is written then the previous ephemeral record is
|
||
|
restored right beneath it. When an ephemeral record with an empty
|
||
|
message is emitted, then the last ephemeral record turns into a
|
||
|
normal record and is allowed to spool.
|
||
|
|
||
|
This class is thread safe.
|
||
|
"""
|
||
|
|
||
|
EPHEMERAL = '.ephemeral' # Name suffix for ephemeral loggers.
|
||
|
|
||
|
COLORS = {
|
||
|
logging.FATAL: Ansi.BOLD + Ansi.RED,
|
||
|
logging.ERROR: Ansi.RED,
|
||
|
logging.WARN: Ansi.YELLOW,
|
||
|
logging.INFO: '',
|
||
|
logging.DEBUG: Ansi.MAGENTA,
|
||
|
}
|
||
|
|
||
|
def __init__(self, stream, type_='detect'):
|
||
|
"""Creates new instance.
|
||
|
|
||
|
Args:
|
||
|
stream: A file-like object.
|
||
|
type_: If "detect", will call stream.isatty() and perform system
|
||
|
checks to determine if it's safe to output ANSI terminal
|
||
|
codes. If type is "ansi" then this forces the use of ANSI
|
||
|
terminal codes.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If type is not "detect" or "ansi".
|
||
|
"""
|
||
|
if type_ not in ('detect', 'ansi'):
|
||
|
raise ValueError('type should be detect or ansi')
|
||
|
super(LogHandler, self).__init__(stream)
|
||
|
self._stream = stream
|
||
|
self._disable_flush = False
|
||
|
self._is_tty = (type_ == 'ansi' or
|
||
|
(hasattr(stream, 'isatty') and
|
||
|
stream.isatty() and
|
||
|
os.name != 'nt'))
|
||
|
self._ephemeral = ''
|
||
|
|
||
|
def emit(self, record):
|
||
|
"""Emits a log record.
|
||
|
|
||
|
:type record: logging.LogRecord
|
||
|
"""
|
||
|
self.acquire()
|
||
|
try:
|
||
|
is_ephemeral = record.name.endswith(LogHandler.EPHEMERAL)
|
||
|
color = LogHandler.COLORS.get(record.levelno)
|
||
|
if is_ephemeral:
|
||
|
if self._is_tty:
|
||
|
ephemeral = record.getMessage()
|
||
|
if ephemeral:
|
||
|
if color:
|
||
|
ephemeral = color + ephemeral + Ansi.RESET
|
||
|
self._clear_line()
|
||
|
self._stream.write(ephemeral)
|
||
|
else:
|
||
|
if self._ephemeral:
|
||
|
self._stream.write('\n')
|
||
|
self._ephemeral = ephemeral
|
||
|
else:
|
||
|
self._clear_line()
|
||
|
if self._is_tty and color:
|
||
|
self._stream.write(color)
|
||
|
self._disable_flush = True # prevent double flush
|
||
|
super(LogHandler, self).emit(record)
|
||
|
self._disable_flush = False
|
||
|
if self._is_tty and color:
|
||
|
self._stream.write(Ansi.RESET)
|
||
|
if self._ephemeral:
|
||
|
self._stream.write(self._ephemeral)
|
||
|
self.flush()
|
||
|
finally:
|
||
|
self._disable_flush = False
|
||
|
self.release()
|
||
|
|
||
|
def format(self, record):
|
||
|
"""Turns a log record into a string.
|
||
|
|
||
|
:type record: logging.LogRecord
|
||
|
:rtype: str
|
||
|
"""
|
||
|
message = super(LogHandler, self).format(record)
|
||
|
if not self._is_tty:
|
||
|
message = Ansi.ESCAPE_PATTERN.sub('', message)
|
||
|
return message
|
||
|
|
||
|
def flush(self):
|
||
|
"""Flushes output stream."""
|
||
|
self.acquire()
|
||
|
try:
|
||
|
if not self._disable_flush:
|
||
|
super(LogHandler, self).flush()
|
||
|
finally:
|
||
|
self.release()
|
||
|
|
||
|
def _clear_line(self):
|
||
|
if self._is_tty and self._ephemeral:
|
||
|
# We're counting columns in the terminal, not bytes. So we don't
|
||
|
# want to take UTF-8 or color codes into consideration.
|
||
|
text = Ansi.ESCAPE_PATTERN.sub('', tf.compat.as_text(self._ephemeral))
|
||
|
self._stream.write('\r' + ' ' * len(text) + '\r')
|
||
|
|
||
|
|
||
|
def _hack_the_main_frame():
|
||
|
"""Returns caller frame and skips over tf_logging.
|
||
|
|
||
|
This works around a bug in TensorFlow's open source logging module
|
||
|
where the Python logging module attributes log entries to the
|
||
|
delegate functions in tf_logging.py.
|
||
|
"""
|
||
|
if hasattr(sys, '_getframe'):
|
||
|
frame = sys._getframe(3)
|
||
|
else:
|
||
|
try:
|
||
|
raise Exception
|
||
|
except Exception: # pylint: disable=broad-except
|
||
|
frame = sys.exc_info()[2].tb_frame.f_back
|
||
|
if (frame is not None and
|
||
|
hasattr(frame.f_back, 'f_code') and
|
||
|
'tf_logging.py' in frame.f_back.f_code.co_filename):
|
||
|
return frame.f_back
|
||
|
return frame
|
||
|
|
||
|
|
||
|
class PersistentOpEvaluator(object):
|
||
|
"""Evaluate a fixed TensorFlow graph repeatedly, safely, efficiently.
|
||
|
|
||
|
Extend this class to create a particular kind of op evaluator, like an
|
||
|
image encoder. In `initialize_graph`, create an appropriate TensorFlow
|
||
|
graph with placeholder inputs. In `run`, evaluate this graph and
|
||
|
return its result. This class will manage a singleton graph and
|
||
|
session to preserve memory usage, and will ensure that this graph and
|
||
|
session do not interfere with other concurrent sessions.
|
||
|
|
||
|
A subclass of this class offers a threadsafe, highly parallel Python
|
||
|
entry point for evaluating a particular TensorFlow graph.
|
||
|
|
||
|
Example usage:
|
||
|
|
||
|
class FluxCapacitanceEvaluator(PersistentOpEvaluator):
|
||
|
\"\"\"Compute the flux capacitance required for a system.
|
||
|
|
||
|
Arguments:
|
||
|
x: Available power input, as a `float`, in jigawatts.
|
||
|
|
||
|
Returns:
|
||
|
A `float`, in nanofarads.
|
||
|
\"\"\"
|
||
|
|
||
|
def initialize_graph(self):
|
||
|
self._placeholder = tf.placeholder(some_dtype)
|
||
|
self._op = some_op(self._placeholder)
|
||
|
|
||
|
def run(self, x):
|
||
|
return self._op.eval(feed_dict: {self._placeholder: x})
|
||
|
|
||
|
evaluate_flux_capacitance = FluxCapacitanceEvaluator()
|
||
|
|
||
|
for x in xs:
|
||
|
evaluate_flux_capacitance(x)
|
||
|
"""
|
||
|
|
||
|
def __init__(self):
|
||
|
super(PersistentOpEvaluator, self).__init__()
|
||
|
self._session = None
|
||
|
self._initialization_lock = threading.Lock()
|
||
|
|
||
|
def _lazily_initialize(self):
|
||
|
"""Initialize the graph and session, if this has not yet been done."""
|
||
|
with self._initialization_lock:
|
||
|
if self._session:
|
||
|
return
|
||
|
graph = tf.Graph()
|
||
|
with graph.as_default():
|
||
|
self.initialize_graph()
|
||
|
# Don't reserve GPU because libpng can't run on GPU.
|
||
|
config = tf.ConfigProto(device_count={'GPU': 0})
|
||
|
self._session = tf.Session(graph=graph, config=config)
|
||
|
|
||
|
def initialize_graph(self):
|
||
|
"""Create the TensorFlow graph needed to compute this operation.
|
||
|
|
||
|
This should write ops to the default graph and return `None`.
|
||
|
"""
|
||
|
raise NotImplementedError('Subclasses must implement "initialize_graph".')
|
||
|
|
||
|
def run(self, *args, **kwargs):
|
||
|
"""Evaluate the ops with the given input.
|
||
|
|
||
|
When this function is called, the default session will have the
|
||
|
graph defined by a previous call to `initialize_graph`. This
|
||
|
function should evaluate any ops necessary to compute the result of
|
||
|
the query for the given *args and **kwargs, likely returning the
|
||
|
result of a call to `some_op.eval(...)`.
|
||
|
"""
|
||
|
raise NotImplementedError('Subclasses must implement "run".')
|
||
|
|
||
|
def __call__(self, *args, **kwargs):
|
||
|
self._lazily_initialize()
|
||
|
with self._session.as_default():
|
||
|
return self.run(*args, **kwargs)
|
||
|
|
||
|
|
||
|
class _TensorFlowPngEncoder(PersistentOpEvaluator):
|
||
|
"""Encode an image to PNG.
|
||
|
|
||
|
This function is thread-safe, and has high performance when run in
|
||
|
parallel. See `encode_png_benchmark.py` for details.
|
||
|
|
||
|
Arguments:
|
||
|
image: A numpy array of shape `[height, width, channels]`, where
|
||
|
`channels` is 1, 3, or 4, and of dtype uint8.
|
||
|
|
||
|
Returns:
|
||
|
A bytestring with PNG-encoded data.
|
||
|
"""
|
||
|
|
||
|
def __init__(self):
|
||
|
super(_TensorFlowPngEncoder, self).__init__()
|
||
|
self._image_placeholder = None
|
||
|
self._encode_op = None
|
||
|
|
||
|
def initialize_graph(self):
|
||
|
self._image_placeholder = tf.placeholder(
|
||
|
dtype=tf.uint8, name='image_to_encode')
|
||
|
self._encode_op = tf.image.encode_png(self._image_placeholder)
|
||
|
|
||
|
def run(self, image): # pylint: disable=arguments-differ
|
||
|
if not isinstance(image, np.ndarray):
|
||
|
raise ValueError("'image' must be a numpy array: %r" % image)
|
||
|
if image.dtype != np.uint8:
|
||
|
raise ValueError("'image' dtype must be uint8, but is %r" % image.dtype)
|
||
|
return self._encode_op.eval(feed_dict={self._image_placeholder: image})
|
||
|
|
||
|
|
||
|
encode_png = _TensorFlowPngEncoder()
|
||
|
|
||
|
|
||
|
class _TensorFlowWavEncoder(PersistentOpEvaluator):
|
||
|
"""Encode an audio clip to WAV.
|
||
|
|
||
|
This function is thread-safe and exhibits good parallel performance.
|
||
|
|
||
|
Arguments:
|
||
|
audio: A numpy array of shape `[samples, channels]`.
|
||
|
samples_per_second: A positive `int`, in Hz.
|
||
|
|
||
|
Returns:
|
||
|
A bytestring with WAV-encoded data.
|
||
|
"""
|
||
|
|
||
|
def __init__(self):
|
||
|
super(_TensorFlowWavEncoder, self).__init__()
|
||
|
self._audio_placeholder = None
|
||
|
self._samples_per_second_placeholder = None
|
||
|
self._encode_op = None
|
||
|
|
||
|
def initialize_graph(self):
|
||
|
self._audio_placeholder = tf.placeholder(
|
||
|
dtype=tf.float32, name='image_to_encode')
|
||
|
self._samples_per_second_placeholder = tf.placeholder(
|
||
|
dtype=tf.int32, name='samples_per_second')
|
||
|
self._encode_op = tf.contrib.ffmpeg.encode_audio(
|
||
|
self._audio_placeholder,
|
||
|
file_format='wav',
|
||
|
samples_per_second=self._samples_per_second_placeholder)
|
||
|
|
||
|
def run(self, audio, samples_per_second): # pylint: disable=arguments-differ
|
||
|
if not isinstance(audio, np.ndarray):
|
||
|
raise ValueError("'audio' must be a numpy array: %r" % audio)
|
||
|
if not isinstance(samples_per_second, int):
|
||
|
raise ValueError("'samples_per_second' must be an int: %r"
|
||
|
% samples_per_second)
|
||
|
feed_dict = {
|
||
|
self._audio_placeholder: audio,
|
||
|
self._samples_per_second_placeholder: samples_per_second,
|
||
|
}
|
||
|
return self._encode_op.eval(feed_dict=feed_dict)
|
||
|
|
||
|
|
||
|
encode_wav = _TensorFlowWavEncoder()
|