ConcurrentHashMap源码详解
1. ConcurrentHashMap概述
ConcurrentHashMap是线程安全的哈希表,不同于HashTable,后者在方法上增加synchronized关键字,利用对象同步锁实现线程之间的同步。显然,HashTable实现线程安全的方式太“重”,并发度高的情况下,很多线程争用同一把锁,吞吐量较低。
ConcurrentHashMap通过锁分段技术,只有在同一个段内,才会存在锁竞争,提高了并发处理能力。它的内部数据结构其实是一个Segment数组,该数组的大小代表了ConcurrentHashMap的并发度,Segment同时也是一把可重入锁,该锁用来确保该段数据并发访问的线程安全。每一个Segment其实是一个类似于HashMap的哈希表,用来存储key-value。看下ConcurrentHashMap结构图:
ConcurrentHashMap维护了一个Segment数组segments,每个Segment是一个哈希表。当线程需要访问segments[1]处的哈希表,首先需要获取该段的锁,然后才能访问该段的哈希表。上图中segments数组大小为8,因此并发度为8,最多支持8个线程在不同的段同时访问。
2. HashEntry
HashEntry代表了哈希表的一个key-value项,它是ConcurrentHashMap的一个内部静态类,看下HashEntry的数据结构:
static final class HashEntry<K,V> {
final int hash;
final K key;
volatile V value;
volatile HashEntry<K,V> next;
HashEntry(int hash, K key, V value, HashEntry<K,V> next) {
this.hash = hash;
this.key = key;
this.value = value;
this.next = next;
}
//……
}
HashEntry数据结构也很简单,它是一个单链表结点,每个结点包括了key-value对、哈希值、指向下一个节点的引用。
3. Segment
ConcurrentHashMap最重要的概念就是Segment了,它是一个有锁功能的(继承了ReentrantLock)哈希表,ConcurrentHashMap正是由Segment数组组成的数据结构。
看下Segment的类声明:
static final class Segment<K,V> extends ReentrantLock implements Serializable
Segment通过继承ReentrantLock拥有了锁的功能。
接着看下Segment的几个成员变量:
//获取锁失败后的尝试次数,和机器可用的cpu核数量有关
static final int MAX_SCAN_RETRIES =
Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;
//哈希表,一个segment对应一个哈希表
transient volatile HashEntry<K,V>[] table;
//哈希表kv元素的个数,注意:ConcurrentHashMap的元素数量是所有segment的元素数量之和
transient int count;
//哈希表改变的次数
transient int modCount;
//哈希表重哈希的阀值,元素数量超过这个值,需要扩充哈希表,否则哈希冲突会增加
transient int threshold;
//加载因子
final float loadFactor;
这几个变量我们之前在学习HashMap的时候基本上都学习过,看**释就可以了。
看下Segment唯一的一个构造方法:
Segment(float lf, int threshold, HashEntry<K,V>[] tab) {
this.loadFactor = lf;
this.threshold = threshold;
this.table = tab;
}
Segment没有默认构造方法。
接着看下Segment的put方法:
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
HashEntry<K,V> node = tryLock() ? null :
scanAndLockForPut(key, hash, value);
//到这,一定成功获取锁了。
//注意,此时node可能为空,也可能不为空。如果为空,接下来put的时候需要创建一个新的结点,如果不为空
//可以直接使用该节点。
//返回key对应老的value值
V oldValue;
try {
HashEntry<K,V>[] tab = table;
//定位到HashEntry索引
int index = (tab.length - 1) & hash;
HashEntry<K,V> first = entryAt(tab, index);
for (HashEntry<K,V> e = first;;) {
if (e != null) {
K k;
//如果key已经存在,更新对应的value,跳出for循环
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
oldValue = e.value;
if (!onlyIfAbsent) {
e.value = value;
++modCount;
}
break;
}
e = e.next;
}
else {
//node不为空,将node插入到链表的头部
if (node != null)
node.setNext(first);
else
//创建一个新的节点并插入到链表的头部
node = new HashEntry<K,V>(hash, key, value, first);
int c = count + 1;
//元素数量超过阀值,重哈希
if (c > threshold && tab.length < MAXIMUM_CAPACITY)
rehash(node);
else
//更新哈希表table处索引index处的值为node,每次插入都是插入到链表的头部
setEntryAt(tab, index, node);
++modCount;
count = c;
oldValue = null;
break;
}
}
} finally {
//释放锁
unlock();
}
return oldValue;
}
该方法将指定的key-value添加到哈希表,如果key已经存在,更新对应的value值,否则创建一个新的节点,加入到哈希表。基本思路很简单,但是put之前加锁操作比较复杂。
put方法开始的时候,尝试获取锁,如果获取锁不成功,调用scanAndLockForPut方法,这个方法在尝试获取锁失败的情况下,如果key对应的节点存在,返回null,否则为后续put操作创建一个新的节点,该方法返回之前一定成功获取到锁。注意,虽然scanAndLockForPut方法在发现key对应的节点存在的情况下,返回了null,put方法还是会判断该key对应的节点是否存在,如果存在则更新value。
如果插入节点后的元素个数大于threshold,需要对该哈希表重哈希,重哈希后的哈希表容量是原来的2倍。
Segment的重哈希过程做了一个优化,找到该segment的HashEntry链表的某个元素lastIdx,使得从该元素开始到链表末尾的所有元素在新哈希表相同的桶中。这样,只需要将该元素之前的元素一个个的添加到新的链表即可,一定程度上复用了原来同一个槽上的部分节点,看下示意图:
上图中红色节点在新哈希表的位置相同,直接复用这几个节点。红色节点之前的元素需要一个个的添加到新的哈希表中。原理介绍完了,看下代码的实现:
private void rehash(HashEntry<K,V> node) {
HashEntry<K,V>[] oldTable = table;
int oldCapacity = oldTable.length;
//新容量扩充为原来的2倍
int newCapacity = oldCapacity << 1;
threshold = (int)(newCapacity * loadFactor);
HashEntry<K,V>[] newTable =
(HashEntry<K,V>[]) new HashEntry[newCapacity];
int sizeMask = newCapacity - 1;
for (int i = 0; i < oldCapacity ; i++) {
HashEntry<K,V> e = oldTable[i];
if (e != null) {
HashEntry<K,V> next = e.next;
int idx = e.hash & sizeMask;
//只有一个节点的链表,直接放到新表即可
if (next == null)
newTable[idx] = e;
else {
//这段代码就是找到上图中lastIdx的节点
HashEntry<K,V> lastRun = e;
int lastIdx = idx;
for (HashEntry<K,V> last = next;
last != null;
last = last.next) {
int k = last.hash & sizeMask;
if (k != lastIdx) {
lastIdx = k;
lastRun = last;
}
}
//将lastIdx和后面的节点放到新的哈希表
newTable[lastIdx] = lastRun;
//lastIdx之前的节点一个个加到新的哈希表
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
V v = p.value;
int h = p.hash;
int k = h & sizeMask;
HashEntry<K,V> n = newTable[k];
newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
}
}
}
}
//新添加的节点node放到链表头部
int nodeIndex = node.hash & sizeMask;
node.setNext(newTable[nodeIndex]);
newTable[nodeIndex] = node;
table = newTable;
}
接着看下Segment的remove方法:
final V remove(Object key, int hash, Object value) {
//删除元素之前要获取锁
if (!tryLock())
scanAndLock(key, hash);
V oldValue = null;
try {
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash;
HashEntry<K,V> e = entryAt(tab, index);
HashEntry<K,V> pred = null;
while (e != null) {
K k;
HashEntry<K,V> next = e.next;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
V v = e.value;
if (value == null || value == v || value.equals(v)) {
//value为空,只要key相同就删除,value不为空,要比较value和v是否相同
if (pred == null)
setEntryAt(tab, index, next);
else
pred.setNext(next);
++modCount;
--count;
oldValue = v;
}
break;
}
pred = e;
e = next;
}
} finally {
unlock();
}
return oldValue;
}
删除操作很简单,注意一点,当入参value为null,只要key相同就删除,否则需要比较value和当前节点的值是否相同。
删除之前需要获取锁,如果通过tryLock获取锁失败,调用scanAndLock获取锁。scanAndLock通过tryLock尝试获取次数越过MAX_SCAN_RETRIES,则调用lock方法阻塞等待锁,显然,lock方法将引起线程上下文切换,增加额外开销。
Segment还有两个replace的重载方法和一个clear方法,代码逻辑都很简单,不再说明了。
接着看下ConcurrentHashMap的put方法:
4. put
该方法将指定key-value对添加到哈希表中,其中key和value都不能为null,看下源码:
public V put(K key, V value) {
Segment<K,V> s;
//value不能为null
if (value == null)
throw new NullPointerException();
//如果key为null,hash方法会抛出NPE异常
int hash = hash(key);
int j = (hash >>> segmentShift) & segmentMask;
if ((s = (Segment<K,V>)UNSAFE.getObject
(segments, (j << SSHIFT) + SBASE)) == null)
s = ensureSegment(j);
//委托给了segment的put方法
return s.put(key, hash, value, false);
}
通过hash方法得到键key的哈希值,将该哈希值右移segmentShift位后和segmentMask执行“与”操作,得到key对应的segment索引。
看下segmentShift和segmentMask,构造函数中初始化了这两个值:
this.segmentShift = 32 - sshift;
this.segmentMask = ssize - 1;
其中ssize是segments数组的大小,它是2的n次方,例如16、32、64等。sshift是数字n的大小,例如4、5、6等。获取key对应的segment段索引时,其实是通过键key哈希码的高sshift位来决定segment索引的。put方法最终委托给了segment的put方法,真正执行添加操作。
5. putAll
putAll将指定的Map添加到该ConcurrentHashMap,看下源码:
public void putAll(Map<? extends K, ? extends V> m) {
for (Map.Entry<? extends K, ? extends V> e : m.entrySet())
put(e.getKey(), e.getValue());
}
源码看简单,遍历指定的map,通过put方法一个个的添加。
6. get
该方法取key对应的value值,源码看着比较复杂:
public V get(Object key) {
Segment<K,V> s;
HashEntry<K,V>[] tab;
//取键的哈希码
int h = hash(key);
//segment的内存偏移
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
(tab = s.table) != null) {
//找到key所在HashEntry链表
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
e != null; e = e.next) {
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
//找到了,返回对应的value
return e.value;
}
}
//没有找到,返回null
return null;
}
7. containsKey
判断是否存在指定的key,看下源码:
public boolean containsKey(Object key) {
Segment<K,V> s;
HashEntry<K,V>[] tab;
//计算key对应的哈希码
int h = hash(key);
//计算key对应的segment内存偏移
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
(tab = s.table) != null) {
//找到key所在的桶,然后沿着链表查找
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
e != null; e = e.next) {
K k;
//找到了就返回true
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return true;
}
}
return false;
}
containsKey方法逻辑上也很简单,首先通过key的哈希码,找到所在的segment,然后找到该key在该segment所在的桶,在这个桶的链表上查找该key,如果找到,返回true,否则返回false。
8. containsValue
该方法查找ConcurrentHashMap是否存在指定的value,如果存在key对应的value为该值,返回true,否则返回false。注意,该方法无法快速定位到segment和桶,只能整个遍历ConcurrentHashMap并比较value值,因此相对于containsKey,该方法就显得很慢了。
9. size
该方法返回ConcurrentHashMap的key-value对数量。将每个segment的key-value对数量相加,如果相加后发现modCount和上次保存的modCount不一样,说明相加过程中有线程修改了ConcurrentHashMap,为了获取准确的size,需要重试。如果重试次数超过指定的次数,锁住所有的segment,然后再执行相加操作,确保相加过程中没有线程能够修改。
看下源码:
public int size() {
final Segment<K,V>[] segments = this.segments;
int size;
boolean overflow;
long sum;
long last = 0L;
//重试的次数
int retries = -1;
try {
//外循环,直到本次的sum和上次的sum相同为止,本次sum和上次sum相同,
//说明计算过程中,没有线程修改(改变modCount),计算的元素个数一定是准确的
for (;;) {
//重试超过指定次数(默认为2),将所有的segment都锁住,防止计算size
//过程被线程修改,元素个数计算完成后再解锁
if (retries++ == RETRIES_BEFORE_LOCK) {
for (int j = 0; j < segments.length; ++j)
ensureSegment(j).lock();
}
//保存所有segment的修改次数
sum = 0L;
//元素个数
size = 0;
overflow = false;
//遍历所有的segment,累加每个segment的元素个数
for (int j = 0; j < segments.length; ++j) {
Segment<K,V> seg = segmentAt(segments, j);
if (seg != null) {
sum += seg.modCount;
int c = seg.count;
if (c < 0 || (size += c) < 0)
overflow = true;
}
}
//相等,说明计算过程中没有线程进行修改操作,计算结果是正确的,跳出外循环,否则继续重试计算
if (sum == last)
break;
last = sum;
}
} finally {
//重试超过指定次数,将所有的segment解锁
if (retries > RETRIES_BEFORE_LOCK) {
for (int j = 0; j < segments.length; ++j)
segmentAt(segments, j).unlock();
}
}
//size溢出,返回Inter.MAX_VALUE
return overflow ? Integer.MAX_VALUE : size;
}
10. isEmpty
该方法判断ConcurrentHashMap的元素个数是否为0。遍历每个segment,若当前segment的元素个数不为0,返回false。否则将每个segment的modCount累加,结果计为sum。若sum等于0,说明遍历过程中元素个数未变,返回true。否则继续第二次遍历segment,同样的道理,遍历过程中若发现当前segment的元素个数不为0,返回false,否则将sum减去当前segment的modCount,若遍历结束后sum不等于0,说明第二次遍历过程,元素个数有改变,返回false。
看下源码:
public boolean isEmpty() {
long sum = 0L;
final Segment<K,V>[] segments = this.segments;
//第一次遍历segments
for (int j = 0; j < segments.length; ++j) {
Segment<K,V> seg = segmentAt(segments, j);
if (seg != null) {
//若当前segment的元素个数不等于0,说明不为空,返回false
if (seg.count != 0)
return false;
//累加modCount
sum += seg.modCount;
}
}
if (sum != 0L) {
//第二次遍历
for (int j = 0; j < segments.length; ++j) {
Segment<K,V> seg = segmentAt(segments, j);
if (seg != null) {
//同样的道理,若当前segment的元素个数不等于0,说明不为空,返回false
if (seg.count != 0)
return false;
//减去当前的modCount
sum -= seg.modCount;
}
}
//说明第二次遍历过程中元素个数有改变,认为不为空
if (sum != 0L)
return false;
}
return true;
}
为何要进行两次遍历操作?正常情况下应该将所有的segment锁住,遍历所有的segment,累加元素个数,判断是否为0,然后再解锁。但是为了减少加锁对性能的影响,采用两次操作来判断是否为空。这种方法可以避免加锁的性能影响,但是也会失去100%的正确性,某些情况下,该方法返回true并不真正意味着该Map为空。例如第一次遍历过程中,当遍历到第二个segment,有其他线程已经往第一个segment添加了元素,但是我们遍历第一个segment的时候,该segment的modCount为0,第一次遍历结束后我们可能得到sum等于0,返回了true。但是第一个segment的的元素个数已经不为0了。这也是ConcurrentHashMap的弱一致性表现,为了性能,这种折衷和妥协也是可以理解的。
11. 迭代器
ConcurrentHashMap的迭代器实现了Iterator接口,并继承了内部抽象类HashIterator。迭代器的功能都委托给了HashIterator。看下HashIterator源码:
abstract class HashIterator {
int nextSegmentIndex;
int nextTableIndex;
HashEntry<K,V>[] currentTable;
HashEntry<K, V> nextEntry;
HashEntry<K, V> lastReturned;
HashIterator() {
nextSegmentIndex = segments.length - 1;
nextTableIndex = -1;
advance();
}
//将nextEntry指向非空的HashEntry节点
final void advance() {
for (;;) {
if (nextTableIndex >= 0) {
if ((nextEntry = entryAt(currentTable,
nextTableIndex--)) != null)
break;
}
else if (nextSegmentIndex >= 0) {
Segment<K,V> seg = segmentAt(segments, nextSegmentIndex--);
if (seg != null && (currentTable = seg.table) != null)
nextTableIndex = currentTable.length - 1;
}
else
break;
}
}
//返回下一个节点,如果返回节点的下一个节点为空,需要调用advance找到下一个非空的节点
final HashEntry<K,V> nextEntry() {
HashEntry<K,V> e = nextEntry;
if (e == null)
throw new NoSuchElementException();
lastReturned = e;
//下一个节点为空,找到下一个非空的节点
if ((nextEntry = e.next) == null)
advance();
return e;
}
public final boolean hasNext() { return nextEntry != null; }
public final boolean hasMoreElements() { return nextEntry != null; }
//删除操作委托给了外部类的remove方法,注意删除节点后需要将lastReturned设置为null
public final void remove() {
if (lastReturned == null)
throw new IllegalStateException();
ConcurrentHashMap.this.remove(lastReturned.key);
lastReturned = null;
}
}
这里需要注意的是,构造函数和nextEntry方法中,需要确保下一个节点非空,如果为空,说明迭代器遍历结束了。
ConcurrentHashMap的迭代器,如KeyIterator、ValueIterator、EntryIterator等都继承了HashEntry,并委托给了HashIterator,方法都很简单,不再说明。
参考源码:jdk1.7.0_79
上一篇: JAVA筛选法求N以内的质数
下一篇: Mxnet练习: ResNet网络