import contextlib import json import os import shutil import sys import tempfile __all__ = ['MockCommand', 'assert_calls'] pkgdir = os.path.dirname(__file__) recording_dir = None def prepend_to_path(dir): os.environ['PATH'] = dir + os.pathsep + os.environ['PATH'] def remove_from_path(dir): path_dirs = os.environ['PATH'].split(os.pathsep) path_dirs.remove(dir) os.environ['PATH'] = os.pathsep.join(path_dirs) _record_run = """#!{python} import os, sys import json with open({recording_file!r}, 'a') as f: json.dump({{'env': dict(os.environ), 'argv': sys.argv, 'cwd': os.getcwd()}}, f) f.write('\\x1e') # ASCII record separator """ # TODO: Overlapping calls to the same command may interleave writes. class MockCommand(object): """Context manager to mock a system command. The mock command will be written to a directory at the front of $PATH, taking precedence over any existing command with the same name. By specifying content as a string, you can determine what running the command will do. The default content records each time the command is called and exits: you can access these records with mockcmd.get_calls(). On Windows, the specified content will be run by the Python interpreter in use. On Unix, it should start with a shebang (``#!/path/to/interpreter``). """ def __init__(self, name, content=None): global recording_dir self.name = name self.content = content if recording_dir is None: recording_dir = tempfile.mkdtemp() fd, self.recording_file = tempfile.mkstemp(dir=recording_dir, prefix=name, suffix='.json') os.close(fd) self.command_dir = tempfile.mkdtemp() def _copy_exe(self): bitness = '64' if (sys.maxsize > 2**32) else '32' src = os.path.join(pkgdir, 'cli-%s.exe' % bitness) dst = os.path.join(self.command_dir, self.name+'.exe') shutil.copy(src, dst) @property def _cmd_path(self): # Can only be used once commands_dir has been set p = os.path.join(self.command_dir, self.name) if os.name == 'nt': p += '-script.py' return p def __enter__(self): if os.path.isfile(self._cmd_path): raise EnvironmentError("Command %r already exists at %s" % (self.name, self._cmd_path)) if self.content is None: self.content = _record_run.format(python=sys.executable, recording_file=self.recording_file) with open(self._cmd_path, 'w') as f: f.write(self.content) if os.name == 'nt': self._copy_exe() else: os.chmod(self._cmd_path, 0o755) # Set executable bit prepend_to_path(self.command_dir) return self def __exit__(self, etype, evalue, tb): remove_from_path(self.command_dir) shutil.rmtree(self.command_dir, ignore_errors=True) def get_calls(self): """Get a list of calls made to this mocked command. This relies on the default script content, so it will return an empty list if you specified a different content parameter. For each time the command was run, the list will contain a dictionary with keys argv, env and cwd. """ if recording_dir is None: return [] if not os.path.isfile(self.recording_file): return [] with open(self.recording_file, 'r') as f: # 1E is ASCII record separator, last chunk is empty chunks = f.read().split('\x1e')[:-1] return [json.loads(c) for c in chunks] @contextlib.contextmanager def assert_calls(cmd, args=None): """Assert that a block of code runs the given command. If args is passed, also check that it was called at least once with the given arguments (not including the command name). Use as a context manager, e.g.:: with assert_calls('git'): some_function_wrapping_git() with assert_calls('git', ['add', myfile]): some_other_function() """ with MockCommand(cmd) as mc: yield calls = mc.get_calls() assert calls != [], "Command %r was not called" % cmd if args is not None: if not any(args == c['argv'][1:] for c in calls): msg = ["Command %r was not called with specified args (%r)" % (cmd, args), "It was called with these arguments: "] for c in calls: msg.append(' %r' % c['argv'][1:]) raise AssertionError('\n'.join(msg))