_ni_support.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright (C) 2003-2005 Peter J. Verveer
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. #
  7. # 1. Redistributions of source code must retain the above copyright
  8. # notice, this list of conditions and the following disclaimer.
  9. #
  10. # 2. Redistributions in binary form must reproduce the above
  11. # copyright notice, this list of conditions and the following
  12. # disclaimer in the documentation and/or other materials provided
  13. # with the distribution.
  14. #
  15. # 3. The name of the author may not be used to endorse or promote
  16. # products derived from this software without specific prior
  17. # written permission.
  18. #
  19. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS
  20. # OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  21. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  22. # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
  23. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  24. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
  25. # GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  26. # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
  27. # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  28. # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  29. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  30. from __future__ import division, print_function, absolute_import
  31. import numpy
  32. from scipy._lib.six import string_types
  33. def _extend_mode_to_code(mode):
  34. """Convert an extension mode to the corresponding integer code.
  35. """
  36. if mode == 'nearest':
  37. return 0
  38. elif mode == 'wrap':
  39. return 1
  40. elif mode == 'reflect':
  41. return 2
  42. elif mode == 'mirror':
  43. return 3
  44. elif mode == 'constant':
  45. return 4
  46. else:
  47. raise RuntimeError('boundary mode not supported')
  48. def _normalize_sequence(input, rank):
  49. """If input is a scalar, create a sequence of length equal to the
  50. rank by duplicating the input. If input is a sequence,
  51. check if its length is equal to the length of array.
  52. """
  53. is_str = isinstance(input, string_types)
  54. if hasattr(input, '__iter__') and not is_str:
  55. normalized = list(input)
  56. if len(normalized) != rank:
  57. err = "sequence argument must have length equal to input rank"
  58. raise RuntimeError(err)
  59. else:
  60. normalized = [input] * rank
  61. return normalized
  62. def _get_output(output, input, shape=None):
  63. if shape is None:
  64. shape = input.shape
  65. if output is None:
  66. output = numpy.zeros(shape, dtype=input.dtype.name)
  67. elif type(output) in [type(type), type(numpy.zeros((4,)).dtype)]:
  68. output = numpy.zeros(shape, dtype=output)
  69. elif isinstance(output, string_types):
  70. output = numpy.typeDict[output]
  71. output = numpy.zeros(shape, dtype=output)
  72. elif output.shape != shape:
  73. raise RuntimeError("output shape not correct")
  74. return output
  75. def _check_axis(axis, rank):
  76. if axis < 0:
  77. axis += rank
  78. if axis < 0 or axis >= rank:
  79. raise ValueError('invalid axis')
  80. return axis