test_assert_series_equal.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # -*- coding: utf-8 -*-
  2. import pytest
  3. from pandas import Categorical, DataFrame, Series
  4. from pandas.util.testing import assert_series_equal
  5. def _assert_series_equal_both(a, b, **kwargs):
  6. """
  7. Check that two Series equal.
  8. This check is performed commutatively.
  9. Parameters
  10. ----------
  11. a : Series
  12. The first Series to compare.
  13. b : Series
  14. The second Series to compare.
  15. kwargs : dict
  16. The arguments passed to `assert_series_equal`.
  17. """
  18. assert_series_equal(a, b, **kwargs)
  19. assert_series_equal(b, a, **kwargs)
  20. def _assert_not_series_equal(a, b, **kwargs):
  21. """
  22. Check that two Series are not equal.
  23. Parameters
  24. ----------
  25. a : Series
  26. The first Series to compare.
  27. b : Series
  28. The second Series to compare.
  29. kwargs : dict
  30. The arguments passed to `assert_series_equal`.
  31. """
  32. try:
  33. assert_series_equal(a, b, **kwargs)
  34. msg = "The two Series were equal when they shouldn't have been"
  35. pytest.fail(msg=msg)
  36. except AssertionError:
  37. pass
  38. def _assert_not_series_equal_both(a, b, **kwargs):
  39. """
  40. Check that two Series are not equal.
  41. This check is performed commutatively.
  42. Parameters
  43. ----------
  44. a : Series
  45. The first Series to compare.
  46. b : Series
  47. The second Series to compare.
  48. kwargs : dict
  49. The arguments passed to `assert_series_equal`.
  50. """
  51. _assert_not_series_equal(a, b, **kwargs)
  52. _assert_not_series_equal(b, a, **kwargs)
  53. @pytest.mark.parametrize("data", [
  54. range(3), list("abc"), list(u"áàä"),
  55. ])
  56. def test_series_equal(data):
  57. _assert_series_equal_both(Series(data), Series(data))
  58. @pytest.mark.parametrize("data1,data2", [
  59. (range(3), range(1, 4)),
  60. (list("abc"), list("xyz")),
  61. (list(u"áàä"), list(u"éèë")),
  62. (list(u"áàä"), list(b"aaa")),
  63. (range(3), range(4)),
  64. ])
  65. def test_series_not_equal_value_mismatch(data1, data2):
  66. _assert_not_series_equal_both(Series(data1), Series(data2))
  67. @pytest.mark.parametrize("kwargs", [
  68. dict(dtype="float64"), # dtype mismatch
  69. dict(index=[1, 2, 4]), # index mismatch
  70. dict(name="foo"), # name mismatch
  71. ])
  72. def test_series_not_equal_metadata_mismatch(kwargs):
  73. data = range(3)
  74. s1 = Series(data)
  75. s2 = Series(data, **kwargs)
  76. _assert_not_series_equal_both(s1, s2)
  77. @pytest.mark.parametrize("data1,data2", [(0.12345, 0.12346), (0.1235, 0.1236)])
  78. @pytest.mark.parametrize("dtype", ["float32", "float64"])
  79. @pytest.mark.parametrize("check_less_precise", [False, True, 0, 1, 2, 3, 10])
  80. def test_less_precise(data1, data2, dtype, check_less_precise):
  81. s1 = Series([data1], dtype=dtype)
  82. s2 = Series([data2], dtype=dtype)
  83. kwargs = dict(check_less_precise=check_less_precise)
  84. if ((check_less_precise is False or check_less_precise == 10) or
  85. ((check_less_precise is True or check_less_precise >= 3) and
  86. abs(data1 - data2) >= 0.0001)):
  87. msg = "Series values are different"
  88. with pytest.raises(AssertionError, match=msg):
  89. assert_series_equal(s1, s2, **kwargs)
  90. else:
  91. _assert_series_equal_both(s1, s2, **kwargs)
  92. @pytest.mark.parametrize("s1,s2,msg", [
  93. # Index
  94. (Series(["l1", "l2"], index=[1, 2]),
  95. Series(["l1", "l2"], index=[1., 2.]),
  96. "Series\\.index are different"),
  97. # MultiIndex
  98. (DataFrame.from_records({"a": [1, 2], "b": [2.1, 1.5],
  99. "c": ["l1", "l2"]}, index=["a", "b"]).c,
  100. DataFrame.from_records({"a": [1., 2.], "b": [2.1, 1.5],
  101. "c": ["l1", "l2"]}, index=["a", "b"]).c,
  102. "MultiIndex level \\[0\\] are different")
  103. ])
  104. def test_series_equal_index_dtype(s1, s2, msg, check_index_type):
  105. kwargs = dict(check_index_type=check_index_type)
  106. if check_index_type:
  107. with pytest.raises(AssertionError, match=msg):
  108. assert_series_equal(s1, s2, **kwargs)
  109. else:
  110. assert_series_equal(s1, s2, **kwargs)
  111. def test_series_equal_length_mismatch(check_less_precise):
  112. msg = """Series are different
  113. Series length are different
  114. \\[left\\]: 3, RangeIndex\\(start=0, stop=3, step=1\\)
  115. \\[right\\]: 4, RangeIndex\\(start=0, stop=4, step=1\\)"""
  116. s1 = Series([1, 2, 3])
  117. s2 = Series([1, 2, 3, 4])
  118. with pytest.raises(AssertionError, match=msg):
  119. assert_series_equal(s1, s2, check_less_precise=check_less_precise)
  120. def test_series_equal_values_mismatch(check_less_precise):
  121. msg = """Series are different
  122. Series values are different \\(33\\.33333 %\\)
  123. \\[left\\]: \\[1, 2, 3\\]
  124. \\[right\\]: \\[1, 2, 4\\]"""
  125. s1 = Series([1, 2, 3])
  126. s2 = Series([1, 2, 4])
  127. with pytest.raises(AssertionError, match=msg):
  128. assert_series_equal(s1, s2, check_less_precise=check_less_precise)
  129. def test_series_equal_categorical_mismatch(check_categorical):
  130. msg = """Attributes are different
  131. Attribute "dtype" are different
  132. \\[left\\]: CategoricalDtype\\(categories=\\[u?'a', u?'b'\\], ordered=False\\)
  133. \\[right\\]: CategoricalDtype\\(categories=\\[u?'a', u?'b', u?'c'\\], \
  134. ordered=False\\)"""
  135. s1 = Series(Categorical(["a", "b"]))
  136. s2 = Series(Categorical(["a", "b"], categories=list("abc")))
  137. if check_categorical:
  138. with pytest.raises(AssertionError, match=msg):
  139. assert_series_equal(s1, s2, check_categorical=check_categorical)
  140. else:
  141. _assert_series_equal_both(s1, s2, check_categorical=check_categorical)