欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

三次样条+线性插值

程序员文章站 2023-12-27 23:38:15
...
               三次样条插值是要解决,在模拟函数的时候,插值节点处不平滑的问题
S(x)是f(x) (x0 .... xn)的三次样条插值函数,则S(x)要在区间(xj,xj+1)上满足S(xj-0) = S(xj+1+0) , S'(xj-0) = S'(xj+1 + 0), s''(xj - 0) = S''(xj+1 +0),
       这样在每一个这样的区间建立一个等式条件,同时在两个端点出处得到两个确定值
(S'(x0) = f'(x0),S'(xn) = f'(xn)   ,S''(x0) = f''(x0) S''(xn) = f''(xn))
通过上述的等式可以得出S(x)的最终形式
因为S(x)是三次的,导数两次之后就是线性的,在两点之间可以求出两次导数后的S''(x)在往回积分得到下面的式子
S(x) = Mj*(xj+1 - x)^3/(6*hj) + Mj+1(x - xj)^3/(6 * hj) + (yj - Mj*hj^2/6)*(xj+1 - x)/hj+(yj+1 - Mj+1 * hj^2)*(x - xj)/hj           j = 0,1,......n-1
       hj 就是 xj+1 - xj,现在就是需要求 Mj  (Mj是S''(xj)的值)
        由求得的S(x)导数一次  +  S'(xj-0) = S'(xj+1+0) 可以得到
        uj * Mj-1 + 2Mj  +lj * Mj+1 = dj
        uj = hj-1/(hj-1 + hj)      lj = 1 - uj

        dj = 6f[xj-1,xj,xj+1]均差    

其中 d0 = 6/h0 * (f[x0,x1] - f'(x0))    dn = 6/hn-1 * (f'(xn) - f[xn-1,xn])

       可以根据这些写出矩阵

2 l0

         u1 2 l1                               *M = d[i =0 ....n]

               u2 2 l2

              .....

                              un 2        


      这样可以得到M了,其他需要的也可以计算了.


   

画图的时候用到了matplotlib 

import scipy.interpolate as itp
import matplotlib.pyplot as plt
import numpy as np

''''
线性插值
'''

def line_insert(xn,yn):
    def line(x):
        for value in range(len(xn)-1):
            if xn[value] <= x <= xn[value+1]:
                result1 = ((x-xn[value+1])/(xn[value] - xn[value+1]))*yn[value]
                result2 = ((x-xn[value])/(xn[value+1] - xn[value]))*yn[value+1]
                return result1+result2
    return line
这个是线性插值,返回一个直接计算的函数

接下来是实现的三次样条的茬插值


'''
三次样条
'''
last_trax = []
#均差 求d的时候
def sub_spl(x,y):
    if len(x) == 2:
        return (y[1] -y[0])/(x[1] - x[0])
    return (sub_spl(x[1:],y[1:]) - sub_spl(x[:len(x)-1],y[:len(y) - 1]))/(x[len(x) - 1] - x[0])
def three_spline_trax(tup_x,tup_y,s0,sn):
    trax = []
    h = []
    for i in range(len(tup_x) - 1):
        h.append(tup_x[i+1] - tup_x[i])
    d = [0 for i in  range(len(tup_x))]
    d[0] = 6/h[0] * (sub_spl(tup_x[:2],tup_y[:2]) - s0)
    d [-1] = 6/h[-1] * (sn - sub_spl(tup_x[-2:],tup_y[-2:]))
    u = [0 for i in  range(len(tup_x)-1)]
    for j in  range(1,len(tup_x) - 1):
        d[j] = 6 * sub_spl(tup_x[j-1:j+2],tup_y[j-1:j+2])
        u[j] = h[j-1]/(h[j-1] + h[j])
    l =  [1-i for i in u]
    u.append(1) #un = 1
    for i in range(len(tup_x)):
        trax.append([0 for j in range(len(tup_x))])
        trax[i][i] = 2

    for i in range(len(tup_x) - 1):
        trax[i][i+1] = l[i]
        trax[i+1][i] = u[i+1]
        trax[i].append(d[i])
    trax[-1].append(d[-1])
    return trax,h
#对得到的矩阵简化上三角
def return_trax(trax,m,n,num):
    if num is 1:
        last_trax.append(trax[n-1][n-1:])
        return 
    else:
        #寻找主元
        Max = abs(trax[0][0])
        t1 = 0
        for i in  range(1,m):
            if abs(trax[i][0]) > Max:
                Max = trax[i][0]
                t1 = i
        trax[0],trax[t1] = trax[t1],trax[0]
        tmp2 = []
        last_trax.append(trax[0])
        for i in range(1,n):
            #倍数
            tmp = -trax[i][0]/trax[0][0]
            def cal(tup):
                return tup[0] + tup[1] * tmp
            tmp2.append(map(cal,zip(trax[i][1:],trax[0][1:])))
        trax = tmp2
        return return_trax(trax,m-1,n-1,num-1)

def s_um(tup):
    return reduce(lambda x,y:x*y,tup)

#这个函数得到M0 - Mn
def calculate(trax,m):
    #倒着循环
    l = [0 for i in range(m)]
    for i in range(m-1,-1,-1):
        if i == m-1:
            #直接求出最后一个参数的数值
            l[i] = trax[i][m-i]/trax[i][0]
        else:
            tmp3 = trax[i][m-i]
            tmp3 -= sum(map(s_um,zip(l[i:],trax[i][:m-i])))
            l[i] = tmp3/trax[i][0]
    return l

def res(l,tup_x1,tup_y1,h):
    def outcome(x):
        for j in  range(len(tup_x1) - 1):
            if tup_x1[j] <= x <= tup_x1[j+1]:
                sum_func = l[j]*(tup_x1[j+1] - x)**3/(6*h[j]) + l[j+1]*(x - tup_x1[j])**3/(6*h[j])+(tup_y1[j] - l[j]/6*h[j]**2)*(tup_x1[j+1] - x)/h[j] + (tup_y1[j+1]-l[j+1]/6*h[j]**2)*(x - tup_x1[j])/h[j]
                return sum_func
    return outcome
在计算矩阵的乘法的时候,直接用上次的列主元消去....

最后的res函数就是在构造S(x)

three_line_trax 就是返回上面写的那个带2的矩阵,h就是相邻x的插值

最后再用return_trax返回上三角矩阵,calculate计算就可以

可能中间有些地方说的不准确,如果发现请指正


上一篇:

下一篇: