blob: 3043c44beb003969e4a9c41f5a343fbdb870538b [file] [log] [blame]
piotr299ff0e2014-07-20 23:49:48 +02001# This program is public domain
2# Authors: Paul Kienzle, Nadav Horesh
3"""
4Chirp z-transform.
5
6CZT: callable (x,axis=-1)->array
7 define a chirp-z transform that can be applied to different signals
8ZoomFFT: callable (x,axis=-1)->array
9 define a Fourier transform on a range of frequencies
10ScaledFFT: callable (x,axis=-1)->array
11 define a limited frequency FFT
12
13czt: array
14 compute the chirp-z transform for a signal
15zoomfft: array
16 compute the Fourier transform on a range of frequencies
17scaledfft: array
18 compute a limited frequency FFT for a signal
19"""
20__all__ = ['czt', 'zoomfft', 'scaledfft']
21
22import math, cmath
23
24import numpy as np
25from numpy import pi, arange
26from scipy.fftpack import fft, ifft, fftshift
27
28class CZT:
29 """
30 Chirp-Z Transform.
31
32 Transform to compute the frequency response around a spiral.
33 Objects of this class are callables which can compute the
34 chirp-z transform on their inputs. This object precalculates
35 constants for the given transform.
36
37 If w does not lie on the unit circle, then the transform will be
38 around a spiral with exponentially increasing radius. Regardless,
39 angle will increase linearly.
40
41 The chirp-z transform can be faster than an equivalent fft with
42 zero padding. Try it with your own array sizes to see. It is
43 theoretically faster for large prime fourier transforms, but not
44 in practice.
45
46 The chirp-z transform is considerably less precise than the
47 equivalent zero-padded FFT, with differences on the order of 1e-11
48 from the direct transform rather than the on the order of 1e-15 as
49 seen with zero-padding.
50
51 See zoomfft for a friendlier interface to partial fft calculations.
52 """
53 def __init__(self, n, m=None, w=1, a=1):
54 """
55 Chirp-Z transform definition.
56
57 Parameters:
58 ----------
59 n: int
60 The size of the signal
61 m: int
62 The number of points desired. The default is the length of the input data.
63 a: complex
64 The starting point in the complex plane. The default is 1.
65 w: complex or float
66 If w is complex, it is the ratio between points in each step.
67 If w is float, it serves as a frequency scaling factor. for instance
68 when assigning w=0.5, the result FT will span half of frequncy range
69 (that fft would result) at half of the frequncy step size.
70
71 Returns:
72 --------
73 CZT:
74 callable object f(x,axis=-1) for computing the chirp-z transform on x
75 """
76 if m is None:
77 m = n
78 if w is None:
79 w = cmath.exp(-1j*pi/m)
80 elif type(w) in (float, int):
81 w = cmath.exp(-1j*pi/m * w)
82 else:
83 w = cmath.sqrt(w)
84 self.w, self.a = w, a
85 self.m, self.n = m, n
86
87 k = arange(max(m,n))
88 wk2 = w**(k**2)
89 nfft = 2**nextpow2(n+m-1)
90 self._Awk2 = (a**-k * wk2)[:n]
91 self._nfft = nfft
92 self._Fwk2 = fft(1/np.hstack((wk2[n-1:0:-1], wk2[:m])), nfft)
93 self._wk2 = wk2[:m]
94 self._yidx = slice(n-1, n+m-1)
95
96 def __call__(self, x, axis=-1):
97 """
98 Parameters:
99 ----------
100 x: array
101 The signal to transform.
102 axis: int
103 Array dimension to operate over. The default is the final
104 dimension.
105
106 Returns:
107 -------
108 An array of the same dimensions as x, but with the length of the
109 transformed axis set to m. Note that this is a view on a much
110 larger array. To save space, you may want to call it as
111 y = czt(x).copy()
112 """
113 x = np.asarray(x)
114 if x.shape[axis] != self.n:
115 raise ValueError("CZT defined for length %d, not %d" %
116 (self.n, x.shape[axis]))
117 # Calculate transpose coordinates, to allow operation on any given axis
118 trnsp = np.arange(x.ndim)
119 trnsp[[axis, -1]] = [-1, axis]
120 x = x.transpose(*trnsp)
121 y = ifft(self._Fwk2 * fft(x*self._Awk2, self._nfft))
122 y = y[..., self._yidx] * self._wk2
123 return y.transpose(*trnsp)
124
125
126def nextpow2(n):
127 """
128 Return the smallest power of two greater than or equal to n.
129 """
130 return int(math.ceil(math.log(n)/math.log(2)))
131
132
133def ZoomFFT(n, f1, f2=None, m=None, Fs=2):
134 """
135 Zoom FFT transform definition.
136
137 Computes the Fourier transform for a set of equally spaced
138 frequencies.
139
140 Parameters:
141 ----------
142 n: int
143 size of the signal
144 m: int
145 size of the output
146 f1, f2: float
147 start and end frequencies; if f2 is not specified, use 0 to f1
148 Fs: float
149 sampling frequency (default=2)
150
151 Returns:
152 -------
153 A CZT instance
154 A callable object f(x,axis=-1) for computing the zoom FFT on x.
155
156 Sampling frequency is 1/dt, the time step between samples in the
157 signal x. The unit circle corresponds to frequencies from 0 up
158 to the sampling frequency. The default sampling frequency of 2
159 means that f1,f2 values up to the Nyquist frequency are in the
160 range [0,1). For f1,f2 values expressed in radians, a sampling
161 frequency of 1/pi should be used.
162
163 To graph the magnitude of the resulting transform, use::
164
165 plot(linspace(f1,f2,m), abs(zoomfft(x,f1,f2,m))).
166
167 Use the zoomfft wrapper if you only need to compute one transform.
168 """
169 if m is None: m = n
170 if f2 is None: f1, f2 = 0., f1
171 w = cmath.exp(-2j * pi * (f2-f1) / ((m-1)*Fs))
172 a = cmath.exp(2j * pi * f1/Fs)
173 return CZT(n, m=m, w=w, a=a)
174
175def ScaledFFT(n, m=None, scale=1.0):
176 """
177 Scaled fft transform definition.
178
179 Similar to fft, where the frequency range is scaled by a factor 'scale' and
180 divided into 'm-1' equal steps. Like the FFT, frequencies are arranged
181 from 0 to scale*Fs/2-delta followed by -scale*Fs/2 to -delta, where delta
182 is the step size scale*Fs/m for sampling frequence Fs. The intended use is in
183 a convolution of two signals, each has its own sampling step.
184
185 This is equivalent to:
186
187 fftshift(zoomfft(x, -scale, scale*(m-2.)/m, m=m))
188
189 For example:
190
191 m,n = 10,len(x)
192 sf = ScaledFFT(n, m=m, scale=0.25)
193 X = fftshift(fft(x))
194 W = linspace(-8, 8*(n-2.)/n, n)
195 SX = fftshift(sf(x))
196 SW = linspace(-2, 2*(m-2.)/m, m)
197 plot(X,W,SX,SW)
198
199 Parameters:
200 ----------
201 n: int
202 Size of the signal
203 m: int
204 The size of the output.
205 Default: m=n
206 scale: float
207 Frequenct scaling factor.
208 Default: scale=1.0
209
210 Returns:
211 -------
212 function
213 A callable f(x,axis=-1) for computing the scaled FFT on x.
214 """
215 if m is None:
216 m = n
217 w = np.exp(-2j * pi / m * scale)
218 a = w**(m//2)
219 transform = CZT(n=n, m=m, a=a, w=w)
220 return lambda x, axis=-1: fftshift(transform(x, axis), axes=(axis,))
221
222def scaledfft(x, m=None, scale=1.0, axis=-1):
223 """
224 Partial with a frequency scaling.
225 See ScaledFFT doc for details
226
227 Parameters:
228 ----------
229 x: input array
230 m: int
231 The length of the output signal
232 scale: float
233 A frequency scaling factor
234 axis: int
235 The array dimension to operate over. The default is the
236 final dimension.
237
238 Returns:
239 -------
240 An array of the same rank of 'x', but with the size if
241 the 'axis' dimension set to 'm'
242 """
243 return ScaledFFT(x.shape[axis], m, scale)(x,axis)
244
245def czt(x, m=None, w=1.0, a=1, axis=-1):
246 """
247 Compute the frequency response around a spiral.
248
249 Parameters:
250 ----------
251 x: array
252 The set of data to transform.
253 m: int
254 The number of points desired. The default is the length of the input data.
255 a: complex
256 The starting point in the complex plane. The default is 1.
257 w: complex or float
258 If w is complex, it is the ratio between points in each step.
259 If w is float, it is the frequency step scale (relative to the
260 normal dft frquency step).
261 axis: int
262 Array dimension to operate over. The default is the final
263 dimension.
264
265 Returns:
266 -------
267 An array of the same dimensions as x, but with the length of the
268 transformed axis set to m. Note that this is a view on a much
269 larger array. To save space, you may want to call it as
270 y = ascontiguousarray(czt(x))
271
272 See zoomfft for a friendlier interface to partial fft calculations.
273
274 If the transform needs to be repeated, use CZT to construct a
275 specialized transform function which can be reused without
276 recomputing constants.
277 """
278 x = np.asarray(x)
279 transform = CZT(x.shape[axis], m=m, w=w, a=a)
280 return transform(x,axis=axis)
281
282def zoomfft(x, f1, f2=None, m=None, Fs=2, axis=-1):
283 """
284 Compute the Fourier transform of x for frequencies in [f1, f2].
285
286 Parameters:
287 ----------
288 m: int
289 The number of points to evaluate. The default is the length of x.
290 f1, f2: float
291 The frequency range. If f2 is not specified, the range 0-f1 is assumed.
292 Fs: float
293 The sampling frequency. With a sampling frequency of
294 10kHz for example, the range f1 and f2 can be expressed in kHz.
295 The default sampling frequency is 2, so f1 and f2 should be
296 in the range 0,1 to keep the transform below the Nyquist
297 frequency.
298 x : array
299 The input signal.
300 axis: int
301 The array dimension the transform operates over. The default is the
302 final dimension.
303
304 Returns:
305 -------
306 array
307 The transformed signal. The fourier transform will be calculate
308 at the points f1, f1+df, f1+2df, ..., f2, where df=(f2-f1)/m.
309
310 zoomfft(x,0,2-2./len(x)) is equivalent to fft(x).
311
312 To graph the magnitude of the resulting transform, use::
313
314 plot(linspace(f1,f2,m), abs(zoomfit(x,f1,f2,m))).
315
316 If the transform needs to be repeated, use ZoomFFT to construct a
317 specialized transform function which can be reused without
318 recomputing constants.
319 """
320 x = np.asarray(x)
321 transform = ZoomFFT(x.shape[axis], f1, f2=f2, m=m, Fs=Fs)
322 return transform(x,axis=axis)
323
324
325def _test1(x,show=False,plots=[1,2,3,4]):
326 norm = np.linalg.norm
327
328 # Normal fft and zero-padded fft equivalent to 10x oversampling
329 over=10
330 w = np.linspace(0,2-2./len(x),len(x))
331 y = fft(x)
332 wover = np.linspace(0,2-2./(over*len(x)),over*len(x))
333 yover = fft(x,over*len(x))
334
335 # Check that zoomfft is the equivalent of fft
336 y1 = zoomfft(x,0,2-2./len(y))
337
338 # Check that zoomfft with oversampling is equivalent to zero padding
339 y2 = zoomfft(x,0,2-2./len(yover), m=len(yover))
340
341 # Check that zoomfft works on a subrange
342 f1,f2 = w[3],w[6]
343 y3 = zoomfft(x,f1,f2,m=3*over+1)
344 w3 = np.linspace(f1,f2,len(y3))
345 idx3 = slice(3*over,6*over+1)
346
347 if not show: plots = []
348 if plots != []:
349 import pylab
350 if 0 in plots:
351 pylab.figure(0)
352 pylab.plot(x)
353 pylab.ylabel('Intensity')
354 if 1 in plots:
355 pylab.figure(1)
356 pylab.subplot(311)
357 pylab.plot(w,abs(y),'o',w,abs(y1))
358 pylab.legend(['fft','zoom'])
359 pylab.ylabel('Magnitude')
360 pylab.title('FFT equivalent')
361 pylab.subplot(312)
362 pylab.plot(w,np.angle(y),'o',w,np.angle(y1))
363 pylab.legend(['fft','zoom'])
364 pylab.ylabel('Phase (radians)')
365 pylab.subplot(313)
366 pylab.plot(w,abs(y)-abs(y1)) #,w,np.angle(y)-np.angle(y1))
367 #pylab.legend(['magnitude','phase'])
368 pylab.ylabel('Residuals')
369 if 2 in plots:
370 pylab.figure(2)
371 pylab.subplot(211)
372 pylab.plot(w,abs(y),'o',wover,abs(y2),wover,abs(yover))
373 pylab.ylabel('Magnitude')
374 pylab.title('Oversampled FFT')
375 pylab.legend(['fft','zoom','pad'])
376 pylab.subplot(212)
377 pylab.plot(wover,abs(yover)-abs(y2),
378 w,abs(y)-abs(y2[0::over]),'o',
379 w,abs(y)-abs(yover[0::over]),'x')
380 pylab.legend(['pad-zoom','fft-zoom','fft-pad'])
381 pylab.ylabel('Residuals')
382 if 3 in plots:
383 pylab.figure(3)
384 ax1=pylab.subplot(211)
385 pylab.plot(w,abs(y),'o',w3,abs(y3),wover,abs(yover),
386 w[3:7],abs(y3[::over]),'x')
387 pylab.title('Zoomed FFT')
388 pylab.ylabel('Magnitude')
389 pylab.legend(['fft','zoom','pad'])
390 pylab.plot(w3,abs(y3),'x')
391 ax1.set_xlim(f1,f2)
392 ax2=pylab.subplot(212)
393 pylab.plot(wover[idx3],abs(yover[idx3])-abs(y3),
394 w[3:7],abs(y[3:7])-abs(y3[::over]),'o',
395 w[3:7],abs(y[3:7])-abs(yover[3*over:6*over+1:over]),'x')
396 pylab.legend(['pad-zoom','fft-zoom','fft-pad'])
397 ax2.set_xlim(f1,f2)
398 pylab.ylabel('Residuals')
399 if plots != []:
400 pylab.show()
401
402 err = norm(y-y1)/norm(y)
403 #print "direct err %g"%err
404 assert err < 1e-10, "error for direct transform is %g"%(err,)
405 err = norm(yover-y2)/norm(yover)
406 #print "over err %g"%err
407 assert err < 1e-10, "error for oversampling is %g"%(err,)
408 err = norm(yover[idx3]-y3)/norm(yover[idx3])
409 #print "range err %g"%err
410 assert err < 1e-10, "error for subrange is %g"%(err,)
411
412def _testscaled(x):
413 n = len(x)
414 norm = np.linalg.norm
415 assert norm(fft(x)-scaledfft(x)) < 1e-10
416 assert norm(fftshift(fft(x))[n/4:3*n/4] - fftshift(scaledfft(x,scale=0.5,m=n/2))) < 1e-10
417
418def test(demo=None,plots=[1,2,3]):
419 # 0: Gauss
420 t = np.linspace(-2,2,128)
421 x = np.exp(-t**2/0.01)
422 _test1(x, show=(demo==0), plots=plots)
423
424 # 1: Linear
425 x=[1,2,3,4,5,6,7]
426 _test1(x, show=(demo==1), plots=plots)
427
428 # Check near powers of two
429 _test1(range(126-31), show=False)
430 _test1(range(127-31), show=False)
431 _test1(range(128-31), show=False)
432 _test1(range(129-31), show=False)
433 _test1(range(130-31), show=False)
434
435 # Check transform on n-D array input
436 x = np.reshape(np.arange(3*2*28),(3,2,28))
437 y1 = zoomfft(x,0,2-2./28)
438 y2 = zoomfft(x[2,0,:],0,2-2./28)
439 err = np.linalg.norm(y2-y1[2,0])
440 assert err < 1e-15, "error for n-D array is %g"%(err,)
441
442 # 2: Random (not a test condition)
443 if demo==2:
444 x = np.random.rand(101)
445 _test1(x, show=True, plots=plots)
446
447 # 3: Spikes
448 t=np.linspace(0,1,128)
449 x=np.sin(2*pi*t*5)+np.sin(2*pi*t*13)
450 _test1(x, show=(demo==3), plots=plots)
451
452 # 4: Sines
453 x=np.zeros(100)
454 x[[1,5,21]]=1
455 _test1(x, show=(demo==4), plots=plots)
456
457 # 5: Sines plus complex component
458 x += 1j*np.linspace(0,0.5,x.shape[0])
459 _test1(x, show=(demo==5), plots=plots)
460
461 # 6: Scaled FFT on complex sines
462 x += 1j*np.linspace(0,0.5,x.shape[0])
463 if demo == 6:
464 demo_scaledfft(x,0.25,200)
465 _testscaled(x)
466
467
468def demo_scaledfft(v, scale, m):
469 import pylab
470 shift = pylab.fftshift
471 n = len(v)
472 x = pylab.linspace(-0.5, 0.5 - 1./n, n)
473 xz = pylab.linspace(-scale*0.5, scale*0.5*(m-2.)/m, m)
474 pylab.figure()
475 pylab.plot(x, shift(abs(fft(v))), label='fft')
476 pylab.plot(x, shift(abs(scaledfft(v))),'ro', label='x1 scaled fft')
477 pylab.plot(xz, abs(zoomfft(v, -scale, scale*(m-2.)/m, m=m)),
478 'bo',label='zoomfft')
479 pylab.plot(xz, shift(abs(scaledfft(v, m=m, scale=scale))),
480 'gx', label='x'+str(scale)+' scaled fft')
481 pylab.gca().set_yscale('log')
482 pylab.legend()
483 pylab.show()
484
485if __name__ == "__main__":
486 # Choose demo in [0,4] to show plot, or None for testing only
487 test(demo=None)
488