utils.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from __future__ import division, print_function, absolute_import
  2. __docformat__ = "restructuredtext en"
  3. __all__ = []
  4. from numpy import asanyarray, asarray, asmatrix, array, matrix, zeros
  5. from scipy.sparse.linalg.interface import aslinearoperator, LinearOperator, \
  6. IdentityOperator
  7. _coerce_rules = {('f','f'):'f', ('f','d'):'d', ('f','F'):'F',
  8. ('f','D'):'D', ('d','f'):'d', ('d','d'):'d',
  9. ('d','F'):'D', ('d','D'):'D', ('F','f'):'F',
  10. ('F','d'):'D', ('F','F'):'F', ('F','D'):'D',
  11. ('D','f'):'D', ('D','d'):'D', ('D','F'):'D',
  12. ('D','D'):'D'}
  13. def coerce(x,y):
  14. if x not in 'fdFD':
  15. x = 'd'
  16. if y not in 'fdFD':
  17. y = 'd'
  18. return _coerce_rules[x,y]
  19. def id(x):
  20. return x
  21. def make_system(A, M, x0, b):
  22. """Make a linear system Ax=b
  23. Parameters
  24. ----------
  25. A : LinearOperator
  26. sparse or dense matrix (or any valid input to aslinearoperator)
  27. M : {LinearOperator, Nones}
  28. preconditioner
  29. sparse or dense matrix (or any valid input to aslinearoperator)
  30. x0 : {array_like, None}
  31. initial guess to iterative method
  32. b : array_like
  33. right hand side
  34. Returns
  35. -------
  36. (A, M, x, b, postprocess)
  37. A : LinearOperator
  38. matrix of the linear system
  39. M : LinearOperator
  40. preconditioner
  41. x : rank 1 ndarray
  42. initial guess
  43. b : rank 1 ndarray
  44. right hand side
  45. postprocess : function
  46. converts the solution vector to the appropriate
  47. type and dimensions (e.g. (N,1) matrix)
  48. """
  49. A_ = A
  50. A = aslinearoperator(A)
  51. if A.shape[0] != A.shape[1]:
  52. raise ValueError('expected square matrix, but got shape=%s' % (A.shape,))
  53. N = A.shape[0]
  54. b = asanyarray(b)
  55. if not (b.shape == (N,1) or b.shape == (N,)):
  56. raise ValueError('A and b have incompatible dimensions')
  57. if b.dtype.char not in 'fdFD':
  58. b = b.astype('d') # upcast non-FP types to double
  59. def postprocess(x):
  60. if isinstance(b,matrix):
  61. x = asmatrix(x)
  62. return x.reshape(b.shape)
  63. if hasattr(A,'dtype'):
  64. xtype = A.dtype.char
  65. else:
  66. xtype = A.matvec(b).dtype.char
  67. xtype = coerce(xtype, b.dtype.char)
  68. b = asarray(b,dtype=xtype) # make b the same type as x
  69. b = b.ravel()
  70. if x0 is None:
  71. x = zeros(N, dtype=xtype)
  72. else:
  73. x = array(x0, dtype=xtype)
  74. if not (x.shape == (N,1) or x.shape == (N,)):
  75. raise ValueError('A and x have incompatible dimensions')
  76. x = x.ravel()
  77. # process preconditioner
  78. if M is None:
  79. if hasattr(A_,'psolve'):
  80. psolve = A_.psolve
  81. else:
  82. psolve = id
  83. if hasattr(A_,'rpsolve'):
  84. rpsolve = A_.rpsolve
  85. else:
  86. rpsolve = id
  87. if psolve is id and rpsolve is id:
  88. M = IdentityOperator(shape=A.shape, dtype=A.dtype)
  89. else:
  90. M = LinearOperator(A.shape, matvec=psolve, rmatvec=rpsolve,
  91. dtype=A.dtype)
  92. else:
  93. M = aslinearoperator(M)
  94. if A.shape != M.shape:
  95. raise ValueError('matrix and preconditioner have different shapes')
  96. return A, M, x, b, postprocess