欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

死磕 java线程系列之ForkJoinPool深入解析

程序员文章站 2023-10-29 11:24:28
(手机横屏看源码更方便) 注:java源码分析部分如无特殊说明均基于 java8 版本。 注:本文基于ForkJoinPool分治线程池类。 简介 随着在硬件上多核处理器的发展和广泛使用,并发编程成为程序员必须掌握的一门技术,在面试中也经常考查面试者并发相关的知识。 今天,我们就来看一道面试题: 如 ......

死磕 java线程系列之ForkJoinPool深入解析

(手机横屏看源码更方便)


注:java源码分析部分如无特殊说明均基于 java8 版本。

注:本文基于forkjoinpool分治线程池类。

简介

随着在硬件上多核处理器的发展和广泛使用,并发编程成为程序员必须掌握的一门技术,在面试中也经常考查面试者并发相关的知识。

今天,我们就来看一道面试题:

如何充分利用多核cpu,计算很大数组中所有整数的和?

剖析

  • 单线程相加?

我们最容易想到就是单线程相加,一个for循环搞定。

  • 线程池相加?

如果进一步优化,我们会自然而然地想到使用线程池来分段相加,最后再把每个段的结果相加。

  • 其它?

yes,就是我们今天的主角——forkjoinpool,但是它要怎么实现呢?似乎没怎么用过哈^^

三种实现

ok,剖析完了,我们直接来看三种实现,不墨迹,直接上菜。

/**
 * 计算1亿个整数的和
 */
public class forkjoinpooltest01 {
    public static void main(string[] args) throws executionexception, interruptedexception {
        // 构造数据
        int length = 100000000;
        long[] arr = new long[length];
        for (int i = 0; i < length; i++) {
            arr[i] = threadlocalrandom.current().nextint(integer.max_value);
        }
        // 单线程
        singlethreadsum(arr);
        // threadpoolexecutor线程池
        multithreadsum(arr);
        // forkjoinpool线程池
        forkjoinsum(arr);

    }

    private static void singlethreadsum(long[] arr) {
        long start = system.currenttimemillis();

        long sum = 0;
        for (int i = 0; i < arr.length; i++) {
            // 模拟耗时,本文由公从号“彤哥读源码”原创
            sum += (arr[i]/3*3/3*3/3*3/3*3/3*3);
        }

        system.out.println("sum: " + sum);
        system.out.println("single thread elapse: " + (system.currenttimemillis() - start));

    }

    private static void multithreadsum(long[] arr) throws executionexception, interruptedexception {
        long start = system.currenttimemillis();

        int count = 8;
        executorservice threadpool = executors.newfixedthreadpool(count);
        list<future<long>> list = new arraylist<>();
        for (int i = 0; i < count; i++) {
            int num = i;
            // 分段提交任务
            future<long> future = threadpool.submit(() -> {
                long sum = 0;
                for (int j = arr.length / count * num; j < (arr.length / count * (num + 1)); j++) {
                    try {
                        // 模拟耗时
                        sum += (arr[j]/3*3/3*3/3*3/3*3/3*3);
                    } catch (exception e) {
                        e.printstacktrace();
                    }
                }
                return sum;
            });
            list.add(future);
        }

        // 每个段结果相加
        long sum = 0;
        for (future<long> future : list) {
            sum += future.get();
        }

        system.out.println("sum: " + sum);
        system.out.println("multi thread elapse: " + (system.currenttimemillis() - start));
    }

    private static void forkjoinsum(long[] arr) throws executionexception, interruptedexception {
        long start = system.currenttimemillis();

        forkjoinpool forkjoinpool = forkjoinpool.commonpool();
        // 提交任务
        forkjointask<long> forkjointask = forkjoinpool.submit(new sumtask(arr, 0, arr.length));
        // 获取结果
        long sum = forkjointask.get();

        forkjoinpool.shutdown();

        system.out.println("sum: " + sum);
        system.out.println("fork join elapse: " + (system.currenttimemillis() - start));
    }

    private static class sumtask extends recursivetask<long> {
        private long[] arr;
        private int from;
        private int to;

        public sumtask(long[] arr, int from, int to) {
            this.arr = arr;
            this.from = from;
            this.to = to;
        }

        @override
        protected long compute() {
            // 小于1000的时候直接相加,可灵活调整
            if (to - from <= 1000) {
                long sum = 0;
                for (int i = from; i < to; i++) {
                    // 模拟耗时
                    sum += (arr[i]/3*3/3*3/3*3/3*3/3*3);
                }
                return sum;
            }

            // 分成两段任务,本文由公从号“彤哥读源码”原创
            int middle = (from + to) / 2;
            sumtask left = new sumtask(arr, from, middle);
            sumtask right = new sumtask(arr, middle, to);

            // 提交左边的任务
            left.fork();
            // 右边的任务直接利用当前线程计算,节约开销
            long rightresult = right.compute();
            // 等待左边计算完毕
            long leftresult = left.join();
            // 返回结果
            return leftresult + rightresult;
        }
    }
}

彤哥偷偷地告诉你,实际上计算1亿个整数相加,单线程是最快的,我的电脑大概是100ms左右,使用线程池反而会变慢。

所以,为了演示forkjoinpool的牛逼之处,我把每个数都/3*3/3*3/3*3/3*3/3*3了一顿操作,用来模拟计算耗时。

来看结果:

sum: 107352457433800662
single thread elapse: 789
sum: 107352457433800662
multi thread elapse: 228
sum: 107352457433800662
fork join elapse: 189

可以看到,forkjoinpool相对普通线程池还是有很大提升的。

问题:普通线程池能否实现forkjoinpool这种计算方式呢,即大任务拆中任务,中任务拆小任务,最后再汇总?

死磕 java线程系列之ForkJoinPool深入解析

你可以试试看(-᷅_-᷄)

ok,下面我们正式进入forkjoinpool的解析。

分治法

  • 基本思想

把一个规模大的问题划分为规模较小的子问题,然后分而治之,最后合并子问题的解得到原问题的解。

  • 步骤

(1)分割原问题:

(2)求解子问题:

(3)合并子问题的解为原问题的解。

在分治法中,子问题一般是相互独立的,因此,经常通过递归调用算法来求解子问题。

  • 典型应用场景

(1)二分搜索

(2)大整数乘法

(3)strassen矩阵乘法

(4)棋盘覆盖

(5)归并排序

(6)快速排序

(7)线性时间选择

(8)汉诺塔

forkjoinpool继承体系

forkjoinpool是 java 7 中新增的线程池类,它的继承体系如下:

死磕 java线程系列之ForkJoinPool深入解析

forkjoinpool和threadpoolexecutor都是继承自abstractexecutorservice抽象类,所以它和threadpoolexecutor的使用几乎没有多少区别,除了任务变成了forkjointask以外。

这里又运用到了一种很重要的设计原则——开闭原则——对修改关闭,对扩展开放。

可见整个线程池体系一开始的接口设计就很好,新增一个线程池类,不会对原有的代码造成干扰,还能利用原有的特性。

forkjointask

两个主要方法

  • fork()

fork()方法类似于线程的thread.start()方法,但是它不是真的启动一个线程,而是将任务放入到工作队列中。

  • join()

join()方法类似于线程的thread.join()方法,但是它不是简单地阻塞线程,而是利用工作线程运行其它任务。当一个工作线程中调用了join()方法,它将处理其它任务,直到注意到目标子任务已经完成了。

三个子类

  • recursiveaction

无返回值任务。

  • recursivetask

有返回值任务。

  • countedcompleter

无返回值任务,完成任务后可以触发回调。

forkjoinpool内部原理

forkjoinpool内部使用的是“工作窃取”算法实现的。

死磕 java线程系列之ForkJoinPool深入解析

(1)每个工作线程都有自己的工作队列workqueue;

(2)这是一个双端队列,它是线程私有的;

(3)forkjointask中fork的子任务,将放入运行该任务的工作线程的队头,工作线程将以lifo的顺序来处理工作队列中的任务;

(4)为了最大化地利用cpu,空闲的线程将从其它线程的队列中“窃取”任务来执行;

(5)从工作队列的尾部窃取任务,以减少竞争;

(6)双端队列的操作:push()/pop()仅在其所有者工作线程中调用,poll()是由其它线程窃取任务时调用的;

(7)当只剩下最后一个任务时,还是会存在竞争,是通过cas来实现的;

死磕 java线程系列之ForkJoinPool深入解析

forkjoinpool最佳实践

(1)最适合的是计算密集型任务,本文由公从号“彤哥读源码”原创;

(2)在需要阻塞工作线程时,可以使用managedblocker;

(3)不应该在recursivetask的内部使用forkjoinpool.invoke()/invokeall();

总结

(1)forkjoinpool特别适合于“分而治之”算法的实现;

(2)forkjoinpool和threadpoolexecutor是互补的,不是谁替代谁的关系,二者适用的场景不同;

(3)forkjointask有两个核心方法——fork()和join(),有三个重要子类——recursiveaction、recursivetask和countedcompleter;

(4)forkjoinpool内部基于“工作窃取”算法实现;

(5)每个线程有自己的工作队列,它是一个双端队列,自己从队列头存取任务,其它线程从尾部窃取任务;

(6)forkjoinpool最适合于计算密集型任务,但也可以使用managedblocker以便用于阻塞型任务;

(7)recursivetask内部可以少调用一次fork(),利用当前线程处理,这是一种技巧;

彩蛋

managedblocker怎么使用?

答:managedblocker相当于明确告诉forkjoinpool框架要阻塞了,forkjoinpool就会启另一个线程来运行任务,以最大化地利用cpu。

请看下面的例子,自己琢磨哈^^。

/**
 * 斐波那契数列
 * 一个数是它前面两个数之和
 * 1,1,2,3,5,8,13,21
 */
public class fibonacci {

    public static void main(string[] args) {
        long time = system.currenttimemillis();
        fibonacci fib = new fibonacci();
        int result = fib.f(1_000).bitcount();
        time = system.currenttimemillis() - time;
        system.out.println("result,本文由公从号“彤哥读源码”原创 = " + result);
        system.out.println("test1_000() time = " + time);
    }

    public biginteger f(int n) {
        map<integer, biginteger> cache = new concurrenthashmap<>();
        cache.put(0, biginteger.zero);
        cache.put(1, biginteger.one);
        return f(n, cache);
    }

    private final biginteger reserved = biginteger.valueof(-1000);

    public biginteger f(int n, map<integer, biginteger> cache) {
        biginteger result = cache.putifabsent(n, reserved);
        if (result == null) {

            int half = (n + 1) / 2;

            recursivetask<biginteger> f0_task = new recursivetask<biginteger>() {
                @override
                protected biginteger compute() {
                    return f(half - 1, cache);
                }
            };
            f0_task.fork();

            biginteger f1 = f(half, cache);
            biginteger f0 = f0_task.join();

            long time = n > 10_000 ? system.currenttimemillis() : 0;
            try {

                if (n % 2 == 1) {
                    result = f0.multiply(f0).add(f1.multiply(f1));
                } else {
                    result = f0.shiftleft(1).add(f1).multiply(f1);
                }
                synchronized (reserved) {
                    cache.put(n, result);
                    reserved.notifyall();
                }
            } finally {
                time = n > 10_000 ? system.currenttimemillis() - time : 0;
                if (time > 50)
                    system.out.printf("f(%d) took %d%n", n, time);
            }
        } else if (result == reserved) {
            try {
                reservedfibonacciblocker blocker = new reservedfibonacciblocker(n, cache);
                forkjoinpool.managedblock(blocker);
                result = blocker.result;
            } catch (interruptedexception e) {
                throw new cancellationexception("interrupted");
            }

        }
        return result;
        // return f(n - 1).add(f(n - 2));
    }

    private class reservedfibonacciblocker implements forkjoinpool.managedblocker {
        private biginteger result;
        private final int n;
        private final map<integer, biginteger> cache;

        public reservedfibonacciblocker(int n, map<integer, biginteger> cache) {
            this.n = n;
            this.cache = cache;
        }

        @override
        public boolean block() throws interruptedexception {
            synchronized (reserved) {
                while (!isreleasable()) {
                    reserved.wait();
                }
            }
            return true;
        }

        @override
        public boolean isreleasable() {
            return (result = cache.get(n)) != reserved;
        }
    }
}

欢迎关注我的公众号“彤哥读源码”,查看更多源码系列文章, 与彤哥一起畅游源码的海洋。

死磕 java线程系列之ForkJoinPool深入解析