from __future__ import absolute_import, division, print_function from collections import deque from dask.core import istask, subs def head(task): """Return the top level node of a task""" if istask(task): return task[0] elif isinstance(task, list): return list else: return task def args(task): """Get the arguments for the current task""" if istask(task): return task[1:] elif isinstance(task, list): return task else: return () class Traverser(object): """Traverser interface for tasks. Class for storing the state while performing a preorder-traversal of a task. Parameters ---------- term : task The task to be traversed Attributes ---------- term The current element in the traversal current The head of the current element in the traversal. This is simply `head` applied to the attribute `term`. """ def __init__(self, term, stack=None): self.term = term if not stack: self._stack = deque([END]) else: self._stack = stack def __iter__(self): while self.current is not END: yield self.current self.next() def copy(self): """Copy the traverser in its current state. This allows the traversal to be pushed onto a stack, for easy backtracking.""" return Traverser(self.term, deque(self._stack)) def next(self): """Proceed to the next term in the preorder traversal.""" subterms = args(self.term) if not subterms: # No subterms, pop off stack self.term = self._stack.pop() else: self.term = subterms[0] self._stack.extend(reversed(subterms[1:])) @property def current(self): return head(self.term) def skip(self): """Skip over all subterms of the current level in the traversal""" self.term = self._stack.pop() class Token(object): """A token object. Used to express certain objects in the traversal of a task or pattern.""" def __init__(self, name): self.name = name def __repr__(self): return self.name # A variable to represent *all* variables in a discrimination net VAR = Token('?') # Represents the end of the traversal of an expression. We can't use `None`, # 'False', etc... here, as anything may be an argument to a function. END = Token('end') class Node(tuple): """A Discrimination Net node.""" __slots__ = () def __new__(cls, edges=None, patterns=None): edges = edges if edges else {} patterns = patterns if patterns else [] return tuple.__new__(cls, (edges, patterns)) @property def edges(self): """A dictionary, where the keys are edges, and the values are nodes""" return self[0] @property def patterns(self): """A list of all patterns that currently match at this node""" return self[1] class RewriteRule(object): """A rewrite rule. Expresses `lhs` -> `rhs`, for variables `vars`. Parameters ---------- lhs : task The left-hand-side of the rewrite rule. rhs : task or function The right-hand-side of the rewrite rule. If it's a task, variables in `rhs` will be replaced by terms in the subject that match the variables in `lhs`. If it's a function, the function will be called with a dict of such matches. vars: tuple, optional Tuple of variables found in the lhs. Variables can be represented as any hashable object; a good convention is to use strings. If there are no variables, this can be omitted. Examples -------- Here's a `RewriteRule` to replace all nested calls to `list`, so that `(list, (list, 'x'))` is replaced with `(list, 'x')`, where `'x'` is a variable. >>> lhs = (list, (list, 'x')) >>> rhs = (list, 'x') >>> variables = ('x',) >>> rule = RewriteRule(lhs, rhs, variables) Here's a more complicated rule that uses a callable right-hand-side. A callable `rhs` takes in a dictionary mapping variables to their matching values. This rule replaces all occurrences of `(list, 'x')` with `'x'` if `'x'` is a list itself. >>> lhs = (list, 'x') >>> def repl_list(sd): ... x = sd['x'] ... if isinstance(x, list): ... return x ... else: ... return (list, x) >>> rule = RewriteRule(lhs, repl_list, variables) """ def __init__(self, lhs, rhs, vars=()): if not isinstance(vars, tuple): raise TypeError("vars must be a tuple of variables") self.lhs = lhs if callable(rhs): self.subs = rhs else: self.subs = self._apply self.rhs = rhs self._varlist = [t for t in Traverser(lhs) if t in vars] # Reduce vars down to just variables found in lhs self.vars = tuple(sorted(set(self._varlist))) def _apply(self, sub_dict): term = self.rhs for key, val in sub_dict.items(): term = subs(term, key, val) return term def __str__(self): return "RewriteRule({0}, {1}, {2})".format(self.lhs, self.rhs, self.vars) def __repr__(self): return str(self) class RuleSet(object): """A set of rewrite rules. Forms a structure for fast rewriting over a set of rewrite rules. This allows for syntactic matching of terms to patterns for many patterns at the same time. Examples -------- >>> def f(*args): pass >>> def g(*args): pass >>> def h(*args): pass >>> from operator import add >>> rs = RuleSet( # Make RuleSet with two Rules ... RewriteRule((add, 'x', 0), 'x', ('x',)), ... RewriteRule((f, (g, 'x'), 'y'), ... (h, 'x', 'y'), ... ('x', 'y'))) >>> rs.rewrite((add, 2, 0)) # Apply ruleset to single task 2 >>> rs.rewrite((f, (g, 'a', 3))) # doctest: +SKIP (h, 'a', 3) >>> dsk = {'a': (add, 2, 0), # Apply ruleset to full dask graph ... 'b': (f, (g, 'a', 3))} >>> from toolz import valmap >>> valmap(rs.rewrite, dsk) # doctest: +SKIP {'a': 2, 'b': (h, 'a', 3)} Attributes ---------- rules : list A list of `RewriteRule`s included in the `RuleSet`. """ def __init__(self, *rules): """Create a `RuleSet` for a number of rules Parameters ---------- rules One or more instances of RewriteRule """ self._net = Node() self.rules = [] for p in rules: self.add(p) def add(self, rule): """Add a rule to the RuleSet. Parameters ---------- rule : RewriteRule """ if not isinstance(rule, RewriteRule): raise TypeError("rule must be instance of RewriteRule") vars = rule.vars curr_node = self._net ind = len(self.rules) # List of variables, in order they appear in the POT of the term for t in Traverser(rule.lhs): prev_node = curr_node if t in vars: t = VAR if t in curr_node.edges: curr_node = curr_node.edges[t] else: curr_node.edges[t] = Node() curr_node = curr_node.edges[t] # We've reached a leaf node. Add the term index to this leaf. prev_node.edges[t].patterns.append(ind) self.rules.append(rule) def iter_matches(self, term): """A generator that lazily finds matchings for term from the RuleSet. Parameters ---------- term : task Yields ------ Tuples of `(rule, subs)`, where `rule` is the rewrite rule being matched, and `subs` is a dictionary mapping the variables in the lhs of the rule to their matching values in the term.""" S = Traverser(term) for m, syms in _match(S, self._net): for i in m: rule = self.rules[i] subs = _process_match(rule, syms) if subs is not None: yield rule, subs def _rewrite(self, term): """Apply the rewrite rules in RuleSet to top level of term""" for rule, sd in self.iter_matches(term): # We use for (...) because it's fast in all cases for getting the # first element from the match iterator. As we only want that # element, we break here term = rule.subs(sd) break return term def rewrite(self, task, strategy="bottom_up"): """Apply the `RuleSet` to `task`. This applies the most specific matching rule in the RuleSet to the task, using the provided strategy. Parameters ---------- term: a task The task to be rewritten strategy: str, optional The rewriting strategy to use. Options are "bottom_up" (default), or "top_level". Examples -------- Suppose there was a function `add` that returned the sum of 2 numbers, and another function `double` that returned twice its input: >>> add = lambda x, y: x + y >>> double = lambda x: 2*x Now suppose `double` was *significantly* faster than `add`, so you'd like to replace all expressions `(add, x, x)` with `(double, x)`, where `x` is a variable. This can be expressed as a rewrite rule: >>> rule = RewriteRule((add, 'x', 'x'), (double, 'x'), ('x',)) >>> rs = RuleSet(rule) This can then be applied to terms to perform the rewriting: >>> term = (add, (add, 2, 2), (add, 2, 2)) >>> rs.rewrite(term) # doctest: +SKIP (double, (double, 2)) If we only wanted to apply this to the top level of the term, the `strategy` kwarg can be set to "top_level". >>> rs.rewrite(term) # doctest: +SKIP (double, (add, 2, 2)) """ return strategies[strategy](self, task) def _top_level(net, term): return net._rewrite(term) def _bottom_up(net, term): if istask(term): term = (head(term),) + tuple(_bottom_up(net, t) for t in args(term)) elif isinstance(term, list): term = [_bottom_up(net, t) for t in args(term)] return net._rewrite(term) strategies = {'top_level': _top_level, 'bottom_up': _bottom_up} def _match(S, N): """Structural matching of term S to discrimination net node N.""" stack = deque() restore_state_flag = False # matches are stored in a tuple, because all mutations result in a copy, # preventing operations from changing matches stored on the stack. matches = () while True: if S.current is END: yield N.patterns, matches try: # This try-except block is to catch hashing errors from un-hashable # types. This allows for variables to be matched with un-hashable # objects. n = N.edges.get(S.current, None) if n and not restore_state_flag: stack.append((S.copy(), N, matches)) N = n S.next() continue except TypeError: pass n = N.edges.get(VAR, None) if n: restore_state_flag = False matches = matches + (S.term,) S.skip() N = n continue try: # Backtrack here (S, N, matches) = stack.pop() restore_state_flag = True except Exception: return def _process_match(rule, syms): """Process a match to determine if it is correct, and to find the correct substitution that will convert the term into the pattern. Parameters ---------- rule : RewriteRule syms : iterable Iterable of subterms that match a corresponding variable. Returns ------- A dictionary of {vars : subterms} describing the substitution to make the pattern equivalent with the term. Returns `None` if the match is invalid.""" subs = {} varlist = rule._varlist if not len(varlist) == len(syms): raise RuntimeError("length of varlist doesn't match length of syms.") for v, s in zip(varlist, syms): if v in subs and subs[v] != s: return None else: subs[v] = s return subs