test_differentiable_functions.py 21 KB


  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from numpy.testing import (TestCase, assert_array_almost_equal,
  4. assert_array_equal, assert_)
  5. from scipy.sparse import csr_matrix
  6. from scipy.sparse.linalg import LinearOperator
  7. from scipy.optimize._differentiable_functions import (ScalarFunction,
  8. VectorFunction,
  9. LinearVectorFunction,
  10. IdentityVectorFunction)
  11. class ExScalarFunction:
  12. def __init__(self):
  13. self.nfev = 0
  14. self.ngev = 0
  15. self.nhev = 0
  16. def fun(self, x):
  17. self.nfev += 1
  18. return 2*(x[0]**2 + x[1]**2 - 1) - x[0]
  19. def grad(self, x):
  20. self.ngev += 1
  21. return np.array([4*x[0]-1, 4*x[1]])
  22. def hess(self, x):
  23. self.nhev += 1
  24. return 4*np.eye(2)
  25. class TestScalarFunction(TestCase):
  26. def test_finite_difference_grad(self):
  27. ex = ExScalarFunction()
  28. nfev = 0
  29. ngev = 0
  30. x0 = [1.0, 0.0]
  31. analit = ScalarFunction(ex.fun, x0, (), ex.grad,
  32. ex.hess, None, (-np.inf, np.inf))
  33. nfev += 1
  34. ngev += 1
  35. assert_array_equal(ex.nfev, nfev)
  36. assert_array_equal(analit.nfev, nfev)
  37. assert_array_equal(ex.ngev, ngev)
  38. assert_array_equal(analit.ngev, nfev)
  39. approx = ScalarFunction(ex.fun, x0, (), '2-point',
  40. ex.hess, None, (-np.inf, np.inf))
  41. nfev += 3
  42. assert_array_equal(ex.nfev, nfev)
  43. assert_array_equal(analit.nfev+approx.nfev, nfev)
  44. assert_array_equal(ex.ngev, ngev)
  45. assert_array_equal(analit.ngev+approx.ngev, ngev)
  46. assert_array_equal(analit.f, approx.f)
  47. assert_array_almost_equal(analit.g, approx.g)
  48. x = [10, 0.3]
  49. f_analit = analit.fun(x)
  50. g_analit = analit.grad(x)
  51. nfev += 1
  52. ngev += 1
  53. assert_array_equal(ex.nfev, nfev)
  54. assert_array_equal(analit.nfev+approx.nfev, nfev)
  55. assert_array_equal(ex.ngev, ngev)
  56. assert_array_equal(analit.ngev+approx.ngev, ngev)
  57. f_approx = approx.fun(x)
  58. g_approx = approx.grad(x)
  59. nfev += 3
  60. assert_array_equal(ex.nfev, nfev)
  61. assert_array_equal(analit.nfev+approx.nfev, nfev)
  62. assert_array_equal(ex.ngev, ngev)
  63. assert_array_equal(analit.ngev+approx.ngev, ngev)
  64. assert_array_almost_equal(f_analit, f_approx)
  65. assert_array_almost_equal(g_analit, g_approx)
  66. x = [2.0, 1.0]
  67. g_analit = analit.grad(x)
  68. ngev += 1
  69. assert_array_equal(ex.nfev, nfev)
  70. assert_array_equal(analit.nfev+approx.nfev, nfev)
  71. assert_array_equal(ex.ngev, ngev)
  72. assert_array_equal(analit.ngev+approx.ngev, ngev)
  73. g_approx = approx.grad(x)
  74. nfev += 3
  75. assert_array_equal(ex.nfev, nfev)
  76. assert_array_equal(analit.nfev+approx.nfev, nfev)
  77. assert_array_equal(ex.ngev, ngev)
  78. assert_array_equal(analit.ngev+approx.ngev, ngev)
  79. assert_array_almost_equal(g_analit, g_approx)
  80. x = [2.5, 0.3]
  81. f_analit = analit.fun(x)
  82. g_analit = analit.grad(x)
  83. nfev += 1
  84. ngev += 1
  85. assert_array_equal(ex.nfev, nfev)
  86. assert_array_equal(analit.nfev+approx.nfev, nfev)
  87. assert_array_equal(ex.ngev, ngev)
  88. assert_array_equal(analit.ngev+approx.ngev, ngev)
  89. f_approx = approx.fun(x)
  90. g_approx = approx.grad(x)
  91. nfev += 3
  92. assert_array_equal(ex.nfev, nfev)
  93. assert_array_equal(analit.nfev+approx.nfev, nfev)
  94. assert_array_equal(ex.ngev, ngev)
  95. assert_array_equal(analit.ngev+approx.ngev, ngev)
  96. assert_array_almost_equal(f_analit, f_approx)
  97. assert_array_almost_equal(g_analit, g_approx)
  98. x = [2, 0.3]
  99. f_analit = analit.fun(x)
  100. g_analit = analit.grad(x)
  101. nfev += 1
  102. ngev += 1
  103. assert_array_equal(ex.nfev, nfev)
  104. assert_array_equal(analit.nfev+approx.nfev, nfev)
  105. assert_array_equal(ex.ngev, ngev)
  106. assert_array_equal(analit.ngev+approx.ngev, ngev)
  107. f_approx = approx.fun(x)
  108. g_approx = approx.grad(x)
  109. nfev += 3
  110. assert_array_equal(ex.nfev, nfev)
  111. assert_array_equal(analit.nfev+approx.nfev, nfev)
  112. assert_array_equal(ex.ngev, ngev)
  113. assert_array_equal(analit.ngev+approx.ngev, ngev)
  114. assert_array_almost_equal(f_analit, f_approx)
  115. assert_array_almost_equal(g_analit, g_approx)
  116. def test_finite_difference_hess_linear_operator(self):
  117. ex = ExScalarFunction()
  118. nfev = 0
  119. ngev = 0
  120. nhev = 0
  121. x0 = [1.0, 0.0]
  122. analit = ScalarFunction(ex.fun, x0, (), ex.grad,
  123. ex.hess, None, (-np.inf, np.inf))
  124. nfev += 1
  125. ngev += 1
  126. nhev += 1
  127. assert_array_equal(ex.nfev, nfev)
  128. assert_array_equal(analit.nfev, nfev)
  129. assert_array_equal(ex.ngev, ngev)
  130. assert_array_equal(analit.ngev, ngev)
  131. assert_array_equal(ex.nhev, nhev)
  132. assert_array_equal(analit.nhev, nhev)
  133. approx = ScalarFunction(ex.fun, x0, (), ex.grad,
  134. '2-point', None, (-np.inf, np.inf))
  135. assert_(isinstance(approx.H, LinearOperator))
  136. for v in ([1.0, 2.0], [3.0, 4.0], [5.0, 2.0]):
  137. assert_array_equal(analit.f, approx.f)
  138. assert_array_almost_equal(analit.g, approx.g)
  139. assert_array_almost_equal(analit.H.dot(v), approx.H.dot(v))
  140. nfev += 1
  141. ngev += 4
  142. assert_array_equal(ex.nfev, nfev)
  143. assert_array_equal(analit.nfev+approx.nfev, nfev)
  144. assert_array_equal(ex.ngev, ngev)
  145. assert_array_equal(analit.ngev+approx.ngev, ngev)
  146. assert_array_equal(ex.nhev, nhev)
  147. assert_array_equal(analit.nhev+approx.nhev, nhev)
  148. x = [2.0, 1.0]
  149. H_analit = analit.hess(x)
  150. nhev += 1
  151. assert_array_equal(ex.nfev, nfev)
  152. assert_array_equal(analit.nfev+approx.nfev, nfev)
  153. assert_array_equal(ex.ngev, ngev)
  154. assert_array_equal(analit.ngev+approx.ngev, ngev)
  155. assert_array_equal(ex.nhev, nhev)
  156. assert_array_equal(analit.nhev+approx.nhev, nhev)
  157. H_approx = approx.hess(x)
  158. assert_(isinstance(H_approx, LinearOperator))
  159. for v in ([1.0, 2.0], [3.0, 4.0], [5.0, 2.0]):
  160. assert_array_almost_equal(H_analit.dot(v), H_approx.dot(v))
  161. ngev += 4
  162. assert_array_equal(ex.nfev, nfev)
  163. assert_array_equal(analit.nfev+approx.nfev, nfev)
  164. assert_array_equal(ex.ngev, ngev)
  165. assert_array_equal(analit.ngev+approx.ngev, ngev)
  166. assert_array_equal(ex.nhev, nhev)
  167. assert_array_equal(analit.nhev+approx.nhev, nhev)
  168. x = [2.1, 1.2]
  169. H_analit = analit.hess(x)
  170. nhev += 1
  171. assert_array_equal(ex.nfev, nfev)
  172. assert_array_equal(analit.nfev+approx.nfev, nfev)
  173. assert_array_equal(ex.ngev, ngev)
  174. assert_array_equal(analit.ngev+approx.ngev, ngev)
  175. assert_array_equal(ex.nhev, nhev)
  176. assert_array_equal(analit.nhev+approx.nhev, nhev)
  177. H_approx = approx.hess(x)
  178. assert_(isinstance(H_approx, LinearOperator))
  179. for v in ([1.0, 2.0], [3.0, 4.0], [5.0, 2.0]):
  180. assert_array_almost_equal(H_analit.dot(v), H_approx.dot(v))
  181. ngev += 4
  182. assert_array_equal(ex.nfev, nfev)
  183. assert_array_equal(analit.nfev+approx.nfev, nfev)
  184. assert_array_equal(ex.ngev, ngev)
  185. assert_array_equal(analit.ngev+approx.ngev, ngev)
  186. assert_array_equal(ex.nhev, nhev)
  187. assert_array_equal(analit.nhev+approx.nhev, nhev)
  188. x = [2.5, 0.3]
  189. _ = analit.grad(x)
  190. H_analit = analit.hess(x)
  191. ngev += 1
  192. nhev += 1
  193. assert_array_equal(ex.nfev, nfev)
  194. assert_array_equal(analit.nfev+approx.nfev, nfev)
  195. assert_array_equal(ex.ngev, ngev)
  196. assert_array_equal(analit.ngev+approx.ngev, ngev)
  197. assert_array_equal(ex.nhev, nhev)
  198. assert_array_equal(analit.nhev+approx.nhev, nhev)
  199. _ = approx.grad(x)
  200. H_approx = approx.hess(x)
  201. assert_(isinstance(H_approx, LinearOperator))
  202. for v in ([1.0, 2.0], [3.0, 4.0], [5.0, 2.0]):
  203. assert_array_almost_equal(H_analit.dot(v), H_approx.dot(v))
  204. ngev += 4
  205. assert_array_equal(ex.nfev, nfev)
  206. assert_array_equal(analit.nfev+approx.nfev, nfev)
  207. assert_array_equal(ex.ngev, ngev)
  208. assert_array_equal(analit.ngev+approx.ngev, ngev)
  209. assert_array_equal(ex.nhev, nhev)
  210. assert_array_equal(analit.nhev+approx.nhev, nhev)
  211. x = [5.2, 2.3]
  212. _ = analit.grad(x)
  213. H_analit = analit.hess(x)
  214. ngev += 1
  215. nhev += 1
  216. assert_array_equal(ex.nfev, nfev)
  217. assert_array_equal(analit.nfev+approx.nfev, nfev)
  218. assert_array_equal(ex.ngev, ngev)
  219. assert_array_equal(analit.ngev+approx.ngev, ngev)
  220. assert_array_equal(ex.nhev, nhev)
  221. assert_array_equal(analit.nhev+approx.nhev, nhev)
  222. _ = approx.grad(x)
  223. H_approx = approx.hess(x)
  224. assert_(isinstance(H_approx, LinearOperator))
  225. for v in ([1.0, 2.0], [3.0, 4.0], [5.0, 2.0]):
  226. assert_array_almost_equal(H_analit.dot(v), H_approx.dot(v))
  227. ngev += 4
  228. assert_array_equal(ex.nfev, nfev)
  229. assert_array_equal(analit.nfev+approx.nfev, nfev)
  230. assert_array_equal(ex.ngev, ngev)
  231. assert_array_equal(analit.ngev+approx.ngev, ngev)
  232. assert_array_equal(ex.nhev, nhev)
  233. assert_array_equal(analit.nhev+approx.nhev, nhev)
  234. class ExVectorialFunction:
  235. def __init__(self):
  236. self.nfev = 0
  237. self.njev = 0
  238. self.nhev = 0
  239. def fun(self, x):
  240. self.nfev += 1
  241. return np.array([2*(x[0]**2 + x[1]**2 - 1) - x[0],
  242. 4*(x[0]**3 + x[1]**2 - 4) - 3*x[0]])
  243. def jac(self, x):
  244. self.njev += 1
  245. return np.array([[4*x[0]-1, 4*x[1]],
  246. [12*x[0]**2-3, 8*x[1]]])
  247. def hess(self, x, v):
  248. self.nhev += 1
  249. return v[0]*4*np.eye(2) + v[1]*np.array([[24*x[0], 0],
  250. [0, 8]])
  251. class TestVectorialFunction(TestCase):
  252. def test_finite_difference_jac(self):
  253. ex = ExVectorialFunction()
  254. nfev = 0
  255. njev = 0
  256. x0 = [1.0, 0.0]
  257. v0 = [0.0, 1.0]
  258. analit = VectorFunction(ex.fun, x0, ex.jac, ex.hess, None, None,
  259. (-np.inf, np.inf), None)
  260. nfev += 1
  261. njev += 1
  262. assert_array_equal(ex.nfev, nfev)
  263. assert_array_equal(analit.nfev, nfev)
  264. assert_array_equal(ex.njev, njev)
  265. assert_array_equal(analit.njev, njev)
  266. approx = VectorFunction(ex.fun, x0, '2-point', ex.hess, None, None,
  267. (-np.inf, np.inf), None)
  268. nfev += 3
  269. assert_array_equal(ex.nfev, nfev)
  270. assert_array_equal(analit.nfev+approx.nfev, nfev)
  271. assert_array_equal(ex.njev, njev)
  272. assert_array_equal(analit.njev+approx.njev, njev)
  273. assert_array_equal(analit.f, approx.f)
  274. assert_array_almost_equal(analit.J, approx.J)
  275. x = [10, 0.3]
  276. f_analit = analit.fun(x)
  277. J_analit = analit.jac(x)
  278. nfev += 1
  279. njev += 1
  280. assert_array_equal(ex.nfev, nfev)
  281. assert_array_equal(analit.nfev+approx.nfev, nfev)
  282. assert_array_equal(ex.njev, njev)
  283. assert_array_equal(analit.njev+approx.njev, njev)
  284. f_approx = approx.fun(x)
  285. J_approx = approx.jac(x)
  286. nfev += 3
  287. assert_array_equal(ex.nfev, nfev)
  288. assert_array_equal(analit.nfev+approx.nfev, nfev)
  289. assert_array_equal(ex.njev, njev)
  290. assert_array_equal(analit.njev+approx.njev, njev)
  291. assert_array_almost_equal(f_analit, f_approx)
  292. assert_array_almost_equal(J_analit, J_approx, decimal=4)
  293. x = [2.0, 1.0]
  294. J_analit = analit.jac(x)
  295. njev += 1
  296. assert_array_equal(ex.nfev, nfev)
  297. assert_array_equal(analit.nfev+approx.nfev, nfev)
  298. assert_array_equal(ex.njev, njev)
  299. assert_array_equal(analit.njev+approx.njev, njev)
  300. J_approx = approx.jac(x)
  301. nfev += 3
  302. assert_array_equal(ex.nfev, nfev)
  303. assert_array_equal(analit.nfev+approx.nfev, nfev)
  304. assert_array_equal(ex.njev, njev)
  305. assert_array_equal(analit.njev+approx.njev, njev)
  306. assert_array_almost_equal(J_analit, J_approx)
  307. x = [2.5, 0.3]
  308. f_analit = analit.fun(x)
  309. J_analit = analit.jac(x)
  310. nfev += 1
  311. njev += 1
  312. assert_array_equal(ex.nfev, nfev)
  313. assert_array_equal(analit.nfev+approx.nfev, nfev)
  314. assert_array_equal(ex.njev, njev)
  315. assert_array_equal(analit.njev+approx.njev, njev)
  316. f_approx = approx.fun(x)
  317. J_approx = approx.jac(x)
  318. nfev += 3
  319. assert_array_equal(ex.nfev, nfev)
  320. assert_array_equal(analit.nfev+approx.nfev, nfev)
  321. assert_array_equal(ex.njev, njev)
  322. assert_array_equal(analit.njev+approx.njev, njev)
  323. assert_array_almost_equal(f_analit, f_approx)
  324. assert_array_almost_equal(J_analit, J_approx)
  325. x = [2, 0.3]
  326. f_analit = analit.fun(x)
  327. J_analit = analit.jac(x)
  328. nfev += 1
  329. njev += 1
  330. assert_array_equal(ex.nfev, nfev)
  331. assert_array_equal(analit.nfev+approx.nfev, nfev)
  332. assert_array_equal(ex.njev, njev)
  333. assert_array_equal(analit.njev+approx.njev, njev)
  334. f_approx = approx.fun(x)
  335. J_approx = approx.jac(x)
  336. nfev += 3
  337. assert_array_equal(ex.nfev, nfev)
  338. assert_array_equal(analit.nfev+approx.nfev, nfev)
  339. assert_array_equal(ex.njev, njev)
  340. assert_array_equal(analit.njev+approx.njev, njev)
  341. assert_array_almost_equal(f_analit, f_approx)
  342. assert_array_almost_equal(J_analit, J_approx)
  343. def test_finite_difference_hess_linear_operator(self):
  344. ex = ExVectorialFunction()
  345. nfev = 0
  346. njev = 0
  347. nhev = 0
  348. x0 = [1.0, 0.0]
  349. v0 = [1.0, 2.0]
  350. analit = VectorFunction(ex.fun, x0, ex.jac, ex.hess, None, None,
  351. (-np.inf, np.inf), None)
  352. nfev += 1
  353. njev += 1
  354. nhev += 1
  355. assert_array_equal(ex.nfev, nfev)
  356. assert_array_equal(analit.nfev, nfev)
  357. assert_array_equal(ex.njev, njev)
  358. assert_array_equal(analit.njev, njev)
  359. assert_array_equal(ex.nhev, nhev)
  360. assert_array_equal(analit.nhev, nhev)
  361. approx = VectorFunction(ex.fun, x0, ex.jac, '2-point', None, None,
  362. (-np.inf, np.inf), None)
  363. assert_(isinstance(approx.H, LinearOperator))
  364. for p in ([1.0, 2.0], [3.0, 4.0], [5.0, 2.0]):
  365. assert_array_equal(analit.f, approx.f)
  366. assert_array_almost_equal(analit.J, approx.J)
  367. assert_array_almost_equal(analit.H.dot(p), approx.H.dot(p))
  368. nfev += 1
  369. njev += 4
  370. assert_array_equal(ex.nfev, nfev)
  371. assert_array_equal(analit.nfev+approx.nfev, nfev)
  372. assert_array_equal(ex.njev, njev)
  373. assert_array_equal(analit.njev+approx.njev, njev)
  374. assert_array_equal(ex.nhev, nhev)
  375. assert_array_equal(analit.nhev+approx.nhev, nhev)
  376. x = [2.0, 1.0]
  377. H_analit = analit.hess(x, v0)
  378. nhev += 1
  379. assert_array_equal(ex.nfev, nfev)
  380. assert_array_equal(analit.nfev+approx.nfev, nfev)
  381. assert_array_equal(ex.njev, njev)
  382. assert_array_equal(analit.njev+approx.njev, njev)
  383. assert_array_equal(ex.nhev, nhev)
  384. assert_array_equal(analit.nhev+approx.nhev, nhev)
  385. H_approx = approx.hess(x, v0)
  386. assert_(isinstance(H_approx, LinearOperator))
  387. for p in ([1.0, 2.0], [3.0, 4.0], [5.0, 2.0]):
  388. assert_array_almost_equal(H_analit.dot(p), H_approx.dot(p),
  389. decimal=5)
  390. njev += 4
  391. assert_array_equal(ex.nfev, nfev)
  392. assert_array_equal(analit.nfev+approx.nfev, nfev)
  393. assert_array_equal(ex.njev, njev)
  394. assert_array_equal(analit.njev+approx.njev, njev)
  395. assert_array_equal(ex.nhev, nhev)
  396. assert_array_equal(analit.nhev+approx.nhev, nhev)
  397. x = [2.1, 1.2]
  398. v = [1.0, 1.0]
  399. H_analit = analit.hess(x, v)
  400. nhev += 1
  401. assert_array_equal(ex.nfev, nfev)
  402. assert_array_equal(analit.nfev+approx.nfev, nfev)
  403. assert_array_equal(ex.njev, njev)
  404. assert_array_equal(analit.njev+approx.njev, njev)
  405. assert_array_equal(ex.nhev, nhev)
  406. assert_array_equal(analit.nhev+approx.nhev, nhev)
  407. H_approx = approx.hess(x, v)
  408. assert_(isinstance(H_approx, LinearOperator))
  409. for v in ([1.0, 2.0], [3.0, 4.0], [5.0, 2.0]):
  410. assert_array_almost_equal(H_analit.dot(v), H_approx.dot(v))
  411. njev += 4
  412. assert_array_equal(ex.nfev, nfev)
  413. assert_array_equal(analit.nfev+approx.nfev, nfev)
  414. assert_array_equal(ex.njev, njev)
  415. assert_array_equal(analit.njev+approx.njev, njev)
  416. assert_array_equal(ex.nhev, nhev)
  417. assert_array_equal(analit.nhev+approx.nhev, nhev)
  418. x = [2.5, 0.3]
  419. _ = analit.jac(x)
  420. H_analit = analit.hess(x, v0)
  421. njev += 1
  422. nhev += 1
  423. assert_array_equal(ex.nfev, nfev)
  424. assert_array_equal(analit.nfev+approx.nfev, nfev)
  425. assert_array_equal(ex.njev, njev)
  426. assert_array_equal(analit.njev+approx.njev, njev)
  427. assert_array_equal(ex.nhev, nhev)
  428. assert_array_equal(analit.nhev+approx.nhev, nhev)
  429. _ = approx.jac(x)
  430. H_approx = approx.hess(x, v0)
  431. assert_(isinstance(H_approx, LinearOperator))
  432. for v in ([1.0, 2.0], [3.0, 4.0], [5.0, 2.0]):
  433. assert_array_almost_equal(H_analit.dot(v), H_approx.dot(v), decimal=4)
  434. njev += 4
  435. assert_array_equal(ex.nfev, nfev)
  436. assert_array_equal(analit.nfev+approx.nfev, nfev)
  437. assert_array_equal(ex.njev, njev)
  438. assert_array_equal(analit.njev+approx.njev, njev)
  439. assert_array_equal(ex.nhev, nhev)
  440. assert_array_equal(analit.nhev+approx.nhev, nhev)
  441. x = [5.2, 2.3]
  442. v = [2.3, 5.2]
  443. _ = analit.jac(x)
  444. H_analit = analit.hess(x, v)
  445. njev += 1
  446. nhev += 1
  447. assert_array_equal(ex.nfev, nfev)
  448. assert_array_equal(analit.nfev+approx.nfev, nfev)
  449. assert_array_equal(ex.njev, njev)
  450. assert_array_equal(analit.njev+approx.njev, njev)
  451. assert_array_equal(ex.nhev, nhev)
  452. assert_array_equal(analit.nhev+approx.nhev, nhev)
  453. _ = approx.jac(x)
  454. H_approx = approx.hess(x, v)
  455. assert_(isinstance(H_approx, LinearOperator))
  456. for v in ([1.0, 2.0], [3.0, 4.0], [5.0, 2.0]):
  457. assert_array_almost_equal(H_analit.dot(v), H_approx.dot(v), decimal=4)
  458. njev += 4
  459. assert_array_equal(ex.nfev, nfev)
  460. assert_array_equal(analit.nfev+approx.nfev, nfev)
  461. assert_array_equal(ex.njev, njev)
  462. assert_array_equal(analit.njev+approx.njev, njev)
  463. assert_array_equal(ex.nhev, nhev)
  464. assert_array_equal(analit.nhev+approx.nhev, nhev)
  465. def test_LinearVectorFunction():
  466. A_dense = np.array([
  467. [-1, 2, 0],
  468. [0, 4, 2]
  469. ])
  470. x0 = np.zeros(3)
  471. A_sparse = csr_matrix(A_dense)
  472. x = np.array([1, -1, 0])
  473. v = np.array([-1, 1])
  474. Ax = np.array([-3, -4])
  475. f1 = LinearVectorFunction(A_dense, x0, None)
  476. assert_(not f1.sparse_jacobian)
  477. f2 = LinearVectorFunction(A_dense, x0, True)
  478. assert_(f2.sparse_jacobian)
  479. f3 = LinearVectorFunction(A_dense, x0, False)
  480. assert_(not f3.sparse_jacobian)
  481. f4 = LinearVectorFunction(A_sparse, x0, None)
  482. assert_(f4.sparse_jacobian)
  483. f5 = LinearVectorFunction(A_sparse, x0, True)
  484. assert_(f5.sparse_jacobian)
  485. f6 = LinearVectorFunction(A_sparse, x0, False)
  486. assert_(not f6.sparse_jacobian)
  487. assert_array_equal(f1.fun(x), Ax)
  488. assert_array_equal(f2.fun(x), Ax)
  489. assert_array_equal(f1.jac(x), A_dense)
  490. assert_array_equal(f2.jac(x).toarray(), A_sparse.toarray())
  491. assert_array_equal(f1.hess(x, v).toarray(), np.zeros((3, 3)))
  492. def test_LinearVectorFunction_memoization():
  493. A = np.array([[-1, 2, 0], [0, 4, 2]])
  494. x0 = np.array([1, 2, -1])
  495. fun = LinearVectorFunction(A, x0, False)
  496. assert_array_equal(x0, fun.x)
  497. assert_array_equal(A.dot(x0), fun.f)
  498. x1 = np.array([-1, 3, 10])
  499. assert_array_equal(A, fun.jac(x1))
  500. assert_array_equal(x1, fun.x)
  501. assert_array_equal(A.dot(x0), fun.f)
  502. assert_array_equal(A.dot(x1), fun.fun(x1))
  503. assert_array_equal(A.dot(x1), fun.f)
  504. def test_IdentityVectorFunction():
  505. x0 = np.zeros(3)
  506. f1 = IdentityVectorFunction(x0, None)
  507. f2 = IdentityVectorFunction(x0, False)
  508. f3 = IdentityVectorFunction(x0, True)
  509. assert_(f1.sparse_jacobian)
  510. assert_(not f2.sparse_jacobian)
  511. assert_(f3.sparse_jacobian)
  512. x = np.array([-1, 2, 1])
  513. v = np.array([-2, 3, 0])
  514. assert_array_equal(f1.fun(x), x)
  515. assert_array_equal(f2.fun(x), x)
  516. assert_array_equal(f1.jac(x).toarray(), np.eye(3))
  517. assert_array_equal(f2.jac(x), np.eye(3))
  518. assert_array_equal(f1.hess(x, v).toarray(), np.zeros((3, 3)))