from __future__ import absolute_import, division, print_function

import multiprocessing
import traceback
import pickle
import sys
from warnings import warn

import cloudpickle

from . import config
from .compatibility import copyreg
from .local import get_async  # TODO: get better get
from .optimization import fuse, cull


def _reduce_method_descriptor(m):
    return getattr, (m.__objclass__, m.__name__)


# type(set.union) is used as a proxy to <class 'method_descriptor'>
copyreg.pickle(type(set.union), _reduce_method_descriptor)


def _dumps(x):
    return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)


_loads = cloudpickle.loads


def _process_get_id():
    return multiprocessing.current_process().ident


# -- Remote Exception Handling --
# By default, tracebacks can't be serialized using pickle. However, the
# `tblib` library can enable support for this. Since we don't mandate
# that tblib is installed, we do the following:
#
# - If tblib is installed, use it to serialize the traceback and reraise
#   in the scheduler process
# - Otherwise, use a ``RemoteException`` class to contain a serialized
#   version of the formatted traceback, which will then print in the
#   scheduler process.
#
# To enable testing of the ``RemoteException`` class even when tblib is
# installed, we don't wrap the class in the try block below
class RemoteException(Exception):
    """ Remote Exception

    Contains the exception and traceback from a remotely run task
    """
    def __init__(self, exception, traceback):
        self.exception = exception
        self.traceback = traceback

    def __str__(self):
        return (str(self.exception) + "\n\n"
                "Traceback\n"
                "---------\n" +
                self.traceback)

    def __dir__(self):
        return sorted(set(dir(type(self)) +
                      list(self.__dict__) +
                      dir(self.exception)))

    def __getattr__(self, key):
        try:
            return object.__getattribute__(self, key)
        except AttributeError:
            return getattr(self.exception, key)


exceptions = dict()


def remote_exception(exc, tb):
    """ Metaclass that wraps exception type in RemoteException """
    if type(exc) in exceptions:
        typ = exceptions[type(exc)]
        return typ(exc, tb)
    else:
        try:
            typ = type(exc.__class__.__name__,
                       (RemoteException, type(exc)),
                       {'exception_type': type(exc)})
            exceptions[type(exc)] = typ
            return typ(exc, tb)
        except TypeError:
            return exc


try:
    import tblib.pickling_support
    tblib.pickling_support.install()
    from dask.compatibility import reraise

    def _pack_traceback(tb):
        return tb

except ImportError:
    def _pack_traceback(tb):
        return ''.join(traceback.format_tb(tb))

    def reraise(exc, tb):
        exc = remote_exception(exc, tb)
        raise exc


def pack_exception(e, dumps):
    exc_type, exc_value, exc_traceback = sys.exc_info()
    tb = _pack_traceback(exc_traceback)
    try:
        result = dumps((e, tb))
    except BaseException as e:
        exc_type, exc_value, exc_traceback = sys.exc_info()
        tb = _pack_traceback(exc_traceback)
        result = dumps((e, tb))
    return result


_CONTEXT_UNSUPPORTED = """\
The 'multiprocessing.context' configuration option will be ignored on Python 2
and on Windows, because they each only support a single context.
"""


def get_context():
    """ Return the current multiprocessing context."""
    if sys.platform == "win32" or sys.version_info.major == 2:
        # Just do the default, since we can't change it:
        if config.get("multiprocessing.context", None) is not None:
            warn(_CONTEXT_UNSUPPORTED, UserWarning)
        return multiprocessing
    context_name = config.get("multiprocessing.context", None)
    return multiprocessing.get_context(context_name)


def get(dsk, keys, num_workers=None, func_loads=None, func_dumps=None,
        optimize_graph=True, pool=None, **kwargs):
    """ Multiprocessed get function appropriate for Bags

    Parameters
    ----------
    dsk : dict
        dask graph
    keys : object or list
        Desired results from graph
    num_workers : int
        Number of worker processes (defaults to number of cores)
    func_dumps : function
        Function to use for function serialization
        (defaults to cloudpickle.dumps)
    func_loads : function
        Function to use for function deserialization
        (defaults to cloudpickle.loads)
    optimize_graph : bool
        If True [default], `fuse` is applied to the graph before computation.
    """
    pool = pool or config.get('pool', None)
    num_workers = num_workers or config.get('num_workers', None)
    if pool is None:
        context = get_context()
        pool = context.Pool(num_workers,
                            initializer=initialize_worker_process)
        cleanup = True
    else:
        cleanup = False

    # Optimize Dask
    dsk2, dependencies = cull(dsk, keys)
    if optimize_graph:
        dsk3, dependencies = fuse(dsk2, keys, dependencies)
    else:
        dsk3 = dsk2

    # We specify marshalling functions in order to catch serialization
    # errors and report them to the user.
    loads = func_loads or config.get('func_loads', None) or _loads
    dumps = func_dumps or config.get('func_dumps', None) or _dumps

    # Note former versions used a multiprocessing Manager to share
    # a Queue between parent and workers, but this is fragile on Windows
    # (issue #1652).
    try:
        # Run
        result = get_async(pool.apply_async, len(pool._pool), dsk3, keys,
                           get_id=_process_get_id, dumps=dumps, loads=loads,
                           pack_exception=pack_exception,
                           raise_exception=reraise, **kwargs)
    finally:
        if cleanup:
            pool.close()
    return result


def initialize_worker_process():
    """
    Initialize a worker process before running any tasks in it.
    """
    # If Numpy is already imported, presumably its random state was
    # inherited from the parent => re-seed it.
    np = sys.modules.get('numpy')
    if np is not None:
        np.random.seed()