You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
186 lines
6.0 KiB
186 lines
6.0 KiB
import pytest
|
|
distributed = pytest.importorskip('distributed')
|
|
|
|
from functools import partial
|
|
from operator import add
|
|
from tornado import gen
|
|
|
|
import dask
|
|
from dask import persist, delayed, compute
|
|
from dask.compatibility import get_named_args
|
|
from dask.delayed import Delayed
|
|
from dask.utils import tmpdir
|
|
from distributed.client import wait, Client
|
|
from distributed.utils_test import gen_cluster, inc, cluster, loop # noqa F401
|
|
|
|
|
|
if 'should_check_state' in get_named_args(gen_cluster):
|
|
gen_cluster = partial(gen_cluster, should_check_state=False)
|
|
cluster = partial(cluster, should_check_state=False)
|
|
|
|
|
|
def test_can_import_client():
|
|
from dask.distributed import Client # noqa: F401
|
|
|
|
|
|
@gen_cluster(client=True)
|
|
def test_persist(c, s, a, b):
|
|
x = delayed(inc)(1)
|
|
x2, = persist(x)
|
|
|
|
yield wait(x2)
|
|
assert x2.key in a.data or x2.key in b.data
|
|
|
|
y = delayed(inc)(10)
|
|
y2, one = persist(y, 1)
|
|
|
|
yield wait(y2)
|
|
assert y2.key in a.data or y2.key in b.data
|
|
|
|
|
|
def test_persist_nested(loop):
|
|
with cluster() as (s, [a, b]):
|
|
with Client(s['address'], loop=loop):
|
|
a = delayed(1) + 5
|
|
b = a + 1
|
|
c = a + 2
|
|
result = persist({'a': a, 'b': [1, 2, b]}, (c, 2), 4, [5])
|
|
assert isinstance(result[0]['a'], Delayed)
|
|
assert isinstance(result[0]['b'][2], Delayed)
|
|
assert isinstance(result[1][0], Delayed)
|
|
|
|
sol = ({'a': 6, 'b': [1, 2, 7]}, (8, 2), 4, [5])
|
|
assert compute(*result) == sol
|
|
|
|
res = persist([a, b], c, 4, [5], traverse=False)
|
|
assert res[0][0] is a
|
|
assert res[0][1] is b
|
|
assert res[1].compute() == 8
|
|
assert res[2:] == (4, [5])
|
|
|
|
|
|
def test_futures_to_delayed_dataframe(loop):
|
|
pd = pytest.importorskip('pandas')
|
|
dd = pytest.importorskip('dask.dataframe')
|
|
df = pd.DataFrame({'x': [1, 2, 3]})
|
|
with cluster() as (s, [a, b]):
|
|
with Client(s['address'], loop=loop) as c:
|
|
futures = c.scatter([df, df])
|
|
ddf = dd.from_delayed(futures)
|
|
dd.utils.assert_eq(ddf.compute(), pd.concat([df, df], axis=0))
|
|
|
|
with pytest.raises(TypeError):
|
|
ddf = dd.from_delayed([1, 2])
|
|
|
|
|
|
def test_futures_to_delayed_bag(loop):
|
|
db = pytest.importorskip('dask.bag')
|
|
L = [1, 2, 3]
|
|
with cluster() as (s, [a, b]):
|
|
with Client(s['address'], loop=loop) as c:
|
|
futures = c.scatter([L, L])
|
|
b = db.from_delayed(futures)
|
|
assert list(b) == L + L
|
|
|
|
|
|
def test_futures_to_delayed_array(loop):
|
|
da = pytest.importorskip('dask.array')
|
|
from dask.array.utils import assert_eq
|
|
np = pytest.importorskip('numpy')
|
|
x = np.arange(5)
|
|
with cluster() as (s, [a, b]):
|
|
with Client(s['address'], loop=loop) as c:
|
|
futures = c.scatter([x, x])
|
|
A = da.concatenate([da.from_delayed(f, shape=x.shape, dtype=x.dtype)
|
|
for f in futures], axis=0)
|
|
assert_eq(A.compute(), np.concatenate([x, x], axis=0))
|
|
|
|
|
|
@gen_cluster(client=True)
|
|
def test_local_get_with_distributed_active(c, s, a, b):
|
|
with dask.config.set(scheduler='sync'):
|
|
x = delayed(inc)(1).persist()
|
|
yield gen.sleep(0.01)
|
|
assert not s.tasks # scheduler hasn't done anything
|
|
|
|
x = delayed(inc)(2).persist(scheduler='sync') # noqa F841
|
|
yield gen.sleep(0.01)
|
|
assert not s.tasks # scheduler hasn't done anything
|
|
|
|
|
|
def test_to_hdf_distributed(loop):
|
|
from ..dataframe.io.tests.test_hdf import test_to_hdf
|
|
with cluster() as (s, [a, b]):
|
|
with distributed.Client(s['address'], loop=loop):
|
|
test_to_hdf()
|
|
|
|
|
|
@pytest.mark.xfail(reason='HDF not multi-process safe')
|
|
@pytest.mark.parametrize('npartitions', [1, 4, 10])
|
|
def test_to_hdf_scheduler_distributed(npartitions, loop):
|
|
from ..dataframe.io.tests.test_hdf import test_to_hdf_schedulers
|
|
with cluster() as (s, [a, b]):
|
|
with distributed.Client(s['address'], loop=loop):
|
|
test_to_hdf_schedulers(None, npartitions)
|
|
|
|
|
|
@gen_cluster(client=True)
|
|
def test_serializable_groupby_agg(c, s, a, b):
|
|
pd = pytest.importorskip('pandas')
|
|
dd = pytest.importorskip('dask.dataframe')
|
|
df = pd.DataFrame({'x': [1, 2, 3, 4], 'y': [1, 0, 1, 0]})
|
|
ddf = dd.from_pandas(df, npartitions=2)
|
|
|
|
result = ddf.groupby('y').agg('count')
|
|
|
|
yield c.compute(result)
|
|
|
|
|
|
def test_futures_in_graph(loop):
|
|
with cluster() as (s, [a, b]):
|
|
with Client(s['address'], loop=loop) as c:
|
|
x, y = delayed(1), delayed(2)
|
|
xx = delayed(add)(x, x)
|
|
yy = delayed(add)(y, y)
|
|
xxyy = delayed(add)(xx, yy)
|
|
|
|
xxyy2 = c.persist(xxyy)
|
|
xxyy3 = delayed(add)(xxyy2, 10)
|
|
|
|
assert xxyy3.compute(scheduler='dask.distributed') == ((1 + 1) + (2 + 2)) + 10
|
|
|
|
|
|
def test_zarr_distributed_roundtrip(loop):
|
|
da = pytest.importorskip('dask.array')
|
|
pytest.importorskip('zarr')
|
|
assert_eq = da.utils.assert_eq
|
|
with cluster() as (s, [a, b]):
|
|
with Client(s['address'], loop=loop):
|
|
with tmpdir() as d:
|
|
a = da.zeros((3, 3), chunks=(1, 1))
|
|
a.to_zarr(d)
|
|
a2 = da.from_zarr(d)
|
|
assert_eq(a, a2)
|
|
assert a2.chunks == a.chunks
|
|
|
|
|
|
def test_zarr_in_memory_distributed_err(loop):
|
|
da = pytest.importorskip('dask.array')
|
|
zarr = pytest.importorskip('zarr')
|
|
with cluster() as (s, [a, b]):
|
|
with Client(s['address'], loop=loop,
|
|
client_kwargs={'set_as_default': True}) as c:
|
|
with pytest.raises(RuntimeError):
|
|
c = (1, 1)
|
|
a = da.ones((3, 3), chunks=c)
|
|
z = zarr.zeros_like(a, chunks=c)
|
|
a.to_zarr(z)
|
|
|
|
|
|
def test_scheduler_equals_client(loop):
|
|
with cluster() as (s, [a, b]):
|
|
with Client(s['address'], loop=loop) as client:
|
|
x = delayed(lambda: 1)()
|
|
assert x.compute(scheduler=client) == 1
|
|
assert client.run_on_scheduler(lambda dask_scheduler: dask_scheduler.story(x.key))
|