import itertools import numpy as np try: import cytoolz as toolz except ImportError: import toolz from . import core, utils from .compatibility import apply, Mapping from .delayed import to_task_dask from .highlevelgraph import HighLevelGraph from .optimization import SubgraphCallable def subs(task, substitution): """ Create a new task with the values substituted This is like dask.core.subs, but takes a dict of many substitutions to perform simultaneously. It is not as concerned with micro performance. """ if isinstance(task, dict): return {k: subs(v, substitution) for k, v in task.items()} if type(task) in (tuple, list, set): return type(task)([subs(x, substitution) for x in task]) try: return substitution[task] except (KeyError, TypeError): return task def index_subs(ind, substitution): """ A simple subs function that works both on tuples and strings """ if ind is None: return ind else: return tuple([substitution.get(c, c) for c in ind]) def blockwise_token(i, prefix='_'): return prefix + '%d' % i def blockwise(func, output, output_indices, *arrind_pairs, **kwargs): """ Create a Blockwise symbolic mutable mapping This is like the ``make_blockwise_graph`` function, but rather than construct a dict, it returns a symbolic Blockwise object. See Also -------- make_blockwise_graph Blockwise """ numblocks = kwargs.pop('numblocks') concatenate = kwargs.pop('concatenate', None) new_axes = kwargs.pop('new_axes', {}) dependencies = kwargs.pop('dependencies', []) arrind_pairs = list(arrind_pairs) # Transform indices to canonical elements # We use terms like _0, and _1 rather than provided index elements unique_indices = {i for ii in arrind_pairs[1::2] if ii is not None for i in ii} | set(output_indices) sub = {k: blockwise_token(i, '.') for i, k in enumerate(sorted(unique_indices))} output_indices = index_subs(tuple(output_indices), sub) arrind_pairs[1::2] = [tuple(a) if a is not None else a for a in arrind_pairs[1::2]] arrind_pairs[1::2] = [index_subs(a, sub) for a in arrind_pairs[1::2]] new_axes = {index_subs((k,), sub)[0]: v for k, v in new_axes.items()} # Unpack dask values in non-array arguments argpairs = list(toolz.partition(2, arrind_pairs)) # separate argpairs into two separate tuples inputs = tuple([name for name, _ in argpairs]) inputs_indices = tuple([index for _, index in argpairs]) # Unpack delayed objects in kwargs new_keys = {n for c in dependencies for n in c.__dask_layers__()} if kwargs: # replace keys in kwargs with _0 tokens new_tokens = tuple(blockwise_token(i) for i in range(len(inputs), len(inputs) + len(new_keys))) sub = dict(zip(new_keys, new_tokens)) inputs = inputs + tuple(new_keys) inputs_indices = inputs_indices + (None,) * len(new_keys) kwargs = subs(kwargs, sub) indices = [(k, v) for k, v in zip(inputs, inputs_indices)] keys = tuple(map(blockwise_token, range(len(inputs)))) # Construct local graph if not kwargs: subgraph = {output: (func,) + keys} else: _keys = list(keys) if new_keys: _keys = _keys[:-len(new_keys)] kwargs2 = (dict, list(map(list, kwargs.items()))) subgraph = {output: (apply, func, _keys, kwargs2)} # Construct final output subgraph = Blockwise(output, output_indices, subgraph, indices, numblocks=numblocks, concatenate=concatenate, new_axes=new_axes) return subgraph class Blockwise(Mapping): """ Tensor Operation This is a lazily constructed mapping for tensor operation graphs. This defines a dictionary using an operation and an indexing pattern. It is built for many operations like elementwise, transpose, tensordot, and so on. We choose to keep these as symbolic mappings rather than raw dictionaries because we are able to fuse them during optimization, sometimes resulting in much lower overhead. Parameters ---------- output: str The name of the output collection. Used in keynames output_indices: tuple The output indices, like ``('i', 'j', 'k')`` used to determine the structure of the block computations dsk: dict A small graph to apply per-output-block. May include keys from the input indices. indices: Tuple[str, Tuple[str, str]] An ordered mapping from input key name, like ``'x'`` to input indices, like ``('i', 'j')`` Or includes literals, which have ``None`` for an index value numblocks: Dict[key, Sequence[int]] Number of blocks along each dimension for each input concatenate: boolean Whether or not to pass contracted dimensions as a list of inputs or a single input to the block function new_axes: Dict New index dimensions that may have been created, and their extent See Also -------- dask.blockwise.blockwise dask.array.blockwise """ def __init__(self, output, output_indices, dsk, indices, numblocks, concatenate=None, new_axes=None): self.output = output self.output_indices = tuple(output_indices) self.dsk = dsk self.indices = tuple((name, tuple(ind) if ind is not None else ind) for name, ind in indices) self.numblocks = numblocks self.concatenate = concatenate self.new_axes = new_axes or {} @property def _dict(self): if hasattr(self, '_cached_dict'): return self._cached_dict else: keys = tuple(map(blockwise_token, range(len(self.indices)))) func = SubgraphCallable(self.dsk, self.output, keys) self._cached_dict = make_blockwise_graph( func, self.output, self.output_indices, *list(toolz.concat(self.indices)), new_axes=self.new_axes, numblocks=self.numblocks, concatenate=self.concatenate ) return self._cached_dict def __getitem__(self, key): return self._dict[key] def __iter__(self): return iter(self._dict) def __len__(self): return int(np.prod(list(self._out_numblocks().values()))) def _out_numblocks(self): d = {} indices = {k: v for k, v in self.indices if v is not None} for k, v in self.numblocks.items(): for a, b in zip(indices[k], v): d[a] = max(d.get(a, 0), b) return {k: v for k, v in d.items() if k in self.output_indices} def make_blockwise_graph(func, output, out_indices, *arrind_pairs, **kwargs): """ Tensor operation Applies a function, ``func``, across blocks from many different input collections. We arrange the pattern with which those blocks interact with sets of matching indices. E.g.:: make_blockwise_graph(func, 'z', 'i', 'x', 'i', 'y', 'i') yield an embarrassingly parallel communication pattern and is read as $$ z_i = func(x_i, y_i) $$ More complex patterns may emerge, including multiple indices:: make_blockwise_graph(func, 'z', 'ij', 'x', 'ij', 'y', 'ji') $$ z_{ij} = func(x_{ij}, y_{ji}) $$ Indices missing in the output but present in the inputs results in many inputs being sent to one function (see examples). Examples -------- Simple embarrassing map operation >>> inc = lambda x: x + 1 >>> make_blockwise_graph(inc, 'z', 'ij', 'x', 'ij', numblocks={'x': (2, 2)}) # doctest: +SKIP {('z', 0, 0): (inc, ('x', 0, 0)), ('z', 0, 1): (inc, ('x', 0, 1)), ('z', 1, 0): (inc, ('x', 1, 0)), ('z', 1, 1): (inc, ('x', 1, 1))} Simple operation on two datasets >>> add = lambda x, y: x + y >>> make_blockwise_graph(add, 'z', 'ij', 'x', 'ij', 'y', 'ij', numblocks={'x': (2, 2), ... 'y': (2, 2)}) # doctest: +SKIP {('z', 0, 0): (add, ('x', 0, 0), ('y', 0, 0)), ('z', 0, 1): (add, ('x', 0, 1), ('y', 0, 1)), ('z', 1, 0): (add, ('x', 1, 0), ('y', 1, 0)), ('z', 1, 1): (add, ('x', 1, 1), ('y', 1, 1))} Operation that flips one of the datasets >>> addT = lambda x, y: x + y.T # Transpose each chunk >>> # z_ij ~ x_ij y_ji >>> # .. .. .. notice swap >>> make_blockwise_graph(addT, 'z', 'ij', 'x', 'ij', 'y', 'ji', numblocks={'x': (2, 2), ... 'y': (2, 2)}) # doctest: +SKIP {('z', 0, 0): (add, ('x', 0, 0), ('y', 0, 0)), ('z', 0, 1): (add, ('x', 0, 1), ('y', 1, 0)), ('z', 1, 0): (add, ('x', 1, 0), ('y', 0, 1)), ('z', 1, 1): (add, ('x', 1, 1), ('y', 1, 1))} Dot product with contraction over ``j`` index. Yields list arguments >>> make_blockwise_graph(dotmany, 'z', 'ik', 'x', 'ij', 'y', 'jk', numblocks={'x': (2, 2), ... 'y': (2, 2)}) # doctest: +SKIP {('z', 0, 0): (dotmany, [('x', 0, 0), ('x', 0, 1)], [('y', 0, 0), ('y', 1, 0)]), ('z', 0, 1): (dotmany, [('x', 0, 0), ('x', 0, 1)], [('y', 0, 1), ('y', 1, 1)]), ('z', 1, 0): (dotmany, [('x', 1, 0), ('x', 1, 1)], [('y', 0, 0), ('y', 1, 0)]), ('z', 1, 1): (dotmany, [('x', 1, 0), ('x', 1, 1)], [('y', 0, 1), ('y', 1, 1)])} Pass ``concatenate=True`` to concatenate arrays ahead of time >>> make_blockwise_graph(f, 'z', 'i', 'x', 'ij', 'y', 'ij', concatenate=True, ... numblocks={'x': (2, 2), 'y': (2, 2,)}) # doctest: +SKIP {('z', 0): (f, (concatenate_axes, [('x', 0, 0), ('x', 0, 1)], (1,)), (concatenate_axes, [('y', 0, 0), ('y', 0, 1)], (1,))) ('z', 1): (f, (concatenate_axes, [('x', 1, 0), ('x', 1, 1)], (1,)), (concatenate_axes, [('y', 1, 0), ('y', 1, 1)], (1,)))} Supports Broadcasting rules >>> make_blockwise_graph(add, 'z', 'ij', 'x', 'ij', 'y', 'ij', numblocks={'x': (1, 2), ... 'y': (2, 2)}) # doctest: +SKIP {('z', 0, 0): (add, ('x', 0, 0), ('y', 0, 0)), ('z', 0, 1): (add, ('x', 0, 1), ('y', 0, 1)), ('z', 1, 0): (add, ('x', 0, 0), ('y', 1, 0)), ('z', 1, 1): (add, ('x', 0, 1), ('y', 1, 1))} Support keyword arguments with apply >>> def f(a, b=0): return a + b >>> make_blockwise_graph(f, 'z', 'i', 'x', 'i', numblocks={'x': (2,)}, b=10) # doctest: +SKIP {('z', 0): (apply, f, [('x', 0)], {'b': 10}), ('z', 1): (apply, f, [('x', 1)], {'b': 10})} Include literals by indexing with ``None`` >>> make_blockwise_graph(add, 'z', 'i', 'x', 'i', 100, None, numblocks={'x': (2,)}) # doctest: +SKIP {('z', 0): (add, ('x', 0), 100), ('z', 1): (add, ('x', 1), 100)} See Also -------- dask.array.blockwise dask.blockwise.blockwise """ numblocks = kwargs.pop('numblocks') concatenate = kwargs.pop('concatenate', None) new_axes = kwargs.pop('new_axes', {}) argpairs = list(toolz.partition(2, arrind_pairs)) if concatenate is True: from dask.array.core import concatenate_axes as concatenate assert set(numblocks) == {name for name, ind in argpairs if ind is not None} all_indices = {x for _, ind in argpairs if ind for x in ind} dummy_indices = all_indices - set(out_indices) # Dictionary mapping {i: 3, j: 4, ...} for i, j, ... the dimensions dims = broadcast_dimensions(argpairs, numblocks) for k in new_axes: dims[k] = 1 # (0, 0), (0, 1), (0, 2), (1, 0), ... keytups = list(itertools.product(*[range(dims[i]) for i in out_indices])) # {i: 0, j: 0}, {i: 0, j: 1}, ... keydicts = [dict(zip(out_indices, tup)) for tup in keytups] # {j: [1, 2, 3], ...} For j a dummy index of dimension 3 dummies = dict((i, list(range(dims[i]))) for i in dummy_indices) dsk = {} # Create argument lists valtups = [] for kd in keydicts: args = [] for arg, ind in argpairs: if ind is None: args.append(arg) else: tups = lol_tuples((arg,), ind, kd, dummies) if any(nb == 1 for nb in numblocks[arg]): tups2 = zero_broadcast_dimensions(tups, numblocks[arg]) else: tups2 = tups if concatenate and isinstance(tups2, list): axes = [n for n, i in enumerate(ind) if i in dummies] tups2 = (concatenate, tups2, axes) args.append(tups2) valtups.append(args) if not kwargs: # will not be used in an apply, should be a tuple valtups = [tuple(vt) for vt in valtups] # Add heads to tuples keys = [(output,) + kt for kt in keytups] # Unpack delayed objects in kwargs if kwargs: task, dsk2 = to_task_dask(kwargs) if dsk2: dsk.update(utils.ensure_dict(dsk2)) kwargs2 = task else: kwargs2 = kwargs vals = [(apply, func, vt, kwargs2) for vt in valtups] else: vals = [(func,) + vt for vt in valtups] dsk.update(dict(zip(keys, vals))) return dsk def lol_tuples(head, ind, values, dummies): """ List of list of tuple keys Parameters ---------- head : tuple The known tuple so far ind : Iterable An iterable of indices not yet covered values : dict Known values for non-dummy indices dummies : dict Ranges of values for dummy indices Examples -------- >>> lol_tuples(('x',), 'ij', {'i': 1, 'j': 0}, {}) ('x', 1, 0) >>> lol_tuples(('x',), 'ij', {'i': 1}, {'j': range(3)}) [('x', 1, 0), ('x', 1, 1), ('x', 1, 2)] >>> lol_tuples(('x',), 'ij', {'i': 1}, {'j': range(3)}) [('x', 1, 0), ('x', 1, 1), ('x', 1, 2)] >>> lol_tuples(('x',), 'ijk', {'i': 1}, {'j': [0, 1, 2], 'k': [0, 1]}) # doctest: +NORMALIZE_WHITESPACE [[('x', 1, 0, 0), ('x', 1, 0, 1)], [('x', 1, 1, 0), ('x', 1, 1, 1)], [('x', 1, 2, 0), ('x', 1, 2, 1)]] """ if not ind: return head if ind[0] not in dummies: return lol_tuples(head + (values[ind[0]],), ind[1:], values, dummies) else: return [lol_tuples(head + (v,), ind[1:], values, dummies) for v in dummies[ind[0]]] def optimize_blockwise(graph, keys=()): """ High level optimization of stacked Blockwise layers For operations that have multiple Blockwise operations one after the other, like ``x.T + 123`` we can fuse these into a single Blockwise operation. This happens before any actual tasks are generated, and so can reduce overhead. This finds groups of Blockwise operations that can be safely fused, and then passes them to ``rewrite_blockwise`` for rewriting. Parameters ---------- full_graph: HighLevelGraph keys: Iterable The keys of all outputs of all collections. Used to make sure that we don't fuse a layer needed by an output Returns ------- HighLevelGraph See Also -------- rewrite_blockwise """ out = _optimize_blockwise(graph, keys=keys) while out.dependencies != graph.dependencies: graph = out out = _optimize_blockwise(graph, keys=keys) return out def _optimize_blockwise(full_graph, keys=()): keep = {k[0] if type(k) is tuple else k for k in keys} layers = full_graph.dicts dependents = core.reverse_dict(full_graph.dependencies) roots = {k for k in full_graph.dicts if not dependents.get(k)} stack = list(roots) out = {} dependencies = {} seen = set() while stack: layer = stack.pop() if layer in seen or layer not in layers: continue seen.add(layer) # Outer loop walks through possible output Blockwise layers if isinstance(layers[layer], Blockwise): blockwise_layers = {layer} deps = set(blockwise_layers) while deps: # we gather as many sub-layers as we can dep = deps.pop() if dep not in layers: stack.append(dep) continue if not isinstance(layers[dep], Blockwise): stack.append(dep) continue if (dep != layer and dep in keep): stack.append(dep) continue if layers[dep].concatenate != layers[layer].concatenate: stack.append(dep) continue if sum(k == dep for k, ind in layers[layer].indices if ind is not None) > 1: stack.append(dep) continue # passed everything, proceed blockwise_layers.add(dep) # traverse further to this child's children for d in full_graph.dependencies.get(dep, ()): # Don't allow reductions to proceed output_indices = set(layers[dep].output_indices) input_indices = {i for _, ind in layers[dep].indices if ind for i in ind} if len(dependents[d]) <= 1 and output_indices.issuperset(input_indices): deps.add(d) else: stack.append(d) # Merge these Blockwise layers into one new_layer = rewrite_blockwise([layers[l] for l in blockwise_layers]) out[layer] = new_layer dependencies[layer] = {k for k, v in new_layer.indices if v is not None} else: out[layer] = layers[layer] dependencies[layer] = full_graph.dependencies.get(layer, set()) stack.extend(full_graph.dependencies.get(layer, ())) return HighLevelGraph(out, dependencies) def rewrite_blockwise(inputs): """ Rewrite a stack of Blockwise expressions into a single blockwise expression Given a set of Blockwise layers, combine them into a single layer. The provided layers are expected to fit well together. That job is handled by ``optimize_blockwise`` Parameters ---------- inputs : List[Blockwise] Returns ------- blockwise: Blockwise See Also -------- optimize_blockwise """ inputs = {inp.output: inp for inp in inputs} dependencies = {inp.output: {d for d, v in inp.indices if v is not None and d in inputs} for inp in inputs.values()} dependents = core.reverse_dict(dependencies) new_index_iter = (c + (str(d) if d else '') # A, B, ... A1, B1, ... for d in itertools.count() for c in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ') [root] = [k for k, v in dependents.items() if not v] # Our final results. These will change during fusion below indices = list(inputs[root].indices) new_axes = inputs[root].new_axes concatenate = inputs[root].concatenate dsk = dict(inputs[root].dsk) changed = True while changed: changed = False for i, (dep, ind) in enumerate(indices): if ind is None: continue if dep not in inputs: continue changed = True # Replace _n with dep name in existing tasks # (inc, _0) -> (inc, 'b') dsk = {k: subs(v, {blockwise_token(i): dep}) for k, v in dsk.items()} # Remove current input from input indices # [('a', 'i'), ('b', 'i')] -> [('a', 'i')] _, current_dep_indices = indices.pop(i) sub = {blockwise_token(i): blockwise_token(i - 1) for i in range(i + 1, len(indices) + 1)} dsk = subs(dsk, sub) # Change new input_indices to match give index from current computation # [('c', j')] -> [('c', 'i')] new_indices = inputs[dep].indices sub = dict(zip(inputs[dep].output_indices, current_dep_indices)) contracted = {x for _, j in new_indices if j is not None for x in j if x not in inputs[dep].output_indices} extra = dict(zip(contracted, new_index_iter)) sub.update(extra) new_indices = [(x, index_subs(j, sub)) for x, j in new_indices] # Update new_axes for k, v in inputs[dep].new_axes.items(): new_axes[sub[k]] = v # Bump new inputs up in list sub = {} for i, index in enumerate(new_indices): try: contains = index in indices except (ValueError, TypeError): contains = False if contains: # use old inputs if available sub[blockwise_token(i)] = blockwise_token(indices.index(index)) else: sub[blockwise_token(i)] = blockwise_token(len(indices)) indices.append(index) new_dsk = subs(inputs[dep].dsk, sub) # indices.extend(new_indices) dsk.update(new_dsk) indices = [(a, tuple(b) if isinstance(b, list) else b) for a, b in indices] # De-duplicate indices like [(a, ij), (b, i), (a, ij)] -> [(a, ij), (b, i)] # Make sure that we map everything else appropriately as we remove inputs new_indices = [] seen = {} sub = {} # like {_0: _0, _1: _0, _2: _1} for i, x in enumerate(indices): if x[1] is not None and x in seen: sub[i] = seen[x] else: if x[1] is not None: seen[x] = len(new_indices) sub[i] = len(new_indices) new_indices.append(x) sub = {blockwise_token(k): blockwise_token(v) for k, v in sub.items()} dsk = {k: subs(v, sub) for k, v in dsk.items()} indices_check = {k for k, v in indices if v is not None} numblocks = toolz.merge([inp.numblocks for inp in inputs.values()]) numblocks = {k: v for k, v in numblocks.items() if v is None or k in indices_check} out = Blockwise(root, inputs[root].output_indices, dsk, new_indices, numblocks=numblocks, new_axes=new_axes, concatenate=concatenate) return out def zero_broadcast_dimensions(lol, nblocks): """ >>> lol = [('x', 1, 0), ('x', 1, 1), ('x', 1, 2)] >>> nblocks = (4, 1, 2) # note singleton dimension in second place >>> lol = [[('x', 1, 0, 0), ('x', 1, 0, 1)], ... [('x', 1, 1, 0), ('x', 1, 1, 1)], ... [('x', 1, 2, 0), ('x', 1, 2, 1)]] >>> zero_broadcast_dimensions(lol, nblocks) # doctest: +NORMALIZE_WHITESPACE [[('x', 1, 0, 0), ('x', 1, 0, 1)], [('x', 1, 0, 0), ('x', 1, 0, 1)], [('x', 1, 0, 0), ('x', 1, 0, 1)]] See Also -------- lol_tuples """ f = lambda t: (t[0],) + tuple(0 if d == 1 else i for i, d in zip(t[1:], nblocks)) return utils.homogeneous_deepmap(f, lol) def broadcast_dimensions(argpairs, numblocks, sentinels=(1, (1,)), consolidate=None): """ Find block dimensions from arguments Parameters ---------- argpairs: iterable name, ijk index pairs numblocks: dict maps {name: number of blocks} sentinels: iterable (optional) values for singleton dimensions consolidate: func (optional) use this to reduce each set of common blocks into a smaller set Examples -------- >>> argpairs = [('x', 'ij'), ('y', 'ji')] >>> numblocks = {'x': (2, 3), 'y': (3, 2)} >>> broadcast_dimensions(argpairs, numblocks) {'i': 2, 'j': 3} Supports numpy broadcasting rules >>> argpairs = [('x', 'ij'), ('y', 'ij')] >>> numblocks = {'x': (2, 1), 'y': (1, 3)} >>> broadcast_dimensions(argpairs, numblocks) {'i': 2, 'j': 3} Works in other contexts too >>> argpairs = [('x', 'ij'), ('y', 'ij')] >>> d = {'x': ('Hello', 1), 'y': (1, (2, 3))} >>> broadcast_dimensions(argpairs, d) {'i': 'Hello', 'j': (2, 3)} """ # List like [('i', 2), ('j', 1), ('i', 1), ('j', 2)] argpairs2 = [(a, ind) for a, ind in argpairs if ind is not None] L = toolz.concat([zip(inds, dims) for (x, inds), (x, dims) in toolz.join(toolz.first, argpairs2, toolz.first, numblocks.items())]) g = toolz.groupby(0, L) g = dict((k, set([d for i, d in v])) for k, v in g.items()) g2 = dict((k, v - set(sentinels) if len(v) > 1 else v) for k, v in g.items()) if consolidate: return toolz.valmap(consolidate, g2) if g2 and not set(map(len, g2.values())) == set([1]): raise ValueError("Shapes do not align %s" % g) return toolz.valmap(toolz.first, g2)