_logsumexp.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from scipy._lib._util import _asarray_validated
  4. __all__ = ["logsumexp", "softmax"]
  5. def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
  6. """Compute the log of the sum of exponentials of input elements.
  7. Parameters
  8. ----------
  9. a : array_like
  10. Input array.
  11. axis : None or int or tuple of ints, optional
  12. Axis or axes over which the sum is taken. By default `axis` is None,
  13. and all elements are summed.
  14. .. versionadded:: 0.11.0
  15. keepdims : bool, optional
  16. If this is set to True, the axes which are reduced are left in the
  17. result as dimensions with size one. With this option, the result
  18. will broadcast correctly against the original array.
  19. .. versionadded:: 0.15.0
  20. b : array-like, optional
  21. Scaling factor for exp(`a`) must be of the same shape as `a` or
  22. broadcastable to `a`. These values may be negative in order to
  23. implement subtraction.
  24. .. versionadded:: 0.12.0
  25. return_sign : bool, optional
  26. If this is set to True, the result will be a pair containing sign
  27. information; if False, results that are negative will be returned
  28. as NaN. Default is False (no sign information).
  29. .. versionadded:: 0.16.0
  30. Returns
  31. -------
  32. res : ndarray
  33. The result, ``np.log(np.sum(np.exp(a)))`` calculated in a numerically
  34. more stable way. If `b` is given then ``np.log(np.sum(b*np.exp(a)))``
  35. is returned.
  36. sgn : ndarray
  37. If return_sign is True, this will be an array of floating-point
  38. numbers matching res and +1, 0, or -1 depending on the sign
  39. of the result. If False, only one result is returned.
  40. See Also
  41. --------
  42. numpy.logaddexp, numpy.logaddexp2
  43. Notes
  44. -----
  45. Numpy has a logaddexp function which is very similar to `logsumexp`, but
  46. only handles two arguments. `logaddexp.reduce` is similar to this
  47. function, but may be less stable.
  48. Examples
  49. --------
  50. >>> from scipy.special import logsumexp
  51. >>> a = np.arange(10)
  52. >>> np.log(np.sum(np.exp(a)))
  53. 9.4586297444267107
  54. >>> logsumexp(a)
  55. 9.4586297444267107
  56. With weights
  57. >>> a = np.arange(10)
  58. >>> b = np.arange(10, 0, -1)
  59. >>> logsumexp(a, b=b)
  60. 9.9170178533034665
  61. >>> np.log(np.sum(b*np.exp(a)))
  62. 9.9170178533034647
  63. Returning a sign flag
  64. >>> logsumexp([1,2],b=[1,-1],return_sign=True)
  65. (1.5413248546129181, -1.0)
  66. Notice that `logsumexp` does not directly support masked arrays. To use it
  67. on a masked array, convert the mask into zero weights:
  68. >>> a = np.ma.array([np.log(2), 2, np.log(3)],
  69. ... mask=[False, True, False])
  70. >>> b = (~a.mask).astype(int)
  71. >>> logsumexp(a.data, b=b), np.log(5)
  72. 1.6094379124341005, 1.6094379124341005
  73. """
  74. a = _asarray_validated(a, check_finite=False)
  75. if b is not None:
  76. a, b = np.broadcast_arrays(a, b)
  77. if np.any(b == 0):
  78. a = a + 0. # promote to at least float
  79. a[b == 0] = -np.inf
  80. a_max = np.amax(a, axis=axis, keepdims=True)
  81. if a_max.ndim > 0:
  82. a_max[~np.isfinite(a_max)] = 0
  83. elif not np.isfinite(a_max):
  84. a_max = 0
  85. if b is not None:
  86. b = np.asarray(b)
  87. tmp = b * np.exp(a - a_max)
  88. else:
  89. tmp = np.exp(a - a_max)
  90. # suppress warnings about log of zero
  91. with np.errstate(divide='ignore'):
  92. s = np.sum(tmp, axis=axis, keepdims=keepdims)
  93. if return_sign:
  94. sgn = np.sign(s)
  95. s *= sgn # /= makes more sense but we need zero -> zero
  96. out = np.log(s)
  97. if not keepdims:
  98. a_max = np.squeeze(a_max, axis=axis)
  99. out += a_max
  100. if return_sign:
  101. return out, sgn
  102. else:
  103. return out
  104. def softmax(x, axis=None):
  105. r"""
  106. Softmax function
  107. The softmax function transforms each element of a collection by
  108. computing the exponential of each element divided by the sum of the
  109. exponentials of all the elements. That is, if `x` is a one-dimensional
  110. numpy array::
  111. softmax(x) = np.exp(x)/sum(np.exp(x))
  112. Parameters
  113. ----------
  114. x : array_like
  115. Input array.
  116. axis : int or tuple of ints, optional
  117. Axis to compute values along. Default is None and softmax will be
  118. computed over the entire array `x`.
  119. Returns
  120. -------
  121. s : ndarray
  122. An array the same shape as `x`. The result will sum to 1 along the
  123. specified axis.
  124. Notes
  125. -----
  126. The formula for the softmax function :math:`\sigma(x)` for a vector
  127. :math:`x = \{x_0, x_1, ..., x_{n-1}\}` is
  128. .. math:: \sigma(x)_j = \frac{e^{x_j}}{\sum_k e^{x_k}}
  129. The `softmax` function is the gradient of `logsumexp`.
  130. .. versionadded:: 1.2.0
  131. Examples
  132. --------
  133. >>> from scipy.special import softmax
  134. >>> np.set_printoptions(precision=5)
  135. >>> x = np.array([[1, 0.5, 0.2, 3],
  136. ... [1, -1, 7, 3],
  137. ... [2, 12, 13, 3]])
  138. ...
  139. Compute the softmax transformation over the entire array.
  140. >>> m = softmax(x)
  141. >>> m
  142. array([[ 4.48309e-06, 2.71913e-06, 2.01438e-06, 3.31258e-05],
  143. [ 4.48309e-06, 6.06720e-07, 1.80861e-03, 3.31258e-05],
  144. [ 1.21863e-05, 2.68421e-01, 7.29644e-01, 3.31258e-05]])
  145. >>> m.sum()
  146. 1.0000000000000002
  147. Compute the softmax transformation along the first axis (i.e. the columns).
  148. >>> m = softmax(x, axis=0)
  149. >>> m
  150. array([[ 2.11942e-01, 1.01300e-05, 2.75394e-06, 3.33333e-01],
  151. [ 2.11942e-01, 2.26030e-06, 2.47262e-03, 3.33333e-01],
  152. [ 5.76117e-01, 9.99988e-01, 9.97525e-01, 3.33333e-01]])
  153. >>> m.sum(axis=0)
  154. array([ 1., 1., 1., 1.])
  155. Compute the softmax transformation along the second axis (i.e. the rows).
  156. >>> m = softmax(x, axis=1)
  157. >>> m
  158. array([[ 1.05877e-01, 6.42177e-02, 4.75736e-02, 7.82332e-01],
  159. [ 2.42746e-03, 3.28521e-04, 9.79307e-01, 1.79366e-02],
  160. [ 1.22094e-05, 2.68929e-01, 7.31025e-01, 3.31885e-05]])
  161. >>> m.sum(axis=1)
  162. array([ 1., 1., 1.])
  163. """
  164. # compute in log space for numerical stability
  165. return np.exp(x - logsumexp(x, axis=axis, keepdims=True))