import numbers import warnings import toolz from .. import base, utils from ..delayed import unpack_collections from ..highlevelgraph import HighLevelGraph from ..blockwise import blockwise as core_blockwise def blockwise(func, out_ind, *args, **kwargs): """ Tensor operation: Generalized inner and outer products A broad class of blocked algorithms and patterns can be specified with a concise multi-index notation. The ``blockwise`` function applies an in-memory function across multiple blocks of multiple inputs in a variety of ways. Many dask.array operations are special cases of blockwise including elementwise, broadcasting, reductions, tensordot, and transpose. Parameters ---------- func : callable Function to apply to individual tuples of blocks out_ind : iterable Block pattern of the output, something like 'ijk' or (1, 2, 3) *args : sequence of Array, index pairs Sequence like (x, 'ij', y, 'jk', z, 'i') **kwargs : dict Extra keyword arguments to pass to function dtype : np.dtype Datatype of resulting array. concatenate : bool, keyword only If true concatenate arrays along dummy indices, else provide lists adjust_chunks : dict Dictionary mapping index to function to be applied to chunk sizes new_axes : dict, keyword only New indexes and their dimension lengths Examples -------- 2D embarrassingly parallel operation from two arrays, x, and y. >>> z = blockwise(operator.add, 'ij', x, 'ij', y, 'ij', dtype='f8') # z = x + y # doctest: +SKIP Outer product multiplying x by y, two 1-d vectors >>> z = blockwise(operator.mul, 'ij', x, 'i', y, 'j', dtype='f8') # doctest: +SKIP z = x.T >>> z = blockwise(np.transpose, 'ji', x, 'ij', dtype=x.dtype) # doctest: +SKIP The transpose case above is illustrative because it does same transposition both on each in-memory block by calling ``np.transpose`` and on the order of the blocks themselves, by switching the order of the index ``ij -> ji``. We can compose these same patterns with more variables and more complex in-memory functions z = X + Y.T >>> z = blockwise(lambda x, y: x + y.T, 'ij', x, 'ij', y, 'ji', dtype='f8') # doctest: +SKIP Any index, like ``i`` missing from the output index is interpreted as a contraction (note that this differs from Einstein convention; repeated indices do not imply contraction.) In the case of a contraction the passed function should expect an iterable of blocks on any array that holds that index. To receive arrays concatenated along contracted dimensions instead pass ``concatenate=True``. Inner product multiplying x by y, two 1-d vectors >>> def sequence_dot(x_blocks, y_blocks): ... result = 0 ... for x, y in zip(x_blocks, y_blocks): ... result += x.dot(y) ... return result >>> z = blockwise(sequence_dot, '', x, 'i', y, 'i', dtype='f8') # doctest: +SKIP Add new single-chunk dimensions with the ``new_axes=`` keyword, including the length of the new dimension. New dimensions will always be in a single chunk. >>> def f(x): ... return x[:, None] * np.ones((1, 5)) >>> z = blockwise(f, 'az', x, 'a', new_axes={'z': 5}, dtype=x.dtype) # doctest: +SKIP If the applied function changes the size of each chunk you can specify this with a ``adjust_chunks={...}`` dictionary holding a function for each index that modifies the dimension size in that index. >>> def double(x): ... return np.concatenate([x, x]) >>> y = blockwise(double, 'ij', x, 'ij', ... adjust_chunks={'i': lambda n: 2 * n}, dtype=x.dtype) # doctest: +SKIP Include literals by indexing with None >>> y = blockwise(add, 'ij', x, 'ij', 1234, None, dtype=x.dtype) # doctest: +SKIP See Also -------- top - dict formulation of this function, contains most logic """ out = kwargs.pop('name', None) # May be None at this point token = kwargs.pop('token', None) dtype = kwargs.pop('dtype', None) adjust_chunks = kwargs.pop('adjust_chunks', None) new_axes = kwargs.pop('new_axes', {}) align_arrays = kwargs.pop('align_arrays', True) # Input Validation if len(set(out_ind)) != len(out_ind): raise ValueError("Repeated elements not allowed in output index", [k for k, v in toolz.frequencies(out_ind).items() if v > 1]) new = (set(out_ind) - {a for arg in args[1::2] if arg is not None for a in arg} - set(new_axes or ())) if new: raise ValueError("Unknown dimension", new) from .core import Array, unify_chunks, normalize_arg if dtype is None: raise ValueError("Must specify dtype of output array") if align_arrays: chunkss, arrays = unify_chunks(*args) else: arg, ind = max([(a, i) for (a, i) in toolz.partition(2, args) if i is not None], key=lambda ai: len(ai[1])) chunkss = dict(zip(ind, arg.chunks)) arrays = args[::2] for k, v in new_axes.items(): if not isinstance(v, tuple): v = (v,) chunkss[k] = v arginds = list(zip(arrays, args[1::2])) for arg, ind in arginds: if hasattr(arg, 'ndim') and hasattr(ind, '__len__') and arg.ndim != len(ind): raise ValueError("Index string %s does not match array dimension %d" % (ind, arg.ndim)) numblocks = {a.name: a.numblocks for a, ind in arginds if ind is not None} dependencies = [] arrays = [] # Normalize arguments argindsstr = [] for a, ind in arginds: if ind is None: a = normalize_arg(a) a, collections = unpack_collections(a) dependencies.extend(collections) else: arrays.append(a) a = a.name argindsstr.extend((a, ind)) # Normalize keyword arguments kwargs2 = {} for k, v in kwargs.items(): v = normalize_arg(v) v, collections = unpack_collections(v) dependencies.extend(collections) kwargs2[k] = v # Finish up the name if not out: out = '%s-%s' % (token or utils.funcname(func).strip('_'), base.tokenize(func, out_ind, argindsstr, dtype, **kwargs)) graph = core_blockwise(func, out, out_ind, *argindsstr, numblocks=numblocks, dependencies=dependencies, new_axes=new_axes, **kwargs2) graph = HighLevelGraph.from_collections(out, graph, dependencies=arrays + dependencies) chunks = [chunkss[i] for i in out_ind] if adjust_chunks: for i, ind in enumerate(out_ind): if ind in adjust_chunks: if callable(adjust_chunks[ind]): chunks[i] = tuple(map(adjust_chunks[ind], chunks[i])) elif isinstance(adjust_chunks[ind], numbers.Integral): chunks[i] = tuple(adjust_chunks[ind] for _ in chunks[i]) elif isinstance(adjust_chunks[ind], (tuple, list)): chunks[i] = tuple(adjust_chunks[ind]) else: raise NotImplementedError( "adjust_chunks values must be callable, int, or tuple") chunks = tuple(chunks) return Array(graph, out, chunks, dtype=dtype) def atop(*args, **kwargs): warnings.warn("The da.atop function has moved to da.blockwise") return blockwise(*args, **kwargs)