laywerrobot/lib/python3.6/site-packages/tensorflow/contrib/autograph/utils/builtins.py
2020-08-27 21:55:39 +02:00

126 lines
4.1 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builtin conversion utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import six
from tensorflow.contrib.autograph.utils import py_func
from tensorflow.contrib.autograph.utils import type_check
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
def dynamic_builtin(f, *args, **kwargs):
"""Converts a builtin function call inline."""
if f is len:
return dynamic_len(*args, **kwargs)
if six.PY2 and f is xrange:
return dynamic_range(*args, **kwargs)
if f is range:
return dynamic_range(*args, **kwargs)
if f is int:
return dynamic_int(*args, **kwargs)
if f is float:
return dynamic_float(*args, **kwargs)
raise NotImplementedError(
'The "%s" builtin is not yet supported.' % f.__name__)
def dynamic_len(list_or_tensor):
"""Implementation of len using dynamic dispatch."""
if tensor_util.is_tensor(list_or_tensor):
shape = list_or_tensor.shape
if not shape:
raise ValueError(
'len requires non-zero rank for tensor "%s"' % list_or_tensor)
return array_ops.shape(list_or_tensor)[0]
return len(list_or_tensor)
def dynamic_int(num_or_tensor, **kwargs):
"""Implementation of int() using dynamic dispatch."""
if tensor_util.is_tensor(num_or_tensor):
return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs)
return int(num_or_tensor)
def dynamic_float(num_or_tensor, **kwargs):
"""Implementation of float() using dynamic dispatch."""
if tensor_util.is_tensor(num_or_tensor):
return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs)
return float(num_or_tensor)
def dynamic_range(start_or_stop, stop=None, step=None):
"""Implementation of range using dynamic dispatch."""
if type_check.is_tensor(start_or_stop, stop, step):
if step is not None:
return math_ops.range(start_or_stop, stop, step)
if stop is not None:
return math_ops.range(start_or_stop, stop)
return math_ops.range(start_or_stop)
if step is not None:
return range(start_or_stop, stop, step)
elif stop is not None:
return range(start_or_stop, stop)
return range(start_or_stop)
def is_tf_print_compatible(value):
# TODO(mdan): Enable once we can reliably test this.
# This is currently disabled because we can't capture the output of
# op kernels from Python.
del value
return False
def dynamic_print(*values):
"""Implementation of print using dynamic dispatch.
The function attempts to use tf.Print if all the values are compatible.
Otherwise, it will fall back to py_func.
Args:
*values: values to print
Returns:
A dummy value indicating the print completed. If tf.
"""
if all(map(is_tf_print_compatible, values)):
return logging_ops.Print(1, values)
def print_wrapper(*vals):
if six.PY3:
# TensorFlow doesn't seem to generate Unicode when passing strings to
# py_func. This causes the print to add a "b'" wrapper to the output,
# which is probably never what you want.
vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals)
print(*vals)
# The flush helps avoid garbled output in IPython.
sys.stdout.flush()
return py_func.wrap_py_func(
print_wrapper, None, values, use_dummy_return=True)