You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
398 lines
16 KiB
398 lines
16 KiB
import threading
|
|
import array
|
|
from ctypes import (POINTER, Structure, byref, cast, c_long, memmove, pointer,
|
|
sizeof)
|
|
from comtypes import _safearray, IUnknown, com_interface_registry, npsupport
|
|
from comtypes.patcher import Patch
|
|
|
|
numpy = npsupport.numpy
|
|
_safearray_type_cache = {}
|
|
|
|
|
|
class _SafeArrayAsNdArrayContextManager(object):
|
|
'''Context manager allowing safe arrays to be extracted as ndarrays.
|
|
|
|
This is thread-safe.
|
|
|
|
Example
|
|
-------
|
|
|
|
This works in python >= 2.5
|
|
>>> with safearray_as_ndarray:
|
|
>>> my_arr = com_object.AsSafeArray
|
|
>>> type(my_arr)
|
|
numpy.ndarray
|
|
|
|
'''
|
|
thread_local = threading.local()
|
|
|
|
def __enter__(self):
|
|
try:
|
|
self.thread_local.count += 1
|
|
except AttributeError:
|
|
self.thread_local.count = 1
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.thread_local.count -= 1
|
|
|
|
def __bool__(self):
|
|
'''True if context manager is currently entered on given thread.
|
|
|
|
'''
|
|
return bool(getattr(self.thread_local, 'count', 0))
|
|
|
|
|
|
# Global _SafeArrayAsNdArrayContextManager
|
|
safearray_as_ndarray = _SafeArrayAsNdArrayContextManager()
|
|
|
|
|
|
################################################################
|
|
# This is THE PUBLIC function: the gateway to the SAFEARRAY functionality.
|
|
def _midlSAFEARRAY(itemtype):
|
|
"""This function mimics the 'SAFEARRAY(aType)' IDL idiom. It
|
|
returns a subtype of SAFEARRAY, instances will be built with a
|
|
typecode VT_... corresponding to the aType, which must be one of
|
|
the supported ctypes.
|
|
"""
|
|
try:
|
|
return POINTER(_safearray_type_cache[itemtype])
|
|
except KeyError:
|
|
sa_type = _make_safearray_type(itemtype)
|
|
_safearray_type_cache[itemtype] = sa_type
|
|
return POINTER(sa_type)
|
|
|
|
|
|
def _make_safearray_type(itemtype):
|
|
# Create and return a subclass of tagSAFEARRAY
|
|
from comtypes.automation import _ctype_to_vartype, VT_RECORD, \
|
|
VT_UNKNOWN, IDispatch, VT_DISPATCH
|
|
|
|
meta = type(_safearray.tagSAFEARRAY)
|
|
sa_type = meta.__new__(meta,
|
|
"SAFEARRAY_%s" % itemtype.__name__,
|
|
(_safearray.tagSAFEARRAY,), {})
|
|
|
|
try:
|
|
vartype = _ctype_to_vartype[itemtype]
|
|
extra = None
|
|
except KeyError:
|
|
if issubclass(itemtype, Structure):
|
|
try:
|
|
guids = itemtype._recordinfo_
|
|
except AttributeError:
|
|
extra = None
|
|
else:
|
|
from comtypes.typeinfo import GetRecordInfoFromGuids
|
|
extra = GetRecordInfoFromGuids(*guids)
|
|
vartype = VT_RECORD
|
|
elif issubclass(itemtype, POINTER(IDispatch)):
|
|
vartype = VT_DISPATCH
|
|
extra = pointer(itemtype._iid_)
|
|
elif issubclass(itemtype, POINTER(IUnknown)):
|
|
vartype = VT_UNKNOWN
|
|
extra = pointer(itemtype._iid_)
|
|
else:
|
|
raise TypeError(itemtype)
|
|
|
|
@Patch(POINTER(sa_type))
|
|
class _(object):
|
|
# Should explain the ideas how SAFEARRAY is used in comtypes
|
|
_itemtype_ = itemtype # a ctypes type
|
|
_vartype_ = vartype # a VARTYPE value: VT_...
|
|
_needsfree = False
|
|
|
|
@classmethod
|
|
def create(cls, value, extra=None):
|
|
"""Create a POINTER(SAFEARRAY_...) instance of the correct
|
|
type; value is an object containing the items to store.
|
|
|
|
Python lists, tuples, and array.array instances containing
|
|
compatible item types can be passed to create
|
|
one-dimensional arrays. To create multidimensional arrys,
|
|
numpy arrays must be passed.
|
|
"""
|
|
if npsupport.isndarray(value):
|
|
return cls.create_from_ndarray(value, extra)
|
|
|
|
# For VT_UNKNOWN or VT_DISPATCH, extra must be a pointer to
|
|
# the GUID of the interface.
|
|
#
|
|
# For VT_RECORD, extra must be a pointer to an IRecordInfo
|
|
# describing the record.
|
|
|
|
# XXX How to specify the lbound (3. parameter to CreateVectorEx)?
|
|
# XXX How to write tests for lbound != 0?
|
|
pa = _safearray.SafeArrayCreateVectorEx(cls._vartype_,
|
|
0,
|
|
len(value),
|
|
extra)
|
|
if not pa:
|
|
if cls._vartype_ == VT_RECORD and extra is None:
|
|
raise TypeError("Cannot create SAFEARRAY type VT_RECORD without IRecordInfo.")
|
|
# Hm, there may be other reasons why the creation fails...
|
|
raise MemoryError()
|
|
# We now have a POINTER(tagSAFEARRAY) instance which we must cast
|
|
# to the correct type:
|
|
pa = cast(pa, cls)
|
|
# Now, fill the data in:
|
|
ptr = POINTER(cls._itemtype_)() # container for the values
|
|
_safearray.SafeArrayAccessData(pa, byref(ptr))
|
|
try:
|
|
if isinstance(value, array.array):
|
|
addr, n = value.buffer_info()
|
|
nbytes = len(value) * sizeof(cls._itemtype_)
|
|
memmove(ptr, addr, nbytes)
|
|
else:
|
|
for index, item in enumerate(value):
|
|
ptr[index] = item
|
|
finally:
|
|
_safearray.SafeArrayUnaccessData(pa)
|
|
return pa
|
|
|
|
@classmethod
|
|
def create_from_ndarray(cls, value, extra, lBound=0):
|
|
from comtypes.automation import VARIANT
|
|
# If processing VARIANT, makes sure the array type is correct.
|
|
if cls._itemtype_ is VARIANT:
|
|
if value.dtype != npsupport.VARIANT_dtype:
|
|
value = _ndarray_to_variant_array(value)
|
|
else:
|
|
ai = value.__array_interface__
|
|
if ai["version"] != 3:
|
|
raise TypeError("only __array_interface__ version 3 supported")
|
|
if cls._itemtype_ != numpy.ctypeslib._typecodes[ai["typestr"]]:
|
|
raise TypeError("Wrong array item type")
|
|
|
|
# SAFEARRAYs have Fortran order; convert the numpy array if needed
|
|
if not value.flags.f_contiguous:
|
|
value = numpy.array(value, order="F")
|
|
|
|
# For VT_UNKNOWN or VT_DISPATCH, extra must be a pointer to
|
|
# the GUID of the interface.
|
|
#
|
|
# For VT_RECORD, extra must be a pointer to an IRecordInfo
|
|
# describing the record.
|
|
rgsa = (_safearray.SAFEARRAYBOUND * value.ndim)()
|
|
nitems = 1
|
|
for i, d in enumerate(value.shape):
|
|
nitems *= d
|
|
rgsa[i].cElements = d
|
|
rgsa[i].lBound = lBound
|
|
pa = _safearray.SafeArrayCreateEx(cls._vartype_,
|
|
value.ndim, # cDims
|
|
rgsa, # rgsaBound
|
|
extra) # pvExtra
|
|
if not pa:
|
|
if cls._vartype_ == VT_RECORD and extra is None:
|
|
raise TypeError("Cannot create SAFEARRAY type VT_RECORD without IRecordInfo.")
|
|
# Hm, there may be other reasons why the creation fails...
|
|
raise MemoryError()
|
|
# We now have a POINTER(tagSAFEARRAY) instance which we must cast
|
|
# to the correct type:
|
|
pa = cast(pa, cls)
|
|
# Now, fill the data in:
|
|
ptr = POINTER(cls._itemtype_)() # pointer to the item values
|
|
_safearray.SafeArrayAccessData(pa, byref(ptr))
|
|
try:
|
|
nbytes = nitems * sizeof(cls._itemtype_)
|
|
memmove(ptr, value.ctypes.data, nbytes)
|
|
finally:
|
|
_safearray.SafeArrayUnaccessData(pa)
|
|
return pa
|
|
|
|
@classmethod
|
|
def from_param(cls, value):
|
|
if not isinstance(value, cls):
|
|
value = cls.create(value, extra)
|
|
value._needsfree = True
|
|
return value
|
|
|
|
def __getitem__(self, index):
|
|
# pparray[0] returns the whole array contents.
|
|
if index != 0:
|
|
raise IndexError("Only index 0 allowed")
|
|
return self.unpack()
|
|
|
|
def __setitem__(self, index, value):
|
|
# XXX Need this to implement [in, out] safearrays in COM servers!
|
|
## print "__setitem__", index, value
|
|
raise TypeError("Setting items not allowed")
|
|
|
|
def __ctypes_from_outparam__(self):
|
|
self._needsfree = True
|
|
return self[0]
|
|
|
|
def __del__(self, _SafeArrayDestroy=_safearray.SafeArrayDestroy):
|
|
if self._needsfree:
|
|
_SafeArrayDestroy(self)
|
|
|
|
def _get_size(self, dim):
|
|
"Return the number of elements for dimension 'dim'"
|
|
ub = _safearray.SafeArrayGetUBound(self, dim) + 1
|
|
lb = _safearray.SafeArrayGetLBound(self, dim)
|
|
return ub - lb
|
|
|
|
def unpack(self):
|
|
"""Unpack a POINTER(SAFEARRAY_...) into a Python tuple or ndarray."""
|
|
dim = _safearray.SafeArrayGetDim(self)
|
|
|
|
if dim == 1:
|
|
num_elements = self._get_size(1)
|
|
result = self._get_elements_raw(num_elements)
|
|
if safearray_as_ndarray:
|
|
import numpy
|
|
return numpy.asarray(result)
|
|
return tuple(result)
|
|
elif dim == 2:
|
|
# get the number of elements in each dimension
|
|
rows, cols = self._get_size(1), self._get_size(2)
|
|
# get all elements
|
|
result = self._get_elements_raw(rows * cols)
|
|
# this must be reshaped and transposed because it is
|
|
# flat, and in VB order
|
|
if safearray_as_ndarray:
|
|
import numpy
|
|
return numpy.asarray(result).reshape((cols, rows)).T
|
|
result = [tuple(result[r::rows]) for r in range(rows)]
|
|
return tuple(result)
|
|
else:
|
|
lowerbounds = [_safearray.SafeArrayGetLBound(self, d)
|
|
for d in range(1, dim+1)]
|
|
indexes = (c_long * dim)(*lowerbounds)
|
|
upperbounds = [_safearray.SafeArrayGetUBound(self, d)
|
|
for d in range(1, dim+1)]
|
|
row = self._get_row(0, indexes, lowerbounds, upperbounds)
|
|
if safearray_as_ndarray:
|
|
import numpy
|
|
return numpy.asarray(row)
|
|
return row
|
|
|
|
def _get_elements_raw(self, num_elements):
|
|
"""Returns a flat list or ndarray containing ALL elements in
|
|
the safearray."""
|
|
from comtypes.automation import VARIANT
|
|
# XXX Not sure this is true:
|
|
# For VT_UNKNOWN and VT_DISPATCH, we should retrieve the
|
|
# interface iid by SafeArrayGetIID().
|
|
ptr = POINTER(self._itemtype_)() # container for the values
|
|
_safearray.SafeArrayAccessData(self, byref(ptr))
|
|
try:
|
|
if self._itemtype_ == VARIANT:
|
|
# We have to loop over each item, so we get no
|
|
# speedup by creating an ndarray here.
|
|
return [i.value for i in ptr[:num_elements]]
|
|
elif issubclass(self._itemtype_, POINTER(IUnknown)):
|
|
iid = _safearray.SafeArrayGetIID(self)
|
|
itf = com_interface_registry[str(iid)]
|
|
# COM interface pointers retrieved from array
|
|
# must be AddRef()'d if non-NULL.
|
|
elems = ptr[:num_elements]
|
|
result = []
|
|
# We have to loop over each item, so we get no
|
|
# speedup by creating an ndarray here.
|
|
for p in elems:
|
|
if bool(p):
|
|
p.AddRef()
|
|
result.append(p.QueryInterface(itf))
|
|
else:
|
|
# return a NULL-interface pointer.
|
|
result.append(POINTER(itf)())
|
|
return result
|
|
else:
|
|
# If the safearray element are NOT native python
|
|
# objects, the containing safearray must be kept
|
|
# alive until all the elements are destroyed.
|
|
if not issubclass(self._itemtype_, Structure):
|
|
# Create an ndarray if requested. This is where
|
|
# we can get the most speed-up.
|
|
# XXX Only try to convert types known to
|
|
# numpy.ctypeslib.
|
|
if (safearray_as_ndarray and self._itemtype_ in
|
|
list(numpy.ctypeslib._typecodes.values())):
|
|
arr = numpy.ctypeslib.as_array(ptr,
|
|
(num_elements,))
|
|
return arr.copy()
|
|
return ptr[:num_elements]
|
|
|
|
def keep_safearray(v):
|
|
v.__keepref = self
|
|
return v
|
|
return [keep_safearray(x) for x in ptr[:num_elements]]
|
|
finally:
|
|
_safearray.SafeArrayUnaccessData(self)
|
|
|
|
def _get_row(self, dim, indices, lowerbounds, upperbounds):
|
|
# loop over the index of dimension 'dim'
|
|
# we have to restore the index of the dimension we're looping over
|
|
restore = indices[dim]
|
|
|
|
result = []
|
|
obj = self._itemtype_()
|
|
pobj = byref(obj)
|
|
if dim+1 == len(indices):
|
|
# It should be faster to lock the array and get a whole row at once?
|
|
# How to calculate the pointer offset?
|
|
for i in range(indices[dim], upperbounds[dim]+1):
|
|
indices[dim] = i
|
|
_safearray.SafeArrayGetElement(self, indices, pobj)
|
|
result.append(obj.value)
|
|
else:
|
|
for i in range(indices[dim], upperbounds[dim]+1):
|
|
indices[dim] = i
|
|
result.append(self._get_row(dim+1, indices, lowerbounds, upperbounds))
|
|
indices[dim] = restore
|
|
return tuple(result) # for compatibility with pywin32.
|
|
|
|
@Patch(POINTER(POINTER(sa_type)))
|
|
class __(object):
|
|
|
|
@classmethod
|
|
def from_param(cls, value):
|
|
if isinstance(value, cls._type_):
|
|
return byref(value)
|
|
return byref(cls._type_.create(value, extra))
|
|
|
|
def __setitem__(self, index, value):
|
|
# create an LP_SAFEARRAY_... instance
|
|
pa = self._type_.create(value, extra)
|
|
# XXX Must we destroy the currently contained data?
|
|
# fill it into self
|
|
super(POINTER(POINTER(sa_type)), self).__setitem__(index, pa)
|
|
|
|
return sa_type
|
|
|
|
|
|
def _ndarray_to_variant_array(value):
|
|
""" Convert an ndarray to VARIANT_dtype array """
|
|
# Check that variant arrays are supported
|
|
if npsupport.VARIANT_dtype is None:
|
|
msg = "VARIANT ndarrays require NumPy 1.7 or newer."
|
|
raise RuntimeError(msg)
|
|
|
|
# special cases
|
|
if numpy.issubdtype(value.dtype, npsupport.datetime64):
|
|
return _datetime64_ndarray_to_variant_array(value)
|
|
|
|
from comtypes.automation import VARIANT
|
|
# Empty array
|
|
varr = numpy.zeros(value.shape, npsupport.VARIANT_dtype, order='F')
|
|
# Convert each value to a variant and put it in the array.
|
|
varr.flat = [VARIANT(v) for v in value.flat]
|
|
return varr
|
|
|
|
|
|
def _datetime64_ndarray_to_variant_array(value):
|
|
""" Convert an ndarray of datetime64 to VARIANT_dtype array """
|
|
# The OLE automation date format is a floating point value, counting days
|
|
# since midnight 30 December 1899. Hours and minutes are represented as
|
|
# fractional days.
|
|
from comtypes.automation import VT_DATE
|
|
value = numpy.array(value, "datetime64[ns]")
|
|
value = value - npsupport.com_null_date64
|
|
# Convert to days
|
|
value = value / numpy.timedelta64(1, 'D')
|
|
varr = numpy.zeros(value.shape, npsupport.VARIANT_dtype, order='F')
|
|
varr['vt'] = VT_DATE
|
|
varr['_']['VT_R8'].flat = value.flat
|
|
return varr
|