test_groupby.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # coding: utf-8
  2. """ Test cases for GroupBy.plot """
  3. import numpy as np
  4. import pandas.util._test_decorators as td
  5. from pandas import DataFrame, Series
  6. from pandas.tests.plotting.common import TestPlotBase
  7. import pandas.util.testing as tm
  8. @td.skip_if_no_mpl
  9. class TestDataFrameGroupByPlots(TestPlotBase):
  10. def test_series_groupby_plotting_nominally_works(self):
  11. n = 10
  12. weight = Series(np.random.normal(166, 20, size=n))
  13. height = Series(np.random.normal(60, 10, size=n))
  14. with tm.RNGContext(42):
  15. gender = np.random.choice(['male', 'female'], size=n)
  16. weight.groupby(gender).plot()
  17. tm.close()
  18. height.groupby(gender).hist()
  19. tm.close()
  20. # Regression test for GH8733
  21. height.groupby(gender).plot(alpha=0.5)
  22. tm.close()
  23. def test_plotting_with_float_index_works(self):
  24. # GH 7025
  25. df = DataFrame({'def': [1, 1, 1, 2, 2, 2, 3, 3, 3],
  26. 'val': np.random.randn(9)},
  27. index=[1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0])
  28. df.groupby('def')['val'].plot()
  29. tm.close()
  30. df.groupby('def')['val'].apply(lambda x: x.plot())
  31. tm.close()
  32. def test_hist_single_row(self):
  33. # GH10214
  34. bins = np.arange(80, 100 + 2, 1)
  35. df = DataFrame({"Name": ["AAA", "BBB"],
  36. "ByCol": [1, 2],
  37. "Mark": [85, 89]})
  38. df["Mark"].hist(by=df["ByCol"], bins=bins)
  39. df = DataFrame({"Name": ["AAA"], "ByCol": [1], "Mark": [85]})
  40. df["Mark"].hist(by=df["ByCol"], bins=bins)
  41. def test_plot_submethod_works(self):
  42. df = DataFrame({'x': [1, 2, 3, 4, 5],
  43. 'y': [1, 2, 3, 2, 1],
  44. 'z': list('ababa')})
  45. df.groupby('z').plot.scatter('x', 'y')
  46. tm.close()
  47. df.groupby('z')['x'].plot.line()
  48. tm.close()
  49. def test_plot_kwargs(self):
  50. df = DataFrame({'x': [1, 2, 3, 4, 5],
  51. 'y': [1, 2, 3, 2, 1],
  52. 'z': list('ababa')})
  53. res = df.groupby('z').plot(kind='scatter', x='x', y='y')
  54. # check that a scatter plot is effectively plotted: the axes should
  55. # contain a PathCollection from the scatter plot (GH11805)
  56. assert len(res['a'].collections) == 1
  57. res = df.groupby('z').plot.scatter(x='x', y='y')
  58. assert len(res['a'].collections) == 1