test_splines.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. """Tests for spline filtering."""
  2. from __future__ import division, print_function, absolute_import
  3. import numpy as np
  4. import pytest
  5. from numpy.testing import assert_almost_equal
  6. from scipy import ndimage
  7. def get_spline_knot_values(order):
  8. """Knot values to the right of a B-spline's center."""
  9. knot_values = {0: [1],
  10. 1: [1],
  11. 2: [6, 1],
  12. 3: [4, 1],
  13. 4: [230, 76, 1],
  14. 5: [66, 26, 1]}
  15. return knot_values[order]
  16. def make_spline_knot_matrix(n, order, mode='mirror'):
  17. """Matrix to invert to find the spline coefficients."""
  18. knot_values = get_spline_knot_values(order)
  19. matrix = np.zeros((n, n))
  20. for diag, knot_value in enumerate(knot_values):
  21. indices = np.arange(diag, n)
  22. if diag == 0:
  23. matrix[indices, indices] = knot_value
  24. else:
  25. matrix[indices, indices - diag] = knot_value
  26. matrix[indices - diag, indices] = knot_value
  27. knot_values_sum = knot_values[0] + 2 * sum(knot_values[1:])
  28. if mode == 'mirror':
  29. start, step = 1, 1
  30. elif mode == 'reflect':
  31. start, step = 0, 1
  32. elif mode == 'wrap':
  33. start, step = -1, -1
  34. else:
  35. raise ValueError('unsupported mode {}'.format(mode))
  36. for row in range(len(knot_values) - 1):
  37. for idx, knot_value in enumerate(knot_values[row + 1:]):
  38. matrix[row, start + step*idx] += knot_value
  39. matrix[-row - 1, -start - 1 - step*idx] += knot_value
  40. return matrix / knot_values_sum
  41. @pytest.mark.parametrize('order', [0, 1, 2, 3, 4, 5])
  42. @pytest.mark.parametrize('mode', ['mirror', 'wrap', 'reflect'])
  43. def test_spline_filter_vs_matrix_solution(order, mode):
  44. n = 100
  45. eye = np.eye(n, dtype=float)
  46. spline_filter_axis_0 = ndimage.spline_filter1d(eye, axis=0, order=order,
  47. mode=mode)
  48. spline_filter_axis_1 = ndimage.spline_filter1d(eye, axis=1, order=order,
  49. mode=mode)
  50. matrix = make_spline_knot_matrix(n, order, mode=mode)
  51. assert_almost_equal(eye, np.dot(spline_filter_axis_0, matrix))
  52. assert_almost_equal(eye, np.dot(spline_filter_axis_1, matrix.T))