You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ORPA-pyOpenRPA/Resources/WPy64-3720/python-3.7.2.amd64/Lib/site-packages/dask/array/tests/test_sparse.py

136 lines
3.5 KiB

import random
from distutils.version import LooseVersion
import numpy as np
import pytest
import dask.array as da
from dask.array.utils import assert_eq
sparse = pytest.importorskip('sparse')
if sparse:
# Test failures on older versions of Numba.
# Conda-Forge provides 0.35.0 on windows right now, causing failures like
# searchsorted() got an unexpected keyword argument 'side'
pytest.importorskip("numba", minversion="0.40.0")
if LooseVersion(np.__version__) < '1.11.2':
pytestmark = pytest.mark.skip
functions = [
lambda x: x,
lambda x: da.expm1(x),
lambda x: 2 * x,
lambda x: x / 2,
lambda x: x**2,
lambda x: x + x,
lambda x: x * x,
lambda x: x[0],
lambda x: x[:, 1],
lambda x: x[:1, None, 1:3],
lambda x: x.T,
lambda x: da.transpose(x, (1, 2, 0)),
lambda x: x.sum(),
lambda x: x.dot(np.arange(x.shape[-1])),
lambda x: x.dot(np.eye(x.shape[-1])),
lambda x: da.tensordot(x, np.ones(x.shape[:2]), axes=[(0, 1), (0, 1)]),
lambda x: x.sum(axis=0),
lambda x: x.max(axis=0),
lambda x: x.sum(axis=(1, 2)),
lambda x: x.astype(np.complex128),
lambda x: x.map_blocks(lambda x: x * 2),
lambda x: x.round(1),
lambda x: x.reshape((x.shape[0] * x.shape[1], x.shape[2])),
lambda x: abs(x),
lambda x: x > 0.5,
lambda x: x.rechunk((4, 4, 4)),
lambda x: x.rechunk((2, 2, 1)),
]
@pytest.mark.parametrize('func', functions)
def test_basic(func):
x = da.random.random((2, 3, 4), chunks=(1, 2, 2))
x[x < 0.8] = 0
y = x.map_blocks(sparse.COO.from_numpy)
xx = func(x)
yy = func(y)
assert_eq(xx, yy)
if yy.shape:
zz = yy.compute()
if not isinstance(zz, sparse.COO):
assert (zz != 1).sum() > np.prod(zz.shape) / 2 # mostly dense
def test_tensordot():
x = da.random.random((2, 3, 4), chunks=(1, 2, 2))
x[x < 0.8] = 0
y = da.random.random((4, 3, 2), chunks=(2, 2, 1))
y[y < 0.8] = 0
xx = x.map_blocks(sparse.COO.from_numpy)
yy = y.map_blocks(sparse.COO.from_numpy)
assert_eq(da.tensordot(x, y, axes=(2, 0)),
da.tensordot(xx, yy, axes=(2, 0)))
assert_eq(da.tensordot(x, y, axes=(1, 1)),
da.tensordot(xx, yy, axes=(1, 1)))
assert_eq(da.tensordot(x, y, axes=((1, 2), (1, 0))),
da.tensordot(xx, yy, axes=((1, 2), (1, 0))))
@pytest.mark.xfail(reason="upstream change", strict=False)
@pytest.mark.parametrize('func', functions)
def test_mixed_concatenate(func):
x = da.random.random((2, 3, 4), chunks=(1, 2, 2))
y = da.random.random((2, 3, 4), chunks=(1, 2, 2))
y[y < 0.8] = 0
yy = y.map_blocks(sparse.COO.from_numpy)
d = da.concatenate([x, y], axis=0)
s = da.concatenate([x, yy], axis=0)
dd = func(d)
ss = func(s)
assert_eq(dd, ss)
@pytest.mark.xfail(reason="upstream change", strict=False)
@pytest.mark.parametrize('func', functions)
def test_mixed_random(func):
d = da.random.random((4, 3, 4), chunks=(1, 2, 2))
d[d < 0.7] = 0
fn = lambda x: sparse.COO.from_numpy(x) if random.random() < 0.5 else x
s = d.map_blocks(fn)
dd = func(d)
ss = func(s)
assert_eq(dd, ss)
@pytest.mark.xfail(reason="upstream change", strict=False)
def test_mixed_output_type():
y = da.random.random((10, 10), chunks=(5, 5))
y[y < 0.8] = 0
y = y.map_blocks(sparse.COO.from_numpy)
x = da.zeros((10, 1), chunks=(5, 1))
z = da.concatenate([x, y], axis=1)
assert z.shape == (10, 11)
zz = z.compute()
assert isinstance(zz, sparse.COO)
assert zz.nnz == y.compute().nnz