test_categorical.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. """
  2. This file contains a minimal set of tests for compliance with the extension
  3. array interface test suite, and should contain no other tests.
  4. The test suite for the full functionality of the array is located in
  5. `pandas/tests/arrays/`.
  6. The tests in this file are inherited from the BaseExtensionTests, and only
  7. minimal tweaks should be applied to get the tests passing (by overwriting a
  8. parent method).
  9. Additional tests should either be added to one of the BaseExtensionTests
  10. classes (if they are relevant for the extension interface for all dtypes), or
  11. be added to the array-specific tests in `pandas/tests/arrays/`.
  12. """
  13. import string
  14. import numpy as np
  15. import pytest
  16. import pandas as pd
  17. from pandas import Categorical
  18. from pandas.api.types import CategoricalDtype
  19. from pandas.tests.extension import base
  20. def make_data():
  21. while True:
  22. values = np.random.choice(list(string.ascii_letters), size=100)
  23. # ensure we meet the requirements
  24. # 1. first two not null
  25. # 2. first and second are different
  26. if values[0] != values[1]:
  27. break
  28. return values
  29. @pytest.fixture
  30. def dtype():
  31. return CategoricalDtype()
  32. @pytest.fixture
  33. def data():
  34. """Length-100 array for this type.
  35. * data[0] and data[1] should both be non missing
  36. * data[0] and data[1] should not gbe equal
  37. """
  38. return Categorical(make_data())
  39. @pytest.fixture
  40. def data_missing():
  41. """Length 2 array with [NA, Valid]"""
  42. return Categorical([np.nan, 'A'])
  43. @pytest.fixture
  44. def data_for_sorting():
  45. return Categorical(['A', 'B', 'C'], categories=['C', 'A', 'B'],
  46. ordered=True)
  47. @pytest.fixture
  48. def data_missing_for_sorting():
  49. return Categorical(['A', None, 'B'], categories=['B', 'A'],
  50. ordered=True)
  51. @pytest.fixture
  52. def na_value():
  53. return np.nan
  54. @pytest.fixture
  55. def data_for_grouping():
  56. return Categorical(['a', 'a', None, None, 'b', 'b', 'a', 'c'])
  57. class TestDtype(base.BaseDtypeTests):
  58. pass
  59. class TestInterface(base.BaseInterfaceTests):
  60. @pytest.mark.skip(reason="Memory usage doesn't match")
  61. def test_memory_usage(self, data):
  62. # Is this deliberate?
  63. super(TestInterface, self).test_memory_usage(data)
  64. class TestConstructors(base.BaseConstructorsTests):
  65. pass
  66. class TestReshaping(base.BaseReshapingTests):
  67. pass
  68. class TestGetitem(base.BaseGetitemTests):
  69. skip_take = pytest.mark.skip(reason="GH-20664.")
  70. @pytest.mark.skip(reason="Backwards compatibility")
  71. def test_getitem_scalar(self, data):
  72. # CategoricalDtype.type isn't "correct" since it should
  73. # be a parent of the elements (object). But don't want
  74. # to break things by changing.
  75. super(TestGetitem, self).test_getitem_scalar(data)
  76. @skip_take
  77. def test_take(self, data, na_value, na_cmp):
  78. # TODO remove this once Categorical.take is fixed
  79. super(TestGetitem, self).test_take(data, na_value, na_cmp)
  80. @skip_take
  81. def test_take_negative(self, data):
  82. super().test_take_negative(data)
  83. @skip_take
  84. def test_take_pandas_style_negative_raises(self, data, na_value):
  85. super().test_take_pandas_style_negative_raises(data, na_value)
  86. @skip_take
  87. def test_take_non_na_fill_value(self, data_missing):
  88. super().test_take_non_na_fill_value(data_missing)
  89. @skip_take
  90. def test_take_out_of_bounds_raises(self, data, allow_fill):
  91. return super().test_take_out_of_bounds_raises(data, allow_fill)
  92. @pytest.mark.skip(reason="GH-20747. Unobserved categories.")
  93. def test_take_series(self, data):
  94. super().test_take_series(data)
  95. @skip_take
  96. def test_reindex_non_na_fill_value(self, data_missing):
  97. super().test_reindex_non_na_fill_value(data_missing)
  98. @pytest.mark.skip(reason="Categorical.take buggy")
  99. def test_take_empty(self, data, na_value, na_cmp):
  100. super().test_take_empty(data, na_value, na_cmp)
  101. @pytest.mark.skip(reason="test not written correctly for categorical")
  102. def test_reindex(self, data, na_value):
  103. super().test_reindex(data, na_value)
  104. class TestSetitem(base.BaseSetitemTests):
  105. pass
  106. class TestMissing(base.BaseMissingTests):
  107. @pytest.mark.skip(reason="Not implemented")
  108. def test_fillna_limit_pad(self, data_missing):
  109. super().test_fillna_limit_pad(data_missing)
  110. @pytest.mark.skip(reason="Not implemented")
  111. def test_fillna_limit_backfill(self, data_missing):
  112. super().test_fillna_limit_backfill(data_missing)
  113. class TestReduce(base.BaseNoReduceTests):
  114. pass
  115. class TestMethods(base.BaseMethodsTests):
  116. @pytest.mark.skip(reason="Unobserved categories included")
  117. def test_value_counts(self, all_data, dropna):
  118. return super().test_value_counts(all_data, dropna)
  119. def test_combine_add(self, data_repeated):
  120. # GH 20825
  121. # When adding categoricals in combine, result is a string
  122. orig_data1, orig_data2 = data_repeated(2)
  123. s1 = pd.Series(orig_data1)
  124. s2 = pd.Series(orig_data2)
  125. result = s1.combine(s2, lambda x1, x2: x1 + x2)
  126. expected = pd.Series(([a + b for (a, b) in
  127. zip(list(orig_data1), list(orig_data2))]))
  128. self.assert_series_equal(result, expected)
  129. val = s1.iloc[0]
  130. result = s1.combine(val, lambda x1, x2: x1 + x2)
  131. expected = pd.Series([a + val for a in list(orig_data1)])
  132. self.assert_series_equal(result, expected)
  133. @pytest.mark.skip(reason="Not Applicable")
  134. def test_fillna_length_mismatch(self, data_missing):
  135. super().test_fillna_length_mismatch(data_missing)
  136. def test_searchsorted(self, data_for_sorting):
  137. if not data_for_sorting.ordered:
  138. raise pytest.skip(reason="searchsorted requires ordered data.")
  139. class TestCasting(base.BaseCastingTests):
  140. pass
  141. class TestArithmeticOps(base.BaseArithmeticOpsTests):
  142. def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
  143. op_name = all_arithmetic_operators
  144. if op_name != '__rmod__':
  145. super(TestArithmeticOps, self).test_arith_series_with_scalar(
  146. data, op_name)
  147. else:
  148. pytest.skip('rmod never called when string is first argument')
  149. def test_add_series_with_extension_array(self, data):
  150. ser = pd.Series(data)
  151. with pytest.raises(TypeError, match="cannot perform"):
  152. ser + data
  153. def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
  154. return super(TestArithmeticOps, self)._check_divmod_op(
  155. s, op, other, exc=TypeError
  156. )
  157. class TestComparisonOps(base.BaseComparisonOpsTests):
  158. def _compare_other(self, s, data, op_name, other):
  159. op = self.get_op_from_name(op_name)
  160. if op_name == '__eq__':
  161. result = op(s, other)
  162. expected = s.combine(other, lambda x, y: x == y)
  163. assert (result == expected).all()
  164. elif op_name == '__ne__':
  165. result = op(s, other)
  166. expected = s.combine(other, lambda x, y: x != y)
  167. assert (result == expected).all()
  168. else:
  169. with pytest.raises(TypeError):
  170. op(data, other)
  171. class TestParsing(base.BaseParsingTests):
  172. pass