test_c_api.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from numpy.testing import assert_allclose
  4. from scipy import ndimage
  5. from scipy.ndimage import _ctest
  6. from scipy.ndimage import _ctest_oldapi
  7. from scipy.ndimage import _cytest
  8. from scipy._lib._ccallback import LowLevelCallable
  9. FILTER1D_FUNCTIONS = [
  10. lambda filter_size: _ctest.filter1d(filter_size),
  11. lambda filter_size: _ctest_oldapi.filter1d(filter_size),
  12. lambda filter_size: _cytest.filter1d(filter_size, with_signature=False),
  13. lambda filter_size: LowLevelCallable(_cytest.filter1d(filter_size, with_signature=True)),
  14. lambda filter_size: LowLevelCallable.from_cython(_cytest, "_filter1d",
  15. _cytest.filter1d_capsule(filter_size)),
  16. ]
  17. FILTER2D_FUNCTIONS = [
  18. lambda weights: _ctest.filter2d(weights),
  19. lambda weights: _ctest_oldapi.filter2d(weights),
  20. lambda weights: _cytest.filter2d(weights, with_signature=False),
  21. lambda weights: LowLevelCallable(_cytest.filter2d(weights, with_signature=True)),
  22. lambda weights: LowLevelCallable.from_cython(_cytest, "_filter2d", _cytest.filter2d_capsule(weights)),
  23. ]
  24. TRANSFORM_FUNCTIONS = [
  25. lambda shift: _ctest.transform(shift),
  26. lambda shift: _ctest_oldapi.transform(shift),
  27. lambda shift: _cytest.transform(shift, with_signature=False),
  28. lambda shift: LowLevelCallable(_cytest.transform(shift, with_signature=True)),
  29. lambda shift: LowLevelCallable.from_cython(_cytest, "_transform", _cytest.transform_capsule(shift)),
  30. ]
  31. def test_generic_filter():
  32. def filter2d(footprint_elements, weights):
  33. return (weights*footprint_elements).sum()
  34. def check(j):
  35. func = FILTER2D_FUNCTIONS[j]
  36. im = np.ones((20, 20))
  37. im[:10,:10] = 0
  38. footprint = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
  39. footprint_size = np.count_nonzero(footprint)
  40. weights = np.ones(footprint_size)/footprint_size
  41. res = ndimage.generic_filter(im, func(weights),
  42. footprint=footprint)
  43. std = ndimage.generic_filter(im, filter2d, footprint=footprint,
  44. extra_arguments=(weights,))
  45. assert_allclose(res, std, err_msg="#{} failed".format(j))
  46. for j, func in enumerate(FILTER2D_FUNCTIONS):
  47. check(j)
  48. def test_generic_filter1d():
  49. def filter1d(input_line, output_line, filter_size):
  50. for i in range(output_line.size):
  51. output_line[i] = 0
  52. for j in range(filter_size):
  53. output_line[i] += input_line[i+j]
  54. output_line /= filter_size
  55. def check(j):
  56. func = FILTER1D_FUNCTIONS[j]
  57. im = np.tile(np.hstack((np.zeros(10), np.ones(10))), (10, 1))
  58. filter_size = 3
  59. res = ndimage.generic_filter1d(im, func(filter_size),
  60. filter_size)
  61. std = ndimage.generic_filter1d(im, filter1d, filter_size,
  62. extra_arguments=(filter_size,))
  63. assert_allclose(res, std, err_msg="#{} failed".format(j))
  64. for j, func in enumerate(FILTER1D_FUNCTIONS):
  65. check(j)
  66. def test_geometric_transform():
  67. def transform(output_coordinates, shift):
  68. return output_coordinates[0] - shift, output_coordinates[1] - shift
  69. def check(j):
  70. func = TRANSFORM_FUNCTIONS[j]
  71. im = np.arange(12).reshape(4, 3).astype(np.float64)
  72. shift = 0.5
  73. res = ndimage.geometric_transform(im, func(shift))
  74. std = ndimage.geometric_transform(im, transform, extra_arguments=(shift,))
  75. assert_allclose(res, std, err_msg="#{} failed".format(j))
  76. for j, func in enumerate(TRANSFORM_FUNCTIONS):
  77. check(j)