130 lines
3.8 KiB
Python
130 lines
3.8 KiB
Python
|
from __future__ import division, print_function, absolute_import
|
||
|
|
||
|
import numpy as np
|
||
|
from scipy._lib._util import _asarray_validated
|
||
|
|
||
|
__all__ = ["logsumexp"]
|
||
|
|
||
|
|
||
|
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
|
||
|
"""Compute the log of the sum of exponentials of input elements.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
a : array_like
|
||
|
Input array.
|
||
|
axis : None or int or tuple of ints, optional
|
||
|
Axis or axes over which the sum is taken. By default `axis` is None,
|
||
|
and all elements are summed.
|
||
|
|
||
|
.. versionadded:: 0.11.0
|
||
|
keepdims : bool, optional
|
||
|
If this is set to True, the axes which are reduced are left in the
|
||
|
result as dimensions with size one. With this option, the result
|
||
|
will broadcast correctly against the original array.
|
||
|
|
||
|
.. versionadded:: 0.15.0
|
||
|
b : array-like, optional
|
||
|
Scaling factor for exp(`a`) must be of the same shape as `a` or
|
||
|
broadcastable to `a`. These values may be negative in order to
|
||
|
implement subtraction.
|
||
|
|
||
|
.. versionadded:: 0.12.0
|
||
|
return_sign : bool, optional
|
||
|
If this is set to True, the result will be a pair containing sign
|
||
|
information; if False, results that are negative will be returned
|
||
|
as NaN. Default is False (no sign information).
|
||
|
|
||
|
.. versionadded:: 0.16.0
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
res : ndarray
|
||
|
The result, ``np.log(np.sum(np.exp(a)))`` calculated in a numerically
|
||
|
more stable way. If `b` is given then ``np.log(np.sum(b*np.exp(a)))``
|
||
|
is returned.
|
||
|
sgn : ndarray
|
||
|
If return_sign is True, this will be an array of floating-point
|
||
|
numbers matching res and +1, 0, or -1 depending on the sign
|
||
|
of the result. If False, only one result is returned.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
numpy.logaddexp, numpy.logaddexp2
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
Numpy has a logaddexp function which is very similar to `logsumexp`, but
|
||
|
only handles two arguments. `logaddexp.reduce` is similar to this
|
||
|
function, but may be less stable.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> from scipy.special import logsumexp
|
||
|
>>> a = np.arange(10)
|
||
|
>>> np.log(np.sum(np.exp(a)))
|
||
|
9.4586297444267107
|
||
|
>>> logsumexp(a)
|
||
|
9.4586297444267107
|
||
|
|
||
|
With weights
|
||
|
|
||
|
>>> a = np.arange(10)
|
||
|
>>> b = np.arange(10, 0, -1)
|
||
|
>>> logsumexp(a, b=b)
|
||
|
9.9170178533034665
|
||
|
>>> np.log(np.sum(b*np.exp(a)))
|
||
|
9.9170178533034647
|
||
|
|
||
|
Returning a sign flag
|
||
|
|
||
|
>>> logsumexp([1,2],b=[1,-1],return_sign=True)
|
||
|
(1.5413248546129181, -1.0)
|
||
|
|
||
|
Notice that `logsumexp` does not directly support masked arrays. To use it
|
||
|
on a masked array, convert the mask into zero weights:
|
||
|
|
||
|
>>> a = np.ma.array([np.log(2), 2, np.log(3)],
|
||
|
... mask=[False, True, False])
|
||
|
>>> b = (~a.mask).astype(int)
|
||
|
>>> logsumexp(a.data, b=b), np.log(5)
|
||
|
1.6094379124341005, 1.6094379124341005
|
||
|
|
||
|
"""
|
||
|
a = _asarray_validated(a, check_finite=False)
|
||
|
if b is not None:
|
||
|
a, b = np.broadcast_arrays(a, b)
|
||
|
if np.any(b == 0):
|
||
|
a = a + 0. # promote to at least float
|
||
|
a[b == 0] = -np.inf
|
||
|
|
||
|
a_max = np.amax(a, axis=axis, keepdims=True)
|
||
|
|
||
|
if a_max.ndim > 0:
|
||
|
a_max[~np.isfinite(a_max)] = 0
|
||
|
elif not np.isfinite(a_max):
|
||
|
a_max = 0
|
||
|
|
||
|
if b is not None:
|
||
|
b = np.asarray(b)
|
||
|
tmp = b * np.exp(a - a_max)
|
||
|
else:
|
||
|
tmp = np.exp(a - a_max)
|
||
|
|
||
|
# suppress warnings about log of zero
|
||
|
with np.errstate(divide='ignore'):
|
||
|
s = np.sum(tmp, axis=axis, keepdims=keepdims)
|
||
|
if return_sign:
|
||
|
sgn = np.sign(s)
|
||
|
s *= sgn # /= makes more sense but we need zero -> zero
|
||
|
out = np.log(s)
|
||
|
|
||
|
if not keepdims:
|
||
|
a_max = np.squeeze(a_max, axis=axis)
|
||
|
out += a_max
|
||
|
|
||
|
if return_sign:
|
||
|
return out, sgn
|
||
|
else:
|
||
|
return out
|