Python3-scipy
参考:
1、http://www.scipy-lectures.org/index.html
2、https://github.com/scipy-lectures/scipy-lecture-notes/tree/master/intro/scipy/examples
3、https://github.com/jayleicn/scipy-lecture-notes-zh-CN
4、http://scipy-cookbook.readthedocs.io/
5、https://docs.scipy.org/doc/scipy/reference/
1.5. Scipy : high-level scientific computing
Chapters contents
-
File input/output:
scipy.io
-
Special functions:
scipy.special
-
Linear algebra
operations:
scipy.linalg
-
Interpolation:
scipy.interpolate
-
Optimization and
fit:
scipy.optimize
-
Statistics
and random numbers:
scipy.stats
-
Numerical integration:
scipy.integrate
-
Fast Fourier
transforms:
scipy.fftpack
-
Signal processing:
scipy.signal
-
Image manipulation:
scipy.ndimage
- Summary exercises on scientific computing
- Full code examples for the scipy chapter
scipy
is composed of task-specific
sub-modules:
scipy.cluster
|
Vector quantization / Kmeans |
scipy.constants
|
Physical and mathematical constants |
scipy.fftpack
|
Fourier transform |
scipy.integrate
|
Integration routines |
scipy.interpolate
|
Interpolation |
scipy.io
|
Data input and output |
scipy.linalg
|
Linear algebra routines |
scipy.ndimage
|
n-dimensional image package |
scipy.odr
|
Orthogonal distance regression |
scipy.optimize
|
Optimization |
scipy.signal
|
Signal processing |
scipy.sparse
|
Sparse matrices |
scipy.spatial
|
Spatial data structures and algorithms |
scipy.special
|
Any special mathematical functions |
scipy.stats
|
Statistics |
1.5.1 File input/output: scipy.io
Matlab files: Loading and saving:
from scipy import io as spio
a = np.ones((3, 3))
spio.savemat('file.mat', {'a': a}) # savemat expects a dictionary
data = spio.loadmat('file.mat')
data['a']
Image files: Reading images:
>>> from scipy import misc
>>> misc.imread('fname.png')
array(...)
>>> # Matplotlib also has a similar function
>>> import matplotlib.pyplot as plt
>>> plt.imread('fname.png')
array(...)
See also
- Load text files:
numpy.loadtxt()
/numpy.savetxt()
- Clever loading of text/csv files:
numpy.genfromtxt()
/numpy.recfromcsv()
- Fast and efficient, but numpy-specific, binary format:
numpy.save()
/numpy.load()
- More advanced input/output of images in scikit-image:
skimage.io
1.5.3 Linear algebra operations: scipy.linalg
-
The
scipy.linalg.det()
function computes the determinant of a square matrix:>>>>>> from scipy import linalg >>> arr = np.array([[1, 2], ... [3, 4]]) >>> linalg.det(arr) -2.0 >>> arr = np.array([[3, 2], ... [6, 4]]) >>> linalg.det(arr) 0.0 >>> linalg.det(np.ones((3, 4))) Traceback (most recent call last): ... ValueError: expected square matrix
-
The
scipy.linalg.inv()
function computes the inverse of a square matrix:>>>>>> arr = np.array([[1, 2], ... [3, 4]]) >>> iarr = linalg.inv(arr) >>> iarr array([[-2. , 1. ], [ 1.5, -0.5]]) >>> np.allclose(np.dot(arr, iarr), np.eye(2)) True
Finally computing the inverse of a singular matrix (its determinant is zero) will raise
LinAlgError
:>>>>>> arr = np.array([[3, 2], ... [6, 4]]) >>> linalg.inv(arr) Traceback (most recent call last): ... ...LinAlgError: singular matrix
-
More advanced operations are available, for example singular-value decomposition (SVD):
>>>>>> arr = np.arange(9).reshape((3, 3)) + np.diag([1, 0, 1]) >>> uarr, spec, vharr = linalg.svd(arr)
The resulting array spectrum is:
>>>>>> spec array([ 14.88982544, 0.45294236, 0.29654967])
The original matrix can be re-composed by matrix multiplication of the outputs of
svd
withnp.dot
:>>>>>> sarr = np.diag(spec) >>> svd_mat = uarr.dot(sarr).dot(vharr) >>> np.allclose(svd_mat, arr) True
SVD is commonly used in statistics and signal processing. Many other standard decompositions (QR, LU, Cholesky, Schur), as well as solvers for linear systems, are available in
scipy.linalg
.
其他略过
1.5.12 Full code examples for the scipy chapter
1.5.12.1 Finding the minimum of a smooth function
Demos various methods to find the minimum of a function.
import numpy as np
import matplotlib.pyplot as plt
def f(x):
return x**2 + 10*np.sin(x)
x = np.arange(-10, 10, 0.1)
plt.plot(x, f(x))
Now find the minimum with a few methods
from scipy import optimize
# The default (Nelder Mead)
print(optimize.minimize(f, x0=0))
Out:
fun: -7.945823375615215
hess_inv: array([[ 0.08589237]])
jac: array([ -1.19209290e-06])
message: 'Optimization terminated successfully.'
nfev: 18
nit: 5
njev: 6
status: 0
success: True
x: array([-1.30644012])
print(optimize.minimize(f, x0=0, method="L-BFGS-B"))
Out:
fun: array([-7.94582338])
hess_inv: <1x1 LbfgsInvHessProduct with dtype=float64>
jac: array([ -1.42108547e-06])
message: b'CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL'
nfev: 12
nit: 5
status: 0
success: True
x: array([-1.30644013])
plt.show()
Total running time of the script: ( 0 minutes 0.103 seconds)
"""
=========================================
Finding the minimum of a smooth function
=========================================
Demos various methods to find the minimum of a function.
"""
import numpy as np
import matplotlib.pyplot as plt
def f(x):
return x**2 + 10*np.sin(x)
x = np.arange(-10, 10, 0.1)
plt.plot(x, f(x))
############################################################
# Now find the minimum with a few methods
from scipy import optimize
# The default (Nelder Mead)
print(optimize.minimize(f, x0=0))
############################################################
print(optimize.minimize(f, x0=0, method="L-BFGS-B"))
############################################################
plt.show()
1.5.12.2 Detrending a signal
scipy.signal.detrend()
removes
a linear trend.
Generate a random signal with a trend
import numpy as np
t = np.linspace(0, 5, 100)
x = t + np.random.normal(size=100)
Detrend
from scipy import signal
x_detrended = signal.detrend(x)
Plot
from matplotlib import pyplot as plt
plt.figure(figsize=(5, 4))
plt.plot(t, x, label="x")
plt.plot(t, x_detrended, label="x_detrended")
plt.legend(loc='best')
plt.show()
Total running time of the script: ( 0 minutes 0.187 seconds)
"""
===================
Detrending a signal
===================
:func:`scipy.signal.detrend` removes a linear trend.
"""
############################################################
# Generate a random signal with a trend
import numpy as np
t = np.linspace(0, 5, 100)
x = t + np.random.normal(size=100)
############################################################
# Detrend
from scipy import signal
x_detrended = signal.detrend(x)
############################################################
# Plot
from matplotlib import pyplot as plt
plt.figure(figsize=(5, 4))
plt.plot(t, x, label="x")
plt.plot(t, x_detrended, label="x_detrended")
plt.legend(loc='best')
plt.show()
1.5.12.3 Resample a signal with scipy.signal.resample
scipy.signal.resample()
uses
FFT to resample a 1D signal.
Generate a signal with 100 data point
import numpy as np
t = np.linspace(0, 5, 100)
x = np.sin(t)
Downsample it by a factor of 4
from scipy import signal
x_resampled = signal.resample(x, 25)
Plot
from matplotlib import pyplot as plt
plt.figure(figsize=(5, 4))
plt.plot(t, x, label='Original signal')
plt.plot(t[::4], x_resampled, 'ko', label='Resampled signal')
plt.legend(loc='best')
plt.show()
Total running time of the script: ( 0 minutes 0.037 seconds)
"""
Resample a signal with scipy.signal.resample
=============================================
:func:`scipy.signal.resample` uses FFT to resample a 1D signal.
"""
############################################################
# Generate a signal with 100 data point
import numpy as np
t = np.linspace(0, 5, 100)
x = np.sin(t)
############################################################
# Downsample it by a factor of 4
from scipy import signal
x_resampled = signal.resample(x, 25)
############################################################
# Plot
from matplotlib import pyplot as plt
plt.figure(figsize=(5, 4))
plt.plot(t, x, label='Original signal')
plt.plot(t[::4], x_resampled, 'ko', label='Resampled signal')
plt.legend(loc='best')
plt.show()
1.5.12.4 Integrating a simple ODE
Solve the ODE dy/dt = -2y between t = 0..4, with the initial condition y(t=0) = 1.
import numpy as np
from scipy.integrate import odeint
from matplotlib import pyplot as plt
def calc_derivative(ypos, time):
return -2*ypos
time_vec = np.linspace(0, 4, 40)
yvec = odeint(calc_derivative, 1, time_vec)
plt.figure(figsize=(4, 3))
plt.plot(time_vec, yvec)
plt.xlabel('t: Time')
plt.ylabel('y: Position')
plt.tight_layout()
Total running time of the script: ( 0 minutes 0.056 seconds)
"""
=========================
Integrating a simple ODE
=========================
Solve the ODE dy/dt = -2y between t = 0..4, with the initial condition
y(t=0) = 1.
"""
import numpy as np
from scipy.integrate import odeint
from matplotlib import pyplot as plt
def calc_derivative(ypos, time):
return -2*ypos
time_vec = np.linspace(0, 4, 40)
yvec = odeint(calc_derivative, 1, time_vec)
plt.figure(figsize=(4, 3))
plt.plot(time_vec, yvec)
plt.xlabel('t: Time')
plt.ylabel('y: Position')
plt.tight_layout()
1.5.12.5 Comparing 2 sets of samples from Gaussians
import numpy as np
from matplotlib import pyplot as plt
# Generates 2 sets of observations
samples1 = np.random.normal(0, size=1000)
samples2 = np.random.normal(1, size=1000)
# Compute a histogram of the sample
bins = np.linspace(-4, 4, 30)
histogram1, bins = np.histogram(samples1, bins=bins, normed=True)
histogram2, bins = np.histogram(samples2, bins=bins, normed=True)
plt.figure(figsize=(6, 4))
plt.hist(samples1, bins=bins, normed=True, label="Samples 1")
plt.hist(samples2, bins=bins, normed=True, label="Samples 2")
plt.legend(loc='best')
plt.show()
Total running time of the script: ( 0 minutes 0.078 seconds)
"""
==========================================
Comparing 2 sets of samples from Gaussians
==========================================
"""
import numpy as np
from matplotlib import pyplot as plt
# Generates 2 sets of observations
samples1 = np.random.normal(0, size=1000)
samples2 = np.random.normal(1, size=1000)
# Compute a histogram of the sample
bins = np.linspace(-4, 4, 30)
histogram1, bins = np.histogram(samples1, bins=bins, normed=True)
histogram2, bins = np.histogram(samples2, bins=bins, normed=True)
plt.figure(figsize=(6, 4))
plt.hist(samples1, bins=bins, normed=True, label="Samples 1")
plt.hist(samples2, bins=bins, normed=True, label="Samples 2")
plt.legend(loc='best')
plt.show()
1.5.12.6 Integrate the Damped spring-mass oscillator
import numpy as np
from scipy.integrate import odeint
from matplotlib import pyplot as plt
mass = 0.5 # kg
kspring = 4 # N/m
cviscous = 0.4 # N s/m
eps = cviscous / (2 * mass * np.sqrt(kspring/mass))
omega = np.sqrt(kspring / mass)
def calc_deri(yvec, time, eps, omega):
return (yvec[1], -eps * omega * yvec[1] - omega **2 * yvec[0])
time_vec = np.linspace(0, 10, 100)
yinit = (1, 0)
yarr = odeint(calc_deri, yinit, time_vec, args=(eps, omega))
plt.figure(figsize=(4, 3))
plt.plot(time_vec, yarr[:, 0], label='y')
plt.plot(time_vec, yarr[:, 1], label="y'")
plt.legend(loc='best')
plt.show()
Total running time of the script: ( 0 minutes 0.038 seconds)
"""
============================================
Integrate the Damped spring-mass oscillator
============================================
"""
import numpy as np
from scipy.integrate import odeint
from matplotlib import pyplot as plt
mass = 0.5 # kg
kspring = 4 # N/m
cviscous = 0.4 # N s/m
eps = cviscous / (2 * mass * np.sqrt(kspring/mass))
omega = np.sqrt(kspring / mass)
def calc_deri(yvec, time, eps, omega):
return (yvec[1], -eps * omega * yvec[1] - omega **2 * yvec[0])
time_vec = np.linspace(0, 10, 100)
yinit = (1, 0)
yarr = odeint(calc_deri, yinit, time_vec, args=(eps, omega))
plt.figure(figsize=(4, 3))
plt.plot(time_vec, yarr[:, 0], label='y')
plt.plot(time_vec, yarr[:, 1], label="y'")
plt.legend(loc='best')
plt.show()
1.5.12.7 Normal distribution: histogram and PDF
Explore the normal distribution: a histogram built from samples and the PDF (probability density function).
import numpy as np
# Sample from a normal distribution using numpy's random number generator
samples = np.random.normal(size=10000)
# Compute a histogram of the sample
bins = np.linspace(-5, 5, 30)
histogram, bins = np.histogram(samples, bins=bins, normed=True)
bin_centers = 0.5*(bins[1:] + bins[:-1])
# Compute the PDF on the bin centers from scipy distribution object
from scipy import stats
pdf = stats.norm.pdf(bin_centers)
from matplotlib import pyplot as plt
plt.figure(figsize=(6, 4))
plt.plot(bin_centers, histogram, label="Histogram of samples")
plt.plot(bin_centers, pdf, label="PDF")
plt.legend()
plt.show()
Total running time of the script: ( 0 minutes 0.037 seconds)
"""
=======================================
Normal distribution: histogram and PDF
=======================================
Explore the normal distribution: a histogram built from samples and the
PDF (probability density function).
"""
import numpy as np
# Sample from a normal distribution using numpy's random number generator
samples = np.random.normal(size=10000)
# Compute a histogram of the sample
bins = np.linspace(-5, 5, 30)
histogram, bins = np.histogram(samples, bins=bins, normed=True)
bin_centers = 0.5*(bins[1:] + bins[:-1])
# Compute the PDF on the bin centers from scipy distribution object
from scipy import stats
pdf = stats.norm.pdf(bin_centers)
from matplotlib import pyplot as plt
plt.figure(figsize=(6, 4))
plt.plot(bin_centers, histogram, label="Histogram of samples")
plt.plot(bin_centers, pdf, label="PDF")
plt.legend()
plt.show()
1.5.12.8 Curve fitting
Demos a simple curve fitting
First generate some data
import numpy as np
# Seed the random number generator for reproducibility
np.random.seed(0)
x_data = np.linspace(-5, 5, num=50)
y_data = 2.9 * np.sin(1.5 * x_data) + np.random.normal(size=50)
# And plot it
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 4))
plt.scatter(x_data, y_data)
Now fit a simple sine function to the data
from scipy import optimize
def test_func(x, a, b):
return a * np.sin(b * x)
params, params_covariance = optimize.curve_fit(test_func, x_data, y_data,
p0=[2, 2])
print(params)
Out:
[ 3.05931973 1.45754553]
And plot the resulting curve on the data
plt.figure(figsize=(6, 4))
plt.scatter(x_data, y_data, label='Data')
plt.plot(x_data, test_func(x_data, params[0], params[1]),
label='Fitted function')
plt.legend(loc='best')
plt.show()
Total running time of the script: ( 0 minutes 0.071 seconds)
"""
===============
Curve fitting
===============
Demos a simple curve fitting
"""
############################################################
# First generate some data
import numpy as np
# Seed the random number generator for reproducibility
np.random.seed(0)
x_data = np.linspace(-5, 5, num=50)
y_data = 2.9 * np.sin(1.5 * x_data) + np.random.normal(size=50)
# And plot it
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 4))
plt.scatter(x_data, y_data)
############################################################
# Now fit a simple sine function to the data
from scipy import optimize
def test_func(x, a, b):
return a * np.sin(b * x)
params, params_covariance = optimize.curve_fit(test_func, x_data, y_data,
p0=[2, 2])
print(params)
############################################################
# And plot the resulting curve on the data
plt.figure(figsize=(6, 4))
plt.scatter(x_data, y_data, label='Data')
plt.plot(x_data, test_func(x_data, params[0], params[1]),
label='Fitted function')
plt.legend(loc='best')
plt.show()
1.5.12.9 Spectrogram, power spectral density
Demo spectrogram and power spectral density on a frequency chirp.
import numpy as np
from matplotlib import pyplot as plt
Generate a chirp signal
# Seed the random number generator
np.random.seed(0)
time_step = .01
time_vec = np.arange(0, 70, time_step)
# A signal with a small frequency chirp
sig = np.sin(0.5 * np.pi * time_vec * (1 + .1 * time_vec))
plt.figure(figsize=(8, 5))
plt.plot(time_vec, sig)
Compute and plot the spectrogram
The spectrum of the signal on consecutive time windows
from scipy import signal
freqs, times, spectrogram = signal.spectrogram(sig)
plt.figure(figsize=(5, 4))
plt.imshow(spectrogram, aspect='auto', cmap='hot_r', origin='lower')
plt.title('Spectrogram')
plt.ylabel('Frequency band')
plt.xlabel('Time window')
plt.tight_layout()
Compute and plot the power spectral density (PSD)
The power of the signal per frequency band
freqs, psd = signal.welch(sig)
plt.figure(figsize=(5, 4))
plt.semilogx(freqs, psd)
plt.title('PSD: power spectral density')
plt.xlabel('Frequency')
plt.ylabel('Power')
plt.tight_layout()
plt.show()
Total running time of the script: ( 0 minutes 0.303 seconds)
"""
======================================
Spectrogram, power spectral density
======================================
Demo spectrogram and power spectral density on a frequency chirp.
"""
import numpy as np
from matplotlib import pyplot as plt
############################################################
# Generate a chirp signal
############################################################
# Seed the random number generator
np.random.seed(0)
time_step = .01
time_vec = np.arange(0, 70, time_step)
# A signal with a small frequency chirp
sig = np.sin(0.5 * np.pi * time_vec * (1 + .1 * time_vec))
plt.figure(figsize=(8, 5))
plt.plot(time_vec, sig)
############################################################
# Compute and plot the spectrogram
############################################################
#
# The spectrum of the signal on consecutive time windows
from scipy import signal
freqs, times, spectrogram = signal.spectrogram(sig)
plt.figure(figsize=(5, 4))
plt.imshow(spectrogram, aspect='auto', cmap='hot_r', origin='lower')
plt.title('Spectrogram')
plt.ylabel('Frequency band')
plt.xlabel('Time window')
plt.tight_layout()
############################################################
# Compute and plot the power spectral density (PSD)
############################################################
#
# The power of the signal per frequency band
freqs, psd = signal.welch(sig)
plt.figure(figsize=(5, 4))
plt.semilogx(freqs, psd)
plt.title('PSD: power spectral density')
plt.xlabel('Frequency')
plt.ylabel('Power')
plt.tight_layout()
############################################################
plt.show()
1.5.12.10 A demo of 1D interpolation
# Generate data
import numpy as np
np.random.seed(0)
measured_time = np.linspace(0, 1, 10)
noise = 1e-1 * (np.random.random(10)*2 - 1)
measures = np.sin(2 * np.pi * measured_time) + noise
# Interpolate it to new time points
from scipy.interpolate import interp1d
linear_interp = interp1d(measured_time, measures)
interpolation_time = np.linspace(0, 1, 50)
linear_results = linear_interp(interpolation_time)
cubic_interp = interp1d(measured_time, measures, kind='cubic')
cubic_results = cubic_interp(interpolation_time)
# Plot the data and the interpolation
from matplotlib import pyplot as plt
plt.figure(figsize=(6, 4))
plt.plot(measured_time, measures, 'o', ms=6, label='measures')
plt.plot(interpolation_time, linear_results, label='linear interp')
plt.plot(interpolation_time, cubic_results, label='cubic interp')
plt.legend()
plt.show()
Total running time of the script: ( 0 minutes 0.039 seconds)
"""
============================
A demo of 1D interpolation
============================
"""
# Generate data
import numpy as np
np.random.seed(0)
measured_time = np.linspace(0, 1, 10)
noise = 1e-1 * (np.random.random(10)*2 - 1)
measures = np.sin(2 * np.pi * measured_time) + noise
# Interpolate it to new time points
from scipy.interpolate import interp1d
linear_interp = interp1d(measured_time, measures)
interpolation_time = np.linspace(0, 1, 50)
linear_results = linear_interp(interpolation_time)
cubic_interp = interp1d(measured_time, measures, kind='cubic')
cubic_results = cubic_interp(interpolation_time)
# Plot the data and the interpolation
from matplotlib import pyplot as plt
plt.figure(figsize=(6, 4))
plt.plot(measured_time, measures, 'o', ms=6, label='measures')
plt.plot(interpolation_time, linear_results, label='linear interp')
plt.plot(interpolation_time, cubic_results, label='cubic interp')
plt.legend()
plt.show()
1.5.12.11 Demo mathematical morphology
A basic demo of binary opening and closing.
# Generate some binary data
import numpy as np
np.random.seed(0)
a = np.zeros((50, 50))
a[10:-10, 10:-10] = 1
a += 0.25 * np.random.standard_normal(a.shape)
mask = a>=0.5
# Apply mathematical morphology
from scipy import ndimage
opened_mask = ndimage.binary_opening(mask)
closed_mask = ndimage.binary_closing(opened_mask)
# Plot
from matplotlib import pyplot as plt
plt.figure(figsize=(12, 3.5))
plt.subplot(141)
plt.imshow(a, cmap=plt.cm.gray)
plt.axis('off')
plt.title('a')
plt.subplot(142)
plt.imshow(mask, cmap=plt.cm.gray)
plt.axis('off')
plt.title('mask')
plt.subplot(143)
plt.imshow(opened_mask, cmap=plt.cm.gray)
plt.axis('off')
plt.title('opened_mask')
plt.subplot(144)
plt.imshow(closed_mask, cmap=plt.cm.gray)
plt.title('closed_mask')
plt.axis('off')
plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.99)
plt.show()
Total running time of the script: ( 0 minutes 0.133 seconds)
"""
=============================
Demo mathematical morphology
=============================
A basic demo of binary opening and closing.
"""
# Generate some binary data
import numpy as np
np.random.seed(0)
a = np.zeros((50, 50))
a[10:-10, 10:-10] = 1
a += 0.25 * np.random.standard_normal(a.shape)
mask = a>=0.5
# Apply mathematical morphology
from scipy import ndimage
opened_mask = ndimage.binary_opening(mask)
closed_mask = ndimage.binary_closing(opened_mask)
# Plot
from matplotlib import pyplot as plt
plt.figure(figsize=(12, 3.5))
plt.subplot(141)
plt.imshow(a, cmap=plt.cm.gray)
plt.axis('off')
plt.title('a')
plt.subplot(142)
plt.imshow(mask, cmap=plt.cm.gray)
plt.axis('off')
plt.title('mask')
plt.subplot(143)
plt.imshow(opened_mask, cmap=plt.cm.gray)
plt.axis('off')
plt.title('opened_mask')
plt.subplot(144)
plt.imshow(closed_mask, cmap=plt.cm.gray)
plt.title('closed_mask')
plt.axis('off')
plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.99)
plt.show()
1.5.12.12 Plot geometrical transformations on images
Demo geometrical transformations of images.
# Load some data
from scipy import misc
face = misc.face(gray=True)
# Apply a variety of transformations
from scipy import ndimage
from matplotlib import pyplot as plt
shifted_face = ndimage.shift(face, (50, 50))
shifted_face2 = ndimage.shift(face, (50, 50), mode='nearest')
rotated_face = ndimage.rotate(face, 30)
cropped_face = face[50:-50, 50:-50]
zoomed_face = ndimage.zoom(face, 2)
zoomed_face.shape
plt.figure(figsize=(15, 3))
plt.subplot(151)
plt.imshow(shifted_face, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(152)
plt.imshow(shifted_face2, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(153)
plt.imshow(rotated_face, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(154)
plt.imshow(cropped_face, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(155)
plt.imshow(zoomed_face, cmap=plt.cm.gray)
plt.axis('off')
plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.99)
plt.show()
Total running time of the script: ( 0 minutes 0.916 seconds)
"""
============================================
Plot geometrical transformations on images
============================================
Demo geometrical transformations of images.
"""
# Load some data
from scipy import misc
face = misc.face(gray=True)
# Apply a variety of transformations
from scipy import ndimage
from matplotlib import pyplot as plt
shifted_face = ndimage.shift(face, (50, 50))
shifted_face2 = ndimage.shift(face, (50, 50), mode='nearest')
rotated_face = ndimage.rotate(face, 30)
cropped_face = face[50:-50, 50:-50]
zoomed_face = ndimage.zoom(face, 2)
zoomed_face.shape
plt.figure(figsize=(15, 3))
plt.subplot(151)
plt.imshow(shifted_face, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(152)
plt.imshow(shifted_face2, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(153)
plt.imshow(rotated_face, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(154)
plt.imshow(cropped_face, cmap=plt.cm.gray)
plt.axis('off')
plt.subplot(155)
plt.imshow(zoomed_face, cmap=plt.cm.gray)
plt.axis('off')
plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.99)
plt.show()
1.5.12.13 Demo connected components
Extracting and labeling connected components in a 2D array
import numpy as np
from matplotlib import pyplot as plt
Generate some binary data
np.random.seed(0)
x, y = np.indices((100, 100))
sig = np.sin(2*np.pi*x/50.) * np.sin(2*np.pi*y/50.) * (1+x*y/50.**2)**2
mask = sig > 1
plt.figure(figsize=(7, 3.5))
plt.subplot(1, 2, 1)
plt.imshow(sig)
plt.axis('off')
plt.title('sig')
plt.subplot(1, 2, 2)
plt.imshow(mask, cmap=plt.cm.gray)
plt.axis('off')
plt.title('mask')
plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.9)
Label connected components
from scipy import ndimage
labels, nb = ndimage.label(mask)
plt.figure(figsize=(3.5, 3.5))
plt.imshow(labels)
plt.title('label')
plt.axis('off')
plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.9)
Extract the 4th connected component, and crop the array around it
sl = ndimage.find_objects(labels==4)
plt.figure(figsize=(3.5, 3.5))
plt.imshow(sig[sl[0]])
plt.title('Cropped connected component')
plt.axis('off')
plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.9)
plt.show()
Total running time of the script: ( 0 minutes 0.138 seconds)
"""
=============================
Demo connected components
=============================
Extracting and labeling connected components in a 2D array
"""
import numpy as np
from matplotlib import pyplot as plt
############################################################
# Generate some binary data
np.random.seed(0)
x, y = np.indices((100, 100))
sig = np.sin(2*np.pi*x/50.) * np.sin(2*np.pi*y/50.) * (1+x*y/50.**2)**2
mask = sig > 1
plt.figure(figsize=(7, 3.5))
plt.subplot(1, 2, 1)
plt.imshow(sig)
plt.axis('off')
plt.title('sig')
plt.subplot(1, 2, 2)
plt.imshow(mask, cmap=plt.cm.gray)
plt.axis('off')
plt.title('mask')
plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.9)
############################################################
# Label connected components
from scipy import ndimage
labels, nb = ndimage.label(mask)
plt.figure(figsize=(3.5, 3.5))
plt.imshow(labels)
plt.title('label')
plt.axis('off')
plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.9)
############################################################
# Extract the 4th connected component, and crop the array around it
sl = ndimage.find_objects(labels==4)
plt.figure(figsize=(3.5, 3.5))
plt.imshow(sig[sl[0]])
plt.title('Cropped connected component')
plt.axis('off')
plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.9)
plt.show()
1.5.12.14 Minima and roots of a function
Demos finding minima and roots of a function.
import numpy as np
x = np.arange(-10, 10, 0.1)
def f(x):
return x**2 + 10*np.sin(x)
Find minima
from scipy import optimize
# Global optimization
grid = (-10, 10, 0.1)
xmin_global = optimize.brute(f, (grid, ))
print("Global minima found %s" % xmin_global)
# Constrain optimization
xmin_local = optimize.fminbound(f, 0, 10)
print("Local minimum found %s" % xmin_local)
Out:
Global minima found [-1.30641113]
Local minimum found 3.8374671195
Root finding
root = optimize.root(f, 1) # our initial guess is 1
print("First root found %s" % root.x)
root2 = optimize.root(f, -2.5)
print("Second root found %s" % root2.x)
Out:
First root found [ 0.]
Second root found [-2.47948183]
Plot function, minima, and roots
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111)
# Plot the function
ax.plot(x, f(x), 'b-', label="f(x)")
# Plot the minima
xmins = np.array([xmin_global[0], xmin_local])
ax.plot(xmins, f(xmins), 'go', label="Minima")
# Plot the roots
roots = np.array([root.x, root2.x])
ax.plot(roots, f(roots), 'kv', label="Roots")
# Decorate the figure
ax.legend(loc='best')
ax.set_xlabel('x')
ax.set_ylabel('f(x)')
ax.axhline(0, color='gray')
plt.show()
Total running time of the script: ( 0 minutes 0.043 seconds)
"""
===============================
Minima and roots of a function
===============================
Demos finding minima and roots of a function.
"""
############################################################
# Define the function
############################################################
import numpy as np
x = np.arange(-10, 10, 0.1)
def f(x):
return x**2 + 10*np.sin(x)
############################################################
# Find minima
############################################################
from scipy import optimize
# Global optimization
grid = (-10, 10, 0.1)
xmin_global = optimize.brute(f, (grid, ))
print("Global minima found %s" % xmin_global)
# Constrain optimization
xmin_local = optimize.fminbound(f, 0, 10)
print("Local minimum found %s" % xmin_local)
############################################################
# Root finding
############################################################
root = optimize.root(f, 1) # our initial guess is 1
print("First root found %s" % root.x)
root2 = optimize.root(f, -2.5)
print("Second root found %s" % root2.x)
############################################################
# Plot function, minima, and roots
############################################################
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111)
# Plot the function
ax.plot(x, f(x), 'b-', label="f(x)")
# Plot the minima
xmins = np.array([xmin_global[0], xmin_local])
ax.plot(xmins, f(xmins), 'go', label="Minima")
# Plot the roots
roots = np.array([root.x, root2.x])
ax.plot(roots, f(roots), 'kv', label="Roots")
# Decorate the figure
ax.legend(loc='best')
ax.set_xlabel('x')
ax.set_ylabel('f(x)')
ax.axhline(0, color='gray')
plt.show()
1.5.12.15 Optimization of a two-parameter function
import numpy as np
# Define the function that we are interested in
def sixhump(x):
return ((4 - 2.1*x[0]**2 + x[0]**4 / 3.) * x[0]**2 + x[0] * x[1]
+ (-4 + 4*x[1]**2) * x[1] **2)
# Make a grid to evaluate the function (for plotting)
x = np.linspace(-2, 2)
y = np.linspace(-1, 1)
xg, yg = np.meshgrid(x, y)
A 2D image plot of the function
Simple visualization in 2D
import matplotlib.pyplot as plt
plt.figure()
plt.imshow(sixhump([xg, yg]), extent=[-2, 2, -1, 1])
plt.colorbar()
A 3D surface plot of the function
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(xg, yg, sixhump([xg, yg]), rstride=1, cstride=1,
cmap=plt.cm.jet, linewidth=0, antialiased=False)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('f(x, y)')
ax.set_title('Six-hump Camelback function')
Find the minima
from scipy import optimize
x_min = optimize.minimize(sixhump, x0=[0, 0])
plt.figure()
# Show the function in 2D
plt.imshow(sixhump([xg, yg]), extent=[-2, 2, -1, 1])
plt.colorbar()
# And the minimum that we've found:
plt.scatter(x_min.x[0], x_min.x[1])
plt.show()
Total running time of the script: ( 0 minutes 0.236 seconds)
"""
=========================================
Optimization of a two-parameter function
=========================================
"""
import numpy as np
# Define the function that we are interested in
def sixhump(x):
return ((4 - 2.1*x[0]**2 + x[0]**4 / 3.) * x[0]**2 + x[0] * x[1]
+ (-4 + 4*x[1]**2) * x[1] **2)
# Make a grid to evaluate the function (for plotting)
x = np.linspace(-2, 2)
y = np.linspace(-1, 1)
xg, yg = np.meshgrid(x, y)
############################################################
# A 2D image plot of the function
############################################################
# Simple visualization in 2D
import matplotlib.pyplot as plt
plt.figure()
plt.imshow(sixhump([xg, yg]), extent=[-2, 2, -1, 1])
plt.colorbar()
############################################################
# A 3D surface plot of the function
############################################################
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(xg, yg, sixhump([xg, yg]), rstride=1, cstride=1,
cmap=plt.cm.jet, linewidth=0, antialiased=False)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('f(x, y)')
ax.set_title('Six-hump Camelback function')
############################################################
# Find the minima
############################################################
from scipy import optimize
x_min = optimize.minimize(sixhump, x0=[0, 0])
plt.figure()
# Show the function in 2D
plt.imshow(sixhump([xg, yg]), extent=[-2, 2, -1, 1])
plt.colorbar()
# And the minimum that we've found:
plt.scatter(x_min.x[0], x_min.x[1])
plt.show()
1.5.12.16 Plot filtering on images
Demo filtering for denoising of images.
# Load some data
from scipy import misc
face = misc.face(gray=True)
face = face[:512, -512:] # crop out square on right
# Apply a variety of filters
from scipy import ndimage
from scipy import signal
from matplotlib import pyplot as plt
import numpy as np
noisy_face = np.copy(face).astype(np.float)
noisy_face += face.std() * 0.5 * np.random.standard_normal(face.shape)
blurred_face = ndimage.gaussian_filter(noisy_face, sigma=3)
median_face = ndimage.median_filter(noisy_face, size=5)
wiener_face = signal.wiener(noisy_face, (5, 5))
plt.figure(figsize=(12, 3.5))
plt.subplot(141)
plt.imshow(noisy_face, cmap=plt.cm.gray)
plt.axis('off')
plt.title('noisy')
plt.subplot(142)
plt.imshow(blurred_face, cmap=plt.cm.gray)
plt.axis('off')
plt.title('Gaussian filter')
plt.subplot(143)
plt.imshow(median_face, cmap=plt.cm.gray)
plt.axis('off')
plt.title('median filter')
plt.subplot(144)
plt.imshow(wiener_face, cmap=plt.cm.gray)
plt.title('Wiener filter')
plt.axis('off')
plt.subplots_adjust(wspace=.05, left=.01, bottom=.01, right=.99, top=.99)
plt.show()
Total running time of the script: ( 0 minutes 0.420 seconds)
1.5.12.17 Plotting and manipulating FFTs for filtering
Plot the power of the FFT of a signal and inverse FFT back to reconstruct a signal.
This example demonstrate scipy.fftpack.fft()
, scipy.fftpack.fftfreq()
and scipy.fftpack.ifft()
.
It implements a basic filter that is very suboptimal, and should not be used.
import numpy as np
from scipy import fftpack
from matplotlib import pyplot as plt
Generate the signal
# Seed the random number generator
np.random.seed(1234)
time_step = 0.02
period = 5.
time_vec = np.arange(0, 20, time_step)
sig = (np.sin(2 * np.pi / period * time_vec)
+ 0.5 * np.random.randn(time_vec.size))
plt.figure(figsize=(6, 5))
plt.plot(time_vec, sig, label='Original signal')
Compute and plot the power
# The FFT of the signal
sig_fft = fftpack.fft(sig)
# And the power (sig_fft is of complex dtype)
power = np.abs(sig_fft)
# The corresponding frequencies
sample_freq = fftpack.fftfreq(sig.size, d=time_step)
# Plot the FFT power
plt.figure(figsize=(6, 5))
plt.plot(sample_freq, power)
plt.xlabel('Frequency [Hz]')
plt.ylabel('plower')
# Find the peak frequency: we can focus on only the positive frequencies
pos_mask = np.where(sample_freq > 0)
freqs = sample_freq[pos_mask]
peak_freq = freqs[power[pos_mask].argmax()]
# Check that it does indeed correspond to the frequency that we generate
# the signal with
np.allclose(peak_freq, 1./period)
# An inner plot to show the peak frequency
axes = plt.axes([0.55, 0.3, 0.3, 0.5])
plt.title('Peak frequency')
plt.plot(freqs[:8], power[:8])
plt.setp(axes, yticks=[])
# scipy.signal.find_peaks_cwt can also be used for more advanced
# peak detection
Remove all the high frequencies
We now remove all the high frequencies and transform back from frequencies to signal.
high_freq_fft = sig_fft.copy()
high_freq_fft[np.abs(sample_freq) > peak_freq] = 0
filtered_sig = fftpack.ifft(high_freq_fft)
plt.figure(figsize=(6, 5))
plt.plot(time_vec, sig, label='Original signal')
plt.plot(time_vec, filtered_sig, linewidth=3, label='Filtered signal')
plt.xlabel('Time [s]')
plt.ylabel('Amplitude')
plt.legend(loc='best')
Note This is actually a bad way of creating a filter: such brutal cut-off in frequency space does not control distorsion on the signal.
Filters should be created using the scipy filter design code
plt.show()
Total running time of the script: ( 0 minutes 0.142 seconds)
"""
=============================================
Plotting and manipulating FFTs for filtering
=============================================
Plot the power of the FFT of a signal and inverse FFT back to reconstruct
a signal.
This example demonstrate :func:`scipy.fftpack.fft`,
:func:`scipy.fftpack.fftfreq` and :func:`scipy.fftpack.ifft`. It
implements a basic filter that is very suboptimal, and should not be
used.
"""
import numpy as np
from scipy import fftpack
from matplotlib import pyplot as plt
############################################################
# Generate the signal
############################################################
# Seed the random number generator
np.random.seed(1234)
time_step = 0.02
period = 5.
time_vec = np.arange(0, 20, time_step)
sig = (np.sin(2 * np.pi / period * time_vec)
+ 0.5 * np.random.randn(time_vec.size))
plt.figure(figsize=(6, 5))
plt.plot(time_vec, sig, label='Original signal')
############################################################
# Compute and plot the power
############################################################
# The FFT of the signal
sig_fft = fftpack.fft(sig)
# And the power (sig_fft is of complex dtype)
power = np.abs(sig_fft)
# The corresponding frequencies
sample_freq = fftpack.fftfreq(sig.size, d=time_step)
# Plot the FFT power
plt.figure(figsize=(6, 5))
plt.plot(sample_freq, power)
plt.xlabel('Frequency [Hz]')
plt.ylabel('plower')
# Find the peak frequency: we can focus on only the positive frequencies
pos_mask = np.where(sample_freq > 0)
freqs = sample_freq[pos_mask]
peak_freq = freqs[power[pos_mask].argmax()]
# Check that it does indeed correspond to the frequency that we generate
# the signal with
np.allclose(peak_freq, 1./period)
# An inner plot to show the peak frequency
axes = plt.axes([0.55, 0.3, 0.3, 0.5])
plt.title('Peak frequency')
plt.plot(freqs[:8], power[:8])
plt.setp(axes, yticks=[])
# scipy.signal.find_peaks_cwt can also be used for more advanced
# peak detection
############################################################
# Remove all the high frequencies
############################################################
#
# We now remove all the high frequencies and transform back from
# frequencies to signal.
high_freq_fft = sig_fft.copy()
high_freq_fft[np.abs(sample_freq) > peak_freq] = 0
filtered_sig = fftpack.ifft(high_freq_fft)
plt.figure(figsize=(6, 5))
plt.plot(time_vec, sig, label='Original signal')
plt.plot(time_vec, filtered_sig, linewidth=3, label='Filtered signal')
plt.xlabel('Time [s]')
plt.ylabel('Amplitude')
plt.legend(loc='best')
############################################################
#
# **Note** This is actually a bad way of creating a filter: such brutal
# cut-off in frequency space does not control distorsion on the signal.
#
# Filters should be created using the scipy filter design code
plt.show()
1.5.12.18 Solutions of the exercises for scipy
Crude periodicity finding
Discover the periods in evolution of animal populations (../../data/populations.txt
)
Load the data
import numpy as np
data = np.loadtxt('../../../../data/populations.txt')
years = data[:, 0]
populations = data[:, 1:]
Plot the data
import matplotlib.pyplot as plt
plt.figure()
plt.plot(years, populations * 1e-3)
plt.xlabel('Year')
plt.ylabel('Population number ($\cdot10^3$)')
plt.legend(['hare', 'lynx', 'carrot'], loc=1)
Plot its periods
from scipy import fftpack
ft_populations = fftpack.fft(populations, axis=0)
frequencies = fftpack.fftfreq(populations.shape[0], years[1] - years[0])
periods = 1 / frequencies
plt.figure()
plt.plot(periods, abs(ft_populations) * 1e-3, 'o')
plt.xlim(0, 22)
plt.xlabel('Period')
plt.ylabel('Power ($\cdot10^3$)')
plt.show()
There’s probably a period of around 10 years (obvious from the plot), but for this crude a method, there’s not enough data to say much more.
Total running time of the script: ( 0 minutes 0.073 seconds)
Curve fitting: temperature as a function of month of the year
We have the min and max temperatures in Alaska for each months of the year. We would like to find a function to describe this yearly evolution.
For this, we will fit a periodic function.
The data
import numpy as np
temp_max = np.array([17, 19, 21, 28, 33, 38, 37, 37, 31, 23, 19, 18])
temp_min = np.array([-62, -59, -56, -46, -32, -18, -9, -13, -25, -46, -52, -58])
import matplotlib.pyplot as plt
months = np.arange(12)
plt.plot(months, temp_max, 'ro')
plt.plot(months, temp_min, 'bo')
plt.xlabel('Month')
plt.ylabel('Min and max temperature')
Fitting it to a periodic function
from scipy import optimize
def yearly_temps(times, avg, ampl, time_offset):
return (avg
+ ampl * np.cos((times + time_offset) * 2 * np.pi / times.max()))
res_max, cov_max = optimize.curve_fit(yearly_temps, months,
temp_max, [20, 10, 0])
res_min, cov_min = optimize.curve_fit(yearly_temps, months,
temp_min, [-40, 20, 0])
Plotting the fit
days = np.linspace(0, 12, num=365)
plt.figure()
plt.plot(months, temp_max, 'ro')
plt.plot(days, yearly_temps(days, *res_max), 'r-')
plt.plot(months, temp_min, 'bo')
plt.plot(days, yearly_temps(days, *res_min), 'b-')
plt.xlabel('Month')
plt.ylabel('Temperature ($^\circ$C)')
plt.show()
Total running time of the script: ( 0 minutes 0.072 seconds)
"""
==============================================================
Curve fitting: temperature as a function of month of the year
==============================================================
We have the min and max temperatures in Alaska for each months of the
year. We would like to find a function to describe this yearly evolution.
For this, we will fit a periodic function.
"""
############################################################
# The data
############################################################
import numpy as np
temp_max = np.array([17, 19, 21, 28, 33, 38, 37, 37, 31, 23, 19, 18])
temp_min = np.array([-62, -59, -56, -46, -32, -18, -9, -13, -25, -46, -52, -58])
import matplotlib.pyplot as plt
months = np.arange(12)
plt.plot(months, temp_max, 'ro')
plt.plot(months, temp_min, 'bo')
plt.xlabel('Month')
plt.ylabel('Min and max temperature')
############################################################
# Fitting it to a periodic function
############################################################
from scipy import optimize
def yearly_temps(times, avg, ampl, time_offset):
return (avg
+ ampl * np.cos((times + time_offset) * 2 * np.pi / times.max()))
res_max, cov_max = optimize.curve_fit(yearly_temps, months,
temp_max, [20, 10, 0])
res_min, cov_min = optimize.curve_fit(yearly_temps, months,
temp_min, [-40, 20, 0])
############################################################
# Plotting the fit
############################################################
days = np.linspace(0, 12, num=365)
plt.figure()
plt.plot(months, temp_max, 'ro')
plt.plot(days, yearly_temps(days, *res_max), 'r-')
plt.plot(months, temp_min, 'bo')
plt.plot(days, yearly_temps(days, *res_min), 'b-')
plt.xlabel('Month')
plt.ylabel('Temperature ($^\circ$C)')
plt.show()
Simple image blur by convolution with a Gaussian kernel
Blur an an image (../../../../data/elephant.png
)
using a Gaussian kernel.
Convolution is easy to perform with FFT: convolving two signals boils down to multiplying their FFTs (and performing an inverse FFT).
import numpy as np
from scipy import fftpack
import matplotlib.pyplot as plt
The original image
# read image
img = plt.imread('../../../../data/elephant.png')
plt.figure()
plt.imshow(img)
Prepare an Gaussian convolution kernel
# First a 1-D Gaussian
t = np.linspace(-10, 10, 30)
bump = np.exp(-0.1*t**2)
bump /= np.trapz(bump) # normalize the integral to 1
# make a 2-D kernel out of it
kernel = bump[:, np.newaxis] * bump[np.newaxis, :]
Implement convolution via FFT
# Padded fourier transform, with the same shape as the image
# We use :func:`scipy.signal.fftpack.fft2` to have a 2D FFT
kernel_ft = fftpack.fft2(kernel, shape=img.shape[:2], axes=(0, 1))
# convolve
img_ft = fftpack.fft2(img, axes=(0, 1))
# the 'newaxis' is to match to color direction
img2_ft = kernel_ft[:, :, np.newaxis] * img_ft
img2 = fftpack.ifft2(img2_ft, axes=(0, 1)).real
# clip values to range
img2 = np.clip(img2, 0, 1)
# plot output
plt.figure()
plt.imshow(img2)
Further exercise (only if you are familiar with this stuff):
A “wrapped border” appears in the upper left and top edges of the image. This is because the padding is not done correctly, and does not take the kernel size into account (so the convolution “flows out of bounds of the image”). Try to remove this artifact.
A function to do it: scipy.signal.fftconvolve()
The above exercise was only for didactic reasons: there exists a function in scipy that will do this for us, and probably do a better job:scipy.signal.fftconvolve()
from scipy import signal
# mode='same' is there to enforce the same output shape as input arrays
# (ie avoid border effects)
img3 = signal.fftconvolve(img, kernel[:, :, np.newaxis], mode='same')
plt.figure()
plt.imshow(img3)
Note that we still have a decay to zero at the border of the image. Usingscipy.ndimage.gaussian_filter()
would
get rid of this artifact
plt.show()
Total running time of the script: ( 0 minutes 0.195 seconds)
"""
=======================================================
Simple image blur by convolution with a Gaussian kernel
=======================================================
Blur an an image (:download:`../../../../data/elephant.png`) using a
Gaussian kernel.
Convolution is easy to perform with FFT: convolving two signals boils
down to multiplying their FFTs (and performing an inverse FFT).
"""
import numpy as np
from scipy import fftpack
import matplotlib.pyplot as plt
#####################################################################
# The original image
#####################################################################
# read image
img = plt.imread('../../../../data/elephant.png')
plt.figure()
plt.imshow(img)
#####################################################################
# Prepare an Gaussian convolution kernel
#####################################################################
# First a 1-D Gaussian
t = np.linspace(-10, 10, 30)
bump = np.exp(-0.1*t**2)
bump /= np.trapz(bump) # normalize the integral to 1
# make a 2-D kernel out of it
kernel = bump[:, np.newaxis] * bump[np.newaxis, :]
#####################################################################
# Implement convolution via FFT
#####################################################################
# Padded fourier transform, with the same shape as the image
# We use :func:`scipy.signal.fftpack.fft2` to have a 2D FFT
kernel_ft = fftpack.fft2(kernel, shape=img.shape[:2], axes=(0, 1))
# convolve
img_ft = fftpack.fft2(img, axes=(0, 1))
# the 'newaxis' is to match to color direction
img2_ft = kernel_ft[:, :, np.newaxis] * img_ft
img2 = fftpack.ifft2(img2_ft, axes=(0, 1)).real
# clip values to range
img2 = np.clip(img2, 0, 1)
# plot output
plt.figure()
plt.imshow(img2)
#####################################################################
# Further exercise (only if you are familiar with this stuff):
#
# A "wrapped border" appears in the upper left and top edges of the
# image. This is because the padding is not done correctly, and does
# not take the kernel size into account (so the convolution "flows out
# of bounds of the image"). Try to remove this artifact.
#####################################################################
# A function to do it: :func:`scipy.signal.fftconvolve`
#####################################################################
#
# The above exercise was only for didactic reasons: there exists a
# function in scipy that will do this for us, and probably do a better
# job: :func:`scipy.signal.fftconvolve`
from scipy import signal
# mode='same' is there to enforce the same output shape as input arrays
# (ie avoid border effects)
img3 = signal.fftconvolve(img, kernel[:, :, np.newaxis], mode='same')
plt.figure()
plt.imshow(img3)
#####################################################################
# Note that we still have a decay to zero at the border of the image.
# Using :func:`scipy.ndimage.gaussian_filter` would get rid of this
# artifact
plt.show()
Image denoising by FFT
Denoise an image (../../../../data/moonlanding.png
)
by implementing a blur with an FFT.
Implements, via FFT, the following convolution:
Read and plot the image
import numpy as np
import matplotlib.pyplot as plt
im = plt.imread('../../../../data/moonlanding.png').astype(float)
plt.figure()
plt.imshow(im, plt.cm.gray)
plt.title('Original image')
Compute the 2d FFT of the input image
from scipy import fftpack
im_fft = fftpack.fft2(im)
# Show the results
def plot_spectrum(im_fft):
from matplotlib.colors import LogNorm
# A logarithmic colormap
plt.imshow(np.abs(im_fft), norm=LogNorm(vmin=5))
plt.colorbar()
plt.figure()
plot_spectrum(im_fft)
plt.title('Fourier transform')
Filter in FFT
# In the lines following, we'll make a copy of the original spectrum and
# truncate coefficients.
# Define the fraction of coefficients (in each direction) we keep
keep_fraction = 0.1
# Call ff a copy of the original transform. Numpy arrays have a copy
# method for this purpose.
im_fft2 = im_fft.copy()
# Set r and c to be the number of rows and columns of the array.
r, c = im_fft2.shape
# Set to zero all rows with indices between r*keep_fraction and
# r*(1-keep_fraction):
im_fft2[int(r*keep_fraction):int(r*(1-keep_fraction))] = 0
# Similarly with the columns:
im_fft2[:, int(c*keep_fraction):int(c*(1-keep_fraction))] = 0
plt.figure()
plot_spectrum(im_fft2)
plt.title('Filtered Spectrum')
Reconstruct the final image
# Reconstruct the denoised image from the filtered spectrum, keep only the
# real part for display.
im_new = fftpack.ifft2(im_fft2).real
plt.figure()
plt.imshow(im_new, plt.cm.gray)
plt.title('Reconstructed Image')
Easier and better:scipy.ndimage.gaussian_filter()
Implementing filtering directly with FFTs is tricky and time consuming. We can use the Gaussian filter fromscipy.ndimage
from scipy import ndimage
im_blur = ndimage.gaussian_filter(im, 4)
plt.figure()
plt.imshow(im_blur, plt.cm.gray)
plt.title('Blurred image')
plt.show()
Total running time of the script: ( 0 minutes 0.381 seconds)
r"""
======================
Image denoising by FFT
======================
Denoise an image (:download:`../../../../data/moonlanding.png`) by
implementing a blur with an FFT.
Implements, via FFT, the following convolution:
.. math::
f_1(t) = \int dt'\, K(t-t') f_0(t')
.. math::
\tilde{f}_1(\omega) = \tilde{K}(\omega) \tilde{f}_0(\omega)
"""
############################################################
# Read and plot the image
############################################################
import numpy as np
import matplotlib.pyplot as plt
im = plt.imread('../../../../data/moonlanding.png').astype(float)
plt.figure()
plt.imshow(im, plt.cm.gray)
plt.title('Original image')
############################################################
# Compute the 2d FFT of the input image
############################################################
from scipy import fftpack
im_fft = fftpack.fft2(im)
# Show the results
def plot_spectrum(im_fft):
from matplotlib.colors import LogNorm
# A logarithmic colormap
plt.imshow(np.abs(im_fft), norm=LogNorm(vmin=5))
plt.colorbar()
plt.figure()
plot_spectrum(im_fft)
plt.title('Fourier transform')
############################################################
# Filter in FFT
############################################################
# In the lines following, we'll make a copy of the original spectrum and
# truncate coefficients.
# Define the fraction of coefficients (in each direction) we keep
keep_fraction = 0.1
# Call ff a copy of the original transform. Numpy arrays have a copy
# method for this purpose.
im_fft2 = im_fft.copy()
# Set r and c to be the number of rows and columns of the array.
r, c = im_fft2.shape
# Set to zero all rows with indices between r*keep_fraction and
# r*(1-keep_fraction):
im_fft2[int(r*keep_fraction):int(r*(1-keep_fraction))] = 0
# Similarly with the columns:
im_fft2[:, int(c*keep_fraction):int(c*(1-keep_fraction))] = 0
plt.figure()
plot_spectrum(im_fft2)
plt.title('Filtered Spectrum')
############################################################
# Reconstruct the final image
############################################################
# Reconstruct the denoised image from the filtered spectrum, keep only the
# real part for display.
im_new = fftpack.ifft2(im_fft2).real
plt.figure()
plt.imshow(im_new, plt.cm.gray)
plt.title('Reconstructed Image')
############################################################
# Easier and better: :func:`scipy.ndimage.gaussian_filter`
############################################################
#
# Implementing filtering directly with FFTs is tricky and time consuming.
# We can use the Gaussian filter from :mod:`scipy.ndimage`
from scipy import ndimage
im_blur = ndimage.gaussian_filter(im, 4)
plt.figure()
plt.imshow(im_blur, plt.cm.gray)
plt.title('Blurred image')
plt.show()
See also
References to go further
- Some chapters of the advanced and the packages and applications parts of the scipy lectures
- The scipy cookbook
上一篇: 项目管理心得:一个项目经理的个人体会、经验总结(转载)
下一篇: java 读写锁详解
推荐阅读