You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ORPA-pyOpenRPA/WPy32-3720/python-3.7.2/Lib/site-packages/dask/array/fft.py

295 lines
7.0 KiB

6 years ago
from __future__ import absolute_import, division, print_function
from ..compatibility import Sequence
from functools import wraps
import inspect
import numpy as np
try:
import scipy
import scipy.fftpack
except ImportError:
scipy = None
from .core import concatenate as _concatenate
from .creation import arange as _arange
chunk_error = ("Dask array only supports taking an FFT along an axis that \n"
"has a single chunk. An FFT operation was tried on axis %s \n"
"which has chunks %s. To change the array's chunks use "
"dask.Array.rechunk.")
fft_preamble = """
Wrapping of %s
The axis along which the FFT is applied must have a one chunk. To change
the array's chunking use dask.Array.rechunk.
The %s docstring follows below:
"""
def _fft_out_chunks(a, s, axes):
""" For computing the output chunks of [i]fft*"""
if s is None:
return a.chunks
chunks = list(a.chunks)
for i, axis in enumerate(axes):
chunks[axis] = (s[i],)
return chunks
def _rfft_out_chunks(a, s, axes):
""" For computing the output chunks of rfft*"""
if s is None:
s = [a.chunks[axis][0] for axis in axes]
s = list(s)
s[-1] = s[-1] // 2 + 1
chunks = list(a.chunks)
for i, axis in enumerate(axes):
chunks[axis] = (s[i],)
return chunks
def _irfft_out_chunks(a, s, axes):
""" For computing the output chunks of irfft*"""
if s is None:
s = [a.chunks[axis][0] for axis in axes]
s[-1] = 2 * (s[-1] - 1)
chunks = list(a.chunks)
for i, axis in enumerate(axes):
chunks[axis] = (s[i],)
return chunks
def _hfft_out_chunks(a, s, axes):
assert len(axes) == 1
axis = axes[0]
if s is None:
s = [2 * (a.chunks[axis][0] - 1)]
n = s[0]
chunks = list(a.chunks)
chunks[axis] = (n,)
return chunks
def _ihfft_out_chunks(a, s, axes):
assert len(axes) == 1
axis = axes[0]
if s is None:
s = [a.chunks[axis][0]]
else:
assert len(s) == 1
n = s[0]
chunks = list(a.chunks)
if n % 2 == 0:
m = (n // 2) + 1
else:
m = (n + 1) // 2
chunks[axis] = (m,)
return chunks
_out_chunk_fns = {'fft': _fft_out_chunks,
'ifft': _fft_out_chunks,
'rfft': _rfft_out_chunks,
'irfft': _irfft_out_chunks,
'hfft': _hfft_out_chunks,
'ihfft': _ihfft_out_chunks}
def fft_wrap(fft_func, kind=None, dtype=None):
""" Wrap 1D, 2D, and ND real and complex FFT functions
Takes a function that behaves like ``numpy.fft`` functions and
a specified kind to match it to that are named after the functions
in the ``numpy.fft`` API.
Supported kinds include:
* fft
* fft2
* fftn
* ifft
* ifft2
* ifftn
* rfft
* rfft2
* rfftn
* irfft
* irfft2
* irfftn
* hfft
* ihfft
Examples
--------
>>> parallel_fft = fft_wrap(np.fft.fft)
>>> parallel_ifft = fft_wrap(np.fft.ifft)
"""
if scipy is not None:
if fft_func is scipy.fftpack.rfft:
raise ValueError("SciPy's `rfft` doesn't match the NumPy API.")
elif fft_func is scipy.fftpack.irfft:
raise ValueError("SciPy's `irfft` doesn't match the NumPy API.")
if kind is None:
kind = fft_func.__name__
try:
out_chunk_fn = _out_chunk_fns[kind.rstrip("2n")]
except KeyError:
raise ValueError("Given unknown `kind` %s." % kind)
def func(a, s=None, axes=None):
if axes is None:
if kind.endswith('2'):
axes = (-2, -1)
elif kind.endswith('n'):
if s is None:
axes = tuple(range(a.ndim))
else:
axes = tuple(range(len(s)))
else:
axes = (-1,)
else:
if len(set(axes)) < len(axes):
raise ValueError("Duplicate axes not allowed.")
_dtype = dtype
if _dtype is None:
sample = np.ones(a.ndim * (8,), dtype=a.dtype)
try:
_dtype = fft_func(sample, axes=axes).dtype
except TypeError:
_dtype = fft_func(sample).dtype
for each_axis in axes:
if len(a.chunks[each_axis]) != 1:
raise ValueError(chunk_error % (each_axis, a.chunks[each_axis]))
chunks = out_chunk_fn(a, s, axes)
args = (s, axes)
if kind.endswith('fft'):
axis = None if axes is None else axes[0]
n = None if s is None else s[0]
args = (n, axis)
return a.map_blocks(fft_func, *args, dtype=_dtype,
chunks=chunks)
if kind.endswith('fft'):
_func = func
def func(a, n=None, axis=None):
s = None
if n is not None:
s = (n,)
axes = None
if axis is not None:
axes = (axis,)
return _func(a, s, axes)
func_mod = inspect.getmodule(fft_func)
func_name = fft_func.__name__
func_fullname = func_mod.__name__ + "." + func_name
if fft_func.__doc__ is not None:
func.__doc__ = (fft_preamble % (2 * (func_fullname,)))
func.__doc__ += fft_func.__doc__
func.__name__ = func_name
return func
fft = fft_wrap(np.fft.fft)
fft2 = fft_wrap(np.fft.fft2)
fftn = fft_wrap(np.fft.fftn)
ifft = fft_wrap(np.fft.ifft)
ifft2 = fft_wrap(np.fft.ifft2)
ifftn = fft_wrap(np.fft.ifftn)
rfft = fft_wrap(np.fft.rfft)
rfft2 = fft_wrap(np.fft.rfft2)
rfftn = fft_wrap(np.fft.rfftn)
irfft = fft_wrap(np.fft.irfft)
irfft2 = fft_wrap(np.fft.irfft2)
irfftn = fft_wrap(np.fft.irfftn)
hfft = fft_wrap(np.fft.hfft)
ihfft = fft_wrap(np.fft.ihfft)
def _fftfreq_block(i, n, d):
r = i.copy()
r[i >= (n + 1) // 2] -= n
r /= n * d
return r
@wraps(np.fft.fftfreq)
def fftfreq(n, d=1.0, chunks='auto'):
n = int(n)
d = float(d)
r = _arange(n, dtype=float, chunks=chunks)
return r.map_blocks(_fftfreq_block, dtype=float, n=n, d=d)
@wraps(np.fft.rfftfreq)
def rfftfreq(n, d=1.0, chunks='auto'):
n = int(n)
d = float(d)
r = _arange(n // 2 + 1, dtype=float, chunks=chunks)
r /= n * d
return r
def _fftshift_helper(x, axes=None, inverse=False):
if axes is None:
axes = list(range(x.ndim))
elif not isinstance(axes, Sequence):
axes = (axes,)
y = x
for i in axes:
n = y.shape[i]
n_2 = (n + int(inverse is False)) // 2
l = y.ndim * [slice(None)]
l[i] = slice(None, n_2)
l = tuple(l)
r = y.ndim * [slice(None)]
r[i] = slice(n_2, None)
r = tuple(r)
y = _concatenate([y[r], y[l]], axis=i)
if len(x.chunks[i]) == 1:
y = y.rechunk({i: x.chunks[i]})
return y
@wraps(np.fft.fftshift)
def fftshift(x, axes=None):
return _fftshift_helper(x, axes=axes, inverse=False)
@wraps(np.fft.ifftshift)
def ifftshift(x, axes=None):
return _fftshift_helper(x, axes=axes, inverse=True)