You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

508 lines
18 KiB

4 years ago
  1. """
  2. lxml-based doctest output comparison.
  3. Note: normally, you should just import the `lxml.usedoctest` and
  4. `lxml.html.usedoctest` modules from within a doctest, instead of this
  5. one::
  6. >>> import lxml.usedoctest # for XML output
  7. >>> import lxml.html.usedoctest # for HTML output
  8. To use this module directly, you must call ``lxmldoctest.install()``,
  9. which will cause doctest to use this in all subsequent calls.
  10. This changes the way output is checked and comparisons are made for
  11. XML or HTML-like content.
  12. XML or HTML content is noticed because the example starts with ``<``
  13. (it's HTML if it starts with ``<html``). You can also use the
  14. ``PARSE_HTML`` and ``PARSE_XML`` flags to force parsing.
  15. Some rough wildcard-like things are allowed. Whitespace is generally
  16. ignored (except in attributes). In text (attributes and text in the
  17. body) you can use ``...`` as a wildcard. In an example it also
  18. matches any trailing tags in the element, though it does not match
  19. leading tags. You may create a tag ``<any>`` or include an ``any``
  20. attribute in the tag. An ``any`` tag matches any tag, while the
  21. attribute matches any and all attributes.
  22. When a match fails, the reformatted example and gotten text is
  23. displayed (indented), and a rough diff-like output is given. Anything
  24. marked with ``+`` is in the output but wasn't supposed to be, and
  25. similarly ``-`` means its in the example but wasn't in the output.
  26. You can disable parsing on one line with ``# doctest:+NOPARSE_MARKUP``
  27. """
  28. from lxml import etree
  29. import sys
  30. import re
  31. import doctest
  32. try:
  33. from html import escape as html_escape
  34. except ImportError:
  35. from cgi import escape as html_escape
  36. __all__ = ['PARSE_HTML', 'PARSE_XML', 'NOPARSE_MARKUP', 'LXMLOutputChecker',
  37. 'LHTMLOutputChecker', 'install', 'temp_install']
  38. try:
  39. _basestring = basestring
  40. except NameError:
  41. _basestring = (str, bytes)
  42. _IS_PYTHON_3 = sys.version_info[0] >= 3
  43. PARSE_HTML = doctest.register_optionflag('PARSE_HTML')
  44. PARSE_XML = doctest.register_optionflag('PARSE_XML')
  45. NOPARSE_MARKUP = doctest.register_optionflag('NOPARSE_MARKUP')
  46. OutputChecker = doctest.OutputChecker
  47. def strip(v):
  48. if v is None:
  49. return None
  50. else:
  51. return v.strip()
  52. def norm_whitespace(v):
  53. return _norm_whitespace_re.sub(' ', v)
  54. _html_parser = etree.HTMLParser(recover=False, remove_blank_text=True)
  55. def html_fromstring(html):
  56. return etree.fromstring(html, _html_parser)
  57. # We use this to distinguish repr()s from elements:
  58. _repr_re = re.compile(r'^<[^>]+ (at|object) ')
  59. _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
  60. class LXMLOutputChecker(OutputChecker):
  61. empty_tags = (
  62. 'param', 'img', 'area', 'br', 'basefont', 'input',
  63. 'base', 'meta', 'link', 'col')
  64. def get_default_parser(self):
  65. return etree.XML
  66. def check_output(self, want, got, optionflags):
  67. alt_self = getattr(self, '_temp_override_self', None)
  68. if alt_self is not None:
  69. super_method = self._temp_call_super_check_output
  70. self = alt_self
  71. else:
  72. super_method = OutputChecker.check_output
  73. parser = self.get_parser(want, got, optionflags)
  74. if not parser:
  75. return super_method(
  76. self, want, got, optionflags)
  77. try:
  78. want_doc = parser(want)
  79. except etree.XMLSyntaxError:
  80. return False
  81. try:
  82. got_doc = parser(got)
  83. except etree.XMLSyntaxError:
  84. return False
  85. return self.compare_docs(want_doc, got_doc)
  86. def get_parser(self, want, got, optionflags):
  87. parser = None
  88. if NOPARSE_MARKUP & optionflags:
  89. return None
  90. if PARSE_HTML & optionflags:
  91. parser = html_fromstring
  92. elif PARSE_XML & optionflags:
  93. parser = etree.XML
  94. elif (want.strip().lower().startswith('<html')
  95. and got.strip().startswith('<html')):
  96. parser = html_fromstring
  97. elif (self._looks_like_markup(want)
  98. and self._looks_like_markup(got)):
  99. parser = self.get_default_parser()
  100. return parser
  101. def _looks_like_markup(self, s):
  102. s = s.strip()
  103. return (s.startswith('<')
  104. and not _repr_re.search(s))
  105. def compare_docs(self, want, got):
  106. if not self.tag_compare(want.tag, got.tag):
  107. return False
  108. if not self.text_compare(want.text, got.text, True):
  109. return False
  110. if not self.text_compare(want.tail, got.tail, True):
  111. return False
  112. if 'any' not in want.attrib:
  113. want_keys = sorted(want.attrib.keys())
  114. got_keys = sorted(got.attrib.keys())
  115. if want_keys != got_keys:
  116. return False
  117. for key in want_keys:
  118. if not self.text_compare(want.attrib[key], got.attrib[key], False):
  119. return False
  120. if want.text != '...' or len(want):
  121. want_children = list(want)
  122. got_children = list(got)
  123. while want_children or got_children:
  124. if not want_children or not got_children:
  125. return False
  126. want_first = want_children.pop(0)
  127. got_first = got_children.pop(0)
  128. if not self.compare_docs(want_first, got_first):
  129. return False
  130. if not got_children and want_first.tail == '...':
  131. break
  132. return True
  133. def text_compare(self, want, got, strip):
  134. want = want or ''
  135. got = got or ''
  136. if strip:
  137. want = norm_whitespace(want).strip()
  138. got = norm_whitespace(got).strip()
  139. want = '^%s$' % re.escape(want)
  140. want = want.replace(r'\.\.\.', '.*')
  141. if re.search(want, got):
  142. return True
  143. else:
  144. return False
  145. def tag_compare(self, want, got):
  146. if want == 'any':
  147. return True
  148. if (not isinstance(want, _basestring)
  149. or not isinstance(got, _basestring)):
  150. return want == got
  151. want = want or ''
  152. got = got or ''
  153. if want.startswith('{...}'):
  154. # Ellipsis on the namespace
  155. return want.split('}')[-1] == got.split('}')[-1]
  156. else:
  157. return want == got
  158. def output_difference(self, example, got, optionflags):
  159. want = example.want
  160. parser = self.get_parser(want, got, optionflags)
  161. errors = []
  162. if parser is not None:
  163. try:
  164. want_doc = parser(want)
  165. except etree.XMLSyntaxError:
  166. e = sys.exc_info()[1]
  167. errors.append('In example: %s' % e)
  168. try:
  169. got_doc = parser(got)
  170. except etree.XMLSyntaxError:
  171. e = sys.exc_info()[1]
  172. errors.append('In actual output: %s' % e)
  173. if parser is None or errors:
  174. value = OutputChecker.output_difference(
  175. self, example, got, optionflags)
  176. if errors:
  177. errors.append(value)
  178. return '\n'.join(errors)
  179. else:
  180. return value
  181. html = parser is html_fromstring
  182. diff_parts = []
  183. diff_parts.append('Expected:')
  184. diff_parts.append(self.format_doc(want_doc, html, 2))
  185. diff_parts.append('Got:')
  186. diff_parts.append(self.format_doc(got_doc, html, 2))
  187. diff_parts.append('Diff:')
  188. diff_parts.append(self.collect_diff(want_doc, got_doc, html, 2))
  189. return '\n'.join(diff_parts)
  190. def html_empty_tag(self, el, html=True):
  191. if not html:
  192. return False
  193. if el.tag not in self.empty_tags:
  194. return False
  195. if el.text or len(el):
  196. # This shouldn't happen (contents in an empty tag)
  197. return False
  198. return True
  199. def format_doc(self, doc, html, indent, prefix=''):
  200. parts = []
  201. if not len(doc):
  202. # No children...
  203. parts.append(' '*indent)
  204. parts.append(prefix)
  205. parts.append(self.format_tag(doc))
  206. if not self.html_empty_tag(doc, html):
  207. if strip(doc.text):
  208. parts.append(self.format_text(doc.text))
  209. parts.append(self.format_end_tag(doc))
  210. if strip(doc.tail):
  211. parts.append(self.format_text(doc.tail))
  212. parts.append('\n')
  213. return ''.join(parts)
  214. parts.append(' '*indent)
  215. parts.append(prefix)
  216. parts.append(self.format_tag(doc))
  217. if not self.html_empty_tag(doc, html):
  218. parts.append('\n')
  219. if strip(doc.text):
  220. parts.append(' '*indent)
  221. parts.append(self.format_text(doc.text))
  222. parts.append('\n')
  223. for el in doc:
  224. parts.append(self.format_doc(el, html, indent+2))
  225. parts.append(' '*indent)
  226. parts.append(self.format_end_tag(doc))
  227. parts.append('\n')
  228. if strip(doc.tail):
  229. parts.append(' '*indent)
  230. parts.append(self.format_text(doc.tail))
  231. parts.append('\n')
  232. return ''.join(parts)
  233. def format_text(self, text, strip=True):
  234. if text is None:
  235. return ''
  236. if strip:
  237. text = text.strip()
  238. return html_escape(text, 1)
  239. def format_tag(self, el):
  240. attrs = []
  241. if isinstance(el, etree.CommentBase):
  242. # FIXME: probably PIs should be handled specially too?
  243. return '<!--'
  244. for name, value in sorted(el.attrib.items()):
  245. attrs.append('%s="%s"' % (name, self.format_text(value, False)))
  246. if not attrs:
  247. return '<%s>' % el.tag
  248. return '<%s %s>' % (el.tag, ' '.join(attrs))
  249. def format_end_tag(self, el):
  250. if isinstance(el, etree.CommentBase):
  251. # FIXME: probably PIs should be handled specially too?
  252. return '-->'
  253. return '</%s>' % el.tag
  254. def collect_diff(self, want, got, html, indent):
  255. parts = []
  256. if not len(want) and not len(got):
  257. parts.append(' '*indent)
  258. parts.append(self.collect_diff_tag(want, got))
  259. if not self.html_empty_tag(got, html):
  260. parts.append(self.collect_diff_text(want.text, got.text))
  261. parts.append(self.collect_diff_end_tag(want, got))
  262. parts.append(self.collect_diff_text(want.tail, got.tail))
  263. parts.append('\n')
  264. return ''.join(parts)
  265. parts.append(' '*indent)
  266. parts.append(self.collect_diff_tag(want, got))
  267. parts.append('\n')
  268. if strip(want.text) or strip(got.text):
  269. parts.append(' '*indent)
  270. parts.append(self.collect_diff_text(want.text, got.text))
  271. parts.append('\n')
  272. want_children = list(want)
  273. got_children = list(got)
  274. while want_children or got_children:
  275. if not want_children:
  276. parts.append(self.format_doc(got_children.pop(0), html, indent+2, '+'))
  277. continue
  278. if not got_children:
  279. parts.append(self.format_doc(want_children.pop(0), html, indent+2, '-'))
  280. continue
  281. parts.append(self.collect_diff(
  282. want_children.pop(0), got_children.pop(0), html, indent+2))
  283. parts.append(' '*indent)
  284. parts.append(self.collect_diff_end_tag(want, got))
  285. parts.append('\n')
  286. if strip(want.tail) or strip(got.tail):
  287. parts.append(' '*indent)
  288. parts.append(self.collect_diff_text(want.tail, got.tail))
  289. parts.append('\n')
  290. return ''.join(parts)
  291. def collect_diff_tag(self, want, got):
  292. if not self.tag_compare(want.tag, got.tag):
  293. tag = '%s (got: %s)' % (want.tag, got.tag)
  294. else:
  295. tag = got.tag
  296. attrs = []
  297. any = want.tag == 'any' or 'any' in want.attrib
  298. for name, value in sorted(got.attrib.items()):
  299. if name not in want.attrib and not any:
  300. attrs.append('+%s="%s"' % (name, self.format_text(value, False)))
  301. else:
  302. if name in want.attrib:
  303. text = self.collect_diff_text(want.attrib[name], value, False)
  304. else:
  305. text = self.format_text(value, False)
  306. attrs.append('%s="%s"' % (name, text))
  307. if not any:
  308. for name, value in sorted(want.attrib.items()):
  309. if name in got.attrib:
  310. continue
  311. attrs.append('-%s="%s"' % (name, self.format_text(value, False)))
  312. if attrs:
  313. tag = '<%s %s>' % (tag, ' '.join(attrs))
  314. else:
  315. tag = '<%s>' % tag
  316. return tag
  317. def collect_diff_end_tag(self, want, got):
  318. if want.tag != got.tag:
  319. tag = '%s (got: %s)' % (want.tag, got.tag)
  320. else:
  321. tag = got.tag
  322. return '</%s>' % tag
  323. def collect_diff_text(self, want, got, strip=True):
  324. if self.text_compare(want, got, strip):
  325. if not got:
  326. return ''
  327. return self.format_text(got, strip)
  328. text = '%s (got: %s)' % (want, got)
  329. return self.format_text(text, strip)
  330. class LHTMLOutputChecker(LXMLOutputChecker):
  331. def get_default_parser(self):
  332. return html_fromstring
  333. def install(html=False):
  334. """
  335. Install doctestcompare for all future doctests.
  336. If html is true, then by default the HTML parser will be used;
  337. otherwise the XML parser is used.
  338. """
  339. if html:
  340. doctest.OutputChecker = LHTMLOutputChecker
  341. else:
  342. doctest.OutputChecker = LXMLOutputChecker
  343. def temp_install(html=False, del_module=None):
  344. """
  345. Use this *inside* a doctest to enable this checker for this
  346. doctest only.
  347. If html is true, then by default the HTML parser will be used;
  348. otherwise the XML parser is used.
  349. """
  350. if html:
  351. Checker = LHTMLOutputChecker
  352. else:
  353. Checker = LXMLOutputChecker
  354. frame = _find_doctest_frame()
  355. dt_self = frame.f_locals['self']
  356. checker = Checker()
  357. old_checker = dt_self._checker
  358. dt_self._checker = checker
  359. # The unfortunate thing is that there is a local variable 'check'
  360. # in the function that runs the doctests, that is a bound method
  361. # into the output checker. We have to update that. We can't
  362. # modify the frame, so we have to modify the object in place. The
  363. # only way to do this is to actually change the func_code
  364. # attribute of the method. We change it, and then wait for
  365. # __record_outcome to be run, which signals the end of the __run
  366. # method, at which point we restore the previous check_output
  367. # implementation.
  368. if _IS_PYTHON_3:
  369. check_func = frame.f_locals['check'].__func__
  370. checker_check_func = checker.check_output.__func__
  371. else:
  372. check_func = frame.f_locals['check'].im_func
  373. checker_check_func = checker.check_output.im_func
  374. # Because we can't patch up func_globals, this is the only global
  375. # in check_output that we care about:
  376. doctest.etree = etree
  377. _RestoreChecker(dt_self, old_checker, checker,
  378. check_func, checker_check_func,
  379. del_module)
  380. class _RestoreChecker(object):
  381. def __init__(self, dt_self, old_checker, new_checker, check_func, clone_func,
  382. del_module):
  383. self.dt_self = dt_self
  384. self.checker = old_checker
  385. self.checker._temp_call_super_check_output = self.call_super
  386. self.checker._temp_override_self = new_checker
  387. self.check_func = check_func
  388. self.clone_func = clone_func
  389. self.del_module = del_module
  390. self.install_clone()
  391. self.install_dt_self()
  392. def install_clone(self):
  393. if _IS_PYTHON_3:
  394. self.func_code = self.check_func.__code__
  395. self.func_globals = self.check_func.__globals__
  396. self.check_func.__code__ = self.clone_func.__code__
  397. else:
  398. self.func_code = self.check_func.func_code
  399. self.func_globals = self.check_func.func_globals
  400. self.check_func.func_code = self.clone_func.func_code
  401. def uninstall_clone(self):
  402. if _IS_PYTHON_3:
  403. self.check_func.__code__ = self.func_code
  404. else:
  405. self.check_func.func_code = self.func_code
  406. def install_dt_self(self):
  407. self.prev_func = self.dt_self._DocTestRunner__record_outcome
  408. self.dt_self._DocTestRunner__record_outcome = self
  409. def uninstall_dt_self(self):
  410. self.dt_self._DocTestRunner__record_outcome = self.prev_func
  411. def uninstall_module(self):
  412. if self.del_module:
  413. import sys
  414. del sys.modules[self.del_module]
  415. if '.' in self.del_module:
  416. package, module = self.del_module.rsplit('.', 1)
  417. package_mod = sys.modules[package]
  418. delattr(package_mod, module)
  419. def __call__(self, *args, **kw):
  420. self.uninstall_clone()
  421. self.uninstall_dt_self()
  422. del self.checker._temp_override_self
  423. del self.checker._temp_call_super_check_output
  424. result = self.prev_func(*args, **kw)
  425. self.uninstall_module()
  426. return result
  427. def call_super(self, *args, **kw):
  428. self.uninstall_clone()
  429. try:
  430. return self.check_func(*args, **kw)
  431. finally:
  432. self.install_clone()
  433. def _find_doctest_frame():
  434. import sys
  435. frame = sys._getframe(1)
  436. while frame:
  437. l = frame.f_locals
  438. if 'BOOM' in l:
  439. # Sign of doctest
  440. return frame
  441. frame = frame.f_back
  442. raise LookupError(
  443. "Could not find doctest (only use this function *inside* a doctest)")
  444. __test__ = {
  445. 'basic': '''
  446. >>> temp_install()
  447. >>> print """<xml a="1" b="2">stuff</xml>"""
  448. <xml b="2" a="1">...</xml>
  449. >>> print """<xml xmlns="http://example.com"><tag attr="bar" /></xml>"""
  450. <xml xmlns="...">
  451. <tag attr="..." />
  452. </xml>
  453. >>> print """<xml>blahblahblah<foo /></xml>""" # doctest: +NOPARSE_MARKUP, +ELLIPSIS
  454. <xml>...foo /></xml>
  455. '''}
  456. if __name__ == '__main__':
  457. import doctest
  458. doctest.testmod()