机器学习Rank-中使用的指标及示例代码

49次阅读

共计 1855 个字符,预计需要花费 5 分钟才能阅读完成。

作者:LogM

本文原载于 https://segmentfault.com/u/logm/articles,不允许转载~

1. P@K

P@K,代表前 K 个预测值中有多少的准确率 (Precision)。

比如,一个模型输出了一组排序,其输出的好坏依次为:好、坏、好、坏、好。

那么,

Prec@3 = 2/3

Prec@4 = 2/4

Prec@5 = 3/5

def precision(gt, pred, K):
    """ Computes the average precision.
        gt: list, ground truth, all relevant docs' index
        pred: list, prediction
    """
    hit_num = len(gt & set(pred[:K]))
    return float(1.0 * hit_num / K)

2. MAP

AP 是 average precision 的缩写,计算方式是把所有相关文档的 P@K 求平均。

借用上面的例子,一个模型输出了一组排序,依次为:好的结果、坏的结果、好的结果、坏的结果、好的结果。

那么,

AP = (1/1 + 2/3 + 3/5) / 3 = 0.76

注意,不是 (1/1 + 1/2 + 2/3 + 2/4 + 3/5) / 5。因为在实际情况中,总有一些 ” 好的结果 ” 是模型漏召回的,指标的计算公式应该怎么处理这部分漏召回?

AP 会把所有 ” 好的结果 ” 都算上,但是排序模型可能会对某些 ” 好的结果 ” 漏召回,这些 ” 漏召回 ” 的 P@K 视为 0。

MAP 是 mean average precision 的缩写,就是把所有的 AP 求平均。

def average_precision(gt, pred):
    """ Computes the average precision.
        gt: list, ground truth, all relevant docs' index
        pred: list, prediction
    """
    if not gt:
        return 0.0

    score = 0.0
    num_hits = 0.0
    for i, p in enumerate(pred):
        if p in gt and p not in pred[:i]:
            num_hits += 1.0
            score += num_hits / (i + 1.0)

    return score / max(1.0, len(gt))

3. MRR

某些场景下,每个问题只有一个标准答案,我们只关心标准答案的排序位置。此时,我们是把标准答案在模型输出结果中的排序取倒数作为它的 RR (Reciprocal Rank)。

比如,模型的输出一组排序,依次为:非标准答案,标准答案,非标准答案,非标准答案

那么:

RR = 1/2

对所有的问题取平均,得到 MRR (Mean Reciprocal Rank)。

def reciprocal_rank(gt, pred):
    """ Computes the reciprocal rank.
        gt: list, ground truth, all relevant docs' index
        pred: list, prediction
    """
    if not gt:
        return 0.0

    score = 0.0
    for rank, item in enumerate(pred):
        if item in gt:
            score = 1.0 / (rank + 1.0)
            break
    return score

4. NDCG

这块涉及到一些数学公式,我懒得打了,可以谷歌到。直接看下面的代码来理解估计会更清楚一些。

DCG: Discounted Cumulative Gain

IDCG: ideal DCG,是 ground truth 情况下的 DCG

NDCG: normalized DCG,是 DCG 和 IDCG 相除

NDCG = DCG / IDCG

def NDCG(gt, pred, use_graded_scores=False):
    """ Computes the NDCG.
        gt: list, ground truth, all relevant docs' index
        pred: list, prediction
    """
    score = 0.0
    for rank, item in enumerate(pred):
        if item in gt:
            if use_graded_scores:
                grade = 1.0 / (gt.index(item) + 1)
            else:
                grade = 1.0
            score += grade / np.log2(rank + 2)

    norm = 0.0
    for rank in range(len(gt)):
        if use_graded_scores:
            grade = 1.0 / (rank + 1)
        else:
            grade = 1.0
        norm += grade / np.log2(rank + 2)
    return score / max(0.3, norm)

正文完
 0