前面已经了解了 AbstractQueuedSynchronizer 和 ReentrantLock 的原理,接下来看看共享锁组件 CountDownLatch 是如何基于 AbstractQueuedSynchronizer 实现,在我看来,如果能先了解 AbstractQueuedSynchronizer 的运作过程,那么再看 CountDownLatch 会觉得比较容易理解。

一、源码文档

Doug LeaCountDownLatch 做了如下定义:

A synchronization aid that allows one or more threads to wait until a set of operations being performed in other threads completes.

意思是这是一个允许一个或多个线程等待一批线程里的操作完成的同步化工具,按我理解这是一个异步转同步的工具,即可以理解成有一个负责分发任务线程,将一批任务分发给多个线程异步执行然后自己阻塞,直到异步线程全都执行完,然后任务分发线程被唤醒再继续往下执行。

CountDownLatch is initialized with a given count.The await methods block until the current count reaches zero due to invocations of the countDown method, after which all waiting threads are released and any subsequent invocations of await return immediately. This is a one-shot phenomenon – the count cannot be reset. If you need a version that resets the count, consider using a CyclicBarrier.

使用 CountDownLatch 时都需要先指定一个数量 count,调用 await 方法的线程将会阻塞直到 count 经过 countDown 方法将之减到 0,在此之后所有等待的线程都会被唤醒,并立即从 await 方法返回。因为 await 方法实际上是一个获取锁的操作,但是由于锁是共享的,所以此后任何线程都可以获取到锁,并且不需要考虑释放,所以 CountDownLatch 是一次性的,因为 count 不能被重置,如果需要复用 count,可以看看 CyclicBarrier

A CountDownLatch is a versatile synchronization tool and can be used for a number of purposes. A CountDownLatch initialized with a count of one serves as a simple on/off latch, or gate: all threads invoking await wait at the gate until it is opened by a thread invoking countDown. A CountDownLatch initialized to N can be used to make one thread wait until N threads have completed some action, or some action has been completed N times.

CountDownLatch 有很多用途,count 可以作为一个开关,或者把它理解成一道门,所以调用 await 的线程都在门开启。对于初始化 count 为 N 的 CountDownLatch ,可以让一个线程等待 N 条线程完成或者一个操作完成 N 次。

A useful property of a CountDownLatch is that it doesn’t require that threads calling countDown wait for the count to reach zero before proceeding, it simply prevents any thread from proceeding past an await until all threads could pass.

CountDownLatch 有个特性是,调用 countDown 方法的线程不需要等到 count 变为 0 才继续往下执行,它只会阻塞调用 await 的线程,让它等待所有的线程返回可通过的信号(通过 countDown)。

使用示例

Doug Lea 还给出了以下使用示例,比如:有一个 Driver 类,意为驱动线程,用来调度工作线程的执行,其中定义了两个 CountDownLatch ,第一个是一个启动信号,为了避免工作线程在驱动线程还没准备好的情况下执行。第二个 CountDownLatch 是一个完成信号,用于使驱动线程阻塞等待所有工作线程作业完成。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
public class Driver {
    private static final int N = 5;

    public static void main(String[] args) throws InterruptedException {
        CountDownLatch startSignal = new CountDownLatch(1);
        CountDownLatch doneSignal = new CountDownLatch(N);
        for (int i = 0; i < N; ++i) {
            new Thread(new Worker(startSignal, doneSignal)).start();
        }
        // don't let run yet
        doSomethingElse(); 
        // let all threads proceed
        startSignal.countDown();      
        doSomethingElse();
        // wait for all to finish
        doneSignal.await();           
    }

    static void doSomethingElse() {
        System.out.println("do sth else");
    }
}

在 Worker 中需要等待驱动线程在 startSignal 上调用 countDown 发出启动信号才能往下执行任务,有点像是百米竞跑的枪声一样。各自操作完成在发出完成信号,等所有线程都发出信号则驱动线程可以继续执行。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
public class Worker implements Runnable {
    private final CountDownLatch startSignal;
    private final CountDownLatch doneSignal;
    
    Worker(CountDownLatch startSignal, CountDownLatch doneSignal) {
        this.startSignal = startSignal;
        this.doneSignal = doneSignal;
    }
    
    @Override
    public void run() {
        try {
            startSignal.await();
            doWork();
            doneSignal.countDown();
        } catch (InterruptedException ex) {
            System.out.println("catch ex");
        }
    }
    
    void doWork() {
        System.out.println("doWork");
    }
}

CountDownLatch 的另一种用法是将问题分解成 N 份,每份对应一个 Runnable 任务,每个任务处理问题分区,并且将所有的任务丢人线程池里调度,执行完后进行 countDown。当所有的子任务执行完,Driver2 就可以继续往下执行。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import java.util.concurrent.*;

public class Driver2 {
    private static final int N = 5;

    public static void main(String[] args) throws InterruptedException {
        CountDownLatch doneSignal = new CountDownLatch(N);
        ThreadPoolExecutor executor = new ThreadPoolExecutor(
                N, 20, 0L, 
                TimeUnit.MILLISECONDS, new LinkedBlockingDeque<>(500), 
                r -> new Thread(r, "executor"));
        for (int i = 0; i < N; ++i) {
            executor.execute(new WorkerRunnable(doneSignal, i));
        }
        // wait for all to finish
        doneSignal.await();
    }
}

二、源码分析

通常使用 CountDownLatch 都需要在构造函数中指定一个数值,并赋值给同步状态 state。

1
2
3
4
public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

CountDownLatch 是基于 AbstractQueuedSynchronizer 的共享锁机制的,其中定义了 AbstractQueuedSynchronizer 的实现类 Sync,Sync 的构造函数必须指定一个 int 类型的数值,用于表示共享锁的线程数,或者可以理解成锁个数,通过指定该值调用 CAS 进行加锁。

1
2
3
4
5
6
7
8
private static final class Sync extends AbstractQueuedSynchronizer {
    Sync(int count) {
        setState(count);
    }
    int getCount() {
        return getState();
    }
}

既然使用了共享锁,自然地 Sync 中覆盖了 tryAcquireSharedtryReleaseShared 方法。tryAcquireShared 方法用于判断锁是否被其他线程持有,是则获取失败。因为 CountDownLatch 使用上指定了锁以及对应可持有线程数,后面再有线程来进行加锁则只能阻塞,直到共享锁的线程全部释放锁。

1
2
3
protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}

每次有一个线程释放锁,AQS 里面的 state 字段就会减 1,如果 state 减为 0,则表示该锁完全释放,其他线程可以尝试获取锁。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
        int c = getState();
        if (c == 0)
            return false;
        int nextc = c-1;
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

1. countDown 过程

countDown 方法调用的是 AbstractOwnableSynchronizer 里面的 releaseShared 模板方法,注意调用 CountDownLatch 构造就已经上了 N 层锁了(关门),而 countDown 就是一个加锁的过程(开门)。

1
2
3
public void countDown() {
    sync.releaseShared(1);
}

releaseShared 方法中会回调 Sync 中的钩子函数 tryReleaseShared,当 N 层锁都被释放时,整个锁被释放,tryReleaseShared 才返回 true,其他线程才能来获取锁并返回 true,否则都是返回 false(貌似不关心这个返回值,但是方法语义上是这样定义的)

1
2
3
4
5
6
7
public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}

再看 doReleaseShared 方法,当 N 层锁都被释放时,将会进入该方法,其中主要的判断:

  • 如果队列为空,判断哨兵的等待状态,如果是 0,则唤醒其后继节点,注意如果唤醒失败就 continue,但是唤醒成功仍然没有跳出循环,这时我们转换到 AQS 中线程阻塞的地方 acquireQueued 方法,第一个非哨兵非取消节点被唤醒,并将原头节点删除,所以下面的代码的当下一轮迭代进来时将会继续往后唤醒线程,直到整个队列为空。注意最后一个 h == head 判断,如果 unpark 唤醒的后继节点获取锁成功,那么头节点将会发生变化,则会再继续唤醒下一个节点,否则就会退出这里的循环,不再唤醒后面的节点。
  • 如果队列为空,则结束。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
private void doReleaseShared() {
    for (;;) {
        Node h = head;
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            if (ws == Node.SIGNAL) {
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) {
                    continue;            // loop to recheck cases
                }
                unparkSuccessor(h);
            }
            else if (ws == 0 && !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) {
                continue;                // loop on failed CAS
            }
        }
        // loop if head changed
        if (h == head) {
            break;
        }
    }
}

2. await 过程

await 方法会调用 AbstractQueuedSynchronizer 中的模板方法 acquireSharedInterruptibly,想想 CountDownLatch 的作用,它是一扇门,那么在工作线程没有 countDown 前,将要把调度线程阻塞,而且这个阻塞可唤醒,那么回顾 AbstractQueuedSynchronizer 及结合 countDown 是一个释放锁的过程,我们不难想到 await 其实是一个请求锁的过程。

1
2
3
public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

acquireSharedInterruptibly 首先会回调 CountDownLatch 中的钩子函数 tryAcquireSharedtryAcquireShared 只判断当前 state 是否为 0,也就是 N 层锁是否都已释放,是则返回 1 ,否则返回 -1,如果 N 层锁没完全释放则会继续调用 doAcquireSharedInterruptibly,它是 AbstractQueuedSynchronizer 中一个私有方法,并且只在此处被调用,其实它相当于 AbstractQueuedSynchronizer 中的 acquire 方法的 acquireQueued(addWaiter(Node.EXCLUSIVE), arg) 这段代码的作用。

1
2
3
4
5
6
7
8
public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
    if (Thread.interrupted()) {
        throw new InterruptedException();
    }
    if (tryAcquireShared(arg) < 0) {
        doAcquireSharedInterruptibly(arg);
    }
}

doAcquireSharedInterruptibly 主要经过以下几个步骤:

  1. 节点化当前线程并入队
  2. 如果当前节点的前继节点是哨兵,则去看 N 层锁是否都释放,是的话会返回 1,即 r = 1,然后继续调用 setHeadAndPropagate 方法,setHeadAndPropagate 比较复杂,主要的作用是将等待队列中此刻在等待的线程逐一唤醒去获取锁。 注意下面方法获取锁调用的是 Sync 中的 tryAcquireShared 方法,但是与 acquireQueued 方法不同,它并没有通过调用 CAS 方法改变 state 的值,所以不同担心 state 的整型越界问题。
  3. 如果当前节点的前继节点不是哨兵。则处理方式与 acquireQueued 方法一样,主要将前继节点的等待状态记为 -1,然后阻塞。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
private void doAcquireSharedInterruptibly(int arg) throws InteinrruptedException {
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            final Node p = node.predecessor();
            if (p == head) {
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt()) {
                throw new InterruptedException();
            }
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

三、常见问题

下面对于 CountDownLatch 的使用中遇到的一些问题进行总结。

1. 线程没有正常 countDown

假设下面有三条线程,其中 t3 中会抛出异常,从而导致后面代码中的 countDown 代码没有执行到,从而导致调度线程没有正常释放。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import java.util.concurrent.TimeUnit;

public class App {
    static CountDownLatch latch = new CountDownLatch(3);
    public static void main(String[] args) {
        for (int i = 1; i < 4; i++) {
            final int finalI = i;
            Thread thread = new Thread(() -> {
                try {
                    System.out.println("Thread-" + finalI +  " is running");
                    TimeUnit.MILLISECONDS.sleep(finalI * 1000);
                    if (finalI == 3) {
                        int num = 1 / 0;
                    }
                    latch.countDown();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } finally {
                    // latch.countDown();
                }
            });
            thread.setName("t" + finalI);
            thread.start();
        }
        try {
            long start = System.currentTimeMillis();
            latch.await();
            System.out.println("阻塞耗时 " + (System.currentTimeMillis() - start) + " ms ");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

我们看到输出是这样的,main 线程进入等待队列后一直处于阻塞状态,并切这种状况是外部无法干预的,其后的代码也一直无法执行到。所以一般来说使用 CountDownLatch 时需要在 finally 中调用 countDown 方法,以保证各个子锁都能释放,要不然调度线程会一直阻塞。

Thread-1 is running
main阻塞, head: [-1,null] => [0,main]
Thread-2 is running
Thread-3 is running
Exception in thread "t3" java.lang.ArithmeticException: / by zero
	at App1.lambda$main$0(App1.java:14)
	at java.lang.Thread.run(Thread.java:748)

2. count 数大于 countDown 的线程数

如果把 App 中的代码 static CountDownLatch latch = new CountDownLatch(3); 中的 3 改成 4,那么结果是一样的,main 线程同样会因为差一个 countDown 而一直无法被唤醒。

3. 没有调用 await 方法

我曾经遇到过这样的线上问题,下面给出代码说明:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
public class DataLoader {
    public void load() {
        List<SomeKindOfData> list = ...
        boolean locked = RedisLock.lock(bizCode);
        if (locked) {
            try {
                driver.drive(list);
            } catech(Exception e) {
                //...
            } finally {
                RedisLock.unLock(bizCode);
            }
        }
    }
}

Driver 中的各个子操作也都调用了 countDown 方法,但是 Driver 里却没有调用 await,从而导致 DataLoader 中的分布式锁提前释放,但是由于 Driver 中的各子操作耗时很大,所以迟迟没执行完,这时由于临界数据没有锁住,下一个请求过来时还会加载到同样的待处理数据,从而出现数据错乱问题。

1
2
3
4
5
6
7
8
public class Driver {
    public void drive(List<SomeKindOfData> list) {
        CountDownLatch latch = new CountDownLatch(3);
        doProcOne(executor, latch);
        doProcTow(executor, latch);
        doProcThree(executor, latch);
    }
}