from __future__ import absolute_import, division, print_function from operator import getitem from itertools import product from numbers import Integral from toolz import merge, pipe, concat, partial from toolz.curried import map from . import chunk, wrap from .core import Array, map_blocks, concatenate, concatenate3, reshapelist from ..highlevelgraph import HighLevelGraph from ..base import tokenize from ..core import flatten from ..utils import concrete def fractional_slice(task, axes): """ >>> fractional_slice(('x', 5.1), {0: 2}) # doctest: +SKIP (getitem, ('x', 6), (slice(0, 2),)) >>> fractional_slice(('x', 3, 5.1), {0: 2, 1: 3}) # doctest: +SKIP (getitem, ('x', 3, 5), (slice(None, None, None), slice(-3, None))) >>> fractional_slice(('x', 2.9, 5.1), {0: 2, 1: 3}) # doctest: +SKIP (getitem, ('x', 3, 5), (slice(0, 2), slice(-3, None))) """ rounded = (task[0],) + tuple(int(round(i)) for i in task[1:]) index = [] for i, (t, r) in enumerate(zip(task[1:], rounded[1:])): depth = axes.get(i, 0) if t == r: index.append(slice(None, None, None)) elif t < r: index.append(slice(0, depth)) elif t > r and depth == 0: index.append(slice(0, 0)) else: index.append(slice(-depth, None)) index = tuple(index) if all(ind == slice(None, None, None) for ind in index): return task else: return (getitem, rounded, index) def expand_key(k, dims, name=None, axes=None): """ Get all neighboring keys around center Parameters ---------- k: tuple They key around which to generate new keys dims: Sequence[int] The number of chunks in each dimension name: Option[str] The name to include in the output keys, or none to include no name axes: Dict[int, int] The axes active in the expansion. We don't expand on non-active axes Examples -------- >>> expand_key(('x', 2, 3), dims=[5, 5], name='y', axes={0: 1, 1: 1}) # doctest: +NORMALIZE_WHITESPACE [[('y', 1.1, 2.1), ('y', 1.1, 3), ('y', 1.1, 3.9)], [('y', 2, 2.1), ('y', 2, 3), ('y', 2, 3.9)], [('y', 2.9, 2.1), ('y', 2.9, 3), ('y', 2.9, 3.9)]] >>> expand_key(('x', 0, 4), dims=[5, 5], name='y', axes={0: 1, 1: 1}) # doctest: +NORMALIZE_WHITESPACE [[('y', 0, 3.1), ('y', 0, 4)], [('y', 0.9, 3.1), ('y', 0.9, 4)]] """ def inds(i, ind): rv = [] if ind - 0.9 > 0: rv.append(ind - 0.9) rv.append(ind) if ind + 0.9 < dims[i] - 1: rv.append(ind + 0.9) return rv shape = [] for i, ind in enumerate(k[1:]): num = 1 if ind > 0: num += 1 if ind < dims[i] - 1: num += 1 shape.append(num) args = [inds(i, ind) if axes.get(i, 0) else [ind] for i, ind in enumerate(k[1:])] if name is not None: args = [[name]] + args seq = list(product(*args)) shape2 = [d if axes.get(i, 0) else 1 for i, d in enumerate(shape)] result = reshapelist(shape2, seq) return result def overlap_internal(x, axes): """ Share boundaries between neighboring blocks Parameters ---------- x: da.Array A dask array axes: dict The size of the shared boundary per axis The axes input informs how many cells to overlap between neighboring blocks {0: 2, 2: 5} means share two cells in 0 axis, 5 cells in 2 axis """ dims = list(map(len, x.chunks)) expand_key2 = partial(expand_key, dims=dims, axes=axes) # Make keys for each of the surrounding sub-arrays interior_keys = pipe(x.__dask_keys__(), flatten, map(expand_key2), map(flatten), concat, list) name = 'overlap-' + tokenize(x, axes) getitem_name = 'getitem-' + tokenize(x, axes) interior_slices = {} overlap_blocks = {} for k in interior_keys: frac_slice = fractional_slice((x.name,) + k, axes) if (x.name,) + k != frac_slice: interior_slices[(getitem_name,) + k] = frac_slice else: interior_slices[(getitem_name,) + k] = (x.name,) + k overlap_blocks[(name,) + k] = (concatenate3, (concrete, expand_key2((None,) + k, name=getitem_name))) chunks = [] for i, bds in enumerate(x.chunks): if len(bds) == 1: chunks.append(bds) else: left = [bds[0] + axes.get(i, 0)] right = [bds[-1] + axes.get(i, 0)] mid = [] for bd in bds[1:-1]: mid.append(bd + axes.get(i, 0) * 2) chunks.append(left + mid + right) dsk = merge(interior_slices, overlap_blocks) graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x]) return Array(graph, name, chunks, dtype=x.dtype) def trim_overlap(x, depth, boundary=None): """Trim sides from each block. This couples well with the ``map_overlap`` operation which may leave excess data on each block. See also -------- dask.array.overlap.map_overlap """ # parameter to be passed to trim_internal axes = coerce_depth(x.ndim, depth) boundary2 = coerce_boundary(x.ndim, boundary) return trim_internal(x, axes=axes, boundary=boundary2) def trim_internal(x, axes, boundary=None): """ Trim sides from each block This couples well with the overlap operation, which may leave excess data on each block See also -------- dask.array.chunk.trim dask.array.map_blocks """ boundary = coerce_boundary(x.ndim, boundary) olist = [] for i, bd in enumerate(x.chunks): bdy = boundary.get(i, 'none') ilist = [] for j, d in enumerate(bd): if bdy != 'none': d = d - axes.get(i, 0) * 2 else: d = d - axes.get(i, 0) if j != 0 else d d = d - axes.get(i, 0) if j != len(bd) - 1 else d ilist.append(d) olist.append(tuple(ilist)) chunks = tuple(olist) return map_blocks(partial(_trim, axes=axes, boundary=boundary), x, chunks=chunks, dtype=x.dtype) def _trim(x, axes, boundary, block_info): """Similar to dask.array.chunk.trim but requires one to specificy the boundary condition. ``axes``, and ``boundary`` are assumed to have been coerced. """ axes = [axes.get(i, 0) for i in range(x.ndim)] axes_back = (-ax if ax else None for ax in axes) trim_front = ( 0 if (chunk_location == 0 and boundary.get(i, 'none') == 'none') else ax for i, (chunk_location, ax) in enumerate( zip(block_info[0]['chunk-location'], axes))) trim_back = ( None if (chunk_location == chunks - 1 and boundary.get(i, 'none') == 'none') else ax for i, (chunks, chunk_location, ax) in enumerate(zip( block_info[0]['num-chunks'], block_info[0]['chunk-location'], axes_back))) return x[tuple(slice(front, back) for front, back in zip(trim_front, trim_back))] def periodic(x, axis, depth): """ Copy a slice of an array around to its other side Useful to create periodic boundary conditions for overlap """ left = ((slice(None, None, None),) * axis + (slice(0, depth),) + (slice(None, None, None),) * (x.ndim - axis - 1)) right = ((slice(None, None, None),) * axis + (slice(-depth, None),) + (slice(None, None, None),) * (x.ndim - axis - 1)) l = x[left] r = x[right] l, r = _remove_overlap_boundaries(l, r, axis, depth) return concatenate([r, x, l], axis=axis) def reflect(x, axis, depth): """ Reflect boundaries of array on the same side This is the converse of ``periodic`` """ if depth == 1: left = ((slice(None, None, None),) * axis + (slice(0, 1),) + (slice(None, None, None),) * (x.ndim - axis - 1)) else: left = ((slice(None, None, None),) * axis + (slice(depth - 1, None, -1),) + (slice(None, None, None),) * (x.ndim - axis - 1)) right = ((slice(None, None, None),) * axis + (slice(-1, -depth - 1, -1),) + (slice(None, None, None),) * (x.ndim - axis - 1)) l = x[left] r = x[right] l, r = _remove_overlap_boundaries(l, r, axis, depth) return concatenate([l, x, r], axis=axis) def nearest(x, axis, depth): """ Each reflect each boundary value outwards This mimics what the skimage.filters.gaussian_filter(... mode="nearest") does. """ left = ((slice(None, None, None),) * axis + (slice(0, 1),) + (slice(None, None, None),) * (x.ndim - axis - 1)) right = ((slice(None, None, None),) * axis + (slice(-1, -2, -1),) + (slice(None, None, None),) * (x.ndim - axis - 1)) l = concatenate([x[left]] * depth, axis=axis) r = concatenate([x[right]] * depth, axis=axis) l, r = _remove_overlap_boundaries(l, r, axis, depth) return concatenate([l, x, r], axis=axis) def constant(x, axis, depth, value): """ Add constant slice to either side of array """ chunks = list(x.chunks) chunks[axis] = (depth,) c = wrap.full(tuple(map(sum, chunks)), value, chunks=tuple(chunks), dtype=x.dtype) return concatenate([c, x, c], axis=axis) def _remove_overlap_boundaries(l, r, axis, depth): lchunks = list(l.chunks) lchunks[axis] = (depth,) rchunks = list(r.chunks) rchunks[axis] = (depth,) l = l.rechunk(tuple(lchunks)) r = r.rechunk(tuple(rchunks)) return l, r def boundaries(x, depth=None, kind=None): """ Add boundary conditions to an array before overlaping See Also -------- periodic constant """ if not isinstance(kind, dict): kind = dict((i, kind) for i in range(x.ndim)) if not isinstance(depth, dict): depth = dict((i, depth) for i in range(x.ndim)) for i in range(x.ndim): d = depth.get(i, 0) if d == 0: continue this_kind = kind.get(i, 'none') if this_kind == 'none': continue elif this_kind == 'periodic': x = periodic(x, i, d) elif this_kind == 'reflect': x = reflect(x, i, d) elif this_kind == 'nearest': x = nearest(x, i, d) elif i in kind: x = constant(x, i, d, kind[i]) return x def overlap(x, depth, boundary): """ Share boundaries between neighboring blocks Parameters ---------- x: da.Array A dask array depth: dict The size of the shared boundary per axis boundary: dict The boundary condition on each axis. Options are 'reflect', 'periodic', 'nearest', 'none', or an array value. Such a value will fill the boundary with that value. The depth input informs how many cells to overlap between neighboring blocks ``{0: 2, 2: 5}`` means share two cells in 0 axis, 5 cells in 2 axis. Axes missing from this input will not be overlapped. Examples -------- >>> import numpy as np >>> import dask.array as da >>> x = np.arange(64).reshape((8, 8)) >>> d = da.from_array(x, chunks=(4, 4)) >>> d.chunks ((4, 4), (4, 4)) >>> g = da.overlap.overlap(d, depth={0: 2, 1: 1}, ... boundary={0: 100, 1: 'reflect'}) >>> g.chunks ((8, 8), (6, 6)) >>> np.array(g) array([[100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100], [100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100], [ 0, 0, 1, 2, 3, 4, 3, 4, 5, 6, 7, 7], [ 8, 8, 9, 10, 11, 12, 11, 12, 13, 14, 15, 15], [ 16, 16, 17, 18, 19, 20, 19, 20, 21, 22, 23, 23], [ 24, 24, 25, 26, 27, 28, 27, 28, 29, 30, 31, 31], [ 32, 32, 33, 34, 35, 36, 35, 36, 37, 38, 39, 39], [ 40, 40, 41, 42, 43, 44, 43, 44, 45, 46, 47, 47], [ 16, 16, 17, 18, 19, 20, 19, 20, 21, 22, 23, 23], [ 24, 24, 25, 26, 27, 28, 27, 28, 29, 30, 31, 31], [ 32, 32, 33, 34, 35, 36, 35, 36, 37, 38, 39, 39], [ 40, 40, 41, 42, 43, 44, 43, 44, 45, 46, 47, 47], [ 48, 48, 49, 50, 51, 52, 51, 52, 53, 54, 55, 55], [ 56, 56, 57, 58, 59, 60, 59, 60, 61, 62, 63, 63], [100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100], [100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100]]) """ depth2 = coerce_depth(x.ndim, depth) boundary2 = coerce_boundary(x.ndim, boundary) # is depth larger than chunk size? depth_values = [depth2.get(i, 0) for i in range(x.ndim)] for d, c in zip(depth_values, x.chunks): if d > min(c): raise ValueError("The overlapping depth %d is larger than your\n" "smallest chunk size %d. Rechunk your array\n" "with a larger chunk size or a chunk size that\n" "more evenly divides the shape of your array." % (d, min(c))) x2 = boundaries(x, depth2, boundary2) x3 = overlap_internal(x2, depth2) trim = dict((k, v * 2 if boundary2.get(k, 'none') != 'none' else 0) for k, v in depth2.items()) x4 = chunk.trim(x3, trim) return x4 def add_dummy_padding(x, depth, boundary): """ Pads an array which has 'none' as the boundary type. Used to simplify trimming arrays which use 'none'. >>> import dask.array as da >>> x = da.arange(6, chunks=3) >>> add_dummy_padding(x, {0: 1}, {0: 'none'}).compute() # doctest: +NORMALIZE_WHITESPACE array([..., 0, 1, 2, 3, 4, 5, ...]) """ for k, v in boundary.items(): d = depth.get(k, 0) if v == 'none' and d > 0: empty_shape = list(x.shape) empty_shape[k] = d empty_chunks = list(x.chunks) empty_chunks[k] = (d,) empty = wrap.empty(empty_shape, chunks=empty_chunks, dtype=x.dtype) out_chunks = list(x.chunks) ax_chunks = list(out_chunks[k]) ax_chunks[0] += d ax_chunks[-1] += d out_chunks[k] = tuple(ax_chunks) x = concatenate([empty, x, empty], axis=k) x = x.rechunk(out_chunks) return x def map_overlap(x, func, depth, boundary=None, trim=True, **kwargs): """ Map a function over blocks of the array with some overlap We share neighboring zones between blocks of the array, then map a function, then trim away the neighboring strips. Parameters ---------- func: function The function to apply to each extended block depth: int, tuple, or dict The number of elements that each block should share with its neighbors If a tuple or dict then this can be different per axis boundary: str, tuple, dict How to handle the boundaries. Values include 'reflect', 'periodic', 'nearest', 'none', or any constant value like 0 or np.nan trim: bool Whether or not to trim ``depth`` elements from each block after calling the map function. Set this to False if your mapping function already does this for you **kwargs: Other keyword arguments valid in ``map_blocks`` Examples -------- >>> import numpy as np >>> import dask.array as da >>> x = np.array([1, 1, 2, 3, 3, 3, 2, 1, 1]) >>> x = da.from_array(x, chunks=5) >>> def derivative(x): ... return x - np.roll(x, 1) >>> y = x.map_overlap(derivative, depth=1, boundary=0) >>> y.compute() array([ 1, 0, 1, 1, 0, 0, -1, -1, 0]) >>> x = np.arange(16).reshape((4, 4)) >>> d = da.from_array(x, chunks=(2, 2)) >>> d.map_overlap(lambda x: x + x.size, depth=1).compute() array([[16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31]]) >>> func = lambda x: x + x.size >>> depth = {0: 1, 1: 1} >>> boundary = {0: 'reflect', 1: 'none'} >>> d.map_overlap(func, depth, boundary).compute() # doctest: +NORMALIZE_WHITESPACE array([[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27]]) """ depth2 = coerce_depth(x.ndim, depth) boundary2 = coerce_boundary(x.ndim, boundary) assert all(type(c) is int for cc in x.chunks for c in cc) g = overlap(x, depth=depth2, boundary=boundary2) assert all(type(c) is int for cc in g.chunks for c in cc) g2 = g.map_blocks(func, **kwargs) assert all(type(c) is int for cc in g2.chunks for c in cc) if trim: return trim_internal(g2, depth2, boundary2) else: return g2 def coerce_depth(ndim, depth): if isinstance(depth, Integral): depth = (depth,) * ndim if isinstance(depth, tuple): depth = dict(zip(range(ndim), depth)) return depth def coerce_boundary(ndim, boundary): if boundary is None: boundary = 'reflect' if not isinstance(boundary, (tuple, dict)): boundary = (boundary,) * ndim if isinstance(boundary, tuple): boundary = dict(zip(range(ndim), boundary)) return boundary