欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

ThreadLocal 相关,看这里就够了

程序员文章站 2024-03-05 18:32:07
...


写在开头

       在 Java 的多线程模块中,ThreadLocal是经常被提问到的一个知识点。

1.什么是 ThreadLocal

       早在 JDK 1.2 中就已经为我们提供了ThreadLocal类。ThreadLocal 类的出现,为解决多线程程序的并发问题提供了一种新的思路。使用这个工具类,我们可以很简洁的来编写多线程程序。

       当多线程同步访问共享数据时,为了保证数据的原子操作,我们可以使用 1.synchronized加锁方式或者2.使用AtomicInteger等原子类的方式来解决多线程之间的共享数据问题。很多人会将 ThreadLocal 和共享数据等为一谈,但这两个却是完全不同的概念。

       重点区别:共享数据是多个线程对同一数据的访问;而ThreadLocal 使用到的是线程封闭(即:数据都被封闭在各自的线程之中,各自线程操作各自的值,并不会涉及到多线程的同步问题,这种通过将数据封闭在各自线程中而避免使用同步的技术,就叫做线程封闭

       ThreadLocal 作为一个线程级变量。在多线程任务中,ThreadLocal 会自动的在每一个线程上创建一个它的副本,副本之间彼此独立,互不影响。每个线程都会拥有自己独立的一个 ThreadLocal 变量(线程封闭概念的体现)。类似于单线程操作各自变量,不存在任何的竞争,在并发模式下是绝对安全的变量。通常我们使用static final来修饰定义 ThreadLocal,定义如下:

private static final ThreadLocal<T> threadLocal = new ThreadLocal<T>();

2.ThreadLocal源码

       ThreadLocal 作为一个工具类来说,它的使用是比较简单的。因为该类只提供了:set(T value)get()remove()三个方法供我们使用。在介绍源码前,先来简单说明一下 ThreadLocal 的实现思路。

2.1 实现思路

  1. 上面介绍我们知道,每个线程都有自己独立的一个 ThreadLocal 副本;
  2. ThreadLocal 数据,是保存在 ThreadLocal 类下的 ThreadLocalMap 中;(ThreadLocalMap 是 ThreadLocal 类下的一个内部类);
  3. get() 从 ThreadLocal 中取值,就是将当前线程作为key,从自己线程的 ThreadLocalMap 中依次循环获取指定的value值。

2.2 分析前你该知道

1.ThreadLocalMap 规定了 table 的大小必须是2的幂次方

/**
 * The initial capacity -- MUST be a power of two.
 */
private static final int INITIAL_CAPACITY = 16;

       这个规定其实在 ConcurrentHashMap中也有提及到,而且在很多面试中也会被问到这个问题,这两个问题的原理都是一样的。从计算机的角度来讲,它对位运算的操作效率明显高于数学运算

       在 ThreadLocal 中,它是通过int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);这个公式来计算数据落位的下标。比如说当前 table 长度为16,那么 16-1=15,二进制就是1111。firstKey.threadLocalHashCode计算的是当前线程的 HashCode 值,比如说为23,二进制就是00010111。将这两个数字进行 & 运算(与运算)。结果如下:

00010111
&
00001111 ------------》&运算结果为:00000111

       结果:0000011转为十进制就是7,同我们取模运算 23/16=7 结果相同。结果相同,效率优先,显然就选择位运算了。此处仅仅是拿 23 举个例子,你可以随机举例,但是你会发现 &运算后,结果永远都是在0-15这个范围内,正好可以匹配数组的下标。机智吧,位运算用到如此地步,牛逼。 ConcurrentHashMap也是通过key.hash & (length-1) 的公式来计算下标的。

2.3 源码分析

接下来就从这三个方法入手,来了解 ThreadLocal 的源码实现。

1.set(T value)

1.1 提前罗列出:set() 源码分析会用到的一些变量

/**
 * The initial capacity -- MUST be a power of two.
 */
private static final int INITIAL_CAPACITY = 16;//必须是2的幂次方

/**
 * The table, resized as necessary.
 * table.length MUST always be a power of two.
 */
private Entry[] table;//是一个Entry[]数组

/**
 * The number of entries in the table.
 */
private int size = 0;

/**
 * The next size value at which to resize.
 */
private int threshold; // Default to 0  //扩容使用

1.2 set(T value) 方法源码分析

public void set(T value) {
	//获取当前线程(调用者线程)
	Thread t = Thread.currentThread();
	//以当前线程作为key值,去查找对应的线程变量,找到对应的Map
	ThreadLocalMap map = getMap(t);    //返回来的是一个 ThreadLocal.ThreadLocalMap对象
	//如果map不等于null,就直接添加本地变量,key为当前线程,值为要添加的变量值
	if (map != null)
		//在下面1.3中分析
		map.set(this, value);
	//如果 map == null,说明是首次添加,需要首先创建对应的Map
	else
		//创建Map方法,向下看
		createMap(t, value);
}

void createMap(Thread t, T firstValue) {
	//使用构造器的方式创建Map,源码分析继续向下看
	t.threadLocals = new ThreadLocalMap(this, firstValue);
}

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
	//初始化table
	table = new Entry[INITIAL_CAPACITY];
	//通过公式计算得到下标
	int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
	//当前线程为key,值为value,组装Entry后,赋值到table数组某个坑位中
	table[i] = new Entry(firstKey, firstValue);
	size = 1;
	//扩容相关
	setThreshold(INITIAL_CAPACITY);
}

1.3 map.set(this, value);方法实现

private void set(ThreadLocal<?> key, Object value) {

	// We don't use a fast path as with get() because it is at
	// least as common to use set() to create new entries as
	// it is to replace existing ones, in which case, a fast
	// path would fail more often than not.
	
	Entry[] tab = table;
	int len = tab.length;
	//获取table下标
	int i = key.threadLocalHashCode & (len-1);
	//for循环,循环遍历判断当前坑位是否有值,有值的话开始比较,key相同的话,值覆盖;key为空的话,赋值;
	//key不相同的话,使用nextIndex()方法,下标 i+1,继续判断坑位是否为空,为空赋值,不为空继续判断,直到扩容(此处不介绍扩容)
	for (Entry e = tab[i];
		 e != null;
		 e = tab[i = nextIndex(i, len)]) {
		ThreadLocal<?> k = e.get();
		//
		if (k == key) {
			e.value = value;
			return;
		}

		if (k == null) {
			replaceStaleEntry(key, value, i);
			return;
		}
	}

	tab[i] = new Entry(key, value);
	int sz = ++size;
	if (!cleanSomeSlots(i, sz) && sz >= threshold)
		rehash();
}

//nextIndex 算法
private static int nextIndex(int i, int len) {
	return ((i + 1 < len) ? i + 1 : 0);
}
set源码分析小结

1.获取当前线程,再获取当前线程的 ThreadLocalMap;
2.ThreadLocalMap 为空,则通过 createMap() 方法创建,并赋值;不为空的话,使用 map.set(this, value);赋值;
3. 赋值时,根据int i = key.threadLocalHashCode & (len-1);公式获取当前线程所在 table(即:Entry[]数组)的下标,进行赋值;
4.在赋值时,通过for循环进行判断。当前下标有值并且下标==当前线程,进行覆盖    当前下标有值并且下标 != 当前线程,使用nextIndex() 方法,下标+1,继续判断    当前下标没值,进行赋值操作。

2.get()

    如果你已经理解了set(T value)方法的实现,接下来的get()方法就更简单了。

2.1 get() 源码基础实现

public T get() {
	//获取当前线程
	Thread t = Thread.currentThread();
	//从当前线程中获取到 ThreadLocalMap
	ThreadLocalMap map = getMap(t);
	if (map != null) {
		//从ThreadLocalMap中,根据key找出当前线程所对应的Entry
		//(具体实现方法介绍,参考下文2.2)
		ThreadLocalMap.Entry e = map.getEntry(this);
		if (e != null) {
			@SuppressWarnings("unchecked")
			//如果Entry不为空,直接返回value值
			T result = (T)e.value;
			return result;
		}
	}
	//否则,调用setInitialValue()方法,设置初始值并返回(在Entry[]数组上指定下标,设置值为null)
	return setInitialValue();
}

private T setInitialValue() {
	T value = initialValue();//此处initialValue()返回为 null,所以默认value为null
	Thread t = Thread.currentThread();
	ThreadLocalMap map = getMap(t);
	if (map != null)
		//set()/creatMap()方法,同之前介绍的一样,不再介绍
		map.set(this, value);
	else
		createMap(t, value);
	return value;
}

2.2 map.getEntry(this);方法实现

private Entry getEntry(ThreadLocal<?> key) {
	//获取Entry[]数组下标
	int i = key.threadLocalHashCode & (table.length - 1);
	//找到指定Entry
	Entry e = table[i];
	//Entry不为空,并且key==当前线程
	if (e != null && e.get() == key)
		//直接返回当前Entry
		return e;
	else
		//反之调用 getEntryAfterMiss()
		return getEntryAfterMiss(key, i, e);
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
	Entry[] tab = table;
	int len = tab.length;
	//Entry不为空
	while (e != null) {
		ThreadLocal<?> k = e.get();
		//key==当前线程,返回当前 Entry
		if (k == key)
			return e;
		if (k == null)//key为空,重新rehash(此处不做分析)
			expungeStaleEntry(i);
		else//否则,下标+1,继续遍历查找
			i = nextIndex(i, len);
		e = tab[i];
	}
	return null;
}
get图解

    图片copy来源: https://www.cnblogs.com/cjsblog/p/9773079.html
ThreadLocal 相关,看这里就够了

get源码分析小结

1.获取当前线程,再获取当前线程的 ThreadLocalMap;
2.ThreadLocalMap 不为空,使用getEntry()方法获取 Entry,开始进行判断;
3.①Entry不为空并且key当前线程,直接返回Entry;    ②Entry不为空并且keynull,重新rehash,此处不作介绍     ③还找不到的话,使用nextIndex() 方法,下标+1,继续循环遍历查找。

3.remove()

    remove() 实现,也是比较简单的

3.1 remove() 源码基础实现

private void remove(ThreadLocal<?> key) {
	Entry[] tab = table;
	int len = tab.length;
	//获取下标
	int i = key.threadLocalHashCode & (len-1);
	//for循环遍历(如果key != 当前线程,使用nextIndex()方法,下标+1,继续遍历)
	for (Entry e = tab[i];
		 e != null;
		 e = tab[i = nextIndex(i, len)]) {
		//如果key==当前线程,直接删除
		if (e.get() == key) {
			e.clear();
			expungeStaleEntry(i);
			return;
		}
	}
}
remove源码分析小结

1.根据当前线程,使用公式计算,获得下标;
2.for循环遍历,判断当前线程==Entry中的key,等于的话,直接clear()删除;
3.反之,使用nextIndex()方法,下标+1,继续遍历查找删除。