-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathTrunkClassLoader.java
More file actions
162 lines (135 loc) · 4.65 KB
/
TrunkClassLoader.java
File metadata and controls
162 lines (135 loc) · 4.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
package xyz.spruceloader.trunk;
import org.apache.commons.io.IOUtils;
import xyz.spruceloader.trunk.api.*;
import java.io.*;
import java.net.*;
import java.nio.file.Path;
import java.util.*;
import java.util.function.Predicate;
public class TrunkClassLoader extends URLClassLoader {
private final Trunk trunk;
private final List<Predicate<String>> filters = new ArrayList<>();
private final List<Predicate<String>> transformerFilters = new ArrayList<>();
private final ClassLoader fallback;
public TrunkClassLoader(Trunk trunk, URL[] urls, ClassLoader fallback) {
super(urls, null);
this.trunk = trunk;
this.fallback = fallback;
}
public void addPath(Path path) {
try {
addURL(path.toUri().toURL());
} catch (MalformedURLException e) {
throw new IllegalArgumentException(path.toString(), e);
}
}
/**
* Adds the default loading filters. This is subject to change!
*/
public void addDefaultLoadingFilters() {
addPackageLoadingFilter("java");
addPackageLoadingFilter("jdk");
addPackageLoadingFilter("javax");
addPackageLoadingFilter("sun");
addPackageLoadingFilter("com.sun");
addPackageLoadingFilter("org.apache.logging.log4j");
addPackageLoadingFilter("org.slf4j");
}
/**
* Adds a loading filter - returns true to filter the class.
*
* @param filter the filter.
*/
public void addLoadingFilter(Predicate<String> filter) {
filters.add(filter);
}
/**
* Filters a package out from loading.
*
* @param packageName the package name.
*/
public void addPackageLoadingFilter(String packageName) {
filters.add(packagePredicate(packageName));
}
/**
* Filters a class out from loading.
*
* @param className the class name.
*/
public void addClassLoadingFilter(String className) {
filters.add(className::equals);
}
/**
* Adds a transformation filter - returns true to filter the class.
*
* @param filter the filter.
*/
public void addTransformationFilter(Predicate<String> filter) {
transformerFilters.add(filter);
}
/**
* Filters a package out from being transformed.
*
* @param packageName the package name.
*/
public void addPackageTransformationFilter(String packageName) {
transformerFilters.add(packagePredicate(packageName));
}
/**
* Filters a class out from being transformed.
*
* @param className the class name.
*/
public void addClassTransformationFilter(String className) {
transformerFilters.add(className::equals);
}
@Override
public Class<?> loadClass(String name) throws ClassNotFoundException {
synchronized (getClassLoadingLock(name)) {
Class<?> loaded = findLoadedClass(name);
if (loaded != null)
return loaded;
if (filter(filters, name))
return fallback.loadClass(name);
Class<?> result = findClass(name);
if (result == null)
return fallback.loadClass(name);
return result;
}
}
@Override
protected Class<?> findClass(String name) throws ClassNotFoundException {
try {
byte[] data = transformClassBytes(name);
if (data == null)
throw new ClassNotFoundException(name);
return defineClass(name, data, 0, data.length);
} catch (Throwable e) {
throw new ClassNotFoundException(name, e);
}
}
private byte[] transformClassBytes(String name) throws IOException {
return transformClassBytes(name, getClassBytes(name));
}
private byte[] transformClassBytes(String name, byte[] bytes) {
if (filter(transformerFilters, name))
return bytes;
for (Transformer transformer : trunk.getTransformerManager())
bytes = transformer.transform(name, bytes);
return bytes;
}
private byte[] getClassBytes(String name) throws IOException {
try (InputStream in = getResourceAsStream(name.replace(".", "/").concat(".class"))) {
if (in == null)
return null;
return IOUtils.toByteArray(in);
}
}
private static boolean filter(List<Predicate<String>> predicates, String className) {
return predicates.stream().anyMatch((filter) -> filter.test(className));
}
private static Predicate<String> packagePredicate(String packageName) {
String suffixed = packageName + '.';
return name -> name.startsWith(suffixed);
}
}