from __future__ import absolute_import, division, print_function from itertools import chain from .utils_test import add, inc # noqa: F401 def ishashable(x): """ Is x hashable? Examples -------- >>> ishashable(1) True >>> ishashable([1]) False """ try: hash(x) return True except TypeError: return False def istask(x): """ Is x a runnable task? A task is a tuple with a callable first argument Examples -------- >>> inc = lambda x: x + 1 >>> istask((inc, 1)) True >>> istask(1) False """ return type(x) is tuple and x and callable(x[0]) def has_tasks(dsk, x): """Whether ``x`` has anything to compute. Returns True if: - ``x`` is a task - ``x`` is a key in ``dsk`` - ``x`` is a list that contains any tasks or keys """ if istask(x): return True try: if x in dsk: return True except Exception: pass if isinstance(x, list): for i in x: if has_tasks(dsk, i): return True return False def preorder_traversal(task): """A generator to preorder-traverse a task.""" for item in task: if istask(item): for i in preorder_traversal(item): yield i elif isinstance(item, list): yield list for i in preorder_traversal(item): yield i else: yield item def lists_to_tuples(res, keys): if isinstance(keys, list): return tuple(lists_to_tuples(r, k) for r, k in zip(res, keys)) return res def _execute_task(arg, cache, dsk=None): """ Do the actual work of collecting data and executing a function Examples -------- >>> cache = {'x': 1, 'y': 2} Compute tasks against a cache >>> _execute_task((add, 'x', 1), cache) # Compute task in naive manner 2 >>> _execute_task((add, (inc, 'x'), 1), cache) # Support nested computation 3 Also grab data from cache >>> _execute_task('x', cache) 1 Support nested lists >>> list(_execute_task(['x', 'y'], cache)) [1, 2] >>> list(map(list, _execute_task([['x', 'y'], ['y', 'x']], cache))) [[1, 2], [2, 1]] >>> _execute_task('foo', cache) # Passes through on non-keys 'foo' """ if isinstance(arg, list): return [_execute_task(a, cache) for a in arg] elif istask(arg): func, args = arg[0], arg[1:] args2 = [_execute_task(a, cache) for a in args] return func(*args2) elif not ishashable(arg): return arg elif arg in cache: return cache[arg] else: return arg def get(dsk, out, cache=None): """ Get value from Dask Examples -------- >>> inc = lambda x: x + 1 >>> d = {'x': 1, 'y': (inc, 'x')} >>> get(d, 'x') 1 >>> get(d, 'y') 2 """ for k in (flatten(out) if isinstance(out, list) else [out]): if k not in dsk: raise KeyError("{0} is not a key in the graph".format(k)) if cache is None: cache = {} for key in toposort(dsk): task = dsk[key] result = _execute_task(task, cache) cache[key] = result result = _execute_task(out, cache) if isinstance(out, list): result = lists_to_tuples(result, out) return result def get_dependencies(dsk, key=None, task=None, as_list=False): """ Get the immediate tasks on which this task depends Examples -------- >>> dsk = {'x': 1, ... 'y': (inc, 'x'), ... 'z': (add, 'x', 'y'), ... 'w': (inc, 'z'), ... 'a': (add, (inc, 'x'), 1)} >>> get_dependencies(dsk, 'x') set() >>> get_dependencies(dsk, 'y') {'x'} >>> get_dependencies(dsk, 'z') # doctest: +SKIP {'x', 'y'} >>> get_dependencies(dsk, 'w') # Only direct dependencies {'z'} >>> get_dependencies(dsk, 'a') # Ignore non-keys {'x'} >>> get_dependencies(dsk, task=(inc, 'x')) # provide tasks directly {'x'} """ if key is not None: arg = dsk[key] elif task is not None: arg = task else: raise ValueError("Provide either key or task") result = [] work = [arg] while work: new_work = [] for w in work: typ = type(w) if typ is tuple and w and callable(w[0]): # istask(w) new_work += w[1:] elif typ is list: new_work += w elif typ is dict: new_work += w.values() else: try: if w in dsk: result.append(w) except TypeError: # not hashable pass work = new_work return result if as_list else set(result) def get_deps(dsk): """ Get dependencies and dependents from dask dask graph >>> dsk = {'a': 1, 'b': (inc, 'a'), 'c': (inc, 'b')} >>> dependencies, dependents = get_deps(dsk) >>> dependencies {'a': set(), 'b': {'a'}, 'c': {'b'}} >>> dependents {'a': {'b'}, 'b': {'c'}, 'c': set()} """ dependencies = {k: get_dependencies(dsk, task=v) for k, v in dsk.items()} dependents = reverse_dict(dependencies) return dependencies, dependents def flatten(seq, container=list): """ >>> list(flatten([1])) [1] >>> list(flatten([[1, 2], [1, 2]])) [1, 2, 1, 2] >>> list(flatten([[[1], [2]], [[1], [2]]])) [1, 2, 1, 2] >>> list(flatten(((1, 2), (1, 2)))) # Don't flatten tuples [(1, 2), (1, 2)] >>> list(flatten((1, 2, [3, 4]))) # support heterogeneous [1, 2, 3, 4] """ if isinstance(seq, str): yield seq else: for item in seq: if isinstance(item, container): for item2 in flatten(item, container=container): yield item2 else: yield item def reverse_dict(d): """ >>> a, b, c = 'abc' >>> d = {a: [b, c], b: [c]} >>> reverse_dict(d) # doctest: +SKIP {'a': set([]), 'b': set(['a']}, 'c': set(['a', 'b'])} """ terms = list(d.keys()) + list(chain.from_iterable(d.values())) result = {t: set() for t in terms} for k, vals in d.items(): for val in vals: result[val].add(k) return result def subs(task, key, val): """ Perform a substitution on a task Examples -------- >>> subs((inc, 'x'), 'x', 1) # doctest: +SKIP (inc, 1) """ type_task = type(task) if not (type_task is tuple and task and callable(task[0])): # istask(task): try: if type_task is type(key) and task == key: return val except Exception: pass if type_task is list: return [subs(x, key, val) for x in task] return task newargs = [] for arg in task[1:]: type_arg = type(arg) if type_arg is tuple and arg and callable(arg[0]): # istask(task): arg = subs(arg, key, val) elif type_arg is list: arg = [subs(x, key, val) for x in arg] elif type_arg is type(key): try: # Can't do a simple equality check, since this may trigger # a FutureWarning from NumPy about array equality # https://github.com/dask/dask/pull/2457 if len(arg) == len(key) and all(type(aa) == type(bb) and aa == bb for aa, bb in zip(arg, key)): arg = val except (TypeError, AttributeError): # Handle keys which are not sized (len() fails), but are hashable if arg == key: arg = val newargs.append(arg) return task[:1] + tuple(newargs) def _toposort(dsk, keys=None, returncycle=False, dependencies=None): # Stack-based depth-first search traversal. This is based on Tarjan's # method for topological sorting (see wikipedia for pseudocode) if keys is None: keys = dsk elif not isinstance(keys, list): keys = [keys] if not returncycle: ordered = [] # Nodes whose descendents have been completely explored. # These nodes are guaranteed to not be part of a cycle. completed = set() # All nodes that have been visited in the current traversal. Because # we are doing depth-first search, going "deeper" should never result # in visiting a node that has already been seen. The `seen` and # `completed` sets are mutually exclusive; it is okay to visit a node # that has already been added to `completed`. seen = set() if dependencies is None: dependencies = dict((k, get_dependencies(dsk, k)) for k in dsk) for key in keys: if key in completed: continue nodes = [key] while nodes: # Keep current node on the stack until all descendants are visited cur = nodes[-1] if cur in completed: # Already fully traversed descendants of cur nodes.pop() continue seen.add(cur) # Add direct descendants of cur to nodes stack next_nodes = [] for nxt in dependencies[cur]: if nxt not in completed: if nxt in seen: # Cycle detected! cycle = [nxt] while nodes[-1] != nxt: cycle.append(nodes.pop()) cycle.append(nodes.pop()) cycle.reverse() if returncycle: return cycle else: cycle = '->'.join(str(x) for x in cycle) raise RuntimeError('Cycle detected in Dask: %s' % cycle) next_nodes.append(nxt) if next_nodes: nodes.extend(next_nodes) else: # cur has no more descendants to explore, so we're done with it if not returncycle: ordered.append(cur) completed.add(cur) seen.remove(cur) nodes.pop() if returncycle: return [] return ordered def toposort(dsk, dependencies=None): """ Return a list of keys of dask sorted in topological order.""" return _toposort(dsk, dependencies=dependencies) def getcycle(d, keys): """ Return a list of nodes that form a cycle if Dask is not a DAG. Returns an empty list if no cycle is found. ``keys`` may be a single key or list of keys. Examples -------- >>> d = {'x': (inc, 'z'), 'y': (inc, 'x'), 'z': (inc, 'y')} >>> getcycle(d, 'x') ['x', 'z', 'y', 'x'] See Also -------- isdag """ return _toposort(d, keys=keys, returncycle=True) def isdag(d, keys): """ Does Dask form a directed acyclic graph when calculating keys? ``keys`` may be a single key or list of keys. Examples -------- >>> inc = lambda x: x + 1 >>> isdag({'x': 0, 'y': (inc, 'x')}, 'y') True >>> isdag({'x': (inc, 'y'), 'y': (inc, 'x')}, 'y') False See Also -------- getcycle """ return not getcycle(d, keys) class literal(object): """A small serializable object to wrap literal values without copying""" __slots__ = ('data',) def __init__(self, data): self.data = data def __repr__(self): return 'literal' % type(self.data).__name__ def __reduce__(self): return (literal, (self.data,)) def __call__(self): return self.data def quote(x): """ Ensure that this value remains this value in a dask graph Some values in dask graph take on special meaning. Sometimes we want to ensure that our data is not interpreted but remains literal. >>> quote((add, 1, 2)) # doctest: +SKIP (literal,) """ if istask(x) or type(x) is list: return (literal(x),) return x