共计 7396 个字符,预计需要花费 19 分钟才能阅读完成。
RBF 网络
实验流程
from graphviz import Digraph
dot = Digraph(comment = 'The Round Table')
dot.node('A',"RBF 网络理论理解")
dot.node('B',"结合数据组织 RBF 网络")
dot.node('C',"代码实现")
dot.node('D',"问题以及解决方法")
dot.node('E',"感想和体会")
dot.edges(["AB","BC","CD","DE"])
dot
RBF 网络理解
- rbf 网络在我看来应该是一个可以解决函数拟合问题的网络,与之前学的的 adaline 相比,rbf 网络可以做非线性可分的问题,与 bp 网络相比,rbf 网络只有一层隐藏层,计算方便,而且不太会陷入 bp 网络中常有的局部最优的情况
- rbf 网络应该是以 cover 定理以及插值定理上建立起来的网络,在 cover 定理里,揭示了低维度不可分的样本点在高纬度可分的性质,在插值定理里,揭示了如何利用高维度可分的性质完成函数拟合的问题。
- rbf 网络实现的关键是在 rbf 函数的中心点以及宽度高度调整上,在实现代码的时候围绕这个关键点进行处理
结合数据组织 RBF 网络
- 在本题中使用的数据仍然是双月数据,所以问题的关键是在双月数据中寻找关键点
- 因为我还不是很理解如何利用全局梯度下降调整径向基函数的中心点,宽度,和高度,所以我选择了使用中心自组织的方法进行 rbf 网络的构建
代码实现
代码实现逻辑
from graphviz import Digraph
dot = Digraph(comment = 'The Round Table')
dot.node('A',"相关包导入")
dot.node('B',"导入相关数据并展示")
dot.node('C',"定义高斯函数,误差函数,以及欧拉距离的计算函数")
dot.node('D',"利用 kmeans 算法确定中心点以及宽度")
dot.node('E',"训练调整 rbf 网络高度")
dot.node('F',"决策面展示")
dot.edges(["AB","BC","CD","DE","EF"])
dot
具体代码
# 相关包导入
import numpy as np
import matplotlib.pyplot as plt
import math
LEARNING_RATE = 0.45 # 学习率
def halfmoon(rad, width, d, n_samp):
''' 生成半月数据
@param rad: 半径
@param width: 宽度
@param d: 距离
@param n_samp: 数量
'''
if n_samp%2 != 0: # 确保数量是双数
n_samp += 1
data = np.zeros((3,n_samp))
# 生成 0 矩阵, 生成 3 行 n_samp 列的的矩阵
aa = np.random.random((2,int(n_samp/2)))
radius = (rad-width/2) + width*aa[0,:]
theta = np.pi*aa[1,:]
x = radius*np.cos(theta)
y = radius*np.sin(theta)
label = np.ones((1,len(x))) # label for Class 1
x1 = radius*np.cos(-theta) + rad # 在 x 基础之上向右移动 rad 个单位
y1 = radius*np.sin(-theta) + d # 在 y 取相反数的基础之上向下移动 d 个单位
label1= 0*np.ones((1,len(x1))) # label for Class 2
data[0,:]=np.concatenate([x,x1])
data[1,:]=np.concatenate([y,y1])
data[2,:]=np.concatenate([label,label1],axis=1)
# 合并数据
return data
dataNum = 1000
data = halfmoon(10,5,5,dataNum)
np.savetxt('halfmoon.txt', data.T,fmt='%4f',delimiter=',')
# 导入相关数据并展示
def load_data(file): # 读入数据
x = []
y = []
with open(file, 'r') as f:
lines = f.readlines()
for line in lines:
line = line.strip().split(',')
x_item = [float(line[0]), float(line[1])] # x_item 存储 x,y 值
y_item = float(line[2]) # y_item 存储 z 值 (期望值)
x.append(x_item)
y.append(y_item)
return np.array(x), np.array(y)# x,y 值和 z 值 (期望值), 分别被 x 和 y 分别存储
file="E:\\data\\temp\\workspace\\net_exp3\\adaline_fix\\halfmoon.txt"
INPUTS,OUTPUTS = load_data(file) # 取得 x,y 值和 z 值 (期望值)
# 对数据进行画图
x_aixs = INPUTS[:,0]
y_aixs = INPUTS[:,1]
neg_x_axis = x_aixs[OUTPUTS==0]
neg_y_axis = y_aixs[OUTPUTS==0]
pos_x_axis = x_aixs[OUTPUTS==1]
pos_y_axis = y_aixs[OUTPUTS==1]
plt.figure()
plt.scatter(neg_x_axis,neg_y_axis,c="b",s=10)
plt.scatter(pos_x_axis,pos_y_axis,c="g",s=10)
plt.show()
# 训练数据生成(通过打乱原始数据顺序进行训练数据的生成)data = []
for i in range(len(INPUTS)):
data.append([INPUTS[i][0],INPUTS[i][1],OUTPUTS[i]])
trainData = []
tmp = []
for i in range(1000):
tmp.append(i)
np.random.shuffle(tmp)
for i in range(len(tmp)):
trainData.append(data[tmp[i]])
# 定义高斯函数,误差函数,以及欧拉距离的计算函数
def GaussianFun(center,input,variance):
distance = math.sqrt(math.pow((input[0]-center[0]),2)+math.pow((input[1]-center[1]),2))
GaussianOut = math.exp(-0.5*(1/math.pow(variance,2)*math.pow(distance,2)))
return GaussianOut
def ErrorFun(Error):
Error_sum = 0
for i in Error:
Error_sum+=math.pow(i,2)
Error_sum = Error_sum/2
return Error_sum
def Distance(center,input):
distance = math.sqrt(math.pow((input[0]-center[0]),2)+math.pow((input[1]-center[1]),2))
return distance
利用 kmeans 算法确定中心点以及宽度
在这里先通过调库生成初始的中心点
# kmeans 确定 rbf 中心点
hidGradation_nodeNum = 10 # 隐藏层节点数设置为 10 个
# 生成随机中心
centerNum = hidGradation_nodeNum # 设定的中心数
from sklearn.cluster import KMeans
clusterData = []
for i in trainData:
clusterData.append([i[0],i[1]])
clusterData = np.array(clusterData)
kmeans = KMeans(n_clusters=hidGradation_nodeNum,random_state=0).fit(clusterData)
center = []
for i in kmeans.cluster_centers_:
center.append(list(i))
# 这里确定的是,初始的中心
在调库生成初始中心点的基础之上,利用代价函数对中心进行迭代处理
$$
J(C) = \frac{1}{2}\sum^k_{j=1}\sum_{C(i)=j}\|x_i-u_j\|^2
$$
rbf 函数宽度由下列函数确定
$$
\sigma^2_j = \frac{d_{max}}{\sqrt{2*K}}
$$
# 这里计算的是,经过 k 均值优化之后的中心点
dataLabel = list(kmeans.labels_)
tmpData = []
divideData = []
for i in range(len(clusterData)):
tmpData.append([clusterData[i][0],clusterData[i][1],dataLabel[i]])
for i in range(hidGradation_nodeNum):
divideData.append([])
process = True
for item in tmpData:
divideData[item[2]-1].append([item[0],item[1]])
record = 0
while(process):
record+=1
for i in range(len(divideData)):
x = 0
y = 0
count = 0
tmpCenter = center
for item in divideData[i]:
count+=1
x+=item[0]
y+=item[1]
center[i][0] = x/count
center[i][1] = y/count
count = 0
for i in range(len(center)):
if(abs(tmpCenter[i][0]-center[i][0])==0 and abs(tmpCenter[i][1]-center[i][1])==0):
count+=1
if(count == len(center)):
process = False
for i in range(len(divideData)):
divideData[i] = []
distance = []
for i in range(len(trainData)):
distance=[]
for j in center:
distance.append(Distance(trainData[i],j))
minDistance = np.array(distance).min()
distance = list(distance)
divideData[distance.index(minDistance)].append([trainData[i][0],tmpData[i][1]])
for i in range(len(divideData)):
x = 0
y = 0
count = 0
tmpCenter = center
for item in divideData[i]:
count+=1
x+=item[0]
y+=item[1]
center[i][0] = x/count
center[i][1] = y/count
variance = []
for i in range(hidGradation_nodeNum):
variance.append(0)
key = []
for i in range(len(variance)):
for j in range(len(variance)):
key.append(Distance(center[i],center[j]))
flag = np.array(key).max()
for i in range(len(variance)):
variance[i] = flag/math.sqrt(2*hidGradation_nodeNum)
最后生成的中心点是
center
生成的 rbf 函数宽度是
variance
# 权重随机生成
weight = []
tmp = np.random.uniform(-5,5)
for i in range(10):
weight.append(tmp)
weight = np.array(weight)
del(tmp)
训练调整 rbf 网络高度
在这里的停止规则是:w 自然收敛即停止
# process = True
# while(process)
resultRecord = []
process = True
count = 0
error_count = 0
while(process):
error_record = []
count+=1
print(count,"turn")
for item in range(len(trainData)):
tmp = [trainData[item][0],trainData[item][1]]
tmp_result = 0
result = []
gaussianStore = []
for i in range(hidGradation_nodeNum):
gaussianStore.append(GaussianFun(tmp,center[i],variance[i]))
result.append(weight[i]*GaussianFun(tmp,center[i],variance[i]))
for i in result:
tmp_result += i
error = trainData[item][2]-tmp_result
error_record.append(error)
deltaWeight = np.dot(np.array(gaussianStore),LEARNING_RATE*error)
weight[:]+=deltaWeight
if(ErrorFun(deltaWeight)<0.01):
error_count+=1
if(error_count>10):
process = False
1 turn
2 turn
3 turn
4 turn
5 turn
6 turn
7 turn
8 turn
9 turn
10 turn
11 turn
决策面展示
黑色点是中心点
# 绘制分类结果图
x = []
y = []
red_xaxis = []
red_yaxis = []
green_xaxis = []
green_yaxis = []
for i in range(-15, 25):
x.append(i/1)
for i in range(-8, 13):
y.append(i/1)
for i in x:
for j in y:
tmp = [i,j]
tmp_result = 0
result = []
gaussianStore = []
for k in range(hidGradation_nodeNum):
result.append(weight[k]*GaussianFun(tmp,center[k],variance[k]))
for k in result:
tmp_result += k
if(tmp_result>0.5):
red_xaxis.append(i)
red_yaxis.append(j)
else:
green_xaxis.append(i)
green_yaxis.append(j)
centerX = []
centerY = []
for i in range(len(center)):
centerX.append(center[i][0])
centerY.append(center[i][1])
plt.figure()
plt.scatter(centerX,centerY,c="k",s=200,marker="*")
plt.scatter(green_xaxis,green_yaxis,c="g",s=10)
plt.scatter(red_xaxis,red_yaxis,c="r",s=10)
plt.scatter(neg_x_axis,neg_y_axis,c="b",s=10)
plt.scatter(pos_x_axis,pos_y_axis,c="b",s=10)
plt.show()
问题及解决方法
- 对于有些库可能还是不太理解,比如说在代码中调用的 kmean 库,其实在调用之后,再使用自己定义的代价函数去调整中心点,发现调用的库所给的中心点就已经是代价函数最小的点了,可能有些默认参数我没有吃明白,所以让我花了很多时间处理了不少步骤
- 神经网络参数有很多,首先比如说 kmean 中心点个数的确定,其次是宽度的确定,宽度计算也有两个公式可以进行计算,总的来说可以选择调整的参数有很多,有时候不知道该从何入手
- 在学习 rbf 网络的时候,因为不清楚 rbf 网络的原理,绕了很多弯子,浮于表明,在弄明白几个关键参数的调整之后才意识到这个网络其实可以在 adaline 网络的基础上进行修改,这份代码是我在以前 adaline 网络的基础上进行的修改,其实弄明白后还是很好理解的。
- 在设置的权值调整过程中,其实有误差函数选取的问题,一开始选取的是 $E = \frac{1}{2}\sum^k_{i=1}e^2_i$ 为误差函数,但是因为对这个函数理解不是特别深刻,我发现这个函数在 15 左右就停止下降了,我一开始以为是我的权值调整方法有问题,直到我把决策面画了出来,发现分类效果还可以,再回过头去思考这个函数,发现是选取这个 E 取 15,说明每一个点的误差都在 0.1 以内,这个误差还是可以接受的,关于这一点,只能说自己对误差函数的理解不足,所以造成了自己错误判断。
感想和体会
- 程序应该是数据结构 + 算法过程 + 文档,在这次的作业中,数据算是非常简单的数据,所以困扰我的只是算法过程而已,在理解了 rbf 网络的调整原理之后其实代码也不算难写,只是自己一味的害怕而已。在没有理解 rbf 网络原理之前,我在 github 上找了好几份代码,这个都是因为原理不懂,所以在网上找的代码也看不明白,也无从下手修改,理解基础之后,就决定自己编写,其实也没有花特别多时间就编写出来了。
- 网上的代码不一定适合自己使用,还是需要辩证的去看待网络上的资料
正文完