197 lines
6.6 KiB
Python
197 lines
6.6 KiB
Python
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
import pandas.compat as compat
|
||
|
|
||
|
|
||
|
class TablePlotter(object):
|
||
|
"""
|
||
|
Layout some DataFrames in vertical/horizontal layout for explanation.
|
||
|
Used in merging.rst
|
||
|
"""
|
||
|
|
||
|
def __init__(self, cell_width=0.37, cell_height=0.25, font_size=7.5):
|
||
|
self.cell_width = cell_width
|
||
|
self.cell_height = cell_height
|
||
|
self.font_size = font_size
|
||
|
|
||
|
def _shape(self, df):
|
||
|
"""
|
||
|
Calculate table chape considering index levels.
|
||
|
"""
|
||
|
|
||
|
row, col = df.shape
|
||
|
return row + df.columns.nlevels, col + df.index.nlevels
|
||
|
|
||
|
def _get_cells(self, left, right, vertical):
|
||
|
"""
|
||
|
Calculate appropriate figure size based on left and right data.
|
||
|
"""
|
||
|
|
||
|
if vertical:
|
||
|
# calculate required number of cells
|
||
|
vcells = max(sum(self._shape(l)[0] for l in left),
|
||
|
self._shape(right)[0])
|
||
|
hcells = (max(self._shape(l)[1] for l in left) +
|
||
|
self._shape(right)[1])
|
||
|
else:
|
||
|
vcells = max([self._shape(l)[0] for l in left] +
|
||
|
[self._shape(right)[0]])
|
||
|
hcells = sum([self._shape(l)[1] for l in left] +
|
||
|
[self._shape(right)[1]])
|
||
|
return hcells, vcells
|
||
|
|
||
|
def plot(self, left, right, labels=None, vertical=True):
|
||
|
"""
|
||
|
Plot left / right DataFrames in specified layout.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
left : list of DataFrames before operation is applied
|
||
|
right : DataFrame of operation result
|
||
|
labels : list of str to be drawn as titles of left DataFrames
|
||
|
vertical : bool
|
||
|
If True, use vertical layout. If False, use horizontal layout.
|
||
|
"""
|
||
|
import matplotlib.pyplot as plt
|
||
|
import matplotlib.gridspec as gridspec
|
||
|
|
||
|
if not isinstance(left, list):
|
||
|
left = [left]
|
||
|
left = [self._conv(l) for l in left]
|
||
|
right = self._conv(right)
|
||
|
|
||
|
hcells, vcells = self._get_cells(left, right, vertical)
|
||
|
|
||
|
if vertical:
|
||
|
figsize = self.cell_width * hcells, self.cell_height * vcells
|
||
|
else:
|
||
|
# include margin for titles
|
||
|
figsize = self.cell_width * hcells, self.cell_height * vcells
|
||
|
fig = plt.figure(figsize=figsize)
|
||
|
|
||
|
if vertical:
|
||
|
gs = gridspec.GridSpec(len(left), hcells)
|
||
|
# left
|
||
|
max_left_cols = max(self._shape(l)[1] for l in left)
|
||
|
max_left_rows = max(self._shape(l)[0] for l in left)
|
||
|
for i, (l, label) in enumerate(zip(left, labels)):
|
||
|
ax = fig.add_subplot(gs[i, 0:max_left_cols])
|
||
|
self._make_table(ax, l, title=label,
|
||
|
height=1.0 / max_left_rows)
|
||
|
# right
|
||
|
ax = plt.subplot(gs[:, max_left_cols:])
|
||
|
self._make_table(ax, right, title='Result', height=1.05 / vcells)
|
||
|
fig.subplots_adjust(top=0.9, bottom=0.05, left=0.05, right=0.95)
|
||
|
else:
|
||
|
max_rows = max(self._shape(df)[0] for df in left + [right])
|
||
|
height = 1.0 / np.max(max_rows)
|
||
|
gs = gridspec.GridSpec(1, hcells)
|
||
|
# left
|
||
|
i = 0
|
||
|
for l, label in zip(left, labels):
|
||
|
sp = self._shape(l)
|
||
|
ax = fig.add_subplot(gs[0, i:i + sp[1]])
|
||
|
self._make_table(ax, l, title=label, height=height)
|
||
|
i += sp[1]
|
||
|
# right
|
||
|
ax = plt.subplot(gs[0, i:])
|
||
|
self._make_table(ax, right, title='Result', height=height)
|
||
|
fig.subplots_adjust(top=0.85, bottom=0.05, left=0.05, right=0.95)
|
||
|
|
||
|
return fig
|
||
|
|
||
|
def _conv(self, data):
|
||
|
"""Convert each input to appropriate for table outplot"""
|
||
|
if isinstance(data, pd.Series):
|
||
|
if data.name is None:
|
||
|
data = data.to_frame(name='')
|
||
|
else:
|
||
|
data = data.to_frame()
|
||
|
data = data.fillna('NaN')
|
||
|
return data
|
||
|
|
||
|
def _insert_index(self, data):
|
||
|
# insert is destructive
|
||
|
data = data.copy()
|
||
|
idx_nlevels = data.index.nlevels
|
||
|
if idx_nlevels == 1:
|
||
|
data.insert(0, 'Index', data.index)
|
||
|
else:
|
||
|
for i in range(idx_nlevels):
|
||
|
data.insert(i, 'Index{0}'.format(i),
|
||
|
data.index._get_level_values(i))
|
||
|
|
||
|
col_nlevels = data.columns.nlevels
|
||
|
if col_nlevels > 1:
|
||
|
col = data.columns._get_level_values(0)
|
||
|
values = [data.columns._get_level_values(i).values
|
||
|
for i in range(1, col_nlevels)]
|
||
|
col_df = pd.DataFrame(values)
|
||
|
data.columns = col_df.columns
|
||
|
data = pd.concat([col_df, data])
|
||
|
data.columns = col
|
||
|
return data
|
||
|
|
||
|
def _make_table(self, ax, df, title, height=None):
|
||
|
if df is None:
|
||
|
ax.set_visible(False)
|
||
|
return
|
||
|
|
||
|
import pandas.plotting as plotting
|
||
|
|
||
|
idx_nlevels = df.index.nlevels
|
||
|
col_nlevels = df.columns.nlevels
|
||
|
# must be convert here to get index levels for colorization
|
||
|
df = self._insert_index(df)
|
||
|
tb = plotting.table(ax, df, loc=9)
|
||
|
tb.set_fontsize(self.font_size)
|
||
|
|
||
|
if height is None:
|
||
|
height = 1.0 / (len(df) + 1)
|
||
|
|
||
|
props = tb.properties()
|
||
|
for (r, c), cell in compat.iteritems(props['celld']):
|
||
|
if c == -1:
|
||
|
cell.set_visible(False)
|
||
|
elif r < col_nlevels and c < idx_nlevels:
|
||
|
cell.set_visible(False)
|
||
|
elif r < col_nlevels or c < idx_nlevels:
|
||
|
cell.set_facecolor('#AAAAAA')
|
||
|
cell.set_height(height)
|
||
|
|
||
|
ax.set_title(title, size=self.font_size)
|
||
|
ax.axis('off')
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
p = TablePlotter()
|
||
|
|
||
|
df1 = pd.DataFrame({'A': [10, 11, 12],
|
||
|
'B': [20, 21, 22],
|
||
|
'C': [30, 31, 32]})
|
||
|
df2 = pd.DataFrame({'A': [10, 12],
|
||
|
'C': [30, 32]})
|
||
|
|
||
|
p.plot([df1, df2], pd.concat([df1, df2]),
|
||
|
labels=['df1', 'df2'], vertical=True)
|
||
|
plt.show()
|
||
|
|
||
|
df3 = pd.DataFrame({'X': [10, 12],
|
||
|
'Z': [30, 32]})
|
||
|
|
||
|
p.plot([df1, df3], pd.concat([df1, df3], axis=1),
|
||
|
labels=['df1', 'df2'], vertical=False)
|
||
|
plt.show()
|
||
|
|
||
|
idx = pd.MultiIndex.from_tuples([(1, 'A'), (1, 'B'), (1, 'C'),
|
||
|
(2, 'A'), (2, 'B'), (2, 'C')])
|
||
|
col = pd.MultiIndex.from_tuples([(1, 'A'), (1, 'B')])
|
||
|
df3 = pd.DataFrame({'v1': [1, 2, 3, 4, 5, 6],
|
||
|
'v2': [5, 6, 7, 8, 9, 10]},
|
||
|
index=idx)
|
||
|
df3.columns = col
|
||
|
p.plot(df3, df3, labels=['df3'])
|
||
|
plt.show()
|