You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ORPA-pyOpenRPA/WPy32-3720/python-3.7.2/Lib/site-packages/dask/array/utils.py

162 lines
4.4 KiB

6 years ago
from __future__ import absolute_import, division, print_function
import difflib
import functools
import math
import numbers
import os
import numpy as np
from toolz import frequencies, concat
from .core import Array
from ..highlevelgraph import HighLevelGraph
try:
AxisError = np.AxisError
except AttributeError:
try:
np.array([0]).sum(axis=5)
except Exception as e:
AxisError = type(e)
def normalize_to_array(x):
if 'cupy' in str(type(x)): # TODO: avoid explicit reference to cupy
return x.get()
else:
return x
def allclose(a, b, equal_nan=False, **kwargs):
a = normalize_to_array(a)
b = normalize_to_array(b)
if getattr(a, 'dtype', None) != 'O':
return np.allclose(a, b, equal_nan=equal_nan, **kwargs)
if equal_nan:
return (a.shape == b.shape and
all(np.isnan(b) if np.isnan(a) else a == b
for (a, b) in zip(a.flat, b.flat)))
return (a == b).all()
def same_keys(a, b):
def key(k):
if isinstance(k, str):
return (k, -1, -1, -1)
else:
return k
return sorted(a.dask, key=key) == sorted(b.dask, key=key)
def _not_empty(x):
return x.shape and 0 not in x.shape
def _check_dsk(dsk):
""" Check that graph is well named and non-overlapping """
if not isinstance(dsk, HighLevelGraph):
return
assert all(isinstance(k, (tuple, str)) for k in dsk.layers)
freqs = frequencies(concat(dsk.dicts.values()))
non_one = {k: v for k, v in freqs.items() if v != 1}
assert not non_one, non_one
def assert_eq_shape(a, b, check_nan=True):
for aa, bb in zip(a, b):
if math.isnan(aa) or math.isnan(bb):
if check_nan:
assert math.isnan(aa) == math.isnan(bb)
else:
assert aa == bb
def assert_eq(a, b, check_shape=True, check_graph=True, **kwargs):
a_original = a
b_original = b
if isinstance(a, Array):
assert a.dtype is not None
adt = a.dtype
if check_graph:
_check_dsk(a.dask)
a = a.compute(scheduler='sync')
if hasattr(a, 'todense'):
a = a.todense()
if not hasattr(a, 'dtype'):
a = np.array(a, dtype='O')
if _not_empty(a):
assert a.dtype == a_original.dtype
if check_shape:
assert_eq_shape(a_original.shape, a.shape, check_nan=False)
else:
if not hasattr(a, 'dtype'):
a = np.array(a, dtype='O')
adt = getattr(a, 'dtype', None)
if isinstance(b, Array):
assert b.dtype is not None
bdt = b.dtype
if check_graph:
_check_dsk(b.dask)
b = b.compute(scheduler='sync')
if not hasattr(b, 'dtype'):
b = np.array(b, dtype='O')
if hasattr(b, 'todense'):
b = b.todense()
if _not_empty(b):
assert b.dtype == b_original.dtype
if check_shape:
assert_eq_shape(b_original.shape, b.shape, check_nan=False)
else:
if not hasattr(b, 'dtype'):
b = np.array(b, dtype='O')
bdt = getattr(b, 'dtype', None)
if str(adt) != str(bdt):
diff = difflib.ndiff(str(adt).splitlines(), str(bdt).splitlines())
raise AssertionError('string repr are different' + os.linesep +
os.linesep.join(diff))
try:
assert a.shape == b.shape
assert allclose(a, b, **kwargs)
return True
except TypeError:
pass
c = a == b
if isinstance(c, np.ndarray):
assert c.all()
else:
assert c
return True
def safe_wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS):
"""Like functools.wraps, but safe to use even if wrapped is not a function.
Only needed on Python 2.
"""
if all(hasattr(wrapped, attr) for attr in assigned):
return functools.wraps(wrapped, assigned=assigned)
else:
return lambda x: x
def validate_axis(axis, ndim):
""" Validate an input to axis= keywords """
if isinstance(axis, (tuple, list)):
return tuple(validate_axis(ax, ndim) for ax in axis)
if not isinstance(axis, numbers.Integral):
raise TypeError("Axis value must be an integer, got %s" % axis)
if axis < -ndim or axis >= ndim:
raise AxisError("Axis %d is out of bounds for array of dimension %d"
% (axis, ndim))
if axis < 0:
axis += ndim
return axis