快速掌握并发编程---CountDownLatch原理和实战

it2023-12-04  95

关注“Java后端技术全栈”

回复“000”获取大量电子书

常见面试题

如何实现让主线程等所有子线程执行完了后,主要线程再继续执行?即如何实现一个线程等其他线程执行完了后再继续执行?

方法一

在前面的文章中我们介绍了Thread类的join方法:快速掌握并发编程---Thread常用方法,join的工作原理是,不停检查thread是否存活,如果存活则让当前线程永远wait,直到thread线程终止,线程的notifyAll就会被调用。

下面我们就使用join来实现上面面试题。

import java.util.Random; import java.util.concurrent.CountDownLatch; public class CountDownLatchDemo {     public static void main(String[] args) {         System.out.println("主要线程开始等待其他子线程执行");         try {             test();         } catch (InterruptedException e) {             e.printStackTrace();         }     }     public static void test() throws InterruptedException {        Thread thread1 = new Thread(() -> {             System.out.println(Thread.currentThread().getName() + " 线程开始");             Random random = new Random();             try {                 Thread.sleep(random.nextInt(10000) + 1000);             } catch (InterruptedException e) {                 e.printStackTrace();             }             System.out.println( Thread.currentThread().getName() + " 线程执行完毕");         },"线程1");        thread1.start();         Thread thread2 = new Thread(() -> {             System.out.println(Thread.currentThread().getName() + " 线程开始");             Random random = new Random();             try {                 Thread.sleep(random.nextInt(10000) + 1000);             } catch (InterruptedException e) {                 e.printStackTrace();             }             System.out.println(Thread.currentThread().getName() + " 线程执行完毕");         },"线程2");         thread2.start();         Thread thread3 = new Thread(() -> {             System.out.println(Thread.currentThread().getName() + " 线程开始");             Random random = new Random();             try {                 Thread.sleep(random.nextInt(10000) + 1000);             } catch (InterruptedException e) {                 e.printStackTrace();             }             System.out.println( Thread.currentThread().getName() + " 线程执行完毕");         },"线程3");         thread3.start();         Thread thread4 = new Thread(() -> {             System.out.println(Thread.currentThread().getName() + " 线程开始");             Random random = new Random();             try {                 Thread.sleep(random.nextInt(10000) + 1000);             } catch (InterruptedException e) {                 e.printStackTrace();             }             System.out.println(Thread.currentThread().getName() + " 线程执行完毕");         },"线程4");         thread4.start();         //启动了四个线程,然后让四个线程一直检测自己是否已经结束         thread1.join();         thread2.join();         thread3.join();         thread4.join();         System.out.println("主线程继续执行");         //todo 业务代码     } }

运行结果

主要线程开始等待其他子线程执行 线程1 线程开始 线程2 线程开始 线程3 线程开始 线程4 线程开始 线程3 线程执行完毕 线程2 线程执行完毕 线程1 线程执行完毕 线程4 线程执行完毕 主线程继续执行

主线程继续干活是要等前面四个线程全部执行完毕后再继续的。但是这么搞有点麻烦,那就是每个线程都得调用join方法,有没有更好玩的的呢?

答案是有的,它来了。

它就是juc下面的一个很牛逼的并发工具类CountDownLatch。是JDK1.5的时候有的,言外之意就是在JDK1.5之前就只能用join方法了。

方法二

CountDownLatch中我们主要用到两个方法一个是await()方法,调用这个方法的线程会被阻塞,另外一个是countDown()方法,调用这个方法会使计数器减一,当计数器的值为0时,因调用await()方法被阻塞的线程会被唤醒,继续执行。请看代码:

import java.util.Random; import java.util.concurrent.CountDownLatch; public class CountDownLatchDemo {     public static void main(String[] args) {         System.out.println("主要线程开始等待其他子线程执行");         test();     }     public static void test() {         int threadCount = 5;         CountDownLatch countDownLatch = new CountDownLatch(threadCount);         for (int i = 0; i < threadCount; i++) {             final int finalI = i + 1;             new Thread(() -> {                 System.out.println("第 " + finalI + " 线程开始");                 Random random = new Random();                 try {                     Thread.sleep(random.nextInt(10000) + 1000);                 } catch (InterruptedException e) {                     e.printStackTrace();                 }                 System.out.println("第 " + finalI + " 线程执行完毕");                 countDownLatch.countDown();             }).start();         }         try {             countDownLatch.await();         } catch (InterruptedException e) {             e.printStackTrace();         }         System.out.println(threadCount + " 个线程全部执行完毕");         System.out.println("主线程继续执行");         //todo业务代码     } }

输出

主要线程开始等待其他子线程执行 第 1 线程开始 第 2 线程开始 第 3 线程开始 第 4 线程开始 第 5 线程开始 第 1 线程执行完毕 第 2 线程执行完毕 第 5 线程执行完毕 第 4 线程执行完毕 第 3 线程执行完毕 5 个线程全部执行完毕 主线程继续执行

面试能把这两种方式说出来,证明你还是可以解决这个问题。

但问题来了,如果面试官问你实现原理,你却回答不上来,就会给人你在瞎用的感觉,这样好不容易前面拿到点好印象结果被打回原型。

至于join的原理,建议去看看我之前发的线程常用方法里:快速掌握并发编程---Thread常用方法,那里面说的很清楚了,所这里就不在重复了。

今天我们着重了了CountDownLatch。

CountDownLatch

概念

CountDownLatch可以使一个获多个线程等待其他线程各自执行完毕后再执行。

CountDownLatch 定义了一个计数器,和一个阻塞队列, 当计数器的值递减为0之前,阻塞队列里面的线程处于挂起状态,当计数器递减到0时会唤醒阻塞队列所有线程,这里的计数器是一个标志,可以表示一个任务一个线程,也可以表示一个倒计时器,CountDownLatch可以解决那些一个或者多个线程在执行之前必须依赖于某些必要的前提业务先执行的场景。

整体

常用方法

构造方法

我们在上面的案例中

 int threadCount = 5;  CountDownLatch countDownLatch = new CountDownLatch(threadCount);

有用到new CountDownLatch(threadCount);来创建一个CountDownLatch实例对象。我们看看这个构造方法

private final Sync sync; public CountDownLatch(int count) {      //记者count值不能小于0     if (count < 0) throw new IllegalArgumentException("count < 0");     //创建一个Sync实例对象入参就是count     this.sync = new Sync(count); }

然后这里有个内部类Sync,这个Sync内部类也没几行代码,Sync继承了AbstractQueuedSynchronizer抽象队列同步器(以下简称AQS)。

private static final class Sync extends AbstractQueuedSynchronizer {         private static final long serialVersionUID = 4982264981922014374L;         //入参count         Sync(int count) {             //这个setState方法还记得否?就是上篇文章AQS中的setState()方法             //就是给AQS中的state赋值,state=count             setState(count);         }         //获取AQS中state的值         int getCount() {             return getState();         }         protected int tryAcquireShared(int acquires) {             return (getState() == 0) ? 1 : -1;         }         //死循环         protected boolean tryReleaseShared(int releases) {             for (;;) {                 //获取AQS中的state                 int c = getState();                 //如果AQS中的state==0,就返回false                 if (c == 0)  return false;                 int nextc = c-1;                 //nextc=state-1                 //                 if (compareAndSetState(c, nextc))                     return nextc == 0;             }         }  }
countDown方法
public void countDown() {     //调用的就是AQS中的方法     sync.releaseShared(1); }

AQS中releaseShared方法

public final boolean releaseShared(int arg) {     // arg 为固定值 1     // 如果计数器state 为0 返回true,前提是调用 countDown() 之前不能已经为0     //tryReleaseShared在AQS是空方法     if (tryReleaseShared(arg)) {       // 唤醒等待队列的线程         doReleaseShared();          return true;     }     return false; } protected boolean tryReleaseShared(int arg) {    throw new UnsupportedOperationException(); }

这个方法tryReleaseShared()是在CountDownLatch中内部类Sync中实现的

//其实也没什么新招 //还是死循环+CAS配合 实现计数器state减1 protected boolean tryReleaseShared(int releases) {     // Decrement count; signal when transition to zero     for (;;) {         int c = getState();         if (c == 0)  return false;         int nextc = c-1;         if (compareAndSetState(c, nextc))            return nextc == 0;      } }

方法doReleaseShared却是AQS种实现的(因为CountDownLatch和其内部类都没有实现,只有AQS实现了,那就只认AQS中的实现了)。

//实现思路就是从头到尾的遍历列队中所有的节点为shared状态的 private void doReleaseShared() {         //死循环         for (;;) {             //获取当前列队的头节点             Node h = head;             //列队中可能为空列队,也有可能只有一个node节点             if (h != null && h != tail) {                 //获取头节点的状态                 int ws = h.waitStatus;                 //如果头节点为SIGNAL状态, 说明后继节点需要唤醒                 if (ws == Node.SIGNAL) {                     //将头结点的waitstatue设置为0,以后就不会再次唤醒后继节点了。                     //这一步是为了解决并发问题,保证只unpark一次!!不成功就继续                     if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))                         continue;            // loop to recheck cases                     //(释放)唤醒头节点的后继节点                     unparkSuccessor(h);                 }// 状态为0并且不成功,继续                 else if (ws == 0 && !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))                     continue;// loop on failed CAS             }             // 若头结点改变,继续循环               if (h == head) // loop if head changed                 break;         } }

整个调用逻辑大致为

await方法

在CountDownLatch中await犯法

public void await() throws InterruptedException {    sync.acquireSharedInterruptibly(1); }

然后调用AQS中的

public final void acquireSharedInterruptibly(int arg) throws InterruptedException {       //判断是否被中断过       if (Thread.interrupted()) throw new InterruptedException();       //如果state不等于0的时候       if (tryAcquireShared(arg) < 0){             doAcquireSharedInterruptibly(arg);       } }

其中方法tryAcquireShared(arg)是CountDownLatch的内部类Sync的tryAcquireShared方法

protected int tryAcquireShared(int acquires) {   //判断AQS中的state是否已经等于0了,等于翻译1否则返回-1   return (getState() == 0) ? 1 : -1; }

再调用AQS中的doAcquireSharedInterruptibly方法

 //这个方法就是将当前线程封装成node节点加入到列队中,并判断是否需要阻塞当前线程  //这个节点都会被设置成shared状态,这样做的目的时当state值为0时会唤醒所有shared的节点 private void doAcquireSharedInterruptibly(int arg)         throws InterruptedException {         //这个方法应该很熟悉了吧,前面的文章都介绍过,将当前线程封装成节点加入到列队中         final Node node = addWaiter(Node.SHARED);         boolean failed = true;         try {             //(又是死循环)一直执行,直到获取锁,返回             for (;;) {                 //获取前驱节点                 final Node p = node.predecessor();                 //前驱节点为头结点                 if (p == head) {                     //所以再次尝试获取信号量,这就是上面分析的那个获取方法                     int r = tryAcquireShared(arg);                     //如果r大于0证明获取信号量获取成功了证明下一个可以获取信号量的线程是当前线程                     if (r >= 0) {                         //将当前节点变成列队的head节点然后返回                         setHeadAndPropagate(node, r);                         //方便GC                         p.next = null;                          failed = false;                         return;                     }                 }                //判断是否要进入阻塞状态.如果shouldParkAfterFailedAcquire方法                //返回true,表示需要进入阻塞                //调用parkAndCheckInterrupt;否则表示还可以再次尝试获取锁,继续进行for循环                 if (shouldParkAfterFailedAcquire(p, node) &&                     parkAndCheckInterrupt())                     throw new InterruptedException();             }         } finally {             //失败就放弃             if (failed){                 cancelAcquire(node);             }         } }

方法shouldParkAfterFailedAcquire是AQS的

//p是前驱结点,node是当前结点 private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {     int ws = pred.waitStatus; //获取前驱节点的状态     if (ws == Node.SIGNAL) //表明前驱节点可以运行         return true;     if (ws > 0) { //如果前驱节点状态大于0表明已经中断,         do {             node.prev = pred = pred.prev;          } while (pred.waitStatus > 0);         pred.next = node;     } else {         //等于0进入这里         compareAndSetWaitStatus(pred, ws, Node.SIGNAL);      }     //只有前节点状态为NodeSIGNAL才返回真     return false;  }

我们对shouldParkAfterFailedAcquire来进行一个整体的概述,首先应该明白节点的状态,节点的状态是为了表明当前线程的良好度,如果当前线程被打断了,在唤醒的过程中是不是应该忽略该线程

    static final class Node {         static final int CANCELLED =  1;         static final int SIGNAL    = -1;         static final int CONDITION = -2;         static final int PROPAGATE = -3;        //....

目前你只需知道大于0时表明该线程已近被取消,已近是无效节点,不应该被唤醒,注意:初始化链头节点时头节点状态值为0。

shouldParkAfterFailedAcquire是位于无限for循环内的,这一点需要注意一般每个节点都会经历两次循环后然后被阻塞。

在AQS的doAcquireSharedInterruptibly中可能会再次调用CountDownLatch的内部类Sync的tryAcquireShared方法和AQS的setHeadAndPropagate方法。setHeadAndPropagate方法源码如下。

private void setHeadAndPropagate(Node node, int propagate) {         // 获取头结点         Node h = head;          // 设置头结点         setHead(node);         // 进行判断         if (propagate > 0 || h == null || h.waitStatus < 0 ||             (h = head) == null || h.waitStatus < 0) {             // 获取节点的后继             Node s = node.next;             if (s == null || s.isShared()) // 后继为空或者为共享模式                 // 以共享模式进行释放                 doReleaseShared();         }     }

该方法设置头结点并且释放头结点后面的满足条件的结点,该方法中可能会调用到AQS的doReleaseShared方法,其源码如下。

private void doReleaseShared() {         // 无限循环         for (;;) {             // 保存头结点             Node h = head;             if (h != null && h != tail) { // 头结点不为空并且头结点不为尾结点                 // 获取头结点的等待状态                 int ws = h.waitStatus;                  if (ws == Node.SIGNAL) { // 状态为SIGNAL                     if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) // 不成功就继续                         continue;            // loop to recheck cases                     // 释放后继结点                     unparkSuccessor(h);                 }                 else if (ws == 0 &&                          !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) // 状态为0并且不成功,继续                     continue;                // loop on failed CAS             }             if (h == head) // 若头结点改变,继续循环                   break;         }     }

CountDownLatch的await调用大致会有如下的调用链

经典使用场景

CountDownLatch的一个非常典型的应用场景是:有一个任务想要往下执行,但必须要等到其他的任务执行完毕后才可以继续往下执行。假如我们这个想要继续往下执行的任务调用一个CountDownLatch对象的await()方法,其他的任务执行完自己的任务后调用同一个CountDownLatch对象上的countDown()方法,这个调用await()方法的任务将一直阻塞等待,直到这个CountDownLatch对象的计数值减到0为止。

案例1

举个例子,有三个工人在为老板干活,这个老板有一个习惯,就是当三个工人把一天的活都干完了的时候,他就来检查所有工人所干的活。记住这个条件:三个工人先全部干完活,老板才检查。

案例2

比如读取excel表格,需要把execl表格中的多个sheet进行数据汇总,为了提高汇总的效率我们一般会开启多个线程,每个线程去读取一个sheet,可是线程执行是异步的,我们不知道什么时候数据处理结束了。那么这个时候我们就可以运用CountDownLatch,有几个sheet就把state初始化几。每个线程执行完就调用countDown()方法,在汇总的地方加上await()方法,当所有线程执行完了,就可以进行数据的汇总了。

END

扫描关注公众号“Java后端技术全栈”

解锁程序员的狂野世界

最新回复(0)