diff --git a/pom.xml b/pom.xml index 0841e4f4f..62ee8fd44 100644 --- a/pom.xml +++ b/pom.xml @@ -149,6 +149,14 @@ org.springframework.boot spring-boot-starter-oauth2-resource-server + + org.springframework.boot + spring-boot-starter-websocket + + + org.springframework.security + spring-security-messaging + org.springframework.data spring-data-commons diff --git a/src/main/java/cz/cvut/kbss/termit/config/SecurityConfig.java b/src/main/java/cz/cvut/kbss/termit/config/SecurityConfig.java index 85904dce6..6e17751a6 100644 --- a/src/main/java/cz/cvut/kbss/termit/config/SecurityConfig.java +++ b/src/main/java/cz/cvut/kbss/termit/config/SecurityConfig.java @@ -23,26 +23,36 @@ import cz.cvut.kbss.termit.security.JwtAuthorizationFilter; import cz.cvut.kbss.termit.security.JwtUtils; import cz.cvut.kbss.termit.security.SecurityConstants; +import cz.cvut.kbss.termit.security.WebSocketJwtAuthorizationInterceptor; import cz.cvut.kbss.termit.service.security.TermItUserDetailsService; import cz.cvut.kbss.termit.util.Constants; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Scope; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; +import org.springframework.messaging.Message; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.authorization.AuthorizationManager; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; +import org.springframework.security.config.annotation.web.socket.EnableWebSocketSecurity; +import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager; import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.HttpStatusEntryPoint; +import org.springframework.util.AntPathMatcher; import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.CorsConfigurationSource; import org.springframework.web.cors.UrlBasedCorsConfigurationSource; @@ -143,4 +153,33 @@ protected static CorsConfigurationSource createCorsConfiguration( source.registerCorsConfiguration("/**", corsConfiguration); return source; } + + /** + * Part of {@link EnableWebSocketSecurity @EnableWebSocketSecurity} replacement + * @see WebSocketConfig + */ + @Bean + @Scope("prototype") + public MessageMatcherDelegatingAuthorizationManager.Builder messageAuthorizationManagerBuilder( + ApplicationContext context) { + return MessageMatcherDelegatingAuthorizationManager.builder().simpDestPathMatcher( + () -> (context.getBeanNamesForType(SimpAnnotationMethodMessageHandler.class).length > 0) + ? context.getBean(SimpAnnotationMethodMessageHandler.class).getPathMatcher() + : new AntPathMatcher()); + } + + /** + * WebSocket endpoint authorization + */ + @Bean + public AuthorizationManager> messageAuthorizationManager( + MessageMatcherDelegatingAuthorizationManager.Builder messages) { + return messages.simpTypeMatchers(SimpMessageType.DISCONNECT).permitAll() + .anyMessage().authenticated().build(); + } + + @Bean + public WebSocketJwtAuthorizationInterceptor webSocketJwtAuthorizationInterceptor() { + return new WebSocketJwtAuthorizationInterceptor(jwtUtils, userDetailsService); + } } diff --git a/src/main/java/cz/cvut/kbss/termit/config/WebAppConfig.java b/src/main/java/cz/cvut/kbss/termit/config/WebAppConfig.java index da01d8ece..61781c11a 100644 --- a/src/main/java/cz/cvut/kbss/termit/config/WebAppConfig.java +++ b/src/main/java/cz/cvut/kbss/termit/config/WebAppConfig.java @@ -153,12 +153,12 @@ public SimpleUrlHandlerMapping sparqlQueryControllerMapping() throws Exception { } @Bean - public HttpMessageConverter stringMessageConverter() { + public HttpMessageConverter termitStringHttpMessageConverter() { return new StringHttpMessageConverter(StandardCharsets.UTF_8); } @Bean - public HttpMessageConverter jsonLdMessageConverter() { + public HttpMessageConverter termitJsonLdHttpMessageConverter() { final MappingJackson2HttpMessageConverter converter = new MappingJackson2HttpMessageConverter( jsonLdObjectMapper()); converter.setSupportedMediaTypes(Collections.singletonList(MediaType.valueOf(JsonLd.MEDIA_TYPE))); @@ -166,14 +166,14 @@ public HttpMessageConverter jsonLdMessageConverter() { } @Bean - public HttpMessageConverter jsonMessageConverter() { + public HttpMessageConverter termitJsonHttpMessageConverter() { final MappingJackson2HttpMessageConverter converter = new MappingJackson2HttpMessageConverter(); converter.setObjectMapper(objectMapper()); return converter; } @Bean - public HttpMessageConverter resourceMessageConverter() { + public HttpMessageConverter termitResourceHttpMessageConverter() { return new ResourceHttpMessageConverter(); } diff --git a/src/main/java/cz/cvut/kbss/termit/config/WebSocketConfig.java b/src/main/java/cz/cvut/kbss/termit/config/WebSocketConfig.java new file mode 100644 index 000000000..a0ad16d27 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/config/WebSocketConfig.java @@ -0,0 +1,136 @@ +package cz.cvut.kbss.termit.config; + +import com.fasterxml.jackson.databind.ObjectMapper; +import cz.cvut.kbss.termit.security.WebSocketJwtAuthorizationInterceptor; +import cz.cvut.kbss.termit.util.Constants; +import cz.cvut.kbss.termit.websocket.handler.StompExceptionHandler; +import cz.cvut.kbss.termit.websocket.handler.WebSocketMessageWithHeadersValueHandler; +import org.jetbrains.annotations.NotNull; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Lazy; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.messaging.Message; +import org.springframework.messaging.converter.MappingJackson2MessageConverter; +import org.springframework.messaging.converter.MessageConverter; +import org.springframework.messaging.converter.StringMessageConverter; +import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; +import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.messaging.simp.config.ChannelRegistration; +import org.springframework.messaging.simp.config.MessageBrokerRegistry; +import org.springframework.security.authorization.AuthorizationManager; +import org.springframework.security.authorization.SpringAuthorizationEventPublisher; +import org.springframework.security.config.annotation.web.socket.EnableWebSocketSecurity; +import org.springframework.security.messaging.access.intercept.AuthorizationChannelInterceptor; +import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver; +import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; +import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; +import org.springframework.web.socket.config.annotation.StompEndpointRegistry; +import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; +import org.springframework.web.socket.config.annotation.WebSocketTransportRegistration; + +import java.nio.charset.StandardCharsets; +import java.util.List; + +/* +We are not using @EnableWebSocketSecurity +it automatically requires CSRF which cannot be configured (disabled) at the moment +(will probably change in the future) +*/ +@Configuration +@EnableWebSocketMessageBroker +@Order(Ordered.HIGHEST_PRECEDENCE + 99) // ensures priority above Spring Security +public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { + + private final cz.cvut.kbss.termit.util.Configuration configuration; + + private final ApplicationContext context; + + private final AuthorizationManager> messageAuthorizationManager; + + private final WebSocketJwtAuthorizationInterceptor jwtAuthorizationInterceptor; + + private final ObjectMapper jsonLdMapper; + + private final SimpMessagingTemplate simpMessagingTemplate; + + @Autowired + public WebSocketConfig(cz.cvut.kbss.termit.util.Configuration configuration, ApplicationContext context, + AuthorizationManager> messageAuthorizationManager, + WebSocketJwtAuthorizationInterceptor jwtAuthorizationInterceptor, + @Qualifier("jsonLdMapper") ObjectMapper jsonLdMapper, + @Lazy SimpMessagingTemplate simpMessagingTemplate) { + this.configuration = configuration; + this.context = context; + this.messageAuthorizationManager = messageAuthorizationManager; + this.jwtAuthorizationInterceptor = jwtAuthorizationInterceptor; + this.jsonLdMapper = jsonLdMapper; + this.simpMessagingTemplate = simpMessagingTemplate; + } + + /** + * WebSocket security setup (replaces {@link EnableWebSocketSecurity @EnableWebSocketSecurity}) + */ + @Override + public void addArgumentResolvers(List argumentResolvers) { + AuthenticationPrincipalArgumentResolver resolver = new AuthenticationPrincipalArgumentResolver(); + argumentResolvers.add(resolver); + } + + /** + * WebSocket security setup (replaces {@link EnableWebSocketSecurity @EnableWebSocketSecurity}) + * @see Spring security source + */ + @Override + public void configureClientInboundChannel(@NotNull ChannelRegistration registration) { + AuthorizationChannelInterceptor interceptor = new AuthorizationChannelInterceptor(this.messageAuthorizationManager); + interceptor.setAuthorizationEventPublisher(new SpringAuthorizationEventPublisher(this.context)); + registration.interceptors(jwtAuthorizationInterceptor, new SecurityContextChannelInterceptor(), interceptor); + } + + @Override + public void addReturnValueHandlers(List returnValueHandlers) { + returnValueHandlers.add(new WebSocketMessageWithHeadersValueHandler(simpMessagingTemplate)); + } + + @Override + public void registerStompEndpoints(StompEndpointRegistry registry) { + registry.addEndpoint("/ws").setAllowedOrigins(configuration.getCors().getAllowedOrigins().split(",")); + registry.setErrorHandler(new StompExceptionHandler()); + } + + @Override + public void configureMessageBroker(MessageBrokerRegistry registry) { + registry.setApplicationDestinationPrefixes("/") + .setUserDestinationPrefix("/user"); + } + + @Override + public void configureWebSocketTransport(WebSocketTransportRegistration registry) { + registry.setTimeToFirstMessage(Constants.WEBSOCKET_TIME_TO_FIRST_MESSAGE); + registry.setSendBufferSizeLimit(Constants.WEBSOCKET_SEND_BUFFER_SIZE_LIMIT); + } + + @Override + public boolean configureMessageConverters(List messageConverters) { + messageConverters.add(termitJsonLdMessageConverter()); + messageConverters.add(termitStringMessageConverter()); + return false; // do not add default converters + } + + @Bean + public MessageConverter termitStringMessageConverter() { + return new StringMessageConverter(StandardCharsets.UTF_8); + } + + @Bean + public MessageConverter termitJsonLdMessageConverter() { + return new MappingJackson2MessageConverter(jsonLdMapper); + } + +} diff --git a/src/main/java/cz/cvut/kbss/termit/rest/VocabularyController.java b/src/main/java/cz/cvut/kbss/termit/rest/VocabularyController.java index 454f457e2..881e6b71b 100644 --- a/src/main/java/cz/cvut/kbss/termit/rest/VocabularyController.java +++ b/src/main/java/cz/cvut/kbss/termit/rest/VocabularyController.java @@ -27,7 +27,6 @@ import cz.cvut.kbss.termit.model.acl.AccessControlRecord; import cz.cvut.kbss.termit.model.acl.AccessLevel; import cz.cvut.kbss.termit.model.changetracking.AbstractChangeRecord; -import cz.cvut.kbss.termit.model.validation.ValidationResult; import cz.cvut.kbss.termit.rest.doc.ApiDocConstants; import cz.cvut.kbss.termit.rest.util.RestUtils; import cz.cvut.kbss.termit.security.SecurityConstants; @@ -413,27 +412,6 @@ public List termsRelations(@Parameter(description = ApiDoc.ID_LOC return vocabularyService.getTermRelations(vocabulary); } - @Operation(description = "Validates the terms in a vocabulary with the specified identifier.") - @ApiResponses({ - @ApiResponse(responseCode = "200", description = "A collection of validation results."), - @ApiResponse(responseCode = "404", description = ApiDoc.ID_NOT_FOUND_DESCRIPTION) - }) - @PreAuthorize("permitAll()") // TODO Authorize? - @GetMapping(value = "/{localName}/validate", - produces = {MediaType.APPLICATION_JSON_VALUE, JsonLd.MEDIA_TYPE}) - public List validateVocabulary( - @Parameter(description = ApiDoc.ID_LOCAL_NAME_DESCRIPTION, - example = ApiDoc.ID_LOCAL_NAME_EXAMPLE) - @PathVariable String localName, - @Parameter(description = ApiDoc.ID_NAMESPACE_DESCRIPTION, - example = ApiDoc.ID_NAMESPACE_EXAMPLE) - @RequestParam(name = QueryParams.NAMESPACE, - required = false) Optional namespace) { - final URI identifier = resolveIdentifier(namespace.orElse(config.getNamespace().getVocabulary()), localName); - final Vocabulary vocabulary = vocabularyService.getReference(identifier); - return vocabularyService.validateContents(vocabulary); - } - @Operation(security = {@SecurityRequirement(name = "bearer-key")}, description = "Creates a snapshot of the vocabulary with the specified identifier.") @ApiResponses({ diff --git a/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java index c17a253d1..90eb63cac 100644 --- a/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java +++ b/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java @@ -56,6 +56,7 @@ * The general pattern should be that unless an exception can be handled in a more appropriate place it bubbles up to a * REST controller which originally received the request. There, it is caught by this handler, logged and a reasonable * error message is returned to the user. + * @implSpec Should reflect {@link cz.cvut.kbss.termit.websocket.handler.WebSocketExceptionHandler} */ @RestControllerAdvice public class RestExceptionHandler { diff --git a/src/main/java/cz/cvut/kbss/termit/security/WebSocketJwtAuthorizationInterceptor.java b/src/main/java/cz/cvut/kbss/termit/security/WebSocketJwtAuthorizationInterceptor.java new file mode 100644 index 000000000..71e5627de --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/security/WebSocketJwtAuthorizationInterceptor.java @@ -0,0 +1,65 @@ +package cz.cvut.kbss.termit.security; + +import cz.cvut.kbss.termit.exception.AuthorizationException; +import cz.cvut.kbss.termit.exception.JwtException; +import cz.cvut.kbss.termit.security.model.TermItUserDetails; +import cz.cvut.kbss.termit.service.security.SecurityUtils; +import cz.cvut.kbss.termit.service.security.TermItUserDetailsService; +import org.jetbrains.annotations.NotNull; +import org.springframework.http.HttpHeaders; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.messaging.support.MessageHeaderAccessor; +import org.springframework.security.authentication.DisabledException; +import org.springframework.security.authentication.LockedException; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.userdetails.UsernameNotFoundException; + +/** + * Authorizes STOMP CONNECT messages + *

+ * Retrieves token from the {@code Authorization} header of STOMP message and validates JWT token. + */ +public class WebSocketJwtAuthorizationInterceptor implements ChannelInterceptor { + + private final JwtUtils jwtUtils; + + private final TermItUserDetailsService userDetailsService; + + public WebSocketJwtAuthorizationInterceptor(JwtUtils jwtUtils, TermItUserDetailsService userDetailsService) { + this.jwtUtils = jwtUtils; + this.userDetailsService = userDetailsService; + } + + @Override + public Message preSend(@NotNull Message message, @NotNull MessageChannel channel) { + StompHeaderAccessor headerAccessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + if (headerAccessor != null && StompCommand.CONNECT.equals(headerAccessor.getCommand()) && headerAccessor.isMutable()) { + final String authHeader = headerAccessor.getFirstNativeHeader(HttpHeaders.AUTHORIZATION); + if (authHeader != null && authHeader.startsWith(SecurityConstants.JWT_TOKEN_PREFIX)) { + headerAccessor.removeNativeHeader(HttpHeaders.AUTHORIZATION); + return process(message, authHeader, headerAccessor); + } + throw new AuthorizationException("Authorization header is invalid"); + } + return message; + } + + private Message process(final @NotNull Message message, final @NotNull String authHeader, + final @NotNull StompHeaderAccessor headerAccessor) { + final String authToken = authHeader.substring(SecurityConstants.JWT_TOKEN_PREFIX.length()); + try { + final TermItUserDetails userDetails = jwtUtils.extractUserInfo(authToken); + final TermItUserDetails existingDetails = userDetailsService.loadUserByUsername(userDetails.getUsername()); + SecurityUtils.verifyAccountStatus(existingDetails.getUser()); + Authentication authentication = SecurityUtils.setCurrentUser(existingDetails); + headerAccessor.setUser(authentication); + return message; + } catch (JwtException | DisabledException | LockedException | UsernameNotFoundException e) { + throw new AuthorizationException(e.getMessage()); + } + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/util/Constants.java b/src/main/java/cz/cvut/kbss/termit/util/Constants.java index d46203655..fb0959d8f 100644 --- a/src/main/java/cz/cvut/kbss/termit/util/Constants.java +++ b/src/main/java/cz/cvut/kbss/termit/util/Constants.java @@ -243,4 +243,15 @@ private QueryParams() { throw new AssertionError(); } } + + /** + * the maximum amount of data to buffer when sending messages to a WebSocket session + */ + public static final int WEBSOCKET_SEND_BUFFER_SIZE_LIMIT = Integer.MAX_VALUE; + + /** + * Set the maximum time allowed in milliseconds after the WebSocket connection is established + * and before the first sub-protocol message is received. + */ + public static final int WEBSOCKET_TIME_TO_FIRST_MESSAGE = 15 * 1000 /* 15s */; } diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/ResultWithHeaders.java b/src/main/java/cz/cvut/kbss/termit/websocket/ResultWithHeaders.java new file mode 100644 index 000000000..1ec5c0ef6 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/websocket/ResultWithHeaders.java @@ -0,0 +1,68 @@ +package cz.cvut.kbss.termit.websocket; + +import cz.cvut.kbss.termit.websocket.handler.WebSocketMessageWithHeadersValueHandler; +import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import org.springframework.messaging.handler.annotation.SendTo; +import org.springframework.messaging.simp.annotation.SendToUser; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * Wrapper carrying a result from WebSocket controller + * including the {@link #payload}, {@link #destination} and {@link #headers} for the resulting message. + *

+ * Do not combine with other method-return-value handlers (like {@link SendTo @SendTo}) + *

+ * The {@code ResultWithHeaders} is then handled by {@link WebSocketMessageWithHeadersValueHandler}. + * Every value returned from a controller method + * can be handled only by a single {@link HandlerMethodReturnValueHandler}. + * Annotations like {@link SendTo @SendTo}/{@link SendToUser @SendToUser} + * are handled by separate return value handlers, so only one can be used simultaneously. + * + * @param payload The actual result of the method + * @param destination The destination channel where the message will be sent + * @param headers Headers that will overwrite headers in the message. + * @param The type of the payload + * @see WebSocketMessageWithHeadersValueHandler + * @see HandlerMethodReturnValueHandler + */ +public record ResultWithHeaders(T payload, @NotNull String destination, @NotNull Map headers, + boolean toUser) { + + public static ResultWithHeadersBuilder result(T payload) { + return new ResultWithHeadersBuilder<>(payload); + } + + public static class ResultWithHeadersBuilder { + + private final T payload; + + private @Nullable Map headers = null; + + private ResultWithHeadersBuilder(T payload) { + this.payload = payload; + } + + /** + * All values will be mapped to strings with {@link Object#toString()} + */ + public ResultWithHeadersBuilder withHeaders(@NotNull Map headers) { + this.headers = new HashMap<>(); + headers.forEach((key, value) -> this.headers.put(key, value.toString())); + this.headers = Collections.unmodifiableMap(this.headers); + return this; + } + + public ResultWithHeaders sendTo(String destination) { + return new ResultWithHeaders<>(payload, destination, headers == null ? Map.of() : headers, false); + } + + public ResultWithHeaders sendToUser(String userDestination) { + return new ResultWithHeaders<>(payload, userDestination, headers == null ? Map.of() : headers, true); + } + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/VocabularySocketController.java b/src/main/java/cz/cvut/kbss/termit/websocket/VocabularySocketController.java new file mode 100644 index 000000000..67c1af291 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/websocket/VocabularySocketController.java @@ -0,0 +1,51 @@ +package cz.cvut.kbss.termit.websocket; + +import cz.cvut.kbss.termit.model.Vocabulary; +import cz.cvut.kbss.termit.model.validation.ValidationResult; +import cz.cvut.kbss.termit.rest.BaseController; +import cz.cvut.kbss.termit.security.SecurityConstants; +import cz.cvut.kbss.termit.service.IdentifierResolver; +import cz.cvut.kbss.termit.service.business.VocabularyService; +import cz.cvut.kbss.termit.util.Configuration; +import cz.cvut.kbss.termit.util.Constants; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.messaging.handler.annotation.DestinationVariable; +import org.springframework.messaging.handler.annotation.Header; +import org.springframework.messaging.handler.annotation.MessageMapping; +import org.springframework.security.access.prepost.PreAuthorize; +import org.springframework.stereotype.Controller; + +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static cz.cvut.kbss.termit.websocket.ResultWithHeaders.result; + +@Controller +@MessageMapping("/vocabularies") +@PreAuthorize("hasRole('" + SecurityConstants.ROLE_RESTRICTED_USER + "')") +public class VocabularySocketController extends BaseController { + + private final VocabularyService vocabularyService; + + protected VocabularySocketController(IdentifierResolver idResolver, Configuration config, + VocabularyService vocabularyService) { + super(idResolver, config); + this.vocabularyService = vocabularyService; + } + + /** + * Validates the terms in a vocabulary with the specified identifier. + */ + @MessageMapping("/{localName}/validate") + public ResultWithHeaders> validateVocabulary(@DestinationVariable String localName, + @Header(name = Constants.QueryParams.NAMESPACE, + required = false) Optional namespace) { + final URI identifier = resolveIdentifier(namespace.orElse(config.getNamespace().getVocabulary()), localName); + final Vocabulary vocabulary = vocabularyService.getReference(identifier); + return result(vocabularyService.validateContents(vocabulary)).withHeaders(Map.of("vocabulary", identifier)) + .sendToUser("/vocabularies/validation"); + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java new file mode 100644 index 000000000..4f78bf920 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java @@ -0,0 +1,21 @@ +package cz.cvut.kbss.termit.websocket.handler; + +import org.jetbrains.annotations.NotNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.messaging.Message; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler; + +public class StompExceptionHandler extends StompSubProtocolErrorHandler { + + private static final Logger LOG = LoggerFactory.getLogger(StompExceptionHandler.class); + + @Override + protected @NotNull Message handleInternal(@NotNull StompHeaderAccessor errorHeaderAccessor, + byte @NotNull [] errorPayload, + Throwable cause, StompHeaderAccessor clientHeaderAccessor) { + LOG.error("STOMP sub-protocol exception", cause); + return super.handleInternal(errorHeaderAccessor, errorPayload, cause, clientHeaderAccessor); + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java new file mode 100644 index 000000000..50411c650 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java @@ -0,0 +1,216 @@ +package cz.cvut.kbss.termit.websocket.handler; + +import cz.cvut.kbss.jopa.exceptions.EntityNotFoundException; +import cz.cvut.kbss.jopa.exceptions.OWLPersistenceException; +import cz.cvut.kbss.jsonld.exception.JsonLdException; +import cz.cvut.kbss.termit.exception.AnnotationGenerationException; +import cz.cvut.kbss.termit.exception.AssetRemovalException; +import cz.cvut.kbss.termit.exception.AuthorizationException; +import cz.cvut.kbss.termit.exception.InvalidLanguageConstantException; +import cz.cvut.kbss.termit.exception.InvalidParameterException; +import cz.cvut.kbss.termit.exception.InvalidPasswordChangeRequestException; +import cz.cvut.kbss.termit.exception.InvalidTermStateException; +import cz.cvut.kbss.termit.exception.NotFoundException; +import cz.cvut.kbss.termit.exception.PersistenceException; +import cz.cvut.kbss.termit.exception.ResourceExistsException; +import cz.cvut.kbss.termit.exception.SnapshotNotEditableException; +import cz.cvut.kbss.termit.exception.SuppressibleLogging; +import cz.cvut.kbss.termit.exception.TermItException; +import cz.cvut.kbss.termit.exception.UnsupportedOperationException; +import cz.cvut.kbss.termit.exception.UnsupportedSearchFacetException; +import cz.cvut.kbss.termit.exception.ValidationException; +import cz.cvut.kbss.termit.exception.WebServiceIntegrationException; +import cz.cvut.kbss.termit.exception.importing.UnsupportedImportMediaTypeException; +import cz.cvut.kbss.termit.exception.importing.VocabularyImportException; +import cz.cvut.kbss.termit.rest.handler.ErrorInfo; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageDeliveryException; +import org.springframework.messaging.handler.annotation.MessageExceptionHandler; +import org.springframework.messaging.simp.annotation.SendToUser; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.security.core.userdetails.UsernameNotFoundException; +import org.springframework.web.bind.annotation.ControllerAdvice; +import org.springframework.web.multipart.MaxUploadSizeExceededException; + +/** + * @implSpec Should reflect {@link cz.cvut.kbss.termit.rest.handler.RestExceptionHandler} + */ +@SendToUser +@ControllerAdvice +public class WebSocketExceptionHandler { + + private static final Logger LOG = LoggerFactory.getLogger(WebSocketExceptionHandler.class); + + private static String destination(Message message) { + return message.getHeaders().getOrDefault("destination", "missing destination").toString(); + } + + private static void logException(TermItException ex, Message message) { + if (shouldSuppressLogging(ex)) { + return; + } + logException("Exception caught when processing request to '" + destination(message) + "'.", ex); + } + + private static boolean shouldSuppressLogging(TermItException ex) { + return ex.getClass().getAnnotation(SuppressibleLogging.class) != null; + } + + private static void logException(Throwable ex, Message message) { + logException("Exception caught when processing request to '" + destination(message) + "'.", ex); + } + + private static void logException(String message, Throwable ex) { + LOG.error(message, ex); + } + + private static ErrorInfo errorInfo(Message message, Throwable e) { + return ErrorInfo.createWithMessage(e.getMessage(), destination(message)); + } + + @MessageExceptionHandler + public void messageDeliveryException(Message message, MessageDeliveryException e) { + final StompHeaderAccessor headerAccessor = StompHeaderAccessor.wrap(message); + LOG.error("Failed to send message with destination {}: {}", headerAccessor.getDestination(), e.getMessage()); + } + + @MessageExceptionHandler(PersistenceException.class) + public ErrorInfo persistenceException(Message message, PersistenceException e) { + logException(e, message); + return errorInfo(message, e.getCause()); + } + + @MessageExceptionHandler(OWLPersistenceException.class) + public ErrorInfo jopaException(Message message, OWLPersistenceException e) { + logException("Persistence exception caught.", e); + return errorInfo(message, e); + } + + @MessageExceptionHandler(ResourceExistsException.class) + public ErrorInfo resourceExistsException(Message message, ResourceExistsException e) { + logException(e, message); + return errorInfo(message, e); + } + + @MessageExceptionHandler(NotFoundException.class) + public ErrorInfo resourceNotFound(Message message, NotFoundException e) { + // Not necessary to log NotFoundException, they may be quite frequent and do not represent an issue with the application + return errorInfo(message, e); + } + + @MessageExceptionHandler(UsernameNotFoundException.class) + public ErrorInfo usernameNotFound(Message message, UsernameNotFoundException e) { + return errorInfo(message, e); + } + + @MessageExceptionHandler(EntityNotFoundException.class) + public ErrorInfo entityNotFoundException(Message message, EntityNotFoundException e) { + logException(e, message); + return errorInfo(message, e); + } + + @MessageExceptionHandler(AuthorizationException.class) + public ErrorInfo authorizationException(Message message, AuthorizationException e) { + logException(e, message); + return errorInfo(message, e); + } + + @MessageExceptionHandler(ValidationException.class) + public ErrorInfo validationException(Message message, ValidationException e) { + logException(e, message); + return errorInfo(message, e); + } + + @MessageExceptionHandler(WebServiceIntegrationException.class) + public ErrorInfo webServiceIntegrationException(Message message, WebServiceIntegrationException e) { + logException(e.getCause(), message); + return errorInfo(message, e); + } + + @MessageExceptionHandler(AnnotationGenerationException.class) + public ErrorInfo annotationGenerationException(Message message, AnnotationGenerationException e) { + logException(e.getCause(), message); + return errorInfo(message, e); + } + + @MessageExceptionHandler(TermItException.class) + public ErrorInfo termItException(Message message, TermItException e) { + logException(e, message); + return errorInfo(message, e); + } + + @MessageExceptionHandler(JsonLdException.class) + public ErrorInfo jsonLdException(Message message, JsonLdException e) { + logException(e, message); + return ErrorInfo.createWithMessage("Error when processing JSON-LD.", destination(message)); + } + + @MessageExceptionHandler(UnsupportedOperationException.class) + public ErrorInfo unsupportedAssetOperationException(Message message, UnsupportedOperationException e) { + logException(e, message); + return errorInfo(message, e); + } + + @MessageExceptionHandler(VocabularyImportException.class) + public ErrorInfo vocabularyImportException(Message message, VocabularyImportException e) { + logException(e, message); + return ErrorInfo.createWithMessageAndMessageId(e.getMessage(), e.getMessageId(), destination(message)); + } + + @MessageExceptionHandler + public ErrorInfo unsupportedImportMediaTypeException(Message message, UnsupportedImportMediaTypeException e) { + logException(e, message); + return errorInfo(message, e); + } + + @MessageExceptionHandler + public ErrorInfo assetRemovalException(Message message, AssetRemovalException e) { + logException(e, message); + return errorInfo(message, e); + } + + @MessageExceptionHandler + public ErrorInfo invalidParameter(Message message, InvalidParameterException e) { + logException(e, message); + return errorInfo(message, e); + } + + @MessageExceptionHandler + public ErrorInfo maxUploadSizeExceededException(Message message, MaxUploadSizeExceededException e) { + logException(e, message); + return ErrorInfo.createWithMessageAndMessageId(e.getMessage(), "error.file.maxUploadSizeExceeded", destination(message)); + } + + @MessageExceptionHandler + public ErrorInfo snapshotNotEditableException(Message message, SnapshotNotEditableException e) { + logException(e, message); + return ErrorInfo.createWithMessage(e.getMessage(), destination(message)); + } + + @MessageExceptionHandler + public ErrorInfo unsupportedSearchFacetException(Message message, UnsupportedSearchFacetException e) { + logException(e, message); + return errorInfo(message, e); + } + + @MessageExceptionHandler + public ErrorInfo invalidLanguageConstantException(Message message, InvalidLanguageConstantException e) { + logException(e, message); + return errorInfo(message, e); + } + + @MessageExceptionHandler + public ErrorInfo invalidTermStateException(Message message, InvalidTermStateException e) { + logException(e, message); + return ErrorInfo.createWithMessageAndMessageId(e.getMessage(), e.getMessageId(), destination(message)); + } + + @MessageExceptionHandler + public ErrorInfo invalidPasswordChangeRequestException(Message message, + InvalidPasswordChangeRequestException e) { + logException(e, message); + return ErrorInfo.createWithMessageAndMessageId(e.getMessage(), e.getMessageId(), destination(message)); + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketMessageWithHeadersValueHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketMessageWithHeadersValueHandler.java new file mode 100644 index 000000000..5494294f2 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketMessageWithHeadersValueHandler.java @@ -0,0 +1,46 @@ +package cz.cvut.kbss.termit.websocket.handler; + +import cz.cvut.kbss.termit.exception.UnsupportedOperationException; +import cz.cvut.kbss.termit.websocket.ResultWithHeaders; +import org.jetbrains.annotations.NotNull; +import org.springframework.core.MethodParameter; +import org.springframework.messaging.Message; +import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.messaging.simp.annotation.support.MissingSessionUserException; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; + +public class WebSocketMessageWithHeadersValueHandler implements HandlerMethodReturnValueHandler { + + private final SimpMessagingTemplate simpMessagingTemplate; + + public WebSocketMessageWithHeadersValueHandler(SimpMessagingTemplate simpMessagingTemplate) { + this.simpMessagingTemplate = simpMessagingTemplate; + } + + @Override + public boolean supportsReturnType(MethodParameter returnType) { + return ResultWithHeaders.class.isAssignableFrom(returnType.getParameterType()); + } + + @Override + public void handleReturnValue(Object returnValue, @NotNull MethodParameter returnType, @NotNull Message message) + throws Exception { + if (returnValue instanceof ResultWithHeaders resultWithHeaders) { + final StompHeaderAccessor headerAccessor = StompHeaderAccessor.wrap(message); + resultWithHeaders.headers().forEach(headerAccessor::setNativeHeader); + if (resultWithHeaders.toUser()) { + final String sessionId = SimpMessageHeaderAccessor.getSessionId(headerAccessor.toMessageHeaders()); + if (sessionId == null || sessionId.isBlank()) { + throw new MissingSessionUserException(message); + } + simpMessagingTemplate.convertAndSendToUser(sessionId, resultWithHeaders.destination(), resultWithHeaders.payload(), headerAccessor.toMessageHeaders()); + } else { + simpMessagingTemplate.convertAndSend(resultWithHeaders.destination(), resultWithHeaders.payload(), headerAccessor.toMessageHeaders()); + } + return; + } + throw new UnsupportedOperationException("Unable to process returned value: " + returnValue + " of type " + returnType.getParameterType() + " from " + returnType.getMethod()); + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/environment/Environment.java b/src/test/java/cz/cvut/kbss/termit/environment/Environment.java index e9db7449e..9f7bd62a1 100644 --- a/src/test/java/cz/cvut/kbss/termit/environment/Environment.java +++ b/src/test/java/cz/cvut/kbss/termit/environment/Environment.java @@ -41,6 +41,7 @@ import org.springframework.http.converter.ResourceHttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter; import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextImpl; @@ -81,11 +82,13 @@ private static DtoMapper initDtoMapper() { * * @param user User to set as currently authenticated */ - public static void setCurrentUser(UserAccount user) { + public static Authentication setCurrentUser(UserAccount user) { final TermItUserDetails userDetails = new TermItUserDetails(user, new HashSet<>()); SecurityContext context = new SecurityContextImpl(); - context.setAuthentication(new AuthenticationToken(userDetails.getAuthorities(), userDetails)); + Authentication authentication = new AuthenticationToken(userDetails.getAuthorities(), userDetails); + context.setAuthentication(authentication); SecurityContextHolder.setContext(context); + return authentication; } /** diff --git a/src/test/java/cz/cvut/kbss/termit/environment/config/TestWebSocketConfig.java b/src/test/java/cz/cvut/kbss/termit/environment/config/TestWebSocketConfig.java new file mode 100644 index 000000000..f3cd62e7d --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/environment/config/TestWebSocketConfig.java @@ -0,0 +1,98 @@ +package cz.cvut.kbss.termit.environment.config; + +import cz.cvut.kbss.termit.config.WebAppConfig; +import cz.cvut.kbss.termit.config.WebSocketConfig; +import cz.cvut.kbss.termit.util.Configuration; +import cz.cvut.kbss.termit.websocket.util.ReturnValueCollectingSimpMessagingTemplate; +import org.jetbrains.annotations.NotNull; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.test.context.TestConfiguration; +import org.springframework.context.ApplicationListener; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.Import; +import org.springframework.context.annotation.Lazy; +import org.springframework.context.annotation.Primary; +import org.springframework.context.event.ContextRefreshedEvent; +import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.SubscribableChannel; +import org.springframework.messaging.converter.CompositeMessageConverter; +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler; +import org.springframework.messaging.simp.config.ChannelRegistration; +import org.springframework.messaging.simp.config.MessageBrokerRegistry; +import org.springframework.messaging.support.AbstractSubscribableChannel; +import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +@TestConfiguration +@EnableConfigurationProperties(Configuration.class) +@Import({TestSecurityConfig.class, TestRestSecurityConfig.class, WebAppConfig.class, WebSocketConfig.class}) +@ComponentScan(basePackages = "cz.cvut.kbss.termit.websocket") +public class TestWebSocketConfig + implements ApplicationListener, WebSocketMessageBrokerConfigurer { + + private final List channels; + + private final List handlers; + + @Autowired + @Lazy + public TestWebSocketConfig(List channels, List handlers) { + this.channels = channels; + this.handlers = handlers; + } + + /** + * Unregisters MessageHandler's from the message channels to reduce processing during the test. + * Also stops further processing so for example user responses remain in the broker channel. + * @param event the event to respond to + */ + @Override + public void onApplicationEvent(@NotNull ContextRefreshedEvent event) { + for (MessageHandler handler : handlers) { + if (handler instanceof SimpAnnotationMethodMessageHandler) { + continue; + } + for (SubscribableChannel channel : channels) { + channel.unsubscribe(handler); + } + } + } + + @Override + public void configureClientInboundChannel(ChannelRegistration registration) { + registration.executor(new SyncTaskExecutor()); + } + + @Override + public void configureClientOutboundChannel(ChannelRegistration registration) { + registration.executor(new SyncTaskExecutor()); + } + + @Override + public void configureMessageBroker(MessageBrokerRegistry registry) { + registry.configureBrokerChannel().executor(new SyncTaskExecutor()); + } + + @Bean + public Map returnedValuesMap() { + return new HashMap<>(); + } + + @Bean + @Primary + public SimpMessagingTemplate brokerMessagingTemplate( + AbstractSubscribableChannel brokerChannel, CompositeMessageConverter brokerMessageConverter) { + + SimpMessagingTemplate template = new ReturnValueCollectingSimpMessagingTemplate(brokerChannel, returnedValuesMap()); + template.setMessageConverter(brokerMessageConverter); + return template; + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/rest/VocabularyControllerTest.java b/src/test/java/cz/cvut/kbss/termit/rest/VocabularyControllerTest.java index cccd38e09..76b2e0521 100644 --- a/src/test/java/cz/cvut/kbss/termit/rest/VocabularyControllerTest.java +++ b/src/test/java/cz/cvut/kbss/termit/rest/VocabularyControllerTest.java @@ -467,30 +467,6 @@ void getHistoryOfContentReturnsListOfAggregatedChangeObjectsForTermsInSpecifiedV verify(serviceMock).getChangesOfContent(vocabulary); } - @Test - void validateExecutesServiceValidate() throws Exception { - final Vocabulary vocabulary = generateVocabularyAndInitReferenceResolution(); - final List records = Collections.singletonList(new ValidationResult() - .setTermUri(Generator.generateUri()) - .setIssueCauseUri( - Generator.generateUri()) - .setSeverity(URI.create( - SH.Violation.toString()))); - when(serviceMock.validateContents(vocabulary)).thenReturn(records); - - - final MvcResult mvcResult = mockMvc.perform(get(PATH + "/" + FRAGMENT + "/validate")) - .andExpect(status().isOk()) - .andReturn(); - final List result = - readValue(mvcResult, new TypeReference>() { - }); - assertNotNull(result); - assertEquals(records.stream().map(ValidationResult::getId).collect(Collectors.toList()), - result.stream().map(ValidationResult::getId).collect(Collectors.toList())); - verify(serviceMock).validateContents(vocabulary); - } - private Vocabulary generateVocabularyAndInitReferenceResolution() { final Vocabulary vocabulary = generateVocabulary(); vocabulary.setUri(VOCABULARY_URI); diff --git a/src/test/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolverTest.java b/src/test/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolverTest.java index 627a24feb..bf85ea85e 100644 --- a/src/test/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolverTest.java +++ b/src/test/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolverTest.java @@ -32,6 +32,8 @@ import org.jsoup.Jsoup; import org.jsoup.select.Elements; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnOs; +import org.junit.jupiter.api.condition.OS; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.InjectMocks; import org.mockito.Mock; @@ -177,6 +179,7 @@ void findTermOccurrencesMarksOccurrencesAsSuggested() { } @Test + @DisabledOnOs(OS.WINDOWS) // TODO: https://github.com/kbss-cvut/termit/issues/275 void findTermOccurrencesSetsFoundOccurrencesAsApprovedWhenCorrespondingExistingOccurrenceWasApproved() throws Exception { when(termService.exists(TERM_URI)).thenReturn(true); final File file = initFile(); diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketControllerTestRunner.java b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketControllerTestRunner.java new file mode 100644 index 000000000..4c4c944d8 --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketControllerTestRunner.java @@ -0,0 +1,103 @@ +package cz.cvut.kbss.termit.websocket; + +import cz.cvut.kbss.termit.environment.config.TestWebSocketConfig; +import cz.cvut.kbss.termit.util.Configuration; +import cz.cvut.kbss.termit.websocket.util.CachingChannelInterceptor; +import jakarta.annotation.PostConstruct; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.test.context.ConfigDataApplicationContextInitializer; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.AbstractSubscribableChannel; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; + +import java.util.Map; +import java.util.Optional; +import java.util.UUID; + +import static cz.cvut.kbss.termit.websocket.util.ReturnValueCollectingSimpMessagingTemplate.MESSAGE_IDENTIFIER_HEADER; + +@ActiveProfiles("test") +@ExtendWith(SpringExtension.class) +@ExtendWith(MockitoExtension.class) +@EnableConfigurationProperties({Configuration.class}) +@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_CLASS) +@ContextConfiguration(classes = {TestWebSocketConfig.class}, + initializers = {ConfigDataApplicationContextInitializer.class}) +public abstract class BaseWebSocketControllerTestRunner { + + private static final Logger LOG = LoggerFactory.getLogger(BaseWebSocketControllerTestRunner.class); + + /** + * Simulated messages from client to server + */ + @Autowired + @Qualifier("clientInboundChannel") + protected AbstractSubscribableChannel serverInboundChannel; + + /** + * Messages sent from the server to the client + */ + @Autowired + @Qualifier("clientOutboundChannel") + protected AbstractSubscribableChannel serverOutboundChannel; + + @Autowired + protected AbstractSubscribableChannel brokerChannel; + + /** + * Holds message ids mapped to the values returned from the controllers + */ + @Autowired + protected Map returnedValuesMap; + + /** + * Caches any messages sent from the server to the client + */ + protected CachingChannelInterceptor serverOutboundChannelInterceptor; + + protected CachingChannelInterceptor brokerChannelInterceptor; + + @PostConstruct + protected void runnerPostConstruct() { + this.brokerChannelInterceptor = new CachingChannelInterceptor(); + this.serverOutboundChannelInterceptor = new CachingChannelInterceptor(); + + this.brokerChannel.addInterceptor(this.brokerChannelInterceptor); + this.serverOutboundChannel.addInterceptor(this.serverOutboundChannelInterceptor); + } + + @BeforeEach + protected void runnerBeforeEach() { + this.serverOutboundChannelInterceptor.reset(); + this.brokerChannelInterceptor.reset(); + this.returnedValuesMap.clear(); + } + + /** + * Returns result of controller method associated with the specified message. + * + * @param message The message sent from some controller + */ + @SuppressWarnings("unchecked") + protected Optional readPayload(Message message) throws ClassCastException { + final UUID id = message.getHeaders().get(MESSAGE_IDENTIFIER_HEADER, UUID.class); + if (id == null) { + LOG.error("Unable to read message payload. Message id is null."); + return Optional.empty(); + } + if (returnedValuesMap.containsKey(id)) { + return Optional.of((T) returnedValuesMap.get(id)); + } + return Optional.empty(); + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketIntegrationTestRunner.java b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketIntegrationTestRunner.java new file mode 100644 index 000000000..59ba7d502 --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketIntegrationTestRunner.java @@ -0,0 +1,169 @@ +package cz.cvut.kbss.termit.websocket; + +import cz.cvut.kbss.termit.config.AppConfig; +import cz.cvut.kbss.termit.config.SecurityConfig; +import cz.cvut.kbss.termit.config.WebAppConfig; +import cz.cvut.kbss.termit.config.WebSocketConfig; +import cz.cvut.kbss.termit.environment.Generator; +import cz.cvut.kbss.termit.environment.config.TestConfig; +import cz.cvut.kbss.termit.environment.config.TestPersistenceConfig; +import cz.cvut.kbss.termit.environment.config.TestSecurityConfig; +import cz.cvut.kbss.termit.environment.config.TestServiceConfig; +import cz.cvut.kbss.termit.security.JwtUtils; +import cz.cvut.kbss.termit.security.model.TermItUserDetails; +import cz.cvut.kbss.termit.service.security.TermItUserDetailsService; +import cz.cvut.kbss.termit.util.Configuration; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.autoconfigure.AutoConfigureOrder; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.test.context.ConfigDataApplicationContextInitializer; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.mock.mockito.MockBean; +import org.springframework.boot.test.mock.mockito.SpyBean; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.EnableAspectJAutoProxy; +import org.springframework.context.annotation.aspectj.EnableSpringConfigured; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaders; +import org.springframework.messaging.simp.stomp.StompSession; +import org.springframework.messaging.simp.stomp.StompSessionHandlerAdapter; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.transaction.annotation.EnableTransactionManagement; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketMessage; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.client.standard.StandardWebSocketClient; +import org.springframework.web.socket.messaging.WebSocketStompClient; + +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicReference; + +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.when; + +@ActiveProfiles("test") +@EnableSpringConfigured +@EnableAutoConfiguration +@EnableTransactionManagement +@ExtendWith(MockitoExtension.class) +@EnableAspectJAutoProxy(proxyTargetClass = true) +@EnableConfigurationProperties({Configuration.class}) +@ContextConfiguration( + classes = {TestConfig.class, TestPersistenceConfig.class, TestConfig.class, + TestServiceConfig.class, AppConfig.class, SecurityConfig.class, WebAppConfig.class, WebSocketConfig.class}, + initializers = {ConfigDataApplicationContextInitializer.class}) +@ComponentScan("cz.cvut.kbss.termit.security") +@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_CLASS) +@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +public abstract class BaseWebSocketIntegrationTestRunner { + + protected Logger LOG = LoggerFactory.getLogger(this.getClass()); + + protected WebSocketStompClient stompClient; + + @Value("ws://localhost:${local.server.port}/ws") + protected String url; + + @SpyBean + protected TermItUserDetailsService userDetailsService; + + @SpyBean + protected JwtUtils jwtUtils; + + protected TermItUserDetails userDetails; + + protected Future connect(StompSessionHandlerAdapter sessionHandler) { + return stompClient.connectAsync(url, sessionHandler); + } + + protected String generateToken() { + return jwtUtils.generateToken(userDetails.getUser(), userDetails.getAuthorities()); + } + + @BeforeEach + void runnerSetup() { + stompClient = new WebSocketStompClient(new StandardWebSocketClient()); + + userDetails = new TermItUserDetails(Generator.generateUserAccountWithPassword()); + doReturn(userDetails).when(userDetailsService).loadUserByUsername(userDetails.getUsername()); + } + + protected class TestWebSocketSessionHandler implements WebSocketHandler { + + @Override + public void afterConnectionEstablished(WebSocketSession session) throws Exception { + LOG.info("WebSocket connection established"); + } + + @Override + public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception { + LOG.info("WebSocket message received: {}", message.getPayload()); + } + + @Override + public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { + LOG.error("WebSocket transport error", exception); + } + + @Override + public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { + LOG.info("WebSocket connection closed"); + } + + @Override + public boolean supportsPartialMessages() { + return false; + } + } + + protected class TestStompSessionHandler extends StompSessionHandlerAdapter { + + private final AtomicReference exception = new AtomicReference<>(); + + @Override + public void afterConnected(StompSession session, StompHeaders connectedHeaders) { + super.afterConnected(session, connectedHeaders); + LOG.info("STOMP session connected"); + } + + @Override + public void handleFrame(@NotNull StompHeaders headers, Object payload) { + super.handleFrame(headers, payload); + exception.set(new Exception(headers.toString())); + LOG.error("STOMP frame: {}", headers); + } + + @Override + public void handleException(@NotNull StompSession session, StompCommand command, @NotNull StompHeaders headers, + byte @NotNull [] payload, @NotNull Throwable exception) { + super.handleException(session, command, headers, payload, exception); + this.exception.set(exception); + LOG.error("STOMP exception", exception); + } + + @Override + public void handleTransportError(@NotNull StompSession session, @NotNull Throwable exception) { + super.handleTransportError(session, exception); + this.exception.set(exception); + LOG.error("STOMP transport error", exception); + } + + public Throwable getException() { + return exception.get(); + } + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java b/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java new file mode 100644 index 000000000..b3bbcc043 --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java @@ -0,0 +1,183 @@ +package cz.cvut.kbss.termit.websocket; + +import com.fasterxml.jackson.databind.ObjectMapper; +import cz.cvut.kbss.termit.security.SecurityConstants; +import cz.cvut.kbss.termit.util.Utils; +import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.SignatureAlgorithm; +import io.jsonwebtoken.jackson.io.JacksonSerializer; +import io.jsonwebtoken.security.Keys; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpHeaders; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaders; +import org.springframework.messaging.simp.stomp.StompSession; +import org.springframework.messaging.simp.stomp.StompSessionHandler; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketMessage; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.client.WebSocketClient; +import org.springframework.web.socket.client.standard.StandardWebSocketClient; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Arrays; +import java.util.Date; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Stream; + +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class IntegrationWebSocketSecurityTest extends BaseWebSocketIntegrationTestRunner { + + /** + * The number of seconds after which some operations will time out. + */ + private static final int OPERATION_TIMEOUT = 15; + + @Autowired + ObjectMapper objectMapper; + + /** + * @return Stream of argument pairs with StompCommand (CONNECT excluded) and true + false value for each command + */ + public static Stream stompCommands() { + return Arrays.stream(StompCommand.values()).filter(c -> c != StompCommand.CONNECT).map(Enum::name) + .flatMap(name -> Stream.of(Arguments.of(name, true), Arguments.of(name, false))); + } + + /** + * Ensures that connection is closed on receiving any message other than CONNECT + * (even with valid authorization token) + */ + @ParameterizedTest + @MethodSource("stompCommands") + void connectionIsClosedOnAnyMessageBeforeConnect(String stompCommand, Boolean withAuth) throws Exception { + final AtomicBoolean receivedReply = new AtomicBoolean(false); + final AtomicBoolean receivedError = new AtomicBoolean(false); + + final String auth = withAuth ? HttpHeaders.AUTHORIZATION + ":" + SecurityConstants.JWT_TOKEN_PREFIX + generateToken() + "\n" : ""; + final TextMessage message = new TextMessage(stompCommand + "\n" + auth + "\n\0"); + + final WebSocketClient wsClient = new StandardWebSocketClient(); + Future connectFuture = wsClient.execute(makeWebSocketHandler(receivedReply, receivedError), url); + + WebSocketSession session = connectFuture.get(OPERATION_TIMEOUT, TimeUnit.SECONDS); + + assertTrue(session.isOpen()); + + session.sendMessage(message); + + await().atMost(OPERATION_TIMEOUT, TimeUnit.SECONDS).until(() -> !session.isOpen()); + + assertTrue(receivedError.get()); + assertFalse(session.isOpen()); + assertFalse(receivedReply.get()); + } + + WebSocketHandler makeWebSocketHandler(AtomicBoolean receivedReply, AtomicBoolean receivedError) { + return new TestWebSocketSessionHandler() { + @Override + public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception { + super.handleMessage(session, message); + if (message instanceof TextMessage textMessage) { + final String command = textMessage.getPayload().split("\n")[0]; + if (command.equals(StompCommand.ERROR.name())) { + receivedError.set(true); + return; + } + } + receivedReply.set(true); + session.close(); + } + }; + } + + /** + * STOMP CONNECT message is rejected with invalid auth token + */ + @Test + void connectWithInvalidAuthorizationIsRejected() throws Throwable { + final AtomicBoolean receivedReply = new AtomicBoolean(false); + final AtomicBoolean receivedError = new AtomicBoolean(false); + + final TextMessage message = new TextMessage(StompCommand.CONNECT + "\n" + HttpHeaders.AUTHORIZATION + ":" + SecurityConstants.JWT_TOKEN_PREFIX + "DefinitelyNotValidToken\n\n\0"); + + final WebSocketClient wsClient = new StandardWebSocketClient(); + Future connectFuture = wsClient.execute(makeWebSocketHandler(receivedReply, receivedError), url); + + WebSocketSession session = connectFuture.get(OPERATION_TIMEOUT, TimeUnit.SECONDS); + + assertTrue(session.isOpen()); + + session.sendMessage(message); + + await().atMost(OPERATION_TIMEOUT, TimeUnit.SECONDS).until(() -> !session.isOpen()); + + assertTrue(receivedError.get()); + assertFalse(session.isOpen()); + assertFalse(receivedReply.get()); + } + + /** + * STOMP CONNECT message is rejected with invalid JWT token + */ + @Test + void connectWithInvalidJwtAuthorizationIsRejected() throws Throwable { + final AtomicBoolean receivedReply = new AtomicBoolean(false); + final AtomicBoolean receivedError = new AtomicBoolean(false); + + final Instant issued = Utils.timestamp(); + // creates "valid" JWT token but with invalid signature + final String token = Jwts.builder().setSubject(userDetails.getUser().getUsername()) + .setId(userDetails.getUser().getUri().toString()).setIssuedAt(Date.from(issued)) + .setExpiration(Date.from(issued.plusMillis(SecurityConstants.SESSION_TIMEOUT))) + .signWith(Keys.hmacShaKeyFor("my very secure and really private key".getBytes(StandardCharsets.UTF_8)), SignatureAlgorithm.HS256) + .serializeToJsonWith(new JacksonSerializer<>(objectMapper)).compact(); + + final TextMessage message = new TextMessage(StompCommand.CONNECT + "\n" + HttpHeaders.AUTHORIZATION + ":" + SecurityConstants.JWT_TOKEN_PREFIX + token + "\n\n\0"); + + final WebSocketClient wsClient = new StandardWebSocketClient(); + Future connectFuture = wsClient.execute(makeWebSocketHandler(receivedReply, receivedError), url); + + WebSocketSession session = connectFuture.get(OPERATION_TIMEOUT, TimeUnit.SECONDS); + + assertTrue(session.isOpen()); + + session.sendMessage(message); + + await().atMost(OPERATION_TIMEOUT, TimeUnit.SECONDS).until(() -> !session.isOpen()); + + assertTrue(receivedError.get()); + assertFalse(session.isOpen()); + assertFalse(receivedReply.get()); + } + + /** + * Checks that it is possible to establish STOMP connection with valid authorization + */ + @Test + void connectionIsNotClosedWhenConnectMessageIsSent() throws Throwable { + final StompSessionHandler handler = new TestStompSessionHandler(); + + final StompHeaders headers = new StompHeaders(); + headers.add(HttpHeaders.AUTHORIZATION, SecurityConstants.JWT_TOKEN_PREFIX + generateToken()); + + Future connectFuture = stompClient.connectAsync(URI.create(url), null, headers, handler); + + StompSession session = connectFuture.get(OPERATION_TIMEOUT, TimeUnit.SECONDS); + assertTrue(session.isConnected()); + session.disconnect(); + await().atMost(OPERATION_TIMEOUT, TimeUnit.SECONDS).until(() -> !session.isConnected()); + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/VocabularySocketControllerTest.java b/src/test/java/cz/cvut/kbss/termit/websocket/VocabularySocketControllerTest.java new file mode 100644 index 000000000..65507b909 --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/websocket/VocabularySocketControllerTest.java @@ -0,0 +1,105 @@ +package cz.cvut.kbss.termit.websocket; + +import cz.cvut.kbss.jopa.model.MultilingualString; +import cz.cvut.kbss.termit.environment.Environment; +import cz.cvut.kbss.termit.environment.Generator; +import cz.cvut.kbss.termit.model.Vocabulary; +import cz.cvut.kbss.termit.model.validation.ValidationResult; +import cz.cvut.kbss.termit.service.IdentifierResolver; +import cz.cvut.kbss.termit.service.business.VocabularyService; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.boot.test.mock.mockito.MockBean; +import org.springframework.boot.test.mock.mockito.SpyBean; +import org.springframework.messaging.Message; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.security.core.Authentication; + +import java.util.HashMap; +import java.util.List; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + + +class VocabularySocketControllerTest extends BaseWebSocketControllerTestRunner { + + @MockBean + IdentifierResolver idResolver; + + @MockBean + VocabularyService vocabularyService; + + @SpyBean + VocabularySocketController sut; + + Vocabulary vocabulary; + + String fragment; + + String namespace; + + StompHeaderAccessor messageHeaders; + + @BeforeEach + public void setup() { + vocabulary = Generator.generateVocabularyWithId(); + fragment = IdentifierResolver.extractIdentifierFragment(vocabulary.getUri()).substring(1); + namespace = vocabulary.getUri().toString().substring(0, vocabulary.getUri().toString().lastIndexOf('/')); + when(idResolver.resolveIdentifier(namespace, fragment)).thenReturn(vocabulary.getUri()); + when(vocabularyService.getReference(vocabulary.getUri())).thenReturn(vocabulary); + + messageHeaders = StompHeaderAccessor.create(StompCommand.MESSAGE); + messageHeaders.setSessionId("0"); + messageHeaders.setSubscriptionId("0"); + Authentication auth = Environment.setCurrentUser(Generator.generateUserAccountWithPassword()); + messageHeaders.setUser(auth); + messageHeaders.setSessionAttributes(new HashMap<>()); + } + + @Test + void validateVocabularyValidatesContents() { + messageHeaders.setContentLength(0); + messageHeaders.setHeader("namespace", namespace); + messageHeaders.setDestination("/vocabularies/" + fragment + "/validate"); + + this.serverInboundChannel.send(MessageBuilder.withPayload("").setHeaders(messageHeaders).build()); + + verify(vocabularyService).validateContents(vocabulary); + } + + @Test + void validateVocabularyReturnsValidationResults() { + messageHeaders.setContentLength(0); + messageHeaders.setHeader("namespace", namespace); + messageHeaders.setDestination("/vocabularies/" + fragment + "/validate"); + + final ValidationResult validationResult = new ValidationResult().setTermUri(Generator.generateUri()) + .setResultPath(Generator.generateUri()) + .setMessage(MultilingualString.create("message", "en")) + .setSeverity(Generator.generateUri()) + .setIssueCauseUri(Generator.generateUri()); + final List validationResults = List.of(validationResult); + when(vocabularyService.validateContents(vocabulary)).thenReturn(validationResults); + + this.serverInboundChannel.send(MessageBuilder.withPayload("").setHeaders(messageHeaders).build()); + + assertEquals(1, this.brokerChannelInterceptor.getMessages().size()); + Message reply = this.brokerChannelInterceptor.getMessages().get(0); + + assertNotNull(reply); + StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(reply); + // as reply is sent to a common channel for all vocabularies, there must be header with vocabulary uri + assertEquals(vocabulary.getUri().toString(), replyHeaders.getFirstNativeHeader("vocabulary"), "Invalid or missing vocabulary header in the reply"); + + Optional> payload = readPayload(reply); + assertTrue(payload.isPresent()); + assertEquals(validationResults, payload.get()); + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/WebSocketExceptionHandlerTest.java b/src/test/java/cz/cvut/kbss/termit/websocket/WebSocketExceptionHandlerTest.java new file mode 100644 index 000000000..52df68045 --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/websocket/WebSocketExceptionHandlerTest.java @@ -0,0 +1,59 @@ +package cz.cvut.kbss.termit.websocket; + +import cz.cvut.kbss.termit.environment.Environment; +import cz.cvut.kbss.termit.environment.Generator; +import cz.cvut.kbss.termit.exception.PersistenceException; +import cz.cvut.kbss.termit.websocket.handler.WebSocketExceptionHandler; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.boot.test.mock.mockito.MockBean; +import org.springframework.boot.test.mock.mockito.SpyBean; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.security.core.Authentication; + +import java.util.HashMap; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.notNull; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class WebSocketExceptionHandlerTest extends BaseWebSocketControllerTestRunner { + + @SpyBean + WebSocketExceptionHandler sut; + + @MockBean + VocabularySocketController controller; + + StompHeaderAccessor messageHeaders; + + @BeforeEach + public void setup() { + messageHeaders = StompHeaderAccessor.create(StompCommand.MESSAGE); + messageHeaders.setSessionId("0"); + messageHeaders.setSubscriptionId("0"); + Authentication auth = Environment.setCurrentUser(Generator.generateUserAccountWithPassword()); + messageHeaders.setUser(auth); + messageHeaders.setSessionAttributes(new HashMap<>()); + messageHeaders.setContentLength(0); + messageHeaders.setHeader("namespace", "namespace"); + messageHeaders.setDestination("/vocabularies/fragment/validate"); + } + + void sendMessage() { + this.serverInboundChannel.send(MessageBuilder.withPayload("").setHeaders(messageHeaders).build()); + } + + @Test + void handlerIsCalledForPersistenceException() { + final PersistenceException e = new PersistenceException(new Exception("mocked exception")); + when(controller.validateVocabulary(any(), any())).thenThrow(e); + sendMessage(); + verify(sut).persistenceException(notNull(), eq(e)); + } + +} diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/util/CachingChannelInterceptor.java b/src/test/java/cz/cvut/kbss/termit/websocket/util/CachingChannelInterceptor.java new file mode 100644 index 000000000..43a454b81 --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/websocket/util/CachingChannelInterceptor.java @@ -0,0 +1,32 @@ +package cz.cvut.kbss.termit.websocket.util; + +import org.jetbrains.annotations.NotNull; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.support.ChannelInterceptor; + +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; + +/** + * Caches any message sent to the intercepted channel + */ +public class CachingChannelInterceptor implements ChannelInterceptor { + + private final BlockingQueue> messages = new ArrayBlockingQueue<>(100); + + @Override + public Message preSend(@NotNull Message message, @NotNull MessageChannel channel) { + this.messages.add(message); + return message; + } + + public void reset() { + this.messages.clear(); + } + + public List> getMessages() { + return List.copyOf(messages); + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/util/ReturnValueCollectingSimpMessagingTemplate.java b/src/test/java/cz/cvut/kbss/termit/websocket/util/ReturnValueCollectingSimpMessagingTemplate.java new file mode 100644 index 000000000..6a9dd91a5 --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/websocket/util/ReturnValueCollectingSimpMessagingTemplate.java @@ -0,0 +1,43 @@ +package cz.cvut.kbss.termit.websocket.util; + +import org.jetbrains.annotations.NotNull; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.core.MessagePostProcessor; +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.messaging.support.MessageBuilder; + +import java.util.Map; +import java.util.UUID; + +/** + * Intercepts doConvert method and caches the returned payload before conversion + * mapped by resulting message id. + * Allows reading raw-returned values before serialization. + */ +public class ReturnValueCollectingSimpMessagingTemplate extends SimpMessagingTemplate { + + public static final String MESSAGE_IDENTIFIER_HEADER = "test-message-id"; + + private final Map returnedValuesMap; + + public ReturnValueCollectingSimpMessagingTemplate(MessageChannel messageChannel, + Map returnedValuesMap) { + super(messageChannel); + this.returnedValuesMap = returnedValuesMap; + } + + @Override + protected @NotNull Message doConvert(@NotNull Object payload, Map headers, + MessagePostProcessor postProcessor) { + final Message converted = super.doConvert(payload, headers, postProcessor); + + UUID id = converted.getHeaders().getId(); + if (id == null) { + id = UUID.randomUUID(); + } + + returnedValuesMap.put(id, payload); + return MessageBuilder.fromMessage(converted).copyHeaders(Map.of(MESSAGE_IDENTIFIER_HEADER, id)).build(); + } +}