test_common.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import numpy as np
  2. import pytest
  3. from pandas.core.dtypes import dtypes
  4. from pandas.core.dtypes.common import is_extension_array_dtype
  5. import pandas as pd
  6. from pandas.core.arrays import ExtensionArray
  7. import pandas.util.testing as tm
  8. class DummyDtype(dtypes.ExtensionDtype):
  9. pass
  10. class DummyArray(ExtensionArray):
  11. def __init__(self, data):
  12. self.data = data
  13. def __array__(self, dtype):
  14. return self.data
  15. @property
  16. def dtype(self):
  17. return DummyDtype()
  18. def astype(self, dtype, copy=True):
  19. # we don't support anything but a single dtype
  20. if isinstance(dtype, DummyDtype):
  21. if copy:
  22. return type(self)(self.data)
  23. return self
  24. return np.array(self, dtype=dtype, copy=copy)
  25. class TestExtensionArrayDtype(object):
  26. @pytest.mark.parametrize('values', [
  27. pd.Categorical([]),
  28. pd.Categorical([]).dtype,
  29. pd.Series(pd.Categorical([])),
  30. DummyDtype(),
  31. DummyArray(np.array([1, 2])),
  32. ])
  33. def test_is_extension_array_dtype(self, values):
  34. assert is_extension_array_dtype(values)
  35. @pytest.mark.parametrize('values', [
  36. np.array([]),
  37. pd.Series(np.array([])),
  38. ])
  39. def test_is_not_extension_array_dtype(self, values):
  40. assert not is_extension_array_dtype(values)
  41. def test_astype():
  42. arr = DummyArray(np.array([1, 2, 3]))
  43. expected = np.array([1, 2, 3], dtype=object)
  44. result = arr.astype(object)
  45. tm.assert_numpy_array_equal(result, expected)
  46. result = arr.astype('object')
  47. tm.assert_numpy_array_equal(result, expected)
  48. def test_astype_no_copy():
  49. arr = DummyArray(np.array([1, 2, 3], dtype=np.int64))
  50. result = arr.astype(arr.dtype, copy=False)
  51. assert arr is result
  52. result = arr.astype(arr.dtype)
  53. assert arr is not result
  54. @pytest.mark.parametrize('dtype', [
  55. dtypes.CategoricalDtype(),
  56. dtypes.IntervalDtype(),
  57. ])
  58. def test_is_extension_array_dtype(dtype):
  59. assert isinstance(dtype, dtypes.ExtensionDtype)
  60. assert is_extension_array_dtype(dtype)