由于 JDK 1.7 的 ConcurrentHashMap 的实现使用到了 ReentrantLock,刚好前面也已经看过了 ReentrantLock 的实现,所以顺势看下 1.7 版本的 ConcurrentHashMap 是如何实现的。
一. 简介
与 Hashtable 使用 synchronized 关键字对读写方法进行加锁不同,JDK 1.7 的 ConcurrentHashMap(以下简称 ConcurrentHashMap)使用了分段锁的思想,它将数据散列成一个 Segment 数组,每个 Segment 对象实际类似一个 HashMap,它们各自持有一个 HashEntry 数组,实际的键值对数据是存放到 HashEntry 数组中的,如果发生哈希冲突,则通过在 HashEntry 对象上构建链表存放哈希值一样的键值对。
特别地,Segment 继承自 ReentrantLock,说明一个 Segment 就是一把锁,每个想要往 ConcurrentHashMap 写数据的线程都需要拿到相应分段的锁才行,而拿不到锁的线程,则需要重试与自旋,期间可以预先构建好需要的节点信息,然后如果有限次尝试后还是没能拿到锁,则会加入等待队列然后进行阻塞。
二. 构造函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
public ConcurrentHashMap() {
this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}
/**
* The default initial capacity for this table,
* used when not otherwise specified in a constructor.
*/
static final int DEFAULT_INITIAL_CAPACITY = 16;
/**
* The default load factor for this table, used when not
* otherwise specified in a constructor.
*/
static final float DEFAULT_LOAD_FACTOR = 0.75f;
/**
* The default concurrency level for this table, used when not
* otherwise specified in a constructor.
*/
static final int DEFAULT_CONCURRENCY_LEVEL = 16;
|
归根到底都是调用下面的这个构造方法,主要以默认的入参进行分析
- 检查入参的合法性
concurrencyLevel
与 Segment
数组的初始长度的计算有关,因为 Segment
数组的长度 ssize
必须是 2 的次幂,所以在计算 ssize
是会进行控制,如输入 concurrencyLevel
为 16,则 ssize
会从 1 开始右移直到确保值不比 concurrencyLevel
小,所以 sshift
为 4,ssize
等于 16,但如果 concurrencyLevel
等于 17,则 sshift
为 5,ssize
等于 32。
segmentShift
赋值为 32 - 4 = 28,segmentMask
赋值为 16 - 1= 15。
initialCapacity
为总共能够存放的键值对数,均摊到每个 Segment 中,则能知道每个 Segment 至少应该存放多少个,然后由于传进来的 initialCapacity
可能很大,concurrencyLevel
比较小,为了保证每个 Segment 都能有足够的数组空间存放键值对,避免哈希冲突,需要对 c 进行向上取整,然后为了保证 HashEntry
数组的长度是 2 的次幂,需要对 cap 进行左移运算。
- 然后创建一个 Segment 数组
ss
,长度为 16,另外还创建了一个 Segment 对象 s0
,它的 loadFactor
为 0.75,扩容阈值 threshold
为 12,创建 HashEntry
数组且长度为 2,然后通过 Unsafe 操作内存,将 s0 放到 ss 的起始地址上,刚好占满第一个索引位置,即 s0 = ss[0]。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
|
public ConcurrentHashMap(int initialCapacity,
float loadFactor, int concurrencyLevel) {
if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException();
if (concurrencyLevel > MAX_SEGMENTS)
concurrencyLevel = MAX_SEGMENTS;
// Find power-of-two sizes best matching arguments
int sshift = 0;
int ssize = 1;
while (ssize < concurrencyLevel) {
++sshift;
ssize <<= 1;
}
this.segmentShift = 32 - sshift;
this.segmentMask = ssize - 1;
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
int c = initialCapacity / ssize;
if (c * ssize < initialCapacity)
++c;
int cap = MIN_SEGMENT_TABLE_CAPACITY;
while (cap < c)
cap <<= 1;
// create segments and segments[0]
Segment<K,V> s0 =
new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
(HashEntry<K,V>[])new HashEntry[cap]);
Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
this.segments = ss;
}
/**
* The minimum capacity for per-segment tables. Must be a power
* of two, at least two to avoid immediate resizing on next use
* after lazy construction.
* 每个 segment 存放的 Entry 数组的最小长度,至少是 2 且为 2 的次幂
*/
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
|
所以调用完默认的无参构造函数后,得到的 ConcurrentHashMap
数据结构状态如下图所示:
三. 添加元素
添加元素操作在 put 方法中实现,主要过程:
- 不允许键值对的值为空的情况。
- 计算得到扰动后的哈希值,采用了 Wang/Jenkins hash 算法变体。
hash >>> segmentShift
哈希值无符号右移 28 位,这时高 4 位移到了低 4 位,然后通过 segmentMask 掩码屏蔽新的高 28 位,即只关心新的低 4 位,也就是让 hash 值的高 n 位参与运算,这个 n 与 Segment 数组的长度对应,其实也就是为了增大哈希值的随机性,作二次扰动。
- 获取 Segment 数组的第 j 个元素,如果该位置的元素还没有初始化,则调用
ensureSegment
初始化该位置的 Segment 对象,然后再将键值对放到里面。
1
2
3
4
5
6
7
8
9
10
|
public V put(K key, V value) {
Segment<K,V> s;
if (value == null) throw new NullPointerException();
int hash = hash(key);
int j = (hash >>> segmentShift) & segmentMask;
if ((s = (Segment<K,V>)UNSAFE.getObject // nonvolatile; recheck
(segments, (j << SSHIFT) + SBASE)) == null) // in ensureSegment
s = ensureSegment(j);
return s.put(key, hash, value, false);
}
|
ensureSegment
入参是 Segment 数组的下标,顾名思义就是保证 Segment 数组该下标的元素一定要初始化完成。主要过程
- 获取该位置的 Segment 对象,如果是空的,则使用 ss[0] 作为复制的原型。
- 基于原型的参数,计算与初始化新 Segment 的参数。
- 由于此前可能有另一条线程初始化了该位置的 Segment,所以需要进行二次检查。
- 校验通过则创建新的 Segment 对象 s,不停自旋,直到 CAS 成功将该位置从 null 设置为 s,并返回 s;或者有另一条并发的线程往该位置设置了另外一个新的 Segment 对象 s1,则退出自旋,并返回 s1。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
private Segment<K,V> ensureSegment(int k) {
final Segment<K,V>[] ss = this.segments;
long u = (k << SSHIFT) + SBASE; // raw offset
Segment<K,V> seg;
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
Segment<K,V> proto = ss[0]; // use segment 0 as prototype
int cap = proto.table.length;
float lf = proto.loadFactor;
int threshold = (int)(cap * lf);
HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) { // recheck
Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) {
if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
break;
}
}
}
return seg;
}
|
然后调用 Segment 的 put 方法,将键值对放到具体的 HashEntry 数组中,主要过程
- 通过
tryLock
方法尝试获取锁,变量 node 代表插入的键值对的信息,如果获取锁成功则在临界区内寻找插入位置,否则调用 scanAndLockForPut
预创建 HashEntry
节点。
- 当 node 为 null:即一次
CAS
就能获取到锁,则通过哈希值对应节点在 HashEntry
数组中的位置,然后遍历其冲突链表
- 如果遍历期间找到 key 相等的键值对,则更新 value 然后退出。
- 如果遍历完整个链表都没找到可以更新键值对,则将节点插入到链表的头部,插入后可能超过 Segment 的容量限制,则需要检查是否需要扩容,是则扩容。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
|
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value);
V oldValue;
try {
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash;
HashEntry<K,V> first = entryAt(tab, index);
for (HashEntry<K,V> e = first;;) {
if (e != null) {
K k;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
oldValue = e.value;
if (!onlyIfAbsent) {
e.value = value;
++modCount;
}
break;
}
e = e.next;
}
else {
if (node != null) {
node.setNext(first);
} else {
node = new HashEntry<K,V>(hash, key, value, first);
}
int c = count + 1;
if (c > threshold && tab.length < MAXIMUM_CAPACITY) {
rehash(node);
} else {
setEntryAt(tab, index, node);
}
++modCount;
count = c;
oldValue = null;
break;
}
}
} finally {
unlock();
}
return oldValue;
}
|
下面再看 scanAndLockForPut
方法,获取不到锁的线程将进入到这里,
]
- 获取目标插入到 HashEntry 数组的节点 e
- 自旋尝试加锁,如果失败
- 如果第一次重试(retries = -1),
- 如果 e 没初始化,则预初始化一个节点并让 node 指向它, retries++,然后重试
- 如果 e 的 key 与要插入的 key 相同,则为更新,但由于没获取到锁,不做任何动作,仅 retries++,然后重试
- 以上两种情况都不是,则继续往后遍历,但不会有 retries++,即下一次自旋还是会进入到
retries < 0
的 if 中
- 第 n 次重试(-1 < retries < MAX_SCAN_RETRIES):如果 retries 的次数超过了
MAX_SCAN_RETRIES
,则调用父类 ReentrantLock
的 lock 方法并跳出死循环,当然 lock 方法中还会进行几次重试,如果还是不成功,则进入等待队列并阻塞。值得一提的是,MAX_SCAN_RETRIES
的值与当前操作系统的处理器数有关。
(retries & 1) == 0
表示偶数次重试:尝试获取目标 HashEntry 数组的入口,如果已经变了,即不等于一开始拿到的 first,说明别的线程将该位置初始化,那么当前线程前期初始化的 HashEntry 需要作废,所以将 retries 改回 -1,等下一次死循环将进入 retries < 0
的 if 重新生成新的 HashEntry 节点并继续自旋重试。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
HashEntry<K,V> first = entryForHash(this, hash);
HashEntry<K,V> e = first;
HashEntry<K,V> node = null;
int retries = -1; // negative while locating node
while (!tryLock()) {
HashEntry<K,V> f; // to recheck first below
if (retries < 0) {
if (e == null) {
if (node == null) // speculatively create node
node = new HashEntry<K,V>(hash, key, value, null);
retries = 0;
}
else if (key.equals(e.key))
retries = 0;
else
e = e.next;
}
else if (++retries > MAX_SCAN_RETRIES) {
lock();
break;
}
else if ((retries & 1) == 0 &&
(f = entryForHash(this, hash)) != first) {
e = first = f; // re-traverse if entry changed
retries = -1;
}
}
return node;
}
static final int MAX_SCAN_RETRIES = Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;
|
entryForHash
用于获取 hash 值对应在当前 Segment 的 HashEntry 节点,如 hash = 16,则 (tab.length - 1) & h) = 0
,即获取 tab[0] 的节点,如果 table 没初始化或者 tab[0] 未初始化都会返回 null,否则返回具体的 HashEntry 节点。
1
2
3
4
5
6
|
static final <K,V> HashEntry<K,V> entryForHash(Segment<K,V> seg, int h) {
HashEntry<K,V>[] tab;
return (seg == null || (tab = seg.table) == null) ? null :
(HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
}
|
四. 扩容
上面的 put 方法中提到,在插入一个键值对前,需要先检查容量是否超过阈值,如果是,则需扩容,下面是扩容的方法逻辑,其中入参是当前插入的节点。注意这里的扩容指的是 Segment 内部的扩容,而不是整个 ConcurrentHashMap 的扩容,也就是说 Segment 数组的长度是不可变的了,但是每个 Segment 内部的 HashEntry 数组可以进行扩展。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
|
private void rehash(HashEntry<K,V> node) {
HashEntry<K,V>[] oldTable = table;
int oldCapacity = oldTable.length;
int newCapacity = oldCapacity << 1;
threshold = (int)(newCapacity * loadFactor);
HashEntry<K,V>[] newTable = (HashEntry<K,V>[]) new HashEntry[newCapacity];
int sizeMask = newCapacity - 1;
for (int i = 0; i < oldCapacity ; i++) {
HashEntry<K,V> e = oldTable[i];
if (e != null) {
HashEntry<K,V> next = e.next;
int idx = e.hash & sizeMask;
if (next == null) // Single node on list
newTable[idx] = e;
else { // Reuse consecutive sequence at same slot
HashEntry<K,V> lastRun = e;
int lastIdx = idx;
for (HashEntry<K,V> last = next; last != null; last = last.next) {
int k = last.hash & sizeMask;
if (k != lastIdx) {
lastIdx = k;
lastRun = last;
}
}
newTable[lastIdx] = lastRun;
// Clone remaining nodes
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
V v = p.value;
int h = p.hash;
int k = h & sizeMask;
HashEntry<K,V> n = newTable[k];
newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
}
}
}
}
int nodeIndex = node.hash & sizeMask; // add the new node
node.setNext(newTable[nodeIndex]);
newTable[nodeIndex] = node;
table = newTable;
}
|
上述代码的大致过程:
- 基于旧 table 计算新的 table 的容量(原来的两倍)、扩容阈值、容量掩码等;
- 遍历 HashEntry 数组,计算每个节点 e 在新数组中的索引 idx(通过新的容量掩码)
- 如果 e 的后继为 null,说明只有链表中只有一个节点,则直接将 e 放入新的数组中的新位置处(newTable[idx])
- 如果后继还有非空节点,则往后遍历
- 找到最后一个与前面节点的哈希结果不一样的节点 lastRun,也就是说 lastRun 后面的节点都是要搬到新数组的同一个索引里面的,为了省事,直接一整条链表搬过去。
- 然后对于 lastRun 前面的节点,由于它们去往的 “桶” 不一样,所以需要一个个头插到相应的 “桶口” 位置。
- 处理完历史的节点,需要将 node 节点放在新的 HashEntry 数组里面,结束。
五. 查询元素
相比于写入,查询元素的逻辑比较简单,注意这里没有加锁,只是获取 Segment 和 HashEntry 时都用到了 getObjectVolatile
方法,字面上理解,getObjectVolatile
有 Volatile
字样,则其他线程对相应对象的改动能被当前线程感知到。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
public V get(Object key) {
Segment<K,V> s; // manually integrate access methods to reduce overhead
HashEntry<K,V>[] tab;
int h = hash(key);
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
(tab = s.table) != null) {
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
e != null; e = e.next) {
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return e.value;
}
}
return null;
}
|
上面代码比较好理解,主要过程为:
- 计算哈希值 h,并通过 h 计算到目标的 Segment 数组下标
- 如果 Segment 节点已经初始化且其中的 HashEntry 数组 tab 也初始化了,则获取计算目标节点在 tab 的索引 j,并拿到 e=tab[j],遍历以 e 为头节点的链表,如果找到目标,则返回,结束。