欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

非线性曲线拟合(高斯分布为例)--scipy求解器中optimize的curve_fit的应用--附代码

程序员文章站 2024-01-16 21:36:58
...
  1. 安装求解包scipy
pip install scipy
  1. 定义对应的高斯函数
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from scipy import optimize as op

def get_boundaries(obj, length=5):
    """
    获取目标数最近边界
    
    Parameters: obj: 目标数
                length:区间长度
    Returns: 边界值
    
    Examples:
        >>>input: get_boundaries(-23.5, length=5)
        >>>output: -25
    """
    return round(obj - (obj % length)) if obj < 0 else round(obj + length - (obj % length))


def f_gauss(x, A, B, C, sigma):
    """高斯函数的变换形式"""
    return A*np.exp(-(x-B)**2/(2*sigma**2)) + C


def normfun(x, mu, sigma, C=0):
    """标准高斯分布"""
    pdf = np.exp(-((x - mu)**2)/(2*sigma**2)) / (sigma * np.sqrt(2*np.pi)) + C
    return pdf


def norm_fitt(x, y, func=f_gauss, fitt_num=100):
    """
    拟合函数
    
    params: x: 测试数据的自变量
            y: 频数
            func: 选择拟合函数
            fitt_num:拟合图的x数量(默认100)
    """
    x = np.array(x)
    y = np.array(y)
    
    print('x:', list(x))
    print('y:', list(y))
    popt, pcov = op.curve_fit(func, x, y)
    print('参数:', list(popt))
    
    x_1 = np.linspace(x.min(), x.max(), fitt_num)
    plt.figure(figsize=(12, 8))
    
    A, B, C, sigma = popt
    label = r'$f(x) = %s*e^\frac{(x-%s)^2}{2%s^2}+%s$' %(A.round(2), B.round(2), sigma.round(2), C.round(2))
    plt.plot(x_1, func(x_1, *popt), label=label, alpha=2.5, linewidth=3)
    plt.plot(x, y, 'o-', alpha=0.5, label='Original Line', linewidth=1)
    plt.legend(fontsize=16)
    

def main(data_list, bins=0, length=5, is_freq=True, func=f_gauss, fitt_num=100, adj_x=[], adj_y=[]):
    """
    主函数
    
    Parameters: data_list:需要统计的数据列表
                bins: 柱状图个数(默认0,自动调整)
                length: 区间长度(默认5)
                is_freq: 是否用频率计算(默认True)
                func: 选择拟合函数
                fitt_num: 拟合图的x数量(默认100)
                adj_x: 手动输入 x 的值,默认[]
                adj_y: 手动输入 y 的值,默认[]
    """
    # 获取x轴边界
    left, right = get_boundaries(min(data_list), length), get_boundaries(max(data_list), length)   # 左右边界
    print('左边界:', left, ',右边界:', right)
    
    # 根据边界以及区间长度计算柱状图数量
    bins = int(bins if bins else (right - left) / length)
    print(bins)
    
    # 画直方图,获取直方图的x,y坐标
    plt.figure(figsize=(10, 6))
    y, x, _ = plt.hist(data_list, bins=bins, range=[left, right])
    
    # 获取两两x的中间值
    x_ = [(x[i] + x[i - 1]) / 2 for i in range(1, len(x))]
    
    # 将频数转换为频率
    y_ = ((y / y.sum()) * 100).round(4) if is_freq else y
    
    # 是否使用手动输入x,y
    x_ = adj_x if adj_x else x_
    y_ = adj_y if adj_y else y_

    
    # 高斯拟合
    norm_fitt(x_, y_, func=func, fitt_num=fitt_num)
  1. 输入离散数据
data_list = [-2.360e+01, -1.994e+01, -1.564e+01, -1.396e+01, -1.294e+01,
            -1.235e+01, -1.057e+01, -9.290e+00, -8.430e+00, -8.100e+00,
            -7.380e+00, -7.160e+00, -6.960e+00, -6.580e+00, -6.210e+00,
            -5.930e+00, -5.790e+00, -5.660e+00, -5.330e+00, -4.960e+00,
            -4.870e+00, -4.300e+00, -4.230e+00, -3.890e+00, -3.870e+00,
            -3.870e+00, -3.780e+00, -3.690e+00, -3.570e+00, -3.430e+00,
            -3.410e+00, -2.990e+00, -2.810e+00, -1.830e+00, -1.440e+00,
            -8.700e-01, -1.800e-01,  1.000e-02,  3.200e-01,  3.400e-01,
            4.800e-01,  5.600e-01,  7.600e-01,  1.070e+00,  1.600e+00,
            1.930e+00,  1.960e+00,  2.030e+00,  2.230e+00,  2.280e+00,
            2.450e+00,  2.670e+00,  2.870e+00,  3.280e+00,  3.430e+00,
            3.710e+00,  4.070e+00,  4.550e+00,  4.660e+00,  5.090e+00,
            5.620e+00,  5.900e+00,  6.670e+00,  6.870e+00,  7.140e+00,
            7.350e+00,  7.410e+00,  7.520e+00,  7.870e+00,  7.880e+00,
            8.680e+00,  9.050e+00,  9.130e+00,  9.240e+00,  9.650e+00,
            9.930e+00,  9.970e+00,  1.019e+01,  1.022e+01,  1.030e+01,
            1.041e+01,  1.045e+01,  1.052e+01,  1.057e+01,  1.066e+01,
            1.161e+01,  1.180e+01,  1.180e+01,  1.228e+01,  1.229e+01,
            1.250e+01,  1.410e+01,  1.427e+01,  1.462e+01,  1.517e+01,
            1.562e+01,  1.562e+01,  1.574e+01,  1.732e+01,  1.735e+01,
            1.994e+01,  2.326e+01,  2.360e+01]

y = [0.97, 1.94, 3.88, 11.65, 17.48, 21.36, 17.50, 16.5, 6.8, 1.94]
main(data_list, adj_y=y)
# main(data_list, bins=10)
  1. 得出拟合后的高斯分布函数
    非线性曲线拟合(高斯分布为例)--scipy求解器中optimize的curve_fit的应用--附代码
    即拟合后的均值为3.63,标准差为29.84。

  2. 注意点:
    a. 公式里自定义了频数分布区间为5,可以改左右区间以及区间个数bins;
    b. adj_y=y表示用自定义的y代替了原来的y,即绘图的x、y也可以自定义;
    c. 当前使用相对频率进行了绘图,可以自行选择频数/相对频率进行绘图。

转载请附出处,谢谢。