You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

287 lines
9.3 KiB

4 years ago
  1. import os
  2. import sys
  3. import re
  4. from glob import glob
  5. import matplotlib as mpl
  6. from jupyter_core.paths import jupyter_config_dir
  7. # path to install (~/.jupyter/custom/)
  8. jupyter_custom = os.path.join(jupyter_config_dir(), 'custom')
  9. # path to local site-packages/jupyterthemes
  10. package_dir = os.path.dirname(os.path.realpath(__file__))
  11. # theme colors, layout, and font directories
  12. styles_dir = os.path.join(package_dir, 'styles')
  13. # text file containing name of currently installed theme
  14. theme_name_file = os.path.join(jupyter_custom, 'current_theme.txt')
  15. # base style params
  16. base_style = {
  17. 'axes.axisbelow': True,
  18. 'figure.autolayout': True,
  19. 'grid.linestyle': u'-',
  20. 'lines.solid_capstyle': u'round',
  21. 'legend.frameon': False,
  22. "legend.numpoints": 1,
  23. "legend.scatterpoints": 1}
  24. # base context params
  25. base_context = {
  26. 'axes.linewidth': 1.4,
  27. "grid.linewidth": 1.4,
  28. "lines.linewidth": 1.5,
  29. "patch.linewidth": .2,
  30. "lines.markersize": 7,
  31. "lines.markeredgewidth": 0,
  32. "xtick.major.width": 1,
  33. "ytick.major.width": 1,
  34. "xtick.minor.width": .5,
  35. "ytick.minor.width": .5,
  36. "xtick.major.pad": 7,
  37. "ytick.major.pad": 7,
  38. "xtick.major.size": 0,
  39. "ytick.major.size": 0,
  40. "xtick.minor.size": 0,
  41. "ytick.minor.size": 0}
  42. # base font params
  43. base_font = {
  44. "font.size": 11,
  45. "axes.labelsize": 12,
  46. "axes.titlesize": 12,
  47. "xtick.labelsize": 10.5,
  48. "ytick.labelsize": 10.5,
  49. "legend.fontsize": 10.5}
  50. def remove_non_colors(clist):
  51. checkHex = r'^#(?:[0-9a-fA-F]{3}){1,2}$'
  52. return [clr for clr in clist if re.search(checkHex, clr)]
  53. def infer_theme():
  54. """ checks jupyter_config_dir() for text file containing theme name
  55. (updated whenever user installs a new theme)
  56. """
  57. themes = [os.path.basename(theme).replace('.less', '')
  58. for theme in glob('{0}/*.less'.format(styles_dir))]
  59. if os.path.exists(theme_name_file):
  60. with open(theme_name_file) as f:
  61. theme = f.readlines()[0]
  62. if theme not in themes:
  63. theme = 'default'
  64. else:
  65. theme = 'default'
  66. return theme
  67. def style(theme=None, context='paper', grid=True, gridlines=u'-', ticks=False, spines=True, fscale=1.2, figsize=(8., 7.)):
  68. """
  69. main function for styling matplotlib according to theme
  70. ::Arguments::
  71. theme (str): 'oceans16', 'grade3', 'chesterish', 'onedork', 'monokai', 'solarizedl', 'solarizedd'. If no theme name supplied the currently installed notebook theme will be used.
  72. context (str): 'paper' (Default), 'notebook', 'talk', or 'poster'
  73. grid (bool): removes axis grid lines if False
  74. gridlines (str): set grid linestyle (e.g., '--' for dashed grid)
  75. ticks (bool): make major x and y ticks visible
  76. spines (bool): removes x (bottom) and y (left) axis spines if False
  77. fscale (float): scale font size for axes labels, legend, etc.
  78. figsize (tuple): default figure size of matplotlib figures
  79. """
  80. # set context and font rc parameters, return rcdict
  81. rcdict = set_context(context=context, fscale=fscale, figsize=figsize)
  82. # read in theme name from ~/.jupyter/custom/current_theme.txt
  83. if theme is None:
  84. theme = infer_theme()
  85. # combine context & font rcparams with theme style
  86. set_style(rcdict, theme=theme, grid=grid, gridlines=gridlines, ticks=ticks, spines=spines)
  87. def set_style(rcdict, theme=None, grid=True, gridlines=u'-', ticks=False, spines=True):
  88. """
  89. This code has been modified from seaborn.rcmod.set_style()
  90. ::Arguments::
  91. rcdict (str): dict of "context" properties (filled by set_context())
  92. theme (str): name of theme to use when setting color properties
  93. grid (bool): turns off axis grid if False (default: True)
  94. ticks (bool): removes x,y axis ticks if True (default: False)
  95. spines (bool): removes axis spines if False (default: True)
  96. """
  97. # extract style and color info for theme
  98. styleMap, clist = get_theme_style(theme)
  99. # extract style variables
  100. figureFace = styleMap['figureFace']
  101. axisFace = styleMap['axisFace']
  102. textColor = styleMap['textColor']
  103. edgeColor = styleMap['edgeColor']
  104. gridColor = styleMap['gridColor']
  105. if not spines:
  106. edgeColor = 'none'
  107. style_dict = {
  108. 'figure.edgecolor': figureFace,
  109. 'figure.facecolor': figureFace,
  110. 'axes.facecolor': axisFace,
  111. 'axes.edgecolor': edgeColor,
  112. 'axes.labelcolor': textColor,
  113. 'axes.grid': grid,
  114. 'grid.linestyle': gridlines,
  115. 'grid.color': gridColor,
  116. 'text.color': textColor,
  117. 'xtick.color': textColor,
  118. 'ytick.color': textColor,
  119. 'patch.edgecolor': axisFace,
  120. 'patch.facecolor': gridColor,
  121. 'savefig.facecolor': figureFace,
  122. 'savefig.edgecolor': figureFace}
  123. # update rcdict with style params
  124. rcdict.update(style_dict)
  125. # Show or hide the axes ticks
  126. if ticks:
  127. rcdict.update({
  128. "xtick.major.size": 6,
  129. "ytick.major.size": 6,
  130. "xtick.minor.size": 3,
  131. "ytick.minor.size": 3})
  132. base_style.update(rcdict)
  133. # update matplotlib with rcdict (incl. context, font, & style)
  134. mpl.rcParams.update(rcdict)
  135. # update seaborn with rcdict (incl. context, font, & style)
  136. try:
  137. import seaborn as sns
  138. sns.set_style(rc=rcdict)
  139. except Exception:
  140. pass
  141. try:
  142. from cycler import cycler
  143. # set color cycle to jt-style color list
  144. mpl.rcParams['axes.prop_cycle'] = cycler(color=clist)
  145. except Exception:
  146. pass
  147. # replace default blue, green, etc. with jt colors
  148. for code, color in zip("bgrmyck", clist[:7]):
  149. rgb = mpl.colors.colorConverter.to_rgb(color)
  150. mpl.colors.colorConverter.colors[code] = rgb
  151. mpl.colors.colorConverter.cache[code] = rgb
  152. def set_context(context='paper', fscale=1., figsize=(8., 7.)):
  153. """
  154. Most of this code has been copied/modified from seaborn.rcmod.plotting_context()
  155. ::Arguments::
  156. context (str): 'paper', 'notebook', 'talk', or 'poster'
  157. fscale (float): font-size scalar applied to axes ticks, legend, labels, etc.
  158. """
  159. # scale all the parameters by the same factor depending on the context
  160. scaling = dict(paper=.8, notebook=1, talk=1.3, poster=1.6)[context]
  161. context_dict = {k: v * scaling for k, v in base_context.items()}
  162. # scale default figsize
  163. figX, figY = figsize
  164. context_dict["figure.figsize"] = (figX*scaling, figY*scaling)
  165. # independently scale the fonts
  166. font_dict = {k: v * fscale for k, v in base_font.items()}
  167. font_dict["font.family"] = ["sans-serif"]
  168. font_dict["font.sans-serif"] = ["Helvetica", "Helvetica Neue", "Arial",
  169. "DejaVu Sans", "Liberation Sans", "sans-serif"]
  170. context_dict.update(font_dict)
  171. return context_dict
  172. def figsize(x=8, y=7., aspect=1.):
  173. """ manually set the default figure size of plots
  174. ::Arguments::
  175. x (float): x-axis size
  176. y (float): y-axis size
  177. aspect (float): aspect ratio scalar
  178. """
  179. # update rcparams with adjusted figsize params
  180. mpl.rcParams.update({'figure.figsize': (x*aspect, y)})
  181. def get_theme_style(theme):
  182. """
  183. read-in theme style info and populate styleMap (dict of with mpl.rcParams)
  184. and clist (list of hex codes passed to color cylcler)
  185. ::Arguments::
  186. theme (str): theme name
  187. ::Returns::
  188. styleMap (dict): dict containing theme-specific colors for figure properties
  189. clist (list): list of colors to replace mpl's default color_cycle
  190. """
  191. styleMap, clist = get_default_jtstyle()
  192. if theme == 'default':
  193. return styleMap, clist
  194. syntaxVars = ['@yellow:', '@orange:', '@red:', '@magenta:', '@violet:', '@blue:', '@cyan:', '@green:']
  195. get_hex_code = lambda line: line.split(':')[-1].split(';')[0][-7:]
  196. themeFile = os.path.join(styles_dir, theme+'.less')
  197. with open(themeFile) as f:
  198. for line in f:
  199. for k, v in styleMap.items():
  200. if k in line.strip():
  201. styleMap[k] = get_hex_code(line)
  202. for c in syntaxVars:
  203. if c in line.strip():
  204. syntaxVars[syntaxVars.index(c)] = get_hex_code(line)
  205. # remove duplicate hexcolors
  206. syntaxVars = list(set(syntaxVars))
  207. clist.extend(syntaxVars)
  208. clist = remove_non_colors(clist)
  209. return styleMap, clist
  210. def get_default_jtstyle():
  211. styleMap = {'axisFace': 'white',
  212. 'figureFace': 'white',
  213. 'textColor': '.15',
  214. 'edgeColor': '.8',
  215. 'gridColor': '.8'}
  216. return styleMap, get_color_list()
  217. def get_color_list():
  218. return ['#3572C6', '#83a83b', '#c44e52', '#8172b2', "#ff914d",
  219. "#77BEDB", "#222222", "#4168B7", "#27ae60", "#e74c3c",'#bc89e0',
  220. "#ff711a", "#3498db", '#6C7A89']
  221. def reset():
  222. """ full reset of matplotlib default style and colors
  223. """
  224. colors = [(0., 0., 1.), (0., .5, 0.), (1., 0., 0.), (.75, .75, 0.),
  225. (.75, .75, 0.), (0., .75, .75), (0., 0., 0.)]
  226. for code, color in zip("bgrmyck", colors):
  227. rgb = mpl.colors.colorConverter.to_rgb(color)
  228. mpl.colors.colorConverter.colors[code] = rgb
  229. mpl.colors.colorConverter.cache[code] = rgb
  230. mpl.rcParams.update(mpl.rcParamsDefault)
  231. mpl.rcParams['figure.facecolor'] = 'white'
  232. mpl.rcParams['axes.facecolor'] = 'white'