样条插值的思想做回归

    科技2023-10-09  73

    样条插值的思想做回归

    一、生成数据

    多项式 y = 0.2 x 3 + 0.5 x 2 − 0.8 x + 3 y = 0.2x^3+0.5x^2-0.8x+3 y=0.2x3+0.5x20.8x+3 再加上服从正态分布的噪声 ξ \xi ξ

    import numpy import matplotlib.pyplot as plt numpy.random.seed(1) def cal_poly(x): return 0.2 * x ** 3 + 0.5 * x**2 - 0.8 * x + 3 #生成100个数据 x_data = numpy.linspace(-10, 10, num=100) y_data = [cal_poly(i)+numpy.random.randint(0,1) for i in x] #生成点的图 fig, ax = plt.subplots() ax.scatter(x_data,y_data) plt.show()

    二、样条思想

    介绍: 样条插值法本来是从一个函数中选取若干个点,去做数值分析。在这里使用插值方法,我想到可以在在梯度变化最大的地方取样本,可以根据总长度,平均分几块,每块取中位数的点,看各个块之间的梯度(不看梯度也行,但是这个就比较靠直观,本例去掉了阈值判断部分),梯度大于阈值,则取这个样本。设分的段数为n,这里n选10

    在本例中,分10个区域,得到 x 1 到 x 10 x_1\text{到}x_{10} x1x10,分别计算相互之间的梯度,梯度小于阈值,则去掉 x i + 1 x_{i+1} xi+1,与 x i + 2 x_{i+2} xi+2计算梯度,如果大于阈值,用 x i 到 x i + 2 x_i\text{到}x_{i+2} xixi+2算一个样条,从 x i + 3 x_{i+3} xi+3算另外的样条

    确定样本点后,通过样条插值的思想计算,第i段的直线的偏差 F i = ∑ i = i 1 N 1 1 2 ( a 1 x i + b 1 − y i ) 2 F_i = \sum_{i = i_1}^{N1} {\frac{1}{2}}(a_1x_i + b_1 - y_i)^2 Fi=i=i1N121(a1xi+b1yi)2 要使每段偏差最小,则需要令$ \frac{\partial{F_i}}{\partial{a_i}} = 0$, $ \frac{\partial{F_i}}{\partial{b_i}}=0 $ 产生两个点去确定这个直线,所有上面两个式子可以写成如下形式,从而解出 a j , b j a_j,b_j aj,bj, { ( a j x i + b j − y i ) ∗ x i = 0 , ( a j x i + b j − y i ) = 0 , \begin{cases} (a_jx_i + b_j- y_i) * x_i = 0, \\[2ex] (a_jx_i + b_j- y_i) = 0, \\ \end{cases} (ajxi+bjyi)xi=0,(ajxi+bjyi)=0, 求解得(以x1,x2为例): { a 1 = y 1 − y 2 x 1 − x 2 , b 1 = y 1 − ( y 1 − y 2 ) x 1 x 1 − x 2 , \begin{cases} a_1 = \frac{y_1-y_2}{x_1-x_2}, \\[2ex] b_1 = y_1 - \frac{(y_1 - y_2)x_1}{x_1-x_2}, \\ \end{cases} a1=x1x2y1y2,b1=y1x1x2(y1y2)x1, 在这个个过程记录下残差 ξ \xi ξ

    三、方法实现

    n = 20 #设置分多少个直线 # 取中位数 def get_canter(arr): x = sorted(arr) return x[int(len(arr)/2)] # 切分成n段,返回n*x维的list def slice(arr,n): # x = sorted(arr) #这是一个缺陷,因为本例生成的数据已经按x排好了 result = [] step = int(len(arr)/n) k = 0 for i in range(n): result.append(arr[(k+i*step):(k+(i+1)*step)]) return result # 求两点之间的梯度 def get_gradient(x1,y1,x2,y2): return (y1-y2)/(x1-x2) #获取样条值,每个包含两个点(一条直线) def get_sample(x_data,y_data,threshold): slic_x = slice(x_data,n) #分成十份 slic_y = slice(y_data,n) #分成十份 sample = [] i=0 while i<n: x_temp = slic_x[i] y_temp = slic_y[i] center_x = get_canter(x_temp) center_y = get_canter(y_temp) i+=1 if(i >= n): break x_temp_2 = slic_x[i] y_temp_2 = slic_y[i] center_x_2 = get_canter(x_temp_2) center_y_2 = get_canter(y_temp_2) gradient = get_gradient(center_x,center_y,center_x_2,center_y_2) # while (abs(gradient) < threshold): # i += 1 # if(i >= n): # break # x_temp_2 = slic_x[i] # y_temp_2 = slic_y[i] # center_x_2 = get_canter(x_temp_2) # center_y_2 = get_canter(y_temp_2) # gradient = get_gradient(center_x,center_y,center_x_2,center_y_2) sample.append((center_x,center_y,center_x_2,center_y_2)) return sample,slic_x,slic_y #根据端点值产生直线 def get_line(x1,y1,x2,y2): print(x1,y2,x2,y2) a = (y1-y2)/(x1-x2) b = y1 - ((y1-y2)*x1)/(x1-x2) return a,b #计算残差 def get_residual(y_pre,y): error = y_pre - y return sum(abs(error))

    四、主过程

    sample,slic_x,slic_y = get_sample(x_data,y_data,0.5) #梯度选择阈值设置为0.5 # 获取到端点之后计算直线 line = [] print('产生直线',len(sample),'条') for i in range(len(sample)): (a,b) = get_line(sample[i][0],sample[i][1],sample[i][2],sample[i][3]) line.append((a,b)) print(line) 产生直线 19 条 -9.595959595959595 -79.85752123829874 -8.585858585858587 -79.85752123829874 -8.585858585858587 -49.20107410190057 -7.575757575757576 -49.20107410190057 -7.575757575757576 -26.799814284050584 -6.565656565656566 -26.799814284050584 -6.565656565656566 -11.41700960219479 -5.555555555555555 -11.41700960219479 -5.555555555555555 -1.8159278737791187 -4.545454545454546 -1.8159278737791187 -4.545454545454546 3.2401630837504722 -3.5353535353535355 3.2401630837504722 -3.5353535353535355 4.987995452948009 -2.525252525252525 4.987995452948009 -2.525252525252525 4.6643014163675325 -1.5151515151515156 4.6643014163675325 -1.5151515151515156 3.50581315656308 -0.5050505050505052 3.50581315656308 -0.5050505050505052 2.8090547346745693 0.5050505050505052 2.8090547346745693 0.5050505050505052 3.6313826974984007 1.5151515151515156 3.6313826974984007 1.5151515151515156 7.388904863346241 2.525252525252524 7.388904863346241 2.525252525252524 15.25856153618626 3.5353535353535346 15.25856153618626 3.5353535353535346 28.477084898572492 4.545454545454545 28.477084898572492 4.545454545454545 48.281207133058984 5.555555555555555 48.281207133058984 5.555555555555555 75.90766042219971 6.565656565656564 75.90766042219971 6.565656565656564 112.59317694854879 7.575757575757574 112.59317694854879 7.575757575757574 159.57448889466025 8.585858585858585 159.57448889466025 8.585858585858585 218.08832844308813 9.595959595959595 218.08832844308813 [(39.74688297112539, 261.40359518045454), (30.34988266503418, 180.7222794210857), (22.177247219671482, 118.80837453197431), (15.228976635037231, 73.18841614801205), (9.50507091113152, 41.38893990409144), (5.005530047954293, 20.936481435104035), (1.7303540455055606, 9.357576375941846), (-0.32045709621467156, 4.178760361496818), (-1.1469033772064074, 2.9265690266608537), (-0.6897908376696255, 3.1574339456188247), (0.8141046831955928, 2.3978907532626534), (3.7199469441893673, -2.0049005512733693), (7.7909601061116165, -12.285236818753795), (13.086338128762366, -31.00627023216553), (19.606081012141622, -60.6414651566167), (27.35018875624936, -103.66428595721524), (36.318661361085574, -162.54819699906915), (46.51149882665033, -239.76666264728698), (57.92870115294359, -337.7931472669765)] #得到了a,b的值,开始画线,10个点产生9条线,并计算残差 error = [] fig, ax = plt.subplots() for i in range(len(line)): (a,b) = line[i] y_pre = a*slic_x[i] + b error.append(get_residual(y_pre,slic_y[i])) ax.scatter(slic_x[i],y_pre) ax.scatter(slic_x[len(line)],(a*slic_x[len(line)] + b)) #补上最后一部分数据 plt.show()

    五、误差分析

    #得到误差error,每段误差画出柱状图 fig, ax = plt.subplots() print(len(error)) plt.bar(range(len(line)), error) plt.show() 9

    可以看到边缘直线的误差比较大,这理论上也是可以理解的,曲线值变化大的时候,偏差会变大 #分析其中一段内部的误差 # 计算slice【0】的 def get_error_individual(line,k,slice_x,slice_y): error = [] a,b = line[k] for i in range(len(slic_x[k])): pre = a * slic_x[k][i] + b error.append(abs(pre - slice_y[k][i])) return error error = get_error_individual(line,3,slic_x,slic_y) #画图 fig, ax = plt.subplots() plt.bar(range(len(error)), error) plt.show()

    不太像正态分,从该方法的实现上就可以看出,因为采样选的是中位数点,所以不像
    Processed: 0.028, SQL: 8