"""Pickle related utilities. Perhaps this should be called 'can'.""" # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. import typing import warnings warnings.warn( "ipykernel.pickleutil is deprecated. It has moved to ipyparallel.", DeprecationWarning, stacklevel=2, ) import copy import pickle import sys from types import FunctionType # This registers a hook when it's imported from ipyparallel.serialize import codeutil # noqa F401 from traitlets.log import get_logger from traitlets.utils.importstring import import_item buffer = memoryview class_type = type PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL def _get_cell_type(a=None): """the type of a closure cell doesn't seem to be importable, so just create one """ def inner(): return a return type(inner.__closure__[0]) # type:ignore[index] cell_type = _get_cell_type() # ------------------------------------------------------------------------------- # Functions # ------------------------------------------------------------------------------- def interactive(f): """decorator for making functions appear as interactively defined. This results in the function being linked to the user_ns as globals() instead of the module globals(). """ # build new FunctionType, so it can have the right globals # interactive functions never have closures, that's kind of the point if isinstance(f, FunctionType): mainmod = __import__("__main__") f = FunctionType( f.__code__, mainmod.__dict__, f.__name__, f.__defaults__, ) # associate with __main__ for uncanning f.__module__ = "__main__" return f def use_dill(): """use dill to expand serialization support adds support for object methods and closures to serialization. """ # import dill causes most of the magic import dill # dill doesn't work with cPickle, # tell the two relevant modules to use plain pickle global pickle pickle = dill try: from ipykernel import serialize except ImportError: pass else: serialize.pickle = dill # type:ignore[attr-defined] # disable special function handling, let dill take care of it can_map.pop(FunctionType, None) def use_cloudpickle(): """use cloudpickle to expand serialization support adds support for object methods and closures to serialization. """ import cloudpickle global pickle pickle = cloudpickle try: from ipykernel import serialize except ImportError: pass else: serialize.pickle = cloudpickle # type:ignore[attr-defined] # disable special function handling, let cloudpickle take care of it can_map.pop(FunctionType, None) # ------------------------------------------------------------------------------- # Classes # ------------------------------------------------------------------------------- class CannedObject: def __init__(self, obj, keys=None, hook=None): """can an object for safe pickling Parameters ---------- obj The object to be canned keys : list (optional) list of attribute names that will be explicitly canned / uncanned hook : callable (optional) An optional extra callable, which can do additional processing of the uncanned object. Notes ----- large data may be offloaded into the buffers list, used for zero-copy transfers. """ self.keys = keys or [] self.obj = copy.copy(obj) self.hook = can(hook) for key in keys: setattr(self.obj, key, can(getattr(obj, key))) self.buffers = [] def get_object(self, g=None): if g is None: g = {} obj = self.obj for key in self.keys: setattr(obj, key, uncan(getattr(obj, key), g)) if self.hook: self.hook = uncan(self.hook, g) self.hook(obj, g) return self.obj class Reference(CannedObject): """object for wrapping a remote reference by name.""" def __init__(self, name): if not isinstance(name, str): raise TypeError("illegal name: %r" % name) self.name = name self.buffers = [] def __repr__(self): return "" % self.name def get_object(self, g=None): if g is None: g = {} return eval(self.name, g) class CannedCell(CannedObject): """Can a closure cell""" def __init__(self, cell): self.cell_contents = can(cell.cell_contents) def get_object(self, g=None): cell_contents = uncan(self.cell_contents, g) def inner(): return cell_contents return inner.__closure__[0] # type:ignore[index] class CannedFunction(CannedObject): def __init__(self, f): self._check_type(f) self.code = f.__code__ self.defaults: typing.Optional[typing.List[typing.Any]] if f.__defaults__: self.defaults = [can(fd) for fd in f.__defaults__] else: self.defaults = None self.closure: typing.Any closure = f.__closure__ if closure: self.closure = tuple(can(cell) for cell in closure) else: self.closure = None self.module = f.__module__ or "__main__" self.__name__ = f.__name__ self.buffers = [] def _check_type(self, obj): assert isinstance(obj, FunctionType), "Not a function type" def get_object(self, g=None): # try to load function back into its module: if not self.module.startswith("__"): __import__(self.module) g = sys.modules[self.module].__dict__ if g is None: g = {} if self.defaults: defaults = tuple(uncan(cfd, g) for cfd in self.defaults) else: defaults = None if self.closure: closure = tuple(uncan(cell, g) for cell in self.closure) else: closure = None newFunc = FunctionType(self.code, g, self.__name__, defaults, closure) return newFunc class CannedClass(CannedObject): def __init__(self, cls): self._check_type(cls) self.name = cls.__name__ self.old_style = not isinstance(cls, type) self._canned_dict = {} for k, v in cls.__dict__.items(): if k not in ("__weakref__", "__dict__"): self._canned_dict[k] = can(v) if self.old_style: mro = [] else: mro = cls.mro() self.parents = [can(c) for c in mro[1:]] self.buffers = [] def _check_type(self, obj): assert isinstance(obj, class_type), "Not a class type" def get_object(self, g=None): parents = tuple(uncan(p, g) for p in self.parents) return type(self.name, parents, uncan_dict(self._canned_dict, g=g)) class CannedArray(CannedObject): def __init__(self, obj): from numpy import ascontiguousarray self.shape = obj.shape self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str self.pickled = False if sum(obj.shape) == 0: self.pickled = True elif obj.dtype == "O": # can't handle object dtype with buffer approach self.pickled = True elif obj.dtype.fields and any(dt == "O" for dt, sz in obj.dtype.fields.values()): self.pickled = True if self.pickled: # just pickle it self.buffers = [pickle.dumps(obj, PICKLE_PROTOCOL)] else: # ensure contiguous obj = ascontiguousarray(obj, dtype=None) self.buffers = [buffer(obj)] def get_object(self, g=None): from numpy import frombuffer data = self.buffers[0] if self.pickled: # we just pickled it return pickle.loads(data) else: return frombuffer(data, dtype=self.dtype).reshape(self.shape) class CannedBytes(CannedObject): @staticmethod def wrap(buf: typing.Union[memoryview, bytes, typing.SupportsBytes]) -> bytes: """Cast a buffer or memoryview object to bytes""" if isinstance(buf, memoryview): return buf.tobytes() if not isinstance(buf, bytes): return bytes(buf) return buf def __init__(self, obj): self.buffers = [obj] def get_object(self, g=None): data = self.buffers[0] return self.wrap(data) class CannedBuffer(CannedBytes): wrap = buffer # type:ignore[assignment] class CannedMemoryView(CannedBytes): wrap = memoryview # type:ignore[assignment] # ------------------------------------------------------------------------------- # Functions # ------------------------------------------------------------------------------- def _import_mapping(mapping, original=None): """import any string-keys in a type mapping""" log = get_logger() log.debug("Importing canning map") for key, _ in list(mapping.items()): if isinstance(key, str): try: cls = import_item(key) except Exception: if original and key not in original: # only message on user-added classes log.error("canning class not importable: %r", key, exc_info=True) mapping.pop(key) else: mapping[cls] = mapping.pop(key) def istype(obj, check): """like isinstance(obj, check), but strict This won't catch subclasses. """ if isinstance(check, tuple): for cls in check: if type(obj) is cls: return True return False else: return type(obj) is check def can(obj): """prepare an object for pickling""" import_needed = False for cls, canner in can_map.items(): if isinstance(cls, str): import_needed = True break elif istype(obj, cls): return canner(obj) if import_needed: # perform can_map imports, then try again # this will usually only happen once _import_mapping(can_map, _original_can_map) return can(obj) return obj def can_class(obj): if isinstance(obj, class_type) and obj.__module__ == "__main__": return CannedClass(obj) else: return obj def can_dict(obj): """can the *values* of a dict""" if istype(obj, dict): newobj = {} for k, v in obj.items(): newobj[k] = can(v) return newobj else: return obj sequence_types = (list, tuple, set) def can_sequence(obj): """can the elements of a sequence""" if istype(obj, sequence_types): t = type(obj) return t([can(i) for i in obj]) else: return obj def uncan(obj, g=None): """invert canning""" import_needed = False for cls, uncanner in uncan_map.items(): if isinstance(cls, str): import_needed = True break elif isinstance(obj, cls): return uncanner(obj, g) if import_needed: # perform uncan_map imports, then try again # this will usually only happen once _import_mapping(uncan_map, _original_uncan_map) return uncan(obj, g) return obj def uncan_dict(obj, g=None): if istype(obj, dict): newobj = {} for k, v in obj.items(): newobj[k] = uncan(v, g) return newobj else: return obj def uncan_sequence(obj, g=None): if istype(obj, sequence_types): t = type(obj) return t([uncan(i, g) for i in obj]) else: return obj # ------------------------------------------------------------------------------- # API dictionaries # ------------------------------------------------------------------------------- # These dicts can be extended for custom serialization of new objects can_map = { "numpy.ndarray": CannedArray, FunctionType: CannedFunction, bytes: CannedBytes, memoryview: CannedMemoryView, cell_type: CannedCell, class_type: can_class, } if buffer is not memoryview: can_map[buffer] = CannedBuffer uncan_map: typing.Dict[type, typing.Any] = { CannedObject: lambda obj, g: obj.get_object(g), dict: uncan_dict, } # for use in _import_mapping: _original_can_map = can_map.copy() _original_uncan_map = uncan_map.copy()