You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
40 lines
935 B
40 lines
935 B
import pytest
|
|
import numpy as np
|
|
|
|
import dask.array as da
|
|
from dask.array.numpy_compat import _make_sliced_dtype
|
|
from dask.array.utils import assert_eq
|
|
|
|
|
|
@pytest.fixture(params=[
|
|
[('A', ('f4', (3, 2))), ('B', ('f4', 3)), ('C', ('f8', 3))],
|
|
[('A', ('i4', (3, 2))), ('B', ('f4', 3)), ('C', ('S4', 3))],
|
|
])
|
|
def dtype(request):
|
|
return np.dtype(request.param)
|
|
|
|
|
|
@pytest.fixture(params=[
|
|
['A'],
|
|
['A', 'B'],
|
|
['A', 'B', 'C'],
|
|
])
|
|
def index(request):
|
|
return request.param
|
|
|
|
|
|
def test_basic():
|
|
# sanity check
|
|
dtype = [('a', 'f8'), ('b', 'f8'), ('c', 'f8')]
|
|
x = np.ones((5, 3), dtype=dtype)
|
|
dx = da.ones((5, 3), dtype=dtype, chunks=3)
|
|
result = dx[['a', 'b']]
|
|
expected = x[['a', 'b']]
|
|
assert_eq(result, expected)
|
|
|
|
|
|
def test_slice_dtype(dtype, index):
|
|
result = _make_sliced_dtype(dtype, index)
|
|
expected = np.ones((5, len(dtype)), dtype=dtype)[index].dtype
|
|
assert result == expected
|