You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

602 lines
19 KiB

4 years ago
  1. import numpy as np
  2. from numpy import ma
  3. from matplotlib import cbook, docstring, rcParams
  4. from matplotlib.ticker import (
  5. NullFormatter, ScalarFormatter, LogFormatterSciNotation, LogitFormatter,
  6. NullLocator, LogLocator, AutoLocator, AutoMinorLocator,
  7. SymmetricalLogLocator, LogitLocator)
  8. from matplotlib.transforms import Transform, IdentityTransform
  9. class ScaleBase(object):
  10. """
  11. The base class for all scales.
  12. Scales are separable transformations, working on a single dimension.
  13. Any subclasses will want to override:
  14. - :attr:`name`
  15. - :meth:`get_transform`
  16. - :meth:`set_default_locators_and_formatters`
  17. And optionally:
  18. - :meth:`limit_range_for_scale`
  19. """
  20. def get_transform(self):
  21. """
  22. Return the :class:`~matplotlib.transforms.Transform` object
  23. associated with this scale.
  24. """
  25. raise NotImplementedError()
  26. def set_default_locators_and_formatters(self, axis):
  27. """
  28. Set the :class:`~matplotlib.ticker.Locator` and
  29. :class:`~matplotlib.ticker.Formatter` objects on the given
  30. axis to match this scale.
  31. """
  32. raise NotImplementedError()
  33. def limit_range_for_scale(self, vmin, vmax, minpos):
  34. """
  35. Returns the range *vmin*, *vmax*, possibly limited to the
  36. domain supported by this scale.
  37. *minpos* should be the minimum positive value in the data.
  38. This is used by log scales to determine a minimum value.
  39. """
  40. return vmin, vmax
  41. class LinearScale(ScaleBase):
  42. """
  43. The default linear scale.
  44. """
  45. name = 'linear'
  46. def __init__(self, axis, **kwargs):
  47. pass
  48. def set_default_locators_and_formatters(self, axis):
  49. """
  50. Set the locators and formatters to reasonable defaults for
  51. linear scaling.
  52. """
  53. axis.set_major_locator(AutoLocator())
  54. axis.set_major_formatter(ScalarFormatter())
  55. axis.set_minor_formatter(NullFormatter())
  56. # update the minor locator for x and y axis based on rcParams
  57. if rcParams['xtick.minor.visible']:
  58. axis.set_minor_locator(AutoMinorLocator())
  59. else:
  60. axis.set_minor_locator(NullLocator())
  61. def get_transform(self):
  62. """
  63. The transform for linear scaling is just the
  64. :class:`~matplotlib.transforms.IdentityTransform`.
  65. """
  66. return IdentityTransform()
  67. class LogTransformBase(Transform):
  68. input_dims = 1
  69. output_dims = 1
  70. is_separable = True
  71. has_inverse = True
  72. def __init__(self, nonpos='clip'):
  73. Transform.__init__(self)
  74. self._clip = {"clip": True, "mask": False}[nonpos]
  75. def transform_non_affine(self, a):
  76. # Ignore invalid values due to nans being passed to the transform
  77. with np.errstate(divide="ignore", invalid="ignore"):
  78. out = np.log(a)
  79. out /= np.log(self.base)
  80. if self._clip:
  81. # SVG spec says that conforming viewers must support values up
  82. # to 3.4e38 (C float); however experiments suggest that
  83. # Inkscape (which uses cairo for rendering) runs into cairo's
  84. # 24-bit limit (which is apparently shared by Agg).
  85. # Ghostscript (used for pdf rendering appears to overflow even
  86. # earlier, with the max value around 2 ** 15 for the tests to
  87. # pass. On the other hand, in practice, we want to clip beyond
  88. # np.log10(np.nextafter(0, 1)) ~ -323
  89. # so 1000 seems safe.
  90. out[a <= 0] = -1000
  91. return out
  92. def __str__(self):
  93. return "{}({!r})".format(
  94. type(self).__name__, "clip" if self._clip else "mask")
  95. class InvertedLogTransformBase(Transform):
  96. input_dims = 1
  97. output_dims = 1
  98. is_separable = True
  99. has_inverse = True
  100. def transform_non_affine(self, a):
  101. return ma.power(self.base, a)
  102. def __str__(self):
  103. return "{}()".format(type(self).__name__)
  104. class Log10Transform(LogTransformBase):
  105. base = 10.0
  106. def inverted(self):
  107. return InvertedLog10Transform()
  108. class InvertedLog10Transform(InvertedLogTransformBase):
  109. base = 10.0
  110. def inverted(self):
  111. return Log10Transform()
  112. class Log2Transform(LogTransformBase):
  113. base = 2.0
  114. def inverted(self):
  115. return InvertedLog2Transform()
  116. class InvertedLog2Transform(InvertedLogTransformBase):
  117. base = 2.0
  118. def inverted(self):
  119. return Log2Transform()
  120. class NaturalLogTransform(LogTransformBase):
  121. base = np.e
  122. def inverted(self):
  123. return InvertedNaturalLogTransform()
  124. class InvertedNaturalLogTransform(InvertedLogTransformBase):
  125. base = np.e
  126. def inverted(self):
  127. return NaturalLogTransform()
  128. class LogTransform(LogTransformBase):
  129. def __init__(self, base, nonpos='clip'):
  130. LogTransformBase.__init__(self, nonpos)
  131. self.base = base
  132. def inverted(self):
  133. return InvertedLogTransform(self.base)
  134. class InvertedLogTransform(InvertedLogTransformBase):
  135. def __init__(self, base):
  136. InvertedLogTransformBase.__init__(self)
  137. self.base = base
  138. def inverted(self):
  139. return LogTransform(self.base)
  140. class LogScale(ScaleBase):
  141. """
  142. A standard logarithmic scale. Care is taken so non-positive
  143. values are not plotted.
  144. For computational efficiency (to push as much as possible to Numpy
  145. C code in the common cases), this scale provides different
  146. transforms depending on the base of the logarithm:
  147. - base 10 (:class:`Log10Transform`)
  148. - base 2 (:class:`Log2Transform`)
  149. - base e (:class:`NaturalLogTransform`)
  150. - arbitrary base (:class:`LogTransform`)
  151. """
  152. name = 'log'
  153. # compatibility shim
  154. LogTransformBase = LogTransformBase
  155. Log10Transform = Log10Transform
  156. InvertedLog10Transform = InvertedLog10Transform
  157. Log2Transform = Log2Transform
  158. InvertedLog2Transform = InvertedLog2Transform
  159. NaturalLogTransform = NaturalLogTransform
  160. InvertedNaturalLogTransform = InvertedNaturalLogTransform
  161. LogTransform = LogTransform
  162. InvertedLogTransform = InvertedLogTransform
  163. def __init__(self, axis, **kwargs):
  164. """
  165. *basex*/*basey*:
  166. The base of the logarithm
  167. *nonposx*/*nonposy*: {'mask', 'clip'}
  168. non-positive values in *x* or *y* can be masked as
  169. invalid, or clipped to a very small positive number
  170. *subsx*/*subsy*:
  171. Where to place the subticks between each major tick.
  172. Should be a sequence of integers. For example, in a log10
  173. scale: ``[2, 3, 4, 5, 6, 7, 8, 9]``
  174. will place 8 logarithmically spaced minor ticks between
  175. each major tick.
  176. """
  177. if axis.axis_name == 'x':
  178. base = kwargs.pop('basex', 10.0)
  179. subs = kwargs.pop('subsx', None)
  180. nonpos = kwargs.pop('nonposx', 'clip')
  181. else:
  182. base = kwargs.pop('basey', 10.0)
  183. subs = kwargs.pop('subsy', None)
  184. nonpos = kwargs.pop('nonposy', 'clip')
  185. if len(kwargs):
  186. raise ValueError(("provided too many kwargs, can only pass "
  187. "{'basex', 'subsx', nonposx'} or "
  188. "{'basey', 'subsy', nonposy'}. You passed ") +
  189. "{!r}".format(kwargs))
  190. if nonpos not in ['mask', 'clip']:
  191. raise ValueError("nonposx, nonposy kwarg must be 'mask' or 'clip'")
  192. if base <= 0 or base == 1:
  193. raise ValueError('The log base cannot be <= 0 or == 1')
  194. if base == 10.0:
  195. self._transform = self.Log10Transform(nonpos)
  196. elif base == 2.0:
  197. self._transform = self.Log2Transform(nonpos)
  198. elif base == np.e:
  199. self._transform = self.NaturalLogTransform(nonpos)
  200. else:
  201. self._transform = self.LogTransform(base, nonpos)
  202. self.base = base
  203. self.subs = subs
  204. def set_default_locators_and_formatters(self, axis):
  205. """
  206. Set the locators and formatters to specialized versions for
  207. log scaling.
  208. """
  209. axis.set_major_locator(LogLocator(self.base))
  210. axis.set_major_formatter(LogFormatterSciNotation(self.base))
  211. axis.set_minor_locator(LogLocator(self.base, self.subs))
  212. axis.set_minor_formatter(
  213. LogFormatterSciNotation(self.base,
  214. labelOnlyBase=(self.subs is not None)))
  215. def get_transform(self):
  216. """
  217. Return a :class:`~matplotlib.transforms.Transform` instance
  218. appropriate for the given logarithm base.
  219. """
  220. return self._transform
  221. def limit_range_for_scale(self, vmin, vmax, minpos):
  222. """
  223. Limit the domain to positive values.
  224. """
  225. if not np.isfinite(minpos):
  226. minpos = 1e-300 # This value should rarely if ever
  227. # end up with a visible effect.
  228. return (minpos if vmin <= 0 else vmin,
  229. minpos if vmax <= 0 else vmax)
  230. class SymmetricalLogTransform(Transform):
  231. input_dims = 1
  232. output_dims = 1
  233. is_separable = True
  234. has_inverse = True
  235. def __init__(self, base, linthresh, linscale):
  236. Transform.__init__(self)
  237. self.base = base
  238. self.linthresh = linthresh
  239. self.linscale = linscale
  240. self._linscale_adj = (linscale / (1.0 - self.base ** -1))
  241. self._log_base = np.log(base)
  242. def transform_non_affine(self, a):
  243. sign = np.sign(a)
  244. masked = ma.masked_inside(a,
  245. -self.linthresh,
  246. self.linthresh,
  247. copy=False)
  248. log = sign * self.linthresh * (
  249. self._linscale_adj +
  250. ma.log(np.abs(masked) / self.linthresh) / self._log_base)
  251. if masked.mask.any():
  252. return ma.where(masked.mask, a * self._linscale_adj, log)
  253. else:
  254. return log
  255. def inverted(self):
  256. return InvertedSymmetricalLogTransform(self.base, self.linthresh,
  257. self.linscale)
  258. class InvertedSymmetricalLogTransform(Transform):
  259. input_dims = 1
  260. output_dims = 1
  261. is_separable = True
  262. has_inverse = True
  263. def __init__(self, base, linthresh, linscale):
  264. Transform.__init__(self)
  265. symlog = SymmetricalLogTransform(base, linthresh, linscale)
  266. self.base = base
  267. self.linthresh = linthresh
  268. self.invlinthresh = symlog.transform(linthresh)
  269. self.linscale = linscale
  270. self._linscale_adj = (linscale / (1.0 - self.base ** -1))
  271. def transform_non_affine(self, a):
  272. sign = np.sign(a)
  273. masked = ma.masked_inside(a, -self.invlinthresh,
  274. self.invlinthresh, copy=False)
  275. exp = sign * self.linthresh * (
  276. ma.power(self.base, (sign * (masked / self.linthresh))
  277. - self._linscale_adj))
  278. if masked.mask.any():
  279. return ma.where(masked.mask, a / self._linscale_adj, exp)
  280. else:
  281. return exp
  282. def inverted(self):
  283. return SymmetricalLogTransform(self.base,
  284. self.linthresh, self.linscale)
  285. class SymmetricalLogScale(ScaleBase):
  286. """
  287. The symmetrical logarithmic scale is logarithmic in both the
  288. positive and negative directions from the origin.
  289. Since the values close to zero tend toward infinity, there is a
  290. need to have a range around zero that is linear. The parameter
  291. *linthresh* allows the user to specify the size of this range
  292. (-*linthresh*, *linthresh*).
  293. """
  294. name = 'symlog'
  295. # compatibility shim
  296. SymmetricalLogTransform = SymmetricalLogTransform
  297. InvertedSymmetricalLogTransform = InvertedSymmetricalLogTransform
  298. def __init__(self, axis, **kwargs):
  299. """
  300. *basex*/*basey*:
  301. The base of the logarithm
  302. *linthreshx*/*linthreshy*:
  303. A single float which defines the range (-*x*, *x*), within
  304. which the plot is linear. This avoids having the plot go to
  305. infinity around zero.
  306. *subsx*/*subsy*:
  307. Where to place the subticks between each major tick.
  308. Should be a sequence of integers. For example, in a log10
  309. scale: ``[2, 3, 4, 5, 6, 7, 8, 9]``
  310. will place 8 logarithmically spaced minor ticks between
  311. each major tick.
  312. *linscalex*/*linscaley*:
  313. This allows the linear range (-*linthresh* to *linthresh*)
  314. to be stretched relative to the logarithmic range. Its
  315. value is the number of decades to use for each half of the
  316. linear range. For example, when *linscale* == 1.0 (the
  317. default), the space used for the positive and negative
  318. halves of the linear range will be equal to one decade in
  319. the logarithmic range.
  320. """
  321. if axis.axis_name == 'x':
  322. base = kwargs.pop('basex', 10.0)
  323. linthresh = kwargs.pop('linthreshx', 2.0)
  324. subs = kwargs.pop('subsx', None)
  325. linscale = kwargs.pop('linscalex', 1.0)
  326. else:
  327. base = kwargs.pop('basey', 10.0)
  328. linthresh = kwargs.pop('linthreshy', 2.0)
  329. subs = kwargs.pop('subsy', None)
  330. linscale = kwargs.pop('linscaley', 1.0)
  331. if base <= 1.0:
  332. raise ValueError("'basex/basey' must be larger than 1")
  333. if linthresh <= 0.0:
  334. raise ValueError("'linthreshx/linthreshy' must be positive")
  335. if linscale <= 0.0:
  336. raise ValueError("'linscalex/linthreshy' must be positive")
  337. self._transform = self.SymmetricalLogTransform(base,
  338. linthresh,
  339. linscale)
  340. self.base = base
  341. self.linthresh = linthresh
  342. self.linscale = linscale
  343. self.subs = subs
  344. def set_default_locators_and_formatters(self, axis):
  345. """
  346. Set the locators and formatters to specialized versions for
  347. symmetrical log scaling.
  348. """
  349. axis.set_major_locator(SymmetricalLogLocator(self.get_transform()))
  350. axis.set_major_formatter(LogFormatterSciNotation(self.base))
  351. axis.set_minor_locator(SymmetricalLogLocator(self.get_transform(),
  352. self.subs))
  353. axis.set_minor_formatter(NullFormatter())
  354. def get_transform(self):
  355. """
  356. Return a :class:`SymmetricalLogTransform` instance.
  357. """
  358. return self._transform
  359. class LogitTransform(Transform):
  360. input_dims = 1
  361. output_dims = 1
  362. is_separable = True
  363. has_inverse = True
  364. def __init__(self, nonpos='mask'):
  365. Transform.__init__(self)
  366. self._nonpos = nonpos
  367. self._clip = {"clip": True, "mask": False}[nonpos]
  368. def transform_non_affine(self, a):
  369. """logit transform (base 10), masked or clipped"""
  370. with np.errstate(divide="ignore", invalid="ignore"):
  371. out = np.log10(a / (1 - a))
  372. if self._clip: # See LogTransform for choice of clip value.
  373. out[a <= 0] = -1000
  374. out[1 <= a] = 1000
  375. return out
  376. def inverted(self):
  377. return LogisticTransform(self._nonpos)
  378. def __str__(self):
  379. return "{}({!r})".format(type(self).__name__,
  380. "clip" if self._clip else "mask")
  381. class LogisticTransform(Transform):
  382. input_dims = 1
  383. output_dims = 1
  384. is_separable = True
  385. has_inverse = True
  386. def __init__(self, nonpos='mask'):
  387. Transform.__init__(self)
  388. self._nonpos = nonpos
  389. def transform_non_affine(self, a):
  390. """logistic transform (base 10)"""
  391. return 1.0 / (1 + 10**(-a))
  392. def inverted(self):
  393. return LogitTransform(self._nonpos)
  394. def __str__(self):
  395. return "{}({!r})".format(type(self).__name__, self._nonpos)
  396. class LogitScale(ScaleBase):
  397. """
  398. Logit scale for data between zero and one, both excluded.
  399. This scale is similar to a log scale close to zero and to one, and almost
  400. linear around 0.5. It maps the interval ]0, 1[ onto ]-infty, +infty[.
  401. """
  402. name = 'logit'
  403. def __init__(self, axis, nonpos='mask'):
  404. """
  405. *nonpos*: {'mask', 'clip'}
  406. values beyond ]0, 1[ can be masked as invalid, or clipped to a number
  407. very close to 0 or 1
  408. """
  409. if nonpos not in ['mask', 'clip']:
  410. raise ValueError("nonposx, nonposy kwarg must be 'mask' or 'clip'")
  411. self._transform = LogitTransform(nonpos)
  412. def get_transform(self):
  413. """
  414. Return a :class:`LogitTransform` instance.
  415. """
  416. return self._transform
  417. def set_default_locators_and_formatters(self, axis):
  418. # ..., 0.01, 0.1, 0.5, 0.9, 0.99, ...
  419. axis.set_major_locator(LogitLocator())
  420. axis.set_major_formatter(LogitFormatter())
  421. axis.set_minor_locator(LogitLocator(minor=True))
  422. axis.set_minor_formatter(LogitFormatter())
  423. def limit_range_for_scale(self, vmin, vmax, minpos):
  424. """
  425. Limit the domain to values between 0 and 1 (excluded).
  426. """
  427. if not np.isfinite(minpos):
  428. minpos = 1e-7 # This value should rarely if ever
  429. # end up with a visible effect.
  430. return (minpos if vmin <= 0 else vmin,
  431. 1 - minpos if vmax >= 1 else vmax)
  432. _scale_mapping = {
  433. 'linear': LinearScale,
  434. 'log': LogScale,
  435. 'symlog': SymmetricalLogScale,
  436. 'logit': LogitScale,
  437. }
  438. def get_scale_names():
  439. return sorted(_scale_mapping)
  440. def scale_factory(scale, axis, **kwargs):
  441. """
  442. Return a scale class by name.
  443. ACCEPTS: [ %(names)s ]
  444. """
  445. scale = scale.lower()
  446. if scale is None:
  447. scale = 'linear'
  448. if scale not in _scale_mapping:
  449. raise ValueError("Unknown scale type '%s'" % scale)
  450. return _scale_mapping[scale](axis, **kwargs)
  451. scale_factory.__doc__ = cbook.dedent(scale_factory.__doc__) % \
  452. {'names': " | ".join(get_scale_names())}
  453. def register_scale(scale_class):
  454. """
  455. Register a new kind of scale.
  456. *scale_class* must be a subclass of :class:`ScaleBase`.
  457. """
  458. _scale_mapping[scale_class.name] = scale_class
  459. def get_scale_docs():
  460. """
  461. Helper function for generating docstrings related to scales.
  462. """
  463. docs = []
  464. for name in get_scale_names():
  465. scale_class = _scale_mapping[name]
  466. docs.append(" '%s'" % name)
  467. docs.append("")
  468. class_docs = cbook.dedent(scale_class.__init__.__doc__)
  469. class_docs = "".join([" %s\n" %
  470. x for x in class_docs.split("\n")])
  471. docs.append(class_docs)
  472. docs.append("")
  473. return "\n".join(docs)
  474. docstring.interpd.update(
  475. scale=' | '.join([repr(x) for x in get_scale_names()]),
  476. scale_docs=get_scale_docs().rstrip(),
  477. )