from __future__ import absolute_import, division, print_function import multiprocessing import traceback import pickle import sys from warnings import warn import cloudpickle from . import config from .compatibility import copyreg from .local import get_async # TODO: get better get from .optimization import fuse, cull def _reduce_method_descriptor(m): return getattr, (m.__objclass__, m.__name__) # type(set.union) is used as a proxy to copyreg.pickle(type(set.union), _reduce_method_descriptor) def _dumps(x): return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) _loads = cloudpickle.loads def _process_get_id(): return multiprocessing.current_process().ident # -- Remote Exception Handling -- # By default, tracebacks can't be serialized using pickle. However, the # `tblib` library can enable support for this. Since we don't mandate # that tblib is installed, we do the following: # # - If tblib is installed, use it to serialize the traceback and reraise # in the scheduler process # - Otherwise, use a ``RemoteException`` class to contain a serialized # version of the formatted traceback, which will then print in the # scheduler process. # # To enable testing of the ``RemoteException`` class even when tblib is # installed, we don't wrap the class in the try block below class RemoteException(Exception): """ Remote Exception Contains the exception and traceback from a remotely run task """ def __init__(self, exception, traceback): self.exception = exception self.traceback = traceback def __str__(self): return (str(self.exception) + "\n\n" "Traceback\n" "---------\n" + self.traceback) def __dir__(self): return sorted(set(dir(type(self)) + list(self.__dict__) + dir(self.exception))) def __getattr__(self, key): try: return object.__getattribute__(self, key) except AttributeError: return getattr(self.exception, key) exceptions = dict() def remote_exception(exc, tb): """ Metaclass that wraps exception type in RemoteException """ if type(exc) in exceptions: typ = exceptions[type(exc)] return typ(exc, tb) else: try: typ = type(exc.__class__.__name__, (RemoteException, type(exc)), {'exception_type': type(exc)}) exceptions[type(exc)] = typ return typ(exc, tb) except TypeError: return exc try: import tblib.pickling_support tblib.pickling_support.install() from dask.compatibility import reraise def _pack_traceback(tb): return tb except ImportError: def _pack_traceback(tb): return ''.join(traceback.format_tb(tb)) def reraise(exc, tb): exc = remote_exception(exc, tb) raise exc def pack_exception(e, dumps): exc_type, exc_value, exc_traceback = sys.exc_info() tb = _pack_traceback(exc_traceback) try: result = dumps((e, tb)) except BaseException as e: exc_type, exc_value, exc_traceback = sys.exc_info() tb = _pack_traceback(exc_traceback) result = dumps((e, tb)) return result _CONTEXT_UNSUPPORTED = """\ The 'multiprocessing.context' configuration option will be ignored on Python 2 and on Windows, because they each only support a single context. """ def get_context(): """ Return the current multiprocessing context.""" if sys.platform == "win32" or sys.version_info.major == 2: # Just do the default, since we can't change it: if config.get("multiprocessing.context", None) is not None: warn(_CONTEXT_UNSUPPORTED, UserWarning) return multiprocessing context_name = config.get("multiprocessing.context", None) return multiprocessing.get_context(context_name) def get(dsk, keys, num_workers=None, func_loads=None, func_dumps=None, optimize_graph=True, pool=None, **kwargs): """ Multiprocessed get function appropriate for Bags Parameters ---------- dsk : dict dask graph keys : object or list Desired results from graph num_workers : int Number of worker processes (defaults to number of cores) func_dumps : function Function to use for function serialization (defaults to cloudpickle.dumps) func_loads : function Function to use for function deserialization (defaults to cloudpickle.loads) optimize_graph : bool If True [default], `fuse` is applied to the graph before computation. """ pool = pool or config.get('pool', None) num_workers = num_workers or config.get('num_workers', None) if pool is None: context = get_context() pool = context.Pool(num_workers, initializer=initialize_worker_process) cleanup = True else: cleanup = False # Optimize Dask dsk2, dependencies = cull(dsk, keys) if optimize_graph: dsk3, dependencies = fuse(dsk2, keys, dependencies) else: dsk3 = dsk2 # We specify marshalling functions in order to catch serialization # errors and report them to the user. loads = func_loads or config.get('func_loads', None) or _loads dumps = func_dumps or config.get('func_dumps', None) or _dumps # Note former versions used a multiprocessing Manager to share # a Queue between parent and workers, but this is fragile on Windows # (issue #1652). try: # Run result = get_async(pool.apply_async, len(pool._pool), dsk3, keys, get_id=_process_get_id, dumps=dumps, loads=loads, pack_exception=pack_exception, raise_exception=reraise, **kwargs) finally: if cleanup: pool.close() return result def initialize_worker_process(): """ Initialize a worker process before running any tasks in it. """ # If Numpy is already imported, presumably its random state was # inherited from the parent => re-seed it. np = sys.modules.get('numpy') if np is not None: np.random.seed()