Skip to content

Commit

Permalink
GH-2891: Add rabbitConnection.addShutdownListener(this)
Browse files Browse the repository at this point in the history
Fixes: #2891
Issue link: #2891

The `ShutdownListener` is not registered into connections created by the `AbstractConnectionFactory`

* Fix `AbstractConnectionFactory.createBareConnection()` add itself into just created connection as a `ShutdownListener`
* Fix tests with mocks where `mockConnectionFactory.newConnection()` did not return an instance of `Connection`
  • Loading branch information
artembilan committed Nov 11, 2024
1 parent a2ac767 commit 5471679
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ public void handleRecovery(Recoverable recoverable) {

@Nullable
private BackOff connectionCreatingBackOff;

/**
* Create a new AbstractConnectionFactory for the given target ConnectionFactory, with no publisher connection
* factory.
Expand Down Expand Up @@ -580,8 +581,8 @@ public ConnectionFactory getPublisherConnectionFactory() {
protected final Connection createBareConnection() {
try {
String connectionName = this.connectionNameStrategy.obtainNewConnectionName(this);

com.rabbitmq.client.Connection rabbitConnection = connect(connectionName);
rabbitConnection.addShutdownListener(this);
Connection connection = new SimpleConnection(rabbitConnection, this.closeTimeout,
this.connectionCreatingBackOff == null ? null : this.connectionCreatingBackOff.start());
if (rabbitConnection instanceof AutorecoveringConnection auto) {
Expand Down Expand Up @@ -732,16 +733,8 @@ public String toString() {
}
}

private static final class ConnectionBlockedListener implements BlockedListener {

private final Connection connection;

private final ApplicationEventPublisher applicationEventPublisher;

ConnectionBlockedListener(Connection connection, ApplicationEventPublisher applicationEventPublisher) {
this.connection = connection;
this.applicationEventPublisher = applicationEventPublisher;
}
private record ConnectionBlockedListener(Connection connection, ApplicationEventPublisher applicationEventPublisher)
implements BlockedListener {

@Override
public void handleBlocked(String reason) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.willCallRealMethod;
Expand Down Expand Up @@ -64,11 +65,11 @@ public abstract class AbstractConnectionFactoryTests {

@Test
public void testWithListener() throws Exception {

com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock(com.rabbitmq.client.ConnectionFactory.class);
com.rabbitmq.client.Connection mockConnection = mock(com.rabbitmq.client.Connection.class);
com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock();
com.rabbitmq.client.Connection mockConnection = mock();

given(mockConnectionFactory.newConnection(any(ExecutorService.class), anyString())).willReturn(mockConnection);
given(mockConnectionFactory.newConnection(any(), anyList(), anyString())).willReturn(mockConnection);

final AtomicInteger called = new AtomicInteger(0);
AbstractConnectionFactory connectionFactory = createConnectionFactory(mockConnectionFactory);
Expand Down Expand Up @@ -125,9 +126,8 @@ public void onClose(Connection connection) {

@Test
public void testWithListenerRegisteredAfterOpen() throws Exception {

com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock(com.rabbitmq.client.ConnectionFactory.class);
com.rabbitmq.client.Connection mockConnection = mock(com.rabbitmq.client.Connection.class);
com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock();
com.rabbitmq.client.Connection mockConnection = mock();

given(mockConnectionFactory.newConnection(any(ExecutorService.class), anyString())).willReturn(mockConnection);

Expand Down Expand Up @@ -168,10 +168,9 @@ public void onClose(Connection connection) {

@Test
public void testCloseInvalidConnection() throws Exception {

com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock(com.rabbitmq.client.ConnectionFactory.class);
com.rabbitmq.client.Connection mockConnection1 = mock(com.rabbitmq.client.Connection.class);
com.rabbitmq.client.Connection mockConnection2 = mock(com.rabbitmq.client.Connection.class);
com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock();
com.rabbitmq.client.Connection mockConnection1 = mock();
com.rabbitmq.client.Connection mockConnection2 = mock();

given(mockConnectionFactory.newConnection(any(ExecutorService.class), anyString()))
.willReturn(mockConnection1, mockConnection2);
Expand All @@ -194,8 +193,7 @@ public void testCloseInvalidConnection() throws Exception {

@Test
public void testDestroyBeforeUsed() throws Exception {

com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock(com.rabbitmq.client.ConnectionFactory.class);
com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock();

AbstractConnectionFactory connectionFactory = createConnectionFactory(mockConnectionFactory);
connectionFactory.destroy();
Expand All @@ -205,7 +203,7 @@ public void testDestroyBeforeUsed() throws Exception {

@Test
public void testCreatesConnectionWithGivenFactory() {
com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock(com.rabbitmq.client.ConnectionFactory.class);
com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock();
willCallRealMethod().given(mockConnectionFactory).params(any(ExecutorService.class));
willCallRealMethod().given(mockConnectionFactory).setThreadFactory(any(ThreadFactory.class));
willCallRealMethod().given(mockConnectionFactory).getThreadFactory();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
Expand Down Expand Up @@ -117,7 +118,7 @@ void stringRepresentation() {
assertThat(ccf.toString()).contains(", addresses=[h3:1236, h4:1237]")
.doesNotContain("host")
.doesNotContain("port");
ccf.setAddressResolver(() -> {
ccf.setAddressResolver(() -> {
throw new IOException("test");
});
ccf.setPort(0);
Expand Down Expand Up @@ -710,7 +711,7 @@ public void testCheckoutLimitWithPublisherConfirmsLogicalAlreadyCloses() throws
willAnswer(invoc -> {
open.set(false); // so the logical close detects a closed delegate
return null;
}).given(mockChannel).basicPublish(any(), any(), anyBoolean(), any(), any());
}).given(mockChannel).basicPublish(any(), any(), anyBoolean(), any(), any());

CachingConnectionFactory ccf = new CachingConnectionFactory(mockConnectionFactory);
ccf.setExecutor(mock(ExecutorService.class));
Expand All @@ -722,7 +723,7 @@ public void testCheckoutLimitWithPublisherConfirmsLogicalAlreadyCloses() throws
rabbitTemplate.convertAndSend("foo", "bar");
open.set(true);
rabbitTemplate.convertAndSend("foo", "bar");
verify(mockChannel, times(2)).basicPublish(any(), any(), anyBoolean(), any(), any());
verify(mockChannel, times(2)).basicPublish(any(), any(), anyBoolean(), any(), any());
}

@Test
Expand Down Expand Up @@ -1300,7 +1301,6 @@ public void onClose(Connection connection) {
verify(mockConnections.get(3)).close(30000);
}


@Test
public void testWithConnectionFactoryCachedConnectionAndChannels() throws Exception {
com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock(com.rabbitmq.client.ConnectionFactory.class);
Expand Down Expand Up @@ -1644,6 +1644,8 @@ private void verifyChannelIs(Channel mockChannel, Channel channel) {
@Test
public void setAddressesEmpty() throws Exception {
ConnectionFactory mock = mock(com.rabbitmq.client.ConnectionFactory.class);
given(mock.newConnection(any(ExecutorService.class), anyString()))
.willReturn(mock(com.rabbitmq.client.Connection.class));
CachingConnectionFactory ccf = new CachingConnectionFactory(mock);
ccf.setExecutor(mock(ExecutorService.class));
ccf.setHost("abc");
Expand All @@ -1663,6 +1665,8 @@ public void setAddressesEmpty() throws Exception {
@Test
public void setAddressesOneHost() throws Exception {
ConnectionFactory mock = mock(com.rabbitmq.client.ConnectionFactory.class);
given(mock.newConnection(any(), anyList(), anyString()))
.willReturn(mock(com.rabbitmq.client.Connection.class));
CachingConnectionFactory ccf = new CachingConnectionFactory(mock);
ccf.setAddresses("mq1");
ccf.createConnection();
Expand All @@ -1674,16 +1678,18 @@ public void setAddressesOneHost() throws Exception {

@Test
public void setAddressesTwoHosts() throws Exception {
ConnectionFactory mock = mock(com.rabbitmq.client.ConnectionFactory.class);
ConnectionFactory mock = mock();
willReturn(true).given(mock).isAutomaticRecoveryEnabled();
willReturn(mock(com.rabbitmq.client.Connection.class)).given(mock).newConnection(any(), anyList(), anyString());
CachingConnectionFactory ccf = new CachingConnectionFactory(mock);
ccf.setAddresses("mq1,mq2");
ccf.createConnection();
verify(mock).isAutomaticRecoveryEnabled();
verify(mock).setAutomaticRecoveryEnabled(false);
verify(mock).newConnection(
isNull(),
argThat((ArgumentMatcher<List<Address>>) a -> a.size() == 2 && a.contains(new Address("mq1")) && a.contains(new Address("mq2"))),
argThat((ArgumentMatcher<List<Address>>) a -> a.size() == 2
&& a.contains(new Address("mq1")) && a.contains(new Address("mq2"))),
anyString());
verifyNoMoreInteractions(mock);
}
Expand All @@ -1692,7 +1698,9 @@ public void setAddressesTwoHosts() throws Exception {
public void setUri() throws Exception {
URI uri = new URI("amqp://localhost:1234/%2f");

ConnectionFactory mock = mock(com.rabbitmq.client.ConnectionFactory.class);
ConnectionFactory mock = mock();
given(mock.newConnection(any(ExecutorService.class), anyString()))
.willReturn(mock(com.rabbitmq.client.Connection.class));
CachingConnectionFactory ccf = new CachingConnectionFactory(mock);
ccf.setExecutor(mock(ExecutorService.class));

Expand Down Expand Up @@ -1854,12 +1862,12 @@ public void testFirstConnectionDoesntWait() throws IOException, TimeoutException
@SuppressWarnings("unchecked")
@Test
public void testShuffleRandom() throws IOException, TimeoutException {
com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock(com.rabbitmq.client.ConnectionFactory.class);
com.rabbitmq.client.Connection mockConnection = mock(com.rabbitmq.client.Connection.class);
com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock();
com.rabbitmq.client.Connection mockConnection = mock();
Channel mockChannel = mock(Channel.class);

given(mockConnectionFactory.newConnection((ExecutorService) isNull(), any(List.class), anyString()))
.willReturn(mockConnection);
given(mockConnectionFactory.newConnection(any(), anyList(), anyString()))
.willReturn(mockConnection);
given(mockConnection.createChannel()).willReturn(mockChannel);
given(mockChannel.isOpen()).willReturn(true);
given(mockConnection.isOpen()).willReturn(true);
Expand All @@ -1873,11 +1881,11 @@ public void testShuffleRandom() throws IOException, TimeoutException {
ArgumentCaptor<List<Address>> captor = ArgumentCaptor.forClass(List.class);
verify(mockConnectionFactory, times(100)).newConnection(isNull(), captor.capture(), anyString());
List<String> firstAddress = captor.getAllValues()
.stream()
.map(addresses -> addresses.get(0).getHost())
.distinct()
.sorted()
.collect(Collectors.toList());
.stream()
.map(addresses -> addresses.get(0).getHost())
.distinct()
.sorted()
.collect(Collectors.toList());
assertThat(firstAddress).containsExactly("host1", "host2", "host3");
}

Expand All @@ -1888,8 +1896,8 @@ public void testShuffleInOrder() throws IOException, TimeoutException {
com.rabbitmq.client.Connection mockConnection = mock(com.rabbitmq.client.Connection.class);
Channel mockChannel = mock(Channel.class);

given(mockConnectionFactory.newConnection((ExecutorService) isNull(), any(List.class), anyString()))
.willReturn(mockConnection);
given(mockConnectionFactory.newConnection(isNull(), anyList(), anyString()))
.willReturn(mockConnection);
given(mockConnection.createChannel()).willReturn(mockChannel);
given(mockChannel.isOpen()).willReturn(true);
given(mockConnection.isOpen()).willReturn(true);
Expand All @@ -1903,17 +1911,17 @@ public void testShuffleInOrder() throws IOException, TimeoutException {
ArgumentCaptor<List<Address>> captor = ArgumentCaptor.forClass(List.class);
verify(mockConnectionFactory, times(3)).newConnection(isNull(), captor.capture(), anyString());
List<String> connectAddresses = captor.getAllValues()
.stream()
.map(addresses -> addresses.get(0).getHost())
.collect(Collectors.toList());
.stream()
.map(addresses -> addresses.get(0).getHost())
.collect(Collectors.toList());
assertThat(connectAddresses).containsExactly("host1", "host2", "host3");
}

@Test
void testResolver() throws Exception {
com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock(com.rabbitmq.client.ConnectionFactory.class);
com.rabbitmq.client.Connection mockConnection = mock(com.rabbitmq.client.Connection.class);
Channel mockChannel = mock(Channel.class);
com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock();
com.rabbitmq.client.Connection mockConnection = mock();
Channel mockChannel = mock();

AddressResolver resolver = () -> Collections.singletonList(Address.parseAddress("foo:5672"));
given(mockConnectionFactory.newConnection(any(ExecutorService.class), eq(resolver), anyString()))
Expand All @@ -1934,7 +1942,7 @@ void testResolver() throws Exception {

@Test
void nullShutdownCause() {
com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock(com.rabbitmq.client.ConnectionFactory.class);
com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock();
AbstractConnectionFactory cf = createConnectionFactory(mockConnectionFactory);
AtomicBoolean connShutDown = new AtomicBoolean();
cf.addConnectionListener(new ConnectionListener() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,9 @@
import static org.mockito.Mockito.verify;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import org.junit.jupiter.api.Test;
Expand All @@ -50,7 +47,6 @@
import org.springframework.amqp.core.Exchange;
import org.springframework.amqp.core.Queue;
import org.springframework.amqp.rabbit.connection.CachingConnectionFactory;
import org.springframework.amqp.rabbit.connection.CachingConnectionFactory.CacheMode;
import org.springframework.amqp.rabbit.connection.Connection;
import org.springframework.amqp.rabbit.connection.ConnectionFactory;
import org.springframework.amqp.rabbit.connection.ConnectionListener;
Expand Down Expand Up @@ -104,47 +100,6 @@ public void testUnconditional() throws Exception {
verify(channel).queueBind("foo", "bar", "foo", new HashMap<>());
}

@Test
public void testNoDeclareWithCachedConnections() throws Exception {
com.rabbitmq.client.ConnectionFactory mockConnectionFactory = mock(com.rabbitmq.client.ConnectionFactory.class);

List<Channel> mockChannels = new ArrayList<>();

AtomicInteger connectionNumber = new AtomicInteger();
willAnswer(invocation -> {
com.rabbitmq.client.Connection connection = mock(com.rabbitmq.client.Connection.class);
AtomicInteger channelNumber = new AtomicInteger();
willAnswer(invocation1 -> {
Channel channel = mock(Channel.class);
given(channel.isOpen()).willReturn(true);
int channelNum = channelNumber.incrementAndGet();
given(channel.toString()).willReturn("mockChannel" + channelNum);
mockChannels.add(channel);
return channel;
}).given(connection).createChannel();
int connectionNum = connectionNumber.incrementAndGet();
given(connection.toString()).willReturn("mockConnection" + connectionNum);
given(connection.isOpen()).willReturn(true);
return connection;
}).given(mockConnectionFactory).newConnection((ExecutorService) null);

CachingConnectionFactory ccf = new CachingConnectionFactory(mockConnectionFactory);
ccf.setCacheMode(CacheMode.CONNECTION);
ccf.afterPropertiesSet();

RabbitAdmin admin = new RabbitAdmin(ccf);
GenericApplicationContext context = new GenericApplicationContext();
Queue queue = new Queue("foo");
context.getBeanFactory().registerSingleton("foo", queue);
context.refresh();
admin.setApplicationContext(context);
admin.afterPropertiesSet();
ccf.createConnection().close();
ccf.destroy();

assertThat(mockChannels.size()).as("Admin should not have created a channel").isEqualTo(0);
}

@Test
public void testUnconditionalWithExplicitFactory() throws Exception {
ConnectionFactory cf = mock(ConnectionFactory.class);
Expand Down

0 comments on commit 5471679

Please sign in to comment.