Python科学实验----最小二乘法
最小二乘法应用场景很多,这里我们要演示如何通过最小二乘法将给定的点拟合成比较接近的直线或者曲线。
我们先以直线y=ax+b分析。问题则变为,已知一系列点,求a,b,使得直线在某种意义上是最优解。
最小二乘则是将某种意义具化了,具化的方法其实十分常见。那就是我们要使所有的点到直线的总距离最短。不过,这样可能会出现正负相消的情况,不能反映出来真正的总距离,这样的解也会有无数个。当然,我们可以将每个点到直线的距离都取正(个人没有推导出来)。不过也可以用平方和作为总距离,让这些点到直线的距离的平方和最短。
为何要用平方和作为量度标准?用平方的好处在于我们可以放大点和直线之间的距离差异。接近的点更接近,远离的点更远离。这应该就是其精妙之处吧。
下面来推导最小二乘法的公式。
假设这些点的序列为(x1,y1),(x2,y2),(x3,y3)...(xn,yn),设我们要拟合的方程为y=ax+b形式,S(a,b)为总距离
S(a,b)=(x1*a+b-y1)^2+(x2*a+b-y2)^2+.....+(xn*a+b-yn)^2
如果我们要求S(a,b)的最小值,需要对S(a,b)求导。使其对a和b的偏导数同时为0。
S(a,b)对a的偏导=2*(x1*a+b-y1)*x1+.....+2*(xn*a+b-yn)*xn=2*[(x1*x1+x2*x2+......+xn*xn)+2(x1+x2+.....+xn)b-(x1*y1+x2*y2+......+xn*yn)]=0
S(a,b)对b的偏导=2*(x1*a+b-y1)*1+....+2*(xn*a+b-yn)*1=2*[(x1+x2+.....+xn)+2*n*b-(y1+y2+......+yn)]=0
联立方程组求出a,b。
推导了最小二乘法,那么让我们用python来实现它。个人选择python实现也是因为它有matplotlib,提供了和matlab一样的画图功能,十分直观。
1 import matplotlib.pyplot as plt 2 import numpy as np 3 import random 4 #points=[(0,1),(1,2.5),(2,2),(3,1),(4,3),(5,5),(6,5.5),(7,4),(8,3)] 5 #points=[(10,62),(20,68),(30,75),(40,81),(60,89),(60,95),(70,102),(80,108),(90,115),(100,122)] 6 points=[] 7 #generate our experiment data 8 for i in range(1,10): 9 points.append((i,((5.5-random.random())*i+1+random.random()-0.5))) 10 11 def least_squares_line(points): 12 if len(points)==0: 13 return (0,0) 14 k1=k2=k3=k4=0 15 for point in points: 16 x=point[0] 17 y=point[1] 18 k1=x*x+k1 19 k2=x+k2 20 k3=x*y+k3 21 k4=y+k4 22 #I don‘t check the special condition. 23 n=len(points) 24 a=(n*k3-k4)/(n*k1-k2) 25 b=(k1*k4-k2*k3)/(k1*n*k2-k2*k2) 26 return (a,b) 27 28 def convert_points_format(points): 29 convertx=[] 30 converty=[] 31 for point in points: 32 convertx.append(point[0]) 33 converty.append(point[1]) 34 return (convertx,converty) 35 36 37 (a,b)=least_squares_line(points) 38 if b>=0: 39 print "y=%fx+%f"%(a,b) 40 else: 41 print "y=%fx%f"%(a,b) 42 43 (convertx,converty)=convert_points_format(points) 44 #show plot 45 x=np.arange(0,10,0.1) 46 y=np.add(np.multiply(x,a),b) 47 plt.plot(x,y,‘g‘) 48 plt.plot(convertx,converty,‘.‘) 49 plt.grid(True) 50 plt.show()
同理,我们可以得到曲线的拟合。
1 import matplotlib.pyplot as plt 2 import numpy as np 3 import random 4 import numpy.linalg as lina 5 points=[] 6 for i in range(1,10): 7 points.append((i+random.random(),8*i*i+9)) 8 9 def least_squares_curve(points): 10 if len(points)==0: 11 return (0,0) 12 k1=k2=k3=k4=0 13 for point in points: 14 x=point[0] 15 y=point[1] 16 k1=x*x*x*x+k1 17 k2=x*x+k2 18 k3=x*x*y+k3 19 k4=y+k4 20 #I don‘t check the special condition. 21 n=len(points) 25 return lina.solve([[k1,k2],[k2,n]],[k3,k4]) 26 def convert_points_format(points): 27 convertx=[] 28 converty=[] 29 for point in points: 30 convertx.append(point[0]) 31 converty.append(point[1]) 32 return (convertx,converty) 33 34 35 (a,b)=least_squares_curve(points) 36 if b>=0: 37 print "y=%fx^2+%f"%(a,b) 38 else: 39 print "y=%fx^2%f"%(a,b) 40 41 42 (convertx,converty)=convert_points_format(points) 43 #show plot 44 x=np.arange(0,10,0.01) 45 y=np.add(np.multiply(np.multiply(x,x),a),b) 46 plt.plot(x,y,‘g‘) 47 plt.plot(convertx,converty,‘.‘) 48 plt.grid(True) 49 plt.show()
郑重声明:本站内容如果来自互联网及其他传播媒体,其版权均属原媒体及文章作者所有。转载目的在于传递更多信息及用于网络分享,并不代表本站赞同其观点和对其真实性负责,也不构成任何其他建议。