You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
381 lines
8.3 KiB
381 lines
8.3 KiB
6 years ago
|
import functools
|
||
|
import operator
|
||
|
import pickle
|
||
|
|
||
|
import numpy as np
|
||
|
import pytest
|
||
|
|
||
|
from dask.compatibility import PY2
|
||
|
from dask.utils import (takes_multiple_arguments, Dispatch, random_state_data,
|
||
|
memory_repr, methodcaller, M, skip_doctest,
|
||
|
SerializableLock, funcname, ndeepmap, ensure_dict,
|
||
|
extra_titles, asciitable, itemgetter, partial_by_order,
|
||
|
has_keyword)
|
||
|
from dask.utils_test import inc
|
||
|
from dask.highlevelgraph import HighLevelGraph
|
||
|
|
||
|
|
||
|
def test_takes_multiple_arguments():
|
||
|
assert takes_multiple_arguments(map)
|
||
|
assert not takes_multiple_arguments(sum)
|
||
|
|
||
|
def multi(a, b, c):
|
||
|
return a, b, c
|
||
|
|
||
|
class Singular(object):
|
||
|
def __init__(self, a):
|
||
|
pass
|
||
|
|
||
|
class Multi(object):
|
||
|
def __init__(self, a, b):
|
||
|
pass
|
||
|
|
||
|
assert takes_multiple_arguments(multi)
|
||
|
assert not takes_multiple_arguments(Singular)
|
||
|
assert takes_multiple_arguments(Multi)
|
||
|
|
||
|
def f():
|
||
|
pass
|
||
|
|
||
|
assert not takes_multiple_arguments(f)
|
||
|
|
||
|
def vararg(*args):
|
||
|
pass
|
||
|
|
||
|
assert takes_multiple_arguments(vararg)
|
||
|
assert not takes_multiple_arguments(vararg, varargs=False)
|
||
|
|
||
|
|
||
|
def test_dispatch():
|
||
|
foo = Dispatch()
|
||
|
foo.register(int, lambda a: a + 1)
|
||
|
foo.register(float, lambda a: a - 1)
|
||
|
foo.register(tuple, lambda a: tuple(foo(i) for i in a))
|
||
|
|
||
|
def f(a):
|
||
|
""" My Docstring """
|
||
|
return a
|
||
|
|
||
|
foo.register(object, f)
|
||
|
|
||
|
class Bar(object):
|
||
|
pass
|
||
|
b = Bar()
|
||
|
assert foo(1) == 2
|
||
|
assert foo.dispatch(int)(1) == 2
|
||
|
assert foo(1.0) == 0.0
|
||
|
assert foo(b) == b
|
||
|
assert foo((1, 2.0, b)) == (2, 1.0, b)
|
||
|
|
||
|
assert foo.__doc__ == f.__doc__
|
||
|
|
||
|
|
||
|
def test_dispatch_kwargs():
|
||
|
foo = Dispatch()
|
||
|
foo.register(int, lambda a, b=10: a + b)
|
||
|
|
||
|
assert foo(1, b=20) == 21
|
||
|
|
||
|
|
||
|
def test_dispatch_variadic_on_first_argument():
|
||
|
foo = Dispatch()
|
||
|
foo.register(int, lambda a, b: a + b)
|
||
|
foo.register(float, lambda a, b: a - b)
|
||
|
|
||
|
assert foo(1, 2) == 3
|
||
|
assert foo(1., 2.) == -1
|
||
|
|
||
|
|
||
|
def test_dispatch_lazy():
|
||
|
# this tests the recursive component of dispatch
|
||
|
foo = Dispatch()
|
||
|
foo.register(int, lambda a: a)
|
||
|
|
||
|
import decimal
|
||
|
|
||
|
# keep it outside lazy dec for test
|
||
|
def foo_dec(a):
|
||
|
return a + 1
|
||
|
|
||
|
@foo.register_lazy("decimal")
|
||
|
def register_decimal():
|
||
|
import decimal
|
||
|
foo.register(decimal.Decimal, foo_dec)
|
||
|
|
||
|
# This test needs to be *before* any other calls
|
||
|
assert foo.dispatch(decimal.Decimal) == foo_dec
|
||
|
assert foo(decimal.Decimal(1)) == decimal.Decimal(2)
|
||
|
assert foo(1) == 1
|
||
|
|
||
|
|
||
|
def test_random_state_data():
|
||
|
seed = 37
|
||
|
state = np.random.RandomState(seed)
|
||
|
n = 10000
|
||
|
|
||
|
# Use an integer
|
||
|
states = random_state_data(n, seed)
|
||
|
assert len(states) == n
|
||
|
|
||
|
# Use RandomState object
|
||
|
states2 = random_state_data(n, state)
|
||
|
for s1, s2 in zip(states, states2):
|
||
|
assert s1.shape == (624,)
|
||
|
assert (s1 == s2).all()
|
||
|
|
||
|
# Consistent ordering
|
||
|
states = random_state_data(10, 1234)
|
||
|
states2 = random_state_data(20, 1234)[:10]
|
||
|
|
||
|
for s1, s2 in zip(states, states2):
|
||
|
assert (s1 == s2).all()
|
||
|
|
||
|
|
||
|
def test_memory_repr():
|
||
|
for power, mem_repr in enumerate(['1.0 bytes', '1.0 KB', '1.0 MB', '1.0 GB']):
|
||
|
assert memory_repr(1024 ** power) == mem_repr
|
||
|
|
||
|
|
||
|
def test_method_caller():
|
||
|
a = [1, 2, 3, 3, 3]
|
||
|
f = methodcaller('count')
|
||
|
assert f(a, 3) == a.count(3)
|
||
|
assert methodcaller('count') is f
|
||
|
assert M.count is f
|
||
|
assert pickle.loads(pickle.dumps(f)) is f
|
||
|
assert 'count' in dir(M)
|
||
|
|
||
|
assert 'count' in str(methodcaller('count'))
|
||
|
assert 'count' in repr(methodcaller('count'))
|
||
|
|
||
|
|
||
|
def test_skip_doctest():
|
||
|
example = """>>> xxx
|
||
|
>>>
|
||
|
>>> # comment
|
||
|
>>> xxx"""
|
||
|
|
||
|
res = skip_doctest(example)
|
||
|
assert res == """>>> xxx # doctest: +SKIP
|
||
|
>>>
|
||
|
>>> # comment
|
||
|
>>> xxx # doctest: +SKIP"""
|
||
|
|
||
|
assert skip_doctest(None) == ''
|
||
|
|
||
|
example = """
|
||
|
>>> 1 + 2 # doctest: +ELLIPSES
|
||
|
3"""
|
||
|
|
||
|
expected = """
|
||
|
>>> 1 + 2 # doctest: +ELLIPSES, +SKIP
|
||
|
3"""
|
||
|
res = skip_doctest(example)
|
||
|
assert res == expected
|
||
|
|
||
|
|
||
|
def test_extra_titles():
|
||
|
example = """
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
hello
|
||
|
|
||
|
Foo
|
||
|
---
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
bar
|
||
|
"""
|
||
|
|
||
|
expected = """
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
hello
|
||
|
|
||
|
Foo
|
||
|
---
|
||
|
|
||
|
Extra Notes
|
||
|
-----------
|
||
|
bar
|
||
|
"""
|
||
|
|
||
|
assert extra_titles(example) == expected
|
||
|
|
||
|
|
||
|
def test_asciitable():
|
||
|
res = asciitable(['fruit', 'color'],
|
||
|
[('apple', 'red'),
|
||
|
('banana', 'yellow'),
|
||
|
('tomato', 'red'),
|
||
|
('pear', 'green')])
|
||
|
assert res == ('+--------+--------+\n'
|
||
|
'| fruit | color |\n'
|
||
|
'+--------+--------+\n'
|
||
|
'| apple | red |\n'
|
||
|
'| banana | yellow |\n'
|
||
|
'| tomato | red |\n'
|
||
|
'| pear | green |\n'
|
||
|
'+--------+--------+')
|
||
|
|
||
|
|
||
|
def test_SerializableLock():
|
||
|
a = SerializableLock()
|
||
|
b = SerializableLock()
|
||
|
with a:
|
||
|
pass
|
||
|
|
||
|
with a:
|
||
|
with b:
|
||
|
pass
|
||
|
|
||
|
with a:
|
||
|
assert not a.acquire(False)
|
||
|
|
||
|
a2 = pickle.loads(pickle.dumps(a))
|
||
|
a3 = pickle.loads(pickle.dumps(a))
|
||
|
a4 = pickle.loads(pickle.dumps(a2))
|
||
|
|
||
|
for x in [a, a2, a3, a4]:
|
||
|
for y in [a, a2, a3, a4]:
|
||
|
with x:
|
||
|
assert not y.acquire(False)
|
||
|
|
||
|
b2 = pickle.loads(pickle.dumps(b))
|
||
|
b3 = pickle.loads(pickle.dumps(b2))
|
||
|
|
||
|
for x in [a, a2, a3, a4]:
|
||
|
for y in [b, b2, b3]:
|
||
|
with x:
|
||
|
with y:
|
||
|
pass
|
||
|
with y:
|
||
|
with x:
|
||
|
pass
|
||
|
|
||
|
|
||
|
def test_SerializableLock_name_collision():
|
||
|
a = SerializableLock('a')
|
||
|
b = SerializableLock('b')
|
||
|
c = SerializableLock('a')
|
||
|
d = SerializableLock()
|
||
|
|
||
|
assert a.lock is not b.lock
|
||
|
assert a.lock is c.lock
|
||
|
assert d.lock not in (a.lock, b.lock, c.lock)
|
||
|
|
||
|
|
||
|
def test_SerializableLock_locked():
|
||
|
a = SerializableLock('a')
|
||
|
assert not a.locked()
|
||
|
with a:
|
||
|
assert a.locked()
|
||
|
assert not a.locked()
|
||
|
|
||
|
|
||
|
@pytest.mark.skipif(PY2, reason="no blocking= keyword in Python 2")
|
||
|
def test_SerializableLock_acquire_blocking():
|
||
|
a = SerializableLock('a')
|
||
|
assert a.acquire(blocking=True)
|
||
|
assert not a.acquire(blocking=False)
|
||
|
a.release()
|
||
|
|
||
|
|
||
|
def test_funcname():
|
||
|
def foo(a, b, c):
|
||
|
pass
|
||
|
|
||
|
assert funcname(foo) == 'foo'
|
||
|
assert funcname(functools.partial(foo, a=1)) == 'foo'
|
||
|
assert funcname(M.sum) == 'sum'
|
||
|
assert funcname(lambda: 1) == 'lambda'
|
||
|
|
||
|
class Foo(object):
|
||
|
pass
|
||
|
|
||
|
assert funcname(Foo) == 'Foo'
|
||
|
assert 'Foo' in funcname(Foo())
|
||
|
|
||
|
|
||
|
def test_funcname_toolz():
|
||
|
toolz = pytest.importorskip('toolz')
|
||
|
|
||
|
@toolz.curry
|
||
|
def foo(a, b, c):
|
||
|
pass
|
||
|
|
||
|
assert funcname(foo) == 'foo'
|
||
|
assert funcname(foo(1)) == 'foo'
|
||
|
|
||
|
|
||
|
def test_funcname_multipledispatch():
|
||
|
md = pytest.importorskip('multipledispatch')
|
||
|
|
||
|
@md.dispatch(int, int, int)
|
||
|
def foo(a, b, c):
|
||
|
pass
|
||
|
|
||
|
assert funcname(foo) == 'foo'
|
||
|
assert funcname(functools.partial(foo, a=1)) == 'foo'
|
||
|
|
||
|
|
||
|
def test_ndeepmap():
|
||
|
L = 1
|
||
|
assert ndeepmap(0, inc, L) == 2
|
||
|
|
||
|
L = [1]
|
||
|
assert ndeepmap(0, inc, L) == 2
|
||
|
|
||
|
L = [1, 2, 3]
|
||
|
assert ndeepmap(1, inc, L) == [2, 3, 4]
|
||
|
|
||
|
L = [[1, 2], [3, 4]]
|
||
|
assert ndeepmap(2, inc, L) == [[2, 3], [4, 5]]
|
||
|
|
||
|
L = [[[1, 2], [3, 4, 5]], [[6], []]]
|
||
|
assert ndeepmap(3, inc, L) == [[[2, 3], [4, 5, 6]], [[7], []]]
|
||
|
|
||
|
|
||
|
def test_ensure_dict():
|
||
|
d = {'x': 1}
|
||
|
assert ensure_dict(d) is d
|
||
|
hlg = HighLevelGraph.from_collections('x', d)
|
||
|
assert type(ensure_dict(hlg)) is dict
|
||
|
assert ensure_dict(hlg) == d
|
||
|
|
||
|
class mydict(dict):
|
||
|
pass
|
||
|
|
||
|
md = mydict()
|
||
|
md['x'] = 1
|
||
|
assert type(ensure_dict(md)) is dict
|
||
|
assert ensure_dict(md) == d
|
||
|
|
||
|
|
||
|
def test_itemgetter():
|
||
|
data = [1, 2, 3]
|
||
|
g = itemgetter(1)
|
||
|
assert g(data) == 2
|
||
|
g2 = pickle.loads(pickle.dumps(g))
|
||
|
assert g2(data) == 2
|
||
|
assert g2.index == 1
|
||
|
|
||
|
|
||
|
def test_partial_by_order():
|
||
|
assert partial_by_order(5, function=operator.add, other=[(1, 20)]) == 25
|
||
|
|
||
|
|
||
|
def test_has_keyword():
|
||
|
def foo(a, b, c=None):
|
||
|
pass
|
||
|
assert has_keyword(foo, 'a')
|
||
|
assert has_keyword(foo, 'b')
|
||
|
assert has_keyword(foo, 'c')
|
||
|
|
||
|
bar = functools.partial(foo, a=1)
|
||
|
assert has_keyword(bar, 'b')
|
||
|
assert has_keyword(bar, 'c')
|