You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
295 lines
7.0 KiB
295 lines
7.0 KiB
from __future__ import absolute_import, division, print_function
|
|
|
|
from ..compatibility import Sequence
|
|
from functools import wraps
|
|
import inspect
|
|
|
|
import numpy as np
|
|
|
|
try:
|
|
import scipy
|
|
import scipy.fftpack
|
|
except ImportError:
|
|
scipy = None
|
|
|
|
from .core import concatenate as _concatenate
|
|
from .creation import arange as _arange
|
|
|
|
|
|
chunk_error = ("Dask array only supports taking an FFT along an axis that \n"
|
|
"has a single chunk. An FFT operation was tried on axis %s \n"
|
|
"which has chunks %s. To change the array's chunks use "
|
|
"dask.Array.rechunk.")
|
|
|
|
fft_preamble = """
|
|
Wrapping of %s
|
|
|
|
The axis along which the FFT is applied must have a one chunk. To change
|
|
the array's chunking use dask.Array.rechunk.
|
|
|
|
The %s docstring follows below:
|
|
|
|
"""
|
|
|
|
|
|
def _fft_out_chunks(a, s, axes):
|
|
""" For computing the output chunks of [i]fft*"""
|
|
if s is None:
|
|
return a.chunks
|
|
chunks = list(a.chunks)
|
|
for i, axis in enumerate(axes):
|
|
chunks[axis] = (s[i],)
|
|
return chunks
|
|
|
|
|
|
def _rfft_out_chunks(a, s, axes):
|
|
""" For computing the output chunks of rfft*"""
|
|
if s is None:
|
|
s = [a.chunks[axis][0] for axis in axes]
|
|
s = list(s)
|
|
s[-1] = s[-1] // 2 + 1
|
|
chunks = list(a.chunks)
|
|
for i, axis in enumerate(axes):
|
|
chunks[axis] = (s[i],)
|
|
return chunks
|
|
|
|
|
|
def _irfft_out_chunks(a, s, axes):
|
|
""" For computing the output chunks of irfft*"""
|
|
if s is None:
|
|
s = [a.chunks[axis][0] for axis in axes]
|
|
s[-1] = 2 * (s[-1] - 1)
|
|
chunks = list(a.chunks)
|
|
for i, axis in enumerate(axes):
|
|
chunks[axis] = (s[i],)
|
|
return chunks
|
|
|
|
|
|
def _hfft_out_chunks(a, s, axes):
|
|
assert len(axes) == 1
|
|
|
|
axis = axes[0]
|
|
|
|
if s is None:
|
|
s = [2 * (a.chunks[axis][0] - 1)]
|
|
|
|
n = s[0]
|
|
|
|
chunks = list(a.chunks)
|
|
chunks[axis] = (n,)
|
|
return chunks
|
|
|
|
|
|
def _ihfft_out_chunks(a, s, axes):
|
|
assert len(axes) == 1
|
|
|
|
axis = axes[0]
|
|
|
|
if s is None:
|
|
s = [a.chunks[axis][0]]
|
|
else:
|
|
assert len(s) == 1
|
|
|
|
n = s[0]
|
|
|
|
chunks = list(a.chunks)
|
|
if n % 2 == 0:
|
|
m = (n // 2) + 1
|
|
else:
|
|
m = (n + 1) // 2
|
|
chunks[axis] = (m,)
|
|
return chunks
|
|
|
|
|
|
_out_chunk_fns = {'fft': _fft_out_chunks,
|
|
'ifft': _fft_out_chunks,
|
|
'rfft': _rfft_out_chunks,
|
|
'irfft': _irfft_out_chunks,
|
|
'hfft': _hfft_out_chunks,
|
|
'ihfft': _ihfft_out_chunks}
|
|
|
|
|
|
def fft_wrap(fft_func, kind=None, dtype=None):
|
|
""" Wrap 1D, 2D, and ND real and complex FFT functions
|
|
|
|
Takes a function that behaves like ``numpy.fft`` functions and
|
|
a specified kind to match it to that are named after the functions
|
|
in the ``numpy.fft`` API.
|
|
|
|
Supported kinds include:
|
|
|
|
* fft
|
|
* fft2
|
|
* fftn
|
|
* ifft
|
|
* ifft2
|
|
* ifftn
|
|
* rfft
|
|
* rfft2
|
|
* rfftn
|
|
* irfft
|
|
* irfft2
|
|
* irfftn
|
|
* hfft
|
|
* ihfft
|
|
|
|
Examples
|
|
--------
|
|
>>> parallel_fft = fft_wrap(np.fft.fft)
|
|
>>> parallel_ifft = fft_wrap(np.fft.ifft)
|
|
"""
|
|
if scipy is not None:
|
|
if fft_func is scipy.fftpack.rfft:
|
|
raise ValueError("SciPy's `rfft` doesn't match the NumPy API.")
|
|
elif fft_func is scipy.fftpack.irfft:
|
|
raise ValueError("SciPy's `irfft` doesn't match the NumPy API.")
|
|
|
|
if kind is None:
|
|
kind = fft_func.__name__
|
|
try:
|
|
out_chunk_fn = _out_chunk_fns[kind.rstrip("2n")]
|
|
except KeyError:
|
|
raise ValueError("Given unknown `kind` %s." % kind)
|
|
|
|
def func(a, s=None, axes=None):
|
|
if axes is None:
|
|
if kind.endswith('2'):
|
|
axes = (-2, -1)
|
|
elif kind.endswith('n'):
|
|
if s is None:
|
|
axes = tuple(range(a.ndim))
|
|
else:
|
|
axes = tuple(range(len(s)))
|
|
else:
|
|
axes = (-1,)
|
|
else:
|
|
if len(set(axes)) < len(axes):
|
|
raise ValueError("Duplicate axes not allowed.")
|
|
|
|
_dtype = dtype
|
|
if _dtype is None:
|
|
sample = np.ones(a.ndim * (8,), dtype=a.dtype)
|
|
try:
|
|
_dtype = fft_func(sample, axes=axes).dtype
|
|
except TypeError:
|
|
_dtype = fft_func(sample).dtype
|
|
|
|
for each_axis in axes:
|
|
if len(a.chunks[each_axis]) != 1:
|
|
raise ValueError(chunk_error % (each_axis, a.chunks[each_axis]))
|
|
|
|
chunks = out_chunk_fn(a, s, axes)
|
|
|
|
args = (s, axes)
|
|
if kind.endswith('fft'):
|
|
axis = None if axes is None else axes[0]
|
|
n = None if s is None else s[0]
|
|
args = (n, axis)
|
|
|
|
return a.map_blocks(fft_func, *args, dtype=_dtype,
|
|
chunks=chunks)
|
|
|
|
if kind.endswith('fft'):
|
|
_func = func
|
|
|
|
def func(a, n=None, axis=None):
|
|
s = None
|
|
if n is not None:
|
|
s = (n,)
|
|
|
|
axes = None
|
|
if axis is not None:
|
|
axes = (axis,)
|
|
|
|
return _func(a, s, axes)
|
|
|
|
func_mod = inspect.getmodule(fft_func)
|
|
func_name = fft_func.__name__
|
|
func_fullname = func_mod.__name__ + "." + func_name
|
|
if fft_func.__doc__ is not None:
|
|
func.__doc__ = (fft_preamble % (2 * (func_fullname,)))
|
|
func.__doc__ += fft_func.__doc__
|
|
func.__name__ = func_name
|
|
return func
|
|
|
|
|
|
fft = fft_wrap(np.fft.fft)
|
|
fft2 = fft_wrap(np.fft.fft2)
|
|
fftn = fft_wrap(np.fft.fftn)
|
|
ifft = fft_wrap(np.fft.ifft)
|
|
ifft2 = fft_wrap(np.fft.ifft2)
|
|
ifftn = fft_wrap(np.fft.ifftn)
|
|
rfft = fft_wrap(np.fft.rfft)
|
|
rfft2 = fft_wrap(np.fft.rfft2)
|
|
rfftn = fft_wrap(np.fft.rfftn)
|
|
irfft = fft_wrap(np.fft.irfft)
|
|
irfft2 = fft_wrap(np.fft.irfft2)
|
|
irfftn = fft_wrap(np.fft.irfftn)
|
|
hfft = fft_wrap(np.fft.hfft)
|
|
ihfft = fft_wrap(np.fft.ihfft)
|
|
|
|
|
|
def _fftfreq_block(i, n, d):
|
|
r = i.copy()
|
|
r[i >= (n + 1) // 2] -= n
|
|
r /= n * d
|
|
return r
|
|
|
|
|
|
@wraps(np.fft.fftfreq)
|
|
def fftfreq(n, d=1.0, chunks='auto'):
|
|
n = int(n)
|
|
d = float(d)
|
|
|
|
r = _arange(n, dtype=float, chunks=chunks)
|
|
|
|
return r.map_blocks(_fftfreq_block, dtype=float, n=n, d=d)
|
|
|
|
|
|
@wraps(np.fft.rfftfreq)
|
|
def rfftfreq(n, d=1.0, chunks='auto'):
|
|
n = int(n)
|
|
d = float(d)
|
|
|
|
r = _arange(n // 2 + 1, dtype=float, chunks=chunks)
|
|
r /= n * d
|
|
|
|
return r
|
|
|
|
|
|
def _fftshift_helper(x, axes=None, inverse=False):
|
|
if axes is None:
|
|
axes = list(range(x.ndim))
|
|
elif not isinstance(axes, Sequence):
|
|
axes = (axes,)
|
|
|
|
y = x
|
|
for i in axes:
|
|
n = y.shape[i]
|
|
n_2 = (n + int(inverse is False)) // 2
|
|
|
|
l = y.ndim * [slice(None)]
|
|
l[i] = slice(None, n_2)
|
|
l = tuple(l)
|
|
|
|
r = y.ndim * [slice(None)]
|
|
r[i] = slice(n_2, None)
|
|
r = tuple(r)
|
|
|
|
y = _concatenate([y[r], y[l]], axis=i)
|
|
|
|
if len(x.chunks[i]) == 1:
|
|
y = y.rechunk({i: x.chunks[i]})
|
|
|
|
return y
|
|
|
|
|
|
@wraps(np.fft.fftshift)
|
|
def fftshift(x, axes=None):
|
|
return _fftshift_helper(x, axes=axes, inverse=False)
|
|
|
|
|
|
@wraps(np.fft.ifftshift)
|
|
def ifftshift(x, axes=None):
|
|
return _fftshift_helper(x, axes=axes, inverse=True)
|