You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

217 lines
7.8 KiB

4 years ago
  1. from __future__ import absolute_import
  2. import os
  3. import unittest
  4. import tempfile
  5. from .Compiler import Errors
  6. from .CodeWriter import CodeWriter
  7. from .Compiler.TreeFragment import TreeFragment, strip_common_indent
  8. from .Compiler.Visitor import TreeVisitor, VisitorTransform
  9. from .Compiler import TreePath
  10. class NodeTypeWriter(TreeVisitor):
  11. def __init__(self):
  12. super(NodeTypeWriter, self).__init__()
  13. self._indents = 0
  14. self.result = []
  15. def visit_Node(self, node):
  16. if not self.access_path:
  17. name = u"(root)"
  18. else:
  19. tip = self.access_path[-1]
  20. if tip[2] is not None:
  21. name = u"%s[%d]" % tip[1:3]
  22. else:
  23. name = tip[1]
  24. self.result.append(u" " * self._indents +
  25. u"%s: %s" % (name, node.__class__.__name__))
  26. self._indents += 1
  27. self.visitchildren(node)
  28. self._indents -= 1
  29. def treetypes(root):
  30. """Returns a string representing the tree by class names.
  31. There's a leading and trailing whitespace so that it can be
  32. compared by simple string comparison while still making test
  33. cases look ok."""
  34. w = NodeTypeWriter()
  35. w.visit(root)
  36. return u"\n".join([u""] + w.result + [u""])
  37. class CythonTest(unittest.TestCase):
  38. def setUp(self):
  39. self.listing_file = Errors.listing_file
  40. self.echo_file = Errors.echo_file
  41. Errors.listing_file = Errors.echo_file = None
  42. def tearDown(self):
  43. Errors.listing_file = self.listing_file
  44. Errors.echo_file = self.echo_file
  45. def assertLines(self, expected, result):
  46. "Checks that the given strings or lists of strings are equal line by line"
  47. if not isinstance(expected, list):
  48. expected = expected.split(u"\n")
  49. if not isinstance(result, list):
  50. result = result.split(u"\n")
  51. for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
  52. self.assertEqual(expected_line, result_line,
  53. "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
  54. self.assertEqual(len(expected), len(result),
  55. "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
  56. def codeToLines(self, tree):
  57. writer = CodeWriter()
  58. writer.write(tree)
  59. return writer.result.lines
  60. def codeToString(self, tree):
  61. return "\n".join(self.codeToLines(tree))
  62. def assertCode(self, expected, result_tree):
  63. result_lines = self.codeToLines(result_tree)
  64. expected_lines = strip_common_indent(expected.split("\n"))
  65. for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)):
  66. self.assertEqual(expected_line, line,
  67. "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line))
  68. self.assertEqual(len(result_lines), len(expected_lines),
  69. "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
  70. def assertNodeExists(self, path, result_tree):
  71. self.assertNotEqual(TreePath.find_first(result_tree, path), None,
  72. "Path '%s' not found in result tree" % path)
  73. def fragment(self, code, pxds=None, pipeline=None):
  74. "Simply create a tree fragment using the name of the test-case in parse errors."
  75. if pxds is None:
  76. pxds = {}
  77. if pipeline is None:
  78. pipeline = []
  79. name = self.id()
  80. if name.startswith("__main__."):
  81. name = name[len("__main__."):]
  82. name = name.replace(".", "_")
  83. return TreeFragment(code, name, pxds, pipeline=pipeline)
  84. def treetypes(self, root):
  85. return treetypes(root)
  86. def should_fail(self, func, exc_type=Exception):
  87. """Calls "func" and fails if it doesn't raise the right exception
  88. (any exception by default). Also returns the exception in question.
  89. """
  90. try:
  91. func()
  92. self.fail("Expected an exception of type %r" % exc_type)
  93. except exc_type as e:
  94. self.assertTrue(isinstance(e, exc_type))
  95. return e
  96. def should_not_fail(self, func):
  97. """Calls func and succeeds if and only if no exception is raised
  98. (i.e. converts exception raising into a failed testcase). Returns
  99. the return value of func."""
  100. try:
  101. return func()
  102. except Exception as exc:
  103. self.fail(str(exc))
  104. class TransformTest(CythonTest):
  105. """
  106. Utility base class for transform unit tests. It is based around constructing
  107. test trees (either explicitly or by parsing a Cython code string); running
  108. the transform, serialize it using a customized Cython serializer (with
  109. special markup for nodes that cannot be represented in Cython),
  110. and do a string-comparison line-by-line of the result.
  111. To create a test case:
  112. - Call run_pipeline. The pipeline should at least contain the transform you
  113. are testing; pyx should be either a string (passed to the parser to
  114. create a post-parse tree) or a node representing input to pipeline.
  115. The result will be a transformed result.
  116. - Check that the tree is correct. If wanted, assertCode can be used, which
  117. takes a code string as expected, and a ModuleNode in result_tree
  118. (it serializes the ModuleNode to a string and compares line-by-line).
  119. All code strings are first stripped for whitespace lines and then common
  120. indentation.
  121. Plans: One could have a pxd dictionary parameter to run_pipeline.
  122. """
  123. def run_pipeline(self, pipeline, pyx, pxds=None):
  124. if pxds is None:
  125. pxds = {}
  126. tree = self.fragment(pyx, pxds).root
  127. # Run pipeline
  128. for T in pipeline:
  129. tree = T(tree)
  130. return tree
  131. class TreeAssertVisitor(VisitorTransform):
  132. # actually, a TreeVisitor would be enough, but this needs to run
  133. # as part of the compiler pipeline
  134. def visit_CompilerDirectivesNode(self, node):
  135. directives = node.directives
  136. if 'test_assert_path_exists' in directives:
  137. for path in directives['test_assert_path_exists']:
  138. if TreePath.find_first(node, path) is None:
  139. Errors.error(
  140. node.pos,
  141. "Expected path '%s' not found in result tree" % path)
  142. if 'test_fail_if_path_exists' in directives:
  143. for path in directives['test_fail_if_path_exists']:
  144. if TreePath.find_first(node, path) is not None:
  145. Errors.error(
  146. node.pos,
  147. "Unexpected path '%s' found in result tree" % path)
  148. self.visitchildren(node)
  149. return node
  150. visit_Node = VisitorTransform.recurse_to_children
  151. def unpack_source_tree(tree_file, dir=None):
  152. if dir is None:
  153. dir = tempfile.mkdtemp()
  154. header = []
  155. cur_file = None
  156. f = open(tree_file)
  157. try:
  158. lines = f.readlines()
  159. finally:
  160. f.close()
  161. del f
  162. try:
  163. for line in lines:
  164. if line[:5] == '#####':
  165. filename = line.strip().strip('#').strip().replace('/', os.path.sep)
  166. path = os.path.join(dir, filename)
  167. if not os.path.exists(os.path.dirname(path)):
  168. os.makedirs(os.path.dirname(path))
  169. if cur_file is not None:
  170. f, cur_file = cur_file, None
  171. f.close()
  172. cur_file = open(path, 'w')
  173. elif cur_file is not None:
  174. cur_file.write(line)
  175. elif line.strip() and not line.lstrip().startswith('#'):
  176. if line.strip() not in ('"""', "'''"):
  177. header.append(line)
  178. finally:
  179. if cur_file is not None:
  180. cur_file.close()
  181. return dir, ''.join(header)