Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug #303] Reduce JWT error logging in WebSockets #319

Merged
merged 3 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

/**
* calls {@link WebSocketExceptionHandler} when possible, otherwise logs exception as error
*/
Expand All @@ -29,51 +26,12 @@ public StompExceptionHandler(WebSocketExceptionHandler webSocketExceptionHandler
@Nonnull byte[] errorPayload, Throwable cause,
StompHeaderAccessor clientHeaderAccessor) {
final Message<?> message = MessageBuilder.withPayload(errorPayload).setHeaders(errorHeaderAccessor).build();
boolean handled = false;
try {
handled = delegate(message, cause);
} catch (InvocationTargetException e) {
LOG.error("Exception thrown during exception handler invocation", e);
} catch (IllegalAccessException unexpected) {
// is checked by delegate
}
final boolean handled = webSocketExceptionHandler.delegate(message, cause);

if (!handled) {
LOG.error("STOMP sub-protocol exception", cause);
}

return super.handleInternal(errorHeaderAccessor, errorPayload, cause, clientHeaderAccessor);
}

/**
* Tries to match method on {@link #webSocketExceptionHandler}
*
* @return true when a method was found and called, false otherwise
* @throws IllegalArgumentException never
*/
private boolean delegate(Message<?> message, Throwable throwable)
throws InvocationTargetException, IllegalAccessException {
if (throwable instanceof Exception exception) {
Method[] methods = webSocketExceptionHandler.getClass().getMethods();
for (final Method method : methods) {
if (!method.canAccess(webSocketExceptionHandler)) {
continue;
}
Class<?>[] params = method.getParameterTypes();
if (params.length != 2) {
continue;
}
if (params[0].isAssignableFrom(message.getClass()) && params[1].isAssignableFrom(exception.getClass())) {
// message, exception
method.invoke(webSocketExceptionHandler, message, exception);
return true;
} else if (params[0].isAssignableFrom(exception.getClass()) && params[1].isAssignableFrom(message.getClass())) {
// exception, message
method.invoke(webSocketExceptionHandler, exception, message);
return true;
}
}
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import cz.cvut.kbss.termit.exception.InvalidParameterException;
import cz.cvut.kbss.termit.exception.InvalidPasswordChangeRequestException;
import cz.cvut.kbss.termit.exception.InvalidTermStateException;
import cz.cvut.kbss.termit.exception.JwtException;
import cz.cvut.kbss.termit.exception.NotFoundException;
import cz.cvut.kbss.termit.exception.PersistenceException;
import cz.cvut.kbss.termit.exception.ResourceExistsException;
Expand All @@ -25,6 +26,7 @@
import cz.cvut.kbss.termit.exception.importing.UnsupportedImportMediaTypeException;
import cz.cvut.kbss.termit.exception.importing.VocabularyImportException;
import cz.cvut.kbss.termit.rest.handler.ErrorInfo;
import cz.cvut.kbss.termit.util.ExceptionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.messaging.Message;
Expand All @@ -33,12 +35,15 @@
import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.context.request.async.AsyncRequestNotUsableException;
import org.springframework.web.multipart.MaxUploadSizeExceededException;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URISyntaxException;

import static cz.cvut.kbss.termit.util.ExceptionUtils.findCause;
Expand Down Expand Up @@ -92,8 +97,61 @@ private static ErrorInfo errorInfo(Message<?> message, TermItException e) {
e.getParameters());
}

/**
* Tries to match method on this object that matches signature with params
*
* @return true when a method was found and called, false otherwise
* @throws IllegalArgumentException never
*/
public boolean delegate(Message<?> message, Throwable throwable) {
try {
return delegateInternal(message, throwable.getCause());
} catch (InvocationTargetException invEx) {
LOG.error("Exception thrown during exception handler invocation", invEx);
} catch (IllegalAccessException unexpected) {
// is checked by delegate
}
return false;
}

/**
* Tries to match method on this object that matches signature with params
*
* @return true when a method was found and called, false otherwise
* @throws IllegalArgumentException never
*/
private boolean delegateInternal(Message<?> message, Throwable throwable)
lukaskabc marked this conversation as resolved.
Show resolved Hide resolved
throws InvocationTargetException, IllegalAccessException {
if (throwable instanceof Exception exception) {
Method[] methods = this.getClass().getMethods();
for (final Method method : methods) {
if (!method.canAccess(this) || method.getName().equals("delegate") || method.getName().equals("delegateInternal")) {
continue;
}
Class<?>[] params = method.getParameterTypes();
if (params.length != 2) {
continue;
}
if (params[0].isAssignableFrom(message.getClass()) && params[1].isAssignableFrom(exception.getClass())) {
// message, exception
method.invoke(this, message, exception);
return true;
} else if (params[0].isAssignableFrom(exception.getClass()) && params[1].isAssignableFrom(message.getClass())) {
// exception, message
method.invoke(this, exception, message);
return true;
}
lukaskabc marked this conversation as resolved.
Show resolved Hide resolved
}
}
return false;
}

@MessageExceptionHandler
public void messageDeliveryException(Message<?> message, MessageDeliveryException e) {
if (!(e.getCause() instanceof MessageDeliveryException) && delegate(message, e.getCause())) {
return;
}

// messages without destination will be logged only on trace
(hasDestination(message) ? LOG.atError() : LOG.atTrace())
.setMessage("Failed to send message with destination {}: {}")
Expand Down Expand Up @@ -144,11 +202,16 @@ public ErrorInfo authorizationException(Message<?> message, AuthorizationExcepti
return errorInfo(message, e);
}

@MessageExceptionHandler(AuthenticationException.class)
@MessageExceptionHandler({AuthenticationException.class, AuthenticationServiceException.class})
public ErrorInfo authenticationException(Message<?> message, AuthenticationException e) {
LOG.atDebug().setCause(e).log(e.getMessage());
LOG.atError().setMessage("Authentication failure during message processing: {}\nMessage: {}")
LOG.atWarn().setMessage("Authentication failure during message processing: {}\nMessage: {}")
.addArgument(e.getMessage()).addArgument(message::toString).log();

if (ExceptionUtils.findCause(e, JwtException.class).isPresent()) {
return errorInfo(message, e);
}

LOG.atDebug().setCause(e).log(e.getMessage());
return errorInfo(message, e);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import static org.awaitility.Awaitility.await;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.Mockito.verify;

Expand Down Expand Up @@ -86,7 +87,8 @@ void connectionIsClosedOnAnyMessageBeforeConnect(String stompCommand, Boolean wi
assertTrue(receivedError.get());
assertFalse(session.isOpen());
assertFalse(receivedReply.get());
verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull());
verify(webSocketExceptionHandler).delegate(notNull(), notNull());
verify(webSocketExceptionHandler).accessDeniedException(notNull(), notNull());
}

WebSocketHandler makeWebSocketHandler(AtomicBoolean receivedReply, AtomicBoolean receivedError) {
Expand Down Expand Up @@ -131,7 +133,8 @@ void connectWithInvalidAuthorizationIsRejected() throws Throwable {
assertTrue(receivedError.get());
assertFalse(session.isOpen());
assertFalse(receivedReply.get());
verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull());
verify(webSocketExceptionHandler).delegate(notNull(), notNull());
verify(webSocketExceptionHandler).authenticationException(notNull(), notNull());
}

/**
Expand Down Expand Up @@ -167,7 +170,8 @@ void connectWithInvalidJwtAuthorizationIsRejected() throws Throwable {
assertFalse(session.isOpen());
assertFalse(receivedReply.get());

verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull());
verify(webSocketExceptionHandler).delegate(notNull(), notNull());
verify(webSocketExceptionHandler).authenticationException(notNull(), notNull());
}

/**
Expand All @@ -186,5 +190,6 @@ void connectionIsNotClosedWhenConnectMessageIsSent() throws Throwable {
assertTrue(session.isConnected());
session.disconnect();
await().atMost(OPERATION_TIMEOUT, TimeUnit.SECONDS).until(() -> !session.isConnected());
verify(webSocketExceptionHandler).delegate(notNull(), isNull());
}
}