common.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. """ common utilities """
  2. import itertools
  3. from warnings import catch_warnings, filterwarnings
  4. import numpy as np
  5. import pytest
  6. from pandas.compat import lrange
  7. from pandas.core.dtypes.common import is_scalar
  8. from pandas import (
  9. DataFrame, Float64Index, MultiIndex, Panel, Series, UInt64Index,
  10. date_range)
  11. from pandas.util import testing as tm
  12. from pandas.io.formats.printing import pprint_thing
  13. _verbose = False
  14. def _mklbl(prefix, n):
  15. return ["%s%s" % (prefix, i) for i in range(n)]
  16. def _axify(obj, key, axis):
  17. # create a tuple accessor
  18. axes = [slice(None)] * obj.ndim
  19. axes[axis] = key
  20. return tuple(axes)
  21. @pytest.mark.filterwarnings("ignore:\\nPanel:FutureWarning")
  22. class Base(object):
  23. """ indexing comprehensive base class """
  24. _objs = {'series', 'frame', 'panel'}
  25. _typs = {'ints', 'uints', 'labels', 'mixed', 'ts', 'floats', 'empty',
  26. 'ts_rev', 'multi'}
  27. def setup_method(self, method):
  28. self.series_ints = Series(np.random.rand(4), index=lrange(0, 8, 2))
  29. self.frame_ints = DataFrame(np.random.randn(4, 4),
  30. index=lrange(0, 8, 2),
  31. columns=lrange(0, 12, 3))
  32. with catch_warnings(record=True):
  33. self.panel_ints = Panel(np.random.rand(4, 4, 4),
  34. items=lrange(0, 8, 2),
  35. major_axis=lrange(0, 12, 3),
  36. minor_axis=lrange(0, 16, 4))
  37. self.series_uints = Series(np.random.rand(4),
  38. index=UInt64Index(lrange(0, 8, 2)))
  39. self.frame_uints = DataFrame(np.random.randn(4, 4),
  40. index=UInt64Index(lrange(0, 8, 2)),
  41. columns=UInt64Index(lrange(0, 12, 3)))
  42. self.panel_uints = Panel(np.random.rand(4, 4, 4),
  43. items=UInt64Index(lrange(0, 8, 2)),
  44. major_axis=UInt64Index(lrange(0, 12, 3)),
  45. minor_axis=UInt64Index(lrange(0, 16, 4)))
  46. self.series_floats = Series(np.random.rand(4),
  47. index=Float64Index(range(0, 8, 2)))
  48. self.frame_floats = DataFrame(np.random.randn(4, 4),
  49. index=Float64Index(range(0, 8, 2)),
  50. columns=Float64Index(range(0, 12, 3)))
  51. self.panel_floats = Panel(np.random.rand(4, 4, 4),
  52. items=Float64Index(range(0, 8, 2)),
  53. major_axis=Float64Index(range(0, 12, 3)),
  54. minor_axis=Float64Index(range(0, 16, 4)))
  55. m_idces = [MultiIndex.from_product([[1, 2], [3, 4]]),
  56. MultiIndex.from_product([[5, 6], [7, 8]]),
  57. MultiIndex.from_product([[9, 10], [11, 12]])]
  58. self.series_multi = Series(np.random.rand(4),
  59. index=m_idces[0])
  60. self.frame_multi = DataFrame(np.random.randn(4, 4),
  61. index=m_idces[0],
  62. columns=m_idces[1])
  63. self.panel_multi = Panel(np.random.rand(4, 4, 4),
  64. items=m_idces[0],
  65. major_axis=m_idces[1],
  66. minor_axis=m_idces[2])
  67. self.series_labels = Series(np.random.randn(4), index=list('abcd'))
  68. self.frame_labels = DataFrame(np.random.randn(4, 4),
  69. index=list('abcd'), columns=list('ABCD'))
  70. self.panel_labels = Panel(np.random.randn(4, 4, 4),
  71. items=list('abcd'),
  72. major_axis=list('ABCD'),
  73. minor_axis=list('ZYXW'))
  74. self.series_mixed = Series(np.random.randn(4), index=[2, 4, 'null', 8])
  75. self.frame_mixed = DataFrame(np.random.randn(4, 4),
  76. index=[2, 4, 'null', 8])
  77. self.panel_mixed = Panel(np.random.randn(4, 4, 4),
  78. items=[2, 4, 'null', 8])
  79. self.series_ts = Series(np.random.randn(4),
  80. index=date_range('20130101', periods=4))
  81. self.frame_ts = DataFrame(np.random.randn(4, 4),
  82. index=date_range('20130101', periods=4))
  83. self.panel_ts = Panel(np.random.randn(4, 4, 4),
  84. items=date_range('20130101', periods=4))
  85. dates_rev = (date_range('20130101', periods=4)
  86. .sort_values(ascending=False))
  87. self.series_ts_rev = Series(np.random.randn(4),
  88. index=dates_rev)
  89. self.frame_ts_rev = DataFrame(np.random.randn(4, 4),
  90. index=dates_rev)
  91. self.panel_ts_rev = Panel(np.random.randn(4, 4, 4),
  92. items=dates_rev)
  93. self.frame_empty = DataFrame({})
  94. self.series_empty = Series({})
  95. self.panel_empty = Panel({})
  96. # form agglomerates
  97. for o in self._objs:
  98. d = dict()
  99. for t in self._typs:
  100. d[t] = getattr(self, '%s_%s' % (o, t), None)
  101. setattr(self, o, d)
  102. def generate_indices(self, f, values=False):
  103. """ generate the indices
  104. if values is True , use the axis values
  105. is False, use the range
  106. """
  107. axes = f.axes
  108. if values:
  109. axes = [lrange(len(a)) for a in axes]
  110. return itertools.product(*axes)
  111. def get_result(self, obj, method, key, axis):
  112. """ return the result for this obj with this key and this axis """
  113. if isinstance(key, dict):
  114. key = key[axis]
  115. # use an artificial conversion to map the key as integers to the labels
  116. # so ix can work for comparisons
  117. if method == 'indexer':
  118. method = 'ix'
  119. key = obj._get_axis(axis)[key]
  120. # in case we actually want 0 index slicing
  121. with catch_warnings(record=True):
  122. try:
  123. xp = getattr(obj, method).__getitem__(_axify(obj, key, axis))
  124. except AttributeError:
  125. xp = getattr(obj, method).__getitem__(key)
  126. return xp
  127. def get_value(self, f, i, values=False):
  128. """ return the value for the location i """
  129. # check against values
  130. if values:
  131. return f.values[i]
  132. # this is equiv of f[col][row].....
  133. # v = f
  134. # for a in reversed(i):
  135. # v = v.__getitem__(a)
  136. # return v
  137. with catch_warnings(record=True):
  138. filterwarnings("ignore", "\\n.ix", DeprecationWarning)
  139. return f.ix[i]
  140. def check_values(self, f, func, values=False):
  141. if f is None:
  142. return
  143. axes = f.axes
  144. indicies = itertools.product(*axes)
  145. for i in indicies:
  146. result = getattr(f, func)[i]
  147. # check against values
  148. if values:
  149. expected = f.values[i]
  150. else:
  151. expected = f
  152. for a in reversed(i):
  153. expected = expected.__getitem__(a)
  154. tm.assert_almost_equal(result, expected)
  155. def check_result(self, name, method1, key1, method2, key2, typs=None,
  156. objs=None, axes=None, fails=None):
  157. def _eq(t, o, a, obj, k1, k2):
  158. """ compare equal for these 2 keys """
  159. if a is not None and a > obj.ndim - 1:
  160. return
  161. def _print(result, error=None):
  162. if error is not None:
  163. error = str(error)
  164. v = ("%-16.16s [%-16.16s]: [typ->%-8.8s,obj->%-8.8s,"
  165. "key1->(%-4.4s),key2->(%-4.4s),axis->%s] %s" %
  166. (name, result, t, o, method1, method2, a, error or ''))
  167. if _verbose:
  168. pprint_thing(v)
  169. try:
  170. rs = getattr(obj, method1).__getitem__(_axify(obj, k1, a))
  171. try:
  172. xp = self.get_result(obj, method2, k2, a)
  173. except Exception:
  174. result = 'no comp'
  175. _print(result)
  176. return
  177. detail = None
  178. try:
  179. if is_scalar(rs) and is_scalar(xp):
  180. assert rs == xp
  181. elif xp.ndim == 1:
  182. tm.assert_series_equal(rs, xp)
  183. elif xp.ndim == 2:
  184. tm.assert_frame_equal(rs, xp)
  185. elif xp.ndim == 3:
  186. tm.assert_panel_equal(rs, xp)
  187. result = 'ok'
  188. except AssertionError as e:
  189. detail = str(e)
  190. result = 'fail'
  191. # reverse the checks
  192. if fails is True:
  193. if result == 'fail':
  194. result = 'ok (fail)'
  195. _print(result)
  196. if not result.startswith('ok'):
  197. raise AssertionError(detail)
  198. except AssertionError:
  199. raise
  200. except Exception as detail:
  201. # if we are in fails, the ok, otherwise raise it
  202. if fails is not None:
  203. if isinstance(detail, fails):
  204. result = 'ok (%s)' % type(detail).__name__
  205. _print(result)
  206. return
  207. result = type(detail).__name__
  208. raise AssertionError(_print(result, error=detail))
  209. if typs is None:
  210. typs = self._typs
  211. if objs is None:
  212. objs = self._objs
  213. if axes is not None:
  214. if not isinstance(axes, (tuple, list)):
  215. axes = [axes]
  216. else:
  217. axes = list(axes)
  218. else:
  219. axes = [0, 1, 2]
  220. # check
  221. for o in objs:
  222. if o not in self._objs:
  223. continue
  224. d = getattr(self, o)
  225. for a in axes:
  226. for t in typs:
  227. if t not in self._typs:
  228. continue
  229. obj = d[t]
  230. if obj is None:
  231. continue
  232. def _call(obj=obj):
  233. obj = obj.copy()
  234. k2 = key2
  235. _eq(t, o, a, obj, key1, k2)
  236. # Panel deprecations
  237. if isinstance(obj, Panel):
  238. with catch_warnings():
  239. filterwarnings("ignore", "\nPanel*", FutureWarning)
  240. _call()
  241. else:
  242. _call()