_misc.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. # being a bit too dynamic
  2. # pylint: disable=E1101
  3. from __future__ import division
  4. import numpy as np
  5. from pandas.compat import lmap, lrange, range, zip
  6. from pandas.util._decorators import deprecate_kwarg
  7. from pandas.core.dtypes.missing import notna
  8. from pandas.io.formats.printing import pprint_thing
  9. from pandas.plotting._style import _get_standard_colors
  10. from pandas.plotting._tools import _set_ticks_props, _subplots
  11. def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,
  12. diagonal='hist', marker='.', density_kwds=None,
  13. hist_kwds=None, range_padding=0.05, **kwds):
  14. """
  15. Draw a matrix of scatter plots.
  16. Parameters
  17. ----------
  18. frame : DataFrame
  19. alpha : float, optional
  20. amount of transparency applied
  21. figsize : (float,float), optional
  22. a tuple (width, height) in inches
  23. ax : Matplotlib axis object, optional
  24. grid : bool, optional
  25. setting this to True will show the grid
  26. diagonal : {'hist', 'kde'}
  27. pick between 'kde' and 'hist' for
  28. either Kernel Density Estimation or Histogram
  29. plot in the diagonal
  30. marker : str, optional
  31. Matplotlib marker type, default '.'
  32. hist_kwds : other plotting keyword arguments
  33. To be passed to hist function
  34. density_kwds : other plotting keyword arguments
  35. To be passed to kernel density estimate plot
  36. range_padding : float, optional
  37. relative extension of axis range in x and y
  38. with respect to (x_max - x_min) or (y_max - y_min),
  39. default 0.05
  40. kwds : other plotting keyword arguments
  41. To be passed to scatter function
  42. Examples
  43. --------
  44. >>> df = pd.DataFrame(np.random.randn(1000, 4), columns=['A','B','C','D'])
  45. >>> scatter_matrix(df, alpha=0.2)
  46. """
  47. df = frame._get_numeric_data()
  48. n = df.columns.size
  49. naxes = n * n
  50. fig, axes = _subplots(naxes=naxes, figsize=figsize, ax=ax,
  51. squeeze=False)
  52. # no gaps between subplots
  53. fig.subplots_adjust(wspace=0, hspace=0)
  54. mask = notna(df)
  55. marker = _get_marker_compat(marker)
  56. hist_kwds = hist_kwds or {}
  57. density_kwds = density_kwds or {}
  58. # GH 14855
  59. kwds.setdefault('edgecolors', 'none')
  60. boundaries_list = []
  61. for a in df.columns:
  62. values = df[a].values[mask[a].values]
  63. rmin_, rmax_ = np.min(values), np.max(values)
  64. rdelta_ext = (rmax_ - rmin_) * range_padding / 2.
  65. boundaries_list.append((rmin_ - rdelta_ext, rmax_ + rdelta_ext))
  66. for i, a in zip(lrange(n), df.columns):
  67. for j, b in zip(lrange(n), df.columns):
  68. ax = axes[i, j]
  69. if i == j:
  70. values = df[a].values[mask[a].values]
  71. # Deal with the diagonal by drawing a histogram there.
  72. if diagonal == 'hist':
  73. ax.hist(values, **hist_kwds)
  74. elif diagonal in ('kde', 'density'):
  75. from scipy.stats import gaussian_kde
  76. y = values
  77. gkde = gaussian_kde(y)
  78. ind = np.linspace(y.min(), y.max(), 1000)
  79. ax.plot(ind, gkde.evaluate(ind), **density_kwds)
  80. ax.set_xlim(boundaries_list[i])
  81. else:
  82. common = (mask[a] & mask[b]).values
  83. ax.scatter(df[b][common], df[a][common],
  84. marker=marker, alpha=alpha, **kwds)
  85. ax.set_xlim(boundaries_list[j])
  86. ax.set_ylim(boundaries_list[i])
  87. ax.set_xlabel(b)
  88. ax.set_ylabel(a)
  89. if j != 0:
  90. ax.yaxis.set_visible(False)
  91. if i != n - 1:
  92. ax.xaxis.set_visible(False)
  93. if len(df.columns) > 1:
  94. lim1 = boundaries_list[0]
  95. locs = axes[0][1].yaxis.get_majorticklocs()
  96. locs = locs[(lim1[0] <= locs) & (locs <= lim1[1])]
  97. adj = (locs - lim1[0]) / (lim1[1] - lim1[0])
  98. lim0 = axes[0][0].get_ylim()
  99. adj = adj * (lim0[1] - lim0[0]) + lim0[0]
  100. axes[0][0].yaxis.set_ticks(adj)
  101. if np.all(locs == locs.astype(int)):
  102. # if all ticks are int
  103. locs = locs.astype(int)
  104. axes[0][0].yaxis.set_ticklabels(locs)
  105. _set_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
  106. return axes
  107. def _get_marker_compat(marker):
  108. import matplotlib.lines as mlines
  109. if marker not in mlines.lineMarkers:
  110. return 'o'
  111. return marker
  112. def radviz(frame, class_column, ax=None, color=None, colormap=None, **kwds):
  113. """
  114. Plot a multidimensional dataset in 2D.
  115. Each Series in the DataFrame is represented as a evenly distributed
  116. slice on a circle. Each data point is rendered in the circle according to
  117. the value on each Series. Highly correlated `Series` in the `DataFrame`
  118. are placed closer on the unit circle.
  119. RadViz allow to project a N-dimensional data set into a 2D space where the
  120. influence of each dimension can be interpreted as a balance between the
  121. influence of all dimensions.
  122. More info available at the `original article
  123. <http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.135.889>`_
  124. describing RadViz.
  125. Parameters
  126. ----------
  127. frame : `DataFrame`
  128. Pandas object holding the data.
  129. class_column : str
  130. Column name containing the name of the data point category.
  131. ax : :class:`matplotlib.axes.Axes`, optional
  132. A plot instance to which to add the information.
  133. color : list[str] or tuple[str], optional
  134. Assign a color to each category. Example: ['blue', 'green'].
  135. colormap : str or :class:`matplotlib.colors.Colormap`, default None
  136. Colormap to select colors from. If string, load colormap with that
  137. name from matplotlib.
  138. kwds : optional
  139. Options to pass to matplotlib scatter plotting method.
  140. Returns
  141. -------
  142. axes : :class:`matplotlib.axes.Axes`
  143. See Also
  144. --------
  145. pandas.plotting.andrews_curves : Plot clustering visualization.
  146. Examples
  147. --------
  148. .. plot::
  149. :context: close-figs
  150. >>> df = pd.DataFrame({
  151. ... 'SepalLength': [6.5, 7.7, 5.1, 5.8, 7.6, 5.0, 5.4, 4.6,
  152. ... 6.7, 4.6],
  153. ... 'SepalWidth': [3.0, 3.8, 3.8, 2.7, 3.0, 2.3, 3.0, 3.2,
  154. ... 3.3, 3.6],
  155. ... 'PetalLength': [5.5, 6.7, 1.9, 5.1, 6.6, 3.3, 4.5, 1.4,
  156. ... 5.7, 1.0],
  157. ... 'PetalWidth': [1.8, 2.2, 0.4, 1.9, 2.1, 1.0, 1.5, 0.2,
  158. ... 2.1, 0.2],
  159. ... 'Category': ['virginica', 'virginica', 'setosa',
  160. ... 'virginica', 'virginica', 'versicolor',
  161. ... 'versicolor', 'setosa', 'virginica',
  162. ... 'setosa']
  163. ... })
  164. >>> rad_viz = pd.plotting.radviz(df, 'Category') # doctest: +SKIP
  165. """
  166. import matplotlib.pyplot as plt
  167. import matplotlib.patches as patches
  168. def normalize(series):
  169. a = min(series)
  170. b = max(series)
  171. return (series - a) / (b - a)
  172. n = len(frame)
  173. classes = frame[class_column].drop_duplicates()
  174. class_col = frame[class_column]
  175. df = frame.drop(class_column, axis=1).apply(normalize)
  176. if ax is None:
  177. ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1])
  178. to_plot = {}
  179. colors = _get_standard_colors(num_colors=len(classes), colormap=colormap,
  180. color_type='random', color=color)
  181. for kls in classes:
  182. to_plot[kls] = [[], []]
  183. m = len(frame.columns) - 1
  184. s = np.array([(np.cos(t), np.sin(t))
  185. for t in [2.0 * np.pi * (i / float(m))
  186. for i in range(m)]])
  187. for i in range(n):
  188. row = df.iloc[i].values
  189. row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
  190. y = (s * row_).sum(axis=0) / row.sum()
  191. kls = class_col.iat[i]
  192. to_plot[kls][0].append(y[0])
  193. to_plot[kls][1].append(y[1])
  194. for i, kls in enumerate(classes):
  195. ax.scatter(to_plot[kls][0], to_plot[kls][1], color=colors[i],
  196. label=pprint_thing(kls), **kwds)
  197. ax.legend()
  198. ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor='none'))
  199. for xy, name in zip(s, df.columns):
  200. ax.add_patch(patches.Circle(xy, radius=0.025, facecolor='gray'))
  201. if xy[0] < 0.0 and xy[1] < 0.0:
  202. ax.text(xy[0] - 0.025, xy[1] - 0.025, name,
  203. ha='right', va='top', size='small')
  204. elif xy[0] < 0.0 and xy[1] >= 0.0:
  205. ax.text(xy[0] - 0.025, xy[1] + 0.025, name,
  206. ha='right', va='bottom', size='small')
  207. elif xy[0] >= 0.0 and xy[1] < 0.0:
  208. ax.text(xy[0] + 0.025, xy[1] - 0.025, name,
  209. ha='left', va='top', size='small')
  210. elif xy[0] >= 0.0 and xy[1] >= 0.0:
  211. ax.text(xy[0] + 0.025, xy[1] + 0.025, name,
  212. ha='left', va='bottom', size='small')
  213. ax.axis('equal')
  214. return ax
  215. @deprecate_kwarg(old_arg_name='data', new_arg_name='frame')
  216. def andrews_curves(frame, class_column, ax=None, samples=200, color=None,
  217. colormap=None, **kwds):
  218. """
  219. Generates a matplotlib plot of Andrews curves, for visualising clusters of
  220. multivariate data.
  221. Andrews curves have the functional form:
  222. f(t) = x_1/sqrt(2) + x_2 sin(t) + x_3 cos(t) +
  223. x_4 sin(2t) + x_5 cos(2t) + ...
  224. Where x coefficients correspond to the values of each dimension and t is
  225. linearly spaced between -pi and +pi. Each row of frame then corresponds to
  226. a single curve.
  227. Parameters
  228. ----------
  229. frame : DataFrame
  230. Data to be plotted, preferably normalized to (0.0, 1.0)
  231. class_column : Name of the column containing class names
  232. ax : matplotlib axes object, default None
  233. samples : Number of points to plot in each curve
  234. color : list or tuple, optional
  235. Colors to use for the different classes
  236. colormap : str or matplotlib colormap object, default None
  237. Colormap to select colors from. If string, load colormap with that name
  238. from matplotlib.
  239. kwds : keywords
  240. Options to pass to matplotlib plotting method
  241. Returns
  242. -------
  243. ax : Matplotlib axis object
  244. """
  245. from math import sqrt, pi
  246. import matplotlib.pyplot as plt
  247. def function(amplitudes):
  248. def f(t):
  249. x1 = amplitudes[0]
  250. result = x1 / sqrt(2.0)
  251. # Take the rest of the coefficients and resize them
  252. # appropriately. Take a copy of amplitudes as otherwise numpy
  253. # deletes the element from amplitudes itself.
  254. coeffs = np.delete(np.copy(amplitudes), 0)
  255. coeffs.resize(int((coeffs.size + 1) / 2), 2)
  256. # Generate the harmonics and arguments for the sin and cos
  257. # functions.
  258. harmonics = np.arange(0, coeffs.shape[0]) + 1
  259. trig_args = np.outer(harmonics, t)
  260. result += np.sum(coeffs[:, 0, np.newaxis] * np.sin(trig_args) +
  261. coeffs[:, 1, np.newaxis] * np.cos(trig_args),
  262. axis=0)
  263. return result
  264. return f
  265. n = len(frame)
  266. class_col = frame[class_column]
  267. classes = frame[class_column].drop_duplicates()
  268. df = frame.drop(class_column, axis=1)
  269. t = np.linspace(-pi, pi, samples)
  270. used_legends = set()
  271. color_values = _get_standard_colors(num_colors=len(classes),
  272. colormap=colormap, color_type='random',
  273. color=color)
  274. colors = dict(zip(classes, color_values))
  275. if ax is None:
  276. ax = plt.gca(xlim=(-pi, pi))
  277. for i in range(n):
  278. row = df.iloc[i].values
  279. f = function(row)
  280. y = f(t)
  281. kls = class_col.iat[i]
  282. label = pprint_thing(kls)
  283. if label not in used_legends:
  284. used_legends.add(label)
  285. ax.plot(t, y, color=colors[kls], label=label, **kwds)
  286. else:
  287. ax.plot(t, y, color=colors[kls], **kwds)
  288. ax.legend(loc='upper right')
  289. ax.grid()
  290. return ax
  291. def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
  292. """
  293. Bootstrap plot on mean, median and mid-range statistics.
  294. The bootstrap plot is used to estimate the uncertainty of a statistic
  295. by relaying on random sampling with replacement [1]_. This function will
  296. generate bootstrapping plots for mean, median and mid-range statistics
  297. for the given number of samples of the given size.
  298. .. [1] "Bootstrapping (statistics)" in \
  299. https://en.wikipedia.org/wiki/Bootstrapping_%28statistics%29
  300. Parameters
  301. ----------
  302. series : pandas.Series
  303. Pandas Series from where to get the samplings for the bootstrapping.
  304. fig : matplotlib.figure.Figure, default None
  305. If given, it will use the `fig` reference for plotting instead of
  306. creating a new one with default parameters.
  307. size : int, default 50
  308. Number of data points to consider during each sampling. It must be
  309. greater or equal than the length of the `series`.
  310. samples : int, default 500
  311. Number of times the bootstrap procedure is performed.
  312. **kwds :
  313. Options to pass to matplotlib plotting method.
  314. Returns
  315. -------
  316. fig : matplotlib.figure.Figure
  317. Matplotlib figure
  318. See Also
  319. --------
  320. pandas.DataFrame.plot : Basic plotting for DataFrame objects.
  321. pandas.Series.plot : Basic plotting for Series objects.
  322. Examples
  323. --------
  324. .. plot::
  325. :context: close-figs
  326. >>> s = pd.Series(np.random.uniform(size=100))
  327. >>> fig = pd.plotting.bootstrap_plot(s) # doctest: +SKIP
  328. """
  329. import random
  330. import matplotlib.pyplot as plt
  331. # random.sample(ndarray, int) fails on python 3.3, sigh
  332. data = list(series.values)
  333. samplings = [random.sample(data, size) for _ in range(samples)]
  334. means = np.array([np.mean(sampling) for sampling in samplings])
  335. medians = np.array([np.median(sampling) for sampling in samplings])
  336. midranges = np.array([(min(sampling) + max(sampling)) * 0.5
  337. for sampling in samplings])
  338. if fig is None:
  339. fig = plt.figure()
  340. x = lrange(samples)
  341. axes = []
  342. ax1 = fig.add_subplot(2, 3, 1)
  343. ax1.set_xlabel("Sample")
  344. axes.append(ax1)
  345. ax1.plot(x, means, **kwds)
  346. ax2 = fig.add_subplot(2, 3, 2)
  347. ax2.set_xlabel("Sample")
  348. axes.append(ax2)
  349. ax2.plot(x, medians, **kwds)
  350. ax3 = fig.add_subplot(2, 3, 3)
  351. ax3.set_xlabel("Sample")
  352. axes.append(ax3)
  353. ax3.plot(x, midranges, **kwds)
  354. ax4 = fig.add_subplot(2, 3, 4)
  355. ax4.set_xlabel("Mean")
  356. axes.append(ax4)
  357. ax4.hist(means, **kwds)
  358. ax5 = fig.add_subplot(2, 3, 5)
  359. ax5.set_xlabel("Median")
  360. axes.append(ax5)
  361. ax5.hist(medians, **kwds)
  362. ax6 = fig.add_subplot(2, 3, 6)
  363. ax6.set_xlabel("Midrange")
  364. axes.append(ax6)
  365. ax6.hist(midranges, **kwds)
  366. for axis in axes:
  367. plt.setp(axis.get_xticklabels(), fontsize=8)
  368. plt.setp(axis.get_yticklabels(), fontsize=8)
  369. return fig
  370. @deprecate_kwarg(old_arg_name='colors', new_arg_name='color')
  371. @deprecate_kwarg(old_arg_name='data', new_arg_name='frame', stacklevel=3)
  372. def parallel_coordinates(frame, class_column, cols=None, ax=None, color=None,
  373. use_columns=False, xticks=None, colormap=None,
  374. axvlines=True, axvlines_kwds=None, sort_labels=False,
  375. **kwds):
  376. """Parallel coordinates plotting.
  377. Parameters
  378. ----------
  379. frame : DataFrame
  380. class_column : str
  381. Column name containing class names
  382. cols : list, optional
  383. A list of column names to use
  384. ax : matplotlib.axis, optional
  385. matplotlib axis object
  386. color : list or tuple, optional
  387. Colors to use for the different classes
  388. use_columns : bool, optional
  389. If true, columns will be used as xticks
  390. xticks : list or tuple, optional
  391. A list of values to use for xticks
  392. colormap : str or matplotlib colormap, default None
  393. Colormap to use for line colors.
  394. axvlines : bool, optional
  395. If true, vertical lines will be added at each xtick
  396. axvlines_kwds : keywords, optional
  397. Options to be passed to axvline method for vertical lines
  398. sort_labels : bool, False
  399. Sort class_column labels, useful when assigning colors
  400. .. versionadded:: 0.20.0
  401. kwds : keywords
  402. Options to pass to matplotlib plotting method
  403. Returns
  404. -------
  405. ax: matplotlib axis object
  406. Examples
  407. --------
  408. >>> from matplotlib import pyplot as plt
  409. >>> df = pd.read_csv('https://raw.github.com/pandas-dev/pandas/master'
  410. '/pandas/tests/data/iris.csv')
  411. >>> pd.plotting.parallel_coordinates(
  412. df, 'Name',
  413. color=('#556270', '#4ECDC4', '#C7F464'))
  414. >>> plt.show()
  415. """
  416. if axvlines_kwds is None:
  417. axvlines_kwds = {'linewidth': 1, 'color': 'black'}
  418. import matplotlib.pyplot as plt
  419. n = len(frame)
  420. classes = frame[class_column].drop_duplicates()
  421. class_col = frame[class_column]
  422. if cols is None:
  423. df = frame.drop(class_column, axis=1)
  424. else:
  425. df = frame[cols]
  426. used_legends = set()
  427. ncols = len(df.columns)
  428. # determine values to use for xticks
  429. if use_columns is True:
  430. if not np.all(np.isreal(list(df.columns))):
  431. raise ValueError('Columns must be numeric to be used as xticks')
  432. x = df.columns
  433. elif xticks is not None:
  434. if not np.all(np.isreal(xticks)):
  435. raise ValueError('xticks specified must be numeric')
  436. elif len(xticks) != ncols:
  437. raise ValueError('Length of xticks must match number of columns')
  438. x = xticks
  439. else:
  440. x = lrange(ncols)
  441. if ax is None:
  442. ax = plt.gca()
  443. color_values = _get_standard_colors(num_colors=len(classes),
  444. colormap=colormap, color_type='random',
  445. color=color)
  446. if sort_labels:
  447. classes = sorted(classes)
  448. color_values = sorted(color_values)
  449. colors = dict(zip(classes, color_values))
  450. for i in range(n):
  451. y = df.iloc[i].values
  452. kls = class_col.iat[i]
  453. label = pprint_thing(kls)
  454. if label not in used_legends:
  455. used_legends.add(label)
  456. ax.plot(x, y, color=colors[kls], label=label, **kwds)
  457. else:
  458. ax.plot(x, y, color=colors[kls], **kwds)
  459. if axvlines:
  460. for i in x:
  461. ax.axvline(i, **axvlines_kwds)
  462. ax.set_xticks(x)
  463. ax.set_xticklabels(df.columns)
  464. ax.set_xlim(x[0], x[-1])
  465. ax.legend(loc='upper right')
  466. ax.grid()
  467. return ax
  468. def lag_plot(series, lag=1, ax=None, **kwds):
  469. """Lag plot for time series.
  470. Parameters
  471. ----------
  472. series : Time series
  473. lag : lag of the scatter plot, default 1
  474. ax : Matplotlib axis object, optional
  475. kwds : Matplotlib scatter method keyword arguments, optional
  476. Returns
  477. -------
  478. ax: Matplotlib axis object
  479. """
  480. import matplotlib.pyplot as plt
  481. # workaround because `c='b'` is hardcoded in matplotlibs scatter method
  482. kwds.setdefault('c', plt.rcParams['patch.facecolor'])
  483. data = series.values
  484. y1 = data[:-lag]
  485. y2 = data[lag:]
  486. if ax is None:
  487. ax = plt.gca()
  488. ax.set_xlabel("y(t)")
  489. ax.set_ylabel("y(t + {lag})".format(lag=lag))
  490. ax.scatter(y1, y2, **kwds)
  491. return ax
  492. def autocorrelation_plot(series, ax=None, **kwds):
  493. """Autocorrelation plot for time series.
  494. Parameters:
  495. -----------
  496. series: Time series
  497. ax: Matplotlib axis object, optional
  498. kwds : keywords
  499. Options to pass to matplotlib plotting method
  500. Returns:
  501. -----------
  502. ax: Matplotlib axis object
  503. """
  504. import matplotlib.pyplot as plt
  505. n = len(series)
  506. data = np.asarray(series)
  507. if ax is None:
  508. ax = plt.gca(xlim=(1, n), ylim=(-1.0, 1.0))
  509. mean = np.mean(data)
  510. c0 = np.sum((data - mean) ** 2) / float(n)
  511. def r(h):
  512. return ((data[:n - h] - mean) *
  513. (data[h:] - mean)).sum() / float(n) / c0
  514. x = np.arange(n) + 1
  515. y = lmap(r, x)
  516. z95 = 1.959963984540054
  517. z99 = 2.5758293035489004
  518. ax.axhline(y=z99 / np.sqrt(n), linestyle='--', color='grey')
  519. ax.axhline(y=z95 / np.sqrt(n), color='grey')
  520. ax.axhline(y=0.0, color='black')
  521. ax.axhline(y=-z95 / np.sqrt(n), color='grey')
  522. ax.axhline(y=-z99 / np.sqrt(n), linestyle='--', color='grey')
  523. ax.set_xlabel("Lag")
  524. ax.set_ylabel("Autocorrelation")
  525. ax.plot(x, y, **kwds)
  526. if 'label' in kwds:
  527. ax.legend()
  528. ax.grid()
  529. return ax