当前位置: 代码迷 >> GIS >> 机器学习-Logistic回归算法
  详细解决方案

机器学习-Logistic回归算法

热度:434   发布时间:2016-05-05 06:06:58.0
机器学习--Logistic回归算法

一、基本原理
        假设有一些数据点,用一条直线对这些点进行拟合(该线称为最佳拟合直线),这个拟合过程就称作回归。训练分类器就是为了寻找最佳拟合参数,使用的是最优化算法。      

        现实生活中有一些情况,如判断邮件是否为垃圾邮件,判断患者癌细胞为恶性的还是良性的,以及预测患有疝病的马的存活问题等,这就属于分类问题了,是线性回归无法解决的。这里以线性回归为基础,讲解Logistic回归实现分类的思想。      

         Logistic回归想要的函数是,能够接受所有的输入然后预测出类别。这个函数就是Sigmoid函数,具体计算公式如下:u(z) = 1/(1+e^(-z))。      

         当z为0时,Sigmoid函数值为0.5。随着z的增大,Sigmoid值将逼近于1;而随着z的减小,Sigmoid值将逼近于0。如果坐标刻度足够大,Sigmoid函数看起来很像一个阶跃函数。因此,为了实现Logistic回归分类器,我们可以在每个特征上都乘以一个回归系数,然后把所有的结果值相加,将这个总和代入Sigmoid函数中,进而得到一个范围在0~1之间的数值。任何大于0.5的数据被分入1类,小于0.5的数据被分入0类。所以,Logistic回归也可以看成是一种概率估计。     

         确定了分类器的函数形式后,现在的问题就变成了:最佳回归系数是多少?如何确定它的大小?

         Sigmoid函数的输入记为z,z = w(T)x,其中向量x是分类器的输入数据,w就是我们要找的最佳回归系数。为了寻找该最佳系数,需要学习最优化理论的一些知识。
 
         梯度上升法:要找到某函数的最大值,最好的方法是沿着该函数的梯度方向探寻。具体如下:从p0开始,计算完该点的梯度,函数就根据梯度移到下一个点p1。在p1点,梯度再次被重新计算,并沿新的梯度方向移到p2。如此循环迭代,直到满足停止条件。迭代过程中,梯度算子总是保证我们能选取到最佳的移动方向。

 
二、梯度上升算法流程
        每个回归系数初始化为1

        重复N次:

               计算整个数据集梯度

               使用alpha * gradient更新回归系数的向量

               返回回归系数
三、算法的特点
        优点:计算代价不高,易于理解和实现。
        缺点:容易欠拟合,分类精度可能不高。
        适用数据范围:数值型和标称型。
 
四、python代码实现

1、获取文本文件,每行前两个值为分类器输入数据,第三个值是数据对应的类别标签。为了方便计算,将分类器输入数据增加一列,设为X0,初始值为1.0。

########################################
#功能:读取文本文件
#输入变量:无
#输出变量:data_mat, label_mat 数据集,类别标签
########################################
def load_data_set():

    data_mat = []
    label_mat = []

    fr = open('testSet.txt')
    for line in fr.readlines():
        line_arr = line.strip().split()

        # 每行前两个值分别为X1和X2,X0设为1.0
        data_mat.append([1.0, float(line_arr[0]), float(line_arr[1])])

        label_mat.append(int(line_arr[2]))
    return data_mat, label_mat

 

2、Sigmoid函数

########################################
#功能:Sigmoid函数
########################################
def sigmoid(inx):
    return 1.0/(1 + exp(-inx))

 

3、梯度上升算法的具体实现

########################################
#功能:梯度上升算法
#输入变量:data_mat_in, class_labels 数据集,类别标签
#输出变量:weights 回归系数
########################################
def grad_ascent(data_mat_in, class_labels):

    data_matrix = mat(data_mat_in)  # data_mat_in为100*3
    label_matrix = mat(class_labels).transpose()  # 对矩阵进行转置,成为100*1
    m, n = shape(data_mat_in)

    alpha = 0.001  # 向目标移动的步长
    max_cycles = 500  # 迭代次数
    weights = ones((n, 1))  # 回归系数,3*1

    for k in xrange(max_cycles):

        h = sigmoid(data_matrix*weights)  # h不是一个数,是一个列向量,100*1
        error = (label_matrix - h)  # 计算梯度
        weights += alpha * data_matrix.transpose() * error

    return weights

 

4、考虑到梯度上升算法的计算复杂度太高,人们提出一种改进方法是一次仅用一个样本点来更新回归系数,该方法称为随机梯度上升算法。这种算法可以在新样本到来时,对分类器进行增量式更新。与梯度上升算法不同的是:梯度上升算法中的变量h和error都是向量,并且有矩阵转换过程。而这里对应的变量全是数值,数据处理的变量类型全是NumPy数组。

########################################
#功能:随机梯度上升算法
########################################
def rand_grad_ascent0(data_matrix, class_labels):

    m, n = shape(data_matrix)
    alpha = 0.01
    weights = ones(n)  # 1*3列的向量

    for i in xrange(m):

        h = sigmoid(sum(data_matrix[i] * weights))  # h是一个数值,不是向量
        error = class_labels[i] - h   # error是一个数值
        weights += alpha * error * data_matrix[i]

    return weights

 

由于训练集中存在一些不能正确分类的样本点,从而使得随机梯度上升算法在每次迭代时会引发系数的剧烈改变。我们期望算法能够避免来回波动,收敛到某个值。于是,对随机梯度上升算法进行了改进。

########################################
#功能:随机梯度上升优化算法
########################################
def rand_grad_ascent1(data_matrix, class_labels, num_iterator):

    m, n = shape(data_matrix)
    weights = ones(n)  # 1*3列的向量

    for j in xrange(num_iterator):

        data_index = range(m)

        for i in xrange(m):

            alpha = 4/(1.0+j+i) + 0.001
            rand_index = int(random.uniform(0, len(data_index)))
            h = sigmoid(sum(data_matrix[rand_index] * weights))
            error = class_labels[rand_index] - h
            weights += alpha * error * data_matrix[rand_index]
            del(data_index[rand_index])

    return weights

改进算法中主要做了三个方面处理:

1)alpha在每次迭代的时候都会调整,这会缓解数据波动或者高频波动。

2)通过随机选取样本来更新回归系数,这样可以减少周期性波动。

3)增加了一个迭代参数。

  相关解决方案