diff --git a/README.md b/README.md index 5eacd94f0..53158ea04 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ Learn more at http://rsocket.io ## Build and Binaries -[![Build Status](https://travis-ci.org/rsocket/rsocket-java.svg?branch=1.0.x)](https://travis-ci.org/rsocket/rsocket-java) +[![Build Status](https://travis-ci.org/rsocket/rsocket-java.svg?branch=develop)](https://travis-ci.org/rsocket/rsocket-java) Releases are available via Maven Central. diff --git a/gradle.properties b/gradle.properties index 044cb3c77..740dee4db 100644 --- a/gradle.properties +++ b/gradle.properties @@ -11,4 +11,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -version=0.12.2-RC3 +version=0.12.2-RC4 diff --git a/rsocket-core/src/jmh/java/io/rsocket/metadata/WellKnownMimeTypePerf.java b/rsocket-core/src/jmh/java/io/rsocket/metadata/WellKnownMimeTypePerf.java new file mode 100644 index 000000000..8f429fc19 --- /dev/null +++ b/rsocket-core/src/jmh/java/io/rsocket/metadata/WellKnownMimeTypePerf.java @@ -0,0 +1,96 @@ +package io.rsocket.metadata; + +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +@BenchmarkMode(Mode.Throughput) +@Fork(value = 1) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +@State(Scope.Thread) +public class WellKnownMimeTypePerf { + + // this is the old values() looping implementation of fromIdentifier + private WellKnownMimeType fromIdValuesLoop(int id) { + if (id < 0 || id > 127) { + return WellKnownMimeType.UNPARSEABLE_MIME_TYPE; + } + for (WellKnownMimeType value : WellKnownMimeType.values()) { + if (value.getIdentifier() == id) { + return value; + } + } + return WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE; + } + + // this is the core of the old values() looping implementation of fromString + private WellKnownMimeType fromStringValuesLoop(String mimeType) { + for (WellKnownMimeType value : WellKnownMimeType.values()) { + if (mimeType.equals(value.getString())) { + return value; + } + } + return WellKnownMimeType.UNPARSEABLE_MIME_TYPE; + } + + @Benchmark + public void fromIdArrayLookup(final Blackhole bh) { + // negative lookup + bh.consume(WellKnownMimeType.fromIdentifier(-10)); + bh.consume(WellKnownMimeType.fromIdentifier(-1)); + // too large lookup + bh.consume(WellKnownMimeType.fromIdentifier(129)); + // first lookup + bh.consume(WellKnownMimeType.fromIdentifier(0)); + // middle lookup + bh.consume(WellKnownMimeType.fromIdentifier(37)); + // reserved lookup + bh.consume(WellKnownMimeType.fromIdentifier(63)); + // last lookup + bh.consume(WellKnownMimeType.fromIdentifier(127)); + } + + @Benchmark + public void fromIdValuesLoopLookup(final Blackhole bh) { + // negative lookup + bh.consume(fromIdValuesLoop(-10)); + bh.consume(fromIdValuesLoop(-1)); + // too large lookup + bh.consume(fromIdValuesLoop(129)); + // first lookup + bh.consume(fromIdValuesLoop(0)); + // middle lookup + bh.consume(fromIdValuesLoop(37)); + // reserved lookup + bh.consume(fromIdValuesLoop(63)); + // last lookup + bh.consume(fromIdValuesLoop(127)); + } + + @Benchmark + public void fromStringMapLookup(final Blackhole bh) { + // unknown lookup + bh.consume(WellKnownMimeType.fromString("foo/bar")); + // first lookup + bh.consume(WellKnownMimeType.fromString(WellKnownMimeType.APPLICATION_AVRO.getString())); + // middle lookup + bh.consume(WellKnownMimeType.fromString(WellKnownMimeType.VIDEO_VP8.getString())); + // last lookup + bh.consume( + WellKnownMimeType.fromString( + WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString())); + } + + @Benchmark + public void fromStringValuesLoopLookup(final Blackhole bh) { + // unknown lookup + bh.consume(fromStringValuesLoop("foo/bar")); + // first lookup + bh.consume(fromStringValuesLoop(WellKnownMimeType.APPLICATION_AVRO.getString())); + // middle lookup + bh.consume(fromStringValuesLoop(WellKnownMimeType.VIDEO_VP8.getString())); + // last lookup + bh.consume( + fromStringValuesLoop(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString())); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java index 34c113481..1743bd6da 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java @@ -134,14 +134,25 @@ public ClientRSocketFactory addConnectionPlugin(DuplexConnectionInterceptor inte plugins.addConnectionPlugin(interceptor); return this; } - + /** Deprecated. Use {@link #addRequesterPlugin(RSocketInterceptor)} instead */ + @Deprecated public ClientRSocketFactory addClientPlugin(RSocketInterceptor interceptor) { - plugins.addClientPlugin(interceptor); + return addRequesterPlugin(interceptor); + } + + public ClientRSocketFactory addRequesterPlugin(RSocketInterceptor interceptor) { + plugins.addRequesterPlugin(interceptor); return this; } + /** Deprecated. Use {@link #addResponderPlugin(RSocketInterceptor)} instead */ + @Deprecated public ClientRSocketFactory addServerPlugin(RSocketInterceptor interceptor) { - plugins.addServerPlugin(interceptor); + return addResponderPlugin(interceptor); + } + + public ClientRSocketFactory addResponderPlugin(RSocketInterceptor interceptor) { + plugins.addResponderPlugin(interceptor); return this; } @@ -291,8 +302,8 @@ public Mono start() { ClientServerInputMultiplexer multiplexer = new ClientServerInputMultiplexer(wrappedConnection, plugins); - RSocketClient rSocketClient = - new RSocketClient( + RSocketRequester rSocketRequester = + new RSocketRequester( allocator, multiplexer.asClientConnection(), payloadDecoder, @@ -314,27 +325,27 @@ public Mono start() { setupPayload.sliceMetadata(), setupPayload.sliceData()); - RSocket wrappedRSocketClient = plugins.applyClient(rSocketClient); + RSocket wrappedRSocketRequester = plugins.applyRequester(rSocketRequester); - RSocket unwrappedServerSocket; + RSocket rSocketHandler; if (biAcceptor != null) { ConnectionSetupPayload setup = ConnectionSetupPayload.create(setupFrame); - unwrappedServerSocket = biAcceptor.apply(setup, wrappedRSocketClient); + rSocketHandler = biAcceptor.apply(setup, wrappedRSocketRequester); } else { - unwrappedServerSocket = acceptor.get().apply(wrappedRSocketClient); + rSocketHandler = acceptor.get().apply(wrappedRSocketRequester); } - RSocket wrappedRSocketServer = plugins.applyServer(unwrappedServerSocket); + RSocket wrappedRSocketHandler = plugins.applyResponder(rSocketHandler); - RSocketServer rSocketServer = - new RSocketServer( + RSocketResponder rSocketResponder = + new RSocketResponder( allocator, multiplexer.asServerConnection(), - wrappedRSocketServer, + wrappedRSocketHandler, payloadDecoder, errorConsumer); - return wrappedConnection.sendOne(setupFrame).thenReturn(wrappedRSocketClient); + return wrappedConnection.sendOne(setupFrame).thenReturn(wrappedRSocketRequester); }); } @@ -397,14 +408,25 @@ public ServerRSocketFactory addConnectionPlugin(DuplexConnectionInterceptor inte plugins.addConnectionPlugin(interceptor); return this; } - + /** Deprecated. Use {@link #addRequesterPlugin(RSocketInterceptor)} instead */ + @Deprecated public ServerRSocketFactory addClientPlugin(RSocketInterceptor interceptor) { - plugins.addClientPlugin(interceptor); + return addRequesterPlugin(interceptor); + } + + public ServerRSocketFactory addRequesterPlugin(RSocketInterceptor interceptor) { + plugins.addRequesterPlugin(interceptor); return this; } + /** Deprecated. Use {@link #addResponderPlugin(RSocketInterceptor)} instead */ + @Deprecated public ServerRSocketFactory addServerPlugin(RSocketInterceptor interceptor) { - plugins.addServerPlugin(interceptor); + return addResponderPlugin(interceptor); + } + + public ServerRSocketFactory addResponderPlugin(RSocketInterceptor interceptor) { + plugins.addResponderPlugin(interceptor); return this; } @@ -525,29 +547,29 @@ private Mono acceptSetup( (keepAliveHandler, wrappedMultiplexer) -> { ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(setupFrame); - RSocketClient rSocketClient = - new RSocketClient( + RSocketRequester rSocketRequester = + new RSocketRequester( allocator, wrappedMultiplexer.asServerConnection(), payloadDecoder, errorConsumer, StreamIdSupplier.serverSupplier()); - RSocket wrappedRSocketClient = plugins.applyClient(rSocketClient); + RSocket wrappedRSocketRequester = plugins.applyRequester(rSocketRequester); return acceptor - .accept(setupPayload, wrappedRSocketClient) + .accept(setupPayload, wrappedRSocketRequester) .onErrorResume( err -> sendError(multiplexer, rejectedSetupError(err)).then(Mono.error(err))) .doOnNext( - unwrappedServerSocket -> { - RSocket wrappedRSocketServer = plugins.applyServer(unwrappedServerSocket); + rSocketHandler -> { + RSocket wrappedRSocketHandler = plugins.applyResponder(rSocketHandler); - RSocketServer rSocketServer = - new RSocketServer( + RSocketResponder rSocketResponder = + new RSocketResponder( allocator, wrappedMultiplexer.asClientConnection(), - wrappedRSocketServer, + wrappedRSocketHandler, payloadDecoder, errorConsumer, setupPayload.keepAliveInterval(), diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java similarity index 99% rename from rsocket-core/src/main/java/io/rsocket/RSocketClient.java rename to rsocket-core/src/main/java/io/rsocket/RSocketRequester.java index 539e104b4..9c01db295 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java @@ -45,8 +45,10 @@ import org.reactivestreams.Subscriber; import reactor.core.publisher.*; -/** Client Side of a RSocket socket. Sends {@link ByteBuf}s to a {@link RSocketServer} */ -class RSocketClient implements RSocket { +/** + * Requester Side of a RSocket socket. Sends {@link ByteBuf}s to a {@link RSocketResponder} of peer + */ +class RSocketRequester implements RSocket { private final DuplexConnection connection; private final PayloadDecoder payloadDecoder; @@ -60,7 +62,7 @@ class RSocketClient implements RSocket { private final KeepAliveFramesAcceptor keepAliveFramesAcceptor; /*client requester*/ - RSocketClient( + RSocketRequester( ByteBufAllocator allocator, DuplexConnection connection, PayloadDecoder payloadDecoder, @@ -99,7 +101,7 @@ class RSocketClient implements RSocket { } /*server requester*/ - RSocketClient( + RSocketRequester( ByteBufAllocator allocator, DuplexConnection connection, PayloadDecoder payloadDecoder, diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/RSocketResponder.java similarity index 98% rename from rsocket-core/src/main/java/io/rsocket/RSocketServer.java rename to rsocket-core/src/main/java/io/rsocket/RSocketResponder.java index 6ce5ef88d..f861303a4 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketResponder.java @@ -43,8 +43,8 @@ import reactor.core.Exceptions; import reactor.core.publisher.*; -/** Server side RSocket. Receives {@link ByteBuf}s from a {@link RSocketClient} */ -class RSocketServer implements ResponderRSocket { +/** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */ +class RSocketResponder implements ResponderRSocket { private final DuplexConnection connection; private final RSocket requestHandler; @@ -61,7 +61,7 @@ class RSocketServer implements ResponderRSocket { private final KeepAliveFramesAcceptor keepAliveFramesAcceptor; /*client responder*/ - RSocketServer( + RSocketResponder( ByteBufAllocator allocator, DuplexConnection connection, RSocket requestHandler, @@ -71,7 +71,7 @@ class RSocketServer implements ResponderRSocket { } /*server responder*/ - RSocketServer( + RSocketResponder( ByteBufAllocator allocator, DuplexConnection connection, RSocket requestHandler, diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java index 68e15fe70..cbe989d4b 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java @@ -26,6 +26,7 @@ import io.rsocket.frame.FrameLengthFlyweight; import io.rsocket.frame.FrameType; import java.util.Objects; +import javax.annotation.Nullable; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -57,13 +58,10 @@ public FragmentationDuplexConnection( String type) { Objects.requireNonNull(delegate, "delegate must not be null"); Objects.requireNonNull(allocator, "byteBufAllocator must not be null"); - if (mtu < MIN_MTU_SIZE) { - throw new IllegalArgumentException("smallest allowed mtu size is " + MIN_MTU_SIZE + " bytes"); - } this.encodeLength = encodeLength; this.allocator = allocator; this.delegate = delegate; - this.mtu = mtu; + this.mtu = assertMtu(mtu); this.frameReassembler = new FrameReassembler(allocator); this.type = type; @@ -74,6 +72,32 @@ private boolean shouldFragment(FrameType frameType, int readableBytes) { return frameType.isFragmentable() && readableBytes > mtu; } + /*TODO this is nullable and not returning empty to workaround javac 11.0.3 compiler issue on ubuntu (at least) */ + @Nullable + public static Mono checkMtu(int mtu) { + if (isInsufficientMtu(mtu)) { + String msg = + String.format("smallest allowed mtu size is %d bytes, provided: %d", MIN_MTU_SIZE, mtu); + return Mono.error(new IllegalArgumentException(msg)); + } else { + return null; + } + } + + private static int assertMtu(int mtu) { + if (isInsufficientMtu(mtu)) { + String msg = + String.format("smallest allowed mtu size is %d bytes, provided: %d", MIN_MTU_SIZE, mtu); + throw new IllegalArgumentException(msg); + } else { + return mtu; + } + } + + private static boolean isInsufficientMtu(int mtu) { + return mtu > 0 && mtu < MIN_MTU_SIZE || mtu < 0; + } + @Override public Mono send(Publisher frames) { return Flux.from(frames).concatMap(this::sendOne).then(); @@ -89,13 +113,13 @@ public Mono sendOne(ByteBuf frame) { Flux.from(fragmentFrame(allocator, mtu, frame, frameType, encodeLength)) .doOnNext( byteBuf -> { - ByteBuf frame1 = FrameLengthFlyweight.frame(byteBuf); + ByteBuf f = encodeLength ? FrameLengthFlyweight.frame(byteBuf) : byteBuf; logger.debug( "{} - stream id {} - frame type {} - \n {}", type, - FrameHeaderFlyweight.streamId(frame1), - FrameHeaderFlyweight.frameType(frame1), - ByteBufUtil.prettyHexDump(frame1)); + FrameHeaderFlyweight.streamId(f), + FrameHeaderFlyweight.frameType(f), + ByteBufUtil.prettyHexDump(f)); })); } else { return delegate.send( @@ -108,7 +132,7 @@ public Mono sendOne(ByteBuf frame) { private ByteBuf encode(ByteBuf frame) { if (encodeLength) { - return FrameLengthFlyweight.encode(allocator, frame.readableBytes(), frame).retain(); + return FrameLengthFlyweight.encode(allocator, frame.readableBytes(), frame); } else { return frame; } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java index 4f65ebf63..5660e3615 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java @@ -151,6 +151,7 @@ void handleNoFollowsFlag(ByteBuf frame, SynchronousSink sink, int strea ByteBuf assembledFrame = FragmentationFlyweight.encode(allocator, header, data); sink.next(assembledFrame); } + frame.release(); } else { sink.next(frame); } @@ -220,9 +221,8 @@ void handleFollowsFlag(ByteBuf frame, int streamId, FrameType frameType) { throw new IllegalStateException("unsupported fragment type"); } - if (data != Unpooled.EMPTY_BUFFER) { - getData(streamId).addComponents(true, data); - } + getData(streamId).addComponents(true, data); + frame.release(); } void reassembleFrame(ByteBuf frame, SynchronousSink sink) { @@ -259,25 +259,21 @@ private ByteBuf assembleFrameWithMetadata(ByteBuf frame, int streamId, ByteBuf h ByteBuf metadata; CompositeByteBuf cm = removeMetadata(streamId); if (cm != null) { - ByteBuf m = PayloadFrameFlyweight.metadata(frame); - metadata = cm.addComponents(true, m); + metadata = cm.addComponents(true, PayloadFrameFlyweight.metadata(frame).retain()); } else { - metadata = PayloadFrameFlyweight.metadata(frame); + metadata = PayloadFrameFlyweight.metadata(frame).retain(); } ByteBuf data = assembleData(frame, streamId); - return FragmentationFlyweight.encode(allocator, header, metadata.retain(), data); + return FragmentationFlyweight.encode(allocator, header, metadata, data); } private ByteBuf assembleData(ByteBuf frame, int streamId) { ByteBuf data; CompositeByteBuf cd = removeData(streamId); if (cd != null) { - ByteBuf d = PayloadFrameFlyweight.data(frame); - if (d != null) { - cd.addComponents(true, d); - } + cd.addComponents(true, PayloadFrameFlyweight.data(frame).retain()); data = cd; } else { data = Unpooled.EMPTY_BUFFER; diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java index 417e44857..a4d862135 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java @@ -70,7 +70,7 @@ private static ByteBuf getMetadata(ByteBuf frame, FrameType frameType) { default: return Unpooled.EMPTY_BUFFER; } - return metadata.retain(); + return metadata; } else { return Unpooled.EMPTY_BUFFER; } diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java new file mode 100644 index 000000000..9eb349396 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java @@ -0,0 +1,220 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static io.rsocket.metadata.CompositeMetadataFlyweight.computeNextEntryIndex; +import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeAndContentBuffersSlices; +import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeIdFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataFlyweight.hasEntry; +import static io.rsocket.metadata.CompositeMetadataFlyweight.isWellKnownMimeType; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.metadata.CompositeMetadata.Entry; +import java.util.Iterator; +import reactor.util.annotation.Nullable; + +/** + * An {@link Iterable} wrapper around a {@link ByteBuf} that exposes metadata entry information at + * each decoding step. This is only possible on frame types used to initiate interactions, if the + * SETUP metadata mime type was {@link WellKnownMimeType#MESSAGE_RSOCKET_COMPOSITE_METADATA}. + * + *

This allows efficient incremental decoding of the entries (without moving the source's {@link + * io.netty.buffer.ByteBuf#readerIndex()}). The buffer is assumed to contain just enough bytes to + * represent one or more entries (mime type compressed or not). The decoding stops when the buffer + * reaches 0 readable bytes, and fails if it contains bytes but not enough to correctly decode an + * entry. + * + *

A note on future-proofness: it is possible to come across a compressed mime type that this + * implementation doesn't recognize. This is likely to be due to the use of a byte id that is merely + * reserved in this implementation, but maps to a {@link WellKnownMimeType} in the implementation + * that encoded the metadata. This can be detected by detecting that an entry is a {@link + * ReservedMimeTypeEntry}. In this case {@link Entry#getMimeType()} will return {@code null}. The + * encoded id can be retrieved using {@link ReservedMimeTypeEntry#getType()}. The byte and content + * buffer should be kept around and re-encoded using {@link + * CompositeMetadataFlyweight#encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, byte, + * ByteBuf)} in case passing that entry through is required. + */ +public final class CompositeMetadata implements Iterable { + + private final boolean retainSlices; + + private final ByteBuf source; + + public CompositeMetadata(ByteBuf source, boolean retainSlices) { + this.source = source; + this.retainSlices = retainSlices; + } + + @Override + public Iterator iterator() { + return new Iterator() { + + private int entryIndex = 0; + + @Override + public boolean hasNext() { + return hasEntry(CompositeMetadata.this.source, this.entryIndex); + } + + @Override + public Entry next() { + ByteBuf[] headerAndData = + decodeMimeAndContentBuffersSlices( + CompositeMetadata.this.source, + this.entryIndex, + CompositeMetadata.this.retainSlices); + + ByteBuf header = headerAndData[0]; + ByteBuf data = headerAndData[1]; + + this.entryIndex = computeNextEntryIndex(this.entryIndex, header, data); + + if (!isWellKnownMimeType(header)) { + CharSequence typeString = decodeMimeTypeFromMimeBuffer(header); + if (typeString == null) { + throw new IllegalStateException("MIME type cannot be null"); + } + + return new ExplicitMimeTimeEntry(data, typeString.toString()); + } + + byte id = decodeMimeIdFromMimeBuffer(header); + WellKnownMimeType type = WellKnownMimeType.fromIdentifier(id); + + if (WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE == type) { + return new ReservedMimeTypeEntry(data, id); + } + + return new WellKnownMimeTypeEntry(data, type); + } + }; + } + + /** An entry in the {@link CompositeMetadata}. */ + public interface Entry { + + /** + * Returns the un-decoded content of the {@link Entry}. + * + * @return the un-decoded content of the {@link Entry} + */ + ByteBuf getContent(); + + /** + * Returns the MIME type of the entry, if it can be decoded. + * + * @return the MIME type of the entry, if it can be decoded, otherwise {@code null}. + */ + @Nullable + String getMimeType(); + } + + /** An {@link Entry} backed by an explicitly declared MIME type. */ + public static final class ExplicitMimeTimeEntry implements Entry { + + private final ByteBuf content; + + private final String type; + + public ExplicitMimeTimeEntry(ByteBuf content, String type) { + this.content = content; + this.type = type; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + @Override + public String getMimeType() { + return this.type; + } + } + + /** + * An {@link Entry} backed by a {@link WellKnownMimeType} entry, but one that is not understood by + * this implementation. + */ + public static final class ReservedMimeTypeEntry implements Entry { + private final ByteBuf content; + private final int type; + + public ReservedMimeTypeEntry(ByteBuf content, int type) { + this.content = content; + this.type = type; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + /** + * {@inheritDoc} Since this entry represents a compressed id that couldn't be decoded, this is + * always {@code null}. + */ + @Override + public String getMimeType() { + return null; + } + + /** + * Returns the reserved, but unknown {@link WellKnownMimeType} for this entry. Range is 0-127 + * (inclusive). + * + * @return the reserved, but unknown {@link WellKnownMimeType} for this entry + */ + public int getType() { + return this.type; + } + } + + /** An {@link Entry} backed by a {@link WellKnownMimeType}. */ + public static final class WellKnownMimeTypeEntry implements Entry { + + private final ByteBuf content; + private final WellKnownMimeType type; + + public WellKnownMimeTypeEntry(ByteBuf content, WellKnownMimeType type) { + this.content = content; + this.type = type; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + @Override + public String getMimeType() { + return this.type.getString(); + } + + /** + * Returns the {@link WellKnownMimeType} for this entry. + * + * @return the {@link WellKnownMimeType} for this entry + */ + public WellKnownMimeType getType() { + return this.type; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataFlyweight.java new file mode 100644 index 000000000..9abd638cb --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataFlyweight.java @@ -0,0 +1,383 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.CharsetUtil; +import io.rsocket.util.NumberUtils; +import reactor.util.annotation.Nullable; + +/** + * A flyweight class that can be used to encode/decode composite metadata information to/from {@link + * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link + * CompositeMetadata} for an Iterator-like approach to decoding entries. + */ +public class CompositeMetadataFlyweight { + + static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + + static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 + + private CompositeMetadataFlyweight() {} + + public static int computeNextEntryIndex( + int currentEntryIndex, ByteBuf headerSlice, ByteBuf contentSlice) { + return currentEntryIndex + + headerSlice.readableBytes() // this includes the mime length byte + + 3 // 3 bytes of the content length, which are excluded from the slice + + contentSlice.readableBytes(); + } + + /** + * Decode the next metadata entry (a mime header + content pair of {@link ByteBuf}) from a {@link + * ByteBuf} that contains at least enough bytes for one more such entry. These buffers are + * actually slices of the full metadata buffer, and this method doesn't move the full metadata + * buffer's {@link ByteBuf#readerIndex()}. As such, it requires the user to provide an {@code + * index} to read from. The next index is computed by calling {@link #computeNextEntryIndex(int, + * ByteBuf, ByteBuf)}. Size of the first buffer (the "header buffer") drives which decoding method + * should be further applied to it. + * + *

The header buffer is either: + * + *

    + *
  • made up of a single byte: this represents an encoded mime id, which can be further + * decoded using {@link #decodeMimeIdFromMimeBuffer(ByteBuf)} + *
  • made up of 2 or more bytes: this represents an encoded mime String + its length, which + * can be further decoded using {@link #decodeMimeTypeFromMimeBuffer(ByteBuf)}. Note the + * encoded length, in the first byte, is skipped by this decoding method because the + * remaining length of the buffer is that of the mime string. + *
+ * + * @param compositeMetadata the source {@link ByteBuf} that originally contains one or more + * metadata entries + * @param entryIndex the {@link ByteBuf#readerIndex()} to start decoding from. original reader + * index is kept on the source buffer + * @param retainSlices should produced metadata entry buffers {@link ByteBuf#slice() slices} be + * {@link ByteBuf#retainedSlice() retained}? + * @return a {@link ByteBuf} array of length 2 containing the mime header buffer + * slice and the content buffer slice, or one of the + * zero-length error constant arrays + */ + public static ByteBuf[] decodeMimeAndContentBuffersSlices( + ByteBuf compositeMetadata, int entryIndex, boolean retainSlices) { + compositeMetadata.markReaderIndex(); + compositeMetadata.readerIndex(entryIndex); + + if (compositeMetadata.isReadable()) { + ByteBuf mime; + int ridx = compositeMetadata.readerIndex(); + byte mimeIdOrLength = compositeMetadata.readByte(); + if ((mimeIdOrLength & STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK) { + mime = + retainSlices + ? compositeMetadata.retainedSlice(ridx, 1) + : compositeMetadata.slice(ridx, 1); + } else { + // M flag unset, remaining 7 bits are the length of the mime + int mimeLength = Byte.toUnsignedInt(mimeIdOrLength) + 1; + + if (compositeMetadata.isReadable( + mimeLength)) { // need to be able to read an extra mimeLength bytes + // here we need a way for the returned ByteBuf to differentiate between a + // 1-byte length mime type and a 1 byte encoded mime id, preferably without + // re-applying the byte mask. The easiest way is to include the initial byte + // and have further decoding ignore the first byte. 1 byte buffer == id, 2+ byte + // buffer == full mime string. + mime = + retainSlices + ? + // we accommodate that we don't read from current readerIndex, but + // readerIndex - 1 ("0"), for a total slice size of mimeLength + 1 + compositeMetadata.retainedSlice(ridx, mimeLength + 1) + : compositeMetadata.slice(ridx, mimeLength + 1); + // we thus need to skip the bytes we just sliced, but not the flag/length byte + // which was already skipped in initial read + compositeMetadata.skipBytes(mimeLength); + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } + + if (compositeMetadata.isReadable(3)) { + // ensures the length medium can be read + final int metadataLength = compositeMetadata.readUnsignedMedium(); + if (compositeMetadata.isReadable(metadataLength)) { + ByteBuf metadata = + retainSlices + ? compositeMetadata.readRetainedSlice(metadataLength) + : compositeMetadata.readSlice(metadataLength); + compositeMetadata.resetReaderIndex(); + return new ByteBuf[] {mime, metadata}; + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } + compositeMetadata.resetReaderIndex(); + throw new IllegalArgumentException( + String.format("entry index %d is larger than buffer size", entryIndex)); + } + + /** + * Decode a {@code byte} compressed mime id from a {@link ByteBuf}, assuming said buffer properly + * contains such an id. + * + *

The buffer must have exactly one readable byte, which is assumed to have been tested for + * mime id encoding via the {@link #STREAM_METADATA_KNOWN_MASK} mask ({@code firstByte & + * STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK}). + * + *

If there is no readable byte, the negative identifier of {@link + * WellKnownMimeType#UNPARSEABLE_MIME_TYPE} is returned. + * + * @param mimeBuffer the buffer that should next contain the compressed mime id byte + * @return the compressed mime id, between 0 and 127, or a negative id if the input is invalid + * @see #decodeMimeTypeFromMimeBuffer(ByteBuf) + */ + public static byte decodeMimeIdFromMimeBuffer(ByteBuf mimeBuffer) { + if (mimeBuffer.readableBytes() != 1) { + return WellKnownMimeType.UNPARSEABLE_MIME_TYPE.getIdentifier(); + } + return (byte) (mimeBuffer.readByte() & STREAM_METADATA_LENGTH_MASK); + } + + /** + * Decode a {@link CharSequence} custome mime type from a {@link ByteBuf}, assuming said buffer + * properly contains such a mime type. + * + *

The buffer must at least have two readable bytes, which distinguishes it from the {@link + * #decodeMimeIdFromMimeBuffer(ByteBuf) compressed id} case. The first byte is a size and the + * remaining bytes must correspond to the {@link CharSequence}, encoded fully in US_ASCII. As a + * result, the first byte can simply be skipped, and the remaining of the buffer be decoded to the + * mime type. + * + *

If the mime header buffer is less than 2 bytes long, returns {@code null}. + * + * @param flyweightMimeBuffer the mime header {@link ByteBuf} that contains length + custom mime + * type + * @return the decoded custom mime type, as a {@link CharSequence}, or null if the input is + * invalid + * @see #decodeMimeIdFromMimeBuffer(ByteBuf) + */ + @Nullable + public static CharSequence decodeMimeTypeFromMimeBuffer(ByteBuf flyweightMimeBuffer) { + if (flyweightMimeBuffer.readableBytes() < 2) { + throw new IllegalStateException("unable to decode explicit MIME type"); + } + // the encoded length is assumed to be kept at the start of the buffer + // but also assumed to be irrelevant because the rest of the slice length + // actually already matches _decoded_length + flyweightMimeBuffer.skipBytes(1); + int mimeStringLength = flyweightMimeBuffer.readableBytes(); + return flyweightMimeBuffer.readCharSequence(mimeStringLength, CharsetUtil.US_ASCII); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}, without checking if the {@link String} can be matched with a well known compressable + * mime type. Prefer using this method and {@link #encodeAndAddMetadata(CompositeByteBuf, + * ByteBufAllocator, WellKnownMimeType, ByteBuf)} if you know in advance whether or not the mime + * is well known. Otherwise use {@link #encodeAndAddMetadataWithCompression(CompositeByteBuf, + * ByteBufAllocator, String, ByteBuf)} + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param customMimeType the custom mime type to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, String, int) + public static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + String customMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, encodeMetadataHeader(allocator, customMimeType, metadata.readableBytes()), metadata); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param knownMimeType the {@link WellKnownMimeType} to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, byte, int) + public static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + WellKnownMimeType knownMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, knownMimeType.getIdentifier(), metadata.readableBytes()), + metadata); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}, first verifying if the passed {@link String} matches a {@link WellKnownMimeType} (in + * which case it will be encoded in a compressed fashion using the mime id of that type). + * + *

Prefer using {@link #encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, String, + * ByteBuf)} if you already know that the mime type is not a {@link WellKnownMimeType}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param mimeType the mime type to encode, as a {@link String}. well known mime types are + * compressed. + * @param metadata the metadata value to encode. + * @see #encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, WellKnownMimeType, ByteBuf) + */ + // see #encodeMetadataHeader(ByteBufAllocator, String, int) + public static void encodeAndAddMetadataWithCompression( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + String mimeType, + ByteBuf metadata) { + WellKnownMimeType wkn = WellKnownMimeType.fromString(mimeType); + if (wkn == WellKnownMimeType.UNPARSEABLE_MIME_TYPE) { + compositeMetaData.addComponents( + true, encodeMetadataHeader(allocator, mimeType, metadata.readableBytes()), metadata); + } else { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, wkn.getIdentifier(), metadata.readableBytes()), + metadata); + } + } + + /** + * Returns whether there is another entry available at a given index + * + * @param compositeMetadata the buffer to inspect + * @param entryIndex the index to check at + * @return whether there is another entry available at a given index + */ + public static boolean hasEntry(ByteBuf compositeMetadata, int entryIndex) { + return compositeMetadata.writerIndex() - entryIndex > 0; + } + + /** + * Returns whether the header represents a well-known MIME type. + * + * @param header the header to inspect + * @return whether the header represents a well-known MIME type + */ + public static boolean isWellKnownMimeType(ByteBuf header) { + return header.readableBytes() == 1; + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param unknownCompressedMimeType the id of the {@link + * WellKnownMimeType#UNKNOWN_RESERVED_MIME_TYPE} to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, byte, int) + static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + byte unknownCompressedMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, unknownCompressedMimeType, metadata.readableBytes()), + metadata); + } + + /** + * Encode a custom mime type and a metadata value length into a newly allocated {@link ByteBuf}. + * + *

This larger representation encodes the mime type representation's length on a single byte, + * then the representation itself, then the unsigned metadata value length on 3 additional bytes. + * + * @param allocator the {@link ByteBufAllocator} to use to create the buffer. + * @param customMime a custom mime type to encode. + * @param metadataLength the metadata length to append to the buffer as an unsigned 24 bits + * integer. + * @return the encoded mime and metadata length information + */ + static ByteBuf encodeMetadataHeader( + ByteBufAllocator allocator, String customMime, int metadataLength) { + ByteBuf metadataHeader = allocator.buffer(4 + customMime.length()); + // reserve 1 byte for the customMime length + int writerIndexInitial = metadataHeader.writerIndex(); + metadataHeader.writerIndex(writerIndexInitial + 1); + + // write the custom mime in UTF8 but validate it is all ASCII-compatible + // (which produces the right result since ASCII chars are still encoded on 1 byte in UTF8) + int customMimeLength = ByteBufUtil.writeUtf8(metadataHeader, customMime); + if (!ByteBufUtil.isText(metadataHeader, CharsetUtil.US_ASCII)) { + metadataHeader.release(); + throw new IllegalArgumentException("custom mime type must be US_ASCII characters only"); + } + if (customMimeLength < 1 || customMimeLength > 128) { + metadataHeader.release(); + throw new IllegalArgumentException( + "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + metadataHeader.markWriterIndex(); + + // go back to beginning and write the length + // encoded length is one less than actual length, since 0 is never a valid length, which gives + // wider representation range + metadataHeader.writerIndex(writerIndexInitial); + metadataHeader.writeByte(customMimeLength - 1); + + // go back to post-mime type and write the metadata content length + metadataHeader.resetWriterIndex(); + NumberUtils.encodeUnsignedMedium(metadataHeader, metadataLength); + + return metadataHeader; + } + + /** + * Encode a {@link WellKnownMimeType well known mime type} and a metadata value length into a + * newly allocated {@link ByteBuf}. + * + *

This compact representation encodes the mime type via its ID on a single byte, and the + * unsigned value length on 3 additional bytes. + * + * @param allocator the {@link ByteBufAllocator} to use to create the buffer. + * @param mimeType a byte identifier of a {@link WellKnownMimeType} to encode. + * @param metadataLength the metadata length to append to the buffer as an unsigned 24 bits + * integer. + * @return the encoded mime and metadata length information + */ + static ByteBuf encodeMetadataHeader( + ByteBufAllocator allocator, byte mimeType, int metadataLength) { + ByteBuf buffer = allocator.buffer(4, 4).writeByte(mimeType | STREAM_METADATA_KNOWN_MASK); + + NumberUtils.encodeUnsignedMedium(buffer, metadataLength); + + return buffer; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java new file mode 100644 index 000000000..9ecaf0859 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java @@ -0,0 +1,162 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * Enumeration of Well Known Mime Types, as defined in the eponymous extension. Such mime types are + * used in composite metadata (which can include routing and/or tracing metadata). Per + * specification, identifiers are between 0 and 127 (inclusive). + */ +public enum WellKnownMimeType { + UNPARSEABLE_MIME_TYPE("UNPARSEABLE_MIME_TYPE_DO_NOT_USE", (byte) -2), + UNKNOWN_RESERVED_MIME_TYPE("UNKNOWN_YET_RESERVED_DO_NOT_USE", (byte) -1), + + APPLICATION_AVRO("application/avro", (byte) 0x00), + APPLICATION_CBOR("application/cbor", (byte) 0x01), + APPLICATION_GRAPHQL("application/graphql", (byte) 0x02), + APPLICATION_GZIP("application/gzip", (byte) 0x03), + APPLICATION_JAVASCRIPT("application/javascript", (byte) 0x04), + APPLICATION_JSON("application/json", (byte) 0x05), + APPLICATION_OCTET_STREAM("application/octet-stream", (byte) 0x06), + APPLICATION_PDF("application/pdf", (byte) 0x07), + APPLICATION_THRIFT("application/vnd.apache.thrift.binary", (byte) 0x08), + APPLICATION_PROTOBUF("application/vnd.google.protobuf", (byte) 0x09), + APPLICATION_XML("application/xml", (byte) 0x0A), + APPLICATION_ZIP("application/zip", (byte) 0x0B), + AUDIO_AAC("audio/aac", (byte) 0x0C), + AUDIO_MP3("audio/mp3", (byte) 0x0D), + AUDIO_MP4("audio/mp4", (byte) 0x0E), + AUDIO_MPEG3("audio/mpeg3", (byte) 0x0F), + AUDIO_MPEG("audio/mpeg", (byte) 0x10), + AUDIO_OGG("audio/ogg", (byte) 0x11), + AUDIO_OPUS("audio/opus", (byte) 0x12), + AUDIO_VORBIS("audio/vorbis", (byte) 0x13), + IMAGE_BMP("image/bmp", (byte) 0x14), + IMAGE_GIG("image/gif", (byte) 0x15), + IMAGE_HEIC_SEQUENCE("image/heic-sequence", (byte) 0x16), + IMAGE_HEIC("image/heic", (byte) 0x17), + IMAGE_HEIF_SEQUENCE("image/heif-sequence", (byte) 0x18), + IMAGE_HEIF("image/heif", (byte) 0x19), + IMAGE_JPEG("image/jpeg", (byte) 0x1A), + IMAGE_PNG("image/png", (byte) 0x1B), + IMAGE_TIFF("image/tiff", (byte) 0x1C), + MULTIPART_MIXED("multipart/mixed", (byte) 0x1D), + TEXT_CSS("text/css", (byte) 0x1E), + TEXT_CSV("text/csv", (byte) 0x1F), + TEXT_HTML("text/html", (byte) 0x20), + TEXT_PLAIN("text/plain", (byte) 0x21), + TEXT_XML("text/xml", (byte) 0x22), + VIDEO_H264("video/H264", (byte) 0x23), + VIDEO_H265("video/H265", (byte) 0x24), + VIDEO_VP8("video/VP8", (byte) 0x25), + + // ... reserved for future use ... + + MESSAGE_RSOCKET_TRACING_ZIPKIN("message/x.rsocket.tracing-zipkin.v0", (byte) 0x7D), + MESSAGE_RSOCKET_ROUTING("message/x.rsocket.routing.v0", (byte) 0x7E), + MESSAGE_RSOCKET_COMPOSITE_METADATA("message/x.rsocket.composite-metadata.v0", (byte) 0x7F); + + static final WellKnownMimeType[] TYPES_BY_MIME_ID; + static final Map TYPES_BY_MIME_STRING; + + static { + // precompute an array of all valid mime ids, filling the blanks with the RESERVED enum + TYPES_BY_MIME_ID = new WellKnownMimeType[128]; // 0-127 inclusive + Arrays.fill(TYPES_BY_MIME_ID, UNKNOWN_RESERVED_MIME_TYPE); + // also prepare a Map of the types by mime string + TYPES_BY_MIME_STRING = new HashMap<>(128); + + for (WellKnownMimeType value : values()) { + if (value.getIdentifier() >= 0) { + TYPES_BY_MIME_ID[value.getIdentifier()] = value; + TYPES_BY_MIME_STRING.put(value.getString(), value); + } + } + } + + private final byte identifier; + private final String str; + + WellKnownMimeType(String str, byte identifier) { + this.str = str; + this.identifier = identifier; + } + + /** + * Find the {@link WellKnownMimeType} for the given identifier (as an {@code int}). Valid + * identifiers are defined to be integers between 0 and 127, inclusive. Identifiers outside of + * this range will produce the {@link #UNPARSEABLE_MIME_TYPE}. Additionally, some identifiers in + * that range are still only reserved and don't have a type associated yet: this method returns + * the {@link #UNKNOWN_RESERVED_MIME_TYPE} when passing such an identifier, which lets call sites + * potentially detect this and keep the original representation when transmitting the associated + * metadata buffer. + * + * @param id the looked up identifier + * @return the {@link WellKnownMimeType}, or {@link #UNKNOWN_RESERVED_MIME_TYPE} if the id is out + * of the specification's range, or {@link #UNKNOWN_RESERVED_MIME_TYPE} if the id is one that + * is merely reserved but unknown to this implementation. + */ + public static WellKnownMimeType fromIdentifier(int id) { + if (id < 0x00 || id > 0x7F) { + return UNPARSEABLE_MIME_TYPE; + } + return TYPES_BY_MIME_ID[id]; + } + + /** + * Find the {@link WellKnownMimeType} for the given {@link String} representation. If the + * representation is {@code null} or doesn't match a {@link WellKnownMimeType}, the {@link + * #UNPARSEABLE_MIME_TYPE} is returned. + * + * @param mimeType the looked up mime type + * @return the matching {@link WellKnownMimeType}, or {@link #UNPARSEABLE_MIME_TYPE} if none + * matches + */ + public static WellKnownMimeType fromString(String mimeType) { + if (mimeType == null) throw new IllegalArgumentException("type must be non-null"); + + // force UNPARSEABLE if by chance UNKNOWN_RESERVED_MIME_TYPE's text has been used + if (mimeType.equals(UNKNOWN_RESERVED_MIME_TYPE.str)) { + return UNPARSEABLE_MIME_TYPE; + } + + return TYPES_BY_MIME_STRING.getOrDefault(mimeType, UNPARSEABLE_MIME_TYPE); + } + + /** @return the byte identifier of the mime type, guaranteed to be positive or zero. */ + public byte getIdentifier() { + return identifier; + } + + /** + * @return the mime type represented as a {@link String}, which is made of US_ASCII compatible + * characters only + */ + public String getString() { + return str; + } + + /** @see #getString() */ + @Override + public String toString() { + return str; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java index 873f6babb..676cfc19c 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java @@ -23,39 +23,63 @@ public class PluginRegistry { private List connections = new ArrayList<>(); - private List clients = new ArrayList<>(); - private List servers = new ArrayList<>(); + private List requesters = new ArrayList<>(); + private List responders = new ArrayList<>(); public PluginRegistry() {} public PluginRegistry(PluginRegistry defaults) { this.connections.addAll(defaults.connections); - this.clients.addAll(defaults.clients); - this.servers.addAll(defaults.servers); + this.requesters.addAll(defaults.requesters); + this.responders.addAll(defaults.responders); } public void addConnectionPlugin(DuplexConnectionInterceptor interceptor) { connections.add(interceptor); } + /** Deprecated. Use {@link #addRequesterPlugin(RSocketInterceptor)} instead */ + @Deprecated public void addClientPlugin(RSocketInterceptor interceptor) { - clients.add(interceptor); + addRequesterPlugin(interceptor); } + public void addRequesterPlugin(RSocketInterceptor interceptor) { + requesters.add(interceptor); + } + + /** Deprecated. Use {@link #addResponderPlugin(RSocketInterceptor)} instead */ + @Deprecated public void addServerPlugin(RSocketInterceptor interceptor) { - servers.add(interceptor); + addResponderPlugin(interceptor); + } + + public void addResponderPlugin(RSocketInterceptor interceptor) { + responders.add(interceptor); } + /** Deprecated. Use {@link #applyRequester(RSocket)} instead */ + @Deprecated public RSocket applyClient(RSocket rSocket) { - for (RSocketInterceptor i : clients) { + return applyRequester(rSocket); + } + + public RSocket applyRequester(RSocket rSocket) { + for (RSocketInterceptor i : requesters) { rSocket = i.apply(rSocket); } return rSocket; } + /** Deprecated. Use {@link #applyResponder(RSocket)} instead */ + @Deprecated public RSocket applyServer(RSocket rSocket) { - for (RSocketInterceptor i : servers) { + return applyResponder(rSocket); + } + + public RSocket applyResponder(RSocket rSocket) { + for (RSocketInterceptor i : responders) { rSocket = i.apply(rSocket); } diff --git a/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java b/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java index 12e3cee45..3ff720447 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java +++ b/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java @@ -16,6 +16,7 @@ package io.rsocket.util; +import io.netty.buffer.ByteBuf; import java.util.Objects; public final class NumberUtils { @@ -143,4 +144,21 @@ public static int requireUnsignedShort(int i) { return i; } + + /** + * Encode an unsigned medium integer on 3 bytes / 24 bits. This can be decoded directly by the + * {@link ByteBuf#readUnsignedMedium()} method. + * + * @param byteBuf the {@link ByteBuf} into which to write the bits + * @param i the medium integer to encode + * @see #requireUnsignedMedium(int) + */ + public static void encodeUnsignedMedium(ByteBuf byteBuf, int i) { + requireUnsignedMedium(i); + // Write each byte separately in reverse order, this mean we can write 1 << 23 without + // overflowing. + byteBuf.writeByte(i >> 16); + byteBuf.writeByte(i >> 8); + byteBuf.writeByte(i); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java index aededb804..106e61097 100644 --- a/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java +++ b/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java @@ -65,8 +65,8 @@ static Supplier requester(int tickPeriod, int timeout) { return () -> { TestDuplexConnection connection = new TestDuplexConnection(); Errors errors = new Errors(); - RSocketClient rSocket = - new RSocketClient( + RSocketRequester rSocket = + new RSocketRequester( ByteBufAllocator.DEFAULT, connection, DefaultPayload::create, @@ -84,8 +84,8 @@ static Supplier responder(int tickPeriod, int timeout) { TestDuplexConnection connection = new TestDuplexConnection(); AbstractRSocket handler = new AbstractRSocket() {}; Errors errors = new Errors(); - RSocketServer rSocket = - new RSocketServer( + RSocketResponder rSocket = + new RSocketResponder( ByteBufAllocator.DEFAULT, connection, handler, @@ -110,8 +110,8 @@ static Supplier resumableRequester(int tickPeriod, int ti false); Errors errors = new Errors(); - RSocketClient rSocket = - new RSocketClient( + RSocketRequester rSocket = + new RSocketRequester( ByteBufAllocator.DEFAULT, resumableConnection, DefaultPayload::create, @@ -136,8 +136,8 @@ static Supplier resumableResponder(int tickPeriod, int ti Duration.ofSeconds(10), false); Errors errors = new Errors(); - RSocketServer rSocket = - new RSocketServer( + RSocketResponder rSocket = + new RSocketResponder( ByteBufAllocator.DEFAULT, resumableConnection, handler, diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketClientTerminationTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketRequesterTerminationTest.java similarity index 86% rename from rsocket-core/src/test/java/io/rsocket/RSocketClientTerminationTest.java rename to rsocket-core/src/test/java/io/rsocket/RSocketRequesterTerminationTest.java index ae3bfc489..a2c17cf95 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketClientTerminationTest.java +++ b/rsocket-core/src/test/java/io/rsocket/RSocketRequesterTerminationTest.java @@ -1,6 +1,6 @@ package io.rsocket; -import io.rsocket.RSocketClientTest.ClientSocketRule; +import io.rsocket.RSocketRequesterTest.ClientSocketRule; import io.rsocket.util.EmptyPayload; import java.nio.channels.ClosedChannelException; import java.time.Duration; @@ -16,18 +16,18 @@ import reactor.test.StepVerifier; @RunWith(Parameterized.class) -public class RSocketClientTerminationTest { +public class RSocketRequesterTerminationTest { @Rule public final ClientSocketRule rule = new ClientSocketRule(); private Function> interaction; - public RSocketClientTerminationTest(Function> interaction) { + public RSocketRequesterTerminationTest(Function> interaction) { this.interaction = interaction; } @Test public void testCurrentStreamIsTerminatedOnConnectionClose() { - RSocketClient rSocket = rule.socket; + RSocketRequester rSocket = rule.socket; Mono.delay(Duration.ofSeconds(1)).doOnNext(v -> rule.connection.dispose()).subscribe(); @@ -38,7 +38,7 @@ public void testCurrentStreamIsTerminatedOnConnectionClose() { @Test public void testSubsequentStreamIsTerminatedAfterConnectionClose() { - RSocketClient rSocket = rule.socket; + RSocketRequester rSocket = rule.socket; rule.connection.dispose(); StepVerifier.create(interaction.apply(rSocket)) diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketRequesterTest.java similarity index 98% rename from rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java rename to rsocket-core/src/test/java/io/rsocket/RSocketRequesterTest.java index c60dba312..7337161e0 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java +++ b/rsocket-core/src/test/java/io/rsocket/RSocketRequesterTest.java @@ -47,7 +47,7 @@ import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.UnicastProcessor; -public class RSocketClientTest { +public class RSocketRequesterTest { @Rule public final ClientSocketRule rule = new ClientSocketRule(); @@ -223,10 +223,10 @@ public int sendRequestResponse(Publisher response) { return streamId; } - public static class ClientSocketRule extends AbstractSocketRule { + public static class ClientSocketRule extends AbstractSocketRule { @Override - protected RSocketClient newRSocket() { - return new RSocketClient( + protected RSocketRequester newRSocket() { + return new RSocketRequester( ByteBufAllocator.DEFAULT, connection, DefaultPayload::create, diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketServerTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketResponderTest.java similarity index 97% rename from rsocket-core/src/test/java/io/rsocket/RSocketServerTest.java rename to rsocket-core/src/test/java/io/rsocket/RSocketResponderTest.java index 32c0406b9..c14a73e69 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketServerTest.java +++ b/rsocket-core/src/test/java/io/rsocket/RSocketResponderTest.java @@ -37,7 +37,7 @@ import org.reactivestreams.Subscriber; import reactor.core.publisher.Mono; -public class RSocketServerTest { +public class RSocketResponderTest { @Rule public final ServerSocketRule rule = new ServerSocketRule(); @@ -106,7 +106,7 @@ public Mono requestResponse(Payload payload) { assertThat("Subscription not cancelled.", cancelled.get(), is(true)); } - public static class ServerSocketRule extends AbstractSocketRule { + public static class ServerSocketRule extends AbstractSocketRule { private RSocket acceptingSocket; @@ -140,8 +140,8 @@ public void setAcceptingSocket(RSocket acceptingSocket, int prefetch) { } @Override - protected RSocketServer newRSocket() { - return new RSocketServer( + protected RSocketResponder newRSocket() { + return new RSocketResponder( ByteBufAllocator.DEFAULT, connection, acceptingSocket, diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketTest.java index 9ef37a398..4a7ad45ef 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketTest.java +++ b/rsocket-core/src/test/java/io/rsocket/RSocketTest.java @@ -102,10 +102,10 @@ public static class SocketRule extends ExternalResource { DirectProcessor serverProcessor; DirectProcessor clientProcessor; - private RSocketClient crs; + private RSocketRequester crs; @SuppressWarnings("unused") - private RSocketServer srs; + private RSocketResponder srs; private RSocket requestAcceptor; private ArrayList clientErrors = new ArrayList<>(); @@ -163,7 +163,7 @@ public Flux requestChannel(Publisher payloads) { }; srs = - new RSocketServer( + new RSocketResponder( ByteBufAllocator.DEFAULT, serverConnection, requestAcceptor, @@ -171,7 +171,7 @@ public Flux requestChannel(Publisher payloads) { throwable -> serverErrors.add(throwable)); crs = - new RSocketClient( + new RSocketRequester( ByteBufAllocator.DEFAULT, clientConnection, DefaultPayload::create, diff --git a/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java index 16b7eafb5..74bbe083c 100644 --- a/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java @@ -49,8 +49,8 @@ void responderRejectSetup() { void requesterStreamsTerminatedOnZeroErrorFrame() { TestDuplexConnection conn = new TestDuplexConnection(); List errors = new ArrayList<>(); - RSocketClient rSocket = - new RSocketClient( + RSocketRequester rSocket = + new RSocketRequester( ByteBufAllocator.DEFAULT, conn, DefaultPayload::create, @@ -79,8 +79,8 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { @Test void requesterNewStreamsTerminatedAfterZeroErrorFrame() { TestDuplexConnection conn = new TestDuplexConnection(); - RSocketClient rSocket = - new RSocketClient( + RSocketRequester rSocket = + new RSocketRequester( ByteBufAllocator.DEFAULT, conn, DefaultPayload::create, diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java index a16f6d28d..3d96bfd12 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java @@ -64,7 +64,7 @@ void constructorInvalidMaxFragmentSize() { () -> new FragmentationDuplexConnection( delegate, allocator, Integer.MIN_VALUE, false, "")) - .withMessage("smallest allowed mtu size is 64 bytes"); + .withMessage("smallest allowed mtu size is 64 bytes, provided: -2147483648"); } @DisplayName("constructor throws IllegalArgumentException with negative maxFragmentLength") @@ -72,7 +72,7 @@ void constructorInvalidMaxFragmentSize() { void constructorMtuLessThanMin() { assertThatIllegalArgumentException() .isThrownBy(() -> new FragmentationDuplexConnection(delegate, allocator, 2, false, "")) - .withMessage("smallest allowed mtu size is 64 bytes"); + .withMessage("smallest allowed mtu size is 64 bytes, provided: 2"); } @DisplayName("constructor throws NullPointerException with null byteBufAllocator") diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataFlyweightTest.java new file mode 100644 index 000000000..1a22e9e23 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataFlyweightTest.java @@ -0,0 +1,527 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeAndContentBuffersSlices; +import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeIdFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.util.CharsetUtil; +import io.rsocket.test.util.ByteBufUtils; +import io.rsocket.util.NumberUtils; +import org.junit.jupiter.api.Test; + +class CompositeMetadataFlyweightTest { + + static String byteToBitsString(byte b) { + return String.format("%8s", Integer.toBinaryString(b & 0xFF)).replace(' ', '0'); + } + + static String toHeaderBits(ByteBuf encoded) { + encoded.markReaderIndex(); + byte headerByte = encoded.readByte(); + String byteAsString = byteToBitsString(headerByte); + encoded.resetReaderIndex(); + return byteAsString; + } + // ==== + + @Test + void customMimeHeaderLatin1_encodingFails() { + String mimeNotAscii = "mime/typé"; + + assertThatIllegalArgumentException() + .isThrownBy( + () -> + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, mimeNotAscii, 0)) + .withMessage("custom mime type must be US_ASCII characters only"); + } + + @Test + void customMimeHeaderLength0_encodingFails() { + assertThatIllegalArgumentException() + .isThrownBy( + () -> CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, "", 0)) + .withMessage( + "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void customMimeHeaderLength127() { + StringBuilder builder = new StringBuilder(127); + for (int i = 0; i < 127; i++) { + builder.append('a'); + } + String mimeString = builder.toString(); + ByteBuf encoded = + CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("01111110"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()) + .as("mime length") + .isEqualTo(127 - 1); // encoded as actual length - 1 + + assertThat(header.readCharSequence(127, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + } + + @Test + void customMimeHeaderLength128() { + StringBuilder builder = new StringBuilder(128); + for (int i = 0; i < 128; i++) { + builder.append('a'); + } + String mimeString = builder.toString(); + ByteBuf encoded = + CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("01111111"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()) + .as("mime length") + .isEqualTo(128 - 1); // encoded as actual length - 1 + + assertThat(header.readCharSequence(128, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + } + + @Test + void customMimeHeaderLength129_encodingFails() { + StringBuilder builder = new StringBuilder(129); + for (int i = 0; i < 129; i++) { + builder.append('a'); + } + + assertThatIllegalArgumentException() + .isThrownBy( + () -> + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, builder.toString(), 0)) + .withMessage( + "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void customMimeHeaderLengthOne() { + String mimeString = "w"; + ByteBuf encoded = + CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("00000000"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()).as("mime length").isZero(); // encoded as actual length - 1 + + assertThat(header.readCharSequence(1, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + } + + @Test + void customMimeHeaderLengthTwo() { + String mimeString = "ww"; + ByteBuf encoded = + CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("00000001"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()) + .as("mime length") + .isEqualTo(2 - 1); // encoded as actual length - 1 + + assertThat(header.readCharSequence(2, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + } + + @Test + void customMimeHeaderUtf8_encodingFails() { + String mimeNotAscii = + "mime/tyࠒe"; // this is the SAMARITAN LETTER QUF u+0812 represented on 3 bytes + assertThatIllegalArgumentException() + .isThrownBy( + () -> + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, mimeNotAscii, 0)) + .withMessage("custom mime type must be US_ASCII characters only"); + } + + @Test + void decodeEntryAtEndOfBuffer() { + ByteBuf fakeEntry = Unpooled.buffer(); + + assertThatIllegalArgumentException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeEntryHasNoContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(0); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeEntryTooShortForContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(1); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + NumberUtils.encodeUnsignedMedium(fakeEntry, 456); + fakeEntry.writeChar('w'); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeEntryTooShortForMimeLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(120); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeIdMinusTwoWhenMoreThanOneByte() { + ByteBuf fakeIdBuffer = Unpooled.buffer(2); + fakeIdBuffer.writeInt(200); + + assertThat(decodeMimeIdFromMimeBuffer(fakeIdBuffer)) + .isEqualTo((WellKnownMimeType.UNPARSEABLE_MIME_TYPE.getIdentifier())); + } + + @Test + void decodeIdMinusTwoWhenZeroByte() { + ByteBuf fakeIdBuffer = Unpooled.buffer(0); + + assertThat(decodeMimeIdFromMimeBuffer(fakeIdBuffer)) + .isEqualTo((WellKnownMimeType.UNPARSEABLE_MIME_TYPE.getIdentifier())); + } + + @Test + void decodeStringNullIfLengthOne() { + ByteBuf fakeTypeBuffer = Unpooled.buffer(2); + fakeTypeBuffer.writeByte(1); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeTypeFromMimeBuffer(fakeTypeBuffer)); + } + + @Test + void decodeStringNullIfLengthZero() { + ByteBuf fakeTypeBuffer = Unpooled.buffer(2); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeTypeFromMimeBuffer(fakeTypeBuffer)); + } + + @Test + void decodeTypeSkipsFirstByte() { + ByteBuf fakeTypeBuffer = Unpooled.buffer(2); + fakeTypeBuffer.writeByte(128); + fakeTypeBuffer.writeCharSequence("example", CharsetUtil.US_ASCII); + + assertThat(decodeMimeTypeFromMimeBuffer(fakeTypeBuffer)).hasToString("example"); + } + + @Test + void encodeMetadataCustomTypeDelegates() { + ByteBuf expected = + CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, "foo", 2); + + CompositeByteBuf test = ByteBufAllocator.DEFAULT.compositeBuffer(); + + CompositeMetadataFlyweight.encodeAndAddMetadata( + test, ByteBufAllocator.DEFAULT, "foo", ByteBufUtils.getRandomByteBuf(2)); + + assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + } + + @Test + void encodeMetadataKnownTypeDelegates() { + ByteBuf expected = + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, + WellKnownMimeType.APPLICATION_OCTET_STREAM.getIdentifier(), + 2); + + CompositeByteBuf test = ByteBufAllocator.DEFAULT.compositeBuffer(); + + CompositeMetadataFlyweight.encodeAndAddMetadata( + test, + ByteBufAllocator.DEFAULT, + WellKnownMimeType.APPLICATION_OCTET_STREAM, + ByteBufUtils.getRandomByteBuf(2)); + + assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + } + + @Test + void encodeMetadataReservedTypeDelegates() { + ByteBuf expected = + CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, (byte) 120, 2); + + CompositeByteBuf test = ByteBufAllocator.DEFAULT.compositeBuffer(); + + CompositeMetadataFlyweight.encodeAndAddMetadata( + test, ByteBufAllocator.DEFAULT, (byte) 120, ByteBufUtils.getRandomByteBuf(2)); + + assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + } + + @Test + void encodeTryCompressWithCompressableType() { + ByteBuf metadata = ByteBufUtils.getRandomByteBuf(2); + CompositeByteBuf target = UnpooledByteBufAllocator.DEFAULT.compositeBuffer(); + + CompositeMetadataFlyweight.encodeAndAddMetadataWithCompression( + target, + UnpooledByteBufAllocator.DEFAULT, + WellKnownMimeType.APPLICATION_AVRO.getString(), + metadata); + + assertThat(target.readableBytes()).as("readableBytes 1 + 3 + 2").isEqualTo(6); + } + + @Test + void encodeTryCompressWithCustomType() { + ByteBuf metadata = ByteBufUtils.getRandomByteBuf(2); + CompositeByteBuf target = UnpooledByteBufAllocator.DEFAULT.compositeBuffer(); + + CompositeMetadataFlyweight.encodeAndAddMetadataWithCompression( + target, UnpooledByteBufAllocator.DEFAULT, "custom/example", metadata); + + assertThat(target.readableBytes()).as("readableBytes 1 + 14 + 3 + 2").isEqualTo(20); + } + + @Test + void hasEntry() { + WellKnownMimeType mime = WellKnownMimeType.APPLICATION_AVRO; + + CompositeByteBuf buffer = + Unpooled.compositeBuffer() + .addComponent( + true, + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0)) + .addComponent( + true, + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0)); + + assertThat(CompositeMetadataFlyweight.hasEntry(buffer, 0)).isTrue(); + assertThat(CompositeMetadataFlyweight.hasEntry(buffer, 4)).isTrue(); + assertThat(CompositeMetadataFlyweight.hasEntry(buffer, 8)).isFalse(); + } + + @Test + void isWellKnownMimeType() { + ByteBuf wellKnown = Unpooled.buffer().writeByte(0); + assertThat(CompositeMetadataFlyweight.isWellKnownMimeType(wellKnown)).isTrue(); + + ByteBuf explicit = Unpooled.buffer().writeByte(2).writeChar('a'); + assertThat(CompositeMetadataFlyweight.isWellKnownMimeType(explicit)).isFalse(); + } + + @Test + void knownMimeHeader120_reserved() { + byte mime = (byte) 120; + ByteBuf encoded = + CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mime, 0); + + assertThat(mime) + .as("smoke test RESERVED_120 unsigned 7 bits representation") + .isEqualTo((byte) 0b01111000); + + assertThat(toHeaderBits(encoded)).startsWith("1").isEqualTo("11111000"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isOne(); + + assertThat(byteToBitsString(header.readByte())) + .as("header bit representation") + .isEqualTo("11111000"); + + header.resetReaderIndex(); + assertThat(decodeMimeIdFromMimeBuffer(header)).as("decoded mime id").isEqualTo(mime); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + } + + @Test + void knownMimeHeader127_compositeMetadata() { + WellKnownMimeType mime = WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA; + assertThat(mime.getIdentifier()) + .as("smoke test COMPOSITE unsigned 7 bits representation") + .isEqualTo((byte) 127) + .isEqualTo((byte) 0b01111111); + ByteBuf encoded = + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0); + + assertThat(toHeaderBits(encoded)) + .startsWith("1") + .isEqualTo("11111111") + .isEqualTo(byteToBitsString(mime.getIdentifier()).replaceFirst("0", "1")); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isOne(); + + assertThat(byteToBitsString(header.readByte())) + .as("header bit representation") + .isEqualTo("11111111"); + + header.resetReaderIndex(); + assertThat(decodeMimeIdFromMimeBuffer(header)) + .as("decoded mime id") + .isEqualTo(mime.getIdentifier()); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + } + + @Test + void knownMimeHeaderZero_avro() { + WellKnownMimeType mime = WellKnownMimeType.APPLICATION_AVRO; + assertThat(mime.getIdentifier()) + .as("smoke test AVRO unsigned 7 bits representation") + .isEqualTo((byte) 0) + .isEqualTo((byte) 0b00000000); + ByteBuf encoded = + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0); + + assertThat(toHeaderBits(encoded)) + .startsWith("1") + .isEqualTo("10000000") + .isEqualTo(byteToBitsString(mime.getIdentifier()).replaceFirst("0", "1")); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isOne(); + + assertThat(byteToBitsString(header.readByte())) + .as("header bit representation") + .isEqualTo("10000000"); + + header.resetReaderIndex(); + assertThat(decodeMimeIdFromMimeBuffer(header)) + .as("decoded mime id") + .isEqualTo(mime.getIdentifier()); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java new file mode 100644 index 000000000..cc00df7d4 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java @@ -0,0 +1,157 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.metadata.CompositeMetadata.Entry; +import io.rsocket.metadata.CompositeMetadata.ReservedMimeTypeEntry; +import io.rsocket.metadata.CompositeMetadata.WellKnownMimeTypeEntry; +import io.rsocket.test.util.ByteBufUtils; +import io.rsocket.util.NumberUtils; +import java.util.Iterator; +import org.junit.jupiter.api.Test; + +class CompositeMetadataTest { + + @Test + void decodeEntryHasNoContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(0); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeEntry, false); + + assertThatIllegalStateException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("metadata is malformed"); + } + + @Test + void decodeEntryOnDoneBufferThrowsIllegalArgument() { + ByteBuf fakeBuffer = ByteBufUtils.getRandomByteBuf(0); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeBuffer, false); + + assertThatIllegalArgumentException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("entry index 0 is larger than buffer size"); + } + + @Test + void decodeEntryTooShortForContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(1); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + NumberUtils.encodeUnsignedMedium(fakeEntry, 456); + fakeEntry.writeChar('w'); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeEntry, false); + + assertThatIllegalStateException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("metadata is malformed"); + } + + @Test + void decodeEntryTooShortForMimeLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(120); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeEntry, false); + + assertThatIllegalStateException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("metadata is malformed"); + } + + @Test + void decodeThreeEntries() { + // metadata 1: well known + WellKnownMimeType mimeType1 = WellKnownMimeType.APPLICATION_PDF; + ByteBuf metadata1 = Unpooled.buffer(); + metadata1.writeCharSequence("abcdefghijkl", CharsetUtil.UTF_8); + + // metadata 2: custom + String mimeType2 = "application/custom"; + ByteBuf metadata2 = Unpooled.buffer(); + metadata2.writeChar('E'); + metadata2.writeChar('∑'); + metadata2.writeChar('é'); + metadata2.writeBoolean(true); + metadata2.writeChar('W'); + + // metadata 3: reserved but unknown + byte reserved = 120; + assertThat(WellKnownMimeType.fromIdentifier(reserved)) + .as("ensure UNKNOWN RESERVED used in test") + .isSameAs(WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE); + ByteBuf metadata3 = Unpooled.buffer(); + metadata3.writeByte(88); + + CompositeByteBuf compositeMetadataBuffer = ByteBufAllocator.DEFAULT.compositeBuffer(); + CompositeMetadataFlyweight.encodeAndAddMetadata( + compositeMetadataBuffer, ByteBufAllocator.DEFAULT, mimeType1, metadata1); + CompositeMetadataFlyweight.encodeAndAddMetadata( + compositeMetadataBuffer, ByteBufAllocator.DEFAULT, mimeType2, metadata2); + CompositeMetadataFlyweight.encodeAndAddMetadata( + compositeMetadataBuffer, ByteBufAllocator.DEFAULT, reserved, metadata3); + + Iterator iterator = new CompositeMetadata(compositeMetadataBuffer, true).iterator(); + + assertThat(iterator.next()) + .as("entry1") + .isNotNull() + .satisfies( + e -> + assertThat(e.getMimeType()).as("entry1 mime type").isEqualTo(mimeType1.getString())) + .satisfies( + e -> + assertThat(((WellKnownMimeTypeEntry) e).getType()) + .as("entry1 mime id") + .isEqualTo(WellKnownMimeType.APPLICATION_PDF)) + .satisfies( + e -> + assertThat(e.getContent().toString(CharsetUtil.UTF_8)) + .as("entry1 decoded") + .isEqualTo("abcdefghijkl")); + + assertThat(iterator.next()) + .as("entry2") + .isNotNull() + .satisfies(e -> assertThat(e.getMimeType()).as("entry2 mime type").isEqualTo(mimeType2)) + .satisfies( + e -> assertThat(e.getContent()).as("entry2 decoded").isEqualByComparingTo(metadata2)); + + assertThat(iterator.next()) + .as("entry3") + .isNotNull() + .satisfies(e -> assertThat(e.getMimeType()).as("entry3 mime type").isNull()) + .satisfies( + e -> + assertThat(((ReservedMimeTypeEntry) e).getType()) + .as("entry3 mime id") + .isEqualTo(reserved)) + .satisfies( + e -> assertThat(e.getContent()).as("entry3 decoded").isEqualByComparingTo(metadata3)); + + assertThat(iterator.hasNext()).as("has no more than 3 entries").isFalse(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/WellKnownMimeTypeTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/WellKnownMimeTypeTest.java new file mode 100644 index 000000000..316aaf091 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/WellKnownMimeTypeTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +class WellKnownMimeTypeTest { + + @Test + void fromIdentifierGreaterThan127() { + assertThat(WellKnownMimeType.fromIdentifier(128)) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } + + @Test + void fromIdentifierMatchFromMimeType() { + for (WellKnownMimeType mimeType : WellKnownMimeType.values()) { + if (mimeType == WellKnownMimeType.UNPARSEABLE_MIME_TYPE + || mimeType == WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE) { + continue; + } + assertThat(WellKnownMimeType.fromString(mimeType.toString())) + .as("mimeType string for " + mimeType.name()) + .isSameAs(mimeType); + + assertThat(WellKnownMimeType.fromIdentifier(mimeType.getIdentifier())) + .as("mimeType ID for " + mimeType.name()) + .isSameAs(mimeType); + } + } + + @Test + void fromIdentifierNegative() { + assertThat(WellKnownMimeType.fromIdentifier(-1)) + .isSameAs(WellKnownMimeType.fromIdentifier(-2)) + .isSameAs(WellKnownMimeType.fromIdentifier(-12)) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } + + @Test + void fromIdentifierReserved() { + assertThat(WellKnownMimeType.fromIdentifier(120)) + .isSameAs(WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE); + } + + @Test + void fromStringUnknown() { + assertThat(WellKnownMimeType.fromString("foo/bar")) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } + + @Test + void fromStringUnknownReservedStillReturnsUnparseable() { + assertThat( + WellKnownMimeType.fromString(WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE.getString())) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java index 988bd523d..46e0f77f4 100644 --- a/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java +++ b/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java @@ -18,6 +18,8 @@ import static org.assertj.core.api.Assertions.*; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; @@ -158,4 +160,28 @@ void requireUnsignedShortOverFlow() { .isThrownBy(() -> NumberUtils.requireUnsignedShort(1 << 16)) .withMessage("%d is larger than 16 bits", 1 << 16); } + + @Test + void encodeUnsignedMedium() { + ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); + NumberUtils.encodeUnsignedMedium(buffer, 129); + buffer.markReaderIndex(); + + assertThat(buffer.readUnsignedMedium()).as("reading as unsigned medium").isEqualTo(129); + + buffer.resetReaderIndex(); + assertThat(buffer.readMedium()).as("reading as signed medium").isEqualTo(129); + } + + @Test + void encodeUnsignedMediumLarge() { + ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); + NumberUtils.encodeUnsignedMedium(buffer, 0xFFFFFC); + buffer.markReaderIndex(); + + assertThat(buffer.readUnsignedMedium()).as("reading as unsigned medium").isEqualTo(16777212); + + buffer.resetReaderIndex(); + assertThat(buffer.readMedium()).as("reading as signed medium").isEqualTo(-4); + } } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java new file mode 100644 index 000000000..d3865c01b --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java @@ -0,0 +1,141 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.ws; + +import io.netty.handler.codec.http.HttpResponseStatus; +import io.rsocket.AbstractRSocket; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RSocketFactory; +import io.rsocket.SocketAcceptor; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.WebsocketDuplexConnection; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.HashMap; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.netty.Connection; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +public class WebSocketHeadersSample { + static final Payload payload1 = ByteBufPayload.create("Hello "); + + public static void main(String[] args) { + + ServerTransport.ConnectionAcceptor acceptor = + RSocketFactory.receive() + .frameDecoder(PayloadDecoder.ZERO_COPY) + .acceptor(new SocketAcceptorImpl()) + .toConnectionAcceptor(); + + DisposableServer disposableServer = + HttpServer.create() + .host("localhost") + .port(0) + .route( + routes -> + routes.ws( + "/", + (in, out) -> { + if (in.headers().containsValue("Authorization", "test", true)) { + DuplexConnection connection = + new WebsocketDuplexConnection((Connection) in); + return acceptor.apply(connection).then(out.neverComplete()); + } + + return out.sendClose( + HttpResponseStatus.UNAUTHORIZED.code(), + HttpResponseStatus.UNAUTHORIZED.reasonPhrase()); + })) + .bindNow(); + + WebsocketClientTransport clientTransport = + WebsocketClientTransport.create(disposableServer.host(), disposableServer.port()); + + clientTransport.setTransportHeaders( + () -> { + HashMap map = new HashMap<>(); + map.put("Authorization", "test"); + return map; + }); + + RSocket socket = + RSocketFactory.connect() + .keepAliveAckTimeout(Duration.ofMinutes(10)) + .frameDecoder(PayloadDecoder.ZERO_COPY) + .transport(clientTransport) + .start() + .block(); + + Flux.range(0, 100) + .concatMap(i -> socket.fireAndForget(payload1.retain())) + // .doOnNext(p -> { + //// System.out.println(p.getDataUtf8()); + // p.release(); + // }) + .blockLast(); + socket.dispose(); + + WebsocketClientTransport clientTransport2 = + WebsocketClientTransport.create(disposableServer.host(), disposableServer.port()); + + RSocket rSocket = + RSocketFactory.connect() + .keepAliveAckTimeout(Duration.ofMinutes(10)) + .frameDecoder(PayloadDecoder.ZERO_COPY) + .transport(clientTransport2) + .start() + .block(); + + // expect error here because of closed channel + rSocket.requestResponse(payload1).block(); + } + + private static class SocketAcceptorImpl implements SocketAcceptor { + @Override + public Mono accept(ConnectionSetupPayload setupPayload, RSocket reactiveSocket) { + return Mono.just( + new AbstractRSocket() { + + @Override + public Mono fireAndForget(Payload payload) { + // System.out.println(payload.getDataUtf8()); + payload.release(); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads).subscribeOn(Schedulers.single()); + } + }); + } + } +} diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java index e41d9e7db..40db8ef74 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java @@ -76,7 +76,8 @@ private Mono connect() { @Override public Mono connect(int mtu) { - Mono connect = connect(); + Mono isError = FragmentationDuplexConnection.checkMtu(mtu); + Mono connect = isError != null ? isError : connect(); if (mtu > 0) { return connect.map( duplexConnection -> diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java index f649d30ce..d755859d2 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java @@ -107,17 +107,20 @@ public LocalClientTransport clientTransport() { public Mono start(ConnectionAcceptor acceptor, int mtu) { Objects.requireNonNull(acceptor, "acceptor must not be null"); - return Mono.create( - sink -> { - ServerDuplexConnectionAcceptor serverDuplexConnectionAcceptor = - new ServerDuplexConnectionAcceptor(name, acceptor, mtu); - - if (registry.putIfAbsent(name, serverDuplexConnectionAcceptor) != null) { - throw new IllegalStateException("name already registered: " + name); - } - - sink.success(serverDuplexConnectionAcceptor); - }); + Mono isError = FragmentationDuplexConnection.checkMtu(mtu); + return isError != null + ? isError + : Mono.create( + sink -> { + ServerDuplexConnectionAcceptor serverDuplexConnectionAcceptor = + new ServerDuplexConnectionAcceptor(name, acceptor, mtu); + + if (registry.putIfAbsent(name, serverDuplexConnectionAcceptor) != null) { + throw new IllegalStateException("name already registered: " + name); + } + + sink.success(serverDuplexConnectionAcceptor); + }); } /** diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java index 250d2eee0..f5e79e9bf 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java @@ -94,21 +94,24 @@ public static TcpClientTransport create(TcpClient client) { @Override public Mono connect(int mtu) { - return client - .doOnConnected(c -> c.addHandlerLast(new RSocketLengthCodec())) - .connect() - .map( - c -> { - if (mtu > 0) { - return new FragmentationDuplexConnection( - new TcpDuplexConnection(c, false), - ByteBufAllocator.DEFAULT, - mtu, - true, - "client"); - } else { - return new TcpDuplexConnection(c); - } - }); + Mono isError = FragmentationDuplexConnection.checkMtu(mtu); + return isError != null + ? isError + : client + .doOnConnected(c -> c.addHandlerLast(new RSocketLengthCodec())) + .connect() + .map( + c -> { + if (mtu > 0) { + return new FragmentationDuplexConnection( + new TcpDuplexConnection(c, false), + ByteBufAllocator.DEFAULT, + mtu, + true, + "client"); + } else { + return new TcpDuplexConnection(c); + } + }); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java index 0da5b04d8..5049119a5 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java @@ -151,21 +151,24 @@ private static TcpClient createClient(URI uri) { @Override public Mono connect(int mtu) { - return client - .headers(headers -> transportHeaders.get().forEach(headers::set)) - .websocket(FRAME_LENGTH_MASK) - .uri(path) - .connect() - .map( - c -> { - DuplexConnection connection = new WebsocketDuplexConnection(c); - if (mtu > 0) { - connection = - new FragmentationDuplexConnection( - connection, ByteBufAllocator.DEFAULT, mtu, false, "client"); - } - return connection; - }); + Mono isError = FragmentationDuplexConnection.checkMtu(mtu); + return isError != null + ? isError + : client + .headers(headers -> transportHeaders.get().forEach(headers::set)) + .websocket(FRAME_LENGTH_MASK) + .uri(path) + .connect() + .map( + c -> { + DuplexConnection connection = new WebsocketDuplexConnection(c); + if (mtu > 0) { + connection = + new FragmentationDuplexConnection( + connection, ByteBufAllocator.DEFAULT, mtu, false, "client"); + } + return connection; + }); } @Override diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java index b7f60aa6c..54ef016c0 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java @@ -94,26 +94,31 @@ public static TcpServerTransport create(TcpServer server) { @Override public Mono start(ConnectionAcceptor acceptor, int mtu) { Objects.requireNonNull(acceptor, "acceptor must not be null"); - - return server - .doOnConnection( - c -> { - c.addHandlerLast(new RSocketLengthCodec()); - DuplexConnection connection; - if (mtu > 0) { - connection = - new FragmentationDuplexConnection( - new TcpDuplexConnection(c, false), - ByteBufAllocator.DEFAULT, - mtu, - true, - "server"); - } else { - connection = new TcpDuplexConnection(c); - } - acceptor.apply(connection).then(Mono.never()).subscribe(c.disposeSubscriber()); - }) - .bind() - .map(CloseableChannel::new); + Mono isError = FragmentationDuplexConnection.checkMtu(mtu); + return isError != null + ? isError + : server + .doOnConnection( + c -> { + c.addHandlerLast(new RSocketLengthCodec()); + DuplexConnection connection; + if (mtu > 0) { + connection = + new FragmentationDuplexConnection( + new TcpDuplexConnection(c, false), + ByteBufAllocator.DEFAULT, + mtu, + true, + "server"); + } else { + connection = new TcpDuplexConnection(c); + } + acceptor + .apply(connection) + .then(Mono.never()) + .subscribe(c.disposeSubscriber()); + }) + .bind() + .map(CloseableChannel::new); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java index 4ac68cdfd..ee0b8e3d3 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java @@ -111,24 +111,28 @@ public void setTransportHeaders(Supplier> transportHeaders) public Mono start(ConnectionAcceptor acceptor, int mtu) { Objects.requireNonNull(acceptor, "acceptor must not be null"); - return server - .handle( - (request, response) -> { - transportHeaders.get().forEach(response::addHeader); - return response.sendWebsocket( - null, - FRAME_LENGTH_MASK, - (in, out) -> { - DuplexConnection connection = new WebsocketDuplexConnection((Connection) in); - if (mtu > 0) { - connection = - new FragmentationDuplexConnection( - connection, ByteBufAllocator.DEFAULT, mtu, false, "server"); - } - return acceptor.apply(connection).then(out.neverComplete()); - }); - }) - .bind() - .map(CloseableChannel::new); + Mono isError = FragmentationDuplexConnection.checkMtu(mtu); + return isError != null + ? isError + : server + .handle( + (request, response) -> { + transportHeaders.get().forEach(response::addHeader); + return response.sendWebsocket( + null, + FRAME_LENGTH_MASK, + (in, out) -> { + DuplexConnection connection = + new WebsocketDuplexConnection((Connection) in); + if (mtu > 0) { + connection = + new FragmentationDuplexConnection( + connection, ByteBufAllocator.DEFAULT, mtu, false, "server"); + } + return acceptor.apply(connection).then(out.neverComplete()); + }); + }) + .bind() + .map(CloseableChannel::new); } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java new file mode 100644 index 000000000..07e9378fa --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java @@ -0,0 +1,125 @@ +package io.rsocket.transport.netty; + +import static io.rsocket.RSocketFactory.*; + +import io.rsocket.RSocket; +import io.rsocket.RSocketFactory; +import io.rsocket.SocketAcceptor; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.time.Duration; +import java.util.function.Function; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +class RSocketFactoryNettyTransportFragmentationTest { + + @ParameterizedTest + @MethodSource("serverTransportProvider") + void serverErrorsWithEnabledFragmentationOnInsufficientMtu( + ServerTransport serverTransport) { + Mono server = createServer(serverTransport, f -> f.fragment(2)); + + StepVerifier.create(server) + .expectErrorMatches( + err -> + err instanceof IllegalArgumentException + && "smallest allowed mtu size is 64 bytes, provided: 2" + .equals(err.getMessage())) + .verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("serverTransportProvider") + void serverSucceedsWithEnabledFragmentationOnSufficientMtu( + ServerTransport serverTransport) { + Mono server = + createServer(serverTransport, f -> f.fragment(100)).doOnNext(CloseableChannel::dispose); + StepVerifier.create(server).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("serverTransportProvider") + void serverSucceedsWithDisabledFragmentation() { + Mono server = + createServer(TcpServerTransport.create("localhost", 0), Function.identity()) + .doOnNext(CloseableChannel::dispose); + StepVerifier.create(server).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("serverTransportProvider") + void clientErrorsWithEnabledFragmentationOnInsufficientMtu( + ServerTransport serverTransport) { + CloseableChannel server = createServer(serverTransport, f -> f.fragment(100)).block(); + + Mono rSocket = + createClient(TcpClientTransport.create(server.address()), f -> f.fragment(2)) + .doFinally(s -> server.dispose()); + + StepVerifier.create(rSocket) + .expectErrorMatches( + err -> + err instanceof IllegalArgumentException + && "smallest allowed mtu size is 64 bytes, provided: 2" + .equals(err.getMessage())) + .verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("serverTransportProvider") + void clientSucceedsWithEnabledFragmentationOnSufficientMtu( + ServerTransport serverTransport) { + CloseableChannel server = createServer(serverTransport, f -> f.fragment(100)).block(); + + Mono rSocket = + createClient(TcpClientTransport.create(server.address()), f -> f.fragment(100)) + .doFinally(s -> server.dispose()); + StepVerifier.create(rSocket).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("serverTransportProvider") + void clientSucceedsWithDisabledFragmentation() { + CloseableChannel server = + createServer(TcpServerTransport.create("localhost", 0), Function.identity()).block(); + + Mono rSocket = + createClient(TcpClientTransport.create(server.address()), Function.identity()) + .doFinally(s -> server.dispose()); + StepVerifier.create(rSocket).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + private Mono createClient( + ClientTransport transport, Function f) { + return f.apply(RSocketFactory.connect()).transport(transport).start(); + } + + private Mono createServer( + ServerTransport transport, + Function f) { + return f.apply(receive()).acceptor(mockAcceptor()).transport(transport).start(); + } + + private SocketAcceptor mockAcceptor() { + SocketAcceptor mock = Mockito.mock(SocketAcceptor.class); + Mockito.when(mock.accept(Mockito.any(), Mockito.any())) + .thenReturn(Mono.just(Mockito.mock(RSocket.class))); + return mock; + } + + static Stream> serverTransportProvider() { + String host = "localhost"; + int port = 0; + return Stream.of( + TcpServerTransport.create(host, port), WebsocketServerTransport.create(host, port)); + } +}