ndgriddata.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. """
  2. Convenience interface to N-D interpolation
  3. .. versionadded:: 0.9
  4. """
  5. from __future__ import division, print_function, absolute_import
  6. import numpy as np
  7. from .interpnd import LinearNDInterpolator, NDInterpolatorBase, \
  8. CloughTocher2DInterpolator, _ndim_coords_from_arrays
  9. from scipy.spatial import cKDTree
  10. __all__ = ['griddata', 'NearestNDInterpolator', 'LinearNDInterpolator',
  11. 'CloughTocher2DInterpolator']
  12. #------------------------------------------------------------------------------
  13. # Nearest-neighbour interpolation
  14. #------------------------------------------------------------------------------
  15. class NearestNDInterpolator(NDInterpolatorBase):
  16. """
  17. NearestNDInterpolator(x, y)
  18. Nearest-neighbour interpolation in N dimensions.
  19. .. versionadded:: 0.9
  20. Methods
  21. -------
  22. __call__
  23. Parameters
  24. ----------
  25. x : (Npoints, Ndims) ndarray of floats
  26. Data point coordinates.
  27. y : (Npoints,) ndarray of float or complex
  28. Data values.
  29. rescale : boolean, optional
  30. Rescale points to unit cube before performing interpolation.
  31. This is useful if some of the input dimensions have
  32. incommensurable units and differ by many orders of magnitude.
  33. .. versionadded:: 0.14.0
  34. tree_options : dict, optional
  35. Options passed to the underlying ``cKDTree``.
  36. .. versionadded:: 0.17.0
  37. Notes
  38. -----
  39. Uses ``scipy.spatial.cKDTree``
  40. """
  41. def __init__(self, x, y, rescale=False, tree_options=None):
  42. NDInterpolatorBase.__init__(self, x, y, rescale=rescale,
  43. need_contiguous=False,
  44. need_values=False)
  45. if tree_options is None:
  46. tree_options = dict()
  47. self.tree = cKDTree(self.points, **tree_options)
  48. self.values = y
  49. def __call__(self, *args):
  50. """
  51. Evaluate interpolator at given points.
  52. Parameters
  53. ----------
  54. xi : ndarray of float, shape (..., ndim)
  55. Points where to interpolate data at.
  56. """
  57. xi = _ndim_coords_from_arrays(args, ndim=self.points.shape[1])
  58. xi = self._check_call_shape(xi)
  59. xi = self._scale_x(xi)
  60. dist, i = self.tree.query(xi)
  61. return self.values[i]
  62. #------------------------------------------------------------------------------
  63. # Convenience interface function
  64. #------------------------------------------------------------------------------
  65. def griddata(points, values, xi, method='linear', fill_value=np.nan,
  66. rescale=False):
  67. """
  68. Interpolate unstructured D-dimensional data.
  69. Parameters
  70. ----------
  71. points : ndarray of floats, shape (n, D)
  72. Data point coordinates. Can either be an array of
  73. shape (n, D), or a tuple of `ndim` arrays.
  74. values : ndarray of float or complex, shape (n,)
  75. Data values.
  76. xi : 2-D ndarray of float or tuple of 1-D array, shape (M, D)
  77. Points at which to interpolate data.
  78. method : {'linear', 'nearest', 'cubic'}, optional
  79. Method of interpolation. One of
  80. ``nearest``
  81. return the value at the data point closest to
  82. the point of interpolation. See `NearestNDInterpolator` for
  83. more details.
  84. ``linear``
  85. tessellate the input point set to n-dimensional
  86. simplices, and interpolate linearly on each simplex. See
  87. `LinearNDInterpolator` for more details.
  88. ``cubic`` (1-D)
  89. return the value determined from a cubic
  90. spline.
  91. ``cubic`` (2-D)
  92. return the value determined from a
  93. piecewise cubic, continuously differentiable (C1), and
  94. approximately curvature-minimizing polynomial surface. See
  95. `CloughTocher2DInterpolator` for more details.
  96. fill_value : float, optional
  97. Value used to fill in for requested points outside of the
  98. convex hull of the input points. If not provided, then the
  99. default is ``nan``. This option has no effect for the
  100. 'nearest' method.
  101. rescale : bool, optional
  102. Rescale points to unit cube before performing interpolation.
  103. This is useful if some of the input dimensions have
  104. incommensurable units and differ by many orders of magnitude.
  105. .. versionadded:: 0.14.0
  106. Returns
  107. -------
  108. ndarray
  109. Array of interpolated values.
  110. Notes
  111. -----
  112. .. versionadded:: 0.9
  113. Examples
  114. --------
  115. Suppose we want to interpolate the 2-D function
  116. >>> def func(x, y):
  117. ... return x*(1-x)*np.cos(4*np.pi*x) * np.sin(4*np.pi*y**2)**2
  118. on a grid in [0, 1]x[0, 1]
  119. >>> grid_x, grid_y = np.mgrid[0:1:100j, 0:1:200j]
  120. but we only know its values at 1000 data points:
  121. >>> points = np.random.rand(1000, 2)
  122. >>> values = func(points[:,0], points[:,1])
  123. This can be done with `griddata` -- below we try out all of the
  124. interpolation methods:
  125. >>> from scipy.interpolate import griddata
  126. >>> grid_z0 = griddata(points, values, (grid_x, grid_y), method='nearest')
  127. >>> grid_z1 = griddata(points, values, (grid_x, grid_y), method='linear')
  128. >>> grid_z2 = griddata(points, values, (grid_x, grid_y), method='cubic')
  129. One can see that the exact result is reproduced by all of the
  130. methods to some degree, but for this smooth function the piecewise
  131. cubic interpolant gives the best results:
  132. >>> import matplotlib.pyplot as plt
  133. >>> plt.subplot(221)
  134. >>> plt.imshow(func(grid_x, grid_y).T, extent=(0,1,0,1), origin='lower')
  135. >>> plt.plot(points[:,0], points[:,1], 'k.', ms=1)
  136. >>> plt.title('Original')
  137. >>> plt.subplot(222)
  138. >>> plt.imshow(grid_z0.T, extent=(0,1,0,1), origin='lower')
  139. >>> plt.title('Nearest')
  140. >>> plt.subplot(223)
  141. >>> plt.imshow(grid_z1.T, extent=(0,1,0,1), origin='lower')
  142. >>> plt.title('Linear')
  143. >>> plt.subplot(224)
  144. >>> plt.imshow(grid_z2.T, extent=(0,1,0,1), origin='lower')
  145. >>> plt.title('Cubic')
  146. >>> plt.gcf().set_size_inches(6, 6)
  147. >>> plt.show()
  148. """
  149. points = _ndim_coords_from_arrays(points)
  150. if points.ndim < 2:
  151. ndim = points.ndim
  152. else:
  153. ndim = points.shape[-1]
  154. if ndim == 1 and method in ('nearest', 'linear', 'cubic'):
  155. from .interpolate import interp1d
  156. points = points.ravel()
  157. if isinstance(xi, tuple):
  158. if len(xi) != 1:
  159. raise ValueError("invalid number of dimensions in xi")
  160. xi, = xi
  161. # Sort points/values together, necessary as input for interp1d
  162. idx = np.argsort(points)
  163. points = points[idx]
  164. values = values[idx]
  165. if method == 'nearest':
  166. fill_value = 'extrapolate'
  167. ip = interp1d(points, values, kind=method, axis=0, bounds_error=False,
  168. fill_value=fill_value)
  169. return ip(xi)
  170. elif method == 'nearest':
  171. ip = NearestNDInterpolator(points, values, rescale=rescale)
  172. return ip(xi)
  173. elif method == 'linear':
  174. ip = LinearNDInterpolator(points, values, fill_value=fill_value,
  175. rescale=rescale)
  176. return ip(xi)
  177. elif method == 'cubic' and ndim == 2:
  178. ip = CloughTocher2DInterpolator(points, values, fill_value=fill_value,
  179. rescale=rescale)
  180. return ip(xi)
  181. else:
  182. raise ValueError("Unknown interpolation method %r for "
  183. "%d dimensional data" % (method, ndim))