diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java index 4aeeb1883..d471f242f 100644 --- a/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java +++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java @@ -1,6 +1,7 @@ package cz.cvut.kbss.termit.websocket.handler; import jakarta.annotation.Nonnull; +import jakarta.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.messaging.Message; @@ -8,9 +9,6 @@ import org.springframework.messaging.support.MessageBuilder; import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; - /** * calls {@link WebSocketExceptionHandler} when possible, otherwise logs exception as error */ @@ -26,17 +24,15 @@ public StompExceptionHandler(WebSocketExceptionHandler webSocketExceptionHandler @Override protected @Nonnull Message handleInternal(@Nonnull StompHeaderAccessor errorHeaderAccessor, - @Nonnull byte[] errorPayload, Throwable cause, - StompHeaderAccessor clientHeaderAccessor) { + @Nonnull byte[] errorPayload, + @Nullable Throwable cause, + @Nullable StompHeaderAccessor clientHeaderAccessor) { final Message message = MessageBuilder.withPayload(errorPayload).setHeaders(errorHeaderAccessor).build(); - boolean handled = false; - try { - handled = delegate(message, cause); - } catch (InvocationTargetException e) { - LOG.error("Exception thrown during exception handler invocation", e); - } catch (IllegalAccessException unexpected) { - // is checked by delegate + Throwable causeToHandle = cause; + if (causeToHandle != null && causeToHandle.getCause() != null) { + causeToHandle = causeToHandle.getCause(); } + final boolean handled = webSocketExceptionHandler.delegate(message, causeToHandle); if (!handled) { LOG.error("STOMP sub-protocol exception", cause); @@ -44,36 +40,4 @@ public StompExceptionHandler(WebSocketExceptionHandler webSocketExceptionHandler return super.handleInternal(errorHeaderAccessor, errorPayload, cause, clientHeaderAccessor); } - - /** - * Tries to match method on {@link #webSocketExceptionHandler} - * - * @return true when a method was found and called, false otherwise - * @throws IllegalArgumentException never - */ - private boolean delegate(Message message, Throwable throwable) - throws InvocationTargetException, IllegalAccessException { - if (throwable instanceof Exception exception) { - Method[] methods = webSocketExceptionHandler.getClass().getMethods(); - for (final Method method : methods) { - if (!method.canAccess(webSocketExceptionHandler)) { - continue; - } - Class[] params = method.getParameterTypes(); - if (params.length != 2) { - continue; - } - if (params[0].isAssignableFrom(message.getClass()) && params[1].isAssignableFrom(exception.getClass())) { - // message, exception - method.invoke(webSocketExceptionHandler, message, exception); - return true; - } else if (params[0].isAssignableFrom(exception.getClass()) && params[1].isAssignableFrom(message.getClass())) { - // exception, message - method.invoke(webSocketExceptionHandler, exception, message); - return true; - } - } - } - return false; - } } diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java index c6042bb9a..dd5c574ad 100644 --- a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java +++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java @@ -11,6 +11,7 @@ import cz.cvut.kbss.termit.exception.InvalidParameterException; import cz.cvut.kbss.termit.exception.InvalidPasswordChangeRequestException; import cz.cvut.kbss.termit.exception.InvalidTermStateException; +import cz.cvut.kbss.termit.exception.JwtException; import cz.cvut.kbss.termit.exception.NotFoundException; import cz.cvut.kbss.termit.exception.PersistenceException; import cz.cvut.kbss.termit.exception.ResourceExistsException; @@ -25,6 +26,7 @@ import cz.cvut.kbss.termit.exception.importing.UnsupportedImportMediaTypeException; import cz.cvut.kbss.termit.exception.importing.VocabularyImportException; import cz.cvut.kbss.termit.rest.handler.ErrorInfo; +import cz.cvut.kbss.termit.util.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.messaging.Message; @@ -33,18 +35,25 @@ import org.springframework.messaging.simp.annotation.SendToUser; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.web.bind.annotation.ControllerAdvice; import org.springframework.web.context.request.async.AsyncRequestNotUsableException; import org.springframework.web.multipart.MaxUploadSizeExceededException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; import static cz.cvut.kbss.termit.util.ExceptionUtils.findCause; /** - * @implSpec Should reflect {@link cz.cvut.kbss.termit.rest.handler.RestExceptionHandler} + * @implSpec Should reflect {@link cz.cvut.kbss.termit.rest.handler.RestExceptionHandler}.
+ * In order for the delegation to work, the method signature of MessageExceptionHandler methods must be {@code (Message, Exception)} */ @SendToUser @ControllerAdvice @@ -92,8 +101,83 @@ private static ErrorInfo errorInfo(Message message, TermItException e) { e.getParameters()); } + /** + * Searches available methods annotated with {@link MessageExceptionHandler} in this class + * when the method signature matches {@code (Message, Exception)} + * and the exception parameter is assignable from the supplied throwable + * the method is called. + * + * @param message the associated message + * @param throwable the exception to handle + * @return true when a method was found and called, false otherwise + */ + public boolean delegate(Message message, Throwable throwable) { + try { + return delegateInternal(message, throwable); + } catch (InvocationTargetException invEx) { + // Exception handler method threw an exception + LOG.error("Exception thrown during exception handler invocation", invEx); + } catch (IllegalAccessException unexpected) { + // is checked by delegateInternal + } + return false; + } + + /** + * Searches available methods annotated with {@link MessageExceptionHandler} in this class + * when the method signature matches {@code (Message, Exception)} + * and the exception parameter is assignable from the supplied throwable + * the method is called. + * + * @param message the associated message + * @param throwable the exception to handle + * @return true when a method was found and called, false otherwise + * @throws IllegalArgumentException never + * @throws IllegalAccessException never + * @throws InvocationTargetException when the exception handler method throws an exception + */ + private boolean delegateInternal(Message message, Throwable throwable) + throws InvocationTargetException, IllegalAccessException { + // handle only exceptions + if (throwable instanceof Exception exception) { + // find all methods annotated with MessageExceptionHandler + List methods = Arrays.stream(this.getClass().getMethods()) + .filter(m -> m.isAnnotationPresent(MessageExceptionHandler.class)).toList(); + for (final Method method : methods) { + // check for reflection access to prevent IllegalAccessException + if (!method.canAccess(this)) { + continue; + } + // we are interested only in methods with exactly two parameters (message, exception) + Class[] params = method.getParameterTypes(); + if (params.length != 2) { + continue; + } + // check if the MessageExceptionHandler annotation has value with allowed exceptions + Class[] allowedExceptions = Optional.ofNullable(method.getAnnotation(MessageExceptionHandler.class)) + .map(MessageExceptionHandler::value).orElseGet(() -> new Class[0]); + // if the exception is not allowed by the annotation, skip the method + if (allowedExceptions.length > 0 && Arrays.stream(allowedExceptions).noneMatch(e -> e.isAssignableFrom(exception.getClass()))) { + continue; + } + // validate the method signature + if (params[0].isAssignableFrom(message.getClass()) && params[1].isAssignableFrom(exception.getClass())) { + // call the method with message, exception parameters + method.invoke(this, message, exception); + return true; // exception was handled + } + } + } + // throwable is not an exception or no suitable method was found + return false; + } + @MessageExceptionHandler public void messageDeliveryException(Message message, MessageDeliveryException e) { + if (!(e.getCause() instanceof MessageDeliveryException) && delegate(message, e.getCause())) { + return; + } + // messages without destination will be logged only on trace (hasDestination(message) ? LOG.atError() : LOG.atTrace()) .setMessage("Failed to send message with destination {}: {}") @@ -144,11 +228,16 @@ public ErrorInfo authorizationException(Message message, AuthorizationExcepti return errorInfo(message, e); } - @MessageExceptionHandler(AuthenticationException.class) + @MessageExceptionHandler({AuthenticationException.class, AuthenticationServiceException.class}) public ErrorInfo authenticationException(Message message, AuthenticationException e) { - LOG.atDebug().setCause(e).log(e.getMessage()); - LOG.atError().setMessage("Authentication failure during message processing: {}\nMessage: {}") + LOG.atWarn().setMessage("Authentication failure during message processing: {}\nMessage: {}") .addArgument(e.getMessage()).addArgument(message::toString).log(); + + if (ExceptionUtils.findCause(e, JwtException.class).isPresent()) { + return errorInfo(message, e); + } + + LOG.atDebug().setCause(e).log(e.getMessage()); return errorInfo(message, e); } diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java b/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java index c1d2d81b0..0d47b608d 100644 --- a/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java +++ b/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java @@ -37,6 +37,7 @@ import static org.awaitility.Awaitility.await; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.ArgumentMatchers.notNull; import static org.mockito.Mockito.verify; @@ -86,7 +87,8 @@ void connectionIsClosedOnAnyMessageBeforeConnect(String stompCommand, Boolean wi assertTrue(receivedError.get()); assertFalse(session.isOpen()); assertFalse(receivedReply.get()); - verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull()); + verify(webSocketExceptionHandler).delegate(notNull(), notNull()); + verify(webSocketExceptionHandler).accessDeniedException(notNull(), notNull()); } WebSocketHandler makeWebSocketHandler(AtomicBoolean receivedReply, AtomicBoolean receivedError) { @@ -131,7 +133,8 @@ void connectWithInvalidAuthorizationIsRejected() throws Throwable { assertTrue(receivedError.get()); assertFalse(session.isOpen()); assertFalse(receivedReply.get()); - verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull()); + verify(webSocketExceptionHandler).delegate(notNull(), notNull()); + verify(webSocketExceptionHandler).authenticationException(notNull(), notNull()); } /** @@ -167,7 +170,8 @@ void connectWithInvalidJwtAuthorizationIsRejected() throws Throwable { assertFalse(session.isOpen()); assertFalse(receivedReply.get()); - verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull()); + verify(webSocketExceptionHandler).delegate(notNull(), notNull()); + verify(webSocketExceptionHandler).authenticationException(notNull(), notNull()); } /** @@ -186,5 +190,6 @@ void connectionIsNotClosedWhenConnectMessageIsSent() throws Throwable { assertTrue(session.isConnected()); session.disconnect(); await().atMost(OPERATION_TIMEOUT, TimeUnit.SECONDS).until(() -> !session.isConnected()); + verify(webSocketExceptionHandler).delegate(notNull(), isNull()); } }