使用ForkJoin并行计算,实现一个Master-Worker并行计算框架
程序员文章站
2022-06-30 20:43:09
...
java.util.concurrent 包提供了一种将一个大任务分割成一个个小任务,并行执行这些小任务以提高效率的框架 ForkJoin。它的使用很简单,自己在程序中实现 compute() 方法即可,这个工具类也是使用空间换时间的思路。
代码清单一:ForkJoin的使用
package com.jack.jucstudy;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;
public class UseForkJoinTask extends RecursiveTask<Long>{
private static final int THRESHORD = 2;
private Integer start;
private Integer end;
public UseForkJoinTask(Integer start, Integer end) {
super();
this.start = start;
this.end = end;
}
public static void main(String[] args) throws InterruptedException, ExecutionException {
long start = System.currentTimeMillis();
ForkJoinPool pool = new ForkJoinPool();
UseForkJoinTask ufj = new UseForkJoinTask(0,10000000);
Future<Long> submit = pool.submit(ufj);
long end = System.currentTimeMillis();
System.out.println(String.format("使用forkJoinTask执行的结果为:%s,使用的时间为:%s毫秒", submit.get(),end - start));
start = System.currentTimeMillis();
long sum = 0;
for(int i = 0; i <= 10000000; i++) {
sum += i;
}
end = System.currentTimeMillis();
System.out.println(String.format("使用普通的for循环执行的结果为:%s,使用的时间为:%s毫秒", sum,end - start));
}
@Override
protected Long compute() {
long sum = 0;
boolean canCompute = (end - start) <= THRESHORD;
if(canCompute) {
for(int i = start; i <= end; i++) {
sum += i;
}
}else {
int middle = (start + end) / 2;
UseForkJoinTask leftTask = new UseForkJoinTask(start, middle);
UseForkJoinTask rightTask = new UseForkJoinTask(middle + 1, end);
//执行拆分
leftTask.fork();
rightTask.fork();
//执行结果合并
Long leftResult = leftTask.join();
Long rightResult = rightTask.join();
sum = leftResult + rightResult;
}
return sum;
}
}
运行结果:
从结果可以看出,使用并行计算的方式确实可以大大提升效率。
自己实现一个并行计算的框架:Master-Worker模式
上图是Master-Worker原理的示意图,在客户端传入了很多的Task,Master需要将这些存储在一个任务队列中,然后分发给各个Worker,每个Worker是一个工作线程,这种模式也是一种并行计算模式,以空间换时间的思想提高效率。
代码清单二:Task类
package com.jack.jucstudy.masterworker;
public class Task {
private String taskId;
private Integer count;
public Task(String taskId, Integer count) {
super();
this.taskId = taskId;
this.count = count;
}
public String getTaskId() {
return taskId;
}
public void setTaskId(String taskId) {
this.taskId = taskId;
}
public Integer getCount() {
return count;
}
public void setCount(Integer count) {
this.count = count;
}
}
代码清单三:Worker类
package com.jack.jucstudy.masterworker;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
public class Worker implements Runnable {
private ConcurrentLinkedQueue<Task> taskQuere;
private ConcurrentHashMap<String, Integer> resultMap;
private CountDownLatch countDownLatch;
private Random random = new Random();
@Override
public void run() {
while(true) {
//worker具体执行任务的地方
Task task = taskQuere.poll();
if(task == null) break;
System.out.println(Thread.currentThread().getName() + "开始执行任务--" + task.getTaskId());
try {
//执行任务的耗时
Thread.sleep(200 * random.nextInt(10));
} catch (InterruptedException e) {
e.printStackTrace();
}
resultMap.put(task.getTaskId(), task.getCount());
countDownLatch.countDown();
}
}
public void setTaskQuere(ConcurrentLinkedQueue<Task> taskQuere) {
this.taskQuere = taskQuere;
}
public void setResultMap(ConcurrentHashMap<String, Integer> resultMap) {
this.resultMap = resultMap;
}
public void setCountDownLatch(CountDownLatch countDownLatch) {
this.countDownLatch = countDownLatch;
}
}
代码清单四:Master类
package com.jack.jucstudy.masterworker;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
public class Master {
//定义一个任务队列用于盛装一个个任务
private ConcurrentLinkedQueue<Task> taskQueue = new ConcurrentLinkedQueue<>();
//存放worker的容器
private Map<String, Thread> workers = new HashMap<>();
//存放结果的容器,由于任务是并发执行的,可能存在线程安全问题,因此使用ConcurrentHashMap
private ConcurrentHashMap<String, Integer> resultMap = new ConcurrentHashMap<>();
//使用juc工具类 CountDownLatch,当所有的线程完成之后通知主线程。
private CountDownLatch countDownLatch;
public Master(int workerCount, int taskCount) {
countDownLatch = new CountDownLatch(taskCount);
Worker worker = new Worker();
worker.setResultMap(resultMap);
worker.setTaskQuere(taskQueue);
worker.setCountDownLatch(countDownLatch);
for(int i = 0; i < workerCount; i++) {
this.workers.put(Integer.valueOf(i).toString(), new Thread(worker));
}
}
/**
* 添加任务的方法
* @param task
*/
public void addTask(Task task) {
taskQueue.add(task);
}
/**
* Master开始执行,让所有的worker跑起来
*/
public void Execute() {
for(Entry<String, Thread> entry:workers.entrySet()) {
entry.getValue().start();
}
}
/**
* 统计结果,只有当所有的任务都完成之后才能统计结果。
* 使用 countDownLatch.await(); 所有的任务完成之后通知这个线程去统计结果
* @return
*/
public int getResult() {
try {
countDownLatch.await();
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
System.out.println("所有的任务已执行完成,开始统计结果");
int ret = 0;
for(Entry<String,Integer> e:resultMap.entrySet()) {
ret += e.getValue();
}
return ret;
}
}
代码清单五:测试类
package com.jack.jucstudy.masterworker;
import java.util.Random;
public class MainTest {
public static void main(String[] args) {
Random r = new Random();
int taskCount = 100;
//根据电脑核数创建worker数量
Master m = new Master(Runtime.getRuntime().availableProcessors(), taskCount);
//创建任务
for(int i = 0; i < taskCount; i++) {
Task t = new Task("task-" + i,r.nextInt(20));
m.addTask(t);
}
long start = System.currentTimeMillis();
m.Execute();
int ret = m.getResult();
long end = System.currentTimeMillis();
System.out.println("执行100个任务耗时:" + (end - start) + "ms,统计的结果为:" + ret);
}
}