From 752f7344515f54722bef12f7e42521b810a9efaf Mon Sep 17 00:00:00 2001 From: Alessandro Autiero Date: Tue, 16 Jan 2024 23:39:34 +0100 Subject: [PATCH] Fixed websocket message reading --- .../auties/whatsapp/socket/SocketSession.java | 85 ++++++++----------- 1 file changed, 37 insertions(+), 48 deletions(-) diff --git a/src/main/java/it/auties/whatsapp/socket/SocketSession.java b/src/main/java/it/auties/whatsapp/socket/SocketSession.java index 49bca82a..033abe6d 100644 --- a/src/main/java/it/auties/whatsapp/socket/SocketSession.java +++ b/src/main/java/it/auties/whatsapp/socket/SocketSession.java @@ -15,8 +15,6 @@ import java.nio.channels.AsynchronousChannelGroup; import java.nio.channels.AsynchronousSocketChannel; import java.nio.channels.CompletionHandler; -import java.util.ArrayList; -import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutorService; @@ -55,11 +53,11 @@ static SocketSession of(URI proxy, ExecutorService executor, boolean webSocket) public static final class WebSocketSession extends SocketSession implements WebSocket.Listener { private WebSocket session; - private final List inputParts; + private byte[] message; + private int messageOffset; WebSocketSession(URI proxy, ExecutorService executor) { super(proxy, executor); - this.inputParts = new ArrayList<>(5); } @SuppressWarnings("resource") // Not needed @@ -101,7 +99,9 @@ public CompletableFuture sendBinary(byte[] bytes) { } private boolean isOpen() { - return session != null && !session.isInputClosed() && !session.isOutputClosed(); + return session != null + && !session.isInputClosed() + && !session.isOutputClosed(); } @Override @@ -112,7 +112,7 @@ public void onOpen(WebSocket webSocket) { @Override public CompletionStage onClose(WebSocket webSocket, int statusCode, String reason) { - inputParts.clear(); + message = null; listener.onClose(); return WebSocket.Listener.super.onClose(webSocket, statusCode, reason); } @@ -122,55 +122,44 @@ public void onError(WebSocket webSocket, Throwable error) { listener.onError(error); } - // Ugly but necessary to keep byte[] allocations to a minimum @Override public CompletionStage onBinary(WebSocket webSocket, ByteBuffer data, boolean last) { - inputParts.add(data); - if (!last) { - return WebSocket.Listener.super.onBinary(webSocket, data, false); - } - - var inputPartsCounter = 0; - var length = 0; - var written = 0; - byte[] result = null; - while (inputPartsCounter < inputParts.size()) { - var inputPart = inputParts.get(inputPartsCounter); - if(length <= 0) { - if(inputPart.remaining() >= MESSAGE_LENGTH) { - length = (inputPart.get() << 16) | Short.toUnsignedInt(inputPart.getShort()); - } - - if (length <= 0) { - break; - } - - result = new byte[length]; + if (message == null) { + var length = (data.get() << 16) | Short.toUnsignedInt(data.getShort()); + if(length < 0) { + return WebSocket.Listener.super.onBinary(webSocket, data, last); } - var inputPartSize = inputPart.remaining(); - var readLength = Math.min(inputPartSize, length); - inputPart.get(result, written, readLength); - if(inputPart.remaining() < MESSAGE_LENGTH) { - inputPartsCounter++; - } + this.message = new byte[length]; + this.messageOffset = 0; + } - written += readLength; - length -= readLength; - if(length <= 0) { - try { - listener.onMessage(result); - } catch (Throwable throwable) { - listener.onError(throwable); - } - - written = 0; - result = null; - } + var currentDataLength = data.remaining(); + var remainingDataLength = message.length - messageOffset; + var actualDataLength = Math.min(currentDataLength, remainingDataLength); + data.get(message, messageOffset, actualDataLength); + messageOffset += actualDataLength; + if (messageOffset != message.length) { + return WebSocket.Listener.super.onBinary(webSocket, data, last); + } + + notifyMessage(); + if(remainingDataLength - currentDataLength != 0) { + return onBinary(webSocket, data, true); } - inputParts.clear(); - return WebSocket.Listener.super.onBinary(webSocket, data, true); + + return WebSocket.Listener.super.onBinary(webSocket, data, last); + } + + private void notifyMessage() { + try { + listener.onMessage(message); + } catch (Throwable throwable) { + listener.onError(throwable); + }finally { + this.message = null; + } } }