133 lines
4.4 KiB
Python
133 lines
4.4 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Common utilities used across this package."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import re
|
|
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import init_ops
|
|
from tensorflow.python.ops import state_ops
|
|
from tensorflow.python.ops import variable_scope
|
|
|
|
# Skip all operations that are backprop related or export summaries.
|
|
SKIPPED_PREFIXES = (
|
|
'gradients/', 'RMSProp/', 'Adagrad/', 'Const_', 'HistogramSummary',
|
|
'ScalarSummary')
|
|
|
|
# Valid activation ops for quantization end points.
|
|
_ACTIVATION_OP_SUFFIXES = ['/Relu6', '/Relu', '/Identity']
|
|
|
|
# Regular expression for recognizing nodes that are part of batch norm group.
|
|
_BATCHNORM_RE = re.compile(r'^(.*)/BatchNorm/batchnorm')
|
|
|
|
|
|
def BatchNormGroups(graph):
|
|
"""Finds batch norm layers, returns their prefixes as a list of strings.
|
|
|
|
Args:
|
|
graph: Graph to inspect.
|
|
|
|
Returns:
|
|
List of strings, prefixes of batch norm group names found.
|
|
"""
|
|
bns = []
|
|
for op in graph.get_operations():
|
|
match = _BATCHNORM_RE.search(op.name)
|
|
if match:
|
|
bn = match.group(1)
|
|
if not bn.startswith(SKIPPED_PREFIXES):
|
|
bns.append(bn)
|
|
# Filter out duplicates.
|
|
return list(collections.OrderedDict.fromkeys(bns))
|
|
|
|
|
|
def GetEndpointActivationOp(graph, prefix):
|
|
"""Returns an Operation with the given prefix and a valid end point suffix.
|
|
|
|
Args:
|
|
graph: Graph where to look for the operation.
|
|
prefix: String, prefix of Operation to return.
|
|
|
|
Returns:
|
|
The Operation with the given prefix and a valid end point suffix or None if
|
|
there are no matching operations in the graph for any valid suffix
|
|
"""
|
|
for suffix in _ACTIVATION_OP_SUFFIXES:
|
|
activation = _GetOperationByNameDontThrow(graph, prefix + suffix)
|
|
if activation:
|
|
return activation
|
|
return None
|
|
|
|
|
|
def _GetOperationByNameDontThrow(graph, name):
|
|
"""Returns an Operation with the given name.
|
|
|
|
Args:
|
|
graph: Graph where to look for the operation.
|
|
name: String, name of Operation to return.
|
|
|
|
Returns:
|
|
The Operation with the given name. None if the name does not correspond to
|
|
any operation in the graph
|
|
"""
|
|
try:
|
|
return graph.get_operation_by_name(name)
|
|
except KeyError:
|
|
return None
|
|
|
|
|
|
def CreateOrGetQuantizationStep():
|
|
"""Returns a Tensor of the number of steps the quantized graph has run.
|
|
|
|
Returns:
|
|
Quantization step Tensor.
|
|
"""
|
|
quantization_step_name = 'fake_quantization_step'
|
|
quantization_step_tensor_name = quantization_step_name + '/Identity:0'
|
|
g = ops.get_default_graph()
|
|
try:
|
|
return g.get_tensor_by_name(quantization_step_tensor_name)
|
|
except KeyError:
|
|
# Create in proper graph and base name_scope.
|
|
with g.name_scope(None):
|
|
quantization_step_tensor = variable_scope.get_variable(
|
|
quantization_step_name,
|
|
shape=[],
|
|
dtype=dtypes.int64,
|
|
initializer=init_ops.zeros_initializer(),
|
|
trainable=False,
|
|
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
|
|
with g.name_scope(quantization_step_tensor.op.name + '/'):
|
|
# We return the incremented variable tensor. Since this is used in conds
|
|
# for quant_delay and freeze_bn_delay, it will run once per graph
|
|
# execution. We return an identity to force resource variables and
|
|
# normal variables to return a tensor of the same name.
|
|
return array_ops.identity(
|
|
state_ops.assign_add(quantization_step_tensor, 1))
|
|
|
|
|
|
def DropStringPrefix(s, prefix):
|
|
"""If the string starts with this prefix, drops it."""
|
|
if s.startswith(prefix):
|
|
return s[len(prefix):]
|
|
else:
|
|
return s
|