From 533dab13d09073754c36ff2ba261442c2a600e67 Mon Sep 17 00:00:00 2001 From: Adrian Todt Date: Fri, 14 Feb 2020 22:10:29 -0300 Subject: [PATCH 1/7] An reactive experiment, using Project Reactor --- build.gradle.kts | 1 + src/main/java/com/rethinkdb/ast/ReqlAst.java | 81 +++++-- .../java/com/rethinkdb/net/Connection.java | 74 +++---- src/main/java/com/rethinkdb/net/Cursor.java | 45 ---- .../java/com/rethinkdb/net/CursorImpl.java | 209 ------------------ .../com/rethinkdb/net/ResponseHandler.java | 140 ++++++++++++ src/main/java/com/rethinkdb/net/Util.java | 12 +- .../java/com/rethinkdb/RethinkDBTest.java | 1 - .../java/com/rethinkdb/TestingCommon.java | 1 - 9 files changed, 241 insertions(+), 323 deletions(-) delete mode 100644 src/main/java/com/rethinkdb/net/Cursor.java delete mode 100644 src/main/java/com/rethinkdb/net/CursorImpl.java create mode 100644 src/main/java/com/rethinkdb/net/ResponseHandler.java diff --git a/build.gradle.kts b/build.gradle.kts index b6d67863..7f73b0a7 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -23,6 +23,7 @@ dependencies { compile("org.slf4j:slf4j-api:1.7.12") compile("org.jetbrains:annotations:17.0.0") compile("com.fasterxml.jackson.core:jackson-databind:2.10.0") + compile("io.projectreactor:reactor-core:3.3.2.RELEASE") } signing { diff --git a/src/main/java/com/rethinkdb/ast/ReqlAst.java b/src/main/java/com/rethinkdb/ast/ReqlAst.java index 8ddabcdc..cb18ab48 100644 --- a/src/main/java/com/rethinkdb/ast/ReqlAst.java +++ b/src/main/java/com/rethinkdb/ast/ReqlAst.java @@ -1,11 +1,14 @@ package com.rethinkdb.ast; +import com.fasterxml.jackson.core.type.TypeReference; import com.rethinkdb.gen.exc.ReqlDriverError; import com.rethinkdb.gen.proto.TermType; import com.rethinkdb.model.Arguments; import com.rethinkdb.model.OptArgs; import com.rethinkdb.net.Connection; +import reactor.core.publisher.Flux; +import java.lang.reflect.Type; import java.util.*; import java.util.stream.Collectors; @@ -54,10 +57,9 @@ protected Object build() { * which may be iterated to get a sequence of atom results * * @param conn The connection to run this query - * @param The type of result * @return The result of this query */ - public T run(Connection conn) { + public Flux run(Connection conn) { return conn.run(this, new OptArgs(), null); } @@ -69,46 +71,76 @@ public T run(Connection conn) { * * @param conn The connection to run this query * @param runOpts The options to run this query with - * @param The type of result * @return The result of this query */ - public T run(Connection conn, OptArgs runOpts) { + public Flux run(Connection conn, OptArgs runOpts) { return conn.run(this, runOpts, null); } /** * Runs this query via connection {@code conn} with default options and returns an atom result * or a sequence result as a cursor. The atom result representing a JSON object is converted - * to an object of type {@code Class

} specified with {@code pojoClass}. The cursor + * to an object of type {@code Class} specified with {@code typeRef}. The cursor * is a {@code com.rethinkdb.net.Cursor} which may be iterated to get a sequence of atom results - * of type {@code Class

} + * of type {@code Class} * - * @param conn The connection to run this query - * @param pojoClass The class of POJO to convert to * @param The type of result - * @param

The type of POJO to convert to + * @param conn The connection to run this query + * @param typeRef The class of POJO to convert to * @return The result of this query (either a {@code P or a Cursor

} */ - public T run(Connection conn, Class

pojoClass) { - return conn.run(this, new OptArgs(), pojoClass); + public Flux run(Connection conn, Class typeRef) { + return conn.run(this, new OptArgs(), new ClassReference<>(typeRef)); } /** * Runs this query via connection {@code conn} with options {@code runOpts} and returns an atom result * or a sequence result as a cursor. The atom result representing a JSON object is converted - * to an object of type {@code Class

} specified with {@code pojoClass}. The cursor + * to an object of type {@code Class} specified with {@code typeRef}. The cursor * is a {@code com.rethinkdb.net.Cursor} which may be iterated to get a sequence of atom results - * of type {@code Class

} + * of type {@code Class} * + * @param The type of result * @param conn The connection to run this query * @param runOpts The options to run this query with - * @param pojoClass The class of POJO to convert to + * @param typeRef The class of POJO to convert to + * @return The result of this query (either a {@code P or a Cursor

} + */ + public Flux run(Connection conn, OptArgs runOpts, Class typeRef) { + return conn.run(this, runOpts, new ClassReference<>(typeRef)); + } + + /** + * Runs this query via connection {@code conn} with default options and returns an atom result + * or a sequence result as a cursor. The atom result representing a JSON object is converted + * to an object of type {@code Class} specified with {@code typeRef}. The cursor + * is a {@code com.rethinkdb.net.Cursor} which may be iterated to get a sequence of atom results + * of type {@code Class} + * * @param The type of result - * @param

The type of POJO to convert to + * @param conn The connection to run this query + * @param typeRef The class of POJO to convert to * @return The result of this query (either a {@code P or a Cursor

} */ - public T run(Connection conn, OptArgs runOpts, Class

pojoClass) { - return conn.run(this, runOpts, pojoClass); + public Flux run(Connection conn, TypeReference typeRef) { + return conn.run(this, new OptArgs(), typeRef); + } + + /** + * Runs this query via connection {@code conn} with options {@code runOpts} and returns an atom result + * or a sequence result as a cursor. The atom result representing a JSON object is converted + * to an object of type {@code Class} specified with {@code typeRef}. The cursor + * is a {@code com.rethinkdb.net.Cursor} which may be iterated to get a sequence of atom results + * of type {@code Class} + * + * @param The type of result + * @param conn The connection to run this query + * @param runOpts The options to run this query with + * @param typeRef The class of POJO to convert to + * @return The result of this query (either a {@code P or a Cursor

} + */ + public Flux run(Connection conn, OptArgs runOpts, TypeReference typeRef) { + return conn.run(this, runOpts, typeRef); } public void runNoReply(Connection conn) { @@ -127,4 +159,17 @@ public String toString() { ", optargs=" + optargs + '}'; } -} + + static class ClassReference extends TypeReference { + private Class c; + + ClassReference(Class c) { + this.c = c; + } + + @Override + public Type getType() { + return c; + } + } +} \ No newline at end of file diff --git a/src/main/java/com/rethinkdb/net/Connection.java b/src/main/java/com/rethinkdb/net/Connection.java index c3d0b4bc..78d6a930 100644 --- a/src/main/java/com/rethinkdb/net/Connection.java +++ b/src/main/java/com/rethinkdb/net/Connection.java @@ -1,5 +1,6 @@ package com.rethinkdb.net; +import com.fasterxml.jackson.core.type.TypeReference; import com.rethinkdb.ast.Query; import com.rethinkdb.ast.ReqlAst; import com.rethinkdb.gen.ast.Db; @@ -7,14 +8,16 @@ import com.rethinkdb.model.Arguments; import com.rethinkdb.model.OptArgs; import org.jetbrains.annotations.Nullable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import javax.net.ssl.SSLContext; import java.io.Closeable; import java.io.IOException; import java.io.InputStream; import java.net.SocketAddress; -import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.locks.ReentrantLock; @@ -35,7 +38,6 @@ public class Connection implements Closeable { // network stuff private @Nullable SocketWrapper socket; - private Map> cursorCache = new ConcurrentHashMap<>(); // execution stuff private ExecutorService exec; @@ -154,10 +156,12 @@ public void close(boolean shouldNoreplyWait) { nextToken.set(0); // clear cursor cache - for (CursorImpl cursor : cursorCache.values()) { - cursor.setError("Connection is closed."); + for (ResponseHandler handler : tracked) { + try { + handler.onConnectionClosed(); + } catch (InterruptedException ignored) { + } } - cursorCache.clear(); // handle current awaiters this.awaiters.values().forEach(awaiter -> { @@ -197,7 +201,7 @@ public void use(String db) { * @param query the query to execute. * @return a completable future. */ - private Future sendQuery(Query query) { + private CompletableFuture sendQuery(Query query) { // check if response pump is running if (!exec.isShutdown() && !exec.isTerminated()) { final CompletableFuture awaiter = new CompletableFuture<>(); @@ -242,44 +246,16 @@ private void runQueryNoreply(Query query) { throw new ReqlDriverError("Can't write query because response pump is not running."); } - @SuppressWarnings("unchecked") - private T runQuery(Query query, @Nullable Class

pojoClass) { - Response res; - try { - res = sendQuery(query).get(); - } catch (InterruptedException | ExecutionException e) { - throw new ReqlDriverError(e); - } - - if (res.isAtom()) { - try { - Converter.FormatOptions fmt = new Converter.FormatOptions(query.globalOptions); - Object value = ((List) Converter.convertPseudotypes(res.data, fmt)).get(0); - return Util.convertToPojo(value, pojoClass); - } catch (IndexOutOfBoundsException ex) { - throw new ReqlDriverError("Atom response was empty!", ex); - } - } else if (res.isPartial() || res.isSequence()) { - return (T) new CursorImpl<>(this, query, res, (Class) pojoClass); - } else if (res.isWaitComplete()) { - return null; - } else { - throw res.makeError(query); - } + private Flux runQuery(Query query, @Nullable TypeReference typeRef) { + return Mono.fromFuture(sendQuery(query)).onErrorMap(ReqlDriverError::new) + .flatMapMany(res -> Flux.create(new ResponseHandler<>(this, query, res, typeRef))); } private long newToken() { return nextToken.incrementAndGet(); } - void addToCache(long token, CursorImpl cursor) { - cursorCache.put(token, cursor); - } - - void removeFromCache(long token) { - cursorCache.remove(token); - } - + // unused for some reason public void noreplyWait() { runQuery(Query.noreplyWait(newToken()), null); } @@ -296,13 +272,13 @@ private void setDefaultDB(OptArgs globalOpts) { } } - public T run(ReqlAst term, OptArgs globalOpts, @Nullable Class

pojoClass) { + public Flux run(ReqlAst term, OptArgs globalOpts, @Nullable TypeReference typeRef) { setDefaultDB(globalOpts); Query q = Query.start(newToken(), term, globalOpts); if (globalOpts.containsKey("noreply")) { throw new ReqlDriverError("Don't provide the noreply option as an optarg. Use `.runNoReply` instead of `.run`"); } - return runQuery(q, pojoClass); + return runQuery(q, typeRef); } public void runNoReply(ReqlAst term, OptArgs globalOpts) { @@ -311,15 +287,25 @@ public void runNoReply(ReqlAst term, OptArgs globalOpts) { runQueryNoreply(Query.start(newToken(), term, globalOpts)); } - Future continue_(Cursor cursor) { - return sendQuery(Query.continue_(cursor.connectionToken())); + CompletableFuture continueResponse(long token) { + return sendQuery(Query.continue_(token)); } - void stop(Cursor cursor) { + void stop(long token) { // While the server does reply to the stop request, we ignore that reply. // This works because the response pump in `connect` ignores replies for which // no waiter exists. - runQueryNoreply(Query.stop(cursor.connectionToken())); + runQueryNoreply(Query.stop(token)); + } + + Set> tracked = ConcurrentHashMap.newKeySet(); + + public void loseTrackOf(ResponseHandler r) { + tracked.add(r); + } + + public void keepTrackOf(ResponseHandler r) { + tracked.remove(r); } /** diff --git a/src/main/java/com/rethinkdb/net/Cursor.java b/src/main/java/com/rethinkdb/net/Cursor.java deleted file mode 100644 index 9ec6c50f..00000000 --- a/src/main/java/com/rethinkdb/net/Cursor.java +++ /dev/null @@ -1,45 +0,0 @@ -package com.rethinkdb.net; - -import java.io.Closeable; -import java.util.Iterator; -import java.util.List; -import java.util.concurrent.TimeoutException; -import java.util.function.BiConsumer; -import java.util.stream.Collector; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import java.util.stream.StreamSupport; - -public interface Cursor extends Iterator, Iterable, Closeable { - long connectionToken(); - - Converter.FormatOptions formatOptions(); - - @Override - void close(); - - boolean isFeed(); - - int bufferedSize(); - - T next(long timeout) throws TimeoutException; - - default List toList() { - return collect(Collectors.toList()); - } - - default R collect(Collector collector) { - A container = collector.supplier().get(); - BiConsumer accumulator = collector.accumulator(); - forEachRemaining(next -> accumulator.accept(container, next)); - return collector.finisher().apply(container); - } - - default Stream stream() { - return StreamSupport.stream(spliterator(), false); - } - - default Stream parallelStream() { - return StreamSupport.stream(spliterator(), true); - } -} diff --git a/src/main/java/com/rethinkdb/net/CursorImpl.java b/src/main/java/com/rethinkdb/net/CursorImpl.java deleted file mode 100644 index 9ba926a5..00000000 --- a/src/main/java/com/rethinkdb/net/CursorImpl.java +++ /dev/null @@ -1,209 +0,0 @@ -package com.rethinkdb.net; - -import com.rethinkdb.ast.Query; -import com.rethinkdb.gen.exc.ReqlDriverError; -import com.rethinkdb.gen.exc.ReqlRuntimeError; -import com.rethinkdb.gen.proto.ResponseType; -import org.jetbrains.annotations.Nullable; - -import java.util.Deque; -import java.util.Iterator; -import java.util.NoSuchElementException; -import java.util.concurrent.Future; -import java.util.concurrent.LinkedBlockingDeque; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; - -class CursorImpl implements Cursor { - private final Connection connection; - private final Query query; - private final long token; - private final boolean feed; - @Nullable - private final Class pojoClass; - private final Converter.FormatOptions fmt; - private final Deque bufferedItems = new LinkedBlockingDeque<>(); - - // mutable members - private int outstandingRequests = 0; - private int threshold = 1; - @Nullable - private RuntimeException error = null; - private boolean alreadyIterated = false; - private Future awaitingContinue = null; - - //constructor - CursorImpl(Connection connection, Query query, Response firstResponse, @Nullable Class pojoClass) { - this.connection = connection; - this.query = query; - this.token = query.token; - this.feed = firstResponse.isFeed(); - this.pojoClass = pojoClass; - this.fmt = new Converter.FormatOptions(query.globalOptions); - connection.addToCache(query.token, this); - maybeSendContinue(); - extendInternal(firstResponse); - } - - //region Cursor implementation - @Override - public long connectionToken() { - return token; - } - - @Override - public Converter.FormatOptions formatOptions() { - return fmt; - } - - @Override - public boolean isFeed() { - return feed; - } - - @Override - public int bufferedSize() { - return bufferedItems.size(); - } - - @Override - public T next(long timeout) throws TimeoutException { - while (bufferedItems.size() == 0) { - maybeSendContinue(); - waitOnCursorItems(timeout); - - if (bufferedItems.size() != 0) { - break; - } - - if (error != null) { - throw error; - } - } - - return Util.convertToPojo(Converter.convertPseudotypes(bufferedItems.pop(), fmt), pojoClass); - } - - @Override - public void close() { - connection.removeFromCache(this.token); - if (error == null) { - error = new NoSuchElementException(); - if (connection.isOpen()) { - outstandingRequests += 1; - connection.stop(this); - } - } - } - - @Override - public Iterator iterator() { - if (!alreadyIterated) { - alreadyIterated = true; - return this; - } - throw new ReqlDriverError("The results of this query have already been consumed."); - } - - @Override - public boolean hasNext() { - if (bufferedItems.size() > 0) { - return true; - } - if (error != null) { - return false; - } - if (feed) { - return true; - } - - maybeSendContinue(); - waitOnCursorItems(); - - return bufferedItems.size() > 0; - } - - @Override - public T next() { - while (bufferedItems.size() == 0) { - maybeSendContinue(); - waitOnCursorItems(); - - if (bufferedItems.size() != 0) { - break; - } - - if (error != null) { - throw error; - } - } - - return Util.convertToPojo(Converter.convertPseudotypes(bufferedItems.pop(), fmt), pojoClass); - } - - //end - - //region internals - private void maybeSendContinue() { - if (error == null && bufferedItems.size() < threshold && outstandingRequests == 0) { - outstandingRequests += 1; - this.awaitingContinue = connection.continue_(this); - } - } - - private void waitOnCursorItems() { - Response res; - try { - res = this.awaitingContinue.get(); - } catch (Exception e) { - throw new ReqlDriverError(e); - } - extend(res); - } - - private void waitOnCursorItems(long timeout) throws TimeoutException { - Response res; - try { - res = this.awaitingContinue.get(timeout, TimeUnit.MILLISECONDS); - } catch (TimeoutException exc) { - throw exc; - } catch (Exception e) { - throw new ReqlDriverError(e); - } - extend(res); - } - - private void extend(Response response) { - outstandingRequests -= 1; - maybeSendContinue(); - extendInternal(response); - } - - private void extendInternal(Response response) { - threshold = response.data.size(); - if (error == null) { - if (response.isPartial()) { - bufferedItems.addAll(response.data); - } else if (response.isSequence()) { - bufferedItems.addAll(response.data); - error = new NoSuchElementException(); - } else { - error = response.makeError(query); - } - } - if (outstandingRequests == 0 && error != null) { - connection.removeFromCache(response.token); - } - } - - void setError(String errMsg) { - if (error == null) { - error = new ReqlRuntimeError(errMsg); - Response dummyResponse = Response - .make(query.token, ResponseType.SUCCESS_SEQUENCE) - .build(); - extendInternal(dummyResponse); - } - } - //end -} diff --git a/src/main/java/com/rethinkdb/net/ResponseHandler.java b/src/main/java/com/rethinkdb/net/ResponseHandler.java new file mode 100644 index 00000000..7569a21b --- /dev/null +++ b/src/main/java/com/rethinkdb/net/ResponseHandler.java @@ -0,0 +1,140 @@ +package com.rethinkdb.net; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.rethinkdb.ast.Query; +import com.rethinkdb.gen.exc.ReqlDriverError; +import com.rethinkdb.gen.exc.ReqlRuntimeError; +import com.rethinkdb.gen.proto.ResponseType; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; + +import java.util.List; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +public class ResponseHandler implements Consumer> { + private final Connection connection; + private final Query query; + private final Response firstRes; + private final TypeReference typeRef; + private final Converter.FormatOptions fmt; + + // This gets used if it's a partial request. + private final Semaphore requesting = new Semaphore(1); + private final Semaphore emitting = new Semaphore(1); + private final AtomicLong requestCount = new AtomicLong(); + private final AtomicReference currentResponse = new AtomicReference<>(); + private final AtomicReference> sink = new AtomicReference<>(); + + public ResponseHandler(Connection connection, Query query, Response firstRes, TypeReference typeRef) { + this.connection = connection; + this.query = query; + this.firstRes = firstRes; + this.typeRef = typeRef; + fmt = new Converter.FormatOptions(query.globalOptions); + currentResponse.set(firstRes); + } + + @Override + public void accept(final FluxSink sink) { + if (firstRes.isWaitComplete()) { + sink.complete(); + return; + } + + if (firstRes.isAtom() || firstRes.isSequence()) { + try { + emitData(sink); + } catch (IndexOutOfBoundsException ex) { + throw new ReqlDriverError("Atom response was empty!", ex); + } + sink.complete(); + return; + } + + if (firstRes.isPartial()) { + // Welcome to the code documentation of partial sequences, please take a seat. + + // First of all, we emit all of this request. Reactor's buffer should handle this. + emitData(sink); + + // It is a partial response, so connection should be able to kill us if needed, and clients should be able to stop us. + this.sink.set(sink); + sink.onCancel(() -> { + connection.loseTrackOf(this); + connection.stop(firstRes.token); + }); + sink.onDispose(() -> connection.loseTrackOf(this)); + connection.keepTrackOf(this); + + // We can't simply overflow buffers, so we gotta do small batches. + sink.onRequest(amount -> onRequest(sink, amount)); + return; + } + + sink.error(firstRes.makeError(query)); + } + + private void onRequest(FluxSink sink, long amount) { + final Response lastRes = currentResponse.get(); + if (lastRes.isPartial() && requestCount.addAndGet(amount) > 0 && requesting.tryAcquire()) { + // great, we should make a CONTINUE request. + + // TODO isolate this into methods + Mono.fromFuture(connection.continueResponse(lastRes.token)).subscribe( + nextRes -> { // Okay, let's process this response. + boolean shouldContinue = currentResponse.compareAndSet(lastRes, nextRes); + if (nextRes.isSequence()) { + try { + emitting.acquire(); + emitData(sink); + emitting.release(); + sink.complete(); // Completed. This means it's over. + } catch (InterruptedException e) { + sink.error(e); // It errored. This means it's over. + } + } else if (nextRes.isPartial()) { + // Okay, we got another partial response, so there's more. + + requesting.release(); // Request's over, release this for later. + try { + emitting.acquire(); + int count = emitData(sink); + emitting.release(); + if (shouldContinue) { + onRequest(sink, -count); //Recursion! + } + } catch (InterruptedException e) { + sink.error(e); // It errored. This means it's over. + } + } else { + sink.error(nextRes.makeError(query)); // It errored. This means it's over. + } + }, sink::error // It errored. This means it's over. + ); + } + } + + void onConnectionClosed() throws InterruptedException { + // This will spin wait for a bit until it is not null + while (sink.compareAndSet(null, null)) Thread.yield(); + FluxSink sink = this.sink.get(); + currentResponse.set(Response.make(query.token, ResponseType.SUCCESS_SEQUENCE).build()); + try { + emitting.acquire(); + } finally { + sink.error(new ReqlRuntimeError("Connection is closed.")); + } + } + + @SuppressWarnings("unchecked") + private int emitData(final FluxSink sink) { + List objects = (List) Converter.convertPseudotypes(firstRes.data, fmt); + for (Object each : objects) { + sink.next(Util.convertToPojo(each, typeRef)); + } + return objects.size(); + } +} diff --git a/src/main/java/com/rethinkdb/net/Util.java b/src/main/java/com/rethinkdb/net/Util.java index 4019a7d5..f12178b8 100644 --- a/src/main/java/com/rethinkdb/net/Util.java +++ b/src/main/java/com/rethinkdb/net/Util.java @@ -1,5 +1,6 @@ package com.rethinkdb.net; +import com.fasterxml.jackson.core.type.TypeReference; import com.rethinkdb.RethinkDB; import com.rethinkdb.gen.exc.ReqlDriverError; @@ -56,17 +57,18 @@ public static Map toJSON(ByteBuffer buf) { } @SuppressWarnings("unchecked") - public static T convertToPojo(Object value, Class

pojoClass) { - if (pojoClass != null) { - if (pojoClass.isEnum()) { - Enum[] enumConstants = ((Class>) pojoClass).getEnumConstants(); + public static T convertToPojo(Object value, TypeReference typeRef) { + if (typeRef != null) { + Class rawClass = RethinkDB.getInternalMapper().getTypeFactory().constructType(typeRef).getRawClass(); + if (rawClass.isEnum()) { + Enum[] enumConstants = ((Class>) rawClass).getEnumConstants(); for (Enum enumConst : enumConstants) { if (enumConst.name().equals(value)) { return (T) enumConst; } } } else if (value instanceof Map) { - return (T) RethinkDB.getPOJOMapper().convertValue(value, pojoClass); + return (T) RethinkDB.getPOJOMapper().convertValue(value, typeRef); } } return (T) value; diff --git a/src/test/java/com/rethinkdb/RethinkDBTest.java b/src/test/java/com/rethinkdb/RethinkDBTest.java index df006aec..5915a936 100644 --- a/src/test/java/com/rethinkdb/RethinkDBTest.java +++ b/src/test/java/com/rethinkdb/RethinkDBTest.java @@ -5,7 +5,6 @@ import com.rethinkdb.model.MapObject; import com.rethinkdb.model.OptArgs; import com.rethinkdb.net.Connection; -import com.rethinkdb.net.Cursor; import net.jodah.concurrentunit.Waiter; import org.junit.*; import org.junit.rules.ExpectedException; diff --git a/src/test/java/com/rethinkdb/TestingCommon.java b/src/test/java/com/rethinkdb/TestingCommon.java index d1902246..d6d9ffbc 100644 --- a/src/test/java/com/rethinkdb/TestingCommon.java +++ b/src/test/java/com/rethinkdb/TestingCommon.java @@ -3,7 +3,6 @@ import com.rethinkdb.ast.ReqlAst; import com.rethinkdb.model.OptArgs; import com.rethinkdb.net.Connection; -import com.rethinkdb.net.Cursor; import java.time.Instant; import java.time.OffsetDateTime; From 27504124a848b6321d577e312ce1e23e35a79ec0 Mon Sep 17 00:00:00 2001 From: Adrian Todt Date: Fri, 14 Feb 2020 23:53:23 -0300 Subject: [PATCH 2/7] small changes --- src/main/java/com/rethinkdb/net/Connection.java | 4 +--- src/main/java/com/rethinkdb/net/ResponseHandler.java | 8 +++++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/main/java/com/rethinkdb/net/Connection.java b/src/main/java/com/rethinkdb/net/Connection.java index 78d6a930..3337441b 100644 --- a/src/main/java/com/rethinkdb/net/Connection.java +++ b/src/main/java/com/rethinkdb/net/Connection.java @@ -38,6 +38,7 @@ public class Connection implements Closeable { // network stuff private @Nullable SocketWrapper socket; + Set> tracked = ConcurrentHashMap.newKeySet(); // execution stuff private ExecutorService exec; @@ -255,7 +256,6 @@ private long newToken() { return nextToken.incrementAndGet(); } - // unused for some reason public void noreplyWait() { runQuery(Query.noreplyWait(newToken()), null); } @@ -298,8 +298,6 @@ void stop(long token) { runQueryNoreply(Query.stop(token)); } - Set> tracked = ConcurrentHashMap.newKeySet(); - public void loseTrackOf(ResponseHandler r) { tracked.add(r); } diff --git a/src/main/java/com/rethinkdb/net/ResponseHandler.java b/src/main/java/com/rethinkdb/net/ResponseHandler.java index 7569a21b..9948cdac 100644 --- a/src/main/java/com/rethinkdb/net/ResponseHandler.java +++ b/src/main/java/com/rethinkdb/net/ResponseHandler.java @@ -133,7 +133,13 @@ void onConnectionClosed() throws InterruptedException { private int emitData(final FluxSink sink) { List objects = (List) Converter.convertPseudotypes(firstRes.data, fmt); for (Object each : objects) { - sink.next(Util.convertToPojo(each, typeRef)); + if (firstRes.isAtom() && each instanceof List) { + for (Object o : ((List) each)) { + sink.next(Util.convertToPojo(o, typeRef)); + } + } else { + sink.next(Util.convertToPojo(each, typeRef)); + } } return objects.size(); } From 037429c67b454c66464f256a69064588048fc502 Mon Sep 17 00:00:00 2001 From: Adrian Todt Date: Sat, 15 Feb 2020 09:31:59 -0300 Subject: [PATCH 3/7] rewrite Connection, detached socket and response pump --- .../java/com/rethinkdb/net/Connection.java | 343 ++++++++---------- .../com/rethinkdb/net/ConnectionSocket.java | 35 ++ src/main/java/com/rethinkdb/net/Crypto.java | 38 +- .../net/DefaultConnectionFactory.java | 251 +++++++++++++ ...{Handshake.java => HandshakeProtocol.java} | 79 ++-- src/main/java/com/rethinkdb/net/Response.java | 8 +- .../com/rethinkdb/net/ResponseHandler.java | 4 +- .../java/com/rethinkdb/net/ResponsePump.java | 21 ++ .../java/com/rethinkdb/net/SocketWrapper.java | 191 ---------- 9 files changed, 511 insertions(+), 459 deletions(-) create mode 100644 src/main/java/com/rethinkdb/net/ConnectionSocket.java create mode 100644 src/main/java/com/rethinkdb/net/DefaultConnectionFactory.java rename src/main/java/com/rethinkdb/net/{Handshake.java => HandshakeProtocol.java} (87%) create mode 100644 src/main/java/com/rethinkdb/net/ResponsePump.java delete mode 100644 src/main/java/com/rethinkdb/net/SocketWrapper.java diff --git a/src/main/java/com/rethinkdb/net/Connection.java b/src/main/java/com/rethinkdb/net/Connection.java index 3337441b..86a62bcf 100644 --- a/src/main/java/com/rethinkdb/net/Connection.java +++ b/src/main/java/com/rethinkdb/net/Connection.java @@ -13,132 +13,113 @@ import javax.net.ssl.SSLContext; import java.io.Closeable; -import java.io.IOException; import java.io.InputStream; -import java.net.SocketAddress; -import java.util.Map; +import java.nio.ByteBuffer; import java.util.Set; -import java.util.concurrent.*; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; public class Connection implements Closeable { - // public immutable - public final String hostname; - public final int port; + private final ConnectionSocket.Factory socketFactory; + private final ResponsePump.Factory pumpFactory; + private final String hostname; + private final int port; + private final @Nullable SSLContext sslContext; + private final @Nullable Long timeout; + private final @Nullable String user; + private final @Nullable String password; private final AtomicLong nextToken = new AtomicLong(); + private final Set> tracked = ConcurrentHashMap.newKeySet(); + private final Lock writeLock = new ReentrantLock(); - // private mutable private @Nullable String dbname; - private @Nullable Long connectTimeout; - private @Nullable SSLContext sslContext; - private final Handshake handshake; + private @Nullable ConnectionSocket socket; + private @Nullable ResponsePump pump; - // network stuff - private @Nullable SocketWrapper socket; - - Set> tracked = ConcurrentHashMap.newKeySet(); - - // execution stuff - private ExecutorService exec; - private final Map> awaiters = new ConcurrentHashMap<>(); - private Exception awaiterException = null; - private final ReentrantLock lock = new ReentrantLock(); - - public Connection(Builder builder) { - dbname = builder.dbname; - if (builder.authKey != null && builder.user != null) { + public Connection(Builder c) { + if (c.authKey != null && c.user != null) { throw new ReqlDriverError("Either `authKey` or `user` can be used, but not both."); } - String user = builder.user != null ? builder.user : "admin"; - String password = builder.password != null ? builder.password : builder.authKey != null ? builder.authKey : ""; - handshake = new Handshake(user, password); - hostname = builder.hostname != null ? builder.hostname : "localhost"; - port = builder.port != null ? builder.port : 28015; - // is certFile provided? if so, it has precedence over SSLContext - sslContext = Crypto.handleCertfile(builder.certFile, builder.sslContext); - connectTimeout = builder.timeout; + this.socketFactory = c.socketFactory != null ? c.socketFactory : DefaultConnectionFactory.INSTANCE; + this.pumpFactory = c.pumpFactory != null ? c.pumpFactory : DefaultConnectionFactory.INSTANCE; + this.hostname = c.hostname != null ? c.hostname : "localhost"; + this.port = c.port != null ? c.port : 28015; + this.dbname = c.dbname; + this.sslContext = c.sslContext; + this.timeout = c.timeout; + this.user = c.user != null ? c.user : "admin"; + this.password = c.password != null ? c.password : c.authKey != null ? c.authKey : ""; } public @Nullable String db() { return dbname; } - public void connect() { - connect(null); + public void use(String db) { + dbname = db; } - public Connection reconnect() { - return reconnect(false, null); + public boolean isOpen() { + return socket != null && socket.isOpen() && pump != null; } - public Connection reconnect(boolean noreplyWait, @Nullable Long timeout) { - if (timeout == null) { - timeout = connectTimeout; + public Connection connect() { + if (socket != null) { + throw new ReqlDriverError("Client already connected!"); } - close(noreplyWait); - connect(timeout); + ConnectionSocket socket = socketFactory.newSocket(hostname, port, sslContext, timeout); + this.socket = socket; + + // execute RethinkDB handshake + HandshakeProtocol handshake = HandshakeProtocol.start(user, password); + + // initialize handshake + ByteBuffer toWrite = handshake.toSend(); + // Sit in the handshake until it's completed. Exceptions will be thrown if + // anything goes wrong. + while (!handshake.isFinished()) { + if (toWrite != null) { + socket.write(toWrite); + } + String serverMsg = socket.readCString(timeout); + handshake = handshake.nextState(serverMsg); + toWrite = handshake.toSend(); + } + + pump = pumpFactory.newPump(socket); return this; } - private void connect(@Nullable Long timeout) { - final SocketWrapper sock = new SocketWrapper(hostname, port, sslContext, timeout != null ? timeout : connectTimeout); - sock.connect(handshake); - socket = sock; - - // start response pump - exec = Executors.newSingleThreadExecutor(); - exec.submit(() -> { - // pump responses until canceled - while (true) { - // validate socket is open - if (!isOpen()) { - awaiterException = new IOException("The socket is closed, exiting response pump."); - this.close(); - break; - } + public Connection reconnect() { + return reconnect(false); + } - // read response and send it to whoever is waiting, if anyone - try { - if (socket == null) { - throw new ReqlDriverError("No socket available."); - } - final Response response = socket.read(); - final CompletableFuture awaiter = awaiters.remove(response.token); - if (awaiter != null) { - awaiter.complete(response); - } - } catch (Exception e) { - awaiterException = e; - this.close(); - break; - } - } - }); + public Connection reconnect(boolean noreplyWait) { + close(noreplyWait); + connect(); + return this; } - @Nullable - public Integer clientPort() { - if (socket != null) { - return socket.clientPort(); - } - return null; + public void noreplyWait() { + runQuery(Query.noreplyWait(nextToken.incrementAndGet()), null); } - @Nullable - public SocketAddress clientAddress() { - if (socket != null) { - return socket.clientAddress(); + public Flux run(ReqlAst term, OptArgs optArgs, @Nullable TypeReference typeRef) { + handleOptArgs(optArgs); + Query q = Query.start(nextToken.incrementAndGet(), term, optArgs); + if (optArgs.containsKey("noreply")) { + throw new ReqlDriverError("Don't provide the noreply option as an optarg. Use `.runNoReply` instead of `.run`"); } - return null; + return runQuery(q, typeRef); } - public boolean isOpen() { - if (socket != null) { - return socket.isOpen(); - } - return false; + public void runNoReply(ReqlAst term, OptArgs optArgs) { + handleOptArgs(optArgs); + optArgs.with("noreply", true); + runQueryNoreply(Query.start(nextToken.incrementAndGet(), term, optArgs)); } @Override @@ -149,9 +130,7 @@ public void close() { public void close(boolean shouldNoreplyWait) { // disconnect try { - if (shouldNoreplyWait) { - runQuery(Query.noreplyWait(newToken()), null); - } + noreplyWait(); } finally { // reset token nextToken.set(0); @@ -164,20 +143,9 @@ public void close(boolean shouldNoreplyWait) { } } - // handle current awaiters - this.awaiters.values().forEach(awaiter -> { - // what happened? - if (this.awaiterException != null) { // an exception - awaiter.completeExceptionally(this.awaiterException); - } else { // probably canceled - awaiter.cancel(true); - } - }); - awaiters.clear(); - // terminate response pump - if (exec != null && !exec.isShutdown()) { - exec.shutdown(); + if (pump != null) { + pump.shutdownPump(); } // close the socket @@ -187,14 +155,30 @@ public void close(boolean shouldNoreplyWait) { } } - public void use(String db) { - dbname = db; + // package-private methods + + void sendStop(long token) { + // While the server does reply to the stop request, we ignore that reply. + // This works because the response pump in `connect` ignores replies for which + // no waiter exists. + runQueryNoreply(Query.stop(token)); } - public @Nullable Long timeout() { - return connectTimeout; + Mono sendContinue(long token) { + return sendQuery(Query.continue_(token)); + } + + void loseTrackOf(ResponseHandler r) { + tracked.add(r); + } + + void keepTrackOf(ResponseHandler r) { + tracked.remove(r); } + // private methods + + /** * Writes a query and returns a completable future. * Said completable future value will eventually be set by the runnable response pump (see {@link #connect}). @@ -202,25 +186,23 @@ public void use(String db) { * @param query the query to execute. * @return a completable future. */ - private CompletableFuture sendQuery(Query query) { - // check if response pump is running - if (!exec.isShutdown() && !exec.isTerminated()) { - final CompletableFuture awaiter = new CompletableFuture<>(); - awaiters.put(query.token, awaiter); - try { - lock.lock(); - if (socket == null) { - throw new ReqlDriverError("No socket available."); - } - socket.write(query.serialize()); - return awaiter.toCompletableFuture(); - } finally { - lock.unlock(); - } + private Mono sendQuery(Query query) { + if (socket == null || !socket.isOpen()) { + throw new ReqlDriverError("Client not connected."); + } + + if (pump == null) { + throw new ReqlDriverError("Response pump is not running."); } - // shouldn't be here - throw new ReqlDriverError("Can't write query because response pump is not running."); + Mono response = pump.await(query.token); + try { + writeLock.lock(); + socket.write(query.serialize()); + return response; + } finally { + writeLock.unlock(); + } } /** @@ -229,103 +211,64 @@ private CompletableFuture sendQuery(Query query) { * @param query the query to execute. */ private void runQueryNoreply(Query query) { - // check if response pump is running - if (!exec.isShutdown() && !exec.isTerminated()) { - try { - lock.lock(); - if (socket == null) { - throw new ReqlDriverError("No socket available."); - } - socket.write(query.serialize()); - return; - } finally { - lock.unlock(); - } + if (socket == null || !socket.isOpen()) { + throw new ReqlDriverError("Client not connected."); + } + + if (pump == null) { + throw new ReqlDriverError("Response pump is not running."); } - // shouldn't be here - throw new ReqlDriverError("Can't write query because response pump is not running."); + try { + writeLock.lock(); + socket.write(query.serialize()); + } finally { + writeLock.unlock(); + } } private Flux runQuery(Query query, @Nullable TypeReference typeRef) { - return Mono.fromFuture(sendQuery(query)).onErrorMap(ReqlDriverError::new) + return sendQuery(query).onErrorMap(ReqlDriverError::new) .flatMapMany(res -> Flux.create(new ResponseHandler<>(this, query, res, typeRef))); } - private long newToken() { - return nextToken.incrementAndGet(); - } - - public void noreplyWait() { - runQuery(Query.noreplyWait(newToken()), null); - } - - private void setDefaultDB(OptArgs globalOpts) { - if (!globalOpts.containsKey("db") && dbname != null) { + private void handleOptArgs(OptArgs optArgs) { + if (!optArgs.containsKey("db") && dbname != null) { // Only override the db global arg if the user hasn't // specified one already and one is specified on the connection - globalOpts.with("db", dbname); + optArgs.with("db", dbname); } - if (globalOpts.containsKey("db")) { + if (optArgs.containsKey("db")) { // The db arg must be wrapped in a db ast object - globalOpts.with("db", new Db(Arguments.make(globalOpts.get("db")))); - } - } - - public Flux run(ReqlAst term, OptArgs globalOpts, @Nullable TypeReference typeRef) { - setDefaultDB(globalOpts); - Query q = Query.start(newToken(), term, globalOpts); - if (globalOpts.containsKey("noreply")) { - throw new ReqlDriverError("Don't provide the noreply option as an optarg. Use `.runNoReply` instead of `.run`"); + optArgs.with("db", new Db(Arguments.make(optArgs.get("db")))); } - return runQuery(q, typeRef); - } - - public void runNoReply(ReqlAst term, OptArgs globalOpts) { - setDefaultDB(globalOpts); - globalOpts.with("noreply", true); - runQueryNoreply(Query.start(newToken(), term, globalOpts)); } - CompletableFuture continueResponse(long token) { - return sendQuery(Query.continue_(token)); - } + // builder - void stop(long token) { - // While the server does reply to the stop request, we ignore that reply. - // This works because the response pump in `connect` ignores replies for which - // no waiter exists. - runQueryNoreply(Query.stop(token)); - } - - public void loseTrackOf(ResponseHandler r) { - tracked.add(r); - } - - public void keepTrackOf(ResponseHandler r) { - tracked.remove(r); - } /** - * Connection.Builder should be used to build a Connection instance. + * Builder should be used to build a Connection instance. */ - public static class Builder implements Cloneable { + public static class Builder { + private @Nullable ConnectionSocket.Factory socketFactory; + private @Nullable ResponsePump.Factory pumpFactory; private @Nullable String hostname; private @Nullable Integer port; private @Nullable String dbname; - private @Nullable InputStream certFile; private @Nullable SSLContext sslContext; private @Nullable Long timeout; private @Nullable String authKey; private @Nullable String user; private @Nullable String password; - public Builder clone() throws CloneNotSupportedException { - Builder c = (Builder) super.clone(); + public Builder copyOf() { + Builder c = new Builder(); + c.socketFactory = socketFactory; + c.pumpFactory = pumpFactory; c.hostname = hostname; c.port = port; c.dbname = dbname; - c.certFile = certFile; c.sslContext = sslContext; c.timeout = timeout; c.authKey = authKey; @@ -334,6 +277,16 @@ public Builder clone() throws CloneNotSupportedException { return c; } + public Builder socketFactory(ConnectionSocket.Factory factory) { + socketFactory = factory; + return this; + } + + public Builder pumpFactory(ResponsePump.Factory factory) { + pumpFactory = factory; + return this; + } + public Builder hostname(String val) { hostname = val; return this; @@ -361,7 +314,7 @@ public Builder user(String user, String password) { } public Builder certFile(InputStream val) { - certFile = val; + sslContext = Crypto.readCertFile(val); return this; } diff --git a/src/main/java/com/rethinkdb/net/ConnectionSocket.java b/src/main/java/com/rethinkdb/net/ConnectionSocket.java new file mode 100644 index 00000000..bf8158dd --- /dev/null +++ b/src/main/java/com/rethinkdb/net/ConnectionSocket.java @@ -0,0 +1,35 @@ +package com.rethinkdb.net; + +import org.jetbrains.annotations.NotNull; +import reactor.util.annotation.Nullable; + +import javax.net.ssl.SSLContext; +import java.io.Closeable; +import java.io.IOError; +import java.io.IOException; +import java.nio.ByteBuffer; + +public interface ConnectionSocket extends Closeable { + interface Factory { + ConnectionSocket newSocket(@NotNull String hostname, + int port, + @Nullable SSLContext sslContext, + @Nullable Long timeoutMs); + } + + boolean isOpen(); + + void close(); + + void write(@NotNull ByteBuffer buffer); + + @NotNull ByteBuffer read(int size); + + /** + * Reads a null-terminated string, under a timeout. If time runs out, it throws instead. + * + * @param timeoutMs the timeout, in milliseconds + * @return the string. + */ + @NotNull String readCString(@Nullable Long timeoutMs); +} diff --git a/src/main/java/com/rethinkdb/net/Crypto.java b/src/main/java/com/rethinkdb/net/Crypto.java index 6c3287bb..3ba65c3a 100644 --- a/src/main/java/com/rethinkdb/net/Crypto.java +++ b/src/main/java/com/rethinkdb/net/Crypto.java @@ -139,27 +139,23 @@ static byte[] fromBase64(String string) { return decoder.decode(string); } - static @Nullable SSLContext handleCertfile(@Nullable InputStream certFile, @Nullable SSLContext sslContext) { - if (certFile != null) { - try { - final CertificateFactory cf = CertificateFactory.getInstance("X.509"); - final X509Certificate caCert = (X509Certificate) cf.generateCertificate(certFile); - - final TrustManagerFactory tmf = TrustManagerFactory - .getInstance(TrustManagerFactory.getDefaultAlgorithm()); - KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); - ks.load(null); // You don't need the KeyStore instance to come from a file. - ks.setCertificateEntry("caCert", caCert); - tmf.init(ks); - - final SSLContext ssc = SSLContext.getInstance(DEFAULT_SSL_PROTOCOL); - ssc.init(null, tmf.getTrustManagers(), null); - return ssc; - } catch (IOException | CertificateException | NoSuchAlgorithmException | KeyStoreException | KeyManagementException e) { - throw new ReqlDriverError(e); - } - } else { - return sslContext; + static SSLContext readCertFile(@Nullable InputStream certFile) { + try { + final CertificateFactory cf = CertificateFactory.getInstance("X.509"); + final X509Certificate caCert = (X509Certificate) cf.generateCertificate(certFile); + + final TrustManagerFactory tmf = TrustManagerFactory + .getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + ks.load(null); // You don't need the KeyStore instance to come from a file. + ks.setCertificateEntry("caCert", caCert); + tmf.init(ks); + + final SSLContext ssc = SSLContext.getInstance(DEFAULT_SSL_PROTOCOL); + ssc.init(null, tmf.getTrustManagers(), null); + return ssc; + } catch (IOException | CertificateException | NoSuchAlgorithmException | KeyStoreException | KeyManagementException e) { + throw new ReqlDriverError(e); } } } diff --git a/src/main/java/com/rethinkdb/net/DefaultConnectionFactory.java b/src/main/java/com/rethinkdb/net/DefaultConnectionFactory.java new file mode 100644 index 00000000..d2bbb896 --- /dev/null +++ b/src/main/java/com/rethinkdb/net/DefaultConnectionFactory.java @@ -0,0 +1,251 @@ +package com.rethinkdb.net; + +import com.rethinkdb.gen.exc.ReqlDriverError; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import reactor.core.publisher.Mono; + +import javax.net.SocketFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; +import java.io.DataInputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; + +public class DefaultConnectionFactory implements ConnectionSocket.Factory, ResponsePump.Factory { + public static final DefaultConnectionFactory INSTANCE = new DefaultConnectionFactory(); + + private DefaultConnectionFactory() { + } + + @Override + public ConnectionSocket newSocket(@NotNull String hostname, int port, SSLContext sslContext, Long timeoutMs) { + SocketWrapper s = new SocketWrapper(hostname, port, sslContext, timeoutMs); + s.connect(); + return s; + } + + @Override + public ResponsePump newPump(@NotNull ConnectionSocket socket) { + return new ThreadResponsePump(socket); + } + + private static class SocketWrapper implements ConnectionSocket { + // networking stuff + private Socket socket; + private SocketFactory socketFactory = SocketFactory.getDefault(); + private SSLSocket sslSocket; + private OutputStream writeStream; + private DataInputStream readStream; + + // options + private SSLContext sslContext; + private Long timeoutMs; + private final String hostname; + private final int port; + + SocketWrapper(String hostname, + int port, + SSLContext sslContext, + Long timeoutMs) { + this.hostname = hostname; + this.port = port; + this.sslContext = sslContext; + this.timeoutMs = timeoutMs; + } + + void connect() { + Long deadline = timeoutMs == null ? null : System.currentTimeMillis() + timeoutMs; + try { + // establish connection + final InetSocketAddress addr = new InetSocketAddress(hostname, port); + socket = socketFactory.createSocket(); + socket.connect(addr, timeoutMs == null ? 0 : timeoutMs.intValue()); + socket.setTcpNoDelay(true); + socket.setKeepAlive(true); + + // should we secure the connection? + if (sslContext != null) { + socketFactory = sslContext.getSocketFactory(); + SSLSocketFactory sslSf = (SSLSocketFactory) socketFactory; + sslSocket = (SSLSocket) sslSf.createSocket(socket, + socket.getInetAddress().getHostAddress(), + socket.getPort(), + true); + + // replace input/output streams + readStream = new DataInputStream(sslSocket.getInputStream()); + writeStream = sslSocket.getOutputStream(); + + // execute SSL handshake + sslSocket.startHandshake(); + } else { + writeStream = socket.getOutputStream(); + readStream = new DataInputStream(socket.getInputStream()); + } + } catch (IOException e) { + throw new ReqlDriverError("Connection timed out.", e); + } + } + + @Override + public void write(ByteBuffer buffer) { + try { + buffer.flip(); + writeStream.write(buffer.array()); + } catch (IOException e) { + throw new ReqlDriverError(e); + } + } + + @NotNull + @Override + public String readCString(@Nullable Long deadline) { + try { + final StringBuilder sb = new StringBuilder(); + char c; + while ((c = (char) this.readStream.readByte()) != '\0') { + // is there a deadline? + if (deadline != null) { + // have we timed-out? + if (deadline < System.currentTimeMillis()) { // reached time-out + throw new ReqlDriverError("Connection timed out."); + } + } + sb.append(c); + } + + return sb.toString(); + } catch (IOException e) { + throw new ReqlDriverError(e); + } + } + + @NotNull + @Override + public ByteBuffer read(int bufsize) { + try { + byte[] buf = new byte[bufsize]; + int bytesRead = 0; + while (bytesRead < bufsize) { + final int res = this.readStream.read(buf, bytesRead, bufsize - bytesRead); + if (res == -1) { + throw new ReqlDriverError("Reached the end of the read stream."); + } else { + bytesRead += res; + } + } + return ByteBuffer.wrap(buf).order(ByteOrder.LITTLE_ENDIAN); + } catch (IOException e) { + throw new ReqlDriverError(e); + } + } + + @Nullable + Integer clientPort() { + if (socket != null) { + return socket.getLocalPort(); + } + return null; + } + + @Nullable + SocketAddress clientAddress() { + if (socket != null) { + return socket.getLocalSocketAddress(); + } + return null; + } + + @Override + public boolean isOpen() { + return socket != null && (socket.isConnected() && !socket.isClosed()); + } + + @Override + public void close() { + // if needed, disconnect from server + if (socket != null && isOpen()) { + try { + socket.close(); + } catch (IOException e) { + throw new ReqlDriverError(e); + } + } + } + + @Override + public String toString() { + return "ConnectionSocket(" + hostname + ':' + port + ')'; + } + } + + private static class ThreadResponsePump implements ResponsePump { + private final Thread thread; + private Map> awaiting = new ConcurrentHashMap<>(); + + public ThreadResponsePump(ConnectionSocket socket) { + this.thread = new Thread(() -> { + // pump responses until interrupted + while (true) { + // validate socket is open + if (!socket.isOpen()) { + shutdown(new IOException("Socket closed, exiting response pump.")); + return; + } + + if (awaiting == null) { + return; + } + + // read response and send it to whoever is waiting, if anyone + try { + final Response response = Response.readFrom(socket); + final CompletableFuture awaiter = awaiting.remove(response.token); + if (awaiter != null) { + awaiter.complete(response); + } + } catch (Exception e) { + shutdown(e); + return; + } + } + }, "RethinkDB-" + socket + "-ResponsePump"); + + } + + @Override + public Mono await(long token) { + CompletableFuture future = new CompletableFuture<>(); + if (awaiting == null) { + throw new ReqlDriverError("Response pump closed."); + } + awaiting.put(token, future); + return Mono.fromFuture(future); + } + + private void shutdown(Exception e) { + Map> awaiting = this.awaiting; + this.awaiting = null; + thread.interrupt(); + awaiting.forEach((token, future) -> { + future.completeExceptionally(e); + }); + } + + @Override + public void shutdownPump() { + shutdown(new ReqlDriverError("Response pump closed.")); + } + } +} diff --git a/src/main/java/com/rethinkdb/net/Handshake.java b/src/main/java/com/rethinkdb/net/HandshakeProtocol.java similarity index 87% rename from src/main/java/com/rethinkdb/net/Handshake.java rename to src/main/java/com/rethinkdb/net/HandshakeProtocol.java index 723d241a..b548df46 100644 --- a/src/main/java/com/rethinkdb/net/Handshake.java +++ b/src/main/java/com/rethinkdb/net/HandshakeProtocol.java @@ -14,30 +14,28 @@ import static com.rethinkdb.net.Util.toJSON; import static com.rethinkdb.net.Util.toUTF8; -public class Handshake { - static final Version VERSION = Version.V1_0; - static final Long SUB_PROTOCOL_VERSION = 0L; - static final Protocol PROTOCOL = Protocol.JSON; - - private static final String CLIENT_KEY = "Client Key"; - private static final String SERVER_KEY = "Server Key"; - - private final String username; - private final String password; - private ProtocolState state; - - public Handshake(String username, String password) { - this.username = username; - this.password = password; - this.state = new InitialState(username, password); - } +abstract class HandshakeProtocol { + public static final Version VERSION = Version.V1_0; + public static final Long SUB_PROTOCOL_VERSION = 0L; + public static final Protocol PROTOCOL = Protocol.JSON; + + public static final String CLIENT_KEY = "Client Key"; + public static final String SERVER_KEY = "Server Key"; - public ByteBuffer nextMessage(String response) { - this.state = this.state.nextState(response); - return this.state.toSend(); + public static HandshakeProtocol start(String username, String password) { + return new InitialState(username, password); } - private void throwIfFailure(Map json) { + private HandshakeProtocol() {} + + public abstract HandshakeProtocol nextState(String response); + + @Nullable + public abstract ByteBuffer toSend(); + + public abstract boolean isFinished(); + + public static void throwIfFailure(Map json) { if (!(boolean) json.get("success")) { Long errorCode = (Long) json.get("error_code"); if (errorCode >= 10 && errorCode <= 20) { @@ -48,20 +46,7 @@ private void throwIfFailure(Map json) { } } - public void reset() { - this.state = new InitialState(this.username, this.password); - } - - private interface ProtocolState { - ProtocolState nextState(String response); - - @Nullable - ByteBuffer toSend(); - - boolean isFinished(); - } - - private class InitialState implements ProtocolState { + public static class InitialState extends HandshakeProtocol { private final String nonce; private final String username; private final byte[] password; @@ -73,7 +58,7 @@ private class InitialState implements ProtocolState { } @Override - public ProtocolState nextState(String response) { + public HandshakeProtocol nextState(String response) { if (response != null) { throw new ReqlDriverError("Unexpected response"); } @@ -110,7 +95,7 @@ public boolean isFinished() { } } - private class WaitingForProtocolRange implements ProtocolState { + public static class WaitingForProtocolRange extends HandshakeProtocol { private final String nonce; private final ByteBuffer message; private final ScramAttributes clientFirstMessageBare; @@ -128,7 +113,7 @@ private class WaitingForProtocolRange implements ProtocolState { } @Override - public ProtocolState nextState(String response) { + public HandshakeProtocol nextState(String response) { Map json = toJSON(response); throwIfFailure(json); long minVersion = (long) json.get("min_protocol_version"); @@ -152,7 +137,7 @@ public boolean isFinished() { } } - private class WaitingForAuthResponse implements ProtocolState { + public static class WaitingForAuthResponse extends HandshakeProtocol { private final String nonce; private final byte[] password; private final ScramAttributes clientFirstMessageBare; @@ -165,7 +150,7 @@ private class WaitingForAuthResponse implements ProtocolState { } @Override - public ProtocolState nextState(String response) { + public HandshakeProtocol nextState(String response) { Map json = toJSON(response); throwIfFailure(json); String serverFirstMessage = (String) json.get("authentication"); @@ -227,9 +212,9 @@ public boolean isFinished() { } } - private class HandshakeSuccess implements ProtocolState { + public static class HandshakeSuccess extends HandshakeProtocol { @Override - public ProtocolState nextState(String response) { + public HandshakeProtocol nextState(String response) { return this; } @@ -244,7 +229,7 @@ public boolean isFinished() { } } - private class WaitingForAuthSuccess implements ProtocolState { + public static class WaitingForAuthSuccess extends HandshakeProtocol { private final byte[] serverSignature; private final ByteBuffer message; @@ -254,7 +239,7 @@ public WaitingForAuthSuccess(byte[] serverSignature, ByteBuffer message) { } @Override - public ProtocolState nextState(String response) { + public HandshakeProtocol nextState(String response) { Map json = toJSON(response); throwIfFailure(json); ScramAttributes auth = ScramAttributes @@ -276,11 +261,7 @@ public boolean isFinished() { } } - public boolean isFinished() { - return this.state.isFinished(); - } - - static class ScramAttributes { + public static class ScramAttributes { @Nullable String _authIdentity; // a @Nullable String _username; // n @Nullable String _nonce; // r diff --git a/src/main/java/com/rethinkdb/net/Response.java b/src/main/java/com/rethinkdb/net/Response.java index c7ca1836..34812c7c 100644 --- a/src/main/java/com/rethinkdb/net/Response.java +++ b/src/main/java/com/rethinkdb/net/Response.java @@ -13,6 +13,7 @@ import org.slf4j.LoggerFactory; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -48,7 +49,12 @@ private Response(long token, this.errorType = errorType; } - public static Response parseFrom(long token, ByteBuffer buf) { + public static Response readFrom(ConnectionSocket socket) { + final ByteBuffer header = socket.read(12); + final long token = header.getLong(); + final int responseLength = header.getInt(); + final ByteBuffer buf = socket.read(responseLength).order(ByteOrder.LITTLE_ENDIAN); + if (Response.logger.isDebugEnabled()) { Response.logger.debug( "JSON Recv: Token: {} {}", token, Util.bufferToString(buf)); diff --git a/src/main/java/com/rethinkdb/net/ResponseHandler.java b/src/main/java/com/rethinkdb/net/ResponseHandler.java index 9948cdac..a15b6850 100644 --- a/src/main/java/com/rethinkdb/net/ResponseHandler.java +++ b/src/main/java/com/rethinkdb/net/ResponseHandler.java @@ -64,7 +64,7 @@ public void accept(final FluxSink sink) { this.sink.set(sink); sink.onCancel(() -> { connection.loseTrackOf(this); - connection.stop(firstRes.token); + connection.sendStop(firstRes.token); }); sink.onDispose(() -> connection.loseTrackOf(this)); connection.keepTrackOf(this); @@ -83,7 +83,7 @@ private void onRequest(FluxSink sink, long amount) { // great, we should make a CONTINUE request. // TODO isolate this into methods - Mono.fromFuture(connection.continueResponse(lastRes.token)).subscribe( + connection.sendContinue(lastRes.token).subscribe( nextRes -> { // Okay, let's process this response. boolean shouldContinue = currentResponse.compareAndSet(lastRes, nextRes); if (nextRes.isSequence()) { diff --git a/src/main/java/com/rethinkdb/net/ResponsePump.java b/src/main/java/com/rethinkdb/net/ResponsePump.java new file mode 100644 index 00000000..7b4e09db --- /dev/null +++ b/src/main/java/com/rethinkdb/net/ResponsePump.java @@ -0,0 +1,21 @@ +package com.rethinkdb.net; + +import org.jetbrains.annotations.NotNull; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +import javax.net.ssl.SSLContext; +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.CompletableFuture; + +public interface ResponsePump { + interface Factory { + ResponsePump newPump(@NotNull ConnectionSocket socket); + } + + Mono await(long token); + + void shutdownPump(); +} diff --git a/src/main/java/com/rethinkdb/net/SocketWrapper.java b/src/main/java/com/rethinkdb/net/SocketWrapper.java deleted file mode 100644 index 1dedf6ab..00000000 --- a/src/main/java/com/rethinkdb/net/SocketWrapper.java +++ /dev/null @@ -1,191 +0,0 @@ -package com.rethinkdb.net; - -import com.rethinkdb.gen.exc.ReqlDriverError; -import org.jetbrains.annotations.Nullable; - -import javax.net.SocketFactory; -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLSocket; -import javax.net.ssl.SSLSocketFactory; -import java.io.DataInputStream; -import java.io.IOException; -import java.io.OutputStream; -import java.net.InetSocketAddress; -import java.net.Socket; -import java.net.SocketAddress; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -class SocketWrapper { - // networking stuff - private Socket socket = null; - private SocketFactory socketFactory = SocketFactory.getDefault(); - private SSLSocket sslSocket = null; - private OutputStream writeStream = null; - private DataInputStream readStream = null; - - // options - private SSLContext sslContext; - private Long timeout; - private final String hostname; - private final int port; - - SocketWrapper(String hostname, - int port, SSLContext sslContext, - Long timeout) { - this.hostname = hostname; - this.port = port; - this.sslContext = sslContext; - this.timeout = timeout; - } - - /** - * @param handshake - */ - void connect(Handshake handshake) { - Long deadline = timeout == null ? null : System.currentTimeMillis() + timeout; - try { - handshake.reset(); - // establish connection - final InetSocketAddress addr = new InetSocketAddress(hostname, port); - socket = socketFactory.createSocket(); - socket.connect(addr, timeout == null ? 0 : timeout.intValue()); - socket.setTcpNoDelay(true); - socket.setKeepAlive(true); - - // should we secure the connection? - if (sslContext != null) { - socketFactory = sslContext.getSocketFactory(); - SSLSocketFactory sslSf = (SSLSocketFactory) socketFactory; - sslSocket = (SSLSocket) sslSf.createSocket(socket, - socket.getInetAddress().getHostAddress(), - socket.getPort(), - true); - - // replace input/output streams - readStream = new DataInputStream(sslSocket.getInputStream()); - writeStream = sslSocket.getOutputStream(); - - // execute SSL handshake - sslSocket.startHandshake(); - } else { - writeStream = socket.getOutputStream(); - readStream = new DataInputStream(socket.getInputStream()); - } - - // execute RethinkDB handshake - - // initialize handshake - ByteBuffer toWrite = handshake.nextMessage(null); - // Sit in the handshake until it's completed. Exceptions will be thrown if - // anything goes wrong. - while (!handshake.isFinished()) { - if (toWrite != null) { - write(toWrite); - } - String serverMsg = readNullTerminatedString(deadline); - toWrite = handshake.nextMessage(serverMsg); - } - } catch (IOException e) { - throw new ReqlDriverError("Connection timed out.", e); - } - } - - void write(ByteBuffer buffer) { - try { - buffer.flip(); - writeStream.write(buffer.array()); - } catch (IOException e) { - throw new ReqlDriverError(e); - } - } - - /** - * Tries to read a null-terminated string from the socket. This operation may timeout if a timeout is specified. - * - * @param deadline an optional timeout. - * @return a string. - * @throws IOException - */ - private String readNullTerminatedString(@Nullable Long deadline) throws IOException { - final StringBuilder sb = new StringBuilder(); - char c; - while ((c = (char) this.readStream.readByte()) != '\0') { - // is there a deadline? - if (deadline != null) { - // have we timed-out? - if (deadline < System.currentTimeMillis()) { // reached time-out - throw new ReqlDriverError("Connection timed out."); - } - } - sb.append(c); - } - - return sb.toString(); - } - - /** - * Tries to read a {@link Response} from the socket. This operation is blocking. - * - * @return a {@link Response}. - * @throws IOException - */ - Response read() throws IOException { - final ByteBuffer header = readBytesToBuffer(12); - final long token = header.getLong(); - final int responseLength = header.getInt(); - return Response.parseFrom(token, readBytesToBuffer(responseLength).order(ByteOrder.LITTLE_ENDIAN)); - } - - private ByteBuffer readBytesToBuffer(int bufsize) throws IOException { - byte[] buf = new byte[bufsize]; - int bytesRead = 0; - while (bytesRead < bufsize) { - final int res = this.readStream.read(buf, bytesRead, bufsize - bytesRead); - if (res == -1) { - throw new ReqlDriverError("Reached the end of the read stream."); - } else { - bytesRead += res; - } - } - return ByteBuffer.wrap(buf).order(ByteOrder.LITTLE_ENDIAN); - } - - @Nullable - Integer clientPort() { - if (socket != null) { - return socket.getLocalPort(); - } - return null; - } - - @Nullable - SocketAddress clientAddress() { - if (socket != null) { - return socket.getLocalSocketAddress(); - } - return null; - } - - /** - * Tells whether we have a working connection or not. - * - * @return true if connection is connected and open, false otherwise. - */ - boolean isOpen() { - return socket != null && (socket.isConnected() && !socket.isClosed()); - } - - /** - * Close connection. - */ - void close() { - // if needed, disconnect from server - if (socket != null && isOpen()) - try { - socket.close(); - } catch (IOException e) { - throw new ReqlDriverError(e); - } - } -} From 4892bd47da0e126ed805446249f33c4433fccdd6 Mon Sep 17 00:00:00 2001 From: Adrian Todt Date: Sat, 15 Feb 2020 10:47:08 -0300 Subject: [PATCH 4/7] Minor inlining --- src/main/java/com/rethinkdb/ast/Query.java | 6 +- src/main/java/com/rethinkdb/net/Crypto.java | 12 +-- .../com/rethinkdb/net/HandshakeProtocol.java | 83 +++++++++++++------ src/main/java/com/rethinkdb/net/Response.java | 8 +- src/main/java/com/rethinkdb/net/Util.java | 27 ------ 5 files changed, 73 insertions(+), 63 deletions(-) diff --git a/src/main/java/com/rethinkdb/ast/Query.java b/src/main/java/com/rethinkdb/ast/Query.java index 50225918..fb8be6ef 100644 --- a/src/main/java/com/rethinkdb/ast/Query.java +++ b/src/main/java/com/rethinkdb/ast/Query.java @@ -11,6 +11,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; @@ -67,12 +68,13 @@ public ByteBuffer serialize() { } String queryJson = RethinkDB.getInternalMapper().writeValueAsString(queryArr); byte[] queryBytes = queryJson.getBytes(StandardCharsets.UTF_8); - ByteBuffer bb = Util.leByteBuffer(Long.BYTES + Integer.BYTES + queryBytes.length) + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Integer.BYTES + queryBytes.length) + .order(ByteOrder.LITTLE_ENDIAN) .putLong(token) .putInt(queryBytes.length) .put(queryBytes); logger.trace("JSON Send: Token: {} {}", token, queryJson); - return bb; + return buffer; } catch (IOException e) { throw new ReqlRuntimeError(e); } diff --git a/src/main/java/com/rethinkdb/net/Crypto.java b/src/main/java/com/rethinkdb/net/Crypto.java index 3ba65c3a..6a2e4eee 100644 --- a/src/main/java/com/rethinkdb/net/Crypto.java +++ b/src/main/java/com/rethinkdb/net/Crypto.java @@ -11,6 +11,7 @@ import javax.net.ssl.TrustManagerFactory; import java.io.IOException; import java.io.InputStream; +import java.nio.charset.StandardCharsets; import java.security.*; import java.security.cert.CertificateException; import java.security.cert.CertificateFactory; @@ -21,9 +22,6 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import static com.rethinkdb.net.Util.fromUTF8; -import static com.rethinkdb.net.Util.toUTF8; - class Crypto { private static final String DEFAULT_SSL_PROTOCOL = "TLSv1.2"; private static final String HMAC_SHA_256 = "HmacSHA256"; @@ -90,7 +88,7 @@ static byte[] hmac(byte[] key, String string) { Mac mac = Mac.getInstance(HMAC_SHA_256); SecretKeySpec secretKey = new SecretKeySpec(key, HMAC_SHA_256); mac.init(secretKey); - return mac.doFinal(toUTF8(string)); + return mac.doFinal(string.getBytes(StandardCharsets.UTF_8)); } catch (InvalidKeyException | NoSuchAlgorithmException e) { throw new ReqlDriverError(e); } @@ -102,7 +100,9 @@ static byte[] pbkdf2(byte[] password, byte[] salt, Integer iterationCount) { return cachedValue; } final PBEKeySpec spec = new PBEKeySpec( - fromUTF8(password).toCharArray(), salt, iterationCount, 256); + new String(password, StandardCharsets.UTF_8).toCharArray(), + salt, iterationCount, 256 + ); final SecretKeyFactory skf; try { skf = SecretKeyFactory.getInstance(PBKDF2_ALGORITHM); @@ -132,7 +132,7 @@ static byte[] xor(byte[] a, byte[] b) { } static String toBase64(byte[] bytes) { - return fromUTF8(encoder.encode(bytes)); + return new String(encoder.encode(bytes), StandardCharsets.UTF_8); } static byte[] fromBase64(String string) { diff --git a/src/main/java/com/rethinkdb/net/HandshakeProtocol.java b/src/main/java/com/rethinkdb/net/HandshakeProtocol.java index b548df46..e2b0915f 100644 --- a/src/main/java/com/rethinkdb/net/HandshakeProtocol.java +++ b/src/main/java/com/rethinkdb/net/HandshakeProtocol.java @@ -7,13 +7,17 @@ import org.jetbrains.annotations.Nullable; import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.util.Map; import static com.rethinkdb.net.Crypto.*; import static com.rethinkdb.net.Util.toJSON; -import static com.rethinkdb.net.Util.toUTF8; +/** + * Internal class used by {@link Connection#connect()} to do a proper handshake with the server. + */ abstract class HandshakeProtocol { public static final Version VERSION = Version.V1_0; public static final Long SUB_PROTOCOL_VERSION = 0L; @@ -26,7 +30,8 @@ public static HandshakeProtocol start(String username, String password) { return new InitialState(username, password); } - private HandshakeProtocol() {} + private HandshakeProtocol() { + } public abstract HandshakeProtocol nextState(String response); @@ -53,7 +58,7 @@ public static class InitialState extends HandshakeProtocol { InitialState(String username, String password) { this.username = username; - this.password = toUTF8(password); + this.password = password.getBytes(StandardCharsets.UTF_8); this.nonce = makeNonce(); } @@ -66,18 +71,20 @@ public HandshakeProtocol nextState(String response) { ScramAttributes clientFirstMessageBare = ScramAttributes.create() .username(username) .nonce(nonce); - byte[] jsonBytes = toUTF8( - "{" + - "\"protocol_version\":" + SUB_PROTOCOL_VERSION + "," + - "\"authentication_method\":\"SCRAM-SHA-256\"," + - "\"authentication\":" + "\"n,," + clientFirstMessageBare + "\"" + - "}" - ); - ByteBuffer msg = Util.leByteBuffer( - Integer.BYTES + // size of VERSION - jsonBytes.length + // json auth payload - 1 // terminating null byte - ).putInt(VERSION.value) + byte[] jsonBytes = ("{" + + "\"protocol_version\":" + SUB_PROTOCOL_VERSION + "," + + "\"authentication_method\":\"SCRAM-SHA-256\"," + + "\"authentication\":" + "\"n,," + clientFirstMessageBare + "\"" + + "}").getBytes(StandardCharsets.UTF_8); + // Creating the ByteBuffer over an underlying array makes + // it easier to turn into a string later. + //return ByteBuffer.wrap(new byte[capacity]).order(ByteOrder.LITTLE_ENDIAN); + // size of VERSION + // json auth payload + // terminating null byte + ByteBuffer msg = ByteBuffer.allocate(Integer.BYTES + // size of VERSION + jsonBytes.length + // json auth payload + 1).order(ByteOrder.LITTLE_ENDIAN).putInt(VERSION.value) .put(jsonBytes) .put(new byte[1]); return new WaitingForProtocolRange( @@ -194,8 +201,9 @@ public HandshakeProtocol nextState(String response) { ScramAttributes auth = clientFinalMessageWithoutProof .clientProof(clientProof); - byte[] authJson = toUTF8("{\"authentication\":\"" + auth + "\"}"); - ByteBuffer message = Util.leByteBuffer(authJson.length + 1) + byte[] authJson = ("{\"authentication\":\"" + auth + "\"}").getBytes(StandardCharsets.UTF_8); + + ByteBuffer message = ByteBuffer.allocate(authJson.length + 1).order(ByteOrder.LITTLE_ENDIAN) .put(authJson) .put(new byte[1]); return new WaitingForAuthSuccess(serverSignature, message); @@ -261,7 +269,10 @@ public boolean isFinished() { } } - public static class ScramAttributes { + /** + * Salted Challenge Response Authentication Mechanism (SCRAM) attributes + */ + static class ScramAttributes { @Nullable String _authIdentity; // a @Nullable String _username; // n @Nullable String _nonce; // r @@ -387,22 +398,40 @@ ScramAttributes clientProof(byte[] clientProof) { } // Getters - String authIdentity() { return _authIdentity; } + String authIdentity() { + return _authIdentity; + } - String username() { return _username; } + String username() { + return _username; + } - String nonce() { return _nonce; } + String nonce() { + return _nonce; + } - String headerAndChannelBinding() { return _headerAndChannelBinding; } + String headerAndChannelBinding() { + return _headerAndChannelBinding; + } - byte[] salt() { return _salt; } + byte[] salt() { + return _salt; + } - Integer iterationCount() { return _iterationCount; } + Integer iterationCount() { + return _iterationCount; + } - String clientProof() { return _clientProof; } + String clientProof() { + return _clientProof; + } - byte[] serverSignature() { return _serverSignature; } + byte[] serverSignature() { + return _serverSignature; + } - String error() { return _error; } + String error() { + return _error; + } } } diff --git a/src/main/java/com/rethinkdb/net/Response.java b/src/main/java/com/rethinkdb/net/Response.java index 34812c7c..2dfc2d06 100644 --- a/src/main/java/com/rethinkdb/net/Response.java +++ b/src/main/java/com/rethinkdb/net/Response.java @@ -14,6 +14,7 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -57,7 +58,12 @@ public static Response readFrom(ConnectionSocket socket) { if (Response.logger.isDebugEnabled()) { Response.logger.debug( - "JSON Recv: Token: {} {}", token, Util.bufferToString(buf)); + "JSON Recv: Token: {} {}", token, new String( + buf.array(), + buf.arrayOffset() + buf.position(), + buf.remaining(), + StandardCharsets.UTF_8 + )); } Map jsonResp = Util.toJSON(buf); ResponseType responseType = ResponseType.fromValue(((Long) jsonResp.get("t")).intValue()); diff --git a/src/main/java/com/rethinkdb/net/Util.java b/src/main/java/com/rethinkdb/net/Util.java index f12178b8..73a17572 100644 --- a/src/main/java/com/rethinkdb/net/Util.java +++ b/src/main/java/com/rethinkdb/net/Util.java @@ -13,25 +13,6 @@ public class Util { private Util() {} - public static long deadline(long timeout) { - return System.currentTimeMillis() + timeout; - } - - public static ByteBuffer leByteBuffer(int capacity) { - return ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN); - // Creating the ByteBuffer over an underlying array makes - // it easier to turn into a string later. - //return ByteBuffer.wrap(new byte[capacity]).order(ByteOrder.LITTLE_ENDIAN); - } - - public static String bufferToString(ByteBuffer buf) { - return new String( - buf.array(), - buf.arrayOffset() + buf.position(), - buf.remaining(), - StandardCharsets.UTF_8 - ); - } @SuppressWarnings("unchecked") public static Map toJSON(String str) { @@ -73,14 +54,6 @@ public static T convertToPojo(Object value, TypeReference typeRef) { } return (T) value; } - - public static byte[] toUTF8(String s) { - return s.getBytes(StandardCharsets.UTF_8); - } - - public static String fromUTF8(byte[] ba) { - return new String(ba, StandardCharsets.UTF_8); - } } From 2efd8f00baf599ff1924c3d926716813d68c118d Mon Sep 17 00:00:00 2001 From: Adrian Todt Date: Sat, 15 Feb 2020 11:03:44 -0300 Subject: [PATCH 5/7] Made all private fields and methods private, so custom implementations are possible. Implements https://github.com/rethinkdb/rethinkdb/issues/5996 --- .../java/com/rethinkdb/net/Connection.java | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/main/java/com/rethinkdb/net/Connection.java b/src/main/java/com/rethinkdb/net/Connection.java index 86a62bcf..246a4a8d 100644 --- a/src/main/java/com/rethinkdb/net/Connection.java +++ b/src/main/java/com/rethinkdb/net/Connection.java @@ -22,22 +22,22 @@ import java.util.concurrent.locks.ReentrantLock; public class Connection implements Closeable { - private final ConnectionSocket.Factory socketFactory; - private final ResponsePump.Factory pumpFactory; - private final String hostname; - private final int port; - private final @Nullable SSLContext sslContext; - private final @Nullable Long timeout; - private final @Nullable String user; - private final @Nullable String password; - - private final AtomicLong nextToken = new AtomicLong(); - private final Set> tracked = ConcurrentHashMap.newKeySet(); - private final Lock writeLock = new ReentrantLock(); - - private @Nullable String dbname; - private @Nullable ConnectionSocket socket; - private @Nullable ResponsePump pump; + protected final ConnectionSocket.Factory socketFactory; + protected final ResponsePump.Factory pumpFactory; + protected final String hostname; + protected final int port; + protected final @Nullable SSLContext sslContext; + protected final @Nullable Long timeout; + protected final @Nullable String user; + protected final @Nullable String password; + + protected final AtomicLong nextToken = new AtomicLong(); + protected final Set> tracked = ConcurrentHashMap.newKeySet(); + protected final Lock writeLock = new ReentrantLock(); + + protected @Nullable String dbname; + protected @Nullable ConnectionSocket socket; + protected @Nullable ResponsePump pump; public Connection(Builder c) { if (c.authKey != null && c.user != null) { @@ -186,7 +186,7 @@ void keepTrackOf(ResponseHandler r) { * @param query the query to execute. * @return a completable future. */ - private Mono sendQuery(Query query) { + protected Mono sendQuery(Query query) { if (socket == null || !socket.isOpen()) { throw new ReqlDriverError("Client not connected."); } @@ -210,7 +210,7 @@ private Mono sendQuery(Query query) { * * @param query the query to execute. */ - private void runQueryNoreply(Query query) { + protected void runQueryNoreply(Query query) { if (socket == null || !socket.isOpen()) { throw new ReqlDriverError("Client not connected."); } @@ -227,12 +227,12 @@ private void runQueryNoreply(Query query) { } } - private Flux runQuery(Query query, @Nullable TypeReference typeRef) { + protected Flux runQuery(Query query, @Nullable TypeReference typeRef) { return sendQuery(query).onErrorMap(ReqlDriverError::new) .flatMapMany(res -> Flux.create(new ResponseHandler<>(this, query, res, typeRef))); } - private void handleOptArgs(OptArgs optArgs) { + protected void handleOptArgs(OptArgs optArgs) { if (!optArgs.containsKey("db") && dbname != null) { // Only override the db global arg if the user hasn't // specified one already and one is specified on the connection From 5ab589e30fc0500381828b2e61e8a76082c83c78 Mon Sep 17 00:00:00 2001 From: Adrian Todt Date: Sat, 15 Feb 2020 19:58:06 -0300 Subject: [PATCH 6/7] Initial attempt to crack the tests on reactive. --- .../java/com/rethinkdb/net/Connection.java | 17 +- .../java/com/rethinkdb/net/Converter.java | 6 +- src/main/java/com/rethinkdb/net/Response.java | 20 +-- .../com/rethinkdb/net/ResponseHandler.java | 28 +-- src/test/java/com/rethinkdb/AuthTest.java | 4 +- .../java/com/rethinkdb/RethinkDBTest.java | 159 +++++++----------- .../java/com/rethinkdb/TestingCommon.java | 43 +++-- .../java/com/rethinkdb/TestingFramework.java | 42 ++++- 8 files changed, 161 insertions(+), 158 deletions(-) diff --git a/src/main/java/com/rethinkdb/net/Connection.java b/src/main/java/com/rethinkdb/net/Connection.java index 246a4a8d..6427bf05 100644 --- a/src/main/java/com/rethinkdb/net/Connection.java +++ b/src/main/java/com/rethinkdb/net/Connection.java @@ -30,6 +30,7 @@ public class Connection implements Closeable { protected final @Nullable Long timeout; protected final @Nullable String user; protected final @Nullable String password; + protected final boolean unwrapLists; protected final AtomicLong nextToken = new AtomicLong(); protected final Set> tracked = ConcurrentHashMap.newKeySet(); @@ -52,6 +53,7 @@ public Connection(Builder c) { this.timeout = c.timeout; this.user = c.user != null ? c.user : "admin"; this.password = c.password != null ? c.password : c.authKey != null ? c.authKey : ""; + this.unwrapLists = c.unwrapLists; } public @Nullable String db() { @@ -157,22 +159,22 @@ public void close(boolean shouldNoreplyWait) { // package-private methods - void sendStop(long token) { + protected void sendStop(long token) { // While the server does reply to the stop request, we ignore that reply. // This works because the response pump in `connect` ignores replies for which // no waiter exists. runQueryNoreply(Query.stop(token)); } - Mono sendContinue(long token) { + protected Mono sendContinue(long token) { return sendQuery(Query.continue_(token)); } - void loseTrackOf(ResponseHandler r) { + protected void loseTrackOf(ResponseHandler r) { tracked.add(r); } - void keepTrackOf(ResponseHandler r) { + protected void keepTrackOf(ResponseHandler r) { tracked.remove(r); } @@ -261,6 +263,7 @@ public static class Builder { private @Nullable String authKey; private @Nullable String user; private @Nullable String password; + private boolean unwrapLists = false; public Builder copyOf() { Builder c = new Builder(); @@ -274,6 +277,7 @@ public Builder copyOf() { c.authKey = authKey; c.user = user; c.password = password; + c.unwrapLists = unwrapLists; return c; } @@ -323,6 +327,11 @@ public Builder sslContext(SSLContext val) { return this; } + public Builder unwrapLists(boolean val) { + unwrapLists = val; + return this; + } + public Builder timeout(long val) { timeout = val; return this; diff --git a/src/main/java/com/rethinkdb/net/Converter.java b/src/main/java/com/rethinkdb/net/Converter.java index c3f1790a..538fc2cd 100644 --- a/src/main/java/com/rethinkdb/net/Converter.java +++ b/src/main/java/com/rethinkdb/net/Converter.java @@ -44,7 +44,7 @@ public FormatOptions(OptArgs args){ @SuppressWarnings("unchecked") public static Object convertPseudotypes(Object obj, FormatOptions fmt){ if(obj instanceof List) { - return ((List) obj).stream() + return ((List) obj).stream() .map(item -> convertPseudotypes(item, fmt)) .collect(Collectors.toList()); } else if(obj instanceof Map) { @@ -88,9 +88,9 @@ public static Object convertPseudo(Map value, FormatOptions fmt) } @SuppressWarnings("unchecked") - private static List getGrouped(Map value) { + private static List> getGrouped(Map value) { return ((List>) value.get("data")).stream() - .map(g -> new GroupedResult(g.remove(0), g)) + .map(g -> new GroupedResult<>(g.remove(0), g)) .collect(Collectors.toList()); } diff --git a/src/main/java/com/rethinkdb/net/Response.java b/src/main/java/com/rethinkdb/net/Response.java index 2dfc2d06..3dabe17b 100644 --- a/src/main/java/com/rethinkdb/net/Response.java +++ b/src/main/java/com/rethinkdb/net/Response.java @@ -21,7 +21,7 @@ import java.util.Objects; import java.util.stream.Collectors; -class Response { +public class Response { private static final Logger logger = LoggerFactory.getLogger(Query.class); public final long token; @@ -86,11 +86,11 @@ public static Response readFrom(ConnectionSocket socket) { .build(); } - static Builder make(long token, ResponseType type) { + public static Builder make(long token, ResponseType type) { return new Builder(token, type); } - ReqlError makeError(Query query) { + public ReqlError makeError(Query query) { String msg = data.size() > 0 ? (String) data.get(0) : "Unknown error message"; @@ -101,30 +101,30 @@ ReqlError makeError(Query query) { .build(); } - boolean isWaitComplete() { + public boolean isWaitComplete() { return type == ResponseType.WAIT_COMPLETE; } /* Whether the response is any kind of feed */ - boolean isFeed() { + public boolean isFeed() { return notes.stream().anyMatch(ResponseNote::isFeed); } /* Whether the response is any kind of error */ - boolean isError() { + public boolean isError() { return type.isError(); } /* What type of success the response contains */ - boolean isAtom() { + public boolean isAtom() { return type == ResponseType.SUCCESS_ATOM; } - boolean isSequence() { + public boolean isSequence() { return type == ResponseType.SUCCESS_SEQUENCE; } - boolean isPartial() { + public boolean isPartial() { return type == ResponseType.SUCCESS_PARTIAL; } @@ -140,7 +140,7 @@ public String toString() { '}'; } - static class Builder { + public static class Builder { long token; ResponseType responseType; List notes = new ArrayList<>(); diff --git a/src/main/java/com/rethinkdb/net/ResponseHandler.java b/src/main/java/com/rethinkdb/net/ResponseHandler.java index a15b6850..a879bd23 100644 --- a/src/main/java/com/rethinkdb/net/ResponseHandler.java +++ b/src/main/java/com/rethinkdb/net/ResponseHandler.java @@ -15,18 +15,18 @@ import java.util.function.Consumer; public class ResponseHandler implements Consumer> { - private final Connection connection; - private final Query query; - private final Response firstRes; - private final TypeReference typeRef; - private final Converter.FormatOptions fmt; + protected final Connection connection; + protected final Query query; + protected final Response firstRes; + protected final TypeReference typeRef; + protected final Converter.FormatOptions fmt; // This gets used if it's a partial request. - private final Semaphore requesting = new Semaphore(1); - private final Semaphore emitting = new Semaphore(1); - private final AtomicLong requestCount = new AtomicLong(); - private final AtomicReference currentResponse = new AtomicReference<>(); - private final AtomicReference> sink = new AtomicReference<>(); + protected final Semaphore requesting = new Semaphore(1); + protected final Semaphore emitting = new Semaphore(1); + protected final AtomicLong requestCount = new AtomicLong(); + protected final AtomicReference currentResponse = new AtomicReference<>(); + protected final AtomicReference> sink = new AtomicReference<>(); public ResponseHandler(Connection connection, Query query, Response firstRes, TypeReference typeRef) { this.connection = connection; @@ -77,7 +77,7 @@ public void accept(final FluxSink sink) { sink.error(firstRes.makeError(query)); } - private void onRequest(FluxSink sink, long amount) { + protected void onRequest(FluxSink sink, long amount) { final Response lastRes = currentResponse.get(); if (lastRes.isPartial() && requestCount.addAndGet(amount) > 0 && requesting.tryAcquire()) { // great, we should make a CONTINUE request. @@ -117,7 +117,7 @@ private void onRequest(FluxSink sink, long amount) { } } - void onConnectionClosed() throws InterruptedException { + protected void onConnectionClosed() throws InterruptedException { // This will spin wait for a bit until it is not null while (sink.compareAndSet(null, null)) Thread.yield(); FluxSink sink = this.sink.get(); @@ -130,10 +130,10 @@ void onConnectionClosed() throws InterruptedException { } @SuppressWarnings("unchecked") - private int emitData(final FluxSink sink) { + protected int emitData(final FluxSink sink) { List objects = (List) Converter.convertPseudotypes(firstRes.data, fmt); for (Object each : objects) { - if (firstRes.isAtom() && each instanceof List) { + if (connection.unwrapLists && firstRes.isAtom() && each instanceof List) { for (Object o : ((List) each)) { sink.next(Util.convertToPojo(o, typeRef)); } diff --git a/src/test/java/com/rethinkdb/AuthTest.java b/src/test/java/com/rethinkdb/AuthTest.java index bb4044d3..f85ede78 100644 --- a/src/test/java/com/rethinkdb/AuthTest.java +++ b/src/test/java/com/rethinkdb/AuthTest.java @@ -35,14 +35,14 @@ public static void oneTimeTearDown() throws Exception { @Test public void testConnectWithNonAdminUser() throws Exception { - Connection bogusConn = TestingFramework.defaultConnectionBuilder().clone() + Connection bogusConn = TestingFramework.defaultConnectionBuilder().copyOf() .user(bogusUsername, bogusPassword).connect(); bogusConn.close(); } @Test (expected=ReqlDriverError.class) public void testConnectWithBothAuthKeyAndUsername() throws Exception { - Connection bogusConn = TestingFramework.defaultConnectionBuilder().clone() + Connection bogusConn = TestingFramework.defaultConnectionBuilder().copyOf() .user(bogusUsername, bogusPassword).authKey("test").connect(); } } diff --git a/src/test/java/com/rethinkdb/RethinkDBTest.java b/src/test/java/com/rethinkdb/RethinkDBTest.java index 5915a936..64980127 100644 --- a/src/test/java/com/rethinkdb/RethinkDBTest.java +++ b/src/test/java/com/rethinkdb/RethinkDBTest.java @@ -8,12 +8,11 @@ import net.jodah.concurrentunit.Waiter; import org.junit.*; import org.junit.rules.ExpectedException; +import reactor.core.publisher.Flux; import java.lang.reflect.Field; import java.time.OffsetDateTime; -import java.util.Arrays; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; @@ -67,62 +66,62 @@ public void tearDown() throws Exception { @Test public void testBooleans() throws Exception { - Boolean t = r.expr(true).run(conn); + Boolean t = r.expr(true).run(conn).cast(Boolean.class).blockFirst(); Assert.assertEquals(true, t.booleanValue()); - Boolean f = r.expr(false).run(conn); + Boolean f = r.expr(false).run(conn).cast(Boolean.class).blockFirst(); Assert.assertEquals(false, f.booleanValue()); - String trueType = r.expr(true).typeOf().run(conn); + String trueType = r.expr(true).typeOf().run(conn).cast(String.class).blockFirst(); Assert.assertEquals("BOOL", trueType); - String falseString = r.expr(false).coerceTo("string").run(conn); + String falseString = r.expr(false).coerceTo("string").run(conn).cast(String.class).blockFirst(); Assert.assertEquals("false", falseString); - Boolean boolCoerce = r.expr(true).coerceTo("bool").run(conn); + Boolean boolCoerce = r.expr(true).coerceTo("bool").run(conn).cast(Boolean.class).blockFirst(); Assert.assertEquals(true, boolCoerce.booleanValue()); } @Test public void testNull() { - Object o = r.expr(null).run(conn); + Object o = r.expr(null).run(conn).blockFirst(); Assert.assertEquals(null, o); - String nullType = r.expr(null).typeOf().run(conn); + String nullType = r.expr(null).typeOf().run(conn).cast(String.class).blockFirst(); Assert.assertEquals("NULL", nullType); - String nullCoerce = r.expr(null).coerceTo("string").run(conn); + String nullCoerce = r.expr(null).coerceTo("string").run(conn).cast(String.class).blockFirst(); Assert.assertEquals("null", nullCoerce); - Object n = r.expr(null).coerceTo("null").run(conn); + Object n = r.expr(null).coerceTo("null").run(conn).blockFirst(); Assert.assertEquals(null, n); } @Test public void testString() { - String str = r.expr("str").run(conn); + String str = r.expr("str").run(conn).cast(String.class).blockFirst(); Assert.assertEquals("str", str); - String unicode = r.expr("こんにちは").run(conn); + String unicode = r.expr("こんにちは").run(conn).cast(String.class).blockFirst(); Assert.assertEquals("こんにちは", unicode); - String strType = r.expr("foo").typeOf().run(conn); + String strType = r.expr("foo").typeOf().run(conn).cast(String.class).blockFirst(); Assert.assertEquals("STRING", strType); - String strCoerce = r.expr("foo").coerceTo("string").run(conn); + String strCoerce = r.expr("foo").coerceTo("string").run(conn).cast(String.class).blockFirst(); Assert.assertEquals("foo", strCoerce); - Number nmb12 = r.expr("-1.2").coerceTo("NUMBER").run(conn); + Number nmb12 = r.expr("-1.2").coerceTo("NUMBER").run(conn).cast(Number.class).blockFirst(); Assert.assertEquals(-1.2, nmb12); - Long nmb10 = r.expr("0xa").coerceTo("NUMBER").run(conn); + Long nmb10 = r.expr("0xa").coerceTo("NUMBER").run(conn).cast(Long.class).blockFirst(); Assert.assertEquals(10L, nmb10.longValue()); } @Test public void testDate() { OffsetDateTime date = OffsetDateTime.now(); - OffsetDateTime result = r.expr(date).run(conn); + OffsetDateTime result = r.expr(date).run(conn).cast(OffsetDateTime.class).blockFirst(); Assert.assertEquals(date, result); } @@ -147,62 +146,67 @@ public void testCoerceFailureInfinity() { r.expr("inf").coerceTo("NUMBER").run(conn); } + @SuppressWarnings("unchecked") @Test public void testSplitEdgeCases() { - List emptySplitNothing = r.expr("").split().run(conn); + List emptySplitNothing = r.expr("").split().run(conn).cast(List.class).blockFirst(); Assert.assertEquals(emptySplitNothing, Arrays.asList()); - List nullSplit = r.expr("").split(null).run(conn); + List nullSplit = r.expr("").split(null).run(conn).cast(List.class).blockFirst(); Assert.assertEquals(nullSplit, Arrays.asList()); - List emptySplitSpace = r.expr("").split(" ").run(conn); + List emptySplitSpace = r.expr("").split(" ").run(conn).cast(List.class).blockFirst(); Assert.assertEquals(Arrays.asList(""), emptySplitSpace); - List emptySplitEmpty = r.expr("").split("").run(conn); + List emptySplitEmpty = r.expr("").split("").run(conn).cast(List.class).blockFirst(); assertEquals(Arrays.asList(), emptySplitEmpty); - List emptySplitNull5 = r.expr("").split(null, 5).run(conn); + List emptySplitNull5 = r.expr("").split(null, 5).run(conn).cast(List.class).blockFirst(); assertEquals(Arrays.asList(), emptySplitNull5); - List emptySplitSpace5 = r.expr("").split(" ", 5).run(conn); + List emptySplitSpace5 = r.expr("").split(" ", 5).run(conn).cast(List.class).blockFirst(); assertEquals(Arrays.asList(""), emptySplitSpace5); - List emptySplitEmpty5 = r.expr("").split("", 5).run(conn); + List emptySplitEmpty5 = r.expr("").split("", 5).run(conn).cast(List.class).blockFirst(); assertEquals(Arrays.asList(), emptySplitEmpty5); } + @SuppressWarnings("unchecked") @Test public void testSplitWithNullOrWhitespace() { - List extraWhitespace = r.expr("aaaa bbbb cccc ").split().run(conn); + List extraWhitespace = r.expr("aaaa bbbb cccc ").split().run(conn).cast(List.class).blockFirst(); assertEquals(Arrays.asList("aaaa", "bbbb", "cccc"), extraWhitespace); - List extraWhitespaceNull = r.expr("aaaa bbbb cccc ").split(null).run(conn); + List extraWhitespaceNull = r.expr("aaaa bbbb cccc ").split(null).run(conn).cast(List.class).blockFirst(); assertEquals(Arrays.asList("aaaa", "bbbb", "cccc"), extraWhitespaceNull); - List extraWhitespaceSpace = r.expr("aaaa bbbb cccc ").split(" ").run(conn); + List extraWhitespaceSpace = r.expr("aaaa bbbb cccc ").split(" ").run(conn).cast(List.class).blockFirst(); assertEquals(Arrays.asList("aaaa", "bbbb", "", "cccc", ""), extraWhitespaceSpace); - List extraWhitespaceEmpty = r.expr("aaaa bbbb cccc ").split("").run(conn); + List extraWhitespaceEmpty = r.expr("aaaa bbbb cccc ").split("").run(conn).cast(List.class).blockFirst(); assertEquals(Arrays.asList("a", "a", "a", "a", " ", "b", "b", "b", "b", " ", " ", "c", "c", "c", "c", " "), extraWhitespaceEmpty); } + @SuppressWarnings("unchecked") @Test public void testSplitWithString() { - List b = r.expr("aaaa bbbb cccc ").split("b").run(conn); + List b = r.expr("aaaa bbbb cccc ").split("b").run(conn).cast(List.class).blockFirst(); assertEquals(Arrays.asList("aaaa ", "", "", "", " cccc "), b); } + @SuppressWarnings("unchecked") @Test public void testTableInsert(){ MapObject foo = new MapObject() .with("hi", "There") .with("yes", 7) .with("no", null ); - Map result = r.db(dbName).table(tableName).insert(foo).run(conn); + Map result = r.db(dbName).table(tableName).insert(foo).run(conn).cast(Map.class).blockFirst(); assertEquals(1L, result.get("inserted")); } + @SuppressWarnings("unchecked") @Test public void testDbGlobalArgInserted() { final String tblName = "test_global_optargs"; @@ -221,22 +225,22 @@ public void testDbGlobalArgInserted() { try { // no optarg set, no default db conn.use(null); - Map x1 = r.table(tblName).get(1).run(conn); + Map x1 = r.table(tblName).get(1).run(conn).cast(Map.class).blockFirst(); assertEquals("test", x1.get("dbName")); // no optarg set, default db set conn.use("conn_default"); - Map x2 = r.table(tblName).get(1).run(conn); + Map x2 = r.table(tblName).get(1).run(conn).cast(Map.class).blockFirst(); assertEquals("conn_default", x2.get("dbName")); // optarg set, no default db conn.use(null); - Map x3 = r.table(tblName).get(1).run(conn, OptArgs.of("db", "optargs")); + Map x3 = r.table(tblName).get(1).run(conn, OptArgs.of("db", "optargs")).cast(Map.class).blockFirst(); assertEquals("optargs", x3.get("dbName")); // optarg set, default db conn.use("conn_default"); - Map x4 = r.table(tblName).get(1).run(conn, OptArgs.of("db", "optargs")); + Map x4 = r.table(tblName).get(1).run(conn, OptArgs.of("db", "optargs")).cast(Map.class).blockFirst(); assertEquals("optargs", x4.get("dbName")); } finally { @@ -251,62 +255,42 @@ public void testFilter() { r.db(dbName).table(tableName).insert(new MapObject().with("field", "123")).run(conn); r.db(dbName).table(tableName).insert(new MapObject().with("field", "456")).run(conn); - Cursor> allEntries = r.db(dbName).table(tableName).run(conn); - assertEquals(2, allEntries.toList().size()); + Flux allEntries = r.db(dbName).table(tableName).run(conn); + assertEquals(2, allEntries.count().block().longValue()); // The following won't work, because r.row is not implemented in the Java driver. Use lambda syntax instead // Cursor> oneEntryRow = r.db(dbName).table(tableName).filter(r.row("field").eq("456")).run(conn); // assertEquals(1, oneEntryRow.toList().size()); - Cursor> oneEntryLambda = r.db(dbName).table(tableName).filter(table -> table.getField("field").eq("456")).run(conn); - assertEquals(1, oneEntryLambda.toList().size()); - } - - @Test - public void testCursorTryWithResources() { - r.db(dbName).table(tableName).insert(new MapObject().with("field", "123")).run(conn); - r.db(dbName).table(tableName).insert(new MapObject().with("field", "456")).run(conn); - - try(Cursor> allEntries = r.db(dbName).table(tableName).run(conn)) { - assertEquals(2, allEntries.toList().size()); - } + Flux oneEntryLambda = r.db(dbName).table(tableName).filter(table -> table.getField("field").eq("456")).run(conn); + assertEquals(1, oneEntryLambda.count().block().intValue()); } @Test public void testTableSelectOfPojo() { TestPojo pojo = new TestPojo("foo", new TestPojoInner(42L, true)); - Map pojoResult = r.db(dbName).table(tableName).insert(pojo).run(conn); + Map pojoResult = r.db(dbName).table(tableName).insert(pojo).run(conn).cast(Map.class).blockFirst(); assertEquals(1L, pojoResult.get("inserted")); String key = (String) ((List) pojoResult.get("generated_keys")).get(0); - TestPojo result = r.db(dbName).table(tableName).get(key).run(conn, TestPojo.class); + TestPojo result = r.db(dbName).table(tableName).get(key).run(conn, TestPojo.class).blockFirst(); assertEquals("foo", result.getStringProperty()); assertTrue(42L == result.getPojoProperty().getLongProperty()); assertEquals(true, result.getPojoProperty().getBooleanProperty()); } - @Test(expected = ClassCastException.class) - public void testTableSelectOfPojo_withNoPojoClass_throwsException() { - TestPojo pojo = new TestPojo("foo", new TestPojoInner(42L, true)); - Map pojoResult = r.db(dbName).table(tableName).insert(pojo).run(conn); - assertEquals(1L, pojoResult.get("inserted")); - - String key = (String) ((List) pojoResult.get("generated_keys")).get(0); - TestPojo result = r.db(dbName).table(tableName).get(key).run(conn /* TestPojo.class is not specified */); - } - @Test public void testTableSelectOfPojoCursor() { TestPojo pojoOne = new TestPojo("foo", new TestPojoInner(42L, true)); TestPojo pojoTwo = new TestPojo("bar", new TestPojoInner(53L, false)); - Map pojoOneResult = r.db(dbName).table(tableName).insert(pojoOne).run(conn); - Map pojoTwoResult = r.db(dbName).table(tableName).insert(pojoTwo).run(conn); + Map pojoOneResult = r.db(dbName).table(tableName).insert(pojoOne).run(conn).cast(Map.class).blockFirst(); + Map pojoTwoResult = r.db(dbName).table(tableName).insert(pojoTwo).run(conn).cast(Map.class).blockFirst(); assertEquals(1L, pojoOneResult.get("inserted")); assertEquals(1L, pojoTwoResult.get("inserted")); - Cursor cursor = r.db(dbName).table(tableName).run(conn, TestPojo.class); - List result = cursor.toList(); + Flux cursor = r.db(dbName).table(tableName).run(conn, TestPojo.class); + List result = Objects.requireNonNull(cursor.collectList().block()); assertEquals(2L, result.size()); TestPojo pojoOneSelected = "foo".equals(result.get(0).getStringProperty()) ? result.get(0) : result.get(1); @@ -321,21 +305,6 @@ public void testTableSelectOfPojoCursor() { assertEquals(false, pojoTwoSelected.getPojoProperty().getBooleanProperty()); } - @Test(expected = ClassCastException.class) - public void testTableSelectOfPojoCursor_withNoPojoClass_throwsException() { - TestPojo pojoOne = new TestPojo("foo", new TestPojoInner(42L, true)); - TestPojo pojoTwo = new TestPojo("bar", new TestPojoInner(53L, false)); - Map pojoOneResult = r.db(dbName).table(tableName).insert(pojoOne).run(conn); - Map pojoTwoResult = r.db(dbName).table(tableName).insert(pojoTwo).run(conn); - assertEquals(1L, pojoOneResult.get("inserted")); - assertEquals(1L, pojoTwoResult.get("inserted")); - - Cursor cursor = r.db(dbName).table(tableName).run(conn /* TestPojo.class is not specified */); - List result = cursor.toList(); - - TestPojo pojoSelected = result.get(0); - } - @Test(timeout=40000) public void testConcurrentWrites() throws TimeoutException, InterruptedException { final int total = 500; @@ -344,7 +313,7 @@ public void testConcurrentWrites() throws TimeoutException, InterruptedException for (int i = 0; i < total; i++) new Thread(() -> { final TestPojo pojo = new TestPojo("writezz", new TestPojoInner(10L, true)); - final Map result = r.db(dbName).table(tableName).insert(pojo).run(conn); + final Map result = r.db(dbName).table(tableName).insert(pojo).run(conn).cast(Map.class).blockFirst(); waiter.assertEquals(1L, result.get("inserted")); writeCounter.getAndIncrement(); waiter.resume(); @@ -362,17 +331,17 @@ public void testConcurrentReads() throws TimeoutException { // write to the database and retrieve the id final TestPojo pojo = new TestPojo("readzz", new TestPojoInner(10L, true)); - final Map result = r.db(dbName).table(tableName).insert(pojo).optArg("return_changes", true).run(conn); + final Map result = r.db(dbName).table(tableName).insert(pojo).optArg("return_changes", true).run(conn).cast(Map.class).blockFirst(); final String id = ((List) result.get("generated_keys")).get(0).toString(); final Waiter waiter = new Waiter(); for (int i = 0; i < total; i++) new Thread(() -> { // make sure there's only one - final Cursor cursor = r.db(dbName).table(tableName).run(conn, TestPojo.class); - assertEquals(1, cursor.toList().size()); + final Flux cursor = r.db(dbName).table(tableName).run(conn, TestPojo.class); + assertEquals(1, cursor.count().block().intValue()); // read that one - final TestPojo readPojo = r.db(dbName).table(tableName).get(id).run(conn, TestPojo.class); + final TestPojo readPojo = r.db(dbName).table(tableName).get(id).run(conn, TestPojo.class).blockFirst(); waiter.assertNotNull(readPojo); // assert inserted values waiter.assertEquals("readzz", readPojo.getStringProperty()); @@ -394,15 +363,15 @@ public void testConcurrentCursor() throws TimeoutException, InterruptedException for (int i = 0; i < total; i++) new Thread(() -> { final TestPojo pojo = new TestPojo("writezz", new TestPojoInner(10L, true)); - final Map result = r.db(dbName).table(tableName).insert(pojo).run(conn); + final Map result = r.db(dbName).table(tableName).insert(pojo).run(conn).cast(Map.class).blockFirst(); waiter.assertEquals(1L, result.get("inserted")); waiter.resume(); }).start(); waiter.await(2500, total); - final Cursor all = r.db(dbName).table(tableName).run(conn); - assertEquals(total, all.toList().size()); + final Flux all = r.db(dbName).table(tableName).run(conn, TestPojo.class); + assertEquals(total, all.count().block().intValue()); } @Test @@ -412,19 +381,13 @@ public void testNoreply() throws Exception { @Test public void test_Changefeeds_Cursor_Close_cause_new_cursor_cause_memory_leak() throws Exception { - Field f_cursorCache = Connection.class.getDeclaredField("cursorCache"); + Field f_cursorCache = Connection.class.getDeclaredField("tracked"); f_cursorCache.setAccessible(true); - Map cursorCache = (Map) f_cursorCache.get(conn); + Set cursorCache = (Set) f_cursorCache.get(conn); assertEquals(0, cursorCache.size()); - Cursor c = r.db(dbName).table(tableName).changes().run(conn); - - try { - c.next(1000); - } catch (TimeoutException ex) { - } - c.close(); + Flux f = r.db(dbName).table(tableName).changes().run(conn); assertEquals(0, cursorCache.size()); } diff --git a/src/test/java/com/rethinkdb/TestingCommon.java b/src/test/java/com/rethinkdb/TestingCommon.java index d6d9ffbc..69e222ec 100644 --- a/src/test/java/com/rethinkdb/TestingCommon.java +++ b/src/test/java/com/rethinkdb/TestingCommon.java @@ -340,22 +340,24 @@ public static ErrRegex err_regex(String classname, String message_rgx, Object du } public static ArrayList fetch(Object cursor_, long limit) throws Exception { - if(limit < 0) { - limit = Long.MAX_VALUE; - } - Cursor cursor = (Cursor) cursor_; - long total = 0; - ArrayList result = new ArrayList((int) limit); - for(long i = 0; i < limit; i++) { - if(!cursor.hasNext()){ - break; - } - result.add(cursor.next(500)); - } - return result; - } - - public static ArrayList fetch(Cursor cursor) throws Exception { + System.out.println(cursor_); + return new ArrayList(); +// if(limit < 0) { +// limit = Long.MAX_VALUE; +// } +// Cursor cursor = (Cursor) cursor_; +// long total = 0; +// ArrayList result = new ArrayList((int) limit); +// for(long i = 0; i < limit; i++) { +// if(!cursor.hasNext()){ +// break; +// } +// result.add(cursor.next(500)); +// } +// return result; + } + + public static ArrayList fetch(Object cursor) throws Exception { return fetch(cursor, -1); } @@ -367,14 +369,7 @@ public static Object runOrCatch(Object query, OptArgs runopts, Connection conn) return query; } try { - Object res = ((ReqlAst)query).run(conn, runopts); - if(res instanceof Cursor) { - ArrayList ret = new ArrayList(); - ((Cursor) res).forEachRemaining(ret::add); - return ret; - }else{ - return res; - } + return ((TestingFramework.TestingConnection) conn).internalRun(((ReqlAst)query), runopts).block(); } catch (Exception e) { return e; } diff --git a/src/test/java/com/rethinkdb/TestingFramework.java b/src/test/java/com/rethinkdb/TestingFramework.java index 220dda9f..e650b8e7 100644 --- a/src/test/java/com/rethinkdb/TestingFramework.java +++ b/src/test/java/com/rethinkdb/TestingFramework.java @@ -1,8 +1,20 @@ package com.rethinkdb; -import com.rethinkdb.net.Connection; +import com.fasterxml.jackson.core.type.TypeReference; +import com.rethinkdb.ast.Query; +import com.rethinkdb.ast.ReqlAst; +import com.rethinkdb.gen.exc.ReqlDriverError; +import com.rethinkdb.model.OptArgs; +import com.rethinkdb.net.*; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; -import java.io.*; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.List; import java.util.Properties; /** @@ -83,7 +95,31 @@ public static Connection createConnection() throws Exception { * i.e. connection secured with SSL. */ public static Connection createConnection(Connection.Builder builder) throws Exception { - return builder.connect(); + return new TestingConnection(builder).connect(); } + /** + * This injects a method that complies with what the test framework is awaiting. + */ + public static class TestingConnection extends Connection { + public TestingConnection(Builder c) { + super(c); + } + + public Mono internalRun(ReqlAst term, OptArgs optArgs) { + handleOptArgs(optArgs); + Query q = Query.start(nextToken.incrementAndGet(), term, optArgs); + if (optArgs.containsKey("noreply")) { + throw new ReqlDriverError("Don't provide the noreply option as an optarg. Use `.runNoReply` instead of `.run`"); + } + return sendQuery(q).onErrorMap(ReqlDriverError::new).flatMap(res -> { + Flux flux = Flux.create(new ResponseHandler<>(this, q, res, null)); + if (res.isAtom()) { + return flux.next(); + } else { + return flux.collectList(); + } + }); + } + } } From 995bf4865498ac4b274919d757aa57908609548c Mon Sep 17 00:00:00 2001 From: Adrian Todt Date: Sat, 15 Feb 2020 21:50:50 -0300 Subject: [PATCH 7/7] Fixed Connection, ResponsePump and Handshake issues --- src/main/java/com/rethinkdb/net/Connection.java | 7 +++++-- .../java/com/rethinkdb/net/DefaultConnectionFactory.java | 2 +- src/main/java/com/rethinkdb/net/HandshakeProtocol.java | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/main/java/com/rethinkdb/net/Connection.java b/src/main/java/com/rethinkdb/net/Connection.java index 6427bf05..b779d558 100644 --- a/src/main/java/com/rethinkdb/net/Connection.java +++ b/src/main/java/com/rethinkdb/net/Connection.java @@ -132,7 +132,9 @@ public void close() { public void close(boolean shouldNoreplyWait) { // disconnect try { - noreplyWait(); + if (shouldNoreplyWait) { + noreplyWait(); + } } finally { // reset token nextToken.set(0); @@ -231,7 +233,8 @@ protected void runQueryNoreply(Query query) { protected Flux runQuery(Query query, @Nullable TypeReference typeRef) { return sendQuery(query).onErrorMap(ReqlDriverError::new) - .flatMapMany(res -> Flux.create(new ResponseHandler<>(this, query, res, typeRef))); + .flatMapMany(res -> Flux.create(new ResponseHandler<>(this, query, res, typeRef))) + .cache(); } protected void handleOptArgs(OptArgs optArgs) { diff --git a/src/main/java/com/rethinkdb/net/DefaultConnectionFactory.java b/src/main/java/com/rethinkdb/net/DefaultConnectionFactory.java index d2bbb896..bdce06a4 100644 --- a/src/main/java/com/rethinkdb/net/DefaultConnectionFactory.java +++ b/src/main/java/com/rethinkdb/net/DefaultConnectionFactory.java @@ -221,7 +221,7 @@ public ThreadResponsePump(ConnectionSocket socket) { } } }, "RethinkDB-" + socket + "-ResponsePump"); - + thread.start(); } @Override diff --git a/src/main/java/com/rethinkdb/net/HandshakeProtocol.java b/src/main/java/com/rethinkdb/net/HandshakeProtocol.java index e2b0915f..b8491b65 100644 --- a/src/main/java/com/rethinkdb/net/HandshakeProtocol.java +++ b/src/main/java/com/rethinkdb/net/HandshakeProtocol.java @@ -27,7 +27,7 @@ abstract class HandshakeProtocol { public static final String SERVER_KEY = "Server Key"; public static HandshakeProtocol start(String username, String password) { - return new InitialState(username, password); + return new InitialState(username, password).nextState(null); } private HandshakeProtocol() {