关于人工智能:自训练和半监督学习介绍

9次阅读

共计 5282 个字符,预计需要花费 14 分钟才能阅读完成。

作者 |Doug Steen
编译 |VK
起源 |Towards Data Science

当波及到机器学习分类工作时,用于训练算法的数据越多越好。在监督学习中,这些数据必须依据指标类进行标记,否则,这些算法将无奈学习独立变量和指标变量之间的关系。然而,在构建用于分类的大型标记数据集时,会呈现两个问题:

  1. 标记数据可能很耗时 。假如咱们有 1000000 张狗图像,咱们想将它们输出到分类算法中,目标是预测每个图像是否蕴含波士顿狗。如果咱们想将所有这些图像用于监督分类工作,咱们须要一个人查看每个图像并确定是否存在波士顿狗。
  2. 标记数据可能很低廉 。起因一:要想让人费尽心思去搜 100 万张狗狗照片,咱们可能得掏钱。

那么,这些未标记的数据能够用在分类算法中吗?

这就是半监督学习的用武之地。在半监督办法中,咱们能够在大量的标记数据上训练分类器,而后应用该分类器对未标记的数据进行预测。

因为这些预测可能比随机猜想更好,未标记的数据预测能够作为“伪标签”在随后的分类器迭代中采纳。尽管半监督学习有很多种格调,但这种非凡的技术称为自训练。

自训练

在概念层面上,自训练的工作原理如下:

步骤 1 :将标记的数据实例拆分为训练集和测试集。而后,对标记的训练数据训练一个分类算法。

步骤 2 :应用经过训练的分类器来预测所有未标记数据实例的类标签。在这些预测的类标签中,正确率最高的被认为是“伪标签”。

(第 2 步的几个变动:a)所有预测的标签能够同时作为“伪标签”应用,而不思考概率;或者 b)“伪标签”数据能够通过预测的置信度进行加权。)

步骤 3 :将“伪标记”数据与正确标记的训练数据连接起来。在组合的“伪标记”和正确标记训练数据上从新训练分类器。

步骤 4 :应用经过训练的分类器来预测已标记的测试数据实例的类标签。应用你抉择的度量来评估分类器性能。

(能够反复步骤 1 到 4,直到步骤 2 中的预测类标签不再满足特定的概率阈值,或者直到没有更多未标记的数据保留。)

好的,明确了吗?很好!让咱们通过一个例子解释。

示例:应用自训练改良分类器

为了演示自训练,我应用 Python 和 surgical_deepnet 数据集,能够在 Kaggle 上找到:https://www.kaggle.com/omnama…

此数据集用于二分类,蕴含 14.6k+ 手术的数据。这些属性是 bmi、年龄等各种测量值,而指标变量 complexing 则记录患者是否因手术而呈现并发症。显然,可能精确地预测患者是否会因手术而呈现并发症,这对医疗保健和保险供应商都是最无利的。

导入库

对于本教程,我将导入 numpy、pandas 和 matplotlib。我还将应用 sklearn 中的 LogisticRegression 分类器,以及用于模型评估的 f1_score 和 plot_confusion_matrix 函数

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.linear_model import LogisticRegression

from sklearn.metrics import f1_score
from sklearn.metrics import plot_confusion_matrix

加载数据

# 加载数据

df = pd.read_csv('surgical_deepnet.csv')
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14635 entries, 0 to 14634
Data columns (total 25 columns):
bmi                    14635 non-null float64
Age                    14635 non-null float64
asa_status             14635 non-null int64
baseline_cancer        14635 non-null int64
baseline_charlson      14635 non-null int64
baseline_cvd           14635 non-null int64
baseline_dementia      14635 non-null int64
baseline_diabetes      14635 non-null int64
baseline_digestive     14635 non-null int64
baseline_osteoart      14635 non-null int64
baseline_psych         14635 non-null int64
baseline_pulmonary     14635 non-null int64
ahrq_ccs               14635 non-null int64
ccsComplicationRate    14635 non-null float64
ccsMort30Rate          14635 non-null float64
complication_rsi       14635 non-null float64
dow                    14635 non-null int64
gender                 14635 non-null int64
hour                   14635 non-null float64
month                  14635 non-null int64
moonphase              14635 non-null int64
mort30                 14635 non-null int64
mortality_rsi          14635 non-null float64
race                   14635 non-null int64
complication           14635 non-null int64
dtypes: float64(7), int64(18)
memory usage: 2.8 MB

数据集中的属性都是数值型的,没有缺失值。因为我这里的重点不是数据清理,所以我将持续对数据进行划分。

数据划分

为了测试自训练的成果,我须要将数据分成三局部:训练集、测试集和未标记集。我将按以下比例拆分数据:

  • 1% 训练
  • 25% 测试
  • 74% 未标记

对于未标记集,我将简略地放弃指标变量 complexing,并伪装它从未存在过。

所以,在这个病例中,咱们认为 74% 的手术病例没有对于并发症的信息。我这样做是为了模仿这样一个事实:在理论的分类问题中,可用的大部分数据可能没有类标签。然而,如果咱们有一小部分数据的类标签(在本例中为 1%),那么能够应用半监督学习技术从未标记的数据中得出结论。

上面,我随机化数据,生成索引来划分数据,而后创立测试、训练和未标记的划分。而后我查看各个集的大小,确保所有都按计划进行。

X_train dimensions: (146, 24)
y_train dimensions: (146,)

X_test dimensions: (3659, 24)
y_test dimensions: (3659,)

X_unlabeled dimensions: (10830, 24)

类散布

少数类的样本数((并发症))是少数类(并发症)的两倍多。在这样一个不均衡的类的状况下,我想准确度可能不是最佳的评估指标。

抉择 F1 分数作为分类指标来判断分类器的有效性。F1 分数对类别不均衡的影响比准确度更为持重,当类别近似均衡时,这一点更为适合。F1 得分计算如下:

其中 precision 是预测正例中正确预测的比例,recall 是实在正例中正确预测的比例。

初始分类器(监督)

为了使半监督学习的后果更实在,我首先应用标记的训练数据训练一个简略的 Logistic 回归分类器,并对测试数据集进行预测。

Train f1 Score: 0.5846153846153846
Test f1 Score: 0.5002908667830134

分类器的 F1 分数为 0.5。混同矩阵通知咱们,分类器能够很好地预测没有并发症的手术,准确率为 86%。然而,分类器更难正确辨认有并发症的手术,准确率只有 47%。

预测概率

对于自训练算法,咱们须要晓得 Logistic 回归分类器预测的概率。侥幸的是,sklearn 提供了.predict_proba() 办法,它容许咱们查看属于任一类的预测的概率。如下所示,在二元分类问题中,每个预测的总概率总和为 1.0。

array([[0.93931367, 0.06068633],
       [0.2327203 , 0.7672797],
       [0.93931367, 0.06068633],
       ...,
       [0.61940353, 0.38059647],
       [0.41240068, 0.58759932],
       [0.24306008, 0.75693992]])

自训练分类器(半监督)

既然咱们晓得了如何应用 sklearn 取得预测概率,咱们能够持续编码自训练分类器。以下是简要概述:

第 1 步 :首先,在标记的训练数据上训练 Logistic 回归分类器。

第 2 步 :接下来,应用分类器预测所有未标记数据的标签,以及这些预测的概率。在这种状况下,我只对概率大于 99% 的预测采纳“伪标签”。

第 3 步 :将“伪标记”数据与标记的训练数据连接起来,并在连贯的数据上从新训练分类器。

第 4 步 :应用训练好的分类器对标记的测试数据进行预测,并对分类器进行评估。

反复步骤 1 到 4,直到没有更多的预测具备大于 99% 的概率,或者没有未标记的数据保留。

上面的代码应用 while 循环在 Python 中实现这些步骤。

Iteration 0
Train f1: 0.5846153846153846
Test f1: 0.5002908667830134
Now predicting labels for unlabeled data...
42 high-probability predictions added to training data.
10788 unlabeled instances remaining.

Iteration 1
Train f1: 0.7627118644067796
Test f1: 0.5037463976945246
Now predicting labels for unlabeled data...
30 high-probability predictions added to training data.
10758 unlabeled instances remaining.

Iteration 2
Train f1: 0.8181818181818182
Test f1: 0.505431675242996
Now predicting labels for unlabeled data...
20 high-probability predictions added to training data.
10738 unlabeled instances remaining.

Iteration 3
Train f1: 0.847457627118644
Test f1: 0.5076835515082526
Now predicting labels for unlabeled data...
21 high-probability predictions added to training data.
10717 unlabeled instances remaining.

...
Iteration 44
Train f1: 0.9481216457960644
Test f1: 0.5259179265658748
Now predicting labels for unlabeled data...
0 high-probability predictions added to training data.
10079 unlabeled instances remaining.

自训练算法通过 44 次迭代,就不能以 99% 的概率预测更多的未标记实例了。即便一开始有 10,830 个未标记的实例,在自训练之后依然有 10,079 个实例未标记 (并且未被分类器应用)。

通过 44 次迭代,F1 的分数从 0.50 进步到 0.525!尽管这只是一个小的增长,但看起来自训练曾经改善了分类器在测试数据集上的性能。上图的顶部面板显示,这种改良大部分产生在算法的晚期迭代中。同样,底部面板显示,增加到训练数据中的大多数“伪标签”都是在前 20-30 次迭代中呈现的。

最初的混同矩阵显示有并发症的手术分类有所改善,但没有并发症的手术分类略有降落。有了 F1 分数的进步,我认为这是一个能够承受的提高 - 可能更重要的是确定会导致并发症的手术病例(真正例),并且可能值得减少假正例率来达到这个后果。

正告语

所以你可能会想:用这么多未标记的数据进行自训练有危险吗?答案当然是必定的。请记住,只管咱们将“伪标记”数据与标记的训练数据一起蕴含在内,但某些“伪标记”数据必定会不正确。当足够多的“伪标签”不正确时,自训练算法会强化蹩脚的分类决策,而分类器的性能实际上会变得更糟。

能够应用分类器在训练期间没有看到的测试集,或者应用“伪标签”预测的概率阈值,能够加重这种危险。

原文链接:https://towardsdatascience.co…

欢送关注磐创 AI 博客站:
http://panchuang.net/

sklearn 机器学习中文官网文档:
http://sklearn123.com/

欢送关注磐创博客资源汇总站:
http://docs.panchuang.net/

正文完
 0