common.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. import os
  4. import warnings
  5. import numpy as np
  6. from numpy import random
  7. import pytest
  8. from pandas.compat import iteritems, zip
  9. from pandas.util._decorators import cache_readonly
  10. import pandas.util._test_decorators as td
  11. from pandas.core.dtypes.api import is_list_like
  12. from pandas import DataFrame, Series
  13. import pandas.util.testing as tm
  14. from pandas.util.testing import (
  15. assert_is_valid_plot_return_object, ensure_clean)
  16. import pandas.plotting as plotting
  17. from pandas.plotting._tools import _flatten
  18. """
  19. This is a common base class used for various plotting tests
  20. """
  21. def _skip_if_no_scipy_gaussian_kde():
  22. try:
  23. from scipy.stats import gaussian_kde # noqa
  24. except ImportError:
  25. pytest.skip("scipy version doesn't support gaussian_kde")
  26. def _ok_for_gaussian_kde(kind):
  27. if kind in ['kde', 'density']:
  28. try:
  29. from scipy.stats import gaussian_kde # noqa
  30. except ImportError:
  31. return False
  32. return True
  33. @td.skip_if_no_mpl
  34. class TestPlotBase(object):
  35. def setup_method(self, method):
  36. import matplotlib as mpl
  37. mpl.rcdefaults()
  38. self.mpl_ge_2_0_1 = plotting._compat._mpl_ge_2_0_1()
  39. self.mpl_ge_2_1_0 = plotting._compat._mpl_ge_2_1_0()
  40. self.mpl_ge_2_2_0 = plotting._compat._mpl_ge_2_2_0()
  41. self.mpl_ge_2_2_2 = plotting._compat._mpl_ge_2_2_2()
  42. self.mpl_ge_3_0_0 = plotting._compat._mpl_ge_3_0_0()
  43. self.bp_n_objects = 7
  44. self.polycollection_factor = 2
  45. self.default_figsize = (6.4, 4.8)
  46. self.default_tick_position = 'left'
  47. n = 100
  48. with tm.RNGContext(42):
  49. gender = np.random.choice(['Male', 'Female'], size=n)
  50. classroom = np.random.choice(['A', 'B', 'C'], size=n)
  51. self.hist_df = DataFrame({'gender': gender,
  52. 'classroom': classroom,
  53. 'height': random.normal(66, 4, size=n),
  54. 'weight': random.normal(161, 32, size=n),
  55. 'category': random.randint(4, size=n)})
  56. self.tdf = tm.makeTimeDataFrame()
  57. self.hexbin_df = DataFrame({"A": np.random.uniform(size=20),
  58. "B": np.random.uniform(size=20),
  59. "C": np.arange(20) + np.random.uniform(
  60. size=20)})
  61. def teardown_method(self, method):
  62. tm.close()
  63. @cache_readonly
  64. def plt(self):
  65. import matplotlib.pyplot as plt
  66. return plt
  67. @cache_readonly
  68. def colorconverter(self):
  69. import matplotlib.colors as colors
  70. return colors.colorConverter
  71. def _check_legend_labels(self, axes, labels=None, visible=True):
  72. """
  73. Check each axes has expected legend labels
  74. Parameters
  75. ----------
  76. axes : matplotlib Axes object, or its list-like
  77. labels : list-like
  78. expected legend labels
  79. visible : bool
  80. expected legend visibility. labels are checked only when visible is
  81. True
  82. """
  83. if visible and (labels is None):
  84. raise ValueError('labels must be specified when visible is True')
  85. axes = self._flatten_visible(axes)
  86. for ax in axes:
  87. if visible:
  88. assert ax.get_legend() is not None
  89. self._check_text_labels(ax.get_legend().get_texts(), labels)
  90. else:
  91. assert ax.get_legend() is None
  92. def _check_data(self, xp, rs):
  93. """
  94. Check each axes has identical lines
  95. Parameters
  96. ----------
  97. xp : matplotlib Axes object
  98. rs : matplotlib Axes object
  99. """
  100. xp_lines = xp.get_lines()
  101. rs_lines = rs.get_lines()
  102. def check_line(xpl, rsl):
  103. xpdata = xpl.get_xydata()
  104. rsdata = rsl.get_xydata()
  105. tm.assert_almost_equal(xpdata, rsdata)
  106. assert len(xp_lines) == len(rs_lines)
  107. [check_line(xpl, rsl) for xpl, rsl in zip(xp_lines, rs_lines)]
  108. tm.close()
  109. def _check_visible(self, collections, visible=True):
  110. """
  111. Check each artist is visible or not
  112. Parameters
  113. ----------
  114. collections : matplotlib Artist or its list-like
  115. target Artist or its list or collection
  116. visible : bool
  117. expected visibility
  118. """
  119. from matplotlib.collections import Collection
  120. if not isinstance(collections,
  121. Collection) and not is_list_like(collections):
  122. collections = [collections]
  123. for patch in collections:
  124. assert patch.get_visible() == visible
  125. def _get_colors_mapped(self, series, colors):
  126. unique = series.unique()
  127. # unique and colors length can be differed
  128. # depending on slice value
  129. mapped = dict(zip(unique, colors))
  130. return [mapped[v] for v in series.values]
  131. def _check_colors(self, collections, linecolors=None, facecolors=None,
  132. mapping=None):
  133. """
  134. Check each artist has expected line colors and face colors
  135. Parameters
  136. ----------
  137. collections : list-like
  138. list or collection of target artist
  139. linecolors : list-like which has the same length as collections
  140. list of expected line colors
  141. facecolors : list-like which has the same length as collections
  142. list of expected face colors
  143. mapping : Series
  144. Series used for color grouping key
  145. used for andrew_curves, parallel_coordinates, radviz test
  146. """
  147. from matplotlib.lines import Line2D
  148. from matplotlib.collections import (
  149. Collection, PolyCollection, LineCollection
  150. )
  151. conv = self.colorconverter
  152. if linecolors is not None:
  153. if mapping is not None:
  154. linecolors = self._get_colors_mapped(mapping, linecolors)
  155. linecolors = linecolors[:len(collections)]
  156. assert len(collections) == len(linecolors)
  157. for patch, color in zip(collections, linecolors):
  158. if isinstance(patch, Line2D):
  159. result = patch.get_color()
  160. # Line2D may contains string color expression
  161. result = conv.to_rgba(result)
  162. elif isinstance(patch, (PolyCollection, LineCollection)):
  163. result = tuple(patch.get_edgecolor()[0])
  164. else:
  165. result = patch.get_edgecolor()
  166. expected = conv.to_rgba(color)
  167. assert result == expected
  168. if facecolors is not None:
  169. if mapping is not None:
  170. facecolors = self._get_colors_mapped(mapping, facecolors)
  171. facecolors = facecolors[:len(collections)]
  172. assert len(collections) == len(facecolors)
  173. for patch, color in zip(collections, facecolors):
  174. if isinstance(patch, Collection):
  175. # returned as list of np.array
  176. result = patch.get_facecolor()[0]
  177. else:
  178. result = patch.get_facecolor()
  179. if isinstance(result, np.ndarray):
  180. result = tuple(result)
  181. expected = conv.to_rgba(color)
  182. assert result == expected
  183. def _check_text_labels(self, texts, expected):
  184. """
  185. Check each text has expected labels
  186. Parameters
  187. ----------
  188. texts : matplotlib Text object, or its list-like
  189. target text, or its list
  190. expected : str or list-like which has the same length as texts
  191. expected text label, or its list
  192. """
  193. if not is_list_like(texts):
  194. assert texts.get_text() == expected
  195. else:
  196. labels = [t.get_text() for t in texts]
  197. assert len(labels) == len(expected)
  198. for label, e in zip(labels, expected):
  199. assert label == e
  200. def _check_ticks_props(self, axes, xlabelsize=None, xrot=None,
  201. ylabelsize=None, yrot=None):
  202. """
  203. Check each axes has expected tick properties
  204. Parameters
  205. ----------
  206. axes : matplotlib Axes object, or its list-like
  207. xlabelsize : number
  208. expected xticks font size
  209. xrot : number
  210. expected xticks rotation
  211. ylabelsize : number
  212. expected yticks font size
  213. yrot : number
  214. expected yticks rotation
  215. """
  216. from matplotlib.ticker import NullFormatter
  217. axes = self._flatten_visible(axes)
  218. for ax in axes:
  219. if xlabelsize or xrot:
  220. if isinstance(ax.xaxis.get_minor_formatter(), NullFormatter):
  221. # If minor ticks has NullFormatter, rot / fontsize are not
  222. # retained
  223. labels = ax.get_xticklabels()
  224. else:
  225. labels = ax.get_xticklabels() + ax.get_xticklabels(
  226. minor=True)
  227. for label in labels:
  228. if xlabelsize is not None:
  229. tm.assert_almost_equal(label.get_fontsize(),
  230. xlabelsize)
  231. if xrot is not None:
  232. tm.assert_almost_equal(label.get_rotation(), xrot)
  233. if ylabelsize or yrot:
  234. if isinstance(ax.yaxis.get_minor_formatter(), NullFormatter):
  235. labels = ax.get_yticklabels()
  236. else:
  237. labels = ax.get_yticklabels() + ax.get_yticklabels(
  238. minor=True)
  239. for label in labels:
  240. if ylabelsize is not None:
  241. tm.assert_almost_equal(label.get_fontsize(),
  242. ylabelsize)
  243. if yrot is not None:
  244. tm.assert_almost_equal(label.get_rotation(), yrot)
  245. def _check_ax_scales(self, axes, xaxis='linear', yaxis='linear'):
  246. """
  247. Check each axes has expected scales
  248. Parameters
  249. ----------
  250. axes : matplotlib Axes object, or its list-like
  251. xaxis : {'linear', 'log'}
  252. expected xaxis scale
  253. yaxis : {'linear', 'log'}
  254. expected yaxis scale
  255. """
  256. axes = self._flatten_visible(axes)
  257. for ax in axes:
  258. assert ax.xaxis.get_scale() == xaxis
  259. assert ax.yaxis.get_scale() == yaxis
  260. def _check_axes_shape(self, axes, axes_num=None, layout=None,
  261. figsize=None):
  262. """
  263. Check expected number of axes is drawn in expected layout
  264. Parameters
  265. ----------
  266. axes : matplotlib Axes object, or its list-like
  267. axes_num : number
  268. expected number of axes. Unnecessary axes should be set to
  269. invisible.
  270. layout : tuple
  271. expected layout, (expected number of rows , columns)
  272. figsize : tuple
  273. expected figsize. default is matplotlib default
  274. """
  275. if figsize is None:
  276. figsize = self.default_figsize
  277. visible_axes = self._flatten_visible(axes)
  278. if axes_num is not None:
  279. assert len(visible_axes) == axes_num
  280. for ax in visible_axes:
  281. # check something drawn on visible axes
  282. assert len(ax.get_children()) > 0
  283. if layout is not None:
  284. result = self._get_axes_layout(_flatten(axes))
  285. assert result == layout
  286. tm.assert_numpy_array_equal(
  287. visible_axes[0].figure.get_size_inches(),
  288. np.array(figsize, dtype=np.float64))
  289. def _get_axes_layout(self, axes):
  290. x_set = set()
  291. y_set = set()
  292. for ax in axes:
  293. # check axes coordinates to estimate layout
  294. points = ax.get_position().get_points()
  295. x_set.add(points[0][0])
  296. y_set.add(points[0][1])
  297. return (len(y_set), len(x_set))
  298. def _flatten_visible(self, axes):
  299. """
  300. Flatten axes, and filter only visible
  301. Parameters
  302. ----------
  303. axes : matplotlib Axes object, or its list-like
  304. """
  305. axes = _flatten(axes)
  306. axes = [ax for ax in axes if ax.get_visible()]
  307. return axes
  308. def _check_has_errorbars(self, axes, xerr=0, yerr=0):
  309. """
  310. Check axes has expected number of errorbars
  311. Parameters
  312. ----------
  313. axes : matplotlib Axes object, or its list-like
  314. xerr : number
  315. expected number of x errorbar
  316. yerr : number
  317. expected number of y errorbar
  318. """
  319. axes = self._flatten_visible(axes)
  320. for ax in axes:
  321. containers = ax.containers
  322. xerr_count = 0
  323. yerr_count = 0
  324. for c in containers:
  325. has_xerr = getattr(c, 'has_xerr', False)
  326. has_yerr = getattr(c, 'has_yerr', False)
  327. if has_xerr:
  328. xerr_count += 1
  329. if has_yerr:
  330. yerr_count += 1
  331. assert xerr == xerr_count
  332. assert yerr == yerr_count
  333. def _check_box_return_type(self, returned, return_type, expected_keys=None,
  334. check_ax_title=True):
  335. """
  336. Check box returned type is correct
  337. Parameters
  338. ----------
  339. returned : object to be tested, returned from boxplot
  340. return_type : str
  341. return_type passed to boxplot
  342. expected_keys : list-like, optional
  343. group labels in subplot case. If not passed,
  344. the function checks assuming boxplot uses single ax
  345. check_ax_title : bool
  346. Whether to check the ax.title is the same as expected_key
  347. Intended to be checked by calling from ``boxplot``.
  348. Normal ``plot`` doesn't attach ``ax.title``, it must be disabled.
  349. """
  350. from matplotlib.axes import Axes
  351. types = {'dict': dict, 'axes': Axes, 'both': tuple}
  352. if expected_keys is None:
  353. # should be fixed when the returning default is changed
  354. if return_type is None:
  355. return_type = 'dict'
  356. assert isinstance(returned, types[return_type])
  357. if return_type == 'both':
  358. assert isinstance(returned.ax, Axes)
  359. assert isinstance(returned.lines, dict)
  360. else:
  361. # should be fixed when the returning default is changed
  362. if return_type is None:
  363. for r in self._flatten_visible(returned):
  364. assert isinstance(r, Axes)
  365. return
  366. assert isinstance(returned, Series)
  367. assert sorted(returned.keys()) == sorted(expected_keys)
  368. for key, value in iteritems(returned):
  369. assert isinstance(value, types[return_type])
  370. # check returned dict has correct mapping
  371. if return_type == 'axes':
  372. if check_ax_title:
  373. assert value.get_title() == key
  374. elif return_type == 'both':
  375. if check_ax_title:
  376. assert value.ax.get_title() == key
  377. assert isinstance(value.ax, Axes)
  378. assert isinstance(value.lines, dict)
  379. elif return_type == 'dict':
  380. line = value['medians'][0]
  381. axes = line.axes
  382. if check_ax_title:
  383. assert axes.get_title() == key
  384. else:
  385. raise AssertionError
  386. def _check_grid_settings(self, obj, kinds, kws={}):
  387. # Make sure plot defaults to rcParams['axes.grid'] setting, GH 9792
  388. import matplotlib as mpl
  389. def is_grid_on():
  390. xoff = all(not g.gridOn
  391. for g in self.plt.gca().xaxis.get_major_ticks())
  392. yoff = all(not g.gridOn
  393. for g in self.plt.gca().yaxis.get_major_ticks())
  394. return not (xoff and yoff)
  395. spndx = 1
  396. for kind in kinds:
  397. if not _ok_for_gaussian_kde(kind):
  398. continue
  399. self.plt.subplot(1, 4 * len(kinds), spndx)
  400. spndx += 1
  401. mpl.rc('axes', grid=False)
  402. obj.plot(kind=kind, **kws)
  403. assert not is_grid_on()
  404. self.plt.subplot(1, 4 * len(kinds), spndx)
  405. spndx += 1
  406. mpl.rc('axes', grid=True)
  407. obj.plot(kind=kind, grid=False, **kws)
  408. assert not is_grid_on()
  409. if kind != 'pie':
  410. self.plt.subplot(1, 4 * len(kinds), spndx)
  411. spndx += 1
  412. mpl.rc('axes', grid=True)
  413. obj.plot(kind=kind, **kws)
  414. assert is_grid_on()
  415. self.plt.subplot(1, 4 * len(kinds), spndx)
  416. spndx += 1
  417. mpl.rc('axes', grid=False)
  418. obj.plot(kind=kind, grid=True, **kws)
  419. assert is_grid_on()
  420. def _unpack_cycler(self, rcParams, field='color'):
  421. """
  422. Auxiliary function for correctly unpacking cycler after MPL >= 1.5
  423. """
  424. return [v[field] for v in rcParams['axes.prop_cycle']]
  425. def _check_plot_works(f, filterwarnings='always', **kwargs):
  426. import matplotlib.pyplot as plt
  427. ret = None
  428. with warnings.catch_warnings():
  429. warnings.simplefilter(filterwarnings)
  430. try:
  431. try:
  432. fig = kwargs['figure']
  433. except KeyError:
  434. fig = plt.gcf()
  435. plt.clf()
  436. ax = kwargs.get('ax', fig.add_subplot(211)) # noqa
  437. ret = f(**kwargs)
  438. assert_is_valid_plot_return_object(ret)
  439. try:
  440. kwargs['ax'] = fig.add_subplot(212)
  441. ret = f(**kwargs)
  442. except Exception:
  443. pass
  444. else:
  445. assert_is_valid_plot_return_object(ret)
  446. with ensure_clean(return_filelike=True) as path:
  447. plt.savefig(path)
  448. finally:
  449. tm.close(fig)
  450. return ret
  451. def curpath():
  452. pth, _ = os.path.split(os.path.abspath(__file__))
  453. return pth