You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

557 lines
20 KiB

4 years ago
  1. """
  2. :mod:`~matplotlib.gridspec` is a module which specifies the location
  3. of the subplot in the figure.
  4. `GridSpec`
  5. specifies the geometry of the grid that a subplot will be
  6. placed. The number of rows and number of columns of the grid
  7. need to be set. Optionally, the subplot layout parameters
  8. (e.g., left, right, etc.) can be tuned.
  9. `SubplotSpec`
  10. specifies the location of the subplot in the given `GridSpec`.
  11. """
  12. import copy
  13. import logging
  14. import warnings
  15. import numpy as np
  16. import matplotlib as mpl
  17. from matplotlib import _pylab_helpers, cbook, tight_layout, rcParams
  18. from matplotlib.transforms import Bbox
  19. import matplotlib._layoutbox as layoutbox
  20. _log = logging.getLogger(__name__)
  21. class GridSpecBase(object):
  22. """
  23. A base class of GridSpec that specifies the geometry of the grid
  24. that a subplot will be placed.
  25. """
  26. def __init__(self, nrows, ncols, height_ratios=None, width_ratios=None):
  27. """
  28. The number of rows and number of columns of the grid need to
  29. be set. Optionally, the ratio of heights and widths of rows and
  30. columns can be specified.
  31. """
  32. self._nrows, self._ncols = nrows, ncols
  33. self.set_height_ratios(height_ratios)
  34. self.set_width_ratios(width_ratios)
  35. def __repr__(self):
  36. height_arg = (', height_ratios=%r' % self._row_height_ratios
  37. if self._row_height_ratios is not None else '')
  38. width_arg = (', width_ratios=%r' % self._col_width_ratios
  39. if self._col_width_ratios is not None else '')
  40. return '{clsname}({nrows}, {ncols}{optionals})'.format(
  41. clsname=self.__class__.__name__,
  42. nrows=self._nrows,
  43. ncols=self._ncols,
  44. optionals=height_arg + width_arg,
  45. )
  46. def get_geometry(self):
  47. 'get the geometry of the grid, e.g., 2,3'
  48. return self._nrows, self._ncols
  49. def get_subplot_params(self, figure=None, fig=None):
  50. pass
  51. def new_subplotspec(self, loc, rowspan=1, colspan=1):
  52. """
  53. create and return a SuplotSpec instance.
  54. """
  55. loc1, loc2 = loc
  56. subplotspec = self[loc1:loc1+rowspan, loc2:loc2+colspan]
  57. return subplotspec
  58. def set_width_ratios(self, width_ratios):
  59. if width_ratios is not None and len(width_ratios) != self._ncols:
  60. raise ValueError('Expected the given number of width ratios to '
  61. 'match the number of columns of the grid')
  62. self._col_width_ratios = width_ratios
  63. def get_width_ratios(self):
  64. return self._col_width_ratios
  65. def set_height_ratios(self, height_ratios):
  66. if height_ratios is not None and len(height_ratios) != self._nrows:
  67. raise ValueError('Expected the given number of height ratios to '
  68. 'match the number of rows of the grid')
  69. self._row_height_ratios = height_ratios
  70. def get_height_ratios(self):
  71. return self._row_height_ratios
  72. def get_grid_positions(self, fig, raw=False):
  73. """
  74. return lists of bottom and top position of rows, left and
  75. right positions of columns.
  76. If raw=True, then these are all in units relative to the container
  77. with no margins. (used for constrained_layout).
  78. """
  79. nrows, ncols = self.get_geometry()
  80. if raw:
  81. left = 0.
  82. right = 1.
  83. bottom = 0.
  84. top = 1.
  85. wspace = 0.
  86. hspace = 0.
  87. else:
  88. subplot_params = self.get_subplot_params(fig)
  89. left = subplot_params.left
  90. right = subplot_params.right
  91. bottom = subplot_params.bottom
  92. top = subplot_params.top
  93. wspace = subplot_params.wspace
  94. hspace = subplot_params.hspace
  95. tot_width = right - left
  96. tot_height = top - bottom
  97. # calculate accumulated heights of columns
  98. cell_h = tot_height / (nrows + hspace*(nrows-1))
  99. sep_h = hspace * cell_h
  100. if self._row_height_ratios is not None:
  101. norm = cell_h * nrows / sum(self._row_height_ratios)
  102. cell_heights = [r * norm for r in self._row_height_ratios]
  103. else:
  104. cell_heights = [cell_h] * nrows
  105. sep_heights = [0] + ([sep_h] * (nrows-1))
  106. cell_hs = np.cumsum(np.column_stack([sep_heights, cell_heights]).flat)
  107. # calculate accumulated widths of rows
  108. cell_w = tot_width / (ncols + wspace*(ncols-1))
  109. sep_w = wspace * cell_w
  110. if self._col_width_ratios is not None:
  111. norm = cell_w * ncols / sum(self._col_width_ratios)
  112. cell_widths = [r * norm for r in self._col_width_ratios]
  113. else:
  114. cell_widths = [cell_w] * ncols
  115. sep_widths = [0] + ([sep_w] * (ncols-1))
  116. cell_ws = np.cumsum(np.column_stack([sep_widths, cell_widths]).flat)
  117. fig_tops, fig_bottoms = (top - cell_hs).reshape((-1, 2)).T
  118. fig_lefts, fig_rights = (left + cell_ws).reshape((-1, 2)).T
  119. return fig_bottoms, fig_tops, fig_lefts, fig_rights
  120. def __getitem__(self, key):
  121. """Create and return a SuplotSpec instance.
  122. """
  123. nrows, ncols = self.get_geometry()
  124. def _normalize(key, size): # Includes last index.
  125. if isinstance(key, slice):
  126. start, stop, _ = key.indices(size)
  127. if stop > start:
  128. return start, stop - 1
  129. else:
  130. if key < 0:
  131. key += size
  132. if 0 <= key < size:
  133. return key, key
  134. raise IndexError("invalid index")
  135. if isinstance(key, tuple):
  136. try:
  137. k1, k2 = key
  138. except ValueError:
  139. raise ValueError("unrecognized subplot spec")
  140. num1, num2 = np.ravel_multi_index(
  141. [_normalize(k1, nrows), _normalize(k2, ncols)], (nrows, ncols))
  142. else: # Single key
  143. num1, num2 = _normalize(key, nrows * ncols)
  144. return SubplotSpec(self, num1, num2)
  145. class GridSpec(GridSpecBase):
  146. """
  147. A class that specifies the geometry of the grid that a subplot
  148. will be placed. The location of grid is determined by similar way
  149. as the SubplotParams.
  150. """
  151. def __init__(self, nrows, ncols, figure=None,
  152. left=None, bottom=None, right=None, top=None,
  153. wspace=None, hspace=None,
  154. width_ratios=None, height_ratios=None):
  155. """
  156. The number of rows and number of columns of the grid need to be set.
  157. Optionally, the subplot layout parameters (e.g., left, right, etc.)
  158. can be tuned.
  159. Parameters
  160. ----------
  161. nrows : int
  162. Number of rows in grid.
  163. ncols : int
  164. Number or columns in grid.
  165. figure : ~.figure.Figure, optional
  166. left, right, top, bottom : float
  167. Extent of the subplots as a fraction of figure width or height.
  168. Left cannot be larger than right, and bottom cannot be larger than
  169. top.
  170. wspace : float
  171. The amount of width reserved for space between subplots,
  172. expressed as a fraction of the average axis width.
  173. hspace : float
  174. The amount of height reserved for space between subplots,
  175. expressed as a fraction of the average axis height.
  176. Notes
  177. -----
  178. See `~.figure.SubplotParams` for descriptions of the layout parameters.
  179. """
  180. self.left = left
  181. self.bottom = bottom
  182. self.right = right
  183. self.top = top
  184. self.wspace = wspace
  185. self.hspace = hspace
  186. self.figure = figure
  187. GridSpecBase.__init__(self, nrows, ncols,
  188. width_ratios=width_ratios,
  189. height_ratios=height_ratios)
  190. if self.figure is None or not self.figure.get_constrained_layout():
  191. self._layoutbox = None
  192. else:
  193. self.figure.init_layoutbox()
  194. self._layoutbox = layoutbox.LayoutBox(
  195. parent=self.figure._layoutbox,
  196. name='gridspec' + layoutbox.seq_id(),
  197. artist=self)
  198. # by default the layoutbox for a gridsepc will fill a figure.
  199. # but this can change below if the gridspec is created from a
  200. # subplotspec. (GridSpecFromSubplotSpec)
  201. _AllowedKeys = ["left", "bottom", "right", "top", "wspace", "hspace"]
  202. def __getstate__(self):
  203. state = self.__dict__
  204. try:
  205. state.pop('_layoutbox')
  206. except KeyError:
  207. pass
  208. return state
  209. def __setstate__(self, state):
  210. self.__dict__ = state
  211. # layoutboxes don't survive pickling...
  212. self._layoutbox = None
  213. def update(self, **kwargs):
  214. """
  215. Update the current values. If any kwarg is None, default to
  216. the current value, if set, otherwise to rc.
  217. """
  218. for k, v in kwargs.items():
  219. if k in self._AllowedKeys:
  220. setattr(self, k, v)
  221. else:
  222. raise AttributeError("%s is unknown keyword" % (k,))
  223. for figmanager in _pylab_helpers.Gcf.figs.values():
  224. for ax in figmanager.canvas.figure.axes:
  225. # copied from Figure.subplots_adjust
  226. if not isinstance(ax, mpl.axes.SubplotBase):
  227. # Check if sharing a subplots axis
  228. if isinstance(ax._sharex, mpl.axes.SubplotBase):
  229. if ax._sharex.get_subplotspec().get_gridspec() == self:
  230. ax._sharex.update_params()
  231. ax._set_position(ax._sharex.figbox)
  232. elif isinstance(ax._sharey, mpl.axes.SubplotBase):
  233. if ax._sharey.get_subplotspec().get_gridspec() == self:
  234. ax._sharey.update_params()
  235. ax._set_position(ax._sharey.figbox)
  236. else:
  237. ss = ax.get_subplotspec().get_topmost_subplotspec()
  238. if ss.get_gridspec() == self:
  239. ax.update_params()
  240. ax._set_position(ax.figbox)
  241. def get_subplot_params(self, figure=None, fig=None):
  242. """
  243. Return a dictionary of subplot layout parameters. The default
  244. parameters are from rcParams unless a figure attribute is set.
  245. """
  246. if fig is not None:
  247. cbook.warn_deprecated("2.2", "fig", obj_type="keyword argument",
  248. alternative="figure")
  249. if figure is None:
  250. figure = fig
  251. if figure is None:
  252. kw = {k: rcParams["figure.subplot."+k] for k in self._AllowedKeys}
  253. subplotpars = mpl.figure.SubplotParams(**kw)
  254. else:
  255. subplotpars = copy.copy(figure.subplotpars)
  256. subplotpars.update(**{k: getattr(self, k) for k in self._AllowedKeys})
  257. return subplotpars
  258. def locally_modified_subplot_params(self):
  259. return [k for k in self._AllowedKeys if getattr(self, k)]
  260. def tight_layout(self, figure, renderer=None,
  261. pad=1.08, h_pad=None, w_pad=None, rect=None):
  262. """
  263. Adjust subplot parameters to give specified padding.
  264. Parameters
  265. ----------
  266. pad : float
  267. Padding between the figure edge and the edges of subplots, as a
  268. fraction of the font-size.
  269. h_pad, w_pad : float, optional
  270. Padding (height/width) between edges of adjacent subplots.
  271. Defaults to ``pad_inches``.
  272. rect : tuple of 4 floats, optional
  273. (left, bottom, right, top) rectangle in normalized figure
  274. coordinates that the whole subplots area (including labels) will
  275. fit into. Default is (0, 0, 1, 1).
  276. """
  277. subplotspec_list = tight_layout.get_subplotspec_list(
  278. figure.axes, grid_spec=self)
  279. if None in subplotspec_list:
  280. warnings.warn("This figure includes Axes that are not compatible "
  281. "with tight_layout, so results might be incorrect.")
  282. if renderer is None:
  283. renderer = tight_layout.get_renderer(figure)
  284. kwargs = tight_layout.get_tight_layout_figure(
  285. figure, figure.axes, subplotspec_list, renderer,
  286. pad=pad, h_pad=h_pad, w_pad=w_pad, rect=rect)
  287. if kwargs:
  288. self.update(**kwargs)
  289. class GridSpecFromSubplotSpec(GridSpecBase):
  290. """
  291. GridSpec whose subplot layout parameters are inherited from the
  292. location specified by a given SubplotSpec.
  293. """
  294. def __init__(self, nrows, ncols,
  295. subplot_spec,
  296. wspace=None, hspace=None,
  297. height_ratios=None, width_ratios=None):
  298. """
  299. The number of rows and number of columns of the grid need to
  300. be set. An instance of SubplotSpec is also needed to be set
  301. from which the layout parameters will be inherited. The wspace
  302. and hspace of the layout can be optionally specified or the
  303. default values (from the figure or rcParams) will be used.
  304. """
  305. self._wspace = wspace
  306. self._hspace = hspace
  307. self._subplot_spec = subplot_spec
  308. GridSpecBase.__init__(self, nrows, ncols,
  309. width_ratios=width_ratios,
  310. height_ratios=height_ratios)
  311. # do the layoutboxes
  312. subspeclb = subplot_spec._layoutbox
  313. if subspeclb is None:
  314. self._layoutbox = None
  315. else:
  316. # OK, this is needed to divide the figure.
  317. self._layoutbox = subspeclb.layout_from_subplotspec(
  318. subplot_spec,
  319. name=subspeclb.name + '.gridspec' + layoutbox.seq_id(),
  320. artist=self)
  321. def get_subplot_params(self, figure=None, fig=None):
  322. """Return a dictionary of subplot layout parameters.
  323. """
  324. if fig is not None:
  325. cbook.warn_deprecated("2.2", "fig", obj_type="keyword argument",
  326. alternative="figure")
  327. if figure is None:
  328. figure = fig
  329. hspace = (self._hspace if self._hspace is not None
  330. else figure.subplotpars.hspace if figure is not None
  331. else rcParams["figure.subplot.hspace"])
  332. wspace = (self._wspace if self._wspace is not None
  333. else figure.subplotpars.wspace if figure is not None
  334. else rcParams["figure.subplot.wspace"])
  335. figbox = self._subplot_spec.get_position(figure)
  336. left, bottom, right, top = figbox.extents
  337. return mpl.figure.SubplotParams(left=left, right=right,
  338. bottom=bottom, top=top,
  339. wspace=wspace, hspace=hspace)
  340. def get_topmost_subplotspec(self):
  341. """Get the topmost SubplotSpec instance associated with the subplot."""
  342. return self._subplot_spec.get_topmost_subplotspec()
  343. class SubplotSpec(object):
  344. """Specifies the location of the subplot in the given `GridSpec`.
  345. """
  346. def __init__(self, gridspec, num1, num2=None):
  347. """
  348. The subplot will occupy the num1-th cell of the given
  349. gridspec. If num2 is provided, the subplot will span between
  350. num1-th cell and num2-th cell.
  351. The index starts from 0.
  352. """
  353. self._gridspec = gridspec
  354. self.num1 = num1
  355. self.num2 = num2
  356. if gridspec._layoutbox is not None:
  357. glb = gridspec._layoutbox
  358. # So note that here we don't assign any layout yet,
  359. # just make the layoutbox that will conatin all items
  360. # associated w/ this axis. This can include other axes like
  361. # a colorbar or a legend.
  362. self._layoutbox = layoutbox.LayoutBox(
  363. parent=glb,
  364. name=glb.name + '.ss' + layoutbox.seq_id(),
  365. artist=self)
  366. else:
  367. self._layoutbox = None
  368. def __getstate__(self):
  369. state = self.__dict__
  370. try:
  371. state.pop('_layoutbox')
  372. except KeyError:
  373. pass
  374. return state
  375. def __setstate__(self, state):
  376. self.__dict__ = state
  377. # layoutboxes don't survive pickling...
  378. self._layoutbox = None
  379. def get_gridspec(self):
  380. return self._gridspec
  381. def get_geometry(self):
  382. """
  383. Get the subplot geometry (``n_rows, n_cols, start, stop``).
  384. start and stop are the index of the start and stop of the
  385. subplot.
  386. """
  387. rows, cols = self.get_gridspec().get_geometry()
  388. return rows, cols, self.num1, self.num2
  389. def get_rows_columns(self):
  390. """
  391. Get the subplot row and column numbers:
  392. (``n_rows, n_cols, row_start, row_stop, col_start, col_stop``)
  393. """
  394. gridspec = self.get_gridspec()
  395. nrows, ncols = gridspec.get_geometry()
  396. row_start, col_start = divmod(self.num1, ncols)
  397. if self.num2 is not None:
  398. row_stop, col_stop = divmod(self.num2, ncols)
  399. else:
  400. row_stop = row_start
  401. col_stop = col_start
  402. return nrows, ncols, row_start, row_stop, col_start, col_stop
  403. def get_position(self, figure, return_all=False):
  404. """Update the subplot position from ``figure.subplotpars``.
  405. """
  406. gridspec = self.get_gridspec()
  407. nrows, ncols = gridspec.get_geometry()
  408. rows, cols = np.unravel_index(
  409. [self.num1] if self.num2 is None else [self.num1, self.num2],
  410. (nrows, ncols))
  411. fig_bottoms, fig_tops, fig_lefts, fig_rights = \
  412. gridspec.get_grid_positions(figure)
  413. fig_bottom = fig_bottoms[rows].min()
  414. fig_top = fig_tops[rows].max()
  415. fig_left = fig_lefts[cols].min()
  416. fig_right = fig_rights[cols].max()
  417. figbox = Bbox.from_extents(fig_left, fig_bottom, fig_right, fig_top)
  418. if return_all:
  419. return figbox, rows[0], cols[0], nrows, ncols
  420. else:
  421. return figbox
  422. def get_topmost_subplotspec(self):
  423. 'get the topmost SubplotSpec instance associated with the subplot'
  424. gridspec = self.get_gridspec()
  425. if hasattr(gridspec, "get_topmost_subplotspec"):
  426. return gridspec.get_topmost_subplotspec()
  427. else:
  428. return self
  429. def __eq__(self, other):
  430. # other may not even have the attributes we are checking.
  431. return ((self._gridspec, self.num1, self.num2)
  432. == (getattr(other, "_gridspec", object()),
  433. getattr(other, "num1", object()),
  434. getattr(other, "num2", object())))
  435. def __hash__(self):
  436. return hash((self._gridspec, self.num1, self.num2))
  437. def subgridspec(self, nrows, ncols, **kwargs):
  438. """
  439. Return a `.GridSpecFromSubplotSpec` that has this subplotspec as
  440. a parent.
  441. Parameters
  442. ----------
  443. nrows : int
  444. Number of rows in grid.
  445. ncols : int
  446. Number or columns in grid.
  447. Returns
  448. -------
  449. gridspec : `.GridSpec`
  450. Other Parameters
  451. ----------------
  452. **kwargs
  453. All other parameters are passed to `.GridSpec`.
  454. See Also
  455. --------
  456. matplotlib.pyplot.subplots
  457. Examples
  458. --------
  459. Adding three subplots in the space occupied by a single subplot::
  460. fig = plt.figure()
  461. gs0 = fig.add_gridspec(3, 1)
  462. ax1 = fig.add_subplot(gs0[0])
  463. ax2 = fig.add_subplot(gs0[1])
  464. gssub = gs0[2].subgridspec(1, 3)
  465. for i in range(3):
  466. fig.add_subplot(gssub[0, i])
  467. """
  468. return GridSpecFromSubplotSpec(nrows, ncols, self, **kwargs)