439 lines
12 KiB

from __future__ import absolute_import, division, print_function
from collections import deque
from dask.core import istask, subs
def head(task):
"""Return the top level node of a task"""
if istask(task):
return task[0]
elif isinstance(task, list):
return list
else:
return task
def args(task):
"""Get the arguments for the current task"""
if istask(task):
return task[1:]
elif isinstance(task, list):
return task
else:
return ()
class Traverser(object):
"""Traverser interface for tasks.
Class for storing the state while performing a preorder-traversal of a
task.
Parameters
----------
term : task
The task to be traversed
Attributes
----------
term
The current element in the traversal
current
The head of the current element in the traversal. This is simply `head`
applied to the attribute `term`.
"""
def __init__(self, term, stack=None):
self.term = term
if not stack:
self._stack = deque([END])
else:
self._stack = stack
def __iter__(self):
while self.current is not END:
yield self.current
self.next()
def copy(self):
"""Copy the traverser in its current state.
This allows the traversal to be pushed onto a stack, for easy
backtracking."""
return Traverser(self.term, deque(self._stack))
def next(self):
"""Proceed to the next term in the preorder traversal."""
subterms = args(self.term)
if not subterms:
# No subterms, pop off stack
self.term = self._stack.pop()
else:
self.term = subterms[0]
self._stack.extend(reversed(subterms[1:]))
@property
def current(self):
return head(self.term)
def skip(self):
"""Skip over all subterms of the current level in the traversal"""
self.term = self._stack.pop()
class Token(object):
"""A token object.
Used to express certain objects in the traversal of a task or pattern."""
def __init__(self, name):
self.name = name
def __repr__(self):
return self.name
# A variable to represent *all* variables in a discrimination net
VAR = Token('?')
# Represents the end of the traversal of an expression. We can't use `None`,
# 'False', etc... here, as anything may be an argument to a function.
END = Token('end')
class Node(tuple):
"""A Discrimination Net node."""
__slots__ = ()
def __new__(cls, edges=None, patterns=None):
edges = edges if edges else {}
patterns = patterns if patterns else []
return tuple.__new__(cls, (edges, patterns))
@property
def edges(self):
"""A dictionary, where the keys are edges, and the values are nodes"""
return self[0]
@property
def patterns(self):
"""A list of all patterns that currently match at this node"""
return self[1]
class RewriteRule(object):
"""A rewrite rule.
Expresses `lhs` -> `rhs`, for variables `vars`.
Parameters
----------
lhs : task
The left-hand-side of the rewrite rule.
rhs : task or function
The right-hand-side of the rewrite rule. If it's a task, variables in
`rhs` will be replaced by terms in the subject that match the variables
in `lhs`. If it's a function, the function will be called with a dict
of such matches.
vars: tuple, optional
Tuple of variables found in the lhs. Variables can be represented as
any hashable object; a good convention is to use strings. If there are
no variables, this can be omitted.
Examples
--------
Here's a `RewriteRule` to replace all nested calls to `list`, so that
`(list, (list, 'x'))` is replaced with `(list, 'x')`, where `'x'` is a
variable.
>>> lhs = (list, (list, 'x'))
>>> rhs = (list, 'x')
>>> variables = ('x',)
>>> rule = RewriteRule(lhs, rhs, variables)
Here's a more complicated rule that uses a callable right-hand-side. A
callable `rhs` takes in a dictionary mapping variables to their matching
values. This rule replaces all occurrences of `(list, 'x')` with `'x'` if
`'x'` is a list itself.
>>> lhs = (list, 'x')
>>> def repl_list(sd):
... x = sd['x']
... if isinstance(x, list):
... return x
... else:
... return (list, x)
>>> rule = RewriteRule(lhs, repl_list, variables)
"""
def __init__(self, lhs, rhs, vars=()):
if not isinstance(vars, tuple):
raise TypeError("vars must be a tuple of variables")
self.lhs = lhs
if callable(rhs):
self.subs = rhs
else:
self.subs = self._apply
self.rhs = rhs
self._varlist = [t for t in Traverser(lhs) if t in vars]
# Reduce vars down to just variables found in lhs
self.vars = tuple(sorted(set(self._varlist)))
def _apply(self, sub_dict):
term = self.rhs
for key, val in sub_dict.items():
term = subs(term, key, val)
return term
def __str__(self):
return "RewriteRule({0}, {1}, {2})".format(self.lhs, self.rhs,
self.vars)
def __repr__(self):
return str(self)
class RuleSet(object):
"""A set of rewrite rules.
Forms a structure for fast rewriting over a set of rewrite rules. This
allows for syntactic matching of terms to patterns for many patterns at
the same time.
Examples
--------
>>> def f(*args): pass
>>> def g(*args): pass
>>> def h(*args): pass
>>> from operator import add
>>> rs = RuleSet( # Make RuleSet with two Rules
... RewriteRule((add, 'x', 0), 'x', ('x',)),
... RewriteRule((f, (g, 'x'), 'y'),
... (h, 'x', 'y'),
... ('x', 'y')))
>>> rs.rewrite((add, 2, 0)) # Apply ruleset to single task
2
>>> rs.rewrite((f, (g, 'a', 3))) # doctest: +SKIP
(h, 'a', 3)
>>> dsk = {'a': (add, 2, 0), # Apply ruleset to full dask graph
... 'b': (f, (g, 'a', 3))}
>>> from toolz import valmap
>>> valmap(rs.rewrite, dsk) # doctest: +SKIP
{'a': 2,
'b': (h, 'a', 3)}
Attributes
----------
rules : list
A list of `RewriteRule`s included in the `RuleSet`.
"""
def __init__(self, *rules):
"""Create a `RuleSet` for a number of rules
Parameters
----------
rules
One or more instances of RewriteRule
"""
self._net = Node()
self.rules = []
for p in rules:
self.add(p)
def add(self, rule):
"""Add a rule to the RuleSet.
Parameters
----------
rule : RewriteRule
"""
if not isinstance(rule, RewriteRule):
raise TypeError("rule must be instance of RewriteRule")
vars = rule.vars
curr_node = self._net
ind = len(self.rules)
# List of variables, in order they appear in the POT of the term
for t in Traverser(rule.lhs):
prev_node = curr_node
if t in vars:
t = VAR
if t in curr_node.edges:
curr_node = curr_node.edges[t]
else:
curr_node.edges[t] = Node()
curr_node = curr_node.edges[t]
# We've reached a leaf node. Add the term index to this leaf.
prev_node.edges[t].patterns.append(ind)
self.rules.append(rule)
def iter_matches(self, term):
"""A generator that lazily finds matchings for term from the RuleSet.
Parameters
----------
term : task
Yields
------
Tuples of `(rule, subs)`, where `rule` is the rewrite rule being
matched, and `subs` is a dictionary mapping the variables in the lhs
of the rule to their matching values in the term."""
S = Traverser(term)
for m, syms in _match(S, self._net):
for i in m:
rule = self.rules[i]
subs = _process_match(rule, syms)
if subs is not None:
yield rule, subs
def _rewrite(self, term):
"""Apply the rewrite rules in RuleSet to top level of term"""
for rule, sd in self.iter_matches(term):
# We use for (...) because it's fast in all cases for getting the
# first element from the match iterator. As we only want that
# element, we break here
term = rule.subs(sd)
break
return term
def rewrite(self, task, strategy="bottom_up"):
"""Apply the `RuleSet` to `task`.
This applies the most specific matching rule in the RuleSet to the
task, using the provided strategy.
Parameters
----------
term: a task
The task to be rewritten
strategy: str, optional
The rewriting strategy to use. Options are "bottom_up" (default),
or "top_level".
Examples
--------
Suppose there was a function `add` that returned the sum of 2 numbers,
and another function `double` that returned twice its input:
>>> add = lambda x, y: x + y
>>> double = lambda x: 2*x
Now suppose `double` was *significantly* faster than `add`, so
you'd like to replace all expressions `(add, x, x)` with `(double,
x)`, where `x` is a variable. This can be expressed as a rewrite rule:
>>> rule = RewriteRule((add, 'x', 'x'), (double, 'x'), ('x',))
>>> rs = RuleSet(rule)
This can then be applied to terms to perform the rewriting:
>>> term = (add, (add, 2, 2), (add, 2, 2))
>>> rs.rewrite(term) # doctest: +SKIP
(double, (double, 2))
If we only wanted to apply this to the top level of the term, the
`strategy` kwarg can be set to "top_level".
>>> rs.rewrite(term) # doctest: +SKIP
(double, (add, 2, 2))
"""
return strategies[strategy](self, task)
def _top_level(net, term):
return net._rewrite(term)
def _bottom_up(net, term):
if istask(term):
term = (head(term),) + tuple(_bottom_up(net, t) for t in args(term))
elif isinstance(term, list):
term = [_bottom_up(net, t) for t in args(term)]
return net._rewrite(term)
strategies = {'top_level': _top_level,
'bottom_up': _bottom_up}
def _match(S, N):
"""Structural matching of term S to discrimination net node N."""
stack = deque()
restore_state_flag = False
# matches are stored in a tuple, because all mutations result in a copy,
# preventing operations from changing matches stored on the stack.
matches = ()
while True:
if S.current is END:
yield N.patterns, matches
try:
# This try-except block is to catch hashing errors from un-hashable
# types. This allows for variables to be matched with un-hashable
# objects.
n = N.edges.get(S.current, None)
if n and not restore_state_flag:
stack.append((S.copy(), N, matches))
N = n
S.next()
continue
except TypeError:
pass
n = N.edges.get(VAR, None)
if n:
restore_state_flag = False
matches = matches + (S.term,)
S.skip()
N = n
continue
try:
# Backtrack here
(S, N, matches) = stack.pop()
restore_state_flag = True
except Exception:
return
def _process_match(rule, syms):
"""Process a match to determine if it is correct, and to find the correct
substitution that will convert the term into the pattern.
Parameters
----------
rule : RewriteRule
syms : iterable
Iterable of subterms that match a corresponding variable.
Returns
-------
A dictionary of {vars : subterms} describing the substitution to make the
pattern equivalent with the term. Returns `None` if the match is
invalid."""
subs = {}
varlist = rule._varlist
if not len(varlist) == len(syms):
raise RuntimeError("length of varlist doesn't match length of syms.")
for v, s in zip(varlist, syms):
if v in subs and subs[v] != s:
return None
else:
subs[v] = s
return subs