ソースを参照

add setClassLoader

Looly 6 年 前
コミット
476285b302

+ 1 - 0
CHANGELOG.md

@@ -8,6 +8,7 @@
 ### 新特性
 * 【all】        修复注释中的错别字(issue#I12XE6@Gitee)
 * 【core】       CsvWriter支持其它类型的参数(issue#I12XE3@Gitee)
+* 【core】       ClassScaner支持自定义ClassLoader
 
 ### Bug修复
 * 【all】        修复阶乘计算错误bug(issue#I12XE4@Gitee)

+ 91 - 65
hutool-core/src/main/java/cn/hutool/core/lang/ClassScaner.java

@@ -1,5 +1,12 @@
 package cn.hutool.core.lang;
 
+import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.collection.EnumerationIter;
+import cn.hutool.core.io.FileUtil;
+import cn.hutool.core.io.IORuntimeException;
+import cn.hutool.core.io.resource.ResourceUtil;
+import cn.hutool.core.util.*;
+
 import java.io.File;
 import java.io.IOException;
 import java.io.Serializable;
@@ -12,48 +19,54 @@ import java.util.Set;
 import java.util.jar.JarEntry;
 import java.util.jar.JarFile;
 
-import cn.hutool.core.collection.CollUtil;
-import cn.hutool.core.collection.EnumerationIter;
-import cn.hutool.core.io.FileUtil;
-import cn.hutool.core.io.IORuntimeException;
-import cn.hutool.core.io.resource.ResourceUtil;
-import cn.hutool.core.util.CharUtil;
-import cn.hutool.core.util.CharsetUtil;
-import cn.hutool.core.util.ClassUtil;
-import cn.hutool.core.util.StrUtil;
-import cn.hutool.core.util.URLUtil;
-
 /**
  * 类扫描器
- * 
+ *
  * @author looly
  * @since 4.1.5
- *
  */
-public class ClassScaner implements Serializable{
+public class ClassScaner implements Serializable {
 	private static final long serialVersionUID = 1L;
 
-	/** 包名 */
+	/**
+	 * 包名
+	 */
 	private String packageName;
-	/** 包名,最后跟一个点,表示包名,避免在检查前缀时的歧义 */
+	/**
+	 * 包名,最后跟一个点,表示包名,避免在检查前缀时的歧义
+	 */
 	private String packageNameWithDot;
-	/** 包路径,用于文件中对路径操作 */
+	/**
+	 * 包路径,用于文件中对路径操作
+	 */
 	private String packageDirName;
-	/** 包路径,用于jar中对路径操作,在Linux下与packageDirName一致 */
+	/**
+	 * 包路径,用于jar中对路径操作,在Linux下与packageDirName一致
+	 */
 	private String packagePath;
-	/** 过滤器 */
+	/**
+	 * 过滤器
+	 */
 	private Filter<Class<?>> classFilter;
-	/** 编码 */
+	/**
+	 * 编码
+	 */
 	private Charset charset;
-	/** 是否初始化类 */
+	/**
+	 * 类加载器
+	 */
+	private ClassLoader classLoader;
+	/**
+	 * 是否初始化类
+	 */
 	private boolean initialize;
 
-	private Set<Class<?>> classes = new HashSet<Class<?>>();
-	
+	private Set<Class<?>> classes = new HashSet<>();
+
 	/**
 	 * 扫描指定包路径下所有包含指定注解的类
-	 * 
-	 * @param packageName 包路径
+	 *
+	 * @param packageName     包路径
 	 * @param annotationClass 注解类
 	 * @return 类集合
 	 */
@@ -68,9 +81,9 @@ public class ClassScaner implements Serializable{
 
 	/**
 	 * 扫描指定包路径下所有指定类或接口的子类或实现类
-	 * 
+	 *
 	 * @param packageName 包路径
-	 * @param superClass 父类或接口
+	 * @param superClass  父类或接口
 	 * @return 类集合
 	 */
 	public static Set<Class<?>> scanPackageBySuper(String packageName, final Class<?> superClass) {
@@ -84,7 +97,7 @@ public class ClassScaner implements Serializable{
 
 	/**
 	 * 扫描该包路径下所有class文件
-	 * 
+	 *
 	 * @return 类集合
 	 */
 	public static Set<Class<?>> scanPackage() {
@@ -93,7 +106,7 @@ public class ClassScaner implements Serializable{
 
 	/**
 	 * 扫描该包路径下所有class文件
-	 * 
+	 *
 	 * @param packageName 包路径 com | com. | com.abs | com.abs.
 	 * @return 类集合
 	 */
@@ -105,7 +118,7 @@ public class ClassScaner implements Serializable{
 	 * 扫描包路径下满足class过滤器条件的所有class文件,<br>
 	 * 如果包路径为 com.abs + A.class 但是输入 abs会产生classNotFoundException<br>
 	 * 因为className 应该为 com.abs.A 现在却成为abs.A,此工具类对该异常进行忽略处理<br>
-	 * 
+	 *
 	 * @param packageName 包路径 com | com. | com.abs | com.abs.
 	 * @param classFilter class过滤器,过滤掉不需要的class
 	 * @return 类集合
@@ -123,7 +136,7 @@ public class ClassScaner implements Serializable{
 
 	/**
 	 * 构造,默认UTF-8编码
-	 * 
+	 *
 	 * @param packageName 包名,所有包传入""或者null
 	 */
 	public ClassScaner(String packageName) {
@@ -132,7 +145,7 @@ public class ClassScaner implements Serializable{
 
 	/**
 	 * 构造,默认UTF-8编码
-	 * 
+	 *
 	 * @param packageName 包名,所有包传入""或者null
 	 * @param classFilter 过滤器,无需传入null
 	 */
@@ -142,10 +155,10 @@ public class ClassScaner implements Serializable{
 
 	/**
 	 * 构造
-	 * 
+	 *
 	 * @param packageName 包名,所有包传入""或者null
 	 * @param classFilter 过滤器,无需传入null
-	 * @param charset 编码
+	 * @param charset     编码
 	 */
 	public ClassScaner(String packageName, Filter<Class<?>> classFilter, Charset charset) {
 		packageName = StrUtil.nullToEmpty(packageName);
@@ -159,42 +172,51 @@ public class ClassScaner implements Serializable{
 
 	/**
 	 * 扫描包路径下满足class过滤器条件的所有class文件
-	 * 
+	 *
 	 * @return 类集合
 	 */
 	public Set<Class<?>> scan() {
 		for (URL url : ResourceUtil.getResourceIter(this.packagePath)) {
 			switch (url.getProtocol()) {
-			case "file":
-				scanFile(new File(URLUtil.decode(url.getFile(), this.charset.name())), null);
-				break;
-			case "jar":
-				scanJar(URLUtil.getJarFile(url));
-				break;
+				case "file":
+					scanFile(new File(URLUtil.decode(url.getFile(), this.charset.name())), null);
+					break;
+				case "jar":
+					scanJar(URLUtil.getJarFile(url));
+					break;
 			}
 		}
-		
-		if(CollUtil.isEmpty(this.classes)) {
+
+		if (CollUtil.isEmpty(this.classes)) {
 			scanJavaClassPaths();
 		}
-		
+
 		return Collections.unmodifiableSet(this.classes);
 	}
 
 	/**
 	 * 设置是否在扫描到类时初始化类
-	 * 
+	 *
 	 * @param initialize 是否初始化类
 	 */
 	public void setInitialize(boolean initialize) {
 		this.initialize = initialize;
 	}
 
+	/**
+	 * 设置自定义的类加载器
+	 *
+	 * @param classLoader 类加载器
+	 * @since 4.6.9
+	 */
+	public void setClassLoader(ClassLoader classLoader) {
+		this.classLoader = classLoader;
+	}
+
 	// --------------------------------------------------------------------------------------------------- Private method start
+
 	/**
 	 * 扫描Java指定的ClassPath路径
-	 * 
-	 * @return 扫描到的类
 	 */
 	private void scanJavaClassPaths() {
 		final String[] javaClassPaths = ClassUtil.getJavaClassPaths();
@@ -205,11 +227,11 @@ public class ClassScaner implements Serializable{
 			scanFile(new File(classPath), null);
 		}
 	}
-	
+
 	/**
 	 * 扫描文件或目录中的类
-	 * 
-	 * @param file 文件或目录
+	 *
+	 * @param file    文件或目录
 	 * @param rootDir 包名对应classpath绝对路径
 	 */
 	private void scanFile(File file, String rootDir) {
@@ -238,7 +260,7 @@ public class ClassScaner implements Serializable{
 
 	/**
 	 * 扫描jar包
-	 * 
+	 *
 	 * @param jar jar包
 	 */
 	private void scanJar(JarFile jar) {
@@ -258,14 +280,20 @@ public class ClassScaner implements Serializable{
 
 	/**
 	 * 加载类
-	 * 
+	 *
 	 * @param className 类名
 	 * @return 加载的类
 	 */
 	private Class<?> loadClass(String className) {
+		ClassLoader loader = this.classLoader;
+		if (null == loader) {
+			loader = ClassLoaderUtil.getClassLoader();
+			this.classLoader = loader;
+		}
+
 		Class<?> clazz = null;
 		try {
-			clazz = Class.forName(className, this.initialize, ClassUtil.getClassLoader());
+			clazz = Class.forName(className, this.initialize, loader);
 		} catch (NoClassDefFoundError e) {
 			// 由于依赖库导致的类无法加载,直接跳过此类
 		} catch (UnsupportedClassVersionError e) {
@@ -276,27 +304,26 @@ public class ClassScaner implements Serializable{
 		}
 		return clazz;
 	}
-	
+
 	/**
 	 * 通过过滤器,是否满足接受此类的条件
-	 * 
-	 * @param clazz 类
-	 * @return 是否接受
+	 *
+	 * @param className 类名
 	 */
 	private void addIfAccept(String className) {
-		if(StrUtil.isBlank(className)) {
+		if (StrUtil.isBlank(className)) {
 			return;
 		}
 		int classLen = className.length();
 		int packageLen = this.packageName.length();
-		if(classLen == packageLen) {
+		if (classLen == packageLen) {
 			//类名和包名长度一致,用户可能传入的包名是类名
-			if(className.equals(this.packageName)) {
+			if (className.equals(this.packageName)) {
 				addIfAccept(loadClass(className));
 			}
-		} else if(classLen > packageLen){
+		} else if (classLen > packageLen) {
 			//检查类名是否以指定包名为前缀,包名后加.(避免类似于cn.hutool.A和cn.hutool.ATest这类类名引起的歧义)
-			if(className.startsWith(this.packageNameWithDot)) {
+			if (className.startsWith(this.packageNameWithDot)) {
 				addIfAccept(loadClass(className));
 			}
 		}
@@ -304,9 +331,8 @@ public class ClassScaner implements Serializable{
 
 	/**
 	 * 通过过滤器,是否满足接受此类的条件
-	 * 
+	 *
 	 * @param clazz 类
-	 * @return 是否接受
 	 */
 	private void addIfAccept(Class<?> clazz) {
 		if (null != clazz) {
@@ -319,7 +345,7 @@ public class ClassScaner implements Serializable{
 
 	/**
 	 * 截取文件绝对路径中包名之前的部分
-	 * 
+	 *
 	 * @param file 文件
 	 * @return 包名之前的部分
 	 */

+ 1 - 1
hutool-core/src/test/java/cn/hutool/core/lang/ClassScanerTest.java

@@ -10,7 +10,7 @@ public class ClassScanerTest {
 	@Test
 	@Ignore
 	public void scanTest() {
-		ClassScaner scaner = new ClassScaner("cn.hutool.core.util.StrUtil", null);
+		ClassScaner scaner = new ClassScaner("cn.hutool.core.util", null);
 		Set<Class<?>> set = scaner.scan();
 		for (Class<?> clazz : set) {
 			Console.log(clazz.getName());