You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
41 lines
1.3 KiB
41 lines
1.3 KiB
from .core import tensordot_lookup, concatenate_lookup, einsum_lookup
|
|
|
|
|
|
@tensordot_lookup.register_lazy('cupy')
|
|
@concatenate_lookup.register_lazy('cupy')
|
|
def register_cupy():
|
|
import cupy
|
|
concatenate_lookup.register(cupy.ndarray, cupy.concatenate)
|
|
tensordot_lookup.register(cupy.ndarray, cupy.tensordot)
|
|
|
|
@einsum_lookup.register(cupy.ndarray)
|
|
def _cupy_einsum(*args, **kwargs):
|
|
# NB: cupy does not accept `order` or `casting` kwargs - ignore
|
|
kwargs.pop('casting', None)
|
|
kwargs.pop('order', None)
|
|
return cupy.einsum(*args, **kwargs)
|
|
|
|
|
|
@tensordot_lookup.register_lazy('sparse')
|
|
@concatenate_lookup.register_lazy('sparse')
|
|
def register_sparse():
|
|
import sparse
|
|
concatenate_lookup.register(sparse.COO, sparse.concatenate)
|
|
tensordot_lookup.register(sparse.COO, sparse.tensordot)
|
|
|
|
|
|
@concatenate_lookup.register_lazy('scipy')
|
|
def register_scipy_sparse():
|
|
import scipy.sparse
|
|
|
|
def _concatenate(L, axis=0):
|
|
if axis == 0:
|
|
return scipy.sparse.vstack(L)
|
|
elif axis == 1:
|
|
return scipy.sparse.hstack(L)
|
|
else:
|
|
msg = ("Can only concatenate scipy sparse matrices for axis in "
|
|
"{0, 1}. Got %s" % axis)
|
|
raise ValueError(msg)
|
|
concatenate_lookup.register(scipy.sparse.spmatrix, _concatenate)
|