Gradient Descent 梯度下降法

梯度下降法可以帮助我们找到某个函数的极小值或者最小值。

对于一般问题,梯度下降法的好处是不需要进行公式推导,这是许多工程师的福音。

对于高维(>1000维)问题,梯度下降法(以及其衍生方法)是最常用的一类优化方法。

本页演示代码都在这里。这个教程是写给有计算机科学基础的同学的,不以讨厌的数学符号而以代码来诠释概念。

例1 导数下降法 Derivative Descent

已知 x^3 + 2x + e^x - 3 = 0,求x。

这个题目实际上是要我们找到一个数x,使得等号左边的式子,尽量接近等号右边的零(如果相等最好)。

直接求解析解需要进行痛苦的数学推导,令很多人望而却步。

这时我们可以用导数下降法求数值解。下面以Python作演示。

首先我们编写一个函数,用于计算 x^3 + 2x + e^x - 3:

def problem(x):
    e = 2.71828182845904590
    return x**3 + 2*x + e**x - 3

我们称上面的函数为problem。我们的目的就是找到数x,令problem(x) = 0。

接着,我们设计误差函数error。error(x) 计算的是 problem(x) 与 0 的差的平方:

def error(x):
    return (problem(x)-0)**2

error函数的图像,看上去像个碗一样,中间是凹的,两边翘起。这个碗的最低点,也就是 error(x) 的最小值,就是我们要寻找的点。

我们发现:

因此,如果我们在:

就可以逐渐让 x 到达 error 函数的最小值点,使得error(x) = 0,同时也就使得 problem(x) = 0。

为了实现上面的过程,我们可以先求 error 函数在 x 处的导数,然后将 x 减去这个导数。下面给出“导数下降”的过程:

def derivative_descent(x):
    delta = 0.00000001
    derivative = (error(x + delta) - error(x - delta)) / (delta * 2)

    alpha = 0.01
    x = x - derivative * alpha
    return x

解释:

我们可以先给x设定一个初始值(比如x = 0),然后重复运行这个下降过程。

x = 0.0
for i in range(50):
    x = derivative_descent(x)
    print('x = {:6f}, problem(x) = {:6f}'.format(x,problem(x)))

结果:

x = 0.120000, problem(x) = -1.630775
x = 0.223414, problem(x) = -1.291683
x = 0.311250, problem(x) = -0.982215
x = 0.383065, problem(x) = -0.710885
x = 0.438614, problem(x) = -0.487835
x = 0.478886, problem(x) = -0.318127
x = 0.506260, problem(x) = -0.198652
x = 0.523852, problem(x) = -0.120019
x = 0.534682, problem(x) = -0.070872
x = 0.541152, problem(x) = -0.041236
x = 0.544943, problem(x) = -0.023775
x = 0.547138, problem(x) = -0.013634
x = 0.548399, problem(x) = -0.007794
x = 0.549121, problem(x) = -0.004447
x = 0.549534, problem(x) = -0.002535
x = 0.549769, problem(x) = -0.001444
x = 0.549903, problem(x) = -0.000822
x = 0.549979, problem(x) = -0.000468
x = 0.550023, problem(x) = -0.000267
x = 0.550047, problem(x) = -0.000152
x = 0.550062, problem(x) = -0.000086
x = 0.550070, problem(x) = -0.000049
x = 0.550074, problem(x) = -0.000028
x = 0.550077, problem(x) = -0.000016
x = 0.550078, problem(x) = -0.000009
x = 0.550079, problem(x) = -0.000005
x = 0.550080, problem(x) = -0.000003
x = 0.550080, problem(x) = -0.000002
x = 0.550080, problem(x) = -0.000001
x = 0.550080, problem(x) = -0.000001
x = 0.550080, problem(x) = -0.000000
x = 0.550080, problem(x) = -0.000000

...

可以看到,导数下降法很快就找到了 x = 0.550080,使得 error(x) = 0,进而使得 problem(x) 也等于 0。如果运行的步数足够多,可以得到非常高精度的x。

最重要的是,在这个过程中我们不需要推导任何数学公式。

如果原函数很容易计算,我们可以向上面那样用定义法求导数,即 f'(x) = (f(x+d)-f(x-d)) / 2d .

如果计算原函数很困难,但原函数的导数很好计算,那么我们也可以用解析法求导数,然后应用导数下降法。

例2 梯度下降法 Gradient Descent

上面我们介绍的导数下降法,可以求一元函数 error(x) 的最小值。如果待求函数是多元函数(输入变量多于一个)呢?

已知 x0^5 + e^x1 + x0^3 + x0 + x1 - 5 = 0, 求 x0和x1.

首先定义problem(x),其中x是2维矢量(即x0和x1)。下面用x[0]代表x0,x[1]代表x1.

def problem(x):
    e = 2.71828182845904590
    return x[0]**5 + e**x[1] + x[0]**3 + x[0] + x[1] - 5

def error(x):
    return (problem(x)-0)**2

图中可以看到,由于x是2维矢量,problem(x)形成了一个曲面。

下图的error函数则是一个2维的山谷,谷底是最低点,那里是函数的解。利用梯度下降法,我们可以让x逐渐靠近谷底。

下面是“梯度下降”的过程:

def gradient_descent(x):
    delta = 0.00000001

    derivative_x0 = (error([x[0] + delta, x[1]]) - error([x[0] - delta, x[1]])) / (delta * 2)
    derivative_x1 = (error([x[0], x[1] + delta]) - error([x[0], x[1] - delta])) / (delta * 2)

    alpha = 0.01
    x[0] = x[0] - derivative_x0 * alpha
    x[1] = x[1] - derivative_x1 * alpha
    return [x[0],x[1]]

因为待求变量有两个,因此这里的导数也有两个,分别是error关于x0的偏导数和error关于x1的偏导数。什么是偏导数?这么说吧,error关于x0的偏导数,就是当其他变量(比如x1, x2, x3...)保持不变的时候,error函数在x0处的导数。

除此之外,梯度下降法和导数下降法的过程是完全一样的。

我们给x设定一个初始值,并重复运行这个梯度下降过程:

x = [0.0, 0.0]
for i in range(50):
    x = gradient_descent(x)
    print('x = {:6f},{:6f}, problem(x) = {:6f}'.format(x[0],x[1],problem(x)))

结果:

x = 0.080000, 0.160000, problem(x) = -3.585974
x = 0.153111, 0.315883, problem(x) = -3.155862
x = 0.220841, 0.465564, problem(x) = -2.709388
x = 0.283601, 0.606068, problem(x) = -2.252478
x = 0.340978, 0.733703, problem(x) = -1.798288
x = 0.391919, 0.844577, problem(x) = -1.367064
x = 0.435085, 0.935541, problem(x) = -0.982829
x = 0.469426, 1.005295, problem(x) = -0.666329
x = 0.494799, 1.055039, problem(x) = -0.427279
x = 0.512182, 1.088128, problem(x) = -0.261371
x = 0.523322, 1.108874, problem(x) = -0.154290
x = 0.530100, 1.121313, problem(x) = -0.088886
x = 0.534078, 1.128546, problem(x) = -0.050422
x = 0.536360, 1.132672, problem(x) = -0.028339
x = 0.537650, 1.134998, problem(x) = -0.015841
x = 0.538374, 1.136300, problem(x) = -0.008828
x = 0.538779, 1.137027, problem(x) = -0.004911
x = 0.539004, 1.137431, problem(x) = -0.002729
x = 0.539129, 1.137656, problem(x) = -0.001516
x = 0.539199, 1.137781, problem(x) = -0.000842
x = 0.539237, 1.137850, problem(x) = -0.000467
x = 0.539259, 1.137889, problem(x) = -0.000259
x = 0.539271, 1.137910, problem(x) = -0.000144
x = 0.539277, 1.137922, problem(x) = -0.000080
x = 0.539281, 1.137929, problem(x) = -0.000044
x = 0.539283, 1.137932, problem(x) = -0.000025
x = 0.539284, 1.137934, problem(x) = -0.000014
x = 0.539285, 1.137936, problem(x) = -0.000008
x = 0.539285, 1.137936, problem(x) = -0.000004
x = 0.539285, 1.137937, problem(x) = -0.000002
x = 0.539285, 1.137937, problem(x) = -0.000001
x = 0.539285, 1.137937, problem(x) = -0.000001
x = 0.539285, 1.137937, problem(x) = -0.000000
x = 0.539285, 1.137937, problem(x) = -0.000000

可以看到,这次我们也成功求解。

什么是梯度?

一元函数f(x)在点x处的导数,称为f在x处的导数;

而N元函数f(x0,x1,x2...xn)在点x = [x0,x1,x2...xn] 处,共有N个偏导数,分别是f在该点关于x0的偏导数,关于x1的偏导数……

把这N个偏导数,组合成一个N维矢量,称为函数f在点x处的梯度。

在教科书中,梯度一般记为倒三角,因此梯度下降法的每一步通常会这么写:

θ := θ - α∇θ

在python中,如果使用numpy库,把梯度作为矢量进行运算,之前的代码可以简写为:

x = x - gradient * alpha

例3 numpy库梯度下降法

已知 x0 * x1 + x0 * x2^2 + x1 + x1^3 + x2 + x2^5 + x3 + x3^7 - 15 = 0, 求 x0, x1, x2和x3.

numpy是Python最重要的函数库之一。

有了numpy库,我们就可以用简洁的代码(用一个矢量代表多个变量)实现梯度下降法。

import numpy as np
def problem(x):
    return x[0]*x[1] + x[0]*(x[2]**2) + x[1] + x[1]**3 + x[2] + x[2]**5 + x[3] + x[3]**7 - 15

def error(x):
    return (problem(x)-0)**2

def gradient_descent(x):
    delta = 0.00000001
    gradient = np.zeros(x.shape) # 梯度矢量

    for i in range(len(gradient)): # 逐个求取偏导数,放进梯度矢量
        deltavector = np.zeros(x.shape)
        deltavector[i] = delta

        gradient[i] = (error(x+deltavector)-error(x-deltavector)) / (delta*2)

    alpha = 0.001
    x = x - gradient * alpha
    return x

x = np.array([0.0, 0.0, 0.0, 0.0])
for i in range(50):
    x = gradient_descent(x)
    print('x = {}, problem(x) = {:6f}'.format(x,problem(x)))

结果:

x = [ 0.    0.03  0.03  0.03], problem(x) = -14.909973
x = [ 0.00092144  0.05990046  0.05982007  0.05981995], problem(x) = -14.820185
x = [ 0.00280297  0.0898872   0.08946561  0.08946033], problem(x) = -14.730180
x = [ 0.00568689  0.12014423  0.11895018  0.11892079], problem(x) = -14.639463
x = [ 0.00961885  0.15085755  0.14829803  0.1482003 ], problem(x) = -14.547475
x = [ 0.01464791  0.18221879  0.17754634  0.17729741], problem(x) = -14.453574
x = [ 0.02082657  0.21442884  0.20674747  0.20621084], problem(x) = -14.357004
x = [ 0.02821105  0.24770165  0.23597107  0.2349403 ], problem(x) = -14.256859
x = [ 0.03686166  0.28226825  0.26530645  0.26348758], problem(x) = -14.152046
x = [ 0.04684326  0.31838109  0.2948653   0.29185797], problem(x) = -14.041226
x = [ 0.05822582  0.35631888  0.32478498  0.32006192], problem(x) = -13.922748
x = [ 0.07108499  0.39639176  0.35523285  0.34811695], problem(x) = -13.794551
x = [ 0.08550256  0.43894696  0.38641194  0.37604976], problem(x) = -13.654041
x = [ 0.10156685  0.48437466  0.41856863  0.40389843], problem(x) = -13.497923
x = [ 0.11937261  0.53311359  0.45200298  0.43171468], problem(x) = -13.321963
x = [ 0.13902038  0.58565548  0.4870829   0.45956607], problem(x) = -13.120673
x = [ 0.16061453  0.64254666  0.52426338  0.48753792], problem(x) = -12.886867
x = [ 0.18425932  0.70438334  0.56411284  0.5157345 ], problem(x) = -12.611030
x = [ 0.21005156  0.77179502  0.60734891  0.54427884], problem(x) = -12.280455
x = [ 0.2380674   0.8454054   0.65488617  0.57330938], problem(x) = -11.878001
x = [ 0.26833923  0.92575291  0.70789738  0.60297019], problem(x) = -11.380361
x = [ 0.30081586  1.01314031  0.76788356  0.63338792], problem(x) = -10.755628
x = [ 0.33529379  1.10736339  0.83672797  0.66462178], problem(x) = -9.959930
x = [ 0.37129848  1.20724266  0.91664437  0.69655962], problem(x) = -8.933138
x = [ 0.40787932  1.30985946  1.00973995  0.72871095], problem(x) = -7.595421
x = [ 0.44326538  1.40943656  1.11640043  0.75982431], problem(x) = -5.856840
x = [ 0.47437442  1.49615059  1.23068688  0.78731666], problem(x) = -3.697850
x = [ 0.49664095  1.55671969  1.33154595  0.80704267], problem(x) = -1.469685
x = [ 0.50642827  1.58248846  1.38457365  0.81566705], problem(x) = -0.153432
x = [ 0.50750215  1.58525614  1.39094956  0.8166065 ], problem(x) = 0.011811
x = [ 0.507419    1.58504244  1.39045048  0.81653385], problem(x) = -0.001176
x = [ 0.50742727  1.58506371  1.3905001   0.81654108], problem(x) = 0.000115
x = [ 0.50742646  1.58506163  1.39049525  0.81654037], problem(x) = -0.000011
x = [ 0.50742654  1.58506184  1.39049573  0.81654044], problem(x) = 0.000001
x = [ 0.50742654  1.58506182  1.39049568  0.81654043], problem(x) = -0.000000
x = [ 0.50742654  1.58506182  1.39049568  0.81654043], problem(x) = 0.000000

file: gd.md

last modified: 2017-03-04 20:02