diff --git a/influent-java/src/main/java/influent/forward/NioForwardConnection.java b/influent-java/src/main/java/influent/forward/NioForwardConnection.java index a7e3859..9f29033 100644 --- a/influent-java/src/main/java/influent/forward/NioForwardConnection.java +++ b/influent-java/src/main/java/influent/forward/NioForwardConnection.java @@ -1,315 +1,128 @@ -/* - * Copyright 2016 okumin - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package influent.forward; +import influent.internal.msgpack.MsgpackStreamUnpacker; +import influent.internal.nio.NioEventLoop; +import influent.internal.nio.NioTcpChannel; +import influent.internal.util.ThreadSafeQueue; +import org.msgpack.core.MessageBufferPacker; +import org.msgpack.core.MessagePack; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; -import java.nio.channels.SocketChannel; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.util.Optional; -import java.util.function.Consumer; -import java.util.function.Supplier; -import org.msgpack.core.MessageBufferPacker; -import org.msgpack.core.MessagePack; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import influent.exception.InfluentIOException; -import influent.internal.msgpack.MsgpackStreamUnpacker; -import influent.internal.nio.NioAttachment; -import influent.internal.nio.NioEventLoop; -import influent.internal.nio.NioTcpChannel; -import influent.internal.nio.NioTcpConfig; -import influent.internal.util.ThreadSafeQueue; - -/** - * A connection for forward protocol. - */ -final class NioForwardConnection implements NioAttachment { - private static final Logger logger = LoggerFactory.getLogger(NioForwardConnection.class); - private static final String ACK_KEY = "ack"; - - private final NioTcpChannel channel; - private final NioEventLoop eventLoop; - private final ForwardCallback callback; - private final MsgpackStreamUnpacker unpacker; - private final MsgpackForwardRequestDecoder decoder; - private final ForwardSecurity security; - private MsgPackPingDecoder pingDecoder; - private Optional node; - - final ThreadSafeQueue responses = new ThreadSafeQueue<>(); - - private final byte[] nonce = new byte[16]; - private final byte[] userAuth = new byte[16]; - - enum ConnectionState { - HELO, PINGPONG, ESTABLISHED - } - private ConnectionState state; +abstract class NioForwardConnection { + private static final Logger logger = LoggerFactory.getLogger(NioForwardConnection.class); - NioForwardConnection(final NioTcpChannel channel, final NioEventLoop eventLoop, - final ForwardCallback callback, final MsgpackStreamUnpacker unpacker, - final MsgpackForwardRequestDecoder decoder, final ForwardSecurity security) { - this.channel = channel; - this.eventLoop = eventLoop; - this.callback = callback; - this.unpacker = unpacker; - this.decoder = decoder; - this.security = security; - state = ConnectionState.ESTABLISHED; - } + protected static final String ACK_KEY = "ack"; + protected final NioTcpChannel channel; + protected final NioEventLoop eventLoop; + protected final ForwardCallback callback; + protected final MsgpackStreamUnpacker unpacker; + protected final MsgpackForwardRequestDecoder decoder; + protected final ForwardSecurity security; + protected final ThreadSafeQueue responses = new ThreadSafeQueue<>(); - NioForwardConnection(final NioTcpChannel channel, final NioEventLoop eventLoop, - final ForwardCallback callback, final long chunkSizeLimit, final ForwardSecurity security) { - this(channel, eventLoop, callback, new MsgpackStreamUnpacker(chunkSizeLimit), - new MsgpackForwardRequestDecoder(), security); - } + protected final byte[] nonce = new byte[16]; + protected final byte[] userAuth = new byte[16]; - /** - * Constructs a new {@code NioForwardConnection}. - * - * @param socketChannel the inbound channel - * @param eventLoop the {@code NioEventLoop} to which this {@code NioForwardConnection} belongs - * @param callback the callback to handle requests - * @param chunkSizeLimit the allowable size of a chunk - * @param tcpConfig the {@code NioTcpConfig} - * @throws InfluentIOException if some IO error occurs - */ - NioForwardConnection(final SocketChannel socketChannel, final NioEventLoop eventLoop, - final ForwardCallback callback, final long chunkSizeLimit, final NioTcpConfig tcpConfig, - final ForwardSecurity security) { - this(new NioTcpChannel(socketChannel, tcpConfig), eventLoop, callback, chunkSizeLimit, security); + protected MsgPackPingDecoder pingDecoder; + protected Optional node; + protected ConnectionState state; - if (this.security.isEnabled()) { - try { - // SecureRandom secureRandom = SecureRandom.getInstanceStrong(); - // Above secureRandom may block... - // TODO: reuse SecureRandom instance - SecureRandom secureRandom = SecureRandom.getInstance("NativePRNGNonBlocking"); - logger.debug(secureRandom.getAlgorithm()); - secureRandom.nextBytes(nonce); - secureRandom.nextBytes(userAuth); - } catch (NoSuchAlgorithmException e) { - e.printStackTrace(); - } - node = security.findNode(((InetSocketAddress) channel.getRemoteAddress()).getAddress()); - state = ConnectionState.HELO; - pingDecoder = new MsgPackPingDecoder(this.security, node.orElse(null), nonce, userAuth); - channel.register(eventLoop, false, true, this); - responses.enqueue(generateHelo()); - } else { - state = ConnectionState.ESTABLISHED; - channel.register(eventLoop, true, false, this); + enum ConnectionState { + HELO, PINGPONG, ESTABLISHED } - } - /** - * Handles a write event. - * - * @throws InfluentIOException if some IO error occurs - */ - @Override - public void onWritable() { - if (sendResponses()) { - channel.disableOpWrite(eventLoop); - if (state == ConnectionState.HELO) { - state = ConnectionState.PINGPONG; - channel.enableOpRead(eventLoop); - // TODO disconnect after writing failed PONG - } + NioForwardConnection(final NioTcpChannel channel, final NioEventLoop eventLoop, + final ForwardCallback callback, final MsgpackStreamUnpacker unpacker, + final MsgpackForwardRequestDecoder decoder, final ForwardSecurity security) { + state = ConnectionState.ESTABLISHED; + this.channel = channel; + this.eventLoop = eventLoop; + this.callback = callback; + this.unpacker = unpacker; + this.decoder = decoder; + this.security = security; + + if (security.isEnabled()) { + try { + // SecureRandom secureRandom = SecureRandom.getInstanceStrong(); + // Above secureRandom may block... + // TODO: reuse SecureRandom instance + SecureRandom secureRandom = SecureRandom.getInstance("NativePRNGNonBlocking"); + logger.debug(secureRandom.getAlgorithm()); + secureRandom.nextBytes(nonce); + secureRandom.nextBytes(userAuth); + } catch (NoSuchAlgorithmException e) { + e.printStackTrace(); + } + node = security.findNode(((InetSocketAddress) channel.getRemoteAddress()).getAddress()); + pingDecoder = new MsgPackPingDecoder(this.security, node.orElse(null), nonce, userAuth); + } } - } - private boolean sendResponses() { - // TODO: gathering - while (responses.nonEmpty()) { - final ByteBuffer head = responses.peek(); - channel.write(head); - if (head.hasRemaining()) { - return false; + // TODO Set keepalive on HELO message true/false according to ForwardServer configuration + // ForwardServer.keepAliveEnabled set SO_KEEPALIVE. + // See also https://github.com/okumin/influent/pull/32#discussion_r145196969 + protected ByteBuffer generateHelo() { + // ['HELO', options(hash)] + // ['HELO', {'nonce' => nonce, 'auth' => user_auth_salt/empty string, 'keepalive' => true/false}].to_msgpack + MessageBufferPacker packer = MessagePack.newDefaultBufferPacker(); + try { + packer.packArrayHeader(2).packString("HELO").packMapHeader(3).packString("nonce") + .packBinaryHeader(16).writePayload(nonce).packString("auth").packBinaryHeader(16) + .writePayload(userAuth).packString("keepalive").packBoolean(true); + } catch (IOException e) { + logger.error("Failed to pack HELO message", e); } - responses.dequeue(); - } - return true; - } - /** - * Handles a read event. - * - * @throws InfluentIOException if some IO error occurs - */ - @Override - public void onReadable() { - switch (state) { - case PINGPONG: - receivePing(result -> { - responses.enqueue(generatePong(result)); - channel.enableOpWrite(eventLoop); - state = ConnectionState.ESTABLISHED; - }); - break; - case ESTABLISHED: - receiveRequests(); - break; + return packer.toMessageBuffer().sliceAsByteBuffer(); } - if (!channel.isOpen()) { - close(); - } - } - private void receivePing(Consumer checkPingResultConsumer) { - // TODO: optimize - final Supplier supplier = () -> { - final ByteBuffer buffer = ByteBuffer.allocate(1024); - if (!channel.read(buffer)) { - return null; - } - buffer.flip(); - return buffer; - }; - unpacker.feed(supplier, channel); - while (unpacker.hasNext()) { + protected ByteBuffer generatePong(CheckPingResult checkPingResult) { + // [ + // 'PONG', + // bool(authentication result), + // 'reason if authentication failed', + // self_hostname, + // sha512_hex(salt + self_hostname + nonce + sharedkey) + // ] + MessageBufferPacker packer = MessagePack.newDefaultBufferPacker(); try { - checkPingResultConsumer.accept(pingDecoder.decode(unpacker.next())); - } catch (final IllegalArgumentException e) { - logger.error( - "Received an invalid ping message. remote address = " + channel.getRemoteAddress(), e - ); - } - } - } - - private void receiveRequests() { - // TODO: optimize - final Supplier supplier = () -> { - final ByteBuffer buffer = ByteBuffer.allocate(1024); - if (!channel.read(buffer)) { - return null; - } - buffer.flip(); - return buffer; - }; - unpacker.feed(supplier, channel); - while (unpacker.hasNext()) { - try { - decoder.decode(unpacker.next()).ifPresent(result -> { - logger.debug( - "Received a forward request from {}. chunk_id = {}", - channel.getRemoteAddress(), result.getOption() - ); - callback.consume(result.getStream()).thenRun(() -> { - // Executes on user's callback thread since the queue never block. - result.getOption().getChunk().ifPresent(chunk -> completeTask(chunk)); - logger.debug("Completed the task. chunk_id = {}.", result.getOption()); - }); - }); - } catch (final IllegalArgumentException e) { - logger.error( - "Received an invalid message. remote address = " + channel.getRemoteAddress(), e - ); + if (checkPingResult.isSucceeded()) { + MessageDigest md = MessageDigest.getInstance("SHA-512"); + md.update(checkPingResult.getSharedKeySalt().getBytes()); + md.update(security.getSelfHostname().getBytes()); + md.update(nonce); + md.update(checkPingResult.getSharedKey().getBytes()); + packer.packArrayHeader(5).packString("PONG").packBoolean(checkPingResult.isSucceeded()) + .packString("").packString(security.getSelfHostname()) + .packString(generateHexString(md.digest())); + } else { + packer.packArrayHeader(5).packString("PONG").packBoolean(checkPingResult.isSucceeded()) + .packString(checkPingResult.getReason()).packString("").packString(""); + } + } catch (IOException e) { + logger.error("Failed to pack PONG message", e); + } catch (NoSuchAlgorithmException e) { + logger.error(e.getMessage(), e); } - } - } - // This method is thread-safe. - private void completeTask(final String chunk) { - try { - final MessageBufferPacker packer = MessagePack.newDefaultBufferPacker(); - packer.packMapHeader(1); - packer.packString(ACK_KEY); - packer.packString(chunk); - final ByteBuffer buffer = packer.toMessageBuffer().sliceAsByteBuffer(); - responses.enqueue(buffer); - channel.enableOpWrite(eventLoop); - } catch (final IOException e) { - logger.error("Failed packing. chunk = " + chunk, e); + return packer.toMessageBuffer().sliceAsByteBuffer(); } - } - - // TODO Set keepalive on HELO message true/false according to ForwardServer configuration - // ForwardServer.keepAliveEnabled set SO_KEEPALIVE. - // See also https://github.com/okumin/influent/pull/32#discussion_r145196969 - private ByteBuffer generateHelo() { - // ['HELO', options(hash)] - // ['HELO', {'nonce' => nonce, 'auth' => user_auth_salt/empty string, 'keepalive' => true/false}].to_msgpack - MessageBufferPacker packer = MessagePack.newDefaultBufferPacker(); - try { - packer.packArrayHeader(2).packString("HELO").packMapHeader(3).packString("nonce") - .packBinaryHeader(16).writePayload(nonce).packString("auth").packBinaryHeader(16) - .writePayload(userAuth).packString("keepalive").packBoolean(true); - } catch (IOException e) { - logger.error("Failed to pack HELO message", e); - } - - return packer.toMessageBuffer().sliceAsByteBuffer(); - } - private ByteBuffer generatePong(CheckPingResult checkPingResult) { - // [ - // 'PONG', - // bool(authentication result), - // 'reason if authentication failed', - // self_hostname, - // sha512_hex(salt + self_hostname + nonce + sharedkey) - // ] - MessageBufferPacker packer = MessagePack.newDefaultBufferPacker(); - try { - if (checkPingResult.isSucceeded()) { - MessageDigest md = MessageDigest.getInstance("SHA-512"); - md.update(checkPingResult.getSharedKeySalt().getBytes()); - md.update(security.getSelfHostname().getBytes()); - md.update(nonce); - md.update(checkPingResult.getSharedKey().getBytes()); - packer.packArrayHeader(5).packString("PONG").packBoolean(checkPingResult.isSucceeded()) - .packString("").packString(security.getSelfHostname()) - .packString(generateHexString(md.digest())); - } else { - packer.packArrayHeader(5).packString("PONG").packBoolean(checkPingResult.isSucceeded()) - .packString(checkPingResult.getReason()).packString("").packString(""); + private String generateHexString(final byte[] digest) { + StringBuilder sb = new StringBuilder(); + for (byte b : digest) { + sb.append(String.format("%02x", b)); } - } catch (IOException e) { - logger.error("Failed to pack PONG message", e); - } catch (NoSuchAlgorithmException e) { - logger.error(e.getMessage(), e); + return sb.toString(); } - - return packer.toMessageBuffer().sliceAsByteBuffer(); - } - - private String generateHexString(final byte[] digest) { - StringBuilder sb = new StringBuilder(); - for (byte b : digest) { - sb.append(String.format("%02x", b)); - } - return sb.toString(); - } - - @Override - public void close() { - channel.close(); - logger.debug("NioForwardConnection bound with {} closed.", channel.getRemoteAddress()); - } - - @Override - public String toString() { - return "NioForwardConnection(" + channel.getRemoteAddress() + ")"; - } } diff --git a/influent-java/src/main/java/influent/forward/NioForwardServer.java b/influent-java/src/main/java/influent/forward/NioForwardServer.java index 8ce893f..7aa0a21 100644 --- a/influent-java/src/main/java/influent/forward/NioForwardServer.java +++ b/influent-java/src/main/java/influent/forward/NioForwardServer.java @@ -60,10 +60,10 @@ final class NioForwardServer implements ForwardServer { if (channelConfig.isSslEnabled()) { channelFactory = (socketChannel) -> new NioSslForwardConnection( socketChannel, workerEventLoopPool.next(), callback, - channelConfig.createSSLEngine(), chunkSizeLimit, tcpConfig + channelConfig.createSSLEngine(), chunkSizeLimit, tcpConfig, security ); } else { - channelFactory = (socketChannel) -> new NioForwardConnection( + channelFactory = (socketChannel) -> new NioTcpForwardConnection( socketChannel, workerEventLoopPool.next(), callback, chunkSizeLimit, tcpConfig, security ); } diff --git a/influent-java/src/main/java/influent/forward/NioSslForwardConnection.java b/influent-java/src/main/java/influent/forward/NioSslForwardConnection.java index 1ffdac1..dd27a8d 100644 --- a/influent-java/src/main/java/influent/forward/NioSslForwardConnection.java +++ b/influent-java/src/main/java/influent/forward/NioSslForwardConnection.java @@ -16,64 +16,54 @@ package influent.forward; +import influent.exception.InfluentIOException; +import influent.internal.msgpack.MsgpackStreamUnpacker; +import influent.internal.nio.NioAttachment; +import influent.internal.nio.NioEventLoop; +import influent.internal.nio.NioTcpChannel; +import influent.internal.nio.NioTcpConfig; +import org.msgpack.core.MessageBufferPacker; +import org.msgpack.core.MessagePack; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ReadOnlyBufferException; import java.nio.channels.SocketChannel; import java.util.LinkedList; import java.util.Queue; +import java.util.function.Consumer; import java.util.function.Supplier; -import javax.net.ssl.SSLEngine; -import javax.net.ssl.SSLEngineResult; -import javax.net.ssl.SSLException; -import org.msgpack.core.MessageBufferPacker; -import org.msgpack.core.MessagePack; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import influent.exception.InfluentIOException; -import influent.internal.msgpack.MsgpackStreamUnpacker; -import influent.internal.nio.NioAttachment; -import influent.internal.nio.NioEventLoop; -import influent.internal.nio.NioTcpChannel; -import influent.internal.nio.NioTcpConfig; -import influent.internal.util.ThreadSafeQueue; /** * A connection for SSL/TLS forward protocol. */ -final class NioSslForwardConnection implements NioAttachment { +final class NioSslForwardConnection extends NioForwardConnection implements NioAttachment { private static final Logger logger = LoggerFactory.getLogger(NioSslForwardConnection.class); - private static final String ACK_KEY = "ack"; - - private final NioTcpChannel channel; - private final NioEventLoop eventLoop; - private final ForwardCallback callback; - private final SSLEngine engine; - private final MsgpackStreamUnpacker unpacker; - private final MsgpackForwardRequestDecoder decoder; - private final ThreadSafeQueue responses = new ThreadSafeQueue<>(); + private final SSLEngine engine; - // Prepare a ByteBuffer with sufficient size + // Prepare a ByteBuffer with sufficient size private ByteBuffer inboundNetworkBuffer = ByteBuffer.allocate(1024 * 1024); private final Queue outboundNetworkBuffers = new LinkedList<>(); NioSslForwardConnection(final NioTcpChannel channel, final NioEventLoop eventLoop, final ForwardCallback callback, final SSLEngine engine, final MsgpackStreamUnpacker unpacker, - final MsgpackForwardRequestDecoder decoder) { - this.channel = channel; - this.eventLoop = eventLoop; - this.callback = callback; - this.engine = engine; - this.unpacker = unpacker; - this.decoder = decoder; - inboundNetworkBuffer.position(inboundNetworkBuffer.limit()); + final MsgpackForwardRequestDecoder decoder, final ForwardSecurity security) { + super(channel, eventLoop, callback, unpacker, decoder, security); + this.engine = engine; + inboundNetworkBuffer.position(inboundNetworkBuffer.limit()); } NioSslForwardConnection(final NioTcpChannel channel, final NioEventLoop eventLoop, - final ForwardCallback callback, final SSLEngine engine, final long chunkSizeLimit) { + final ForwardCallback callback, final SSLEngine engine, final long chunkSizeLimit, + final ForwardSecurity security) { this(channel, eventLoop, callback, engine, new MsgpackStreamUnpacker(chunkSizeLimit), - new MsgpackForwardRequestDecoder()); + new MsgpackForwardRequestDecoder(), security); } /** @@ -88,10 +78,18 @@ final class NioSslForwardConnection implements NioAttachment { */ NioSslForwardConnection(final SocketChannel socketChannel, final NioEventLoop eventLoop, final ForwardCallback callback, final SSLEngine engine, final long chunkSizeLimit, - final NioTcpConfig tcpConfig) { - this(new NioTcpChannel(socketChannel, tcpConfig), eventLoop, callback, engine, chunkSizeLimit); + final NioTcpConfig tcpConfig, final ForwardSecurity security) { + this(new NioTcpChannel(socketChannel, tcpConfig), eventLoop, callback, engine, chunkSizeLimit, + security); - channel.register(eventLoop, true, false, this); + if (security.isEnabled()) { + state = ConnectionState.HELO; + channel.register(eventLoop, false, true, this); + responses.enqueue(generateHelo()); + } else { + state = ConnectionState.ESTABLISHED; + channel.register(eventLoop, true, false, this); + } } /** @@ -108,10 +106,21 @@ public void onWritable() { return; } + boolean isWrittenAll = false; while (responses.nonEmpty()) { final ByteBuffer head = responses.dequeue(); - wrapAndSend(head); + isWrittenAll &= wrapAndSend(head); } + + if (isWrittenAll) { + channel.disableOpWrite(eventLoop); + if (state == ConnectionState.HELO) { + state = ConnectionState.PINGPONG; + channel.enableOpRead(eventLoop); + // TODO disconnect after writing failed PONG + } + } + if (!channel.isOpen()) { close(); } @@ -131,12 +140,46 @@ public void onReadable() { return; } - receiveRequests(); + switch (state) { + case PINGPONG: + receivePing(result -> { + responses.enqueue(generatePong(result)); + channel.enableOpWrite(eventLoop); + state = ConnectionState.ESTABLISHED; + }); + break; + case ESTABLISHED: + receiveRequests(); + break; + } if (!channel.isOpen()) { close(); } } + private void receivePing(Consumer checkPingResultConsumer) { + // TODO: optimize + final Supplier supplier = () -> { + final ByteBuffer buffer = ByteBuffer.allocate(1024); + receiveAndUnwrap(buffer); + buffer.flip(); + if (!buffer.hasRemaining()) { + return null; + } + return buffer; + }; + unpacker.feed(supplier, channel); + while (unpacker.hasNext()) { + try { + checkPingResultConsumer.accept(pingDecoder.decode(unpacker.next())); + } catch (final IllegalArgumentException e) { + logger.error( + "Received an invalid ping message. remote address = " + channel.getRemoteAddress(), e + ); + } + } + } + private void receiveRequests() { // TODO: optimize final Supplier supplier = () -> { @@ -305,7 +348,7 @@ private static boolean isHandshaking(final SSLEngineResult.HandshakeStatus statu && status != SSLEngineResult.HandshakeStatus.FINISHED; } - @Override + @Override public void close() { // TODO: graceful stop channel.close(); diff --git a/influent-java/src/main/java/influent/forward/NioTcpForwardConnection.java b/influent-java/src/main/java/influent/forward/NioTcpForwardConnection.java new file mode 100644 index 0000000..63b991f --- /dev/null +++ b/influent-java/src/main/java/influent/forward/NioTcpForwardConnection.java @@ -0,0 +1,212 @@ +/* + * Copyright 2016 okumin + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package influent.forward; + +import influent.exception.InfluentIOException; +import influent.internal.msgpack.MsgpackStreamUnpacker; +import influent.internal.nio.NioAttachment; +import influent.internal.nio.NioEventLoop; +import influent.internal.nio.NioTcpChannel; +import influent.internal.nio.NioTcpConfig; +import org.msgpack.core.MessageBufferPacker; +import org.msgpack.core.MessagePack; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.SocketChannel; +import java.util.function.Consumer; +import java.util.function.Supplier; + +/** + * A connection for forward protocol. + */ +final class NioTcpForwardConnection extends NioForwardConnection implements NioAttachment { + private static final Logger logger = LoggerFactory.getLogger(NioTcpForwardConnection.class); + + NioTcpForwardConnection(final NioTcpChannel channel, final NioEventLoop eventLoop, + final ForwardCallback callback, final MsgpackStreamUnpacker unpacker, + final MsgpackForwardRequestDecoder decoder, final ForwardSecurity security) { + super(channel, eventLoop, callback, unpacker, decoder, security); + } + + NioTcpForwardConnection(final NioTcpChannel channel, final NioEventLoop eventLoop, + final ForwardCallback callback, final long chunkSizeLimit, final ForwardSecurity security) { + this(channel, eventLoop, callback, new MsgpackStreamUnpacker(chunkSizeLimit), + new MsgpackForwardRequestDecoder(), security); + } + + /** + * Constructs a new {@code NioTcpForwardConnection}. + * + * @param socketChannel the inbound channel + * @param eventLoop the {@code NioEventLoop} to which this {@code NioTcpForwardConnection} belongs + * @param callback the callback to handle requests + * @param chunkSizeLimit the allowable size of a chunk + * @param tcpConfig the {@code NioTcpConfig} + * @throws InfluentIOException if some IO error occurs + */ + NioTcpForwardConnection(final SocketChannel socketChannel, final NioEventLoop eventLoop, + final ForwardCallback callback, final long chunkSizeLimit, final NioTcpConfig tcpConfig, + final ForwardSecurity security) { + this(new NioTcpChannel(socketChannel, tcpConfig), eventLoop, callback, chunkSizeLimit, security); + + if (this.security.isEnabled()) { + state = ConnectionState.HELO; + channel.register(eventLoop, false, true, this); + responses.enqueue(generateHelo()); + } else { + state = ConnectionState.ESTABLISHED; + channel.register(eventLoop, true, false, this); + } + } + + /** + * Handles a write event. + * + * @throws InfluentIOException if some IO error occurs + */ + @Override + public void onWritable() { + if (sendResponses()) { + channel.disableOpWrite(eventLoop); + if (state == ConnectionState.HELO) { + state = ConnectionState.PINGPONG; + channel.enableOpRead(eventLoop); + // TODO disconnect after writing failed PONG + } + } + } + + private boolean sendResponses() { + // TODO: gathering + while (responses.nonEmpty()) { + final ByteBuffer head = responses.peek(); + channel.write(head); + if (head.hasRemaining()) { + return false; + } + responses.dequeue(); + } + return true; + } + + /** + * Handles a read event. + * + * @throws InfluentIOException if some IO error occurs + */ + @Override + public void onReadable() { + switch (state) { + case PINGPONG: + receivePing(result -> { + responses.enqueue(generatePong(result)); + channel.enableOpWrite(eventLoop); + state = ConnectionState.ESTABLISHED; + }); + break; + case ESTABLISHED: + receiveRequests(); + break; + } + if (!channel.isOpen()) { + close(); + } + } + + private void receivePing(Consumer checkPingResultConsumer) { + // TODO: optimize + final Supplier supplier = () -> { + final ByteBuffer buffer = ByteBuffer.allocate(1024); + if (!channel.read(buffer)) { + return null; + } + buffer.flip(); + return buffer; + }; + unpacker.feed(supplier, channel); + while (unpacker.hasNext()) { + try { + checkPingResultConsumer.accept(pingDecoder.decode(unpacker.next())); + } catch (final IllegalArgumentException e) { + logger.error( + "Received an invalid ping message. remote address = " + channel.getRemoteAddress(), e + ); + } + } + } + + private void receiveRequests() { + // TODO: optimize + final Supplier supplier = () -> { + final ByteBuffer buffer = ByteBuffer.allocate(1024); + if (!channel.read(buffer)) { + return null; + } + buffer.flip(); + return buffer; + }; + unpacker.feed(supplier, channel); + while (unpacker.hasNext()) { + try { + decoder.decode(unpacker.next()).ifPresent(result -> { + logger.debug( + "Received a forward request from {}. chunk_id = {}", + channel.getRemoteAddress(), result.getOption() + ); + callback.consume(result.getStream()).thenRun(() -> { + // Executes on user's callback thread since the queue never block. + result.getOption().getChunk().ifPresent(chunk -> completeTask(chunk)); + logger.debug("Completed the task. chunk_id = {}.", result.getOption()); + }); + }); + } catch (final IllegalArgumentException e) { + logger.error( + "Received an invalid message. remote address = " + channel.getRemoteAddress(), e + ); + } + } + } + + // This method is thread-safe. + private void completeTask(final String chunk) { + try { + final MessageBufferPacker packer = MessagePack.newDefaultBufferPacker(); + packer.packMapHeader(1); + packer.packString(ACK_KEY); + packer.packString(chunk); + final ByteBuffer buffer = packer.toMessageBuffer().sliceAsByteBuffer(); + responses.enqueue(buffer); + channel.enableOpWrite(eventLoop); + } catch (final IOException e) { + logger.error("Failed packing. chunk = " + chunk, e); + } + } + + @Override + public void close() { + channel.close(); + logger.debug("NioTcpForwardConnection bound with {} closed.", channel.getRemoteAddress()); + } + + @Override + public String toString() { + return "NioTcpForwardConnection(" + channel.getRemoteAddress() + ")"; + } +} diff --git a/influent-java/src/test/scala/influent/forward/NioForwardConnectionSpec.scala b/influent-java/src/test/scala/influent/forward/NioTcpForwardConnectionSpec.scala similarity index 92% rename from influent-java/src/test/scala/influent/forward/NioForwardConnectionSpec.scala rename to influent-java/src/test/scala/influent/forward/NioTcpForwardConnectionSpec.scala index 00f4b75..74bd661 100644 --- a/influent-java/src/test/scala/influent/forward/NioForwardConnectionSpec.scala +++ b/influent-java/src/test/scala/influent/forward/NioTcpForwardConnectionSpec.scala @@ -35,7 +35,7 @@ import org.msgpack.value.impl.ImmutableStringValueImpl import org.scalatest.WordSpec import org.scalatest.mockito.MockitoSugar -class NioForwardConnectionSpec extends WordSpec with MockitoSugar { +class NioTcpForwardConnectionSpec extends WordSpec with MockitoSugar { private[this] def success: CompletableFuture[Void] = { val future = new CompletableFuture[Void]() future.complete(null) @@ -57,8 +57,8 @@ class NioForwardConnectionSpec extends WordSpec with MockitoSugar { } "onWritable" should { - def createConnection(channel: NioTcpChannel, eventLoop: NioEventLoop): NioForwardConnection = { - new NioForwardConnection(channel, eventLoop, mock[ForwardCallback], Int.MaxValue, mock[ForwardSecurity]) + def createConnection(channel: NioTcpChannel, eventLoop: NioEventLoop): NioTcpForwardConnection = { + new NioTcpForwardConnection(channel, eventLoop, mock[ForwardCallback], Int.MaxValue, mock[ForwardSecurity]) } "send responses" in { @@ -187,7 +187,7 @@ class NioForwardConnectionSpec extends WordSpec with MockitoSugar { val eventLoop = mock[NioEventLoop] val security = mock[ForwardSecurity] - val connection = new NioForwardConnection(channel, eventLoop, callback, unpacker, decoder, security) + val connection = new NioTcpForwardConnection(channel, eventLoop, callback, unpacker, decoder, security) assert(connection.onReadable() === ()) requests.filter(_.isPresent).foreach { request => @@ -229,7 +229,7 @@ class NioForwardConnectionSpec extends WordSpec with MockitoSugar { val eventLoop = mock[NioEventLoop] val security = mock[ForwardSecurity] - val connection = new NioForwardConnection(channel, eventLoop, callback, unpacker, decoder, security) + val connection = new NioTcpForwardConnection(channel, eventLoop, callback, unpacker, decoder, security) assert(connection.onReadable() === ()) verify(callback).consume(request.getStream) assert(!connection.responses.nonEmpty()) @@ -260,7 +260,7 @@ class NioForwardConnectionSpec extends WordSpec with MockitoSugar { val eventLoop = mock[NioEventLoop] val security = mock[ForwardSecurity] - val connection = new NioForwardConnection(channel, eventLoop, callback, unpacker, decoder, security) + val connection = new NioTcpForwardConnection(channel, eventLoop, callback, unpacker, decoder, security) assert(connection.onReadable() === ()) verify(callback).consume(request.getStream) verifyZeroInteractions(eventLoop) @@ -278,7 +278,7 @@ class NioForwardConnectionSpec extends WordSpec with MockitoSugar { val callback = mock[ForwardCallback] val decoder = mock[MsgpackForwardRequestDecoder] val security = mock[ForwardSecurity] - val connection = new NioForwardConnection(channel, eventLoop, callback, unpacker, decoder, security) + val connection = new NioTcpForwardConnection(channel, eventLoop, callback, unpacker, decoder, security) assert(connection.onReadable() === ()) verify(channel).close() @@ -306,7 +306,7 @@ class NioForwardConnectionSpec extends WordSpec with MockitoSugar { val eventLoop = mock[NioEventLoop] val security = mock[ForwardSecurity] - val connection = new NioForwardConnection(channel, eventLoop, callback, unpacker, decoder, security) + val connection = new NioTcpForwardConnection(channel, eventLoop, callback, unpacker, decoder, security) assert(connection.onReadable() === ()) verify(callback).consume(request.getStream) @@ -337,7 +337,7 @@ class NioForwardConnectionSpec extends WordSpec with MockitoSugar { val eventLoop = mock[NioEventLoop] val security = mock[ForwardSecurity] - val connection = new NioForwardConnection(channel, eventLoop, callback, unpacker, decoder, security) + val connection = new NioTcpForwardConnection(channel, eventLoop, callback, unpacker, decoder, security) assert(connection.onReadable() === ()) verify(callback).consume(request.getStream) @@ -354,7 +354,7 @@ class NioForwardConnectionSpec extends WordSpec with MockitoSugar { val security = mock[ForwardSecurity] when(unpacker.feed(any[Supplier[ByteBuffer]], ArgumentMatchers.eq[NioTcpChannel](channel))) .thenThrow(new InfluentIOException()) - val connection = new NioForwardConnection( + val connection = new NioTcpForwardConnection( channel, mock[NioEventLoop], callback, unpacker, mock[MsgpackForwardRequestDecoder], security ) @@ -368,7 +368,7 @@ class NioForwardConnectionSpec extends WordSpec with MockitoSugar { "closes the channel" in { val channel = mock[NioTcpChannel] val security = mock[ForwardSecurity] - val connection = new NioForwardConnection( + val connection = new NioTcpForwardConnection( channel, mock[NioEventLoop], mock[ForwardCallback], Int.MaxValue, security ) assert(connection.close() === ())