Skip to content

Commit

Permalink
Adds support for users to control the outcome of a WebSocket upgrade …
Browse files Browse the repository at this point in the history
…request (#8594)

* Adds support for users to control the outcome of a WebSocket upgrade request. If the user handler returns a non-101 code, the protocol upgrade fails and a response is written back based on the data returned by the handler, including the error code, headers and the reason for the failure. See issue 7953. Some new tests.

* Removes all possible upgrade headers that don't belong in a failed response. Some other minor cleanup.

Signed-off-by: Santiago Pericas-Geertsen <[email protected]>

---------

Signed-off-by: Santiago Pericas-Geertsen <[email protected]>
  • Loading branch information
spericas authored Apr 3, 2024
1 parent 17938b1 commit 47f273e
Show file tree
Hide file tree
Showing 5 changed files with 421 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, 2023 Oracle and/or its affiliates.
* Copyright (c) 2022, 2024 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -70,6 +70,7 @@ class WebSocketHandler extends SimpleChannelInboundHandler<Object> {
private volatile Connection connection;
private final WebSocketEngine.UpgradeInfo upgradeInfo;
private final BufferedEmittingPublisher<ByteBuf> emitter;
private final TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse();

WebSocketHandler(ChannelHandlerContext ctx, String path,
FullHttpRequest upgradeRequest,
Expand Down Expand Up @@ -140,6 +141,10 @@ public WebSocketEngine getWebSocketEngine() {
this.upgradeInfo = upgrade(ctx);
}

TyrusUpgradeResponse upgradeResponse() {
return upgradeResponse;
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
LOGGER.log(Level.SEVERE, "WS handler ERROR ", cause);
Expand Down Expand Up @@ -195,9 +200,7 @@ WebSocketEngine.UpgradeInfo upgrade(ChannelHandlerContext ctx) {
upgradeRequest.headers().forEach(e -> requestContext.getHeaders().put(e.getKey(), List.of(e.getValue())));

// Use Tyrus to process a WebSocket upgrade request
final TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse();
final WebSocketEngine.UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse);

WebSocketEngine.UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse);
upgradeResponse.getHeaders().forEach(this.upgradeResponseHeaders::add);
return upgradeInfo;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates.
* Copyright (c) 2022, 2024 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,23 +15,38 @@
*/
package io.helidon.webserver.websocket;

import java.nio.charset.Charset;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;

import io.helidon.common.http.Http;
import io.helidon.webserver.ForwardingHandler;

import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.EmptyHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpServerUpgradeHandler;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import org.glassfish.tyrus.core.TyrusUpgradeResponse;

class WebSocketUpgradeCodec implements HttpServerUpgradeHandler.UpgradeCodec {

private static final Logger LOGGER = Logger.getLogger(WebSocketUpgradeCodec.class.getName());

private static final String SEC_WEBSOCKET_ACCEPT = "Sec-WebSocket-Accept";
private static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol";

private final WebSocketRouting webSocketRouting;
private String path;
private WebSocketHandler wsHandler;
Expand All @@ -52,14 +67,43 @@ public boolean prepareUpgradeResponse(ChannelHandlerContext ctx,
HttpHeaders upgradeResponseHeaders) {
try {
path = upgradeRequest.uri();
upgradeResponseHeaders.remove("upgrade");
upgradeResponseHeaders.remove("connection");
this.wsHandler = new WebSocketHandler(ctx, path, upgradeRequest, upgradeResponseHeaders, webSocketRouting);
return true;
upgradeResponseHeaders.remove(Http.Header.UPGRADE);
upgradeResponseHeaders.remove(Http.Header.CONNECTION);
wsHandler = new WebSocketHandler(ctx, path, upgradeRequest, upgradeResponseHeaders, webSocketRouting);

// if not 101 code, create and write to channel a custom user response of
// type text/plain using reason as payload and return false back to Netty
TyrusUpgradeResponse upgradeResponse = wsHandler.upgradeResponse();
if (upgradeResponse.getStatus() != Http.Status.SWITCHING_PROTOCOLS_101.code()) {
// prepare headers for failed response
Map<String, List<String>> upgradeHeaders = upgradeResponse.getHeaders();
upgradeHeaders.remove(Http.Header.UPGRADE);
upgradeHeaders.remove(Http.Header.CONNECTION);
upgradeHeaders.remove(SEC_WEBSOCKET_ACCEPT);
upgradeHeaders.remove(SEC_WEBSOCKET_PROTOCOL);
HttpHeaders headers = new DefaultHttpHeaders();
upgradeHeaders.forEach(headers::add);

// set payload as text/plain with reason phrase
headers.add(Http.Header.CONTENT_TYPE, "text/plain");
String reasonPhrase = upgradeResponse.getReasonPhrase() == null ? ""
: upgradeResponse.getReasonPhrase();
HttpResponse httpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
HttpResponseStatus.valueOf(upgradeResponse.getStatus()),
Unpooled.wrappedBuffer(reasonPhrase.getBytes(Charset.defaultCharset())),
headers,
EmptyHttpHeaders.INSTANCE); // trailing headers

// write, flush and later close connection
ChannelFuture writeComplete = ctx.writeAndFlush(httpResponse);
writeComplete.addListener(ChannelFutureListener.CLOSE);
return false;
}
} catch (Throwable cause) {
LOGGER.log(Level.SEVERE, "Error during upgrade to WebSocket", cause);
return false;
}
return true;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates.
* Copyright (c) 2022, 2024 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@
package io.helidon.webserver.websocket.test;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Logger;

Expand All @@ -28,6 +29,7 @@
import jakarta.websocket.server.HandshakeRequest;
import jakarta.websocket.server.ServerEndpoint;
import jakarta.websocket.server.ServerEndpointConfig;
import org.glassfish.tyrus.core.TyrusUpgradeResponse;

import static io.helidon.webserver.websocket.test.UppercaseCodec.isDecoded;

Expand Down Expand Up @@ -86,6 +88,19 @@ public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request,
LOGGER.info("ServerConfigurator called during handshake");
super.modifyHandshake(sec, request, response);
EchoEndpoint.modifyHandshakeCalled.set(true);

// if not user Helidon, fail to authenticate, return reason and user header
String user = getUserFromParams(request);
if (!user.equals("Helidon") && response instanceof TyrusUpgradeResponse tyrusResponse) {
tyrusResponse.setStatus(401);
tyrusResponse.setReasonPhrase("Failed to authenticate");
tyrusResponse.getHeaders().put("Endpoint", List.of("EchoEndpoint"));
}
}

private String getUserFromParams(HandshakeRequest request) {
List<String> values = request.getParameterMap().get("user");
return values != null && !values.isEmpty() ? values.get(0) : "";
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) 2022, 2024 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.helidon.webserver.websocket.test;

import java.net.URI;
import java.util.List;
import java.util.Map;

import io.helidon.common.http.Http;

import jakarta.websocket.DeploymentException;
import jakarta.websocket.HandshakeResponse;
import jakarta.websocket.server.HandshakeRequest;
import jakarta.websocket.server.ServerEndpointConfig;
import org.glassfish.tyrus.client.auth.AuthenticationException;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.fail;

class HandshakeFailureTest extends TyrusSupportBaseTest {

@BeforeAll
static void startServer() throws Exception {
webServer(true, EchoEndpoint.class);
}

/**
* Should fail because user is not Helidon. See server handshake at
* {@link EchoEndpoint.ServerConfigurator#modifyHandshake(ServerEndpointConfig, HandshakeRequest, HandshakeResponse)}.
*/
@Test
void testEchoSingleUpgradeFail() {
URI uri = URI.create("ws://localhost:" + webServer().port() + "/tyrus/echo?user=Unknown");
EchoClient echoClient = new EchoClient(uri);
try {
echoClient.echo("One");
} catch (Exception e) {
assertThat(e, instanceOf(DeploymentException.class));
assertThat(e.getCause(), instanceOf(AuthenticationException.class));
AuthenticationException ae = (AuthenticationException) e.getCause();
assertThat(ae.getHttpStatusCode(), is(401));
assertThat(ae.getMessage(), is("Authentication failed."));
return;
}
fail("Exception not thrown");
}

/**
* Should fail because user is not Helidon. See server handshake at
* {@link EchoEndpoint.ServerConfigurator#modifyHandshake(ServerEndpointConfig, HandshakeRequest, HandshakeResponse)}.
*/
@Test
void testEchoSingleUpgradeFailRaw() throws Exception {
String response = SocketHttpClient.sendAndReceive("/tyrus/echo?user=Unknown",
Http.Method.GET,
List.of("Connection:Upgrade",
"Upgrade:websocket",
"Sec-WebSocket-Key:0SBbaRkS/idPrmvImDNHBA==",
"Sec-WebSocket-Version:13"),
webServer());

assertThat(SocketHttpClient.statusFromResponse(response),
is(Http.Status.UNAUTHORIZED_401));
assertThat(SocketHttpClient.entityFromResponse(response, false),
is("Failed to authenticate\n"));
Map<String, String> headers = SocketHttpClient.headersFromResponse(response);
assertThat(headers.get("Endpoint"), is("EchoEndpoint"));
assertFalse(headers.containsKey("Connection") || headers.containsKey("connection"));
assertFalse(headers.containsKey("Upgrade") || headers.containsKey("upgrade"));
}
}
Loading

0 comments on commit 47f273e

Please sign in to comment.