laywerrobot/lib/python3.6/site-packages/tensorflow/python/ops/linalg_grad.py

397 lines
15 KiB
Python
Raw Normal View History

2020-08-27 21:55:39 +02:00
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradients for operators defined in linalg_ops.py.
Useful reference for derivative formulas is
An extended collection of matrix derivative results for forward and reverse
mode algorithmic differentiation by Mike Giles:
http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
A detailed derivation of formulas for backpropagating through spectral layers
(SVD and Eig) by Ionescu, Vantzos & Sminchisescu:
https://arxiv.org/pdf/1509.07838v4.pdf
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg_impl as _linalg
@ops.RegisterGradient("MatrixInverse")
def _MatrixInverseGrad(op, grad):
"""Gradient for MatrixInverse."""
ainv = op.outputs[0]
return -math_ops.matmul(
ainv, math_ops.matmul(grad, ainv, adjoint_b=True), adjoint_a=True)
@ops.RegisterGradient("MatrixDeterminant")
def _MatrixDeterminantGrad(op, grad):
"""Gradient for MatrixDeterminant."""
a = op.inputs[0]
c = op.outputs[0]
a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
multipliers = array_ops.reshape(grad * c,
array_ops.concat([array_ops.shape(c), [1, 1]],
0))
return multipliers * a_adj_inv
@ops.RegisterGradient("LogMatrixDeterminant")
def _LogMatrixDeterminantGrad(op, _, grad_b):
"""Gradient for LogMatrixDeterminant."""
a = op.inputs[0]
c = op.outputs[1]
a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
multipliers = array_ops.reshape(
grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0))
return multipliers * a_adj_inv
@ops.RegisterGradient("Cholesky")
def _CholeskyGrad(op, grad):
"""Gradient for Cholesky."""
# Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
l = op.outputs[0]
num_rows = array_ops.shape(l)[-1]
batch_shape = array_ops.shape(l)[:-2]
l_inverse = linalg_ops.matrix_triangular_solve(l,
linalg_ops.eye(
num_rows,
batch_shape=batch_shape,
dtype=l.dtype))
middle = math_ops.matmul(l, grad, adjoint_a=True)
middle = array_ops.matrix_set_diag(middle,
0.5 * array_ops.matrix_diag_part(middle))
middle = array_ops.matrix_band_part(middle, -1, 0)
grad_a = math_ops.matmul(
math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
grad_a += _linalg.adjoint(grad_a)
return grad_a * 0.5
@ops.RegisterGradient("Qr")
def _QrGrad(op, dq, dr):
"""Gradient for Qr."""
q, r = op.outputs
if q.dtype.is_complex:
raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype)
if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
r.shape.as_list()[-1] is None):
raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
if r.shape[-2].value != r.shape[-1].value:
raise NotImplementedError("QrGrad not implemented when ncols > nrows "
"or full_matrices is true and ncols != nrows.")
qdq = math_ops.matmul(q, dq, adjoint_a=True)
qdq_ = qdq - _linalg.adjoint(qdq)
rdr = math_ops.matmul(r, dr, adjoint_b=True)
rdr_ = rdr - _linalg.adjoint(rdr)
tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0)
def _TriangularSolve(x, r):
"""Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
return _linalg.adjoint(
linalg_ops.matrix_triangular_solve(
r, _linalg.adjoint(x), lower=False, adjoint=False))
grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
return grad_a + grad_b
@ops.RegisterGradient("MatrixSolve")
def _MatrixSolveGrad(op, grad):
"""Gradient for MatrixSolve."""
a = op.inputs[0]
adjoint_a = op.get_attr("adjoint")
c = op.outputs[0]
grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a)
if adjoint_a:
grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
else:
grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
return (grad_a, grad_b)
@ops.RegisterGradient("MatrixSolveLs")
def _MatrixSolveLsGrad(op, grad):
"""Gradients for MatrixSolveLs."""
# TODO(rmlarsen): The implementation could be more efficient:
# a) Output the Cholesky factorization from forward op instead of
# recomputing it here.
# b) Implement a symmetric rank-k update op instead of computing
# x*z + transpose(x*z). This pattern occurs other places in TensorFlow.
def _Overdetermined(op, grad):
"""Gradients for the overdetermined case of MatrixSolveLs.
This is the backprop for the solution to the normal equations of the first
kind:
X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B
which solve the least squares problem
min ||A * X - B||_F^2 + lambda ||X||_F^2.
"""
a = op.inputs[0]
b = op.inputs[1]
x = op.outputs[0]
l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
# pylint: disable=protected-access
chol = linalg_ops._RegularizedGramianCholesky(
a, l2_regularizer=l2_regularizer, first_kind=True)
# pylint: enable=protected-access
# Temporary z = (A^T * A + lambda * I)^{-1} * grad.
z = linalg_ops.cholesky_solve(chol, grad)
xzt = math_ops.matmul(x, z, adjoint_b=True)
zx_sym = xzt + array_ops.matrix_transpose(xzt)
grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True)
grad_b = math_ops.matmul(a, z)
return (grad_a, grad_b, None)
def _Underdetermined(op, grad):
"""Gradients for the underdetermined case of MatrixSolveLs.
This is the backprop for the solution to the normal equations of the second
kind:
X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B
that (for lambda=0) solve the least squares problem
min ||X||_F subject to A*X = B.
"""
a = op.inputs[0]
b = op.inputs[1]
l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
# pylint: disable=protected-access
chol = linalg_ops._RegularizedGramianCholesky(
a, l2_regularizer=l2_regularizer, first_kind=False)
# pylint: enable=protected-access
grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad))
# Temporary tmp = (A * A^T + lambda * I)^{-1} * B.
tmp = linalg_ops.cholesky_solve(chol, b)
a1 = math_ops.matmul(tmp, a, adjoint_a=True)
a1 = -math_ops.matmul(grad_b, a1)
a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True)
a2 = math_ops.matmul(tmp, a2, adjoint_b=True)
grad_a = a1 + a2
return (grad_a, grad_b, None)
fast = op.get_attr("fast")
if fast is False:
raise ValueError("Gradient not defined for fast=False")
matrix_shape = op.inputs[0].get_shape()[-2:]
if matrix_shape.is_fully_defined():
if matrix_shape[-2] >= matrix_shape[-1]:
return _Overdetermined(op, grad)
else:
return _Underdetermined(op, grad)
else:
# We have to defer determining the shape to runtime and use
# conditional execution of the appropriate graph.
matrix_shape = array_ops.shape(op.inputs[0])[-2:]
return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1],
lambda: _Overdetermined(op, grad),
lambda: _Underdetermined(op, grad))
@ops.RegisterGradient("MatrixTriangularSolve")
def _MatrixTriangularSolveGrad(op, grad):
"""Gradient for MatrixTriangularSolve."""
a = op.inputs[0]
adjoint_a = op.get_attr("adjoint")
lower_a = op.get_attr("lower")
c = op.outputs[0]
grad_b = linalg_ops.matrix_triangular_solve(
a, grad, lower=lower_a, adjoint=not adjoint_a)
if adjoint_a:
grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
else:
grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
if lower_a:
grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
else:
grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
return (grad_a, grad_b)
@ops.RegisterGradient("SelfAdjointEigV2")
def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
"""Gradient for SelfAdjointEigV2."""
e = op.outputs[0]
compute_v = op.get_attr("compute_v")
# a = op.inputs[0], which satisfies
# a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
with ops.control_dependencies([grad_e, grad_v]):
if compute_v:
v = op.outputs[1]
# Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
# Notice that because of the term involving f, the gradient becomes
# infinite (or NaN in practice) when eigenvalues are not unique.
# Mathematically this should not be surprising, since for (k-fold)
# degenerate eigenvalues, the corresponding eigenvectors are only defined
# up to arbitrary rotation in a (k-dimensional) subspace.
f = array_ops.matrix_set_diag(
math_ops.reciprocal(
array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
array_ops.zeros_like(e))
grad_a = math_ops.matmul(
v,
math_ops.matmul(
array_ops.matrix_diag(grad_e) +
f * math_ops.matmul(v, grad_v, adjoint_a=True),
v,
adjoint_b=True))
else:
_, v = linalg_ops.self_adjoint_eig(op.inputs[0])
grad_a = math_ops.matmul(v,
math_ops.matmul(
array_ops.matrix_diag(grad_e),
v,
adjoint_b=True))
# The forward op only depends on the lower triangular part of a, so here we
# symmetrize and take the lower triangle
grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0)
grad_a = array_ops.matrix_set_diag(grad_a,
0.5 * array_ops.matrix_diag_part(grad_a))
return grad_a
@ops.RegisterGradient("Svd")
def _SvdGrad(op, grad_s, grad_u, grad_v):
"""Gradient for the singular value decomposition."""
# The derivation for the compute_uv=False case, and most of
# the derivation for the full_matrices=True case, are in
# Giles' paper (see reference at top of file). A derivation for
# the full_matrices=False case is available at
# https://j-towns.github.io/papers/svd-derivative.pdf
a = op.inputs[0]
a_shape = a.get_shape().with_rank_at_least(2)
grad_s_mat = array_ops.matrix_diag(grad_s)
if not op.get_attr("compute_uv"):
s, u, v = linalg_ops.svd(a, compute_uv=True)
grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
grad_a.set_shape(a_shape)
return grad_a
full_matrices = op.get_attr("full_matrices")
# TODO(rmlarsen): Make this work with complex types.
if a.dtype.is_complex:
raise NotImplementedError(
"SVD gradient is not implemented for complex types and "
"compute_uv=True.")
grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
m = a_shape[-2].merge_with(grad_u_shape[-2])
n = a_shape[-1].merge_with(grad_v_shape[-2])
batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
grad_v_shape[:-2])
a_shape = batch_shape.concatenate([m, n])
m = a_shape[-2].value
n = a_shape[-1].value
# TODO(rmlarsen): Make this work with placeholders.
if m is None or n is None:
raise NotImplementedError(
"SVD gradient has not been implemented for input with unknown "
"inner matrix shape.")
s = op.outputs[0]
u = op.outputs[1]
v = op.outputs[2]
use_adjoint = False
if m > n:
# Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
# Hermitian transpose of the gradient at the end.
use_adjoint = True
m, n = n, m
u, v = v, u
grad_u, grad_v = grad_v, grad_u
with ops.control_dependencies([grad_s, grad_u, grad_v]):
if full_matrices and abs(m - n) > 1:
raise NotImplementedError(
"svd gradient is not implemented for abs(m - n) > 1 "
"when full_matrices is True")
s_mat = array_ops.matrix_diag(s)
s2 = math_ops.square(s)
# NOTICE: Because of the term involving f, the gradient becomes
# infinite (or NaN in practice) when singular values are not unique.
# Mathematically this should not be surprising, since for (k-fold)
# degenerate singular values, the corresponding singular vectors are
# only defined up a (k-dimensional) subspace. In practice, this can
# lead to numerical instability when singular values are close but not
# exactly equal.
f = array_ops.matrix_set_diag(
math_ops.reciprocal(
array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
array_ops.zeros_like(s))
s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s))
v1 = v[..., :, :m]
grad_v1 = grad_v[..., :, :m]
u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True)
f_u = f * u_gu
f_v = f * v_gv
term1_nouv = (
grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))
term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True))
if m == n:
grad_a_before_transpose = term1
else:
gv1t = array_ops.matrix_transpose(grad_v1)
gv1t_v1 = math_ops.matmul(gv1t, v1)
term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
if full_matrices:
v2 = v[..., :, m:n]
grad_v2 = grad_v[..., :, m:n]
v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True)
term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True)
u_s_inv = math_ops.matmul(u, s_inv_mat)
term2 = math_ops.matmul(u_s_inv, term2_nous)
grad_a_before_transpose = term1 + term2
if use_adjoint:
grad_a = array_ops.matrix_transpose(grad_a_before_transpose)
else:
grad_a = grad_a_before_transpose
grad_a.set_shape(a_shape)
return grad_a