mmio.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835
  1. """
  2. Matrix Market I/O in Python.
  3. See http://math.nist.gov/MatrixMarket/formats.html
  4. for information about the Matrix Market format.
  5. """
  6. #
  7. # Author: Pearu Peterson <pearu@cens.ioc.ee>
  8. # Created: October, 2004
  9. #
  10. # References:
  11. # http://math.nist.gov/MatrixMarket/
  12. #
  13. from __future__ import division, print_function, absolute_import
  14. import os
  15. import sys
  16. from numpy import (asarray, real, imag, conj, zeros, ndarray, concatenate,
  17. ones, can_cast)
  18. from numpy.compat import asbytes, asstr
  19. from scipy._lib.six import string_types
  20. from scipy.sparse import coo_matrix, isspmatrix
  21. __all__ = ['mminfo', 'mmread', 'mmwrite', 'MMFile']
  22. # -----------------------------------------------------------------------------
  23. def mminfo(source):
  24. """
  25. Return size and storage parameters from Matrix Market file-like 'source'.
  26. Parameters
  27. ----------
  28. source : str or file-like
  29. Matrix Market filename (extension .mtx) or open file-like object
  30. Returns
  31. -------
  32. rows : int
  33. Number of matrix rows.
  34. cols : int
  35. Number of matrix columns.
  36. entries : int
  37. Number of non-zero entries of a sparse matrix
  38. or rows*cols for a dense matrix.
  39. format : str
  40. Either 'coordinate' or 'array'.
  41. field : str
  42. Either 'real', 'complex', 'pattern', or 'integer'.
  43. symmetry : str
  44. Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
  45. """
  46. return MMFile.info(source)
  47. # -----------------------------------------------------------------------------
  48. def mmread(source):
  49. """
  50. Reads the contents of a Matrix Market file-like 'source' into a matrix.
  51. Parameters
  52. ----------
  53. source : str or file-like
  54. Matrix Market filename (extensions .mtx, .mtz.gz)
  55. or open file-like object.
  56. Returns
  57. -------
  58. a : ndarray or coo_matrix
  59. Dense or sparse matrix depending on the matrix format in the
  60. Matrix Market file.
  61. """
  62. return MMFile().read(source)
  63. # -----------------------------------------------------------------------------
  64. def mmwrite(target, a, comment='', field=None, precision=None, symmetry=None):
  65. """
  66. Writes the sparse or dense array `a` to Matrix Market file-like `target`.
  67. Parameters
  68. ----------
  69. target : str or file-like
  70. Matrix Market filename (extension .mtx) or open file-like object.
  71. a : array like
  72. Sparse or dense 2D array.
  73. comment : str, optional
  74. Comments to be prepended to the Matrix Market file.
  75. field : None or str, optional
  76. Either 'real', 'complex', 'pattern', or 'integer'.
  77. precision : None or int, optional
  78. Number of digits to display for real or complex values.
  79. symmetry : None or str, optional
  80. Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
  81. If symmetry is None the symmetry type of 'a' is determined by its
  82. values.
  83. """
  84. MMFile().write(target, a, comment, field, precision, symmetry)
  85. ###############################################################################
  86. class MMFile (object):
  87. __slots__ = ('_rows',
  88. '_cols',
  89. '_entries',
  90. '_format',
  91. '_field',
  92. '_symmetry')
  93. @property
  94. def rows(self):
  95. return self._rows
  96. @property
  97. def cols(self):
  98. return self._cols
  99. @property
  100. def entries(self):
  101. return self._entries
  102. @property
  103. def format(self):
  104. return self._format
  105. @property
  106. def field(self):
  107. return self._field
  108. @property
  109. def symmetry(self):
  110. return self._symmetry
  111. @property
  112. def has_symmetry(self):
  113. return self._symmetry in (self.SYMMETRY_SYMMETRIC,
  114. self.SYMMETRY_SKEW_SYMMETRIC,
  115. self.SYMMETRY_HERMITIAN)
  116. # format values
  117. FORMAT_COORDINATE = 'coordinate'
  118. FORMAT_ARRAY = 'array'
  119. FORMAT_VALUES = (FORMAT_COORDINATE, FORMAT_ARRAY)
  120. @classmethod
  121. def _validate_format(self, format):
  122. if format not in self.FORMAT_VALUES:
  123. raise ValueError('unknown format type %s, must be one of %s' %
  124. (format, self.FORMAT_VALUES))
  125. # field values
  126. FIELD_INTEGER = 'integer'
  127. FIELD_UNSIGNED = 'unsigned-integer'
  128. FIELD_REAL = 'real'
  129. FIELD_COMPLEX = 'complex'
  130. FIELD_PATTERN = 'pattern'
  131. FIELD_VALUES = (FIELD_INTEGER, FIELD_UNSIGNED, FIELD_REAL, FIELD_COMPLEX, FIELD_PATTERN)
  132. @classmethod
  133. def _validate_field(self, field):
  134. if field not in self.FIELD_VALUES:
  135. raise ValueError('unknown field type %s, must be one of %s' %
  136. (field, self.FIELD_VALUES))
  137. # symmetry values
  138. SYMMETRY_GENERAL = 'general'
  139. SYMMETRY_SYMMETRIC = 'symmetric'
  140. SYMMETRY_SKEW_SYMMETRIC = 'skew-symmetric'
  141. SYMMETRY_HERMITIAN = 'hermitian'
  142. SYMMETRY_VALUES = (SYMMETRY_GENERAL, SYMMETRY_SYMMETRIC,
  143. SYMMETRY_SKEW_SYMMETRIC, SYMMETRY_HERMITIAN)
  144. @classmethod
  145. def _validate_symmetry(self, symmetry):
  146. if symmetry not in self.SYMMETRY_VALUES:
  147. raise ValueError('unknown symmetry type %s, must be one of %s' %
  148. (symmetry, self.SYMMETRY_VALUES))
  149. DTYPES_BY_FIELD = {FIELD_INTEGER: 'intp',
  150. FIELD_UNSIGNED: 'uint64',
  151. FIELD_REAL: 'd',
  152. FIELD_COMPLEX: 'D',
  153. FIELD_PATTERN: 'd'}
  154. # -------------------------------------------------------------------------
  155. @staticmethod
  156. def reader():
  157. pass
  158. # -------------------------------------------------------------------------
  159. @staticmethod
  160. def writer():
  161. pass
  162. # -------------------------------------------------------------------------
  163. @classmethod
  164. def info(self, source):
  165. """
  166. Return size, storage parameters from Matrix Market file-like 'source'.
  167. Parameters
  168. ----------
  169. source : str or file-like
  170. Matrix Market filename (extension .mtx) or open file-like object
  171. Returns
  172. -------
  173. rows : int
  174. Number of matrix rows.
  175. cols : int
  176. Number of matrix columns.
  177. entries : int
  178. Number of non-zero entries of a sparse matrix
  179. or rows*cols for a dense matrix.
  180. format : str
  181. Either 'coordinate' or 'array'.
  182. field : str
  183. Either 'real', 'complex', 'pattern', or 'integer'.
  184. symmetry : str
  185. Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
  186. """
  187. stream, close_it = self._open(source)
  188. try:
  189. # read and validate header line
  190. line = stream.readline()
  191. mmid, matrix, format, field, symmetry = \
  192. [asstr(part.strip()) for part in line.split()]
  193. if not mmid.startswith('%%MatrixMarket'):
  194. raise ValueError('source is not in Matrix Market format')
  195. if not matrix.lower() == 'matrix':
  196. raise ValueError("Problem reading file header: " + line)
  197. # http://math.nist.gov/MatrixMarket/formats.html
  198. if format.lower() == 'array':
  199. format = self.FORMAT_ARRAY
  200. elif format.lower() == 'coordinate':
  201. format = self.FORMAT_COORDINATE
  202. # skip comments
  203. while line.startswith(b'%'):
  204. line = stream.readline()
  205. line = line.split()
  206. if format == self.FORMAT_ARRAY:
  207. if not len(line) == 2:
  208. raise ValueError("Header line not of length 2: " + line)
  209. rows, cols = map(int, line)
  210. entries = rows * cols
  211. else:
  212. if not len(line) == 3:
  213. raise ValueError("Header line not of length 3: " + line)
  214. rows, cols, entries = map(int, line)
  215. return (rows, cols, entries, format, field.lower(),
  216. symmetry.lower())
  217. finally:
  218. if close_it:
  219. stream.close()
  220. # -------------------------------------------------------------------------
  221. @staticmethod
  222. def _open(filespec, mode='rb'):
  223. """ Return an open file stream for reading based on source.
  224. If source is a file name, open it (after trying to find it with mtx and
  225. gzipped mtx extensions). Otherwise, just return source.
  226. Parameters
  227. ----------
  228. filespec : str or file-like
  229. String giving file name or file-like object
  230. mode : str, optional
  231. Mode with which to open file, if `filespec` is a file name.
  232. Returns
  233. -------
  234. fobj : file-like
  235. Open file-like object.
  236. close_it : bool
  237. True if the calling function should close this file when done,
  238. false otherwise.
  239. """
  240. close_it = False
  241. if isinstance(filespec, string_types):
  242. close_it = True
  243. # open for reading
  244. if mode[0] == 'r':
  245. # determine filename plus extension
  246. if not os.path.isfile(filespec):
  247. if os.path.isfile(filespec+'.mtx'):
  248. filespec = filespec + '.mtx'
  249. elif os.path.isfile(filespec+'.mtx.gz'):
  250. filespec = filespec + '.mtx.gz'
  251. elif os.path.isfile(filespec+'.mtx.bz2'):
  252. filespec = filespec + '.mtx.bz2'
  253. # open filename
  254. if filespec.endswith('.gz'):
  255. import gzip
  256. stream = gzip.open(filespec, mode)
  257. elif filespec.endswith('.bz2'):
  258. import bz2
  259. stream = bz2.BZ2File(filespec, 'rb')
  260. else:
  261. stream = open(filespec, mode)
  262. # open for writing
  263. else:
  264. if filespec[-4:] != '.mtx':
  265. filespec = filespec + '.mtx'
  266. stream = open(filespec, mode)
  267. else:
  268. stream = filespec
  269. return stream, close_it
  270. # -------------------------------------------------------------------------
  271. @staticmethod
  272. def _get_symmetry(a):
  273. m, n = a.shape
  274. if m != n:
  275. return MMFile.SYMMETRY_GENERAL
  276. issymm = True
  277. isskew = True
  278. isherm = a.dtype.char in 'FD'
  279. # sparse input
  280. if isspmatrix(a):
  281. # check if number of nonzero entries of lower and upper triangle
  282. # matrix are equal
  283. a = a.tocoo()
  284. (row, col) = a.nonzero()
  285. if (row < col).sum() != (row > col).sum():
  286. return MMFile.SYMMETRY_GENERAL
  287. # define iterator over symmetric pair entries
  288. a = a.todok()
  289. def symm_iterator():
  290. for ((i, j), aij) in a.items():
  291. if i > j:
  292. aji = a[j, i]
  293. yield (aij, aji)
  294. # non-sparse input
  295. else:
  296. # define iterator over symmetric pair entries
  297. def symm_iterator():
  298. for j in range(n):
  299. for i in range(j+1, n):
  300. aij, aji = a[i][j], a[j][i]
  301. yield (aij, aji)
  302. # check for symmetry
  303. for (aij, aji) in symm_iterator():
  304. if issymm and aij != aji:
  305. issymm = False
  306. if isskew and aij != -aji:
  307. isskew = False
  308. if isherm and aij != conj(aji):
  309. isherm = False
  310. if not (issymm or isskew or isherm):
  311. break
  312. # return symmetry value
  313. if issymm:
  314. return MMFile.SYMMETRY_SYMMETRIC
  315. if isskew:
  316. return MMFile.SYMMETRY_SKEW_SYMMETRIC
  317. if isherm:
  318. return MMFile.SYMMETRY_HERMITIAN
  319. return MMFile.SYMMETRY_GENERAL
  320. # -------------------------------------------------------------------------
  321. @staticmethod
  322. def _field_template(field, precision):
  323. return {MMFile.FIELD_REAL: '%%.%ie\n' % precision,
  324. MMFile.FIELD_INTEGER: '%i\n',
  325. MMFile.FIELD_UNSIGNED: '%u\n',
  326. MMFile.FIELD_COMPLEX: '%%.%ie %%.%ie\n' %
  327. (precision, precision)
  328. }.get(field, None)
  329. # -------------------------------------------------------------------------
  330. def __init__(self, **kwargs):
  331. self._init_attrs(**kwargs)
  332. # -------------------------------------------------------------------------
  333. def read(self, source):
  334. """
  335. Reads the contents of a Matrix Market file-like 'source' into a matrix.
  336. Parameters
  337. ----------
  338. source : str or file-like
  339. Matrix Market filename (extensions .mtx, .mtz.gz)
  340. or open file object.
  341. Returns
  342. -------
  343. a : ndarray or coo_matrix
  344. Dense or sparse matrix depending on the matrix format in the
  345. Matrix Market file.
  346. """
  347. stream, close_it = self._open(source)
  348. try:
  349. self._parse_header(stream)
  350. return self._parse_body(stream)
  351. finally:
  352. if close_it:
  353. stream.close()
  354. # -------------------------------------------------------------------------
  355. def write(self, target, a, comment='', field=None, precision=None,
  356. symmetry=None):
  357. """
  358. Writes sparse or dense array `a` to Matrix Market file-like `target`.
  359. Parameters
  360. ----------
  361. target : str or file-like
  362. Matrix Market filename (extension .mtx) or open file-like object.
  363. a : array like
  364. Sparse or dense 2D array.
  365. comment : str, optional
  366. Comments to be prepended to the Matrix Market file.
  367. field : None or str, optional
  368. Either 'real', 'complex', 'pattern', or 'integer'.
  369. precision : None or int, optional
  370. Number of digits to display for real or complex values.
  371. symmetry : None or str, optional
  372. Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
  373. If symmetry is None the symmetry type of 'a' is determined by its
  374. values.
  375. """
  376. stream, close_it = self._open(target, 'wb')
  377. try:
  378. self._write(stream, a, comment, field, precision, symmetry)
  379. finally:
  380. if close_it:
  381. stream.close()
  382. else:
  383. stream.flush()
  384. # -------------------------------------------------------------------------
  385. def _init_attrs(self, **kwargs):
  386. """
  387. Initialize each attributes with the corresponding keyword arg value
  388. or a default of None
  389. """
  390. attrs = self.__class__.__slots__
  391. public_attrs = [attr[1:] for attr in attrs]
  392. invalid_keys = set(kwargs.keys()) - set(public_attrs)
  393. if invalid_keys:
  394. raise ValueError('''found %s invalid keyword arguments, please only
  395. use %s''' % (tuple(invalid_keys),
  396. public_attrs))
  397. for attr in attrs:
  398. setattr(self, attr, kwargs.get(attr[1:], None))
  399. # -------------------------------------------------------------------------
  400. def _parse_header(self, stream):
  401. rows, cols, entries, format, field, symmetry = \
  402. self.__class__.info(stream)
  403. self._init_attrs(rows=rows, cols=cols, entries=entries, format=format,
  404. field=field, symmetry=symmetry)
  405. # -------------------------------------------------------------------------
  406. def _parse_body(self, stream):
  407. rows, cols, entries, format, field, symm = (self.rows, self.cols,
  408. self.entries, self.format,
  409. self.field, self.symmetry)
  410. try:
  411. from scipy.sparse import coo_matrix
  412. except ImportError:
  413. coo_matrix = None
  414. dtype = self.DTYPES_BY_FIELD.get(field, None)
  415. has_symmetry = self.has_symmetry
  416. is_integer = field == self.FIELD_INTEGER
  417. is_unsigned_integer = field == self.FIELD_UNSIGNED
  418. is_complex = field == self.FIELD_COMPLEX
  419. is_skew = symm == self.SYMMETRY_SKEW_SYMMETRIC
  420. is_herm = symm == self.SYMMETRY_HERMITIAN
  421. is_pattern = field == self.FIELD_PATTERN
  422. if format == self.FORMAT_ARRAY:
  423. a = zeros((rows, cols), dtype=dtype)
  424. line = 1
  425. i, j = 0, 0
  426. if is_skew:
  427. a[i, j] = 0
  428. if i < rows - 1:
  429. i += 1
  430. while line:
  431. line = stream.readline()
  432. if not line or line.startswith(b'%'):
  433. continue
  434. if is_integer:
  435. aij = int(line)
  436. elif is_unsigned_integer:
  437. aij = int(line)
  438. elif is_complex:
  439. aij = complex(*map(float, line.split()))
  440. else:
  441. aij = float(line)
  442. a[i, j] = aij
  443. if has_symmetry and i != j:
  444. if is_skew:
  445. a[j, i] = -aij
  446. elif is_herm:
  447. a[j, i] = conj(aij)
  448. else:
  449. a[j, i] = aij
  450. if i < rows-1:
  451. i = i + 1
  452. else:
  453. j = j + 1
  454. if not has_symmetry:
  455. i = 0
  456. else:
  457. i = j
  458. if is_skew:
  459. a[i, j] = 0
  460. if i < rows-1:
  461. i += 1
  462. if is_skew:
  463. if not (i in [0, j] and j == cols - 1):
  464. raise ValueError("Parse error, did not read all lines.")
  465. else:
  466. if not (i in [0, j] and j == cols):
  467. raise ValueError("Parse error, did not read all lines.")
  468. elif format == self.FORMAT_COORDINATE and coo_matrix is None:
  469. # Read sparse matrix to dense when coo_matrix is not available.
  470. a = zeros((rows, cols), dtype=dtype)
  471. line = 1
  472. k = 0
  473. while line:
  474. line = stream.readline()
  475. if not line or line.startswith(b'%'):
  476. continue
  477. l = line.split()
  478. i, j = map(int, l[:2])
  479. i, j = i-1, j-1
  480. if is_integer:
  481. aij = int(l[2])
  482. elif is_unsigned_integer:
  483. aij = int(l[2])
  484. elif is_complex:
  485. aij = complex(*map(float, l[2:]))
  486. else:
  487. aij = float(l[2])
  488. a[i, j] = aij
  489. if has_symmetry and i != j:
  490. if is_skew:
  491. a[j, i] = -aij
  492. elif is_herm:
  493. a[j, i] = conj(aij)
  494. else:
  495. a[j, i] = aij
  496. k = k + 1
  497. if not k == entries:
  498. ValueError("Did not read all entries")
  499. elif format == self.FORMAT_COORDINATE:
  500. # Read sparse COOrdinate format
  501. if entries == 0:
  502. # empty matrix
  503. return coo_matrix((rows, cols), dtype=dtype)
  504. I = zeros(entries, dtype='intc')
  505. J = zeros(entries, dtype='intc')
  506. if is_pattern:
  507. V = ones(entries, dtype='int8')
  508. elif is_integer:
  509. V = zeros(entries, dtype='intp')
  510. elif is_unsigned_integer:
  511. V = zeros(entries, dtype='uint64')
  512. elif is_complex:
  513. V = zeros(entries, dtype='complex')
  514. else:
  515. V = zeros(entries, dtype='float')
  516. entry_number = 0
  517. for line in stream:
  518. if not line or line.startswith(b'%'):
  519. continue
  520. if entry_number+1 > entries:
  521. raise ValueError("'entries' in header is smaller than "
  522. "number of entries")
  523. l = line.split()
  524. I[entry_number], J[entry_number] = map(int, l[:2])
  525. if not is_pattern:
  526. if is_integer:
  527. V[entry_number] = int(l[2])
  528. elif is_unsigned_integer:
  529. V[entry_number] = int(l[2])
  530. elif is_complex:
  531. V[entry_number] = complex(*map(float, l[2:]))
  532. else:
  533. V[entry_number] = float(l[2])
  534. entry_number += 1
  535. if entry_number < entries:
  536. raise ValueError("'entries' in header is larger than "
  537. "number of entries")
  538. I -= 1 # adjust indices (base 1 -> base 0)
  539. J -= 1
  540. if has_symmetry:
  541. mask = (I != J) # off diagonal mask
  542. od_I = I[mask]
  543. od_J = J[mask]
  544. od_V = V[mask]
  545. I = concatenate((I, od_J))
  546. J = concatenate((J, od_I))
  547. if is_skew:
  548. od_V *= -1
  549. elif is_herm:
  550. od_V = od_V.conjugate()
  551. V = concatenate((V, od_V))
  552. a = coo_matrix((V, (I, J)), shape=(rows, cols), dtype=dtype)
  553. else:
  554. raise NotImplementedError(format)
  555. return a
  556. # ------------------------------------------------------------------------
  557. def _write(self, stream, a, comment='', field=None, precision=None,
  558. symmetry=None):
  559. if isinstance(a, list) or isinstance(a, ndarray) or \
  560. isinstance(a, tuple) or hasattr(a, '__array__'):
  561. rep = self.FORMAT_ARRAY
  562. a = asarray(a)
  563. if len(a.shape) != 2:
  564. raise ValueError('Expected 2 dimensional array')
  565. rows, cols = a.shape
  566. if field is not None:
  567. if field == self.FIELD_INTEGER:
  568. if not can_cast(a.dtype, 'intp'):
  569. raise OverflowError("mmwrite does not support integer "
  570. "dtypes larger than native 'intp'.")
  571. a = a.astype('intp')
  572. elif field == self.FIELD_REAL:
  573. if a.dtype.char not in 'fd':
  574. a = a.astype('d')
  575. elif field == self.FIELD_COMPLEX:
  576. if a.dtype.char not in 'FD':
  577. a = a.astype('D')
  578. else:
  579. if not isspmatrix(a):
  580. raise ValueError('unknown matrix type: %s' % type(a))
  581. rep = 'coordinate'
  582. rows, cols = a.shape
  583. typecode = a.dtype.char
  584. if precision is None:
  585. if typecode in 'fF':
  586. precision = 8
  587. else:
  588. precision = 16
  589. if field is None:
  590. kind = a.dtype.kind
  591. if kind == 'i':
  592. if not can_cast(a.dtype, 'intp'):
  593. raise OverflowError("mmwrite does not support integer "
  594. "dtypes larger than native 'intp'.")
  595. field = 'integer'
  596. elif kind == 'f':
  597. field = 'real'
  598. elif kind == 'c':
  599. field = 'complex'
  600. elif kind == 'u':
  601. field = 'unsigned-integer'
  602. else:
  603. raise TypeError('unexpected dtype kind ' + kind)
  604. if symmetry is None:
  605. symmetry = self._get_symmetry(a)
  606. # validate rep, field, and symmetry
  607. self.__class__._validate_format(rep)
  608. self.__class__._validate_field(field)
  609. self.__class__._validate_symmetry(symmetry)
  610. # write initial header line
  611. stream.write(asbytes('%%MatrixMarket matrix {0} {1} {2}\n'.format(rep,
  612. field, symmetry)))
  613. # write comments
  614. for line in comment.split('\n'):
  615. stream.write(asbytes('%%%s\n' % (line)))
  616. template = self._field_template(field, precision)
  617. # write dense format
  618. if rep == self.FORMAT_ARRAY:
  619. # write shape spec
  620. stream.write(asbytes('%i %i\n' % (rows, cols)))
  621. if field in (self.FIELD_INTEGER, self.FIELD_REAL, self.FIELD_UNSIGNED):
  622. if symmetry == self.SYMMETRY_GENERAL:
  623. for j in range(cols):
  624. for i in range(rows):
  625. stream.write(asbytes(template % a[i, j]))
  626. elif symmetry == self.SYMMETRY_SKEW_SYMMETRIC:
  627. for j in range(cols):
  628. for i in range(j + 1, rows):
  629. stream.write(asbytes(template % a[i, j]))
  630. else:
  631. for j in range(cols):
  632. for i in range(j, rows):
  633. stream.write(asbytes(template % a[i, j]))
  634. elif field == self.FIELD_COMPLEX:
  635. if symmetry == self.SYMMETRY_GENERAL:
  636. for j in range(cols):
  637. for i in range(rows):
  638. aij = a[i, j]
  639. stream.write(asbytes(template % (real(aij),
  640. imag(aij))))
  641. else:
  642. for j in range(cols):
  643. for i in range(j, rows):
  644. aij = a[i, j]
  645. stream.write(asbytes(template % (real(aij),
  646. imag(aij))))
  647. elif field == self.FIELD_PATTERN:
  648. raise ValueError('pattern type inconsisted with dense format')
  649. else:
  650. raise TypeError('Unknown field type %s' % field)
  651. # write sparse format
  652. else:
  653. coo = a.tocoo() # convert to COOrdinate format
  654. # if symmetry format used, remove values above main diagonal
  655. if symmetry != self.SYMMETRY_GENERAL:
  656. lower_triangle_mask = coo.row >= coo.col
  657. coo = coo_matrix((coo.data[lower_triangle_mask],
  658. (coo.row[lower_triangle_mask],
  659. coo.col[lower_triangle_mask])),
  660. shape=coo.shape)
  661. # write shape spec
  662. stream.write(asbytes('%i %i %i\n' % (rows, cols, coo.nnz)))
  663. template = self._field_template(field, precision-1)
  664. if field == self.FIELD_PATTERN:
  665. for r, c in zip(coo.row+1, coo.col+1):
  666. stream.write(asbytes("%i %i\n" % (r, c)))
  667. elif field in (self.FIELD_INTEGER, self.FIELD_REAL, self.FIELD_UNSIGNED):
  668. for r, c, d in zip(coo.row+1, coo.col+1, coo.data):
  669. stream.write(asbytes(("%i %i " % (r, c)) +
  670. (template % d)))
  671. elif field == self.FIELD_COMPLEX:
  672. for r, c, d in zip(coo.row+1, coo.col+1, coo.data):
  673. stream.write(asbytes(("%i %i " % (r, c)) +
  674. (template % (d.real, d.imag))))
  675. else:
  676. raise TypeError('Unknown field type %s' % field)
  677. def _is_fromfile_compatible(stream):
  678. """
  679. Check whether `stream` is compatible with numpy.fromfile.
  680. Passing a gzipped file object to ``fromfile/fromstring`` doesn't work with
  681. Python3.
  682. """
  683. if sys.version_info[0] < 3:
  684. return True
  685. bad_cls = []
  686. try:
  687. import gzip
  688. bad_cls.append(gzip.GzipFile)
  689. except ImportError:
  690. pass
  691. try:
  692. import bz2
  693. bad_cls.append(bz2.BZ2File)
  694. except ImportError:
  695. pass
  696. bad_cls = tuple(bad_cls)
  697. return not isinstance(stream, bad_cls)
  698. # -----------------------------------------------------------------------------
  699. if __name__ == '__main__':
  700. import time
  701. for filename in sys.argv[1:]:
  702. print('Reading', filename, '...', end=' ')
  703. sys.stdout.flush()
  704. t = time.time()
  705. mmread(filename)
  706. print('took %s seconds' % (time.time() - t))