diff --git a/RSocket.Rpc.Protobuf/src/csharp_generator.cc b/RSocket.Rpc.Protobuf/src/csharp_generator.cc index 0d89852..dcb2236 100644 --- a/RSocket.Rpc.Protobuf/src/csharp_generator.cc +++ b/RSocket.Rpc.Protobuf/src/csharp_generator.cc @@ -132,6 +132,14 @@ void GenerateDocCommentMethod(google::protobuf::io::Printer* printer, } } +inline string CapitalizeFirstLetter(string s) { + if (s.empty()) { + return s; + } + s[0] = ::toupper(s[0]); + return s; +} + std::string GetServiceClassName(const ServiceDescriptor* service) { return service->name(); } @@ -151,7 +159,7 @@ std::string GetServerClassName(const ServiceDescriptor* service) { std::string GetServiceFieldName() { return "__Service"; } std::string GetMethodFieldName(const MethodDescriptor* method) { - return "__Method_" + method->name(); + return "__Method_" + CapitalizeFirstLetter(method->name()); } // Gets vector of all messages used as input or output types. @@ -176,8 +184,8 @@ std::vector GetUsedMessages( void GenerateStaticMethodField(Printer* out, const MethodDescriptor* method) { out->Print("public const string $methodfield$ = \"$methodname$\";\n", - "methodfield", GetMethodFieldName(method), "methodname", - method->name()); + "methodfield", GetMethodFieldName(method), + "methodname", CapitalizeFirstLetter(method->name())); } void GenerateServiceDescriptorProperty(Printer* out, @@ -216,16 +224,20 @@ void GenerateInterface(Printer* out, const ServiceDescriptor* service) { if (server_streaming) { out->Print("IAsyncEnumerable<$output_type$> $method_name$", - "output_type", GetClassName(method->output_type()), "method_name", method->name()); + "output_type", GetClassName(method->output_type()), + "method_name", CapitalizeFirstLetter(method->name())); } else if (client_streaming) { out->Print("Task<$output_type$> $method_name$", - "output_type", GetClassName(method->output_type()), "method_name", method->name()); + "output_type", GetClassName(method->output_type()), + "method_name", CapitalizeFirstLetter(method->name())); } else { if (options.fire_and_forget()) { - out->Print("Task $method_name$", "method_name", method->name()); + out->Print("Task $method_name$", + "method_name", CapitalizeFirstLetter(method->name())); } else { out->Print("Task<$output_type$> $method_name$", - "output_type", GetClassName(method->output_type()), "method_name", method->name()); + "output_type", GetClassName(method->output_type()), + "method_name", CapitalizeFirstLetter(method->name())); } } @@ -266,16 +278,20 @@ void GenerateClientClass(Printer* out, const ServiceDescriptor* service) { if (server_streaming) { out->Print("public IAsyncEnumerable<$output_type$> $method_name$", - "output_type", GetClassName(method->output_type()), "method_name", method->name()); + "output_type", GetClassName(method->output_type()), + "method_name", CapitalizeFirstLetter(method->name())); } else if (client_streaming) { out->Print("public Task<$output_type$> $method_name$", - "output_type", GetClassName(method->output_type()), "method_name", method->name()); + "output_type", GetClassName(method->output_type()), + "method_name", CapitalizeFirstLetter(method->name())); } else { if (options.fire_and_forget()) { - out->Print("public Task $method_name$", "method_name", method->name()); + out->Print("public Task $method_name$", + "method_name", CapitalizeFirstLetter(method->name())); } else { out->Print("public Task<$output_type$> $method_name$", - "output_type", GetClassName(method->output_type()), "method_name", method->name()); + "output_type", GetClassName(method->output_type()), + "method_name", CapitalizeFirstLetter(method->name())); } } @@ -298,7 +314,11 @@ void GenerateClientClass(Printer* out, const ServiceDescriptor* service) { "servicefield", GetServiceFieldName(), "methodfield", GetMethodFieldName(method)); } else { - //TODO + out->Print("__RequestChannel(messages, $intransform$, $outtransform$, metadata, service: $servicefield$, method: $methodfield$).SingleAsync().AsTask();\n", + "intransform", "Google.Protobuf.MessageExtensions.ToByteArray", + "outtransform", GetClassName(method->output_type()) + ".Parser.ParseFrom", + "servicefield", GetServiceFieldName(), + "methodfield", GetMethodFieldName(method)); } } else if (server_streaming) { out->Print("__RequestStream(message, $intransform$, $outtransform$, metadata, service: $servicefield$, method: $methodfield$);\n", @@ -366,16 +386,20 @@ void GenerateServerClass(Printer* out, const ServiceDescriptor* service) { if (server_streaming) { out->Print("public abstract IAsyncEnumerable<$output_type$> $method_name$", - "output_type", GetClassName(method->output_type()), "method_name", method->name()); + "output_type", GetClassName(method->output_type()), + "method_name", CapitalizeFirstLetter(method->name())); } else if (client_streaming) { out->Print("public abstract Task<$output_type$> $method_name$", - "output_type", GetClassName(method->output_type()), "method_name", method->name()); + "output_type", GetClassName(method->output_type()), + "method_name", CapitalizeFirstLetter(method->name())); } else { if (options.fire_and_forget()) { - out->Print("public abstract Task $method_name$", "method_name", method->name()); + out->Print("public abstract Task $method_name$", + "method_name", CapitalizeFirstLetter(method->name())); } else { out->Print("public abstract Task<$output_type$> $method_name$", - "output_type", GetClassName(method->output_type()), "method_name", method->name()); + "output_type", GetClassName(method->output_type()), + "method_name", CapitalizeFirstLetter(method->name())); } } @@ -406,18 +430,33 @@ void GenerateServerClass(Printer* out, const ServiceDescriptor* service) { bool server_streaming = method->server_streaming(); if (client_streaming) { - out->Print("case $methodfield$: return from result in service.$method_name$(from message in messages select $input_type$.Parser.ParseFrom(data.ToArray()), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result);\n", - "methodfield", GetMethodFieldName(method), "method_name", method->name(), "input_type", GetClassName(method->input_type())); + if (server_streaming) { + out->Print("case $methodfield$: return from result in service.$method_name$(from message in messages select $input_type$.Parser.ParseFrom(data.ToArray()), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result);\n", + "methodfield", GetMethodFieldName(method), + "method_name", CapitalizeFirstLetter(method->name()), + "input_type", GetClassName(method->input_type())); + } else { + out->Print("case $methodfield$: return from result in service.$method_name$(from message in messages select $input_type$.Parser.ParseFrom(data.ToArray()), metadata).ToAsyncEnumerable() select Google.Protobuf.MessageExtensions.ToByteArray(result);\n", + "methodfield", GetMethodFieldName(method), + "method_name", CapitalizeFirstLetter(method->name()), + "input_type", GetClassName(method->input_type())); + } } else if (server_streaming) { out->Print("case $methodfield$: return from result in service.$method_name$($input_type$.Parser.ParseFrom(data.ToArray()), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result);\n", - "methodfield", GetMethodFieldName(method), "method_name", method->name(), "input_type", GetClassName(method->input_type())); + "methodfield", GetMethodFieldName(method), + "method_name", CapitalizeFirstLetter(method->name()), + "input_type", GetClassName(method->input_type())); } else { if (options.fire_and_forget()) { - out->Print("case $methodfield$: return AsyncEnumerable.Empty();\n", - "methodfield", GetMethodFieldName(method)); + out->Print("case $methodfield$: service.$method_name$($input_type$.Parser.ParseFrom(data.ToArray()), metadata); return AsyncEnumerable.Empty();\n", + "methodfield", GetMethodFieldName(method), + "method_name", CapitalizeFirstLetter(method->name()), + "input_type", GetClassName(method->input_type())); } else { out->Print("case $methodfield$: return from result in service.$method_name$($input_type$.Parser.ParseFrom(data.ToArray()), metadata).ToAsyncEnumerable() select Google.Protobuf.MessageExtensions.ToByteArray(result);\n", - "methodfield", GetMethodFieldName(method), "method_name", method->name(), "input_type", GetClassName(method->input_type())); + "methodfield", GetMethodFieldName(method), + "method_name", CapitalizeFirstLetter(method->name()), + "input_type", GetClassName(method->input_type())); } } } diff --git a/RSocket.Rpc.Sample/EchoService.cs b/RSocket.Rpc.Sample/EchoService.cs index bfba2c3..17fd956 100644 --- a/RSocket.Rpc.Sample/EchoService.cs +++ b/RSocket.Rpc.Sample/EchoService.cs @@ -113,7 +113,7 @@ static IAsyncEnumerable Dispatch(IEchoService service, ReadOnlySequence< { switch (method) { - case Method_fireAndForget: return AsyncEnumerable.Empty(); + case Method_fireAndForget: service.FireAndForget(Google.Protobuf.WellKnownTypes.Value.Parser.ParseFrom(data.ToArray()), metadata); return AsyncEnumerable.Empty(); case Method_requestResponse: return from result in service.RequestResponse(Google.Protobuf.WellKnownTypes.Value.Parser.ParseFrom(data.ToArray()), metadata).ToAsyncEnumerable() select Google.Protobuf.MessageExtensions.ToByteArray(result); case Method_requestStream: return from result in service.RequestStream(Google.Protobuf.WellKnownTypes.Value.Parser.ParseFrom(data.ToArray()), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result); case Method_requestChannel: return from result in service.RequestChannel(from message in messages select Google.Protobuf.WellKnownTypes.Value.Parser.ParseFrom(message), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result);