diff --git a/d2-test-api/src/main/java/com/linkedin/d2/balancer/clients/TestClient.java b/d2-test-api/src/main/java/com/linkedin/d2/balancer/clients/TestClient.java index bdcdb66f7a..5de8a01a16 100644 --- a/d2-test-api/src/main/java/com/linkedin/d2/balancer/clients/TestClient.java +++ b/d2-test-api/src/main/java/com/linkedin/d2/balancer/clients/TestClient.java @@ -102,6 +102,24 @@ public void restRequest(RestRequest request, callback.onResponse(TransportResponseImpl.success(response)); } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Map wireAttrs, TransportCallback callback) { + restRequest = request; + restRequestContext = requestContext; + restWireAttrs = wireAttrs; + streamCallback = callback; + StreamResponseBuilder builder = new StreamResponseBuilder(); + StreamResponse response = _emptyResponse ? builder.build(EntityStreams.emptyStream()) + : builder.build(EntityStreams.newEntityStream(new ByteStringWriter(ByteString.copy("This is not empty".getBytes())))); + if (_deferCallback) + { + scheduleTimeout(requestContext, callback); + return; + } + callback.onResponse(TransportResponseImpl.success(response, wireAttrs)); + } + @Override public void streamRequest(StreamRequest request, RequestContext requestContext, diff --git a/d2/src/main/java/com/linkedin/d2/balancer/D2ClientDelegator.java b/d2/src/main/java/com/linkedin/d2/balancer/D2ClientDelegator.java index 19e4706beb..5211e8e208 100644 --- a/d2/src/main/java/com/linkedin/d2/balancer/D2ClientDelegator.java +++ b/d2/src/main/java/com/linkedin/d2/balancer/D2ClientDelegator.java @@ -93,6 +93,17 @@ public void streamRequest(StreamRequest request, RequestContext requestContext, _d2Client.streamRequest(request, requestContext, callback); } + @Override + public void restRequestStreamResponse(RestRequest request, Callback callback) { + _d2Client.restRequestStreamResponse(request, callback); + } + + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Callback callback) { + _d2Client.restRequestStreamResponse(request, requestContext, callback); + } + @Override public void getMetadata(URI uri, Callback> callback) { diff --git a/d2/src/main/java/com/linkedin/d2/balancer/clients/BackupRequestsClient.java b/d2/src/main/java/com/linkedin/d2/balancer/clients/BackupRequestsClient.java index 655e9a6fcb..105f67c789 100644 --- a/d2/src/main/java/com/linkedin/d2/balancer/clients/BackupRequestsClient.java +++ b/d2/src/main/java/com/linkedin/d2/balancer/clients/BackupRequestsClient.java @@ -390,6 +390,24 @@ public void streamRequest(StreamRequest request, RequestContext requestContext, decorateCallbackSync(request, requestContext, _d2Client::streamRequest, callback)); } + @Override + public void restRequestStreamResponse(RestRequest request, Callback callback) { + restRequestStreamResponse(request, new RequestContext(), callback); + } + + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Callback callback) { + if (_isD2Async) + { + requestAsync(request, requestContext, _d2Client::restRequestStreamResponse, callback); + return; + } + + _d2Client.restRequestStreamResponse(request, requestContext, + decorateCallbackSync(request, requestContext, _d2Client::restRequestStreamResponse, callback)); + } + private Callback decorateCallbackSync(R request, RequestContext requestContext, DecoratorClient client, Callback callback) { diff --git a/d2/src/main/java/com/linkedin/d2/balancer/clients/DynamicClient.java b/d2/src/main/java/com/linkedin/d2/balancer/clients/DynamicClient.java index 8b7864808a..1b91dba1b9 100644 --- a/d2/src/main/java/com/linkedin/d2/balancer/clients/DynamicClient.java +++ b/d2/src/main/java/com/linkedin/d2/balancer/clients/DynamicClient.java @@ -107,7 +107,16 @@ public void streamRequest(StreamRequest request, _balancer.getClient(request, requestContext, getClientCallback(request, requestContext, true, callback, client -> client.streamRequest(request, requestContext, loggerCallback)) ); + } + + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Callback callback) { + Callback loggerCallback = decorateLoggingCallback(callback, request, "stream"); + _balancer.getClient(request, requestContext, + getClientCallback(request, requestContext, true, callback, client -> client.restRequestStreamResponse(request, requestContext, loggerCallback)) + ); } private Callback getClientCallback(Request request, RequestContext requestContext, final boolean restOverStream, Callback callback, SuccessCallback clientSuccessCallback) diff --git a/d2/src/main/java/com/linkedin/d2/balancer/clients/LazyClient.java b/d2/src/main/java/com/linkedin/d2/balancer/clients/LazyClient.java index 6e502ff826..36fbac2080 100644 --- a/d2/src/main/java/com/linkedin/d2/balancer/clients/LazyClient.java +++ b/d2/src/main/java/com/linkedin/d2/balancer/clients/LazyClient.java @@ -60,6 +60,12 @@ public void restRequest(RestRequest request, getWrappedClient().restRequest(request, requestContext, wireAttrs, callback); } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Map wireAttrs, TransportCallback callback) { + getWrappedClient().restRequestStreamResponse(request, requestContext, wireAttrs, callback); + } + @Override public void streamRequest(StreamRequest request, RequestContext requestContext, diff --git a/d2/src/main/java/com/linkedin/d2/balancer/clients/RequestTimeoutClient.java b/d2/src/main/java/com/linkedin/d2/balancer/clients/RequestTimeoutClient.java index f63071796b..bca6a62a53 100644 --- a/d2/src/main/java/com/linkedin/d2/balancer/clients/RequestTimeoutClient.java +++ b/d2/src/main/java/com/linkedin/d2/balancer/clients/RequestTimeoutClient.java @@ -120,6 +120,20 @@ public void streamRequest(StreamRequest request, RequestContext requestContext, _d2Client.streamRequest(request, requestContext, transportCallback); } + @Override + public void restRequestStreamResponse(RestRequest request, Callback callback) { + restRequestStreamResponse(request, new RequestContext(), callback); + } + + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Callback callback) { + final Callback transportCallback = + decorateCallbackWithRequestTimeout(callback, request, requestContext); + + _d2Client.restRequestStreamResponse(request, requestContext, transportCallback); + } + /** * Enforces the user timeout to the layer below if necessary. * diff --git a/d2/src/main/java/com/linkedin/d2/balancer/clients/RetryClient.java b/d2/src/main/java/com/linkedin/d2/balancer/clients/RetryClient.java index f464af53a7..241f50431b 100644 --- a/d2/src/main/java/com/linkedin/d2/balancer/clients/RetryClient.java +++ b/d2/src/main/java/com/linkedin/d2/balancer/clients/RetryClient.java @@ -186,6 +186,29 @@ public void streamRequest(StreamRequest request, RequestContext requestContext, } } + @Override + public void restRequestStreamResponse(RestRequest request, Callback callback) { + restRequestStreamResponse(request, new RequestContext(), callback); + } + + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Callback callback) { + if (_restRetryEnabled) + { + RestRequest newRequest = request.builder() + .setHeader(HttpConstants.HEADER_NUMBER_OF_RETRY_ATTEMPTS, "0") + .build(); + ClientRetryTracker retryTracker = updateRetryTracker(newRequest.getURI(), false); + final Callback transportCallback = new ResponseOnlyStreamRetryRequestCallback(newRequest, requestContext, callback, retryTracker); + _d2Client.restRequestStreamResponse(newRequest, requestContext, transportCallback); + } + else + { + _d2Client.restRequestStreamResponse(request, requestContext, callback); + } + } + private ClientRetryTracker updateRetryTracker(URI uri, boolean isRetry) { String serviceName = LoadBalancerUtil.getServiceNameFromUri(uri); @@ -276,6 +299,28 @@ public boolean doRetryRequest(RestRequest request, RequestContext context, int n } } + /** + * Callback implementation for Retry {@link RestRequest} and {@link StreamResponse} + */ + private class ResponseOnlyStreamRetryRequestCallback extends RetryRequestCallback + { + public ResponseOnlyStreamRetryRequestCallback(RestRequest request, RequestContext context, Callback callback, ClientRetryTracker retryTracker) + { + super(request, context, callback, retryTracker); + } + + @Override + public boolean doRetryRequest(RestRequest request, RequestContext context, int numberOfRetryAttempts) + { + RestRequest newRequest = request.builder() + .setHeader(HttpConstants.HEADER_NUMBER_OF_RETRY_ATTEMPTS, Integer.toString(numberOfRetryAttempts)) + .build(); + updateRetryTracker(request.getURI(), true); + _d2Client.restRequestStreamResponse(newRequest, context, this); + return true; + } + } + /** * Abstract callback implementation of retry requests. * diff --git a/d2/src/main/java/com/linkedin/d2/balancer/clients/RewriteClient.java b/d2/src/main/java/com/linkedin/d2/balancer/clients/RewriteClient.java index 81787b6a71..baf055a226 100644 --- a/d2/src/main/java/com/linkedin/d2/balancer/clients/RewriteClient.java +++ b/d2/src/main/java/com/linkedin/d2/balancer/clients/RewriteClient.java @@ -60,6 +60,14 @@ public void restRequest(RestRequest request, RequestContext requestContext, Map< _transportClient.restRequest(rewriteRequest(request), requestContext, wireAttrs, callback); } + @Override + public void restRequestStreamResponse(RestRequest request, + RequestContext requestContext, + Map wireAttrs, + TransportCallback callback) { + _transportClient.restRequestStreamResponse(rewriteRequest(request), requestContext, wireAttrs, callback); + } + /** * Asynchronously issues the given request. The given callback is invoked when the response is * received. diff --git a/d2/src/main/java/com/linkedin/d2/balancer/clients/RewriteLoadBalancerClient.java b/d2/src/main/java/com/linkedin/d2/balancer/clients/RewriteLoadBalancerClient.java index dde59d37ff..a41e2e6d7c 100644 --- a/d2/src/main/java/com/linkedin/d2/balancer/clients/RewriteLoadBalancerClient.java +++ b/d2/src/main/java/com/linkedin/d2/balancer/clients/RewriteLoadBalancerClient.java @@ -74,6 +74,13 @@ public void streamRequest(StreamRequest request, _client.streamRequest(request, requestContext, wireAttrs, callback); } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Map wireAttrs, TransportCallback callback) { + assert _serviceName.equals(LoadBalancerUtil.getServiceNameFromUri(request.getURI())); + _client.restRequestStreamResponse(request, requestContext, wireAttrs, callback); + } + @Override public void shutdown(Callback callback) { diff --git a/d2/src/main/java/com/linkedin/d2/balancer/clients/TrackerClientImpl.java b/d2/src/main/java/com/linkedin/d2/balancer/clients/TrackerClientImpl.java index f8f2d5bfe4..9b2b33f48b 100644 --- a/d2/src/main/java/com/linkedin/d2/balancer/clients/TrackerClientImpl.java +++ b/d2/src/main/java/com/linkedin/d2/balancer/clients/TrackerClientImpl.java @@ -132,6 +132,12 @@ public void restRequest(RestRequest request, _transportClient.restRequest(request, requestContext, wireAttrs, new TrackerClientRestCallback(callback, _callTracker.startCall())); } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Map wireAttrs, TransportCallback callback) { + _transportClient.restRequestStreamResponse(request, requestContext, wireAttrs, new TrackerClientStreamCallback(callback, _callTracker.startCall())); + } + @Override public void streamRequest(StreamRequest request, RequestContext requestContext, diff --git a/d2/src/main/java/com/linkedin/d2/balancer/simple/ClusterAwareTransportClient.java b/d2/src/main/java/com/linkedin/d2/balancer/simple/ClusterAwareTransportClient.java index 7d6b837fa3..37966eef6a 100644 --- a/d2/src/main/java/com/linkedin/d2/balancer/simple/ClusterAwareTransportClient.java +++ b/d2/src/main/java/com/linkedin/d2/balancer/simple/ClusterAwareTransportClient.java @@ -72,6 +72,13 @@ public void restRequest(RestRequest request, getWrappedClient().restRequest(request, requestContext, wireAttrs, callback); } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Map wireAttrs, TransportCallback callback) { + updateRequestContext(requestContext); + getWrappedClient().restRequestStreamResponse(request, requestContext, wireAttrs, callback); + } + @Override public void streamRequest(StreamRequest request, RequestContext requestContext, diff --git a/d2/src/test/java/com/linkedin/d2/balancer/clients/LazyClientTest.java b/d2/src/test/java/com/linkedin/d2/balancer/clients/LazyClientTest.java index 99b31ce3d0..b5d23f6907 100644 --- a/d2/src/test/java/com/linkedin/d2/balancer/clients/LazyClientTest.java +++ b/d2/src/test/java/com/linkedin/d2/balancer/clients/LazyClientTest.java @@ -134,6 +134,12 @@ public void restRequest(RestRequest request, ++requestCount; } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Map wireAttrs, TransportCallback callback) { + ++requestCount; + } + @Override public void shutdown(Callback callback) { diff --git a/d2/src/test/java/com/linkedin/d2/balancer/clients/RetryTrackerClient.java b/d2/src/test/java/com/linkedin/d2/balancer/clients/RetryTrackerClient.java index 4fad531cde..83aee52928 100644 --- a/d2/src/test/java/com/linkedin/d2/balancer/clients/RetryTrackerClient.java +++ b/d2/src/test/java/com/linkedin/d2/balancer/clients/RetryTrackerClient.java @@ -74,6 +74,14 @@ public void streamRequest(StreamRequest request, () -> new StreamResponseBuilder().build(EntityStreams.emptyStream())); } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Map wireAttrs, TransportCallback callback) { + handleRequest(request, wireAttrs, callback, + r -> {}, + () -> new StreamResponseBuilder().build(EntityStreams.emptyStream())); + } + @Override public URI getUri() { diff --git a/d2/src/test/java/com/linkedin/d2/balancer/clients/TrackerClientTest.java b/d2/src/test/java/com/linkedin/d2/balancer/clients/TrackerClientTest.java index c110f2befd..f3d8bcbe6f 100644 --- a/d2/src/test/java/com/linkedin/d2/balancer/clients/TrackerClientTest.java +++ b/d2/src/test/java/com/linkedin/d2/balancer/clients/TrackerClientTest.java @@ -293,7 +293,110 @@ public void shutdown(Callback callback) {} Assert.assertEquals(stats.getErrorCountTotal(), 3); Assert.assertEquals(degraderControl.getCurrentComputedDropRate(), 0.2, 0.001); } + @Test + public void testCallTrackingRestRequestStreamResponse() throws Exception + { + URI uri = URI.create("http://test.qa.com:1234/foo"); + SettableClock clock = new SettableClock(); + AtomicInteger action = new AtomicInteger(0); + TransportClient tc = new TransportClient() { + @Override + public void restRequest(RestRequest request, RequestContext requestContext, Map wireAttrs, TransportCallback callback) { + } + + @Override + public void restRequestStreamResponse(RestRequest request, + RequestContext requestContext, + Map wireAttrs, + TransportCallback callback) { + clock.addDuration(5); + switch (action.get()) + { + // success + case 0: callback.onResponse(TransportResponseImpl.success(new StreamResponseBuilder().build(EntityStreams.emptyStream()))); + break; + // fail with stream exception + case 1: callback.onResponse(TransportResponseImpl.error( + new StreamException(new StreamResponseBuilder().setStatus(500).build(EntityStreams.emptyStream())))); + break; + // fail with timeout exception + case 2: callback.onResponse(TransportResponseImpl.error(new RemoteInvocationException(new TimeoutException()))); + break; + // fail with other exception + default: callback.onResponse(TransportResponseImpl.error(new RuntimeException())); + break; + } + } + @Override + public void shutdown(Callback callback) {} + }; + + DegraderTrackerClientImpl client = (DegraderTrackerClientImpl) createTrackerClient(tc, clock, uri); + CallTracker callTracker = client.getCallTracker(); + CallTracker.CallStats stats; + DegraderControl degraderControl = client.getDegraderControl(DefaultPartitionAccessor.DEFAULT_PARTITION_ID); + DelayConsumeCallback delayConsumeCallback = new DelayConsumeCallback(); + client.restRequestStreamResponse(new RestRequestBuilder(uri).build(), new RequestContext(), new HashMap<>(), delayConsumeCallback); + clock.addDuration(5); + // we only recorded the time when stream response arrives, but callcompletion.endcall hasn't been called yet. + Assert.assertEquals(callTracker.getCurrentCallCountTotal(), 0); + Assert.assertEquals(callTracker.getCurrentErrorCountTotal(), 0); + + // delay + clock.addDuration(100); + delayConsumeCallback.consume(); + clock.addDuration(5000); + // now that we consumed the entity stream, callcompletion.endcall has been called. + stats = callTracker.getCallStats(); + Assert.assertEquals(stats.getCallCount(), 1); + Assert.assertEquals(stats.getErrorCount(), 0); + Assert.assertEquals(stats.getCallCountTotal(), 1); + Assert.assertEquals(stats.getErrorCountTotal(), 0); + Assert.assertEquals(degraderControl.getCurrentComputedDropRate(), 0.0, 0.001); + + action.set(1); + client.restRequestStreamResponse(new RestRequestBuilder(uri).build(), new RequestContext(), new HashMap<>(), delayConsumeCallback); + clock.addDuration(5); + // we endcall with error immediately for stream exception, even before the entity is consumed + Assert.assertEquals(callTracker.getCurrentCallCountTotal(), 2); + Assert.assertEquals(callTracker.getCurrentErrorCountTotal(), 1); + delayConsumeCallback.consume(); + clock.addDuration(5000); + // no change in tracking after entity is consumed + stats = callTracker.getCallStats(); + Assert.assertEquals(stats.getCallCount(), 1); + Assert.assertEquals(stats.getErrorCount(), 1); + Assert.assertEquals(stats.getCallCountTotal(), 2); + Assert.assertEquals(stats.getErrorCountTotal(), 1); + Assert.assertEquals(degraderControl.getCurrentComputedDropRate(), 0.2, 0.001); + + action.set(2); + client.restRequestStreamResponse(new RestRequestBuilder(uri).build(), new RequestContext(), new HashMap<>(), new TestTransportCallback<>()); + clock.addDuration(5); + Assert.assertEquals(callTracker.getCurrentCallCountTotal(), 3); + Assert.assertEquals(callTracker.getCurrentErrorCountTotal(), 2); + clock.addDuration(5000); + stats = callTracker.getCallStats(); + Assert.assertEquals(stats.getCallCount(), 1); + Assert.assertEquals(stats.getErrorCount(), 1); + Assert.assertEquals(stats.getCallCountTotal(), 3); + Assert.assertEquals(stats.getErrorCountTotal(), 2); + Assert.assertEquals(degraderControl.getCurrentComputedDropRate(), 0.4, 0.001); + + action.set(3); + client.restRequestStreamResponse(new RestRequestBuilder(uri).build(), new RequestContext(), new HashMap<>(), new TestTransportCallback<>()); + clock.addDuration(5); + Assert.assertEquals(callTracker.getCurrentCallCountTotal(), 4); + Assert.assertEquals(callTracker.getCurrentErrorCountTotal(), 3); + clock.addDuration(5000); + stats = callTracker.getCallStats(); + Assert.assertEquals(stats.getCallCount(), 1); + Assert.assertEquals(stats.getErrorCount(), 1); + Assert.assertEquals(stats.getCallCountTotal(), 4); + Assert.assertEquals(stats.getErrorCountTotal(), 3); + Assert.assertEquals(degraderControl.getCurrentComputedDropRate(), 0.2, 0.001); + } @Test public void testDoNotSlowStartWhenTrue() { @@ -406,17 +509,15 @@ public void restRequest(RestRequest request, callback.onResponse(TransportResponseImpl.success(response)); } - @Override - public void streamRequest(StreamRequest request, - RequestContext requestContext, - Map wireAttrs, - TransportCallback callback) + public void restRequestStreamResponse(RestRequest request, + RequestContext requestContext, + Map wireAttrs, + TransportCallback callback) { - streamRequest = request; + restRequest = request; restRequestContext = requestContext; restWireAttrs = wireAttrs; streamCallback = callback; - StreamResponseBuilder builder = new StreamResponseBuilder(); StreamResponse response = _emptyResponse ? builder.build(EntityStreams.emptyStream()) : builder.build(EntityStreams.newEntityStream(new ByteStringWriter(ByteString.copy("This is not empty".getBytes())))); @@ -428,6 +529,26 @@ public void streamRequest(StreamRequest request, callback.onResponse(TransportResponseImpl.success(response, wireAttrs)); } + @Override + public void streamRequest(StreamRequest request, + RequestContext requestContext, + Map wireAttrs, + TransportCallback callback) { + streamRequest = request; + restRequestContext = requestContext; + restWireAttrs = wireAttrs; + streamCallback = callback; + + StreamResponseBuilder builder = new StreamResponseBuilder(); + StreamResponse response = _emptyResponse ? builder.build(EntityStreams.emptyStream()) : builder.build( + EntityStreams.newEntityStream(new ByteStringWriter(ByteString.copy("This is not empty".getBytes())))); + if (_dontCallCallback) { + scheduleTimeout(requestContext, callback); + return; + } + callback.onResponse(TransportResponseImpl.success(response, wireAttrs)); + } + private void scheduleTimeout(RequestContext requestContext, TransportCallback callback) { Integer requestTimeout = (Integer) requestContext.getLocalAttr(R2Constants.REQUEST_TIMEOUT); diff --git a/d2/src/test/java/com/linkedin/d2/balancer/simple/LoadBalancerSimulator.java b/d2/src/test/java/com/linkedin/d2/balancer/simple/LoadBalancerSimulator.java index 9f0fe6f2bb..0ec604dda3 100644 --- a/d2/src/test/java/com/linkedin/d2/balancer/simple/LoadBalancerSimulator.java +++ b/d2/src/test/java/com/linkedin/d2/balancer/simple/LoadBalancerSimulator.java @@ -467,6 +467,12 @@ public void streamRequest(StreamRequest request, throw new IllegalArgumentException("StreamRequest is not supported yet"); } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Map wireAttrs, TransportCallback callback) { + throw new IllegalArgumentException("RestRequestStreamResponse is not supported yet"); + } + @Override public void restRequest(RestRequest request, RequestContext requestContext, diff --git a/d2/src/test/java/com/linkedin/d2/balancer/simple/SimpleLoadBalancerTest.java b/d2/src/test/java/com/linkedin/d2/balancer/simple/SimpleLoadBalancerTest.java index 1be5daf6b9..186d2a283d 100644 --- a/d2/src/test/java/com/linkedin/d2/balancer/simple/SimpleLoadBalancerTest.java +++ b/d2/src/test/java/com/linkedin/d2/balancer/simple/SimpleLoadBalancerTest.java @@ -1647,6 +1647,14 @@ public void restRequest(RestRequest request, { } + @Override + public void restRequestStreamResponse(RestRequest request, + RequestContext requestContext, + Map wireAttrs, + TransportCallback callback) + { + } + @Override public void shutdown(Callback callback) { diff --git a/d2/src/test/java/com/linkedin/d2/balancer/strategies/degrader/DegraderLoadBalancerTest.java b/d2/src/test/java/com/linkedin/d2/balancer/strategies/degrader/DegraderLoadBalancerTest.java index 4299238fb1..00a44903fd 100644 --- a/d2/src/test/java/com/linkedin/d2/balancer/strategies/degrader/DegraderLoadBalancerTest.java +++ b/d2/src/test/java/com/linkedin/d2/balancer/strategies/degrader/DegraderLoadBalancerTest.java @@ -3783,6 +3783,12 @@ public void streamRequest(StreamRequest request, captureValues(request, requestContext, wireAttrs, callback); } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Map wireAttrs, TransportCallback callback) { + captureValues(request, requestContext, wireAttrs, callback); + } + @Override public void restRequest(RestRequest request, RequestContext requestContext, diff --git a/r2-core/src/main/java/com/linkedin/r2/filter/transport/FilterChainClient.java b/r2-core/src/main/java/com/linkedin/r2/filter/transport/FilterChainClient.java index 11e4e15840..dbc97e3b31 100644 --- a/r2-core/src/main/java/com/linkedin/r2/filter/transport/FilterChainClient.java +++ b/r2-core/src/main/java/com/linkedin/r2/filter/transport/FilterChainClient.java @@ -79,6 +79,14 @@ public void restRequest(RestRequest request, _filters.onRestRequest(request, requestContext, wireAttrs); } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Map wireAttrs, TransportCallback callback) { + ResponseFilter.registerCallback(createWrappedClientTimingCallback(requestContext, callback), requestContext); + markOnRequestTimings(requestContext); + _filters.onRestRequest(request, requestContext, wireAttrs); + } + @Override public void streamRequest(StreamRequest request, RequestContext requestContext, diff --git a/r2-core/src/main/java/com/linkedin/r2/transport/common/AbstractClient.java b/r2-core/src/main/java/com/linkedin/r2/transport/common/AbstractClient.java index db6c930351..a40c8301b0 100644 --- a/r2-core/src/main/java/com/linkedin/r2/transport/common/AbstractClient.java +++ b/r2-core/src/main/java/com/linkedin/r2/transport/common/AbstractClient.java @@ -81,6 +81,11 @@ public void streamRequest(StreamRequest request, Callback callba streamRequest(request, new RequestContext(), callback); } + @Override + public void restRequestStreamResponse(RestRequest request, Callback callback) { + restRequestStreamResponse(request, new RequestContext(), callback); + } + @Override public void restRequest(RestRequest request, RequestContext requestContext, Callback callback) { diff --git a/r2-core/src/main/java/com/linkedin/r2/transport/common/Client.java b/r2-core/src/main/java/com/linkedin/r2/transport/common/Client.java index 7ff463f82f..19e05540c5 100644 --- a/r2-core/src/main/java/com/linkedin/r2/transport/common/Client.java +++ b/r2-core/src/main/java/com/linkedin/r2/transport/common/Client.java @@ -85,7 +85,7 @@ void restRequest(RestRequest request, RequestContext requestContext, * Asynchronously issues the given request. The given callback is invoked when the response is * received. * - * Any implementation that wants to support streaming MUST override this method. + * Any implementation that wants to support bidirectional streaming MUST override this method. * * @param request the request to issue * @param callback the callback to invoke with the response @@ -99,7 +99,7 @@ default void streamRequest(StreamRequest request, Callback callb * Asynchronously issues the given request. The given callback is invoked when the response is * received. * - * Any implementation that wants to support streaming MUST override this method. + * Any implementation that wants to support bidirectional streaming MUST override this method. * * @param request the request to issue * @param requestContext context for the request @@ -110,6 +110,35 @@ default void streamRequest(StreamRequest request, RequestContext requestContext, throw new UnsupportedOperationException("Please use an implementation that supports streaming."); } + /** + * Asynchronously issues the given request. The given callback is invoked when the response is + * received. + * + * Any implementation that wants to support response-only streaming MUST override this method. + * + * @param request the request to issue + * @param callback the callback to invoke with the response + */ + default void restRequestStreamResponse(RestRequest request, Callback callback) + { + throw new UnsupportedOperationException("Please use an implementation that supports response-only streaming."); + } + + /** + * Asynchronously issues the given request. The given callback is invoked when the response is + * received. + * + * Any implementation that wants to support response-only streaming MUST override this method. + * + * @param request the request to issue + * @param requestContext context for the request + * @param callback the callback to invoke with the response + */ + default void restRequestStreamResponse(RestRequest request, RequestContext requestContext, Callback callback) + { + throw new UnsupportedOperationException("Please use an implementation that supports response-only streaming."); + } + /** * Initiates asynchronous shutdown of the client. This method should block minimally, if at all. * diff --git a/r2-core/src/main/java/com/linkedin/r2/transport/common/ClientDelegator.java b/r2-core/src/main/java/com/linkedin/r2/transport/common/ClientDelegator.java index 4b8e94c658..65e61bd372 100644 --- a/r2-core/src/main/java/com/linkedin/r2/transport/common/ClientDelegator.java +++ b/r2-core/src/main/java/com/linkedin/r2/transport/common/ClientDelegator.java @@ -76,6 +76,17 @@ public void streamRequest(StreamRequest request, RequestContext requestContext, _client.streamRequest(request, requestContext, callback); } + @Override + public void restRequestStreamResponse(RestRequest request, Callback callback) { + _client.restRequestStreamResponse(request, callback); + } + + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Callback callback) { + _client.restRequestStreamResponse(request, requestContext, callback); + } + @Override public void shutdown(Callback callback) { diff --git a/r2-core/src/main/java/com/linkedin/r2/transport/common/bridge/client/TransportClient.java b/r2-core/src/main/java/com/linkedin/r2/transport/common/bridge/client/TransportClient.java index 82249836a5..1928682640 100644 --- a/r2-core/src/main/java/com/linkedin/r2/transport/common/bridge/client/TransportClient.java +++ b/r2-core/src/main/java/com/linkedin/r2/transport/common/bridge/client/TransportClient.java @@ -54,7 +54,25 @@ void restRequest(RestRequest request, * Asynchronously issues the given request. The given callback is invoked when the response is * received. * - * Any implementation that wants to support streaming MUST override this method. + * This method allows the response-only streaming, compared to streamRequest(). + * + * @param request the request to issue + * @param requestContext context for the request + * @param wireAttrs attributes that should be sent over the wire to the server + * @param callback the callback to invoke with the response + */ + default void restRequestStreamResponse(RestRequest request, + RequestContext requestContext, + Map wireAttrs, + TransportCallback callback) { + throw new UnsupportedOperationException("Please use an implementation that supports response-only streaming."); + } + + /** + * Asynchronously issues the given request. The given callback is invoked when the response is + * received. + * + * Any implementation that wants to support bidirectional streaming MUST override this method. * * @param request the request to issue * @param requestContext context for the request diff --git a/r2-core/src/main/java/com/linkedin/r2/transport/common/bridge/client/TransportClientAdapter.java b/r2-core/src/main/java/com/linkedin/r2/transport/common/bridge/client/TransportClientAdapter.java index 571b1cc6e0..5d8b46dcb8 100644 --- a/r2-core/src/main/java/com/linkedin/r2/transport/common/bridge/client/TransportClientAdapter.java +++ b/r2-core/src/main/java/com/linkedin/r2/transport/common/bridge/client/TransportClientAdapter.java @@ -87,6 +87,12 @@ public void restRequest(RestRequest request, RequestContext requestContext, Call } } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, Callback callback) { + final Map wireAttrs = new HashMap(); + _client.restRequestStreamResponse(request, requestContext, wireAttrs, new TransportCallbackAdapter(callback)); + } + @Override public void shutdown(Callback callback) { diff --git a/r2-netty/src/main/java/com/linkedin/r2/netty/client/HttpNettyClient.java b/r2-netty/src/main/java/com/linkedin/r2/netty/client/HttpNettyClient.java index d458a2f1a1..ef039c3ab8 100644 --- a/r2-netty/src/main/java/com/linkedin/r2/netty/client/HttpNettyClient.java +++ b/r2-netty/src/main/java/com/linkedin/r2/netty/client/HttpNettyClient.java @@ -177,6 +177,14 @@ public void restRequest(RestRequest request, sendRequest(request, requestContext, wireAttrs, Messages.toStreamTransportCallback(callback)); } + @Override + public void restRequestStreamResponse(RestRequest request, + RequestContext requestContext, + Map wireAttrs, + TransportCallback callback) { + sendRequest(request, requestContext, wireAttrs, callback); + } + @Override public void streamRequest(StreamRequest request, RequestContext requestContext, diff --git a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/HttpClientFactory.java b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/HttpClientFactory.java index 2be59979f3..3d52be6cec 100644 --- a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/HttpClientFactory.java +++ b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/HttpClientFactory.java @@ -49,7 +49,9 @@ import com.linkedin.r2.transport.http.client.common.ConnectionSharingChannelPoolManagerFactory; import com.linkedin.r2.transport.http.client.common.EventAwareChannelPoolManagerFactory; import com.linkedin.r2.transport.http.client.rest.HttpNettyClient; +import com.linkedin.r2.transport.http.client.stream.http.HttpNettyResponseOnlyStreamClient; import com.linkedin.r2.transport.http.client.stream.http.HttpNettyStreamClient; +import com.linkedin.r2.transport.http.client.stream.http2.Http2NettyResponseOnlyStreamClient; import com.linkedin.r2.transport.http.client.stream.http2.Http2NettyStreamClient; import com.linkedin.r2.transport.http.common.HttpProtocolVersion; import com.linkedin.r2.util.ConfigValueExtractor; @@ -1412,16 +1414,23 @@ TransportClient getRawClient(Map properties, } TransportClient streamClient; + TransportClient responseOnlyStreamClient; switch (httpProtocolVersion) { case HTTP_1_1: streamClient = new HttpNettyStreamClient(_eventLoopGroup, _executor, requestTimeout, shutdownTimeout, _callbackExecutorGroup, _jmxManager, _channelPoolManagerFactory.buildStream(key), _channelPoolManagerFactory.buildStream(sslKey)); + responseOnlyStreamClient = new HttpNettyResponseOnlyStreamClient(_eventLoopGroup, _executor, requestTimeout, shutdownTimeout, + _callbackExecutorGroup, _jmxManager, _channelPoolManagerFactory.buildStream(key), + _channelPoolManagerFactory.buildStream(sslKey)); break; case HTTP_2: streamClient = new Http2NettyStreamClient(_eventLoopGroup, _executor, requestTimeout, shutdownTimeout, _callbackExecutorGroup, _jmxManager, _channelPoolManagerFactory.buildHttp2Stream(key), _channelPoolManagerFactory.buildHttp2Stream(sslKey)); + responseOnlyStreamClient = new Http2NettyResponseOnlyStreamClient(_eventLoopGroup, _executor, requestTimeout, shutdownTimeout, + _callbackExecutorGroup, _jmxManager, _channelPoolManagerFactory.buildHttp2Stream(key), + _channelPoolManagerFactory.buildHttp2Stream(sslKey)); break; default: throw new IllegalArgumentException("Unrecognized HTTP protocol version " + httpProtocolVersion); @@ -1431,7 +1440,7 @@ TransportClient getRawClient(Map properties, new HttpNettyClient(_eventLoopGroup, _executor, requestTimeout, shutdownTimeout, _callbackExecutorGroup, _jmxManager, _channelPoolManagerFactory.buildRest(key), _channelPoolManagerFactory.buildRest(sslKey)); - return new MixedClient(legacyClient, streamClient); + return new MixedClient(legacyClient, streamClient, responseOnlyStreamClient); } /** @@ -1627,6 +1636,14 @@ public void restRequest(RestRequest request, _client.restRequest(request, requestContext, wireAttrs, callback); } + @Override + public void restRequestStreamResponse(RestRequest request, + RequestContext requestContext, + Map wireAttrs, + TransportCallback callback) { + _client.restRequestStreamResponse(request, requestContext, wireAttrs, callback); + } + @Override public void streamRequest(StreamRequest request, RequestContext requestContext, Map wireAttrs, @@ -1687,11 +1704,13 @@ static class MixedClient implements TransportClient { private final TransportClient _legacyClient; private final TransportClient _streamClient; + private final TransportClient _responseOnlyStreamClient; - MixedClient(TransportClient legacyClient, TransportClient streamClient) + MixedClient(TransportClient legacyClient, TransportClient streamClient, TransportClient responseOnlyStreamClient) { _legacyClient = legacyClient; _streamClient = streamClient; + _responseOnlyStreamClient = responseOnlyStreamClient; } @Override @@ -1703,6 +1722,12 @@ public void restRequest(RestRequest request, _legacyClient.restRequest(request, requestContext, wireAttrs, callback); } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Map wireAttrs, TransportCallback callback) { + _responseOnlyStreamClient.restRequestStreamResponse(request, requestContext, wireAttrs, callback); + } + @Override public void streamRequest(StreamRequest request, RequestContext requestContext, @@ -1715,9 +1740,10 @@ public void streamRequest(StreamRequest request, @Override public void shutdown(final Callback callback) { - Callback multiCallback = new MultiCallback(callback, 2); + Callback multiCallback = new MultiCallback(callback, 3); _legacyClient.shutdown(multiCallback); _streamClient.shutdown(multiCallback); + _responseOnlyStreamClient.shutdown(multiCallback); } } } diff --git a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/common/AbstractNettyClient.java b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/common/AbstractNettyClient.java index d08dbc1e8a..11ef1e9008 100644 --- a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/common/AbstractNettyClient.java +++ b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/common/AbstractNettyClient.java @@ -159,6 +159,14 @@ public void restRequest(RestRequest request, RequestContext requestContext, Map< writeRequest((Req) request, requestContext, wireAttrs, (TransportCallback) HttpBridge.restToHttpCallback(callback, request)); } + @Override + @SuppressWarnings("unchecked") + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Map wireAttrs, TransportCallback callback) { + MessageType.setMessageType(MessageType.Type.REST, wireAttrs); + writeRequest((Req) request, requestContext, wireAttrs, (TransportCallback) HttpBridge.streamToHttpCallback(callback, request)); + } + @Override @SuppressWarnings("unchecked") public void streamRequest(StreamRequest request, RequestContext requestContext, Map wireAttrs, diff --git a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/rest/HttpNettyClient.java b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/rest/HttpNettyClient.java index 7edb0b2277..016312a31a 100644 --- a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/rest/HttpNettyClient.java +++ b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/rest/HttpNettyClient.java @@ -99,6 +99,12 @@ public void streamRequest(StreamRequest request, RequestContext requestContext, throw new UnsupportedOperationException("Stream is not supported."); } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Map wireAttrs, TransportCallback callback) { + throw new UnsupportedOperationException("Stream is not supported."); + } + @Override protected TransportCallback getExecutionCallback(TransportCallback callback) { diff --git a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/stream/AbstractNettyStreamClient.java b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/stream/AbstractNettyStreamClient.java index d44b7dd369..64933181a1 100644 --- a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/stream/AbstractNettyStreamClient.java +++ b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/stream/AbstractNettyStreamClient.java @@ -97,6 +97,12 @@ public void restRequest(RestRequest request, RequestContext requestContext, Map< throw new UnsupportedOperationException("Rest is not supported."); } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Map wireAttrs, TransportCallback callback) { + throw new UnsupportedOperationException("Response-only Stream is not supported."); + } + @Override protected TransportCallback getExecutionCallback(TransportCallback callback) { diff --git a/restli-client/src/test/java/com/linkedin/restli/client/MockClient.java b/restli-client/src/test/java/com/linkedin/restli/client/MockClient.java index d7ab45f2b8..992f54711d 100644 --- a/restli-client/src/test/java/com/linkedin/restli/client/MockClient.java +++ b/restli-client/src/test/java/com/linkedin/restli/client/MockClient.java @@ -24,6 +24,7 @@ import com.linkedin.common.util.None; import com.linkedin.r2.message.RequestContext; import com.linkedin.r2.message.Messages; +import com.linkedin.r2.message.rest.RestRequest; import com.linkedin.r2.message.rest.RestResponse; import com.linkedin.r2.message.rest.RestResponseBuilder; import com.linkedin.r2.message.stream.StreamRequest; @@ -70,6 +71,20 @@ public void streamRequest(StreamRequest request, RequestContext requestContext, adapter.onResponse(TransportResponseImpl.success(Messages.toStreamResponse(response))); } + @Override + public void restRequestStreamResponse(RestRequest request, RequestContext requestContext, + Callback callback) { + TransportCallback adapter = HttpBridge.streamToHttpCallback(new TransportCallbackAdapter(callback), request); + + RestResponse response = new RestResponseBuilder() + .setStatus(status()) + .setHeaders(headers()) + .setEntity(body()) + .build(); + + adapter.onResponse(TransportResponseImpl.success(Messages.toStreamResponse(response))); + } + @Override public void shutdown(Callback callback) {