diff --git a/code-gen/src/main/scala/scalapb/zio_grpc/ZioCodeGenerator.scala b/code-gen/src/main/scala/scalapb/zio_grpc/ZioCodeGenerator.scala index eb8da799..5136eba5 100644 --- a/code-gen/src/main/scala/scalapb/zio_grpc/ZioCodeGenerator.scala +++ b/code-gen/src/main/scala/scalapb/zio_grpc/ZioCodeGenerator.scala @@ -71,6 +71,7 @@ class ZioFilePrinter( val ZManagedChannel = "scalapb.zio_grpc.ZManagedChannel" val ZChannel = "scalapb.zio_grpc.ZChannel" val GTransform = "scalapb.zio_grpc.GTransform" + val RTransform = "scalapb.zio_grpc.RTransform" val Transform = "scalapb.zio_grpc.Transform" val Nanos = "java.util.concurrent.TimeUnit.NANOSECONDS" val serverServiceDef = "_root_.io.grpc.ServerServiceDefinition" @@ -246,48 +247,24 @@ class ZioFilePrinter( ) } - def printClientWithResponseMetadataTransform( + def printServerRTransform( fp: FunctionalPrinter, method: MethodDescriptor ): FunctionalPrinter = { val delegate = s"self.${method.name}" - val newImpl = method.streamType match { - case StreamType.Unary => - s"f.effect($delegate(request))" - case StreamType.ServerStreaming => - s"f.stream($delegate(request))" - case StreamType.ClientStreaming => - s"f.effect($delegate(request))" - case StreamType.Bidirectional => - s"f.stream($delegate(request))" - } - fp.add( - clientWithResponseMetadataSignature( - method, - "Any" - ) + " = " + newImpl - ) - } + val reqType = methodInType(method, StatusException) - def printClientTransform( - fp: FunctionalPrinter, - method: MethodDescriptor - ): FunctionalPrinter = { - val delegate = s"self.${method.name}" - val newImpl = method.streamType match { - case StreamType.Unary => - s"f.effect($delegate(request))" - case StreamType.ServerStreaming => - s"f.stream($delegate(request))" - case StreamType.ClientStreaming => - s"f.effect($delegate(request))" - case StreamType.Bidirectional => - s"f.stream($delegate(request))" + val newImpl = method.streamType match { + case StreamType.Unary | StreamType.ClientStreaming => + s"f.effect((req: $reqType, ctx) => $delegate(req, ctx))(request, context)" + case StreamType.ServerStreaming | StreamType.Bidirectional => + s"f.stream((req: $reqType, ctx) => $delegate(req, ctx))(request, context)" } fp.add( - clientMethodSignature( + methodSignature( method, - contextType = "Any" + contextType = Some("Context1"), + errorType = Some("Error1") ) + " = " + newImpl ) } @@ -301,6 +278,15 @@ class ZioFilePrinter( ) ).add("}") + def printRTransformMethod(fp: FunctionalPrinter): FunctionalPrinter = + fp.add( + s"def transform[Context1, Error1](f: $RTransform[Context, Error, Context1, Error1]): ${gtraitName.fullName}[Context1, Error1] = new ${gtraitName.fullName}[Context1, Error1] {" + ).indented( + _.print(service.getMethods().asScala.toVector)( + printServerRTransform + ) + ).add("}") + def print(fp: FunctionalPrinter): FunctionalPrinter = fp.add( s"trait ${gtraitName.name}[-Context, +Error] extends scalapb.zio_grpc.GenericGeneratedService[Context, Error, ${gtraitName.name}] {" @@ -314,6 +300,8 @@ class ZioFilePrinter( ) .add("") .call(printGTransformMethod) + .add("") + .call(printRTransformMethod) ).add("}") .add("") .add( diff --git a/core/src/main/scala/scalapb/zio_grpc/transforms.scala b/core/src/main/scala/scalapb/zio_grpc/transforms.scala index 190f6418..40d72bfd 100644 --- a/core/src/main/scala/scalapb/zio_grpc/transforms.scala +++ b/core/src/main/scala/scalapb/zio_grpc/transforms.scala @@ -131,6 +131,88 @@ object GTransform { } } +trait RTransform[+ContextIn, -ErrorIn, -ContextOut, +ErrorOut] { + self => + def effect[Req, Resp]( + io: (Req, ContextIn) => ZIO[Any, ErrorIn, Resp] + ): (Req, ContextOut) => ZIO[Any, ErrorOut, Resp] + + def stream[Req, Resp]( + io: (Req, ContextIn) => ZStream[Any, ErrorIn, Resp] + ): (Req, ContextOut) => ZStream[Any, ErrorOut, Resp] + + def andThen[ContextIn2 <: ContextOut, ErrorIn2 >: ErrorOut, ContextOut2, ErrorOut2]( + other: RTransform[ContextIn2, ErrorIn2, ContextOut2, ErrorOut2] + ): RTransform[ContextIn, ErrorIn, ContextOut2, ErrorOut2] = + new RTransform[ContextIn, ErrorIn, ContextOut2, ErrorOut2] { + def effect[Req, Resp]( + io: (Req, ContextIn) => ZIO[Any, ErrorIn, Resp] + ): (Req, ContextOut2) => ZIO[Any, ErrorOut2, Resp] = + other.effect(self.effect(io)) + + def stream[Req, Resp]( + io: (Req, ContextIn) => ZStream[Any, ErrorIn, Resp] + ): (Req, ContextOut2) => ZStream[Any, ErrorOut2, Resp] = + other.stream(self.stream(io)) + } + + def compose[ContextIn2, ErrorIn2, ContextOut2 >: ContextIn, ErrorOut2 <: ErrorIn]( + other: RTransform[ContextIn2, ErrorIn2, ContextOut2, ErrorOut2] + ): RTransform[ContextIn2, ErrorIn2, ContextOut, ErrorOut] = other.andThen(self) +} + +object RTransform { + + def identity[C, E]: RTransform[C, E, C, E] = new RTransform[C, E, C, E] { + def effect[Req, Resp](io: (Req, C) => ZIO[Any, E, Resp]): (Req, C) => ZIO[Any, E, Resp] = io + def stream[Req, Resp](io: (Req, C) => ZStream[Any, E, Resp]): (Req, C) => ZStream[Any, E, Resp] = io + } + + // Returns a GTransform that effectfully transforms the context parameter + def apply[ContextIn, Error, ContextOut]( + f: ContextOut => ZIO[Any, Error, ContextIn] + ): RTransform[ContextIn, Error, ContextOut, Error] = + new RTransform[ContextIn, Error, ContextOut, Error] { + def effect[Req, Resp]( + io: (Req, ContextIn) => ZIO[Any, Error, Resp] + ): (Req, ContextOut) => ZIO[Any, Error, Resp] = { (req, ctx) => + f(ctx).flatMap(newCtx => io(req, newCtx)) + } + + def stream[Req, Resp]( + io: (Req, ContextIn) => ZStream[Any, Error, Resp] + ): (Req, ContextOut) => ZStream[Any, Error, Resp] = { (req, ctx) => + ZStream.fromZIO(f(ctx)).flatMap(newCtx => io(req, newCtx)) + } + } + + // Returns a GTransform that maps the error parameter. + def mapError[C, E1, E2](f: E1 => E2): RTransform[C, E1, C, E2] = new RTransform[C, E1, C, E2] { + def effect[Req, Resp](io: (Req, C) => zio.ZIO[Any, E1, Resp]): (Req, C) => zio.ZIO[Any, E2, Resp] = { (req, ctx) => + io(req, ctx).mapError(f) + } + + def stream[Req, Resp]( + io: (Req, C) => zio.stream.ZStream[Any, E1, Resp] + ): (Req, C) => zio.stream.ZStream[Any, E2, Resp] = { (req, ctx) => + io(req, ctx).mapError(f) + } + } + + // Returns a GTransform that effectfully maps the error parameter. + def mapErrorZIO[C, E1, E2](f: E1 => zio.UIO[E2]): RTransform[C, E1, C, E2] = new RTransform[C, E1, C, E2] { + def effect[Req, Resp](io: (Req, C) => zio.ZIO[Any, E1, Resp]): (Req, C) => zio.ZIO[Any, E2, Resp] = { (req, ctx) => + io(req, ctx).flatMapError(f) + } + + def stream[Req, Resp]( + io: (Req, C) => zio.stream.ZStream[Any, E1, Resp] + ): (Req, C) => zio.stream.ZStream[Any, E2, Resp] = { (req, ctx) => + io(req, ctx).catchAll(e => ZStream.fromZIO(f(e).flatMap(ZIO.fail(_)))) + } + } +} + object ZTransform { def apply[ContextIn, ContextOut]( f: ContextOut => ZIO[Any, StatusException, ContextIn]