250 lines
8.6 KiB
Python
250 lines
8.6 KiB
Python
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Python TF-Lite interpreter."""
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import sys
|
||
|
from tensorflow.python.util.lazy_loader import LazyLoader
|
||
|
|
||
|
# Lazy load since some of the performance benchmark skylark rules
|
||
|
# break dependencies. Must use double quotes to match code internal rewrite
|
||
|
# rule.
|
||
|
# pylint: disable=g-inconsistent-quotes
|
||
|
_interpreter_wrapper = LazyLoader(
|
||
|
"_interpreter_wrapper", globals(),
|
||
|
"tensorflow.contrib.lite.python.interpreter_wrapper."
|
||
|
"tensorflow_wrap_interpreter_wrapper")
|
||
|
# pylint: enable=g-inconsistent-quotes
|
||
|
|
||
|
del LazyLoader
|
||
|
|
||
|
|
||
|
class Interpreter(object):
|
||
|
"""Interpreter inferace for TF-Lite Models."""
|
||
|
|
||
|
def __init__(self, model_path=None, model_content=None):
|
||
|
"""Constructor.
|
||
|
|
||
|
Args:
|
||
|
model_path: Path to TF-Lite Flatbuffer file.
|
||
|
model_content: Content of model.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the interpreter was unable to create.
|
||
|
"""
|
||
|
if model_path and not model_content:
|
||
|
self._interpreter = (
|
||
|
_interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile(
|
||
|
model_path))
|
||
|
if not self._interpreter:
|
||
|
raise ValueError('Failed to open {}'.format(model_path))
|
||
|
elif model_content and not model_path:
|
||
|
self._interpreter = (
|
||
|
_interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer(
|
||
|
model_content))
|
||
|
elif not model_path and not model_path:
|
||
|
raise ValueError('`model_path` or `model_content` must be specified.')
|
||
|
else:
|
||
|
raise ValueError('Can\'t both provide `model_path` and `model_content`')
|
||
|
|
||
|
def allocate_tensors(self):
|
||
|
self._ensure_safe()
|
||
|
return self._interpreter.AllocateTensors()
|
||
|
|
||
|
def _safe_to_run(self):
|
||
|
"""Returns true if there exist no numpy array buffers.
|
||
|
|
||
|
This means it is safe to run tflite calls that may destroy internally
|
||
|
allocated memory. This works, because in the wrapper.cc we have made
|
||
|
the numpy base be the self._interpreter.
|
||
|
"""
|
||
|
# NOTE, our tensor() call in cpp will use _interpreter as a base pointer.
|
||
|
# If this environment is the only _interpreter, then the ref count should be
|
||
|
# 2 (1 in self and 1 in temporary of sys.getrefcount).
|
||
|
return sys.getrefcount(self._interpreter) == 2
|
||
|
|
||
|
def _ensure_safe(self):
|
||
|
"""Makes sure no numpy arrays pointing to internal buffers are active.
|
||
|
|
||
|
This should be called from any function that will call a function on
|
||
|
_interpreter that may reallocate memory e.g. invoke(), ...
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError: If there exist numpy objects pointing to internal memory
|
||
|
then we throw.
|
||
|
"""
|
||
|
if not self._safe_to_run():
|
||
|
raise RuntimeError("""There is at least 1 reference to internal data
|
||
|
in the interpreter in the form of a numpy array or slice. Be sure to
|
||
|
only hold the function returned from tensor() if you are using raw
|
||
|
data access.""")
|
||
|
|
||
|
def _get_tensor_details(self, tensor_index):
|
||
|
"""Gets tensor details.
|
||
|
|
||
|
Args:
|
||
|
tensor_index: Tensor index of tensor to query.
|
||
|
|
||
|
Returns:
|
||
|
a dictionary containing the name, index, shape and type of the tensor.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If tensor_index is invalid.
|
||
|
"""
|
||
|
tensor_index = int(tensor_index)
|
||
|
tensor_name = self._interpreter.TensorName(tensor_index)
|
||
|
tensor_size = self._interpreter.TensorSize(tensor_index)
|
||
|
tensor_type = self._interpreter.TensorType(tensor_index)
|
||
|
tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
|
||
|
|
||
|
if not tensor_name or not tensor_type:
|
||
|
raise ValueError('Could not get tensor details')
|
||
|
|
||
|
details = {
|
||
|
'name': tensor_name,
|
||
|
'index': tensor_index,
|
||
|
'shape': tensor_size,
|
||
|
'dtype': tensor_type,
|
||
|
'quantization': tensor_quantization,
|
||
|
}
|
||
|
|
||
|
return details
|
||
|
|
||
|
def get_input_details(self):
|
||
|
"""Gets model input details.
|
||
|
|
||
|
Returns:
|
||
|
A list of input details.
|
||
|
"""
|
||
|
return [
|
||
|
self._get_tensor_details(i) for i in self._interpreter.InputIndices()
|
||
|
]
|
||
|
|
||
|
def set_tensor(self, tensor_index, value):
|
||
|
"""Sets the value of the input tensor. Note this copies data in `value`.
|
||
|
|
||
|
If you want to avoid copying, you can use the `tensor()` function to get a
|
||
|
numpy buffer pointing to the input buffer in the tflite interpreter.
|
||
|
|
||
|
Args:
|
||
|
tensor_index: Tensor index of tensor to set. This value can be gotten from
|
||
|
the 'index' field in get_input_details.
|
||
|
value: Value of tensor to set.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the interpreter could not set the tensor.
|
||
|
"""
|
||
|
self._interpreter.SetTensor(tensor_index, value)
|
||
|
|
||
|
def resize_tensor_input(self, input_index, tensor_size):
|
||
|
"""Resizes an input tensor.
|
||
|
|
||
|
Args:
|
||
|
input_index: Tensor index of input to set. This value can be gotten from
|
||
|
the 'index' field in get_input_details.
|
||
|
tensor_size: The tensor_shape to resize the input to.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the interpreter could not resize the input tensor.
|
||
|
"""
|
||
|
self._ensure_safe()
|
||
|
self._interpreter.ResizeInputTensor(input_index, tensor_size)
|
||
|
|
||
|
def get_output_details(self):
|
||
|
"""Gets model output details.
|
||
|
|
||
|
Returns:
|
||
|
A list of output details.
|
||
|
"""
|
||
|
return [
|
||
|
self._get_tensor_details(i) for i in self._interpreter.OutputIndices()
|
||
|
]
|
||
|
|
||
|
def get_tensor(self, tensor_index):
|
||
|
"""Gets the value of the input tensor (get a copy).
|
||
|
|
||
|
If you wish to avoid the copy, use `tensor()`.
|
||
|
|
||
|
Args:
|
||
|
tensor_index: Tensor index of tensor to get. This value can be gotten from
|
||
|
the 'index' field in get_output_details.
|
||
|
|
||
|
Returns:
|
||
|
a numpy array.
|
||
|
"""
|
||
|
return self._interpreter.GetTensor(tensor_index)
|
||
|
|
||
|
def tensor(self, tensor_index):
|
||
|
"""Returns function that gives a numpy view of the current tensor buffer.
|
||
|
|
||
|
This allows reading and writing to this tensors w/o copies. This more
|
||
|
closely mirrors the C++ Interpreter class interface's tensor() member, hence
|
||
|
the name. Be careful to not hold these output references through calls
|
||
|
to `allocate_tensors()` and `invoke()`.
|
||
|
|
||
|
Usage:
|
||
|
|
||
|
interpreter.allocate_tensors()
|
||
|
input = interpreter.tensor(interpreter.get_input_details()[0]["index"])
|
||
|
output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
|
||
|
for i in range(10):
|
||
|
input().fill(3.)
|
||
|
interpreter.invoke()
|
||
|
print("inference %s" % output)
|
||
|
|
||
|
Notice how this function avoids making a numpy array directly. This is
|
||
|
because it is important to not hold actual numpy views to the data longer
|
||
|
than necessary. If you do, then the interpreter can no longer be invoked,
|
||
|
because it is possible the interpreter would resize and invalidate the
|
||
|
referenced tensors. The NumPy API doesn't allow any mutability of the
|
||
|
the underlying buffers.
|
||
|
|
||
|
WRONG:
|
||
|
|
||
|
input = interpreter.tensor(interpreter.get_input_details()[0]["index"])()
|
||
|
output = interpreter.tensor(interpreter.get_output_details()[0]["index"])()
|
||
|
interpreter.allocate_tensors() # This will throw RuntimeError
|
||
|
for i in range(10):
|
||
|
input.fill(3.)
|
||
|
interpreter.invoke() # this will throw RuntimeError since input,output
|
||
|
|
||
|
Args:
|
||
|
tensor_index: Tensor index of tensor to get. This value can be gotten from
|
||
|
the 'index' field in get_output_details.
|
||
|
|
||
|
Returns:
|
||
|
A function that can return a new numpy array pointing to the internal
|
||
|
TFLite tensor state at any point. It is safe to hold the function forever,
|
||
|
but it is not safe to hold the numpy array forever.
|
||
|
"""
|
||
|
return lambda: self._interpreter.tensor(self._interpreter, tensor_index)
|
||
|
|
||
|
def invoke(self):
|
||
|
"""Invoke the interpreter.
|
||
|
|
||
|
Be sure to set the input sizes, allocate tensors and fill values before
|
||
|
calling this.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: When the underlying interpreter fails raise ValueError.
|
||
|
"""
|
||
|
self._ensure_safe()
|
||
|
self._interpreter.Invoke()
|
||
|
|
||
|
def reset_all_variables_to_zero(self):
|
||
|
return self._interpreter.ResetVariableTensorsToZero()
|