diff --git a/.travis.yml b/.travis.yml index 116d2d2ff..a40bdf55e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -23,7 +23,7 @@ matrix: - jdk: openjdk8 - jdk: openjdk11 env: SKIP_RELEASE=true - - jdk: openjdk12 + - jdk: openjdk13 env: SKIP_RELEASE=true env: diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 000000000..6ba6755a6 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,47 @@ +## Usage of JMH tasks + +Only execute specific benchmark(s) (wildcards are added before and after): +``` +../gradlew jmh --include="(BenchmarkPrimary|OtherBench)" +``` +If you want to specify the wildcards yourself, you can pass the full regexp: +``` +../gradlew jmh --fullInclude=.*MyBenchmark.* +``` + +Specify extra profilers: +``` +../gradlew jmh --profilers="gc,stack" +``` + +Prominent profilers (for full list call `jmhProfilers` task): +- comp - JitCompilations, tune your iterations +- stack - which methods used most time +- gc - print garbage collection stats +- hs_thr - thread usage + +Change report format from JSON to one of [CSV, JSON, NONE, SCSV, TEXT]: +``` +./gradlew jmh --format=csv +``` + +Specify JVM arguments: +``` +../gradlew jmh --jvmArgs="-Dtest.cluster=local" +``` + +Run in verification mode (execute benchmarks with minimum of fork/warmup-/benchmark-iterations): +``` +../gradlew jmh --verify=true +``` + +## Comparing with the baseline +If you wish you run two sets of benchmarks, one for the current change and another one for the "baseline", +there is an additional task `jmhBaseline` that will use the latest release: +``` +../gradlew jmh jmhBaseline --include=MyBenchmark +``` + +## Resources +- http://tutorials.jenkov.com/java-performance/jmh.html (Introduction) +- http://hg.openjdk.java.net/code-tools/jmh/file/tip/jmh-samples/src/main/java/org/openjdk/jmh/samples/ (Samples) diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle new file mode 100644 index 000000000..fa1d6e04b --- /dev/null +++ b/benchmarks/build.gradle @@ -0,0 +1,163 @@ +apply plugin: 'java' +apply plugin: 'idea' + +configurations { + current + baseline { + resolutionStrategy.cacheChangingModulesFor 0, 'seconds' + } +} + +dependencies { + // Use the baseline to avoid using new APIs in the benchmarks + compileOnly "io.rsocket:rsocket-core:${perfBaselineVersion}" + compileOnly "io.rsocket:rsocket-transport-local:${perfBaselineVersion}" + + implementation "org.openjdk.jmh:jmh-core:1.21" + annotationProcessor "org.openjdk.jmh:jmh-generator-annprocess:1.21" + + current project(':rsocket-core') + current project(':rsocket-transport-local') + baseline "io.rsocket:rsocket-core:${perfBaselineVersion}", { + changing = true + } + baseline "io.rsocket:rsocket-transport-local:${perfBaselineVersion}", { + changing = true + } +} + +task jmhProfilers(type: JavaExec, description:'Lists the available profilers for the jmh task', group: 'Development') { + classpath = sourceSets.main.runtimeClasspath + main = 'org.openjdk.jmh.Main' + args '-lprof' +} + +task jmh(type: JmhExecTask, description: 'Executing JMH benchmarks') { + classpath = sourceSets.main.runtimeClasspath + configurations.current +} + +task jmhBaseline(type: JmhExecTask, description: 'Executing JMH baseline benchmarks') { + classpath = sourceSets.main.runtimeClasspath + configurations.baseline +} + +class JmhExecTask extends JavaExec { + + private String include; + private String fullInclude; + private String exclude; + private String format = "json"; + private String profilers; + private String jmhJvmArgs; + private String verify; + + public JmhExecTask() { + super(); + } + + public String getInclude() { + return include; + } + + @Option(option = "include", description="configure bench inclusion using substring") + public void setInclude(String include) { + this.include = include; + } + + public String getFullInclude() { + return fullInclude; + } + + @Option(option = "fullInclude", description = "explicitly configure bench inclusion using full JMH style regexp") + public void setFullInclude(String fullInclude) { + this.fullInclude = fullInclude; + } + + public String getExclude() { + return exclude; + } + + @Option(option = "exclude", description = "explicitly configure bench exclusion using full JMH style regexp") + public void setExclude(String exclude) { + this.exclude = exclude; + } + + public String getFormat() { + return format; + } + + @Option(option = "format", description = "configure report format") + public void setFormat(String format) { + this.format = format; + } + + public String getProfilers() { + return profilers; + } + + @Option(option = "profilers", description = "configure jmh profiler(s) to use, comma separated") + public void setProfilers(String profilers) { + this.profilers = profilers; + } + + public String getJmhJvmArgs() { + return jmhJvmArgs; + } + + @Option(option = "jvmArgs", description = "configure additional JMH JVM arguments, comma separated") + public void setJmhJvmArgs(String jvmArgs) { + this.jmhJvmArgs = jvmArgs; + } + + public String getVerify() { + return verify; + } + + @Option(option = "verify", description = "run in verify mode") + public void setVerify(String verify) { + this.verify = verify; + } + + @TaskAction + public void exec() { + setMain("org.openjdk.jmh.Main"); + File resultFile = getProject().file("build/reports/" + getName() + "/result." + format); + + if (include != null) { + args(".*" + include + ".*"); + } + else if (fullInclude != null) { + args(fullInclude); + } + + if(exclude != null) { + args("-e", exclude); + } + if(verify != null) { // execute benchmarks with the minimum amount of execution (only to check if they are working) + System.out.println("Running in verify mode"); + args("-f", 1); + args("-wi", 1); + args("-i", 1); + } + args("-foe", "true"); //fail-on-error + args("-v", "NORMAL"); //verbosity [SILENT, NORMAL, EXTRA] + if(profilers != null) { + for (String prof : profilers.split(",")) { + args("-prof", prof); + } + } + args("-jvmArgsPrepend", "-Xmx3072m"); + args("-jvmArgsPrepend", "-Xms3072m"); + if(jmhJvmArgs != null) { + for(String jvmArg : jmhJvmArgs.split(" ")) { + args("-jvmArgsPrepend", jvmArg); + } + } + args("-rf", format); + args("-rff", resultFile); + + System.out.println("\nExecuting JMH with: " + getArgs() + "\n"); + resultFile.getParentFile().mkdirs(); + + super.exec(); + } +} \ No newline at end of file diff --git a/rsocket-core/src/jmh/java/io/rsocket/MaxPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java similarity index 71% rename from rsocket-core/src/jmh/java/io/rsocket/MaxPerfSubscriber.java rename to benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java index ace985a39..2e6fa6acc 100644 --- a/rsocket-core/src/jmh/java/io/rsocket/MaxPerfSubscriber.java +++ b/benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java @@ -5,12 +5,12 @@ import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; -public class MaxPerfSubscriber implements CoreSubscriber { +public class MaxPerfSubscriber extends CountDownLatch implements CoreSubscriber { - final CountDownLatch latch = new CountDownLatch(1); final Blackhole blackhole; public MaxPerfSubscriber(Blackhole blackhole) { + super(1); this.blackhole = blackhole; } @@ -20,19 +20,18 @@ public void onSubscribe(Subscription s) { } @Override - public void onNext(Payload payload) { - payload.release(); + public void onNext(T payload) { blackhole.consume(payload); } @Override public void onError(Throwable t) { blackhole.consume(t); - latch.countDown(); + countDown(); } @Override public void onComplete() { - latch.countDown(); + countDown(); } } diff --git a/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java new file mode 100644 index 000000000..7a7a1fdd6 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java @@ -0,0 +1,16 @@ +package io.rsocket; + +import org.openjdk.jmh.infra.Blackhole; + +public class PayloadsMaxPerfSubscriber extends MaxPerfSubscriber { + + public PayloadsMaxPerfSubscriber(Blackhole blackhole) { + super(blackhole); + } + + @Override + public void onNext(Payload payload) { + payload.release(); + super.onNext(payload); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java new file mode 100644 index 000000000..efc116958 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java @@ -0,0 +1,16 @@ +package io.rsocket; + +import org.openjdk.jmh.infra.Blackhole; + +public class PayloadsPerfSubscriber extends PerfSubscriber { + + public PayloadsPerfSubscriber(Blackhole blackhole) { + super(blackhole); + } + + @Override + public void onNext(Payload payload) { + payload.release(); + super.onNext(payload); + } +} diff --git a/rsocket-core/src/jmh/java/io/rsocket/PerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PerfSubscriber.java similarity index 72% rename from rsocket-core/src/jmh/java/io/rsocket/PerfSubscriber.java rename to benchmarks/src/main/java/io/rsocket/PerfSubscriber.java index 98c5edd3b..92577d95c 100644 --- a/rsocket-core/src/jmh/java/io/rsocket/PerfSubscriber.java +++ b/benchmarks/src/main/java/io/rsocket/PerfSubscriber.java @@ -5,14 +5,14 @@ import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; -public class PerfSubscriber implements CoreSubscriber { +public class PerfSubscriber extends CountDownLatch implements CoreSubscriber { - final CountDownLatch latch = new CountDownLatch(1); final Blackhole blackhole; Subscription s; public PerfSubscriber(Blackhole blackhole) { + super(1); this.blackhole = blackhole; } @@ -23,8 +23,7 @@ public void onSubscribe(Subscription s) { } @Override - public void onNext(Payload payload) { - payload.release(); + public void onNext(T payload) { blackhole.consume(payload); s.request(1); } @@ -32,11 +31,11 @@ public void onNext(Payload payload) { @Override public void onError(Throwable t) { blackhole.consume(t); - latch.countDown(); + countDown(); } @Override public void onComplete() { - latch.countDown(); + countDown(); } } diff --git a/rsocket-core/src/jmh/java/io/rsocket/RSocketPerf.java b/benchmarks/src/main/java/io/rsocket/RSocketPerf.java similarity index 64% rename from rsocket-core/src/jmh/java/io/rsocket/RSocketPerf.java rename to benchmarks/src/main/java/io/rsocket/RSocketPerf.java index 476d6c814..0c6515140 100644 --- a/rsocket-core/src/jmh/java/io/rsocket/RSocketPerf.java +++ b/benchmarks/src/main/java/io/rsocket/RSocketPerf.java @@ -4,15 +4,20 @@ import io.rsocket.transport.local.LocalClientTransport; import io.rsocket.transport.local.LocalServerTransport; import io.rsocket.util.EmptyPayload; +import java.lang.reflect.Field; +import java.util.Queue; +import java.util.concurrent.locks.LockSupport; import java.util.stream.IntStream; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; import org.openjdk.jmh.annotations.Measurement; import org.openjdk.jmh.annotations.Mode; import org.openjdk.jmh.annotations.Scope; import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; import org.openjdk.jmh.annotations.Warmup; import org.openjdk.jmh.infra.Blackhole; import org.reactivestreams.Publisher; @@ -36,11 +41,26 @@ public class RSocketPerf { RSocket client; Closeable server; + Queue clientsQueue; + + @TearDown + public void tearDown() { + client.dispose(); + server.dispose(); + } + + @TearDown(Level.Iteration) + public void awaitToBeConsumed() { + while (!clientsQueue.isEmpty()) { + LockSupport.parkNanos(1000); + } + } @Setup - public void setUp() { + public void setUp() throws NoSuchFieldException, IllegalAccessException { server = RSocketFactory.receive() + .frameDecoder(PayloadDecoder.ZERO_COPY) .acceptor( (setup, sendingSocket) -> Mono.just( @@ -75,67 +95,73 @@ public Flux requestChannel(Publisher payloads) { client = RSocketFactory.connect() + .singleSubscriberRequester() .frameDecoder(PayloadDecoder.ZERO_COPY) .transport(LocalClientTransport.create("server")) .start() .block(); + + Field sendProcessorField = RSocketRequester.class.getDeclaredField("sendProcessor"); + sendProcessorField.setAccessible(true); + + clientsQueue = (Queue) sendProcessorField.get(client); } @Benchmark @SuppressWarnings("unchecked") - public PerfSubscriber fireAndForget(Blackhole blackhole) throws InterruptedException { - PerfSubscriber subscriber = new PerfSubscriber(blackhole); + public PayloadsPerfSubscriber fireAndForget(Blackhole blackhole) throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); client.fireAndForget(PAYLOAD).subscribe((CoreSubscriber) subscriber); - subscriber.latch.await(); + subscriber.await(); return subscriber; } @Benchmark - public PerfSubscriber requestResponse(Blackhole blackhole) throws InterruptedException { - PerfSubscriber subscriber = new PerfSubscriber(blackhole); + public PayloadsPerfSubscriber requestResponse(Blackhole blackhole) throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); client.requestResponse(PAYLOAD).subscribe(subscriber); - subscriber.latch.await(); + subscriber.await(); return subscriber; } @Benchmark - public PerfSubscriber requestStreamWithRequestByOneStrategy(Blackhole blackhole) + public PayloadsPerfSubscriber requestStreamWithRequestByOneStrategy(Blackhole blackhole) throws InterruptedException { - PerfSubscriber subscriber = new PerfSubscriber(blackhole); + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); client.requestStream(PAYLOAD).subscribe(subscriber); - subscriber.latch.await(); + subscriber.await(); return subscriber; } @Benchmark - public MaxPerfSubscriber requestStreamWithRequestAllStrategy(Blackhole blackhole) + public PayloadsMaxPerfSubscriber requestStreamWithRequestAllStrategy(Blackhole blackhole) throws InterruptedException { - MaxPerfSubscriber subscriber = new MaxPerfSubscriber(blackhole); + PayloadsMaxPerfSubscriber subscriber = new PayloadsMaxPerfSubscriber(blackhole); client.requestStream(PAYLOAD).subscribe(subscriber); - subscriber.latch.await(); + subscriber.await(); return subscriber; } @Benchmark - public PerfSubscriber requestChannelWithRequestByOneStrategy(Blackhole blackhole) + public PayloadsPerfSubscriber requestChannelWithRequestByOneStrategy(Blackhole blackhole) throws InterruptedException { - PerfSubscriber subscriber = new PerfSubscriber(blackhole); + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); client.requestChannel(PAYLOAD_FLUX).subscribe(subscriber); - subscriber.latch.await(); + subscriber.await(); return subscriber; } @Benchmark - public MaxPerfSubscriber requestChannelWithRequestAllStrategy(Blackhole blackhole) + public PayloadsMaxPerfSubscriber requestChannelWithRequestAllStrategy(Blackhole blackhole) throws InterruptedException { - MaxPerfSubscriber subscriber = new MaxPerfSubscriber(blackhole); + PayloadsMaxPerfSubscriber subscriber = new PayloadsMaxPerfSubscriber(blackhole); client.requestChannel(PAYLOAD_FLUX).subscribe(subscriber); - subscriber.latch.await(); + subscriber.await(); return subscriber; } diff --git a/rsocket-core/src/jmh/java/io/rsocket/StreamIdSupplierPerf.java b/benchmarks/src/main/java/io/rsocket/StreamIdSupplierPerf.java similarity index 100% rename from rsocket-core/src/jmh/java/io/rsocket/StreamIdSupplierPerf.java rename to benchmarks/src/main/java/io/rsocket/StreamIdSupplierPerf.java diff --git a/rsocket-core/src/jmh/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java b/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java similarity index 100% rename from rsocket-core/src/jmh/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java rename to benchmarks/src/main/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java diff --git a/rsocket-core/src/jmh/java/io/rsocket/frame/FrameTypePerf.java b/benchmarks/src/main/java/io/rsocket/frame/FrameTypePerf.java similarity index 100% rename from rsocket-core/src/jmh/java/io/rsocket/frame/FrameTypePerf.java rename to benchmarks/src/main/java/io/rsocket/frame/FrameTypePerf.java diff --git a/rsocket-core/src/jmh/java/io/rsocket/frame/PayloadFlyweightPerf.java b/benchmarks/src/main/java/io/rsocket/frame/PayloadFlyweightPerf.java similarity index 100% rename from rsocket-core/src/jmh/java/io/rsocket/frame/PayloadFlyweightPerf.java rename to benchmarks/src/main/java/io/rsocket/frame/PayloadFlyweightPerf.java diff --git a/benchmarks/src/main/java/io/rsocket/internal/UnicastVsDefaultMonoProcessorPerf.java b/benchmarks/src/main/java/io/rsocket/internal/UnicastVsDefaultMonoProcessorPerf.java new file mode 100644 index 000000000..963067ef0 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/internal/UnicastVsDefaultMonoProcessorPerf.java @@ -0,0 +1,46 @@ +package io.rsocket.internal; + +import io.rsocket.MaxPerfSubscriber; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import reactor.core.publisher.MonoProcessor; + +@BenchmarkMode({Mode.Throughput, Mode.SampleTime}) +@Fork(1) +@Warmup(iterations = 10) +@Measurement(iterations = 10, time = 20) +@State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +public class UnicastVsDefaultMonoProcessorPerf { + + @Benchmark + public void monoProcessorPerf(Blackhole bh) { + MaxPerfSubscriber subscriber = new MaxPerfSubscriber<>(bh); + MonoProcessor monoProcessor = MonoProcessor.create(); + monoProcessor.onNext(1); + monoProcessor.subscribe(subscriber); + + bh.consume(monoProcessor); + bh.consume(subscriber); + } + + @Benchmark + public void unicastMonoProcessorPerf(Blackhole bh) { + MaxPerfSubscriber subscriber = new MaxPerfSubscriber<>(bh); + UnicastMonoProcessor monoProcessor = UnicastMonoProcessor.create(); + monoProcessor.onNext(1); + monoProcessor.subscribe(subscriber); + + bh.consume(monoProcessor); + bh.consume(subscriber); + } +} diff --git a/rsocket-core/src/jmh/java/io/rsocket/metadata/WellKnownMimeTypePerf.java b/benchmarks/src/main/java/io/rsocket/metadata/WellKnownMimeTypePerf.java similarity index 100% rename from rsocket-core/src/jmh/java/io/rsocket/metadata/WellKnownMimeTypePerf.java rename to benchmarks/src/main/java/io/rsocket/metadata/WellKnownMimeTypePerf.java diff --git a/build.gradle b/build.gradle index c354869df..e834e21bd 100644 --- a/build.gradle +++ b/build.gradle @@ -15,13 +15,11 @@ */ plugins { - id 'com.gradle.build-scan' version '2.4.2' - id 'com.github.sherter.google-java-format' version '0.8' apply false - id 'com.jfrog.artifactory' version '4.9.10' apply false + id 'com.jfrog.artifactory' version '4.11.0' apply false id 'com.jfrog.bintray' version '1.8.4' apply false - id 'me.champeau.gradle.jmh' version '0.4.8' apply false - id 'io.spring.dependency-management' version '1.0.7.RELEASE' apply false + id 'me.champeau.gradle.jmh' version '0.5.0' apply false + id 'io.spring.dependency-management' version '1.0.8.RELEASE' apply false id 'io.morethan.jmhreport' version '0.9.0' apply false } @@ -35,7 +33,7 @@ subprojects { ext['netty-bom.version'] = '4.1.37.Final' ext['netty-boringssl.version'] = '2.0.25.Final' ext['hdrhistogram.version'] = '2.1.10' - ext['mockito.version'] = '2.25.1' + ext['mockito.version'] = '3.2.0' ext['slf4j.version'] = '1.7.25' ext['jmh.version'] = '1.21' ext['junit.version'] = '5.5.2' @@ -145,7 +143,6 @@ subprojects { } } } - } apply from: "${rootDir}/gradle/publications.gradle" diff --git a/ci/travis.sh b/ci/travis.sh index 9154da33b..d190a59ec 100755 --- a/ci/travis.sh +++ b/ci/travis.sh @@ -5,13 +5,24 @@ if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then echo -e "Building PR #$TRAVIS_PULL_REQUEST [$TRAVIS_PULL_REQUEST_SLUG/$TRAVIS_PULL_REQUEST_BRANCH => $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH]" ./gradlew build -elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" == "" ] && [ "$bintrayUser" != "" ] ; then +elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" == "" ] && [ "$bintrayUser" != "" ] && [ "$TRAVIS_BRANCH" == "develop" ] ; then - echo -e "Building Snapshot $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH" + echo -e "Building Develop Snapshot $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH/$TRAVIS_BUILD_NUMBER" ./gradlew \ -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" \ -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" \ -PversionSuffix="-SNAPSHOT" \ + -PbuildNumber="$TRAVIS_BUILD_NUMBER" \ + build artifactoryPublish --stacktrace + +elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" == "" ] && [ "$bintrayUser" != "" ] ; then + + echo -e "Building Branch Snapshot $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH/$TRAVIS_BUILD_NUMBER" + ./gradlew \ + -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" \ + -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" \ + -PversionSuffix="-${TRAVIS_BRANCH//\//-}-SNAPSHOT" \ + -PbuildNumber="$TRAVIS_BUILD_NUMBER" \ build artifactoryPublish --stacktrace elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" != "" ] && [ "$bintrayUser" != "" ] ; then @@ -21,6 +32,7 @@ elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" != "" ] && [ "$bin -Pversion="$TRAVIS_TAG" \ -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" \ -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" \ + -PbuildNumber="$TRAVIS_BUILD_NUMBER" \ build bintrayUpload --stacktrace else diff --git a/gradle.properties b/gradle.properties index eddfb6cb3..13a89e30c 100644 --- a/gradle.properties +++ b/gradle.properties @@ -11,4 +11,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # -version=1.0.0-RC5 +version=1.0.0-RC6 +perfBaselineVersion=1.0.0-RC5 diff --git a/gradle/artifactory.gradle b/gradle/artifactory.gradle index 7f4369242..cdffb2741 100644 --- a/gradle/artifactory.gradle +++ b/gradle/artifactory.gradle @@ -33,6 +33,10 @@ if (project.hasProperty('bintrayUser') && project.hasProperty('bintrayKey')) { defaults { publications(publishing.publications.maven) } + + if (project.hasProperty('buildNumber')) { + clientConfig.info.setBuildNumber(project.property('buildNumber').toString()) + } } } tasks.named("artifactoryPublish").configure { diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 29953ea14..5c2d1cf01 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 7c4388a92..94920145f 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-5.6.2-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-6.0.1-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index cccdd3d51..83f2acfdc 100755 --- a/gradlew +++ b/gradlew @@ -1,5 +1,21 @@ #!/usr/bin/env sh +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + ############################################################################## ## ## Gradle start up script for UN*X @@ -28,7 +44,7 @@ APP_NAME="Gradle" APP_BASE_NAME=`basename "$0"` # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS="" +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD="maximum" @@ -109,8 +125,8 @@ if $darwin; then GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" fi -# For Cygwin, switch paths to Windows format before running java -if $cygwin ; then +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then APP_HOME=`cygpath --path --mixed "$APP_HOME"` CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` JAVACMD=`cygpath --unix "$JAVACMD"` diff --git a/gradlew.bat b/gradlew.bat index e95643d6a..24467a141 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -1,3 +1,19 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + @if "%DEBUG%" == "" @echo off @rem ########################################################################## @rem @@ -14,7 +30,7 @@ set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS= +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" @rem Find java.exe if defined JAVA_HOME goto findJavaFromJavaHome diff --git a/rsocket-bom/build.gradle b/rsocket-bom/build.gradle index ca48a87c0..2efc20a91 100755 --- a/rsocket-bom/build.gradle +++ b/rsocket-bom/build.gradle @@ -22,9 +22,11 @@ plugins { description = 'RSocket Java Bill of materials.' +def excluded = ["rsocket-examples", "benchmarks"] + dependencies { constraints { - parent.subprojects.findAll { it.name != project.name }.sort { "$it.name" }.each { + parent.subprojects.findAll { it.name != project.name && !excluded.contains(it.name) } .sort { "$it.name" }.each { api it } } @@ -34,12 +36,6 @@ publishing { publications { maven(MavenPublication) { from components.javaPlatform - // remove scope information from published BOM - pom.withXml { - asNode().dependencyManagement.first().dependencies.first().each { - it.remove(it.scope.first()) - } - } } } } \ No newline at end of file diff --git a/rsocket-core/build.gradle b/rsocket-core/build.gradle index d62452619..ca2ae0e65 100644 --- a/rsocket-core/build.gradle +++ b/rsocket-core/build.gradle @@ -46,6 +46,4 @@ dependencies { testRuntimeOnly 'org.junit.vintage:junit-vintage-engine' } -description = "Core functionality for the RSocket library" - -apply from: 'jmh.gradle' +description = "Core functionality for the RSocket library" \ No newline at end of file diff --git a/rsocket-core/jmh.gradle b/rsocket-core/jmh.gradle deleted file mode 100644 index 2a2b4d7cd..000000000 --- a/rsocket-core/jmh.gradle +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -dependencies { - jmh configurations.api - jmh configurations.implementation - jmh 'org.openjdk.jmh:jmh-core' - jmh 'org.openjdk.jmh:jmh-generator-annprocess' - jmh 'io.projectreactor:reactor-test' - jmh project(':rsocket-transport-local') -} - -jmhCompileGeneratedClasses.enabled = false - -jmh { - includeTests = false - profilers = ['gc'] - resultFormat = 'JSON' - - jvmArgs = ['-XX:+UnlockCommercialFeatures', '-XX:+FlightRecorder'] - // jvmArgsAppend = ['-XX:+UseG1GC', '-Xms4g', '-Xmx4g'] -} - -jmhJar { - from project.configurations.jmh -} - -tasks.jmh.finalizedBy tasks.jmhReport - -jmhReport { - jmhResultPath = project.file('build/reports/jmh/results.json') - jmhReportOutput = project.file('build/reports/jmh') -} diff --git a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java index 7739a34c0..b87ed0570 100644 --- a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java @@ -30,9 +30,9 @@ public interface DuplexConnection extends Availability, Closeable { * Sends the source of Frames on this connection and returns the {@code Publisher} representing * the result of this send. * - *

Flow control

+ *

Flow control * - * The passed {@code Publisher} must + *

The passed {@code Publisher} must * * @param frames Stream of {@code Frame}s to send on the connection. * @return {@code Publisher} that completes when all the frames are written on the connection @@ -56,20 +56,20 @@ default Mono sendOne(ByteBuf frame) { /** * Returns a stream of all {@code Frame}s received on this connection. * - *

Completion

+ *

Completion * - * Returned {@code Publisher} MUST never emit a completion event ({@link + *

Returned {@code Publisher} MUST never emit a completion event ({@link * Subscriber#onComplete()}. * - *

Error

+ *

Error * - * Returned {@code Publisher} can error with various transport errors. If the underlying physical - * connection is closed by the peer, then the returned stream from here MUST emit an - * {@link ClosedChannelException}. + *

Returned {@code Publisher} can error with various transport errors. If the underlying + * physical connection is closed by the peer, then the returned stream from here MUST + * emit an {@link ClosedChannelException}. * - *

Multiple Subscriptions

+ *

Multiple Subscriptions * - * Returned {@code Publisher} is not required to support multiple concurrent subscriptions. + *

Returned {@code Publisher} is not required to support multiple concurrent subscriptions. * RSocket will never have multiple subscriptions to this source. Implementations MUST * emit an {@link IllegalStateException} for subsequent concurrent subscriptions, if they do not * support multiple concurrent subscriptions. diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java index b6c268464..44f64e550 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java @@ -44,7 +44,6 @@ import io.rsocket.util.MultiSubscriberRSocket; import java.time.Duration; import java.util.Objects; -import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -277,17 +276,7 @@ public ClientTransportAcceptor acceptor(Function acceptor) { } public ClientTransportAcceptor acceptor(Supplier> acceptor) { - return acceptor( - (SocketAcceptor) - (setup, sendingSocket) -> Mono.just(acceptor.get().apply(sendingSocket))); - } - - @Deprecated - public ClientTransportAcceptor acceptor( - BiFunction biAcceptor) { - return acceptor( - (SocketAcceptor) - (setup, sendingSocket) -> Mono.just(biAcceptor.apply(setup, sendingSocket))); + return acceptor((setup, sendingSocket) -> Mono.just(acceptor.get().apply(sendingSocket))); } public ClientTransportAcceptor acceptor(SocketAcceptor acceptor) { diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java index f921365da..5590a9df0 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java @@ -25,28 +25,43 @@ import io.netty.util.collection.IntObjectMap; import io.rsocket.exceptions.ConnectionErrorException; import io.rsocket.exceptions.Exceptions; -import io.rsocket.frame.*; +import io.rsocket.frame.CancelFrameFlyweight; +import io.rsocket.frame.ErrorFrameFlyweight; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.MetadataPushFrameFlyweight; +import io.rsocket.frame.PayloadFrameFlyweight; +import io.rsocket.frame.RequestChannelFrameFlyweight; +import io.rsocket.frame.RequestFireAndForgetFrameFlyweight; +import io.rsocket.frame.RequestNFrameFlyweight; +import io.rsocket.frame.RequestResponseFrameFlyweight; +import io.rsocket.frame.RequestStreamFrameFlyweight; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.RateLimitableRequestPublisher; import io.rsocket.internal.SynchronizedIntObjectHashMap; import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.internal.UnicastMonoEmpty; import io.rsocket.internal.UnicastMonoProcessor; import io.rsocket.keepalive.KeepAliveFramesAcceptor; import io.rsocket.keepalive.KeepAliveHandler; import io.rsocket.keepalive.KeepAliveSupport; import io.rsocket.lease.RequesterLeaseHandler; -import io.rsocket.util.OnceConsumer; +import io.rsocket.util.MonoLifecycleHandler; import java.nio.channels.ClosedChannelException; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.Consumer; import java.util.function.LongConsumer; +import java.util.function.Supplier; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.reactivestreams.Processor; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.publisher.*; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; +import reactor.core.publisher.UnicastProcessor; import reactor.util.concurrent.Queues; /** @@ -56,6 +71,11 @@ class RSocketRequester implements RSocket { private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = AtomicReferenceFieldUpdater.newUpdater( RSocketRequester.class, Throwable.class, "terminationError"); + private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); + + static { + CLOSED_CHANNEL_EXCEPTION.setStackTrace(new StackTraceElement[0]); + } private final DuplexConnection connection; private final PayloadDecoder payloadDecoder; @@ -91,11 +111,11 @@ class RSocketRequester implements RSocket { // DO NOT Change the order here. The Send processor must be subscribed to before receiving this.sendProcessor = new UnboundedProcessor<>(); - connection.onClose().doFinally(signalType -> terminate()).subscribe(null, errorConsumer); connection - .send(sendProcessor) - .doFinally(this::handleSendProcessorCancel) - .subscribe(null, this::handleSendProcessorError); + .onClose() + .doFinally(signalType -> tryTerminateOnConnectionClose()) + .subscribe(null, errorConsumer); + connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError); connection.receive().subscribe(this::handleIncomingFrames, errorConsumer); @@ -103,57 +123,13 @@ class RSocketRequester implements RSocket { KeepAliveSupport keepAliveSupport = new ClientKeepAliveSupport(allocator, keepAliveTickPeriod, keepAliveAckTimeout); this.keepAliveFramesAcceptor = - keepAliveHandler.start(keepAliveSupport, sendProcessor::onNext, this::terminate); + keepAliveHandler.start( + keepAliveSupport, sendProcessor::onNext, this::tryTerminateOnKeepAlive); } else { keepAliveFramesAcceptor = null; } } - private void terminate(KeepAlive keepAlive) { - String message = - String.format("No keep-alive acks for %d ms", keepAlive.getTimeout().toMillis()); - ConnectionErrorException err = new ConnectionErrorException(message); - setTerminationError(err); - errorConsumer.accept(err); - connection.dispose(); - } - - private void handleSendProcessorError(Throwable t) { - Throwable terminationError = this.terminationError; - Throwable err = terminationError != null ? terminationError : t; - receivers - .values() - .forEach( - subscriber -> { - try { - subscriber.onError(err); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); - - senders.values().forEach(RateLimitableRequestPublisher::cancel); - } - - private void handleSendProcessorCancel(SignalType t) { - if (SignalType.ON_ERROR == t) { - return; - } - - receivers - .values() - .forEach( - subscriber -> { - try { - subscriber.onError(new Throwable("closed connection")); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); - - senders.values().forEach(RateLimitableRequestPublisher::cancel); - } - @Override public Mono fireAndForget(Payload payload) { return handleFireAndForget(payload); @@ -208,23 +184,19 @@ private Mono handleFireAndForget(Payload payload) { final int streamId = streamIdSupplier.nextStreamId(receivers); - return emptyUnicastMono() - .doOnSubscribe( - new OnceConsumer() { - @Override - public void acceptOnce(@Nonnull Subscription subscription) { - ByteBuf requestFrame = - RequestFireAndForgetFrameFlyweight.encode( - allocator, - streamId, - false, - payload.hasMetadata() ? payload.sliceMetadata().retain() : null, - payload.sliceData().retain()); - payload.release(); - - sendProcessor.onNext(requestFrame); - } - }); + return UnicastMonoEmpty.newInstance( + () -> { + ByteBuf requestFrame = + RequestFireAndForgetFrameFlyweight.encode( + allocator, + streamId, + false, + payload.hasMetadata() ? payload.sliceMetadata().retain() : null, + payload.sliceData().retain()); + payload.release(); + + sendProcessor.onNext(requestFrame); + }); } private Mono handleRequestResponse(final Payload payload) { @@ -237,14 +209,11 @@ private Mono handleRequestResponse(final Payload payload) { int streamId = streamIdSupplier.nextStreamId(receivers); final UnboundedProcessor sendProcessor = this.sendProcessor; - UnicastMonoProcessor receiver = UnicastMonoProcessor.create(); - receivers.put(streamId, receiver); - - return receiver - .doOnSubscribe( - new OnceConsumer() { + UnicastMonoProcessor receiver = + UnicastMonoProcessor.create( + new MonoLifecycleHandler() { @Override - public void acceptOnce(@Nonnull Subscription subscription) { + public void doOnSubscribe() { final ByteBuf requestFrame = RequestResponseFrameFlyweight.encode( allocator, @@ -256,16 +225,23 @@ public void acceptOnce(@Nonnull Subscription subscription) { sendProcessor.onNext(requestFrame); } - }) - .doOnError(t -> sendProcessor.onNext(ErrorFrameFlyweight.encode(allocator, streamId, t))) - .doFinally( - s -> { - if (s == SignalType.CANCEL) { - sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); - } - receivers.remove(streamId); + @Override + public void doOnTerminal( + @Nonnull SignalType signalType, + @Nullable Payload element, + @Nullable Throwable e) { + if (signalType == SignalType.ON_ERROR) { + sendProcessor.onNext(ErrorFrameFlyweight.encode(allocator, streamId, e)); + } else if (signalType == SignalType.CANCEL) { + sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); + } + removeStreamReceiver(streamId); + } }); + receivers.put(streamId, receiver); + + return receiver; } private Flux handleRequestStream(final Payload payload) { @@ -318,7 +294,7 @@ public void accept(long n) { sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); } }) - .doFinally(s -> receivers.remove(streamId)); + .doFinally(s -> removeStreamReceiver(streamId)); } private Flux handleChannel(Flux request) { @@ -419,14 +395,7 @@ protected void hookOnError(Throwable t) { sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); } }) - .doFinally( - s -> { - receivers.remove(streamId); - RateLimitableRequestPublisher sender = senders.remove(streamId); - if (sender != null) { - sender.cancel(); - } - }); + .doFinally(s -> removeStreamReceiverAndSender(streamId)); } private Mono handleMetadataPush(Payload payload) { @@ -436,24 +405,14 @@ private Mono handleMetadataPush(Payload payload) { return Mono.error(err); } - return emptyUnicastMono() - .doOnSubscribe( - new OnceConsumer() { - @Override - public void acceptOnce(@Nonnull Subscription subscription) { - ByteBuf metadataPushFrame = - MetadataPushFrameFlyweight.encode(allocator, payload.sliceMetadata().retain()); - payload.release(); - - sendProcessor.onNext(metadataPushFrame); - } - }); - } + return UnicastMonoEmpty.newInstance( + () -> { + ByteBuf metadataPushFrame = + MetadataPushFrameFlyweight.encode(allocator, payload.sliceMetadata().retain()); + payload.release(); - private static UnicastMonoProcessor emptyUnicastMono() { - UnicastMonoProcessor result = UnicastMonoProcessor.create(); - result.onComplete(); - return result; + sendProcessor.onNext(metadataPushFrame); + }); } private Throwable checkAvailable() { @@ -472,40 +431,6 @@ private boolean contains(int streamId) { return receivers.containsKey(streamId); } - private void terminate() { - setTerminationError(new ClosedChannelException()); - leaseHandler.dispose(); - try { - receivers.values().forEach(this::cleanUpSubscriber); - senders.values().forEach(this::cleanUpLimitableRequestPublisher); - } finally { - senders.clear(); - receivers.clear(); - sendProcessor.dispose(); - } - } - - private void setTerminationError(Throwable error) { - TERMINATION_ERROR.compareAndSet(this, null, error); - } - - private synchronized void cleanUpLimitableRequestPublisher( - RateLimitableRequestPublisher limitableRequestPublisher) { - try { - limitableRequestPublisher.cancel(); - } catch (Throwable t) { - errorConsumer.accept(t); - } - } - - private synchronized void cleanUpSubscriber(Processor subscriber) { - try { - subscriber.onError(terminationError); - } catch (Throwable t) { - errorConsumer.accept(t); - } - } - private void handleIncomingFrames(ByteBuf frame) { try { int streamId = FrameHeaderFlyweight.streamId(frame); @@ -525,10 +450,7 @@ private void handleIncomingFrames(ByteBuf frame) { private void handleStreamZero(FrameType type, ByteBuf frame) { switch (type) { case ERROR: - RuntimeException error = Exceptions.from(frame); - setTerminationError(error); - errorConsumer.accept(error); - connection.dispose(); + tryTerminateOnZeroError(frame); break; case LEASE: leaseHandler.receive(frame); @@ -553,7 +475,7 @@ private void handleFrame(int streamId, FrameType type, ByteBuf frame) { } else { switch (type) { case ERROR: - receiver.onError(Exceptions.from(frame)); + receiver.onError(Exceptions.from(streamId, frame)); receivers.remove(streamId); break; case NEXT_COMPLETE: @@ -614,4 +536,86 @@ private void handleMissingResponseProcessor(int streamId, FrameType type, ByteBu // receiving a frame after a given stream has been cancelled/completed, // so ignore (cancellation is async so there is a race condition) } + + private void tryTerminateOnKeepAlive(KeepAlive keepAlive) { + tryTerminate( + () -> + new ConnectionErrorException( + String.format("No keep-alive acks for %d ms", keepAlive.getTimeout().toMillis()))); + } + + private void tryTerminateOnConnectionClose() { + tryTerminate(() -> CLOSED_CHANNEL_EXCEPTION); + } + + private void tryTerminateOnZeroError(ByteBuf errorFrame) { + tryTerminate(() -> Exceptions.from(0, errorFrame)); + } + + private void tryTerminate(Supplier errorSupplier) { + if (terminationError == null) { + Exception e = errorSupplier.get(); + if (TERMINATION_ERROR.compareAndSet(this, null, e)) { + terminate(e); + } + } + } + + private void terminate(Exception e) { + connection.dispose(); + leaseHandler.dispose(); + + synchronized (receivers) { + receivers + .values() + .forEach( + receiver -> { + try { + receiver.onError(e); + } catch (Throwable t) { + errorConsumer.accept(t); + } + }); + } + synchronized (senders) { + senders + .values() + .forEach( + sender -> { + try { + sender.cancel(); + } catch (Throwable t) { + errorConsumer.accept(t); + } + }); + } + senders.clear(); + receivers.clear(); + sendProcessor.dispose(); + errorConsumer.accept(e); + } + + private void removeStreamReceiver(int streamId) { + /*on termination receivers are explicitly cleared to avoid removing from map while iterating over one + of its views*/ + if (terminationError == null) { + receivers.remove(streamId); + } + } + + private void removeStreamReceiverAndSender(int streamId) { + /*on termination senders & receivers are explicitly cleared to avoid removing from map while iterating over one + of its views*/ + if (terminationError == null) { + receivers.remove(streamId); + RateLimitableRequestPublisher sender = senders.remove(streamId); + if (sender != null) { + sender.cancel(); + } + } + } + + private void handleSendProcessorError(Throwable t) { + connection.dispose(); + } } diff --git a/rsocket-core/src/main/java/io/rsocket/buffer/AbstractTupleByteBuf.java b/rsocket-core/src/main/java/io/rsocket/buffer/AbstractTupleByteBuf.java index eaff4caa0..a80605877 100644 --- a/rsocket-core/src/main/java/io/rsocket/buffer/AbstractTupleByteBuf.java +++ b/rsocket-core/src/main/java/io/rsocket/buffer/AbstractTupleByteBuf.java @@ -16,6 +16,7 @@ abstract class AbstractTupleByteBuf extends AbstractReferenceCountedByteBuf { static final int DEFAULT_DIRECT_MEMORY_CACHE_ALIGNMENT = SystemPropertyUtil.getInt("io.netty.allocator.directMemoryCacheAlignment", 0); static final ByteBuffer EMPTY_NIO_BUFFER = Unpooled.EMPTY_BUFFER.nioBuffer(); + static final int NOT_ENOUGH_BYTES_AT_MAX_CAPACITY_CODE = 3; final ByteBufAllocator allocator; final int capacity; @@ -294,7 +295,7 @@ public ByteBuf ensureWritable(int minWritableBytes) { @Override public int ensureWritable(int minWritableBytes, boolean force) { - return 0; + return NOT_ENOUGH_BYTES_AT_MAX_CAPACITY_CODE; } @Override @@ -549,7 +550,7 @@ public int writeCharSequence(CharSequence sequence, Charset charset) { @Override public ByteBuffer internalNioBuffer(int index, int length) { - throw new UnsupportedOperationException(); + return nioBuffer(index, length); } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/buffer/Tuple2ByteBuf.java b/rsocket-core/src/main/java/io/rsocket/buffer/Tuple2ByteBuf.java index 66c68009a..ba6620cb0 100644 --- a/rsocket-core/src/main/java/io/rsocket/buffer/Tuple2ByteBuf.java +++ b/rsocket-core/src/main/java/io/rsocket/buffer/Tuple2ByteBuf.java @@ -134,6 +134,12 @@ public ByteBuffer[] _nioBuffers(int index, int length) { @Override public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + checkDstIndex(index, length, dstIndex, dst.capacity()); + if (length == 0) { + return this; + } + + // FIXME: check twice here long ri = calculateRelativeIndex(index); index = (int) (ri & Integer.MAX_VALUE); switch ((int) ((ri & MASK) >>> 32L)) { @@ -165,20 +171,22 @@ public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { @Override public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { ByteBuf dstBuf = Unpooled.wrappedBuffer(dst); - int min = Math.min(dst.length, capacity); - return getBytes(0, dstBuf, index, min); + return getBytes(index, dstBuf, dstIndex, length); } @Override public ByteBuf getBytes(int index, ByteBuffer dst) { ByteBuf dstBuf = Unpooled.wrappedBuffer(dst); - int min = Math.min(dst.limit(), capacity); - return getBytes(0, dstBuf, index, min); + return getBytes(index, dstBuf); } @Override public ByteBuf getBytes(int index, final OutputStream out, int length) throws IOException { checkIndex(index, length); + if (length == 0) { + return this; + } + long ri = calculateRelativeIndex(index); index = (int) (ri & Integer.MAX_VALUE); switch ((int) ((ri & MASK) >>> 32L)) { @@ -354,18 +362,12 @@ protected void deallocate() { @Override public String toString(Charset charset) { - StringBuilder builder = new StringBuilder(3); + StringBuilder builder = new StringBuilder(capacity); builder.append(one.toString(charset)); builder.append(two.toString(charset)); return builder.toString(); } - @Override - public String toString(int index, int length, Charset charset) { - // TODO - make this smarter - return toString(charset).substring(index, length); - } - @Override public String toString() { return "Tuple2ByteBuf{" diff --git a/rsocket-core/src/main/java/io/rsocket/buffer/Tuple3ByteBuf.java b/rsocket-core/src/main/java/io/rsocket/buffer/Tuple3ByteBuf.java index d02b22586..be593019f 100644 --- a/rsocket-core/src/main/java/io/rsocket/buffer/Tuple3ByteBuf.java +++ b/rsocket-core/src/main/java/io/rsocket/buffer/Tuple3ByteBuf.java @@ -164,7 +164,7 @@ public ByteBuffer[] _nioBuffers(int index, int length) { ByteBuffer[] twoBuffer; ByteBuffer[] threeBuffer; int l = Math.min(twoReadableBytes - index, length); - twoBuffer = two.nioBuffers(index, length); + twoBuffer = two.nioBuffers(index, l); length -= l; if (length != 0) { threeBuffer = three.nioBuffers(threeReadIndex, length); @@ -235,15 +235,13 @@ public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { @Override public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { ByteBuf dstBuf = Unpooled.wrappedBuffer(dst); - int min = Math.min(dst.length, capacity); - return getBytes(0, dstBuf, index, min); + return getBytes(index, dstBuf, dstIndex, length); } @Override public ByteBuf getBytes(int index, ByteBuffer dst) { ByteBuf dstBuf = Unpooled.wrappedBuffer(dst); - int min = Math.min(dst.limit(), capacity); - return getBytes(0, dstBuf, index, min); + return getBytes(index, dstBuf); } @Override @@ -539,12 +537,6 @@ public String toString(Charset charset) { return builder.toString(); } - @Override - public String toString(int index, int length, Charset charset) { - // TODO - make this smarter - return toString(charset).substring(index, length); - } - @Override public String toString() { return "Tuple3ByteBuf{" diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java new file mode 100644 index 000000000..6315206b5 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java @@ -0,0 +1,51 @@ +package io.rsocket.exceptions; + +import io.rsocket.frame.ErrorType; + +public class CustomRSocketException extends RSocketException { + private static final long serialVersionUID = 7873267740343446585L; + + private final int errorCode; + + /** + * Constructs a new exception with the specified message. + * + * @param errorCode customizable error code. Should be in range [0x00000301-0xFFFFFFFE] + * @param message the message + * @throws NullPointerException if {@code message} is {@code null} + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public CustomRSocketException(int errorCode, String message) { + super(message); + if (errorCode > ErrorType.MAX_USER_ALLOWED_ERROR_CODE + && errorCode < ErrorType.MIN_USER_ALLOWED_ERROR_CODE) { + throw new IllegalArgumentException( + "Allowed errorCode value should be in range [0x00000301-0xFFFFFFFE]"); + } + this.errorCode = errorCode; + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param errorCode customizable error code. Should be in range [0x00000301-0xFFFFFFFE] + * @param message the message + * @param cause the cause of this exception + * @throws NullPointerException if {@code message} or {@code cause} is {@code null} + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public CustomRSocketException(int errorCode, String message, Throwable cause) { + super(message, cause); + if (errorCode > ErrorType.MAX_USER_ALLOWED_ERROR_CODE + && errorCode < ErrorType.MIN_USER_ALLOWED_ERROR_CODE) { + throw new IllegalArgumentException( + "Allowed errorCode value should be in range [0x00000301-0xFFFFFFFE]"); + } + this.errorCode = errorCode; + } + + @Override + public int errorCode() { + return errorCode; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java b/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java index 97de65a96..3a10410f0 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java @@ -34,36 +34,50 @@ private Exceptions() {} * @return a {@link RSocketException} that matches the error code in the Frame * @throws NullPointerException if {@code frame} is {@code null} */ - public static RuntimeException from(ByteBuf frame) { + public static RuntimeException from(int streamId, ByteBuf frame) { Objects.requireNonNull(frame, "frame must not be null"); int errorCode = ErrorFrameFlyweight.errorCode(frame); String message = ErrorFrameFlyweight.dataUtf8(frame); - switch (errorCode) { - case APPLICATION_ERROR: - return new ApplicationErrorException(message); - case CANCELED: - return new CanceledException(message); - case CONNECTION_CLOSE: - return new ConnectionCloseException(message); - case CONNECTION_ERROR: - return new ConnectionErrorException(message); - case INVALID: - return new InvalidException(message); - case INVALID_SETUP: - return new InvalidSetupException(message); - case REJECTED: - return new RejectedException(message); - case REJECTED_RESUME: - return new RejectedResumeException(message); - case REJECTED_SETUP: - return new RejectedSetupException(message); - case UNSUPPORTED_SETUP: - return new UnsupportedSetupException(message); - default: - return new IllegalArgumentException( - String.format("Invalid Error frame: %d '%s'", errorCode, message)); + if (streamId == 0) { + switch (errorCode) { + case INVALID_SETUP: + return new InvalidSetupException(message); + case UNSUPPORTED_SETUP: + return new UnsupportedSetupException(message); + case REJECTED_SETUP: + return new RejectedSetupException(message); + case REJECTED_RESUME: + return new RejectedResumeException(message); + case CONNECTION_ERROR: + return new ConnectionErrorException(message); + case CONNECTION_CLOSE: + return new ConnectionCloseException(message); + default: + return new IllegalArgumentException( + String.format("Invalid Error frame in Stream ID 0: 0x%08X '%s'", errorCode, message)); + } + } else { + switch (errorCode) { + case APPLICATION_ERROR: + return new ApplicationErrorException(message); + case REJECTED: + return new RejectedException(message); + case CANCELED: + return new CanceledException(message); + case INVALID: + return new InvalidException(message); + default: + if (errorCode >= MIN_USER_ALLOWED_ERROR_CODE + || errorCode <= MAX_USER_ALLOWED_ERROR_CODE) { + return new CustomRSocketException(errorCode, message); + } + return new IllegalArgumentException( + String.format( + "Invalid Error frame in Stream ID %d: 0x%08X '%s'", + streamId, errorCode, message)); + } } } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java index 55e23541e..df9d39ba8 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java @@ -8,17 +8,21 @@ public class ErrorFrameFlyweight { - // defined error codes + // defined zero stream id error codes public static final int INVALID_SETUP = 0x00000001; public static final int UNSUPPORTED_SETUP = 0x00000002; public static final int REJECTED_SETUP = 0x00000003; public static final int REJECTED_RESUME = 0x00000004; public static final int CONNECTION_ERROR = 0x00000101; public static final int CONNECTION_CLOSE = 0x00000102; + // defined non-zero stream id error codes public static final int APPLICATION_ERROR = 0x00000201; public static final int REJECTED = 0x00000202; public static final int CANCELED = 0x00000203; public static final int INVALID = 0x00000204; + // defined user-allowed error codes range + public static final int MIN_USER_ALLOWED_ERROR_CODE = 0x00000301; + public static final int MAX_USER_ALLOWED_ERROR_CODE = 0xFFFFFFFE; public static ByteBuf encode( ByteBufAllocator allocator, int streamId, Throwable t, ByteBuf data) { diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ErrorType.java b/rsocket-core/src/main/java/io/rsocket/frame/ErrorType.java index ccbff374e..ffd99930d 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/ErrorType.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/ErrorType.java @@ -11,7 +11,7 @@ public final class ErrorType { /** * Application layer logic generating a Reactive Streams onError event. Stream ID MUST be > 0. */ - public static final int APPLICATION_ERROR = 0x00000201;; + public static final int APPLICATION_ERROR = 0x00000201; /** * The Responder canceled the request but may have started processing it (similar to REJECTED but @@ -70,5 +70,14 @@ public final class ErrorType { */ public static final int UNSUPPORTED_SETUP = 0x00000002; + /** Minimum allowed user defined error code value */ + public static final int MIN_USER_ALLOWED_ERROR_CODE = 0x00000301; + + /** + * Maximum allowed user defined error code value. Note, the value is above signed integer maximum, + * so it will be negative after overflow. + */ + public static final int MAX_USER_ALLOWED_ERROR_CODE = 0xFFFFFFFE; + private ErrorType() {} } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java index 74186f1d1..692dcb363 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java @@ -11,7 +11,7 @@ class DefaultPayloadDecoder implements PayloadDecoder { @Override - public synchronized Payload apply(ByteBuf byteBuf) { + public Payload apply(ByteBuf byteBuf) { ByteBuf m; ByteBuf d; FrameType type = FrameHeaderFlyweight.frameType(byteBuf); diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoEmpty.java b/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoEmpty.java new file mode 100644 index 000000000..eb8a1aa11 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoEmpty.java @@ -0,0 +1,86 @@ +package io.rsocket.internal; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; + +/** + * Represents an empty publisher which only calls onSubscribe and onComplete. + * + *

This Publisher is effectively stateless and only a single instance exists. Use the {@link + * #instance()} method to obtain a properly type-parametrized view of it. + * + * @see Reactive-Streams-Commons + */ +public final class UnicastMonoEmpty extends Mono implements Scannable { + + final Runnable onSubscribe; + + volatile int once; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ONCE = + AtomicIntegerFieldUpdater.newUpdater(UnicastMonoEmpty.class, "once"); + + UnicastMonoEmpty(Runnable onSubscribe) { + this.onSubscribe = onSubscribe; + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { + onSubscribe.run(); + Operators.complete(actual); + } else { + Operators.error( + actual, new IllegalStateException("UnicastMonoEmpty allows only a single Subscriber")); + } + } + + /** + * Returns a properly parametrized instance of this empty Publisher. + * + * @param the output type + * @return a properly parametrized instance of this empty Publisher + */ + @SuppressWarnings("unchecked") + public static Mono newInstance(Runnable onSubscribe) { + return (Mono) new UnicastMonoEmpty(onSubscribe); + } + + @Override + @Nullable + public Object block(Duration m) { + if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { + onSubscribe.run(); + return null; + } else { + throw new IllegalStateException("UnicastMonoEmpty allows only a single Subscriber"); + } + } + + @Override + @Nullable + public Object block() { + if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { + onSubscribe.run(); + return null; + } else { + throw new IllegalStateException("UnicastMonoEmpty allows only a single Subscriber"); + } + } + + @Override + public Object scanUnsafe(Attr key) { + return null; // no particular key to be represented, still useful in hooks + } + + @Override + public String stepName() { + return "source(UnicastMonoEmpty)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java index d5958028a..af4c8768b 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java @@ -1,5 +1,22 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.internal; +import io.rsocket.util.MonoLifecycleHandler; import java.util.Objects; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; @@ -13,29 +30,60 @@ import reactor.core.Scannable; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; +import reactor.util.annotation.NonNull; import reactor.util.annotation.Nullable; import reactor.util.context.Context; public class UnicastMonoProcessor extends Mono implements Processor, CoreSubscriber, Disposable, Subscription, Scannable { + static final MonoLifecycleHandler DEFAULT_LIFECYCLE = new MonoLifecycleHandler() {}; + /** * Create a {@link UnicastMonoProcessor} that will eagerly request 1 on {@link - * #onSubscribe(Subscription)}, cache and emit the eventual result for 1 or N subscribers. + * #onSubscribe(Subscription)}, cache and emit the eventual result for a single subscriber. * * @param type of the expected value * @return A {@link UnicastMonoProcessor}. */ + @SuppressWarnings("unchecked") public static UnicastMonoProcessor create() { - return new UnicastMonoProcessor<>(); + return new UnicastMonoProcessor(DEFAULT_LIFECYCLE); + } + + /** + * Create a {@link UnicastMonoProcessor} that will eagerly request 1 on {@link + * #onSubscribe(Subscription)}, cache and emit the eventual result for a single subscriber. + * + * @param lifecycleHandler lifecycle handler + * @param type of the expected value + * @return A {@link UnicastMonoProcessor}. + */ + public static UnicastMonoProcessor create(MonoLifecycleHandler lifecycleHandler) { + return new UnicastMonoProcessor<>(lifecycleHandler); } - volatile CoreSubscriber actual; + /** Indicates this Subscription has no value and not requested yet. */ + static final int NO_SUBSCRIBER_NO_RESULT = 0; + /** Indicates this Subscription has no value and not requested yet. */ + static final int NO_SUBSCRIBER_HAS_RESULT = 1; + /** Indicates this Subscription has no value and not requested yet. */ + static final int NO_REQUEST_NO_RESULT = 4; + /** Indicates this Subscription has a value but not requested yet. */ + static final int NO_REQUEST_HAS_RESULT = 5; + /** Indicates this Subscription has been requested but there is no value yet. */ + static final int HAS_REQUEST_NO_RESULT = 6; + /** Indicates this Subscription has both request and value. */ + static final int HAS_REQUEST_HAS_RESULT = 7; + /** Indicates the Subscription has been cancelled. */ + static final int CANCELLED = 8; + + volatile int state; @SuppressWarnings("rawtypes") - static final AtomicReferenceFieldUpdater ACTUAL = - AtomicReferenceFieldUpdater.newUpdater( - UnicastMonoProcessor.class, CoreSubscriber.class, "actual"); + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(UnicastMonoProcessor.class, "state"); volatile int once; @@ -43,165 +91,296 @@ public static UnicastMonoProcessor create() { static final AtomicIntegerFieldUpdater ONCE = AtomicIntegerFieldUpdater.newUpdater(UnicastMonoProcessor.class, "once"); - Throwable error; - volatile boolean terminated; - O value; - volatile Subscription subscription; + + @SuppressWarnings("rawtypes") static final AtomicReferenceFieldUpdater UPSTREAM = AtomicReferenceFieldUpdater.newUpdater( UnicastMonoProcessor.class, Subscription.class, "subscription"); + CoreSubscriber actual; + + Throwable error; + O value; + + final MonoLifecycleHandler lifecycleHandler; + + UnicastMonoProcessor(MonoLifecycleHandler lifecycleHandler) { + this.lifecycleHandler = lifecycleHandler; + } + @Override - public final void cancel() { - if (isTerminated()) { - return; + @NonNull + public Context currentContext() { + final CoreSubscriber a = this.actual; + return a != null ? a.currentContext() : Context.empty(); + } + + @Override + public final void onSubscribe(Subscription subscription) { + if (Operators.setOnce(UPSTREAM, this, subscription)) { + subscription.request(Long.MAX_VALUE); } + } - final Subscription s = UPSTREAM.getAndSet(this, Operators.cancelledSubscription()); - if (s == Operators.cancelledSubscription()) { + @Override + public final void onComplete() { + onNext(null); + } + + @Override + public final void onError(Throwable cause) { + Objects.requireNonNull(cause, "onError cannot be null"); + + if (UPSTREAM.getAndSet(this, Operators.cancelledSubscription()) + == Operators.cancelledSubscription()) { + Operators.onErrorDropped(cause, currentContext()); return; } - if (s != null) { - s.cancel(); - } + complete(cause); } @Override - @SuppressWarnings("unchecked") - public void dispose() { - final Subscription s = UPSTREAM.getAndSet(this, Operators.cancelledSubscription()); - if (s == Operators.cancelledSubscription()) { + public final void onNext(@Nullable O value) { + final Subscription s; + if ((s = UPSTREAM.getAndSet(this, Operators.cancelledSubscription())) + == Operators.cancelledSubscription()) { + if (value != null) { + Operators.onNextDropped(value, currentContext()); + } return; } - final CancellationException e = new CancellationException("Disposed"); - error = e; - value = null; - terminated = true; - if (s != null) { - s.cancel(); - } + if (value == null) { + complete(); + } else { + if (s != null) { + s.cancel(); + } - final CoreSubscriber a = this.actual; - ACTUAL.lazySet(this, null); - if (a != null) { - a.onError(e); + complete(value); } } /** - * Return the produced {@link Throwable} error if any or null + * Tries to emit the value and complete the underlying subscriber or stores the value away until + * there is a request for it. * - * @return the produced {@link Throwable} error if any or null - */ - @Nullable - public final Throwable getError() { - return isTerminated() ? error : null; - } - - /** - * Indicates whether this {@code UnicastMonoProcessor} has been interrupted via cancellation. + *

Make sure this method is called at most once * - * @return {@code true} if this {@code UnicastMonoProcessor} is cancelled, {@code false} - * otherwise. + * @param v the value to emit */ - public boolean isCancelled() { - return isDisposed() && !isTerminated(); + private void complete(O v) { + for (; ; ) { + int state = this.state; + + // if state is >= HAS_CANCELLED or bit zero is set (*_HAS_VALUE) case, return + if ((state & ~HAS_REQUEST_NO_RESULT) != 0) { + this.value = null; + Operators.onDiscard(v, currentContext()); + return; + } + + if (state == HAS_REQUEST_NO_RESULT) { + if (STATE.compareAndSet(this, HAS_REQUEST_NO_RESULT, HAS_REQUEST_HAS_RESULT)) { + final Subscriber a = actual; + actual = null; + value = null; + lifecycleHandler.doOnTerminal(SignalType.ON_COMPLETE, v, null); + a.onNext(v); + a.onComplete(); + return; + } + } + setValue(v); + if (state == NO_REQUEST_NO_RESULT + && STATE.compareAndSet(this, NO_REQUEST_NO_RESULT, NO_REQUEST_HAS_RESULT)) { + return; + } + if (state == NO_SUBSCRIBER_NO_RESULT + && STATE.compareAndSet(this, NO_SUBSCRIBER_NO_RESULT, NO_SUBSCRIBER_HAS_RESULT)) { + return; + } + } } /** - * Indicates whether this {@code UnicastMonoProcessor} has been completed with an error. + * Tries to emit completion the underlying subscriber * - * @return {@code true} if this {@code UnicastMonoProcessor} was completed with an error, {@code - * false} otherwise. + *

Make sure this method is called at most once */ - public final boolean isError() { - return getError() != null; + private void complete() { + for (; ; ) { + int state = this.state; + + // if state is >= HAS_CANCELLED or bit zero is set (*_HAS_VALUE) case, return + if ((state & ~HAS_REQUEST_NO_RESULT) != 0) { + return; + } + + if (state == HAS_REQUEST_NO_RESULT || state == NO_REQUEST_NO_RESULT) { + if (STATE.compareAndSet(this, state, HAS_REQUEST_HAS_RESULT)) { + final Subscriber a = actual; + actual = null; + lifecycleHandler.doOnTerminal(SignalType.ON_COMPLETE, null, null); + a.onComplete(); + return; + } + } + if (state == NO_SUBSCRIBER_NO_RESULT + && STATE.compareAndSet(this, NO_SUBSCRIBER_NO_RESULT, NO_SUBSCRIBER_HAS_RESULT)) { + return; + } + } } /** - * Indicates whether this {@code UnicastMonoProcessor} has been terminated by the source producer - * with a success or an error. + * Tries to emit error the underlying subscriber or stores the value away until there is a request + * for it. * - * @return {@code true} if this {@code UnicastMonoProcessor} is successful, {@code false} - * otherwise. + *

Make sure this method is called at most once + * + * @param e the error to emit */ - public final boolean isTerminated() { - return terminated; - } + private void complete(Throwable e) { + for (; ; ) { + int state = this.state; - @Override - public boolean isDisposed() { - return subscription == Operators.cancelledSubscription(); - } + // if state is >= HAS_CANCELLED or bit zero is set (*_HAS_VALUE) case, return + if ((state & ~HAS_REQUEST_NO_RESULT) != 0) { + return; + } - @Override - public final void onComplete() { - onNext(null); + setError(e); + if (state == HAS_REQUEST_NO_RESULT || state == NO_REQUEST_NO_RESULT) { + if (STATE.compareAndSet(this, state, HAS_REQUEST_HAS_RESULT)) { + final Subscriber a = actual; + actual = null; + lifecycleHandler.doOnTerminal(SignalType.ON_ERROR, null, e); + a.onError(e); + return; + } + } + if (state == NO_SUBSCRIBER_NO_RESULT + && STATE.compareAndSet(this, NO_SUBSCRIBER_NO_RESULT, NO_SUBSCRIBER_HAS_RESULT)) { + return; + } + } } @Override - @SuppressWarnings("unchecked") - public final void onError(Throwable cause) { - Objects.requireNonNull(cause, "onError cannot be null"); + public void subscribe(CoreSubscriber actual) { + Objects.requireNonNull(actual, "subscribe"); - if (UPSTREAM.getAndSet(this, Operators.cancelledSubscription()) - == Operators.cancelledSubscription()) { - Operators.onErrorDropped(cause, currentContext()); - return; - } + if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { + final MonoLifecycleHandler lh = this.lifecycleHandler; - error = cause; - value = null; - terminated = true; + lh.doOnSubscribe(); - final CoreSubscriber a = actual; - ACTUAL.lazySet(this, null); - if (a != null) { - a.onError(cause); + this.actual = actual; + + int state = this.state; + + // possible states within the racing between [onNext / onComplete / onError / dispose] and + // setting subscriber + // are NO_SUBSCRIBER_[NO_RESULT or HAS_RESULT] + if (state == NO_SUBSCRIBER_NO_RESULT) { + if (STATE.compareAndSet(this, NO_SUBSCRIBER_NO_RESULT, NO_REQUEST_NO_RESULT)) { + state = NO_REQUEST_NO_RESULT; + } else { + // the possible false position is racing with [onNext / onError / onComplete / dispose] + // which are going to put the state in the NO_REQUEST_HAS_RESULT + STATE.set(this, NO_REQUEST_HAS_RESULT); + state = NO_REQUEST_HAS_RESULT; + } + } else { + STATE.set(this, NO_REQUEST_HAS_RESULT); + state = NO_REQUEST_HAS_RESULT; + } + + // check if state is with a result then there is a chance of immediate termination if there is + // no value + // e.g. [onError / onComplete / dispose] only + if (state == NO_REQUEST_HAS_RESULT && this.value == null) { + this.actual = null; + Throwable e = this.error; + // barrier to flush changes + STATE.set(this, HAS_REQUEST_HAS_RESULT); + if (e == null) { + lh.doOnTerminal(SignalType.ON_COMPLETE, null, null); + Operators.complete(actual); + } else { + lh.doOnTerminal(SignalType.ON_ERROR, null, e); + Operators.error(actual, e); + } + return; + } + + // call onSubscribe if has value in the result or no result delivered so far + actual.onSubscribe(this); + } else { + Operators.error( + actual, + new IllegalStateException("UnicastMonoProcessor allows only a single Subscriber")); } } @Override - @SuppressWarnings("unchecked") - public final void onNext(@Nullable O value) { - final Subscription s; - if ((s = UPSTREAM.getAndSet(this, Operators.cancelledSubscription())) - == Operators.cancelledSubscription()) { - if (value != null) { - Operators.onNextDropped(value, currentContext()); + public final void request(long n) { + if (Operators.validate(n)) { + for (; ; ) { + int s = state; + // if the any bits 1-31 are set, we are either in fusion mode (FUSED_*) + // or request has been called (HAS_REQUEST_*) + if ((s & ~NO_REQUEST_HAS_RESULT) != 0) { + return; + } + if (s == NO_REQUEST_HAS_RESULT) { + if (STATE.compareAndSet(this, NO_REQUEST_HAS_RESULT, HAS_REQUEST_HAS_RESULT)) { + final Subscriber a = actual; + final O v = value; + actual = null; + value = null; + lifecycleHandler.doOnTerminal(SignalType.ON_COMPLETE, v, null); + a.onNext(v); + a.onComplete(); + return; + } + } + if (STATE.compareAndSet(this, NO_REQUEST_NO_RESULT, HAS_REQUEST_NO_RESULT)) { + return; + } } - return; } + } - this.value = value; - terminated = true; - - final CoreSubscriber a = actual; - ACTUAL.lazySet(this, null); - if (value == null) { - if (a != null) { - a.onComplete(); - } - } else { - if (s != null) { + @Override + public final void cancel() { + if (STATE.getAndSet(this, CANCELLED) <= HAS_REQUEST_NO_RESULT) { + Operators.onDiscard(value, currentContext()); + value = null; + actual = null; + lifecycleHandler.doOnTerminal(SignalType.CANCEL, null, null); + final Subscription s = UPSTREAM.getAndSet(this, Operators.cancelledSubscription()); + if (s != null && s != Operators.cancelledSubscription()) { s.cancel(); } - - if (a != null) { - a.onNext(value); - a.onComplete(); - } } } @Override - public final void onSubscribe(Subscription subscription) { - if (Operators.setOnce(UPSTREAM, this, subscription)) { - subscription.request(Long.MAX_VALUE); + public void dispose() { + final Subscription s = UPSTREAM.getAndSet(this, Operators.cancelledSubscription()); + if (s == Operators.cancelledSubscription()) { + return; } + + if (s != null) { + s.cancel(); + } + + complete(new CancellationException("Disposed")); } /** @@ -215,7 +394,7 @@ public final void onSubscribe(Subscription subscription) { */ @Nullable public O peek() { - if (!isTerminated()) { + if (isCancelled()) { return null; } @@ -232,25 +411,75 @@ public O peek() { return null; } - @Override - public final void request(long n) { - Operators.validate(n); + /** + * Set the value internally, without impacting request tracking state. + * + * @param value the new value. + * @see #complete(Object) + */ + private void setValue(O value) { + this.value = value; + } + + /** + * Set the error internally, without impacting request tracking state. + * + * @param throwable the error. + * @see #complete(Object) + */ + private void setError(Throwable throwable) { + this.error = throwable; + } + + /** + * Return the produced {@link Throwable} error if any or null + * + * @return the produced {@link Throwable} error if any or null + */ + @Nullable + public final Throwable getError() { + return isDisposed() ? error : null; + } + + /** + * Indicates whether this {@code UnicastMonoProcessor} has been completed with an error. + * + * @return {@code true} if this {@code UnicastMonoProcessor} was completed with an error, {@code + * false} otherwise. + */ + public final boolean isError() { + return getError() != null; + } + + /** + * Indicates whether this {@code UnicastMonoProcessor} has been interrupted via cancellation. + * + * @return {@code true} if this {@code UnicastMonoProcessor} is cancelled, {@code false} + * otherwise. + */ + public boolean isCancelled() { + return state == CANCELLED; + } + + public final boolean isTerminated() { + int state = this.state; + return (state < CANCELLED && state % 2 == 1); } @Override - public Context currentContext() { - final CoreSubscriber a = this.actual; - return a != null ? a.currentContext() : Context.empty(); + public boolean isDisposed() { + int state = this.state; + return state == CANCELLED || (state < CANCELLED && state % 2 == 1); } @Override @Nullable public Object scanUnsafe(Attr key) { // touch guard - boolean c = isCancelled(); + int state = this.state; if (key == Attr.TERMINATED) { - return isTerminated(); + return (state < CANCELLED && state % 2 == 1); } if (key == Attr.PARENT) { return subscription; @@ -262,7 +491,7 @@ public Object scanUnsafe(Attr key) { return Integer.MAX_VALUE; } if (key == Attr.CANCELLED) { - return c; + return state == CANCELLED; } return null; } @@ -273,32 +502,6 @@ public Object scanUnsafe(Attr key) { * @return true if any {@link Subscriber} is actively subscribed */ public final boolean hasDownstream() { - return actual != null; - } - - @Override - public void subscribe(CoreSubscriber actual) { - Objects.requireNonNull(actual, "subscribe"); - if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { - actual.onSubscribe(this); - ACTUAL.lazySet(this, actual); - if (isTerminated()) { - Throwable ex = error; - if (ex != null) { - actual.onError(ex); - } else { - O v = value; - if (v != null) { - actual.onNext(v); - } - actual.onComplete(); - } - ACTUAL.lazySet(this, null); - } - } else { - Operators.error( - actual, - new IllegalStateException("UnicastMonoProcessor allows only a single Subscriber")); - } + return state > NO_SUBSCRIBER_HAS_RESULT && actual != null; } } diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java index 21c2daf93..2743f604d 100644 --- a/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java +++ b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java @@ -50,7 +50,7 @@ public enum WellKnownMimeType { AUDIO_OPUS("audio/opus", (byte) 0x12), AUDIO_VORBIS("audio/vorbis", (byte) 0x13), IMAGE_BMP("image/bmp", (byte) 0x14), - IMAGE_GIG("image/gif", (byte) 0x15), + IMAGE_GIF("image/gif", (byte) 0x15), IMAGE_HEIC_SEQUENCE("image/heic-sequence", (byte) 0x16), IMAGE_HEIC("image/heic", (byte) 0x17), IMAGE_HEIF_SEQUENCE("image/heif-sequence", (byte) 0x18), @@ -73,6 +73,7 @@ public enum WellKnownMimeType { // ... reserved for future use ... + MESSAGE_RSOCKET_AUTHENTICATION("message/x.rsocket.authentication.v0", (byte) 0x7C), MESSAGE_RSOCKET_TRACING_ZIPKIN("message/x.rsocket.tracing-zipkin.v0", (byte) 0x7D), MESSAGE_RSOCKET_ROUTING("message/x.rsocket.routing.v0", (byte) 0x7E), MESSAGE_RSOCKET_COMPOSITE_METADATA("message/x.rsocket.composite-metadata.v0", (byte) 0x7F); diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java new file mode 100644 index 000000000..f0f5cf54e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java @@ -0,0 +1,336 @@ +package io.rsocket.metadata.security; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.buffer.TupleByteBuf; +import io.rsocket.util.CharByteBufUtil; + +public class AuthMetadataFlyweight { + + static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 + + static final int USERNAME_BYTES_LENGTH = 1; + static final int AUTH_TYPE_ID_LENGTH = 1; + + static final char[] EMPTY_CHARS_ARRAY = new char[0]; + + private AuthMetadataFlyweight() {} + + /** + * Encode a Authentication CompositeMetadata payload using custom authentication type + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param customAuthType the custom mime type to encode. + * @param metadata the metadata value to encode. + * @throws IllegalArgumentException in case of {@code customAuthType} is non US_ASCII string or + * empty string or its length is greater than 128 bytes + */ + public static ByteBuf encodeMetadata( + ByteBufAllocator allocator, String customAuthType, ByteBuf metadata) { + + int actualASCIILength = ByteBufUtil.utf8Bytes(customAuthType); + if (actualASCIILength != customAuthType.length()) { + throw new IllegalArgumentException("custom auth type must be US_ASCII characters only"); + } + if (actualASCIILength < 1 || actualASCIILength > 128) { + throw new IllegalArgumentException( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + int capacity = 1 + actualASCIILength; + ByteBuf headerBuffer = allocator.buffer(capacity, capacity); + // encoded length is one less than actual length, since 0 is never a valid length, which gives + // wider representation range + headerBuffer.writeByte(actualASCIILength - 1); + + ByteBufUtil.reserveAndWriteUtf8(headerBuffer, customAuthType, actualASCIILength); + + return TupleByteBuf.of(allocator, headerBuffer, metadata); + } + + /** + * Encode a Authentication CompositeMetadata payload using custom authentication type + * + * @param allocator the {@link ByteBufAllocator} to create intermediate buffers as needed. + * @param authType the well-known mime type to encode. + * @param metadata the metadata value to encode. + * @throws IllegalArgumentException in case of {@code authType} is {@link + * WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} or {@link + * WellKnownAuthType#UNKNOWN_RESERVED_AUTH_TYPE} + */ + public static ByteBuf encodeMetadata( + ByteBufAllocator allocator, WellKnownAuthType authType, ByteBuf metadata) { + + if (authType == WellKnownAuthType.UNPARSEABLE_AUTH_TYPE + || authType == WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE) { + throw new IllegalArgumentException("only allowed AuthType should be used"); + } + + int capacity = AUTH_TYPE_ID_LENGTH; + ByteBuf headerBuffer = + allocator + .buffer(capacity, capacity) + .writeByte(authType.getIdentifier() | STREAM_METADATA_KNOWN_MASK); + + return TupleByteBuf.of(allocator, headerBuffer, metadata); + } + + /** + * Encode a Authentication CompositeMetadata payload using Simple Authentication format + * + * @throws IllegalArgumentException if the username length is greater than 255 + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param username the char sequence which represents user name. + * @param password the char sequence which represents user password. + */ + public static ByteBuf encodeSimpleMetadata( + ByteBufAllocator allocator, char[] username, char[] password) { + + int usernameLength = CharByteBufUtil.utf8Bytes(username); + if (usernameLength > 255) { + throw new IllegalArgumentException( + "Username should be shorter than or equal to 255 bytes length in UTF-8 encoding"); + } + + int passwordLength = CharByteBufUtil.utf8Bytes(password); + int capacity = AUTH_TYPE_ID_LENGTH + USERNAME_BYTES_LENGTH + usernameLength + passwordLength; + final ByteBuf buffer = + allocator + .buffer(capacity, capacity) + .writeByte(WellKnownAuthType.SIMPLE.getIdentifier() | STREAM_METADATA_KNOWN_MASK) + .writeByte(usernameLength); + + CharByteBufUtil.writeUtf8(buffer, username); + CharByteBufUtil.writeUtf8(buffer, password); + + return buffer; + } + + /** + * Encode a Authentication CompositeMetadata payload using Bearer Authentication format + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param token the char sequence which represents BEARER token. + */ + public static ByteBuf encodeBearerMetadata(ByteBufAllocator allocator, char[] token) { + + int tokenLength = CharByteBufUtil.utf8Bytes(token); + int capacity = AUTH_TYPE_ID_LENGTH + tokenLength; + final ByteBuf buffer = + allocator + .buffer(capacity, capacity) + .writeByte(WellKnownAuthType.BEARER.getIdentifier() | STREAM_METADATA_KNOWN_MASK); + + CharByteBufUtil.writeUtf8(buffer, token); + + return buffer; + } + + /** + * Encode a new Authentication Metadata payload information, first verifying if the passed {@link + * String} matches a {@link WellKnownAuthType} (in which case it will be encoded in a compressed + * fashion using the mime id of that type). + * + *

Prefer using {@link #encodeMetadata(ByteBufAllocator, String, ByteBuf)} if you already know + * that the mime type is not a {@link WellKnownAuthType}. + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param authType the mime type to encode, as a {@link String}. well known mime types are + * compressed. + * @param metadata the metadata value to encode. + * @see #encodeMetadata(ByteBufAllocator, WellKnownAuthType, ByteBuf) + * @see #encodeMetadata(ByteBufAllocator, String, ByteBuf) + */ + public static ByteBuf encodeMetadataWithCompression( + ByteBufAllocator allocator, String authType, ByteBuf metadata) { + WellKnownAuthType wkn = WellKnownAuthType.fromString(authType); + if (wkn == WellKnownAuthType.UNPARSEABLE_AUTH_TYPE) { + return AuthMetadataFlyweight.encodeMetadata(allocator, authType, metadata); + } else { + return AuthMetadataFlyweight.encodeMetadata(allocator, wkn, metadata); + } + } + + /** + * Get the first {@code byte} from a {@link ByteBuf} and check whether it is length or {@link + * WellKnownAuthType}. Assuming said buffer properly contains such a {@code byte} + * + * @param metadata byteBuf used to get information from + */ + public static boolean isWellKnownAuthType(ByteBuf metadata) { + byte lengthOrId = metadata.getByte(0); + return (lengthOrId & STREAM_METADATA_LENGTH_MASK) != lengthOrId; + } + + /** + * Read first byte from the given {@code metadata} and tries to convert it's value to {@link + * WellKnownAuthType}. + * + * @param metadata given metadata buffer to read from + * @return Return on of the know Auth types or {@link WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} if + * field's value is length or unknown auth type + * @throws IllegalStateException if not enough readable bytes in the given {@link ByteBuf} + */ + public static WellKnownAuthType decodeWellKnownAuthType(ByteBuf metadata) { + if (metadata.readableBytes() < 1) { + throw new IllegalStateException( + "Unable to decode Well Know Auth type. Not enough readable bytes"); + } + byte lengthOrId = metadata.readByte(); + int normalizedId = (byte) (lengthOrId & STREAM_METADATA_LENGTH_MASK); + + if (normalizedId != lengthOrId) { + return WellKnownAuthType.fromIdentifier(normalizedId); + } + + return WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + } + + /** + * Read up to 129 bytes from the given metadata in order to get the custom Auth Type + * + * @param metadata + * @return + */ + public static CharSequence decodeCustomAuthType(ByteBuf metadata) { + if (metadata.readableBytes() < 2) { + throw new IllegalStateException( + "Unable to decode custom Auth type. Not enough readable bytes"); + } + + byte encodedLength = metadata.readByte(); + if (encodedLength < 0) { + throw new IllegalStateException( + "Unable to decode custom Auth type. Incorrect auth type length"); + } + + // encoded length is realLength - 1 in order to avoid intersection with 0x00 authtype + int realLength = encodedLength + 1; + if (metadata.readableBytes() < realLength) { + throw new IllegalArgumentException( + "Unable to decode custom Auth type. Malformed length or auth type string"); + } + + return metadata.readCharSequence(realLength, CharsetUtil.US_ASCII); + } + + /** + * Read all remaining {@code bytes} from the given {@link ByteBuf} and return sliced + * representation of a payload + * + * @param metadata metadata to get payload from. Please note, the {@code metadata#readIndex} + * should be set to the beginning of the payload bytes + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if no bytes readable in the + * given one + */ + public static ByteBuf decodePayload(ByteBuf metadata) { + if (metadata.readableBytes() == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return metadata.readSlice(metadata.readableBytes()); + } + + /** + * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username + * length and the subsequent number of bytes equal to decoded length + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the username length byte + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if username length is zero + */ + public static ByteBuf decodeUsername(ByteBuf simpleAuthMetadata) { + short usernameLength = decodeUsernameLength(simpleAuthMetadata); + + if (usernameLength == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return simpleAuthMetadata.readSlice(usernameLength); + } + + /** + * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's + * password + * + * @param simpleAuthMetadata the given metadata to read password from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if password length is zero + */ + public static ByteBuf decodePassword(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return simpleAuthMetadata.readSlice(simpleAuthMetadata.readableBytes()); + } + /** + * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username + * length and the subsequent number of bytes equal to decoded length + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the username length byte + * @return {@code char[]} which represents UTF-8 username + */ + public static char[] decodeUsernameAsCharArray(ByteBuf simpleAuthMetadata) { + short usernameLength = decodeUsernameLength(simpleAuthMetadata); + + if (usernameLength == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(simpleAuthMetadata, usernameLength); + } + + /** + * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's + * password + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return {@code char[]} which represents UTF-8 password + */ + public static char[] decodePasswordAsCharArray(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(simpleAuthMetadata, simpleAuthMetadata.readableBytes()); + } + + /** + * Read all the remaining {@code bytes} from the given {@link ByteBuf} where the first byte is + * username length and the subsequent number of bytes equal to decoded length + * + * @param bearerAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return {@code char[]} which represents UTF-8 password + */ + public static char[] decodeBearerTokenAsCharArray(ByteBuf bearerAuthMetadata) { + if (bearerAuthMetadata.readableBytes() == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(bearerAuthMetadata, bearerAuthMetadata.readableBytes()); + } + + private static short decodeUsernameLength(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() < 1) { + throw new IllegalStateException( + "Unable to decode custom username. Not enough readable bytes"); + } + + short usernameLength = simpleAuthMetadata.readUnsignedByte(); + + if (simpleAuthMetadata.readableBytes() < usernameLength) { + throw new IllegalArgumentException( + "Unable to decode username. Malformed username length or content"); + } + + return usernameLength; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/security/WellKnownAuthType.java b/rsocket-core/src/main/java/io/rsocket/metadata/security/WellKnownAuthType.java new file mode 100644 index 000000000..bd4b656b8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/security/WellKnownAuthType.java @@ -0,0 +1,121 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata.security; + +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Enumeration of Well Known Auth Types, as defined in the eponymous extension. Such auth types are + * used in composite metadata (which can include routing and/or tracing metadata). Per + * specification, identifiers are between 0 and 127 (inclusive). + */ +public enum WellKnownAuthType { + UNPARSEABLE_AUTH_TYPE("UNPARSEABLE_AUTH_TYPE_DO_NOT_USE", (byte) -2), + UNKNOWN_RESERVED_AUTH_TYPE("UNKNOWN_YET_RESERVED_DO_NOT_USE", (byte) -1), + + SIMPLE("simple", (byte) 0x00), + BEARER("bearer", (byte) 0x01); + // ... reserved for future use ... + + static final WellKnownAuthType[] TYPES_BY_AUTH_ID; + static final Map TYPES_BY_AUTH_STRING; + + static { + // precompute an array of all valid auth ids, filling the blanks with the RESERVED enum + TYPES_BY_AUTH_ID = new WellKnownAuthType[128]; // 0-127 inclusive + Arrays.fill(TYPES_BY_AUTH_ID, UNKNOWN_RESERVED_AUTH_TYPE); + // also prepare a Map of the types by auth string + TYPES_BY_AUTH_STRING = new LinkedHashMap<>(128); + + for (WellKnownAuthType value : values()) { + if (value.getIdentifier() >= 0) { + TYPES_BY_AUTH_ID[value.getIdentifier()] = value; + TYPES_BY_AUTH_STRING.put(value.getString(), value); + } + } + } + + private final byte identifier; + private final String str; + + WellKnownAuthType(String str, byte identifier) { + this.str = str; + this.identifier = identifier; + } + + /** + * Find the {@link WellKnownAuthType} for the given identifier (as an {@code int}). Valid + * identifiers are defined to be integers between 0 and 127, inclusive. Identifiers outside of + * this range will produce the {@link #UNPARSEABLE_AUTH_TYPE}. Additionally, some identifiers in + * that range are still only reserved and don't have a type associated yet: this method returns + * the {@link #UNKNOWN_RESERVED_AUTH_TYPE} when passing such an identifier, which lets call sites + * potentially detect this and keep the original representation when transmitting the associated + * metadata buffer. + * + * @param id the looked up identifier + * @return the {@link WellKnownAuthType}, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is out + * of the specification's range, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is one that + * is merely reserved but unknown to this implementation. + */ + public static WellKnownAuthType fromIdentifier(int id) { + if (id < 0x00 || id > 0x7F) { + return UNPARSEABLE_AUTH_TYPE; + } + return TYPES_BY_AUTH_ID[id]; + } + + /** + * Find the {@link WellKnownAuthType} for the given {@link String} representation. If the + * representation is {@code null} or doesn't match a {@link WellKnownAuthType}, the {@link + * #UNPARSEABLE_AUTH_TYPE} is returned. + * + * @param authType the looked up auth type + * @return the matching {@link WellKnownAuthType}, or {@link #UNPARSEABLE_AUTH_TYPE} if none + * matches + */ + public static WellKnownAuthType fromString(String authType) { + if (authType == null) throw new IllegalArgumentException("type must be non-null"); + + // force UNPARSEABLE if by chance UNKNOWN_RESERVED_AUTH_TYPE's text has been used + if (authType.equals(UNKNOWN_RESERVED_AUTH_TYPE.str)) { + return UNPARSEABLE_AUTH_TYPE; + } + + return TYPES_BY_AUTH_STRING.getOrDefault(authType, UNPARSEABLE_AUTH_TYPE); + } + + /** @return the byte identifier of the auth type, guaranteed to be positive or zero. */ + public byte getIdentifier() { + return identifier; + } + + /** + * @return the auth type represented as a {@link String}, which is made of US_ASCII compatible + * characters only + */ + public String getString() { + return str; + } + + /** @see #getString() */ + @Override + public String toString() { + return str; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java b/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java new file mode 100644 index 000000000..e011d2a6f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java @@ -0,0 +1,208 @@ +package io.rsocket.util; + +import static io.netty.util.internal.StringUtil.isSurrogate; + +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.MathUtil; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CoderResult; +import java.util.Arrays; + +public class CharByteBufUtil { + + private static final byte WRITE_UTF_UNKNOWN = (byte) '?'; + + private CharByteBufUtil() {} + + /** + * Returns the exact bytes length of UTF8 character sequence. + * + *

This method is producing the exact length according to {@link #writeUtf8(ByteBuf, char[])}. + */ + public static int utf8Bytes(final char[] seq) { + return utf8ByteCount(seq, 0, seq.length); + } + + /** + * This method is producing the exact length according to {@link #writeUtf8(ByteBuf, char[], int, + * int)}. + */ + public static int utf8Bytes(final char[] seq, int start, int end) { + return utf8ByteCount(checkCharSequenceBounds(seq, start, end), start, end); + } + + private static int utf8ByteCount(final char[] seq, int start, int end) { + int i = start; + // ASCII fast path + while (i < end && seq[i] < 0x80) { + ++i; + } + // !ASCII is packed in a separate method to let the ASCII case be smaller + return i < end ? (i - start) + utf8BytesNonAscii(seq, i, end) : i - start; + } + + private static int utf8BytesNonAscii(final char[] seq, final int start, final int end) { + int encodedLength = 0; + for (int i = start; i < end; i++) { + final char c = seq[i]; + // making it 100% branchless isn't rewarding due to the many bit operations necessary! + if (c < 0x800) { + // branchless version of: (c <= 127 ? 0:1) + 1 + encodedLength += ((0x7f - c) >>> 31) + 1; + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + encodedLength++; + // WRITE_UTF_UNKNOWN + continue; + } + final char c2; + try { + // Surrogate Pair consumes 2 characters. Optimistically try to get the next character to + // avoid + // duplicate bounds checking with charAt. + c2 = seq[++i]; + } catch (IndexOutOfBoundsException ignored) { + encodedLength++; + // WRITE_UTF_UNKNOWN + break; + } + if (!Character.isLowSurrogate(c2)) { + // WRITE_UTF_UNKNOWN + (Character.isHighSurrogate(c2) ? WRITE_UTF_UNKNOWN : c2) + encodedLength += 2; + continue; + } + // See http://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + encodedLength += 4; + } else { + encodedLength += 3; + } + } + return encodedLength; + } + + private static char[] checkCharSequenceBounds(char[] seq, int start, int end) { + if (MathUtil.isOutOfBounds(start, end - start, seq.length)) { + throw new IndexOutOfBoundsException( + "expected: 0 <= start(" + + start + + ") <= end (" + + end + + ") <= seq.length(" + + seq.length + + ')'); + } + return seq; + } + + /** + * Encode a {@link char[]} in UTF-8 and write it + * into {@link ByteBuf}. + * + *

This method returns the actual number of bytes written. + */ + public static int writeUtf8(ByteBuf buf, char[] seq) { + return writeUtf8(buf, seq, 0, seq.length); + } + + /** + * Equivalent to {@link #writeUtf8(ByteBuf, char[]) + * writeUtf8(buf, seq.subSequence(start, end), reserveBytes)} but avoids subsequence object + * allocation if possible. + * + * @return actual number of bytes written + */ + public static int writeUtf8(ByteBuf buf, char[] seq, int start, int end) { + return writeUtf8(buf, buf.writerIndex(), checkCharSequenceBounds(seq, start, end), start, end); + } + + // Fast-Path implementation + static int writeUtf8(ByteBuf buffer, int writerIndex, char[] seq, int start, int end) { + int oldWriterIndex = writerIndex; + + // We can use the _set methods as these not need to do any index checks and reference checks. + // This is possible as we called ensureWritable(...) before. + for (int i = start; i < end; i++) { + char c = seq[i]; + if (c < 0x80) { + buffer.setByte(writerIndex++, (byte) c); + } else if (c < 0x800) { + buffer.setByte(writerIndex++, (byte) (0xc0 | (c >> 6))); + buffer.setByte(writerIndex++, (byte) (0x80 | (c & 0x3f))); + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + continue; + } + final char c2; + if (seq.length > ++i) { + // Surrogate Pair consumes 2 characters. Optimistically try to get the next character to + // avoid + // duplicate bounds checking with charAt. If an IndexOutOfBoundsException is thrown we + // will + // re-throw a more informative exception describing the problem. + c2 = seq[i]; + } else { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + break; + } + // Extra method to allow inlining the rest of writeUtf8 which is the most likely code path. + writerIndex = writeUtf8Surrogate(buffer, writerIndex, c, c2); + } else { + buffer.setByte(writerIndex++, (byte) (0xe0 | (c >> 12))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((c >> 6) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | (c & 0x3f))); + } + } + buffer.writerIndex(writerIndex); + return writerIndex - oldWriterIndex; + } + + private static int writeUtf8Surrogate(ByteBuf buffer, int writerIndex, char c, char c2) { + if (!Character.isLowSurrogate(c2)) { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + buffer.setByte(writerIndex++, Character.isHighSurrogate(c2) ? WRITE_UTF_UNKNOWN : c2); + return writerIndex; + } + int codePoint = Character.toCodePoint(c, c2); + // See http://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + buffer.setByte(writerIndex++, (byte) (0xf0 | (codePoint >> 18))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 12) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 6) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | (codePoint & 0x3f))); + return writerIndex; + } + + public static char[] readUtf8(ByteBuf byteBuf, int length) { + CharsetDecoder charsetDecoder = CharsetUtil.UTF_8.newDecoder(); + int en = (int) (length * (double) charsetDecoder.maxCharsPerByte()); + char[] ca = new char[en]; + + CharBuffer charBuffer = CharBuffer.wrap(ca); + ByteBuffer byteBuffer = byteBuf.internalNioBuffer(byteBuf.readerIndex(), length); + byteBuffer.mark(); + try { + CoderResult cr = charsetDecoder.decode(byteBuffer, charBuffer, true); + if (!cr.isUnderflow()) cr.throwException(); + cr = charsetDecoder.flush(charBuffer); + if (!cr.isUnderflow()) cr.throwException(); + + byteBuffer.reset(); + byteBuf.skipBytes(length); + + return safeTrim(charBuffer.array(), charBuffer.position()); + } catch (CharacterCodingException x) { + // Substitution is always enabled, + // so this shouldn't happen + throw new IllegalStateException("unable to decode char array from the given buffer", x); + } + } + + private static char[] safeTrim(char[] ca, int len) { + if (len == ca.length) return ca; + else return Arrays.copyOf(ca, len); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/MonoLifecycleHandler.java b/rsocket-core/src/main/java/io/rsocket/util/MonoLifecycleHandler.java new file mode 100644 index 000000000..4d47c03d6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/MonoLifecycleHandler.java @@ -0,0 +1,21 @@ +package io.rsocket.util; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import reactor.core.publisher.SignalType; + +public interface MonoLifecycleHandler { + + default void doOnSubscribe() {} + + /** + * Handler which is invoked on the terminal activity within a given Monoø + * + * @param signalType a type of signal which explain what happened + * @param element an carried element. May not be present if stream is empty or cancelled or + * errored + * @param e an carried error. May not be present if stream is cancelled or completed successfully + */ + default void doOnTerminal( + @Nonnull SignalType signalType, @Nullable T element, @Nullable Throwable e) {} +} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketLeaseTest.java index 2a2567843..3af8916cd 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketLeaseTest.java +++ b/rsocket-core/src/test/java/io/rsocket/RSocketLeaseTest.java @@ -141,7 +141,8 @@ public void serverRSocketFactoryRejectsUnsupportedLease() { Assertions.assertThat(sent).hasSize(1); ByteBuf error = sent.iterator().next(); Assertions.assertThat(FrameHeaderFlyweight.frameType(error)).isEqualTo(ERROR); - Assertions.assertThat(Exceptions.from(error).getMessage()).isEqualTo("lease is not supported"); + Assertions.assertThat(Exceptions.from(0, error).getMessage()) + .isEqualTo("lease is not supported"); } @Test diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketTest.java index 5d9672fb9..80865ec47 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketTest.java +++ b/rsocket-core/src/test/java/io/rsocket/RSocketTest.java @@ -24,6 +24,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.exceptions.CustomRSocketException; import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.lease.ResponderLeaseHandler; import io.rsocket.test.util.LocalDuplexConnection; @@ -38,6 +39,7 @@ import org.junit.rules.ExternalResource; import org.junit.runner.Description; import org.junit.runners.model.Statement; +import org.mockito.ArgumentCaptor; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import reactor.core.publisher.DirectProcessor; @@ -86,6 +88,34 @@ public Mono requestResponse(Payload payload) { rule.assertServerError("java.lang.NullPointerException: Deliberate exception."); } + @Test(timeout = 2000) + public void testHandlerEmitsCustomError() { + rule.setRequestAcceptor( + new AbstractRSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.error( + new CustomRSocketException(0x00000501, "Deliberate Custom exception.")); + } + }); + Subscriber subscriber = TestSubscriber.create(); + rule.crs.requestResponse(EmptyPayload.INSTANCE).subscribe(subscriber); + ArgumentCaptor customRSocketExceptionArgumentCaptor = + ArgumentCaptor.forClass(CustomRSocketException.class); + verify(subscriber).onError(customRSocketExceptionArgumentCaptor.capture()); + + Assert.assertEquals( + "Deliberate Custom exception.", + customRSocketExceptionArgumentCaptor.getValue().getMessage()); + Assert.assertEquals(0x00000501, customRSocketExceptionArgumentCaptor.getValue().errorCode()); + + // Client sees error through normal API + rule.assertNoClientErrors(); + + rule.assertServerError( + "io.rsocket.exceptions.CustomRSocketException: Deliberate Custom exception."); + } + @Test(timeout = 2000) public void testStream() throws Exception { Flux responses = rule.crs.requestStream(DefaultPayload.create("Payload In")); diff --git a/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java index 75e9f5a85..e6972eec0 100644 --- a/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java @@ -18,13 +18,11 @@ import java.time.Duration; import java.util.ArrayList; import java.util.List; -import org.junit.Ignore; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import reactor.core.publisher.UnicastProcessor; import reactor.test.StepVerifier; -@Ignore public class SetupRejectionTest { @Test @@ -39,7 +37,7 @@ void responderRejectSetup() { ByteBuf sentFrame = transport.awaitSent(); assertThat(FrameHeaderFlyweight.frameType(sentFrame)).isEqualTo(FrameType.ERROR); - RuntimeException error = Exceptions.from(sentFrame); + RuntimeException error = Exceptions.from(0, sentFrame); assertThat(errorMsg).isEqualTo(error.getMessage()); assertThat(error).isInstanceOf(RejectedSetupException.class); RSocket acceptorSender = acceptor.senderRSocket().block(); @@ -64,15 +62,16 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { String errorMsg = "error"; - Mono.delay(Duration.ofMillis(100)) - .doOnTerminate( - () -> - conn.addToReceivedBuffer( - ErrorFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, 0, new RejectedSetupException(errorMsg)))) - .subscribe(); - - StepVerifier.create(rSocket.requestResponse(DefaultPayload.create("test"))) + StepVerifier.create( + rSocket + .requestResponse(DefaultPayload.create("test")) + .doOnRequest( + ignored -> + conn.addToReceivedBuffer( + ErrorFrameFlyweight.encode( + ByteBufAllocator.DEFAULT, + 0, + new RejectedSetupException(errorMsg))))) .expectErrorMatches( err -> err instanceof RejectedSetupException && errorMsg.equals(err.getMessage())) .verify(Duration.ofSeconds(5)); diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java index c7bbfadf6..e646080c7 100644 --- a/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java @@ -16,131 +16,220 @@ package io.rsocket.exceptions; +import static io.rsocket.frame.ErrorFrameFlyweight.APPLICATION_ERROR; +import static io.rsocket.frame.ErrorFrameFlyweight.CANCELED; +import static io.rsocket.frame.ErrorFrameFlyweight.CONNECTION_CLOSE; +import static io.rsocket.frame.ErrorFrameFlyweight.CONNECTION_ERROR; +import static io.rsocket.frame.ErrorFrameFlyweight.INVALID; +import static io.rsocket.frame.ErrorFrameFlyweight.INVALID_SETUP; +import static io.rsocket.frame.ErrorFrameFlyweight.REJECTED; +import static io.rsocket.frame.ErrorFrameFlyweight.REJECTED_RESUME; +import static io.rsocket.frame.ErrorFrameFlyweight.REJECTED_SETUP; +import static io.rsocket.frame.ErrorFrameFlyweight.UNSUPPORTED_SETUP; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.rsocket.frame.ErrorFrameFlyweight; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + final class ExceptionsTest { - /* @DisplayName("from returns ApplicationErrorException") @Test void fromApplicationException() { - ByteBuf byteBuf = createErrorFrame(APPLICATION_ERROR, "test-message"); + ByteBuf byteBuf = createErrorFrame(1, APPLICATION_ERROR, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(1, byteBuf)) .isInstanceOf(ApplicationErrorException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 0: 0x%08X '%s'", APPLICATION_ERROR, "test-message"); } @DisplayName("from returns CanceledException") @Test void fromCanceledException() { - ByteBuf byteBuf = createErrorFrame(CANCELED, "test-message"); + ByteBuf byteBuf = createErrorFrame(1, CANCELED, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(1, byteBuf)) .isInstanceOf(CanceledException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", CANCELED, "test-message"); } @DisplayName("from returns ConnectionCloseException") @Test void fromConnectionCloseException() { - ByteBuf byteBuf = createErrorFrame(CONNECTION_CLOSE, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, CONNECTION_CLOSE, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(ConnectionCloseException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_CLOSE, "test-message"); } @DisplayName("from returns ConnectionErrorException") @Test void fromConnectionErrorException() { - ByteBuf byteBuf = createErrorFrame(CONNECTION_ERROR, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, CONNECTION_ERROR, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(ConnectionErrorException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_ERROR, "test-message"); } @DisplayName("from returns IllegalArgumentException if error frame has illegal error code") @Test void fromIllegalErrorFrame() { - ByteBuf byteBuf = createErrorFrame(0x00000000, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, 0x00000000, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) - .isInstanceOf(IllegalArgumentException.class) - .withFailMessage("Invalid Error frame: %d, '%s'", 0, "test-message"); + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", 0, "test-message") + .isInstanceOf(IllegalArgumentException.class); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 1: 0x%08X '%s'", 0x00000000, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns InvalidException") @Test void fromInvalidException() { - ByteBuf byteBuf = createErrorFrame(INVALID, "test-message"); + ByteBuf byteBuf = createErrorFrame(1, INVALID, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(1, byteBuf)) .isInstanceOf(InvalidException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", INVALID, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns InvalidSetupException") @Test void fromInvalidSetupException() { - ByteBuf byteBuf = createErrorFrame(INVALID_SETUP, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, INVALID_SETUP, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(InvalidSetupException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", INVALID_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns RejectedException") @Test void fromRejectedException() { - ByteBuf byteBuf = createErrorFrame(REJECTED, "test-message"); + ByteBuf byteBuf = createErrorFrame(1, REJECTED, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(1, byteBuf)) .isInstanceOf(RejectedException.class) .withFailMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", REJECTED, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns RejectedResumeException") @Test void fromRejectedResumeException() { - ByteBuf byteBuf = createErrorFrame(REJECTED_RESUME, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, REJECTED_RESUME, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(RejectedResumeException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_RESUME, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns RejectedSetupException") @Test void fromRejectedSetupException() { - ByteBuf byteBuf = createErrorFrame(REJECTED_SETUP, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, REJECTED_SETUP, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(RejectedSetupException.class) .withFailMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns UnsupportedSetupException") @Test void fromUnsupportedSetupException() { - ByteBuf byteBuf = createErrorFrame(UNSUPPORTED_SETUP, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, UNSUPPORTED_SETUP, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(UnsupportedSetupException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", UNSUPPORTED_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } + + @DisplayName("from returns CustomRSocketException") + @Test + void fromCustomRSocketException() { + for (int i = 0; i < 1000; i++) { + int randomCode = + ThreadLocalRandom.current().nextBoolean() + ? ThreadLocalRandom.current() + .nextInt(Integer.MIN_VALUE, ErrorFrameFlyweight.MAX_USER_ALLOWED_ERROR_CODE) + : ThreadLocalRandom.current() + .nextInt(ErrorFrameFlyweight.MIN_USER_ALLOWED_ERROR_CODE, Integer.MAX_VALUE); + ByteBuf byteBuf = createErrorFrame(0, randomCode, "test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(CustomRSocketException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", randomCode, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } } @DisplayName("from throws NullPointerException with null frame") @Test void fromWithNullFrame() { assertThatNullPointerException() - .isThrownBy(() -> Exceptions.from(null)) + .isThrownBy(() -> Exceptions.from(0, null)) .withMessage("frame must not be null"); } - private ByteBuf createErrorFrame(int errorCode, String message) { - ByteBuf byteBuf = Unpooled.buffer(); - - ErrorFrameFlyweight.encode(byteBuf, 0, errorCode, Unpooled.copiedBuffer(message, UTF_8)); - - return byteBuf; - }*/ + private ByteBuf createErrorFrame(int streamId, int errorCode, String message) { + return ErrorFrameFlyweight.encode( + UnpooledByteBufAllocator.DEFAULT, streamId, new TestRSocketException(errorCode, message)); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java b/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java new file mode 100644 index 000000000..6c2e63730 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java @@ -0,0 +1,39 @@ +package io.rsocket.exceptions; + +public class TestRSocketException extends RSocketException { + private static final long serialVersionUID = 7873267740343446585L; + + private final int errorCode; + + /** + * Constructs a new exception with the specified message. + * + * @param errorCode customizable error code + * @param message the message + * @throws NullPointerException if {@code message} is {@code null} + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public TestRSocketException(int errorCode, String message) { + super(message); + this.errorCode = errorCode; + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param errorCode customizable error code + * @param message the message + * @param cause the cause of this exception + * @throws NullPointerException if {@code message} or {@code cause} is {@code null} + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public TestRSocketException(int errorCode, String message, Throwable cause) { + super(message, cause); + this.errorCode = errorCode; + } + + @Override + public int errorCode() { + return errorCode; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java b/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java new file mode 100644 index 000000000..d73f92b85 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java @@ -0,0 +1,23 @@ +package io.rsocket.internal; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import reactor.core.scheduler.Scheduler; + +public class SchedulerUtils { + + public static void warmup(Scheduler scheduler) throws InterruptedException { + warmup(scheduler, 10000); + } + + public static void warmup(Scheduler scheduler, int times) throws InterruptedException { + scheduler.start(); + + // warm up + CountDownLatch latch = new CountDownLatch(times); + for (int i = 0; i < times; i++) { + scheduler.schedule(latch::countDown); + } + latch.await(5, TimeUnit.SECONDS); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/UnicastMonoEmptyTest.java b/rsocket-core/src/test/java/io/rsocket/internal/UnicastMonoEmptyTest.java new file mode 100644 index 000000000..76bb953a4 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/UnicastMonoEmptyTest.java @@ -0,0 +1,97 @@ +package io.rsocket.internal; + +import static io.rsocket.internal.SchedulerUtils.warmup; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.util.RaceTestUtils; + +public class UnicastMonoEmptyTest { + + @Test + public void shouldSupportASingleSubscriber() throws InterruptedException { + warmup(Schedulers.single()); + + for (int i = 0; i < 10000; i++) { + AtomicInteger times = new AtomicInteger(); + Mono unicastMono = UnicastMonoEmpty.newInstance(times::incrementAndGet); + + Assertions.assertThatThrownBy( + () -> + RaceTestUtils.race( + unicastMono::subscribe, unicastMono::subscribe, Schedulers.single())) + .hasCause(new IllegalStateException("UnicastMonoEmpty allows only a single Subscriber")); + Assertions.assertThat(times.get()).isEqualTo(1); + } + } + + @Test + public void shouldSupportASingleBlock() throws InterruptedException { + warmup(Schedulers.single()); + + for (int i = 0; i < 10000; i++) { + AtomicInteger times = new AtomicInteger(); + Mono unicastMono = UnicastMonoEmpty.newInstance(times::incrementAndGet); + + Assertions.assertThatThrownBy( + () -> RaceTestUtils.race(unicastMono::block, unicastMono::block, Schedulers.single())) + .hasMessage("UnicastMonoEmpty allows only a single Subscriber"); + Assertions.assertThat(times.get()).isEqualTo(1); + } + } + + @Test + public void shouldSupportASingleBlockWithTimeout() throws InterruptedException { + warmup(Schedulers.single()); + + for (int i = 0; i < 10000; i++) { + AtomicInteger times = new AtomicInteger(); + Mono unicastMono = UnicastMonoEmpty.newInstance(times::incrementAndGet); + + Assertions.assertThatThrownBy( + () -> + RaceTestUtils.race( + () -> unicastMono.block(Duration.ofMinutes(1)), + () -> unicastMono.block(Duration.ofMinutes(1)), + Schedulers.single())) + .hasMessage("UnicastMonoEmpty allows only a single Subscriber"); + Assertions.assertThat(times.get()).isEqualTo(1); + } + } + + @Test + public void shouldSupportASingleSubscribeOrBlock() throws InterruptedException { + warmup(Schedulers.parallel()); + + for (int i = 0; i < 10000; i++) { + AtomicInteger times = new AtomicInteger(); + Mono unicastMono = UnicastMonoEmpty.newInstance(times::incrementAndGet); + + Assertions.assertThatThrownBy( + () -> + RaceTestUtils.race( + unicastMono::subscribe, + () -> + RaceTestUtils.race( + unicastMono::block, + () -> unicastMono.block(Duration.ofMinutes(1)), + Schedulers.parallel()), + Schedulers.parallel())) + .matches( + t -> { + Assertions.assertThat(t.getSuppressed()).hasSize(2); + Assertions.assertThat(t.getSuppressed()[0]) + .hasMessageContaining("UnicastMonoEmpty allows only a single Subscriber"); + Assertions.assertThat(t.getSuppressed()[1]) + .hasMessageContaining("UnicastMonoEmpty allows only a single Subscriber"); + + return true; + }); + Assertions.assertThat(times.get()).isEqualTo(1); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/UnicastMonoProcessorTest.java b/rsocket-core/src/test/java/io/rsocket/internal/UnicastMonoProcessorTest.java index 20ed4b469..a836dd509 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/UnicastMonoProcessorTest.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/UnicastMonoProcessorTest.java @@ -1,12 +1,31 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.internal; +import static io.rsocket.internal.SchedulerUtils.warmup; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.Assumptions.assumeThat; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.util.MonoLifecycleHandler; import java.lang.ref.WeakReference; import java.time.Duration; +import java.util.ArrayList; import java.util.Date; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; @@ -19,15 +38,1107 @@ import org.reactivestreams.Subscription; import reactor.core.Scannable; import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; +import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; +import reactor.util.context.Context; import reactor.util.function.Tuple2; public class UnicastMonoProcessorTest { + static class VerifyMonoLifecycleHandler implements MonoLifecycleHandler { + private final AtomicInteger onSubscribeCounter = new AtomicInteger(); + private final AtomicInteger onTerminalCounter = new AtomicInteger(); + private final AtomicReference valueReference = new AtomicReference<>(); + private final AtomicReference errorReference = new AtomicReference<>(); + private final AtomicReference signalTypeReference = new AtomicReference<>(); + + @Override + public void doOnSubscribe() { + onSubscribeCounter.incrementAndGet(); + } + + @Override + public void doOnTerminal(SignalType signalType, T element, Throwable e) { + onTerminalCounter.incrementAndGet(); + signalTypeReference.set(signalType); + valueReference.set(element); + errorReference.set(e); + } + + public VerifyMonoLifecycleHandler assertSubscribed() { + assertThat(onSubscribeCounter.get()).isOne(); + return this; + } + + public VerifyMonoLifecycleHandler assertNotSubscribed() { + assertThat(onSubscribeCounter.get()).isZero(); + return this; + } + + public VerifyMonoLifecycleHandler assertTerminated() { + assertThat(onTerminalCounter.get()).describedAs("Expected a single terminal signal").isOne(); + return this; + } + + public VerifyMonoLifecycleHandler assertNotTerminated() { + assertThat(onTerminalCounter.get()).describedAs("Expected zero terminal signals").isZero(); + return this; + } + + public VerifyMonoLifecycleHandler assertCompleted() { + assertTerminated(); + assertThat(signalTypeReference.get()) + .describedAs("Expected ON_COMPLETE signal") + .isEqualTo(SignalType.ON_COMPLETE); + assertThat(errorReference.get()).describedAs("Expected error to be absent").isNull(); + assertThat(valueReference.get()).isNull(); + return this; + } + + public VerifyMonoLifecycleHandler assertCompleted(T value) { + assertTerminated(); + assertThat(signalTypeReference.get()) + .describedAs("Expected ON_COMPLETE signal") + .isEqualTo(SignalType.ON_COMPLETE); + assertThat(errorReference.get()).describedAs("Expected error to be absent").isNull(); + assertThat(valueReference.get()).isEqualTo(value); + return this; + } + + public VerifyMonoLifecycleHandler assertErrored() { + assertTerminated(); + assertThat(signalTypeReference.get()) + .describedAs("Expected ON_ERROR signal") + .isEqualTo(SignalType.ON_ERROR); + assertThat(errorReference.get()).describedAs("Expected error to be present").isNotNull(); + assertThat(valueReference.get()).isNull(); + return this; + } + + public VerifyMonoLifecycleHandler assertCancelled() { + assertTerminated(); + assertThat(signalTypeReference.get()) + .describedAs("Expected ON_ERROR signal") + .isEqualTo(SignalType.CANCEL); + assertThat(errorReference.get()).describedAs("Expected error to be absent").isNull(); + assertThat(valueReference.get()).isNull(); + return this; + } + } + + @Test + public void testUnicast() throws InterruptedException { + warmup(Schedulers.single()); + + for (int i = 0; i < 10000; i++) { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + verifyMonoLifecycleHandler.assertNotSubscribed(); + assertThatThrownBy(() -> RaceTestUtils.race(processor::subscribe, processor::subscribe)) + .hasCause( + new IllegalStateException("UnicastMonoProcessor allows only a single Subscriber")); + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + } + } + + @Test + public void stateFlowTest1_Next() { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.onNext(1); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_HAS_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_HAS_RESULT); + + assertSubscriber.assertNoEvents(); + assertSubscriber.request(1); + + verifyMonoLifecycleHandler.assertCompleted(1); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertValues(1); + assertSubscriber.assertComplete(); + } + + @Test + public void stateFlowTest1_Complete() { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.onComplete(); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_HAS_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertCompleted(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertNoValues(); + assertSubscriber.assertComplete(); + } + + @Test + public void stateFlowTest1_Error() { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + RuntimeException testError = new RuntimeException("test"); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.onError(testError); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_HAS_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertErrored(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertNoValues(); + assertSubscriber.assertError(RuntimeException.class); + assertSubscriber.assertErrorMessage("test"); + } + + @Test + public void stateFlowTest1_Dispose() { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.dispose(); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_HAS_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertErrored(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertNoValues(); + assertSubscriber.assertError(CancellationException.class); + assertSubscriber.assertErrorMessage("Disposed"); + } + + @Test + public void stateFlowTest2_Next() { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + processor.onNext(1); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_HAS_RESULT); + + assertSubscriber.assertNoEvents(); + assertSubscriber.request(1); + + verifyMonoLifecycleHandler.assertSubscribed().assertCompleted(1); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertValues(1); + assertSubscriber.assertComplete(); + } + + @Test + public void stateFlowTest2_Complete() { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + processor.onComplete(); + + verifyMonoLifecycleHandler.assertSubscribed().assertCompleted(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertNoValues(); + assertSubscriber.assertComplete(); + } + + @Test + public void stateFlowTest2_Error() { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + processor.onError(new RuntimeException("Test")); + + verifyMonoLifecycleHandler.assertSubscribed().assertErrored(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertNoValues(); + assertSubscriber.assertError(RuntimeException.class); + assertSubscriber.assertErrorMessage("Test"); + } + + @Test + public void stateFlowTest2_Dispose() { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + processor.dispose(); + + verifyMonoLifecycleHandler.assertSubscribed().assertErrored(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertNoValues(); + assertSubscriber.assertError(CancellationException.class); + assertSubscriber.assertErrorMessage("Disposed"); + } + + @Test + public void stateFlowTest3_Next() { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + assertSubscriber.assertNoEvents(); + assertSubscriber.request(1); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_NO_RESULT); + + assertSubscriber.assertNoEvents(); + processor.onNext(1); + + verifyMonoLifecycleHandler.assertSubscribed().assertCompleted(1); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertValues(1); + assertSubscriber.assertComplete(); + } + + @Test + public void stateFlowTest3_Complete() { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + assertSubscriber.request(1); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_NO_RESULT); + + processor.onComplete(); + + verifyMonoLifecycleHandler.assertSubscribed().assertCompleted(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertNoValues(); + assertSubscriber.assertComplete(); + } + + @Test + public void stateFlowTest3_Error() { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + assertSubscriber.request(1); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_NO_RESULT); + + processor.onError(new RuntimeException("Test")); + + verifyMonoLifecycleHandler.assertSubscribed().assertErrored(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertNoValues(); + assertSubscriber.assertError(RuntimeException.class); + assertSubscriber.assertErrorMessage("Test"); + } + + @Test + public void stateFlowTest3_Dispose() { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + assertSubscriber.request(1); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_NO_RESULT); + + processor.dispose(); + + verifyMonoLifecycleHandler.assertSubscribed().assertErrored(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertNoValues(); + assertSubscriber.assertError(RuntimeException.class); + assertSubscriber.assertErrorMessage("Disposed"); + } + + @Test + public void stateFlowTest4_Next() { + ArrayList discarded = new ArrayList<>(); + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + // Context discardingContext = Operators.enableOnDiscard(null, discarded::add); + Hooks.onNextDropped(discarded::add); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + try { + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + assertSubscriber.request(1); + assertSubscriber.assertNoEvents(); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_NO_RESULT); + + assertSubscriber.cancel(); + assertSubscriber.assertNoEvents(); + + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + processor.onNext(1); + + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + assertSubscriber.assertNoEvents(); + assertThat(discarded).containsExactly(1); + } finally { + Hooks.resetOnNextDropped(); + } + } + + @Test + public void stateFlowTest4_Error() { + ArrayList discarded = new ArrayList<>(); + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + // Context discardingContext = Operators.enableOnDiscard(null, discarded::add); + Hooks.onErrorDropped(discarded::add); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + try { + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + assertSubscriber.request(1); + assertSubscriber.assertNoEvents(); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_NO_RESULT); + + assertSubscriber.cancel(); + assertSubscriber.assertNoEvents(); + + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + RuntimeException testError = new RuntimeException("test"); + processor.onError(testError); + + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + assertSubscriber.assertNoEvents(); + assertThat(discarded).containsExactly(testError); + + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @Test + public void stateFlowTest4_Dispose() { + ArrayList discarded = new ArrayList<>(); + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + // Context discardingContext = Operators.enableOnDiscard(null, discarded::add); + Hooks.onErrorDropped(discarded::add); + try { + AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + assertSubscriber.request(1); + assertSubscriber.assertNoEvents(); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_NO_RESULT); + + assertSubscriber.cancel(); + assertSubscriber.assertNoEvents(); + + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + processor.dispose(); + + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + assertSubscriber.assertNoEvents(); + assertThat(discarded).isEmpty(); + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @Test + public void stateFlowTest4_Complete() { + ArrayList discarded = new ArrayList<>(); + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + // Context discardingContext = Operators.enableOnDiscard(null, discarded::add); + Hooks.onErrorDropped(discarded::add); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + try { + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + assertSubscriber.request(1); + assertSubscriber.assertNoEvents(); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_NO_RESULT); + + assertSubscriber.cancel(); + assertSubscriber.assertNoEvents(); + + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + processor.onComplete(); + + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + assertSubscriber.assertNoEvents(); + assertThat(discarded).isEmpty(); + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @Test + public void stateFlowTest5_Next() { + ArrayList discarded = new ArrayList<>(); + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + Context discardingContext = Operators.enableOnDiscard(null, discarded::add); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(discardingContext, 0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + processor.onNext(1); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_HAS_RESULT); + + assertSubscriber.cancel(); + + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + assertSubscriber.request(1); + + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + assertSubscriber.assertNoEvents(); + assertThat(discarded).containsExactly(1); + } + + @Test + public void stateFlowTest5_Complete() { + ArrayList discarded = new ArrayList<>(); + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + Context discardingContext = Operators.enableOnDiscard(null, discarded::add); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(discardingContext, 0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + processor.onComplete(); + + verifyMonoLifecycleHandler.assertSubscribed().assertCompleted(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.cancel(); + + verifyMonoLifecycleHandler.assertSubscribed().assertCompleted(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + assertSubscriber.assertComplete(); + assertThat(discarded).isEmpty(); + } + + @Test + public void stateFlowTest5_Error() { + ArrayList discarded = new ArrayList<>(); + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + Context discardingContext = Operators.enableOnDiscard(null, discarded::add); + Hooks.onErrorDropped(discarded::add); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(discardingContext, 0); + + try { + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + processor.onError(new RuntimeException("test")); + + verifyMonoLifecycleHandler.assertSubscribed().assertErrored(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.cancel(); + + verifyMonoLifecycleHandler.assertSubscribed().assertErrored(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + assertSubscriber.assertError(RuntimeException.class); + assertSubscriber.assertErrorMessage("test"); + assertThat(discarded).isEmpty(); + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @Test + public void stateFlowTest5_Dispose() { + ArrayList discarded = new ArrayList<>(); + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + Context discardingContext = Operators.enableOnDiscard(null, discarded::add); + Hooks.onErrorDropped(discarded::add); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(discardingContext, 0); + + try { + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_NO_RESULT); + + processor.dispose(); + + verifyMonoLifecycleHandler.assertSubscribed().assertErrored(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.cancel(); + + verifyMonoLifecycleHandler.assertSubscribed().assertErrored(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + assertSubscriber.assertError(CancellationException.class); + assertSubscriber.assertErrorMessage("Disposed"); + assertThat(discarded).isEmpty(); + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @Test + public void stateFlowTest6_Next() { + ArrayList discarded = new ArrayList<>(); + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + Context discardingContext = Operators.enableOnDiscard(null, discarded::add); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(discardingContext, 0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.onNext(1); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_HAS_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_REQUEST_HAS_RESULT); + + assertSubscriber.cancel(); + + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + assertSubscriber.request(1); + + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + assertSubscriber.assertNoEvents(); + assertThat(discarded).containsExactly(1); + } + + @Test + public void stateFlowTest7_Next() throws InterruptedException { + warmup(Schedulers.single()); + + for (int i = 0; i < 10000; i++) { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + RaceTestUtils.race( + () -> processor.onNext(1), + () -> processor.subscribe(assertSubscriber), + Schedulers.single()); + + verifyMonoLifecycleHandler.assertSubscribed().assertCompleted(1); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertValues(1); + assertSubscriber.assertComplete(); + } + } + + @Test + public void stateFlowTest7_Complete() throws InterruptedException { + warmup(Schedulers.single()); + + for (int i = 0; i < 10000; i++) { + UnicastMonoProcessor processor = UnicastMonoProcessor.create(); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + RaceTestUtils.race( + processor::onComplete, () -> processor.subscribe(assertSubscriber), Schedulers.single()); + + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertNoValues(); + assertSubscriber.assertComplete(); + } + } + + @Test + public void stateFlowTest7_Error() throws InterruptedException { + warmup(Schedulers.single()); + + for (int i = 0; i < 10000; i++) { + UnicastMonoProcessor processor = UnicastMonoProcessor.create(); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + RaceTestUtils.race( + () -> processor.onError(new RuntimeException("test")), + () -> processor.subscribe(assertSubscriber), + Schedulers.single()); + + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertNoValues(); + assertSubscriber.assertError(RuntimeException.class); + assertSubscriber.assertErrorMessage("test"); + } + } + + @Test + public void stateFlowTest7_Dispose() throws InterruptedException { + warmup(Schedulers.single()); + + for (int i = 0; i < 10000; i++) { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + RaceTestUtils.race( + processor::dispose, () -> processor.subscribe(assertSubscriber), Schedulers.single()); + + verifyMonoLifecycleHandler.assertSubscribed().assertErrored(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertNoValues(); + assertSubscriber.assertError(CancellationException.class); + assertSubscriber.assertErrorMessage("Disposed"); + } + } + + @Test + public void stateFlowTest8_Next() throws InterruptedException { + warmup(Schedulers.single()); + + for (int i = 0; i < 10000; i++) { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_NO_RESULT); + + RaceTestUtils.race(() -> processor.onNext(1), assertSubscriber::cancel, Schedulers.single()); + + verifyMonoLifecycleHandler.assertSubscribed().assertTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + if (assertSubscriber.values().isEmpty()) { + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertSubscriber.assertNoEvents(); + } else { + verifyMonoLifecycleHandler.assertSubscribed().assertCompleted(1); + assertSubscriber.assertValues(1); + assertSubscriber.assertComplete(); + } + } + } + + @Test + public void stateFlowTest9_Next() throws InterruptedException { + warmup(Schedulers.single()); + + for (int i = 0; i < 10000; i++) { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_NO_RESULT); + + RaceTestUtils.race(() -> processor.onNext(1), processor::dispose, Schedulers.single()); + + verifyMonoLifecycleHandler.assertSubscribed().assertTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + if (processor.isError()) { + verifyMonoLifecycleHandler.assertSubscribed().assertErrored(); + assertSubscriber.assertNoValues(); + assertSubscriber.assertErrorMessage("Disposed"); + } else { + verifyMonoLifecycleHandler.assertSubscribed().assertCompleted(1); + assertSubscriber.assertValues(1); + assertSubscriber.assertComplete(); + } + } + } + + @Test + public void stateFlowTest13_Next() throws InterruptedException { + warmup(Schedulers.single()); + + for (int i = 0; i < 10000; i++) { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.onNext(1); + processor.subscribe(assertSubscriber); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + + RaceTestUtils.race( + () -> assertSubscriber.request(1), + () -> assertSubscriber.request(1), + Schedulers.single()); + + verifyMonoLifecycleHandler.assertSubscribed().assertCompleted(1); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertValues(1); + assertSubscriber.assertComplete(); + } + } + + @Test + public void stateFlowTest14_Next() throws InterruptedException { + warmup(Schedulers.single()); + + for (int i = 0; i < 10000; i++) { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + assertSubscriber.request(1); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + RaceTestUtils.race(() -> processor.onNext(1), () -> processor.onNext(1), Schedulers.single()); + + verifyMonoLifecycleHandler.assertSubscribed().assertCompleted(1); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.HAS_REQUEST_HAS_RESULT); + + assertSubscriber.assertValues(1); + assertSubscriber.assertComplete(); + } + } + + @Test + public void stateFlowTest15_Next() throws InterruptedException { + warmup(Schedulers.single()); + + for (int i = 0; i < 10000; i++) { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + processor.onNext(1); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + RaceTestUtils.race( + () -> assertSubscriber.request(1), assertSubscriber::cancel, Schedulers.single()); + + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + if (assertSubscriber.values().isEmpty()) { + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertSubscriber.assertNoEvents(); + } else { + verifyMonoLifecycleHandler.assertSubscribed().assertCompleted(1); + assertSubscriber.assertValues(1); + assertSubscriber.assertComplete(); + } + } + } + + @Test + public void stateFlowTest16_Next() throws InterruptedException { + warmup(Schedulers.single()); + + for (int i = 0; i < 10000; i++) { + VerifyMonoLifecycleHandler verifyMonoLifecycleHandler = + new VerifyMonoLifecycleHandler<>(); + UnicastMonoProcessor processor = + UnicastMonoProcessor.create(verifyMonoLifecycleHandler); + AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + verifyMonoLifecycleHandler.assertNotSubscribed().assertNotTerminated(); + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.NO_SUBSCRIBER_NO_RESULT); + + processor.subscribe(assertSubscriber); + processor.onNext(1); + + verifyMonoLifecycleHandler.assertSubscribed().assertNotTerminated(); + RaceTestUtils.race(assertSubscriber::cancel, assertSubscriber::cancel, Schedulers.single()); + + assertThat(processor.state).isEqualTo(UnicastMonoProcessor.CANCELLED); + + verifyMonoLifecycleHandler.assertSubscribed().assertCancelled(); + assertSubscriber.assertNoEvents(); + } + } + @Test public void noRetentionOnTermination() throws InterruptedException { Date date = new Date(); @@ -54,7 +1165,7 @@ public void noRetentionOnTermination() throws InterruptedException { Thread.sleep(100); } - assumeThat(refFuture.get()).isNull(); + assertThat(refFuture.get()).isNull(); assertThat(refDate.get()).isNull(); assertThat(cycles).isNotZero().isPositive(); } @@ -83,7 +1194,7 @@ public void noRetentionOnTerminationError() throws InterruptedException { Thread.sleep(100); } - assumeThat(refFuture.get()).isNull(); + assertThat(refFuture.get()).isNull(); assertThat(cycles).isNotZero().isPositive(); } @@ -114,7 +1225,7 @@ public void noRetentionOnTerminationCancel() throws InterruptedException { Thread.sleep(100); } - assumeThat(refFuture.get()).isNull(); + assertThat(refFuture.get()).isNull(); assertThat(cycles).isNotZero().isPositive(); } @@ -169,7 +1280,7 @@ public void MonoProcessorSuccessDoOnSuccessOrError() { mp.onNext("test"); assertThat(ref.get()).isEqualToIgnoringCase("test"); - assertThat(mp.isTerminated()).isTrue(); + assertThat(mp.isDisposed()).isTrue(); assertThat(mp.isError()).isFalse(); } @@ -182,7 +1293,7 @@ public void MonoProcessorSuccessDoOnTerminate() { mp.onNext("test"); assertThat(invoked.get()).isEqualTo(1); - assertThat(mp.isTerminated()).isTrue(); + assertThat(mp.isDisposed()).isTrue(); assertThat(mp.isError()).isFalse(); } @@ -195,7 +1306,7 @@ public void MonoProcessorSuccessSubscribeCallback() { mp.onNext("test"); assertThat(ref.get()).isEqualToIgnoringCase("test"); - assertThat(mp.isTerminated()).isTrue(); + assertThat(mp.isDisposed()).isTrue(); assertThat(mp.isError()).isFalse(); } @@ -227,7 +1338,7 @@ public void MonoProcessorSuccessDoOnSuccess() { mp.onNext("test"); assertThat(ref.get()).isEqualToIgnoringCase("test"); - assertThat(mp.isTerminated()).isTrue(); + assertThat(mp.isDisposed()).isTrue(); assertThat(mp.isError()).isFalse(); } @@ -240,7 +1351,7 @@ public void MonoProcessorSuccessChainTogether() { mp.onNext("test"); assertThat(mp2.peek()).isEqualToIgnoringCase("test"); - assertThat(mp.isTerminated()).isTrue(); + assertThat(mp.isDisposed()).isTrue(); assertThat(mp.isError()).isFalse(); } @@ -278,7 +1389,7 @@ public void MonoProcessorNullFulfill() { mp.onNext(null); - assertThat(mp.isTerminated()).isTrue(); + assertThat(mp.isDisposed()).isTrue(); assertThat(mp.peek()).isNull(); } @@ -290,10 +1401,13 @@ public void MonoProcessorMapFulfill() { UnicastMonoProcessor mp2 = mp.map(s -> s * 2).subscribeWith(UnicastMonoProcessor.create()); - mp2.subscribe(); + assertThat(mp2.isDisposed()).isTrue(); assertThat(mp2.isTerminated()).isTrue(); + assertThat(mp2.isCancelled()).isFalse(); assertThat(mp2.peek()).isEqualTo(2); + + mp2.subscribe(); } @Test @@ -304,10 +1418,13 @@ public void MonoProcessorThenFulfill() { UnicastMonoProcessor mp2 = mp.flatMap(s -> Mono.just(s * 2)).subscribeWith(UnicastMonoProcessor.create()); - mp2.subscribe(); + assertThat(mp2.isDisposed()).isTrue(); assertThat(mp2.isTerminated()).isTrue(); + assertThat(mp2.isCancelled()).isFalse(); assertThat(mp2.peek()).isEqualTo(2); + + mp2.subscribe(); } @Test @@ -328,7 +1445,7 @@ public void MonoProcessorMapError() { .thenRequest(1) .then( () -> { - assertThat(mp2.isTerminated()).isTrue(); + assertThat(mp2.isDisposed()).isTrue(); assertThat(mp2.getError()).hasMessage("test"); }) .verifyErrorMessage("test"); @@ -356,17 +1473,18 @@ public void zipMonoProcessor() { UnicastMonoProcessor mp2 = UnicastMonoProcessor.create(); UnicastMonoProcessor> mp3 = UnicastMonoProcessor.create(); - StepVerifier.create(Mono.zip(mp, mp2).subscribeWith(mp3)) - .then(() -> assertThat(mp3.isTerminated()).isFalse()) + StepVerifier.create(Mono.zip(mp, mp2).subscribeWith(mp3), 0) + .then(() -> assertThat(mp3.isDisposed()).isFalse()) .then(() -> mp.onNext(1)) - .then(() -> assertThat(mp3.isTerminated()).isFalse()) + .then(() -> assertThat(mp3.isDisposed()).isFalse()) .then(() -> mp2.onNext(2)) .then( () -> { - assertThat(mp3.isTerminated()).isTrue(); + assertThat(mp3.isDisposed()).isTrue(); assertThat(mp3.peek().getT1()).isEqualTo(1); assertThat(mp3.peek().getT2()).isEqualTo(2); }) + .thenRequest(1) .expectNextMatches(t -> t.getT1() == 1 && t.getT2() == 2) .verifyComplete(); } @@ -376,14 +1494,15 @@ public void zipMonoProcessor2() { UnicastMonoProcessor mp = UnicastMonoProcessor.create(); UnicastMonoProcessor mp3 = UnicastMonoProcessor.create(); - StepVerifier.create(Mono.zip(d -> (Integer) d[0], mp).subscribeWith(mp3)) - .then(() -> assertThat(mp3.isTerminated()).isFalse()) + StepVerifier.create(Mono.zip(d -> (Integer) d[0], mp).subscribeWith(mp3), 0) + .then(() -> assertThat(mp3.isDisposed()).isFalse()) .then(() -> mp.onNext(1)) .then( () -> { - assertThat(mp3.isTerminated()).isTrue(); + assertThat(mp3.isDisposed()).isTrue(); assertThat(mp3.peek()).isEqualTo(1); }) + .thenRequest(1) .expectNext(1) .verifyComplete(); } @@ -395,11 +1514,11 @@ public void zipMonoProcessorRejected() { UnicastMonoProcessor> mp3 = UnicastMonoProcessor.create(); StepVerifier.create(Mono.zip(mp, mp2).subscribeWith(mp3)) - .then(() -> assertThat(mp3.isTerminated()).isFalse()) + .then(() -> assertThat(mp3.isDisposed()).isFalse()) .then(() -> mp.onError(new Exception("test"))) .then( () -> { - assertThat(mp3.isTerminated()).isTrue(); + assertThat(mp3.isDisposed()).isTrue(); assertThat(mp3.getError()).hasMessage("test"); }) .verifyErrorMessage("test"); @@ -409,12 +1528,13 @@ public void zipMonoProcessorRejected() { public void filterMonoProcessor() { UnicastMonoProcessor mp = UnicastMonoProcessor.create(); UnicastMonoProcessor mp2 = UnicastMonoProcessor.create(); - StepVerifier.create(mp.filter(s -> s % 2 == 0).subscribeWith(mp2)) + StepVerifier.create(mp.filter(s -> s % 2 == 0).subscribeWith(mp2), 0) .then(() -> mp.onNext(2)) .then(() -> assertThat(mp2.isError()).isFalse()) - .then(() -> assertThat(mp2.isTerminated()).isTrue()) + .then(() -> assertThat(mp2.isDisposed()).isTrue()) .then(() -> assertThat(mp2.peek()).isEqualTo(2)) - .then(() -> assertThat(mp2.isTerminated()).isTrue()) + .then(() -> assertThat(mp2.isDisposed()).isTrue()) + .thenRequest(1) .expectNext(2) .verifyComplete(); } @@ -426,9 +1546,9 @@ public void filterMonoProcessorNot() { StepVerifier.create(mp.filter(s -> s % 2 == 0).subscribeWith(mp2)) .then(() -> mp.onNext(1)) .then(() -> assertThat(mp2.isError()).isFalse()) - .then(() -> assertThat(mp2.isTerminated()).isTrue()) + .then(() -> assertThat(mp2.isDisposed()).isTrue()) .then(() -> assertThat(mp2.peek()).isNull()) - .then(() -> assertThat(mp2.isTerminated()).isTrue()) + .then(() -> assertThat(mp2.isDisposed()).isTrue()) .verifyComplete(); } @@ -444,9 +1564,9 @@ public void filterMonoProcessorError() { .subscribeWith(mp2)) .then(() -> mp.onNext(2)) .then(() -> assertThat(mp2.isError()).isTrue()) - .then(() -> assertThat(mp2.isTerminated()).isTrue()) + .then(() -> assertThat(mp2.isDisposed()).isTrue()) .then(() -> assertThat(mp2.getError()).hasMessage("test")) - .then(() -> assertThat(mp2.isTerminated()).isTrue()) + .then(() -> assertThat(mp2.isDisposed()).isTrue()) .verifyErrorMessage("test"); } @@ -466,9 +1586,9 @@ public void doOnSuccessMonoProcessorError() { .then(() -> mp.onNext(2)) .then(() -> assertThat(mp2.isError()).isTrue()) .then(() -> assertThat(ref.get()).hasMessage("test")) - .then(() -> assertThat(mp2.isTerminated()).isTrue()) + .then(() -> assertThat(mp2.isDisposed()).isTrue()) .then(() -> assertThat(mp2.getError()).hasMessage("test")) - .then(() -> assertThat(mp2.isTerminated()).isTrue()) + .then(() -> assertThat(mp2.isDisposed()).isTrue()) .verifyErrorMessage("test"); } @@ -544,7 +1664,6 @@ public void scanProcessorError() { test.onError(new IllegalStateException("boom")); assertThat(test.scan(Scannable.Attr.ERROR)).hasMessage("boom"); - assertThat(test.scan(Scannable.Attr.TERMINATED)).isTrue(); } @Test diff --git a/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java new file mode 100644 index 000000000..84a589a8d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java @@ -0,0 +1,1154 @@ +/* + * Copyright (c) 2011-2017 Pivotal Software Inc, All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal.subscriber; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BooleanSupplier; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.context.Context; + +/** + * A Subscriber implementation that hosts assertion tests for its state and allows asynchronous + * cancellation and requesting. + * + *

To create a new instance of {@link AssertSubscriber}, you have the choice between these static + * methods: + * + *

    + *
  • {@link AssertSubscriber#create()}: create a new {@link AssertSubscriber} and requests an + * unbounded number of elements. + *
  • {@link AssertSubscriber#create(long)}: create a new {@link AssertSubscriber} and requests + * {@code n} elements (can be 0 if you want no initial demand). + *
+ * + *

If you are testing asynchronous publishers, don't forget to use one of the {@code await*()} + * methods to wait for the data to assert. + * + *

You can extend this class but only the onNext, onError and onComplete can be overridden. You + * can call {@link #request(long)} and {@link #cancel()} from any thread or from within the + * overridable methods but you should avoid calling the assertXXX methods asynchronously. + * + *

Usage: + * + *

{@code
+ * AssertSubscriber
+ *   .subscribe(publisher)
+ *   .await()
+ *   .assertValues("ABC", "DEF");
+ * }
+ * + * @param the value type. + * @author Sebastien Deleuze + * @author David Karnok + * @author Anatoly Kadyshev + * @author Stephane Maldini + * @author Brian Clozel + */ +public class AssertSubscriber implements CoreSubscriber, Subscription { + + /** Default timeout for waiting next values to be received */ + public static final Duration DEFAULT_VALUES_TIMEOUT = Duration.ofSeconds(3); + + @SuppressWarnings("rawtypes") + private static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(AssertSubscriber.class, "requested"); + + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater NEXT_VALUES = + AtomicReferenceFieldUpdater.newUpdater(AssertSubscriber.class, List.class, "values"); + + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(AssertSubscriber.class, Subscription.class, "s"); + + private final Context context; + + private final List errors = new LinkedList<>(); + + private final CountDownLatch cdl = new CountDownLatch(1); + + volatile Subscription s; + + volatile long requested; + + volatile List values = new LinkedList<>(); + + /** The fusion mode to request. */ + private int requestedFusionMode = -1; + + /** The established fusion mode. */ + private volatile int establishedFusionMode = -1; + + /** The fuseable QueueSubscription in case a fusion mode was specified. */ + private Fuseable.QueueSubscription qs; + + private int subscriptionCount = 0; + + private int completionCount = 0; + + private volatile long valueCount = 0L; + + private volatile long nextValueAssertedCount = 0L; + + private Duration valuesTimeout = DEFAULT_VALUES_TIMEOUT; + + private boolean valuesStorage = true; + + // + // ============================================================================================================== + // Static methods + // + // ============================================================================================================== + + /** + * Blocking method that waits until {@code conditionSupplier} returns true, or if it does not + * before the specified timeout, throws an {@link AssertionError} with the specified error message + * supplier. + * + * @param timeout the timeout duration + * @param errorMessageSupplier the error message supplier + * @param conditionSupplier condition to break out of the wait loop + * @throws AssertionError + */ + public static void await( + Duration timeout, Supplier errorMessageSupplier, BooleanSupplier conditionSupplier) { + + Objects.requireNonNull(errorMessageSupplier); + Objects.requireNonNull(conditionSupplier); + Objects.requireNonNull(timeout); + + long timeoutNs = timeout.toNanos(); + long startTime = System.nanoTime(); + do { + if (conditionSupplier.getAsBoolean()) { + return; + } + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } while (System.nanoTime() - startTime < timeoutNs); + throw new AssertionError(errorMessageSupplier.get()); + } + + /** + * Blocking method that waits until {@code conditionSupplier} returns true, or if it does not + * before the specified timeout, throw an {@link AssertionError} with the specified error message. + * + * @param timeout the timeout duration + * @param errorMessage the error message + * @param conditionSupplier condition to break out of the wait loop + * @throws AssertionError + */ + public static void await( + Duration timeout, final String errorMessage, BooleanSupplier conditionSupplier) { + await( + timeout, + new Supplier() { + @Override + public String get() { + return errorMessage; + } + }, + conditionSupplier); + } + + /** + * Create a new {@link AssertSubscriber} that requests an unbounded number of elements. + * + *

Be sure at least a publisher has subscribed to it via {@link + * Publisher#subscribe(Subscriber)} before use assert methods. + * + * @param the observed value type + * @return a fresh AssertSubscriber instance + */ + public static AssertSubscriber create() { + return new AssertSubscriber<>(); + } + + /** + * Create a new {@link AssertSubscriber} that requests initially {@code n} elements. You can then + * manage the demand with {@link Subscription#request(long)}. + * + *

Be sure at least a publisher has subscribed to it via {@link + * Publisher#subscribe(Subscriber)} before use assert methods. + * + * @param n Number of elements to request (can be 0 if you want no initial demand). + * @param the observed value type + * @return a fresh AssertSubscriber instance + */ + public static AssertSubscriber create(long n) { + return new AssertSubscriber<>(n); + } + + // + // ============================================================================================================== + // constructors + // + // ============================================================================================================== + + public AssertSubscriber() { + this(Context.empty(), Long.MAX_VALUE); + } + + public AssertSubscriber(long n) { + this(Context.empty(), n); + } + + public AssertSubscriber(Context context) { + this(context, Long.MAX_VALUE); + } + + public AssertSubscriber(Context context, long n) { + if (n < 0) { + throw new IllegalArgumentException("initialRequest >= required but it was " + n); + } + this.context = context; + REQUESTED.lazySet(this, n); + } + + // + // ============================================================================================================== + // Configuration + // + // ============================================================================================================== + + /** + * Enable or disabled the values storage. It is enabled by default, and can be disable in order to + * be able to perform performance benchmarks or tests with a huge amount values. + * + * @param enabled enable value storage? + * @return this + */ + public final AssertSubscriber configureValuesStorage(boolean enabled) { + this.valuesStorage = enabled; + return this; + } + + /** + * Configure the timeout in seconds for waiting next values to be received (3 seconds by default). + * + * @param timeout the new default value timeout duration + * @return this + */ + public final AssertSubscriber configureValuesTimeout(Duration timeout) { + this.valuesTimeout = timeout; + return this; + } + + /** + * Returns the established fusion mode or -1 if it was not enabled + * + * @return the fusion mode, see Fuseable constants + */ + public final int establishedFusionMode() { + return establishedFusionMode; + } + + // + // ============================================================================================================== + // Assertions + // + // ============================================================================================================== + + /** + * Assert a complete successfully signal has been received. + * + * @return this + */ + public final AssertSubscriber assertComplete() { + assertNoError(); + int c = completionCount; + if (c == 0) { + throw new AssertionError("Not completed", null); + } + if (c > 1) { + throw new AssertionError("Multiple completions: " + c, null); + } + return this; + } + + /** + * Assert the specified values have been received. Values storage should be enabled to use this + * method. + * + * @param expectedValues the values to assert + * @see #configureValuesStorage(boolean) + * @return this + */ + public final AssertSubscriber assertContainValues(Set expectedValues) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + if (expectedValues.size() > values.size()) { + throw new AssertionError("Actual contains fewer elements" + values, null); + } + + Iterator expected = expectedValues.iterator(); + + for (; ; ) { + boolean n2 = expected.hasNext(); + if (n2) { + T t2 = expected.next(); + if (!values.contains(t2)) { + throw new AssertionError( + "The element is not contained in the " + + "received results" + + " = " + + valueAndClass(t2), + null); + } + } else { + break; + } + } + return this; + } + + /** + * Assert an error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertError() { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert an error signal has been received. + * + * @param clazz The class of the exception contained in the error signal + * @return this + */ + public final AssertSubscriber assertError(Class clazz) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s == 1) { + Throwable e = errors.get(0); + if (!clazz.isInstance(e)) { + throw new AssertionError( + "Error class incompatible: expected = " + clazz + ", actual = " + e, null); + } + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + public final AssertSubscriber assertErrorMessage(String message) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + assertionError("No error", null); + } + if (s == 1) { + if (!Objects.equals(message, errors.get(0).getMessage())) { + assertionError( + "Error class incompatible: expected = \"" + + message + + "\", actual = \"" + + errors.get(0).getMessage() + + "\"", + null); + } + } + if (s > 1) { + assertionError("Multiple errors: " + s, null); + } + + return this; + } + + /** + * Assert an error signal has been received. + * + * @param expectation A method that can verify the exception contained in the error signal and + * throw an exception (like an {@link AssertionError}) if the exception is not valid. + * @return this + */ + public final AssertSubscriber assertErrorWith(Consumer expectation) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s == 1) { + expectation.accept(errors.get(0)); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert that the upstream was a Fuseable source. + * + * @return this + */ + public final AssertSubscriber assertFuseableSource() { + if (qs == null) { + throw new AssertionError("Upstream was not Fuseable"); + } + return this; + } + + /** + * Assert that the fusion mode was granted. + * + * @return this + */ + public final AssertSubscriber assertFusionEnabled() { + if (establishedFusionMode != Fuseable.SYNC && establishedFusionMode != Fuseable.ASYNC) { + throw new AssertionError("Fusion was not enabled"); + } + return this; + } + + public final AssertSubscriber assertFusionMode(int expectedMode) { + if (establishedFusionMode != expectedMode) { + throw new AssertionError( + "Wrong fusion mode: expected: " + + fusionModeName(expectedMode) + + ", actual: " + + fusionModeName(establishedFusionMode)); + } + return this; + } + + /** + * Assert that the fusion mode was granted. + * + * @return this + */ + public final AssertSubscriber assertFusionRejected() { + if (establishedFusionMode != Fuseable.NONE) { + throw new AssertionError("Fusion was granted"); + } + return this; + } + + /** + * Assert no error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNoError() { + int s = errors.size(); + if (s == 1) { + Throwable e = errors.get(0); + String valueAndClass = e == null ? null : e + " (" + e.getClass().getSimpleName() + ")"; + throw new AssertionError("Error present: " + valueAndClass, null); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert no values have been received. + * + * @return this + */ + public final AssertSubscriber assertNoValues() { + if (valueCount != 0) { + throw new AssertionError( + "No values expected but received: [length = " + values.size() + "] " + values, null); + } + return this; + } + + /** + * Assert that the upstream was not a Fuseable source. + * + * @return this + */ + public final AssertSubscriber assertNonFuseableSource() { + if (qs != null) { + throw new AssertionError("Upstream was Fuseable"); + } + return this; + } + + /** + * Assert no complete successfully signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNotComplete() { + int c = completionCount; + if (c == 1) { + throw new AssertionError("Completed", null); + } + if (c > 1) { + throw new AssertionError("Multiple completions: " + c, null); + } + return this; + } + + /** + * Assert no subscription occurred. + * + * @return this + */ + public final AssertSubscriber assertNotSubscribed() { + int s = subscriptionCount; + + if (s == 1) { + throw new AssertionError("OnSubscribe called once", null); + } + if (s > 1) { + throw new AssertionError("OnSubscribe called multiple times: " + s, null); + } + + return this; + } + + /** + * Assert no complete successfully or error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNotTerminated() { + if (cdl.getCount() == 0) { + throw new AssertionError("Terminated", null); + } + return this; + } + + /** + * Assert subscription occurred (once). + * + * @return this + */ + public final AssertSubscriber assertSubscribed() { + int s = subscriptionCount; + + if (s == 0) { + throw new AssertionError("OnSubscribe not called", null); + } + if (s > 1) { + throw new AssertionError("OnSubscribe called multiple times: " + s, null); + } + + return this; + } + + /** + * Assert either complete successfully or error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertTerminated() { + if (cdl.getCount() != 0) { + throw new AssertionError("Not terminated", null); + } + return this; + } + + /** + * Assert {@code n} values has been received. + * + * @param n the expected value count + * @return this + */ + public final AssertSubscriber assertValueCount(long n) { + if (valueCount != n) { + throw new AssertionError( + "Different value count: expected = " + n + ", actual = " + valueCount, null); + } + return this; + } + + /** + * Assert the specified values have been received in the same order read by the passed {@link + * Iterable}. Values storage should be enabled to use this method. + * + * @param expectedSequence the values to assert + * @see #configureValuesStorage(boolean) + * @return this + */ + public final AssertSubscriber assertValueSequence(Iterable expectedSequence) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + Iterator actual = values.iterator(); + Iterator expected = expectedSequence.iterator(); + int i = 0; + for (; ; ) { + boolean n1 = actual.hasNext(); + boolean n2 = expected.hasNext(); + if (n1 && n2) { + T t1 = actual.next(); + T t2 = expected.next(); + if (!Objects.equals(t1, t2)) { + throw new AssertionError( + "The element with index " + + i + + " does not match: expected = " + + valueAndClass(t2) + + ", actual = " + + valueAndClass(t1), + null); + } + i++; + } else if (n1 && !n2) { + throw new AssertionError("Actual contains more elements" + values, null); + } else if (!n1 && n2) { + throw new AssertionError("Actual contains fewer elements: " + values, null); + } else { + break; + } + } + return this; + } + + /** + * Assert the specified values have been received in the declared order. Values storage should be + * enabled to use this method. + * + * @param expectedValues the values to assert + * @return this + * @see #configureValuesStorage(boolean) + */ + @SafeVarargs + public final AssertSubscriber assertValues(T... expectedValues) { + return assertValueSequence(Arrays.asList(expectedValues)); + } + + /** + * Assert the specified values have been received in the declared order. Values storage should be + * enabled to use this method. + * + * @param expectations One or more methods that can verify the values and throw a exception (like + * an {@link AssertionError}) if the value is not valid. + * @return this + * @see #configureValuesStorage(boolean) + */ + @SafeVarargs + public final AssertSubscriber assertValuesWith(Consumer... expectations) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + final int expectedValueCount = expectations.length; + if (expectedValueCount != values.size()) { + throw new AssertionError( + "Different value count: expected = " + expectedValueCount + ", actual = " + valueCount, + null); + } + for (int i = 0; i < expectedValueCount; i++) { + Consumer consumer = expectations[i]; + T actualValue = values.get(i); + consumer.accept(actualValue); + } + return this; + } + + // + // ============================================================================================================== + // Await methods + // + // ============================================================================================================== + + /** + * Blocking method that waits until a complete successfully or error signal is received. + * + * @return this + */ + public final AssertSubscriber await() { + if (cdl.getCount() == 0) { + return this; + } + try { + cdl.await(); + } catch (InterruptedException ex) { + throw new AssertionError("Wait interrupted", ex); + } + return this; + } + + /** + * Blocking method that waits until a complete successfully or error signal is received or until a + * timeout occurs. + * + * @param timeout The timeout value + * @return this + */ + public final AssertSubscriber await(Duration timeout) { + if (cdl.getCount() == 0) { + return this; + } + try { + if (!cdl.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) { + throw new AssertionError("No complete or error signal before timeout"); + } + return this; + } catch (InterruptedException ex) { + throw new AssertionError("Wait interrupted", ex); + } + } + + /** + * Blocking method that waits until {@code n} next values have been received. + * + * @param n the value count to assert + * @return this + */ + public final AssertSubscriber awaitAndAssertNextValueCount(final long n) { + await( + valuesTimeout, + () -> { + if (valuesStorage) { + return String.format( + "%d out of %d next values received within %d, " + "values : %s", + valueCount - nextValueAssertedCount, + n, + valuesTimeout.toMillis(), + values.toString()); + } + return String.format( + "%d out of %d next values received within %d", + valueCount - nextValueAssertedCount, n, valuesTimeout.toMillis()); + }, + () -> valueCount >= (nextValueAssertedCount + n)); + nextValueAssertedCount += n; + return this; + } + + /** + * Blocking method that waits until {@code n} next values have been received (n is the number of + * values provided) to assert them. + * + * @param values the values to assert + * @return this + */ + @SafeVarargs + @SuppressWarnings("unchecked") + public final AssertSubscriber awaitAndAssertNextValues(T... values) { + final int expectedNum = values.length; + final List> expectations = new ArrayList<>(); + for (int i = 0; i < expectedNum; i++) { + final T expectedValue = values[i]; + expectations.add( + actualValue -> { + if (!actualValue.equals(expectedValue)) { + throw new AssertionError( + String.format( + "Expected Next signal: %s, but got: %s", expectedValue, actualValue)); + } + }); + } + awaitAndAssertNextValuesWith(expectations.toArray((Consumer[]) new Consumer[0])); + return this; + } + + /** + * Blocking method that waits until {@code n} next values have been received (n is the number of + * expectations provided) to assert them. + * + * @param expectations One or more methods that can verify the values and throw a exception (like + * an {@link AssertionError}) if the value is not valid. + * @return this + */ + @SafeVarargs + public final AssertSubscriber awaitAndAssertNextValuesWith(Consumer... expectations) { + valuesStorage = true; + final int expectedValueCount = expectations.length; + await( + valuesTimeout, + () -> { + if (valuesStorage) { + return String.format( + "%d out of %d next values received within %d, " + "values : %s", + valueCount - nextValueAssertedCount, + expectedValueCount, + valuesTimeout.toMillis(), + values.toString()); + } + return String.format( + "%d out of %d next values received within %d ms", + valueCount - nextValueAssertedCount, expectedValueCount, valuesTimeout.toMillis()); + }, + () -> valueCount >= (nextValueAssertedCount + expectedValueCount)); + List nextValuesSnapshot; + List empty = new ArrayList<>(); + for (; ; ) { + nextValuesSnapshot = values; + if (NEXT_VALUES.compareAndSet(this, values, empty)) { + break; + } + } + if (nextValuesSnapshot.size() < expectedValueCount) { + throw new AssertionError( + String.format( + "Expected %d number of signals but received %d", + expectedValueCount, nextValuesSnapshot.size())); + } + for (int i = 0; i < expectedValueCount; i++) { + Consumer consumer = expectations[i]; + T actualValue = nextValuesSnapshot.get(i); + consumer.accept(actualValue); + } + nextValueAssertedCount += expectedValueCount; + return this; + } + + // + // ============================================================================================================== + // Overrides + // + // ============================================================================================================== + + @Override + public void cancel() { + Subscription a = s; + if (a != Operators.cancelledSubscription()) { + a = S.getAndSet(this, Operators.cancelledSubscription()); + if (a != null && a != Operators.cancelledSubscription()) { + a.cancel(); + } + } + } + + final boolean isCancelled() { + return s == Operators.cancelledSubscription(); + } + + public final boolean isTerminated() { + return cdl.getCount() == 0; + } + + @Override + public void onComplete() { + completionCount++; + cdl.countDown(); + } + + @Override + public void onError(Throwable t) { + errors.add(t); + cdl.countDown(); + } + + @Override + public void onNext(T t) { + if (establishedFusionMode == Fuseable.ASYNC) { + for (; ; ) { + t = qs.poll(); + if (t == null) { + break; + } + valueCount++; + if (valuesStorage) { + List nextValuesSnapshot; + for (; ; ) { + nextValuesSnapshot = values; + nextValuesSnapshot.add(t); + if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { + break; + } + } + } + } + } else { + valueCount++; + if (valuesStorage) { + List nextValuesSnapshot; + for (; ; ) { + nextValuesSnapshot = values; + nextValuesSnapshot.add(t); + if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { + break; + } + } + } + } + } + + @Override + @SuppressWarnings("unchecked") + public void onSubscribe(Subscription s) { + subscriptionCount++; + int requestMode = requestedFusionMode; + if (requestMode >= 0) { + if (!setWithoutRequesting(s)) { + if (!isCancelled()) { + errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); + } + } else { + if (s instanceof Fuseable.QueueSubscription) { + this.qs = (Fuseable.QueueSubscription) s; + + int m = qs.requestFusion(requestMode); + establishedFusionMode = m; + + if (m == Fuseable.SYNC) { + for (; ; ) { + T v = qs.poll(); + if (v == null) { + onComplete(); + break; + } + + onNext(v); + } + } else { + requestDeferred(); + } + } else { + requestDeferred(); + } + } + } else { + if (!set(s)) { + if (!isCancelled()) { + errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); + } + } + } + } + + @Override + public void request(long n) { + if (Operators.validate(n)) { + if (establishedFusionMode != Fuseable.SYNC) { + normalRequest(n); + } + } + } + + @Override + @NonNull + public Context currentContext() { + return context; + } + + /** + * Setup what fusion mode should be requested from the incoming Subscription if it happens to be + * QueueSubscription + * + * @param requestMode the mode to request, see Fuseable constants + * @return this + */ + public final AssertSubscriber requestedFusionMode(int requestMode) { + this.requestedFusionMode = requestMode; + return this; + } + + public Subscription upstream() { + return s; + } + + // + // ============================================================================================================== + // Non public methods + // + // ============================================================================================================== + + protected final void normalRequest(long n) { + Subscription a = s; + if (a != null) { + a.request(n); + } else { + Operators.addCap(REQUESTED, this, n); + + a = s; + + if (a != null) { + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + a.request(r); + } + } + } + } + + /** Requests the deferred amount if not zero. */ + protected final void requestDeferred() { + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + s.request(r); + } + } + + /** + * Atomically sets the single subscription and requests the missed amount from it. + * + * @param s + * @return false if this arbiter is cancelled or there was a subscription already set + */ + protected final boolean set(Subscription s) { + Objects.requireNonNull(s, "s"); + Subscription a = this.s; + if (a == Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + if (a != null) { + s.cancel(); + Operators.reportSubscriptionSet(); + return false; + } + + if (S.compareAndSet(this, null, s)) { + + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + s.request(r); + } + + return true; + } + + a = this.s; + + if (a != Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + + Operators.reportSubscriptionSet(); + return false; + } + + /** + * Sets the Subscription once but does not request anything. + * + * @param s the Subscription to set + * @return true if successful, false if the current subscription is not null + */ + protected final boolean setWithoutRequesting(Subscription s) { + Objects.requireNonNull(s, "s"); + for (; ; ) { + Subscription a = this.s; + if (a == Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + if (a != null) { + s.cancel(); + Operators.reportSubscriptionSet(); + return false; + } + + if (S.compareAndSet(this, null, s)) { + return true; + } + } + } + + /** + * Prepares and throws an AssertionError exception based on the message, cause, the active state + * and the potential errors so far. + * + * @param message the message + * @param cause the optional Throwable cause + * @throws AssertionError as expected + */ + protected final void assertionError(String message, Throwable cause) { + StringBuilder b = new StringBuilder(); + + if (cdl.getCount() != 0) { + b.append("(active) "); + } + b.append(message); + + List err = errors; + if (!err.isEmpty()) { + b.append(" (+ ").append(err.size()).append(" errors)"); + } + AssertionError e = new AssertionError(b.toString(), cause); + + for (Throwable t : err) { + e.addSuppressed(t); + } + + throw e; + } + + protected final String fusionModeName(int mode) { + switch (mode) { + case -1: + return "Disabled"; + case Fuseable.NONE: + return "None"; + case Fuseable.SYNC: + return "Sync"; + case Fuseable.ASYNC: + return "Async"; + default: + return "Unknown(" + mode + ")"; + } + } + + protected final String valueAndClass(Object o) { + if (o == null) { + return null; + } + return o + " (" + o.getClass().getSimpleName() + ")"; + } + + public List values() { + return values; + } + + public final AssertSubscriber assertNoEvents() { + return assertNoValues().assertNoError().assertNotComplete(); + } + + @SafeVarargs + public final AssertSubscriber assertIncomplete(T... values) { + return assertValues(values).assertNotComplete().assertNoError(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/security/AuthMetadataFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/security/AuthMetadataFlyweightTest.java new file mode 100644 index 000000000..13d910e15 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/security/AuthMetadataFlyweightTest.java @@ -0,0 +1,470 @@ +package io.rsocket.metadata.security; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +class AuthMetadataFlyweightTest { + + public static final int AUTH_TYPE_ID_LENGTH = 1; + public static final int USER_NAME_BYTES_LENGTH = 1; + public static final String TEST_BEARER_TOKEN = + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJpYXQxIjoxNTE2MjM5MDIyLCJpYXQyIjoxNTE2MjM5MDIyLCJpYXQzIjoxNTE2MjM5MDIyLCJpYXQ0IjoxNTE2MjM5MDIyfQ.ljYuH-GNyyhhLcx-rHMchRkGbNsR2_4aSxo8XjrYrSM"; + + @Test + void shouldCorrectlyEncodeData() { + String username = "test"; + String password = "tset1234"; + + int usernameLength = username.length(); + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + @Test + void shouldCorrectlyEncodeData1() { + String username = "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎"; + String password = "tset1234"; + + int usernameLength = username.getBytes(CharsetUtil.UTF_8).length; + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + @Test + void shouldCorrectlyEncodeData2() { + String username = "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎1234567#4? "; + String password = "tset1234"; + + int usernameLength = username.getBytes(CharsetUtil.UTF_8).length; + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + private static void checkSimpleAuthMetadataEncoding( + String username, String password, int usernameLength, int passwordLength, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(AUTH_TYPE_ID_LENGTH + USER_NAME_BYTES_LENGTH + usernameLength + passwordLength); + + Assertions.assertThat(byteBuf.readUnsignedByte() & ~0x80) + .isEqualTo(WellKnownAuthType.SIMPLE.getIdentifier()); + Assertions.assertThat(byteBuf.readUnsignedByte()).isEqualTo((short) usernameLength); + + Assertions.assertThat(byteBuf.readCharSequence(usernameLength, CharsetUtil.UTF_8)) + .isEqualTo(username); + Assertions.assertThat(byteBuf.readCharSequence(passwordLength, CharsetUtil.UTF_8)) + .isEqualTo(password); + + ReferenceCountUtil.release(byteBuf); + } + + private static void checkSimpleAuthMetadataEncodingUsingDecoders( + String username, String password, int usernameLength, int passwordLength, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(AUTH_TYPE_ID_LENGTH + USER_NAME_BYTES_LENGTH + usernameLength + passwordLength); + + Assertions.assertThat(AuthMetadataFlyweight.decodeWellKnownAuthType(byteBuf)) + .isEqualTo(WellKnownAuthType.SIMPLE); + byteBuf.markReaderIndex(); + Assertions.assertThat(AuthMetadataFlyweight.decodeUsername(byteBuf).toString(CharsetUtil.UTF_8)) + .isEqualTo(username); + Assertions.assertThat(AuthMetadataFlyweight.decodePassword(byteBuf).toString(CharsetUtil.UTF_8)) + .isEqualTo(password); + byteBuf.resetReaderIndex(); + + Assertions.assertThat(new String(AuthMetadataFlyweight.decodeUsernameAsCharArray(byteBuf))) + .isEqualTo(username); + Assertions.assertThat(new String(AuthMetadataFlyweight.decodePasswordAsCharArray(byteBuf))) + .isEqualTo(password); + + ReferenceCountUtil.release(byteBuf); + } + + @Test + void shouldThrowExceptionIfUsernameLengthExitsAllowedBounds() { + String username = + "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎𠸏𠹷𠺝𠺢𠻗𠻹𠻺𠼭𠼮𠽌𠾴𠾼𠿪𡁜𡁯𡁵𡁶𡁻𡃁𡃉𡇙𢃇𢞵𢫕𢭃𢯊𢱑𢱕𢳂𢴈𢵌𢵧𢺳𣲷𤓓𤶸𤷪𥄫𦉘𦟌𦧲𦧺𧨾𨅝𨈇𨋢𨳊𨳍𨳒𩶘𠜎𠜱𠝹"; + String password = "tset1234"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray())) + .hasMessage( + "Username should be shorter than or equal to 255 bytes length in UTF-8 encoding"); + } + + @Test + void shouldEncodeBearerMetadata() { + String testToken = TEST_BEARER_TOKEN; + + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeBearerMetadata( + ByteBufAllocator.DEFAULT, testToken.toCharArray()); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(testToken, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(testToken, byteBuf); + } + + private static void checkBearerAuthMetadataEncoding(String testToken, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(testToken.getBytes(CharsetUtil.UTF_8).length + AUTH_TYPE_ID_LENGTH); + Assertions.assertThat( + byteBuf.readUnsignedByte() & ~AuthMetadataFlyweight.STREAM_METADATA_KNOWN_MASK) + .isEqualTo(WellKnownAuthType.BEARER.getIdentifier()); + Assertions.assertThat(byteBuf.readSlice(byteBuf.capacity() - 1).toString(CharsetUtil.UTF_8)) + .isEqualTo(testToken); + } + + private static void checkBearerAuthMetadataEncodingUsingDecoders( + String testToken, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(testToken.getBytes(CharsetUtil.UTF_8).length + AUTH_TYPE_ID_LENGTH); + Assertions.assertThat(AuthMetadataFlyweight.isWellKnownAuthType(byteBuf)).isTrue(); + Assertions.assertThat(AuthMetadataFlyweight.decodeWellKnownAuthType(byteBuf)) + .isEqualTo(WellKnownAuthType.BEARER); + byteBuf.markReaderIndex(); + Assertions.assertThat(new String(AuthMetadataFlyweight.decodeBearerTokenAsCharArray(byteBuf))) + .isEqualTo(testToken); + byteBuf.resetReaderIndex(); + Assertions.assertThat( + AuthMetadataFlyweight.decodePayload(byteBuf).toString(CharsetUtil.UTF_8).toString()) + .isEqualTo(testToken); + } + + @Test + void shouldEncodeCustomAuth() { + String payloadAsAText = "testsecuritybuffer"; + ByteBuf testSecurityPayload = + Unpooled.wrappedBuffer(payloadAsAText.getBytes(CharsetUtil.UTF_8)); + + String customAuthType = "myownauthtype"; + ByteBuf buffer = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload); + + checkCustomAuthMetadataEncoding(testSecurityPayload, customAuthType, buffer); + } + + private static void checkCustomAuthMetadataEncoding( + ByteBuf testSecurityPayload, String customAuthType, ByteBuf buffer) { + Assertions.assertThat(buffer.capacity()) + .isEqualTo(1 + customAuthType.length() + testSecurityPayload.capacity()); + Assertions.assertThat(buffer.readUnsignedByte()) + .isEqualTo((short) (customAuthType.length() - 1)); + Assertions.assertThat( + buffer.readCharSequence(customAuthType.length(), CharsetUtil.US_ASCII).toString()) + .isEqualTo(customAuthType); + Assertions.assertThat(buffer.readSlice(testSecurityPayload.capacity())) + .isEqualTo(testSecurityPayload); + + ReferenceCountUtil.release(buffer); + } + + @Test + void shouldThrowOnNonASCIIChars() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + String customAuthType = "1234567#4? 𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage("custom auth type must be US_ASCII characters only"); + } + + @Test + void shouldThrowOnOutOfAllowedSizeType() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + // 130 chars + String customAuthType = + "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void shouldThrowOnOutOfAllowedSizeType1() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + String customAuthType = ""; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void shouldEncodeUsingWellKnownAuthType() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.SIMPLE, + ByteBufAllocator.DEFAULT.buffer(3, 3).writeByte(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldEncodeUsingWellKnownAuthType1() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.SIMPLE, + ByteBufAllocator.DEFAULT.buffer().writeByte(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldEncodeUsingWellKnownAuthType2() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.BEARER, + Unpooled.copiedBuffer(TEST_BEARER_TOKEN, CharsetUtil.UTF_8)); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(TEST_BEARER_TOKEN, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(TEST_BEARER_TOKEN, byteBuf); + } + + @Test + void shouldThrowIfWellKnownAuthTypeIsUnsupportedOrUnknown() { + ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, WellKnownAuthType.UNPARSEABLE_AUTH_TYPE, buffer)) + .hasMessage("only allowed AuthType should be used"); + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, WellKnownAuthType.UNPARSEABLE_AUTH_TYPE, buffer)) + .hasMessage("only allowed AuthType should be used"); + + buffer.release(); + } + + @Test + void shouldCompressMetadata() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, + "simple", + ByteBufAllocator.DEFAULT.buffer().writeByte(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldCompressMetadata1() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, + "bearer", + Unpooled.copiedBuffer(TEST_BEARER_TOKEN, CharsetUtil.UTF_8)); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(TEST_BEARER_TOKEN, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(TEST_BEARER_TOKEN, byteBuf); + } + + @Test + void shouldNotCompressMetadata() { + ByteBuf testMetadataPayload = + Unpooled.wrappedBuffer(TEST_BEARER_TOKEN.getBytes(CharsetUtil.UTF_8)); + String customAuthType = "testauthtype"; + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, customAuthType, testMetadataPayload); + + checkCustomAuthMetadataEncoding(testMetadataPayload, customAuthType, byteBuf); + } + + @Test + void shouldConfirmWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple", Unpooled.EMPTY_BUFFER); + + int initialReaderIndex = metadata.readerIndex(); + + Assertions.assertThat(AuthMetadataFlyweight.isWellKnownAuthType(metadata)).isTrue(); + Assertions.assertThat(metadata.readerIndex()).isEqualTo(initialReaderIndex); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldConfirmGivenMetadataIsNotAWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple/afafgafadf", Unpooled.EMPTY_BUFFER); + + int initialReaderIndex = metadata.readerIndex(); + + Assertions.assertThat(AuthMetadataFlyweight.isWellKnownAuthType(metadata)).isFalse(); + Assertions.assertThat(metadata.readerIndex()).isEqualTo(initialReaderIndex); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldReadSimpleWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.SIMPLE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldReadSimpleWellKnownAuthType1() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "bearer", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.BEARER; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldReadSimpleWellKnownAuthType2() { + ByteBuf metadata = + ByteBufAllocator.DEFAULT + .buffer() + .writeByte(3 | AuthMetadataFlyweight.STREAM_METADATA_KNOWN_MASK); + WellKnownAuthType expectedType = WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldNotReadSimpleWellKnownAuthTypeIfEncodedLength() { + ByteBuf metadata = ByteBufAllocator.DEFAULT.buffer().writeByte(3); + WellKnownAuthType expectedType = WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldNotReadSimpleWellKnownAuthTypeIfEncodedLength1() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, "testmetadataauthtype", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldThrowExceptionIsNotEnoughReadableBytes() { + Assertions.assertThatThrownBy( + () -> AuthMetadataFlyweight.decodeWellKnownAuthType(Unpooled.EMPTY_BUFFER)) + .hasMessage("Unable to decode Well Know Auth type. Not enough readable bytes"); + } + + private static void checkDecodeWellKnowAuthTypeCorrectly( + ByteBuf metadata, WellKnownAuthType expectedType) { + int initialReaderIndex = metadata.readerIndex(); + + WellKnownAuthType wellKnownAuthType = AuthMetadataFlyweight.decodeWellKnownAuthType(metadata); + + Assertions.assertThat(wellKnownAuthType).isEqualTo(expectedType); + Assertions.assertThat(metadata.readerIndex()) + .isNotEqualTo(initialReaderIndex) + .isEqualTo(initialReaderIndex + 1); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldReadCustomEncodedAuthType() { + String testAuthType = "TestAuthType"; + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, testAuthType, Unpooled.EMPTY_BUFFER); + checkDecodeCustomAuthTypeCorrectly(testAuthType, byteBuf); + } + + @Test + void shouldThrowExceptionOnEmptyMetadata() { + Assertions.assertThatThrownBy( + () -> AuthMetadataFlyweight.decodeCustomAuthType(Unpooled.EMPTY_BUFFER)) + .hasMessage("Unable to decode custom Auth type. Not enough readable bytes"); + } + + @Test + void shouldThrowExceptionOnMalformedMetadata_wellknowninstead() { + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.decodeCustomAuthType( + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.BEARER, + Unpooled.copiedBuffer(new byte[] {'a', 'b'})))) + .hasMessage("Unable to decode custom Auth type. Incorrect auth type length"); + } + + @Test + void shouldThrowExceptionOnMalformedMetadata_length() { + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.decodeCustomAuthType( + ByteBufAllocator.DEFAULT.buffer().writeByte(127).writeChar('a').writeChar('b'))) + .hasMessage("Unable to decode custom Auth type. Malformed length or auth type string"); + } + + private static void checkDecodeCustomAuthTypeCorrectly(String testAuthType, ByteBuf byteBuf) { + int initialReaderIndex = byteBuf.readerIndex(); + + Assertions.assertThat(AuthMetadataFlyweight.decodeCustomAuthType(byteBuf).toString()) + .isEqualTo(testAuthType); + Assertions.assertThat(byteBuf.readerIndex()) + .isEqualTo(initialReaderIndex + testAuthType.length() + 1); + } +} diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java index 376dcdf73..ed7550233 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java @@ -583,7 +583,6 @@ private class WeightedSocket extends AbstractRSocket implements LoadBalancerSock activeSockets.remove(WeightedSocket.this); logger.debug( "Removed {} from factory {} from activeSockets", WeightedSocket.this, factory); - refreshSockets(); }) .subscribe(); diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java new file mode 100644 index 000000000..26fb44535 --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java @@ -0,0 +1,40 @@ +package io.rsocket.transport.netty.server; + +import static io.netty.channel.ChannelHandler.*; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Closeable; +import io.rsocket.transport.ServerTransport; +import java.util.function.Function; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.netty.http.server.HttpServer; + +abstract class BaseWebsocketServerTransport implements ServerTransport { + private static final Logger logger = LoggerFactory.getLogger(BaseWebsocketServerTransport.class); + private static final ChannelHandler pongHandler = new PongHandler(); + + static Function serverConfigurer = + server -> + server.tcpConfiguration( + tcpServer -> + tcpServer.doOnConnection(connection -> connection.addHandlerLast(pongHandler))); + + @Sharable + private static class PongHandler extends ChannelInboundHandlerAdapter { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof PongWebSocketFrame) { + logger.debug("received WebSocket Pong Frame"); + ReferenceCountUtil.safeRelease(msg); + ctx.read(); + } else { + ctx.fireChannelRead(msg); + } + } + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java index 9b78ece60..30aa0fa96 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java @@ -46,7 +46,7 @@ * An implementation of {@link ServerTransport} that connects via Websocket and listens on specified * routes. */ -public final class WebsocketRouteTransport implements ServerTransport { +public final class WebsocketRouteTransport extends BaseWebsocketServerTransport { private final UriPathTemplate template; @@ -63,8 +63,7 @@ public final class WebsocketRouteTransport implements ServerTransport */ public WebsocketRouteTransport( HttpServer server, Consumer routesBuilder, String path) { - - this.server = Objects.requireNonNull(server, "server must not be null"); + this.server = serverConfigurer.apply(Objects.requireNonNull(server, "server must not be null")); this.routesBuilder = Objects.requireNonNull(routesBuilder, "routesBuilder must not be null"); this.template = new UriPathTemplate(Objects.requireNonNull(path, "path must not be null")); } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java index 205f419a2..948d6f573 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java @@ -19,12 +19,6 @@ import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK; import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; -import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; -import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; import io.rsocket.transport.ClientTransport; @@ -46,8 +40,8 @@ * An implementation of {@link ServerTransport} that connects to a {@link ClientTransport} via a * Websocket. */ -public final class WebsocketServerTransport - implements ServerTransport, TransportHeaderAware { +public final class WebsocketServerTransport extends BaseWebsocketServerTransport + implements TransportHeaderAware { private static final Logger logger = LoggerFactory.getLogger(WebsocketServerTransport.class); private final HttpServer server; @@ -55,7 +49,7 @@ public final class WebsocketServerTransport private Supplier> transportHeaders = Collections::emptyMap; private WebsocketServerTransport(HttpServer server) { - this.server = server; + this.server = serverConfigurer.apply(Objects.requireNonNull(server, "server must not be null")); } /** @@ -107,33 +101,7 @@ public static WebsocketServerTransport create(InetSocketAddress address) { public static WebsocketServerTransport create(final HttpServer server) { Objects.requireNonNull(server, "server must not be null"); - return new WebsocketServerTransport( - server.tcpConfiguration( - tcpServer -> - tcpServer.doOnConnection( - connection -> - connection.addHandlerLast( - new ChannelInboundHandlerAdapter() { - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) - throws Exception { - if (msg instanceof PongWebSocketFrame) { - logger.debug("received WebSocket Pong Frame"); - } else if (msg instanceof PingWebSocketFrame) { - logger.debug( - "received WebSocket Ping Frame - sending Pong Frame"); - PongWebSocketFrame pongWebSocketFrame = - new PongWebSocketFrame(Unpooled.EMPTY_BUFFER); - ctx.writeAndFlush(pongWebSocketFrame); - } else if (msg instanceof CloseWebSocketFrame) { - logger.warn( - "received WebSocket Close Frame - connection is closing"); - ctx.close(); - } else { - ctx.fireChannelRead(msg); - } - } - })))); + return new WebsocketServerTransport(server); } @Override diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java new file mode 100644 index 000000000..eac091dd8 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java @@ -0,0 +1,160 @@ +package io.rsocket.transport.netty; + +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.*; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketRouteTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import io.rsocket.util.DefaultPayload; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.stream.Stream; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +public class WebsocketPingPongIntegrationTest { + private static final String host = "localhost"; + private static final int port = 8088; + + private Closeable server; + + @AfterEach + void tearDown() { + server.dispose(); + } + + @ParameterizedTest + @MethodSource("provideServerTransport") + void webSocketPingPong(ServerTransport serverTransport) { + server = + RSocketFactory.receive() + .acceptor((setup, sendingSocket) -> Mono.just(new EchoRSocket())) + .transport(serverTransport) + .start() + .block(); + + String expectedData = "data"; + String expectedPing = "ping"; + + PingSender pingSender = new PingSender(); + + HttpClient httpClient = + HttpClient.create() + .tcpConfiguration( + tcpClient -> + tcpClient + .doOnConnected(b -> b.addHandlerLast(pingSender)) + .host(host) + .port(port)); + + RSocket rSocket = + RSocketFactory.connect() + .transport(WebsocketClientTransport.create(httpClient, "/")) + .start() + .block(); + + rSocket + .requestResponse(DefaultPayload.create(expectedData)) + .delaySubscription(pingSender.sendPing(expectedPing)) + .as(StepVerifier::create) + .expectNextMatches(p -> expectedData.equals(p.getDataUtf8())) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + pingSender + .receivePong() + .as(StepVerifier::create) + .expectNextMatches(expectedPing::equals) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + rSocket + .requestResponse(DefaultPayload.create(expectedData)) + .delaySubscription(pingSender.sendPong()) + .as(StepVerifier::create) + .expectNextMatches(p -> expectedData.equals(p.getDataUtf8())) + .expectComplete() + .verify(Duration.ofSeconds(5)); + } + + private static Stream provideServerTransport() { + return Stream.of( + Arguments.of(WebsocketServerTransport.create(host, port)), + Arguments.of( + new WebsocketRouteTransport( + HttpServer.create().host(host).port(port), routes -> {}, "/"))); + } + + private static class EchoRSocket extends AbstractRSocket { + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + } + + private static class PingSender extends ChannelInboundHandlerAdapter { + private final MonoProcessor channel = MonoProcessor.create(); + private final MonoProcessor pong = MonoProcessor.create(); + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof PongWebSocketFrame) { + pong.onNext(((PongWebSocketFrame) msg).content().toString(StandardCharsets.UTF_8)); + ReferenceCountUtil.safeRelease(msg); + ctx.read(); + } else { + super.channelRead(ctx, msg); + } + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + Channel ch = ctx.channel(); + if (!channel.isTerminated() && ch.isWritable()) { + channel.onNext(ctx.channel()); + } + super.channelWritabilityChanged(ctx); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + Channel ch = ctx.channel(); + if (ch.isWritable()) { + channel.onNext(ch); + } + super.handlerAdded(ctx); + } + + public Mono sendPing(String data) { + return send( + new PingWebSocketFrame(Unpooled.wrappedBuffer(data.getBytes(StandardCharsets.UTF_8)))); + } + + public Mono sendPong() { + return send(new PongWebSocketFrame()); + } + + public Mono receivePong() { + return pong; + } + + private Mono send(WebSocketFrame webSocketFrame) { + return channel.doOnNext(ch -> ch.writeAndFlush(webSocketFrame)).then(); + } + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java index 26f598c2d..e94bef13c 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java @@ -16,93 +16,16 @@ package io.rsocket.transport.netty.server; -import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK; import static org.assertj.core.api.Assertions.assertThatNullPointerException; -import java.util.function.BiFunction; -import java.util.function.Consumer; -import java.util.function.Predicate; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; import reactor.core.publisher.Mono; import reactor.netty.http.server.HttpServer; -import reactor.netty.http.server.HttpServerRoutes; import reactor.test.StepVerifier; final class WebsocketRouteTransportTest { - @Test - public void testThatSetupWithUnSpecifiedFrameSizeShouldSetMaxFrameSize() { - ArgumentCaptor captor = ArgumentCaptor.forClass(Consumer.class); - HttpServer httpServer = Mockito.spy(HttpServer.create()); - HttpServerRoutes routes = Mockito.mock(HttpServerRoutes.class); - Mockito.doAnswer(a -> httpServer).when(httpServer).route(captor.capture()); - Mockito.doAnswer(a -> Mono.empty()).when(httpServer).bind(); - - WebsocketRouteTransport serverTransport = - new WebsocketRouteTransport(httpServer, (r) -> {}, ""); - - serverTransport.start(c -> Mono.empty(), 0).subscribe(); - - captor.getValue().accept(routes); - - Mockito.verify(routes) - .ws( - Mockito.any(Predicate.class), - Mockito.any(BiFunction.class), - Mockito.nullable(String.class), - Mockito.eq(FRAME_LENGTH_MASK)); - } - - @Test - public void testThatSetupWithSpecifiedFrameSizeButLowerThanWsDefaultShouldSetToWsDefault() { - ArgumentCaptor captor = ArgumentCaptor.forClass(Consumer.class); - HttpServer httpServer = Mockito.spy(HttpServer.create()); - HttpServerRoutes routes = Mockito.mock(HttpServerRoutes.class); - Mockito.doAnswer(a -> httpServer).when(httpServer).route(captor.capture()); - Mockito.doAnswer(a -> Mono.empty()).when(httpServer).bind(); - - WebsocketRouteTransport serverTransport = - new WebsocketRouteTransport(httpServer, (r) -> {}, ""); - - serverTransport.start(c -> Mono.empty(), 1000).subscribe(); - - captor.getValue().accept(routes); - - Mockito.verify(routes) - .ws( - Mockito.any(Predicate.class), - Mockito.any(BiFunction.class), - Mockito.nullable(String.class), - Mockito.eq(FRAME_LENGTH_MASK)); - } - - @Test - public void - testThatSetupWithSpecifiedFrameSizeButHigherThanWsDefaultShouldSetToSpecifiedFrameSize() { - ArgumentCaptor captor = ArgumentCaptor.forClass(Consumer.class); - HttpServer httpServer = Mockito.spy(HttpServer.create()); - HttpServerRoutes routes = Mockito.mock(HttpServerRoutes.class); - Mockito.doAnswer(a -> httpServer).when(httpServer).route(captor.capture()); - Mockito.doAnswer(a -> Mono.empty()).when(httpServer).bind(); - - WebsocketRouteTransport serverTransport = - new WebsocketRouteTransport(httpServer, (r) -> {}, ""); - - serverTransport.start(c -> Mono.empty(), 65536 + 1000).subscribe(); - - captor.getValue().accept(routes); - - Mockito.verify(routes) - .ws( - Mockito.any(Predicate.class), - Mockito.any(BiFunction.class), - Mockito.nullable(String.class), - Mockito.eq(FRAME_LENGTH_MASK)); - } - @DisplayName("creates server") @Test void constructor() { diff --git a/settings.gradle b/settings.gradle index 625633774..25c3feee5 100644 --- a/settings.gradle +++ b/settings.gradle @@ -13,14 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +plugins { + id 'com.gradle.enterprise' version '3.1' +} rootProject.name = 'rsocket-java' include 'rsocket-core' -include 'rsocket-examples' include 'rsocket-load-balancer' include 'rsocket-micrometer' include 'rsocket-test' include 'rsocket-transport-local' include 'rsocket-transport-netty' include 'rsocket-bom' + +include 'rsocket-examples' +include 'benchmarks' + + + +gradleEnterprise { + buildScan { + termsOfServiceUrl = 'https://gradle.com/terms-of-service' + termsOfServiceAgree = 'yes' + } +} +