目录
ForkJoinWorkerThread
简单说明
一、主程序
二、fork/join线程工厂类
三、自定义fork/join线程类
四、分治的任务类
五、执行结果
ForkJoinWorkerThread
该类拓展自Thread类,为其增加了新方法,用于子类拓展:
onStart()方法,在创建线程时执行。onTermination()方法,结束时进行资源清理。
ForkJoinPool类使用ForkJoinWorkerThreadFactory的接口实现来创建它(ForkJoinPool)使用的工作线程。
简单说明
我们要创建自定义的ForkJoin线程,就要拓展ForkJoinWorkerThread类(即继承自它)。由于线程池使用线程工厂创建,所以要实现ForkJoinWorkerThreadFactory接口,以返回自定义的ForkJoin线程对象。
一、主程序
package xyz.jangle.thread.test.n8_7.forkjointhreadfactory;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
/**
* 8.7、实现自定义的fork/join线程类(拓展fork/join线程类的功能)
* @author jangle
* @email jangle@jangle.xyz
* @time 2020年10月3日 下午5:33:20
*
*/
public class M {
public static void main(String[] args) throws Exception {
// 创建线程工厂
var factory = new MyWorkerThreadFactory();
// 使用上述工厂,构建线程池
var pool = new ForkJoinPool(4, factory, null, false);
int array[] = new int[100000];
for (int i = 0; i < array.length; i++) {
array[i] = 1;
}
// 创建解决问题的任务对象
var task = new MyRecursiveTask(array, 0, array.length);
pool.execute(task);
task.join();
pool.shutdown();
pool.awaitTermination(1, TimeUnit.DAYS);
System.out.println("Main: resutl:"+task.get());
System.out.println("Main:结束");
}
}
二、fork/join线程工厂类
package xyz.jangle.thread.test.n8_7.forkjointhreadfactory;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory;
import java.util.concurrent.ForkJoinWorkerThread;
/**
* ForkJoin线程工厂
* @author jangle
* @email jangle@jangle.xyz
* @time 2020年10月3日 下午5:43:11
*
*/
public class MyWorkerThreadFactory implements ForkJoinWorkerThreadFactory {
@Override
public ForkJoinWorkerThread newThread(ForkJoinPool pool) {
return new MyWorkerThread(pool);
}
}
三、自定义fork/join线程类
package xyz.jangle.thread.test.n8_7.forkjointhreadfactory;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
/**
* 自定义fork/join线程类
* @author jangle
* @email jangle@jangle.xyz
* @time 2020年10月3日 下午5:34:00
*
*/
public class MyWorkerThread extends ForkJoinWorkerThread {
// 用于计算执行了几个任务
private final static ThreadLocal<Integer> taskCounter = new ThreadLocal<Integer>();
protected MyWorkerThread(ForkJoinPool pool) {
super(pool);
}
@Override
protected void onStart() {
super.onStart();
System.out.println("MyWorkerThread: onStart getId():" + getId());
taskCounter.set(0);
}
@Override
protected void onTermination(Throwable exception) {
System.out.println("MyWorkerThread: onTermination " + getId() + ":" + taskCounter.get());
super.onTermination(exception);
}
/**
* 增加任务计数。
*/
public void addTask() {
taskCounter.set(taskCounter.get() + 1);
}
}
四、分治的任务类
package xyz.jangle.thread.test.n8_7.forkjointhreadfactory;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;
/**
* 分治任务类
* @author jangle
* @email jangle@jangle.xyz
* @time 2020年10月3日 下午5:53:45
*
*/
public class MyRecursiveTask extends RecursiveTask<Integer> {
private static final long serialVersionUID = 1L;
private int array[];
private int start, end;
public MyRecursiveTask(int[] array, int start, int end) {
super();
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
Integer ret;
MyWorkerThread thread = (MyWorkerThread) Thread.currentThread();
thread.addTask();
if (end - start <= 100) {
// 计算
int add = 0;
for (int i = start; i < end; i++) {
add += array[i];
}
ret = add;
} else {
// 分治
int mid = (start + end) / 2;
var task1 = new MyRecursiveTask(array, start, mid);
var task2 = new MyRecursiveTask(array, mid, end);
invokeAll(task1, task2);
ret = addResults(task1, task2);
}
try {
TimeUnit.MILLISECONDS.sleep(10);
} catch (InterruptedException e) {
e.printStackTrace();
}
return ret;
}
private Integer addResults(MyRecursiveTask task1, MyRecursiveTask task2) {
int value;
try {
value = task1.get().intValue() + task2.get().intValue();
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
value = 0;
}
return value;
}
}
五、执行结果
MyWorkerThread: onStart getId():13
MyWorkerThread: onStart getId():14
MyWorkerThread: onStart getId():15
MyWorkerThread: onStart getId():16
MyWorkerThread: onStart getId():17
MyWorkerThread: onTermination 15:569
MyWorkerThread: onTermination 16:576
MyWorkerThread: onTermination 13:428
MyWorkerThread: onTermination 17:0
MyWorkerThread: onTermination 14:474
Main: resutl:100000
Main:结束