Browse Source

add method and test

Looly 5 years ago
parent
commit
e2428714a0

+ 3 - 0
CHANGELOG.md

@@ -10,7 +10,10 @@
 * 【extra  】     增加Sftp.lsEntries方法,Ftp和Sftp增加recursiveDownloadFolder(pr#121@Gitee)
 * 【system 】     OshiUtil增加getNetworkIFs方法
 * 【core   】     CollUtil增加unionDistinct、unionAll方法(pr#122@Gitee)
+* 【core   】     增加IoUtil.readObj重载,通过ValidateObjectInputStream由用户自定义安全检查。
+
 ### Bug修复
+* 【core   】     修复IoUtil.readObj中反序列化安全检查导致的一些问题,去掉安全检查。
 
 -------------------------------------------------------------------------------------------------------------
 

+ 1 - 1
hutool-core/src/main/java/cn/hutool/core/collection/CollUtil.java

@@ -293,7 +293,7 @@ public class CollUtil {
 			return coll1;
 		}
 
-		final ArrayList<T> result = new ArrayList<>();
+		final List<T> result = new ArrayList<>();
 		final Map<T, Integer> map1 = countMap(coll1);
 		final Map<T, Integer> map2 = countMap(coll2);
 		final Set<T> elts = newHashSet(coll2);

+ 36 - 7
hutool-core/src/main/java/cn/hutool/core/io/IoUtil.java

@@ -20,7 +20,6 @@ import java.io.Flushable;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStreamReader;
-import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.io.OutputStream;
 import java.io.OutputStreamWriter;
@@ -648,14 +647,16 @@ public class IoUtil {
 	/**
 	 * 从流中读取对象,即对象的反序列化
 	 *
+	 * <p>
+	 * 注意!!! 此方法不会检查反序列化安全,可能存在反序列化漏洞风险!!!
+	 * </p>
+	 *
 	 * @param <T> 读取对象的类型
 	 * @param in  输入流
 	 * @return 输出流
 	 * @throws IORuntimeException IO异常
 	 * @throws UtilException      ClassNotFoundException包装
-	 * @deprecated 由于存在对象反序列化漏洞风险,请使用{@link #readObj(InputStream, Class)}
 	 */
-	@Deprecated
 	public static <T> T readObj(InputStream in) throws IORuntimeException, UtilException {
 		return readObj(in, null);
 	}
@@ -663,6 +664,10 @@ public class IoUtil {
 	/**
 	 * 从流中读取对象,即对象的反序列化,读取后不关闭流
 	 *
+	 * <p>
+	 * 注意!!! 此方法不会检查反序列化安全,可能存在反序列化漏洞风险!!!
+	 * </p>
+	 *
 	 * @param <T>   读取对象的类型
 	 * @param in    输入流
 	 * @param clazz 读取对象类型
@@ -671,14 +676,38 @@ public class IoUtil {
 	 * @throws UtilException      ClassNotFoundException包装
 	 */
 	public static <T> T readObj(InputStream in, Class<T> clazz) throws IORuntimeException, UtilException {
+		try {
+			return readObj((in instanceof ValidateObjectInputStream) ?
+							(ValidateObjectInputStream) in : new ValidateObjectInputStream(in),
+					clazz);
+		} catch (IOException e) {
+			throw new IORuntimeException(e);
+		}
+	}
+
+	/**
+	 * 从流中读取对象,即对象的反序列化,读取后不关闭流
+	 *
+	 * <p>
+	 * 此方法使用了{@link ValidateObjectInputStream}中的黑白名单方式过滤类,用于避免反序列化漏洞<br>
+	 * 通过构造{@link ValidateObjectInputStream},调用{@link ValidateObjectInputStream#accept(Class[])}
+	 * 或者{@link ValidateObjectInputStream#refuse(Class[])}方法添加可以被序列化的类或者禁止序列化的类。
+	 * </p>
+	 *
+	 * @param <T>   读取对象的类型
+	 * @param in    输入流,使用{@link ValidateObjectInputStream}中的黑白名单方式过滤类,用于避免反序列化漏洞
+	 * @param clazz 读取对象类型
+	 * @return 输出流
+	 * @throws IORuntimeException IO异常
+	 * @throws UtilException      ClassNotFoundException包装
+	 */
+	public static <T> T readObj(ValidateObjectInputStream in, Class<T> clazz) throws IORuntimeException, UtilException {
 		if (in == null) {
 			throw new IllegalArgumentException("The InputStream must not be null");
 		}
-		ObjectInputStream ois;
 		try {
-			ois = new ValidateObjectInputStream(in, clazz);
 			//noinspection unchecked
-			return (T) ois.readObject();
+			return (T) in.readObject();
 		} catch (IOException e) {
 			throw new IORuntimeException(e);
 		} catch (ClassNotFoundException e) {
@@ -989,7 +1018,7 @@ public class IoUtil {
 	 *
 	 * @param out        输出流
 	 * @param isCloseOut 写入完毕是否关闭输出流
-	 * @param obj   写入的对象内容
+	 * @param obj        写入的对象内容
 	 * @throws IORuntimeException IO异常
 	 * @since 5.3.3
 	 */

+ 60 - 12
hutool-core/src/main/java/cn/hutool/core/io/ValidateObjectInputStream.java

@@ -1,10 +1,14 @@
 package cn.hutool.core.io;
 
+import cn.hutool.core.collection.CollUtil;
+
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.InvalidClassException;
 import java.io.ObjectInputStream;
 import java.io.ObjectStreamClass;
+import java.util.HashSet;
+import java.util.Set;
 
 /**
  * 带有类验证的对象流,用于避免反序列化漏洞<br>
@@ -15,27 +19,48 @@ import java.io.ObjectStreamClass;
  */
 public class ValidateObjectInputStream extends ObjectInputStream {
 
-	private Class<?> acceptClass;
+	private Set<String> whiteClassSet;
+	private Set<String> blackClassSet;
 
 	/**
 	 * 构造
 	 *
 	 * @param inputStream 流
-	 * @param acceptClass 接受的类
+	 * @param acceptClasses 白名单的类
 	 * @throws IOException IO异常
 	 */
-	public ValidateObjectInputStream(InputStream inputStream, Class<?> acceptClass) throws IOException {
+	public ValidateObjectInputStream(InputStream inputStream, Class<?>... acceptClasses) throws IOException {
 		super(inputStream);
-		this.acceptClass = acceptClass;
+		accept(acceptClasses);
+	}
+
+	/**
+	 * 禁止反序列化的类,用于反序列化验证
+	 *
+	 * @param refuseClasses 禁止反序列化的类
+	 * @since 5.3.5
+	 */
+	public void refuse(Class<?>... refuseClasses) {
+		if(null == this.blackClassSet){
+			this.blackClassSet = new HashSet<>();
+		}
+		for (Class<?> acceptClass : refuseClasses) {
+			this.blackClassSet.add(acceptClass.getName());
+		}
 	}
 
 	/**
 	 * 接受反序列化的类,用于反序列化验证
 	 *
-	 * @param acceptClass 接受反序列化的类
+	 * @param acceptClasses 接受反序列化的类
 	 */
-	public void accept(Class<?> acceptClass) {
-		this.acceptClass = acceptClass;
+	public void accept(Class<?>... acceptClasses) {
+		if(null == this.whiteClassSet){
+			this.whiteClassSet = new HashSet<>();
+		}
+		for (Class<?> acceptClass : acceptClasses) {
+			this.whiteClassSet.add(acceptClass.getName());
+		}
 	}
 
 	/**
@@ -43,11 +68,34 @@ public class ValidateObjectInputStream extends ObjectInputStream {
 	 */
 	@Override
 	protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
-		if (null != this.acceptClass && false == desc.getName().equals(acceptClass.getName())) {
-			throw new InvalidClassException(
-					"Unauthorized deserialization attempt",
-					desc.getName());
-		}
+		validateClassName(desc.getName());
 		return super.resolveClass(desc);
 	}
+
+	/**
+	 * 验证反序列化的类是否合法
+	 * @param className 类名
+	 * @throws InvalidClassException 非法类
+	 */
+	private void validateClassName(String className) throws InvalidClassException {
+		// 黑名单
+		if(CollUtil.isNotEmpty(this.blackClassSet)){
+			if(this.blackClassSet.contains(className)){
+				throw new InvalidClassException("Unauthorized deserialization attempt by black list", className);
+			}
+		}
+
+		if(CollUtil.isEmpty(this.whiteClassSet)){
+			return;
+		}
+		if(className.startsWith("java.")){
+			// java中的类默认在白名单中
+			return;
+		}
+		if(this.whiteClassSet.contains(className)){
+			return;
+		}
+
+		throw new InvalidClassException("Unauthorized deserialization attempt", className);
+	}
 }