关于机器学习:机器学习算法系列二-口袋算法Pocket-Algorithm

46次阅读

共计 1872 个字符,预计需要花费 5 分钟才能阅读完成。

浏览本文须要的背景知识点:感知器学习算法、一丢丢编程常识

一、引言

  后面一节咱们学习了机器学习算法系列(一)- 感知器学习算法(PLA),该算法能够将数据集完满的分成两种类型,但有一个前提条件就是假设数据集是线性可分的。

  在理论收集数据的过程中,可能因为各种各样的起因(例如反垃圾邮件的例子中收集的邮件单词谬误或者是人工分类谬误,将不是垃圾邮件的误认为是垃圾邮件)使得数据集中存在谬误数据,这时数据集就可能不是线性可分的,感知器学习算法是没有方法停下来的,所以人们又基于感知器学习算法设计了一个能够解决线性不可分的算法——口袋算法(Pocket Algorithm)

二、模型介绍

  口袋算法(Pocket Algorithm)是一个二元分类算法,将一个数据集通过线性组合的形式分成两种类型。如下图所示

  该算法是在感知器学习算法的根底上做的改良,其核心思想与感知器学习算法的思维统一,也是以谬误为驱动,如果以后后果比口袋中的后果好,则将口袋中的后果替换为以后后果,口袋中放弃着以后看到最好的后果,最初找到一个绝对不错的答案,因而被命名为口袋算法。

三、算法步骤

初始化向量 w,例如 w 初始化为零向量

循环 t = 0,1,2 …

  找到一个随机的谬误数据,即 h(x) 与目标值 y 不符

  $$ \operatorname{sign}\left(w_{t}^{T} x_{n(t)}\right) \neq y_{n(t)} $$
  尝试修改向量 w,如果更新后的 w 的谬误点绝对更新前的 w 更少的时,则更新 w,反之进入下一次循环。

  $$ w_{t+1} \leftarrow w_{t}+y_{n(t)} x_{n(t)} $$
直到达到设定的最大循环数时退出循环,所得的 w 即为一组方程的解

  由下面的步骤能够看到,因为不晓得什么时候循环应该停下来,所以须要人为定义一个最大的循环次数来作为退出条件,所以口袋算法绝对感知器学习算法来说,运行工夫会更慢一些。在循环中是随机选取谬误点,最初的输入后果在每次运行时也不是一个稳固的后果。

四、代码实现

应用 Python 实现口袋算法:

import numpy as np

def errorIndexes(w, X, y):
    """
    获取谬误点的下标汇合
    args:
        w - 权重系数
        X - 训练数据集
        y - 指标标签值
    return:
        errorIndexes - 谬误点的下标汇合
    """
    errorIndexes = []
    # 遍历训练数据集
    for index in range(len(X)):
        x = X[index]
        # 断定是否与目标值不符
        if x.dot(w) * y[index] <= 0:
            errorIndexes.append(index)
    return errorIndexes

def pocket(X, y, iteration, maxIterNoChange = 10):
    """
    口袋算法实现
    args:
        X - 训练数据集
        y - 指标标签值
        iteration - 最大迭代次数
        maxIterNoChange - 在提前进行之前没有晋升的迭代次数
    return:
        w - 权重系数
    """
    np.random.seed(42)
    # 初始化权重系数
    w = np.zeros(X.shape[1])
    # 获取谬误点的下标汇合
    errors = errorIndexes(w, X, y)
    iterNoChange = 0
    # 循环
    for i in range(iteration):
        iterNoChange = iterNoChange + 1
        # 随机获取谬误点下标
        errorIndex = np.random.randint(0, len(errors))
        # 计算长期权重系数
        tmpw = w + y[errors[errorIndex]] * X[errorIndex]
        # 获取长期权重系数下谬误点的下标汇合
        tmpErrors = errorIndexes(tmpw, X, y)
        # 如果谬误点数量更少,就更新权重系数
        if len(errors) >= len(tmpErrors):
            iterNoChange = 0
            # 修改权重系数
            w = tmpw
            errors = tmpErrors
        # 提前进行
        if iterNoChange >= maxIterNoChange:
            break
    return w

五、动画演示

简略训练数据集分类:

简单训练数据集分类:

六、思维导图

七、参考文献

  1. https://zh.wikipedia.org/wiki…
  2. https://www.coursera.org/lear…

残缺演示请点击这里

注:本文力求精确并通俗易懂,但因为笔者也是初学者,程度无限,如文中存在谬误或脱漏之处,恳请读者通过留言的形式批评指正

本文首发于——AI 导图 ,欢送关注

正文完
 0