spfuncs.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. """ Functions that operate on sparse matrices
  2. """
  3. from __future__ import division, print_function, absolute_import
  4. __all__ = ['count_blocks','estimate_blocksize']
  5. from .csr import isspmatrix_csr, csr_matrix
  6. from .csc import isspmatrix_csc
  7. from ._sparsetools import csr_count_blocks
  8. def extract_diagonal(A):
  9. raise NotImplementedError('use .diagonal() instead')
  10. #def extract_diagonal(A):
  11. # """extract_diagonal(A) returns the main diagonal of A."""
  12. # #TODO extract k-th diagonal
  13. # if isspmatrix_csr(A) or isspmatrix_csc(A):
  14. # fn = getattr(sparsetools, A.format + "_diagonal")
  15. # y = empty( min(A.shape), dtype=upcast(A.dtype) )
  16. # fn(A.shape[0],A.shape[1],A.indptr,A.indices,A.data,y)
  17. # return y
  18. # elif isspmatrix_bsr(A):
  19. # M,N = A.shape
  20. # R,C = A.blocksize
  21. # y = empty( min(M,N), dtype=upcast(A.dtype) )
  22. # fn = sparsetools.bsr_diagonal(M//R, N//C, R, C, \
  23. # A.indptr, A.indices, ravel(A.data), y)
  24. # return y
  25. # else:
  26. # return extract_diagonal(csr_matrix(A))
  27. def estimate_blocksize(A,efficiency=0.7):
  28. """Attempt to determine the blocksize of a sparse matrix
  29. Returns a blocksize=(r,c) such that
  30. - A.nnz / A.tobsr( (r,c) ).nnz > efficiency
  31. """
  32. if not (isspmatrix_csr(A) or isspmatrix_csc(A)):
  33. A = csr_matrix(A)
  34. if A.nnz == 0:
  35. return (1,1)
  36. if not 0 < efficiency < 1.0:
  37. raise ValueError('efficiency must satisfy 0.0 < efficiency < 1.0')
  38. high_efficiency = (1.0 + efficiency) / 2.0
  39. nnz = float(A.nnz)
  40. M,N = A.shape
  41. if M % 2 == 0 and N % 2 == 0:
  42. e22 = nnz / (4 * count_blocks(A,(2,2)))
  43. else:
  44. e22 = 0.0
  45. if M % 3 == 0 and N % 3 == 0:
  46. e33 = nnz / (9 * count_blocks(A,(3,3)))
  47. else:
  48. e33 = 0.0
  49. if e22 > high_efficiency and e33 > high_efficiency:
  50. e66 = nnz / (36 * count_blocks(A,(6,6)))
  51. if e66 > efficiency:
  52. return (6,6)
  53. else:
  54. return (3,3)
  55. else:
  56. if M % 4 == 0 and N % 4 == 0:
  57. e44 = nnz / (16 * count_blocks(A,(4,4)))
  58. else:
  59. e44 = 0.0
  60. if e44 > efficiency:
  61. return (4,4)
  62. elif e33 > efficiency:
  63. return (3,3)
  64. elif e22 > efficiency:
  65. return (2,2)
  66. else:
  67. return (1,1)
  68. def count_blocks(A,blocksize):
  69. """For a given blocksize=(r,c) count the number of occupied
  70. blocks in a sparse matrix A
  71. """
  72. r,c = blocksize
  73. if r < 1 or c < 1:
  74. raise ValueError('r and c must be positive')
  75. if isspmatrix_csr(A):
  76. M,N = A.shape
  77. return csr_count_blocks(M,N,r,c,A.indptr,A.indices)
  78. elif isspmatrix_csc(A):
  79. return count_blocks(A.T,(c,r))
  80. else:
  81. return count_blocks(csr_matrix(A),blocksize)