diff --git a/src/main/java/io/r2dbc/pool/ConnectionPool.java b/src/main/java/io/r2dbc/pool/ConnectionPool.java index 733ccd2..96e1236 100644 --- a/src/main/java/io/r2dbc/pool/ConnectionPool.java +++ b/src/main/java/io/r2dbc/pool/ConnectionPool.java @@ -48,6 +48,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiPredicate; import java.util.function.Consumer; @@ -98,21 +99,22 @@ public ConnectionPool(ConnectionPoolConfiguration configuration) { }); } - String acqName = String.format("Connection Acquisition from [%s]", configuration.getConnectionFactory()); - String timeoutMessage = String.format("Connection Acquisition timed out after %dms", this.maxAcquireTime.toMillis()); + String acqName = String.format("Connection acquisition from [%s]", configuration.getConnectionFactory()); + String timeoutMessage = String.format("Connection acquisition timed out after %dms", this.maxAcquireTime.toMillis()); Function> allocateValidation = getValidationFunction(configuration); Mono create = Mono.defer(() -> { AtomicReference> emitted = new AtomicReference<>(); + AtomicBoolean cancelled = new AtomicBoolean(); Mono mono = this.connectionPool.acquire() .doOnNext(emitted::set) .doOnSubscribe(subscription -> { if (logger.isDebugEnabled()) { - logger.debug("Obtaining new connection from the driver"); + logger.debug("Obtaining new connection from the pool"); } }) .flatMap(ref -> { @@ -128,24 +130,23 @@ public ConnectionPool(ConnectionPoolConfiguration configuration) { prepare = prepare == null ? postAllocate : prepare.then(postAllocate); } - return prepare == null ? Mono.just(ref) : prepare.thenReturn(ref).onErrorResume(throwable -> { - return ref.invalidate().then(Mono.error(throwable)); - }); - }) - .flatMap(ref -> { + Mono conn; + if (prepare == null) { + conn = getValidConnection(allocateValidation, ref); + } else { + conn = prepare.then(getValidConnection(allocateValidation, ref)); + } - PooledConnection connection = new PooledConnection(ref, this.preRelease); - return allocateValidation.apply(connection).thenReturn((Connection) connection).onErrorResume(throwable -> { - return ref.invalidate().then(Mono.error(throwable)); - }); + return conn.onErrorResume(throwable -> { + emitted.set(null); // prevent release on cancel + return ref.invalidate().then(Mono.error(throwable)); + }) + .doFinally(s -> cleanup(cancelled, emitted)) + .as(self -> Operators.discardOnCancel(self, () -> cancelled.set(true))); }) - .doOnCancel(() -> { - - PooledRef ref = emitted.get(); - if (ref != null && emitted.compareAndSet(ref, null)) { - ref.release().subscribe(); - } - }).name(acqName); + .as(self -> Operators.discardOnCancel(self, () -> cancelled.set(true))) + .name(acqName) + .doOnNext(it -> emitted.set(null)); if (!this.maxAcquireTime.isNegative()) { mono = mono.timeout(this.maxAcquireTime).onErrorMap(TimeoutException.class, e -> new R2dbcTimeoutException(timeoutMessage, e)); @@ -155,6 +156,24 @@ public ConnectionPool(ConnectionPoolConfiguration configuration) { this.create = configuration.getAcquireRetry() > 0 ? create.retry(configuration.getAcquireRetry()) : create; } + static void cleanup(AtomicBoolean cancelled, AtomicReference> emitted) { + + if (cancelled.compareAndSet(true, false)) { + + PooledRef savedRef = emitted.get(); + if (savedRef != null && emitted.compareAndSet(savedRef, null)) { + logger.debug("Releasing connection after cancellation"); + savedRef.release().subscribe(ignore -> { + }, e -> logger.warn("Error during release", e)); + } + } + } + + private Mono getValidConnection(Function> allocateValidation, PooledRef ref) { + PooledConnection connection = new PooledConnection(ref, this.preRelease); + return allocateValidation.apply(connection).thenReturn(connection); + } + private Function> getValidationFunction(ConnectionPoolConfiguration configuration) { String timeoutMessage = String.format("Validation timed out after %dms", this.maxAcquireTime.toMillis()); diff --git a/src/main/java/io/r2dbc/pool/MonoDiscardOnCancel.java b/src/main/java/io/r2dbc/pool/MonoDiscardOnCancel.java new file mode 100644 index 0000000..e58ed50 --- /dev/null +++ b/src/main/java/io/r2dbc/pool/MonoDiscardOnCancel.java @@ -0,0 +1,126 @@ +/* + * Copyright 2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.pool; + +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoOperator; +import reactor.core.publisher.Operators; +import reactor.util.Logger; +import reactor.util.Loggers; +import reactor.util.context.Context; + +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * A decorating operator that replays signals from its source to a {@link Subscriber} and drains the source upon {@link Subscription#cancel() cancel} and drops data signals until termination. + * Draining data is required to complete a particular request/response window and clear the protocol state as client code expects to start a request/response conversation without any previous + * response state. + */ +class MonoDiscardOnCancel extends MonoOperator { + + private static final Logger logger = Loggers.getLogger(MonoDiscardOnCancel.class); + + private final Runnable cancelConsumer; + + MonoDiscardOnCancel(Mono source, Runnable cancelConsumer) { + super(source); + this.cancelConsumer = cancelConsumer; + } + + @Override + public void subscribe(CoreSubscriber actual) { + this.source.subscribe(new MonoDiscardOnCancelSubscriber<>(actual, this.cancelConsumer)); + } + + static class MonoDiscardOnCancelSubscriber extends AtomicBoolean implements CoreSubscriber, Subscription { + + final CoreSubscriber actual; + + final Context ctx; + + final Runnable cancelConsumer; + + Subscription s; + + MonoDiscardOnCancelSubscriber(CoreSubscriber actual, Runnable cancelConsumer) { + + this.actual = actual; + this.ctx = actual.currentContext(); + this.cancelConsumer = cancelConsumer; + } + + @Override + public void onSubscribe(Subscription s) { + + if (Operators.validate(this.s, s)) { + this.s = s; + this.actual.onSubscribe(this); + } + } + + @Override + public void onNext(T t) { + + if (this.get()) { + Operators.onDiscard(t, this.ctx); + return; + } + + this.actual.onNext(t); + } + + @Override + public void onError(Throwable t) { + if (!this.get()) { + this.actual.onError(t); + } + } + + @Override + public void onComplete() { + if (!this.get()) { + this.actual.onComplete(); + } + } + + @Override + public void request(long n) { + this.s.request(n); + } + + @Override + public void cancel() { + + if (compareAndSet(false, true)) { + if (logger.isDebugEnabled()) { + logger.debug("received cancel signal"); + } + try { + this.cancelConsumer.run(); + } catch (Exception e) { + Operators.onErrorDropped(e, this.ctx); + } + this.s.request(Long.MAX_VALUE); + } + } + + } + +} diff --git a/src/main/java/io/r2dbc/pool/Operators.java b/src/main/java/io/r2dbc/pool/Operators.java new file mode 100644 index 0000000..f181fa1 --- /dev/null +++ b/src/main/java/io/r2dbc/pool/Operators.java @@ -0,0 +1,63 @@ +/* + * Copyright 2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.pool; + +import org.reactivestreams.Subscription; +import reactor.core.publisher.Mono; + +/** + * Operator utility. + * + * @since 0.8.8 + */ +final class Operators { + + private Operators() { + } + + /** + * Replay signals from {@link Mono the source} until cancellation. Drains the source for data signals if the subscriber cancels the subscription. + *

+ * Draining data is required to complete a particular request/response window and clear the protocol state as client code expects to start a request/response conversation without leaving + * previous frames on the stack. + * + * @param source the source to decorate. + * @param The type of values in both source and output sequences. + * @return decorated {@link Mono}. + */ + public static Mono discardOnCancel(Mono source) { + return new MonoDiscardOnCancel<>(source, () -> { + }); + } + + /** + * Replay signals from {@link Mono the source} until cancellation. Drains the source for data signals if the subscriber cancels the subscription. + *

+ * Draining data is required to complete a particular request/response window and clear the protocol state as client code expects to start a request/response conversation without leaving + * previous frames on the stack. + *

Propagate the {@link Subscription#cancel()} signal to a {@link Runnable consumer}. + * + * @param source the source to decorate. + * @param cancelConsumer {@link Runnable} notified when the resulting {@link Mono} receives a {@link Subscription#cancel() cancel} signal. + * @param The type of values in both source and output sequences. + * @return decorated {@link Mono}. + */ + public static Mono discardOnCancel(Mono source, Runnable cancelConsumer) { + return new MonoDiscardOnCancel<>(source, cancelConsumer); + } + +} diff --git a/src/main/java/io/r2dbc/pool/Validation.java b/src/main/java/io/r2dbc/pool/Validation.java index 158aec1..829d0c5 100644 --- a/src/main/java/io/r2dbc/pool/Validation.java +++ b/src/main/java/io/r2dbc/pool/Validation.java @@ -38,7 +38,7 @@ static Mono validate(Connection connection, ValidationDepth depth) { return Flux.from(connection.validate(depth)).handle((state, sink) -> { if (state) { - sink.complete(); + sink.next(state); return; } diff --git a/src/test/java/io/r2dbc/pool/ConnectionPoolUnitTests.java b/src/test/java/io/r2dbc/pool/ConnectionPoolUnitTests.java index 1f11973..dc02353 100644 --- a/src/test/java/io/r2dbc/pool/ConnectionPoolUnitTests.java +++ b/src/test/java/io/r2dbc/pool/ConnectionPoolUnitTests.java @@ -28,6 +28,7 @@ import org.junit.jupiter.api.Test; import org.reactivestreams.Publisher; import org.springframework.util.ReflectionUtils; +import reactor.core.Disposable; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -43,6 +44,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; @@ -928,6 +930,44 @@ void shouldInvokePostAllocateInOrder() { assertThat(order).containsExactly("Lifecycle.postAllocate", "Lifecycle.postAllocate.subscribe", "postAllocate", "postAllocate.subscribe"); } + @Test + void cancelDuringAllocationShouldCompleteAtomically() throws InterruptedException { + + ConnectionFactory connectionFactoryMock = mock(ConnectionFactory.class); + ConnectionWithLifecycle connectionMock = mock(ConnectionWithLifecycle.class); + + CountDownLatch prepareLatch = new CountDownLatch(1); + CountDownLatch validateLatch = new CountDownLatch(1); + AtomicBoolean seenCancel = new AtomicBoolean(); + Mono prepare = Mono.empty().delayElement(Duration.ofMillis(100)).doOnSuccess(s -> prepareLatch.countDown()).doOnCancel(() -> { + seenCancel.set(true); + }); + Mono validate = Mono.just(true).delayElement(Duration.ofSeconds(1)).doOnSuccess(s -> validateLatch.countDown()).doOnCancel(() -> { + seenCancel.set(true); + }); + + when(connectionFactoryMock.create()).thenAnswer(it -> Mono.just(connectionMock)); + when(connectionMock.validate(any())).thenReturn(validate); + when(connectionMock.postAllocate()).thenReturn(prepare); + + ConnectionPoolConfiguration configuration = ConnectionPoolConfiguration.builder(connectionFactoryMock) + .build(); + + ConnectionPool pool = new ConnectionPool(configuration); + Disposable subscribe = pool.create().subscribe(); + prepareLatch.await(); + subscribe.dispose(); + validateLatch.await(); + + PoolMetrics poolMetrics = pool.getMetrics().get(); + await().atMost(Duration.ofSeconds(1)).until(() -> poolMetrics.idleSize() == 10); + + assertThat(seenCancel).isFalse(); + assertThat(poolMetrics.pendingAcquireSize()).isEqualTo(0); + assertThat(poolMetrics.allocatedSize()).isEqualTo(10); + assertThat(poolMetrics.idleSize()).isEqualTo(10); + } + interface ConnectionWithLifecycle extends Connection, Lifecycle { }