laywerrobot/lib/python3.6/site-packages/tensorflow/python/ops/spectral_grad.py

186 lines
6.8 KiB
Python
Raw Normal View History

2020-08-27 21:55:39 +02:00
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradients for operators defined in spectral_ops.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import spectral_ops
def _FFTSizeForGrad(grad, rank):
return math_ops.reduce_prod(array_ops.shape(grad)[-rank:])
@ops.RegisterGradient("FFT")
def _FFTGrad(_, grad):
size = math_ops.cast(_FFTSizeForGrad(grad, 1), grad.dtype)
return spectral_ops.ifft(grad) * size
@ops.RegisterGradient("IFFT")
def _IFFTGrad(_, grad):
rsize = math_ops.cast(
1. / math_ops.cast(_FFTSizeForGrad(grad, 1), grad.dtype.real_dtype),
grad.dtype)
return spectral_ops.fft(grad) * rsize
@ops.RegisterGradient("FFT2D")
def _FFT2DGrad(_, grad):
size = math_ops.cast(_FFTSizeForGrad(grad, 2), grad.dtype)
return spectral_ops.ifft2d(grad) * size
@ops.RegisterGradient("IFFT2D")
def _IFFT2DGrad(_, grad):
rsize = math_ops.cast(
1. / math_ops.cast(_FFTSizeForGrad(grad, 2), grad.dtype.real_dtype),
grad.dtype)
return spectral_ops.fft2d(grad) * rsize
@ops.RegisterGradient("FFT3D")
def _FFT3DGrad(_, grad):
size = math_ops.cast(_FFTSizeForGrad(grad, 3), grad.dtype)
return spectral_ops.ifft3d(grad) * size
@ops.RegisterGradient("IFFT3D")
def _IFFT3DGrad(_, grad):
rsize = math_ops.cast(
1. / math_ops.cast(_FFTSizeForGrad(grad, 3), grad.dtype.real_dtype),
grad.dtype)
return spectral_ops.fft3d(grad) * rsize
def _RFFTGradHelper(rank, irfft_fn):
"""Returns a gradient function for an RFFT of the provided rank."""
# Can't happen because we don't register a gradient for RFFT3D.
assert rank in (1, 2), "Gradient for RFFT3D is not implemented."
def _Grad(op, grad):
"""A gradient function for RFFT with the provided `rank` and `irfft_fn`."""
fft_length = op.inputs[1]
input_shape = array_ops.shape(op.inputs[0])
is_even = math_ops.cast(1 - (fft_length[-1] % 2), dtypes.complex64)
def _TileForBroadcasting(matrix, t):
expanded = array_ops.reshape(
matrix,
array_ops.concat([
array_ops.ones([array_ops.rank(t) - 2], dtypes.int32),
array_ops.shape(matrix)
], 0))
return array_ops.tile(
expanded, array_ops.concat([array_ops.shape(t)[:-2], [1, 1]], 0))
def _MaskMatrix(length):
# TODO(rjryan): Speed up computation of twiddle factors using the
# following recurrence relation and cache them across invocations of RFFT.
#
# t_n = exp(sqrt(-1) * pi * n^2 / line_len)
# for n = 0, 1,..., line_len-1.
# For n > 2, use t_n = t_{n-1}^2 / t_{n-2} * t_1^2
a = array_ops.tile(
array_ops.expand_dims(math_ops.range(length), 0), (length, 1))
b = array_ops.transpose(a, [1, 0])
return math_ops.exp(-2j * np.pi * math_ops.cast(a * b, dtypes.complex64) /
math_ops.cast(length, dtypes.complex64))
def _YMMask(length):
"""A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`."""
return math_ops.cast(1 - 2 * (math_ops.range(length) % 2),
dtypes.complex64)
y0 = grad[..., 0:1]
if rank == 1:
ym = grad[..., -1:]
extra_terms = y0 + is_even * ym * _YMMask(input_shape[-1])
elif rank == 2:
# Create a mask matrix for y0 and ym.
base_mask = _MaskMatrix(input_shape[-2])
# Tile base_mask to match y0 in shape so that we can batch-matmul the
# inner 2 dimensions.
tiled_mask = _TileForBroadcasting(base_mask, y0)
y0_term = math_ops.matmul(tiled_mask, math_ops.conj(y0))
extra_terms = y0_term
ym = grad[..., -1:]
ym_term = math_ops.matmul(tiled_mask, math_ops.conj(ym))
inner_dim = input_shape[-1]
ym_term = array_ops.tile(
ym_term,
array_ops.concat([
array_ops.ones([array_ops.rank(grad) - 1], dtypes.int32),
[inner_dim]
], 0)) * _YMMask(inner_dim)
extra_terms += is_even * ym_term
# The gradient of RFFT is the IRFFT of the incoming gradient times a scaling
# factor, plus some additional terms to make up for the components dropped
# due to Hermitian symmetry.
input_size = math_ops.to_float(_FFTSizeForGrad(op.inputs[0], rank))
irfft = irfft_fn(grad, fft_length)
return 0.5 * (irfft * input_size + math_ops.real(extra_terms)), None
return _Grad
def _IRFFTGradHelper(rank, rfft_fn):
"""Returns a gradient function for an IRFFT of the provided rank."""
# Can't happen because we don't register a gradient for IRFFT3D.
assert rank in (1, 2), "Gradient for IRFFT3D is not implemented."
def _Grad(op, grad):
"""A gradient function for IRFFT with the provided `rank` and `rfft_fn`."""
# Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs
# and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the
# graph we special-case the situation where the FFT length and last
# dimension of the input are known at graph construction time.
fft_length = op.inputs[1]
is_odd = math_ops.mod(fft_length[-1], 2)
input_last_dimension = array_ops.shape(op.inputs[0])[-1]
mask = array_ops.concat(
[[1.0], 2.0 * array_ops.ones([input_last_dimension - 2 + is_odd]),
array_ops.ones([1 - is_odd])], 0)
rsize = math_ops.reciprocal(math_ops.to_float(_FFTSizeForGrad(grad, rank)))
# The gradient of IRFFT is the RFFT of the incoming gradient times a scaling
# factor and a mask. The mask scales the gradient for the Hermitian
# symmetric components of the RFFT by a factor of two, since these
# components are de-duplicated in the RFFT.
rfft = rfft_fn(grad, fft_length)
return rfft * math_ops.cast(rsize * mask, dtypes.complex64), None
return _Grad
ops.RegisterGradient("RFFT")(_RFFTGradHelper(1, spectral_ops.irfft))
ops.RegisterGradient("IRFFT")(_IRFFTGradHelper(1, spectral_ops.rfft))
ops.RegisterGradient("RFFT2D")(_RFFTGradHelper(2, spectral_ops.irfft2d))
ops.RegisterGradient("IRFFT2D")(_IRFFTGradHelper(2, spectral_ops.rfft2d))