Java并发编程之CountDownLatch源码解析

38次阅读

共计 5626 个字符,预计需要花费 15 分钟才能阅读完成。

一、导语
最近在学习并发编程原理,所以准备整理一下自己学到的知识,先写一篇 CountDownLatch 的源码分析,之后希望可以慢慢写完整个并发编程。
二、什么是 CountDownLatch
CountDownLatch 是 java 的 JUC 并发包里的一个工具类,可以理解为一个倒计时器,主要是用来控制多个线程之间的通信。比如有一个主线程 A,它要等待其他 4 个子线程执行完毕之后才能执行,此时就可以利用 CountDownLatch 来实现这种功能了。
三、简单使用
public static void main(String[] args){
System.out.println(“ 主线程和他的两个小兄弟约好去吃火锅 ”);
System.out.println(“ 主线程进入了饭店 ”);
System.out.println(“ 主线程想要开始动筷子吃饭 ”);
//new 一个计数器,初始值为 2,当计数器为 0 时,主线程开始执行
CountDownLatch latch = new CountDownLatch(2);

new Thread(){
public void run() {
try {
System.out.println(“ 子线程 1——小兄弟 A 正在到饭店的路上 ”);
Thread.sleep(3000);
System.out.println(“ 子线程 1——小兄弟 A 到饭店了 ”);
// 一个小兄弟到了,计数器 -1
latch.countDown();
} catch (InterruptedException e) {
e.printStackTrace();
}
};
}.start();

new Thread(){
public void run() {
try {
System.out.println(“ 子线程 2——小兄弟 B 正在到饭店的路上 ”);
Thread.sleep(3000);
System.out.println(“ 子线程 2——小兄弟 B 到饭店了 ”);
// 另一个小兄弟到了,计数器 -1
latch.countDown();
} catch (InterruptedException e) {
e.printStackTrace();
}
};
}.start();

// 主线程等待,直到其他两个小兄弟也进入饭店(计数器 ==0),主线程才能吃饭
latch.await();
System.out.println(“ 主线程终于可以开始吃饭了~”);
}
四、源码分析
核心代码:
CountDownLatch latch = new CountDownLatch(1);
latch.await();
latch.countDown();
其中构造函数的参数是计数器的值;await() 方法是用来阻塞线程,直到计数器的值为 0 countDown() 方法是执行计数器 - 1 操作
1、首先来看构造函数的代码
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException(“count < 0”);
this.sync = new Sync(count);
}
这段代码很简单,首先 if 判断传入的 count 是否 <0,如果小于 0 直接抛异常。然后 new 一个类 Sync,这个 Sync 是什么呢?我们一起来看下
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;

Sync(int count) {
setState(count);
}

int getCount() {
return getState();
}
// 尝试获取共享锁
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
// 尝试释放共享锁
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
可以看到 Sync 是一个内部类,继承了 AQS,AQS 是一个同步器,之后我们会详细讲。其中有几个核心点:

变量 state 是父类 AQS 里面的变量,在这里的语义是计数器的值
getState() 方法也是父类 AQS 里的方法,很简单,就是获取 state 的值
tryAcquireShared 和 tryReleaseShared 也是父类 AQS 里面的方法,在这里 CountDownLatch 对他们进行了重写,先有个印象,之后详讲。

2、了解了 CountDownLatch 的构造函数之后,我们再来看它的核心代码,首先是 await()。
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
可以看到,其实是通过内部类 Sync 调用了父类 AQS 的 acquireSharedInterruptibly() 方法。
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
// 判断线程是否是中断状态
if (Thread.interrupted())
throw new InterruptedException();
// 尝试获取 state 的值
if (tryAcquireShared(arg) < 0)//step1
doAcquireSharedInterruptibly(arg);//step2
}
tryAcquireShared(arg) 这个方法就是我们刚才在 Sync 内看到的重写父类 AQS 的方法,意思就是判断是否 getState() == 0,如果 state 为 0, 返回 1,则 step1 处不进入 if 体内 acquireSharedInterruptibly(int arg) 方法执行完毕。若 state!=0,则返回 -1,进入 if 体内 step2 处。
下面我们来看 acquireSharedInterruptibly(int arg) 方法:
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
//step1、把当前线程封装为共享类型的 Node,加入队列尾部
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
//step2、获取当前 node 的前一个元素
final Node p = node.predecessor();
//step3、如果前一个元素是队首
if (p == head) {
//step4、再次调用 tryAcquireShared() 方法,判断 state 的值是否为 0
int r = tryAcquireShared(arg);
//step5、如果 state 的值 ==0
if (r >= 0) {
//step6、设置当前 node 为队首,并尝试释放共享锁
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
//step7、是否可以安心挂起当前线程,是就挂起;并且判断当前线程是否中断
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
//step8、如果出现异常,failed 没有更新为 false,则把当前 node 从队列中取消
if (failed)
cancelAcquire(node);
}
}
按照代码中的注释,我们可以大概了解该方法的内容,下面我们来仔细看下其中调用的一些方法是干什么的。1、首先看 addWaiter()
//step1
private Node addWaiter(Node mode) {
// 把当前线程封装为 node
Node node = new Node(Thread.currentThread(), mode);
// Try the fast path of enq; backup to full enq on failure
// 获取当前队列的队尾 tail,并赋值给 pred
Node pred = tail;
// 如果 pred!=null,即当前队尾不为 null
if (pred != null) {
// 把当前队尾 tail,变成当前 node 的前继节点
node.prev = pred;
//cas 更新当前 node 为新的队尾
if (compareAndSetTail(pred, node)) {
pred.next = node;
return node;
}
}
// 如果队尾为空,走 enq 方法
enq(node);//step1.1
return node;
}

—————————————————————–
//step1.1
private Node enq(final Node node) {
for (;;) {
Node t = tail;
// 如果队尾 tail 为 null,初始化队列
if (t == null) {// Must initialize
//cas 设置一个新的空 node 为队首
if (compareAndSetHead(new Node()))
tail = head;
} else {
//cas 把当前 node 设置为新队尾,把前队尾设置成当前 node 的前继节点
node.prev = t;
if (compareAndSetTail(t, node)) {
t.next = node;
return t;
}
}
}
}
2、接下来我们在来看 setHeadAndPropagate() 方法,看其内部实现
//step6
private void setHeadAndPropagate(Node node, int propagate) {
// 获取队首 head
Node h = head; // Record old head for check below
// 设置当前 node 为队首,并取消 node 所关联的线程
setHead(node);
//
if (propagate > 0 || h == null || h.waitStatus < 0 ||
(h = head) == null || h.waitStatus < 0) {
Node s = node.next;
// 如果当前 node 的后继节点为 null 或者是 shared 类型的
if (s == null || s.isShared())
// 释放锁,唤醒下一个线程
doReleaseShared();//step6.1
}
}
——————————————————————–
//step6.1
private void doReleaseShared() {
for (;;) {
// 找到头节点
Node h = head;
if (h != null && h != tail) {
// 获取头节点状态
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
// 唤醒 head 节点的 next 节点
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}
3、接下来我们来看 countDown() 方法。
public void countDown() {
sync.releaseShared(1);
}
可以看到调用的是父类 AQS 的 releaseShared 方法
public final boolean releaseShared(int arg) {
//state-1
if (tryReleaseShared(arg)) {//step1
// 唤醒等待线程,内部调用的是 LockSupport.unpark 方法
doReleaseShared();//step2
return true;
}
return false;
}
——————————————————————
//step1
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
// 获取当前 state 的值
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
//cas 操作来进行原子减 1
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
五、总结
CountDownLatch 主要是通过计数器 state 来控制是否可以执行其他操作,如果不能就通过 LockSupport.park() 方法挂起线程,直到其他线程执行完毕后唤醒它。下面我们通过一个简单的图来帮助我们理解一下:PS:本人也是还在学习的路上,理解的也不是特别透彻,如有错误,愿倾听教诲。^_^

正文完
 0