314 lines
8.6 KiB
Python
314 lines
8.6 KiB
Python
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Functions for summarizing and describing TensorFlow graphs.
|
||
|
|
||
|
This contains functions that generate string descriptions from
|
||
|
TensorFlow graphs, for debugging, testing, and model size
|
||
|
estimation.
|
||
|
"""
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import re
|
||
|
from tensorflow.contrib.specs.python import specs
|
||
|
from tensorflow.python.framework import dtypes
|
||
|
from tensorflow.python.framework import ops
|
||
|
from tensorflow.python.ops import array_ops
|
||
|
|
||
|
# These are short abbreviations for common TensorFlow operations used
|
||
|
# in test cases with tf_structure to verify that specs_lib generates a
|
||
|
# graph structure with the right operations. Operations outside the
|
||
|
# scope of specs (e.g., Const and Placeholder) are just assigned "_"
|
||
|
# since they are not relevant to testing.
|
||
|
|
||
|
SHORT_NAMES_SRC = """
|
||
|
BiasAdd biasadd
|
||
|
Const _
|
||
|
Conv2D conv
|
||
|
MatMul dot
|
||
|
Placeholder _
|
||
|
Sigmoid sig
|
||
|
Variable var
|
||
|
""".split()
|
||
|
|
||
|
SHORT_NAMES = {
|
||
|
x: y
|
||
|
for x, y in zip(SHORT_NAMES_SRC[::2], SHORT_NAMES_SRC[1::2])
|
||
|
}
|
||
|
|
||
|
|
||
|
def _truncate_structure(x):
|
||
|
"""A helper function that disables recursion in tf_structure.
|
||
|
|
||
|
Some constructs (e.g., HorizontalLstm) are complex unrolled
|
||
|
structures and don't need to be represented in the output
|
||
|
of tf_structure or tf_print. This helper function defines
|
||
|
which tree branches should be pruned. This is a very imperfect
|
||
|
way of dealing with unrolled LSTM's (since it truncates
|
||
|
useful information as well), but it's not worth doing something
|
||
|
better until the new fused and unrolled ops are ready.
|
||
|
|
||
|
Args:
|
||
|
x: a Tensor or Op
|
||
|
|
||
|
Returns:
|
||
|
A bool indicating whether the subtree should be pruned.
|
||
|
"""
|
||
|
if "/HorizontalLstm/" in x.name:
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
def tf_structure(x, include_shapes=False, finished=None):
|
||
|
"""A postfix expression summarizing the TF graph.
|
||
|
|
||
|
This is intended to be used as part of test cases to
|
||
|
check for gross differences in the structure of the graph.
|
||
|
The resulting string is not invertible or unabiguous
|
||
|
and cannot be used to reconstruct the graph accurately.
|
||
|
|
||
|
Args:
|
||
|
x: a tf.Tensor or tf.Operation
|
||
|
include_shapes: include shapes in the output string
|
||
|
finished: a set of ops that have already been output
|
||
|
|
||
|
Returns:
|
||
|
A string representing the structure as a string of
|
||
|
postfix operations.
|
||
|
"""
|
||
|
if finished is None:
|
||
|
finished = set()
|
||
|
if isinstance(x, ops.Tensor):
|
||
|
shape = x.get_shape().as_list()
|
||
|
x = x.op
|
||
|
else:
|
||
|
shape = []
|
||
|
if x in finished:
|
||
|
return " <>"
|
||
|
finished |= {x}
|
||
|
result = ""
|
||
|
if not _truncate_structure(x):
|
||
|
for y in x.inputs:
|
||
|
result += tf_structure(y, include_shapes, finished)
|
||
|
if include_shapes:
|
||
|
result += " %s" % (shape,)
|
||
|
if x.type != "Identity":
|
||
|
name = SHORT_NAMES.get(x.type, x.type.lower())
|
||
|
result += " " + name
|
||
|
return result
|
||
|
|
||
|
|
||
|
def tf_print(x, depth=0, finished=None, printer=print):
|
||
|
"""A simple print function for a TensorFlow graph.
|
||
|
|
||
|
Args:
|
||
|
x: a tf.Tensor or tf.Operation
|
||
|
depth: current printing depth
|
||
|
finished: set of nodes already output
|
||
|
printer: print function to use
|
||
|
|
||
|
Returns:
|
||
|
Total number of parameters found in the
|
||
|
subtree.
|
||
|
"""
|
||
|
|
||
|
if finished is None:
|
||
|
finished = set()
|
||
|
if isinstance(x, ops.Tensor):
|
||
|
shape = x.get_shape().as_list()
|
||
|
x = x.op
|
||
|
else:
|
||
|
shape = ""
|
||
|
if x.type == "Identity":
|
||
|
x = x.inputs[0].op
|
||
|
if x in finished:
|
||
|
printer("%s<%s> %s %s" % (" " * depth, x.name, x.type, shape))
|
||
|
return
|
||
|
finished |= {x}
|
||
|
printer("%s%s %s %s" % (" " * depth, x.name, x.type, shape))
|
||
|
if not _truncate_structure(x):
|
||
|
for y in x.inputs:
|
||
|
tf_print(y, depth + 1, finished, printer=printer)
|
||
|
|
||
|
|
||
|
def tf_num_params(x):
|
||
|
"""Number of parameters in a TensorFlow subgraph.
|
||
|
|
||
|
Args:
|
||
|
x: root of the subgraph (Tensor, Operation)
|
||
|
|
||
|
Returns:
|
||
|
Total number of elements found in all Variables
|
||
|
in the subgraph.
|
||
|
"""
|
||
|
|
||
|
if isinstance(x, ops.Tensor):
|
||
|
shape = x.get_shape()
|
||
|
x = x.op
|
||
|
if x.type in ["Variable", "VariableV2"]:
|
||
|
return shape.num_elements()
|
||
|
totals = [tf_num_params(y) for y in x.inputs]
|
||
|
return sum(totals)
|
||
|
|
||
|
|
||
|
def tf_left_split(op):
|
||
|
"""Split the parameters of op for left recursion.
|
||
|
|
||
|
Args:
|
||
|
op: tf.Operation
|
||
|
|
||
|
Returns:
|
||
|
A tuple of the leftmost input tensor and a list of the
|
||
|
remaining arguments.
|
||
|
"""
|
||
|
|
||
|
if len(op.inputs) < 1:
|
||
|
return None, []
|
||
|
if op.type == "Concat":
|
||
|
return op.inputs[1], op.inputs[2:]
|
||
|
return op.inputs[0], op.inputs[1:]
|
||
|
|
||
|
|
||
|
def tf_parameter_iter(x):
|
||
|
"""Iterate over the left branches of a graph and yield sizes.
|
||
|
|
||
|
Args:
|
||
|
x: root of the subgraph (Tensor, Operation)
|
||
|
|
||
|
Yields:
|
||
|
A triple of name, number of params, and shape.
|
||
|
"""
|
||
|
|
||
|
while 1:
|
||
|
if isinstance(x, ops.Tensor):
|
||
|
shape = x.get_shape().as_list()
|
||
|
x = x.op
|
||
|
else:
|
||
|
shape = ""
|
||
|
left, right = tf_left_split(x)
|
||
|
totals = [tf_num_params(y) for y in right]
|
||
|
total = sum(totals)
|
||
|
yield x.name, total, shape
|
||
|
if left is None:
|
||
|
break
|
||
|
x = left
|
||
|
|
||
|
|
||
|
def _combine_filter(x):
|
||
|
"""A filter for combining successive layers with similar names."""
|
||
|
last_name = None
|
||
|
last_total = 0
|
||
|
last_shape = None
|
||
|
for name, total, shape in x:
|
||
|
name = re.sub("/.*", "", name)
|
||
|
if name == last_name:
|
||
|
last_total += total
|
||
|
continue
|
||
|
if last_name is not None:
|
||
|
yield last_name, last_total, last_shape
|
||
|
last_name = name
|
||
|
last_total = total
|
||
|
last_shape = shape
|
||
|
if last_name is not None:
|
||
|
yield last_name, last_total, last_shape
|
||
|
|
||
|
|
||
|
def tf_parameter_summary(x, printer=print, combine=True):
|
||
|
"""Summarize parameters by depth.
|
||
|
|
||
|
Args:
|
||
|
x: root of the subgraph (Tensor, Operation)
|
||
|
printer: print function for output
|
||
|
combine: combine layers by top-level scope
|
||
|
"""
|
||
|
seq = tf_parameter_iter(x)
|
||
|
if combine:
|
||
|
seq = _combine_filter(seq)
|
||
|
seq = reversed(list(seq))
|
||
|
for name, total, shape in seq:
|
||
|
printer("%10d %-20s %s" % (total, name, shape))
|
||
|
|
||
|
|
||
|
def tf_spec_structure(spec,
|
||
|
inputs=None,
|
||
|
input_shape=None,
|
||
|
input_type=dtypes.float32):
|
||
|
"""Return a postfix representation of the specification.
|
||
|
|
||
|
This is intended to be used as part of test cases to
|
||
|
check for gross differences in the structure of the graph.
|
||
|
The resulting string is not invertible or unabiguous
|
||
|
and cannot be used to reconstruct the graph accurately.
|
||
|
|
||
|
Args:
|
||
|
spec: specification
|
||
|
inputs: input to the spec construction (usually a Tensor)
|
||
|
input_shape: tensor shape (in lieu of inputs)
|
||
|
input_type: type of the input tensor
|
||
|
|
||
|
Returns:
|
||
|
A string with a postfix representation of the
|
||
|
specification.
|
||
|
"""
|
||
|
|
||
|
if inputs is None:
|
||
|
inputs = array_ops.placeholder(input_type, input_shape)
|
||
|
outputs = specs.create_net(spec, inputs)
|
||
|
return str(tf_structure(outputs).strip())
|
||
|
|
||
|
|
||
|
def tf_spec_summary(spec,
|
||
|
inputs=None,
|
||
|
input_shape=None,
|
||
|
input_type=dtypes.float32):
|
||
|
"""Output a summary of the specification.
|
||
|
|
||
|
This prints a list of left-most tensor operations and summarized the
|
||
|
variables found in the right branches. This kind of representation
|
||
|
is particularly useful for networks that are generally structured
|
||
|
like pipelines.
|
||
|
|
||
|
Args:
|
||
|
spec: specification
|
||
|
inputs: input to the spec construction (usually a Tensor)
|
||
|
input_shape: optional shape of input
|
||
|
input_type: type of the input tensor
|
||
|
"""
|
||
|
|
||
|
if inputs is None:
|
||
|
inputs = array_ops.placeholder(input_type, input_shape)
|
||
|
outputs = specs.create_net(spec, inputs)
|
||
|
tf_parameter_summary(outputs)
|
||
|
|
||
|
|
||
|
def tf_spec_print(spec,
|
||
|
inputs=None,
|
||
|
input_shape=None,
|
||
|
input_type=dtypes.float32):
|
||
|
"""Print a tree representing the spec.
|
||
|
|
||
|
Args:
|
||
|
spec: specification
|
||
|
inputs: input to the spec construction (usually a Tensor)
|
||
|
input_shape: optional shape of input
|
||
|
input_type: type of the input tensor
|
||
|
"""
|
||
|
|
||
|
if inputs is None:
|
||
|
inputs = array_ops.placeholder(input_type, input_shape)
|
||
|
outputs = specs.create_net(spec, inputs)
|
||
|
tf_print(outputs)
|