ID3(来自西瓜书)
ID3 算法应用信息增益的大小抉择特色。信息增益 = 信息熵 – 条件熵
信息熵
例如西瓜书中的数据集,数据集有两个类别:好瓜
, 坏瓜
,即 \(k=2 \)。
则信息熵计算公式为:
条件熵
针对特色 A 的条件熵定义如下:
公式中 \(m \)为特色离散值的品种,比方所特色 色泽
为例,特色取值{青绿, 漆黑, 浅自},即 \(m=3 \)。
信息增益
信息增益 = 信息熵 – 条件熵
code
#encoding=utf-8
import math
import json
def compute_ent(data):
"""计算信息熵"""
cdict = {}
num = len(data)
for e in data:
if e[-1] in cdict:
cdict[e[-1]] += 1
else:
cdict[e[-1]] = 1
res = 0
for key, value in cdict.items():
res += - (value / num) * math.log2(value / num)
return res
def split_data_by_feat(data, feat, pos):
ndata = []
for e in data:
if e[pos] == feat:
ndata.append(e[:pos] + e[pos+1:])
return ndata
def split_features(data, num_feat):
"""寻找最佳划分特色,循环便当每一个特色,选取划分后信息增益最大的特色"""
ent = compute_ent(data)
pos = -1
best = -1
# 遍历所有特色
for i in range(num_feat):
con_ent = 0
feats = set(e[i] for e in data)
# 计算特色条件熵
for sub in feats:
ndata = split_data_by_feat(data, sub, i)
con_ent += len(ndata) / len(data) * compute_ent(ndata)
# 计算信息增益
gain = ent - con_ent
if best < gain:
best = gain
pos = i
return pos
def count(data):
cdict = {}
for e in data:
if e[0] in cdict:
cdict[e[0]] += 1
else:
cdict[e[0]] = 1
cls = ''
ccount = 0
for key, value in cdict.items():
if value > ccount:
cls = key
return cls
def ID3Tree(data, names):
classes = [e[-1] for e in data]
# 如果划分后的数据集中只有一个类别,间接返回
if len(set(classes)) == 1:
return classes[0]
# 如果依照特征值划分后数据集中有多个类别,则依照类别较多的样本定义为划分类别后果
if len(data[0]) == 1 or len(names) == 0:
return count(data)
# 找到最佳的特色
pos = split_features(data, len(names))
feat = names[pos]
tree = {feat:{}}
del(names[pos])
feat_values = set([e[pos] for e in data])
# 依照特征值划分子树
for value in feat_values:
ndata, subnames = split_data_by_feat(data, value, pos), names[:]
tree[feat][value] = ID3Tree(ndata, subnames)
return tree
if __name__ == '__main__':
names = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感']
data = [\
['青绿', '伸直', '浊响', '清晰', '凸起', '硬滑', '是'],\
['漆黑', '伸直', '爽朗', '清晰', '凸起', '硬滑', '是'],\
['漆黑', '伸直', '浊响', '清晰', '凸起', '硬滑', '是'],\
['青绿', '伸直', '爽朗', '清晰', '凸起', '硬滑', '是'],\
['浅白', '伸直', '浊响', '清晰', '凸起', '硬滑', '是'],\
['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '是'],\
['漆黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '是'],\
['漆黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '是'],\
['漆黑', '稍蜷', '爽朗', '稍糊', '稍凹', '硬滑', '否'],\
['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '否'],\
['浅白', '硬挺', '清脆', '含糊', '平坦', '硬滑', '否'],\
['浅白', '伸直', '浊响', '含糊', '平坦', '软粘', '否'],\
['青绿', '稍蜷', '浊响', '稍糊', '凸起', '硬滑', '否'],\
['浅白', '稍蜷', '爽朗', '稍糊', '凸起', '硬滑', '否'],\
['漆黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '否'],\
['浅白', '伸直', '浊响', '含糊', '平坦', '硬滑', '否'],\
['青绿', '伸直', '爽朗', '稍糊', '稍凹', '硬滑', '否',]]
print(json.dumps(ID3Tree(data, names), indent=1, ensure_ascii=False))