KM算法学习笔记

43次阅读

共计 2343 个字符,预计需要花费 6 分钟才能阅读完成。

二分图定义
图的顶点恰好可以分成两个集合,同一个集合内的顶点间不允许有边,处在不同集合的顶点允许有边相连。
问题分类

最大匹配问题:匈牙利算法、Hopcroft–Karp 算法
最优权值匹配问题:Kuhn-Munkras 算法

关键思想
增广路(augmenting path):假设目前已有一个匹配结果,存在一组未匹配定点的 OD,能够找到一条路径,这条路径上匹配和未匹配的定点交替出现,称为增广路
增广路上的匹配和未匹配取反,则匹配数增加 1。
KM 算法
基本思想:通过引入顶标,将最优权值匹配转化为最大匹配问题。

步骤 1:将边权值转化为顶标 / 标杆,一般来讲,初始化时,X 集合的元素取对应权重的最大值,Y 集合的元素取 0。取出满足以下条件的边,构建二分图:weight(i,j) = label(i) + label(j);该二分图称为相等子图

步骤 2:寻找增广路,从 X0 开始,找到 X0Y4;在 X1,找不到增广路,需要调整顶标,扩大相等子图;当找不到增广路径时,对于搜索过的路径上的 XY 点,设该路径上的 X 顶点集为 S,Y 顶点集为 T,对所有在 S 中的点 xi 及不在 T 中的点 yj,计算 d =min{(L(xi)+L(yj)-weight(xiyj))},从 S 集中的 X 标杆中减去 d,并将其加入到 T 集中的 Y 的标杆中;本例寻找增广路的过程中,访问了 X1、Y4、X0 三个节点,因此对应的边是 X1Y0,d 为 2(从贪心选边的角度看,我们可以为 X0 选择新的边而抛弃原先的二分子图中的匹配边,也可以为 X1 选择新的边而抛弃原先的二分子图中的匹配边,因为我们不能同时选择 X0Y4 和 X1Y4,因为这是一个不合法匹配,这个时候,d=min{(L(xi)+L(yj)-weight(xiyj))} 的意义就在于,我们选择一条新的边,这条边将被加入匹配子图中使得匹配合法,选择这条边形成的匹配子图,将比原先的匹配子图加上这条非法边组成的非法匹配子图的权重和(如果它是合法的,它将是最大的)小最少,即权重最大了);此时再次为 X1 寻找增广路,得到 X1Y0.

步骤 3:对 X2 寻找增广路,搜索范围如上图蓝色路径所示,找不到增广路,需要扩大相等子图;按照步骤 2 同一规则,会将边 X0Y2、X2Y1 加入,d=1.

步骤 4:在新的相等子图上,对 X2 重新寻找增广路。如果是深度优先,得到的路线是 X2Y0->Y0X1->X1Y4->Y4X0->X0Y2,此时将匹配结果取反,则得到 X2Y0、X1Y4、X0Y2 三个匹配;如果是宽度优先,得到的匹配结果是 X0Y4、X1Y0、X2Y1,如下图:

Python 实现
import numpy as np

# 声明数据结构
adj_matrix = build_graph() # np array with dimension N*N

# 初始化顶标
label_left = np.max(adj_matrix, axis=1) # init label for the left set
label_right = np.zeros(N) # init label for the right set

# 初始化匹配结果
match_right = np.empty(N) * np.nan

# 初始化辅助变量
visit_left = np.empty(N) * False
visit_right = np.empty(N) * False
slack_right = np.empty(N) * np.inf

# 寻找增广路,深度优先
def find_path(i):
visit_left[i] = True
for j, match_weight in enumerate(adj_matrix[i]):
if visit_right[j]: continue # 已被匹配(解决递归中的冲突)
gap = label_left[i] + label_right[j] – match_weight
if gap == 0:
# 找到可行匹配
visit_right[j] = True
if np.isnan(match_right[j]) or find_path(match_right[j]): ## j 未被匹配,或虽然 j 已被匹配,但是 j 的已匹配对象有其他可选备胎
match_right[j] = i
return True
else:
# 计算变为可行匹配需要的顶标改变量
if slack_right[j] < gap: slack_right[j] = gap
return False

# KM 主函数
def KM():
for i in range(N):
# 重置辅助变量
slack_right = np.empty(N) * np.inf
while True:
# 重置辅助变量
visit_left = np.empty(N) * False
visit_right = np.empty(N) * False

# 能找到可行匹配
if find_path(i): break
# 不能找到可行匹配,修改顶标
# (1) 将所有在增广路中的 X 方点的 label 全部减去一个常数 d
# (2) 将所有在增广路中的 Y 方点的 label 全部加上一个常数 d
d = np.inf
for j, slack in enumerate(slack_right):
if not visit_right[j] and slack < d:
d = slack
for k in range(N):
if visit_left[k]: label_left[k] -= d
for n in range(N):
if visit_right[n]: label_right[n] += d
res = 0
for j in range(N):
if match_right[j] >=0 and match_right[j] < N:
res += adj_matrix[match[j]][j]
return res
参考资料
http://blog.sina.com.cn/s/blo…
https://blog.csdn.net/mosquit…

正文完
 0