You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
77 lines
1.9 KiB
77 lines
1.9 KiB
6 years ago
|
from __future__ import absolute_import, division, print_function
|
||
|
|
||
|
from functools import partial
|
||
|
from itertools import product
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
try:
|
||
|
from cytoolz import curry
|
||
|
except ImportError:
|
||
|
from toolz import curry
|
||
|
|
||
|
from ..base import tokenize
|
||
|
from ..utils import funcname
|
||
|
from .core import Array, normalize_chunks
|
||
|
|
||
|
|
||
|
def wrap_func_shape_as_first_arg(func, *args, **kwargs):
|
||
|
"""
|
||
|
Transform np creation function into blocked version
|
||
|
"""
|
||
|
if 'shape' not in kwargs:
|
||
|
shape, args = args[0], args[1:]
|
||
|
else:
|
||
|
shape = kwargs.pop('shape')
|
||
|
|
||
|
if isinstance(shape, np.ndarray):
|
||
|
shape = shape.tolist()
|
||
|
|
||
|
if not isinstance(shape, (tuple, list)):
|
||
|
shape = (shape,)
|
||
|
|
||
|
chunks = kwargs.pop('chunks', 'auto')
|
||
|
|
||
|
dtype = kwargs.pop('dtype', None)
|
||
|
if dtype is None:
|
||
|
dtype = func(shape, *args, **kwargs).dtype
|
||
|
dtype = np.dtype(dtype)
|
||
|
|
||
|
chunks = normalize_chunks(chunks, shape, dtype=dtype)
|
||
|
name = kwargs.pop('name', None)
|
||
|
|
||
|
name = name or funcname(func) + '-' + tokenize(func, shape, chunks, dtype, args, kwargs)
|
||
|
|
||
|
keys = product([name], *[range(len(bd)) for bd in chunks])
|
||
|
shapes = product(*chunks)
|
||
|
func = partial(func, dtype=dtype, **kwargs)
|
||
|
vals = ((func,) + (s,) + args for s in shapes)
|
||
|
|
||
|
dsk = dict(zip(keys, vals))
|
||
|
return Array(dsk, name, chunks, dtype=dtype)
|
||
|
|
||
|
|
||
|
@curry
|
||
|
def wrap(wrap_func, func, **kwargs):
|
||
|
f = partial(wrap_func, func, **kwargs)
|
||
|
template = """
|
||
|
Blocked variant of %(name)s
|
||
|
|
||
|
Follows the signature of %(name)s exactly except that it also requires a
|
||
|
keyword argument chunks=(...)
|
||
|
|
||
|
Original signature follows below.
|
||
|
"""
|
||
|
if func.__doc__ is not None:
|
||
|
f.__doc__ = template % {'name': func.__name__} + func.__doc__
|
||
|
f.__name__ = 'blocked_' + func.__name__
|
||
|
return f
|
||
|
|
||
|
|
||
|
w = wrap(wrap_func_shape_as_first_arg)
|
||
|
|
||
|
ones = w(np.ones, dtype='f8')
|
||
|
zeros = w(np.zeros, dtype='f8')
|
||
|
empty = w(np.empty, dtype='f8')
|
||
|
full = w(np.full)
|