test_decomp_ldl.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from __future__ import division, print_function, absolute_import
  2. import itertools
  3. from numpy.testing import assert_array_almost_equal, assert_allclose, assert_
  4. from numpy import (array, eye, zeros, empty_like, empty, tril_indices_from,
  5. tril, triu_indices_from, spacing, float32, float64,
  6. complex64, complex128)
  7. from numpy.random import rand, randint, seed
  8. from scipy.linalg import ldl
  9. from pytest import raises as assert_raises, warns
  10. from numpy import ComplexWarning
  11. def test_args():
  12. A = eye(3)
  13. # Nonsquare array
  14. assert_raises(ValueError, ldl, A[:, :2])
  15. # Complex matrix with imaginary diagonal entries with "hermitian=True"
  16. with warns(ComplexWarning):
  17. ldl(A*1j)
  18. def test_empty_array():
  19. a = empty((0, 0), dtype=complex)
  20. l, d, p = ldl(empty((0, 0)))
  21. assert_array_almost_equal(l, empty_like(a))
  22. assert_array_almost_equal(d, empty_like(a))
  23. assert_array_almost_equal(p, array([], dtype=int))
  24. def test_simple():
  25. a = array([[-0.39-0.71j, 5.14-0.64j, -7.86-2.96j, 3.80+0.92j],
  26. [5.14-0.64j, 8.86+1.81j, -3.52+0.58j, 5.32-1.59j],
  27. [-7.86-2.96j, -3.52+0.58j, -2.83-0.03j, -1.54-2.86j],
  28. [3.80+0.92j, 5.32-1.59j, -1.54-2.86j, -0.56+0.12j]])
  29. b = array([[5., 10, 1, 18],
  30. [10., 2, 11, 1],
  31. [1., 11, 19, 9],
  32. [18., 1, 9, 0]])
  33. c = array([[52., 97, 112, 107, 50],
  34. [97., 114, 89, 98, 13],
  35. [112., 89, 64, 33, 6],
  36. [107., 98, 33, 60, 73],
  37. [50., 13, 6, 73, 77]])
  38. d = array([[2., 2, -4, 0, 4],
  39. [2., -2, -2, 10, -8],
  40. [-4., -2, 6, -8, -4],
  41. [0., 10, -8, 6, -6],
  42. [4., -8, -4, -6, 10]])
  43. e = array([[-1.36+0.00j, 0+0j, 0+0j, 0+0j],
  44. [1.58-0.90j, -8.87+0j, 0+0j, 0+0j],
  45. [2.21+0.21j, -1.84+0.03j, -4.63+0j, 0+0j],
  46. [3.91-1.50j, -1.78-1.18j, 0.11-0.11j, -1.84+0.00j]])
  47. for x in (b, c, d):
  48. l, d, p = ldl(x)
  49. assert_allclose(l.dot(d).dot(l.T), x, atol=spacing(1000.), rtol=0)
  50. u, d, p = ldl(x, lower=False)
  51. assert_allclose(u.dot(d).dot(u.T), x, atol=spacing(1000.), rtol=0)
  52. l, d, p = ldl(a, hermitian=False)
  53. assert_allclose(l.dot(d).dot(l.T), a, atol=spacing(1000.), rtol=0)
  54. u, d, p = ldl(a, lower=False, hermitian=False)
  55. assert_allclose(u.dot(d).dot(u.T), a, atol=spacing(1000.), rtol=0)
  56. # Use upper part for the computation and use the lower part for comparison
  57. l, d, p = ldl(e.conj().T, lower=0)
  58. assert_allclose(tril(l.dot(d).dot(l.conj().T)-e), zeros((4, 4)),
  59. atol=spacing(1000.), rtol=0)
  60. def test_permutations():
  61. seed(1234)
  62. for _ in range(10):
  63. n = randint(1, 100)
  64. # Random real/complex array
  65. x = rand(n, n) if randint(2) else rand(n, n) + rand(n, n)*1j
  66. x = x + x.conj().T
  67. x += eye(n)*randint(5, 1e6)
  68. l_ind = tril_indices_from(x, k=-1)
  69. u_ind = triu_indices_from(x, k=1)
  70. # Test whether permutations lead to a triangular array
  71. u, d, p = ldl(x, lower=0)
  72. # lower part should be zero
  73. assert_(not any(u[p, :][l_ind]), 'Spin {} failed'.format(_))
  74. l, d, p = ldl(x, lower=1)
  75. # upper part should be zero
  76. assert_(not any(l[p, :][u_ind]), 'Spin {} failed'.format(_))
  77. def test_ldl_type_size_combinations():
  78. seed(1234)
  79. sizes = [30, 750]
  80. real_dtypes = [float32, float64]
  81. complex_dtypes = [complex64, complex128]
  82. for n, dtype in itertools.product(sizes, real_dtypes):
  83. msg = ("Failed for size: {}, dtype: {}".format(n, dtype))
  84. x = rand(n, n).astype(dtype)
  85. x = x + x.T
  86. x += eye(n, dtype=dtype)*dtype(randint(5, 1e6))
  87. l, d1, p = ldl(x)
  88. u, d2, p = ldl(x, lower=0)
  89. rtol = 1e-4 if dtype is float32 else 1e-10
  90. assert_allclose(l.dot(d1).dot(l.T), x, rtol=rtol, err_msg=msg)
  91. assert_allclose(u.dot(d2).dot(u.T), x, rtol=rtol, err_msg=msg)
  92. for n, dtype in itertools.product(sizes, complex_dtypes):
  93. msg1 = ("Her failed for size: {}, dtype: {}".format(n, dtype))
  94. msg2 = ("Sym failed for size: {}, dtype: {}".format(n, dtype))
  95. # Complex hermitian upper/lower
  96. x = (rand(n, n)+1j*rand(n, n)).astype(dtype)
  97. x = x+x.conj().T
  98. x += eye(n, dtype=dtype)*dtype(randint(5, 1e6))
  99. l, d1, p = ldl(x)
  100. u, d2, p = ldl(x, lower=0)
  101. rtol = 1e-4 if dtype is complex64 else 1e-10
  102. assert_allclose(l.dot(d1).dot(l.conj().T), x, rtol=rtol, err_msg=msg1)
  103. assert_allclose(u.dot(d2).dot(u.conj().T), x, rtol=rtol, err_msg=msg1)
  104. # Complex symmetric upper/lower
  105. x = (rand(n, n)+1j*rand(n, n)).astype(dtype)
  106. x = x+x.T
  107. x += eye(n, dtype=dtype)*dtype(randint(5, 1e6))
  108. l, d1, p = ldl(x, hermitian=0)
  109. u, d2, p = ldl(x, lower=0, hermitian=0)
  110. assert_allclose(l.dot(d1).dot(l.T), x, rtol=rtol, err_msg=msg2)
  111. assert_allclose(u.dot(d2).dot(u.T), x, rtol=rtol, err_msg=msg2)