test_pilutil.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. from __future__ import division, print_function, absolute_import
  2. import os.path
  3. import tempfile
  4. import shutil
  5. import numpy as np
  6. import glob
  7. import pytest
  8. from pytest import raises as assert_raises
  9. from numpy.testing import (assert_equal, assert_allclose,
  10. assert_array_equal, assert_)
  11. from scipy._lib._numpy_compat import suppress_warnings
  12. from scipy import misc
  13. from numpy.ma.testutils import assert_mask_equal
  14. try:
  15. import PIL.Image
  16. except ImportError:
  17. _have_PIL = False
  18. else:
  19. _have_PIL = True
  20. # Function / method decorator for skipping PIL tests on import failure
  21. _pilskip = pytest.mark.skipif(not _have_PIL, reason='Need to import PIL for this test')
  22. datapath = os.path.dirname(__file__)
  23. @_pilskip
  24. class TestPILUtil(object):
  25. def test_imresize(self):
  26. im = np.random.random((10, 20))
  27. for T in np.sctypes['float'] + [float]:
  28. # 1.1 rounds to below 1.1 for float16, 1.101 works
  29. with suppress_warnings() as sup:
  30. sup.filter(DeprecationWarning)
  31. im1 = misc.imresize(im, T(1.101))
  32. assert_equal(im1.shape, (11, 22))
  33. def test_imresize2(self):
  34. im = np.random.random((20, 30))
  35. with suppress_warnings() as sup:
  36. sup.filter(DeprecationWarning)
  37. im2 = misc.imresize(im, (30, 40), interp='bicubic')
  38. assert_equal(im2.shape, (30, 40))
  39. def test_imresize3(self):
  40. im = np.random.random((15, 30))
  41. with suppress_warnings() as sup:
  42. sup.filter(DeprecationWarning)
  43. im2 = misc.imresize(im, (30, 60), interp='nearest')
  44. assert_equal(im2.shape, (30, 60))
  45. def test_imresize4(self):
  46. im = np.array([[1, 2],
  47. [3, 4]])
  48. # Check that resizing by target size, float and int are the same
  49. with suppress_warnings() as sup:
  50. sup.filter(DeprecationWarning)
  51. im2 = misc.imresize(im, (4, 4), mode='F') # output size
  52. im3 = misc.imresize(im, 2., mode='F') # fraction
  53. im4 = misc.imresize(im, 200, mode='F') # percentage
  54. assert_equal(im2, im3)
  55. assert_equal(im2, im4)
  56. def test_imresize5(self):
  57. im = np.random.random((25, 15))
  58. with suppress_warnings() as sup:
  59. sup.filter(DeprecationWarning)
  60. im2 = misc.imresize(im, (30, 60), interp='lanczos')
  61. assert_equal(im2.shape, (30, 60))
  62. def test_bytescale(self):
  63. x = np.array([0, 1, 2], np.uint8)
  64. y = np.array([0, 1, 2])
  65. with suppress_warnings() as sup:
  66. sup.filter(DeprecationWarning)
  67. assert_equal(misc.bytescale(x), x)
  68. assert_equal(misc.bytescale(y), [0, 128, 255])
  69. def test_bytescale_keywords(self):
  70. x = np.array([40, 60, 120, 200, 300, 500])
  71. with suppress_warnings() as sup:
  72. sup.filter(DeprecationWarning)
  73. res_lowhigh = misc.bytescale(x, low=10, high=143)
  74. assert_equal(res_lowhigh, [10, 16, 33, 56, 85, 143])
  75. res_cmincmax = misc.bytescale(x, cmin=60, cmax=300)
  76. assert_equal(res_cmincmax, [0, 0, 64, 149, 255, 255])
  77. assert_equal(misc.bytescale(np.array([3, 3, 3]), low=4), [4, 4, 4])
  78. def test_bytescale_cscale_lowhigh(self):
  79. a = np.arange(10)
  80. with suppress_warnings() as sup:
  81. sup.filter(DeprecationWarning)
  82. actual = misc.bytescale(a, cmin=3, cmax=6, low=100, high=200)
  83. expected = [100, 100, 100, 100, 133, 167, 200, 200, 200, 200]
  84. assert_equal(actual, expected)
  85. def test_bytescale_mask(self):
  86. a = np.ma.MaskedArray(data=[1, 2, 3], mask=[False, False, True])
  87. with suppress_warnings() as sup:
  88. sup.filter(DeprecationWarning)
  89. actual = misc.bytescale(a)
  90. expected = [0, 255, 3]
  91. assert_equal(expected, actual)
  92. assert_mask_equal(a.mask, actual.mask)
  93. assert_(isinstance(actual, np.ma.MaskedArray))
  94. def test_bytescale_rounding(self):
  95. a = np.array([-0.5, 0.5, 1.5, 2.5, 3.5])
  96. with suppress_warnings() as sup:
  97. sup.filter(DeprecationWarning)
  98. actual = misc.bytescale(a, cmin=0, cmax=10, low=0, high=10)
  99. expected = [0, 1, 2, 3, 4]
  100. assert_equal(actual, expected)
  101. def test_bytescale_low_greaterthan_high(self):
  102. with assert_raises(ValueError):
  103. with suppress_warnings() as sup:
  104. sup.filter(DeprecationWarning)
  105. misc.bytescale(np.arange(3), low=10, high=5)
  106. def test_bytescale_low_lessthan_0(self):
  107. with assert_raises(ValueError):
  108. with suppress_warnings() as sup:
  109. sup.filter(DeprecationWarning)
  110. misc.bytescale(np.arange(3), low=-1)
  111. def test_bytescale_high_greaterthan_255(self):
  112. with assert_raises(ValueError):
  113. with suppress_warnings() as sup:
  114. sup.filter(DeprecationWarning)
  115. misc.bytescale(np.arange(3), high=256)
  116. def test_bytescale_low_equals_high(self):
  117. a = np.arange(3)
  118. with suppress_warnings() as sup:
  119. sup.filter(DeprecationWarning)
  120. actual = misc.bytescale(a, low=10, high=10)
  121. expected = [10, 10, 10]
  122. assert_equal(actual, expected)
  123. def test_imsave(self):
  124. picdir = os.path.join(datapath, "data")
  125. for png in glob.iglob(picdir + "/*.png"):
  126. with suppress_warnings() as sup:
  127. # PIL causes a Py3k ResourceWarning
  128. sup.filter(message="unclosed file")
  129. sup.filter(DeprecationWarning)
  130. img = misc.imread(png)
  131. tmpdir = tempfile.mkdtemp()
  132. try:
  133. fn1 = os.path.join(tmpdir, 'test.png')
  134. fn2 = os.path.join(tmpdir, 'testimg')
  135. with suppress_warnings() as sup:
  136. # PIL causes a Py3k ResourceWarning
  137. sup.filter(message="unclosed file")
  138. sup.filter(DeprecationWarning)
  139. misc.imsave(fn1, img)
  140. misc.imsave(fn2, img, 'PNG')
  141. with suppress_warnings() as sup:
  142. # PIL causes a Py3k ResourceWarning
  143. sup.filter(message="unclosed file")
  144. sup.filter(DeprecationWarning)
  145. data1 = misc.imread(fn1)
  146. data2 = misc.imread(fn2)
  147. assert_allclose(data1, img)
  148. assert_allclose(data2, img)
  149. assert_equal(data1.shape, img.shape)
  150. assert_equal(data2.shape, img.shape)
  151. finally:
  152. shutil.rmtree(tmpdir)
  153. def check_fromimage(filename, irange, shape):
  154. fp = open(filename, "rb")
  155. with suppress_warnings() as sup:
  156. sup.filter(DeprecationWarning)
  157. img = misc.fromimage(PIL.Image.open(fp))
  158. fp.close()
  159. imin, imax = irange
  160. assert_equal(img.min(), imin)
  161. assert_equal(img.max(), imax)
  162. assert_equal(img.shape, shape)
  163. @_pilskip
  164. def test_fromimage():
  165. # Test generator for parametric tests
  166. # Tuples in the list are (filename, (datamin, datamax), shape).
  167. files = [('icon.png', (0, 255), (48, 48, 4)),
  168. ('icon_mono.png', (0, 255), (48, 48, 4)),
  169. ('icon_mono_flat.png', (0, 255), (48, 48, 3))]
  170. for fn, irange, shape in files:
  171. with suppress_warnings() as sup:
  172. sup.filter(DeprecationWarning)
  173. check_fromimage(os.path.join(datapath, 'data', fn), irange, shape)
  174. @_pilskip
  175. def test_imread_indexed_png():
  176. # The file `foo3x5x4indexed.png` was created with this array
  177. # (3x5 is (height)x(width)):
  178. data = np.array([[[127, 0, 255, 255],
  179. [127, 0, 255, 255],
  180. [127, 0, 255, 255],
  181. [127, 0, 255, 255],
  182. [127, 0, 255, 255]],
  183. [[192, 192, 255, 0],
  184. [192, 192, 255, 0],
  185. [0, 0, 255, 0],
  186. [0, 0, 255, 0],
  187. [0, 0, 255, 0]],
  188. [[0, 31, 255, 255],
  189. [0, 31, 255, 255],
  190. [0, 31, 255, 255],
  191. [0, 31, 255, 255],
  192. [0, 31, 255, 255]]], dtype=np.uint8)
  193. filename = os.path.join(datapath, 'data', 'foo3x5x4indexed.png')
  194. with open(filename, 'rb') as f:
  195. with suppress_warnings() as sup:
  196. sup.filter(DeprecationWarning)
  197. im = misc.imread(f)
  198. assert_array_equal(im, data)
  199. @_pilskip
  200. def test_imread_1bit():
  201. # box1.png is a 48x48 grayscale image with bit depth 1.
  202. # The border pixels are 1 and the rest are 0.
  203. filename = os.path.join(datapath, 'data', 'box1.png')
  204. with open(filename, 'rb') as f:
  205. with suppress_warnings() as sup:
  206. sup.filter(DeprecationWarning)
  207. im = misc.imread(f)
  208. assert_equal(im.dtype, np.uint8)
  209. expected = np.zeros((48, 48), dtype=np.uint8)
  210. # When scaled up from 1 bit to 8 bits, 1 becomes 255.
  211. expected[:, 0] = 255
  212. expected[:, -1] = 255
  213. expected[0, :] = 255
  214. expected[-1, :] = 255
  215. assert_equal(im, expected)
  216. @_pilskip
  217. def test_imread_2bit():
  218. # blocks2bit.png is a 12x12 grayscale image with bit depth 2.
  219. # The pattern is 4 square subblocks of size 6x6. Upper left
  220. # is all 0, upper right is all 1, lower left is all 2, lower
  221. # right is all 3.
  222. # When scaled up to 8 bits, the values become [0, 85, 170, 255].
  223. filename = os.path.join(datapath, 'data', 'blocks2bit.png')
  224. with open(filename, 'rb') as f:
  225. with suppress_warnings() as sup:
  226. sup.filter(DeprecationWarning)
  227. im = misc.imread(f)
  228. assert_equal(im.dtype, np.uint8)
  229. expected = np.zeros((12, 12), dtype=np.uint8)
  230. expected[:6, 6:] = 85
  231. expected[6:, :6] = 170
  232. expected[6:, 6:] = 255
  233. assert_equal(im, expected)
  234. @_pilskip
  235. def test_imread_4bit():
  236. # pattern4bit.png is a 12(h) x 31(w) grayscale image with bit depth 4.
  237. # The value in row j and column i is maximum(j, i) % 16.
  238. # When scaled up to 8 bits, the values become [0, 17, 34, ..., 255].
  239. filename = os.path.join(datapath, 'data', 'pattern4bit.png')
  240. with open(filename, 'rb') as f:
  241. with suppress_warnings() as sup:
  242. sup.filter(DeprecationWarning)
  243. im = misc.imread(f)
  244. assert_equal(im.dtype, np.uint8)
  245. j, i = np.meshgrid(np.arange(12), np.arange(31), indexing='ij')
  246. expected = 17*(np.maximum(j, i) % 16).astype(np.uint8)
  247. assert_equal(im, expected)