test_json.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import collections
  2. import operator
  3. import pytest
  4. from pandas.compat import PY2, PY36
  5. import pandas as pd
  6. from pandas.tests.extension import base
  7. import pandas.util.testing as tm
  8. from .array import JSONArray, JSONDtype, make_data
  9. pytestmark = pytest.mark.skipif(PY2, reason="Py2 doesn't have a UserDict")
  10. @pytest.fixture
  11. def dtype():
  12. return JSONDtype()
  13. @pytest.fixture
  14. def data():
  15. """Length-100 PeriodArray for semantics test."""
  16. data = make_data()
  17. # Why the while loop? NumPy is unable to construct an ndarray from
  18. # equal-length ndarrays. Many of our operations involve coercing the
  19. # EA to an ndarray of objects. To avoid random test failures, we ensure
  20. # that our data is coercable to an ndarray. Several tests deal with only
  21. # the first two elements, so that's what we'll check.
  22. while len(data[0]) == len(data[1]):
  23. data = make_data()
  24. return JSONArray(data)
  25. @pytest.fixture
  26. def data_missing():
  27. """Length 2 array with [NA, Valid]"""
  28. return JSONArray([{}, {'a': 10}])
  29. @pytest.fixture
  30. def data_for_sorting():
  31. return JSONArray([{'b': 1}, {'c': 4}, {'a': 2, 'c': 3}])
  32. @pytest.fixture
  33. def data_missing_for_sorting():
  34. return JSONArray([{'b': 1}, {}, {'a': 4}])
  35. @pytest.fixture
  36. def na_value(dtype):
  37. return dtype.na_value
  38. @pytest.fixture
  39. def na_cmp():
  40. return operator.eq
  41. @pytest.fixture
  42. def data_for_grouping():
  43. return JSONArray([
  44. {'b': 1}, {'b': 1},
  45. {}, {},
  46. {'a': 0, 'c': 2}, {'a': 0, 'c': 2},
  47. {'b': 1},
  48. {'c': 2},
  49. ])
  50. class BaseJSON(object):
  51. # NumPy doesn't handle an array of equal-length UserDicts.
  52. # The default assert_series_equal eventually does a
  53. # Series.values, which raises. We work around it by
  54. # converting the UserDicts to dicts.
  55. def assert_series_equal(self, left, right, **kwargs):
  56. if left.dtype.name == 'json':
  57. assert left.dtype == right.dtype
  58. left = pd.Series(JSONArray(left.values.astype(object)),
  59. index=left.index, name=left.name)
  60. right = pd.Series(JSONArray(right.values.astype(object)),
  61. index=right.index, name=right.name)
  62. tm.assert_series_equal(left, right, **kwargs)
  63. def assert_frame_equal(self, left, right, *args, **kwargs):
  64. tm.assert_index_equal(
  65. left.columns, right.columns,
  66. exact=kwargs.get('check_column_type', 'equiv'),
  67. check_names=kwargs.get('check_names', True),
  68. check_exact=kwargs.get('check_exact', False),
  69. check_categorical=kwargs.get('check_categorical', True),
  70. obj='{obj}.columns'.format(obj=kwargs.get('obj', 'DataFrame')))
  71. jsons = (left.dtypes == 'json').index
  72. for col in jsons:
  73. self.assert_series_equal(left[col], right[col],
  74. *args, **kwargs)
  75. left = left.drop(columns=jsons)
  76. right = right.drop(columns=jsons)
  77. tm.assert_frame_equal(left, right, *args, **kwargs)
  78. class TestDtype(BaseJSON, base.BaseDtypeTests):
  79. pass
  80. class TestInterface(BaseJSON, base.BaseInterfaceTests):
  81. def test_custom_asserts(self):
  82. # This would always trigger the KeyError from trying to put
  83. # an array of equal-length UserDicts inside an ndarray.
  84. data = JSONArray([collections.UserDict({'a': 1}),
  85. collections.UserDict({'b': 2}),
  86. collections.UserDict({'c': 3})])
  87. a = pd.Series(data)
  88. self.assert_series_equal(a, a)
  89. self.assert_frame_equal(a.to_frame(), a.to_frame())
  90. b = pd.Series(data.take([0, 0, 1]))
  91. with pytest.raises(AssertionError):
  92. self.assert_series_equal(a, b)
  93. with pytest.raises(AssertionError):
  94. self.assert_frame_equal(a.to_frame(), b.to_frame())
  95. class TestConstructors(BaseJSON, base.BaseConstructorsTests):
  96. @pytest.mark.skip(reason="not implemented constructor from dtype")
  97. def test_from_dtype(self, data):
  98. # construct from our dtype & string dtype
  99. pass
  100. class TestReshaping(BaseJSON, base.BaseReshapingTests):
  101. @pytest.mark.skip(reason="Different definitions of NA")
  102. def test_stack(self):
  103. """
  104. The test does .astype(object).stack(). If we happen to have
  105. any missing values in `data`, then we'll end up with different
  106. rows since we consider `{}` NA, but `.astype(object)` doesn't.
  107. """
  108. @pytest.mark.xfail(reason="dict for NA")
  109. def test_unstack(self, data, index):
  110. # The base test has NaN for the expected NA value.
  111. # this matches otherwise
  112. return super().test_unstack(data, index)
  113. class TestGetitem(BaseJSON, base.BaseGetitemTests):
  114. pass
  115. class TestMissing(BaseJSON, base.BaseMissingTests):
  116. @pytest.mark.skip(reason="Setting a dict as a scalar")
  117. def test_fillna_series(self):
  118. """We treat dictionaries as a mapping in fillna, not a scalar."""
  119. @pytest.mark.skip(reason="Setting a dict as a scalar")
  120. def test_fillna_frame(self):
  121. """We treat dictionaries as a mapping in fillna, not a scalar."""
  122. unhashable = pytest.mark.skip(reason="Unhashable")
  123. unstable = pytest.mark.skipif(not PY36, # 3.6 or higher
  124. reason="Dictionary order unstable")
  125. class TestReduce(base.BaseNoReduceTests):
  126. pass
  127. class TestMethods(BaseJSON, base.BaseMethodsTests):
  128. @unhashable
  129. def test_value_counts(self, all_data, dropna):
  130. pass
  131. @unhashable
  132. def test_sort_values_frame(self):
  133. # TODO (EA.factorize): see if _values_for_factorize allows this.
  134. pass
  135. @unstable
  136. def test_argsort(self, data_for_sorting):
  137. super(TestMethods, self).test_argsort(data_for_sorting)
  138. @unstable
  139. def test_argsort_missing(self, data_missing_for_sorting):
  140. super(TestMethods, self).test_argsort_missing(
  141. data_missing_for_sorting)
  142. @unstable
  143. @pytest.mark.parametrize('ascending', [True, False])
  144. def test_sort_values(self, data_for_sorting, ascending):
  145. super(TestMethods, self).test_sort_values(
  146. data_for_sorting, ascending)
  147. @unstable
  148. @pytest.mark.parametrize('ascending', [True, False])
  149. def test_sort_values_missing(self, data_missing_for_sorting, ascending):
  150. super(TestMethods, self).test_sort_values_missing(
  151. data_missing_for_sorting, ascending)
  152. @pytest.mark.skip(reason="combine for JSONArray not supported")
  153. def test_combine_le(self, data_repeated):
  154. pass
  155. @pytest.mark.skip(reason="combine for JSONArray not supported")
  156. def test_combine_add(self, data_repeated):
  157. pass
  158. @pytest.mark.skip(reason="combine for JSONArray not supported")
  159. def test_combine_first(self, data):
  160. pass
  161. @unhashable
  162. def test_hash_pandas_object_works(self, data, kind):
  163. super().test_hash_pandas_object_works(data, kind)
  164. @pytest.mark.skip(reason="broadcasting error")
  165. def test_where_series(self, data, na_value):
  166. # Fails with
  167. # *** ValueError: operands could not be broadcast together
  168. # with shapes (4,) (4,) (0,)
  169. super().test_where_series(data, na_value)
  170. @pytest.mark.skip(reason="Can't compare dicts.")
  171. def test_searchsorted(self, data_for_sorting):
  172. super(TestMethods, self).test_searchsorted(data_for_sorting)
  173. class TestCasting(BaseJSON, base.BaseCastingTests):
  174. @pytest.mark.skip(reason="failing on np.array(self, dtype=str)")
  175. def test_astype_str(self):
  176. """This currently fails in NumPy on np.array(self, dtype=str) with
  177. *** ValueError: setting an array element with a sequence
  178. """
  179. # We intentionally don't run base.BaseSetitemTests because pandas'
  180. # internals has trouble setting sequences of values into scalar positions.
  181. class TestGroupby(BaseJSON, base.BaseGroupbyTests):
  182. @unhashable
  183. def test_groupby_extension_transform(self):
  184. """
  185. This currently fails in Series.name.setter, since the
  186. name must be hashable, but the value is a dictionary.
  187. I think this is what we want, i.e. `.name` should be the original
  188. values, and not the values for factorization.
  189. """
  190. @unhashable
  191. def test_groupby_extension_apply(self):
  192. """
  193. This fails in Index._do_unique_check with
  194. > hash(val)
  195. E TypeError: unhashable type: 'UserDict' with
  196. I suspect that once we support Index[ExtensionArray],
  197. we'll be able to dispatch unique.
  198. """
  199. @unstable
  200. @pytest.mark.parametrize('as_index', [True, False])
  201. def test_groupby_extension_agg(self, as_index, data_for_grouping):
  202. super(TestGroupby, self).test_groupby_extension_agg(
  203. as_index, data_for_grouping
  204. )
  205. class TestArithmeticOps(BaseJSON, base.BaseArithmeticOpsTests):
  206. def test_error(self, data, all_arithmetic_operators):
  207. pass
  208. def test_add_series_with_extension_array(self, data):
  209. ser = pd.Series(data)
  210. with pytest.raises(TypeError, match="unsupported"):
  211. ser + data
  212. def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
  213. return super(TestArithmeticOps, self)._check_divmod_op(
  214. s, op, other, exc=TypeError
  215. )
  216. class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests):
  217. pass
  218. class TestPrinting(BaseJSON, base.BasePrintingTests):
  219. pass