Rand index(兰德指数)原理以及numpy和pytorch实现
什么是Rand指数
关于Rand指数的定义我发现*上总结得到位,我也就不再进行赘述,为了本文的完整性和以防国内打不开*,我这里就当一次搬运工,当然有条件的还是建议去*上去看原文~~
Rand Index
The Rand index or Rand measure (named after William M. Rand) in statistics, and in particular in data clustering, is a measure of the similarity between two data clusterings. A form of the Rand index may be defined that is adjusted for the chance grouping of elements, this is the adjusted Rand index. From a mathematical standpoint, Rand index is related to the accuracy, but is applicable even when class labels are not used.
Definition
Given a set of elements and two partitions of to compare, , a partition of into subsets, and , a partition of into subsets, define the following:
- a, the number of pairs of elements in that are in the same subset in and in the same subset in
- b, the number of pairs of elements in that are in different subsets in and in different subsets in
- c, the number of pairs of elements in that are in the same subset in and in different subsets in
- d, the number of pairs of elements in that are in different subsets in and in the same subset in
The Rand index, R, is:
Intuitively, can be considered as the number of agreements between and , and as the number of disagreements between and .
Since the denominator is the total number of pairs, the Rand index represents the frequency of occurrence of agreements over the total pairs, or the probability that and will agree on a randomly chosen pair, e.g., .
Similarly, one can also view the Rand index as a measure of the percentage of correct decisions made by the algorithm. It can be computed using the following formula:
where is the number of true positives, is the number of true negatives, is the number of false positives, and is the number of false negatives.
Properties
The Rand index has a value between 0 and 1, with 0 indicating that the two data clusterings do not agree on any pair of points and 1 indicating that the data clusterings are exactly the same.
In mathematical terms, a, b, c, d are defined as follows:
- , where
- , where
- , where
-
, where
for some , , , , ,
Relationship with classification accuracy
The Rand index can also be viewed through the prism of binary classification accuracy over the pairs of elements in . The two class labels are " and are in the same subset in and " and " and are in different subsets in and ".
In that setting, is the number of pairs correctly labeled as belonging to the same subset (true positives), and is the number of pairs correctly labeled as belonging to different subsets (true negativess).
The contingency table
Given a set of elements, and two groupings or partitions (e.g. clusterings) of these elements, namely and , the overlap between and can be summarized in a contingency table where each entry denotes the number of objects in common between and : .
\ | … | Sums | |||
---|---|---|---|---|---|
… | |||||
… | |||||
… | … | … | … | … | … |
… | |||||
Sums | … |
Adjusted Rand index
The adjusted Rand index is the corrected-for-chance version of the Rand index. Such a correction for chance establishes a baseline by using the expected similarity of all pair-wise comparisons between clusterings specified by a random model. Traditionally, the Rand Index was corrected using the Permutation Model for clusterings (the number and size of clusters within a clustering are fixed, and all random clusterings are generated by shuffling the elements between the fixed clusters). However, the premises of the permutation model are frequently violated; in many clustering scenarios, either the number of clusters or the size distribution of those clusters vary drastically. For example, consider that in K-means the number of clusters is fixed by the practitioner, but the sizes of those clusters are inferred from the data. Variations of the adjusted Rand Index account for different models of random clusterings.
Though the Rand Index may only yield a value between 0 and +1, the adjusted Rand index can yield negative values if the index is less than the expected index.
上面全是*上的内容,当了一个搬运工,还是那句话,有条件的去*去查看原文~~
接下来就是自己对Rand index的“白话文”解释了,希望能对大家有一点点的帮助,如有错误,也希望大家能及时指出,谢谢
一个二分类的例子
如上图,用一个最简单的例子来解释Rand index的代码实现过程。
我们假设就是预测的结果:, , 为全集
是groundtruth (GT)结果:, .
是左边这个圆,是右边这个圆,我们称和为前景,和为背景。
现在将两者重叠放在一起,假设出现上面的情况,即前景只有部分重叠。整个全集被分成四个部分,.
所以可得The contingency table (T)为:
\ | Sums | ||
---|---|---|---|
Sums |
箭头表示pair对:
对于四个子集,共有10种pair连接方式:
我们大致将这10种pair对分成两类,即和两类,分别用蓝色和红色表示
以红色的(1)为例: 在中,两个端点属于同一类 (都属于),而在中却不是,左端点属于,右端点属于,不是同一类。所以对于红色的(1)pair对应该属于中的情况。其他的情况不再一一列举,是一样的意思
从图中也可以看出计算比要容易一些,所以我们一般将Rand index的计算改为:
而:
这里,代表的是The contingency table,上面的化简是为了得到最后一步的矩阵运算,我们本可以直接使用第二个等号后面的方法计算的,但当不是二分类的时候,该等式的计算方式是非常低效的(避免不了要使用for循环),但如果我们化简为最后一步的方式时,不再需要循环运算,全部依赖矩阵运算(从代码的角度上来说就是一行的事),是非常简洁且高效的
numpy实现
import numpy as np
def Rand_index_numpy(predMasks, gtMasks):
'''
predMasks: Numpy-array, Predcition result; shape: [r, H, W], (r>=1)
gtMasks: Numpy-array, Groundtruth; shape: [s, H, W], (s>=1)
'''
gtMasks = np.concatenate([gtMasks, np.clip(1 - np.sum(gtMasks, axis=0, keepdims=True), a_min=0, a_max=1)], axis=0)
# 在GT上扩充一个类别,即除去所有前景(s类),剩下的背景归为一类
predMasks = np.concatenate([predMasks, np.clip(1 - np.sum(predMasks, axis=0, keepdims=True), a_min=0, a_max=1)], axis=0)
# 在prediction上扩充一个类别,即除去所有前景(r类),剩下的背景归为一类
T = (np.expand_dims(gtMasks, axis=1) * predMasks).sum(-1).sum(-1).astype(np.float32)
# The contingency table
N = T.sum()
# 所有的像素数量
RI = 1 - ((np.power(T.sum(0), 2).sum() + np.power(T.sum(1), 2).sum()) / 2 - np.power(T, 2).sum()) / (N * (N - 1) / 2)
return RI
pytorch实现
import torch
def Rand_index_torch(predMasks, gtMasks):
'''
predMasks: Tensor, Predcition result; shape: [r, H, W], (r>=1)
gtMasks: Tensor, Groundtruth; shape: [s, H, W], (s>=1)
'''
gtMasks = torch.cat([gtMasks, torch.clamp(1 - gtMasks.sum(0, keepdim=True), min=0)], dim=0)
# 在GT上扩充一个类别,即除去所有前景(s类),剩下的背景归为一类
predMasks = torch.cat([predMasks, torch.clamp(1 - predMasks.sum(0, keepdim=True), min=0)], dim=0)
# 在prediction上扩充一个类别,即除去所有前景(r类),剩下的背景归为一类
T = (gtMasks.unsqueeze(1) * predMasks).sum(-1).sum(-1).float()
# The contingency table
N = T.sum()
# 所有的像素数量
RI = 1 - ((T.sum(0).pow(2).sum() + T.sum(1).pow(2).sum()) / 2 - T.pow(2).sum()) / (N * (N - 1) / 2)
return RI
上一篇: Set判重操作实现 博客分类: 功能实现 sethash
下一篇: Java基础-8-输入和输出