laywerrobot/lib/python3.6/site-packages/tensorflow/python/eager/context.py
2020-08-27 21:55:39 +02:00

741 lines
24 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental API for TensorFlow's "Eager" mode of execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import contextlib
import copy
import random
import threading
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev
from tensorflow.python.util import compat
from tensorflow.python.util import is_in_graph_mode
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
GRAPH_MODE = 0
EAGER_MODE = 1
# Default execution mode.
_default_mode = GRAPH_MODE
# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
# new_device_spec).
# Note that we do not protect this with a lock and instead rely on python's GIL
# and the idempotent nature of writes to provide thread safety.
_device_parsing_cache = {}
_MAXINT32 = 2**31 - 1
DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT
DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN
DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT
DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
SYNC = 0
ASYNC = 1
class _TensorCache(object):
"""Simple cache which evicts items based on length in a FIFO manner."""
def __init__(self, max_items=256):
self._data = collections.OrderedDict()
self._max_items = max_items if max_items else 256
def put(self, key, value):
self._data[key] = value
if len(self._data) > self._max_items:
self._data.popitem(last=False)
def get(self, key):
return self._data.get(key, None)
def flush(self):
self._data = {}
# TODO(agarwal): better name ?
class _EagerContext(threading.local):
"""Thread local eager context."""
def __init__(self):
super(_EagerContext, self).__init__()
self.device_spec = pydev.DeviceSpec.from_string("")
self.device_name = self.device_spec.to_string()
self.mode = _default_mode
self.is_eager = _default_mode == EAGER_MODE
self.scope_name = ""
self.recording_summaries = False
self.summary_writer_resource = None
self.scalar_cache = {}
self.ones_rank_cache = _TensorCache()
self.execution_mode = None
ContextSwitch = collections.namedtuple(
"ContextSwitch", ["is_building_function", "enter_context_fn"])
# `_ContextSwitchStack` is a `threading.local` to match the semantics of
# ``DefaultGraphStack`, which is also a `threading.local`.
class _ContextSwitchStack(threading.local):
"""A thread-local stack of context switches."""
def __init__(self, eager):
super(_ContextSwitchStack, self).__init__()
self.stack = []
if eager:
# Initialize the stack with a pointer to enter the eager context; this
# ensures that the fact that eager execution was enabled is propagated
# across threads, since (1) `enable_eager_execution` modifies a
# process-level flag (`_default_mode`) and (2) `__init__` is called each
# time a threading.local object is used in a separate thread.
self.push(is_building_function=False, enter_context_fn=eager_mode)
def push(self, is_building_function, enter_context_fn):
"""Push metadata about a context switch onto the stack.
A context switch can take one of two forms: installing a graph as the
default graph, or entering the eager context. For each context switch,
we record whether or not the entered context is building a function.
Args:
is_building_function: (bool.) Whether the context is building a function.
enter_context_fn: (function.) A callable that executes the context switch.
For example, `graph.as_default` or `eager_mode`.
"""
self.stack.append(
ContextSwitch(is_building_function, enter_context_fn))
def pop(self):
"""Pop the stack."""
self.stack.pop()
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
# TODO(agarwal): consider keeping the corresponding Graph here.
class Context(object):
"""Environment in which eager operations execute."""
# TODO(agarwal): create and link in some documentation for `execution_mode`.
# pylint: disable=redefined-outer-name
def __init__(self,
config=None,
device_policy=None,
execution_mode=None,
server_def=None):
"""Creates a new Context.
Args:
config: (Optional.) A `ConfigProto` protocol buffer with configuration
options for the Context. Note that a lot of these options may be
currently unimplemented or irrelevant when eager execution is enabled.
device_policy: (Optional.) What policy to use when trying to run an
operation on a device with inputs which are not on that device.
When set to None, an appropriate value will be picked automatically.
The value picked may change between TensorFlow releases.
Defaults to tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32.
Valid values:
- tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is
not correct.
- tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the
right device but raises a warning.
- tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might
hide performance problems.
- tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
raising errors on the other ones.
execution_mode: (Optional.) Policy controlling how operations dispatched
are actually executed. When set to None, an appropriate value will be
picked automatically. The value picked may change between TensorFlow
releases.
Valid values:
- tf.contrib.eager.SYNC: executes each operation synchronously.
- tf.contrib.eager.ASYNC: executes each operation asynchronously. These
operations may return "non-ready" handles.
Raises:
ValueError: If execution_mode is not valid.
"""
self._eager_context = _EagerContext()
self._context_switches = _ContextSwitchStack(self.executing_eagerly())
self._context_handle = None
self._context_devices = None
self._post_execution_callbacks = []
self._config = config
self._seed = None
self._initialize_lock = threading.Lock()
self._device_policy = device_policy
if execution_mode not in (None, SYNC, ASYNC):
raise ValueError(
"execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode)
if execution_mode is None:
execution_mode = SYNC
self._execution_mode = execution_mode
self._server_def = server_def
# pylint: enable=redefined-outer-name
def _set_global_seed(self, seed):
"""Set a global eager mode seed for random ops."""
self._seed = seed
self._rng = random.Random(self._seed)
# Also clear the kernel cache, to reset any existing seeds
if self._context_handle is not None:
pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)
def _internal_operation_seed(self):
"""Returns a fake operation seed.
In eager mode, user shouldn't set or depend on operation seed.
Here, we generate a random seed based on global seed to make
operation's randomness different and depend on the global seed.
Returns:
A fake operation seed based on global seed.
"""
return self._rng.randint(0, _MAXINT32)
def _initialize_handle_and_devices(self):
"""Initialize handle and devices."""
with self._initialize_lock:
if self._context_handle is not None:
return
assert self._context_devices is None
opts = pywrap_tensorflow.TFE_NewContextOptions()
try:
if self._config is not None:
config_str = self._config.SerializeToString()
pywrap_tensorflow.TFE_ContextOptionsSetConfig(opts, config_str)
if self._device_policy is not None:
pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
opts, self._device_policy)
if self._execution_mode == ASYNC:
pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
if self._server_def is not None:
server_def_str = self._server_def.SerializeToString()
pywrap_tensorflow.TFE_ContextOptionsSetServerDef(opts, server_def_str)
self._context_handle = pywrap_tensorflow.TFE_NewContext(opts)
finally:
pywrap_tensorflow.TFE_DeleteContextOptions(opts)
# Store list of devices
self._context_devices = []
device_list = pywrap_tensorflow.TFE_ContextListDevices(
self._context_handle)
try:
self._num_gpus = 0
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
self._context_devices.append(pydev.canonical_name(dev_name))
dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
if dev_type == "GPU":
self._num_gpus += 1
finally:
pywrap_tensorflow.TF_DeleteDeviceList(device_list)
@property
def _handle(self):
ctx = self._context_handle
if ctx is None:
self._initialize_handle_and_devices()
return self._context_handle
else:
return ctx
@property
def _devices(self):
devices = self._context_devices
if devices is None:
self._initialize_handle_and_devices()
return self._context_devices
else:
return devices
def __str__(self):
if self._context_handle is None:
return "Eager TensorFlow Context. Devices currently uninitialized."
else:
devices = self._devices
lines = ["Eager TensorFlow Context with %d devices" % (len(devices))]
for i, d in enumerate(devices):
lines.append(" Device %d: %s" % (i, d))
return "\n".join(lines)
@tf_contextlib.contextmanager
def _mode(self, mode):
"""A context manager to allow setting the mode to EAGER/GRAPH."""
ctx = self._eager_context
old_mode = ctx.mode
old_is_eager = ctx.is_eager
ctx.mode = mode
ctx.is_eager = mode == EAGER_MODE
if mode == EAGER_MODE:
# Entering graph mode does not provide us with sufficient information to
# record a context switch; graph-based context switches are only logged
# when a graph is registered as the default graph.
self.context_switches.push(False, eager_mode)
try:
yield
finally:
ctx.is_eager = old_is_eager
ctx.mode = old_mode
if mode == EAGER_MODE:
self.context_switches.pop()
def executing_eagerly(self):
"""Returns True if current thread has eager executing enabled."""
return self._eager_context.is_eager
def scalar_cache(self):
"""Per-device cache for scalars."""
return self._eager_context.scalar_cache
def ones_rank_cache(self):
"""Per-device cache for scalars."""
return self._eager_context.ones_rank_cache
@property
def scope_name(self):
"""Returns scope name for the current thread."""
return self._eager_context.scope_name
@scope_name.setter
def scope_name(self, s):
"""Sets scope name for the current thread."""
self._eager_context.scope_name = s
@property
def summary_writer_resource(self):
"""Returns summary writer resource."""
return self._eager_context.summary_writer_resource
@summary_writer_resource.setter
def summary_writer_resource(self, resource):
"""Sets summary writer resource."""
self._eager_context.summary_writer_resource = resource
@property
def device_name(self):
"""Returns the device name for the current thread."""
return self._eager_context.device_name
@property
def device_spec(self):
"""Returns the device spec for the current thread."""
return self._eager_context.device_spec
@tf_contextlib.contextmanager
def device(self, name):
"""Context-manager to force placement of operations and Tensors on a device.
Args:
name: Name of the device or None to get default placement.
Yields:
Nothing.
Raises:
ValueError: If name is not a string or is an invalid device name.
"""
eager_context = self._eager_context
old_device_name = eager_context.device_name
old_device_spec = eager_context.device_spec
cache_key = (old_device_name, name)
try:
new_device_name, new_device_spec = _device_parsing_cache[cache_key]
except TypeError:
# Error while trying to compute the cache key.
raise ValueError("Expecting a string device name. Got %s(%s)" %
(type(name), name))
except KeyError:
# Handle a cache miss.
if name is not None:
if not isinstance(name, str):
raise ValueError("Expecting a string device name. Got %s(%s)" %
(type(name), name))
device_spec = pydev.DeviceSpec.from_string(name)
if old_device_name:
new_device_spec = copy.copy(old_device_spec)
else:
new_device_spec = pydev.DeviceSpec.from_string(
"/job:localhost/replica:0/task:0/device:CPU:0")
new_device_spec.merge_from(device_spec)
else:
new_device_spec = pydev.DeviceSpec.from_string("")
new_device_name = new_device_spec.to_string()
_device_parsing_cache[cache_key] = (new_device_name, new_device_spec)
try:
eager_context.device_name = new_device_name
eager_context.device_spec = new_device_spec
yield
finally:
eager_context.device_name = old_device_name
eager_context.device_spec = old_device_spec
def devices(self):
"""List of the names of devices available to execute operations."""
return self._devices
def get_execution_mode(self):
mode = self._eager_context.execution_mode
if mode is None:
mode = self._execution_mode
return mode
def set_execution_mode(self, mode):
"""Sets execution mode for current thread."""
if mode not in (None, SYNC, ASYNC):
raise ValueError(
"Execution mode should be None/SYNC/ASYNC. Got %s" % mode)
if mode is None:
mode = SYNC
self._eager_context.execution_mode = mode
pywrap_tensorflow.TFE_ContextSetAsyncForThread(self._handle, mode == ASYNC)
@tf_contextlib.contextmanager
def execution_mode(self, mode):
"""Context manager for setting execution mode for current thread."""
old_mode = self.get_execution_mode()
try:
self.set_execution_mode(mode)
yield
finally:
self.set_execution_mode(old_mode)
def async_wait(self):
"""Waits for ops dispatched in ASYNC mode to finish."""
pywrap_tensorflow.TFE_ContextAsyncWait(self._handle)
def async_clear_error(self):
"""Clears errors raised during ASYNC execution."""
pywrap_tensorflow.TFE_ContextAsyncClearError(self._handle)
def num_gpus(self):
"""The number of GPUs available to execute operations."""
self._initialize_handle_and_devices()
return self._num_gpus
def add_function(self, fn):
"""Add a function definition to the context.
Once added, the function (identified by its name) can be executed like any
other operation.
Args:
fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
"""
pywrap_tensorflow.TFE_ContextAddFunction(
self._handle, # pylint: disable=protected-access
fn)
def add_function_def(self, fdef):
"""Add a function definition to the context.
Once added, the function (identified by its name) can be executed like any
other operation.
Args:
fdef: A FunctionDef protocol buffer message.
"""
fdef_string = fdef.SerializeToString()
pywrap_tensorflow.TFE_ContextAddFunctionDef(
self._handle, # pylint: disable=protected-access
fdef_string,
len(fdef_string))
def add_post_execution_callback(self, callback):
"""Add a post-execution callback to the context.
A post-execution callback is invoked immediately after an eager operation or
function has finished execution, providing access to the op's type, name
input and output tensors. Multiple execution callbacks can be added, in
which case the callbacks will be invoked in the order in which they are
added.
Args:
callback: a callable of the signature
`f(op_type, op_name, attrs, inputs, outputs)`.
`op_type` is the type of the operation that was just executed (e.g.,
`MatMul`).
`op_name` is the name of the operation that has was just executed. This
name is set by the client who created the operation and can be `None` if
it is unset.
`attrs` contains the attributes of the operation as a `tuple` of
alternating attribute names and attribute values.
`inputs` is the `list` of input `Tensor`(s) to the op.
`outputs` is the `list` of output `Tensor`(s) from the op.
Return value(s) from the callback are ignored.
"""
# TODO(cais): (b/64674139) Allow access to function-internal operations.
self._post_execution_callbacks.append(callback)
def clear_post_execution_callbacks(self):
"""Clear all post-execution callbacks added to the context."""
del self._post_execution_callbacks[:]
@property
def post_execution_callbacks(self):
"""Get the list of post-execution callbacks added to the context."""
return self._post_execution_callbacks
def enable_run_metadata(self):
"""Enables tracing of op execution via RunMetadata.
To retrieve the accumulated metadata call context.export_run_metadata()
and to stop tracing call context.disable_run_metadata().
"""
pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._handle)
@tf_contextlib.contextmanager
def device_policy(self, policy):
handle = self._handle
old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(handle)
pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
handle, policy)
try:
yield
finally:
pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
handle, old)
def disable_run_metadata(self):
"""Disables tracing of op execution via RunMetadata."""
if not self._context_handle:
return
pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle)
def export_run_metadata(self):
"""Returns a RunMetadata proto with accumulated information.
The returned protocol buffer contains information since the most recent call
to either enable_run_metadata or export_run_metadata.
Returns:
A RunMetadata protocol buffer. Or None if not enabled.
"""
if not self._context_handle:
return None
with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TFE_ContextExportRunMetadata(
self._context_handle, buffer_)
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
run_metadata = config_pb2.RunMetadata()
run_metadata.ParseFromString(compat.as_bytes(proto_data))
return run_metadata
@property
def context_switches(self):
"""Returns a stack of context switches."""
return self._context_switches
_context = None
_context_lock = threading.Lock()
def _initialize_context():
global _context
with _context_lock:
if _context is None:
_context = Context()
def context():
"""Returns a singleton context object."""
if _context is None:
_initialize_context()
return _context
def context_safe():
return _context
# TODO(agarwal): remove this.
def get_default_context():
"""Same as context."""
if _context is None:
_initialize_context()
return _context
def set_global_seed(seed):
"""Sets the eager mode seed."""
context()._set_global_seed(seed) # pylint: disable=protected-access
def global_seed():
"""Returns the eager mode seed."""
return context()._seed # pylint: disable=protected-access
def internal_operation_seed():
"""Returns the operation seed generated based on global seed."""
return context()._internal_operation_seed() # pylint: disable=protected-access
@tf_export("executing_eagerly")
def executing_eagerly():
"""Returns True if the current thread has eager execution enabled.
Eager execution is typically enabled via @{tf.enable_eager_execution},
but may also be enabled within the context of a Python function via
tf.contrib.eager.py_func.
"""
return context().executing_eagerly()
def in_eager_mode():
"""Use executing_eagerly() instead. This function will be removed."""
return executing_eagerly()
def graph_mode():
"""Context-manager to disable eager execution for the current thread."""
return context()._mode(GRAPH_MODE) # pylint: disable=protected-access
def eager_mode():
"""Context-manager to enable eager execution for the current thread."""
return context()._mode(EAGER_MODE) # pylint: disable=protected-access
# TODO(agarwal): get rid of this and use ops.name_scope instead.
@contextlib.contextmanager
def namescope(name):
"""ContextManager for creating hierarchical name scopes."""
ctx = context()
old_name = ctx.scope_name
ctx.scope_name = "%s/%s" % (old_name, name) if old_name else name
try:
yield
finally:
ctx.scope_name = old_name
def scope_name():
"""Name of the current scope."""
return context().scope_name
def device(name):
"""Context-manager to force placement of operations and Tensors on a device.
Example:
```python
with tfe.device('gpu:0'):
with tfe.device('cpu:0'):
shape = tf.constant([], dtype=tf.int32)
x = tf.truncated_normal(shape, tf.float32)
```
will ensure that the `shape` Tensor is on CPU but the `truncated_normal`
operation runs on GPU 0.
Args:
name: Name of the device (see context().devices()), or None to
perform automatic placement.
Returns:
Context manager for setting the device.
"""
return context().device(name)
def list_devices():
"""List the names of the available devices.
Returns:
Names of the available devices, as a `list`.
"""
return context().devices()
def set_execution_mode(mode):
"""Sets execution mode for the current thread."""
context().set_execution_mode(mode)
def execution_mode(mode):
"""Context manager for setting execution mode for current thread."""
return context().execution_mode(mode)
def async_wait():
"""Waits for ops dispatched in ASYNC mode to finish."""
return context().async_wait()
def async_clear_error():
"""Clears errors raised during ASYNC execution mode."""
return context().async_clear_error()
def num_gpus():
"""Get the number of available GPU devices.
Returns:
The number of available GPU devices.
"""
return context().num_gpus()
def enable_run_metadata():
"""Enables tracing of op execution via RunMetadata.
To retrieve the accumulated metadata call context.export_run_metadata()
and to stop tracing call context.disable_run_metadata().
"""
context().enable_run_metadata()
def disable_run_metadata():
"""Disables tracing of op execution via RunMetadata."""
context().disable_run_metadata()
def export_run_metadata():
"""Returns a RunMetadata proto with accumulated information.
The returned protocol buffer contains information since the most recent call
to either enable_run_metadata or export_run_metadata.
Returns:
A RunMetadata protocol buffer.
"""
return context().export_run_metadata()
# Not every user creates a Context via context.context()
# (for example, enable_eager_execution in python/framework/ops.py),
# but they do all import this file. Note that IS_IN_GRAPH_MODE and
# in_graph_mode are both parameterless functions.
def _tmp_in_graph_mode():
return not executing_eagerly()
is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode