array.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import decimal
  2. import numbers
  3. import random
  4. import sys
  5. import numpy as np
  6. from pandas.core.dtypes.base import ExtensionDtype
  7. import pandas as pd
  8. from pandas.api.extensions import register_extension_dtype
  9. from pandas.core.arrays import ExtensionArray, ExtensionScalarOpsMixin
  10. @register_extension_dtype
  11. class DecimalDtype(ExtensionDtype):
  12. type = decimal.Decimal
  13. name = 'decimal'
  14. na_value = decimal.Decimal('NaN')
  15. _metadata = ('context',)
  16. def __init__(self, context=None):
  17. self.context = context or decimal.getcontext()
  18. def __repr__(self):
  19. return 'DecimalDtype(context={})'.format(self.context)
  20. @classmethod
  21. def construct_array_type(cls):
  22. """Return the array type associated with this dtype
  23. Returns
  24. -------
  25. type
  26. """
  27. return DecimalArray
  28. @classmethod
  29. def construct_from_string(cls, string):
  30. if string == cls.name:
  31. return cls()
  32. else:
  33. raise TypeError("Cannot construct a '{}' from "
  34. "'{}'".format(cls, string))
  35. @property
  36. def _is_numeric(self):
  37. return True
  38. class DecimalArray(ExtensionArray, ExtensionScalarOpsMixin):
  39. __array_priority__ = 1000
  40. def __init__(self, values, dtype=None, copy=False, context=None):
  41. for val in values:
  42. if not isinstance(val, decimal.Decimal):
  43. raise TypeError("All values must be of type " +
  44. str(decimal.Decimal))
  45. values = np.asarray(values, dtype=object)
  46. self._data = values
  47. # Some aliases for common attribute names to ensure pandas supports
  48. # these
  49. self._items = self.data = self._data
  50. # those aliases are currently not working due to assumptions
  51. # in internal code (GH-20735)
  52. # self._values = self.values = self.data
  53. self._dtype = DecimalDtype(context)
  54. @property
  55. def dtype(self):
  56. return self._dtype
  57. @classmethod
  58. def _from_sequence(cls, scalars, dtype=None, copy=False):
  59. return cls(scalars)
  60. @classmethod
  61. def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
  62. return cls._from_sequence([decimal.Decimal(x) for x in strings],
  63. dtype, copy)
  64. @classmethod
  65. def _from_factorized(cls, values, original):
  66. return cls(values)
  67. def __getitem__(self, item):
  68. if isinstance(item, numbers.Integral):
  69. return self._data[item]
  70. else:
  71. return type(self)(self._data[item])
  72. def take(self, indexer, allow_fill=False, fill_value=None):
  73. from pandas.api.extensions import take
  74. data = self._data
  75. if allow_fill and fill_value is None:
  76. fill_value = self.dtype.na_value
  77. result = take(data, indexer, fill_value=fill_value,
  78. allow_fill=allow_fill)
  79. return self._from_sequence(result)
  80. def copy(self, deep=False):
  81. if deep:
  82. return type(self)(self._data.copy())
  83. return type(self)(self)
  84. def astype(self, dtype, copy=True):
  85. if isinstance(dtype, type(self.dtype)):
  86. return type(self)(self._data, context=dtype.context)
  87. return np.asarray(self, dtype=dtype)
  88. def __setitem__(self, key, value):
  89. if pd.api.types.is_list_like(value):
  90. if pd.api.types.is_scalar(key):
  91. raise ValueError("setting an array element with a sequence.")
  92. value = [decimal.Decimal(v) for v in value]
  93. else:
  94. value = decimal.Decimal(value)
  95. self._data[key] = value
  96. def __len__(self):
  97. return len(self._data)
  98. @property
  99. def nbytes(self):
  100. n = len(self)
  101. if n:
  102. return n * sys.getsizeof(self[0])
  103. return 0
  104. def isna(self):
  105. return np.array([x.is_nan() for x in self._data], dtype=bool)
  106. @property
  107. def _na_value(self):
  108. return decimal.Decimal('NaN')
  109. @classmethod
  110. def _concat_same_type(cls, to_concat):
  111. return cls(np.concatenate([x._data for x in to_concat]))
  112. def _reduce(self, name, skipna=True, **kwargs):
  113. if skipna:
  114. raise NotImplementedError("decimal does not support skipna=True")
  115. try:
  116. op = getattr(self.data, name)
  117. except AttributeError:
  118. raise NotImplementedError("decimal does not support "
  119. "the {} operation".format(name))
  120. return op(axis=0)
  121. def to_decimal(values, context=None):
  122. return DecimalArray([decimal.Decimal(x) for x in values], context=context)
  123. def make_data():
  124. return [decimal.Decimal(random.random()) for _ in range(100)]
  125. DecimalArray._add_arithmetic_ops()
  126. DecimalArray._add_comparison_ops()