You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ORPA-pyOpenRPA/WPy32-3720/python-3.7.2/Lib/site-packages/dask/array/optimization.py

301 lines
11 KiB

6 years ago
from __future__ import absolute_import, division, print_function
from operator import getitem
import numpy as np
from .core import getter, getter_nofancy, getter_inline
from ..blockwise import optimize_blockwise
from ..compatibility import zip_longest
from ..core import flatten, reverse_dict
from ..optimization import cull, fuse, inline_functions
from ..utils import ensure_dict
from ..highlevelgraph import HighLevelGraph
from numbers import Integral
# All get* functions the optimizations know about
GETTERS = (getter, getter_nofancy, getter_inline, getitem)
# These get* functions aren't ever completely removed from the graph,
# even if the index should be a no-op by numpy semantics. Some array-like's
# don't completely follow semantics, making indexing always necessary.
GETNOREMOVE = (getter, getter_nofancy)
def optimize(dsk, keys, fuse_keys=None, fast_functions=None,
inline_functions_fast_functions=(getter_inline,), rename_fused_keys=True,
**kwargs):
""" Optimize dask for array computation
1. Cull tasks not necessary to evaluate keys
2. Remove full slicing, e.g. x[:]
3. Inline fast functions like getitem and np.transpose
"""
keys = list(flatten(keys))
# High level stage optimization
if isinstance(dsk, HighLevelGraph):
dsk = optimize_blockwise(dsk, keys=keys)
# Low level task optimizations
dsk = ensure_dict(dsk)
if fast_functions is not None:
inline_functions_fast_functions = fast_functions
dsk2, dependencies = cull(dsk, keys)
hold = hold_keys(dsk2, dependencies)
dsk3, dependencies = fuse(dsk2, hold + keys + (fuse_keys or []),
dependencies, rename_keys=rename_fused_keys)
if inline_functions_fast_functions:
dsk4 = inline_functions(dsk3, keys, dependencies=dependencies,
fast_functions=inline_functions_fast_functions)
else:
dsk4 = dsk3
dsk5 = optimize_slices(dsk4)
return dsk5
def hold_keys(dsk, dependencies):
""" Find keys to avoid fusion
We don't want to fuse data present in the graph because it is easier to
serialize as a raw value.
We don't want to fuse chains after getitem/GETTERS because we want to
move around only small pieces of data, rather than the underlying arrays.
"""
dependents = reverse_dict(dependencies)
data = {k for k, v in dsk.items() if type(v) not in (tuple, str)}
hold_keys = list(data)
for dat in data:
deps = dependents[dat]
for dep in deps:
task = dsk[dep]
# If the task is a get* function, we walk up the chain, and stop
# when there's either more than one dependent, or the dependent is
# no longer a get* function or an alias. We then add the final
# key to the list of keys not to fuse.
if type(task) is tuple and task and task[0] in GETTERS:
try:
while len(dependents[dep]) == 1:
new_dep = next(iter(dependents[dep]))
new_task = dsk[new_dep]
# If the task is a get* or an alias, continue up the
# linear chain
if new_task[0] in GETTERS or new_task in dsk:
dep = new_dep
else:
break
except (IndexError, TypeError):
pass
hold_keys.append(dep)
return hold_keys
def optimize_slices(dsk):
""" Optimize slices
1. Fuse repeated slices, like x[5:][2:6] -> x[7:11]
2. Remove full slices, like x[:] -> x
See also:
fuse_slice_dict
"""
fancy_ind_types = (list, np.ndarray)
dsk = dsk.copy()
for k, v in dsk.items():
if type(v) is tuple and v[0] in GETTERS and len(v) in (3, 5):
if len(v) == 3:
get, a, a_index = v
# getter defaults to asarray=True, getitem is semantically False
a_asarray = get is not getitem
a_lock = None
else:
get, a, a_index, a_asarray, a_lock = v
while type(a) is tuple and a[0] in GETTERS and len(a) in (3, 5):
if len(a) == 3:
f2, b, b_index = a
b_asarray = f2 is not getitem
b_lock = None
else:
f2, b, b_index, b_asarray, b_lock = a
if a_lock and a_lock is not b_lock:
break
if (type(a_index) is tuple) != (type(b_index) is tuple):
break
if type(a_index) is tuple:
indices = b_index + a_index
if (len(a_index) != len(b_index) and
any(i is None for i in indices)):
break
if (f2 is getter_nofancy and
any(isinstance(i, fancy_ind_types) for i in indices)):
break
elif (f2 is getter_nofancy and
(type(a_index) in fancy_ind_types or
type(b_index) in fancy_ind_types)):
break
try:
c_index = fuse_slice(b_index, a_index)
# rely on fact that nested gets never decrease in
# strictness e.g. `(getter_nofancy, (getter, ...))` never
# happens
get = getter if f2 is getter_inline else f2
except NotImplementedError:
break
a, a_index, a_lock = b, c_index, b_lock
a_asarray |= b_asarray
# Skip the get call if not from from_array and nothing to do
if (get not in GETNOREMOVE and
((type(a_index) is slice and not a_index.start and
a_index.stop is None and a_index.step is None) or
(type(a_index) is tuple and
all(type(s) is slice and not s.start and s.stop is None and
s.step is None for s in a_index)))):
dsk[k] = a
elif get is getitem or (a_asarray and not a_lock):
# default settings are fine, drop the extra parameters Since we
# always fallback to inner `get` functions, `get is getitem`
# can only occur if all gets are getitem, meaning all
# parameters must be getitem defaults.
dsk[k] = (get, a, a_index)
else:
dsk[k] = (get, a, a_index, a_asarray, a_lock)
return dsk
def normalize_slice(s):
""" Replace Nones in slices with integers
>>> normalize_slice(slice(None, None, None))
slice(0, None, 1)
"""
start, stop, step = s.start, s.stop, s.step
if start is None:
start = 0
if step is None:
step = 1
if start < 0 or step < 0 or stop is not None and stop < 0:
raise NotImplementedError()
return slice(start, stop, step)
def check_for_nonfusible_fancy_indexing(fancy, normal):
# Check for fancy indexing and normal indexing, where the fancy
# indexed dimensions != normal indexed dimensions with integers. E.g.:
# disallow things like:
# x[:, [1, 2], :][0, :, :] -> x[0, [1, 2], :] or
# x[0, :, :][:, [1, 2], :] -> x[0, [1, 2], :]
for f, n in zip_longest(fancy, normal, fillvalue=slice(None)):
if type(f) is not list and isinstance(n, Integral):
raise NotImplementedError("Can't handle normal indexing with "
"integers and fancy indexing if the "
"integers and fancy indices don't "
"align with the same dimensions.")
def fuse_slice(a, b):
""" Fuse stacked slices together
Fuse a pair of repeated slices into a single slice:
>>> fuse_slice(slice(1000, 2000), slice(10, 15))
slice(1010, 1015, None)
This also works for tuples of slices
>>> fuse_slice((slice(100, 200), slice(100, 200, 10)),
... (slice(10, 15), [5, 2]))
(slice(110, 115, None), [150, 120])
And a variety of other interesting cases
>>> fuse_slice(slice(1000, 2000), 10) # integers
1010
>>> fuse_slice(slice(1000, 2000, 5), slice(10, 20, 2))
slice(1050, 1100, 10)
>>> fuse_slice(slice(1000, 2000, 5), [1, 2, 3]) # lists
[1005, 1010, 1015]
>>> fuse_slice(None, slice(None, None)) # doctest: +SKIP
None
"""
# None only works if the second side is a full slice
if a is None and isinstance(b, slice) and b == slice(None, None):
return None
# Replace None with 0 and one in start and step
if isinstance(a, slice):
a = normalize_slice(a)
if isinstance(b, slice):
b = normalize_slice(b)
if isinstance(a, slice) and isinstance(b, Integral):
if b < 0:
raise NotImplementedError()
return a.start + b * a.step
if isinstance(a, slice) and isinstance(b, slice):
start = a.start + a.step * b.start
if b.stop is not None:
stop = a.start + a.step * b.stop
else:
stop = None
if a.stop is not None:
if stop is not None:
stop = min(a.stop, stop)
else:
stop = a.stop
step = a.step * b.step
if step == 1:
step = None
return slice(start, stop, step)
if isinstance(b, list):
return [fuse_slice(a, bb) for bb in b]
if isinstance(a, list) and isinstance(b, (Integral, slice)):
return a[b]
if isinstance(a, tuple) and not isinstance(b, tuple):
b = (b,)
# If given two tuples walk through both, being mindful of uneven sizes
# and newaxes
if isinstance(a, tuple) and isinstance(b, tuple):
# Check for non-fusible cases with fancy-indexing
a_has_lists = any(isinstance(item, list) for item in a)
b_has_lists = any(isinstance(item, list) for item in b)
if a_has_lists and b_has_lists:
raise NotImplementedError("Can't handle multiple list indexing")
elif a_has_lists:
check_for_nonfusible_fancy_indexing(a, b)
elif b_has_lists:
check_for_nonfusible_fancy_indexing(b, a)
j = 0
result = list()
for i in range(len(a)):
# axis ceased to exist or we're out of b
if isinstance(a[i], Integral) or j == len(b):
result.append(a[i])
continue
while b[j] is None: # insert any Nones on the rhs
result.append(None)
j += 1
result.append(fuse_slice(a[i], b[j])) # Common case
j += 1
while j < len(b): # anything leftover on the right?
result.append(b[j])
j += 1
return tuple(result)
raise NotImplementedError()