123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- """ Testing
- """
- from __future__ import division, print_function, absolute_import
- import os
- import sys
- import zlib
- from io import BytesIO
- if sys.version_info[0] >= 3:
- cStringIO = BytesIO
- else:
- from cStringIO import StringIO as cStringIO
- from tempfile import mkstemp
- from contextlib import contextmanager
- import numpy as np
- from numpy.testing import assert_, assert_equal
- from pytest import raises as assert_raises
- from scipy.io.matlab.streams import (make_stream,
- GenericStream, cStringStream, FileStream, ZlibInputStream,
- _read_into, _read_string)
- IS_PYPY = ('__pypy__' in sys.modules)
- @contextmanager
- def setup_test_file():
- val = b'a\x00string'
- fd, fname = mkstemp()
- with os.fdopen(fd, 'wb') as fs:
- fs.write(val)
- with open(fname, 'rb') as fs:
- gs = BytesIO(val)
- cs = cStringIO(val)
- yield fs, gs, cs
- os.unlink(fname)
- def test_make_stream():
- with setup_test_file() as (fs, gs, cs):
- # test stream initialization
- assert_(isinstance(make_stream(gs), GenericStream))
- if sys.version_info[0] < 3 and not IS_PYPY:
- assert_(isinstance(make_stream(cs), cStringStream))
- assert_(isinstance(make_stream(fs), FileStream))
- def test_tell_seek():
- with setup_test_file() as (fs, gs, cs):
- for s in (fs, gs, cs):
- st = make_stream(s)
- res = st.seek(0)
- assert_equal(res, 0)
- assert_equal(st.tell(), 0)
- res = st.seek(5)
- assert_equal(res, 0)
- assert_equal(st.tell(), 5)
- res = st.seek(2, 1)
- assert_equal(res, 0)
- assert_equal(st.tell(), 7)
- res = st.seek(-2, 2)
- assert_equal(res, 0)
- assert_equal(st.tell(), 6)
- def test_read():
- with setup_test_file() as (fs, gs, cs):
- for s in (fs, gs, cs):
- st = make_stream(s)
- st.seek(0)
- res = st.read(-1)
- assert_equal(res, b'a\x00string')
- st.seek(0)
- res = st.read(4)
- assert_equal(res, b'a\x00st')
- # read into
- st.seek(0)
- res = _read_into(st, 4)
- assert_equal(res, b'a\x00st')
- res = _read_into(st, 4)
- assert_equal(res, b'ring')
- assert_raises(IOError, _read_into, st, 2)
- # read alloc
- st.seek(0)
- res = _read_string(st, 4)
- assert_equal(res, b'a\x00st')
- res = _read_string(st, 4)
- assert_equal(res, b'ring')
- assert_raises(IOError, _read_string, st, 2)
- class TestZlibInputStream(object):
- def _get_data(self, size):
- data = np.random.randint(0, 256, size).astype(np.uint8).tostring()
- compressed_data = zlib.compress(data)
- stream = BytesIO(compressed_data)
- return stream, len(compressed_data), data
- def test_read(self):
- block_size = 131072
- SIZES = [0, 1, 10, block_size//2, block_size-1,
- block_size, block_size+1, 2*block_size-1]
- READ_SIZES = [block_size//2, block_size-1,
- block_size, block_size+1]
- def check(size, read_size):
- compressed_stream, compressed_data_len, data = self._get_data(size)
- stream = ZlibInputStream(compressed_stream, compressed_data_len)
- data2 = b''
- so_far = 0
- while True:
- block = stream.read(min(read_size,
- size - so_far))
- if not block:
- break
- so_far += len(block)
- data2 += block
- assert_equal(data, data2)
- for size in SIZES:
- for read_size in READ_SIZES:
- check(size, read_size)
- def test_read_max_length(self):
- size = 1234
- data = np.random.randint(0, 256, size).astype(np.uint8).tostring()
- compressed_data = zlib.compress(data)
- compressed_stream = BytesIO(compressed_data + b"abbacaca")
- stream = ZlibInputStream(compressed_stream, len(compressed_data))
- stream.read(len(data))
- assert_equal(compressed_stream.tell(), len(compressed_data))
- assert_raises(IOError, stream.read, 1)
- def test_seek(self):
- compressed_stream, compressed_data_len, data = self._get_data(1024)
- stream = ZlibInputStream(compressed_stream, compressed_data_len)
- stream.seek(123)
- p = 123
- assert_equal(stream.tell(), p)
- d1 = stream.read(11)
- assert_equal(d1, data[p:p+11])
- stream.seek(321, 1)
- p = 123+11+321
- assert_equal(stream.tell(), p)
- d2 = stream.read(21)
- assert_equal(d2, data[p:p+21])
- stream.seek(641, 0)
- p = 641
- assert_equal(stream.tell(), p)
- d3 = stream.read(11)
- assert_equal(d3, data[p:p+11])
- assert_raises(IOError, stream.seek, 10, 2)
- assert_raises(IOError, stream.seek, -1, 1)
- assert_raises(ValueError, stream.seek, 1, 123)
- stream.seek(10000, 1)
- assert_raises(IOError, stream.read, 12)
- def test_all_data_read(self):
- compressed_stream, compressed_data_len, data = self._get_data(1024)
- stream = ZlibInputStream(compressed_stream, compressed_data_len)
- assert_(not stream.all_data_read())
- stream.seek(512)
- assert_(not stream.all_data_read())
- stream.seek(1024)
- assert_(stream.all_data_read())
|