test_linesearch.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. """
  2. Tests for line search routines
  3. """
  4. from __future__ import division, print_function, absolute_import
  5. from numpy.testing import assert_, assert_equal, \
  6. assert_array_almost_equal, assert_array_almost_equal_nulp, assert_warns
  7. from scipy._lib._numpy_compat import suppress_warnings
  8. import scipy.optimize.linesearch as ls
  9. from scipy.optimize.linesearch import LineSearchWarning
  10. import numpy as np
  11. def assert_wolfe(s, phi, derphi, c1=1e-4, c2=0.9, err_msg=""):
  12. """
  13. Check that strong Wolfe conditions apply
  14. """
  15. phi1 = phi(s)
  16. phi0 = phi(0)
  17. derphi0 = derphi(0)
  18. derphi1 = derphi(s)
  19. msg = "s = %s; phi(0) = %s; phi(s) = %s; phi'(0) = %s; phi'(s) = %s; %s" % (
  20. s, phi0, phi1, derphi0, derphi1, err_msg)
  21. assert_(phi1 <= phi0 + c1*s*derphi0, "Wolfe 1 failed: " + msg)
  22. assert_(abs(derphi1) <= abs(c2*derphi0), "Wolfe 2 failed: " + msg)
  23. def assert_armijo(s, phi, c1=1e-4, err_msg=""):
  24. """
  25. Check that Armijo condition applies
  26. """
  27. phi1 = phi(s)
  28. phi0 = phi(0)
  29. msg = "s = %s; phi(0) = %s; phi(s) = %s; %s" % (s, phi0, phi1, err_msg)
  30. assert_(phi1 <= (1 - c1*s)*phi0, msg)
  31. def assert_line_wolfe(x, p, s, f, fprime, **kw):
  32. assert_wolfe(s, phi=lambda sp: f(x + p*sp),
  33. derphi=lambda sp: np.dot(fprime(x + p*sp), p), **kw)
  34. def assert_line_armijo(x, p, s, f, **kw):
  35. assert_armijo(s, phi=lambda sp: f(x + p*sp), **kw)
  36. def assert_fp_equal(x, y, err_msg="", nulp=50):
  37. """Assert two arrays are equal, up to some floating-point rounding error"""
  38. try:
  39. assert_array_almost_equal_nulp(x, y, nulp)
  40. except AssertionError as e:
  41. raise AssertionError("%s\n%s" % (e, err_msg))
  42. class TestLineSearch(object):
  43. # -- scalar functions; must have dphi(0.) < 0
  44. def _scalar_func_1(self, s):
  45. self.fcount += 1
  46. p = -s - s**3 + s**4
  47. dp = -1 - 3*s**2 + 4*s**3
  48. return p, dp
  49. def _scalar_func_2(self, s):
  50. self.fcount += 1
  51. p = np.exp(-4*s) + s**2
  52. dp = -4*np.exp(-4*s) + 2*s
  53. return p, dp
  54. def _scalar_func_3(self, s):
  55. self.fcount += 1
  56. p = -np.sin(10*s)
  57. dp = -10*np.cos(10*s)
  58. return p, dp
  59. # -- n-d functions
  60. def _line_func_1(self, x):
  61. self.fcount += 1
  62. f = np.dot(x, x)
  63. df = 2*x
  64. return f, df
  65. def _line_func_2(self, x):
  66. self.fcount += 1
  67. f = np.dot(x, np.dot(self.A, x)) + 1
  68. df = np.dot(self.A + self.A.T, x)
  69. return f, df
  70. # --
  71. def setup_method(self):
  72. self.scalar_funcs = []
  73. self.line_funcs = []
  74. self.N = 20
  75. self.fcount = 0
  76. def bind_index(func, idx):
  77. # Remember Python's closure semantics!
  78. return lambda *a, **kw: func(*a, **kw)[idx]
  79. for name in sorted(dir(self)):
  80. if name.startswith('_scalar_func_'):
  81. value = getattr(self, name)
  82. self.scalar_funcs.append(
  83. (name, bind_index(value, 0), bind_index(value, 1)))
  84. elif name.startswith('_line_func_'):
  85. value = getattr(self, name)
  86. self.line_funcs.append(
  87. (name, bind_index(value, 0), bind_index(value, 1)))
  88. np.random.seed(1234)
  89. self.A = np.random.randn(self.N, self.N)
  90. def scalar_iter(self):
  91. for name, phi, derphi in self.scalar_funcs:
  92. for old_phi0 in np.random.randn(3):
  93. yield name, phi, derphi, old_phi0
  94. def line_iter(self):
  95. for name, f, fprime in self.line_funcs:
  96. k = 0
  97. while k < 9:
  98. x = np.random.randn(self.N)
  99. p = np.random.randn(self.N)
  100. if np.dot(p, fprime(x)) >= 0:
  101. # always pick a descent direction
  102. continue
  103. k += 1
  104. old_fv = float(np.random.randn())
  105. yield name, f, fprime, x, p, old_fv
  106. # -- Generic scalar searches
  107. def test_scalar_search_wolfe1(self):
  108. c = 0
  109. for name, phi, derphi, old_phi0 in self.scalar_iter():
  110. c += 1
  111. s, phi1, phi0 = ls.scalar_search_wolfe1(phi, derphi, phi(0),
  112. old_phi0, derphi(0))
  113. assert_fp_equal(phi0, phi(0), name)
  114. assert_fp_equal(phi1, phi(s), name)
  115. assert_wolfe(s, phi, derphi, err_msg=name)
  116. assert_(c > 3) # check that the iterator really works...
  117. def test_scalar_search_wolfe2(self):
  118. for name, phi, derphi, old_phi0 in self.scalar_iter():
  119. s, phi1, phi0, derphi1 = ls.scalar_search_wolfe2(
  120. phi, derphi, phi(0), old_phi0, derphi(0))
  121. assert_fp_equal(phi0, phi(0), name)
  122. assert_fp_equal(phi1, phi(s), name)
  123. if derphi1 is not None:
  124. assert_fp_equal(derphi1, derphi(s), name)
  125. assert_wolfe(s, phi, derphi, err_msg="%s %g" % (name, old_phi0))
  126. def test_scalar_search_armijo(self):
  127. for name, phi, derphi, old_phi0 in self.scalar_iter():
  128. s, phi1 = ls.scalar_search_armijo(phi, phi(0), derphi(0))
  129. assert_fp_equal(phi1, phi(s), name)
  130. assert_armijo(s, phi, err_msg="%s %g" % (name, old_phi0))
  131. # -- Generic line searches
  132. def test_line_search_wolfe1(self):
  133. c = 0
  134. smax = 100
  135. for name, f, fprime, x, p, old_f in self.line_iter():
  136. f0 = f(x)
  137. g0 = fprime(x)
  138. self.fcount = 0
  139. s, fc, gc, fv, ofv, gv = ls.line_search_wolfe1(f, fprime, x, p,
  140. g0, f0, old_f,
  141. amax=smax)
  142. assert_equal(self.fcount, fc+gc)
  143. assert_fp_equal(ofv, f(x))
  144. if s is None:
  145. continue
  146. assert_fp_equal(fv, f(x + s*p))
  147. assert_array_almost_equal(gv, fprime(x + s*p), decimal=14)
  148. if s < smax:
  149. c += 1
  150. assert_line_wolfe(x, p, s, f, fprime, err_msg=name)
  151. assert_(c > 3) # check that the iterator really works...
  152. def test_line_search_wolfe2(self):
  153. c = 0
  154. smax = 512
  155. for name, f, fprime, x, p, old_f in self.line_iter():
  156. f0 = f(x)
  157. g0 = fprime(x)
  158. self.fcount = 0
  159. with suppress_warnings() as sup:
  160. sup.filter(LineSearchWarning,
  161. "The line search algorithm could not find a solution")
  162. sup.filter(LineSearchWarning,
  163. "The line search algorithm did not converge")
  164. s, fc, gc, fv, ofv, gv = ls.line_search_wolfe2(f, fprime, x, p,
  165. g0, f0, old_f,
  166. amax=smax)
  167. assert_equal(self.fcount, fc+gc)
  168. assert_fp_equal(ofv, f(x))
  169. assert_fp_equal(fv, f(x + s*p))
  170. if gv is not None:
  171. assert_array_almost_equal(gv, fprime(x + s*p), decimal=14)
  172. if s < smax:
  173. c += 1
  174. assert_line_wolfe(x, p, s, f, fprime, err_msg=name)
  175. assert_(c > 3) # check that the iterator really works...
  176. def test_line_search_wolfe2_bounds(self):
  177. # See gh-7475
  178. # For this f and p, starting at a point on axis 0, the strong Wolfe
  179. # condition 2 is met if and only if the step length s satisfies
  180. # |x + s| <= c2 * |x|
  181. f = lambda x: np.dot(x, x)
  182. fp = lambda x: 2 * x
  183. p = np.array([1, 0])
  184. # Smallest s satisfying strong Wolfe conditions for these arguments is 30
  185. x = -60 * p
  186. c2 = 0.5
  187. s, _, _, _, _, _ = ls.line_search_wolfe2(f, fp, x, p, amax=30, c2=c2)
  188. assert_line_wolfe(x, p, s, f, fp)
  189. s, _, _, _, _, _ = assert_warns(LineSearchWarning,
  190. ls.line_search_wolfe2, f, fp, x, p,
  191. amax=29, c2=c2)
  192. assert_(s is None)
  193. # s=30 will only be tried on the 6th iteration, so this won't converge
  194. assert_warns(LineSearchWarning, ls.line_search_wolfe2, f, fp, x, p,
  195. c2=c2, maxiter=5)
  196. def test_line_search_armijo(self):
  197. c = 0
  198. for name, f, fprime, x, p, old_f in self.line_iter():
  199. f0 = f(x)
  200. g0 = fprime(x)
  201. self.fcount = 0
  202. s, fc, fv = ls.line_search_armijo(f, x, p, g0, f0)
  203. c += 1
  204. assert_equal(self.fcount, fc)
  205. assert_fp_equal(fv, f(x + s*p))
  206. assert_line_armijo(x, p, s, f, err_msg=name)
  207. assert_(c >= 9)
  208. # -- More specific tests
  209. def test_armijo_terminate_1(self):
  210. # Armijo should evaluate the function only once if the trial step
  211. # is already suitable
  212. count = [0]
  213. def phi(s):
  214. count[0] += 1
  215. return -s + 0.01*s**2
  216. s, phi1 = ls.scalar_search_armijo(phi, phi(0), -1, alpha0=1)
  217. assert_equal(s, 1)
  218. assert_equal(count[0], 2)
  219. assert_armijo(s, phi)
  220. def test_wolfe_terminate(self):
  221. # wolfe1 and wolfe2 should also evaluate the function only a few
  222. # times if the trial step is already suitable
  223. def phi(s):
  224. count[0] += 1
  225. return -s + 0.05*s**2
  226. def derphi(s):
  227. count[0] += 1
  228. return -1 + 0.05*2*s
  229. for func in [ls.scalar_search_wolfe1, ls.scalar_search_wolfe2]:
  230. count = [0]
  231. r = func(phi, derphi, phi(0), None, derphi(0))
  232. assert_(r[0] is not None, (r, func))
  233. assert_(count[0] <= 2 + 2, (count, func))
  234. assert_wolfe(r[0], phi, derphi, err_msg=str(func))