Skip to content

Commit

Permalink
add RTransform
Browse files Browse the repository at this point in the history
  • Loading branch information
BaekGeunYoung committed Mar 13, 2024
1 parent 6a5a562 commit ecac24f
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 34 deletions.
56 changes: 22 additions & 34 deletions code-gen/src/main/scala/scalapb/zio_grpc/ZioCodeGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ZioFilePrinter(
val ZManagedChannel = "scalapb.zio_grpc.ZManagedChannel"
val ZChannel = "scalapb.zio_grpc.ZChannel"
val GTransform = "scalapb.zio_grpc.GTransform"
val RTransform = "scalapb.zio_grpc.RTransform"
val Transform = "scalapb.zio_grpc.Transform"
val Nanos = "java.util.concurrent.TimeUnit.NANOSECONDS"
val serverServiceDef = "_root_.io.grpc.ServerServiceDefinition"
Expand Down Expand Up @@ -246,48 +247,24 @@ class ZioFilePrinter(
)
}

def printClientWithResponseMetadataTransform(
def printServerRTransform(
fp: FunctionalPrinter,
method: MethodDescriptor
): FunctionalPrinter = {
val delegate = s"self.${method.name}"
val newImpl = method.streamType match {
case StreamType.Unary =>
s"f.effect($delegate(request))"
case StreamType.ServerStreaming =>
s"f.stream($delegate(request))"
case StreamType.ClientStreaming =>
s"f.effect($delegate(request))"
case StreamType.Bidirectional =>
s"f.stream($delegate(request))"
}
fp.add(
clientWithResponseMetadataSignature(
method,
"Any"
) + " = " + newImpl
)
}
val reqType = methodInType(method, StatusException)

def printClientTransform(
fp: FunctionalPrinter,
method: MethodDescriptor
): FunctionalPrinter = {
val delegate = s"self.${method.name}"
val newImpl = method.streamType match {
case StreamType.Unary =>
s"f.effect($delegate(request))"
case StreamType.ServerStreaming =>
s"f.stream($delegate(request))"
case StreamType.ClientStreaming =>
s"f.effect($delegate(request))"
case StreamType.Bidirectional =>
s"f.stream($delegate(request))"
val newImpl = method.streamType match {
case StreamType.Unary | StreamType.ClientStreaming =>
s"f.effect((req: $reqType, ctx) => $delegate(req, ctx))(request, context)"
case StreamType.ServerStreaming | StreamType.Bidirectional =>
s"f.stream((req: $reqType, ctx) => $delegate(req, ctx))(request, context)"
}
fp.add(
clientMethodSignature(
methodSignature(
method,
contextType = "Any"
contextType = Some("Context1"),
errorType = Some("Error1")
) + " = " + newImpl
)
}
Expand All @@ -301,6 +278,15 @@ class ZioFilePrinter(
)
).add("}")

def printRTransformMethod(fp: FunctionalPrinter): FunctionalPrinter =
fp.add(
s"def transform[Context1, Error1](f: $RTransform[Context, Error, Context1, Error1]): ${gtraitName.fullName}[Context1, Error1] = new ${gtraitName.fullName}[Context1, Error1] {"
).indented(
_.print(service.getMethods().asScala.toVector)(
printServerRTransform
)
).add("}")

def print(fp: FunctionalPrinter): FunctionalPrinter =
fp.add(
s"trait ${gtraitName.name}[-Context, +Error] extends scalapb.zio_grpc.GenericGeneratedService[Context, Error, ${gtraitName.name}] {"
Expand All @@ -314,6 +300,8 @@ class ZioFilePrinter(
)
.add("")
.call(printGTransformMethod)
.add("")
.call(printRTransformMethod)
).add("}")
.add("")
.add(
Expand Down
82 changes: 82 additions & 0 deletions core/src/main/scala/scalapb/zio_grpc/transforms.scala
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,88 @@ object GTransform {
}
}

trait RTransform[+ContextIn, -ErrorIn, -ContextOut, +ErrorOut] {
self =>
def effect[Req, Resp](
io: (Req, ContextIn) => ZIO[Any, ErrorIn, Resp]
): (Req, ContextOut) => ZIO[Any, ErrorOut, Resp]

def stream[Req, Resp](
io: (Req, ContextIn) => ZStream[Any, ErrorIn, Resp]
): (Req, ContextOut) => ZStream[Any, ErrorOut, Resp]

def andThen[ContextIn2 <: ContextOut, ErrorIn2 >: ErrorOut, ContextOut2, ErrorOut2](
other: RTransform[ContextIn2, ErrorIn2, ContextOut2, ErrorOut2]
): RTransform[ContextIn, ErrorIn, ContextOut2, ErrorOut2] =
new RTransform[ContextIn, ErrorIn, ContextOut2, ErrorOut2] {
def effect[Req, Resp](
io: (Req, ContextIn) => ZIO[Any, ErrorIn, Resp]
): (Req, ContextOut2) => ZIO[Any, ErrorOut2, Resp] =
other.effect(self.effect(io))

def stream[Req, Resp](
io: (Req, ContextIn) => ZStream[Any, ErrorIn, Resp]
): (Req, ContextOut2) => ZStream[Any, ErrorOut2, Resp] =
other.stream(self.stream(io))
}

def compose[ContextIn2, ErrorIn2, ContextOut2 >: ContextIn, ErrorOut2 <: ErrorIn](
other: RTransform[ContextIn2, ErrorIn2, ContextOut2, ErrorOut2]
): RTransform[ContextIn2, ErrorIn2, ContextOut, ErrorOut] = other.andThen(self)
}

object RTransform {

def identity[C, E]: RTransform[C, E, C, E] = new RTransform[C, E, C, E] {
def effect[Req, Resp](io: (Req, C) => ZIO[Any, E, Resp]): (Req, C) => ZIO[Any, E, Resp] = io
def stream[Req, Resp](io: (Req, C) => ZStream[Any, E, Resp]): (Req, C) => ZStream[Any, E, Resp] = io
}

// Returns a GTransform that effectfully transforms the context parameter
def apply[ContextIn, Error, ContextOut](
f: ContextOut => ZIO[Any, Error, ContextIn]
): RTransform[ContextIn, Error, ContextOut, Error] =
new RTransform[ContextIn, Error, ContextOut, Error] {
def effect[Req, Resp](
io: (Req, ContextIn) => ZIO[Any, Error, Resp]
): (Req, ContextOut) => ZIO[Any, Error, Resp] = { (req, ctx) =>
f(ctx).flatMap(newCtx => io(req, newCtx))
}

def stream[Req, Resp](
io: (Req, ContextIn) => ZStream[Any, Error, Resp]
): (Req, ContextOut) => ZStream[Any, Error, Resp] = { (req, ctx) =>
ZStream.fromZIO(f(ctx)).flatMap(newCtx => io(req, newCtx))
}
}

// Returns a GTransform that maps the error parameter.
def mapError[C, E1, E2](f: E1 => E2): RTransform[C, E1, C, E2] = new RTransform[C, E1, C, E2] {
def effect[Req, Resp](io: (Req, C) => zio.ZIO[Any, E1, Resp]): (Req, C) => zio.ZIO[Any, E2, Resp] = { (req, ctx) =>
io(req, ctx).mapError(f)
}

def stream[Req, Resp](
io: (Req, C) => zio.stream.ZStream[Any, E1, Resp]
): (Req, C) => zio.stream.ZStream[Any, E2, Resp] = { (req, ctx) =>
io(req, ctx).mapError(f)
}
}

// Returns a GTransform that effectfully maps the error parameter.
def mapErrorZIO[C, E1, E2](f: E1 => zio.UIO[E2]): RTransform[C, E1, C, E2] = new RTransform[C, E1, C, E2] {
def effect[Req, Resp](io: (Req, C) => zio.ZIO[Any, E1, Resp]): (Req, C) => zio.ZIO[Any, E2, Resp] = { (req, ctx) =>
io(req, ctx).flatMapError(f)
}

def stream[Req, Resp](
io: (Req, C) => zio.stream.ZStream[Any, E1, Resp]
): (Req, C) => zio.stream.ZStream[Any, E2, Resp] = { (req, ctx) =>
io(req, ctx).catchAll(e => ZStream.fromZIO(f(e).flatMap(ZIO.fail(_))))
}
}
}

object ZTransform {
def apply[ContextIn, ContextOut](
f: ContextOut => ZIO[Any, StatusException, ContextIn]
Expand Down

0 comments on commit ecac24f

Please sign in to comment.