You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ORPA-pyOpenRPA/Resources/WPy64-3720/python-3.7.2.amd64/Lib/site-packages/dask/array/tests/test_cupy.py

58 lines
1.7 KiB

import numpy as np
import pytest
import dask.array as da
from dask.array.utils import assert_eq
cupy = pytest.importorskip('cupy')
functions = [
lambda x: x,
pytest.mark.xfail(lambda x: da.expm1(x), reason="expm1 isn't a proper ufunc"),
lambda x: 2 * x,
lambda x: x / 2,
lambda x: x**2,
lambda x: x + x,
lambda x: x * x,
lambda x: x[0],
lambda x: x[:, 1],
lambda x: x[:1, None, 1:3],
lambda x: x.T,
lambda x: da.transpose(x, (1, 2, 0)),
lambda x: x.sum(),
pytest.mark.xfail(lambda x: x.dot(np.arange(x.shape[-1])), reason='cupy.dot(numpy) fails'),
pytest.mark.xfail(lambda x: x.dot(np.eye(x.shape[-1])), reason='cupy.dot(numpy) fails'),
pytest.mark.xfail(lambda x: da.tensordot(x, np.ones(x.shape[:2]), axes=[(0, 1), (0, 1)]),
reason='cupy.dot(numpy) fails'),
lambda x: x.sum(axis=0),
lambda x: x.max(axis=0),
lambda x: x.sum(axis=(1, 2)),
lambda x: x.astype(np.complex128),
lambda x: x.map_blocks(lambda x: x * 2),
pytest.mark.xfail(lambda x: x.round(1), reason="cupy doesn't support round"),
lambda x: x.reshape((x.shape[0] * x.shape[1], x.shape[2])),
lambda x: abs(x),
lambda x: x > 0.5,
lambda x: x.rechunk((4, 4, 4)),
lambda x: x.rechunk((2, 2, 1)),
lambda x: da.einsum("ijk,ijk", x, x)
]
@pytest.mark.parametrize('func', functions)
def test_basic(func):
c = cupy.random.random((2, 3, 4))
n = c.get()
dc = da.from_array(c, chunks=(1, 2, 2), asarray=False)
dn = da.from_array(n, chunks=(1, 2, 2))
ddc = func(dc)
ddn = func(dn)
assert_eq(ddc, ddn)
if ddc.shape:
result = ddc.compute(scheduler='single-threaded')
assert isinstance(result, cupy.ndarray)