151 lines
4.8 KiB
151 lines
4.8 KiB
import contextlib
|
|
import json
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
|
|
__all__ = ['MockCommand', 'assert_calls']
|
|
|
|
pkgdir = os.path.dirname(__file__)
|
|
|
|
recording_dir = None
|
|
|
|
def prepend_to_path(dir):
|
|
os.environ['PATH'] = dir + os.pathsep + os.environ['PATH']
|
|
|
|
def remove_from_path(dir):
|
|
path_dirs = os.environ['PATH'].split(os.pathsep)
|
|
path_dirs.remove(dir)
|
|
os.environ['PATH'] = os.pathsep.join(path_dirs)
|
|
|
|
|
|
_record_run = """#!{python}
|
|
import os, sys
|
|
import json
|
|
|
|
with open({recording_file!r}, 'a') as f:
|
|
json.dump({{'env': dict(os.environ),
|
|
'argv': sys.argv,
|
|
'cwd': os.getcwd()}},
|
|
f)
|
|
f.write('\\x1e') # ASCII record separator
|
|
"""
|
|
|
|
# TODO: Overlapping calls to the same command may interleave writes.
|
|
|
|
class MockCommand(object):
|
|
"""Context manager to mock a system command.
|
|
|
|
The mock command will be written to a directory at the front of $PATH,
|
|
taking precedence over any existing command with the same name.
|
|
|
|
By specifying content as a string, you can determine what running the
|
|
command will do. The default content records each time the command is
|
|
called and exits: you can access these records with mockcmd.get_calls().
|
|
|
|
On Windows, the specified content will be run by the Python interpreter in
|
|
use. On Unix, it should start with a shebang (``#!/path/to/interpreter``).
|
|
"""
|
|
def __init__(self, name, content=None):
|
|
global recording_dir
|
|
self.name = name
|
|
self.content = content
|
|
|
|
if recording_dir is None:
|
|
recording_dir = tempfile.mkdtemp()
|
|
fd, self.recording_file = tempfile.mkstemp(dir=recording_dir,
|
|
prefix=name, suffix='.json')
|
|
os.close(fd)
|
|
self.command_dir = tempfile.mkdtemp()
|
|
|
|
def _copy_exe(self):
|
|
bitness = '64' if (sys.maxsize > 2**32) else '32'
|
|
src = os.path.join(pkgdir, 'cli-%s.exe' % bitness)
|
|
dst = os.path.join(self.command_dir, self.name+'.exe')
|
|
shutil.copy(src, dst)
|
|
|
|
@property
|
|
def _cmd_path(self):
|
|
# Can only be used once commands_dir has been set
|
|
p = os.path.join(self.command_dir, self.name)
|
|
if os.name == 'nt':
|
|
p += '-script.py'
|
|
return p
|
|
|
|
def __enter__(self):
|
|
if os.path.isfile(self._cmd_path):
|
|
raise EnvironmentError("Command %r already exists at %s" %
|
|
(self.name, self._cmd_path))
|
|
|
|
if self.content is None:
|
|
self.content = _record_run.format(python=sys.executable,
|
|
recording_file=self.recording_file)
|
|
|
|
with open(self._cmd_path, 'w') as f:
|
|
f.write(self.content)
|
|
|
|
if os.name == 'nt':
|
|
self._copy_exe()
|
|
else:
|
|
os.chmod(self._cmd_path, 0o755) # Set executable bit
|
|
|
|
prepend_to_path(self.command_dir)
|
|
|
|
return self
|
|
|
|
def __exit__(self, etype, evalue, tb):
|
|
remove_from_path(self.command_dir)
|
|
shutil.rmtree(self.command_dir, ignore_errors=True)
|
|
|
|
def get_calls(self):
|
|
"""Get a list of calls made to this mocked command.
|
|
|
|
This relies on the default script content, so it will return an
|
|
empty list if you specified a different content parameter.
|
|
|
|
For each time the command was run, the list will contain a dictionary
|
|
with keys argv, env and cwd.
|
|
"""
|
|
if recording_dir is None:
|
|
return []
|
|
if not os.path.isfile(self.recording_file):
|
|
return []
|
|
|
|
with open(self.recording_file, 'r') as f:
|
|
# 1E is ASCII record separator, last chunk is empty
|
|
chunks = f.read().split('\x1e')[:-1]
|
|
|
|
return [json.loads(c) for c in chunks]
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def assert_calls(cmd, args=None):
|
|
"""Assert that a block of code runs the given command.
|
|
|
|
If args is passed, also check that it was called at least once with the
|
|
given arguments (not including the command name).
|
|
|
|
Use as a context manager, e.g.::
|
|
|
|
with assert_calls('git'):
|
|
some_function_wrapping_git()
|
|
|
|
with assert_calls('git', ['add', myfile]):
|
|
some_other_function()
|
|
"""
|
|
with MockCommand(cmd) as mc:
|
|
yield
|
|
|
|
calls = mc.get_calls()
|
|
assert calls != [], "Command %r was not called" % cmd
|
|
|
|
if args is not None:
|
|
if not any(args == c['argv'][1:] for c in calls):
|
|
msg = ["Command %r was not called with specified args (%r)" %
|
|
(cmd, args),
|
|
"It was called with these arguments: "]
|
|
for c in calls:
|
|
msg.append(' %r' % c['argv'][1:])
|
|
raise AssertionError('\n'.join(msg))
|