bool.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. """Rudimentary Apache Arrow-backed ExtensionArray.
  2. At the moment, just a boolean array / type is implemented.
  3. Eventually, we'll want to parametrize the type and support
  4. multiple dtypes. Not all methods are implemented yet, and the
  5. current implementation is not efficient.
  6. """
  7. import copy
  8. import itertools
  9. import numpy as np
  10. import pyarrow as pa
  11. import pandas as pd
  12. from pandas.api.extensions import (
  13. ExtensionArray, ExtensionDtype, register_extension_dtype, take)
  14. @register_extension_dtype
  15. class ArrowBoolDtype(ExtensionDtype):
  16. type = np.bool_
  17. kind = 'b'
  18. name = 'arrow_bool'
  19. na_value = pa.NULL
  20. @classmethod
  21. def construct_from_string(cls, string):
  22. if string == cls.name:
  23. return cls()
  24. else:
  25. raise TypeError("Cannot construct a '{}' from "
  26. "'{}'".format(cls, string))
  27. @classmethod
  28. def construct_array_type(cls):
  29. return ArrowBoolArray
  30. def _is_boolean(self):
  31. return True
  32. class ArrowBoolArray(ExtensionArray):
  33. def __init__(self, values):
  34. if not isinstance(values, pa.ChunkedArray):
  35. raise ValueError
  36. assert values.type == pa.bool_()
  37. self._data = values
  38. self._dtype = ArrowBoolDtype()
  39. def __repr__(self):
  40. return "ArrowBoolArray({})".format(repr(self._data))
  41. @classmethod
  42. def from_scalars(cls, values):
  43. arr = pa.chunked_array([pa.array(np.asarray(values))])
  44. return cls(arr)
  45. @classmethod
  46. def from_array(cls, arr):
  47. assert isinstance(arr, pa.Array)
  48. return cls(pa.chunked_array([arr]))
  49. @classmethod
  50. def _from_sequence(cls, scalars, dtype=None, copy=False):
  51. return cls.from_scalars(scalars)
  52. def __getitem__(self, item):
  53. if pd.api.types.is_scalar(item):
  54. return self._data.to_pandas()[item]
  55. else:
  56. vals = self._data.to_pandas()[item]
  57. return type(self).from_scalars(vals)
  58. def __len__(self):
  59. return len(self._data)
  60. def astype(self, dtype, copy=True):
  61. # needed to fix this astype for the Series constructor.
  62. if isinstance(dtype, type(self.dtype)) and dtype == self.dtype:
  63. if copy:
  64. return self.copy()
  65. return self
  66. return super(ArrowBoolArray, self).astype(dtype, copy)
  67. @property
  68. def dtype(self):
  69. return self._dtype
  70. @property
  71. def nbytes(self):
  72. return sum(x.size for chunk in self._data.chunks
  73. for x in chunk.buffers()
  74. if x is not None)
  75. def isna(self):
  76. nas = pd.isna(self._data.to_pandas())
  77. return type(self).from_scalars(nas)
  78. def take(self, indices, allow_fill=False, fill_value=None):
  79. data = self._data.to_pandas()
  80. if allow_fill and fill_value is None:
  81. fill_value = self.dtype.na_value
  82. result = take(data, indices, fill_value=fill_value,
  83. allow_fill=allow_fill)
  84. return self._from_sequence(result, dtype=self.dtype)
  85. def copy(self, deep=False):
  86. if deep:
  87. return type(self)(copy.deepcopy(self._data))
  88. else:
  89. return type(self)(copy.copy(self._data))
  90. def _concat_same_type(cls, to_concat):
  91. chunks = list(itertools.chain.from_iterable(x._data.chunks
  92. for x in to_concat))
  93. arr = pa.chunked_array(chunks)
  94. return cls(arr)
  95. def __invert__(self):
  96. return type(self).from_scalars(
  97. ~self._data.to_pandas()
  98. )
  99. def _reduce(self, method, skipna=True, **kwargs):
  100. if skipna:
  101. arr = self[~self.isna()]
  102. else:
  103. arr = self
  104. try:
  105. op = getattr(arr, method)
  106. except AttributeError:
  107. raise TypeError
  108. return op(**kwargs)
  109. def any(self, axis=0, out=None):
  110. return self._data.to_pandas().any()
  111. def all(self, axis=0, out=None):
  112. return self._data.to_pandas().all()