test_wavfile.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from __future__ import division, print_function, absolute_import
  2. import os
  3. import sys
  4. import tempfile
  5. from io import BytesIO
  6. import numpy as np
  7. from numpy.testing import assert_equal, assert_, assert_array_equal
  8. from pytest import raises as assert_raises
  9. from scipy._lib._numpy_compat import suppress_warnings
  10. from scipy.io import wavfile
  11. def datafile(fn):
  12. return os.path.join(os.path.dirname(__file__), 'data', fn)
  13. def test_read_1():
  14. for mmap in [False, True]:
  15. rate, data = wavfile.read(datafile('test-44100Hz-le-1ch-4bytes.wav'),
  16. mmap=mmap)
  17. assert_equal(rate, 44100)
  18. assert_(np.issubdtype(data.dtype, np.int32))
  19. assert_equal(data.shape, (4410,))
  20. del data
  21. def test_read_2():
  22. for mmap in [False, True]:
  23. rate, data = wavfile.read(datafile('test-8000Hz-le-2ch-1byteu.wav'),
  24. mmap=mmap)
  25. assert_equal(rate, 8000)
  26. assert_(np.issubdtype(data.dtype, np.uint8))
  27. assert_equal(data.shape, (800, 2))
  28. del data
  29. def test_read_3():
  30. for mmap in [False, True]:
  31. rate, data = wavfile.read(datafile('test-44100Hz-2ch-32bit-float-le.wav'),
  32. mmap=mmap)
  33. assert_equal(rate, 44100)
  34. assert_(np.issubdtype(data.dtype, np.float32))
  35. assert_equal(data.shape, (441, 2))
  36. del data
  37. def test_read_4():
  38. for mmap in [False, True]:
  39. with suppress_warnings() as sup:
  40. sup.filter(wavfile.WavFileWarning,
  41. "Chunk .non-data. not understood, skipping it")
  42. rate, data = wavfile.read(datafile('test-48000Hz-2ch-64bit-float-le-wavex.wav'),
  43. mmap=mmap)
  44. assert_equal(rate, 48000)
  45. assert_(np.issubdtype(data.dtype, np.float64))
  46. assert_equal(data.shape, (480, 2))
  47. del data
  48. def test_read_5():
  49. for mmap in [False, True]:
  50. rate, data = wavfile.read(datafile('test-44100Hz-2ch-32bit-float-be.wav'),
  51. mmap=mmap)
  52. assert_equal(rate, 44100)
  53. assert_(np.issubdtype(data.dtype, np.float32))
  54. assert_(data.dtype.byteorder == '>' or (sys.byteorder == 'big' and
  55. data.dtype.byteorder == '='))
  56. assert_equal(data.shape, (441, 2))
  57. del data
  58. def test_read_fail():
  59. for mmap in [False, True]:
  60. fp = open(datafile('example_1.nc'), 'rb')
  61. assert_raises(ValueError, wavfile.read, fp, mmap=mmap)
  62. fp.close()
  63. def test_read_early_eof():
  64. for mmap in [False, True]:
  65. fp = open(datafile('test-44100Hz-le-1ch-4bytes-early-eof.wav'), 'rb')
  66. assert_raises(ValueError, wavfile.read, fp, mmap=mmap)
  67. fp.close()
  68. def test_read_incomplete_chunk():
  69. for mmap in [False, True]:
  70. fp = open(datafile('test-44100Hz-le-1ch-4bytes-incomplete-chunk.wav'), 'rb')
  71. assert_raises(ValueError, wavfile.read, fp, mmap=mmap)
  72. fp.close()
  73. def _check_roundtrip(realfile, rate, dtype, channels):
  74. if realfile:
  75. fd, tmpfile = tempfile.mkstemp(suffix='.wav')
  76. os.close(fd)
  77. else:
  78. tmpfile = BytesIO()
  79. try:
  80. data = np.random.rand(100, channels)
  81. if channels == 1:
  82. data = data[:,0]
  83. if dtype.kind == 'f':
  84. # The range of the float type should be in [-1, 1]
  85. data = data.astype(dtype)
  86. else:
  87. data = (data*128).astype(dtype)
  88. wavfile.write(tmpfile, rate, data)
  89. for mmap in [False, True]:
  90. rate2, data2 = wavfile.read(tmpfile, mmap=mmap)
  91. assert_equal(rate, rate2)
  92. assert_(data2.dtype.byteorder in ('<', '=', '|'), msg=data2.dtype)
  93. assert_array_equal(data, data2)
  94. del data2
  95. finally:
  96. if realfile:
  97. os.unlink(tmpfile)
  98. def test_write_roundtrip():
  99. for realfile in (False, True):
  100. for dtypechar in ('i', 'u', 'f', 'g', 'q'):
  101. for size in (1, 2, 4, 8):
  102. if size == 1 and dtypechar == 'i':
  103. # signed 8-bit integer PCM is not allowed
  104. continue
  105. if size > 1 and dtypechar == 'u':
  106. # unsigned > 8-bit integer PCM is not allowed
  107. continue
  108. if (size == 1 or size == 2) and dtypechar == 'f':
  109. # 8- or 16-bit float PCM is not expected
  110. continue
  111. if dtypechar in 'gq':
  112. # no size allowed for these types
  113. if size == 1:
  114. size = ''
  115. else:
  116. continue
  117. for endianness in ('>', '<'):
  118. if size == 1 and endianness == '<':
  119. continue
  120. for rate in (8000, 32000):
  121. for channels in (1, 2, 5):
  122. dt = np.dtype('%s%s%s' % (endianness, dtypechar, size))
  123. _check_roundtrip(realfile, rate, dt, channels)