Skip to content

Commit

Permalink
Handle channel requests returning a single item
Browse files Browse the repository at this point in the history
  • Loading branch information
rdegnan committed May 28, 2019
1 parent 91251b6 commit eaca23e
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 23 deletions.
83 changes: 61 additions & 22 deletions RSocket.Rpc.Protobuf/src/csharp_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ void GenerateDocCommentMethod(google::protobuf::io::Printer* printer,
}
}

inline string CapitalizeFirstLetter(string s) {
if (s.empty()) {
return s;
}
s[0] = ::toupper(s[0]);
return s;
}

std::string GetServiceClassName(const ServiceDescriptor* service) {
return service->name();
}
Expand All @@ -151,7 +159,7 @@ std::string GetServerClassName(const ServiceDescriptor* service) {
std::string GetServiceFieldName() { return "__Service"; }

std::string GetMethodFieldName(const MethodDescriptor* method) {
return "__Method_" + method->name();
return "__Method_" + CapitalizeFirstLetter(method->name());
}

// Gets vector of all messages used as input or output types.
Expand All @@ -176,8 +184,8 @@ std::vector<const Descriptor*> GetUsedMessages(

void GenerateStaticMethodField(Printer* out, const MethodDescriptor* method) {
out->Print("public const string $methodfield$ = \"$methodname$\";\n",
"methodfield", GetMethodFieldName(method), "methodname",
method->name());
"methodfield", GetMethodFieldName(method),
"methodname", CapitalizeFirstLetter(method->name()));
}

void GenerateServiceDescriptorProperty(Printer* out,
Expand Down Expand Up @@ -216,16 +224,20 @@ void GenerateInterface(Printer* out, const ServiceDescriptor* service) {

if (server_streaming) {
out->Print("IAsyncEnumerable<$output_type$> $method_name$",
"output_type", GetClassName(method->output_type()), "method_name", method->name());
"output_type", GetClassName(method->output_type()),
"method_name", CapitalizeFirstLetter(method->name()));
} else if (client_streaming) {
out->Print("Task<$output_type$> $method_name$",
"output_type", GetClassName(method->output_type()), "method_name", method->name());
"output_type", GetClassName(method->output_type()),
"method_name", CapitalizeFirstLetter(method->name()));
} else {
if (options.fire_and_forget()) {
out->Print("Task $method_name$", "method_name", method->name());
out->Print("Task $method_name$",
"method_name", CapitalizeFirstLetter(method->name()));
} else {
out->Print("Task<$output_type$> $method_name$",
"output_type", GetClassName(method->output_type()), "method_name", method->name());
"output_type", GetClassName(method->output_type()),
"method_name", CapitalizeFirstLetter(method->name()));
}
}

Expand Down Expand Up @@ -266,16 +278,20 @@ void GenerateClientClass(Printer* out, const ServiceDescriptor* service) {

if (server_streaming) {
out->Print("public IAsyncEnumerable<$output_type$> $method_name$",
"output_type", GetClassName(method->output_type()), "method_name", method->name());
"output_type", GetClassName(method->output_type()),
"method_name", CapitalizeFirstLetter(method->name()));
} else if (client_streaming) {
out->Print("public Task<$output_type$> $method_name$",
"output_type", GetClassName(method->output_type()), "method_name", method->name());
"output_type", GetClassName(method->output_type()),
"method_name", CapitalizeFirstLetter(method->name()));
} else {
if (options.fire_and_forget()) {
out->Print("public Task $method_name$", "method_name", method->name());
out->Print("public Task $method_name$",
"method_name", CapitalizeFirstLetter(method->name()));
} else {
out->Print("public Task<$output_type$> $method_name$",
"output_type", GetClassName(method->output_type()), "method_name", method->name());
"output_type", GetClassName(method->output_type()),
"method_name", CapitalizeFirstLetter(method->name()));
}
}

Expand All @@ -298,7 +314,11 @@ void GenerateClientClass(Printer* out, const ServiceDescriptor* service) {
"servicefield", GetServiceFieldName(),
"methodfield", GetMethodFieldName(method));
} else {
//TODO
out->Print("__RequestChannel(messages, $intransform$, $outtransform$, metadata, service: $servicefield$, method: $methodfield$).SingleAsync().AsTask();\n",
"intransform", "Google.Protobuf.MessageExtensions.ToByteArray",
"outtransform", GetClassName(method->output_type()) + ".Parser.ParseFrom",
"servicefield", GetServiceFieldName(),
"methodfield", GetMethodFieldName(method));
}
} else if (server_streaming) {
out->Print("__RequestStream(message, $intransform$, $outtransform$, metadata, service: $servicefield$, method: $methodfield$);\n",
Expand Down Expand Up @@ -366,16 +386,20 @@ void GenerateServerClass(Printer* out, const ServiceDescriptor* service) {

if (server_streaming) {
out->Print("public abstract IAsyncEnumerable<$output_type$> $method_name$",
"output_type", GetClassName(method->output_type()), "method_name", method->name());
"output_type", GetClassName(method->output_type()),
"method_name", CapitalizeFirstLetter(method->name()));
} else if (client_streaming) {
out->Print("public abstract Task<$output_type$> $method_name$",
"output_type", GetClassName(method->output_type()), "method_name", method->name());
"output_type", GetClassName(method->output_type()),
"method_name", CapitalizeFirstLetter(method->name()));
} else {
if (options.fire_and_forget()) {
out->Print("public abstract Task $method_name$", "method_name", method->name());
out->Print("public abstract Task $method_name$",
"method_name", CapitalizeFirstLetter(method->name()));
} else {
out->Print("public abstract Task<$output_type$> $method_name$",
"output_type", GetClassName(method->output_type()), "method_name", method->name());
"output_type", GetClassName(method->output_type()),
"method_name", CapitalizeFirstLetter(method->name()));
}
}

Expand Down Expand Up @@ -406,18 +430,33 @@ void GenerateServerClass(Printer* out, const ServiceDescriptor* service) {
bool server_streaming = method->server_streaming();

if (client_streaming) {
out->Print("case $methodfield$: return from result in service.$method_name$(from message in messages select $input_type$.Parser.ParseFrom(data.ToArray()), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result);\n",
"methodfield", GetMethodFieldName(method), "method_name", method->name(), "input_type", GetClassName(method->input_type()));
if (server_streaming) {
out->Print("case $methodfield$: return from result in service.$method_name$(from message in messages select $input_type$.Parser.ParseFrom(data.ToArray()), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result);\n",
"methodfield", GetMethodFieldName(method),
"method_name", CapitalizeFirstLetter(method->name()),
"input_type", GetClassName(method->input_type()));
} else {
out->Print("case $methodfield$: return from result in service.$method_name$(from message in messages select $input_type$.Parser.ParseFrom(data.ToArray()), metadata).ToAsyncEnumerable() select Google.Protobuf.MessageExtensions.ToByteArray(result);\n",
"methodfield", GetMethodFieldName(method),
"method_name", CapitalizeFirstLetter(method->name()),
"input_type", GetClassName(method->input_type()));
}
} else if (server_streaming) {
out->Print("case $methodfield$: return from result in service.$method_name$($input_type$.Parser.ParseFrom(data.ToArray()), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result);\n",
"methodfield", GetMethodFieldName(method), "method_name", method->name(), "input_type", GetClassName(method->input_type()));
"methodfield", GetMethodFieldName(method),
"method_name", CapitalizeFirstLetter(method->name()),
"input_type", GetClassName(method->input_type()));
} else {
if (options.fire_and_forget()) {
out->Print("case $methodfield$: return AsyncEnumerable.Empty<byte[]>();\n",
"methodfield", GetMethodFieldName(method));
out->Print("case $methodfield$: service.$method_name$($input_type$.Parser.ParseFrom(data.ToArray()), metadata); return AsyncEnumerable.Empty<byte[]>();\n",
"methodfield", GetMethodFieldName(method),
"method_name", CapitalizeFirstLetter(method->name()),
"input_type", GetClassName(method->input_type()));
} else {
out->Print("case $methodfield$: return from result in service.$method_name$($input_type$.Parser.ParseFrom(data.ToArray()), metadata).ToAsyncEnumerable() select Google.Protobuf.MessageExtensions.ToByteArray(result);\n",
"methodfield", GetMethodFieldName(method), "method_name", method->name(), "input_type", GetClassName(method->input_type()));
"methodfield", GetMethodFieldName(method),
"method_name", CapitalizeFirstLetter(method->name()),
"input_type", GetClassName(method->input_type()));
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion RSocket.Rpc.Sample/EchoService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ static IAsyncEnumerable<byte[]> Dispatch(IEchoService service, ReadOnlySequence<
{
switch (method)
{
case Method_fireAndForget: return AsyncEnumerable.Empty<byte[]>();
case Method_fireAndForget: service.FireAndForget(Google.Protobuf.WellKnownTypes.Value.Parser.ParseFrom(data.ToArray()), metadata); return AsyncEnumerable.Empty<byte[]>();
case Method_requestResponse: return from result in service.RequestResponse(Google.Protobuf.WellKnownTypes.Value.Parser.ParseFrom(data.ToArray()), metadata).ToAsyncEnumerable() select Google.Protobuf.MessageExtensions.ToByteArray(result);
case Method_requestStream: return from result in service.RequestStream(Google.Protobuf.WellKnownTypes.Value.Parser.ParseFrom(data.ToArray()), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result);
case Method_requestChannel: return from result in service.RequestChannel(from message in messages select Google.Protobuf.WellKnownTypes.Value.Parser.ParseFrom(message), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result);
Expand Down

0 comments on commit eaca23e

Please sign in to comment.