You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ORPA-pyOpenRPA/Resources/WPy64-3720/python-3.7.2.amd64/Lib/site-packages/dask/tests/test_threaded.py

172 lines
4.0 KiB

import os
import sys
import signal
import threading
from multiprocessing.pool import ThreadPool
from time import time, sleep
import pytest
import dask
from dask.compatibility import PY2
from dask.threaded import get
from dask.utils_test import inc, add
def test_get():
dsk = {'x': 1, 'y': 2, 'z': (inc, 'x'), 'w': (add, 'z', 'y')}
assert get(dsk, 'w') == 4
assert get(dsk, ['w', 'z']) == (4, 2)
def test_nested_get():
dsk = {'x': 1, 'y': 2, 'a': (add, 'x', 'y'), 'b': (sum, ['x', 'y'])}
assert get(dsk, ['a', 'b']) == (3, 3)
def test_get_without_computation():
dsk = {'x': 1}
assert get(dsk, 'x') == 1
def test_broken_callback():
from dask.callbacks import Callback
def _f_ok(*args, **kwargs):
pass
def _f_broken(*args, **kwargs):
raise ValueError('my_exception')
dsk = {'x': 1}
with Callback(start=_f_broken, finish=_f_ok):
with Callback(start=_f_ok, finish=_f_ok):
with pytest.raises(ValueError, match='my_exception'):
get(dsk, 'x')
def bad(x):
raise ValueError()
def test_exceptions_rise_to_top():
dsk = {'x': 1, 'y': (bad, 'x')}
pytest.raises(ValueError, lambda: get(dsk, 'y'))
def test_reuse_pool():
pool = ThreadPool()
with dask.config.set(pool=pool):
assert get({'x': (inc, 1)}, 'x') == 2
assert get({'x': (inc, 1)}, 'x') == 2
@pytest.mark.skipif(PY2, reason="threading API changed")
def test_pool_kwarg():
def f():
sleep(0.01)
return threading.get_ident()
dsk = {('x', i): (f,) for i in range(30)}
dsk['x'] = (len, (set, [('x', i) for i in range(len(dsk))]))
with ThreadPool(3) as pool:
assert get(dsk, 'x', pool=pool) == 3
def test_threaded_within_thread():
L = []
def f(i):
result = get({'x': (lambda: i,)}, 'x', num_workers=2)
L.append(result)
before = threading.active_count()
for i in range(20):
t = threading.Thread(target=f, args=(1,))
t.daemon = True
t.start()
t.join()
assert L == [1]
del L[:]
start = time() # wait for most threads to join
while threading.active_count() > before + 10:
sleep(0.01)
assert time() < start + 5
def test_dont_spawn_too_many_threads():
before = threading.active_count()
dsk = {('x', i): (lambda: i,) for i in range(10)}
dsk['x'] = (sum, list(dsk))
for i in range(20):
get(dsk, 'x', num_workers=4)
after = threading.active_count()
assert after <= before + 8
def test_thread_safety():
def f(x):
return 1
dsk = {'x': (sleep, 0.05), 'y': (f, 'x')}
L = []
def test_f():
L.append(get(dsk, 'y'))
threads = []
for i in range(20):
t = threading.Thread(target=test_f)
t.daemon = True
t.start()
threads.append(t)
for thread in threads:
thread.join()
assert L == [1] * 20
@pytest.mark.xfail('xdist' in sys.modules,
reason=("This test fails intermittently when using "
"pytest-xdist (maybe)"))
def test_interrupt():
# Python 2 and windows 2 & 3 both implement `queue.get` using polling,
# which means we can set an exception to interrupt the call to `get`.
# Python 3 on other platforms requires sending SIGINT to the main thread.
if PY2:
from thread import interrupt_main
elif os.name == 'nt':
from _thread import interrupt_main
else:
main_thread = threading.get_ident()
def interrupt_main():
signal.pthread_kill(main_thread, signal.SIGINT)
def long_task():
sleep(5)
dsk = {('x', i): (long_task,) for i in range(20)}
dsk['x'] = (len, list(dsk.keys()))
try:
interrupter = threading.Timer(0.5, interrupt_main)
interrupter.start()
start = time()
get(dsk, 'x')
except KeyboardInterrupt:
pass
except Exception:
assert False, "Failed to interrupt"
stop = time()
if stop - start > 4:
assert False, "Failed to interrupt"