You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
162 lines
4.4 KiB
162 lines
4.4 KiB
from __future__ import absolute_import, division, print_function
|
|
|
|
import difflib
|
|
import functools
|
|
import math
|
|
import numbers
|
|
import os
|
|
|
|
import numpy as np
|
|
from toolz import frequencies, concat
|
|
|
|
from .core import Array
|
|
from ..highlevelgraph import HighLevelGraph
|
|
|
|
try:
|
|
AxisError = np.AxisError
|
|
except AttributeError:
|
|
try:
|
|
np.array([0]).sum(axis=5)
|
|
except Exception as e:
|
|
AxisError = type(e)
|
|
|
|
|
|
def normalize_to_array(x):
|
|
if 'cupy' in str(type(x)): # TODO: avoid explicit reference to cupy
|
|
return x.get()
|
|
else:
|
|
return x
|
|
|
|
|
|
def allclose(a, b, equal_nan=False, **kwargs):
|
|
a = normalize_to_array(a)
|
|
b = normalize_to_array(b)
|
|
if getattr(a, 'dtype', None) != 'O':
|
|
return np.allclose(a, b, equal_nan=equal_nan, **kwargs)
|
|
if equal_nan:
|
|
return (a.shape == b.shape and
|
|
all(np.isnan(b) if np.isnan(a) else a == b
|
|
for (a, b) in zip(a.flat, b.flat)))
|
|
return (a == b).all()
|
|
|
|
|
|
def same_keys(a, b):
|
|
def key(k):
|
|
if isinstance(k, str):
|
|
return (k, -1, -1, -1)
|
|
else:
|
|
return k
|
|
return sorted(a.dask, key=key) == sorted(b.dask, key=key)
|
|
|
|
|
|
def _not_empty(x):
|
|
return x.shape and 0 not in x.shape
|
|
|
|
|
|
def _check_dsk(dsk):
|
|
""" Check that graph is well named and non-overlapping """
|
|
if not isinstance(dsk, HighLevelGraph):
|
|
return
|
|
|
|
assert all(isinstance(k, (tuple, str)) for k in dsk.layers)
|
|
freqs = frequencies(concat(dsk.dicts.values()))
|
|
non_one = {k: v for k, v in freqs.items() if v != 1}
|
|
assert not non_one, non_one
|
|
|
|
|
|
def assert_eq_shape(a, b, check_nan=True):
|
|
for aa, bb in zip(a, b):
|
|
if math.isnan(aa) or math.isnan(bb):
|
|
if check_nan:
|
|
assert math.isnan(aa) == math.isnan(bb)
|
|
else:
|
|
assert aa == bb
|
|
|
|
|
|
def assert_eq(a, b, check_shape=True, check_graph=True, **kwargs):
|
|
a_original = a
|
|
b_original = b
|
|
if isinstance(a, Array):
|
|
assert a.dtype is not None
|
|
adt = a.dtype
|
|
if check_graph:
|
|
_check_dsk(a.dask)
|
|
a = a.compute(scheduler='sync')
|
|
if hasattr(a, 'todense'):
|
|
a = a.todense()
|
|
if not hasattr(a, 'dtype'):
|
|
a = np.array(a, dtype='O')
|
|
if _not_empty(a):
|
|
assert a.dtype == a_original.dtype
|
|
if check_shape:
|
|
assert_eq_shape(a_original.shape, a.shape, check_nan=False)
|
|
else:
|
|
if not hasattr(a, 'dtype'):
|
|
a = np.array(a, dtype='O')
|
|
adt = getattr(a, 'dtype', None)
|
|
|
|
if isinstance(b, Array):
|
|
assert b.dtype is not None
|
|
bdt = b.dtype
|
|
if check_graph:
|
|
_check_dsk(b.dask)
|
|
b = b.compute(scheduler='sync')
|
|
if not hasattr(b, 'dtype'):
|
|
b = np.array(b, dtype='O')
|
|
if hasattr(b, 'todense'):
|
|
b = b.todense()
|
|
if _not_empty(b):
|
|
assert b.dtype == b_original.dtype
|
|
if check_shape:
|
|
assert_eq_shape(b_original.shape, b.shape, check_nan=False)
|
|
else:
|
|
if not hasattr(b, 'dtype'):
|
|
b = np.array(b, dtype='O')
|
|
bdt = getattr(b, 'dtype', None)
|
|
|
|
if str(adt) != str(bdt):
|
|
diff = difflib.ndiff(str(adt).splitlines(), str(bdt).splitlines())
|
|
raise AssertionError('string repr are different' + os.linesep +
|
|
os.linesep.join(diff))
|
|
|
|
try:
|
|
assert a.shape == b.shape
|
|
assert allclose(a, b, **kwargs)
|
|
return True
|
|
except TypeError:
|
|
pass
|
|
|
|
c = a == b
|
|
|
|
if isinstance(c, np.ndarray):
|
|
assert c.all()
|
|
else:
|
|
assert c
|
|
|
|
return True
|
|
|
|
|
|
def safe_wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS):
|
|
"""Like functools.wraps, but safe to use even if wrapped is not a function.
|
|
|
|
Only needed on Python 2.
|
|
"""
|
|
if all(hasattr(wrapped, attr) for attr in assigned):
|
|
return functools.wraps(wrapped, assigned=assigned)
|
|
else:
|
|
return lambda x: x
|
|
|
|
|
|
def validate_axis(axis, ndim):
|
|
""" Validate an input to axis= keywords """
|
|
if isinstance(axis, (tuple, list)):
|
|
return tuple(validate_axis(ax, ndim) for ax in axis)
|
|
if not isinstance(axis, numbers.Integral):
|
|
raise TypeError("Axis value must be an integer, got %s" % axis)
|
|
if axis < -ndim or axis >= ndim:
|
|
raise AxisError("Axis %d is out of bounds for array of dimension %d"
|
|
% (axis, ndim))
|
|
if axis < 0:
|
|
axis += ndim
|
|
return axis
|