test_streams.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. """ Testing
  2. """
  3. from __future__ import division, print_function, absolute_import
  4. import os
  5. import sys
  6. import zlib
  7. from io import BytesIO
  8. if sys.version_info[0] >= 3:
  9. cStringIO = BytesIO
  10. else:
  11. from cStringIO import StringIO as cStringIO
  12. from tempfile import mkstemp
  13. from contextlib import contextmanager
  14. import numpy as np
  15. from numpy.testing import assert_, assert_equal
  16. from pytest import raises as assert_raises
  17. from scipy.io.matlab.streams import (make_stream,
  18. GenericStream, cStringStream, FileStream, ZlibInputStream,
  19. _read_into, _read_string)
  20. IS_PYPY = ('__pypy__' in sys.modules)
  21. @contextmanager
  22. def setup_test_file():
  23. val = b'a\x00string'
  24. fd, fname = mkstemp()
  25. with os.fdopen(fd, 'wb') as fs:
  26. fs.write(val)
  27. with open(fname, 'rb') as fs:
  28. gs = BytesIO(val)
  29. cs = cStringIO(val)
  30. yield fs, gs, cs
  31. os.unlink(fname)
  32. def test_make_stream():
  33. with setup_test_file() as (fs, gs, cs):
  34. # test stream initialization
  35. assert_(isinstance(make_stream(gs), GenericStream))
  36. if sys.version_info[0] < 3 and not IS_PYPY:
  37. assert_(isinstance(make_stream(cs), cStringStream))
  38. assert_(isinstance(make_stream(fs), FileStream))
  39. def test_tell_seek():
  40. with setup_test_file() as (fs, gs, cs):
  41. for s in (fs, gs, cs):
  42. st = make_stream(s)
  43. res = st.seek(0)
  44. assert_equal(res, 0)
  45. assert_equal(st.tell(), 0)
  46. res = st.seek(5)
  47. assert_equal(res, 0)
  48. assert_equal(st.tell(), 5)
  49. res = st.seek(2, 1)
  50. assert_equal(res, 0)
  51. assert_equal(st.tell(), 7)
  52. res = st.seek(-2, 2)
  53. assert_equal(res, 0)
  54. assert_equal(st.tell(), 6)
  55. def test_read():
  56. with setup_test_file() as (fs, gs, cs):
  57. for s in (fs, gs, cs):
  58. st = make_stream(s)
  59. st.seek(0)
  60. res = st.read(-1)
  61. assert_equal(res, b'a\x00string')
  62. st.seek(0)
  63. res = st.read(4)
  64. assert_equal(res, b'a\x00st')
  65. # read into
  66. st.seek(0)
  67. res = _read_into(st, 4)
  68. assert_equal(res, b'a\x00st')
  69. res = _read_into(st, 4)
  70. assert_equal(res, b'ring')
  71. assert_raises(IOError, _read_into, st, 2)
  72. # read alloc
  73. st.seek(0)
  74. res = _read_string(st, 4)
  75. assert_equal(res, b'a\x00st')
  76. res = _read_string(st, 4)
  77. assert_equal(res, b'ring')
  78. assert_raises(IOError, _read_string, st, 2)
  79. class TestZlibInputStream(object):
  80. def _get_data(self, size):
  81. data = np.random.randint(0, 256, size).astype(np.uint8).tostring()
  82. compressed_data = zlib.compress(data)
  83. stream = BytesIO(compressed_data)
  84. return stream, len(compressed_data), data
  85. def test_read(self):
  86. block_size = 131072
  87. SIZES = [0, 1, 10, block_size//2, block_size-1,
  88. block_size, block_size+1, 2*block_size-1]
  89. READ_SIZES = [block_size//2, block_size-1,
  90. block_size, block_size+1]
  91. def check(size, read_size):
  92. compressed_stream, compressed_data_len, data = self._get_data(size)
  93. stream = ZlibInputStream(compressed_stream, compressed_data_len)
  94. data2 = b''
  95. so_far = 0
  96. while True:
  97. block = stream.read(min(read_size,
  98. size - so_far))
  99. if not block:
  100. break
  101. so_far += len(block)
  102. data2 += block
  103. assert_equal(data, data2)
  104. for size in SIZES:
  105. for read_size in READ_SIZES:
  106. check(size, read_size)
  107. def test_read_max_length(self):
  108. size = 1234
  109. data = np.random.randint(0, 256, size).astype(np.uint8).tostring()
  110. compressed_data = zlib.compress(data)
  111. compressed_stream = BytesIO(compressed_data + b"abbacaca")
  112. stream = ZlibInputStream(compressed_stream, len(compressed_data))
  113. stream.read(len(data))
  114. assert_equal(compressed_stream.tell(), len(compressed_data))
  115. assert_raises(IOError, stream.read, 1)
  116. def test_seek(self):
  117. compressed_stream, compressed_data_len, data = self._get_data(1024)
  118. stream = ZlibInputStream(compressed_stream, compressed_data_len)
  119. stream.seek(123)
  120. p = 123
  121. assert_equal(stream.tell(), p)
  122. d1 = stream.read(11)
  123. assert_equal(d1, data[p:p+11])
  124. stream.seek(321, 1)
  125. p = 123+11+321
  126. assert_equal(stream.tell(), p)
  127. d2 = stream.read(21)
  128. assert_equal(d2, data[p:p+21])
  129. stream.seek(641, 0)
  130. p = 641
  131. assert_equal(stream.tell(), p)
  132. d3 = stream.read(11)
  133. assert_equal(d3, data[p:p+11])
  134. assert_raises(IOError, stream.seek, 10, 2)
  135. assert_raises(IOError, stream.seek, -1, 1)
  136. assert_raises(ValueError, stream.seek, 1, 123)
  137. stream.seek(10000, 1)
  138. assert_raises(IOError, stream.read, 12)
  139. def test_all_data_read(self):
  140. compressed_stream, compressed_data_len, data = self._get_data(1024)
  141. stream = ZlibInputStream(compressed_stream, compressed_data_len)
  142. assert_(not stream.all_data_read())
  143. stream.seek(512)
  144. assert_(not stream.all_data_read())
  145. stream.seek(1024)
  146. assert_(stream.all_data_read())