_doctools.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import numpy as np
  2. import pandas.compat as compat
  3. import pandas as pd
  4. class TablePlotter(object):
  5. """
  6. Layout some DataFrames in vertical/horizontal layout for explanation.
  7. Used in merging.rst
  8. """
  9. def __init__(self, cell_width=0.37, cell_height=0.25, font_size=7.5):
  10. self.cell_width = cell_width
  11. self.cell_height = cell_height
  12. self.font_size = font_size
  13. def _shape(self, df):
  14. """
  15. Calculate table chape considering index levels.
  16. """
  17. row, col = df.shape
  18. return row + df.columns.nlevels, col + df.index.nlevels
  19. def _get_cells(self, left, right, vertical):
  20. """
  21. Calculate appropriate figure size based on left and right data.
  22. """
  23. if vertical:
  24. # calculate required number of cells
  25. vcells = max(sum(self._shape(l)[0] for l in left),
  26. self._shape(right)[0])
  27. hcells = (max(self._shape(l)[1] for l in left) +
  28. self._shape(right)[1])
  29. else:
  30. vcells = max([self._shape(l)[0] for l in left] +
  31. [self._shape(right)[0]])
  32. hcells = sum([self._shape(l)[1] for l in left] +
  33. [self._shape(right)[1]])
  34. return hcells, vcells
  35. def plot(self, left, right, labels=None, vertical=True):
  36. """
  37. Plot left / right DataFrames in specified layout.
  38. Parameters
  39. ----------
  40. left : list of DataFrames before operation is applied
  41. right : DataFrame of operation result
  42. labels : list of str to be drawn as titles of left DataFrames
  43. vertical : bool
  44. If True, use vertical layout. If False, use horizontal layout.
  45. """
  46. import matplotlib.pyplot as plt
  47. import matplotlib.gridspec as gridspec
  48. if not isinstance(left, list):
  49. left = [left]
  50. left = [self._conv(l) for l in left]
  51. right = self._conv(right)
  52. hcells, vcells = self._get_cells(left, right, vertical)
  53. if vertical:
  54. figsize = self.cell_width * hcells, self.cell_height * vcells
  55. else:
  56. # include margin for titles
  57. figsize = self.cell_width * hcells, self.cell_height * vcells
  58. fig = plt.figure(figsize=figsize)
  59. if vertical:
  60. gs = gridspec.GridSpec(len(left), hcells)
  61. # left
  62. max_left_cols = max(self._shape(l)[1] for l in left)
  63. max_left_rows = max(self._shape(l)[0] for l in left)
  64. for i, (l, label) in enumerate(zip(left, labels)):
  65. ax = fig.add_subplot(gs[i, 0:max_left_cols])
  66. self._make_table(ax, l, title=label,
  67. height=1.0 / max_left_rows)
  68. # right
  69. ax = plt.subplot(gs[:, max_left_cols:])
  70. self._make_table(ax, right, title='Result', height=1.05 / vcells)
  71. fig.subplots_adjust(top=0.9, bottom=0.05, left=0.05, right=0.95)
  72. else:
  73. max_rows = max(self._shape(df)[0] for df in left + [right])
  74. height = 1.0 / np.max(max_rows)
  75. gs = gridspec.GridSpec(1, hcells)
  76. # left
  77. i = 0
  78. for l, label in zip(left, labels):
  79. sp = self._shape(l)
  80. ax = fig.add_subplot(gs[0, i:i + sp[1]])
  81. self._make_table(ax, l, title=label, height=height)
  82. i += sp[1]
  83. # right
  84. ax = plt.subplot(gs[0, i:])
  85. self._make_table(ax, right, title='Result', height=height)
  86. fig.subplots_adjust(top=0.85, bottom=0.05, left=0.05, right=0.95)
  87. return fig
  88. def _conv(self, data):
  89. """Convert each input to appropriate for table outplot"""
  90. if isinstance(data, pd.Series):
  91. if data.name is None:
  92. data = data.to_frame(name='')
  93. else:
  94. data = data.to_frame()
  95. data = data.fillna('NaN')
  96. return data
  97. def _insert_index(self, data):
  98. # insert is destructive
  99. data = data.copy()
  100. idx_nlevels = data.index.nlevels
  101. if idx_nlevels == 1:
  102. data.insert(0, 'Index', data.index)
  103. else:
  104. for i in range(idx_nlevels):
  105. data.insert(i, 'Index{0}'.format(i),
  106. data.index._get_level_values(i))
  107. col_nlevels = data.columns.nlevels
  108. if col_nlevels > 1:
  109. col = data.columns._get_level_values(0)
  110. values = [data.columns._get_level_values(i).values
  111. for i in range(1, col_nlevels)]
  112. col_df = pd.DataFrame(values)
  113. data.columns = col_df.columns
  114. data = pd.concat([col_df, data])
  115. data.columns = col
  116. return data
  117. def _make_table(self, ax, df, title, height=None):
  118. if df is None:
  119. ax.set_visible(False)
  120. return
  121. import pandas.plotting as plotting
  122. idx_nlevels = df.index.nlevels
  123. col_nlevels = df.columns.nlevels
  124. # must be convert here to get index levels for colorization
  125. df = self._insert_index(df)
  126. tb = plotting.table(ax, df, loc=9)
  127. tb.set_fontsize(self.font_size)
  128. if height is None:
  129. height = 1.0 / (len(df) + 1)
  130. props = tb.properties()
  131. for (r, c), cell in compat.iteritems(props['celld']):
  132. if c == -1:
  133. cell.set_visible(False)
  134. elif r < col_nlevels and c < idx_nlevels:
  135. cell.set_visible(False)
  136. elif r < col_nlevels or c < idx_nlevels:
  137. cell.set_facecolor('#AAAAAA')
  138. cell.set_height(height)
  139. ax.set_title(title, size=self.font_size)
  140. ax.axis('off')
  141. class _WritableDoc(type):
  142. # Remove this when Python2 support is dropped
  143. # __doc__ is not mutable for new-style classes in Python2, which means
  144. # we can't use @Appender to share class docstrings. This can be used
  145. # with `add_metaclass` to make cls.__doc__ mutable.
  146. pass
  147. if __name__ == "__main__":
  148. import matplotlib.pyplot as plt
  149. p = TablePlotter()
  150. df1 = pd.DataFrame({'A': [10, 11, 12],
  151. 'B': [20, 21, 22],
  152. 'C': [30, 31, 32]})
  153. df2 = pd.DataFrame({'A': [10, 12],
  154. 'C': [30, 32]})
  155. p.plot([df1, df2], pd.concat([df1, df2]),
  156. labels=['df1', 'df2'], vertical=True)
  157. plt.show()
  158. df3 = pd.DataFrame({'X': [10, 12],
  159. 'Z': [30, 32]})
  160. p.plot([df1, df3], pd.concat([df1, df3], axis=1),
  161. labels=['df1', 'df2'], vertical=False)
  162. plt.show()
  163. idx = pd.MultiIndex.from_tuples([(1, 'A'), (1, 'B'), (1, 'C'),
  164. (2, 'A'), (2, 'B'), (2, 'C')])
  165. col = pd.MultiIndex.from_tuples([(1, 'A'), (1, 'B')])
  166. df3 = pd.DataFrame({'v1': [1, 2, 3, 4, 5, 6],
  167. 'v2': [5, 6, 7, 8, 9, 10]},
  168. index=idx)
  169. df3.columns = col
  170. p.plot(df3, df3, labels=['df3'])
  171. plt.show()