You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
241 lines
6.4 KiB
241 lines
6.4 KiB
6 years ago
|
from collections import namedtuple
|
||
|
|
||
|
import pytest
|
||
|
import pickle
|
||
|
|
||
|
from dask.utils_test import GetFunctionTestMixin, inc, add
|
||
|
from dask import core
|
||
|
from dask.core import (istask, get_dependencies, get_deps, flatten, subs,
|
||
|
preorder_traversal, literal, quote, has_tasks)
|
||
|
|
||
|
|
||
|
def contains(a, b):
|
||
|
"""
|
||
|
|
||
|
>>> contains({'x': 1, 'y': 2}, {'x': 1})
|
||
|
True
|
||
|
>>> contains({'x': 1, 'y': 2}, {'z': 3})
|
||
|
False
|
||
|
"""
|
||
|
return all(a.get(k) == v for k, v in b.items())
|
||
|
|
||
|
|
||
|
def test_istask():
|
||
|
assert istask((inc, 1))
|
||
|
assert not istask(1)
|
||
|
assert not istask((1, 2))
|
||
|
f = namedtuple('f', ['x', 'y'])
|
||
|
assert not istask(f(sum, 2))
|
||
|
|
||
|
|
||
|
def test_has_tasks():
|
||
|
dsk = {'a': [1, 2, 3],
|
||
|
'b': 'a',
|
||
|
'c': [1, (inc, 1)],
|
||
|
'd': [(sum, 'a')],
|
||
|
'e': ['a', 'b'],
|
||
|
'f': [['a', 'b'], 2, 3]}
|
||
|
assert not has_tasks(dsk, dsk['a'])
|
||
|
assert has_tasks(dsk, dsk['b'])
|
||
|
assert has_tasks(dsk, dsk['c'])
|
||
|
assert has_tasks(dsk, dsk['d'])
|
||
|
assert has_tasks(dsk, dsk['e'])
|
||
|
assert has_tasks(dsk, dsk['f'])
|
||
|
|
||
|
|
||
|
def test_preorder_traversal():
|
||
|
t = (add, 1, 2)
|
||
|
assert list(preorder_traversal(t)) == [add, 1, 2]
|
||
|
t = (add, (add, 1, 2), (add, 3, 4))
|
||
|
assert list(preorder_traversal(t)) == [add, add, 1, 2, add, 3, 4]
|
||
|
t = (add, (sum, [1, 2]), 3)
|
||
|
assert list(preorder_traversal(t)) == [add, sum, list, 1, 2, 3]
|
||
|
|
||
|
|
||
|
class TestGet(GetFunctionTestMixin):
|
||
|
get = staticmethod(core.get)
|
||
|
|
||
|
|
||
|
def test_GetFunctionTestMixin_class():
|
||
|
class TestCustomGetFail(GetFunctionTestMixin):
|
||
|
get = staticmethod(lambda x, y: 1)
|
||
|
|
||
|
custom_testget = TestCustomGetFail()
|
||
|
pytest.raises(AssertionError, custom_testget.test_get)
|
||
|
|
||
|
class TestCustomGetPass(GetFunctionTestMixin):
|
||
|
get = staticmethod(core.get)
|
||
|
|
||
|
custom_testget = TestCustomGetPass()
|
||
|
custom_testget.test_get()
|
||
|
|
||
|
|
||
|
def test_get_dependencies_nested():
|
||
|
dsk = {'x': 1, 'y': 2,
|
||
|
'z': (add, (inc, [['x']]), 'y')}
|
||
|
|
||
|
assert get_dependencies(dsk, 'z') == set(['x', 'y'])
|
||
|
assert sorted(get_dependencies(dsk, 'z', as_list=True)) == ['x', 'y']
|
||
|
|
||
|
|
||
|
def test_get_dependencies_empty():
|
||
|
dsk = {'x': (inc,)}
|
||
|
assert get_dependencies(dsk, 'x') == set()
|
||
|
assert get_dependencies(dsk, 'x', as_list=True) == []
|
||
|
|
||
|
|
||
|
def test_get_dependencies_list():
|
||
|
dsk = {'x': 1, 'y': 2, 'z': ['x', [(inc, 'y')]]}
|
||
|
assert get_dependencies(dsk, 'z') == set(['x', 'y'])
|
||
|
assert sorted(get_dependencies(dsk, 'z', as_list=True)) == ['x', 'y']
|
||
|
|
||
|
|
||
|
def test_get_dependencies_task():
|
||
|
dsk = {'x': 1, 'y': 2, 'z': ['x', [(inc, 'y')]]}
|
||
|
assert get_dependencies(dsk, task=(inc, 'x')) == set(['x'])
|
||
|
assert get_dependencies(dsk, task=(inc, 'x'), as_list=True) == ['x']
|
||
|
|
||
|
|
||
|
def test_get_dependencies_nothing():
|
||
|
with pytest.raises(ValueError):
|
||
|
get_dependencies({})
|
||
|
|
||
|
|
||
|
def test_get_dependencies_many():
|
||
|
dsk = {'a': [1, 2, 3],
|
||
|
'b': 'a',
|
||
|
'c': [1, (inc, 1)],
|
||
|
'd': [(sum, 'c')],
|
||
|
'e': ['a', 'b', 'zzz'],
|
||
|
'f': [['a', 'b'], 2, 3]}
|
||
|
|
||
|
tasks = [dsk[k] for k in ('d', 'f')]
|
||
|
s = get_dependencies(dsk, task=tasks)
|
||
|
assert s == {'a', 'b', 'c'}
|
||
|
s = get_dependencies(dsk, task=tasks, as_list=True)
|
||
|
assert sorted(s) == ['a', 'b', 'c']
|
||
|
|
||
|
s = get_dependencies(dsk, task=[])
|
||
|
assert s == set()
|
||
|
s = get_dependencies(dsk, task=[], as_list=True)
|
||
|
assert s == []
|
||
|
|
||
|
|
||
|
def test_get_deps():
|
||
|
"""
|
||
|
>>> dsk = {'a': 1, 'b': (inc, 'a'), 'c': (inc, 'b')}
|
||
|
>>> dependencies, dependents = get_deps(dsk)
|
||
|
>>> dependencies
|
||
|
{'a': set(), 'b': {'a'}, 'c': {'b'}}
|
||
|
>>> dependents
|
||
|
{'a': {'b'}, 'b': {'c'}, 'c': set()}
|
||
|
"""
|
||
|
dsk = {'a': [1, 2, 3],
|
||
|
'b': 'a',
|
||
|
'c': [1, (inc, 1)],
|
||
|
'd': [(sum, 'c')],
|
||
|
'e': ['b', 'zzz', 'b'],
|
||
|
'f': [['a', 'b'], 2, 3]}
|
||
|
dependencies, dependents = get_deps(dsk)
|
||
|
assert dependencies == {'a': set(),
|
||
|
'b': {'a'},
|
||
|
'c': set(),
|
||
|
'd': {'c'},
|
||
|
'e': {'b'},
|
||
|
'f': {'a', 'b'},
|
||
|
}
|
||
|
assert dependents == {'a': {'b', 'f'},
|
||
|
'b': {'e', 'f'},
|
||
|
'c': {'d'},
|
||
|
'd': set(),
|
||
|
'e': set(),
|
||
|
'f': set(),
|
||
|
}
|
||
|
|
||
|
|
||
|
def test_flatten():
|
||
|
assert list(flatten(())) == []
|
||
|
assert list(flatten('foo')) == ['foo']
|
||
|
|
||
|
|
||
|
def test_subs():
|
||
|
assert subs((sum, [1, 'x']), 'x', 2) == (sum, [1, 2])
|
||
|
assert subs((sum, [1, ['x']]), 'x', 2) == (sum, [1, [2]])
|
||
|
|
||
|
|
||
|
class MutateOnEq(object):
|
||
|
hit_eq = 0
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
self.hit_eq += 1
|
||
|
return False
|
||
|
|
||
|
|
||
|
def test_subs_no_key_data_eq():
|
||
|
# Numpy throws a deprecation warning on bool(array == scalar), which
|
||
|
# pollutes the terminal. This test checks that `subs` never tries to
|
||
|
# compare keys (scalars) with values (which could be arrays)`subs` never
|
||
|
# tries to compare keys (scalars) with values (which could be arrays).
|
||
|
a = MutateOnEq()
|
||
|
subs(a, 'x', 1)
|
||
|
assert a.hit_eq == 0
|
||
|
subs((add, a, 'x'), 'x', 1)
|
||
|
assert a.hit_eq == 0
|
||
|
|
||
|
|
||
|
def test_subs_with_unfriendly_eq():
|
||
|
try:
|
||
|
import numpy as np
|
||
|
except ImportError:
|
||
|
return
|
||
|
else:
|
||
|
task = (np.sum, np.array([1, 2]))
|
||
|
assert (subs(task, (4, 5), 1) == task) is True
|
||
|
|
||
|
class MyException(Exception):
|
||
|
pass
|
||
|
|
||
|
class F():
|
||
|
def __eq__(self, other):
|
||
|
raise MyException()
|
||
|
|
||
|
task = F()
|
||
|
assert subs(task, 1, 2) is task
|
||
|
|
||
|
|
||
|
def test_subs_with_surprisingly_friendly_eq():
|
||
|
try:
|
||
|
import pandas as pd
|
||
|
except ImportError:
|
||
|
return
|
||
|
else:
|
||
|
df = pd.DataFrame()
|
||
|
assert subs(df, 'x', 1) is df
|
||
|
|
||
|
|
||
|
def test_subs_unexpected_hashable_key():
|
||
|
class UnexpectedButHashable:
|
||
|
def __init__(self):
|
||
|
self.name = "a"
|
||
|
|
||
|
def __hash__(self):
|
||
|
return hash(self.name)
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
return isinstance(other, UnexpectedButHashable)
|
||
|
|
||
|
assert subs((id, UnexpectedButHashable()), UnexpectedButHashable(), 1) == (id, 1)
|
||
|
|
||
|
|
||
|
def test_quote():
|
||
|
literals = [[1, 2, 3], (add, 1, 2),
|
||
|
[1, [2, 3]], (add, 1, (add, 2, 3))]
|
||
|
|
||
|
for l in literals:
|
||
|
assert core.get({'x': quote(l)}, 'x') == l
|
||
|
|
||
|
|
||
|
def test_literal_serializable():
|
||
|
l = literal((add, 1, 2))
|
||
|
assert pickle.loads(pickle.dumps(l)).data == (add, 1, 2)
|