CountDownLatch

CountDownLatch是基于AQS实现的,它使用AQS中的state成员变量作为计数器,在state不为0的情况下,凡是调用await()方法的线程将会被阻塞,并放入AQS维护的同步队列中,而当state减至0时,队列中的节点会被唤醒,被阻塞的线程即可恢复运行。先来看看它的构造函数:

1
2
3
4
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}

可以看出,它创建了一个Sync对象,并将参数传入,这个参数就是计数器的值。因此,关于CountDownLatch的分析将从这个Sync类开始。

Sync

CountDownLatch中有个Sync内部类,它实现了AQS中的几个重要方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
private static final class Sync extends AbstractQueuedSynchronizer {

Sync(int count) {
setState(count); // 设置 AQS 的 state 变量,也就是计数器的值
}

int getCount() {
return getState(); // 获取 AQS 的 state 变量值,也就是计数器的值
}

// 该方法主要是在 await() 中用到
protected int tryAcquireShared(int acquires) {
// 调用 getState() 方法获取 state 变量的值,
// 如果等于0,则返回正数,后续将不会阻塞线程
// 如果不等于0,则返回负数,后续将会阻塞线程
return (getState() == 0) ? 1 : -1;
}

// 该方法主要是在 countDown() 中用到
protected boolean tryReleaseShared(int releases) {
// 因为可能有多个线程同时调用该方法
// 所以这里使用 CAS + 循环的方式保证线程安全
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc)) // CAS 将 state 的值减一
return nextc == 0; // 如果减到0了,就返回true
}
}
}

await()

使用CountDownLatch同步组件时,基本都会使用到await()方法,当计数器不为0时,这可以阻塞调用该方法的线程。同时,通过这个方法我们也将知道上面介绍的tryAcquireShared()是在何处被调用的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
public void await() throws InterruptedException {
// 调用 AQS 的 acquireSharedInterruptibly() 方法
sync.acquireSharedInterruptibly(1);
}

// 此方法在AQS中实现
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted()) // 响应中断
throw new InterruptedException();
if (tryAcquireShared(arg) < 0) // 该方法由子类 Sync 实现,如果返回值大于0,那么将直接返回
doAcquireSharedInterruptibly(arg); // 否则,将会放入同步队列中被阻塞
}

// 此方法在AQS中实现
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg); // 由 CountDownLatch 的 Sync 具体实现
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}

countDown()

这个方法也是使用CountDownLatch组件时必不可少的一个方法,当一个线程调用上面的await()方法而被阻塞时,通过countDown()方法能将计数器的值(也就是变量state的值)减一,当计数器的值减为0时,阻塞在await()上的线程也就可以正常返回了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
public void countDown() {
// 就像 await() 调用 AQS 的 acquireSharedInterruptibly() 方法一样
// 这里调用 AQS 的 releaseShared() 方法
sync.releaseShared(1);
}

public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) { // 该方法由子类 Sync 实现,会将 state--,如果 state 为0了,就返回 true
doReleaseShared(); // 唤醒同步队列中的线程
return true;
}
return false;
}

private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}

CyclicBarrier

CyclicBarrier的作用和CountDownLatch类似,它是在计数器(等待线程数)达到指定数量后,再唤醒等待线程。它的实现和CountDownLatch不同,并没有直接通过AQS实现同步功能,而是在重入锁ReentrantLock的基础上实现的。先来了解一下它的几个成员变量:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
   private final ReentrantLock lock = new ReentrantLock();

private final Condition trip = lock.newCondition();

// 当 parties 个线程到达屏障后,屏障才会放行
private final int parties;

// 还剩下没到达屏障的线程数,会在新一轮开启或者当前屏障被破坏时重置为 parties
private int count;

// 当第 parties 个线程到达时回调
private final Runnable barrierCommand;

// 代表每一轮的运行状况,仅有一个成员变量 broken 表示屏障是否被破坏
private Generation generation = new Generation();

private static class Generation {
boolean broken = false;
}

接下来看看它的构造函数,与CountDownLatch一样需要传入一个计数器的初始值,除此之外,还可以传入一个回调对象,当最后一个线程到达屏障时会执行该回调逻辑:

1
2
3
4
5
6
7
8
9
10
public CyclicBarrier(int parties) {
this(parties, null);
}

public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) throw new IllegalArgumentException();
this.parties = parties;
this.count = parties; // 初始时有 parties 个线程未到达屏障
this.barrierCommand = barrierAction;
}

await()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen
}
}

private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
final ReentrantLock lock = this.lock;
lock.lock(); // 加锁
try {
final Generation g = generation;
// 如果 g.broken = true,表示屏障被破坏了,这里直接抛出异常
if (g.broken)
throw new BrokenBarrierException();
// 如果线程中断,则调用 breakBarrier() 破坏屏障
if (Thread.interrupted()) {
breakBarrier();
throw new InterruptedException();
}

// index 表示线程到达屏障的顺序,如果为 parties-1 表明当前是第一个到达屏障的
// 如果 index 为0,表示当前线程是最后一个到达屏障的
int index = --count;
if (index == 0) { // 如果 index 为0,唤醒所有处于等待状态的线程
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
if (command != null)
command.run();
ranAction = true;
nextGeneration(); // 重置屏障状态,使其进入新一轮的运行过程中
return 0; // 返回
} finally {
if (!ranAction) // 若执行过程中发生异常,则调用 breakBarrier() 破坏屏障
breakBarrier();
}
}

// 运行到此处的线程都会被屏障挡住,并进入等待状态
for (;;) {
try {
if (!timed) // timed 一般传入 false,因此这里条件成立
trip.await(); // 阻塞在 Condition 上
else if (nanos > 0L)
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
// We're about to finish waiting even if we had not
// been interrupted, so this interrupt is deemed to
// "belong" to subsequent execution.
Thread.currentThread().interrupt();
}
}

// 屏障被破坏,抛出 BrokenBarrierException 异常
if (g.broken)
throw new BrokenBarrierException();

// 屏障进入新的运行轮次,此时返回线程在上一轮次到达屏障的顺序
if (g != generation)
return index;

// 超时判断
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
lock.unlock();
}
}

/**
* 开启新的一轮运行过程
*/
private void nextGeneration() {
// 唤醒所有处于等待状态中的线程
trip.signalAll();
// 重置 count
count = parties;
// 重新创建 Generation
generation = new Generation();
}

/**
* 破坏屏障
*/
private void breakBarrier() {
// 设置屏障被破坏的标志
generation.broken = true;
// 重置 count
count = parties;
// 唤醒所有处于等待状态中的线程
trip.signalAll();
}

reset()

CyclicBarrier的计数器可以在正常结束一轮后自动重置,当然我们也可以使用reset()方法强制重置,代码如下:

1
2
3
4
5
6
7
8
9
10
public void reset() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
breakBarrier(); // 破坏屏障
nextGeneration(); // 开启新一轮的运行过程
} finally {
lock.unlock();
}
}

两者区别

总的来说,CountDownLatchCyclicBarrier能够实现的功能差不多,但是CyclicBarrier可以循环使用,并且可以设置回调,因此对于复杂的业务场景,使用CyclicBarrier更合适一些。关于具体的使用场景可以参考之前的一篇文章:Java中的并发工具类

参考资料