583 lines
22 KiB
Python
583 lines
22 KiB
Python
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Logic to update a TensorFlow model graph with quantization operations."""
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import re
|
||
|
from tensorflow.contrib import graph_editor
|
||
|
from tensorflow.contrib.quantize.python import common
|
||
|
from tensorflow.contrib.quantize.python import graph_matcher
|
||
|
from tensorflow.contrib.quantize.python import input_to_ops
|
||
|
from tensorflow.contrib.quantize.python import quant_ops
|
||
|
from tensorflow.python.framework import ops
|
||
|
from tensorflow.python.ops import control_flow_ops
|
||
|
from tensorflow.python.ops import math_ops
|
||
|
from tensorflow.python.platform import tf_logging as logging
|
||
|
|
||
|
# Quantizable operation types that are supported by the quantization rewrite.
|
||
|
_QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'}
|
||
|
|
||
|
# Activations that are supported by the quantization rewrite.
|
||
|
_ACTIVATION_TYPES = {'Relu', 'Relu6'}
|
||
|
|
||
|
|
||
|
def Quantize(graph,
|
||
|
is_training,
|
||
|
weight_bits=8,
|
||
|
activation_bits=8,
|
||
|
ema_decay=0.999,
|
||
|
quant_delay=None,
|
||
|
vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
|
||
|
scope=None):
|
||
|
"""Updates graph with quantization operations.
|
||
|
|
||
|
Currently we quantize the following tensors:
|
||
|
* Conv/MatMul: Quantize the weights if it matches.
|
||
|
* Activation: Quantize the output if it matches.
|
||
|
* Bypass/Post-activation Bypass: Quantize both input and output
|
||
|
if it matches.
|
||
|
|
||
|
Args:
|
||
|
graph: Graph to modify.
|
||
|
is_training: Whether quantizing training graph or eval graph.
|
||
|
weight_bits: Number of bits to use for quantizing weights.
|
||
|
activation_bits: Number of bits to use for quantizing activations.
|
||
|
ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update
|
||
|
quantization intervals for quantizing activations (see here about EMA:
|
||
|
https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
|
||
|
quant_delay: (Optional, default None) Int, count of global steps for which
|
||
|
to delay quantization. This helps weights stabilize at the start of
|
||
|
training.
|
||
|
vars_collection: (Optional) Collection where to store the variables for
|
||
|
quantization interval ends.
|
||
|
scope: The scope to be transformed. If it's not None, only the ops which
|
||
|
are in this scope will be transformed.
|
||
|
Raises:
|
||
|
ValueError: When quantization fails.
|
||
|
"""
|
||
|
if scope and not scope.endswith('/'):
|
||
|
scope += '/'
|
||
|
|
||
|
input_to_ops_map = input_to_ops.InputToOps(graph)
|
||
|
for layer_match in _FindLayersToQuantize(graph):
|
||
|
# Quantize the weights.
|
||
|
context = _GetContextFromOp(layer_match.layer_op)
|
||
|
|
||
|
# If `scope` is given, only quantize it if the consumer of weights
|
||
|
# (the layer op) is in the right scope.
|
||
|
_InsertQuantOp(
|
||
|
context,
|
||
|
'weights_quant',
|
||
|
layer_match.weight_tensor.op, [layer_match.layer_op],
|
||
|
is_training,
|
||
|
moving_avg=False,
|
||
|
ema_decay=ema_decay,
|
||
|
quant_delay=quant_delay,
|
||
|
narrow_range=True,
|
||
|
vars_collection=vars_collection,
|
||
|
bits=weight_bits,
|
||
|
consumer_scope=scope)
|
||
|
|
||
|
# Quantize the activations.
|
||
|
consumer_ops = input_to_ops_map.ConsumerOperations(
|
||
|
layer_match.activation_op)
|
||
|
add_context = context
|
||
|
if layer_match.bypass_op:
|
||
|
add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
|
||
|
|
||
|
# If `scope` is given, only quantize it if the producer of weights
|
||
|
# (usually it's the layer op) is in the right scope.
|
||
|
_InsertQuantOp(
|
||
|
add_context,
|
||
|
'act_quant',
|
||
|
layer_match.activation_op,
|
||
|
consumer_ops,
|
||
|
is_training,
|
||
|
moving_avg=True,
|
||
|
ema_decay=ema_decay,
|
||
|
quant_delay=quant_delay,
|
||
|
vars_collection=vars_collection,
|
||
|
bits=activation_bits,
|
||
|
init_min=0.0,
|
||
|
producer_scope=scope)
|
||
|
|
||
|
# Quantize the inputs and output to the bypass (if it exists). The input to
|
||
|
# the bypass is the bias add, and the output is the activation.
|
||
|
if layer_match.bypass_op is not None:
|
||
|
# If `scope` is given, only quantize it if the both the producer and the
|
||
|
# consumer are in the right scope.
|
||
|
_InsertQuantOp(
|
||
|
context,
|
||
|
'conv_quant',
|
||
|
layer_match.bias_add_op, [layer_match.bypass_op],
|
||
|
is_training,
|
||
|
moving_avg=True,
|
||
|
ema_decay=ema_decay,
|
||
|
quant_delay=quant_delay,
|
||
|
vars_collection=vars_collection,
|
||
|
bits=activation_bits,
|
||
|
producer_scope=scope,
|
||
|
consumer_scope=scope)
|
||
|
# Make sure the op following this isn't an activation. In which case, we
|
||
|
# shouldn't quantize it, since the activation will be Fused into the
|
||
|
# Add at inference time.
|
||
|
consumers = input_to_ops_map.ConsumerOperations(layer_match.bypass_op)
|
||
|
if any([consumer.type in _ACTIVATION_TYPES for consumer in consumers]):
|
||
|
logging.info('Skipping %s, because its followed by an activation.',
|
||
|
layer_match.bypass_op.name)
|
||
|
else:
|
||
|
_InsertQuantOp(
|
||
|
add_context,
|
||
|
'add_quant',
|
||
|
layer_match.bypass_op,
|
||
|
input_to_ops_map.ConsumerOperations(layer_match.bypass_op),
|
||
|
is_training,
|
||
|
moving_avg=True,
|
||
|
ema_decay=ema_decay,
|
||
|
quant_delay=quant_delay,
|
||
|
vars_collection=vars_collection,
|
||
|
bits=activation_bits,
|
||
|
producer_scope=scope,
|
||
|
consumer_scope=scope)
|
||
|
|
||
|
# Quantize bypass ops that occur after the activation.
|
||
|
if layer_match.post_activation_bypass_op is not None:
|
||
|
post_activation_bypass_context = re.search(
|
||
|
r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name).group(1)
|
||
|
# If `scope` is given, only quantize it if the producer is in the right
|
||
|
# scope.
|
||
|
# Make sure the op following this isn't an activation. In which case, we
|
||
|
# shouldn't quantize it, since the activation will be Fused into the
|
||
|
# Add at inference time.
|
||
|
consumers = input_to_ops_map.ConsumerOperations(
|
||
|
layer_match.post_activation_bypass_op)
|
||
|
if any([consumer.type in _ACTIVATION_TYPES for consumer in consumers]):
|
||
|
logging.info('Skipping %s, because its followed by an activation.',
|
||
|
layer_match.post_activation_bypass_op.name)
|
||
|
else:
|
||
|
_InsertQuantOp(
|
||
|
post_activation_bypass_context,
|
||
|
'post_activation_bypass_quant',
|
||
|
layer_match.post_activation_bypass_op,
|
||
|
consumers,
|
||
|
is_training,
|
||
|
moving_avg=True,
|
||
|
ema_decay=ema_decay,
|
||
|
quant_delay=quant_delay,
|
||
|
vars_collection=vars_collection,
|
||
|
bits=activation_bits,
|
||
|
producer_scope=scope)
|
||
|
|
||
|
|
||
|
def _FindLayersToQuantize(graph):
|
||
|
"""Matches layers in graph to quantize.
|
||
|
|
||
|
The following patterns get matched. Nodes surrounded by [] will be
|
||
|
optionally matched:
|
||
|
|
||
|
weight|folded_weight
|
||
|
/
|
||
|
conv|fc
|
||
|
|
|
||
|
[batch_to_space_nd]
|
||
|
|
|
||
|
[post_conv_correction]
|
||
|
|
|
||
|
biasadd|folded_bias
|
||
|
|
|
||
|
[bypass]
|
||
|
|
|
||
|
activation
|
||
|
|
|
||
|
[post_activation_bypass]
|
||
|
|
||
|
Match replacements:
|
||
|
If weight|folded_weight is found, FakeQuant is added afterwards.
|
||
|
If bypass is found, FakeQuant is added before and after.
|
||
|
If activation is found, FakeQuant is added afterwards.
|
||
|
If post_activation_bypass is found, FakeQuant is added afterwards.
|
||
|
|
||
|
Args:
|
||
|
graph: Graph to perform match on.
|
||
|
|
||
|
Returns:
|
||
|
list of _LayerMatches.
|
||
|
"""
|
||
|
input_pattern = graph_matcher.OpTypePattern('*')
|
||
|
weight_var_pattern = graph_matcher.OpTypePattern('Variable|VariableV2')
|
||
|
weight_partition_identity_pattern = graph_matcher.OpTypePattern(
|
||
|
'Identity', inputs=[weight_var_pattern])
|
||
|
weight_partition_concat_pattern = graph_matcher.OpTypePattern(
|
||
|
'ConcatV2', inputs=[weight_partition_identity_pattern, '*', '*'])
|
||
|
weight_identity_pattern = graph_matcher.OpTypePattern(
|
||
|
'Identity',
|
||
|
inputs=[
|
||
|
graph_matcher.OneofPattern([
|
||
|
weight_partition_identity_pattern,
|
||
|
weight_partition_concat_pattern,
|
||
|
weight_var_pattern,
|
||
|
])
|
||
|
])
|
||
|
weight_resource_var_pattern = graph_matcher.OpTypePattern('ReadVariableOp')
|
||
|
folded_weight_pattern = graph_matcher.OpTypePattern('Mul')
|
||
|
|
||
|
# The weights inputs to the layer operation can either be from the Variable or
|
||
|
# the folded weight (Mul).
|
||
|
layer_pattern = graph_matcher.OpTypePattern(
|
||
|
'|'.join(_QUANTIZABLE_TYPES),
|
||
|
inputs=[
|
||
|
input_pattern,
|
||
|
graph_matcher.OneofPattern([
|
||
|
weight_identity_pattern, weight_resource_var_pattern,
|
||
|
folded_weight_pattern
|
||
|
])
|
||
|
],
|
||
|
ordered_inputs=False)
|
||
|
|
||
|
# For atrous convolutions a BatchToSpaceND will occur after the depthwise
|
||
|
# convolution.
|
||
|
batch_to_space_pattern = graph_matcher.OpTypePattern(
|
||
|
'BatchToSpaceND',
|
||
|
inputs=[
|
||
|
layer_pattern,
|
||
|
graph_matcher.OpTypePattern('*'),
|
||
|
graph_matcher.OpTypePattern('*')
|
||
|
])
|
||
|
|
||
|
layer_output_pattern = graph_matcher.OneofPattern(
|
||
|
[batch_to_space_pattern, layer_pattern])
|
||
|
folded_bias_mul_pattern = graph_matcher.OpTypePattern(
|
||
|
'Mul',
|
||
|
inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern],
|
||
|
ordered_inputs=False)
|
||
|
post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
|
||
|
'Add',
|
||
|
inputs=[folded_bias_mul_pattern,
|
||
|
graph_matcher.OpTypePattern('*')],
|
||
|
ordered_inputs=False)
|
||
|
folded_bias_add_pattern = graph_matcher.OpTypePattern(
|
||
|
'Add',
|
||
|
inputs=[
|
||
|
post_layer_op_correction_pattern,
|
||
|
graph_matcher.OpTypePattern('*')
|
||
|
],
|
||
|
ordered_inputs=False)
|
||
|
|
||
|
# batch_norms with forced updates have an Identity operation at the end.
|
||
|
# TODO(suharshs): Find a way to easily skip extra Identity operations. The
|
||
|
# current issue is that doing so can often match patterns across many layers
|
||
|
# incorrectly.
|
||
|
batch_norm_identity = graph_matcher.OpTypePattern(
|
||
|
'Identity', inputs=[folded_bias_add_pattern])
|
||
|
|
||
|
bias_add_pattern = graph_matcher.OpTypePattern(
|
||
|
'Add|BiasAdd', inputs=[layer_output_pattern, '*'], ordered_inputs=False)
|
||
|
|
||
|
# The bias can come from the bias add or the folded bias add.
|
||
|
bypass_pattern = graph_matcher.OpTypePattern(
|
||
|
'Add',
|
||
|
inputs=[
|
||
|
graph_matcher.OneofPattern(
|
||
|
[bias_add_pattern, folded_bias_add_pattern, batch_norm_identity]),
|
||
|
'*'
|
||
|
],
|
||
|
ordered_inputs=False)
|
||
|
|
||
|
# The input to the activation can come from bias add, fold bias add, the
|
||
|
# bypasses.
|
||
|
# TODO(suharshs): We should ideally skip Identity operations instead of
|
||
|
# treating them as activations.
|
||
|
activation_pattern = graph_matcher.OpTypePattern(
|
||
|
'|'.join(_ACTIVATION_TYPES) + '|Identity',
|
||
|
inputs=[
|
||
|
graph_matcher.OneofPattern([
|
||
|
bias_add_pattern,
|
||
|
folded_bias_add_pattern,
|
||
|
batch_norm_identity,
|
||
|
bypass_pattern,
|
||
|
])
|
||
|
])
|
||
|
|
||
|
post_activation_bypass_pattern = graph_matcher.OpTypePattern(
|
||
|
'Add', inputs=['*', activation_pattern], ordered_inputs=False)
|
||
|
|
||
|
# The order of the following matching blocks is very important. Since matches
|
||
|
# aren't guaranteed to be disjoint, we structure matches from largest to
|
||
|
# smallest to guarantee that the largest match always wins. Additionally, we
|
||
|
# ensure that we don't match layers multiple times.
|
||
|
|
||
|
layer_matches = []
|
||
|
# We use matched_layer_set to ensure that layers aren't matched multiple
|
||
|
# times.
|
||
|
matched_layer_set = set()
|
||
|
|
||
|
# First, we match layers that have a post activation bypass. We do this first
|
||
|
# to ensure we don't match only the first part of this layer, missing the
|
||
|
# post activation bypass node.
|
||
|
post_activation_bypass_layer_matcher = graph_matcher.GraphMatcher(
|
||
|
post_activation_bypass_pattern)
|
||
|
for match_result in post_activation_bypass_layer_matcher.match_graph(graph):
|
||
|
layer_op = match_result.get_op(layer_pattern)
|
||
|
weight_tensor = match_result.get_tensor(weight_identity_pattern)
|
||
|
if weight_tensor is None:
|
||
|
weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
|
||
|
if weight_tensor is None:
|
||
|
weight_tensor = match_result.get_tensor(folded_weight_pattern)
|
||
|
activation_op = match_result.get_op(activation_pattern)
|
||
|
bias_add_op = match_result.get_op(bias_add_pattern)
|
||
|
if bias_add_op is None:
|
||
|
bias_add_op = match_result.get_op(folded_bias_add_pattern)
|
||
|
bypass_op = match_result.get_op(bypass_pattern)
|
||
|
post_activation_bypass_op = match_result.get_op(
|
||
|
post_activation_bypass_pattern)
|
||
|
if layer_op not in matched_layer_set:
|
||
|
matched_layer_set.add(layer_op)
|
||
|
layer_matches.append(
|
||
|
_LayerMatch(layer_op, weight_tensor, activation_op, bypass_op,
|
||
|
post_activation_bypass_op, bias_add_op))
|
||
|
|
||
|
# Now, we match the basic layer ending at an activation. We may get duplicate
|
||
|
# matches from above, but we don't add them to layer_matches.
|
||
|
layer_matcher = graph_matcher.GraphMatcher(activation_pattern)
|
||
|
for match_result in layer_matcher.match_graph(graph):
|
||
|
layer_op = match_result.get_op(layer_pattern)
|
||
|
weight_tensor = match_result.get_tensor(weight_identity_pattern)
|
||
|
if weight_tensor is None:
|
||
|
weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
|
||
|
if weight_tensor is None:
|
||
|
weight_tensor = match_result.get_tensor(folded_weight_pattern)
|
||
|
activation_op = match_result.get_op(activation_pattern)
|
||
|
bias_add_op = match_result.get_op(bias_add_pattern)
|
||
|
if bias_add_op is None:
|
||
|
bias_add_op = match_result.get_op(folded_bias_add_pattern)
|
||
|
bypass_op = match_result.get_op(bypass_pattern)
|
||
|
if layer_op not in matched_layer_set:
|
||
|
matched_layer_set.add(layer_op)
|
||
|
layer_matches.append(
|
||
|
_LayerMatch(layer_op, weight_tensor, activation_op, bypass_op, None,
|
||
|
bias_add_op))
|
||
|
|
||
|
# Match the final layer, where there may not be an activation and instead
|
||
|
# the output of the final BiasAdd must be quantized. So we treat the BiasAdd
|
||
|
# as the 'activation_op' in the _LayerMatch, to ensure that it's output is
|
||
|
# quantized.
|
||
|
final_layer_matcher = graph_matcher.GraphMatcher(
|
||
|
graph_matcher.OneofPattern([bias_add_pattern, folded_bias_add_pattern]))
|
||
|
for match_result in final_layer_matcher.match_graph(graph):
|
||
|
layer_op = match_result.get_op(layer_pattern)
|
||
|
weight_tensor = match_result.get_tensor(weight_identity_pattern)
|
||
|
if weight_tensor is None:
|
||
|
weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
|
||
|
if weight_tensor is None:
|
||
|
weight_tensor = match_result.get_tensor(folded_weight_pattern)
|
||
|
activation_op = match_result.get_op(bias_add_pattern)
|
||
|
if activation_op is None:
|
||
|
activation_op = match_result.get_op(folded_bias_add_pattern)
|
||
|
if layer_op not in matched_layer_set:
|
||
|
matched_layer_set.add(layer_op)
|
||
|
layer_matches.append(
|
||
|
_LayerMatch(layer_op, weight_tensor, activation_op, None, None, None))
|
||
|
|
||
|
return layer_matches
|
||
|
|
||
|
|
||
|
class _LayerMatch(object):
|
||
|
"""Contains all information related to a matched Layer."""
|
||
|
|
||
|
def __init__(self, layer_op, weight_tensor, activation_op, bypass_op,
|
||
|
post_activation_bypass_op, bias_add_op):
|
||
|
self._layer_op = layer_op
|
||
|
self._weight_tensor = weight_tensor
|
||
|
self._activation_op = activation_op
|
||
|
self._bypass_op = bypass_op
|
||
|
self._post_activation_bypass_op = post_activation_bypass_op
|
||
|
self._bias_add_op = bias_add_op
|
||
|
|
||
|
@property
|
||
|
def layer_op(self):
|
||
|
return self._layer_op
|
||
|
|
||
|
@property
|
||
|
def weight_tensor(self):
|
||
|
return self._weight_tensor
|
||
|
|
||
|
@property
|
||
|
def activation_op(self):
|
||
|
return self._activation_op
|
||
|
|
||
|
@property
|
||
|
def bypass_op(self):
|
||
|
return self._bypass_op
|
||
|
|
||
|
@property
|
||
|
def post_activation_bypass_op(self):
|
||
|
return self._post_activation_bypass_op
|
||
|
|
||
|
@property
|
||
|
def bias_add_op(self):
|
||
|
return self._bias_add_op
|
||
|
|
||
|
|
||
|
def _InsertQuantOp(context,
|
||
|
name,
|
||
|
producer,
|
||
|
consumers,
|
||
|
is_training,
|
||
|
moving_avg=True,
|
||
|
init_min=-6.0,
|
||
|
init_max=6.0,
|
||
|
bits=8,
|
||
|
ema_decay=0.999,
|
||
|
quant_delay=None,
|
||
|
vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
|
||
|
narrow_range=False,
|
||
|
producer_scope=None,
|
||
|
consumer_scope=None):
|
||
|
"""Inserts a quant op between a producer op and (multiple) consumer ops.
|
||
|
|
||
|
Args:
|
||
|
context: Context where producer and consumer operations are nested.
|
||
|
name: Name for the new quantization op within the context.
|
||
|
producer: Producer operation of the pairs where quantization will be
|
||
|
inserted.
|
||
|
consumers: Consumer operations of the pairs.
|
||
|
is_training: Whether quantizing training graph or eval graph.
|
||
|
moving_avg: Specifies whether to use exponential moving average or just
|
||
|
the last value seen.
|
||
|
init_min: Starting minimum value for the new quantization op.
|
||
|
init_max: Starting maximum value for the new quantization op.
|
||
|
bits: Number of bits to use for quantization, must be between 2 and 8.
|
||
|
ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update
|
||
|
quantization intervals for quantizing activations (see here about EMA:
|
||
|
https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
|
||
|
quant_delay: (Optional, default None) Int, count of global steps for which
|
||
|
to delay quantization. This helps weights stabilize at the start of
|
||
|
training.
|
||
|
vars_collection: (Optional) Collection where to store the variables for
|
||
|
quantization interval ends.
|
||
|
narrow_range: Whether to use the narrow quantization range
|
||
|
[1; 2^bits - 1] or wide range [0; 2^bits - 1].
|
||
|
producer_scope: The restriction of producer scope. If not None, the new op
|
||
|
will be inserted only when the producer is in this scope.
|
||
|
consumer_scope: The restriction of producer scope. If not None, the new op
|
||
|
will be inserted only when all the consumers are in this scope.
|
||
|
Raises:
|
||
|
ValueError: When producer operation is not directly connected to the
|
||
|
consumer operation.
|
||
|
"""
|
||
|
if producer_scope and not producer.name.startswith(producer_scope):
|
||
|
logging.info(
|
||
|
'_InsertQuantOp ignores context="%s" name="%s" '
|
||
|
'because producer "%s" is not in scope "%s"',
|
||
|
context, name, producer.name, producer_scope)
|
||
|
return
|
||
|
|
||
|
if consumer_scope:
|
||
|
consumers_in_scope = []
|
||
|
for consumer in consumers:
|
||
|
if consumer.name.startswith(consumer_scope):
|
||
|
consumers_in_scope.append(consumer)
|
||
|
else:
|
||
|
logging.info(
|
||
|
'_InsertQuantOp context="%s" name="%s" ignores '
|
||
|
'consumer "%s" because it is not in scope "%s"',
|
||
|
context, name, consumer.name, consumer_scope)
|
||
|
return
|
||
|
consumers = consumers_in_scope
|
||
|
|
||
|
name_prefix = _AddContextToName(context, name)
|
||
|
# This is needed on TPU where name_scope == 'TPUReplicate/loop', and
|
||
|
# name_prefix starts with 'TPUReplicate/loop/'; without dropping it
|
||
|
# variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
|
||
|
# breaks things later.
|
||
|
name_scope = ops.get_name_scope()
|
||
|
if name_scope:
|
||
|
name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')
|
||
|
|
||
|
inputs = producer.outputs[0]
|
||
|
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
|
||
|
# overlap between multiple matches, so we need to ensure that we don't
|
||
|
# add duplicate FakeQuant operations.
|
||
|
fake_quant_ops = set([
|
||
|
'FakeQuantWithMinMaxVars',
|
||
|
'FakeQuantWithMinMaxArgs'
|
||
|
])
|
||
|
if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])):
|
||
|
return
|
||
|
|
||
|
if moving_avg:
|
||
|
quant = (
|
||
|
quant_ops.MovingAvgQuantize(
|
||
|
inputs,
|
||
|
init_min=init_min,
|
||
|
init_max=init_max,
|
||
|
ema_decay=ema_decay,
|
||
|
is_training=is_training,
|
||
|
num_bits=bits,
|
||
|
narrow_range=narrow_range,
|
||
|
vars_collection=vars_collection,
|
||
|
name_prefix=name_prefix))
|
||
|
else:
|
||
|
quant = (
|
||
|
quant_ops.LastValueQuantize(
|
||
|
inputs,
|
||
|
init_min=init_min,
|
||
|
init_max=init_max,
|
||
|
is_training=is_training,
|
||
|
num_bits=bits,
|
||
|
narrow_range=narrow_range,
|
||
|
vars_collection=vars_collection,
|
||
|
name_prefix=name_prefix))
|
||
|
|
||
|
if quant_delay and quant_delay > 0:
|
||
|
activate_quant = math_ops.greater_equal(
|
||
|
common.CreateOrGetQuantizationStep(),
|
||
|
quant_delay,
|
||
|
name=name_prefix + '/activate_quant')
|
||
|
quant = control_flow_ops.cond(
|
||
|
activate_quant,
|
||
|
lambda: quant,
|
||
|
lambda: inputs,
|
||
|
name=name_prefix + '/delayed_quant')
|
||
|
|
||
|
if consumers:
|
||
|
tensors_modified_count = graph_editor.reroute_ts(
|
||
|
[quant], [inputs], can_modify=consumers)
|
||
|
# Some operations can have multiple output tensors going to the same
|
||
|
# consumer. Since consumers is a set, we need to ensure that
|
||
|
# tensors_modified_count is greater than or equal to the length of the set
|
||
|
# of consumers.
|
||
|
if tensors_modified_count < len(consumers):
|
||
|
raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
|
||
|
[consumer.name for consumer in consumers]))
|
||
|
|
||
|
|
||
|
def _GetContextFromOp(op):
|
||
|
"""Gets the root context name from the op name."""
|
||
|
context_re = re.search(r'^(.*)/([^/]+)', op.name)
|
||
|
if context_re:
|
||
|
return context_re.group(1)
|
||
|
return ''
|
||
|
|
||
|
|
||
|
def _AddContextToName(context, name):
|
||
|
"""Adds the context to the name if it exists."""
|
||
|
if not context:
|
||
|
return name
|
||
|
return context + '/' + name
|