Browse Source

Merge pull request #988 from hyl1995/v5-dev

nio-socket修改
Golden Looly 5 years ago
parent
commit
f527c7af39

+ 151 - 71
hutool-socket/src/main/java/cn/hutool/socket/nio/NioClient.java

@@ -2,12 +2,20 @@ package cn.hutool.socket.nio;
 
 import cn.hutool.core.io.IORuntimeException;
 import cn.hutool.core.io.IoUtil;
+import cn.hutool.core.thread.ThreadFactoryBuilder;
 
 import java.io.Closeable;
 import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.nio.ByteBuffer;
+import java.nio.channels.SelectionKey;
+import java.nio.channels.Selector;
+import java.nio.channels.ServerSocketChannel;
 import java.nio.channels.SocketChannel;
+import java.util.Iterator;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadFactory;
 
 /**
  * NIO客户端
@@ -15,88 +23,160 @@ import java.nio.channels.SocketChannel;
  * @author looly
  * @since 4.4.5
  */
-public class NioClient implements Closeable {
+public abstract class NioClient implements Closeable {
 
-	private SocketChannel channel;
+    private Selector selector;
+    private SocketChannel channel;
+    private ExecutorService executorService;
 
-	/**
-	 * 构造
-	 *
-	 * @param host 服务器地址
-	 * @param port 端口
-	 */
-	public NioClient(String host, int port) {
-		init(new InetSocketAddress(host, port));
-	}
+    /**
+     * 构造
+     *
+     * @param host 服务器地址
+     * @param port 端口
+     */
+    public NioClient(String host, int port) {
+        init(new InetSocketAddress(host, port));
+    }
 
-	/**
-	 * 构造
-	 *
-	 * @param address 服务器地址
-	 */
-	public NioClient(InetSocketAddress address) {
-		init(address);
-	}
+    /**
+     * 构造
+     *
+     * @param address 服务器地址
+     */
+    public NioClient(InetSocketAddress address) {
+        init(address);
+    }
 
-	/**
-	 * 初始化
-	 *
-	 * @param address 地址和端口
-	 * @return this
-	 */
-	public NioClient init(InetSocketAddress address) {
-		try {
-			this.channel = SocketChannel.open(address);
-		} catch (IOException e) {
-			throw new IORuntimeException(e);
-		}
-		return this;
-	}
+    /**
+     * 初始化
+     *
+     * @param address 地址和端口
+     * @return this
+     */
+    public NioClient init(InetSocketAddress address) {
+        try {
+            //创建一个SocketChannel对象,配置成非阻塞模式
+            this.channel = SocketChannel.open();
+            channel.configureBlocking(false);
 
-	/**
-	 * 处理读事件<br>
-	 * 当收到读取准备就绪的信号后,回调此方法,用户可读取从客户端传世来的消息
-	 *
-	 * @param buffer 服务端数据存储缓存
-	 * @return this
-	 */
-	public NioClient read(ByteBuffer buffer) {
-		try {
-			this.channel.read(buffer);
-		} catch (IOException e) {
-			throw new IORuntimeException(e);
-		}
-		return this;
-	}
+            //创建一个选择器,并把SocketChannel交给selector对象
+            this.selector = Selector.open();
+            channel.register(selector, SelectionKey.OP_CONNECT);
+
+            //发起建立连接的请求,这里会立即返回,当连接建立完成后,SocketChannel就会被选取出来
+            channel.connect(address);
+        } catch (IOException e) {
+            throw new IORuntimeException(e);
+        }
+        return this;
+    }
 
 	/**
-	 * 实现写逻辑<br>
-	 * 当收到写出准备就绪的信号后,回调此方法,用户可向客户端发送消息
-	 *
-	 * @param datas 发送的数据
-	 * @return this
+	 * 检查连接是否建立完成
 	 */
-	public NioClient write(ByteBuffer... datas) {
-		try {
-			this.channel.write(datas);
-		} catch (IOException e) {
-			throw new IORuntimeException(e);
+    public boolean waitConnect() throws IOException {
+    	boolean isConnect = false;
+		while (0 != this.selector.select()) {
+			final Iterator<SelectionKey> keyIter = selector.selectedKeys().iterator();
+			while (keyIter.hasNext()) {
+				//连接建立完成
+				SelectionKey key = keyIter.next();
+				if (key.isConnectable()) {
+					if (this.channel.finishConnect()) {
+						this.channel.register(selector, SelectionKey.OP_READ);
+						isConnect = true;
+					}
+				}
+				keyIter.remove();
+				break;
+			}
+			if (isConnect) {
+				break;
+			}
 		}
-		return this;
+		return isConnect;
 	}
 
-	/**
-	 * 获取SocketChannel
-	 *
-	 * @return SocketChannel
-	 * @since 5.3.10
-	 */
-	public SocketChannel getChannel() {
-		return this.channel;
-	}
+    /**
+     * 开始监听
+     */
+    public void listen() {
+		this.executorService = Executors.newSingleThreadExecutor(r -> {
+            final Thread thread = Executors.defaultThreadFactory().newThread(r);
+            thread.setName("nio-client-listen");
+            return thread;
+        });
+		this.executorService.execute(() -> {
+            try {
+                doListen();
+            } catch (IOException e) {
+                e.printStackTrace();
+            }
+        });
+    }
 
-	@Override
-	public void close() {
-		IoUtil.close(this.channel);
+    /**
+     * 开始监听
+     *
+     * @throws IOException IO异常
+     */
+    private void doListen() throws IOException {
+        while (0 != this.selector.select()) {
+            // 返回已选择键的集合
+            final Iterator<SelectionKey> keyIter = selector.selectedKeys().iterator();
+            while (keyIter.hasNext()) {
+                handle(keyIter.next());
+                keyIter.remove();
+            }
+        }
+    }
+
+    /**
+     * 处理SelectionKey
+     *
+     * @param key SelectionKey
+     */
+    private void handle(SelectionKey key) throws IOException {
+        // 读事件就绪
+        if (key.isReadable()) {
+            final SocketChannel socketChannel = (SocketChannel) key.channel();
+            read(socketChannel);
+        }
+    }
+
+    /**
+     * 处理读事件<br>
+     * 当收到读取准备就绪的信号后,回调此方法,用户可读取从客户端传出来的消息
+     *
+     * @param socketChannel SocketChannel
+     */
+    protected abstract void read(SocketChannel socketChannel);
+
+    /**
+     * 实现写逻辑<br>
+     * 当收到写出准备就绪的信号后,回调此方法,用户可向客户端发送消息
+     *
+     * @param datas 发送的数据
+     * @return this
+     */
+    public NioClient write(ByteBuffer... datas) {
+        try {
+            this.channel.write(datas);
+        } catch (IOException e) {
+            throw new IORuntimeException(e);
+        }
+        return this;
+    }
+
+    public void closeListen() {
+		this.executorService.shutdown();
 	}
+
+    @Override
+    public void close() {
+        IoUtil.close(this.selector);
+        IoUtil.close(this.channel);
+		closeListen();
+    }
 }

+ 1 - 1
hutool-socket/src/main/java/cn/hutool/socket/nio/NioServer.java

@@ -137,7 +137,7 @@ public abstract class NioServer implements Closeable {
 
 	/**
 	 * 处理读事件<br>
-	 * 当收到读取准备就绪的信号后,回调此方法,用户可读取从客户端传来的消息
+	 * 当收到读取准备就绪的信号后,回调此方法,用户可读取从客户端传来的消息
 	 * 
 	 * @param socketChannel SocketChannel
 	 */

+ 67 - 0
hutool-socket/src/test/java/cn/hutool/socket/NioClientTest.java

@@ -0,0 +1,67 @@
+package cn.hutool.socket;
+
+import cn.hutool.core.util.StrUtil;
+import cn.hutool.socket.nio.NioClient;
+import lombok.SneakyThrows;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.nio.ByteBuffer;
+import java.nio.channels.SelectionKey;
+import java.nio.channels.Selector;
+import java.nio.channels.SocketChannel;
+import java.nio.charset.Charset;
+import java.util.Iterator;
+import java.util.Scanner;
+import java.util.Set;
+
+public class NioClientTest {
+
+    @SneakyThrows
+    public static void main(String[] args) {
+        NioClient client = new NioClient("127.0.0.1", 8080) {
+            @SneakyThrows
+            @Override
+            protected void read(SocketChannel sc) {
+                ByteBuffer readBuffer = ByteBuffer.allocate(1024);
+                //从channel读数据到缓冲区
+                int readBytes = sc.read(readBuffer);
+                if (readBytes > 0){
+                    //Flips this buffer.  The limit is set to the current position and then
+                    // the position is set to zero,就是表示要从起始位置开始读取数据
+                    readBuffer.flip();
+                    //eturns the number of elements between the current position and the  limit.
+                    // 要读取的字节长度
+                    byte[] bytes = new byte[readBuffer.remaining()];
+                    //将缓冲区的数据读到bytes数组
+                    readBuffer.get(bytes);
+                    String body = new String(bytes, "UTF-8");
+                    System.out.println("the read client receive message: " + body);
+                }else if(readBytes < 0){
+                    sc.close();
+                }
+            }
+        };
+        if (client.waitConnect()) {
+            client.listen();
+        }
+        ByteBuffer buffer = ByteBuffer.wrap("client 发生到 server".getBytes());
+        client.write(buffer);
+        buffer = ByteBuffer.wrap("client 再次发生到 server".getBytes());
+        client.write(buffer);
+
+        /**
+         * 在控制台向服务器端发送数据
+         */
+        System.out.println("请在下方畅所欲言");
+        Scanner scanner = new Scanner(System.in);
+        while (scanner.hasNextLine()) {
+            String request = scanner.nextLine();
+            if (request != null && request.trim().length() > 0) {
+                client.write(
+                        Charset.forName("UTF-8")
+                                .encode("测试client" + ": " + request));
+            }
+        }
+    }
+}

+ 82 - 0
hutool-socket/src/test/java/cn/hutool/socket/NioServerTest.java

@@ -0,0 +1,82 @@
+package cn.hutool.socket;
+
+import cn.hutool.core.util.StrUtil;
+import cn.hutool.socket.nio.NioServer;
+import lombok.SneakyThrows;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.nio.ByteBuffer;
+import java.nio.channels.SelectionKey;
+import java.nio.channels.Selector;
+import java.nio.channels.ServerSocketChannel;
+import java.nio.channels.SocketChannel;
+import java.util.Set;
+
+public class NioServerTest {
+
+    public static void main(String[] args) {
+        NioServer server = new NioServer(8080) {
+            @SneakyThrows
+            @Override
+            protected void read(SocketChannel sc) {
+                ByteBuffer readBuffer = ByteBuffer.allocate(1024);
+                //从channel读数据到缓冲区
+                int readBytes = sc.read(readBuffer);
+                if (readBytes > 0){
+                    //Flips this buffer.  The limit is set to the current position and then
+                    // the position is set to zero,就是表示要从起始位置开始读取数据
+                    readBuffer.flip();
+                    //eturns the number of elements between the current position and the  limit.
+                    // 要读取的字节长度
+                    byte[] bytes = new byte[readBuffer.remaining()];
+                    //将缓冲区的数据读到bytes数组
+                    readBuffer.get(bytes);
+                    String body = new String(bytes, "UTF-8");
+                    System.out.println("the read server receive message: " + body);
+                    doWrite(sc, body);
+                }else if(readBytes < 0){
+                    sc.close();
+                }
+            }
+
+            @SneakyThrows
+            @Override
+            protected void write(SocketChannel sc) {
+                ByteBuffer readBuffer = ByteBuffer.allocate(1024);
+                //从channel读数据到缓冲区
+                int readBytes = sc.read(readBuffer);
+                if (readBytes > 0){
+                    //Flips this buffer.  The limit is set to the current position and then
+                    // the position is set to zero,就是表示要从起始位置开始读取数据
+                    readBuffer.flip();
+                    //eturns the number of elements between the current position and the  limit.
+                    // 要读取的字节长度
+                    byte[] bytes = new byte[readBuffer.remaining()];
+                    //将缓冲区的数据读到bytes数组
+                    readBuffer.get(bytes);
+                    String body = new String(bytes, "UTF-8");
+                    System.out.println("the write server receive message: " + body);
+                    doWrite(sc, body);
+                }else if(readBytes < 0){
+                    sc.close();
+                }
+            }
+        };
+        server.listen();
+    }
+
+    public static void doWrite(SocketChannel channel, String response) throws IOException {
+        response = "我们已收到消息:"+response;
+        if(!StrUtil.isBlank(response)){
+            byte []  bytes = response.getBytes();
+            //分配一个bytes的length长度的ByteBuffer
+            ByteBuffer write = ByteBuffer.allocate(bytes.length);
+            //将返回数据写入缓冲区
+            write.put(bytes);
+            write.flip();
+            //将缓冲数据写入渠道,返回给客户端
+            channel.write(write);
+        }
+    }
+}