382 lines
14 KiB
Python
382 lines
14 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
r"""Converts checkpoint variables into Const ops in a standalone GraphDef file.
|
|
|
|
This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
|
|
variable values stored in a checkpoint file, and output a GraphDef with all of
|
|
the variable ops converted into const ops containing the values of the
|
|
variables.
|
|
|
|
It's useful to do this when we need to load a single file in C++, especially in
|
|
environments like mobile or embedded where we may not have access to the
|
|
RestoreTensor ops and file loading calls that they rely on.
|
|
|
|
An example of command-line usage is:
|
|
bazel build tensorflow/python/tools:freeze_graph && \
|
|
bazel-bin/tensorflow/python/tools/freeze_graph \
|
|
--input_graph=some_graph_def.pb \
|
|
--input_checkpoint=model.ckpt-8361242 \
|
|
--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
|
|
|
|
You can also look at freeze_graph_test.py for an example of how to use it.
|
|
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import sys
|
|
|
|
from google.protobuf import text_format
|
|
|
|
from tensorflow.core.framework import graph_pb2
|
|
from tensorflow.core.protobuf import saver_pb2
|
|
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
|
|
from tensorflow.python import pywrap_tensorflow
|
|
from tensorflow.python.client import session
|
|
from tensorflow.python.framework import graph_util
|
|
from tensorflow.python.framework import importer
|
|
from tensorflow.python.platform import app
|
|
from tensorflow.python.platform import gfile
|
|
from tensorflow.python.saved_model import loader
|
|
from tensorflow.python.saved_model import tag_constants
|
|
from tensorflow.python.tools import saved_model_utils
|
|
from tensorflow.python.training import saver as saver_lib
|
|
|
|
|
|
def freeze_graph_with_def_protos(input_graph_def,
|
|
input_saver_def,
|
|
input_checkpoint,
|
|
output_node_names,
|
|
restore_op_name,
|
|
filename_tensor_name,
|
|
output_graph,
|
|
clear_devices,
|
|
initializer_nodes,
|
|
variable_names_whitelist="",
|
|
variable_names_blacklist="",
|
|
input_meta_graph_def=None,
|
|
input_saved_model_dir=None,
|
|
saved_model_tags=None,
|
|
checkpoint_version=saver_pb2.SaverDef.V2):
|
|
"""Converts all variables in a graph and checkpoint into constants."""
|
|
del restore_op_name, filename_tensor_name # Unused by updated loading code.
|
|
|
|
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
|
|
if (not input_saved_model_dir and
|
|
not saver_lib.checkpoint_exists(input_checkpoint)):
|
|
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
|
|
return -1
|
|
|
|
if not output_node_names:
|
|
print("You need to supply the name of a node to --output_node_names.")
|
|
return -1
|
|
|
|
# Remove all the explicit device specifications for this node. This helps to
|
|
# make the graph more portable.
|
|
if clear_devices:
|
|
if input_meta_graph_def:
|
|
for node in input_meta_graph_def.graph_def.node:
|
|
node.device = ""
|
|
elif input_graph_def:
|
|
for node in input_graph_def.node:
|
|
node.device = ""
|
|
|
|
if input_graph_def:
|
|
_ = importer.import_graph_def(input_graph_def, name="")
|
|
with session.Session() as sess:
|
|
if input_saver_def:
|
|
saver = saver_lib.Saver(
|
|
saver_def=input_saver_def, write_version=checkpoint_version)
|
|
saver.restore(sess, input_checkpoint)
|
|
elif input_meta_graph_def:
|
|
restorer = saver_lib.import_meta_graph(
|
|
input_meta_graph_def, clear_devices=True)
|
|
restorer.restore(sess, input_checkpoint)
|
|
if initializer_nodes:
|
|
sess.run(initializer_nodes.replace(" ", "").split(","))
|
|
elif input_saved_model_dir:
|
|
if saved_model_tags is None:
|
|
saved_model_tags = []
|
|
loader.load(sess, saved_model_tags, input_saved_model_dir)
|
|
else:
|
|
var_list = {}
|
|
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
|
|
var_to_shape_map = reader.get_variable_to_shape_map()
|
|
for key in var_to_shape_map:
|
|
try:
|
|
tensor = sess.graph.get_tensor_by_name(key + ":0")
|
|
except KeyError:
|
|
# This tensor doesn't exist in the graph (for example it's
|
|
# 'global_step' or a similar housekeeping element) so skip it.
|
|
continue
|
|
var_list[key] = tensor
|
|
saver = saver_lib.Saver(
|
|
var_list=var_list, write_version=checkpoint_version)
|
|
saver.restore(sess, input_checkpoint)
|
|
if initializer_nodes:
|
|
sess.run(initializer_nodes.replace(" ", "").split(","))
|
|
|
|
variable_names_whitelist = (
|
|
variable_names_whitelist.replace(" ", "").split(",")
|
|
if variable_names_whitelist else None)
|
|
variable_names_blacklist = (
|
|
variable_names_blacklist.replace(" ", "").split(",")
|
|
if variable_names_blacklist else None)
|
|
|
|
if input_meta_graph_def:
|
|
output_graph_def = graph_util.convert_variables_to_constants(
|
|
sess,
|
|
input_meta_graph_def.graph_def,
|
|
output_node_names.replace(" ", "").split(","),
|
|
variable_names_whitelist=variable_names_whitelist,
|
|
variable_names_blacklist=variable_names_blacklist)
|
|
else:
|
|
output_graph_def = graph_util.convert_variables_to_constants(
|
|
sess,
|
|
input_graph_def,
|
|
output_node_names.replace(" ", "").split(","),
|
|
variable_names_whitelist=variable_names_whitelist,
|
|
variable_names_blacklist=variable_names_blacklist)
|
|
|
|
# Write GraphDef to file if output path has been given.
|
|
if output_graph:
|
|
with gfile.GFile(output_graph, "wb") as f:
|
|
f.write(output_graph_def.SerializeToString())
|
|
|
|
return output_graph_def
|
|
|
|
|
|
def _parse_input_graph_proto(input_graph, input_binary):
|
|
"""Parser input tensorflow graph into GraphDef proto."""
|
|
if not gfile.Exists(input_graph):
|
|
print("Input graph file '" + input_graph + "' does not exist!")
|
|
return -1
|
|
input_graph_def = graph_pb2.GraphDef()
|
|
mode = "rb" if input_binary else "r"
|
|
with gfile.FastGFile(input_graph, mode) as f:
|
|
if input_binary:
|
|
input_graph_def.ParseFromString(f.read())
|
|
else:
|
|
text_format.Merge(f.read(), input_graph_def)
|
|
return input_graph_def
|
|
|
|
|
|
def _parse_input_meta_graph_proto(input_graph, input_binary):
|
|
"""Parser input tensorflow graph into MetaGraphDef proto."""
|
|
if not gfile.Exists(input_graph):
|
|
print("Input meta graph file '" + input_graph + "' does not exist!")
|
|
return -1
|
|
input_meta_graph_def = MetaGraphDef()
|
|
mode = "rb" if input_binary else "r"
|
|
with gfile.FastGFile(input_graph, mode) as f:
|
|
if input_binary:
|
|
input_meta_graph_def.ParseFromString(f.read())
|
|
else:
|
|
text_format.Merge(f.read(), input_meta_graph_def)
|
|
print("Loaded meta graph file '" + input_graph)
|
|
return input_meta_graph_def
|
|
|
|
|
|
def _parse_input_saver_proto(input_saver, input_binary):
|
|
"""Parser input tensorflow Saver into SaverDef proto."""
|
|
if not gfile.Exists(input_saver):
|
|
print("Input saver file '" + input_saver + "' does not exist!")
|
|
return -1
|
|
mode = "rb" if input_binary else "r"
|
|
with gfile.FastGFile(input_saver, mode) as f:
|
|
saver_def = saver_pb2.SaverDef()
|
|
if input_binary:
|
|
saver_def.ParseFromString(f.read())
|
|
else:
|
|
text_format.Merge(f.read(), saver_def)
|
|
return saver_def
|
|
|
|
|
|
def freeze_graph(input_graph,
|
|
input_saver,
|
|
input_binary,
|
|
input_checkpoint,
|
|
output_node_names,
|
|
restore_op_name,
|
|
filename_tensor_name,
|
|
output_graph,
|
|
clear_devices,
|
|
initializer_nodes,
|
|
variable_names_whitelist="",
|
|
variable_names_blacklist="",
|
|
input_meta_graph=None,
|
|
input_saved_model_dir=None,
|
|
saved_model_tags=tag_constants.SERVING,
|
|
checkpoint_version=saver_pb2.SaverDef.V2):
|
|
"""Converts all variables in a graph and checkpoint into constants."""
|
|
input_graph_def = None
|
|
if input_saved_model_dir:
|
|
input_graph_def = saved_model_utils.get_meta_graph_def(
|
|
input_saved_model_dir, saved_model_tags).graph_def
|
|
elif input_graph:
|
|
input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
|
|
input_meta_graph_def = None
|
|
if input_meta_graph:
|
|
input_meta_graph_def = _parse_input_meta_graph_proto(
|
|
input_meta_graph, input_binary)
|
|
input_saver_def = None
|
|
if input_saver:
|
|
input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
|
|
freeze_graph_with_def_protos(
|
|
input_graph_def,
|
|
input_saver_def,
|
|
input_checkpoint,
|
|
output_node_names,
|
|
restore_op_name,
|
|
filename_tensor_name,
|
|
output_graph,
|
|
clear_devices,
|
|
initializer_nodes,
|
|
variable_names_whitelist,
|
|
variable_names_blacklist,
|
|
input_meta_graph_def,
|
|
input_saved_model_dir,
|
|
saved_model_tags.replace(" ", "").split(","),
|
|
checkpoint_version=checkpoint_version)
|
|
|
|
|
|
def main(unused_args, flags):
|
|
if flags.checkpoint_version == 1:
|
|
checkpoint_version = saver_pb2.SaverDef.V1
|
|
elif flags.checkpoint_version == 2:
|
|
checkpoint_version = saver_pb2.SaverDef.V2
|
|
else:
|
|
print("Invalid checkpoint version (must be '1' or '2'): %d" %
|
|
flags.checkpoint_version)
|
|
return -1
|
|
freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
|
|
flags.input_checkpoint, flags.output_node_names,
|
|
flags.restore_op_name, flags.filename_tensor_name,
|
|
flags.output_graph, flags.clear_devices, flags.initializer_nodes,
|
|
flags.variable_names_whitelist, flags.variable_names_blacklist,
|
|
flags.input_meta_graph, flags.input_saved_model_dir,
|
|
flags.saved_model_tags, checkpoint_version)
|
|
|
|
def run_main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.register("type", "bool", lambda v: v.lower() == "true")
|
|
parser.add_argument(
|
|
"--input_graph",
|
|
type=str,
|
|
default="",
|
|
help="TensorFlow \'GraphDef\' file to load.")
|
|
parser.add_argument(
|
|
"--input_saver",
|
|
type=str,
|
|
default="",
|
|
help="TensorFlow saver file to load.")
|
|
parser.add_argument(
|
|
"--input_checkpoint",
|
|
type=str,
|
|
default="",
|
|
help="TensorFlow variables file to load.")
|
|
parser.add_argument(
|
|
"--checkpoint_version",
|
|
type=int,
|
|
default=2,
|
|
help="Tensorflow variable file format")
|
|
parser.add_argument(
|
|
"--output_graph",
|
|
type=str,
|
|
default="",
|
|
help="Output \'GraphDef\' file name.")
|
|
parser.add_argument(
|
|
"--input_binary",
|
|
nargs="?",
|
|
const=True,
|
|
type="bool",
|
|
default=False,
|
|
help="Whether the input files are in binary format.")
|
|
parser.add_argument(
|
|
"--output_node_names",
|
|
type=str,
|
|
default="",
|
|
help="The name of the output nodes, comma separated.")
|
|
parser.add_argument(
|
|
"--restore_op_name",
|
|
type=str,
|
|
default="save/restore_all",
|
|
help="""\
|
|
The name of the master restore operator. Deprecated, unused by updated \
|
|
loading code.
|
|
""")
|
|
parser.add_argument(
|
|
"--filename_tensor_name",
|
|
type=str,
|
|
default="save/Const:0",
|
|
help="""\
|
|
The name of the tensor holding the save path. Deprecated, unused by \
|
|
updated loading code.
|
|
""")
|
|
parser.add_argument(
|
|
"--clear_devices",
|
|
nargs="?",
|
|
const=True,
|
|
type="bool",
|
|
default=True,
|
|
help="Whether to remove device specifications.")
|
|
parser.add_argument(
|
|
"--initializer_nodes",
|
|
type=str,
|
|
default="",
|
|
help="Comma separated list of initializer nodes to run before freezing.")
|
|
parser.add_argument(
|
|
"--variable_names_whitelist",
|
|
type=str,
|
|
default="",
|
|
help="""\
|
|
Comma separated list of variables to convert to constants. If specified, \
|
|
only those variables will be converted to constants.\
|
|
""")
|
|
parser.add_argument(
|
|
"--variable_names_blacklist",
|
|
type=str,
|
|
default="",
|
|
help="""\
|
|
Comma separated list of variables to skip converting to constants.\
|
|
""")
|
|
parser.add_argument(
|
|
"--input_meta_graph",
|
|
type=str,
|
|
default="",
|
|
help="TensorFlow \'MetaGraphDef\' file to load.")
|
|
parser.add_argument(
|
|
"--input_saved_model_dir",
|
|
type=str,
|
|
default="",
|
|
help="Path to the dir with TensorFlow \'SavedModel\' file and variables.")
|
|
parser.add_argument(
|
|
"--saved_model_tags",
|
|
type=str,
|
|
default="serve",
|
|
help="""\
|
|
Group of tag(s) of the MetaGraphDef to load, in string format,\
|
|
separated by \',\'. For tag-set contains multiple tags, all tags \
|
|
must be passed in.\
|
|
""")
|
|
flags, unparsed = parser.parse_known_args()
|
|
|
|
my_main = lambda unused_args: main(unused_args, flags)
|
|
app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
|
|
|
|
if __name__ == '__main__':
|
|
run_main()
|