test_decomp_polar.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from numpy.linalg import norm
  4. from numpy.testing import (assert_, assert_allclose, assert_equal)
  5. from scipy.linalg import polar, eigh
  6. diag2 = np.array([[2, 0], [0, 3]])
  7. a13 = np.array([[1, 2, 2]])
  8. precomputed_cases = [
  9. [[[0]], 'right', [[1]], [[0]]],
  10. [[[0]], 'left', [[1]], [[0]]],
  11. [[[9]], 'right', [[1]], [[9]]],
  12. [[[9]], 'left', [[1]], [[9]]],
  13. [diag2, 'right', np.eye(2), diag2],
  14. [diag2, 'left', np.eye(2), diag2],
  15. [a13, 'right', a13/norm(a13[0]), a13.T.dot(a13)/norm(a13[0])],
  16. ]
  17. verify_cases = [
  18. [[1, 2], [3, 4]],
  19. [[1, 2, 3]],
  20. [[1], [2], [3]],
  21. [[1, 2, 3], [3, 4, 0]],
  22. [[1, 2], [3, 4], [5, 5]],
  23. [[1, 2], [3, 4+5j]],
  24. [[1, 2, 3j]],
  25. [[1], [2], [3j]],
  26. [[1, 2, 3+2j], [3, 4-1j, -4j]],
  27. [[1, 2], [3-2j, 4+0.5j], [5, 5]],
  28. [[10000, 10, 1], [-1, 2, 3j], [0, 1, 2]],
  29. ]
  30. def check_precomputed_polar(a, side, expected_u, expected_p):
  31. # Compare the result of the polar decomposition to a
  32. # precomputed result.
  33. u, p = polar(a, side=side)
  34. assert_allclose(u, expected_u, atol=1e-15)
  35. assert_allclose(p, expected_p, atol=1e-15)
  36. def verify_polar(a):
  37. # Compute the polar decomposition, and then verify that
  38. # the result has all the expected properties.
  39. product_atol = np.sqrt(np.finfo(float).eps)
  40. aa = np.asarray(a)
  41. m, n = aa.shape
  42. u, p = polar(a, side='right')
  43. assert_equal(u.shape, (m, n))
  44. assert_equal(p.shape, (n, n))
  45. # a = up
  46. assert_allclose(u.dot(p), a, atol=product_atol)
  47. if m >= n:
  48. assert_allclose(u.conj().T.dot(u), np.eye(n), atol=1e-15)
  49. else:
  50. assert_allclose(u.dot(u.conj().T), np.eye(m), atol=1e-15)
  51. # p is Hermitian positive semidefinite.
  52. assert_allclose(p.conj().T, p)
  53. evals = eigh(p, eigvals_only=True)
  54. nonzero_evals = evals[abs(evals) > 1e-14]
  55. assert_((nonzero_evals >= 0).all())
  56. u, p = polar(a, side='left')
  57. assert_equal(u.shape, (m, n))
  58. assert_equal(p.shape, (m, m))
  59. # a = pu
  60. assert_allclose(p.dot(u), a, atol=product_atol)
  61. if m >= n:
  62. assert_allclose(u.conj().T.dot(u), np.eye(n), atol=1e-15)
  63. else:
  64. assert_allclose(u.dot(u.conj().T), np.eye(m), atol=1e-15)
  65. # p is Hermitian positive semidefinite.
  66. assert_allclose(p.conj().T, p)
  67. evals = eigh(p, eigvals_only=True)
  68. nonzero_evals = evals[abs(evals) > 1e-14]
  69. assert_((nonzero_evals >= 0).all())
  70. def test_precomputed_cases():
  71. for a, side, expected_u, expected_p in precomputed_cases:
  72. check_precomputed_polar(a, side, expected_u, expected_p)
  73. def test_verify_cases():
  74. for a in verify_cases:
  75. verify_polar(a)