laywerrobot/lib/python3.6/site-packages/tensorflow/contrib/specs/python/summaries.py
2020-08-27 21:55:39 +02:00

313 lines
8.6 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for summarizing and describing TensorFlow graphs.
This contains functions that generate string descriptions from
TensorFlow graphs, for debugging, testing, and model size
estimation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from tensorflow.contrib.specs.python import specs
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
# These are short abbreviations for common TensorFlow operations used
# in test cases with tf_structure to verify that specs_lib generates a
# graph structure with the right operations. Operations outside the
# scope of specs (e.g., Const and Placeholder) are just assigned "_"
# since they are not relevant to testing.
SHORT_NAMES_SRC = """
BiasAdd biasadd
Const _
Conv2D conv
MatMul dot
Placeholder _
Sigmoid sig
Variable var
""".split()
SHORT_NAMES = {
x: y
for x, y in zip(SHORT_NAMES_SRC[::2], SHORT_NAMES_SRC[1::2])
}
def _truncate_structure(x):
"""A helper function that disables recursion in tf_structure.
Some constructs (e.g., HorizontalLstm) are complex unrolled
structures and don't need to be represented in the output
of tf_structure or tf_print. This helper function defines
which tree branches should be pruned. This is a very imperfect
way of dealing with unrolled LSTM's (since it truncates
useful information as well), but it's not worth doing something
better until the new fused and unrolled ops are ready.
Args:
x: a Tensor or Op
Returns:
A bool indicating whether the subtree should be pruned.
"""
if "/HorizontalLstm/" in x.name:
return True
return False
def tf_structure(x, include_shapes=False, finished=None):
"""A postfix expression summarizing the TF graph.
This is intended to be used as part of test cases to
check for gross differences in the structure of the graph.
The resulting string is not invertible or unabiguous
and cannot be used to reconstruct the graph accurately.
Args:
x: a tf.Tensor or tf.Operation
include_shapes: include shapes in the output string
finished: a set of ops that have already been output
Returns:
A string representing the structure as a string of
postfix operations.
"""
if finished is None:
finished = set()
if isinstance(x, ops.Tensor):
shape = x.get_shape().as_list()
x = x.op
else:
shape = []
if x in finished:
return " <>"
finished |= {x}
result = ""
if not _truncate_structure(x):
for y in x.inputs:
result += tf_structure(y, include_shapes, finished)
if include_shapes:
result += " %s" % (shape,)
if x.type != "Identity":
name = SHORT_NAMES.get(x.type, x.type.lower())
result += " " + name
return result
def tf_print(x, depth=0, finished=None, printer=print):
"""A simple print function for a TensorFlow graph.
Args:
x: a tf.Tensor or tf.Operation
depth: current printing depth
finished: set of nodes already output
printer: print function to use
Returns:
Total number of parameters found in the
subtree.
"""
if finished is None:
finished = set()
if isinstance(x, ops.Tensor):
shape = x.get_shape().as_list()
x = x.op
else:
shape = ""
if x.type == "Identity":
x = x.inputs[0].op
if x in finished:
printer("%s<%s> %s %s" % (" " * depth, x.name, x.type, shape))
return
finished |= {x}
printer("%s%s %s %s" % (" " * depth, x.name, x.type, shape))
if not _truncate_structure(x):
for y in x.inputs:
tf_print(y, depth + 1, finished, printer=printer)
def tf_num_params(x):
"""Number of parameters in a TensorFlow subgraph.
Args:
x: root of the subgraph (Tensor, Operation)
Returns:
Total number of elements found in all Variables
in the subgraph.
"""
if isinstance(x, ops.Tensor):
shape = x.get_shape()
x = x.op
if x.type in ["Variable", "VariableV2"]:
return shape.num_elements()
totals = [tf_num_params(y) for y in x.inputs]
return sum(totals)
def tf_left_split(op):
"""Split the parameters of op for left recursion.
Args:
op: tf.Operation
Returns:
A tuple of the leftmost input tensor and a list of the
remaining arguments.
"""
if len(op.inputs) < 1:
return None, []
if op.type == "Concat":
return op.inputs[1], op.inputs[2:]
return op.inputs[0], op.inputs[1:]
def tf_parameter_iter(x):
"""Iterate over the left branches of a graph and yield sizes.
Args:
x: root of the subgraph (Tensor, Operation)
Yields:
A triple of name, number of params, and shape.
"""
while 1:
if isinstance(x, ops.Tensor):
shape = x.get_shape().as_list()
x = x.op
else:
shape = ""
left, right = tf_left_split(x)
totals = [tf_num_params(y) for y in right]
total = sum(totals)
yield x.name, total, shape
if left is None:
break
x = left
def _combine_filter(x):
"""A filter for combining successive layers with similar names."""
last_name = None
last_total = 0
last_shape = None
for name, total, shape in x:
name = re.sub("/.*", "", name)
if name == last_name:
last_total += total
continue
if last_name is not None:
yield last_name, last_total, last_shape
last_name = name
last_total = total
last_shape = shape
if last_name is not None:
yield last_name, last_total, last_shape
def tf_parameter_summary(x, printer=print, combine=True):
"""Summarize parameters by depth.
Args:
x: root of the subgraph (Tensor, Operation)
printer: print function for output
combine: combine layers by top-level scope
"""
seq = tf_parameter_iter(x)
if combine:
seq = _combine_filter(seq)
seq = reversed(list(seq))
for name, total, shape in seq:
printer("%10d %-20s %s" % (total, name, shape))
def tf_spec_structure(spec,
inputs=None,
input_shape=None,
input_type=dtypes.float32):
"""Return a postfix representation of the specification.
This is intended to be used as part of test cases to
check for gross differences in the structure of the graph.
The resulting string is not invertible or unabiguous
and cannot be used to reconstruct the graph accurately.
Args:
spec: specification
inputs: input to the spec construction (usually a Tensor)
input_shape: tensor shape (in lieu of inputs)
input_type: type of the input tensor
Returns:
A string with a postfix representation of the
specification.
"""
if inputs is None:
inputs = array_ops.placeholder(input_type, input_shape)
outputs = specs.create_net(spec, inputs)
return str(tf_structure(outputs).strip())
def tf_spec_summary(spec,
inputs=None,
input_shape=None,
input_type=dtypes.float32):
"""Output a summary of the specification.
This prints a list of left-most tensor operations and summarized the
variables found in the right branches. This kind of representation
is particularly useful for networks that are generally structured
like pipelines.
Args:
spec: specification
inputs: input to the spec construction (usually a Tensor)
input_shape: optional shape of input
input_type: type of the input tensor
"""
if inputs is None:
inputs = array_ops.placeholder(input_type, input_shape)
outputs = specs.create_net(spec, inputs)
tf_parameter_summary(outputs)
def tf_spec_print(spec,
inputs=None,
input_shape=None,
input_type=dtypes.float32):
"""Print a tree representing the spec.
Args:
spec: specification
inputs: input to the spec construction (usually a Tensor)
input_shape: optional shape of input
input_type: type of the input tensor
"""
if inputs is None:
inputs = array_ops.placeholder(input_type, input_shape)
outputs = specs.create_net(spec, inputs)
tf_print(outputs)