实现原理

每个线程读写ThreadLocal是线程隔离的,互相之间不会影响。其原因就是在于Thread类有一个ThreadLocal.ThreadLocalMap类型的属性,也就是说每个线程有一个自己的ThreadLocalMap,读写某个ThreadLocal时都会获取当前线程以及当前线程的ThreadLocalMap属性,对其进行读写,以此实现线程隔离。以下是ThreadLocal的几个关键方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value); // 将自己作为 key
else
createMap(t, value);
}

public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this); // 将自己作为 key
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}

可以看出,ThreadLocal本身并没有太多东西,它只是作为ThreadLocalMap的key,核心源码其实都在ThreadLocalMap中。

弱引用

在开始源码分析之前,需要先了解弱引用这个概念,因为在ThreadLocalMapThreadLocal并不是直接作为key的,而是使用的弱引用对象Entry。在Java中存在四种引用,分别是强引用、软引用、弱引用和虚引用,它们的区别如下:

  • 强引用:通常我们通过new来创建一个新对象时返回的引用就是一个强引用,若一个对象通过一系列强引用可到达,它就是强可达的,那么它就不被回收
  • 软引用:软引用和弱引用的区别在于,若一个对象是弱引用可达,无论当前内存是否充足它都会被回收,而软引用可达的对象在内存不充足时才会被回收,因此软引用要比弱引用“强”一些
  • 虚引用:虚引用是Java中最弱的引用,通过虚引用甚至无法获取到被引用的对象,虚引用存在的唯一作用就是当它指向的对象被回收后,虚引用本身会被加入到引用队列中,用作记录它指向的对象已被回收

之所以需要弱引用,是因为在类似HashMap的结构中,如果存放了一个key为Product对象且value为1的节点,此时我们有一个变量product指向了这个Product对象,当我们不再需要这个对象时,如果直接将product设为nullProduct对象其实并不会被回收,因为通过HashMap它还存在一条强引用链,如果我们想让它被垃圾收集器回收,就必须将其彻底从HashMap中移除,让它不再存在任何强引用。如果上述过程我们不想自己手动去实现,而是想告诉垃圾收集器在只有HashMap中的key引用着Product对象的情况下,就可以回收相应的Product对象了,那么就可以使用弱引用。

Java中的弱引用具体指的是java.lang.ref.WeakReference<T>类,我们使用一个指向Product对象的弱引用对象来作为HashMapkey,只需这样定义这个弱引用对象:

1
2
Product product = new Product(...);
WeakReference<Product> weakProduct = new WeakReference<>(product);

而如果要通过weakProduct获取它所指向的Product对象,我们只需要通过这行代码:Product product = weakProductA.get();即可。WeakReference的构造函数如下:

1
2
3
4
//创建一个指向给定对象的弱引用
WeakReference(T referent)
//创建一个指向给定对象并且登记到给定引用队列的弱引用
WeakReference(T referent, ReferenceQueue<? super T> q)

通过将原始对象包装成弱引用对象,当变量product设为null时,指向这个Product对象的就只剩弱引用对象weakProduct了,显然这时候相应的Product对象是弱可达的,所以指向它的弱引用会被清除,这个Product对象随即会被回收,指向它的弱引用对象会进入引用队列中,在引用队列中可以对这些被清除的弱引用对象进行统一管理。

源码分析

Entry节点

上面说过,ThreadLocalMap并不是简单的使用ThreadLocal作为key的,其实它内部存储着一个Entry节点数组,而Entry继承了弱引用类WeakReference

1
2
3
4
5
6
7
8
9
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

当构造一个Entry节点时,会先调用父类WeakReference的构造函数将ThreadLocal传入,并设置了一个类型为Objectvalue,用于存放ThreadLocal对应的值。

这里之所以要使用弱引用Entry节点而不是简单的key-value形式的节点,是因为如果简单的使用key-value形式会造成节点的生命周期与线程强绑定,只要线程存在,那么作为属性的ThreadLocalMap也就存在,在不显式移除的情况下,key对象就依然被强引用着,没办法被回收。在这里通过使用弱引用节点,当我们将某个ThreadLocal对象的强引用设为null后,这个ThreadLocal对象就只剩下弱引用了,之后会被GC回收掉,有效的避免了内存泄漏的问题。

成员变量

1
2
3
4
5
6
7
8
9
10
11
  // 初始容量默认为16
private static final int INITIAL_CAPACITY = 16;

// Entry 数组,大小必须为2的幂
private Entry[] table;

// 数组中 Entry 的实际个数
private int size = 0;

// 扩容时的阈值
private int threshold;

ThreadLocalMapHashMap不同,它是使用的线性探测法而非拉链法解决碰撞冲突的,所以实际上Entry[]数组在逻辑上是作为一个环形存在的。

1
2
3
4
5
6
7
8
9
// 环形意义的下一个索引下标
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}

// 环形意义的前一个索引下标
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}

构造函数

1
2
3
4
5
6
7
8
9
10
11
12
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
// 初始化 table 数组
table = new Entry[INITIAL_CAPACITY];
// 用位运算而非取模得到下标,这也是为什么容量需要为偶数的原因
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
// 构造并设置该节点
table[i] = new Entry(firstKey, firstValue);
// 更新表(数组)的大小
size = 1;
// 设置扩容阈值
setThreshold(INITIAL_CAPACITY);
}

这个构造函数我们重点需要关注其中的threadLocalHashCode,这是传入的ThreadLocal对象的哈希值:

1
2
3
4
5
6
7
8
9
   private static AtomicInteger nextHashCode = new AtomicInteger();

private final int threadLocalHashCode = nextHashCode();

private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}

private static final int HASH_INCREMENT = 0x61c88647;

这个哈希值在对象创建时就会生成,每次都会累加0x61c88647,通过这种方式使得与2的幂取模(实际是位运算)后均匀分布,也就提高了线性探测时的效率。

getEntry

getEntry()方法会被ThreadLocalget()方法直接调用,上面也说过,get()方法内部就是先拿到当前线程的ThreadLocalMap,然后将自己this作为参数调用其getEntry()方法。这里要提前说明一点的是,每个索引(slot)上的状态有三种:有效(ThreadLocal未回收),失效(ThreadLocal已回收),空(null):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
private Entry getEntry(ThreadLocal<?> key) {
// 获取这个 key 的索引下标
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
// 对应的 entry 不为空且未失效,且弱引用指向的 ThreadLocal 就是传入的 key,则命中返回
if (e != null && e.get() == key)
return e;
else
// 因为用的是线性探测,所以往后还是有可能找到目标 Entry 的
return getEntryAfterMiss(key, i, e);
}

/**
* 调用 getEntry() 未直接命中时调用此方法
*/
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;

// 基于线性探测法不断向后探测直到遇到 null
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null) // 如果该 entry 对应的 ThreadLocal 已经被回收(失效),调用 expungeStaleEntry() 来清理无效的 entry
expungeStaleEntry(i);
else // 如果该 entry 对应的 ThreadLocal 未被回收,但与传入的 key 不等,则继续向后探测
i = nextIndex(i, len); // 环形意义下往后面走,线性探测
e = tab[i];
}
// 没找到指定的 key
return null;
}

/**
* ThreadLocal 的核心清理函数,从 staleSlot 下标开始遍历,将无效的的 entry 清理,
* 即将 entry 中的 value 置为 null,指向这个 entry 的 table[i] 置为 null,直到遍历到空 entry。
* 另外,在这个过程中还会对非空的 entry 作 rehash。
*/
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;

// 因为 entry 对应的 ThreadLocal 已经被回收,此时为了垃圾回收:
// 将 entry 中的 value 置为 null,显示断开强引用
tab[staleSlot].value = null;
// 将指向这个 entry 的 table[i] 置为 null
tab[staleSlot] = null;
// 将实际 entry 数减一
size--;

Entry e;
int i;
// 从 staleSlot 的下一个索引开始,不断向后遍历,直到遇到 null
for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// 如果当前 entry 中的 ThreadLocal 已经被回收,则做一次清理
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else { // 如果 entry 对应的 ThreadLocal 还没被回收,需要做一次 rehash
// 如果 ThreadLocal 计算出的 hash 对应的索引h与当前位置不同,
// 则从 h 开始向后线性探测直到第一个空的 slot,把当前的 entry 给挪过去
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null; // 先将当前索引置 null

while (tab[h] != null) // 遍历找到从索引h开始的第一个空 slot
h = nextIndex(h, len);
tab[h] = e; // 将 entry 挪到这个空 slot 上
}
}
}
// 返回 staleSlot 之后第一个 entry 为 null 的索引下标
return i;
}

总的来说,getEntry()会经历以下几步:

  1. 根据传入的ThreadLocal的哈希值定位到某个索引下标
  2. 如果该下标对应的entry存在,且其中的ThreadLocal和方法传入的ThreadLocal相同,则直接命中返回
  3. 否则,调用getEntryAfterMiss()进行线性探测,过程中每次碰到失效的 slot,就调用expungeStaleEntry进行段清理(清理并rehash,直到遇到null)
  4. 遍历直到 null 都未命中 key,直接返回 null

set

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1); // 获取该键对应的索引下标

// 线性探测
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
// 找到 key 相同的 entry,覆盖 value
if (k == key) {
e.value = value;
return;
}

// 如果当前 entry 失效,则替换失效的 entry
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// 线性探测过程中没有遇到 key 相同的 entry,也没遇到失效的 entry,当遇到 null 时跳出循环
// 在 null 的位置上建立新的 entry
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

/**
* 替换失效的 entry
*/
private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;

// 从 staleSlot 向前遍历,查找最前的一个无效的 slot
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;

// 从 staleSlot 向后遍历,看能不能找到相同的 key,如果找到了则和无效的 staleSlot 交换
for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();

// 找到了 key,将其与无效的slot交换
if (k == key) {
e.value = value;

tab[i] = tab[staleSlot];
tab[staleSlot] = e;

if (slotToExpunge == staleSlot)
slotToExpunge = i;
// 从 slotToExpunge 开始做一次连续段的清理,再做一次启发式清理
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}

// 如果当前的 slot 已经无效,并且向前扫描过程中没有无效 slot,则更新 slotToExpunge 为当前位置
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}

// 如果找不到相同的 key,则直接设置在失效的 staleSlot 下标上
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);

// 在探测过程中如果发现任何无效 slot,则做一次清理(连续段清理+启发式清理)
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

/**
* 做一次全量清理,并且调低阈值决定是否扩容
*/
private void rehash() {
// 做一次全量清理
expungeStaleEntries();
// 因为做了一次清理,所以 size 很可能会变小
// 这里是调低阈值判断是否需要扩容,下面一行相当于 if(size >= len / 2)
if (size >= threshold - threshold / 4)
resize();
}

/**
* 全量清理
*/
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
// 这里其实可以将 j 设为返回值,j 之前的 entry 其实已经被清理过了,肯定为 null
expungeStaleEntry(j);
}
}

/**
* 扩容为原来的两倍
*/
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;

for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1); // 计算新容量时哈希值对应的索引下标
while (newTab[h] != null) // 线性探测解决碰撞冲突
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}

// 设置新阈值
setThreshold(newLen);
size = count;
table = newTab;
}

set()方法总体过程如下:

  1. 在遍历(也就是线性探测)遇到null之前,如果遇到了相同的key,则直接覆盖;如果遇到了失效的entry,则调用replaceStaleEntry,效果是最终一定会把key和value放在这个slot上,并且会尽可能地清理无效entry
  2. 遍历过程既没遇到相同的key,也没遇到失效的entry,也就是当前索引上为null,则直接将key和value插在这个空slot上
  3. 如果插入后的size大于阈值,那么做一次全量清理,再根据调低的阈值决定是否需要扩容,扩容两倍(因为容量必须为2的幂)

remove

remove()方法相对比较简单,只需要找到对应的key,然后将弱引用显式的断开,并做一次段清理即可。

1
2
3
4
5
6
7
8
9
10
11
12
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear(); // 显式断开弱引用
expungeStaleEntry(i); // 进行段清理
return;
}
}
}

这里光做e.clear();其实是不够的,因为value此时还被强引用着,所以才需要进行段清理,将table[i] = null;彻底断开强引用。

内存泄露

经过上面的分析我们已经清楚在每个Thread中有一个ThreadLocalMap,每个线程在对某个ThreadLocal对象操作时都会先获取当前线程的ThreadLocalMap,然后对ThreadLocalMap进行操作,并且,ThreadLocal不是简单的作为key的,而是将key和value包装成继承自弱引用WeakReferenceEntry类。但这里要注意的是,弱引用只是针对key(Entry中的ThreadLocal),当没有任何强引用指向ThreadLocal的时候,它就只剩下弱引用了,GC时将会被回收,但是value却不会被回收,因为它存在一条当前Thread->ThreadLocalMap->Entry数组->Entry->value的强引用,所以除非线程销毁,否则它将与线程的生命周期绑定,尤其是在有线程复用比如线程池的场景中,一个线程的寿命很长,大对象长期不被回收会影响系统运行效率与安全,也就造成了人们常说的内存泄露。

但是在源码中我们也会发现,ThreadLocalMap实现中是有一套自我清理的机制的,当我们调用get()或者set()方法时会有很高的概率顺便清理掉失效的Entry,防止出现内存泄露。当然,显示地进行remove()是个良好的编程习惯,它可以确保不会发生内存泄露。

参考资料