Skip to content

Commit

Permalink
feat: client sends retry cookie back to server
Browse files Browse the repository at this point in the history
  • Loading branch information
mutianf committed Aug 16, 2023
1 parent decfe98 commit bfafcd3
Show file tree
Hide file tree
Showing 7 changed files with 353 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.cloud.bigtable.data.v2.stub;

import static com.google.cloud.bigtable.data.v2.stub.CookiesHolder.COOKIES_HOLDER_KEY;
import static com.google.cloud.bigtable.data.v2.stub.CookiesHolder.ROUTING_COOKIE_KEY;
import static com.google.cloud.bigtable.data.v2.stub.CookiesHolder.ROUTING_COOKIE_METADATA_KEY;

import com.google.protobuf.ByteString;
import com.google.rpc.ErrorInfo;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall;
import io.grpc.ForwardingClientCallListener;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.protobuf.ProtoUtils;

/**
* A cookie interceptor that checks the cookie value from returned ErrorInfo, updates the cookie
* holder, and inject it in the header of the next request.
*/
class CookieInterceptor implements ClientInterceptor {

static final Metadata.Key<ErrorInfo> ERROR_INFO_KEY =
ProtoUtils.keyForProto(ErrorInfo.getDefaultInstance());

@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> methodDescriptor, CallOptions callOptions, Channel channel) {
return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(
channel.newCall(methodDescriptor, callOptions)) {
@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
CookiesHolder cookie = callOptions.getOption(COOKIES_HOLDER_KEY);
if (cookie != null && cookie.getRoutingCookie() != null) {
headers.put(ROUTING_COOKIE_METADATA_KEY, cookie.getRoutingCookie().toByteArray());
}
super.start(new UpdateCookieListener<>(responseListener, callOptions), headers);
}
};
}

static class UpdateCookieListener<RespT>
extends ForwardingClientCallListener.SimpleForwardingClientCallListener<RespT> {

private final CallOptions callOptions;

UpdateCookieListener(ClientCall.Listener<RespT> delegate, CallOptions callOptions) {
super(delegate);
this.callOptions = callOptions;
}

@Override
public void onClose(Status status, Metadata trailers) {
if (status != Status.OK && trailers != null) {
ErrorInfo errorInfo = trailers.get(ERROR_INFO_KEY);
if (errorInfo != null) {
CookiesHolder cookieHolder = callOptions.getOption(COOKIES_HOLDER_KEY);
cookieHolder.setRoutingCookie(
ByteString.copyFromUtf8(errorInfo.getMetadataMap().get(ROUTING_COOKIE_KEY)));
}
}
super.onClose(status, trailers);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.cloud.bigtable.data.v2.stub;

import com.google.protobuf.ByteString;
import io.grpc.CallOptions;
import io.grpc.Metadata;
import javax.annotation.Nullable;

/** A cookie that holds information for the retry */
class CookiesHolder {

static final CallOptions.Key<CookiesHolder> COOKIES_HOLDER_KEY =
CallOptions.Key.create("bigtable-cookies");

static final String ROUTING_COOKIE_KEY = "bigtable-routing-cookie";

static final Metadata.Key<byte[]> ROUTING_COOKIE_METADATA_KEY =
Metadata.Key.of("bigtable-routing-cookie-bin", Metadata.BINARY_BYTE_MARSHALLER);

@Nullable private ByteString routingCookie;

void setRoutingCookie(@Nullable ByteString routingCookie) {
this.routingCookie = routingCookie;
}

ByteString getRoutingCookie() {
return this.routingCookie;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.cloud.bigtable.data.v2.stub;

import static com.google.cloud.bigtable.data.v2.stub.CookiesHolder.COOKIES_HOLDER_KEY;

import com.google.api.gax.grpc.GrpcCallContext;
import com.google.api.gax.rpc.ApiCallContext;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallable;

public class CookiesServerStreamingCallable<RequestT, ResponseT>
extends ServerStreamingCallable<RequestT, ResponseT> {

private final ServerStreamingCallable<RequestT, ResponseT> callable;

CookiesServerStreamingCallable(ServerStreamingCallable<RequestT, ResponseT> innerCallable) {
this.callable = innerCallable;
}

@Override
public void call(
RequestT request, ResponseObserver<ResponseT> responseObserver, ApiCallContext context) {
GrpcCallContext grpcCallContext = (GrpcCallContext) context;
callable.call(
request,
responseObserver,
grpcCallContext.withCallOptions(
grpcCallContext.getCallOptions().withOption(COOKIES_HOLDER_KEY, new CookiesHolder())));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.cloud.bigtable.data.v2.stub;

import static com.google.cloud.bigtable.data.v2.stub.CookiesHolder.COOKIES_HOLDER_KEY;

import com.google.api.core.ApiFuture;
import com.google.api.gax.grpc.GrpcCallContext;
import com.google.api.gax.rpc.ApiCallContext;
import com.google.api.gax.rpc.UnaryCallable;

/** Cookie callable injects a placeholder for bigtable retry cookie. */
class CookiesUnaryCallable<RequestT, ResponseT> extends UnaryCallable<RequestT, ResponseT> {
private UnaryCallable<RequestT, ResponseT> innerCallable;

CookiesUnaryCallable(UnaryCallable<RequestT, ResponseT> callable) {
this.innerCallable = callable;
}

@Override
public ApiFuture<ResponseT> futureCall(RequestT request, ApiCallContext context) {
GrpcCallContext grpcCallContext = (GrpcCallContext) context;
return innerCallable.futureCall(
request,
grpcCallContext.withCallOptions(
grpcCallContext.getCallOptions().withOption(COOKIES_HOLDER_KEY, new CookiesHolder())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ public <RowT> ServerStreamingCallable<Query, RowT> createReadRowsCallable(
new TracedServerStreamingCallable<>(
readRowsUserCallable, clientContext.getTracerFactory(), span);

return traced.withDefaultCallContext(clientContext.getDefaultCallContext());
ServerStreamingCallable<Query, RowT> withCookie = new CookiesServerStreamingCallable<>(traced);

return withCookie.withDefaultCallContext(clientContext.getDefaultCallContext());
}

/**
Expand Down Expand Up @@ -401,7 +403,9 @@ public <RowT> UnaryCallable<Query, RowT> createReadRowCallable(RowAdapter<RowT>
new TracedUnaryCallable<>(
firstRow, clientContext.getTracerFactory(), getSpanName("ReadRow"));

return traced.withDefaultCallContext(clientContext.getDefaultCallContext());
UnaryCallable<Query, RowT> withCookie = new CookiesUnaryCallable<>(traced);

return withCookie.withDefaultCallContext(clientContext.getDefaultCallContext());
}

/**
Expand Down Expand Up @@ -1013,7 +1017,9 @@ private <RequestT, ResponseT> UnaryCallable<RequestT, ResponseT> createUserFacin
UnaryCallable<RequestT, ResponseT> traced =
new TracedUnaryCallable<>(inner, clientContext.getTracerFactory(), getSpanName(methodName));

return traced.withDefaultCallContext(clientContext.getDefaultCallContext());
UnaryCallable<RequestT, ResponseT> withCookie = new CookiesUnaryCallable<>(traced);

return withCookie.withDefaultCallContext(clientContext.getDefaultCallContext());
}

private UnaryCallable<PingAndWarmRequest, PingAndWarmResponse> createPingAndWarmCallable() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ public static InstantiatingGrpcChannelProvider.Builder defaultGrpcTransportProvi
Duration.ofSeconds(10)) // wait this long before considering the connection dead
// Attempts direct access to CBT service over gRPC to improve throughput,
// whether the attempt is allowed is totally controlled by service owner.
.setAttemptDirectPath(true);
.setAttemptDirectPath(true)
.setInterceptorProvider(() -> ImmutableList.of(new CookieInterceptor()));
}

@SuppressWarnings("WeakerAccess")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.cloud.bigtable.data.v2.stub;

import static com.google.cloud.bigtable.data.v2.stub.CookieInterceptor.ERROR_INFO_KEY;
import static com.google.cloud.bigtable.data.v2.stub.CookiesHolder.ROUTING_COOKIE_KEY;
import static com.google.cloud.bigtable.data.v2.stub.CookiesHolder.ROUTING_COOKIE_METADATA_KEY;
import static com.google.common.truth.Truth.assertThat;

import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider;
import com.google.bigtable.v2.BigtableGrpc;
import com.google.bigtable.v2.MutateRowRequest;
import com.google.bigtable.v2.MutateRowResponse;
import com.google.cloud.bigtable.data.v2.BigtableDataClient;
import com.google.cloud.bigtable.data.v2.BigtableDataSettings;
import com.google.cloud.bigtable.data.v2.FakeServiceBuilder;
import com.google.cloud.bigtable.data.v2.models.RowMutation;
import com.google.common.collect.ImmutableList;
import com.google.rpc.ErrorInfo;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class CookieHolderTest {
private Server server;
private FakeService fakeService = new FakeService();

private BigtableDataClient client;

private List<Metadata> serverMetadata = new ArrayList<>();

@Before
public void setup() throws Exception {
ServerInterceptor serverInterceptor =
new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> serverCall,
Metadata metadata,
ServerCallHandler<ReqT, RespT> serverCallHandler) {
serverMetadata.add(metadata);
return serverCallHandler.startCall(serverCall, metadata);
}
};

server = FakeServiceBuilder.create(fakeService).intercept(serverInterceptor).start();

BigtableDataSettings.Builder settings =
BigtableDataSettings.newBuilderForEmulator(server.getPort())
.setProjectId("fake-project")
.setInstanceId("fake-instance");

InstantiatingGrpcChannelProvider channelProvider =
(InstantiatingGrpcChannelProvider) settings.stubSettings().getTransportChannelProvider();
settings
.stubSettings()
.setTransportChannelProvider(
channelProvider
.toBuilder()
.setInterceptorProvider(() -> ImmutableList.of(new CookieInterceptor()))
.build());

client = BigtableDataClient.create(settings.build());
}

@Test
public void testRetryCookieIsForwarded() {
client.mutateRow(RowMutation.create("fake-table", "fake-row").setCell("cf", "q", "v"));

assertThat(serverMetadata.size()).isEqualTo(fakeService.count.get());
byte[] bytes = serverMetadata.get(1).get(ROUTING_COOKIE_METADATA_KEY);
assertThat(new String(bytes, StandardCharsets.UTF_8)).isEqualTo("test-routing-cookie");

serverMetadata.clear();
}

@After
public void tearDown() throws Exception {
client.close();
server.shutdown();
}

class FakeService extends BigtableGrpc.BigtableImplBase {

private AtomicInteger count = new AtomicInteger();

@Override
public void mutateRow(
MutateRowRequest request, StreamObserver<MutateRowResponse> responseObserver) {
if (count.getAndIncrement() < 1) {
Metadata trailers = new Metadata();
ErrorInfo errorInfo =
ErrorInfo.newBuilder().putMetadata(ROUTING_COOKIE_KEY, "test-routing-cookie").build();
trailers.put(ERROR_INFO_KEY, errorInfo);
StatusRuntimeException exception = new StatusRuntimeException(Status.UNAVAILABLE, trailers);
responseObserver.onError(exception);
return;
}
responseObserver.onNext(MutateRowResponse.getDefaultInstance());
responseObserver.onCompleted();
}
}
}

0 comments on commit bfafcd3

Please sign in to comment.