You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

150 lines
4.8 KiB

4 years ago
  1. import contextlib
  2. import json
  3. import os
  4. import shutil
  5. import sys
  6. import tempfile
  7. __all__ = ['MockCommand', 'assert_calls']
  8. pkgdir = os.path.dirname(__file__)
  9. recording_dir = None
  10. def prepend_to_path(dir):
  11. os.environ['PATH'] = dir + os.pathsep + os.environ['PATH']
  12. def remove_from_path(dir):
  13. path_dirs = os.environ['PATH'].split(os.pathsep)
  14. path_dirs.remove(dir)
  15. os.environ['PATH'] = os.pathsep.join(path_dirs)
  16. _record_run = """#!{python}
  17. import os, sys
  18. import json
  19. with open({recording_file!r}, 'a') as f:
  20. json.dump({{'env': dict(os.environ),
  21. 'argv': sys.argv,
  22. 'cwd': os.getcwd()}},
  23. f)
  24. f.write('\\x1e') # ASCII record separator
  25. """
  26. # TODO: Overlapping calls to the same command may interleave writes.
  27. class MockCommand(object):
  28. """Context manager to mock a system command.
  29. The mock command will be written to a directory at the front of $PATH,
  30. taking precedence over any existing command with the same name.
  31. By specifying content as a string, you can determine what running the
  32. command will do. The default content records each time the command is
  33. called and exits: you can access these records with mockcmd.get_calls().
  34. On Windows, the specified content will be run by the Python interpreter in
  35. use. On Unix, it should start with a shebang (``#!/path/to/interpreter``).
  36. """
  37. def __init__(self, name, content=None):
  38. global recording_dir
  39. self.name = name
  40. self.content = content
  41. if recording_dir is None:
  42. recording_dir = tempfile.mkdtemp()
  43. fd, self.recording_file = tempfile.mkstemp(dir=recording_dir,
  44. prefix=name, suffix='.json')
  45. os.close(fd)
  46. self.command_dir = tempfile.mkdtemp()
  47. def _copy_exe(self):
  48. bitness = '64' if (sys.maxsize > 2**32) else '32'
  49. src = os.path.join(pkgdir, 'cli-%s.exe' % bitness)
  50. dst = os.path.join(self.command_dir, self.name+'.exe')
  51. shutil.copy(src, dst)
  52. @property
  53. def _cmd_path(self):
  54. # Can only be used once commands_dir has been set
  55. p = os.path.join(self.command_dir, self.name)
  56. if os.name == 'nt':
  57. p += '-script.py'
  58. return p
  59. def __enter__(self):
  60. if os.path.isfile(self._cmd_path):
  61. raise EnvironmentError("Command %r already exists at %s" %
  62. (self.name, self._cmd_path))
  63. if self.content is None:
  64. self.content = _record_run.format(python=sys.executable,
  65. recording_file=self.recording_file)
  66. with open(self._cmd_path, 'w') as f:
  67. f.write(self.content)
  68. if os.name == 'nt':
  69. self._copy_exe()
  70. else:
  71. os.chmod(self._cmd_path, 0o755) # Set executable bit
  72. prepend_to_path(self.command_dir)
  73. return self
  74. def __exit__(self, etype, evalue, tb):
  75. remove_from_path(self.command_dir)
  76. shutil.rmtree(self.command_dir, ignore_errors=True)
  77. def get_calls(self):
  78. """Get a list of calls made to this mocked command.
  79. This relies on the default script content, so it will return an
  80. empty list if you specified a different content parameter.
  81. For each time the command was run, the list will contain a dictionary
  82. with keys argv, env and cwd.
  83. """
  84. if recording_dir is None:
  85. return []
  86. if not os.path.isfile(self.recording_file):
  87. return []
  88. with open(self.recording_file, 'r') as f:
  89. # 1E is ASCII record separator, last chunk is empty
  90. chunks = f.read().split('\x1e')[:-1]
  91. return [json.loads(c) for c in chunks]
  92. @contextlib.contextmanager
  93. def assert_calls(cmd, args=None):
  94. """Assert that a block of code runs the given command.
  95. If args is passed, also check that it was called at least once with the
  96. given arguments (not including the command name).
  97. Use as a context manager, e.g.::
  98. with assert_calls('git'):
  99. some_function_wrapping_git()
  100. with assert_calls('git', ['add', myfile]):
  101. some_other_function()
  102. """
  103. with MockCommand(cmd) as mc:
  104. yield
  105. calls = mc.get_calls()
  106. assert calls != [], "Command %r was not called" % cmd
  107. if args is not None:
  108. if not any(args == c['argv'][1:] for c in calls):
  109. msg = ["Command %r was not called with specified args (%r)" %
  110. (cmd, args),
  111. "It was called with these arguments: "]
  112. for c in calls:
  113. msg.append(' %r' % c['argv'][1:])
  114. raise AssertionError('\n'.join(msg))