欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

Java ForkJoinPool: 3秒计算100万的阶乘

程序员文章站 2022-06-19 14:57:53
Java ForkJoinPool: 3秒计算100万的阶乘...

问题背景&思路

如果需要计算100的阶乘,那应该怎么做?

方法1:
for循环(默认,单线程)

方法2:
多线程,MapReduce思想
main线程开启多个子任务(个数=CPU核心数),放到线程池执行,每个子任务统计from ~ to的正整数,然后返回本次计算的结果;然后main线程拿到所有子任务的返回结果后再次统计
任务: SumTask implements Callable,返回统计结果Future

方法3:
ForkJoinPool,递归任务,MapReduce思想
任务:SumTask extends RecursiveTask,返回统计结果BigDecimal
任务中判断,当需要统计的数>某个阈值时,拆分成两个任务;否则直接执行并返回
这个阈值时多少合适?经过测试后,当计算10000的阶乘时,方法1略快于方法2,所以这个阈值就定位10000

上代码

为了方便测试,先定义了一个接口:

package com.wz.poc.forkjoin;

import java.math.BigDecimal;

/**
 * @author liweizhi
 * @date 2020/12/31
 */
public interface Calculator {
    /**
     * 求传进来数的阶乘
     *
     * @param number
     * @return 总和
     */
    BigDecimal factorial(long number);
}

单线程循环

package com.wz.poc.forkjoin;

import java.math.BigDecimal;

/**
 * @author liweizhi
 * @date 2020/12/31
 */
public class ForLoopCalculator implements Calculator {
    @Override
    public BigDecimal factorial(long number) {
        BigDecimal ret = new BigDecimal(1);
        for (long i = 1; i <= number; i++) {
            ret = ret.multiply(BigDecimal.valueOf(i));
        }
        return ret;
    }
}

普通线程池,多线程

package com.wz.poc.forkjoin;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

/**
 * @author liweizhi
 * @date 2020/12/31
 */
public class ExecutorServiceCalculator implements Calculator {
    @Override
    public BigDecimal factorial(long number) {
        List<Future<BigDecimal>> results = new ArrayList<>();

        // 把任务分解为 n 份,交给 n 个线程处理   4核心 就等分成4份呗
        // 然后把每一份都扔个一个SumTask线程 进行处理
        long pageSize = number / parallism;
        for (int i = 0; i < parallism; i++) {
            long from = i * pageSize + 1; //开始位置
            long to = i == parallism - 1 ? number : Math.min(from + pageSize - 1, number); //结束位置

            //扔给线程池计算
            results.add(pool.submit(new SumTask(from, to)));
        }

        // 把每个线程的结果相加,得到最终结果 get()方法 是阻塞的
        // 优化方案:可以采用CompletableFuture来优化  JDK1.8的新特性
        BigDecimal ret = BigDecimal.valueOf(1);
        for (Future<BigDecimal> f : results) {
            try {
                ret = f.get().multiply(ret);
            } catch (Exception ignore) {
            }
        }

        // 方便测试程序退出
        pool.shutdown();
        return ret;
    }

    private static final int parallism = Runtime.getRuntime().availableProcessors();
    private static final ExecutorService pool = Executors.newFixedThreadPool(parallism);

    //处理计算任务的线程
    private static class SumTask implements Callable<BigDecimal> {

        private long from;
        private long to;

        public SumTask(long from, long to) {
            this.from = from;
            this.to = to;
        }

        @Override
        public BigDecimal call() {
            BigDecimal ret = new BigDecimal(1);
            for (long i = from; i <= to; i++) {
                ret = ret.multiply(BigDecimal.valueOf(i));
            }
            return ret;
        }
    }

}

ForkJoinPool

package com.wz.poc.forkjoin;

import java.math.BigDecimal;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

/**
 * @author liweizhi
 * @date 2020/12/31
 */
public class ForkJoinCalculator implements Calculator {
    @Override
    public BigDecimal factorial(long number) {
        BigDecimal invoke = pool.invoke(new SumTask(1, number));
        // 方便测试程序退出
        pool.shutdown();
        return invoke;
    }

    private static final ForkJoinPool pool = new ForkJoinPool();

    //执行任务RecursiveTask:有返回值  RecursiveAction:无返回值
    private static class SumTask extends RecursiveTask<BigDecimal> {
        private long from;
        private long to;

        public SumTask(long from, long to) {
            this.from = from;
            this.to = to;
        }

        //此方法为ForkJoin的核心方法:对任务进行拆分  拆分的好坏决定了效率的高低
        @Override
        protected BigDecimal compute() {
            // 当需要计算的数字个数小于1_0000时,直接采用for loop方式计算结果
            if (to - from < 1_0000) {
                BigDecimal ret = new BigDecimal(1);
                for (long i = from; i <= to; i++) {
                    ret = ret.multiply(BigDecimal.valueOf(i));
                }
                return ret;
            } else { // 否则,把任务一分为二,递归拆分(注意此处有递归)到底拆分成多少分 需要根据具体情况而定
                long middle = (from + to) / 2;
                SumTask taskLeft = new SumTask(from, middle);
                SumTask taskRight = new SumTask(middle + 1, to);
                taskLeft.fork();
                taskRight.fork();
                return taskLeft.join().multiply(taskRight.join());
            }
        }
    }

}

测试main方法

package com.wz.poc.forkjoin;

import java.math.BigDecimal;
import java.time.Duration;
import java.time.Instant;

/**
 * @author liweizhi
 * @date 2020/12/31
 */
public class MainTest {
    public static void main(String[] args) {
        long numbers = 100_0000;

        Calculator forLoopCalculator = new ForLoopCalculator();
        Calculator executorServiceCalculator = new ExecutorServiceCalculator();
        Calculator forkJoinCalculator = new ForkJoinCalculator();

        Instant start, end;

        // 热热身
        forLoopCalculator.factorial(10000);

        start = Instant.now();
        BigDecimal result_1 = forLoopCalculator.factorial(numbers);
        end = Instant.now();
        System.out.println("forLoopCalculator耗时:" + Duration.between(start, end).toMillis() + "ms");

        start = Instant.now();
        BigDecimal result_2 = executorServiceCalculator.factorial(numbers);
        end = Instant.now();
        System.out.println("executorServiceCalculator:" + Duration.between(start, end).toMillis() + "ms");

        start = Instant.now();
        BigDecimal result_3 = forkJoinCalculator.factorial(numbers);
        end = Instant.now();
        System.out.println("forkJoinCalculator:" + Duration.between(start, end).toMillis() + "ms");

        System.out.println("三者是否相等" + (result_1.equals(result_2) && result_1.equals(result_3)));
    }
}

测试结果

电脑信息

我的电脑是一台笔记本,联想的thinkbook2021,
鲁大师电脑概览:

电脑型号	联想 20VF 笔记本电脑  (扫描时间:2020年12月31日)
操作系统	Windows 10 64位 ( DirectX 12 )
	
处理器	AMD Ryzen 5 4600U with Radeon Graphics 六核
主板	联想 LNVNB161216 ( AMD PCI 标准主机 CPU 桥 )
内存	16 GB ( DDR4 3200MHz )
主硬盘	三星 MZALQ512HALU-000L2 ( 512 GB / 固态硬盘 )
主显卡	AMD Radeon Graphics ( 512 MB / 联想 )
显示器	友达 AUO683D ( 14 英寸  )
声卡	瑞昱  @ AMD High Definition Audio 控制器
网卡	瑞昱 RTL8168/8111/8112 Gigabit Ethernet Controller / 联想

三种方法输出结果(计算1万,10万,100万的阶乘)

10000:
forLoopCalculator耗时:27ms
executorServiceCalculator:28ms
forkJoinCalculator:36ms

10_0000:
forLoopCalculator耗时:2905ms
executorServiceCalculator:393ms
forkJoinCalculator:151ms


100_0000:
forLoopCalculator耗时:351565ms
executorServiceCalculator:17955ms
forkJoinCalculator:2667ms
三者是否相等true

本文地址:https://blog.csdn.net/weixin_42008012/article/details/112008229