Browse Source

add ValidateObjectInputStream

Looly 5 years ago
parent
commit
3921a568dd

+ 1 - 0
CHANGELOG.md

@@ -17,6 +17,7 @@
 * 【core   】     NetUtil增加parseCookies方法
 * 【core   】     CollUtil增加toMap方法
 * 【core   】     CollUtil和IterUtil废弃一些方法
+* 【core   】     添加ValidateObjectInputStream避免对象反序列化漏洞风险
 
 ### Bug修复
 * 【extra  】     修复SpringUtil使用devtools重启报错问题

+ 14 - 0
hutool-core/src/main/java/cn/hutool/core/collection/CollUtil.java

@@ -2348,7 +2348,9 @@ public class CollUtil {
 	 * @param <V>        Value类型
 	 * @param map        {@link Map}
 	 * @param kvConsumer {@link KVConsumer} 遍历的每条数据处理器
+	 * @deprecated JDK8+中使用map.forEach
 	 */
+	@Deprecated
 	public static <K, V> void forEach(Map<K, V> map, KVConsumer<K, V> kvConsumer) {
 		int index = 0;
 		for (Entry<K, V> entry : map.entrySet()) {
@@ -2527,6 +2529,18 @@ public class CollUtil {
 		return Collections.min(coll);
 	}
 
+	/**
+	 * 转为只读集合
+	 *
+	 * @param <T> 元素类型
+	 * @param c   集合
+	 * @return 只读集合
+	 * @since 5.2.6
+	 */
+	public static <T> Collection<T> unmodifiable(Collection<? extends T> c) {
+		return Collections.unmodifiableCollection(c);
+	}
+
 	// ---------------------------------------------------------------------------------------------- Interface start
 
 	/**

+ 12 - 0
hutool-core/src/main/java/cn/hutool/core/collection/ListUtil.java

@@ -444,4 +444,16 @@ public class ListUtil {
 		}
 		return Convert.convert(int[].class, indexList);
 	}
+
+	/**
+	 * 将对应List转换为不可修改的List
+	 *
+	 * @param list Map
+	 * @param <T> 元素类型
+	 * @return 不修改Map
+	 * @since 5.2.6
+	 */
+	public static <T> List<T> unmodifiable(List<T> list) {
+		return Collections.unmodifiableList(list);
+	}
 }

+ 19 - 5
hutool-core/src/main/java/cn/hutool/core/io/IoUtil.java

@@ -636,24 +636,38 @@ public class IoUtil {
 	}
 
 	/**
-	 * 从流中读取内容,读到输出流中
+	 * 从流中读取对象,即对象的反序列化
 	 * 
 	 * @param <T> 读取对象的类型
 	 * @param in 输入流
 	 * @return 输出流
 	 * @throws IORuntimeException IO异常
 	 * @throws UtilException ClassNotFoundException包装
+	 * @deprecated 由于存在对象反序列化漏洞风险,请使用{@link #readObj(InputStream, Class)}
 	 */
+	@Deprecated
 	public static <T> T readObj(InputStream in) throws IORuntimeException, UtilException {
+		return readObj(in, null);
+	}
+
+	/**
+	 * 从流中读取对象,即对象的反序列化,读取后不关闭流
+	 *
+	 * @param <T> 读取对象的类型
+	 * @param in 输入流
+	 * @return 输出流
+	 * @throws IORuntimeException IO异常
+	 * @throws UtilException ClassNotFoundException包装
+	 */
+	public static <T> T readObj(InputStream in, Class<T> clazz) throws IORuntimeException, UtilException {
 		if (in == null) {
 			throw new IllegalArgumentException("The InputStream must not be null");
 		}
 		ObjectInputStream ois;
 		try {
-			ois = new ObjectInputStream(in);
-			@SuppressWarnings("unchecked") // may fail with CCE if serialised form is incorrect
-			final T obj = (T) ois.readObject();
-			return obj;
+			ois = new ValidateObjectInputStream(in, clazz);
+			//noinspection unchecked
+			return (T) ois.readObject();
 		} catch (IOException e) {
 			throw new IORuntimeException(e);
 		} catch (ClassNotFoundException e) {

+ 53 - 0
hutool-core/src/main/java/cn/hutool/core/io/ValidateObjectInputStream.java

@@ -0,0 +1,53 @@
+package cn.hutool.core.io;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InvalidClassException;
+import java.io.ObjectInputStream;
+import java.io.ObjectStreamClass;
+
+/**
+ * 带有类验证的对象流,用于避免反序列化漏洞<br>
+ * 详细见:https://xz.aliyun.com/t/41/
+ *
+ * @author looly
+ * @since 5.2.6
+ */
+public class ValidateObjectInputStream extends ObjectInputStream {
+
+	private Class<?> acceptClass;
+
+	/**
+	 * 构造
+	 *
+	 * @param inputStream 流
+	 * @param acceptClass 接受的类
+	 * @throws IOException IO异常
+	 */
+	public ValidateObjectInputStream(InputStream inputStream, Class<?> acceptClass) throws IOException {
+		super(inputStream);
+		this.acceptClass = acceptClass;
+	}
+
+	/**
+	 * 接受反序列化的类,用于反序列化验证
+	 *
+	 * @param acceptClass 接受反序列化的类
+	 */
+	public void accept(Class<?> acceptClass) {
+		this.acceptClass = acceptClass;
+	}
+
+	/**
+	 * 只允许反序列化SerialObject class
+	 */
+	@Override
+	protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
+		if (null != this.acceptClass && false == desc.getName().equals(acceptClass.getName())) {
+			throw new InvalidClassException(
+					"Unauthorized deserialization attempt",
+					desc.getName());
+		}
+		return super.resolveClass(desc);
+	}
+}

+ 15 - 2
hutool-core/src/main/java/cn/hutool/core/map/MapUtil.java

@@ -560,7 +560,7 @@ public class MapUtil {
 	public static <K, V> String join(Map<K, V> map, String separator, String keyValueSeparator, boolean isIgnoreNull, String... otherParams) {
 		final StringBuilder strBuilder = StrUtil.builder();
 		boolean isFirst = true;
-		if(isNotEmpty(map)){
+		if (isNotEmpty(map)) {
 			for (Entry<K, V> entry : map.entrySet()) {
 				if (false == isIgnoreNull || entry.getKey() != null && entry.getValue() != null) {
 					if (isFirst) {
@@ -733,7 +733,7 @@ public class MapUtil {
 	 * @since 4.0.1
 	 */
 	public static <K, V> TreeMap<K, V> sort(Map<K, V> map, Comparator<? super K> comparator) {
-		if(null == map){
+		if (null == map) {
 			return null;
 		}
 
@@ -777,6 +777,19 @@ public class MapUtil {
 		return new MapWrapper<>(map);
 	}
 
+	/**
+	 * 将对应Map转换为不可修改的Map
+	 *
+	 * @param map Map
+	 * @param <K> 键类型
+	 * @param <V> 值类型
+	 * @return 不修改Map
+	 * @since 5.2.6
+	 */
+	public static <K, V> Map<K, V> unmodifiable(Map<K, V> map) {
+		return Collections.unmodifiableMap(map);
+	}
+
 	// ----------------------------------------------------------------------------------------------- builder
 
 	/**

+ 21 - 9
hutool-http/src/main/java/cn/hutool/http/server/HttpServerRequest.java

@@ -20,6 +20,7 @@ import java.net.HttpCookie;
 import java.net.URI;
 import java.nio.charset.Charset;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.Map;
 
 /**
@@ -28,7 +29,7 @@ import java.util.Map;
  * @author looly
  * @since 5.2.6
  */
-public class HttpServerRequest extends HttpServerBase{
+public class HttpServerRequest extends HttpServerBase {
 
 	private Map<String, HttpCookie> cookieCache;
 
@@ -149,6 +150,21 @@ public class HttpServerRequest extends HttpServerBase{
 	}
 
 	/**
+	 * 获取编码,获取失败默认使用UTF-8,获取规则如下:
+	 *
+	 * <pre>
+	 *     1、从Content-Type头中获取编码,类似于:text/html;charset=utf-8
+	 * </pre>
+	 *
+	 * @return 编码,默认UTF-8
+	 */
+	public Charset getCharset() {
+		final String contentType = getContentType();
+		final String charsetStr = HttpUtil.getCharset(contentType);
+		return CharsetUtil.parse(charsetStr, CharsetUtil.CHARSET_UTF_8);
+	}
+
+	/**
 	 * 获得User-Agent
 	 *
 	 * @return User-Agent字符串
@@ -191,10 +207,10 @@ public class HttpServerRequest extends HttpServerBase{
 	 */
 	public Map<String, HttpCookie> getCookieMap() {
 		if (null == this.cookieCache) {
-			cookieCache = CollUtil.toMap(
+			cookieCache = Collections.unmodifiableMap(CollUtil.toMap(
 					NetUtil.parseCookies(getCookiesStr()),
 					new CaseInsensitiveMap<>(),
-					HttpCookie::getName);
+					HttpCookie::getName));
 		}
 		return cookieCache;
 	}
@@ -220,16 +236,12 @@ public class HttpServerRequest extends HttpServerBase{
 
 	/**
 	 * 获取请求体文本,可以是form表单、json、xml等任意内容<br>
-	 * 根据请求的Content-Type判断编码,判断失败使用UTF-8编码
+	 * 使用{@link #getCharset()}判断编码,判断失败使用UTF-8编码
 	 *
 	 * @return 请求
 	 */
 	public String getBody() {
-		final String contentType = getContentType();
-		final String charsetStr = HttpUtil.getCharset(contentType);
-		final Charset charset = CharsetUtil.parse(charsetStr, CharsetUtil.CHARSET_UTF_8);
-
-		return getBody(charset);
+		return getBody(getCharset());
 	}
 
 	/**

+ 164 - 1
hutool-http/src/main/java/cn/hutool/http/server/HttpServerResponse.java

@@ -1,18 +1,33 @@
 package cn.hutool.http.server;
 
+import cn.hutool.core.io.FileUtil;
 import cn.hutool.core.io.IORuntimeException;
 import cn.hutool.core.io.IoUtil;
+import cn.hutool.core.util.CharsetUtil;
+import cn.hutool.core.util.ObjectUtil;
+import cn.hutool.core.util.StrUtil;
+import cn.hutool.core.util.URLUtil;
+import cn.hutool.http.Header;
+import cn.hutool.http.HttpUtil;
+import com.sun.net.httpserver.Headers;
 import com.sun.net.httpserver.HttpExchange;
 
+import java.io.BufferedInputStream;
 import java.io.ByteArrayInputStream;
+import java.io.File;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
+import java.nio.charset.Charset;
+import java.util.List;
+import java.util.Map;
 
 /**
  * Http响应对象,用于写出数据到客户端
  */
-public class HttpServerResponse extends HttpServerBase{
+public class HttpServerResponse extends HttpServerBase {
+
+	private Charset charset;
 
 	/**
 	 * 构造
@@ -50,6 +65,111 @@ public class HttpServerResponse extends HttpServerBase{
 	}
 
 	/**
+	 * 获得所有响应头,获取后可以添加新的响应头
+	 *
+	 * @return 响应头
+	 */
+	public Headers getHeaders() {
+		return this.httpExchange.getResponseHeaders();
+	}
+
+	/**
+	 * 添加响应头,如果已经存在,则追加
+	 *
+	 * @param header 头key
+	 * @param value  值
+	 * @return this
+	 */
+	public HttpServerResponse addHeader(String header, String value) {
+		getHeaders().add(header, value);
+		return this;
+	}
+
+	/**
+	 * 设置响应头,如果已经存在,则覆盖
+	 *
+	 * @param header 头key
+	 * @param value  值
+	 * @return this
+	 */
+	public HttpServerResponse setHeader(Header header, String value) {
+		return setHeader(header.getValue(), value);
+	}
+
+	/**
+	 * 设置响应头,如果已经存在,则覆盖
+	 *
+	 * @param header 头key
+	 * @param value  值
+	 * @return this
+	 */
+	public HttpServerResponse setHeader(String header, String value) {
+		getHeaders().set(header, value);
+		return this;
+	}
+
+	/**
+	 * 设置响应头,如果已经存在,则覆盖
+	 *
+	 * @param header 头key
+	 * @param value  值列表
+	 * @return this
+	 */
+	public HttpServerResponse setHeader(String header, List<String> value) {
+		getHeaders().put(header, value);
+		return this;
+	}
+
+	/**
+	 * 设置所有响应头,如果已经存在,则覆盖
+	 *
+	 * @param headers 响应头map
+	 * @return this
+	 */
+	public HttpServerResponse setHeaders(Map<String, List<String>> headers) {
+		getHeaders().putAll(headers);
+		return this;
+	}
+
+	/**
+	 * 设置Content-Type头,类似于:text/html;charset=utf-8<br>
+	 * 如果用户传入的信息无charset信息,自动根据charset补充,charset设置见{@link #setCharset(Charset)}
+	 *
+	 * @param contentType Content-Type头内容
+	 * @return this
+	 */
+	public HttpServerResponse setContentType(String contentType) {
+		if (null != contentType && null != this.charset) {
+			if (false == contentType.contains(";charset=")) {
+				contentType += ";charset=" + this.charset;
+			}
+		}
+
+		return setHeader(Header.CONTENT_TYPE, contentType);
+	}
+
+	/**
+	 * 设置Content-Length头
+	 *
+	 * @param contentLength Content-Length头内容
+	 * @return this
+	 */
+	public HttpServerResponse setContentLength(long contentLength) {
+		return setHeader(Header.CONTENT_LENGTH, String.valueOf(contentLength));
+	}
+
+	/**
+	 * 设置响应的编码
+	 *
+	 * @param charset 编码
+	 * @return this
+	 */
+	public HttpServerResponse setCharset(Charset charset) {
+		this.charset = charset;
+		return this;
+	}
+
+	/**
 	 * 获取响应数据流
 	 *
 	 * @return 响应数据流
@@ -59,6 +179,15 @@ public class HttpServerResponse extends HttpServerBase{
 	}
 
 	/**
+	 * 获取响应数据流
+	 *
+	 * @return 响应数据流
+	 */
+	public OutputStream getWriter() {
+		return this.httpExchange.getResponseBody();
+	}
+
+	/**
 	 * 写出数据到客户端
 	 *
 	 * @param data 数据
@@ -86,4 +215,38 @@ public class HttpServerResponse extends HttpServerBase{
 		}
 		return this;
 	}
+
+	/**
+	 * 返回文件给客户端(文件下载)
+	 *
+	 * @param file 写出的文件对象
+	 * @since 5.2.6
+	 */
+	public HttpServerResponse write(File file) {
+		final String fileName = file.getName();
+		final String contentType = ObjectUtil.defaultIfNull(HttpUtil.getMimeType(fileName), "application/octet-stream");
+		BufferedInputStream in = null;
+		try {
+			in = FileUtil.getInputStream(file);
+			write(in, contentType, fileName);
+		} finally {
+			IoUtil.close(in);
+		}
+		return this;
+	}
+
+	/**
+	 * 返回数据给客户端
+	 *
+	 * @param in          需要返回客户端的内容
+	 * @param contentType 返回的类型
+	 * @param fileName    文件名
+	 * @since 5.2.6
+	 */
+	public void write(InputStream in, String contentType, String fileName) {
+		final Charset charset = ObjectUtil.defaultIfNull(this.charset, CharsetUtil.CHARSET_UTF_8);
+		setHeader("Content-Disposition", StrUtil.format("attachment;filename={}", URLUtil.encode(fileName, charset)));
+		setContentType(contentType);
+		write(in);
+	}
 }