Skip to content

Commit

Permalink
GH-2609: Restore ClassLoader for SerMC & SimpleMC
Browse files Browse the repository at this point in the history
Fixes: #2609

The deserialization functionality must rely on the `ClassLoader` from the application context (at least, by default).

* Fix `SimpleMessageConverter` to accept `BeanClassLoaderAware` and use it for `ConfigurableObjectInputStream`
* Fix `SerializerMessageConverter` to accept `BeanClassLoaderAware`
* Remove reflection for the `new DirectFieldAccessor(deserializer).getPropertyValue("classLoader")`
from the `SerializerMessageConverter` since it is not this converter responsibility
to interfere into provided `Deserializer` logic

**Cherry-pick to `3.0.x`**
  • Loading branch information
artembilan committed Feb 6, 2024
1 parent 37d9641 commit e7be534
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 83 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,13 +26,14 @@

import org.springframework.amqp.core.Message;
import org.springframework.amqp.core.MessageProperties;
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.beans.factory.BeanClassLoaderAware;
import org.springframework.core.ConfigurableObjectInputStream;
import org.springframework.core.serializer.DefaultDeserializer;
import org.springframework.core.serializer.DefaultSerializer;
import org.springframework.core.serializer.Deserializer;
import org.springframework.core.serializer.Serializer;
import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;

/**
* Implementation of {@link MessageConverter} that can work with Strings or native objects
Expand All @@ -50,26 +51,26 @@
* @author Gary Russell
* @author Artem Bilan
*/
public class SerializerMessageConverter extends AllowedListDeserializingMessageConverter {
public class SerializerMessageConverter extends AllowedListDeserializingMessageConverter
implements BeanClassLoaderAware {

public static final String DEFAULT_CHARSET = StandardCharsets.UTF_8.name();

private volatile String defaultCharset = DEFAULT_CHARSET;
private String defaultCharset = DEFAULT_CHARSET;

private volatile Serializer<Object> serializer = new DefaultSerializer();
private Serializer<Object> serializer = new DefaultSerializer();

private volatile Deserializer<Object> deserializer = new DefaultDeserializer();
private Deserializer<Object> deserializer = new DefaultDeserializer();

private volatile boolean ignoreContentType = false;
private boolean ignoreContentType = false;

private volatile ClassLoader defaultDeserializerClassLoader;
private ClassLoader defaultDeserializerClassLoader = ClassUtils.getDefaultClassLoader();

private volatile boolean usingDefaultDeserializer = true;
private boolean usingDefaultDeserializer = true;

/**
* Flag to signal that the content type should be ignored and the deserializer used irrespective if it is a text
* message. Defaults to false, in which case the default encoding is used to convert a text message to a String.
*
* @param ignoreContentType the flag value to set
*/
public void setIgnoreContentType(boolean ignoreContentType) {
Expand All @@ -79,7 +80,6 @@ public void setIgnoreContentType(boolean ignoreContentType) {
/**
* Specify the default charset to use when converting to or from text-based Message body content. If not specified,
* the charset will be "UTF-8".
*
* @param defaultCharset The default charset.
*/
public void setDefaultCharset(@Nullable String defaultCharset) {
Expand All @@ -88,7 +88,6 @@ public void setDefaultCharset(@Nullable String defaultCharset) {

/**
* The serializer to use for converting Java objects to message bodies.
*
* @param serializer the serializer to set
*/
public void setSerializer(Serializer<Object> serializer) {
Expand All @@ -97,24 +96,16 @@ public void setSerializer(Serializer<Object> serializer) {

/**
* The deserializer to use for converting from message body to Java object.
*
* @param deserializer the deserializer to set
*/
public void setDeserializer(Deserializer<Object> deserializer) {
this.deserializer = deserializer;
if (this.deserializer.getClass().equals(DefaultDeserializer.class)) {
try {
this.defaultDeserializerClassLoader = (ClassLoader) new DirectFieldAccessor(deserializer)
.getPropertyValue("classLoader");
}
catch (Exception e) {
// no-op
}
this.usingDefaultDeserializer = true;
}
else {
this.usingDefaultDeserializer = false;
}
this.usingDefaultDeserializer = false;
}

@Override
public void setBeanClassLoader(ClassLoader classLoader) {
this.defaultDeserializerClassLoader = classLoader;
}

/**
Expand Down Expand Up @@ -170,17 +161,17 @@ private Object asString(Message message, MessageProperties properties) {

private Object deserialize(ByteArrayInputStream inputStream) throws IOException {
try (ObjectInputStream objectInputStream = new ConfigurableObjectInputStream(inputStream,
this.defaultDeserializerClassLoader) {

@Override
protected Class<?> resolveClass(ObjectStreamClass classDesc)
throws IOException, ClassNotFoundException {
Class<?> clazz = super.resolveClass(classDesc);
checkAllowedList(clazz);
return clazz;
}
this.defaultDeserializerClassLoader) {

@Override
protected Class<?> resolveClass(ObjectStreamClass classDesc)
throws IOException, ClassNotFoundException {
Class<?> clazz = super.resolveClass(classDesc);
checkAllowedList(clazz);
return clazz;
}

}) {
}) {
return objectInputStream.readObject();
}
catch (ClassNotFoundException ex) {
Expand All @@ -194,6 +185,7 @@ protected Class<?> resolveClass(ObjectStreamClass classDesc)
@Override
protected Message createMessage(Object object, MessageProperties messageProperties)
throws MessageConversionException {

byte[] bytes;
if (object instanceof String) {
try {
Expand All @@ -220,9 +212,8 @@ else if (object instanceof byte[]) {
bytes = output.toByteArray();
messageProperties.setContentType(MessageProperties.CONTENT_TYPE_SERIALIZED_OBJECT);
}
if (bytes != null) {
messageProperties.setContentLength(bytes.length);
}

messageProperties.setContentLength(bytes.length);
return new Message(bytes, messageProperties);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,6 +27,10 @@
import org.springframework.amqp.core.Message;
import org.springframework.amqp.core.MessageProperties;
import org.springframework.amqp.utils.SerializationUtils;
import org.springframework.beans.factory.BeanClassLoaderAware;
import org.springframework.core.ConfigurableObjectInputStream;
import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;

/**
* Implementation of {@link MessageConverter} that can work with Strings, Serializable
Expand All @@ -38,23 +42,30 @@
* @author Mark Fisher
* @author Oleg Zhurakousky
* @author Gary Russell
* @author Artem Bilan
*/
public class SimpleMessageConverter extends AllowedListDeserializingMessageConverter {
public class SimpleMessageConverter extends AllowedListDeserializingMessageConverter implements BeanClassLoaderAware {

public static final String DEFAULT_CHARSET = "UTF-8";

private volatile String defaultCharset = DEFAULT_CHARSET;
private String defaultCharset = DEFAULT_CHARSET;

private ClassLoader classLoader = ClassUtils.getDefaultClassLoader();

/**
* Specify the default charset to use when converting to or from text-based
* Message body content. If not specified, the charset will be "UTF-8".
*
* @param defaultCharset The default charset.
*/
public void setDefaultCharset(String defaultCharset) {
public void setDefaultCharset(@Nullable String defaultCharset) {
this.defaultCharset = (defaultCharset != null) ? defaultCharset : DEFAULT_CHARSET;
}

@Override
public void setBeanClassLoader(ClassLoader classLoader) {
this.classLoader = classLoader;
}

/**
* Converts from a AMQP Message to an Object.
*/
Expand All @@ -73,8 +84,7 @@ public Object fromMessage(Message message) throws MessageConversionException {
content = new String(message.getBody(), encoding);
}
catch (UnsupportedEncodingException e) {
throw new MessageConversionException(
"failed to convert text-based Message content", e);
throw new MessageConversionException("failed to convert text-based Message content", e);
}
}
else if (contentType != null &&
Expand All @@ -84,8 +94,7 @@ else if (contentType != null &&
createObjectInputStream(new ByteArrayInputStream(message.getBody())));
}
catch (IOException | IllegalArgumentException | IllegalStateException e) {
throw new MessageConversionException(
"failed to convert serialized Message content", e);
throw new MessageConversionException("failed to convert serialized Message content", e);
}
}
}
Expand All @@ -99,7 +108,9 @@ else if (contentType != null &&
* Creates an AMQP Message from the provided Object.
*/
@Override
protected Message createMessage(Object object, MessageProperties messageProperties) throws MessageConversionException {
protected Message createMessage(Object object, MessageProperties messageProperties)
throws MessageConversionException {

byte[] bytes = null;
if (object instanceof byte[]) {
bytes = (byte[]) object;
Expand All @@ -110,8 +121,7 @@ else if (object instanceof String) {
bytes = ((String) object).getBytes(this.defaultCharset);
}
catch (UnsupportedEncodingException e) {
throw new MessageConversionException(
"failed to convert to Message content", e);
throw new MessageConversionException("failed to convert to Message content", e);
}
messageProperties.setContentType(MessageProperties.CONTENT_TYPE_TEXT_PLAIN);
messageProperties.setContentEncoding(this.defaultCharset);
Expand All @@ -121,8 +131,7 @@ else if (object instanceof Serializable) {
bytes = SerializationUtils.serialize(object);
}
catch (IllegalArgumentException e) {
throw new MessageConversionException(
"failed to convert to serialized Message content", e);
throw new MessageConversionException("failed to convert to serialized Message content", e);
}
messageProperties.setContentType(MessageProperties.CONTENT_TYPE_SERIALIZED_OBJECT);
}
Expand All @@ -135,15 +144,15 @@ else if (object instanceof Serializable) {
}

/**
* Create an ObjectInputStream for the given InputStream and codebase. The default
* implementation creates an ObjectInputStream.
* Create an ObjectInputStream for the given InputStream. The default
* implementation creates an {@link ConfigurableObjectInputStream} against configured {@link ClassLoader}.
* The class for object to deserialize is checked against {@code allowedListPatterns}.
* @param is the InputStream to read from
* @return the new ObjectInputStream instance to use
* @throws IOException if creation of the ObjectInputStream failed
*/
@SuppressWarnings("deprecation")
protected ObjectInputStream createObjectInputStream(InputStream is) throws IOException {
return new ObjectInputStream(is) {
return new ConfigurableObjectInputStream(is, this.classLoader) {

@Override
protected Class<?> resolveClass(ObjectStreamClass classDesc) throws IOException, ClassNotFoundException {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,25 +18,18 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.charset.StandardCharsets;

import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import org.springframework.amqp.core.Message;
import org.springframework.amqp.core.MessageProperties;
import org.springframework.amqp.utils.test.TestUtils;
import org.springframework.core.serializer.DefaultDeserializer;
import org.springframework.core.serializer.Deserializer;

/**
* @author Mark Fisher
Expand Down Expand Up @@ -151,24 +144,6 @@ public void serializedObjectToMessage() throws Exception {
assertThat(deserializedObject).isEqualTo(testBean);
}

@SuppressWarnings("unchecked")
@Test
public void testDefaultDeserializerClassLoader() throws Exception {
SerializerMessageConverter converter = new SerializerMessageConverter();
ClassLoader loader = mock(ClassLoader.class);
Deserializer<Object> deserializer = new DefaultDeserializer(loader);
converter.setDeserializer(deserializer);
assertThat(TestUtils.getPropertyValue(converter, "defaultDeserializerClassLoader")).isSameAs(loader);
assertThat(TestUtils.getPropertyValue(converter, "usingDefaultDeserializer", Boolean.class)).isTrue();
Deserializer<Object> mock = mock(Deserializer.class);
converter.setDeserializer(mock);
assertThat(TestUtils.getPropertyValue(converter, "usingDefaultDeserializer", Boolean.class)).isFalse();
TestBean testBean = new TestBean("foo");
Message message = converter.toMessage(testBean, new MessageProperties());
converter.fromMessage(message);
verify(mock).deserialize(Mockito.any(InputStream.class));
}

@Test
public void messageConversionExceptionForClassNotFound() {
SerializerMessageConverter converter = new SerializerMessageConverter();
Expand Down

0 comments on commit e7be534

Please sign in to comment.