CountDownLatch详解

简介

       CountDownLatch是一个同步工具类,它允许一个或多个线程等待,直到在其它线程中执行的一组操作完成。

       CountDownLatch随JDK 1.5一起引入,并与java.util.concurrent包中的其它并发实用程序(如CyclicBarrier、Semaphore、ConcurrentHashMap和BlockingQueue)一起引入。

CountDownLatch 的原理

       CountDownLatch是通过一个计数器来实现的,维护一个count的变量,并且其操作都是原子操作,计数器的初始值为线程的数量。每当一个线程完成了自己的任务后,计数器的值就会减1,当计数器值到达0时,表示所有的线程已经完成了任务,然后在闭锁上等待的线程就可以恢复执行任务。

       CountDownLatch主要通过countDown()和await()两个方法实现功能,首先建立CountDownLatch对象,并且传入参数即为count初始值。如果一个线程调用了await()方法,那么这个线程便进入阻塞状态,并进入阻塞队列。如果一个线程调用了countDown()方法,则会使count-1;当count的值为0时,这时候阻塞队列中调用await()方法的线程便会逐个被唤醒,从而进入后续的操作。

CountDownLatch 原理图

       CountDownLatch的伪代码可以这样编写:

1
2
3
4
5
6
//主线程启动
//为N个线程创建CountDownLatch
//创建并启动N个线程
//主线程在闩锁上等待
// N个线程完成那里的任务并返回
//主线程恢复执行

CountDownLatch 如何工作

       CountDownLatch类定义了一个构造函数:

1
2
//Constructs a CountDownLatch initialized with the given count.
public CountDownLatch(int count) {...}

       此计数本质上是闩锁应等待的线程数。该值只能设置一次,并且CountDownLatch没有提供其它机制来重置此count。

       第一次与CountDownLatch的交互是与等待其它线程的主线程进行的。此主线程必须在启动其它线程后立即调用CountDownLatch.await()方法,这样主线程的操作就会在这个方法上阻塞,直到其它线程完成各自的任务为止。

       其它N个线程必须引用闩锁对象,因为它们如果完成了任务需要通知CountDownLatch对象。该通知通过CountDownLatch.countDown()方法完成,每次调用计数减少1。当所有N个线程都调用此方法时,计数达到0,主线程可以在await()方法之后继续执行。

       重要的三个方法如下:

1
2
3
4
5
6
//调用await()方法的线程会被挂起,它会等待直到count值为0才继续执行
public void await() throws InterruptedException { };
//和await()类似,只不过等待一定的时间后count值还没变为0的话就会继续执行
public boolean await(long timeout, TimeUnit unit) throws InterruptedException { };
//将count值减1
public void countDown() { };

CountDownLatch 的使用场景

实现最大的并行性

       想同时启动多个线程,实现最大程度的并行性。例如想测试一个单例类。如果创建一个初始计数器为1的CountDownLatch并让其它所有线程都在这个锁上等待,只需要调用一次countDown()方法就可以让其它所有等待的线程同时恢复执行。

开始执行前等待 N 个线程完成各自任务

       例如应用程序启动类要确保在处理用户请求前,所有N个外部系统都已经启动和运行了。

死锁检测

       用N个线程去访问共享资源,在每个测试阶段线程数量不同,并尝试产生死锁。

CountDownLatch 使用案例

案例一

       模拟一个应用程序启动类,开始就启动N个线程,去检查N个外部服务是否正常并通知闩锁。启动类一直在闩锁上等待,一旦验证和检查了所有外部服务,就恢复启动类执行。

       BaseHealthChecker实现Runnable接口,并且是所有特定外部服务运行状况检查程序的父类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
public abstract class BaseHealthChecker implements Runnable {

private CountDownLatch _latch;
private String _serviceName;
private boolean _serviceUp;

//Get latch object in constructor so that after completing the task, thread can countDown() the latch
public BaseHealthChecker(String _serviceName, CountDownLatch _latch) {
this._latch = _latch;
this._serviceName = _serviceName;
this._serviceUp = false;
}

@Override
public void run() {
try {
verifyService();
_serviceUp = true;
} catch (Throwable t) {
t.printStackTrace(System.err);
_serviceUp = false;
} finally {
if (_latch != null) {
_latch.countDown();
}
}
}

public String getServiceName() {
return _serviceName;
}

public boolean isServiceUp() {
return _serviceUp;
}

//This methods needs to be implemented by all specific service checker
public abstract void verifyService();
}

       以下三个类都继承自BaseHealthChecker,引用CountDownLatch 实例,除了服务名和休眠时间不同外,都实现各自的verifyService()方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public class NetworkHealthChecker extends BaseHealthChecker {

public NetworkHealthChecker(CountDownLatch latch) {
super("Network Service", latch);
}

@Override
public void verifyService() {
System.out.println("Checking " + this.getServiceName());
try {
Thread.sleep(7000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(this.getServiceName() + " is UP");
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public class DatabaseHealthChecker extends BaseHealthChecker {

public DatabaseHealthChecker(CountDownLatch latch) {
super("Database Service", latch);
}

@Override
public void verifyService() {
System.out.println("Checking " + this.getServiceName());
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(this.getServiceName() + " is UP");
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public class CacheHealthChecker extends BaseHealthChecker {

public CacheHealthChecker(CountDownLatch latch) {
super("Cache Service", latch);
}

@Override
public void verifyService() {
System.out.println("Checking " + this.getServiceName());
try {
Thread.sleep(5000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(this.getServiceName() + " is UP");
}
}

       ApplicationStartupUtil类是主要的启动类,它将初始化闩锁并等待该闩锁,所有服务都被检查完成后,再恢复执行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
public class ApplicationStartupUtil {

//List of service checkers
private static List<BaseHealthChecker> _services;

//This latch will be used to wait on
private static CountDownLatch _latch;

private ApplicationStartupUtil() {
}

private final static ApplicationStartupUtil INSTANCE = new ApplicationStartupUtil();

public static ApplicationStartupUtil getInstance() {
return INSTANCE;
}

public static boolean checkExternalServices() throws Exception {
//Initialize the latch with number of service checkers
_latch = new CountDownLatch(3);

//All add checker in lists
_services = new ArrayList<BaseHealthChecker>();
_services.add(new NetworkHealthChecker(_latch));
_services.add(new CacheHealthChecker(_latch));
_services.add(new DatabaseHealthChecker(_latch));

//Start service checkers using executor framework
Executor executor = Executors.newFixedThreadPool(_services.size());

for (final BaseHealthChecker v : _services) {
executor.execute(v);
}

//Now wait till all services are checked
_latch.await();

//Services are file and now proceed startup
for (final BaseHealthChecker v : _services) {
if (!v.isServiceUp()) {
return false;
}
}
return true;
}
}

       测试代码如下:

1
2
3
4
5
6
7
8
9
public static void main(String[] args) {
boolean result = false;
try {
result = ApplicationStartupUtil.checkExternalServices();
} catch (Exception e) {
e.printStackTrace();
}
System.out.println("External services validation completed !! Result was :: " + result);
}
1
2
3
4
5
6
7
Checking Cache Service
Checking Database Service
Checking Network Service
Database Service is UP
Cache Service is UP
Network Service is UP
External services validation completed !! Result was :: true

通过join实现CountDownLatch功能

       下面的例子有两个操作,一个是读操作一个是写操作,现在规定必须进行完写操作才能进行读操作。所以当最开始调用读操作时,需要用await()方法使其阻塞,当写操作结束时,则需要使count等于0。因此count的初始值可以定为写操作的记录数,这样便可以使得进行完写操作,然后进行读操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
public class CountDownLatchDemo {

private final static CountDownLatch cdl = new CountDownLatch(3);
private final static Vector v = new Vector();

private static class WriteThread extends Thread {
private final String writeThreadName;
private final int stopTime;
private final String str;

public WriteThread(String name, int time, String str) {
this.writeThreadName = name;
this.stopTime = time;
this.str = str;
}

public void run() {
System.out.println(writeThreadName + "开始写入工作");
try {
Thread.sleep(stopTime);
} catch (InterruptedException e) {
e.printStackTrace();
}
cdl.countDown();
v.add(str);
System.out.println(writeThreadName + "写入内容为:" + str + "。写入工作结束!");
}
}

private static class ReadThread extends Thread {
public void run() {
System.out.println("读操作之前必须先进行写操作");
try {
cdl.await();//该线程进行等待,直到countDown减到0,然后逐个苏醒过来。
//Thread.sleep(3000);
} catch (InterruptedException e) {
e.printStackTrace();
}
for (int i = 0; i < v.size(); i++) {
System.out.println("读取第" + (i + 1) + "条记录内容为:" + v.get(i));
}
System.out.println("读操作结束!");
}
}

public static void main(String[] args) {
new ReadThread().start();
new WriteThread("writeThread1", 1000, "多线程知识点").start();
new WriteThread("writeThread2", 2000, "多线程CountDownLatch的知识点").start();
new WriteThread("writeThread3", 3000, "多线程中控制顺序可以使用CountDownLatch").start();
}
}

       运行代码,结果如下:

1
2
3
4
5
6
7
8
9
10
11
读操作之前必须先进行写操作
writeThread1开始写入工作
writeThread2开始写入工作
writeThread3开始写入工作
writeThread1写入内容为:多线程知识点。写入工作结束!
writeThread2写入内容为:多线程CountDownLatch的知识点。写入工作结束!
writeThread3写入内容为:多线程中控制顺序可以使用CountDownLatch。写入工作结束!
读取第1条记录内容为:多线程知识点
读取第2条记录内容为:多线程CountDownLatch的知识点
读取第3条记录内容为:多线程中控制顺序可以使用CountDownLatch
读操作结束!

       从以上过程可以看出,可以使得先进行写操作然后进行读操作。

       其实上述CountDownLatch这种功能可以通过Thread对象的join方法实现同样的功能,只是这里无须调用await()方法和countDown()方法,而是使用sleep()进行控制时间,然后将读操作以及写操作通过在主线程通过join()方法使其加入主线程,使其实现只有进行写操作结束,才能进行读操作。具体代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
public class JoinDemo {
private final static Vector v = new Vector();
Lock lock = new ReentrantLock();
final Condition condition = lock.newCondition();//创建condition对象

private static class WriteThread extends Thread {
private final String writeThreadName;
private final int stopTime;
private final String str;
Lock lock = new ReentrantLock();
final Condition condition = lock.newCondition();//创建condition对象

public WriteThread(String name, int time, String str) {
this.writeThreadName = name;
this.stopTime = time;
this.str = str;
}

public void run() {
System.out.println(writeThreadName + "开始写入工作");
try {
Thread.sleep(stopTime);
} catch (InterruptedException e) {
e.printStackTrace();
}

v.add(str);
System.out.println(writeThreadName + "写入内容为:" + str + "。写入工作结束!");
}
}

private static class ReadThread extends Thread {
Lock lock = new ReentrantLock();
final Condition condition = lock.newCondition();//创建condition对象

public void run() {
System.out.println("读操作之前必须先进行写操作");
try {
Thread.sleep(10000);//该线程进行暂停,时间控制在写操作结束才使线程苏醒过来。
} catch (InterruptedException e) {
e.printStackTrace();
}
for (int i = 0; i < v.size(); i++) {
System.out.println("读取第" + (i + 1) + "条记录内容为:" + v.get(i));
}
System.out.println("读操作结束!");
}
}

public static void main(String[] args) {

ReadThread readThread = new ReadThread();
readThread.start();
long start = System.currentTimeMillis();
Thread[] write = new Thread[3];
String[] str = {"多线程知识点", "多线程CountDownLatch的知识点", "多线程中控制顺序可以使用CountDownLatch"};

for (int i = 0; i < 3; i++) {
Thread t1 = new WriteThread("writeThread" + (i + 1), 1000 * (i + 1), str[i]);
t1.start();
write[i] = t1;
}
try {
readThread.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
//等待线程结束
for (Thread t : write) {
try {
t.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
//等待线程结束
}
}

       运行上述程序,可以得到如下结果:

1
2
3
4
5
6
7
8
9
10
11
读操作之前必须先进行写操作
writeThread2开始写入工作
writeThread1开始写入工作
writeThread3开始写入工作
writeThread1写入内容为:多线程知识点。写入工作结束!
writeThread2写入内容为:多线程CountDownLatch的知识点。写入工作结束!
writeThread3写入内容为:多线程中控制顺序可以使用CountDownLatch。写入工作结束!
读取第1条记录内容为:多线程知识点
读取第2条记录内容为:多线程CountDownLatch的知识点
读取第3条记录内容为:多线程中控制顺序可以使用CountDownLatch
读操作结束!

       通过上述的比较,可以通过join方法实现CountDownLatch的按顺序执行线程的功能,但是CountDownLatch有join实现不了的情况,比如使用线程池时,线程池的线程不能直接使用,所以只能使用CountDownLatch实现按顺序执行线程,而无法使用join()方法。具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
public class CountDownLatchDemo2 {

private final static CountDownLatch cdl = new CountDownLatch(3);
private final static Vector v = new Vector();
private final static ThreadPoolExecutor threadPool = new ThreadPoolExecutor(10, 15, 60, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>());//使用线程池

private static class WriteThread extends Thread {
private final String writeThreadName;
private final int stopTime;
private final String str;

public WriteThread(String name, int time, String str) {
this.writeThreadName = name;
this.stopTime = time;
this.str = str;
}

public void run() {
System.out.println(writeThreadName + "开始写入工作");
try {
Thread.sleep(stopTime);
} catch (InterruptedException e) {
e.printStackTrace();
}
cdl.countDown();
v.add(str);
System.out.println(writeThreadName + "写入内容为:" + str + "。写入工作结束!");
}
}

private static class ReadThread extends Thread {
public void run() {
System.out.println("读操作之前必须先进行写操作");
try {
cdl.await();//该线程进行等待,直到countDown减到0,然后逐个苏醒过来。
//Thread.sleep(3000);
} catch (InterruptedException e) {
e.printStackTrace();
}
for (int i = 0; i < v.size(); i++) {
System.out.println("读取第" + (i + 1) + "条记录内容为:" + v.get(i));
}
System.out.println("读操作结束!");
}
}

public static void main(String[] args) {
Thread read = new ReadThread();
threadPool.execute(read);
String[] str = {"多线程知识点", "多线程CountDownLatch的知识点", "多线程中控制顺序可以使用CountDownLatch"};
for (int i = 0; i < 3; i++) {
Thread t1 = new WriteThread("writeThread" + (i + 1), 1000 * (i + 1), str[i]);
threadPool.execute(t1);
}
//new WriteThread("writeThread1",1000,"多线程知识点").start();
//new WriteThread("writeThread2",2000,"多线程CountDownLatch的知识点").start();
//new WriteThread("writeThread3",3000,"多线程中控制顺序可以使用CountDownLatch").start();
}
}

       运行如上程序,得到以下结果:

1
2
3
4
5
6
7
8
9
10
11
读操作之前必须先进行写操作
writeThread1开始写入工作
writeThread3开始写入工作
writeThread2开始写入工作
writeThread1写入内容为:多线程知识点。写入工作结束!
writeThread2写入内容为:多线程CountDownLatch的知识点。写入工作结束!
writeThread3写入内容为:多线程中控制顺序可以使用CountDownLatch。写入工作结束!
读取第1条记录内容为:多线程知识点
读取第2条记录内容为:多线程CountDownLatch的知识点
读取第3条记录内容为:多线程中控制顺序可以使用CountDownLatch
读操作结束!

总结

       CountDownLatch是一次性的,计数器的值只能在构造方法中初始化一次,之后没有任何机制再次对其设置值,当CountDownLatch使用完毕后,不能再次被使用。

       Thread的join()方法可以实现相同的功能,但是当使用了线程池时,则join()方法便无法实现,CountDownLatch依然可以实现功能。

       CountDownLatch类主要使用的场景有明显的顺序要求。比如只有等跑完步才能计算排名,只有等所有记录都写入才能进行统计工作等等,因此CountDownLatch完善的是某种逻辑上的功能,使得线程按照正确的逻辑进行。

参考资料:
HowToDoInJava Java concurrency – CountDownLatch Example
羽杰 CountDownLatch 相关整理
carson0408 CountDownLatch的工作原理以及实例
指尖架构141319 countDownLatch

Fork me on GitHub