337 lines
12 KiB
Python
337 lines
12 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Execution Callbacks for Eager Mode."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import functools
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow.python import pywrap_tensorflow
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import core
|
|
from tensorflow.python.eager import execute
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
|
|
_DEFAULT_CALLBACK_ACTION = "raise"
|
|
_VALID_CALLBACK_ACTIONS = (None, "ignore", "print", "raise", "warn")
|
|
|
|
|
|
# TODO(cais): Consider moving this exception class to errors_impl.py.
|
|
class InfOrNanError(Exception):
|
|
"""Exception for inf and/or nan being present in tensor."""
|
|
|
|
def __init__(self,
|
|
op_type,
|
|
op_name,
|
|
output_index,
|
|
num_outputs,
|
|
value):
|
|
"""Constructor of InfOrNanError.
|
|
|
|
Args:
|
|
op_type: Type name of the op that generated the tensor that generated the
|
|
`inf`(s) or `nan`(s) (e.g., `Div`).
|
|
op_name: Name of the op that generated the tensor with `inf`(s) or
|
|
`nan`(s). This name is set by client and can be `None` if it is unset.
|
|
output_index: The 0-based output index of the tensor that contains
|
|
`inf`(s) or `nan`(s).
|
|
num_outputs: Total number of outputs of the operation.
|
|
value: The tensor value that contains `inf`(s) or `nan`(s).
|
|
"""
|
|
self._op_type = op_type
|
|
self._op_name = op_name
|
|
self._output_index = output_index
|
|
self._num_outputs = num_outputs
|
|
self._value = value
|
|
|
|
self._total_count = np.size(value)
|
|
self._inf_count = np.count_nonzero(np.isinf(value))
|
|
self._nan_count = np.count_nonzero(np.isnan(value))
|
|
|
|
super(InfOrNanError, self).__init__(self._get_error_message())
|
|
|
|
def _get_error_message(self):
|
|
"""Get the error message describing this InfOrNanError object."""
|
|
name_str = (("'%s'" % self._op_name) if self._op_name is not None
|
|
else str(self._op_name))
|
|
msg = "Output %d of %d of TFE operation %s (name: %s) contains " % (
|
|
self._output_index + 1, self._num_outputs, self._op_type, name_str)
|
|
if self._inf_count and self._nan_count:
|
|
msg += "%d inf(s) and %d nan(s) " % (self._inf_count, self._nan_count)
|
|
elif self._inf_count:
|
|
msg += "%d inf(s) " % self._inf_count
|
|
else:
|
|
msg += "%d nan(s) " % self._nan_count
|
|
msg += "out of a total of %d element(s). Tensor value: %s" % (
|
|
self._total_count, self._value)
|
|
return msg
|
|
|
|
@property
|
|
def op_type(self):
|
|
return self._op_type
|
|
|
|
@property
|
|
def op_name(self):
|
|
return self._op_name
|
|
|
|
@property
|
|
def output_index(self):
|
|
return self._output_index
|
|
|
|
@property
|
|
def num_outputs(self):
|
|
return self._num_outputs
|
|
|
|
@property
|
|
def value(self):
|
|
return self._value
|
|
|
|
|
|
def inf_nan_callback(op_type,
|
|
inputs,
|
|
attrs,
|
|
outputs,
|
|
op_name,
|
|
check_inf=True,
|
|
check_nan=True,
|
|
action=_DEFAULT_CALLBACK_ACTION):
|
|
"""An execution callback that checks for `inf`s and `nan`s in output tensors.
|
|
|
|
This callback can be used with `tfe.add_execute_callback` to check for invalid
|
|
numeric values. E.g.,
|
|
```python
|
|
tfe.add_execute_callback(tfe.inf_nan_callback)
|
|
```
|
|
|
|
Args:
|
|
op_type: Name of the TFE operation type (e.g., `MatMul`).
|
|
inputs: The `list` of input tensors to the operation, currently unused by
|
|
this callback.
|
|
attrs: Attributes of the TFE operation, as a tuple of alternating attribute
|
|
names and attribute values.
|
|
outputs: The `list` of output tensors from the operation, checked by this
|
|
callback for `inf` and `nan` values.
|
|
op_name: Name of the TFE operation. This name is set by client and can be
|
|
`None` if it unset.
|
|
check_inf: (`bool`) Whether this callback should check for `inf` values in
|
|
the output tensor values.
|
|
check_nan: (`bool`) Whether this callback should check for `nan` values in
|
|
the output tensor values.
|
|
action: (`str`) Action to be taken by the callback when `inf` or `nan`
|
|
values are detected. Possible values {"raise", "warn", "print"}
|
|
`"raise"`: Raise a `InfOrNanError`.
|
|
`"warn"`: Log a warning using `tf.logging.warn`.
|
|
`"print"`: Print a message to `sys.stdout`.
|
|
|
|
Raises:
|
|
InfOrNanError: iff `inf` or `nan` values are seen in any of `outputs` and
|
|
`action` is `"raise"`.
|
|
ValueError: iff the value of `action` is invalid.
|
|
"""
|
|
del attrs, inputs # Not used.
|
|
|
|
ctx = context.get_default_context()
|
|
|
|
for index, output in enumerate(outputs):
|
|
if not output.dtype.is_numpy_compatible:
|
|
continue
|
|
|
|
numpy_dtype = output.dtype.as_numpy_dtype
|
|
if (np.issubdtype(numpy_dtype, np.floating) or
|
|
np.issubdtype(numpy_dtype, np.complex) or
|
|
np.issubdtype(numpy_dtype, np.integer)):
|
|
try:
|
|
check_numerics_op_attrs = (
|
|
"message", "Eager-mode inf/nan check",
|
|
"T", outputs[0].dtype.as_datatype_enum)
|
|
# TODO(cais): Consider moving this into execute.py.
|
|
# pylint: disable=protected-access
|
|
pywrap_tensorflow.TFE_Py_Execute(
|
|
ctx._handle, output.device, "CheckNumerics", [output],
|
|
check_numerics_op_attrs, 1)
|
|
# pylint: enable=protected-access
|
|
except core._NotOkStatusException: # pylint: disable=protected-access
|
|
value = output.numpy()
|
|
inf_detected = np.any(np.isinf(value)) and check_inf
|
|
nan_detected = np.any(np.isnan(value)) and check_nan
|
|
if not inf_detected and not nan_detected:
|
|
continue
|
|
|
|
error = InfOrNanError(op_type, op_name, index, len(outputs), value)
|
|
if action == "print":
|
|
print("Warning: %s" % str(error))
|
|
elif action == "warn":
|
|
logging.warn(str(error))
|
|
elif action == "raise":
|
|
raise error
|
|
else:
|
|
raise ValueError(
|
|
"Invalid action for inf_nan_callback: %s. Valid actions are: "
|
|
"{print | warn | raise}" % action)
|
|
|
|
|
|
def inf_callback(op_type,
|
|
inputs,
|
|
attrs,
|
|
outputs,
|
|
op_name,
|
|
action=_DEFAULT_CALLBACK_ACTION):
|
|
"""A specialization of `inf_nan_callback` that checks for `inf`s only."""
|
|
inf_nan_callback(
|
|
op_type,
|
|
inputs,
|
|
attrs,
|
|
outputs,
|
|
op_name,
|
|
check_inf=True,
|
|
check_nan=False,
|
|
action=action)
|
|
|
|
|
|
def nan_callback(op_type,
|
|
inputs,
|
|
attrs,
|
|
outputs,
|
|
op_name,
|
|
action=_DEFAULT_CALLBACK_ACTION):
|
|
"""A specialization of `inf_nan_callback` that checks for `nan`s only."""
|
|
inf_nan_callback(
|
|
op_type,
|
|
inputs,
|
|
attrs,
|
|
outputs,
|
|
op_name,
|
|
check_inf=False,
|
|
check_nan=True,
|
|
action=action)
|
|
|
|
|
|
def add_execution_callback(callback):
|
|
"""Add an execution callback to the default eager context.
|
|
|
|
An execution callback is invoked immediately after an eager operation or
|
|
function has finished execution, providing access to the op's type, name
|
|
input and output tensors. Multiple execution callbacks can be added, in
|
|
which case the callbacks will be invoked in the order in which they are
|
|
added. To clear all execution callbacks that have been added, use
|
|
`clear_execution_callbacks()`.
|
|
|
|
Example:
|
|
```python
|
|
def print_even_callback(op_type, op_name, attrs, inputs, outputs):
|
|
# A callback that prints only the even output values.
|
|
if outputs[0].numpy() % 2 == 0:
|
|
print("Even output from %s: %s" % (op_name or op_type, outputs))
|
|
tfe.add_execution_callback(print_even_callback)
|
|
|
|
x = tf.pow(2.0, 3.0) - 3.0
|
|
y = tf.multiply(x, tf.add(1.0, 5.0))
|
|
# When the line above is run, you will see all intermediate outputs that are
|
|
# even numbers printed to the console.
|
|
|
|
tfe.clear_execution_callbacks()
|
|
```
|
|
|
|
Args:
|
|
callback: a callable of the signature
|
|
`f(op_type, op_name, attrs, inputs, outputs)`.
|
|
`op_type` is the type of the operation that was just executed (e.g.,
|
|
`MatMul`).
|
|
`op_name` is the name of the operation that was just executed. This
|
|
name is set by the client who created the operation and can be `None` if
|
|
it is unset.
|
|
`attrs` contains the attributes of the operation as a `tuple` of
|
|
alternating attribute name and attribute value.
|
|
`inputs` is the `list` of input `Tensor`(s) to the op.
|
|
`outputs` is the `list` of output `Tensor`(s) from the op.
|
|
Return value(s) from the callback are ignored.
|
|
"""
|
|
execute.execute = execute.execute_with_callbacks
|
|
context.get_default_context().add_post_execution_callback(callback)
|
|
|
|
|
|
def clear_execution_callbacks():
|
|
"""Clear all execution callbacks from the default eager context."""
|
|
context.get_default_context().clear_post_execution_callbacks()
|
|
|
|
|
|
def seterr(inf_or_nan=None):
|
|
"""Set how abnormal conditions are handled by the default eager context.
|
|
|
|
Example:
|
|
```python
|
|
tfe.seterr(inf_or_nan="raise")
|
|
a = tf.constant(10.0)
|
|
b = tf.constant(0.0)
|
|
try:
|
|
c = a / b # <-- Raises InfOrNanError.
|
|
except Exception as e:
|
|
print("Caught Exception: %s" % e)
|
|
|
|
tfe.seterr(inf_or_nan="ignore")
|
|
c = a / b # <-- Does NOT raise exception anymore.
|
|
```
|
|
|
|
Args:
|
|
inf_or_nan: Set action for infinity (`inf`) and NaN (`nan`) values.
|
|
Possible values: `{"ignore", "print", "raise", "warn"}`.
|
|
`"ignore"`: take no action when `inf` values appear.
|
|
`"print"`: print a warning to `stdout`.
|
|
`"raise"`: raise an `InfOrNanError`.
|
|
`"warn"`: print a warning using `tf.logging.warn`.
|
|
A value of `None` leads to no change in the action of the condition.
|
|
|
|
Returns:
|
|
A dictionary of old actions.
|
|
|
|
Raises:
|
|
ValueError: If the value of any keyword arguments is invalid.
|
|
"""
|
|
if inf_or_nan not in _VALID_CALLBACK_ACTIONS:
|
|
raise ValueError(
|
|
"Invalid action value for inf_or_nan: %s. "
|
|
"Valid actions are %s." % (inf_or_nan, _VALID_CALLBACK_ACTIONS))
|
|
|
|
old_settings = {"inf_or_nan": "ignore"}
|
|
default_context = context.get_default_context()
|
|
|
|
carryover_callbacks = []
|
|
for callback in default_context.post_execution_callbacks:
|
|
# Check whether the callback is inf_nan_callback or a partial object of
|
|
# inf_nan_callback.
|
|
if (callback == inf_nan_callback or
|
|
isinstance(callback, functools.partial) and
|
|
callback.func == inf_nan_callback):
|
|
if callback == inf_nan_callback:
|
|
old_settings["inf_or_nan"] = _DEFAULT_CALLBACK_ACTION
|
|
else:
|
|
old_settings["inf_or_nan"] = callback.keywords.get(
|
|
"action", _DEFAULT_CALLBACK_ACTION)
|
|
elif inf_or_nan is not None:
|
|
carryover_callbacks.append(callback)
|
|
|
|
if inf_or_nan is not None:
|
|
default_context.clear_post_execution_callbacks()
|
|
for callback in carryover_callbacks:
|
|
default_context.add_post_execution_callback(callback)
|
|
if inf_or_nan != "ignore":
|
|
default_context.add_post_execution_callback(
|
|
functools.partial(inf_nan_callback, action=inf_or_nan))
|
|
|
|
return old_settings
|