186 lines
6.8 KiB
Python
186 lines
6.8 KiB
Python
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Gradients for operators defined in spectral_ops.py."""
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from tensorflow.python.framework import dtypes
|
||
|
from tensorflow.python.framework import ops
|
||
|
from tensorflow.python.ops import array_ops
|
||
|
from tensorflow.python.ops import math_ops
|
||
|
from tensorflow.python.ops import spectral_ops
|
||
|
|
||
|
|
||
|
def _FFTSizeForGrad(grad, rank):
|
||
|
return math_ops.reduce_prod(array_ops.shape(grad)[-rank:])
|
||
|
|
||
|
|
||
|
@ops.RegisterGradient("FFT")
|
||
|
def _FFTGrad(_, grad):
|
||
|
size = math_ops.cast(_FFTSizeForGrad(grad, 1), grad.dtype)
|
||
|
return spectral_ops.ifft(grad) * size
|
||
|
|
||
|
|
||
|
@ops.RegisterGradient("IFFT")
|
||
|
def _IFFTGrad(_, grad):
|
||
|
rsize = math_ops.cast(
|
||
|
1. / math_ops.cast(_FFTSizeForGrad(grad, 1), grad.dtype.real_dtype),
|
||
|
grad.dtype)
|
||
|
return spectral_ops.fft(grad) * rsize
|
||
|
|
||
|
|
||
|
@ops.RegisterGradient("FFT2D")
|
||
|
def _FFT2DGrad(_, grad):
|
||
|
size = math_ops.cast(_FFTSizeForGrad(grad, 2), grad.dtype)
|
||
|
return spectral_ops.ifft2d(grad) * size
|
||
|
|
||
|
|
||
|
@ops.RegisterGradient("IFFT2D")
|
||
|
def _IFFT2DGrad(_, grad):
|
||
|
rsize = math_ops.cast(
|
||
|
1. / math_ops.cast(_FFTSizeForGrad(grad, 2), grad.dtype.real_dtype),
|
||
|
grad.dtype)
|
||
|
return spectral_ops.fft2d(grad) * rsize
|
||
|
|
||
|
|
||
|
@ops.RegisterGradient("FFT3D")
|
||
|
def _FFT3DGrad(_, grad):
|
||
|
size = math_ops.cast(_FFTSizeForGrad(grad, 3), grad.dtype)
|
||
|
return spectral_ops.ifft3d(grad) * size
|
||
|
|
||
|
|
||
|
@ops.RegisterGradient("IFFT3D")
|
||
|
def _IFFT3DGrad(_, grad):
|
||
|
rsize = math_ops.cast(
|
||
|
1. / math_ops.cast(_FFTSizeForGrad(grad, 3), grad.dtype.real_dtype),
|
||
|
grad.dtype)
|
||
|
return spectral_ops.fft3d(grad) * rsize
|
||
|
|
||
|
|
||
|
def _RFFTGradHelper(rank, irfft_fn):
|
||
|
"""Returns a gradient function for an RFFT of the provided rank."""
|
||
|
# Can't happen because we don't register a gradient for RFFT3D.
|
||
|
assert rank in (1, 2), "Gradient for RFFT3D is not implemented."
|
||
|
|
||
|
def _Grad(op, grad):
|
||
|
"""A gradient function for RFFT with the provided `rank` and `irfft_fn`."""
|
||
|
fft_length = op.inputs[1]
|
||
|
input_shape = array_ops.shape(op.inputs[0])
|
||
|
is_even = math_ops.cast(1 - (fft_length[-1] % 2), dtypes.complex64)
|
||
|
|
||
|
def _TileForBroadcasting(matrix, t):
|
||
|
expanded = array_ops.reshape(
|
||
|
matrix,
|
||
|
array_ops.concat([
|
||
|
array_ops.ones([array_ops.rank(t) - 2], dtypes.int32),
|
||
|
array_ops.shape(matrix)
|
||
|
], 0))
|
||
|
return array_ops.tile(
|
||
|
expanded, array_ops.concat([array_ops.shape(t)[:-2], [1, 1]], 0))
|
||
|
|
||
|
def _MaskMatrix(length):
|
||
|
# TODO(rjryan): Speed up computation of twiddle factors using the
|
||
|
# following recurrence relation and cache them across invocations of RFFT.
|
||
|
#
|
||
|
# t_n = exp(sqrt(-1) * pi * n^2 / line_len)
|
||
|
# for n = 0, 1,..., line_len-1.
|
||
|
# For n > 2, use t_n = t_{n-1}^2 / t_{n-2} * t_1^2
|
||
|
a = array_ops.tile(
|
||
|
array_ops.expand_dims(math_ops.range(length), 0), (length, 1))
|
||
|
b = array_ops.transpose(a, [1, 0])
|
||
|
return math_ops.exp(-2j * np.pi * math_ops.cast(a * b, dtypes.complex64) /
|
||
|
math_ops.cast(length, dtypes.complex64))
|
||
|
|
||
|
def _YMMask(length):
|
||
|
"""A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`."""
|
||
|
return math_ops.cast(1 - 2 * (math_ops.range(length) % 2),
|
||
|
dtypes.complex64)
|
||
|
|
||
|
y0 = grad[..., 0:1]
|
||
|
if rank == 1:
|
||
|
ym = grad[..., -1:]
|
||
|
extra_terms = y0 + is_even * ym * _YMMask(input_shape[-1])
|
||
|
elif rank == 2:
|
||
|
# Create a mask matrix for y0 and ym.
|
||
|
base_mask = _MaskMatrix(input_shape[-2])
|
||
|
|
||
|
# Tile base_mask to match y0 in shape so that we can batch-matmul the
|
||
|
# inner 2 dimensions.
|
||
|
tiled_mask = _TileForBroadcasting(base_mask, y0)
|
||
|
|
||
|
y0_term = math_ops.matmul(tiled_mask, math_ops.conj(y0))
|
||
|
extra_terms = y0_term
|
||
|
|
||
|
ym = grad[..., -1:]
|
||
|
ym_term = math_ops.matmul(tiled_mask, math_ops.conj(ym))
|
||
|
|
||
|
inner_dim = input_shape[-1]
|
||
|
ym_term = array_ops.tile(
|
||
|
ym_term,
|
||
|
array_ops.concat([
|
||
|
array_ops.ones([array_ops.rank(grad) - 1], dtypes.int32),
|
||
|
[inner_dim]
|
||
|
], 0)) * _YMMask(inner_dim)
|
||
|
|
||
|
extra_terms += is_even * ym_term
|
||
|
|
||
|
# The gradient of RFFT is the IRFFT of the incoming gradient times a scaling
|
||
|
# factor, plus some additional terms to make up for the components dropped
|
||
|
# due to Hermitian symmetry.
|
||
|
input_size = math_ops.to_float(_FFTSizeForGrad(op.inputs[0], rank))
|
||
|
irfft = irfft_fn(grad, fft_length)
|
||
|
return 0.5 * (irfft * input_size + math_ops.real(extra_terms)), None
|
||
|
|
||
|
return _Grad
|
||
|
|
||
|
|
||
|
def _IRFFTGradHelper(rank, rfft_fn):
|
||
|
"""Returns a gradient function for an IRFFT of the provided rank."""
|
||
|
# Can't happen because we don't register a gradient for IRFFT3D.
|
||
|
assert rank in (1, 2), "Gradient for IRFFT3D is not implemented."
|
||
|
|
||
|
def _Grad(op, grad):
|
||
|
"""A gradient function for IRFFT with the provided `rank` and `rfft_fn`."""
|
||
|
# Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs
|
||
|
# and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the
|
||
|
# graph we special-case the situation where the FFT length and last
|
||
|
# dimension of the input are known at graph construction time.
|
||
|
fft_length = op.inputs[1]
|
||
|
is_odd = math_ops.mod(fft_length[-1], 2)
|
||
|
input_last_dimension = array_ops.shape(op.inputs[0])[-1]
|
||
|
mask = array_ops.concat(
|
||
|
[[1.0], 2.0 * array_ops.ones([input_last_dimension - 2 + is_odd]),
|
||
|
array_ops.ones([1 - is_odd])], 0)
|
||
|
|
||
|
rsize = math_ops.reciprocal(math_ops.to_float(_FFTSizeForGrad(grad, rank)))
|
||
|
|
||
|
# The gradient of IRFFT is the RFFT of the incoming gradient times a scaling
|
||
|
# factor and a mask. The mask scales the gradient for the Hermitian
|
||
|
# symmetric components of the RFFT by a factor of two, since these
|
||
|
# components are de-duplicated in the RFFT.
|
||
|
rfft = rfft_fn(grad, fft_length)
|
||
|
return rfft * math_ops.cast(rsize * mask, dtypes.complex64), None
|
||
|
|
||
|
return _Grad
|
||
|
|
||
|
|
||
|
ops.RegisterGradient("RFFT")(_RFFTGradHelper(1, spectral_ops.irfft))
|
||
|
ops.RegisterGradient("IRFFT")(_IRFFTGradHelper(1, spectral_ops.rfft))
|
||
|
ops.RegisterGradient("RFFT2D")(_RFFTGradHelper(2, spectral_ops.irfft2d))
|
||
|
ops.RegisterGradient("IRFFT2D")(_IRFFTGradHelper(2, spectral_ops.rfft2d))
|