Skip to content

Commit

Permalink
[Bug #303] Silence JWT exception from websocket connection and match …
Browse files Browse the repository at this point in the history
…log level with HTTP authentication exception logging
  • Loading branch information
lukaskabc authored and ledsoft committed Dec 8, 2024
1 parent c1efa5e commit 3ef98de
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

/**
* calls {@link WebSocketExceptionHandler} when possible, otherwise logs exception as error
*/
Expand All @@ -29,51 +26,12 @@ public StompExceptionHandler(WebSocketExceptionHandler webSocketExceptionHandler
@Nonnull byte[] errorPayload, Throwable cause,
StompHeaderAccessor clientHeaderAccessor) {
final Message<?> message = MessageBuilder.withPayload(errorPayload).setHeaders(errorHeaderAccessor).build();
boolean handled = false;
try {
handled = delegate(message, cause);
} catch (InvocationTargetException e) {
LOG.error("Exception thrown during exception handler invocation", e);
} catch (IllegalAccessException unexpected) {
// is checked by delegate
}
final boolean handled = webSocketExceptionHandler.delegate(message, cause);

if (!handled) {
LOG.error("STOMP sub-protocol exception", cause);
}

return super.handleInternal(errorHeaderAccessor, errorPayload, cause, clientHeaderAccessor);
}

/**
* Tries to match method on {@link #webSocketExceptionHandler}
*
* @return true when a method was found and called, false otherwise
* @throws IllegalArgumentException never
*/
private boolean delegate(Message<?> message, Throwable throwable)
throws InvocationTargetException, IllegalAccessException {
if (throwable instanceof Exception exception) {
Method[] methods = webSocketExceptionHandler.getClass().getMethods();
for (final Method method : methods) {
if (!method.canAccess(webSocketExceptionHandler)) {
continue;
}
Class<?>[] params = method.getParameterTypes();
if (params.length != 2) {
continue;
}
if (params[0].isAssignableFrom(message.getClass()) && params[1].isAssignableFrom(exception.getClass())) {
// message, exception
method.invoke(webSocketExceptionHandler, message, exception);
return true;
} else if (params[0].isAssignableFrom(exception.getClass()) && params[1].isAssignableFrom(message.getClass())) {
// exception, message
method.invoke(webSocketExceptionHandler, exception, message);
return true;
}
}
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import cz.cvut.kbss.termit.exception.InvalidParameterException;
import cz.cvut.kbss.termit.exception.InvalidPasswordChangeRequestException;
import cz.cvut.kbss.termit.exception.InvalidTermStateException;
import cz.cvut.kbss.termit.exception.JwtException;
import cz.cvut.kbss.termit.exception.NotFoundException;
import cz.cvut.kbss.termit.exception.PersistenceException;
import cz.cvut.kbss.termit.exception.ResourceExistsException;
Expand All @@ -25,6 +26,7 @@
import cz.cvut.kbss.termit.exception.importing.UnsupportedImportMediaTypeException;
import cz.cvut.kbss.termit.exception.importing.VocabularyImportException;
import cz.cvut.kbss.termit.rest.handler.ErrorInfo;
import cz.cvut.kbss.termit.util.ExceptionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.messaging.Message;
Expand All @@ -33,12 +35,15 @@
import org.springframework.messaging.simp.annotation.SendToUser;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.context.request.async.AsyncRequestNotUsableException;
import org.springframework.web.multipart.MaxUploadSizeExceededException;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URISyntaxException;

import static cz.cvut.kbss.termit.util.ExceptionUtils.findCause;
Expand Down Expand Up @@ -92,8 +97,61 @@ private static ErrorInfo errorInfo(Message<?> message, TermItException e) {
e.getParameters());
}

/**
* Tries to match method on this object that matches signature with params
*
* @return true when a method was found and called, false otherwise
* @throws IllegalArgumentException never
*/
public boolean delegate(Message<?> message, Throwable throwable) {
try {
return delegateInternal(message, throwable.getCause());
} catch (InvocationTargetException invEx) {
LOG.error("Exception thrown during exception handler invocation", invEx);
} catch (IllegalAccessException unexpected) {
// is checked by delegate
}
return false;
}

/**
* Tries to match method on this object that matches signature with params
*
* @return true when a method was found and called, false otherwise
* @throws IllegalArgumentException never
*/
private boolean delegateInternal(Message<?> message, Throwable throwable)
throws InvocationTargetException, IllegalAccessException {
if (throwable instanceof Exception exception) {
Method[] methods = this.getClass().getMethods();
for (final Method method : methods) {
if (!method.canAccess(this) || method.getName().equals("delegate") || method.getName().equals("delegateInternal")) {
continue;
}
Class<?>[] params = method.getParameterTypes();
if (params.length != 2) {
continue;
}
if (params[0].isAssignableFrom(message.getClass()) && params[1].isAssignableFrom(exception.getClass())) {
// message, exception
method.invoke(this, message, exception);
return true;
} else if (params[0].isAssignableFrom(exception.getClass()) && params[1].isAssignableFrom(message.getClass())) {
// exception, message
method.invoke(this, exception, message);
return true;
}
}
}
return false;
}

@MessageExceptionHandler
public void messageDeliveryException(Message<?> message, MessageDeliveryException e) {
if (!(e.getCause() instanceof MessageDeliveryException) && delegate(message, e.getCause())) {
return;
}

// messages without destination will be logged only on trace
(hasDestination(message) ? LOG.atError() : LOG.atTrace())
.setMessage("Failed to send message with destination {}: {}")
Expand Down Expand Up @@ -144,11 +202,16 @@ public ErrorInfo authorizationException(Message<?> message, AuthorizationExcepti
return errorInfo(message, e);
}

@MessageExceptionHandler(AuthenticationException.class)
@MessageExceptionHandler({AuthenticationException.class, AuthenticationServiceException.class})
public ErrorInfo authenticationException(Message<?> message, AuthenticationException e) {
LOG.atDebug().setCause(e).log(e.getMessage());
LOG.atError().setMessage("Authentication failure during message processing: {}\nMessage: {}")
LOG.atWarn().setMessage("Authentication failure during message processing: {}\nMessage: {}")
.addArgument(e.getMessage()).addArgument(message::toString).log();

if (ExceptionUtils.findCause(e, JwtException.class).isPresent()) {
return errorInfo(message, e);
}

LOG.atDebug().setCause(e).log(e.getMessage());
return errorInfo(message, e);
}

Expand Down

0 comments on commit 3ef98de

Please sign in to comment.