Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug #303] Reduce JWT error logging in WebSockets #319

Merged
merged 3 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
package cz.cvut.kbss.termit.websocket.handler;

import jakarta.annotation.Nonnull;
import jakarta.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

/**
* calls {@link WebSocketExceptionHandler} when possible, otherwise logs exception as error
*/
Expand All @@ -26,54 +24,20 @@ public StompExceptionHandler(WebSocketExceptionHandler webSocketExceptionHandler

@Override
protected @Nonnull Message<byte[]> handleInternal(@Nonnull StompHeaderAccessor errorHeaderAccessor,
@Nonnull byte[] errorPayload, Throwable cause,
StompHeaderAccessor clientHeaderAccessor) {
@Nonnull byte[] errorPayload,
@Nullable Throwable cause,
@Nullable StompHeaderAccessor clientHeaderAccessor) {
final Message<?> message = MessageBuilder.withPayload(errorPayload).setHeaders(errorHeaderAccessor).build();
boolean handled = false;
try {
handled = delegate(message, cause);
} catch (InvocationTargetException e) {
LOG.error("Exception thrown during exception handler invocation", e);
} catch (IllegalAccessException unexpected) {
// is checked by delegate
Throwable causeToHandle = cause;
if (causeToHandle != null && causeToHandle.getCause() != null) {
causeToHandle = causeToHandle.getCause();
}
final boolean handled = webSocketExceptionHandler.delegate(message, causeToHandle);

if (!handled) {
LOG.error("STOMP sub-protocol exception", cause);
}

return super.handleInternal(errorHeaderAccessor, errorPayload, cause, clientHeaderAccessor);
}

/**
* Tries to match method on {@link #webSocketExceptionHandler}
*
* @return true when a method was found and called, false otherwise
* @throws IllegalArgumentException never
*/
private boolean delegate(Message<?> message, Throwable throwable)
throws InvocationTargetException, IllegalAccessException {
if (throwable instanceof Exception exception) {
Method[] methods = webSocketExceptionHandler.getClass().getMethods();
for (final Method method : methods) {
if (!method.canAccess(webSocketExceptionHandler)) {
continue;
}
Class<?>[] params = method.getParameterTypes();
if (params.length != 2) {
continue;
}
if (params[0].isAssignableFrom(message.getClass()) && params[1].isAssignableFrom(exception.getClass())) {
// message, exception
method.invoke(webSocketExceptionHandler, message, exception);
return true;
} else if (params[0].isAssignableFrom(exception.getClass()) && params[1].isAssignableFrom(message.getClass())) {
// exception, message
method.invoke(webSocketExceptionHandler, exception, message);
return true;
}
}
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import cz.cvut.kbss.termit.exception.InvalidParameterException;
import cz.cvut.kbss.termit.exception.InvalidPasswordChangeRequestException;
import cz.cvut.kbss.termit.exception.InvalidTermStateException;
import cz.cvut.kbss.termit.exception.JwtException;
import cz.cvut.kbss.termit.exception.NotFoundException;
import cz.cvut.kbss.termit.exception.PersistenceException;
import cz.cvut.kbss.termit.exception.ResourceExistsException;
Expand All @@ -25,6 +26,7 @@
import cz.cvut.kbss.termit.exception.importing.UnsupportedImportMediaTypeException;
import cz.cvut.kbss.termit.exception.importing.VocabularyImportException;
import cz.cvut.kbss.termit.rest.handler.ErrorInfo;
import cz.cvut.kbss.termit.util.ExceptionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.messaging.Message;
Expand All @@ -33,18 +35,25 @@
import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.context.request.async.AsyncRequestNotUsableException;
import org.springframework.web.multipart.MaxUploadSizeExceededException;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URISyntaxException;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;

import static cz.cvut.kbss.termit.util.ExceptionUtils.findCause;

/**
* @implSpec Should reflect {@link cz.cvut.kbss.termit.rest.handler.RestExceptionHandler}
* @implSpec Should reflect {@link cz.cvut.kbss.termit.rest.handler.RestExceptionHandler}.<br>
* In order for the delegation to work, the method signature of MessageExceptionHandler methods must be {@code (Message<?>, Exception)}
*/
@SendToUser
@ControllerAdvice
Expand Down Expand Up @@ -92,8 +101,83 @@ private static ErrorInfo errorInfo(Message<?> message, TermItException e) {
e.getParameters());
}

/**
* Searches available methods annotated with {@link MessageExceptionHandler} in this class
* when the method signature matches {@code (Message<?>, Exception)}
* and the exception parameter is assignable from the supplied throwable
* the method is called.
*
* @param message the associated message
* @param throwable the exception to handle
* @return true when a method was found and called, false otherwise
*/
public boolean delegate(Message<?> message, Throwable throwable) {
try {
return delegateInternal(message, throwable);
} catch (InvocationTargetException invEx) {
// Exception handler method threw an exception
LOG.error("Exception thrown during exception handler invocation", invEx);
} catch (IllegalAccessException unexpected) {
// is checked by delegateInternal
}
return false;
}

/**
* Searches available methods annotated with {@link MessageExceptionHandler} in this class
* when the method signature matches {@code (Message<?>, Exception)}
* and the exception parameter is assignable from the supplied throwable
* the method is called.
*
* @param message the associated message
* @param throwable the exception to handle
* @return true when a method was found and called, false otherwise
* @throws IllegalArgumentException never
* @throws IllegalAccessException never
* @throws InvocationTargetException when the exception handler method throws an exception
*/
private boolean delegateInternal(Message<?> message, Throwable throwable)
lukaskabc marked this conversation as resolved.
Show resolved Hide resolved
throws InvocationTargetException, IllegalAccessException {
// handle only exceptions
if (throwable instanceof Exception exception) {
// find all methods annotated with MessageExceptionHandler
List<Method> methods = Arrays.stream(this.getClass().getMethods())
.filter(m -> m.isAnnotationPresent(MessageExceptionHandler.class)).toList();
for (final Method method : methods) {
// check for reflection access to prevent IllegalAccessException
if (!method.canAccess(this)) {
continue;
}
// we are interested only in methods with exactly two parameters (message, exception)
Class<?>[] params = method.getParameterTypes();
if (params.length != 2) {
continue;
}
// check if the MessageExceptionHandler annotation has value with allowed exceptions
Class<? extends Throwable>[] allowedExceptions = Optional.ofNullable(method.getAnnotation(MessageExceptionHandler.class))
.map(MessageExceptionHandler::value).orElseGet(() -> new Class[0]);
// if the exception is not allowed by the annotation, skip the method
if (allowedExceptions.length > 0 && Arrays.stream(allowedExceptions).noneMatch(e -> e.isAssignableFrom(exception.getClass()))) {
continue;
}
// validate the method signature
if (params[0].isAssignableFrom(message.getClass()) && params[1].isAssignableFrom(exception.getClass())) {
// call the method with message, exception parameters
method.invoke(this, message, exception);
return true; // exception was handled
}
}
}
// throwable is not an exception or no suitable method was found
return false;
}

@MessageExceptionHandler
public void messageDeliveryException(Message<?> message, MessageDeliveryException e) {
if (!(e.getCause() instanceof MessageDeliveryException) && delegate(message, e.getCause())) {
return;
}

// messages without destination will be logged only on trace
(hasDestination(message) ? LOG.atError() : LOG.atTrace())
.setMessage("Failed to send message with destination {}: {}")
Expand Down Expand Up @@ -144,11 +228,16 @@ public ErrorInfo authorizationException(Message<?> message, AuthorizationExcepti
return errorInfo(message, e);
}

@MessageExceptionHandler(AuthenticationException.class)
@MessageExceptionHandler({AuthenticationException.class, AuthenticationServiceException.class})
public ErrorInfo authenticationException(Message<?> message, AuthenticationException e) {
LOG.atDebug().setCause(e).log(e.getMessage());
LOG.atError().setMessage("Authentication failure during message processing: {}\nMessage: {}")
LOG.atWarn().setMessage("Authentication failure during message processing: {}\nMessage: {}")
.addArgument(e.getMessage()).addArgument(message::toString).log();

if (ExceptionUtils.findCause(e, JwtException.class).isPresent()) {
return errorInfo(message, e);
}

LOG.atDebug().setCause(e).log(e.getMessage());
return errorInfo(message, e);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import static org.awaitility.Awaitility.await;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.Mockito.verify;

Expand Down Expand Up @@ -86,7 +87,8 @@ void connectionIsClosedOnAnyMessageBeforeConnect(String stompCommand, Boolean wi
assertTrue(receivedError.get());
assertFalse(session.isOpen());
assertFalse(receivedReply.get());
verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull());
verify(webSocketExceptionHandler).delegate(notNull(), notNull());
verify(webSocketExceptionHandler).accessDeniedException(notNull(), notNull());
}

WebSocketHandler makeWebSocketHandler(AtomicBoolean receivedReply, AtomicBoolean receivedError) {
Expand Down Expand Up @@ -131,7 +133,8 @@ void connectWithInvalidAuthorizationIsRejected() throws Throwable {
assertTrue(receivedError.get());
assertFalse(session.isOpen());
assertFalse(receivedReply.get());
verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull());
verify(webSocketExceptionHandler).delegate(notNull(), notNull());
verify(webSocketExceptionHandler).authenticationException(notNull(), notNull());
}

/**
Expand Down Expand Up @@ -167,7 +170,8 @@ void connectWithInvalidJwtAuthorizationIsRejected() throws Throwable {
assertFalse(session.isOpen());
assertFalse(receivedReply.get());

verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull());
verify(webSocketExceptionHandler).delegate(notNull(), notNull());
verify(webSocketExceptionHandler).authenticationException(notNull(), notNull());
}

/**
Expand All @@ -186,5 +190,6 @@ void connectionIsNotClosedWhenConnectMessageIsSent() throws Throwable {
assertTrue(session.isConnected());
session.disconnect();
await().atMost(OPERATION_TIMEOUT, TimeUnit.SECONDS).until(() -> !session.isConnected());
verify(webSocketExceptionHandler).delegate(notNull(), isNull());
}
}
Loading