test_assert_almost_equal.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. import pytest
  4. from pandas import DataFrame, Index, Series, Timestamp
  5. from pandas.util.testing import assert_almost_equal
  6. def _assert_almost_equal_both(a, b, **kwargs):
  7. """
  8. Check that two objects are approximately equal.
  9. This check is performed commutatively.
  10. Parameters
  11. ----------
  12. a : object
  13. The first object to compare.
  14. b : object
  15. The second object to compare.
  16. kwargs : dict
  17. The arguments passed to `assert_almost_equal`.
  18. """
  19. assert_almost_equal(a, b, **kwargs)
  20. assert_almost_equal(b, a, **kwargs)
  21. def _assert_not_almost_equal(a, b, **kwargs):
  22. """
  23. Check that two objects are not approximately equal.
  24. Parameters
  25. ----------
  26. a : object
  27. The first object to compare.
  28. b : object
  29. The second object to compare.
  30. kwargs : dict
  31. The arguments passed to `assert_almost_equal`.
  32. """
  33. try:
  34. assert_almost_equal(a, b, **kwargs)
  35. msg = ("{a} and {b} were approximately equal "
  36. "when they shouldn't have been").format(a=a, b=b)
  37. pytest.fail(msg=msg)
  38. except AssertionError:
  39. pass
  40. def _assert_not_almost_equal_both(a, b, **kwargs):
  41. """
  42. Check that two objects are not approximately equal.
  43. This check is performed commutatively.
  44. Parameters
  45. ----------
  46. a : object
  47. The first object to compare.
  48. b : object
  49. The second object to compare.
  50. kwargs : dict
  51. The arguments passed to `tm.assert_almost_equal`.
  52. """
  53. _assert_not_almost_equal(a, b, **kwargs)
  54. _assert_not_almost_equal(b, a, **kwargs)
  55. @pytest.mark.parametrize("a,b", [
  56. (1.1, 1.1), (1.1, 1.100001), (np.int16(1), 1.000001),
  57. (np.float64(1.1), 1.1), (np.uint32(5), 5),
  58. ])
  59. def test_assert_almost_equal_numbers(a, b):
  60. _assert_almost_equal_both(a, b)
  61. @pytest.mark.parametrize("a,b", [
  62. (1.1, 1), (1.1, True), (1, 2), (1.0001, np.int16(1)),
  63. ])
  64. def test_assert_not_almost_equal_numbers(a, b):
  65. _assert_not_almost_equal_both(a, b)
  66. @pytest.mark.parametrize("a,b", [
  67. (0, 0), (0, 0.0), (0, np.float64(0)), (0.000001, 0),
  68. ])
  69. def test_assert_almost_equal_numbers_with_zeros(a, b):
  70. _assert_almost_equal_both(a, b)
  71. @pytest.mark.parametrize("a,b", [
  72. (0.001, 0), (1, 0),
  73. ])
  74. def test_assert_not_almost_equal_numbers_with_zeros(a, b):
  75. _assert_not_almost_equal_both(a, b)
  76. @pytest.mark.parametrize("a,b", [
  77. (1, "abc"), (1, [1, ]), (1, object()),
  78. ])
  79. def test_assert_not_almost_equal_numbers_with_mixed(a, b):
  80. _assert_not_almost_equal_both(a, b)
  81. @pytest.mark.parametrize(
  82. "left_dtype", ["M8[ns]", "m8[ns]", "float64", "int64", "object"])
  83. @pytest.mark.parametrize(
  84. "right_dtype", ["M8[ns]", "m8[ns]", "float64", "int64", "object"])
  85. def test_assert_almost_equal_edge_case_ndarrays(left_dtype, right_dtype):
  86. # Empty compare.
  87. _assert_almost_equal_both(np.array([], dtype=left_dtype),
  88. np.array([], dtype=right_dtype),
  89. check_dtype=False)
  90. def test_assert_almost_equal_dicts():
  91. _assert_almost_equal_both({"a": 1, "b": 2}, {"a": 1, "b": 2})
  92. @pytest.mark.parametrize("a,b", [
  93. ({"a": 1, "b": 2}, {"a": 1, "b": 3}),
  94. ({"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 3}),
  95. ({"a": 1}, 1), ({"a": 1}, "abc"), ({"a": 1}, [1, ]),
  96. ])
  97. def test_assert_not_almost_equal_dicts(a, b):
  98. _assert_not_almost_equal_both(a, b)
  99. @pytest.mark.parametrize("val", [1, 2])
  100. def test_assert_almost_equal_dict_like_object(val):
  101. dict_val = 1
  102. real_dict = dict(a=val)
  103. class DictLikeObj(object):
  104. def keys(self):
  105. return "a",
  106. def __getitem__(self, item):
  107. if item == "a":
  108. return dict_val
  109. func = (_assert_almost_equal_both if val == dict_val
  110. else _assert_not_almost_equal_both)
  111. func(real_dict, DictLikeObj(), check_dtype=False)
  112. def test_assert_almost_equal_strings():
  113. _assert_almost_equal_both("abc", "abc")
  114. @pytest.mark.parametrize("a,b", [
  115. ("abc", "abcd"), ("abc", "abd"), ("abc", 1), ("abc", [1, ]),
  116. ])
  117. def test_assert_not_almost_equal_strings(a, b):
  118. _assert_not_almost_equal_both(a, b)
  119. @pytest.mark.parametrize("a,b", [
  120. ([1, 2, 3], [1, 2, 3]), (np.array([1, 2, 3]), np.array([1, 2, 3])),
  121. ])
  122. def test_assert_almost_equal_iterables(a, b):
  123. _assert_almost_equal_both(a, b)
  124. @pytest.mark.parametrize("a,b", [
  125. # Class is different.
  126. (np.array([1, 2, 3]), [1, 2, 3]),
  127. # Dtype is different.
  128. (np.array([1, 2, 3]), np.array([1., 2., 3.])),
  129. # Can't compare generators.
  130. (iter([1, 2, 3]), [1, 2, 3]), ([1, 2, 3], [1, 2, 4]),
  131. ([1, 2, 3], [1, 2, 3, 4]), ([1, 2, 3], 1),
  132. ])
  133. def test_assert_not_almost_equal_iterables(a, b):
  134. _assert_not_almost_equal(a, b)
  135. def test_assert_almost_equal_null():
  136. _assert_almost_equal_both(None, None)
  137. @pytest.mark.parametrize("a,b", [
  138. (None, np.NaN), (None, 0), (np.NaN, 0),
  139. ])
  140. def test_assert_not_almost_equal_null(a, b):
  141. _assert_not_almost_equal(a, b)
  142. @pytest.mark.parametrize("a,b", [
  143. (np.inf, np.inf), (np.inf, float("inf")),
  144. (np.array([np.inf, np.nan, -np.inf]),
  145. np.array([np.inf, np.nan, -np.inf])),
  146. (np.array([np.inf, None, -np.inf], dtype=np.object_),
  147. np.array([np.inf, np.nan, -np.inf], dtype=np.object_)),
  148. ])
  149. def test_assert_almost_equal_inf(a, b):
  150. _assert_almost_equal_both(a, b)
  151. def test_assert_not_almost_equal_inf():
  152. _assert_not_almost_equal_both(np.inf, 0)
  153. @pytest.mark.parametrize("a,b", [
  154. (Index([1., 1.1]), Index([1., 1.100001])),
  155. (Series([1., 1.1]), Series([1., 1.100001])),
  156. (np.array([1.1, 2.000001]), np.array([1.1, 2.0])),
  157. (DataFrame({"a": [1., 1.1]}), DataFrame({"a": [1., 1.100001]}))
  158. ])
  159. def test_assert_almost_equal_pandas(a, b):
  160. _assert_almost_equal_both(a, b)
  161. def test_assert_almost_equal_object():
  162. a = [Timestamp("2011-01-01"), Timestamp("2011-01-01")]
  163. b = [Timestamp("2011-01-01"), Timestamp("2011-01-01")]
  164. _assert_almost_equal_both(a, b)
  165. def test_assert_almost_equal_value_mismatch():
  166. msg = "expected 2\\.00000 but got 1\\.00000, with decimal 5"
  167. with pytest.raises(AssertionError, match=msg):
  168. assert_almost_equal(1, 2)
  169. @pytest.mark.parametrize("a,b,klass1,klass2", [
  170. (np.array([1]), 1, "ndarray", "int"),
  171. (1, np.array([1]), "int", "ndarray"),
  172. ])
  173. def test_assert_almost_equal_class_mismatch(a, b, klass1, klass2):
  174. msg = """numpy array are different
  175. numpy array classes are different
  176. \\[left\\]: {klass1}
  177. \\[right\\]: {klass2}""".format(klass1=klass1, klass2=klass2)
  178. with pytest.raises(AssertionError, match=msg):
  179. assert_almost_equal(a, b)
  180. def test_assert_almost_equal_value_mismatch1():
  181. msg = """numpy array are different
  182. numpy array values are different \\(66\\.66667 %\\)
  183. \\[left\\]: \\[nan, 2\\.0, 3\\.0\\]
  184. \\[right\\]: \\[1\\.0, nan, 3\\.0\\]"""
  185. with pytest.raises(AssertionError, match=msg):
  186. assert_almost_equal(np.array([np.nan, 2, 3]),
  187. np.array([1, np.nan, 3]))
  188. def test_assert_almost_equal_value_mismatch2():
  189. msg = """numpy array are different
  190. numpy array values are different \\(50\\.0 %\\)
  191. \\[left\\]: \\[1, 2\\]
  192. \\[right\\]: \\[1, 3\\]"""
  193. with pytest.raises(AssertionError, match=msg):
  194. assert_almost_equal(np.array([1, 2]), np.array([1, 3]))
  195. def test_assert_almost_equal_value_mismatch3():
  196. msg = """numpy array are different
  197. numpy array values are different \\(16\\.66667 %\\)
  198. \\[left\\]: \\[\\[1, 2\\], \\[3, 4\\], \\[5, 6\\]\\]
  199. \\[right\\]: \\[\\[1, 3\\], \\[3, 4\\], \\[5, 6\\]\\]"""
  200. with pytest.raises(AssertionError, match=msg):
  201. assert_almost_equal(np.array([[1, 2], [3, 4], [5, 6]]),
  202. np.array([[1, 3], [3, 4], [5, 6]]))
  203. def test_assert_almost_equal_value_mismatch4():
  204. msg = """numpy array are different
  205. numpy array values are different \\(25\\.0 %\\)
  206. \\[left\\]: \\[\\[1, 2\\], \\[3, 4\\]\\]
  207. \\[right\\]: \\[\\[1, 3\\], \\[3, 4\\]\\]"""
  208. with pytest.raises(AssertionError, match=msg):
  209. assert_almost_equal(np.array([[1, 2], [3, 4]]),
  210. np.array([[1, 3], [3, 4]]))
  211. def test_assert_almost_equal_shape_mismatch_override():
  212. msg = """Index are different
  213. Index shapes are different
  214. \\[left\\]: \\(2L*,\\)
  215. \\[right\\]: \\(3L*,\\)"""
  216. with pytest.raises(AssertionError, match=msg):
  217. assert_almost_equal(np.array([1, 2]),
  218. np.array([3, 4, 5]),
  219. obj="Index")
  220. def test_assert_almost_equal_unicode():
  221. # see gh-20503
  222. msg = """numpy array are different
  223. numpy array values are different \\(33\\.33333 %\\)
  224. \\[left\\]: \\[á, à, ä\\]
  225. \\[right\\]: \\[á, à, å\\]"""
  226. with pytest.raises(AssertionError, match=msg):
  227. assert_almost_equal(np.array([u"á", u"à", u"ä"]),
  228. np.array([u"á", u"à", u"å"]))
  229. def test_assert_almost_equal_timestamp():
  230. a = np.array([Timestamp("2011-01-01"), Timestamp("2011-01-01")])
  231. b = np.array([Timestamp("2011-01-01"), Timestamp("2011-01-02")])
  232. msg = """numpy array are different
  233. numpy array values are different \\(50\\.0 %\\)
  234. \\[left\\]: \\[2011-01-01 00:00:00, 2011-01-01 00:00:00\\]
  235. \\[right\\]: \\[2011-01-01 00:00:00, 2011-01-02 00:00:00\\]"""
  236. with pytest.raises(AssertionError, match=msg):
  237. assert_almost_equal(a, b)
  238. def test_assert_almost_equal_iterable_length_mismatch():
  239. msg = """Iterable are different
  240. Iterable length are different
  241. \\[left\\]: 2
  242. \\[right\\]: 3"""
  243. with pytest.raises(AssertionError, match=msg):
  244. assert_almost_equal([1, 2], [3, 4, 5])
  245. def test_assert_almost_equal_iterable_values_mismatch():
  246. msg = """Iterable are different
  247. Iterable values are different \\(50\\.0 %\\)
  248. \\[left\\]: \\[1, 2\\]
  249. \\[right\\]: \\[1, 3\\]"""
  250. with pytest.raises(AssertionError, match=msg):
  251. assert_almost_equal([1, 2], [1, 3])