|
|
@@ -1,10 +1,14 @@
|
|
|
package cn.hutool.core.io;
|
|
|
|
|
|
+import cn.hutool.core.collection.CollUtil;
|
|
|
+
|
|
|
import java.io.IOException;
|
|
|
import java.io.InputStream;
|
|
|
import java.io.InvalidClassException;
|
|
|
import java.io.ObjectInputStream;
|
|
|
import java.io.ObjectStreamClass;
|
|
|
+import java.util.HashSet;
|
|
|
+import java.util.Set;
|
|
|
|
|
|
/**
|
|
|
* 带有类验证的对象流,用于避免反序列化漏洞<br>
|
|
|
@@ -15,27 +19,48 @@ import java.io.ObjectStreamClass;
|
|
|
*/
|
|
|
public class ValidateObjectInputStream extends ObjectInputStream {
|
|
|
|
|
|
- private Class<?> acceptClass;
|
|
|
+ private Set<String> whiteClassSet;
|
|
|
+ private Set<String> blackClassSet;
|
|
|
|
|
|
/**
|
|
|
* 构造
|
|
|
*
|
|
|
* @param inputStream 流
|
|
|
- * @param acceptClass 接受的类
|
|
|
+ * @param acceptClasses 白名单的类
|
|
|
* @throws IOException IO异常
|
|
|
*/
|
|
|
- public ValidateObjectInputStream(InputStream inputStream, Class<?> acceptClass) throws IOException {
|
|
|
+ public ValidateObjectInputStream(InputStream inputStream, Class<?>... acceptClasses) throws IOException {
|
|
|
super(inputStream);
|
|
|
- this.acceptClass = acceptClass;
|
|
|
+ accept(acceptClasses);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 禁止反序列化的类,用于反序列化验证
|
|
|
+ *
|
|
|
+ * @param refuseClasses 禁止反序列化的类
|
|
|
+ * @since 5.3.5
|
|
|
+ */
|
|
|
+ public void refuse(Class<?>... refuseClasses) {
|
|
|
+ if(null == this.blackClassSet){
|
|
|
+ this.blackClassSet = new HashSet<>();
|
|
|
+ }
|
|
|
+ for (Class<?> acceptClass : refuseClasses) {
|
|
|
+ this.blackClassSet.add(acceptClass.getName());
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
* 接受反序列化的类,用于反序列化验证
|
|
|
*
|
|
|
- * @param acceptClass 接受反序列化的类
|
|
|
+ * @param acceptClasses 接受反序列化的类
|
|
|
*/
|
|
|
- public void accept(Class<?> acceptClass) {
|
|
|
- this.acceptClass = acceptClass;
|
|
|
+ public void accept(Class<?>... acceptClasses) {
|
|
|
+ if(null == this.whiteClassSet){
|
|
|
+ this.whiteClassSet = new HashSet<>();
|
|
|
+ }
|
|
|
+ for (Class<?> acceptClass : acceptClasses) {
|
|
|
+ this.whiteClassSet.add(acceptClass.getName());
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
@@ -43,11 +68,34 @@ public class ValidateObjectInputStream extends ObjectInputStream {
|
|
|
*/
|
|
|
@Override
|
|
|
protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
|
|
|
- if (null != this.acceptClass && false == desc.getName().equals(acceptClass.getName())) {
|
|
|
- throw new InvalidClassException(
|
|
|
- "Unauthorized deserialization attempt",
|
|
|
- desc.getName());
|
|
|
- }
|
|
|
+ validateClassName(desc.getName());
|
|
|
return super.resolveClass(desc);
|
|
|
}
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 验证反序列化的类是否合法
|
|
|
+ * @param className 类名
|
|
|
+ * @throws InvalidClassException 非法类
|
|
|
+ */
|
|
|
+ private void validateClassName(String className) throws InvalidClassException {
|
|
|
+ // 黑名单
|
|
|
+ if(CollUtil.isNotEmpty(this.blackClassSet)){
|
|
|
+ if(this.blackClassSet.contains(className)){
|
|
|
+ throw new InvalidClassException("Unauthorized deserialization attempt by black list", className);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if(CollUtil.isEmpty(this.whiteClassSet)){
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ if(className.startsWith("java.")){
|
|
|
+ // java中的类默认在白名单中
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ if(this.whiteClassSet.contains(className)){
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ throw new InvalidClassException("Unauthorized deserialization attempt", className);
|
|
|
+ }
|
|
|
}
|