print_coercion_tables.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #!/usr/bin/env python
  2. """Prints type-coercion tables for the built-in NumPy types
  3. """
  4. from __future__ import division, absolute_import, print_function
  5. import numpy as np
  6. # Generic object that can be added, but doesn't do anything else
  7. class GenericObject(object):
  8. def __init__(self, v):
  9. self.v = v
  10. def __add__(self, other):
  11. return self
  12. def __radd__(self, other):
  13. return self
  14. dtype = np.dtype('O')
  15. def print_cancast_table(ntypes):
  16. print('X', end=' ')
  17. for char in ntypes:
  18. print(char, end=' ')
  19. print()
  20. for row in ntypes:
  21. print(row, end=' ')
  22. for col in ntypes:
  23. print(int(np.can_cast(row, col)), end=' ')
  24. print()
  25. def print_coercion_table(ntypes, inputfirstvalue, inputsecondvalue, firstarray, use_promote_types=False):
  26. print('+', end=' ')
  27. for char in ntypes:
  28. print(char, end=' ')
  29. print()
  30. for row in ntypes:
  31. if row == 'O':
  32. rowtype = GenericObject
  33. else:
  34. rowtype = np.obj2sctype(row)
  35. print(row, end=' ')
  36. for col in ntypes:
  37. if col == 'O':
  38. coltype = GenericObject
  39. else:
  40. coltype = np.obj2sctype(col)
  41. try:
  42. if firstarray:
  43. rowvalue = np.array([rowtype(inputfirstvalue)], dtype=rowtype)
  44. else:
  45. rowvalue = rowtype(inputfirstvalue)
  46. colvalue = coltype(inputsecondvalue)
  47. if use_promote_types:
  48. char = np.promote_types(rowvalue.dtype, colvalue.dtype).char
  49. else:
  50. value = np.add(rowvalue, colvalue)
  51. if isinstance(value, np.ndarray):
  52. char = value.dtype.char
  53. else:
  54. char = np.dtype(type(value)).char
  55. except ValueError:
  56. char = '!'
  57. except OverflowError:
  58. char = '@'
  59. except TypeError:
  60. char = '#'
  61. print(char, end=' ')
  62. print()
  63. print("can cast")
  64. print_cancast_table(np.typecodes['All'])
  65. print()
  66. print("In these tables, ValueError is '!', OverflowError is '@', TypeError is '#'")
  67. print()
  68. print("scalar + scalar")
  69. print_coercion_table(np.typecodes['All'], 0, 0, False)
  70. print()
  71. print("scalar + neg scalar")
  72. print_coercion_table(np.typecodes['All'], 0, -1, False)
  73. print()
  74. print("array + scalar")
  75. print_coercion_table(np.typecodes['All'], 0, 0, True)
  76. print()
  77. print("array + neg scalar")
  78. print_coercion_table(np.typecodes['All'], 0, -1, True)
  79. print()
  80. print("promote_types")
  81. print_coercion_table(np.typecodes['All'], 0, 0, False, True)