You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

626 lines
24 KiB

4 years ago
  1. # -*- coding: utf-8 -*-
  2. # Natural Language Toolkit: ASCII visualization of NLTK trees
  3. #
  4. # Copyright (C) 2001-2019 NLTK Project
  5. # Author: Andreas van Cranenburgh <A.W.vanCranenburgh@uva.nl>
  6. # Peter Ljunglöf <peter.ljunglof@gu.se>
  7. # URL: <http://nltk.org/>
  8. # For license information, see LICENSE.TXT
  9. """
  10. Pretty-printing of discontinuous trees.
  11. Adapted from the disco-dop project, by Andreas van Cranenburgh.
  12. https://github.com/andreasvc/disco-dop
  13. Interesting reference (not used for this code):
  14. T. Eschbach et al., Orth. Hypergraph Drawing, Journal of
  15. Graph Algorithms and Applications, 10(2) 141--157 (2006)149.
  16. http://jgaa.info/accepted/2006/EschbachGuentherBecker2006.10.2.pdf
  17. """
  18. from __future__ import division, print_function, unicode_literals
  19. import re
  20. from cgi import escape
  21. from collections import defaultdict
  22. from operator import itemgetter
  23. from nltk.util import OrderedDict
  24. from nltk.compat import python_2_unicode_compatible
  25. from nltk.tree import Tree
  26. ANSICOLOR = {
  27. 'black': 30,
  28. 'red': 31,
  29. 'green': 32,
  30. 'yellow': 33,
  31. 'blue': 34,
  32. 'magenta': 35,
  33. 'cyan': 36,
  34. 'white': 37,
  35. }
  36. @python_2_unicode_compatible
  37. class TreePrettyPrinter(object):
  38. """
  39. Pretty-print a tree in text format, either as ASCII or Unicode.
  40. The tree can be a normal tree, or discontinuous.
  41. ``TreePrettyPrinter(tree, sentence=None, highlight=())``
  42. creates an object from which different visualizations can be created.
  43. :param tree: a Tree object.
  44. :param sentence: a list of words (strings). If `sentence` is given,
  45. `tree` must contain integers as leaves, which are taken as indices
  46. in `sentence`. Using this you can display a discontinuous tree.
  47. :param highlight: Optionally, a sequence of Tree objects in `tree` which
  48. should be highlighted. Has the effect of only applying colors to nodes
  49. in this sequence (nodes should be given as Tree objects, terminals as
  50. indices).
  51. >>> from nltk.tree import Tree
  52. >>> tree = Tree.fromstring('(S (NP Mary) (VP walks))')
  53. >>> print(TreePrettyPrinter(tree).text())
  54. ... # doctest: +NORMALIZE_WHITESPACE
  55. S
  56. ____|____
  57. NP VP
  58. | |
  59. Mary walks
  60. """
  61. def __init__(self, tree, sentence=None, highlight=()):
  62. if sentence is None:
  63. leaves = tree.leaves()
  64. if (
  65. leaves
  66. and not any(len(a) == 0 for a in tree.subtrees())
  67. and all(isinstance(a, int) for a in leaves)
  68. ):
  69. sentence = [str(a) for a in leaves]
  70. else:
  71. # this deals with empty nodes (frontier non-terminals)
  72. # and multiple/mixed terminals under non-terminals.
  73. tree = tree.copy(True)
  74. sentence = []
  75. for a in tree.subtrees():
  76. if len(a) == 0:
  77. a.append(len(sentence))
  78. sentence.append(None)
  79. elif any(not isinstance(b, Tree) for b in a):
  80. for n, b in enumerate(a):
  81. if not isinstance(b, Tree):
  82. a[n] = len(sentence)
  83. if type(b) == tuple:
  84. b = '/'.join(b)
  85. sentence.append('%s' % b)
  86. self.nodes, self.coords, self.edges, self.highlight = self.nodecoords(
  87. tree, sentence, highlight
  88. )
  89. def __str__(self):
  90. return self.text()
  91. def __repr__(self):
  92. return '<TreePrettyPrinter with %d nodes>' % len(self.nodes)
  93. @staticmethod
  94. def nodecoords(tree, sentence, highlight):
  95. """
  96. Produce coordinates of nodes on a grid.
  97. Objective:
  98. - Produce coordinates for a non-overlapping placement of nodes and
  99. horizontal lines.
  100. - Order edges so that crossing edges cross a minimal number of previous
  101. horizontal lines (never vertical lines).
  102. Approach:
  103. - bottom up level order traversal (start at terminals)
  104. - at each level, identify nodes which cannot be on the same row
  105. - identify nodes which cannot be in the same column
  106. - place nodes into a grid at (row, column)
  107. - order child-parent edges with crossing edges last
  108. Coordinates are (row, column); the origin (0, 0) is at the top left;
  109. the root node is on row 0. Coordinates do not consider the size of a
  110. node (which depends on font, &c), so the width of a column of the grid
  111. should be automatically determined by the element with the greatest
  112. width in that column. Alternatively, the integer coordinates could be
  113. converted to coordinates in which the distances between adjacent nodes
  114. are non-uniform.
  115. Produces tuple (nodes, coords, edges, highlighted) where:
  116. - nodes[id]: Tree object for the node with this integer id
  117. - coords[id]: (n, m) coordinate where to draw node with id in the grid
  118. - edges[id]: parent id of node with this id (ordered dictionary)
  119. - highlighted: set of ids that should be highlighted
  120. """
  121. def findcell(m, matrix, startoflevel, children):
  122. """
  123. Find vacant row, column index for node ``m``.
  124. Iterate over current rows for this level (try lowest first)
  125. and look for cell between first and last child of this node,
  126. add new row to level if no free row available.
  127. """
  128. candidates = [a for _, a in children[m]]
  129. minidx, maxidx = min(candidates), max(candidates)
  130. leaves = tree[m].leaves()
  131. center = scale * sum(leaves) // len(leaves) # center of gravity
  132. if minidx < maxidx and not minidx < center < maxidx:
  133. center = sum(candidates) // len(candidates)
  134. if max(candidates) - min(candidates) > 2 * scale:
  135. center -= center % scale # round to unscaled coordinate
  136. if minidx < maxidx and not minidx < center < maxidx:
  137. center += scale
  138. if ids[m] == 0:
  139. startoflevel = len(matrix)
  140. for rowidx in range(startoflevel, len(matrix) + 1):
  141. if rowidx == len(matrix): # need to add a new row
  142. matrix.append(
  143. [
  144. vertline if a not in (corner, None) else None
  145. for a in matrix[-1]
  146. ]
  147. )
  148. row = matrix[rowidx]
  149. i = j = center
  150. if len(children[m]) == 1: # place unaries directly above child
  151. return rowidx, next(iter(children[m]))[1]
  152. elif all(
  153. a is None or a == vertline
  154. for a in row[min(candidates) : max(candidates) + 1]
  155. ):
  156. # find free column
  157. for n in range(scale):
  158. i = j = center + n
  159. while j > minidx or i < maxidx:
  160. if i < maxidx and (
  161. matrix[rowidx][i] is None or i in candidates
  162. ):
  163. return rowidx, i
  164. elif j > minidx and (
  165. matrix[rowidx][j] is None or j in candidates
  166. ):
  167. return rowidx, j
  168. i += scale
  169. j -= scale
  170. raise ValueError(
  171. 'could not find a free cell for:\n%s\n%s'
  172. 'min=%d; max=%d' % (tree[m], minidx, maxidx, dumpmatrix())
  173. )
  174. def dumpmatrix():
  175. """Dump matrix contents for debugging purposes."""
  176. return '\n'.join(
  177. '%2d: %s' % (n, ' '.join(('%2r' % i)[:2] for i in row))
  178. for n, row in enumerate(matrix)
  179. )
  180. leaves = tree.leaves()
  181. if not all(isinstance(n, int) for n in leaves):
  182. raise ValueError('All leaves must be integer indices.')
  183. if len(leaves) != len(set(leaves)):
  184. raise ValueError('Indices must occur at most once.')
  185. if not all(0 <= n < len(sentence) for n in leaves):
  186. raise ValueError(
  187. 'All leaves must be in the interval 0..n '
  188. 'with n=len(sentence)\ntokens: %d indices: '
  189. '%r\nsentence: %s' % (len(sentence), tree.leaves(), sentence)
  190. )
  191. vertline, corner = -1, -2 # constants
  192. tree = tree.copy(True)
  193. for a in tree.subtrees():
  194. a.sort(key=lambda n: min(n.leaves()) if isinstance(n, Tree) else n)
  195. scale = 2
  196. crossed = set()
  197. # internal nodes and lexical nodes (no frontiers)
  198. positions = tree.treepositions()
  199. maxdepth = max(map(len, positions)) + 1
  200. childcols = defaultdict(set)
  201. matrix = [[None] * (len(sentence) * scale)]
  202. nodes = {}
  203. ids = dict((a, n) for n, a in enumerate(positions))
  204. highlighted_nodes = set(
  205. n for a, n in ids.items() if not highlight or tree[a] in highlight
  206. )
  207. levels = dict((n, []) for n in range(maxdepth - 1))
  208. terminals = []
  209. for a in positions:
  210. node = tree[a]
  211. if isinstance(node, Tree):
  212. levels[maxdepth - node.height()].append(a)
  213. else:
  214. terminals.append(a)
  215. for n in levels:
  216. levels[n].sort(key=lambda n: max(tree[n].leaves()) - min(tree[n].leaves()))
  217. terminals.sort()
  218. positions = set(positions)
  219. for m in terminals:
  220. i = int(tree[m]) * scale
  221. assert matrix[0][i] is None, (matrix[0][i], m, i)
  222. matrix[0][i] = ids[m]
  223. nodes[ids[m]] = sentence[tree[m]]
  224. if nodes[ids[m]] is None:
  225. nodes[ids[m]] = '...'
  226. highlighted_nodes.discard(ids[m])
  227. positions.remove(m)
  228. childcols[m[:-1]].add((0, i))
  229. # add other nodes centered on their children,
  230. # if the center is already taken, back off
  231. # to the left and right alternately, until an empty cell is found.
  232. for n in sorted(levels, reverse=True):
  233. nodesatdepth = levels[n]
  234. startoflevel = len(matrix)
  235. matrix.append(
  236. [vertline if a not in (corner, None) else None for a in matrix[-1]]
  237. )
  238. for m in nodesatdepth: # [::-1]:
  239. if n < maxdepth - 1 and childcols[m]:
  240. _, pivot = min(childcols[m], key=itemgetter(1))
  241. if set(
  242. a[:-1]
  243. for row in matrix[:-1]
  244. for a in row[:pivot]
  245. if isinstance(a, tuple)
  246. ) & set(
  247. a[:-1]
  248. for row in matrix[:-1]
  249. for a in row[pivot:]
  250. if isinstance(a, tuple)
  251. ):
  252. crossed.add(m)
  253. rowidx, i = findcell(m, matrix, startoflevel, childcols)
  254. positions.remove(m)
  255. # block positions where children of this node branch out
  256. for _, x in childcols[m]:
  257. matrix[rowidx][x] = corner
  258. # assert m == () or matrix[rowidx][i] in (None, corner), (
  259. # matrix[rowidx][i], m, str(tree), ' '.join(sentence))
  260. # node itself
  261. matrix[rowidx][i] = ids[m]
  262. nodes[ids[m]] = tree[m]
  263. # add column to the set of children for its parent
  264. if m != ():
  265. childcols[m[:-1]].add((rowidx, i))
  266. assert len(positions) == 0
  267. # remove unused columns, right to left
  268. for m in range(scale * len(sentence) - 1, -1, -1):
  269. if not any(isinstance(row[m], (Tree, int)) for row in matrix):
  270. for row in matrix:
  271. del row[m]
  272. # remove unused rows, reverse
  273. matrix = [
  274. row
  275. for row in reversed(matrix)
  276. if not all(a is None or a == vertline for a in row)
  277. ]
  278. # collect coordinates of nodes
  279. coords = {}
  280. for n, _ in enumerate(matrix):
  281. for m, i in enumerate(matrix[n]):
  282. if isinstance(i, int) and i >= 0:
  283. coords[i] = n, m
  284. # move crossed edges last
  285. positions = sorted(
  286. [a for level in levels.values() for a in level],
  287. key=lambda a: a[:-1] in crossed,
  288. )
  289. # collect edges from node to node
  290. edges = OrderedDict()
  291. for i in reversed(positions):
  292. for j, _ in enumerate(tree[i]):
  293. edges[ids[i + (j,)]] = ids[i]
  294. return nodes, coords, edges, highlighted_nodes
  295. def text(
  296. self,
  297. nodedist=1,
  298. unicodelines=False,
  299. html=False,
  300. ansi=False,
  301. nodecolor='blue',
  302. leafcolor='red',
  303. funccolor='green',
  304. abbreviate=None,
  305. maxwidth=16,
  306. ):
  307. """
  308. :return: ASCII art for a discontinuous tree.
  309. :param unicodelines: whether to use Unicode line drawing characters
  310. instead of plain (7-bit) ASCII.
  311. :param html: whether to wrap output in html code (default plain text).
  312. :param ansi: whether to produce colors with ANSI escape sequences
  313. (only effective when html==False).
  314. :param leafcolor, nodecolor: specify colors of leaves and phrasal
  315. nodes; effective when either html or ansi is True.
  316. :param abbreviate: if True, abbreviate labels longer than 5 characters.
  317. If integer, abbreviate labels longer than `abbr` characters.
  318. :param maxwidth: maximum number of characters before a label starts to
  319. wrap; pass None to disable.
  320. """
  321. if abbreviate == True:
  322. abbreviate = 5
  323. if unicodelines:
  324. horzline = '\u2500'
  325. leftcorner = '\u250c'
  326. rightcorner = '\u2510'
  327. vertline = ' \u2502 '
  328. tee = horzline + '\u252C' + horzline
  329. bottom = horzline + '\u2534' + horzline
  330. cross = horzline + '\u253c' + horzline
  331. ellipsis = '\u2026'
  332. else:
  333. horzline = '_'
  334. leftcorner = rightcorner = ' '
  335. vertline = ' | '
  336. tee = 3 * horzline
  337. cross = bottom = '_|_'
  338. ellipsis = '.'
  339. def crosscell(cur, x=vertline):
  340. """Overwrite center of this cell with a vertical branch."""
  341. splitl = len(cur) - len(cur) // 2 - len(x) // 2 - 1
  342. lst = list(cur)
  343. lst[splitl : splitl + len(x)] = list(x)
  344. return ''.join(lst)
  345. result = []
  346. matrix = defaultdict(dict)
  347. maxnodewith = defaultdict(lambda: 3)
  348. maxnodeheight = defaultdict(lambda: 1)
  349. maxcol = 0
  350. minchildcol = {}
  351. maxchildcol = {}
  352. childcols = defaultdict(set)
  353. labels = {}
  354. wrapre = re.compile(
  355. '(.{%d,%d}\\b\\W*|.{%d})' % (maxwidth - 4, maxwidth, maxwidth)
  356. )
  357. # collect labels and coordinates
  358. for a in self.nodes:
  359. row, column = self.coords[a]
  360. matrix[row][column] = a
  361. maxcol = max(maxcol, column)
  362. label = (
  363. self.nodes[a].label()
  364. if isinstance(self.nodes[a], Tree)
  365. else self.nodes[a]
  366. )
  367. if abbreviate and len(label) > abbreviate:
  368. label = label[:abbreviate] + ellipsis
  369. if maxwidth and len(label) > maxwidth:
  370. label = wrapre.sub(r'\1\n', label).strip()
  371. label = label.split('\n')
  372. maxnodeheight[row] = max(maxnodeheight[row], len(label))
  373. maxnodewith[column] = max(maxnodewith[column], max(map(len, label)))
  374. labels[a] = label
  375. if a not in self.edges:
  376. continue # e.g., root
  377. parent = self.edges[a]
  378. childcols[parent].add((row, column))
  379. minchildcol[parent] = min(minchildcol.get(parent, column), column)
  380. maxchildcol[parent] = max(maxchildcol.get(parent, column), column)
  381. # bottom up level order traversal
  382. for row in sorted(matrix, reverse=True):
  383. noderows = [
  384. [''.center(maxnodewith[col]) for col in range(maxcol + 1)]
  385. for _ in range(maxnodeheight[row])
  386. ]
  387. branchrow = [''.center(maxnodewith[col]) for col in range(maxcol + 1)]
  388. for col in matrix[row]:
  389. n = matrix[row][col]
  390. node = self.nodes[n]
  391. text = labels[n]
  392. if isinstance(node, Tree):
  393. # draw horizontal branch towards children for this node
  394. if n in minchildcol and minchildcol[n] < maxchildcol[n]:
  395. i, j = minchildcol[n], maxchildcol[n]
  396. a, b = (maxnodewith[i] + 1) // 2 - 1, maxnodewith[j] // 2
  397. branchrow[i] = ((' ' * a) + leftcorner).ljust(
  398. maxnodewith[i], horzline
  399. )
  400. branchrow[j] = (rightcorner + (' ' * b)).rjust(
  401. maxnodewith[j], horzline
  402. )
  403. for i in range(minchildcol[n] + 1, maxchildcol[n]):
  404. if i == col and any(a == i for _, a in childcols[n]):
  405. line = cross
  406. elif i == col:
  407. line = bottom
  408. elif any(a == i for _, a in childcols[n]):
  409. line = tee
  410. else:
  411. line = horzline
  412. branchrow[i] = line.center(maxnodewith[i], horzline)
  413. else: # if n and n in minchildcol:
  414. branchrow[col] = crosscell(branchrow[col])
  415. text = [a.center(maxnodewith[col]) for a in text]
  416. color = nodecolor if isinstance(node, Tree) else leafcolor
  417. if isinstance(node, Tree) and node.label().startswith('-'):
  418. color = funccolor
  419. if html:
  420. text = [escape(a) for a in text]
  421. if n in self.highlight:
  422. text = ['<font color=%s>%s</font>' % (color, a) for a in text]
  423. elif ansi and n in self.highlight:
  424. text = ['\x1b[%d;1m%s\x1b[0m' % (ANSICOLOR[color], a) for a in text]
  425. for x in range(maxnodeheight[row]):
  426. # draw vertical lines in partially filled multiline node
  427. # labels, but only if it's not a frontier node.
  428. noderows[x][col] = (
  429. text[x]
  430. if x < len(text)
  431. else (vertline if childcols[n] else ' ').center(
  432. maxnodewith[col], ' '
  433. )
  434. )
  435. # for each column, if there is a node below us which has a parent
  436. # above us, draw a vertical branch in that column.
  437. if row != max(matrix):
  438. for n, (childrow, col) in self.coords.items():
  439. if n > 0 and self.coords[self.edges[n]][0] < row < childrow:
  440. branchrow[col] = crosscell(branchrow[col])
  441. if col not in matrix[row]:
  442. for noderow in noderows:
  443. noderow[col] = crosscell(noderow[col])
  444. branchrow = [
  445. a + ((a[-1] if a[-1] != ' ' else b[0]) * nodedist)
  446. for a, b in zip(branchrow, branchrow[1:] + [' '])
  447. ]
  448. result.append(''.join(branchrow))
  449. result.extend(
  450. (' ' * nodedist).join(noderow) for noderow in reversed(noderows)
  451. )
  452. return '\n'.join(reversed(result)) + '\n'
  453. def svg(self, nodecolor='blue', leafcolor='red', funccolor='green'):
  454. """
  455. :return: SVG representation of a tree.
  456. """
  457. fontsize = 12
  458. hscale = 40
  459. vscale = 25
  460. hstart = vstart = 20
  461. width = max(col for _, col in self.coords.values())
  462. height = max(row for row, _ in self.coords.values())
  463. result = [
  464. '<svg version="1.1" xmlns="http://www.w3.org/2000/svg" '
  465. 'width="%dem" height="%dem" viewBox="%d %d %d %d">'
  466. % (
  467. width * 3,
  468. height * 2.5,
  469. -hstart,
  470. -vstart,
  471. width * hscale + 3 * hstart,
  472. height * vscale + 3 * vstart,
  473. )
  474. ]
  475. children = defaultdict(set)
  476. for n in self.nodes:
  477. if n:
  478. children[self.edges[n]].add(n)
  479. # horizontal branches from nodes to children
  480. for node in self.nodes:
  481. if not children[node]:
  482. continue
  483. y, x = self.coords[node]
  484. x *= hscale
  485. y *= vscale
  486. x += hstart
  487. y += vstart + fontsize // 2
  488. childx = [self.coords[c][1] for c in children[node]]
  489. xmin = hstart + hscale * min(childx)
  490. xmax = hstart + hscale * max(childx)
  491. result.append(
  492. '\t<polyline style="stroke:black; stroke-width:1; fill:none;" '
  493. 'points="%g,%g %g,%g" />' % (xmin, y, xmax, y)
  494. )
  495. result.append(
  496. '\t<polyline style="stroke:black; stroke-width:1; fill:none;" '
  497. 'points="%g,%g %g,%g" />' % (x, y, x, y - fontsize // 3)
  498. )
  499. # vertical branches from children to parents
  500. for child, parent in self.edges.items():
  501. y, _ = self.coords[parent]
  502. y *= vscale
  503. y += vstart + fontsize // 2
  504. childy, childx = self.coords[child]
  505. childx *= hscale
  506. childy *= vscale
  507. childx += hstart
  508. childy += vstart - fontsize
  509. result += [
  510. '\t<polyline style="stroke:white; stroke-width:10; fill:none;"'
  511. ' points="%g,%g %g,%g" />' % (childx, childy, childx, y + 5),
  512. '\t<polyline style="stroke:black; stroke-width:1; fill:none;"'
  513. ' points="%g,%g %g,%g" />' % (childx, childy, childx, y),
  514. ]
  515. # write nodes with coordinates
  516. for n, (row, column) in self.coords.items():
  517. node = self.nodes[n]
  518. x = column * hscale + hstart
  519. y = row * vscale + vstart
  520. if n in self.highlight:
  521. color = nodecolor if isinstance(node, Tree) else leafcolor
  522. if isinstance(node, Tree) and node.label().startswith('-'):
  523. color = funccolor
  524. else:
  525. color = 'black'
  526. result += [
  527. '\t<text style="text-anchor: middle; fill: %s; '
  528. 'font-size: %dpx;" x="%g" y="%g">%s</text>'
  529. % (
  530. color,
  531. fontsize,
  532. x,
  533. y,
  534. escape(node.label() if isinstance(node, Tree) else node),
  535. )
  536. ]
  537. result += ['</svg>']
  538. return '\n'.join(result)
  539. def test():
  540. """Do some tree drawing tests."""
  541. def print_tree(n, tree, sentence=None, ansi=True, **xargs):
  542. print()
  543. print('{0}: "{1}"'.format(n, ' '.join(sentence or tree.leaves())))
  544. print(tree)
  545. print()
  546. drawtree = TreePrettyPrinter(tree, sentence)
  547. try:
  548. print(drawtree.text(unicodelines=ansi, ansi=ansi, **xargs))
  549. except (UnicodeDecodeError, UnicodeEncodeError):
  550. print(drawtree.text(unicodelines=False, ansi=False, **xargs))
  551. from nltk.corpus import treebank
  552. for n in [0, 1440, 1591, 2771, 2170]:
  553. tree = treebank.parsed_sents()[n]
  554. print_tree(n, tree, nodedist=2, maxwidth=8)
  555. print()
  556. print('ASCII version:')
  557. print(TreePrettyPrinter(tree).text(nodedist=2))
  558. tree = Tree.fromstring(
  559. '(top (punct 8) (smain (noun 0) (verb 1) (inf (verb 5) (inf (verb 6) '
  560. '(conj (inf (pp (prep 2) (np (det 3) (noun 4))) (verb 7)) (inf (verb 9)) '
  561. '(vg 10) (inf (verb 11)))))) (punct 12))',
  562. read_leaf=int,
  563. )
  564. sentence = (
  565. 'Ze had met haar moeder kunnen gaan winkelen ,'
  566. ' zwemmen of terrassen .'.split()
  567. )
  568. print_tree('Discontinuous tree', tree, sentence, nodedist=2)
  569. __all__ = ['TreePrettyPrinter']
  570. if __name__ == '__main__':
  571. test()