439 lines
12 KiB
439 lines
12 KiB
from __future__ import absolute_import, division, print_function
|
|
|
|
from collections import deque
|
|
|
|
from dask.core import istask, subs
|
|
|
|
|
|
def head(task):
|
|
"""Return the top level node of a task"""
|
|
|
|
if istask(task):
|
|
return task[0]
|
|
elif isinstance(task, list):
|
|
return list
|
|
else:
|
|
return task
|
|
|
|
|
|
def args(task):
|
|
"""Get the arguments for the current task"""
|
|
|
|
if istask(task):
|
|
return task[1:]
|
|
elif isinstance(task, list):
|
|
return task
|
|
else:
|
|
return ()
|
|
|
|
|
|
class Traverser(object):
|
|
"""Traverser interface for tasks.
|
|
|
|
Class for storing the state while performing a preorder-traversal of a
|
|
task.
|
|
|
|
Parameters
|
|
----------
|
|
term : task
|
|
The task to be traversed
|
|
|
|
Attributes
|
|
----------
|
|
term
|
|
The current element in the traversal
|
|
current
|
|
The head of the current element in the traversal. This is simply `head`
|
|
applied to the attribute `term`.
|
|
"""
|
|
|
|
def __init__(self, term, stack=None):
|
|
self.term = term
|
|
if not stack:
|
|
self._stack = deque([END])
|
|
else:
|
|
self._stack = stack
|
|
|
|
def __iter__(self):
|
|
while self.current is not END:
|
|
yield self.current
|
|
self.next()
|
|
|
|
def copy(self):
|
|
"""Copy the traverser in its current state.
|
|
|
|
This allows the traversal to be pushed onto a stack, for easy
|
|
backtracking."""
|
|
|
|
return Traverser(self.term, deque(self._stack))
|
|
|
|
def next(self):
|
|
"""Proceed to the next term in the preorder traversal."""
|
|
|
|
subterms = args(self.term)
|
|
if not subterms:
|
|
# No subterms, pop off stack
|
|
self.term = self._stack.pop()
|
|
else:
|
|
self.term = subterms[0]
|
|
self._stack.extend(reversed(subterms[1:]))
|
|
|
|
@property
|
|
def current(self):
|
|
return head(self.term)
|
|
|
|
def skip(self):
|
|
"""Skip over all subterms of the current level in the traversal"""
|
|
self.term = self._stack.pop()
|
|
|
|
|
|
class Token(object):
|
|
"""A token object.
|
|
|
|
Used to express certain objects in the traversal of a task or pattern."""
|
|
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def __repr__(self):
|
|
return self.name
|
|
|
|
|
|
# A variable to represent *all* variables in a discrimination net
|
|
VAR = Token('?')
|
|
# Represents the end of the traversal of an expression. We can't use `None`,
|
|
# 'False', etc... here, as anything may be an argument to a function.
|
|
END = Token('end')
|
|
|
|
|
|
class Node(tuple):
|
|
"""A Discrimination Net node."""
|
|
|
|
__slots__ = ()
|
|
|
|
def __new__(cls, edges=None, patterns=None):
|
|
edges = edges if edges else {}
|
|
patterns = patterns if patterns else []
|
|
return tuple.__new__(cls, (edges, patterns))
|
|
|
|
@property
|
|
def edges(self):
|
|
"""A dictionary, where the keys are edges, and the values are nodes"""
|
|
return self[0]
|
|
|
|
@property
|
|
def patterns(self):
|
|
"""A list of all patterns that currently match at this node"""
|
|
return self[1]
|
|
|
|
|
|
class RewriteRule(object):
|
|
"""A rewrite rule.
|
|
|
|
Expresses `lhs` -> `rhs`, for variables `vars`.
|
|
|
|
Parameters
|
|
----------
|
|
lhs : task
|
|
The left-hand-side of the rewrite rule.
|
|
rhs : task or function
|
|
The right-hand-side of the rewrite rule. If it's a task, variables in
|
|
`rhs` will be replaced by terms in the subject that match the variables
|
|
in `lhs`. If it's a function, the function will be called with a dict
|
|
of such matches.
|
|
vars: tuple, optional
|
|
Tuple of variables found in the lhs. Variables can be represented as
|
|
any hashable object; a good convention is to use strings. If there are
|
|
no variables, this can be omitted.
|
|
|
|
Examples
|
|
--------
|
|
Here's a `RewriteRule` to replace all nested calls to `list`, so that
|
|
`(list, (list, 'x'))` is replaced with `(list, 'x')`, where `'x'` is a
|
|
variable.
|
|
|
|
>>> lhs = (list, (list, 'x'))
|
|
>>> rhs = (list, 'x')
|
|
>>> variables = ('x',)
|
|
>>> rule = RewriteRule(lhs, rhs, variables)
|
|
|
|
Here's a more complicated rule that uses a callable right-hand-side. A
|
|
callable `rhs` takes in a dictionary mapping variables to their matching
|
|
values. This rule replaces all occurrences of `(list, 'x')` with `'x'` if
|
|
`'x'` is a list itself.
|
|
|
|
>>> lhs = (list, 'x')
|
|
>>> def repl_list(sd):
|
|
... x = sd['x']
|
|
... if isinstance(x, list):
|
|
... return x
|
|
... else:
|
|
... return (list, x)
|
|
>>> rule = RewriteRule(lhs, repl_list, variables)
|
|
"""
|
|
|
|
def __init__(self, lhs, rhs, vars=()):
|
|
if not isinstance(vars, tuple):
|
|
raise TypeError("vars must be a tuple of variables")
|
|
self.lhs = lhs
|
|
if callable(rhs):
|
|
self.subs = rhs
|
|
else:
|
|
self.subs = self._apply
|
|
self.rhs = rhs
|
|
self._varlist = [t for t in Traverser(lhs) if t in vars]
|
|
# Reduce vars down to just variables found in lhs
|
|
self.vars = tuple(sorted(set(self._varlist)))
|
|
|
|
def _apply(self, sub_dict):
|
|
term = self.rhs
|
|
for key, val in sub_dict.items():
|
|
term = subs(term, key, val)
|
|
return term
|
|
|
|
def __str__(self):
|
|
return "RewriteRule({0}, {1}, {2})".format(self.lhs, self.rhs,
|
|
self.vars)
|
|
|
|
def __repr__(self):
|
|
return str(self)
|
|
|
|
|
|
class RuleSet(object):
|
|
"""A set of rewrite rules.
|
|
|
|
Forms a structure for fast rewriting over a set of rewrite rules. This
|
|
allows for syntactic matching of terms to patterns for many patterns at
|
|
the same time.
|
|
|
|
Examples
|
|
--------
|
|
|
|
>>> def f(*args): pass
|
|
>>> def g(*args): pass
|
|
>>> def h(*args): pass
|
|
>>> from operator import add
|
|
|
|
>>> rs = RuleSet( # Make RuleSet with two Rules
|
|
... RewriteRule((add, 'x', 0), 'x', ('x',)),
|
|
... RewriteRule((f, (g, 'x'), 'y'),
|
|
... (h, 'x', 'y'),
|
|
... ('x', 'y')))
|
|
|
|
>>> rs.rewrite((add, 2, 0)) # Apply ruleset to single task
|
|
2
|
|
|
|
>>> rs.rewrite((f, (g, 'a', 3))) # doctest: +SKIP
|
|
(h, 'a', 3)
|
|
|
|
>>> dsk = {'a': (add, 2, 0), # Apply ruleset to full dask graph
|
|
... 'b': (f, (g, 'a', 3))}
|
|
|
|
>>> from toolz import valmap
|
|
>>> valmap(rs.rewrite, dsk) # doctest: +SKIP
|
|
{'a': 2,
|
|
'b': (h, 'a', 3)}
|
|
|
|
Attributes
|
|
----------
|
|
rules : list
|
|
A list of `RewriteRule`s included in the `RuleSet`.
|
|
"""
|
|
|
|
def __init__(self, *rules):
|
|
"""Create a `RuleSet` for a number of rules
|
|
|
|
Parameters
|
|
----------
|
|
rules
|
|
One or more instances of RewriteRule
|
|
"""
|
|
self._net = Node()
|
|
self.rules = []
|
|
for p in rules:
|
|
self.add(p)
|
|
|
|
def add(self, rule):
|
|
"""Add a rule to the RuleSet.
|
|
|
|
Parameters
|
|
----------
|
|
rule : RewriteRule
|
|
"""
|
|
|
|
if not isinstance(rule, RewriteRule):
|
|
raise TypeError("rule must be instance of RewriteRule")
|
|
vars = rule.vars
|
|
curr_node = self._net
|
|
ind = len(self.rules)
|
|
# List of variables, in order they appear in the POT of the term
|
|
for t in Traverser(rule.lhs):
|
|
prev_node = curr_node
|
|
if t in vars:
|
|
t = VAR
|
|
if t in curr_node.edges:
|
|
curr_node = curr_node.edges[t]
|
|
else:
|
|
curr_node.edges[t] = Node()
|
|
curr_node = curr_node.edges[t]
|
|
# We've reached a leaf node. Add the term index to this leaf.
|
|
prev_node.edges[t].patterns.append(ind)
|
|
self.rules.append(rule)
|
|
|
|
def iter_matches(self, term):
|
|
"""A generator that lazily finds matchings for term from the RuleSet.
|
|
|
|
Parameters
|
|
----------
|
|
term : task
|
|
|
|
Yields
|
|
------
|
|
Tuples of `(rule, subs)`, where `rule` is the rewrite rule being
|
|
matched, and `subs` is a dictionary mapping the variables in the lhs
|
|
of the rule to their matching values in the term."""
|
|
|
|
S = Traverser(term)
|
|
for m, syms in _match(S, self._net):
|
|
for i in m:
|
|
rule = self.rules[i]
|
|
subs = _process_match(rule, syms)
|
|
if subs is not None:
|
|
yield rule, subs
|
|
|
|
def _rewrite(self, term):
|
|
"""Apply the rewrite rules in RuleSet to top level of term"""
|
|
|
|
for rule, sd in self.iter_matches(term):
|
|
# We use for (...) because it's fast in all cases for getting the
|
|
# first element from the match iterator. As we only want that
|
|
# element, we break here
|
|
term = rule.subs(sd)
|
|
break
|
|
return term
|
|
|
|
def rewrite(self, task, strategy="bottom_up"):
|
|
"""Apply the `RuleSet` to `task`.
|
|
|
|
This applies the most specific matching rule in the RuleSet to the
|
|
task, using the provided strategy.
|
|
|
|
Parameters
|
|
----------
|
|
term: a task
|
|
The task to be rewritten
|
|
strategy: str, optional
|
|
The rewriting strategy to use. Options are "bottom_up" (default),
|
|
or "top_level".
|
|
|
|
Examples
|
|
--------
|
|
Suppose there was a function `add` that returned the sum of 2 numbers,
|
|
and another function `double` that returned twice its input:
|
|
|
|
>>> add = lambda x, y: x + y
|
|
>>> double = lambda x: 2*x
|
|
|
|
Now suppose `double` was *significantly* faster than `add`, so
|
|
you'd like to replace all expressions `(add, x, x)` with `(double,
|
|
x)`, where `x` is a variable. This can be expressed as a rewrite rule:
|
|
|
|
>>> rule = RewriteRule((add, 'x', 'x'), (double, 'x'), ('x',))
|
|
>>> rs = RuleSet(rule)
|
|
|
|
This can then be applied to terms to perform the rewriting:
|
|
|
|
>>> term = (add, (add, 2, 2), (add, 2, 2))
|
|
>>> rs.rewrite(term) # doctest: +SKIP
|
|
(double, (double, 2))
|
|
|
|
If we only wanted to apply this to the top level of the term, the
|
|
`strategy` kwarg can be set to "top_level".
|
|
|
|
>>> rs.rewrite(term) # doctest: +SKIP
|
|
(double, (add, 2, 2))
|
|
"""
|
|
return strategies[strategy](self, task)
|
|
|
|
|
|
def _top_level(net, term):
|
|
return net._rewrite(term)
|
|
|
|
|
|
def _bottom_up(net, term):
|
|
if istask(term):
|
|
term = (head(term),) + tuple(_bottom_up(net, t) for t in args(term))
|
|
elif isinstance(term, list):
|
|
term = [_bottom_up(net, t) for t in args(term)]
|
|
return net._rewrite(term)
|
|
|
|
|
|
strategies = {'top_level': _top_level,
|
|
'bottom_up': _bottom_up}
|
|
|
|
|
|
def _match(S, N):
|
|
"""Structural matching of term S to discrimination net node N."""
|
|
|
|
stack = deque()
|
|
restore_state_flag = False
|
|
# matches are stored in a tuple, because all mutations result in a copy,
|
|
# preventing operations from changing matches stored on the stack.
|
|
matches = ()
|
|
while True:
|
|
if S.current is END:
|
|
yield N.patterns, matches
|
|
try:
|
|
# This try-except block is to catch hashing errors from un-hashable
|
|
# types. This allows for variables to be matched with un-hashable
|
|
# objects.
|
|
n = N.edges.get(S.current, None)
|
|
if n and not restore_state_flag:
|
|
stack.append((S.copy(), N, matches))
|
|
N = n
|
|
S.next()
|
|
continue
|
|
except TypeError:
|
|
pass
|
|
n = N.edges.get(VAR, None)
|
|
if n:
|
|
restore_state_flag = False
|
|
matches = matches + (S.term,)
|
|
S.skip()
|
|
N = n
|
|
continue
|
|
try:
|
|
# Backtrack here
|
|
(S, N, matches) = stack.pop()
|
|
restore_state_flag = True
|
|
except Exception:
|
|
return
|
|
|
|
|
|
def _process_match(rule, syms):
|
|
"""Process a match to determine if it is correct, and to find the correct
|
|
substitution that will convert the term into the pattern.
|
|
|
|
Parameters
|
|
----------
|
|
rule : RewriteRule
|
|
syms : iterable
|
|
Iterable of subterms that match a corresponding variable.
|
|
|
|
Returns
|
|
-------
|
|
A dictionary of {vars : subterms} describing the substitution to make the
|
|
pattern equivalent with the term. Returns `None` if the match is
|
|
invalid."""
|
|
|
|
subs = {}
|
|
varlist = rule._varlist
|
|
if not len(varlist) == len(syms):
|
|
raise RuntimeError("length of varlist doesn't match length of syms.")
|
|
for v, s in zip(varlist, syms):
|
|
if v in subs and subs[v] != s:
|
|
return None
|
|
else:
|
|
subs[v] = s
|
|
return subs
|