140 lines
3.9 KiB

6 years ago
from __future__ import absolute_import, division, print_function
from contextlib import contextmanager
__all__ = ['Callback', 'add_callbacks']
class Callback(object):
""" Base class for using the callback mechanism
Create a callback with functions of the following signatures:
>>> def start(dsk):
... pass
>>> def start_state(dsk, state):
... pass
>>> def pretask(key, dsk, state):
... pass
>>> def posttask(key, result, dsk, state, worker_id):
... pass
>>> def finish(dsk, state, failed):
... pass
You may then construct a callback object with any number of them
>>> cb = Callback(pretask=pretask, finish=finish) # doctest: +SKIP
And use it either as a context manager over a compute/get call
>>> with cb: # doctest: +SKIP
... x.compute() # doctest: +SKIP
Or globally with the ``register`` method
>>> cb.register() # doctest: +SKIP
>>> cb.unregister() # doctest: +SKIP
Alternatively subclass the ``Callback`` class with your own methods.
>>> class PrintKeys(Callback):
... def _pretask(self, key, dask, state):
... print("Computing: {0}!".format(repr(key)))
>>> with PrintKeys(): # doctest: +SKIP
... x.compute() # doctest: +SKIP
"""
active = set()
def __init__(self, start=None, start_state=None, pretask=None, posttask=None, finish=None):
if start:
self._start = start
if start_state:
self._start_state = start_state
if pretask:
self._pretask = pretask
if posttask:
self._posttask = posttask
if finish:
self._finish = finish
@property
def _callback(self):
fields = ['_start', '_start_state', '_pretask', '_posttask', '_finish']
return tuple(getattr(self, i, None) for i in fields)
def __enter__(self):
self._cm = add_callbacks(self)
self._cm.__enter__()
return self
def __exit__(self, *args):
self._cm.__exit__(*args)
def register(self):
Callback.active.add(self._callback)
def unregister(self):
Callback.active.remove(self._callback)
def unpack_callbacks(cbs):
"""Take an iterable of callbacks, return a list of each callback."""
if cbs:
return [[i for i in f if i] for f in zip(*cbs)]
else:
return [(), (), (), (), ()]
@contextmanager
def local_callbacks(callbacks=None):
"""Allows callbacks to work with nested schedulers.
Callbacks will only be used by the first started scheduler they encounter.
This means that only the outermost scheduler will use global callbacks."""
global_callbacks = callbacks is None
if global_callbacks:
callbacks, Callback.active = Callback.active, set()
try:
yield callbacks or ()
finally:
if global_callbacks:
Callback.active = callbacks
def normalize_callback(cb):
"""Normalizes a callback to a tuple"""
if isinstance(cb, Callback):
return cb._callback
elif isinstance(cb, tuple):
return cb
else:
raise TypeError("Callbacks must be either `Callback` or `tuple`")
class add_callbacks(object):
"""Context manager for callbacks.
Takes several callbacks and applies them only in the enclosed context.
Callbacks can either be represented as a ``Callback`` object, or as a tuple
of length 4.
Examples
--------
>>> def pretask(key, dsk, state):
... print("Now running {0}").format(key)
>>> callbacks = (None, pretask, None, None)
>>> with add_callbacks(callbacks): # doctest: +SKIP
... res.compute()
"""
def __init__(self, *callbacks):
self.callbacks = [normalize_callback(c) for c in callbacks]
Callback.active.update(self.callbacks)
def __enter__(self):
return
def __exit__(self, type, value, traceback):
for c in self.callbacks:
Callback.active.discard(c)