1. 多项式拟合问题

  多项式拟合(polynominal curve fitting)是一种线性模型,模型和拟合参数的关系是线性的。多项式拟合的输入是一维的,即 x=x x = x <script type="math/tex" id="MathJax-Element-341">\textbf{x} = x</script>,这是多项式拟合和线性回归问题的主要区别之一。

  多项式拟合的目标是构造输入 x x <script type="math/tex" id="MathJax-Element-342">x</script>的 M <script type="math/tex" id="MathJax-Element-343">M</script>阶多项式函数,使得该多项式能够近似表示输入 x x <script type="math/tex" id="MathJax-Element-344">x</script>和输出 y <script type="math/tex" id="MathJax-Element-345">y</script>的关系,虽然实际上 x x <script type="math/tex" id="MathJax-Element-346">x</script>和 y <script type="math/tex" id="MathJax-Element-347">y</script>的关系并不一定是多项式,但使用足够多的阶数,总是可以逼近表示输入 x x <script type="math/tex" id="MathJax-Element-348">x</script>和输出 y <script type="math/tex" id="MathJax-Element-349">y</script>的关系的。

  多项式拟合问题的输入可以表示如下:

D={(x1,y1),(x2,y2),...,(xi,yi),...,(xN,yN)}xiRyiR D = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . , ( x i , y i ) , . . . , ( x N , y N ) } x i ∈ R y i ∈ R
<script type="math/tex; mode=display" id="MathJax-Element-10">\begin{equation*} D = \{(x_1,y_1),(x_2,y_2),...,(x_i,y_i),...,(x_N,y_N)\} \\ x_i \in R \\ y_i \in R \end{equation*}</script>

  目标输出是得到一个多项式函数:

f(x)=w1x1+w2x2+wixi+...+wMxM+b=(i=1Mwixi)+b f ( x ) = w 1 x 1 + w 2 x 2 + w i x i + . . . + w M x M + b = ( ∑ i = 1 M w i x i ) + b
<script type="math/tex; mode=display" id="MathJax-Element-364">\begin{equation*} \begin{aligned} f(x) &= w_1 x^1 + w_2 x^2 + w_i x^i + ... + w_Mx^M + b\\ &= (\sum_{i=1}^{M} w_i x^i) + b \end{aligned} \end{equation*}</script>

其中 M M <script type="math/tex" id="MathJax-Element-365">M</script>表示最高阶数为 M <script type="math/tex" id="MathJax-Element-366">M</script>。

  可见在线性拟合的模型中,共包括了 (M+1) ( M + 1 ) <script type="math/tex" id="MathJax-Element-367">(M+1)</script>个参数,而该模型虽然不是输入 x x <script type="math/tex" id="MathJax-Element-368">x</script>的线性函数,但却是 ( M + 1 ) <script type="math/tex" id="MathJax-Element-369">(M+1)</script>个拟合参数的线性函数,所以称多项式拟合为线性模型。对于多项式拟合问题,其实就是要确定这 (M+1) ( M + 1 ) <script type="math/tex" id="MathJax-Element-370">(M+1)</script>个参数,这里先假设阶数 M M <script type="math/tex" id="MathJax-Element-371">M</script>是固定的( M <script type="math/tex" id="MathJax-Element-372">M</script>是一个超参数,可以用验证集来确定 M M <script type="math/tex" id="MathJax-Element-373">M</script>最优的值,详细的关于 M <script type="math/tex" id="MathJax-Element-374">M</script>值确定的问题,后面再讨论),重点就在于如何求出这 (M+1) ( M + 1 ) <script type="math/tex" id="MathJax-Element-375">(M+1)</script>个参数的值。

2.优化目标

  多项式拟合是利用多项式函数逼近输入 x x <script type="math/tex" id="MathJax-Element-380">x</script>和输出 y <script type="math/tex" id="MathJax-Element-381">y</script>的函数关系,通过什么指标来衡量某个多项式函数的逼近程度呢?(其实这就是误差/损失函数)。拟合/回归问题常用的评价指标是均方误差(在机器学习中的模型评估与度量博客中,我进行了介绍)。多项式拟合问题也同样采用该评价指标,以均方误差作为误差/损失函数,误差函数越小,模型越好。

E(w,b)=1Ni=1N[f(xi)yi]2 E ( w , b ) = 1 N ∑ i = 1 N [ f ( x i ) − y i ] 2
<script type="math/tex; mode=display" id="MathJax-Element-46">\begin{equation*} E(\mathbf {w}, b) = \frac{1}{N} \sum_{i=1}^{N} {\lbrack f(x_i) - y_i\rbrack}^2 \end{equation*}</script>

  系数 1N 1 N <script type="math/tex" id="MathJax-Element-47">\frac{1}{N}</script>是一常数,对优化结果无影响,可以去除,即将均方误差替换为平方误差:

E(w,b)=i=1N[f(xi)yi]2 E ( w , b ) = ∑ i = 1 N [ f ( x i ) − y i ] 2
<script type="math/tex; mode=display" id="MathJax-Element-52">\begin{equation*} E(\mathbf{w}, b) = \sum_{i=1}^{N} {\lbrack f(x_i) - y_i\rbrack}^2 \end{equation*}</script>
   到这里,就成功把多项式拟合问题变成了最优化问题,优化问题可表示为:

argminw,bE(w,b) arg ⁡ min w , b ⁡ E ( w , b )
<script type="math/tex; mode=display" id="MathJax-Element-388">\begin{equation*} \mathop {\arg \min} _ {\mathbf{w} , b} E(\mathbf{w},b) \end{equation*}</script>

即需要求得参数 {w1,...,wM,b} { w 1 , . . . , w M , b } <script type="math/tex" id="MathJax-Element-389">\{w_1, ...,w_M, b\}</script>的值,使得 E(w,b) E ( w , b ) <script type="math/tex" id="MathJax-Element-390">E(\mathbf{w}, b)</script>最小化。那么如何对该最优化问题求解呢?

3. 优化问题求解

3.1 求偏导,联立方程求解

   直观的想法是,直接对所有参数求偏导,令偏导为0,再联立这 M+1 M + 1 <script type="math/tex" id="MathJax-Element-557">M+1</script>个方程求解(因为共有 M+1 M + 1 <script type="math/tex" id="MathJax-Element-558">M+1</script>个参数,故求偏导后也是得到 M+1 M + 1 <script type="math/tex" id="MathJax-Element-559">M+1</script>个方程)。

E(w,b)=i=1N[f(xi)yi]2=i=1N[(w1x1i+w2x2i+wixji+...+wMxMi+b)yi]2 E ( w , b ) = ∑ i = 1 N [ f ( x i ) − y i ] 2 = ∑ i = 1 N [ ( w 1 x i 1 + w 2 x i 2 + w i x i j + . . . + w M x i M + b ) − y i ] 2
<script type="math/tex; mode=display" id="MathJax-Element-518">\begin{equation*} \begin{aligned} E(\mathbf{w}, b) &= \sum_{i=1}^{N} {\lbrack f(x_i) - y_i\rbrack}^2 \\ &= \sum_{i=1}^{N} {\lbrack (w_1 x_i^1 + w_2 x_i^2 + w_i x_i^j + ... + w_Mx_i^M + b) - y_i\rbrack}^2 \end{aligned} \end{equation*}</script>

利用 E(w,b) E ( w , b ) <script type="math/tex" id="MathJax-Element-519">E(\mathbf{w},b)</script>对各个参数求偏导,如下:

E(w,b)wjE(w,b)b=2i=1N[(w1x1i+w2x2i+wixji+...+wMxMi+b)yi]xji=2i=1N[(w1x1i+w2x2i+wixji+...+wMxMi+b)yi] ∂ E ( w , b ) ∂ w j = 2 ∑ i = 1 N [ ( w 1 x i 1 + w 2 x i 2 + w i x i j + . . . + w M x i M + b ) − y i ] x i j ∂ E ( w , b ) ∂ b = 2 ∑ i = 1 N [ ( w 1 x i 1 + w 2 x i 2 + w i x i j + . . . + w M x i M + b ) − y i ]
<script type="math/tex; mode=display" id="MathJax-Element-2129">\begin{equation*} \begin{aligned} \frac{\partial E(\mathbf{w}, b)}{\partial w_j} &= 2 \sum_{i=1}^{N} {\lbrack (w_1 x_i^1 + w_2 x_i^2 + w_i x_i^j + ... + w_Mx_i^M + b) - y_i\rbrack} x_i^j \\ \frac{\partial E(\mathbf{w}, b)}{\partial b} &= 2 \sum_{i=1}^{N} {\lbrack (w_1 x_i^1 + w_2 x_i^2 + w_i x_i^j + ... + w_Mx_i^M + b) - y_i\rbrack} \end{aligned} \end{equation*}</script>

求导之后,将各个点 (xi,yi) ( x i , y i ) <script type="math/tex" id="MathJax-Element-2130">(x_i,y_i)</script>的值带入偏导公式,联立方程求解即可。

  针对该解法,可以举个例子详细说明,比如有两个点 (2,3),(5,8) ( 2 , 3 ) , ( 5 , 8 ) <script type="math/tex" id="MathJax-Element-2131">(2,3),(5,8)</script>,需要利用二阶多项式 f(x)=w1x+w2x2+b f ( x ) = w 1 x + w 2 x 2 + b <script type="math/tex" id="MathJax-Element-2132">f(x) = w_1x + w_2x^2 + b</script>拟合。求解过程如下:

  1. 该二阶多项式对参数求偏导得到

    E(w,b)wjE(w,b)b=2i=12[(w1x1i+w2x2i+b)yi]xji=[(w1x1+w2x21+b)y1]xj1+[(w1x2+w2x22+b)y2]xj2=2i=12[(w1x1i+w2x2i+b)yi]=[(w1x1+w2x21+b)y1]+[(w1x2+w2x22+b)y2] ∂ E ( w , b ) ∂ w j = 2 ∑ i = 1 2 [ ( w 1 x i 1 + w 2 x i 2 + b ) − y i ] x i j = [ ( w 1 x 1 + w 2 x 1 2 + b ) − y 1 ] x 1 j + [ ( w 1 x 2 + w 2 x 2 2 + b ) − y 2 ] x 2 j ∂ E ( w , b ) ∂ b = 2 ∑ i = 1 2 [ ( w 1 x i 1 + w 2 x i 2 + b ) − y i ] = [ ( w 1 x 1 + w 2 x 1 2 + b ) − y 1 ] + [ ( w 1 x 2 + w 2 x 2 2 + b ) − y 2 ]
    <script type="math/tex; mode=display" id="MathJax-Element-2133">\begin{equation*} \begin{aligned} \frac{\partial E(\mathbf{w}, b)}{\partial w_j} &= 2 \sum_{i=1}^{2} {\lbrack (w_1 x_i^1 + w_2 x_i^2 + b) - y_i\rbrack} x_i^j \\ &= [(w_1x_1 + w_2x_1^2 + b) - y_1] x_1^j + [(w_1x_2 + w_2x_2^2 + b) - y_2] x_2^j\\ \frac{\partial E(\mathbf{w}, b)}{\partial b} &= 2 \sum_{i=1}^{2} {\lbrack (w_1 x_i^1 + w_2 x_i^2 + b) - y_i\rbrack} \\ &= [(w_1x_1 + w_2x_1^2 + b) - y_1] + [(w_1x_2 + w_2x_2^2 + b) - y_2]\\ \end{aligned} \end{equation*}</script>

  2. 将点 (2,3),(5,8) ( 2 , 3 ) , ( 5 , 8 ) <script type="math/tex" id="MathJax-Element-2134">(2,3),(5,8)</script>带入方程,可以得到3个方程,

    2b+7w1+29w2=117b+29w1+133w2=4629b+133w1+641w2=212 2 b + 7 w 1 + 29 w 2 = 11 7 b + 29 w 1 + 133 w 2 = 46 29 b + 133 w 1 + 641 w 2 = 212
    <script type="math/tex; mode=display" id="MathJax-Element-2135">\begin{equation*} \begin{aligned} 2b + 7w_1 + 29w_2= 11 \\ 7b + 29w_1 + 133 w_2 = 46 \\ 29b + 133w_1 + 641w_2 = 212 \end{aligned} \end{equation*}</script>

  3. 联立这三个方程求解,发现有无穷多的解,只能得到 3w1+21w2=5 3 w 1 + 21 w 2 = 5 <script type="math/tex" id="MathJax-Element-2136">3w_1 + 21w_2 = 5</script>,这三个方程是线性相关的,故没有唯一解。

  该方法通过求偏导,再联立方程求解,比较复杂,看着也很不美观。那么有没有更加方便的方法呢?

3.2 最小二乘法

   其实求解该最优化问题(平方和的最小值)一般会采用最小二乘法(其实最小二乘法和求偏导再联立方程求解的方法无本质区别,求偏导也是最小二乘法,只是这里介绍最小二乘的矩阵形式而已)。最小二乘法(least squares),从英文名非常容易想到,该方法就是求解平方和的最小值的方法。

  可以将误差函数以矩阵的表示( N N <script type="math/tex" id="MathJax-Element-3333">N</script>个点,最高 M <script type="math/tex" id="MathJax-Element-3334">M</script>阶)为:

Xwy2 ‖ X w − y ‖ 2
<script type="math/tex; mode=display" id="MathJax-Element-3526">\begin{equation*} {\lVert \mathbf{Xw - y} \rVert}_2 \end{equation*}</script>

其中,把偏置 b b <script type="math/tex" id="MathJax-Element-3527">b</script>融合到了参数 w <script type="math/tex" id="MathJax-Element-3528">\bf w</script>中,

w={b,w1,w2,...,wM} w = { b , w 1 , w 2 , . . . , w M }
<script type="math/tex; mode=display" id="MathJax-Element-3529">\begin{equation*} \mathbf{w} = \{b, w_1, w_2, ..., w_M\} \end{equation*}</script>

X X <script type="math/tex" id="MathJax-Element-3530">\mathbf X</script>则表示输入矩阵,

11...1x1x2...xNx21x22...x2N............xM1xM2...xMN [ 1 x 1 x 1 2 . . . x 1 M 1 x 2 x 2 2 . . . x 2 M . . . . . . . . . . . . . . . 1 x N x N 2 . . . x N M ]
<script type="math/tex; mode=display" id="MathJax-Element-3531">\begin{gather*} \begin{bmatrix} 1 & x_1 & x_1^2 &... &x_1^M \\ 1 & x_2 & x_2^2 & ... & x_2^M \\ ... & ... & ... & ... & ... \\ 1 & x_N & x_N^2 & ... & x_N^M \\ \end{bmatrix} \end{gather*}</script>

y y <script type="math/tex" id="MathJax-Element-3532">\mathbf y</script>则表示标注向量,

y={y1,y2,...,yN}T y = { y 1 , y 2 , . . . , y N } T
<script type="math/tex; mode=display" id="MathJax-Element-3534">\begin{equation*} \mathbf{y} = \{y_1,y_2,...,y_N\}^T \end{equation*}</script>

因此,最优化问题可以重新表示为

minwXwy2 min w ‖ X w − y ‖ 2
<script type="math/tex; mode=display" id="MathJax-Element-3535">\begin{equation*} \min_{\mathbf w} {\lVert \mathbf{Xw - y} \rVert}_2 \end{equation*}</script>

对其求导,

Xwy2w=(Xwy)T(Xwy)w=(wTXTyT)(Xwy)w=(wTXTXwyTXwwTXTy+yTy)w ∂ ‖ X w − y ‖ 2 ∂ w = ∂ ( X w − y ) T ( X w − y ) ∂ w = ∂ ( w T X T − y T ) ( X w − y ) ∂ w = ∂ ( w T X T X w − y T X w − w T X T y + y T y ) ∂ w
<script type="math/tex; mode=display" id="MathJax-Element-3552">\begin{equation*} \begin{aligned} \frac {\partial {\lVert \mathbf{Xw - y} \rVert}_2}{\partial \mathbf{w}} &= \frac{\partial (\mathbf{Xw} - \mathbf{y})^T (\mathbf{Xw} - \mathbf{y})}{\partial \mathbf{w}} \\ &= \frac{\partial (\mathbf{w}^T\mathbf{X}^T - \mathbf{y}^T) (\mathbf{Xw} - \mathbf{y})}{\partial \mathbf{w}}\\ &= \frac{\partial (\mathbf{w}^T\mathbf{X}^T\mathbf{Xw} - \mathbf{y}^T\mathbf{Xw} - \mathbf{w}^T\mathbf{X}^T\mathbf{y} + \mathbf{y}^T\mathbf{y}) }{\partial \mathbf{w}} \\ \end{aligned} \end{equation*}</script>

在继续对其求导之前,需要先补充一些矩阵求导的先验知识(常见的一些矩阵求导公式可以参见转载的博客https://blog.csdn.net/lipengcn/article/details/52815429),如下:

xTax=aaxx=aTxTAx=Ax+ATx ∂ x T a ∂ x = a ∂ a x ∂ x = a T ∂ x T A ∂ x = A x + A T x
<script type="math/tex; mode=display" id="MathJax-Element-3553">\begin{equation*} \frac{\partial \mathbf{x}^T\mathbf{a}}{\partial \mathbf{x}} = \mathbf{a} \\ \frac{\partial \mathbf{ax}}{\partial \mathbf{x}} = \mathbf{a}^T \\ \frac{\partial \mathbf{x}^T\mathbf{A}}{\partial \mathbf{x}} = \mathbf{Ax} + \mathbf{A}^T\mathbf{x} \end{equation*}</script>

根据上面的矩阵求导规则,继续进行损失函数的求导

Xwy2w=(wTXTXwyTXwwTXTy+yTy)w=XTXw+(XTX)Tw(yTX)TXTy=2XTXw2XTy ∂ ‖ X w − y ‖ 2 ∂ w = ∂ ( w T X T X w − y T X w − w T X T y + y T y ) ∂ w = X T X w + ( X T X ) T w − ( y T X ) T − X T y = 2 X T X w − 2 X T y
<script type="math/tex; mode=display" id="MathJax-Element-3568">\begin{equation*} \begin{aligned} \frac {\partial {\lVert \mathbf{Xw - y} \rVert}_2}{\partial \mathbf{w}} &= \frac{\partial (\mathbf{w}^T\mathbf{X}^T\mathbf{Xw} - \mathbf{y}^T\mathbf{Xw} - \mathbf{w}^T\mathbf{X}^T\mathbf{y} + \mathbf{y}^T\mathbf{y}) }{\partial \mathbf{w}} \\ &= \mathbf{X}^T\mathbf{Xw} + (\mathbf{X}^T\mathbf{X})^T\mathbf{w} - (\mathbf{y}^T\mathbf{X})^T - \mathbf{X}^T\mathbf{y} \\ &= 2 \mathbf{X}^T\mathbf{Xw} - 2\mathbf{X}^T\mathbf{y} \end{aligned} \end{equation*}</script>

其中 XTXw=(XTX)Tw X T X w = ( X T X ) T w <script type="math/tex" id="MathJax-Element-3569">\mathbf{X}^T\mathbf{Xw} = (\mathbf{X}^T\mathbf{X})^T\mathbf{w}</script>.令求导结果等于0,即可以求导问题的最小值。

2XTXw2XTy=0w=(XTX)1XTy 2 X T X w − 2 X T y = 0 w = ( X T X ) − 1 X T y
<script type="math/tex; mode=display" id="MathJax-Element-5424">\begin{equation*} \begin{aligned} 2 \mathbf{X}^T\mathbf{Xw} - 2\mathbf{X}^T\mathbf{y} = 0 \\ \mathbf{w} = (\mathbf{X}^T\mathbf{X})^{-1}\mathbf{X}^T\mathbf{y} \end{aligned} \end{equation*}</script>

  再利用最小二乘法的矩阵形式对前面的例子进行求解,用二阶多项式拟合即两个点 (2,3),(5,8) ( 2 , 3 ) , ( 5 , 8 ) <script type="math/tex" id="MathJax-Element-5425">(2,3),(5,8)</script>。

  1. 表示输入矩阵 X X <script type="math/tex" id="MathJax-Element-5426">\mathbf{X}</script>和标签向量 y y <script type="math/tex" id="MathJax-Element-5427">\mathbf{y}</script>

    X=[1125425]y=[38]T X = [ 1 2 4 1 5 25 ] y = [ 3 8 ] T
    <script type="math/tex; mode=display" id="MathJax-Element-5428">\begin{gather*} \mathbf{X} = \begin{bmatrix} 1 & 2& 4 \\ 1 & 5 & 25\\ \end{bmatrix} \\ \mathbf{y} = \begin{bmatrix} 3 & 8 \\ \end{bmatrix} ^T\\ \end{gather*}</script>

  2. 计算 XTX X T X <script type="math/tex" id="MathJax-Element-5429">\mathbf{X}^T\mathbf{X}</script>

    XTX=272972913329133641 X T X = [ 2 7 29 7 29 133 29 133 641 ]
    <script type="math/tex; mode=display" id="MathJax-Element-5430">\begin{gather*} \mathbf{X}^T\mathbf{X} = \begin{bmatrix} 2 & 7& 29 \\ 7 & 29 & 133\\ 29 & 133 & 641\\ \end{bmatrix} \\ \end{gather*}</script>

  3. 矩阵求逆,再做矩阵乘法运算
    XTX X T X <script type="math/tex" id="MathJax-Element-5431">\mathbf{X}^T\mathbf{X} </script>不可逆,故无唯一解。

  关于矩阵的逆是否存在,可以通过判断矩阵的行列式是否为0( det(A)=?0 d e t ( A ) = ? 0 <script type="math/tex" id="MathJax-Element-5432">det(\mathbf{A}) \stackrel{?}{=} 0</script> 来判断,也可以通过初等行变换,观察矩阵的行向量是否线性相关,在这个例子下,矩阵不可逆,故有无穷多解。但如果新增一个点 (4,7) ( 4 , 7 ) <script type="math/tex" id="MathJax-Element-5433">(4,7)</script>,则就可以解了。

  其实这和数据集的点数和选择的阶数有关,如果点数小于阶数则会出现无穷解的情况,如果点数等于阶数,那么刚好有解可以完全拟合所有数据点,如果点数大于阶数,则会求的近似解。

  那么对于点数小于阶数的情况,如何求解?在python的多项式拟合函数中是可以拟合的,而且效果不错,具体算法不是很了解,可以想办法参考python的ployfit()函数的实现。

4. 拟合阶数的选择

   在前面的推导中,多项式的阶数被固定了,那么实际场景下应该如何选择合适的阶数 M M <script type="math/tex" id="MathJax-Element-5569">M</script>呢?

  1. 一般会选择阶数 M <script type="math/tex" id="MathJax-Element-5570">M</script>小于点数 N N <script type="math/tex" id="MathJax-Element-5571">N</script>
  2. 把训练数据分为训练集合验证集,在训练集上,同时用不同的 M <script type="math/tex" id="MathJax-Element-5572">M</script>值训练多个模型,然后选择在验证集误差最小的阶数 M M <script type="math/tex" id="MathJax-Element-5573"></script>

    5. 后续

      如果后续还想写的话,可以考虑正则化问题。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐