test_fortran.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. ''' Tests for fortran sequential files '''
  2. import tempfile
  3. import shutil
  4. from os import path, unlink
  5. from glob import iglob
  6. import re
  7. from numpy.testing import assert_equal, assert_allclose
  8. import numpy as np
  9. from scipy.io import FortranFile, _test_fortran
  10. DATA_PATH = path.join(path.dirname(__file__), 'data')
  11. def test_fortranfiles_read():
  12. for filename in iglob(path.join(DATA_PATH, "fortran-*-*x*x*.dat")):
  13. m = re.search(r'fortran-([^-]+)-(\d+)x(\d+)x(\d+).dat', filename, re.I)
  14. if not m:
  15. raise RuntimeError("Couldn't match %s filename to regex" % filename)
  16. dims = (int(m.group(2)), int(m.group(3)), int(m.group(4)))
  17. dtype = m.group(1).replace('s', '<')
  18. f = FortranFile(filename, 'r', '<u4')
  19. data = f.read_record(dtype=dtype).reshape(dims, order='F')
  20. f.close()
  21. expected = np.arange(np.prod(dims)).reshape(dims).astype(dtype)
  22. assert_equal(data, expected)
  23. def test_fortranfiles_mixed_record():
  24. filename = path.join(DATA_PATH, "fortran-mixed.dat")
  25. with FortranFile(filename, 'r', '<u4') as f:
  26. record = f.read_record('<i4,<f4,<i8,(2)<f8')
  27. assert_equal(record['f0'][0], 1)
  28. assert_allclose(record['f1'][0], 2.3)
  29. assert_equal(record['f2'][0], 4)
  30. assert_allclose(record['f3'][0], [5.6, 7.8])
  31. def test_fortranfiles_write():
  32. for filename in iglob(path.join(DATA_PATH, "fortran-*-*x*x*.dat")):
  33. m = re.search(r'fortran-([^-]+)-(\d+)x(\d+)x(\d+).dat', filename, re.I)
  34. if not m:
  35. raise RuntimeError("Couldn't match %s filename to regex" % filename)
  36. dims = (int(m.group(2)), int(m.group(3)), int(m.group(4)))
  37. dtype = m.group(1).replace('s', '<')
  38. data = np.arange(np.prod(dims)).reshape(dims).astype(dtype)
  39. tmpdir = tempfile.mkdtemp()
  40. try:
  41. testFile = path.join(tmpdir,path.basename(filename))
  42. f = FortranFile(testFile, 'w','<u4')
  43. f.write_record(data.T)
  44. f.close()
  45. originalfile = open(filename, 'rb')
  46. newfile = open(testFile, 'rb')
  47. assert_equal(originalfile.read(), newfile.read(),
  48. err_msg=filename)
  49. originalfile.close()
  50. newfile.close()
  51. finally:
  52. shutil.rmtree(tmpdir)
  53. def test_fortranfile_read_mixed_record():
  54. # The data file fortran-3x3d-2i.dat contains the program that
  55. # produced it at the end.
  56. #
  57. # double precision :: a(3,3)
  58. # integer :: b(2)
  59. # ...
  60. # open(1, file='fortran-3x3d-2i.dat', form='unformatted')
  61. # write(1) a, b
  62. # close(1)
  63. #
  64. filename = path.join(DATA_PATH, "fortran-3x3d-2i.dat")
  65. with FortranFile(filename, 'r', '<u4') as f:
  66. record = f.read_record('(3,3)f8', '2i4')
  67. ax = np.arange(3*3).reshape(3, 3).astype(np.double)
  68. bx = np.array([-1, -2], dtype=np.int32)
  69. assert_equal(record[0], ax.T)
  70. assert_equal(record[1], bx.T)
  71. def test_fortranfile_write_mixed_record(tmpdir):
  72. tf = path.join(str(tmpdir), 'test.dat')
  73. records = [
  74. (('f4', 'f4', 'i4'), (np.float32(2), np.float32(3), np.int32(100))),
  75. (('4f4', '(3,3)f4', '8i4'), (np.random.randint(255, size=[4]).astype(np.float32),
  76. np.random.randint(255, size=[3, 3]).astype(np.float32),
  77. np.random.randint(255, size=[8]).astype(np.int32)))
  78. ]
  79. for dtype, a in records:
  80. with FortranFile(tf, 'w') as f:
  81. f.write_record(*a)
  82. with FortranFile(tf, 'r') as f:
  83. b = f.read_record(*dtype)
  84. assert_equal(len(a), len(b))
  85. for aa, bb in zip(a, b):
  86. assert_equal(bb, aa)
  87. def test_fortran_roundtrip(tmpdir):
  88. filename = path.join(str(tmpdir), 'test.dat')
  89. np.random.seed(1)
  90. # double precision
  91. m, n, k = 5, 3, 2
  92. a = np.random.randn(m, n, k)
  93. with FortranFile(filename, 'w') as f:
  94. f.write_record(a.T)
  95. a2 = _test_fortran.read_unformatted_double(m, n, k, filename)
  96. with FortranFile(filename, 'r') as f:
  97. a3 = f.read_record('(2,3,5)f8').T
  98. assert_equal(a2, a)
  99. assert_equal(a3, a)
  100. # integer
  101. m, n, k = 5, 3, 2
  102. a = np.random.randn(m, n, k).astype(np.int32)
  103. with FortranFile(filename, 'w') as f:
  104. f.write_record(a.T)
  105. a2 = _test_fortran.read_unformatted_int(m, n, k, filename)
  106. with FortranFile(filename, 'r') as f:
  107. a3 = f.read_record('(2,3,5)i4').T
  108. assert_equal(a2, a)
  109. assert_equal(a3, a)
  110. # mixed
  111. m, n, k = 5, 3, 2
  112. a = np.random.randn(m, n)
  113. b = np.random.randn(k).astype(np.intc)
  114. with FortranFile(filename, 'w') as f:
  115. f.write_record(a.T, b.T)
  116. a2, b2 = _test_fortran.read_unformatted_mixed(m, n, k, filename)
  117. with FortranFile(filename, 'r') as f:
  118. a3, b3 = f.read_record('(3,5)f8', '2i4')
  119. a3 = a3.T
  120. assert_equal(a2, a)
  121. assert_equal(a3, a)
  122. assert_equal(b2, b)
  123. assert_equal(b3, b)