由于 JDK 1.7 的 ConcurrentHashMap 的实现使用到了 ReentrantLock,刚好前面也已经看过了 ReentrantLock 的实现,所以顺势看下 1.7 版本的 ConcurrentHashMap 是如何实现的。

一. 简介

与 Hashtable 使用 synchronized 关键字对读写方法进行加锁不同,JDK 1.7 的 ConcurrentHashMap(以下简称 ConcurrentHashMap)使用了分段锁的思想,它将数据散列成一个 Segment 数组,每个 Segment 对象实际类似一个 HashMap,它们各自持有一个 HashEntry 数组,实际的键值对数据是存放到 HashEntry 数组中的,如果发生哈希冲突,则通过在 HashEntry 对象上构建链表存放哈希值一样的键值对。

特别地,Segment 继承自 ReentrantLock,说明一个 Segment 就是一把锁,每个想要往 ConcurrentHashMap 写数据的线程都需要拿到相应分段的锁才行,而拿不到锁的线程,则需要重试与自旋,期间可以预先构建好需要的节点信息,然后如果有限次尝试后还是没能拿到锁,则会加入等待队列然后进行阻塞。

二. 构造函数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
public ConcurrentHashMap() {
    this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}
/**
 * The default initial capacity for this table,
 * used when not otherwise specified in a constructor.
 */
static final int DEFAULT_INITIAL_CAPACITY = 16;

/**
 * The default load factor for this table, used when not
 * otherwise specified in a constructor.
 */
static final float DEFAULT_LOAD_FACTOR = 0.75f;
/**
 * The default concurrency level for this table, used when not
 * otherwise specified in a constructor.
 */
static final int DEFAULT_CONCURRENCY_LEVEL = 16;

归根到底都是调用下面的这个构造方法,主要以默认的入参进行分析

  1. 检查入参的合法性
  2. concurrencyLevelSegment 数组的初始长度的计算有关,因为 Segment 数组的长度 ssize 必须是 2 的次幂,所以在计算 ssize 是会进行控制,如输入 concurrencyLevel 为 16,则 ssize 会从 1 开始右移直到确保值不比 concurrencyLevel 小,所以 sshift 为 4,ssize 等于 16,但如果 concurrencyLevel 等于 17,则 sshift 为 5,ssize 等于 32。
  3. segmentShift 赋值为 32 - 4 = 28,segmentMask 赋值为 16 - 1= 15。
  4. initialCapacity 为总共能够存放的键值对数,均摊到每个 Segment 中,则能知道每个 Segment 至少应该存放多少个,然后由于传进来的 initialCapacity 可能很大,concurrencyLevel 比较小,为了保证每个 Segment 都能有足够的数组空间存放键值对,避免哈希冲突,需要对 c 进行向上取整,然后为了保证 HashEntry 数组的长度是 2 的次幂,需要对 cap 进行左移运算。
  5. 然后创建一个 Segment 数组 ss,长度为 16,另外还创建了一个 Segment 对象 s0 ,它的 loadFactor 为 0.75,扩容阈值 threshold 为 12,创建 HashEntry 数组且长度为 2,然后通过 Unsafe 操作内存,将 s0 放到 ss 的起始地址上,刚好占满第一个索引位置,即 s0 = ss[0]。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
public ConcurrentHashMap(int initialCapacity,
                         float loadFactor, int concurrencyLevel) {
    if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
        throw new IllegalArgumentException();
    if (concurrencyLevel > MAX_SEGMENTS)
        concurrencyLevel = MAX_SEGMENTS;
    // Find power-of-two sizes best matching arguments
    int sshift = 0;
    int ssize = 1;
    while (ssize < concurrencyLevel) {
        ++sshift;
        ssize <<= 1;
    }
    this.segmentShift = 32 - sshift;
    this.segmentMask = ssize - 1;
    if (initialCapacity > MAXIMUM_CAPACITY)
        initialCapacity = MAXIMUM_CAPACITY;
    int c = initialCapacity / ssize;
    if (c * ssize < initialCapacity)
        ++c;
    int cap = MIN_SEGMENT_TABLE_CAPACITY;
    while (cap < c)
        cap <<= 1;
    // create segments and segments[0]
    Segment<K,V> s0 =
        new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                         (HashEntry<K,V>[])new HashEntry[cap]);
    Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
    UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
    this.segments = ss;
}

/**
 * The minimum capacity for per-segment tables.  Must be a power
 * of two, at least two to avoid immediate resizing on next use
 * after lazy construction.
 * 每个 segment 存放的 Entry 数组的最小长度,至少是 2 且为 2 的次幂
 */
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;

所以调用完默认的无参构造函数后,得到的 ConcurrentHashMap 数据结构状态如下图所示:

Snipaste_2021-08-28_01-31-42.png

三. 添加元素

添加元素操作在 put 方法中实现,主要过程:

  1. 不允许键值对的值为空的情况。
  2. 计算得到扰动后的哈希值,采用了 Wang/Jenkins hash 算法变体。
  3. hash >>> segmentShift 哈希值无符号右移 28 位,这时高 4 位移到了低 4 位,然后通过 segmentMask 掩码屏蔽新的高 28 位,即只关心新的低 4 位,也就是让 hash 值的高 n 位参与运算,这个 n 与 Segment 数组的长度对应,其实也就是为了增大哈希值的随机性,作二次扰动。
  4. 获取 Segment 数组的第 j 个元素,如果该位置的元素还没有初始化,则调用 ensureSegment 初始化该位置的 Segment 对象,然后再将键值对放到里面。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
public V put(K key, V value) {
    Segment<K,V> s;
    if (value == null)  throw new NullPointerException();
    int hash = hash(key);
    int j = (hash >>> segmentShift) & segmentMask;
    if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
         (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
        s = ensureSegment(j);
    return s.put(key, hash, value, false);
}

ensureSegment 入参是 Segment 数组的下标,顾名思义就是保证 Segment 数组该下标的元素一定要初始化完成。主要过程

  1. 获取该位置的 Segment 对象,如果是空的,则使用 ss[0] 作为复制的原型。
  2. 基于原型的参数,计算与初始化新 Segment 的参数。
  3. 由于此前可能有另一条线程初始化了该位置的 Segment,所以需要进行二次检查。
  4. 校验通过则创建新的 Segment 对象 s,不停自旋,直到 CAS 成功将该位置从 null 设置为 s,并返回 s;或者有另一条并发的线程往该位置设置了另外一个新的 Segment 对象 s1,则退出自旋,并返回 s1。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
private Segment<K,V> ensureSegment(int k) {
    final Segment<K,V>[] ss = this.segments;
    long u = (k << SSHIFT) + SBASE; // raw offset
    Segment<K,V> seg;
    if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
        Segment<K,V> proto = ss[0]; // use segment 0 as prototype
        int cap = proto.table.length;
        float lf = proto.loadFactor;
        int threshold = (int)(cap * lf);
        HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
            == null) { // recheck
            Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
            while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                   == null) {
                if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                    break;
            }
        }
    }
    return seg;
}

然后调用 Segment 的 put 方法,将键值对放到具体的 HashEntry 数组中,主要过程

img

  1. 通过 tryLock 方法尝试获取锁,变量 node 代表插入的键值对的信息,如果获取锁成功则在临界区内寻找插入位置,否则调用 scanAndLockForPut 预创建 HashEntry 节点。
  2. 当 node 为 null:即一次 CAS 就能获取到锁,则通过哈希值对应节点在 HashEntry 数组中的位置,然后遍历其冲突链表
    1. 如果遍历期间找到 key 相等的键值对,则更新 value 然后退出。
    2. 如果遍历完整个链表都没找到可以更新键值对,则将节点插入到链表的头部,插入后可能超过 Segment 的容量限制,则需要检查是否需要扩容,是则扩容。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
    HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value);
    V oldValue;
    try {
        HashEntry<K,V>[] tab = table;
        int index = (tab.length - 1) & hash;
        HashEntry<K,V> first = entryAt(tab, index);
        for (HashEntry<K,V> e = first;;) {
            if (e != null) {
                K k;
                if ((k = e.key) == key ||
                    (e.hash == hash && key.equals(k))) {
                    oldValue = e.value;
                    if (!onlyIfAbsent) {
                        e.value = value;
                        ++modCount;
                    }
                    break;
                }
                e = e.next;
            }
            else {
                if (node != null) {
                    node.setNext(first);
                } else {
                    node = new HashEntry<K,V>(hash, key, value, first);
                }
                int c = count + 1;
                if (c > threshold && tab.length < MAXIMUM_CAPACITY) {
                    rehash(node);
                } else {
                    setEntryAt(tab, index, node);
                }
                ++modCount;
                count = c;
                oldValue = null;
                break;
            }
        }
    } finally {
        unlock();
    }
    return oldValue;
}

下面再看 scanAndLockForPut 方法,获取不到锁的线程将进入到这里,

img ]

  1. 获取目标插入到 HashEntry 数组的节点 e
  2. 自旋尝试加锁,如果失败
    1. 如果第一次重试(retries = -1),
      1. 如果 e 没初始化,则预初始化一个节点并让 node 指向它, retries++,然后重试
      2. 如果 e 的 key 与要插入的 key 相同,则为更新,但由于没获取到锁,不做任何动作,仅 retries++,然后重试
      3. 以上两种情况都不是,则继续往后遍历,但不会有 retries++,即下一次自旋还是会进入到 retries < 0 的 if 中
    2. 第 n 次重试(-1 < retries < MAX_SCAN_RETRIES):如果 retries 的次数超过了 MAX_SCAN_RETRIES,则调用父类 ReentrantLock 的 lock 方法并跳出死循环,当然 lock 方法中还会进行几次重试,如果还是不成功,则进入等待队列并阻塞。值得一提的是,MAX_SCAN_RETRIES 的值与当前操作系统的处理器数有关。
    3. (retries & 1) == 0 表示偶数次重试:尝试获取目标 HashEntry 数组的入口,如果已经变了,即不等于一开始拿到的 first,说明别的线程将该位置初始化,那么当前线程前期初始化的 HashEntry 需要作废,所以将 retries 改回 -1,等下一次死循环将进入 retries < 0 的 if 重新生成新的 HashEntry 节点并继续自旋重试。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
    HashEntry<K,V> first = entryForHash(this, hash);
    HashEntry<K,V> e = first;
    HashEntry<K,V> node = null;
    int retries = -1; // negative while locating node
    while (!tryLock()) {
        HashEntry<K,V> f; // to recheck first below
        if (retries < 0) {
            if (e == null) {
                if (node == null) // speculatively create node
                    node = new HashEntry<K,V>(hash, key, value, null);
                retries = 0;
            }
            else if (key.equals(e.key))
                retries = 0;
            else
                e = e.next;
        }
        else if (++retries > MAX_SCAN_RETRIES) {
            lock();
            break;
        }
        else if ((retries & 1) == 0 &&
                 (f = entryForHash(this, hash)) != first) {
            e = first = f; // re-traverse if entry changed
            retries = -1;
        }
    }
    return node;
}

static final int MAX_SCAN_RETRIES = Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;

entryForHash 用于获取 hash 值对应在当前 Segment 的 HashEntry 节点,如 hash = 16,则 (tab.length - 1) & h) = 0,即获取 tab[0] 的节点,如果 table 没初始化或者 tab[0] 未初始化都会返回 null,否则返回具体的 HashEntry 节点。

1
2
3
4
5
6
static final <K,V> HashEntry<K,V> entryForHash(Segment<K,V> seg, int h) {
    HashEntry<K,V>[] tab;
    return (seg == null || (tab = seg.table) == null) ? null :
    (HashEntry<K,V>) UNSAFE.getObjectVolatile
        (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
}

四. 扩容

上面的 put 方法中提到,在插入一个键值对前,需要先检查容量是否超过阈值,如果是,则需扩容,下面是扩容的方法逻辑,其中入参是当前插入的节点。注意这里的扩容指的是 Segment 内部的扩容,而不是整个 ConcurrentHashMap 的扩容,也就是说 Segment 数组的长度是不可变的了,但是每个 Segment 内部的 HashEntry 数组可以进行扩展。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
private void rehash(HashEntry<K,V> node) {
    HashEntry<K,V>[] oldTable = table;
    int oldCapacity = oldTable.length;
    int newCapacity = oldCapacity << 1;
    threshold = (int)(newCapacity * loadFactor);
    HashEntry<K,V>[] newTable = (HashEntry<K,V>[]) new HashEntry[newCapacity];
    int sizeMask = newCapacity - 1;
    for (int i = 0; i < oldCapacity ; i++) {
        HashEntry<K,V> e = oldTable[i];
        if (e != null) {
            HashEntry<K,V> next = e.next;
            int idx = e.hash & sizeMask;
            if (next == null)   //  Single node on list
                newTable[idx] = e;
            else { // Reuse consecutive sequence at same slot
                HashEntry<K,V> lastRun = e;
                int lastIdx = idx;
                for (HashEntry<K,V> last = next; last != null; last = last.next) {
                    int k = last.hash & sizeMask;
                    if (k != lastIdx) {
                        lastIdx = k;
                        lastRun = last;
                    }
                }
                newTable[lastIdx] = lastRun;
                // Clone remaining nodes
                for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                    V v = p.value;
                    int h = p.hash;
                    int k = h & sizeMask;
                    HashEntry<K,V> n = newTable[k];
                    newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                }
            }
        }
    }
    int nodeIndex = node.hash & sizeMask; // add the new node
    node.setNext(newTable[nodeIndex]);
    newTable[nodeIndex] = node;
    table = newTable;
}

上述代码的大致过程:

  1. 基于旧 table 计算新的 table 的容量(原来的两倍)、扩容阈值、容量掩码等;
  2. 遍历 HashEntry 数组,计算每个节点 e 在新数组中的索引 idx(通过新的容量掩码)
    1. 如果 e 的后继为 null,说明只有链表中只有一个节点,则直接将 e 放入新的数组中的新位置处(newTable[idx])
    2. 如果后继还有非空节点,则往后遍历
      1. 找到最后一个与前面节点的哈希结果不一样的节点 lastRun,也就是说 lastRun 后面的节点都是要搬到新数组的同一个索引里面的,为了省事,直接一整条链表搬过去。
      2. 然后对于 lastRun 前面的节点,由于它们去往的 “桶” 不一样,所以需要一个个头插到相应的 “桶口” 位置。
  3. 处理完历史的节点,需要将 node 节点放在新的 HashEntry 数组里面,结束。

五. 查询元素

相比于写入,查询元素的逻辑比较简单,注意这里没有加锁,只是获取 Segment 和 HashEntry 时都用到了 getObjectVolatile 方法,字面上理解,getObjectVolatileVolatile 字样,则其他线程对相应对象的改动能被当前线程感知到。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
public V get(Object key) {
    Segment<K,V> s; // manually integrate access methods to reduce overhead
    HashEntry<K,V>[] tab;
    int h = hash(key);
    long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
    if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
        (tab = s.table) != null) {
        for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
             (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
             e != null; e = e.next) {
            K k;
            if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                return e.value;
        }
    }
    return null;
}

上面代码比较好理解,主要过程为:

  1. 计算哈希值 h,并通过 h 计算到目标的 Segment 数组下标
  2. 如果 Segment 节点已经初始化且其中的 HashEntry 数组 tab 也初始化了,则获取计算目标节点在 tab 的索引 j,并拿到 e=tab[j],遍历以 e 为头节点的链表,如果找到目标,则返回,结束。