You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
172 lines
4.0 KiB
172 lines
4.0 KiB
import os
|
|
import sys
|
|
import signal
|
|
import threading
|
|
from multiprocessing.pool import ThreadPool
|
|
from time import time, sleep
|
|
|
|
import pytest
|
|
|
|
import dask
|
|
from dask.compatibility import PY2
|
|
from dask.threaded import get
|
|
from dask.utils_test import inc, add
|
|
|
|
|
|
def test_get():
|
|
dsk = {'x': 1, 'y': 2, 'z': (inc, 'x'), 'w': (add, 'z', 'y')}
|
|
assert get(dsk, 'w') == 4
|
|
assert get(dsk, ['w', 'z']) == (4, 2)
|
|
|
|
|
|
def test_nested_get():
|
|
dsk = {'x': 1, 'y': 2, 'a': (add, 'x', 'y'), 'b': (sum, ['x', 'y'])}
|
|
assert get(dsk, ['a', 'b']) == (3, 3)
|
|
|
|
|
|
def test_get_without_computation():
|
|
dsk = {'x': 1}
|
|
assert get(dsk, 'x') == 1
|
|
|
|
|
|
def test_broken_callback():
|
|
from dask.callbacks import Callback
|
|
|
|
def _f_ok(*args, **kwargs):
|
|
pass
|
|
|
|
def _f_broken(*args, **kwargs):
|
|
raise ValueError('my_exception')
|
|
|
|
dsk = {'x': 1}
|
|
|
|
with Callback(start=_f_broken, finish=_f_ok):
|
|
with Callback(start=_f_ok, finish=_f_ok):
|
|
with pytest.raises(ValueError, match='my_exception'):
|
|
get(dsk, 'x')
|
|
|
|
|
|
def bad(x):
|
|
raise ValueError()
|
|
|
|
|
|
def test_exceptions_rise_to_top():
|
|
dsk = {'x': 1, 'y': (bad, 'x')}
|
|
pytest.raises(ValueError, lambda: get(dsk, 'y'))
|
|
|
|
|
|
def test_reuse_pool():
|
|
pool = ThreadPool()
|
|
with dask.config.set(pool=pool):
|
|
assert get({'x': (inc, 1)}, 'x') == 2
|
|
assert get({'x': (inc, 1)}, 'x') == 2
|
|
|
|
|
|
@pytest.mark.skipif(PY2, reason="threading API changed")
|
|
def test_pool_kwarg():
|
|
def f():
|
|
sleep(0.01)
|
|
return threading.get_ident()
|
|
|
|
dsk = {('x', i): (f,) for i in range(30)}
|
|
dsk['x'] = (len, (set, [('x', i) for i in range(len(dsk))]))
|
|
|
|
with ThreadPool(3) as pool:
|
|
assert get(dsk, 'x', pool=pool) == 3
|
|
|
|
|
|
def test_threaded_within_thread():
|
|
L = []
|
|
|
|
def f(i):
|
|
result = get({'x': (lambda: i,)}, 'x', num_workers=2)
|
|
L.append(result)
|
|
|
|
before = threading.active_count()
|
|
|
|
for i in range(20):
|
|
t = threading.Thread(target=f, args=(1,))
|
|
t.daemon = True
|
|
t.start()
|
|
t.join()
|
|
assert L == [1]
|
|
del L[:]
|
|
|
|
start = time() # wait for most threads to join
|
|
while threading.active_count() > before + 10:
|
|
sleep(0.01)
|
|
assert time() < start + 5
|
|
|
|
|
|
def test_dont_spawn_too_many_threads():
|
|
before = threading.active_count()
|
|
|
|
dsk = {('x', i): (lambda: i,) for i in range(10)}
|
|
dsk['x'] = (sum, list(dsk))
|
|
for i in range(20):
|
|
get(dsk, 'x', num_workers=4)
|
|
|
|
after = threading.active_count()
|
|
|
|
assert after <= before + 8
|
|
|
|
|
|
def test_thread_safety():
|
|
def f(x):
|
|
return 1
|
|
|
|
dsk = {'x': (sleep, 0.05), 'y': (f, 'x')}
|
|
|
|
L = []
|
|
|
|
def test_f():
|
|
L.append(get(dsk, 'y'))
|
|
|
|
threads = []
|
|
for i in range(20):
|
|
t = threading.Thread(target=test_f)
|
|
t.daemon = True
|
|
t.start()
|
|
threads.append(t)
|
|
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
assert L == [1] * 20
|
|
|
|
|
|
@pytest.mark.xfail('xdist' in sys.modules,
|
|
reason=("This test fails intermittently when using "
|
|
"pytest-xdist (maybe)"))
|
|
def test_interrupt():
|
|
# Python 2 and windows 2 & 3 both implement `queue.get` using polling,
|
|
# which means we can set an exception to interrupt the call to `get`.
|
|
# Python 3 on other platforms requires sending SIGINT to the main thread.
|
|
if PY2:
|
|
from thread import interrupt_main
|
|
elif os.name == 'nt':
|
|
from _thread import interrupt_main
|
|
else:
|
|
main_thread = threading.get_ident()
|
|
|
|
def interrupt_main():
|
|
signal.pthread_kill(main_thread, signal.SIGINT)
|
|
|
|
def long_task():
|
|
sleep(5)
|
|
|
|
dsk = {('x', i): (long_task,) for i in range(20)}
|
|
dsk['x'] = (len, list(dsk.keys()))
|
|
try:
|
|
interrupter = threading.Timer(0.5, interrupt_main)
|
|
interrupter.start()
|
|
start = time()
|
|
get(dsk, 'x')
|
|
except KeyboardInterrupt:
|
|
pass
|
|
except Exception:
|
|
assert False, "Failed to interrupt"
|
|
stop = time()
|
|
if stop - start > 4:
|
|
assert False, "Failed to interrupt"
|