You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
114 lines
2.5 KiB
114 lines
2.5 KiB
6 years ago
|
from dask.local import get_sync
|
||
|
from dask.threaded import get as get_threaded
|
||
|
from dask.callbacks import Callback
|
||
|
from dask.utils_test import add
|
||
|
|
||
|
|
||
|
def test_start_callback():
|
||
|
flag = [False]
|
||
|
|
||
|
class MyCallback(Callback):
|
||
|
def _start(self, dsk):
|
||
|
flag[0] = True
|
||
|
|
||
|
with MyCallback():
|
||
|
get_sync({'x': 1}, 'x')
|
||
|
|
||
|
assert flag[0] is True
|
||
|
|
||
|
|
||
|
def test_start_state_callback():
|
||
|
flag = [False]
|
||
|
|
||
|
class MyCallback(Callback):
|
||
|
def _start_state(self, dsk, state):
|
||
|
flag[0] = True
|
||
|
assert dsk['x'] == 1
|
||
|
assert len(state['cache']) == 1
|
||
|
|
||
|
with MyCallback():
|
||
|
get_sync({'x': 1}, 'x')
|
||
|
|
||
|
assert flag[0] is True
|
||
|
|
||
|
|
||
|
def test_finish_always_called():
|
||
|
flag = [False]
|
||
|
|
||
|
class MyCallback(Callback):
|
||
|
def _finish(self, dsk, state, errored):
|
||
|
flag[0] = True
|
||
|
assert errored
|
||
|
|
||
|
dsk = {'x': (lambda: 1 / 0,)}
|
||
|
|
||
|
# `raise_on_exception=True`
|
||
|
try:
|
||
|
with MyCallback():
|
||
|
get_sync(dsk, 'x')
|
||
|
except Exception as e:
|
||
|
assert isinstance(e, ZeroDivisionError)
|
||
|
assert flag[0]
|
||
|
|
||
|
# `raise_on_exception=False`
|
||
|
flag[0] = False
|
||
|
try:
|
||
|
with MyCallback():
|
||
|
get_threaded(dsk, 'x')
|
||
|
except Exception as e:
|
||
|
assert isinstance(e, ZeroDivisionError)
|
||
|
assert flag[0]
|
||
|
|
||
|
# KeyboardInterrupt
|
||
|
def raise_keyboard():
|
||
|
raise KeyboardInterrupt()
|
||
|
|
||
|
dsk = {'x': (raise_keyboard,)}
|
||
|
flag[0] = False
|
||
|
try:
|
||
|
with MyCallback():
|
||
|
get_sync(dsk, 'x')
|
||
|
except BaseException as e:
|
||
|
assert isinstance(e, KeyboardInterrupt)
|
||
|
assert flag[0]
|
||
|
|
||
|
|
||
|
def test_nested_schedulers():
|
||
|
|
||
|
class MyCallback(Callback):
|
||
|
def _start(self, dsk):
|
||
|
self.dsk = dsk
|
||
|
|
||
|
def _pretask(self, key, dsk, state):
|
||
|
assert key in self.dsk
|
||
|
|
||
|
inner_callback = MyCallback()
|
||
|
inner_dsk = {'x': (add, 1, 2),
|
||
|
'y': (add, 'x', 3)}
|
||
|
|
||
|
def nested_call(x):
|
||
|
assert not Callback.active
|
||
|
with inner_callback:
|
||
|
return get_threaded(inner_dsk, 'y') + x
|
||
|
|
||
|
outer_callback = MyCallback()
|
||
|
outer_dsk = {'a': (nested_call, 1),
|
||
|
'b': (add, 'a', 2)}
|
||
|
|
||
|
with outer_callback:
|
||
|
get_threaded(outer_dsk, 'b')
|
||
|
|
||
|
assert not Callback.active
|
||
|
assert outer_callback.dsk == outer_dsk
|
||
|
assert inner_callback.dsk == inner_dsk
|
||
|
assert not Callback.active
|
||
|
|
||
|
|
||
|
def test_add_remove_mutates_not_replaces():
|
||
|
assert not Callback.active
|
||
|
|
||
|
with Callback():
|
||
|
assert Callback.active
|
||
|
|
||
|
assert not Callback.active
|