namespace) {
- final URI identifier = resolveIdentifier(namespace.orElse(config.getNamespace().getVocabulary()), localName);
- final Vocabulary vocabulary = vocabularyService.getReference(identifier);
- return vocabularyService.validateContents(vocabulary);
- }
-
@Operation(security = {@SecurityRequirement(name = "bearer-key")},
description = "Creates a snapshot of the vocabulary with the specified identifier.")
@ApiResponses({
diff --git a/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java
index c17a253d1..90eb63cac 100644
--- a/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java
+++ b/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java
@@ -56,6 +56,7 @@
* The general pattern should be that unless an exception can be handled in a more appropriate place it bubbles up to a
* REST controller which originally received the request. There, it is caught by this handler, logged and a reasonable
* error message is returned to the user.
+ * @implSpec Should reflect {@link cz.cvut.kbss.termit.websocket.handler.WebSocketExceptionHandler}
*/
@RestControllerAdvice
public class RestExceptionHandler {
diff --git a/src/main/java/cz/cvut/kbss/termit/security/WebSocketJwtAuthorizationInterceptor.java b/src/main/java/cz/cvut/kbss/termit/security/WebSocketJwtAuthorizationInterceptor.java
new file mode 100644
index 000000000..71e5627de
--- /dev/null
+++ b/src/main/java/cz/cvut/kbss/termit/security/WebSocketJwtAuthorizationInterceptor.java
@@ -0,0 +1,65 @@
+package cz.cvut.kbss.termit.security;
+
+import cz.cvut.kbss.termit.exception.AuthorizationException;
+import cz.cvut.kbss.termit.exception.JwtException;
+import cz.cvut.kbss.termit.security.model.TermItUserDetails;
+import cz.cvut.kbss.termit.service.security.SecurityUtils;
+import cz.cvut.kbss.termit.service.security.TermItUserDetailsService;
+import org.jetbrains.annotations.NotNull;
+import org.springframework.http.HttpHeaders;
+import org.springframework.messaging.Message;
+import org.springframework.messaging.MessageChannel;
+import org.springframework.messaging.simp.stomp.StompCommand;
+import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
+import org.springframework.messaging.support.ChannelInterceptor;
+import org.springframework.messaging.support.MessageHeaderAccessor;
+import org.springframework.security.authentication.DisabledException;
+import org.springframework.security.authentication.LockedException;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.userdetails.UsernameNotFoundException;
+
+/**
+ * Authorizes STOMP CONNECT messages
+ *
+ * Retrieves token from the {@code Authorization} header of STOMP message and validates JWT token.
+ */
+public class WebSocketJwtAuthorizationInterceptor implements ChannelInterceptor {
+
+ private final JwtUtils jwtUtils;
+
+ private final TermItUserDetailsService userDetailsService;
+
+ public WebSocketJwtAuthorizationInterceptor(JwtUtils jwtUtils, TermItUserDetailsService userDetailsService) {
+ this.jwtUtils = jwtUtils;
+ this.userDetailsService = userDetailsService;
+ }
+
+ @Override
+ public Message> preSend(@NotNull Message> message, @NotNull MessageChannel channel) {
+ StompHeaderAccessor headerAccessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
+ if (headerAccessor != null && StompCommand.CONNECT.equals(headerAccessor.getCommand()) && headerAccessor.isMutable()) {
+ final String authHeader = headerAccessor.getFirstNativeHeader(HttpHeaders.AUTHORIZATION);
+ if (authHeader != null && authHeader.startsWith(SecurityConstants.JWT_TOKEN_PREFIX)) {
+ headerAccessor.removeNativeHeader(HttpHeaders.AUTHORIZATION);
+ return process(message, authHeader, headerAccessor);
+ }
+ throw new AuthorizationException("Authorization header is invalid");
+ }
+ return message;
+ }
+
+ private Message> process(final @NotNull Message> message, final @NotNull String authHeader,
+ final @NotNull StompHeaderAccessor headerAccessor) {
+ final String authToken = authHeader.substring(SecurityConstants.JWT_TOKEN_PREFIX.length());
+ try {
+ final TermItUserDetails userDetails = jwtUtils.extractUserInfo(authToken);
+ final TermItUserDetails existingDetails = userDetailsService.loadUserByUsername(userDetails.getUsername());
+ SecurityUtils.verifyAccountStatus(existingDetails.getUser());
+ Authentication authentication = SecurityUtils.setCurrentUser(existingDetails);
+ headerAccessor.setUser(authentication);
+ return message;
+ } catch (JwtException | DisabledException | LockedException | UsernameNotFoundException e) {
+ throw new AuthorizationException(e.getMessage());
+ }
+ }
+}
diff --git a/src/main/java/cz/cvut/kbss/termit/util/Constants.java b/src/main/java/cz/cvut/kbss/termit/util/Constants.java
index d46203655..fb0959d8f 100644
--- a/src/main/java/cz/cvut/kbss/termit/util/Constants.java
+++ b/src/main/java/cz/cvut/kbss/termit/util/Constants.java
@@ -243,4 +243,15 @@ private QueryParams() {
throw new AssertionError();
}
}
+
+ /**
+ * the maximum amount of data to buffer when sending messages to a WebSocket session
+ */
+ public static final int WEBSOCKET_SEND_BUFFER_SIZE_LIMIT = Integer.MAX_VALUE;
+
+ /**
+ * Set the maximum time allowed in milliseconds after the WebSocket connection is established
+ * and before the first sub-protocol message is received.
+ */
+ public static final int WEBSOCKET_TIME_TO_FIRST_MESSAGE = 15 * 1000 /* 15s */;
}
diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/ResultWithHeaders.java b/src/main/java/cz/cvut/kbss/termit/websocket/ResultWithHeaders.java
new file mode 100644
index 000000000..1ec5c0ef6
--- /dev/null
+++ b/src/main/java/cz/cvut/kbss/termit/websocket/ResultWithHeaders.java
@@ -0,0 +1,68 @@
+package cz.cvut.kbss.termit.websocket;
+
+import cz.cvut.kbss.termit.websocket.handler.WebSocketMessageWithHeadersValueHandler;
+import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler;
+import org.jetbrains.annotations.NotNull;
+import org.jetbrains.annotations.Nullable;
+import org.springframework.messaging.handler.annotation.SendTo;
+import org.springframework.messaging.simp.annotation.SendToUser;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Wrapper carrying a result from WebSocket controller
+ * including the {@link #payload}, {@link #destination} and {@link #headers} for the resulting message.
+ *
+ * Do not combine with other method-return-value handlers (like {@link SendTo @SendTo})
+ *
+ * The {@code ResultWithHeaders} is then handled by {@link WebSocketMessageWithHeadersValueHandler}.
+ * Every value returned from a controller method
+ * can be handled only by a single {@link HandlerMethodReturnValueHandler}.
+ * Annotations like {@link SendTo @SendTo}/{@link SendToUser @SendToUser}
+ * are handled by separate return value handlers, so only one can be used simultaneously.
+ *
+ * @param payload The actual result of the method
+ * @param destination The destination channel where the message will be sent
+ * @param headers Headers that will overwrite headers in the message.
+ * @param The type of the payload
+ * @see WebSocketMessageWithHeadersValueHandler
+ * @see HandlerMethodReturnValueHandler
+ */
+public record ResultWithHeaders(T payload, @NotNull String destination, @NotNull Map headers,
+ boolean toUser) {
+
+ public static ResultWithHeadersBuilder result(T payload) {
+ return new ResultWithHeadersBuilder<>(payload);
+ }
+
+ public static class ResultWithHeadersBuilder {
+
+ private final T payload;
+
+ private @Nullable Map headers = null;
+
+ private ResultWithHeadersBuilder(T payload) {
+ this.payload = payload;
+ }
+
+ /**
+ * All values will be mapped to strings with {@link Object#toString()}
+ */
+ public ResultWithHeadersBuilder withHeaders(@NotNull Map headers) {
+ this.headers = new HashMap<>();
+ headers.forEach((key, value) -> this.headers.put(key, value.toString()));
+ this.headers = Collections.unmodifiableMap(this.headers);
+ return this;
+ }
+
+ public ResultWithHeaders sendTo(String destination) {
+ return new ResultWithHeaders<>(payload, destination, headers == null ? Map.of() : headers, false);
+ }
+
+ public ResultWithHeaders sendToUser(String userDestination) {
+ return new ResultWithHeaders<>(payload, userDestination, headers == null ? Map.of() : headers, true);
+ }
+ }
+}
diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/VocabularySocketController.java b/src/main/java/cz/cvut/kbss/termit/websocket/VocabularySocketController.java
new file mode 100644
index 000000000..67c1af291
--- /dev/null
+++ b/src/main/java/cz/cvut/kbss/termit/websocket/VocabularySocketController.java
@@ -0,0 +1,51 @@
+package cz.cvut.kbss.termit.websocket;
+
+import cz.cvut.kbss.termit.model.Vocabulary;
+import cz.cvut.kbss.termit.model.validation.ValidationResult;
+import cz.cvut.kbss.termit.rest.BaseController;
+import cz.cvut.kbss.termit.security.SecurityConstants;
+import cz.cvut.kbss.termit.service.IdentifierResolver;
+import cz.cvut.kbss.termit.service.business.VocabularyService;
+import cz.cvut.kbss.termit.util.Configuration;
+import cz.cvut.kbss.termit.util.Constants;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.messaging.handler.annotation.DestinationVariable;
+import org.springframework.messaging.handler.annotation.Header;
+import org.springframework.messaging.handler.annotation.MessageMapping;
+import org.springframework.security.access.prepost.PreAuthorize;
+import org.springframework.stereotype.Controller;
+
+import java.net.URI;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import static cz.cvut.kbss.termit.websocket.ResultWithHeaders.result;
+
+@Controller
+@MessageMapping("/vocabularies")
+@PreAuthorize("hasRole('" + SecurityConstants.ROLE_RESTRICTED_USER + "')")
+public class VocabularySocketController extends BaseController {
+
+ private final VocabularyService vocabularyService;
+
+ protected VocabularySocketController(IdentifierResolver idResolver, Configuration config,
+ VocabularyService vocabularyService) {
+ super(idResolver, config);
+ this.vocabularyService = vocabularyService;
+ }
+
+ /**
+ * Validates the terms in a vocabulary with the specified identifier.
+ */
+ @MessageMapping("/{localName}/validate")
+ public ResultWithHeaders> validateVocabulary(@DestinationVariable String localName,
+ @Header(name = Constants.QueryParams.NAMESPACE,
+ required = false) Optional namespace) {
+ final URI identifier = resolveIdentifier(namespace.orElse(config.getNamespace().getVocabulary()), localName);
+ final Vocabulary vocabulary = vocabularyService.getReference(identifier);
+ return result(vocabularyService.validateContents(vocabulary)).withHeaders(Map.of("vocabulary", identifier))
+ .sendToUser("/vocabularies/validation");
+ }
+}
diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java
new file mode 100644
index 000000000..4f78bf920
--- /dev/null
+++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java
@@ -0,0 +1,21 @@
+package cz.cvut.kbss.termit.websocket.handler;
+
+import org.jetbrains.annotations.NotNull;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.messaging.Message;
+import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
+import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler;
+
+public class StompExceptionHandler extends StompSubProtocolErrorHandler {
+
+ private static final Logger LOG = LoggerFactory.getLogger(StompExceptionHandler.class);
+
+ @Override
+ protected @NotNull Message handleInternal(@NotNull StompHeaderAccessor errorHeaderAccessor,
+ byte @NotNull [] errorPayload,
+ Throwable cause, StompHeaderAccessor clientHeaderAccessor) {
+ LOG.error("STOMP sub-protocol exception", cause);
+ return super.handleInternal(errorHeaderAccessor, errorPayload, cause, clientHeaderAccessor);
+ }
+}
diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java
new file mode 100644
index 000000000..50411c650
--- /dev/null
+++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java
@@ -0,0 +1,216 @@
+package cz.cvut.kbss.termit.websocket.handler;
+
+import cz.cvut.kbss.jopa.exceptions.EntityNotFoundException;
+import cz.cvut.kbss.jopa.exceptions.OWLPersistenceException;
+import cz.cvut.kbss.jsonld.exception.JsonLdException;
+import cz.cvut.kbss.termit.exception.AnnotationGenerationException;
+import cz.cvut.kbss.termit.exception.AssetRemovalException;
+import cz.cvut.kbss.termit.exception.AuthorizationException;
+import cz.cvut.kbss.termit.exception.InvalidLanguageConstantException;
+import cz.cvut.kbss.termit.exception.InvalidParameterException;
+import cz.cvut.kbss.termit.exception.InvalidPasswordChangeRequestException;
+import cz.cvut.kbss.termit.exception.InvalidTermStateException;
+import cz.cvut.kbss.termit.exception.NotFoundException;
+import cz.cvut.kbss.termit.exception.PersistenceException;
+import cz.cvut.kbss.termit.exception.ResourceExistsException;
+import cz.cvut.kbss.termit.exception.SnapshotNotEditableException;
+import cz.cvut.kbss.termit.exception.SuppressibleLogging;
+import cz.cvut.kbss.termit.exception.TermItException;
+import cz.cvut.kbss.termit.exception.UnsupportedOperationException;
+import cz.cvut.kbss.termit.exception.UnsupportedSearchFacetException;
+import cz.cvut.kbss.termit.exception.ValidationException;
+import cz.cvut.kbss.termit.exception.WebServiceIntegrationException;
+import cz.cvut.kbss.termit.exception.importing.UnsupportedImportMediaTypeException;
+import cz.cvut.kbss.termit.exception.importing.VocabularyImportException;
+import cz.cvut.kbss.termit.rest.handler.ErrorInfo;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.messaging.Message;
+import org.springframework.messaging.MessageDeliveryException;
+import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
+import org.springframework.messaging.simp.annotation.SendToUser;
+import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
+import org.springframework.security.core.userdetails.UsernameNotFoundException;
+import org.springframework.web.bind.annotation.ControllerAdvice;
+import org.springframework.web.multipart.MaxUploadSizeExceededException;
+
+/**
+ * @implSpec Should reflect {@link cz.cvut.kbss.termit.rest.handler.RestExceptionHandler}
+ */
+@SendToUser
+@ControllerAdvice
+public class WebSocketExceptionHandler {
+
+ private static final Logger LOG = LoggerFactory.getLogger(WebSocketExceptionHandler.class);
+
+ private static String destination(Message> message) {
+ return message.getHeaders().getOrDefault("destination", "missing destination").toString();
+ }
+
+ private static void logException(TermItException ex, Message> message) {
+ if (shouldSuppressLogging(ex)) {
+ return;
+ }
+ logException("Exception caught when processing request to '" + destination(message) + "'.", ex);
+ }
+
+ private static boolean shouldSuppressLogging(TermItException ex) {
+ return ex.getClass().getAnnotation(SuppressibleLogging.class) != null;
+ }
+
+ private static void logException(Throwable ex, Message> message) {
+ logException("Exception caught when processing request to '" + destination(message) + "'.", ex);
+ }
+
+ private static void logException(String message, Throwable ex) {
+ LOG.error(message, ex);
+ }
+
+ private static ErrorInfo errorInfo(Message> message, Throwable e) {
+ return ErrorInfo.createWithMessage(e.getMessage(), destination(message));
+ }
+
+ @MessageExceptionHandler
+ public void messageDeliveryException(Message> message, MessageDeliveryException e) {
+ final StompHeaderAccessor headerAccessor = StompHeaderAccessor.wrap(message);
+ LOG.error("Failed to send message with destination {}: {}", headerAccessor.getDestination(), e.getMessage());
+ }
+
+ @MessageExceptionHandler(PersistenceException.class)
+ public ErrorInfo persistenceException(Message> message, PersistenceException e) {
+ logException(e, message);
+ return errorInfo(message, e.getCause());
+ }
+
+ @MessageExceptionHandler(OWLPersistenceException.class)
+ public ErrorInfo jopaException(Message> message, OWLPersistenceException e) {
+ logException("Persistence exception caught.", e);
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler(ResourceExistsException.class)
+ public ErrorInfo resourceExistsException(Message> message, ResourceExistsException e) {
+ logException(e, message);
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler(NotFoundException.class)
+ public ErrorInfo resourceNotFound(Message> message, NotFoundException e) {
+ // Not necessary to log NotFoundException, they may be quite frequent and do not represent an issue with the application
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler(UsernameNotFoundException.class)
+ public ErrorInfo usernameNotFound(Message> message, UsernameNotFoundException e) {
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler(EntityNotFoundException.class)
+ public ErrorInfo entityNotFoundException(Message> message, EntityNotFoundException e) {
+ logException(e, message);
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler(AuthorizationException.class)
+ public ErrorInfo authorizationException(Message> message, AuthorizationException e) {
+ logException(e, message);
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler(ValidationException.class)
+ public ErrorInfo validationException(Message> message, ValidationException e) {
+ logException(e, message);
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler(WebServiceIntegrationException.class)
+ public ErrorInfo webServiceIntegrationException(Message> message, WebServiceIntegrationException e) {
+ logException(e.getCause(), message);
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler(AnnotationGenerationException.class)
+ public ErrorInfo annotationGenerationException(Message> message, AnnotationGenerationException e) {
+ logException(e.getCause(), message);
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler(TermItException.class)
+ public ErrorInfo termItException(Message> message, TermItException e) {
+ logException(e, message);
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler(JsonLdException.class)
+ public ErrorInfo jsonLdException(Message> message, JsonLdException e) {
+ logException(e, message);
+ return ErrorInfo.createWithMessage("Error when processing JSON-LD.", destination(message));
+ }
+
+ @MessageExceptionHandler(UnsupportedOperationException.class)
+ public ErrorInfo unsupportedAssetOperationException(Message> message, UnsupportedOperationException e) {
+ logException(e, message);
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler(VocabularyImportException.class)
+ public ErrorInfo vocabularyImportException(Message> message, VocabularyImportException e) {
+ logException(e, message);
+ return ErrorInfo.createWithMessageAndMessageId(e.getMessage(), e.getMessageId(), destination(message));
+ }
+
+ @MessageExceptionHandler
+ public ErrorInfo unsupportedImportMediaTypeException(Message> message, UnsupportedImportMediaTypeException e) {
+ logException(e, message);
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler
+ public ErrorInfo assetRemovalException(Message> message, AssetRemovalException e) {
+ logException(e, message);
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler
+ public ErrorInfo invalidParameter(Message> message, InvalidParameterException e) {
+ logException(e, message);
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler
+ public ErrorInfo maxUploadSizeExceededException(Message> message, MaxUploadSizeExceededException e) {
+ logException(e, message);
+ return ErrorInfo.createWithMessageAndMessageId(e.getMessage(), "error.file.maxUploadSizeExceeded", destination(message));
+ }
+
+ @MessageExceptionHandler
+ public ErrorInfo snapshotNotEditableException(Message> message, SnapshotNotEditableException e) {
+ logException(e, message);
+ return ErrorInfo.createWithMessage(e.getMessage(), destination(message));
+ }
+
+ @MessageExceptionHandler
+ public ErrorInfo unsupportedSearchFacetException(Message> message, UnsupportedSearchFacetException e) {
+ logException(e, message);
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler
+ public ErrorInfo invalidLanguageConstantException(Message> message, InvalidLanguageConstantException e) {
+ logException(e, message);
+ return errorInfo(message, e);
+ }
+
+ @MessageExceptionHandler
+ public ErrorInfo invalidTermStateException(Message> message, InvalidTermStateException e) {
+ logException(e, message);
+ return ErrorInfo.createWithMessageAndMessageId(e.getMessage(), e.getMessageId(), destination(message));
+ }
+
+ @MessageExceptionHandler
+ public ErrorInfo invalidPasswordChangeRequestException(Message> message,
+ InvalidPasswordChangeRequestException e) {
+ logException(e, message);
+ return ErrorInfo.createWithMessageAndMessageId(e.getMessage(), e.getMessageId(), destination(message));
+ }
+}
diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketMessageWithHeadersValueHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketMessageWithHeadersValueHandler.java
new file mode 100644
index 000000000..5494294f2
--- /dev/null
+++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketMessageWithHeadersValueHandler.java
@@ -0,0 +1,46 @@
+package cz.cvut.kbss.termit.websocket.handler;
+
+import cz.cvut.kbss.termit.exception.UnsupportedOperationException;
+import cz.cvut.kbss.termit.websocket.ResultWithHeaders;
+import org.jetbrains.annotations.NotNull;
+import org.springframework.core.MethodParameter;
+import org.springframework.messaging.Message;
+import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler;
+import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
+import org.springframework.messaging.simp.SimpMessagingTemplate;
+import org.springframework.messaging.simp.annotation.support.MissingSessionUserException;
+import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
+
+public class WebSocketMessageWithHeadersValueHandler implements HandlerMethodReturnValueHandler {
+
+ private final SimpMessagingTemplate simpMessagingTemplate;
+
+ public WebSocketMessageWithHeadersValueHandler(SimpMessagingTemplate simpMessagingTemplate) {
+ this.simpMessagingTemplate = simpMessagingTemplate;
+ }
+
+ @Override
+ public boolean supportsReturnType(MethodParameter returnType) {
+ return ResultWithHeaders.class.isAssignableFrom(returnType.getParameterType());
+ }
+
+ @Override
+ public void handleReturnValue(Object returnValue, @NotNull MethodParameter returnType, @NotNull Message> message)
+ throws Exception {
+ if (returnValue instanceof ResultWithHeaders> resultWithHeaders) {
+ final StompHeaderAccessor headerAccessor = StompHeaderAccessor.wrap(message);
+ resultWithHeaders.headers().forEach(headerAccessor::setNativeHeader);
+ if (resultWithHeaders.toUser()) {
+ final String sessionId = SimpMessageHeaderAccessor.getSessionId(headerAccessor.toMessageHeaders());
+ if (sessionId == null || sessionId.isBlank()) {
+ throw new MissingSessionUserException(message);
+ }
+ simpMessagingTemplate.convertAndSendToUser(sessionId, resultWithHeaders.destination(), resultWithHeaders.payload(), headerAccessor.toMessageHeaders());
+ } else {
+ simpMessagingTemplate.convertAndSend(resultWithHeaders.destination(), resultWithHeaders.payload(), headerAccessor.toMessageHeaders());
+ }
+ return;
+ }
+ throw new UnsupportedOperationException("Unable to process returned value: " + returnValue + " of type " + returnType.getParameterType() + " from " + returnType.getMethod());
+ }
+}
diff --git a/src/test/java/cz/cvut/kbss/termit/environment/Environment.java b/src/test/java/cz/cvut/kbss/termit/environment/Environment.java
index e9db7449e..9f7bd62a1 100644
--- a/src/test/java/cz/cvut/kbss/termit/environment/Environment.java
+++ b/src/test/java/cz/cvut/kbss/termit/environment/Environment.java
@@ -41,6 +41,7 @@
import org.springframework.http.converter.ResourceHttpMessageConverter;
import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
+import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextImpl;
@@ -81,11 +82,13 @@ private static DtoMapper initDtoMapper() {
*
* @param user User to set as currently authenticated
*/
- public static void setCurrentUser(UserAccount user) {
+ public static Authentication setCurrentUser(UserAccount user) {
final TermItUserDetails userDetails = new TermItUserDetails(user, new HashSet<>());
SecurityContext context = new SecurityContextImpl();
- context.setAuthentication(new AuthenticationToken(userDetails.getAuthorities(), userDetails));
+ Authentication authentication = new AuthenticationToken(userDetails.getAuthorities(), userDetails);
+ context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);
+ return authentication;
}
/**
diff --git a/src/test/java/cz/cvut/kbss/termit/environment/config/TestWebSocketConfig.java b/src/test/java/cz/cvut/kbss/termit/environment/config/TestWebSocketConfig.java
new file mode 100644
index 000000000..f3cd62e7d
--- /dev/null
+++ b/src/test/java/cz/cvut/kbss/termit/environment/config/TestWebSocketConfig.java
@@ -0,0 +1,98 @@
+package cz.cvut.kbss.termit.environment.config;
+
+import cz.cvut.kbss.termit.config.WebAppConfig;
+import cz.cvut.kbss.termit.config.WebSocketConfig;
+import cz.cvut.kbss.termit.util.Configuration;
+import cz.cvut.kbss.termit.websocket.util.ReturnValueCollectingSimpMessagingTemplate;
+import org.jetbrains.annotations.NotNull;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.boot.test.context.TestConfiguration;
+import org.springframework.context.ApplicationListener;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.ComponentScan;
+import org.springframework.context.annotation.Import;
+import org.springframework.context.annotation.Lazy;
+import org.springframework.context.annotation.Primary;
+import org.springframework.context.event.ContextRefreshedEvent;
+import org.springframework.core.task.SyncTaskExecutor;
+import org.springframework.messaging.MessageHandler;
+import org.springframework.messaging.SubscribableChannel;
+import org.springframework.messaging.converter.CompositeMessageConverter;
+import org.springframework.messaging.simp.SimpMessagingTemplate;
+import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler;
+import org.springframework.messaging.simp.config.ChannelRegistration;
+import org.springframework.messaging.simp.config.MessageBrokerRegistry;
+import org.springframework.messaging.support.AbstractSubscribableChannel;
+import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+
+@TestConfiguration
+@EnableConfigurationProperties(Configuration.class)
+@Import({TestSecurityConfig.class, TestRestSecurityConfig.class, WebAppConfig.class, WebSocketConfig.class})
+@ComponentScan(basePackages = "cz.cvut.kbss.termit.websocket")
+public class TestWebSocketConfig
+ implements ApplicationListener, WebSocketMessageBrokerConfigurer {
+
+ private final List channels;
+
+ private final List handlers;
+
+ @Autowired
+ @Lazy
+ public TestWebSocketConfig(List channels, List handlers) {
+ this.channels = channels;
+ this.handlers = handlers;
+ }
+
+ /**
+ * Unregisters MessageHandler's from the message channels to reduce processing during the test.
+ * Also stops further processing so for example user responses remain in the broker channel.
+ * @param event the event to respond to
+ */
+ @Override
+ public void onApplicationEvent(@NotNull ContextRefreshedEvent event) {
+ for (MessageHandler handler : handlers) {
+ if (handler instanceof SimpAnnotationMethodMessageHandler) {
+ continue;
+ }
+ for (SubscribableChannel channel : channels) {
+ channel.unsubscribe(handler);
+ }
+ }
+ }
+
+ @Override
+ public void configureClientInboundChannel(ChannelRegistration registration) {
+ registration.executor(new SyncTaskExecutor());
+ }
+
+ @Override
+ public void configureClientOutboundChannel(ChannelRegistration registration) {
+ registration.executor(new SyncTaskExecutor());
+ }
+
+ @Override
+ public void configureMessageBroker(MessageBrokerRegistry registry) {
+ registry.configureBrokerChannel().executor(new SyncTaskExecutor());
+ }
+
+ @Bean
+ public Map returnedValuesMap() {
+ return new HashMap<>();
+ }
+
+ @Bean
+ @Primary
+ public SimpMessagingTemplate brokerMessagingTemplate(
+ AbstractSubscribableChannel brokerChannel, CompositeMessageConverter brokerMessageConverter) {
+
+ SimpMessagingTemplate template = new ReturnValueCollectingSimpMessagingTemplate(brokerChannel, returnedValuesMap());
+ template.setMessageConverter(brokerMessageConverter);
+ return template;
+ }
+}
diff --git a/src/test/java/cz/cvut/kbss/termit/rest/VocabularyControllerTest.java b/src/test/java/cz/cvut/kbss/termit/rest/VocabularyControllerTest.java
index cccd38e09..76b2e0521 100644
--- a/src/test/java/cz/cvut/kbss/termit/rest/VocabularyControllerTest.java
+++ b/src/test/java/cz/cvut/kbss/termit/rest/VocabularyControllerTest.java
@@ -467,30 +467,6 @@ void getHistoryOfContentReturnsListOfAggregatedChangeObjectsForTermsInSpecifiedV
verify(serviceMock).getChangesOfContent(vocabulary);
}
- @Test
- void validateExecutesServiceValidate() throws Exception {
- final Vocabulary vocabulary = generateVocabularyAndInitReferenceResolution();
- final List records = Collections.singletonList(new ValidationResult()
- .setTermUri(Generator.generateUri())
- .setIssueCauseUri(
- Generator.generateUri())
- .setSeverity(URI.create(
- SH.Violation.toString())));
- when(serviceMock.validateContents(vocabulary)).thenReturn(records);
-
-
- final MvcResult mvcResult = mockMvc.perform(get(PATH + "/" + FRAGMENT + "/validate"))
- .andExpect(status().isOk())
- .andReturn();
- final List result =
- readValue(mvcResult, new TypeReference>() {
- });
- assertNotNull(result);
- assertEquals(records.stream().map(ValidationResult::getId).collect(Collectors.toList()),
- result.stream().map(ValidationResult::getId).collect(Collectors.toList()));
- verify(serviceMock).validateContents(vocabulary);
- }
-
private Vocabulary generateVocabularyAndInitReferenceResolution() {
final Vocabulary vocabulary = generateVocabulary();
vocabulary.setUri(VOCABULARY_URI);
diff --git a/src/test/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolverTest.java b/src/test/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolverTest.java
index 627a24feb..bf85ea85e 100644
--- a/src/test/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolverTest.java
+++ b/src/test/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolverTest.java
@@ -32,6 +32,8 @@
import org.jsoup.Jsoup;
import org.jsoup.select.Elements;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.DisabledOnOs;
+import org.junit.jupiter.api.condition.OS;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
@@ -177,6 +179,7 @@ void findTermOccurrencesMarksOccurrencesAsSuggested() {
}
@Test
+ @DisabledOnOs(OS.WINDOWS) // TODO: https://github.com/kbss-cvut/termit/issues/275
void findTermOccurrencesSetsFoundOccurrencesAsApprovedWhenCorrespondingExistingOccurrenceWasApproved() throws Exception {
when(termService.exists(TERM_URI)).thenReturn(true);
final File file = initFile();
diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketControllerTestRunner.java b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketControllerTestRunner.java
new file mode 100644
index 000000000..4c4c944d8
--- /dev/null
+++ b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketControllerTestRunner.java
@@ -0,0 +1,103 @@
+package cz.cvut.kbss.termit.websocket;
+
+import cz.cvut.kbss.termit.environment.config.TestWebSocketConfig;
+import cz.cvut.kbss.termit.util.Configuration;
+import cz.cvut.kbss.termit.websocket.util.CachingChannelInterceptor;
+import jakarta.annotation.PostConstruct;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Qualifier;
+import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.boot.test.context.ConfigDataApplicationContextInitializer;
+import org.springframework.messaging.Message;
+import org.springframework.messaging.support.AbstractSubscribableChannel;
+import org.springframework.test.annotation.DirtiesContext;
+import org.springframework.test.context.ActiveProfiles;
+import org.springframework.test.context.ContextConfiguration;
+import org.springframework.test.context.junit.jupiter.SpringExtension;
+
+import java.util.Map;
+import java.util.Optional;
+import java.util.UUID;
+
+import static cz.cvut.kbss.termit.websocket.util.ReturnValueCollectingSimpMessagingTemplate.MESSAGE_IDENTIFIER_HEADER;
+
+@ActiveProfiles("test")
+@ExtendWith(SpringExtension.class)
+@ExtendWith(MockitoExtension.class)
+@EnableConfigurationProperties({Configuration.class})
+@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_CLASS)
+@ContextConfiguration(classes = {TestWebSocketConfig.class},
+ initializers = {ConfigDataApplicationContextInitializer.class})
+public abstract class BaseWebSocketControllerTestRunner {
+
+ private static final Logger LOG = LoggerFactory.getLogger(BaseWebSocketControllerTestRunner.class);
+
+ /**
+ * Simulated messages from client to server
+ */
+ @Autowired
+ @Qualifier("clientInboundChannel")
+ protected AbstractSubscribableChannel serverInboundChannel;
+
+ /**
+ * Messages sent from the server to the client
+ */
+ @Autowired
+ @Qualifier("clientOutboundChannel")
+ protected AbstractSubscribableChannel serverOutboundChannel;
+
+ @Autowired
+ protected AbstractSubscribableChannel brokerChannel;
+
+ /**
+ * Holds message ids mapped to the values returned from the controllers
+ */
+ @Autowired
+ protected Map returnedValuesMap;
+
+ /**
+ * Caches any messages sent from the server to the client
+ */
+ protected CachingChannelInterceptor serverOutboundChannelInterceptor;
+
+ protected CachingChannelInterceptor brokerChannelInterceptor;
+
+ @PostConstruct
+ protected void runnerPostConstruct() {
+ this.brokerChannelInterceptor = new CachingChannelInterceptor();
+ this.serverOutboundChannelInterceptor = new CachingChannelInterceptor();
+
+ this.brokerChannel.addInterceptor(this.brokerChannelInterceptor);
+ this.serverOutboundChannel.addInterceptor(this.serverOutboundChannelInterceptor);
+ }
+
+ @BeforeEach
+ protected void runnerBeforeEach() {
+ this.serverOutboundChannelInterceptor.reset();
+ this.brokerChannelInterceptor.reset();
+ this.returnedValuesMap.clear();
+ }
+
+ /**
+ * Returns result of controller method associated with the specified message.
+ *
+ * @param message The message sent from some controller
+ */
+ @SuppressWarnings("unchecked")
+ protected Optional readPayload(Message> message) throws ClassCastException {
+ final UUID id = message.getHeaders().get(MESSAGE_IDENTIFIER_HEADER, UUID.class);
+ if (id == null) {
+ LOG.error("Unable to read message payload. Message id is null.");
+ return Optional.empty();
+ }
+ if (returnedValuesMap.containsKey(id)) {
+ return Optional.of((T) returnedValuesMap.get(id));
+ }
+ return Optional.empty();
+ }
+}
diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketIntegrationTestRunner.java b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketIntegrationTestRunner.java
new file mode 100644
index 000000000..59ba7d502
--- /dev/null
+++ b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketIntegrationTestRunner.java
@@ -0,0 +1,169 @@
+package cz.cvut.kbss.termit.websocket;
+
+import cz.cvut.kbss.termit.config.AppConfig;
+import cz.cvut.kbss.termit.config.SecurityConfig;
+import cz.cvut.kbss.termit.config.WebAppConfig;
+import cz.cvut.kbss.termit.config.WebSocketConfig;
+import cz.cvut.kbss.termit.environment.Generator;
+import cz.cvut.kbss.termit.environment.config.TestConfig;
+import cz.cvut.kbss.termit.environment.config.TestPersistenceConfig;
+import cz.cvut.kbss.termit.environment.config.TestSecurityConfig;
+import cz.cvut.kbss.termit.environment.config.TestServiceConfig;
+import cz.cvut.kbss.termit.security.JwtUtils;
+import cz.cvut.kbss.termit.security.model.TermItUserDetails;
+import cz.cvut.kbss.termit.service.security.TermItUserDetailsService;
+import cz.cvut.kbss.termit.util.Configuration;
+import org.jetbrains.annotations.NotNull;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Answers;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.boot.autoconfigure.AutoConfigureOrder;
+import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
+import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.boot.test.context.ConfigDataApplicationContextInitializer;
+import org.springframework.boot.test.context.SpringBootTest;
+import org.springframework.boot.test.mock.mockito.MockBean;
+import org.springframework.boot.test.mock.mockito.SpyBean;
+import org.springframework.context.annotation.ComponentScan;
+import org.springframework.context.annotation.EnableAspectJAutoProxy;
+import org.springframework.context.annotation.aspectj.EnableSpringConfigured;
+import org.springframework.messaging.simp.stomp.StompCommand;
+import org.springframework.messaging.simp.stomp.StompHeaders;
+import org.springframework.messaging.simp.stomp.StompSession;
+import org.springframework.messaging.simp.stomp.StompSessionHandlerAdapter;
+import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
+import org.springframework.test.annotation.DirtiesContext;
+import org.springframework.test.context.ActiveProfiles;
+import org.springframework.test.context.ContextConfiguration;
+import org.springframework.transaction.annotation.EnableTransactionManagement;
+import org.springframework.web.socket.CloseStatus;
+import org.springframework.web.socket.WebSocketHandler;
+import org.springframework.web.socket.WebSocketMessage;
+import org.springframework.web.socket.WebSocketSession;
+import org.springframework.web.socket.client.standard.StandardWebSocketClient;
+import org.springframework.web.socket.messaging.WebSocketStompClient;
+
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.when;
+
+@ActiveProfiles("test")
+@EnableSpringConfigured
+@EnableAutoConfiguration
+@EnableTransactionManagement
+@ExtendWith(MockitoExtension.class)
+@EnableAspectJAutoProxy(proxyTargetClass = true)
+@EnableConfigurationProperties({Configuration.class})
+@ContextConfiguration(
+ classes = {TestConfig.class, TestPersistenceConfig.class, TestConfig.class,
+ TestServiceConfig.class, AppConfig.class, SecurityConfig.class, WebAppConfig.class, WebSocketConfig.class},
+ initializers = {ConfigDataApplicationContextInitializer.class})
+@ComponentScan("cz.cvut.kbss.termit.security")
+@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_CLASS)
+@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
+public abstract class BaseWebSocketIntegrationTestRunner {
+
+ protected Logger LOG = LoggerFactory.getLogger(this.getClass());
+
+ protected WebSocketStompClient stompClient;
+
+ @Value("ws://localhost:${local.server.port}/ws")
+ protected String url;
+
+ @SpyBean
+ protected TermItUserDetailsService userDetailsService;
+
+ @SpyBean
+ protected JwtUtils jwtUtils;
+
+ protected TermItUserDetails userDetails;
+
+ protected Future connect(StompSessionHandlerAdapter sessionHandler) {
+ return stompClient.connectAsync(url, sessionHandler);
+ }
+
+ protected String generateToken() {
+ return jwtUtils.generateToken(userDetails.getUser(), userDetails.getAuthorities());
+ }
+
+ @BeforeEach
+ void runnerSetup() {
+ stompClient = new WebSocketStompClient(new StandardWebSocketClient());
+
+ userDetails = new TermItUserDetails(Generator.generateUserAccountWithPassword());
+ doReturn(userDetails).when(userDetailsService).loadUserByUsername(userDetails.getUsername());
+ }
+
+ protected class TestWebSocketSessionHandler implements WebSocketHandler {
+
+ @Override
+ public void afterConnectionEstablished(WebSocketSession session) throws Exception {
+ LOG.info("WebSocket connection established");
+ }
+
+ @Override
+ public void handleMessage(WebSocketSession session, WebSocketMessage> message) throws Exception {
+ LOG.info("WebSocket message received: {}", message.getPayload());
+ }
+
+ @Override
+ public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
+ LOG.error("WebSocket transport error", exception);
+ }
+
+ @Override
+ public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
+ LOG.info("WebSocket connection closed");
+ }
+
+ @Override
+ public boolean supportsPartialMessages() {
+ return false;
+ }
+ }
+
+ protected class TestStompSessionHandler extends StompSessionHandlerAdapter {
+
+ private final AtomicReference exception = new AtomicReference<>();
+
+ @Override
+ public void afterConnected(StompSession session, StompHeaders connectedHeaders) {
+ super.afterConnected(session, connectedHeaders);
+ LOG.info("STOMP session connected");
+ }
+
+ @Override
+ public void handleFrame(@NotNull StompHeaders headers, Object payload) {
+ super.handleFrame(headers, payload);
+ exception.set(new Exception(headers.toString()));
+ LOG.error("STOMP frame: {}", headers);
+ }
+
+ @Override
+ public void handleException(@NotNull StompSession session, StompCommand command, @NotNull StompHeaders headers,
+ byte @NotNull [] payload, @NotNull Throwable exception) {
+ super.handleException(session, command, headers, payload, exception);
+ this.exception.set(exception);
+ LOG.error("STOMP exception", exception);
+ }
+
+ @Override
+ public void handleTransportError(@NotNull StompSession session, @NotNull Throwable exception) {
+ super.handleTransportError(session, exception);
+ this.exception.set(exception);
+ LOG.error("STOMP transport error", exception);
+ }
+
+ public Throwable getException() {
+ return exception.get();
+ }
+ }
+}
diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java b/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java
new file mode 100644
index 000000000..b3bbcc043
--- /dev/null
+++ b/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java
@@ -0,0 +1,183 @@
+package cz.cvut.kbss.termit.websocket;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import cz.cvut.kbss.termit.security.SecurityConstants;
+import cz.cvut.kbss.termit.util.Utils;
+import io.jsonwebtoken.Jwts;
+import io.jsonwebtoken.SignatureAlgorithm;
+import io.jsonwebtoken.jackson.io.JacksonSerializer;
+import io.jsonwebtoken.security.Keys;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.http.HttpHeaders;
+import org.springframework.messaging.simp.stomp.StompCommand;
+import org.springframework.messaging.simp.stomp.StompHeaders;
+import org.springframework.messaging.simp.stomp.StompSession;
+import org.springframework.messaging.simp.stomp.StompSessionHandler;
+import org.springframework.web.socket.TextMessage;
+import org.springframework.web.socket.WebSocketHandler;
+import org.springframework.web.socket.WebSocketMessage;
+import org.springframework.web.socket.WebSocketSession;
+import org.springframework.web.socket.client.WebSocketClient;
+import org.springframework.web.socket.client.standard.StandardWebSocketClient;
+
+import java.net.URI;
+import java.nio.charset.StandardCharsets;
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.stream.Stream;
+
+import static org.awaitility.Awaitility.await;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+class IntegrationWebSocketSecurityTest extends BaseWebSocketIntegrationTestRunner {
+
+ /**
+ * The number of seconds after which some operations will time out.
+ */
+ private static final int OPERATION_TIMEOUT = 15;
+
+ @Autowired
+ ObjectMapper objectMapper;
+
+ /**
+ * @return Stream of argument pairs with StompCommand (CONNECT excluded) and true + false value for each command
+ */
+ public static Stream stompCommands() {
+ return Arrays.stream(StompCommand.values()).filter(c -> c != StompCommand.CONNECT).map(Enum::name)
+ .flatMap(name -> Stream.of(Arguments.of(name, true), Arguments.of(name, false)));
+ }
+
+ /**
+ * Ensures that connection is closed on receiving any message other than CONNECT
+ * (even with valid authorization token)
+ */
+ @ParameterizedTest
+ @MethodSource("stompCommands")
+ void connectionIsClosedOnAnyMessageBeforeConnect(String stompCommand, Boolean withAuth) throws Exception {
+ final AtomicBoolean receivedReply = new AtomicBoolean(false);
+ final AtomicBoolean receivedError = new AtomicBoolean(false);
+
+ final String auth = withAuth ? HttpHeaders.AUTHORIZATION + ":" + SecurityConstants.JWT_TOKEN_PREFIX + generateToken() + "\n" : "";
+ final TextMessage message = new TextMessage(stompCommand + "\n" + auth + "\n\0");
+
+ final WebSocketClient wsClient = new StandardWebSocketClient();
+ Future connectFuture = wsClient.execute(makeWebSocketHandler(receivedReply, receivedError), url);
+
+ WebSocketSession session = connectFuture.get(OPERATION_TIMEOUT, TimeUnit.SECONDS);
+
+ assertTrue(session.isOpen());
+
+ session.sendMessage(message);
+
+ await().atMost(OPERATION_TIMEOUT, TimeUnit.SECONDS).until(() -> !session.isOpen());
+
+ assertTrue(receivedError.get());
+ assertFalse(session.isOpen());
+ assertFalse(receivedReply.get());
+ }
+
+ WebSocketHandler makeWebSocketHandler(AtomicBoolean receivedReply, AtomicBoolean receivedError) {
+ return new TestWebSocketSessionHandler() {
+ @Override
+ public void handleMessage(WebSocketSession session, WebSocketMessage> message) throws Exception {
+ super.handleMessage(session, message);
+ if (message instanceof TextMessage textMessage) {
+ final String command = textMessage.getPayload().split("\n")[0];
+ if (command.equals(StompCommand.ERROR.name())) {
+ receivedError.set(true);
+ return;
+ }
+ }
+ receivedReply.set(true);
+ session.close();
+ }
+ };
+ }
+
+ /**
+ * STOMP CONNECT message is rejected with invalid auth token
+ */
+ @Test
+ void connectWithInvalidAuthorizationIsRejected() throws Throwable {
+ final AtomicBoolean receivedReply = new AtomicBoolean(false);
+ final AtomicBoolean receivedError = new AtomicBoolean(false);
+
+ final TextMessage message = new TextMessage(StompCommand.CONNECT + "\n" + HttpHeaders.AUTHORIZATION + ":" + SecurityConstants.JWT_TOKEN_PREFIX + "DefinitelyNotValidToken\n\n\0");
+
+ final WebSocketClient wsClient = new StandardWebSocketClient();
+ Future connectFuture = wsClient.execute(makeWebSocketHandler(receivedReply, receivedError), url);
+
+ WebSocketSession session = connectFuture.get(OPERATION_TIMEOUT, TimeUnit.SECONDS);
+
+ assertTrue(session.isOpen());
+
+ session.sendMessage(message);
+
+ await().atMost(OPERATION_TIMEOUT, TimeUnit.SECONDS).until(() -> !session.isOpen());
+
+ assertTrue(receivedError.get());
+ assertFalse(session.isOpen());
+ assertFalse(receivedReply.get());
+ }
+
+ /**
+ * STOMP CONNECT message is rejected with invalid JWT token
+ */
+ @Test
+ void connectWithInvalidJwtAuthorizationIsRejected() throws Throwable {
+ final AtomicBoolean receivedReply = new AtomicBoolean(false);
+ final AtomicBoolean receivedError = new AtomicBoolean(false);
+
+ final Instant issued = Utils.timestamp();
+ // creates "valid" JWT token but with invalid signature
+ final String token = Jwts.builder().setSubject(userDetails.getUser().getUsername())
+ .setId(userDetails.getUser().getUri().toString()).setIssuedAt(Date.from(issued))
+ .setExpiration(Date.from(issued.plusMillis(SecurityConstants.SESSION_TIMEOUT)))
+ .signWith(Keys.hmacShaKeyFor("my very secure and really private key".getBytes(StandardCharsets.UTF_8)), SignatureAlgorithm.HS256)
+ .serializeToJsonWith(new JacksonSerializer<>(objectMapper)).compact();
+
+ final TextMessage message = new TextMessage(StompCommand.CONNECT + "\n" + HttpHeaders.AUTHORIZATION + ":" + SecurityConstants.JWT_TOKEN_PREFIX + token + "\n\n\0");
+
+ final WebSocketClient wsClient = new StandardWebSocketClient();
+ Future connectFuture = wsClient.execute(makeWebSocketHandler(receivedReply, receivedError), url);
+
+ WebSocketSession session = connectFuture.get(OPERATION_TIMEOUT, TimeUnit.SECONDS);
+
+ assertTrue(session.isOpen());
+
+ session.sendMessage(message);
+
+ await().atMost(OPERATION_TIMEOUT, TimeUnit.SECONDS).until(() -> !session.isOpen());
+
+ assertTrue(receivedError.get());
+ assertFalse(session.isOpen());
+ assertFalse(receivedReply.get());
+ }
+
+ /**
+ * Checks that it is possible to establish STOMP connection with valid authorization
+ */
+ @Test
+ void connectionIsNotClosedWhenConnectMessageIsSent() throws Throwable {
+ final StompSessionHandler handler = new TestStompSessionHandler();
+
+ final StompHeaders headers = new StompHeaders();
+ headers.add(HttpHeaders.AUTHORIZATION, SecurityConstants.JWT_TOKEN_PREFIX + generateToken());
+
+ Future connectFuture = stompClient.connectAsync(URI.create(url), null, headers, handler);
+
+ StompSession session = connectFuture.get(OPERATION_TIMEOUT, TimeUnit.SECONDS);
+ assertTrue(session.isConnected());
+ session.disconnect();
+ await().atMost(OPERATION_TIMEOUT, TimeUnit.SECONDS).until(() -> !session.isConnected());
+ }
+}
diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/VocabularySocketControllerTest.java b/src/test/java/cz/cvut/kbss/termit/websocket/VocabularySocketControllerTest.java
new file mode 100644
index 000000000..65507b909
--- /dev/null
+++ b/src/test/java/cz/cvut/kbss/termit/websocket/VocabularySocketControllerTest.java
@@ -0,0 +1,105 @@
+package cz.cvut.kbss.termit.websocket;
+
+import cz.cvut.kbss.jopa.model.MultilingualString;
+import cz.cvut.kbss.termit.environment.Environment;
+import cz.cvut.kbss.termit.environment.Generator;
+import cz.cvut.kbss.termit.model.Vocabulary;
+import cz.cvut.kbss.termit.model.validation.ValidationResult;
+import cz.cvut.kbss.termit.service.IdentifierResolver;
+import cz.cvut.kbss.termit.service.business.VocabularyService;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.springframework.boot.test.mock.mockito.MockBean;
+import org.springframework.boot.test.mock.mockito.SpyBean;
+import org.springframework.messaging.Message;
+import org.springframework.messaging.simp.stomp.StompCommand;
+import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
+import org.springframework.messaging.support.MessageBuilder;
+import org.springframework.security.core.Authentication;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Optional;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+
+class VocabularySocketControllerTest extends BaseWebSocketControllerTestRunner {
+
+ @MockBean
+ IdentifierResolver idResolver;
+
+ @MockBean
+ VocabularyService vocabularyService;
+
+ @SpyBean
+ VocabularySocketController sut;
+
+ Vocabulary vocabulary;
+
+ String fragment;
+
+ String namespace;
+
+ StompHeaderAccessor messageHeaders;
+
+ @BeforeEach
+ public void setup() {
+ vocabulary = Generator.generateVocabularyWithId();
+ fragment = IdentifierResolver.extractIdentifierFragment(vocabulary.getUri()).substring(1);
+ namespace = vocabulary.getUri().toString().substring(0, vocabulary.getUri().toString().lastIndexOf('/'));
+ when(idResolver.resolveIdentifier(namespace, fragment)).thenReturn(vocabulary.getUri());
+ when(vocabularyService.getReference(vocabulary.getUri())).thenReturn(vocabulary);
+
+ messageHeaders = StompHeaderAccessor.create(StompCommand.MESSAGE);
+ messageHeaders.setSessionId("0");
+ messageHeaders.setSubscriptionId("0");
+ Authentication auth = Environment.setCurrentUser(Generator.generateUserAccountWithPassword());
+ messageHeaders.setUser(auth);
+ messageHeaders.setSessionAttributes(new HashMap<>());
+ }
+
+ @Test
+ void validateVocabularyValidatesContents() {
+ messageHeaders.setContentLength(0);
+ messageHeaders.setHeader("namespace", namespace);
+ messageHeaders.setDestination("/vocabularies/" + fragment + "/validate");
+
+ this.serverInboundChannel.send(MessageBuilder.withPayload("").setHeaders(messageHeaders).build());
+
+ verify(vocabularyService).validateContents(vocabulary);
+ }
+
+ @Test
+ void validateVocabularyReturnsValidationResults() {
+ messageHeaders.setContentLength(0);
+ messageHeaders.setHeader("namespace", namespace);
+ messageHeaders.setDestination("/vocabularies/" + fragment + "/validate");
+
+ final ValidationResult validationResult = new ValidationResult().setTermUri(Generator.generateUri())
+ .setResultPath(Generator.generateUri())
+ .setMessage(MultilingualString.create("message", "en"))
+ .setSeverity(Generator.generateUri())
+ .setIssueCauseUri(Generator.generateUri());
+ final List validationResults = List.of(validationResult);
+ when(vocabularyService.validateContents(vocabulary)).thenReturn(validationResults);
+
+ this.serverInboundChannel.send(MessageBuilder.withPayload("").setHeaders(messageHeaders).build());
+
+ assertEquals(1, this.brokerChannelInterceptor.getMessages().size());
+ Message> reply = this.brokerChannelInterceptor.getMessages().get(0);
+
+ assertNotNull(reply);
+ StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(reply);
+ // as reply is sent to a common channel for all vocabularies, there must be header with vocabulary uri
+ assertEquals(vocabulary.getUri().toString(), replyHeaders.getFirstNativeHeader("vocabulary"), "Invalid or missing vocabulary header in the reply");
+
+ Optional> payload = readPayload(reply);
+ assertTrue(payload.isPresent());
+ assertEquals(validationResults, payload.get());
+ }
+}
diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/WebSocketExceptionHandlerTest.java b/src/test/java/cz/cvut/kbss/termit/websocket/WebSocketExceptionHandlerTest.java
new file mode 100644
index 000000000..52df68045
--- /dev/null
+++ b/src/test/java/cz/cvut/kbss/termit/websocket/WebSocketExceptionHandlerTest.java
@@ -0,0 +1,59 @@
+package cz.cvut.kbss.termit.websocket;
+
+import cz.cvut.kbss.termit.environment.Environment;
+import cz.cvut.kbss.termit.environment.Generator;
+import cz.cvut.kbss.termit.exception.PersistenceException;
+import cz.cvut.kbss.termit.websocket.handler.WebSocketExceptionHandler;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.springframework.boot.test.mock.mockito.MockBean;
+import org.springframework.boot.test.mock.mockito.SpyBean;
+import org.springframework.messaging.simp.stomp.StompCommand;
+import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
+import org.springframework.messaging.support.MessageBuilder;
+import org.springframework.security.core.Authentication;
+
+import java.util.HashMap;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.notNull;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+class WebSocketExceptionHandlerTest extends BaseWebSocketControllerTestRunner {
+
+ @SpyBean
+ WebSocketExceptionHandler sut;
+
+ @MockBean
+ VocabularySocketController controller;
+
+ StompHeaderAccessor messageHeaders;
+
+ @BeforeEach
+ public void setup() {
+ messageHeaders = StompHeaderAccessor.create(StompCommand.MESSAGE);
+ messageHeaders.setSessionId("0");
+ messageHeaders.setSubscriptionId("0");
+ Authentication auth = Environment.setCurrentUser(Generator.generateUserAccountWithPassword());
+ messageHeaders.setUser(auth);
+ messageHeaders.setSessionAttributes(new HashMap<>());
+ messageHeaders.setContentLength(0);
+ messageHeaders.setHeader("namespace", "namespace");
+ messageHeaders.setDestination("/vocabularies/fragment/validate");
+ }
+
+ void sendMessage() {
+ this.serverInboundChannel.send(MessageBuilder.withPayload("").setHeaders(messageHeaders).build());
+ }
+
+ @Test
+ void handlerIsCalledForPersistenceException() {
+ final PersistenceException e = new PersistenceException(new Exception("mocked exception"));
+ when(controller.validateVocabulary(any(), any())).thenThrow(e);
+ sendMessage();
+ verify(sut).persistenceException(notNull(), eq(e));
+ }
+
+}
diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/util/CachingChannelInterceptor.java b/src/test/java/cz/cvut/kbss/termit/websocket/util/CachingChannelInterceptor.java
new file mode 100644
index 000000000..43a454b81
--- /dev/null
+++ b/src/test/java/cz/cvut/kbss/termit/websocket/util/CachingChannelInterceptor.java
@@ -0,0 +1,32 @@
+package cz.cvut.kbss.termit.websocket.util;
+
+import org.jetbrains.annotations.NotNull;
+import org.springframework.messaging.Message;
+import org.springframework.messaging.MessageChannel;
+import org.springframework.messaging.support.ChannelInterceptor;
+
+import java.util.List;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+
+/**
+ * Caches any message sent to the intercepted channel
+ */
+public class CachingChannelInterceptor implements ChannelInterceptor {
+
+ private final BlockingQueue> messages = new ArrayBlockingQueue<>(100);
+
+ @Override
+ public Message> preSend(@NotNull Message> message, @NotNull MessageChannel channel) {
+ this.messages.add(message);
+ return message;
+ }
+
+ public void reset() {
+ this.messages.clear();
+ }
+
+ public List> getMessages() {
+ return List.copyOf(messages);
+ }
+}
diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/util/ReturnValueCollectingSimpMessagingTemplate.java b/src/test/java/cz/cvut/kbss/termit/websocket/util/ReturnValueCollectingSimpMessagingTemplate.java
new file mode 100644
index 000000000..6a9dd91a5
--- /dev/null
+++ b/src/test/java/cz/cvut/kbss/termit/websocket/util/ReturnValueCollectingSimpMessagingTemplate.java
@@ -0,0 +1,43 @@
+package cz.cvut.kbss.termit.websocket.util;
+
+import org.jetbrains.annotations.NotNull;
+import org.springframework.messaging.Message;
+import org.springframework.messaging.MessageChannel;
+import org.springframework.messaging.core.MessagePostProcessor;
+import org.springframework.messaging.simp.SimpMessagingTemplate;
+import org.springframework.messaging.support.MessageBuilder;
+
+import java.util.Map;
+import java.util.UUID;
+
+/**
+ * Intercepts doConvert method and caches the returned payload before conversion
+ * mapped by resulting message id.
+ * Allows reading raw-returned values before serialization.
+ */
+public class ReturnValueCollectingSimpMessagingTemplate extends SimpMessagingTemplate {
+
+ public static final String MESSAGE_IDENTIFIER_HEADER = "test-message-id";
+
+ private final Map returnedValuesMap;
+
+ public ReturnValueCollectingSimpMessagingTemplate(MessageChannel messageChannel,
+ Map returnedValuesMap) {
+ super(messageChannel);
+ this.returnedValuesMap = returnedValuesMap;
+ }
+
+ @Override
+ protected @NotNull Message> doConvert(@NotNull Object payload, Map headers,
+ MessagePostProcessor postProcessor) {
+ final Message> converted = super.doConvert(payload, headers, postProcessor);
+
+ UUID id = converted.getHeaders().getId();
+ if (id == null) {
+ id = UUID.randomUUID();
+ }
+
+ returnedValuesMap.put(id, payload);
+ return MessageBuilder.fromMessage(converted).copyHeaders(Map.of(MESSAGE_IDENTIFIER_HEADER, id)).build();
+ }
+}