test_rbf.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Created by John Travers, Robert Hetland, 2007
  2. """ Test functions for rbf module """
  3. from __future__ import division, print_function, absolute_import
  4. import numpy as np
  5. from numpy.testing import (assert_, assert_array_almost_equal,
  6. assert_almost_equal)
  7. from numpy import linspace, sin, random, exp, allclose
  8. from scipy.interpolate.rbf import Rbf
  9. FUNCTIONS = ('multiquadric', 'inverse multiquadric', 'gaussian',
  10. 'cubic', 'quintic', 'thin-plate', 'linear')
  11. def check_rbf1d_interpolation(function):
  12. # Check that the Rbf function interpolates through the nodes (1D)
  13. x = linspace(0,10,9)
  14. y = sin(x)
  15. rbf = Rbf(x, y, function=function)
  16. yi = rbf(x)
  17. assert_array_almost_equal(y, yi)
  18. assert_almost_equal(rbf(float(x[0])), y[0])
  19. def check_rbf2d_interpolation(function):
  20. # Check that the Rbf function interpolates through the nodes (2D).
  21. x = random.rand(50,1)*4-2
  22. y = random.rand(50,1)*4-2
  23. z = x*exp(-x**2-1j*y**2)
  24. rbf = Rbf(x, y, z, epsilon=2, function=function)
  25. zi = rbf(x, y)
  26. zi.shape = x.shape
  27. assert_array_almost_equal(z, zi)
  28. def check_rbf3d_interpolation(function):
  29. # Check that the Rbf function interpolates through the nodes (3D).
  30. x = random.rand(50, 1)*4 - 2
  31. y = random.rand(50, 1)*4 - 2
  32. z = random.rand(50, 1)*4 - 2
  33. d = x*exp(-x**2 - y**2)
  34. rbf = Rbf(x, y, z, d, epsilon=2, function=function)
  35. di = rbf(x, y, z)
  36. di.shape = x.shape
  37. assert_array_almost_equal(di, d)
  38. def test_rbf_interpolation():
  39. for function in FUNCTIONS:
  40. check_rbf1d_interpolation(function)
  41. check_rbf2d_interpolation(function)
  42. check_rbf3d_interpolation(function)
  43. def check_rbf1d_regularity(function, atol):
  44. # Check that the Rbf function approximates a smooth function well away
  45. # from the nodes.
  46. x = linspace(0, 10, 9)
  47. y = sin(x)
  48. rbf = Rbf(x, y, function=function)
  49. xi = linspace(0, 10, 100)
  50. yi = rbf(xi)
  51. # import matplotlib.pyplot as plt
  52. # plt.figure()
  53. # plt.plot(x, y, 'o', xi, sin(xi), ':', xi, yi, '-')
  54. # plt.plot(x, y, 'o', xi, yi-sin(xi), ':')
  55. # plt.title(function)
  56. # plt.show()
  57. msg = "abs-diff: %f" % abs(yi - sin(xi)).max()
  58. assert_(allclose(yi, sin(xi), atol=atol), msg)
  59. def test_rbf_regularity():
  60. tolerances = {
  61. 'multiquadric': 0.1,
  62. 'inverse multiquadric': 0.15,
  63. 'gaussian': 0.15,
  64. 'cubic': 0.15,
  65. 'quintic': 0.1,
  66. 'thin-plate': 0.1,
  67. 'linear': 0.2
  68. }
  69. for function in FUNCTIONS:
  70. check_rbf1d_regularity(function, tolerances.get(function, 1e-2))
  71. def check_rbf1d_stability(function):
  72. # Check that the Rbf function with default epsilon is not subject
  73. # to overshoot. Regression for issue #4523.
  74. #
  75. # Generate some data (fixed random seed hence deterministic)
  76. np.random.seed(1234)
  77. x = np.linspace(0, 10, 50)
  78. z = x + 4.0 * np.random.randn(len(x))
  79. rbf = Rbf(x, z, function=function)
  80. xi = np.linspace(0, 10, 1000)
  81. yi = rbf(xi)
  82. # subtract the linear trend and make sure there no spikes
  83. assert_(np.abs(yi-xi).max() / np.abs(z-x).max() < 1.1)
  84. def test_rbf_stability():
  85. for function in FUNCTIONS:
  86. check_rbf1d_stability(function)
  87. def test_default_construction():
  88. # Check that the Rbf class can be constructed with the default
  89. # multiquadric basis function. Regression test for ticket #1228.
  90. x = linspace(0,10,9)
  91. y = sin(x)
  92. rbf = Rbf(x, y)
  93. yi = rbf(x)
  94. assert_array_almost_equal(y, yi)
  95. def test_function_is_callable():
  96. # Check that the Rbf class can be constructed with function=callable.
  97. x = linspace(0,10,9)
  98. y = sin(x)
  99. linfunc = lambda x:x
  100. rbf = Rbf(x, y, function=linfunc)
  101. yi = rbf(x)
  102. assert_array_almost_equal(y, yi)
  103. def test_two_arg_function_is_callable():
  104. # Check that the Rbf class can be constructed with a two argument
  105. # function=callable.
  106. def _func(self, r):
  107. return self.epsilon + r
  108. x = linspace(0,10,9)
  109. y = sin(x)
  110. rbf = Rbf(x, y, function=_func)
  111. yi = rbf(x)
  112. assert_array_almost_equal(y, yi)
  113. def test_rbf_epsilon_none():
  114. x = linspace(0, 10, 9)
  115. y = sin(x)
  116. rbf = Rbf(x, y, epsilon=None)
  117. def test_rbf_epsilon_none_collinear():
  118. # Check that collinear points in one dimension doesn't cause an error
  119. # due to epsilon = 0
  120. x = [1, 2, 3]
  121. y = [4, 4, 4]
  122. z = [5, 6, 7]
  123. rbf = Rbf(x, y, z, epsilon=None)
  124. assert_(rbf.epsilon > 0)