ID3(来自西瓜书)
ID3算法应用信息增益的大小抉择特色。信息增益 = 信息熵 - 条件熵
信息熵
例如西瓜书中的数据集,数据集有两个类别:好瓜
, 坏瓜
,即\( k=2 \)。
则信息熵计算公式为:
条件熵
针对特色A的条件熵定义如下:
公式中\( m \)为特色离散值的品种,比方所特色色泽
为例,特色取值{青绿,漆黑,浅自},即\( m=3 \)。
信息增益
信息增益 = 信息熵 - 条件熵
code
#encoding=utf-8import mathimport jsondef compute_ent(data): """ 计算信息熵 """ cdict = {} num = len(data) for e in data: if e[-1] in cdict: cdict[e[-1]] += 1 else: cdict[e[-1]] = 1 res = 0 for key, value in cdict.items(): res += - (value / num) * math.log2(value / num) return resdef split_data_by_feat(data, feat, pos): ndata = [] for e in data: if e[pos] == feat: ndata.append(e[:pos] + e[pos+1:]) return ndatadef split_features(data, num_feat): """ 寻找最佳划分特色,循环便当每一个特色,选取划分后信息增益最大的特色 """ ent = compute_ent(data) pos = -1 best = -1 # 遍历所有特色 for i in range(num_feat): con_ent = 0 feats = set(e[i] for e in data) # 计算特色条件熵 for sub in feats: ndata = split_data_by_feat(data, sub, i) con_ent += len(ndata) / len(data) * compute_ent(ndata) # 计算信息增益 gain = ent - con_ent if best < gain: best = gain pos = i return posdef count(data): cdict = {} for e in data: if e[0] in cdict: cdict[e[0]] += 1 else: cdict[e[0]] = 1 cls = '' ccount = 0 for key, value in cdict.items(): if value > ccount: cls = key return clsdef ID3Tree(data, names): classes = [e[-1] for e in data] # 如果划分后的数据集中只有一个类别,间接返回 if len(set(classes)) == 1: return classes[0] # 如果依照特征值划分后数据集中有多个类别,则依照类别较多的样本定义为划分类别后果 if len(data[0]) == 1 or len(names) == 0: return count(data) # 找到最佳的特色 pos = split_features(data, len(names)) feat = names[pos] tree = {feat:{}} del(names[pos]) feat_values = set([e[pos] for e in data]) # 依照特征值划分子树 for value in feat_values: ndata, subnames = split_data_by_feat(data, value, pos), names[:] tree[feat][value] = ID3Tree(ndata, subnames) return treeif __name__ == '__main__': names = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感'] data = [\ ['青绿', '伸直', '浊响', '清晰', '凸起', '硬滑', '是'],\ ['漆黑', '伸直', '爽朗', '清晰', '凸起', '硬滑', '是'],\ ['漆黑', '伸直', '浊响', '清晰', '凸起', '硬滑', '是'],\ ['青绿', '伸直', '爽朗', '清晰', '凸起', '硬滑', '是'],\ ['浅白', '伸直', '浊响', '清晰', '凸起', '硬滑', '是'],\ ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '是'],\ ['漆黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '是'],\ ['漆黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '是'],\ ['漆黑', '稍蜷', '爽朗', '稍糊', '稍凹', '硬滑', '否'],\ ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '否'],\ ['浅白', '硬挺', '清脆', '含糊', '平坦', '硬滑', '否'],\ ['浅白', '伸直', '浊响', '含糊', '平坦', '软粘', '否'],\ ['青绿', '稍蜷', '浊响', '稍糊', '凸起', '硬滑', '否'],\ ['浅白', '稍蜷', '爽朗', '稍糊', '凸起', '硬滑', '否'],\ ['漆黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '否'],\ ['浅白', '伸直', '浊响', '含糊', '平坦', '硬滑', '否'],\ ['青绿', '伸直', '爽朗', '稍糊', '稍凹', '硬滑', '否',]] print(json.dumps(ID3Tree(data, names), indent=1, ensure_ascii=False))