test_decimal.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. import decimal
  2. import math
  3. import operator
  4. import numpy as np
  5. import pytest
  6. import pandas as pd
  7. from pandas import compat
  8. from pandas.tests.extension import base
  9. import pandas.util.testing as tm
  10. from .array import DecimalArray, DecimalDtype, make_data, to_decimal
  11. @pytest.fixture
  12. def dtype():
  13. return DecimalDtype()
  14. @pytest.fixture
  15. def data():
  16. return DecimalArray(make_data())
  17. @pytest.fixture
  18. def data_missing():
  19. return DecimalArray([decimal.Decimal('NaN'), decimal.Decimal(1)])
  20. @pytest.fixture
  21. def data_for_sorting():
  22. return DecimalArray([decimal.Decimal('1'),
  23. decimal.Decimal('2'),
  24. decimal.Decimal('0')])
  25. @pytest.fixture
  26. def data_missing_for_sorting():
  27. return DecimalArray([decimal.Decimal('1'),
  28. decimal.Decimal('NaN'),
  29. decimal.Decimal('0')])
  30. @pytest.fixture
  31. def na_cmp():
  32. return lambda x, y: x.is_nan() and y.is_nan()
  33. @pytest.fixture
  34. def na_value():
  35. return decimal.Decimal("NaN")
  36. @pytest.fixture
  37. def data_for_grouping():
  38. b = decimal.Decimal('1.0')
  39. a = decimal.Decimal('0.0')
  40. c = decimal.Decimal('2.0')
  41. na = decimal.Decimal('NaN')
  42. return DecimalArray([b, b, na, na, a, a, b, c])
  43. class BaseDecimal(object):
  44. def assert_series_equal(self, left, right, *args, **kwargs):
  45. def convert(x):
  46. # need to convert array([Decimal(NaN)], dtype='object') to np.NaN
  47. # because Series[object].isnan doesn't recognize decimal(NaN) as
  48. # NA.
  49. try:
  50. return math.isnan(x)
  51. except TypeError:
  52. return False
  53. if left.dtype == 'object':
  54. left_na = left.apply(convert)
  55. else:
  56. left_na = left.isna()
  57. if right.dtype == 'object':
  58. right_na = right.apply(convert)
  59. else:
  60. right_na = right.isna()
  61. tm.assert_series_equal(left_na, right_na)
  62. return tm.assert_series_equal(left[~left_na],
  63. right[~right_na],
  64. *args, **kwargs)
  65. def assert_frame_equal(self, left, right, *args, **kwargs):
  66. # TODO(EA): select_dtypes
  67. tm.assert_index_equal(
  68. left.columns, right.columns,
  69. exact=kwargs.get('check_column_type', 'equiv'),
  70. check_names=kwargs.get('check_names', True),
  71. check_exact=kwargs.get('check_exact', False),
  72. check_categorical=kwargs.get('check_categorical', True),
  73. obj='{obj}.columns'.format(obj=kwargs.get('obj', 'DataFrame')))
  74. decimals = (left.dtypes == 'decimal').index
  75. for col in decimals:
  76. self.assert_series_equal(left[col], right[col],
  77. *args, **kwargs)
  78. left = left.drop(columns=decimals)
  79. right = right.drop(columns=decimals)
  80. tm.assert_frame_equal(left, right, *args, **kwargs)
  81. class TestDtype(BaseDecimal, base.BaseDtypeTests):
  82. @pytest.mark.skipif(compat.PY2, reason="Context not hashable.")
  83. def test_hashable(self, dtype):
  84. pass
  85. class TestInterface(BaseDecimal, base.BaseInterfaceTests):
  86. pytestmark = pytest.mark.skipif(compat.PY2,
  87. reason="Unhashble dtype in Py2.")
  88. class TestConstructors(BaseDecimal, base.BaseConstructorsTests):
  89. @pytest.mark.skip(reason="not implemented constructor from dtype")
  90. def test_from_dtype(self, data):
  91. # construct from our dtype & string dtype
  92. pass
  93. class TestReshaping(BaseDecimal, base.BaseReshapingTests):
  94. pytestmark = pytest.mark.skipif(compat.PY2,
  95. reason="Unhashble dtype in Py2.")
  96. class TestGetitem(BaseDecimal, base.BaseGetitemTests):
  97. def test_take_na_value_other_decimal(self):
  98. arr = DecimalArray([decimal.Decimal('1.0'),
  99. decimal.Decimal('2.0')])
  100. result = arr.take([0, -1], allow_fill=True,
  101. fill_value=decimal.Decimal('-1.0'))
  102. expected = DecimalArray([decimal.Decimal('1.0'),
  103. decimal.Decimal('-1.0')])
  104. self.assert_extension_array_equal(result, expected)
  105. class TestMissing(BaseDecimal, base.BaseMissingTests):
  106. pass
  107. class Reduce(object):
  108. def check_reduce(self, s, op_name, skipna):
  109. if skipna or op_name in ['median', 'skew', 'kurt']:
  110. with pytest.raises(NotImplementedError):
  111. getattr(s, op_name)(skipna=skipna)
  112. else:
  113. result = getattr(s, op_name)(skipna=skipna)
  114. expected = getattr(np.asarray(s), op_name)()
  115. tm.assert_almost_equal(result, expected)
  116. class TestNumericReduce(Reduce, base.BaseNumericReduceTests):
  117. pass
  118. class TestBooleanReduce(Reduce, base.BaseBooleanReduceTests):
  119. pass
  120. class TestMethods(BaseDecimal, base.BaseMethodsTests):
  121. @pytest.mark.parametrize('dropna', [True, False])
  122. @pytest.mark.xfail(reason="value_counts not implemented yet.")
  123. def test_value_counts(self, all_data, dropna):
  124. all_data = all_data[:10]
  125. if dropna:
  126. other = np.array(all_data[~all_data.isna()])
  127. else:
  128. other = all_data
  129. result = pd.Series(all_data).value_counts(dropna=dropna).sort_index()
  130. expected = pd.Series(other).value_counts(dropna=dropna).sort_index()
  131. tm.assert_series_equal(result, expected)
  132. class TestCasting(BaseDecimal, base.BaseCastingTests):
  133. pytestmark = pytest.mark.skipif(compat.PY2,
  134. reason="Unhashble dtype in Py2.")
  135. class TestGroupby(BaseDecimal, base.BaseGroupbyTests):
  136. pytestmark = pytest.mark.skipif(compat.PY2,
  137. reason="Unhashble dtype in Py2.")
  138. class TestSetitem(BaseDecimal, base.BaseSetitemTests):
  139. pass
  140. class TestPrinting(BaseDecimal, base.BasePrintingTests):
  141. pytestmark = pytest.mark.skipif(compat.PY2,
  142. reason="Unhashble dtype in Py2.")
  143. # TODO(extension)
  144. @pytest.mark.xfail(reason=(
  145. "raising AssertionError as this is not implemented, "
  146. "though easy enough to do"))
  147. def test_series_constructor_coerce_data_to_extension_dtype_raises():
  148. xpr = ("Cannot cast data to extension dtype 'decimal'. Pass the "
  149. "extension array directly.")
  150. with pytest.raises(ValueError, match=xpr):
  151. pd.Series([0, 1, 2], dtype=DecimalDtype())
  152. def test_series_constructor_with_dtype():
  153. arr = DecimalArray([decimal.Decimal('10.0')])
  154. result = pd.Series(arr, dtype=DecimalDtype())
  155. expected = pd.Series(arr)
  156. tm.assert_series_equal(result, expected)
  157. result = pd.Series(arr, dtype='int64')
  158. expected = pd.Series([10])
  159. tm.assert_series_equal(result, expected)
  160. def test_dataframe_constructor_with_dtype():
  161. arr = DecimalArray([decimal.Decimal('10.0')])
  162. result = pd.DataFrame({"A": arr}, dtype=DecimalDtype())
  163. expected = pd.DataFrame({"A": arr})
  164. tm.assert_frame_equal(result, expected)
  165. arr = DecimalArray([decimal.Decimal('10.0')])
  166. result = pd.DataFrame({"A": arr}, dtype='int64')
  167. expected = pd.DataFrame({"A": [10]})
  168. tm.assert_frame_equal(result, expected)
  169. @pytest.mark.parametrize("frame", [True, False])
  170. def test_astype_dispatches(frame):
  171. # This is a dtype-specific test that ensures Series[decimal].astype
  172. # gets all the way through to ExtensionArray.astype
  173. # Designing a reliable smoke test that works for arbitrary data types
  174. # is difficult.
  175. data = pd.Series(DecimalArray([decimal.Decimal(2)]), name='a')
  176. ctx = decimal.Context()
  177. ctx.prec = 5
  178. if frame:
  179. data = data.to_frame()
  180. result = data.astype(DecimalDtype(ctx))
  181. if frame:
  182. result = result['a']
  183. assert result.dtype.context.prec == ctx.prec
  184. class TestArithmeticOps(BaseDecimal, base.BaseArithmeticOpsTests):
  185. def check_opname(self, s, op_name, other, exc=None):
  186. super(TestArithmeticOps, self).check_opname(s, op_name,
  187. other, exc=None)
  188. def test_arith_series_with_array(self, data, all_arithmetic_operators):
  189. op_name = all_arithmetic_operators
  190. s = pd.Series(data)
  191. context = decimal.getcontext()
  192. divbyzerotrap = context.traps[decimal.DivisionByZero]
  193. invalidoptrap = context.traps[decimal.InvalidOperation]
  194. context.traps[decimal.DivisionByZero] = 0
  195. context.traps[decimal.InvalidOperation] = 0
  196. # Decimal supports ops with int, but not float
  197. other = pd.Series([int(d * 100) for d in data])
  198. self.check_opname(s, op_name, other)
  199. if "mod" not in op_name:
  200. self.check_opname(s, op_name, s * 2)
  201. self.check_opname(s, op_name, 0)
  202. self.check_opname(s, op_name, 5)
  203. context.traps[decimal.DivisionByZero] = divbyzerotrap
  204. context.traps[decimal.InvalidOperation] = invalidoptrap
  205. def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
  206. # We implement divmod
  207. super(TestArithmeticOps, self)._check_divmod_op(
  208. s, op, other, exc=None
  209. )
  210. def test_error(self):
  211. pass
  212. class TestComparisonOps(BaseDecimal, base.BaseComparisonOpsTests):
  213. def check_opname(self, s, op_name, other, exc=None):
  214. super(TestComparisonOps, self).check_opname(s, op_name,
  215. other, exc=None)
  216. def _compare_other(self, s, data, op_name, other):
  217. self.check_opname(s, op_name, other)
  218. def test_compare_scalar(self, data, all_compare_operators):
  219. op_name = all_compare_operators
  220. s = pd.Series(data)
  221. self._compare_other(s, data, op_name, 0.5)
  222. def test_compare_array(self, data, all_compare_operators):
  223. op_name = all_compare_operators
  224. s = pd.Series(data)
  225. alter = np.random.choice([-1, 0, 1], len(data))
  226. # Randomly double, halve or keep same value
  227. other = pd.Series(data) * [decimal.Decimal(pow(2.0, i))
  228. for i in alter]
  229. self._compare_other(s, data, op_name, other)
  230. class DecimalArrayWithoutFromSequence(DecimalArray):
  231. """Helper class for testing error handling in _from_sequence."""
  232. def _from_sequence(cls, scalars, dtype=None, copy=False):
  233. raise KeyError("For the test")
  234. class DecimalArrayWithoutCoercion(DecimalArrayWithoutFromSequence):
  235. @classmethod
  236. def _create_arithmetic_method(cls, op):
  237. return cls._create_method(op, coerce_to_dtype=False)
  238. DecimalArrayWithoutCoercion._add_arithmetic_ops()
  239. def test_combine_from_sequence_raises():
  240. # https://github.com/pandas-dev/pandas/issues/22850
  241. ser = pd.Series(DecimalArrayWithoutFromSequence([
  242. decimal.Decimal("1.0"),
  243. decimal.Decimal("2.0")
  244. ]))
  245. result = ser.combine(ser, operator.add)
  246. # note: object dtype
  247. expected = pd.Series([decimal.Decimal("2.0"),
  248. decimal.Decimal("4.0")], dtype="object")
  249. tm.assert_series_equal(result, expected)
  250. @pytest.mark.parametrize("class_", [DecimalArrayWithoutFromSequence,
  251. DecimalArrayWithoutCoercion])
  252. def test_scalar_ops_from_sequence_raises(class_):
  253. # op(EA, EA) should return an EA, or an ndarray if it's not possible
  254. # to return an EA with the return values.
  255. arr = class_([
  256. decimal.Decimal("1.0"),
  257. decimal.Decimal("2.0")
  258. ])
  259. result = arr + arr
  260. expected = np.array([decimal.Decimal("2.0"), decimal.Decimal("4.0")],
  261. dtype="object")
  262. tm.assert_numpy_array_equal(result, expected)
  263. @pytest.mark.parametrize("reverse, expected_div, expected_mod", [
  264. (False, [0, 1, 1, 2], [1, 0, 1, 0]),
  265. (True, [2, 1, 0, 0], [0, 0, 2, 2]),
  266. ])
  267. def test_divmod_array(reverse, expected_div, expected_mod):
  268. # https://github.com/pandas-dev/pandas/issues/22930
  269. arr = to_decimal([1, 2, 3, 4])
  270. if reverse:
  271. div, mod = divmod(2, arr)
  272. else:
  273. div, mod = divmod(arr, 2)
  274. expected_div = to_decimal(expected_div)
  275. expected_mod = to_decimal(expected_mod)
  276. tm.assert_extension_array_equal(div, expected_div)
  277. tm.assert_extension_array_equal(mod, expected_mod)
  278. def test_formatting_values_deprecated():
  279. class DecimalArray2(DecimalArray):
  280. def _formatting_values(self):
  281. return np.array(self)
  282. ser = pd.Series(DecimalArray2([decimal.Decimal('1.0')]))
  283. # different levels for 2 vs. 3
  284. check_stacklevel = compat.PY3
  285. with tm.assert_produces_warning(DeprecationWarning,
  286. check_stacklevel=check_stacklevel):
  287. repr(ser)