From de7e12161a1cb3e6e03cd2b9a0db75e872fccd60 Mon Sep 17 00:00:00 2001 From: Daniel Vigovszky Date: Wed, 19 Aug 2020 18:39:48 +0200 Subject: [PATCH 1/7] Detect custom pagination cases --- .../generator/OperationCollector.scala | 122 ++++++++++-------- 1 file changed, 67 insertions(+), 55 deletions(-) diff --git a/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/OperationCollector.scala b/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/OperationCollector.scala index 7275704e..bac38714 100644 --- a/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/OperationCollector.scala +++ b/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/OperationCollector.scala @@ -52,65 +52,77 @@ object OperationCollector { Option(op.getOutput).flatMap(output => Option(models.serviceModel().getShape(output.getShape))).exists(hasEventStreamMember(models, _)) def get(opName: String, op: Operation): ZIO[GeneratorContext, GeneratorFailure, OperationMethodType] = { - getModels.flatMap { models => - val inputIsStreaming = inputIsStreamingOf(models, op) - val outputIsStreaming = outputIsStreamingOf(models, op) + getService.flatMap { id => + getModels.flatMap { models => + val inputIsStreaming = inputIsStreamingOf(models, op) + val outputIsStreaming = outputIsStreamingOf(models, op) - val inputIsEventStream = inputIsEventStreamOf(models, op) - val outputIsEventStream = outputIsEventStreamOf(models, op) + val inputIsEventStream = inputIsEventStreamOf(models, op) + val outputIsEventStream = outputIsEventStreamOf(models, op) + + if (inputIsStreaming && outputIsStreaming) { + ZIO.succeed(StreamedInputOutput) + } else if (inputIsStreaming) { + ZIO.succeed(StreamedInput) + } else if (outputIsStreaming) { + ZIO.succeed(StreamedOutput) + } else if (inputIsEventStream && outputIsEventStream) { + ZIO.succeed(EventStreamInputOutput) + } else if (inputIsEventStream) { + ZIO.succeed(EventStreamInput) + } else if (outputIsEventStream) { + ZIO.succeed(EventStreamOutput) + } else { + Option(models.paginatorsModel().getPaginatorDefinition(opName)) match { + case Some(paginator) if paginator.isValid => + Option(paginator.getResultKey).flatMap(_.asScala.headOption) match { + case Some(key) => + val outputShape = models.serviceModel().getShape(op.getOutput.getShape) + outputShape.getMembers.asScala.get(key) match { + case Some(outputListMember) => + val listShape = models.serviceModel().getShape(outputListMember.getShape) + Option(listShape.getListMember) match { + case Some(itemMember) => + for { + itemModel <- context.get(itemMember.getShape) + itemType <- toJavaType(itemModel) + wrappedTypeRo <- toWrappedTypeReadOnly(itemModel) + } yield RequestResponse(pagination = Some(PaginationDefinition( + name = key, + model = itemModel, + itemType = itemType, + wrappedTypeRo = wrappedTypeRo + ))) + case None => + ZIO.succeed(RequestResponse(pagination = None)) + } + case None => + ZIO.succeed(RequestResponse(pagination = None)) + } + case None => + ZIO.succeed(RequestResponse(pagination = None)) + } + case _ => + if (op.getOutput == null && op.getInput == null) { + ZIO.succeed(UnitToUnit) + } else if (op.getOutput == null) { + ZIO.succeed(RequestToUnit) + } else if (op.getInput == null) { + ZIO.succeed(UnitToResponse) + } else { - if (inputIsStreaming && outputIsStreaming) { - ZIO.succeed(StreamedInputOutput) - } else if (inputIsStreaming) { - ZIO.succeed(StreamedInput) - } else if (outputIsStreaming) { - ZIO.succeed(StreamedOutput) - } else if (inputIsEventStream && outputIsEventStream) { - ZIO.succeed(EventStreamInputOutput) - } else if (inputIsEventStream) { - ZIO.succeed(EventStreamInput) - } else if (outputIsEventStream) { - ZIO.succeed(EventStreamOutput) - } else { - Option(models.paginatorsModel().getPaginatorDefinition(opName)) match { - case Some(paginator) if paginator.isValid => - Option(paginator.getResultKey).flatMap(_.asScala.headOption) match { - case Some(key) => val outputShape = models.serviceModel().getShape(op.getOutput.getShape) - outputShape.getMembers.asScala.get(key) match { - case Some(outputListMember) => - val listShape = models.serviceModel().getShape(outputListMember.getShape) - Option(listShape.getListMember) match { - case Some(itemMember) => - for { - itemModel <- context.get(itemMember.getShape) - itemType <- toJavaType(itemModel) - wrappedTypeRo <- toWrappedTypeReadOnly(itemModel) - } yield RequestResponse(pagination = Some(PaginationDefinition( - name = key, - model = itemModel, - itemType = itemType, - wrappedTypeRo = wrappedTypeRo - ))) - case None => - ZIO.succeed(RequestResponse(pagination = None)) - } - case None => - ZIO.succeed(RequestResponse(pagination = None)) + val inputShape = models.serviceModel().getShape(op.getInput.getShape) + + if (outputShape.getMembers.containsKey("NextToken") && + inputShape.getMembers.containsKey("NextToken")) { + // TODO: custom pagination + ZIO.succeed(RequestResponse(pagination = None)) + } else { + ZIO.succeed(RequestResponse(pagination = None)) } - case None => - ZIO.succeed(RequestResponse(pagination = None)) - } - case _ => - if (op.getOutput == null && op.getInput == null) { - ZIO.succeed(UnitToUnit) - } else if (op.getOutput == null) { - ZIO.succeed(RequestToUnit) - } else if (op.getInput == null) { - ZIO.succeed(UnitToResponse) - } else { - ZIO.succeed(RequestResponse(pagination = None)) - } + } + } } } } From aef14fc330d83f4a7bb1a204b73c306453912432 Mon Sep 17 00:00:00 2001 From: Daniel Vigovszky Date: Fri, 21 Aug 2020 18:11:40 +0200 Subject: [PATCH 2/7] Streaming support for paginated operations without a Java SDK paginator --- .../codegen/generator/GeneratorBase.scala | 29 +- .../codegen/generator/GeneratorFailure.scala | 3 +- .../generator/OperationCollector.scala | 255 +++++++++++++----- .../generator/PaginationDefinition.scala | 24 +- .../generator/ServiceInterfaceGenerator.scala | 207 +++++++++++--- .../generator/ServiceModelGenerator.scala | 24 -- .../vigoo/zioaws/core/AwsServiceBase.scala | 93 ++++++- .../zioaws/core/StreamingOutputResult.scala | 9 +- .../ZStreamAsyncResponseTransformer.scala | 4 +- .../zioaws/core/AwsServiceBaseSpec.scala | 36 +-- .../SimulatedAsyncResponseTransformer.scala | 12 +- 11 files changed, 523 insertions(+), 173 deletions(-) diff --git a/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/GeneratorBase.scala b/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/GeneratorBase.scala index a4e4141a..37e9e334 100644 --- a/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/GeneratorBase.scala +++ b/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/GeneratorBase.scala @@ -4,8 +4,11 @@ import java.nio.charset.StandardCharsets import java.nio.file.{Files, Path} import io.github.vigoo.zioaws.codegen.generator.context._ +import io.github.vigoo.zioaws.codegen.generator.syntax._ +import software.amazon.awssdk.codegen.model.config.customization.ShapeModifier import zio.{Chunk, ZIO} +import scala.jdk.CollectionConverters._ import scala.meta._ trait GeneratorBase { @@ -99,7 +102,31 @@ trait GeneratorBase { ZIO.succeed(q"""$term : ${Type.Name(model.name)}""") } - def writeIfDifferent(path: Path, contents: String): ZIO[Any, GeneratorFailure, Unit] = + protected def propertyName(model: Model, fieldModel: Model, name: String): ZIO[GeneratorContext, Nothing, String] = { + getNamingStrategy.flatMap { namingStrategy => + getModels.map { models => + val shapeModifiers = Option(models.customizationConfig().getShapeModifiers).map(_.asScala).getOrElse(Map.empty[String, ShapeModifier]) + shapeModifiers.get(model.shapeName).flatMap { shapeModifier => + val modifies = Option(shapeModifier.getModify).map(_.asScala).getOrElse(List.empty) + val matchingModifiers = modifies.flatMap { modifiesMap => + modifiesMap.asScala.map { case (key, value) => (key.toLowerCase, value) }.get(name.toLowerCase) + }.toList + + matchingModifiers + .map(modifier => Option(modifier.getEmitPropertyName)) + .find(_.isDefined) + .flatten.map(_.uncapitalize) + }.getOrElse { + val getterMethod = namingStrategy.getFluentGetterMethodName(name, model.shape, fieldModel.shape) + getterMethod + .stripSuffix("AsString") + .stripSuffix("AsStrings") + } + } + } + } + + protected def writeIfDifferent(path: Path, contents: String): ZIO[Any, GeneratorFailure, Unit] = ZIO(Files.exists(path)).mapError(FailedToReadFile).flatMap { exists => ZIO { if (exists) { diff --git a/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/GeneratorFailure.scala b/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/GeneratorFailure.scala index c493c656..a8de0877 100644 --- a/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/GeneratorFailure.scala +++ b/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/GeneratorFailure.scala @@ -8,4 +8,5 @@ case class FailedToCopy(reason: Throwable) extends GeneratorFailure case class FailedToDelete(reason: Throwable) extends GeneratorFailure case class CannotFindEventStreamInShape(service: String, name: String) extends GeneratorFailure case class UnknownShapeReference(service: String, name: String) extends GeneratorFailure -case class UnknownType(service: String, typ: String) extends GeneratorFailure \ No newline at end of file +case class UnknownType(service: String, typ: String) extends GeneratorFailure +case class InvalidPaginatedOperation(service: String, name: String) extends GeneratorFailure diff --git a/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/OperationCollector.scala b/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/OperationCollector.scala index bac38714..2f5ee722 100644 --- a/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/OperationCollector.scala +++ b/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/OperationCollector.scala @@ -2,6 +2,7 @@ package io.github.vigoo.zioaws.codegen.generator import io.github.vigoo.zioaws.codegen.generator.TypeMapping.{toJavaType, toWrappedTypeReadOnly} import io.github.vigoo.zioaws.codegen.generator.context._ +import io.github.vigoo.zioaws.codegen.loader import software.amazon.awssdk.codegen.C2jModels import software.amazon.awssdk.codegen.model.config.customization.CustomizationConfig import software.amazon.awssdk.codegen.model.service.{Operation, Shape} @@ -10,10 +11,33 @@ import zio.ZIO import scala.jdk.CollectionConverters._ object OperationCollector { - private def isExcluded(customizationConfig: CustomizationConfig, opName: String): Boolean = - Option(customizationConfig.getOperationModifiers) - .flatMap(_.asScala.get(opName)) - .exists(_.isExclude) + val overrides: Set[PaginationOverride] = Set( + PaginationNotSupported(loader.ModelId("greengrass", None), "GetDeviceDefinitionVersion"), + PaginationNotSupported(loader.ModelId("greengrass", None), "GetSubscriptionDefinitionVersion"), + PaginationNotSupported(loader.ModelId("greengrass", None), "GetFunctionDefinitionVersion"), + PaginationNotSupported(loader.ModelId("greengrass", None), "GetConnectorDefinitionVersion"), + PaginationNotSupported(loader.ModelId("budgets", None), "DescribeBudgetPerformanceHistory"), + + PaginationNotSupported(loader.ModelId("athena", None), "GetQueryResults"), + PaginationNotSupported(loader.ModelId("guardduty", None), "GetUsageStatistics"), + + SelectPaginatedStringMember(loader.ModelId("fms", None), "GetProtectionStatus", "Data"), + + SelectPaginatedListMember(loader.ModelId("cloudformation", None), "DescribeChangeSet", "Changes"), + SelectPaginatedListMember(loader.ModelId("ec2", None), "DescribeVpcEndpointServices", "ServiceDetails"), + SelectPaginatedListMember(loader.ModelId("pi", None), "DescribeDimensionKeys", "Keys"), + SelectPaginatedListMember(loader.ModelId("cognitosync", None), "ListRecords", "Records"), + SelectPaginatedListMember(loader.ModelId("textract", None), "GetDocumentAnalysis", "Blocks"), + SelectPaginatedListMember(loader.ModelId("textract", None), "GetDocumentTextDetection", "Blocks"), + + SelectPaginatedListMember(loader.ModelId("resourcegroups", None), "ListGroups", "Groups"), + SelectPaginatedListMember(loader.ModelId("resourcegroups", None), "SearchResources", "ResourceIdentifiers"), + SelectPaginatedListMember(loader.ModelId("resourcegroups", None), "ListGroupResources", "ResourceIdentifiers"), + ) + + case class OverrideKey(id: loader.ModelId, opName: String) + + val overrideMap: Map[OverrideKey, PaginationOverride] = overrides.map(o => o.toKey -> o).toMap def getFilteredOperations(models: C2jModels): Map[String, Operation] = models.serviceModel().getOperations.asScala @@ -21,24 +45,6 @@ object OperationCollector { .filter { case (_, op) => !op.isDeprecated } .filter { case (opName, _) => !isExcluded(models.customizationConfig(), opName) } - private def hasStreamingMember(models: C2jModels, shape: Shape, alreadyChecked: Set[Shape] = Set.empty): Boolean = - if (alreadyChecked(shape)) { - false - } else { - shape.isStreaming || shape.getMembers.asScala.values.exists { member => - member.isStreaming || hasStreamingMember(models, models.serviceModel().getShape(member.getShape), alreadyChecked + shape) - } - } - - private def hasEventStreamMember(models: C2jModels, shape: Shape, alreadyChecked: Set[Shape] = Set.empty): Boolean = - if (alreadyChecked(shape)) { - false - } else { - shape.isEventStream || shape.getMembers.asScala.values.exists { member => - hasEventStreamMember(models, models.serviceModel().getShape(member.getShape), alreadyChecked + shape) - } - } - def inputIsStreamingOf(models: C2jModels, op: Operation): Boolean = Option(op.getInput).flatMap(input => Option(models.serviceModel().getShape(input.getShape))).exists(hasStreamingMember(models, _)) @@ -73,58 +79,171 @@ object OperationCollector { } else if (outputIsEventStream) { ZIO.succeed(EventStreamOutput) } else { - Option(models.paginatorsModel().getPaginatorDefinition(opName)) match { - case Some(paginator) if paginator.isValid => - Option(paginator.getResultKey).flatMap(_.asScala.headOption) match { - case Some(key) => - val outputShape = models.serviceModel().getShape(op.getOutput.getShape) - outputShape.getMembers.asScala.get(key) match { - case Some(outputListMember) => - val listShape = models.serviceModel().getShape(outputListMember.getShape) - Option(listShape.getListMember) match { - case Some(itemMember) => - for { - itemModel <- context.get(itemMember.getShape) - itemType <- toJavaType(itemModel) - wrappedTypeRo <- toWrappedTypeReadOnly(itemModel) - } yield RequestResponse(pagination = Some(PaginationDefinition( - name = key, - model = itemModel, - itemType = itemType, - wrappedTypeRo = wrappedTypeRo - ))) - case None => - ZIO.succeed(RequestResponse(pagination = None)) - } - case None => - ZIO.succeed(RequestResponse(pagination = None)) - } - case None => - ZIO.succeed(RequestResponse(pagination = None)) + if (op.getOutput == null && op.getInput == null) { + ZIO.succeed(UnitToUnit) + } else if (op.getOutput == null) { + ZIO.succeed(RequestToUnit) + } else if (op.getInput == null) { + ZIO.succeed(UnitToResponse) + } else { + + val outputShape = models.serviceModel().getShape(op.getOutput.getShape) + val inputShape = models.serviceModel().getShape(op.getInput.getShape) + + if (outputShape.getMembers.containsKey("NextToken") && + inputShape.getMembers.containsKey("NextToken")) { + + getPaginationDefinition(opName, op).map { paginationDefinition => + RequestResponse(paginationDefinition) } - case _ => - if (op.getOutput == null && op.getInput == null) { - ZIO.succeed(UnitToUnit) - } else if (op.getOutput == null) { - ZIO.succeed(RequestToUnit) - } else if (op.getInput == null) { - ZIO.succeed(UnitToResponse) - } else { - - val outputShape = models.serviceModel().getShape(op.getOutput.getShape) - val inputShape = models.serviceModel().getShape(op.getInput.getShape) - - if (outputShape.getMembers.containsKey("NextToken") && - inputShape.getMembers.containsKey("NextToken")) { - // TODO: custom pagination - ZIO.succeed(RequestResponse(pagination = None)) - } else { + } else { + getJavaSdkPaginatorDefinition(opName, op, models) match { + case Some(createDef) => + // Special paginator with Java SDK support + for { + paginatorDef <- createDef + } yield RequestResponse(pagination = Some(paginatorDef)) + case None => ZIO.succeed(RequestResponse(pagination = None)) - } } + } } } } } } + + private def getPaginationDefinition(opName: String, op: Operation): ZIO[GeneratorContext, GeneratorFailure, Option[PaginationDefinition]] = { + getService.flatMap { id => + getModels.flatMap { models => + val outputShape = models.serviceModel().getShape(op.getOutput.getShape) + + overrideMap.get(OverrideKey(id, opName)) match { + case Some(PaginationNotSupported(_, _)) => + ZIO.none + case Some(SelectPaginatedListMember(_, _, memberName)) => + val listShapeName = outputShape.getMembers.get(memberName).getShape + for { + listModel <- context.get(listShapeName) + listShape = listModel.shape + itemShapeName = listShape.getListMember.getShape + itemModel <- context.get(itemShapeName) + } yield Some(ListPaginationDefinition(memberName, listModel, itemModel, isSimple = false)) + case Some(SelectPaginatedStringMember(_, _, memberName)) => + val stringShapeName = outputShape.getMembers.get(memberName).getShape + for { + stringModel <- context.get(stringShapeName) + } yield Some(StringPaginationDefinition(memberName, stringModel, isSimple = false)) + case None => + val otherOutputMembers = outputShape.getMembers.asScala.toMap - "NextToken" + val outputMembersWithListType = otherOutputMembers.filter { case (name, member) => + models.serviceModel().getShape(member.getShape).getType == "list" + } + val outputMembersWithMapType = otherOutputMembers.filter { case (name, member) => + models.serviceModel().getShape(member.getShape).getType == "map" + } + val outputMembersWithStringType = otherOutputMembers.filter { case (name, member) => + val shape = models.serviceModel().getShape(member.getShape) + shape.getType == "string" && Option(shape.getEnumValues).map(_.asScala).getOrElse(List.empty).isEmpty + } + + val isSimple = otherOutputMembers.size == 1 + + if (outputMembersWithListType.size == 1) { + val memberName = outputMembersWithListType.keys.head + val listShapeName = outputMembersWithListType.values.head.getShape + for { + listModel <- context.get(listShapeName) + listShape = listModel.shape + itemShapeName = listShape.getListMember.getShape + itemModel <- context.get(itemShapeName) + } yield Some(ListPaginationDefinition(memberName, listModel, itemModel, isSimple)) + } else if (outputMembersWithMapType.size == 1) { + val memberName = outputMembersWithMapType.keys.head + val mapShapeName = outputMembersWithMapType.values.head.getShape + for { + mapModel <- context.get(mapShapeName) + mapShape = mapModel.shape + keyModel <- context.get(mapShape.getMapKeyType.getShape) + valueModel <- context.get(mapShape.getMapValueType.getShape) + } yield Some(MapPaginationDefinition(memberName, mapModel, keyModel, valueModel, isSimple)) + } else if (outputMembersWithStringType.size == 1) { + val memberName = outputMembersWithStringType.keys.head + val stringShapeName = outputMembersWithStringType.values.head.getShape + for { + stringModel <- context.get(stringShapeName) + } yield Some(StringPaginationDefinition(memberName, stringModel, isSimple)) + } else { + // Fall back to Java SDK paginator if possible + getJavaSdkPaginatorDefinition(opName, op, models) match { + case Some(definition) => + definition.map(Some(_)) + case None => + ZIO.fail(InvalidPaginatedOperation(id.toString, opName)) + } + } + } + } + } + } + + private def getJavaSdkPaginatorDefinition(opName: String, op: Operation, models: C2jModels) = { + for { + paginator <- Option(models.paginatorsModel().getPaginatorDefinition(opName)) + if paginator.isValid + key <- Option(paginator.getResultKey).flatMap(_.asScala.headOption) + outputShape = models.serviceModel().getShape(op.getOutput.getShape) + outputListMember <- outputShape.getMembers.asScala.get(key) + listShape = models.serviceModel().getShape(outputListMember.getShape) + itemMember <- Option(listShape.getListMember) + } yield for { + itemModel <- context.get(itemMember.getShape) + itemType <- toJavaType(itemModel) + wrappedTypeRo <- toWrappedTypeReadOnly(itemModel) + } yield JavaSdkPaginationDefinition( + name = key, + model = itemModel, + itemType = itemType, + wrappedTypeRo = wrappedTypeRo) + } + + private def isExcluded(customizationConfig: CustomizationConfig, opName: String): Boolean = + Option(customizationConfig.getOperationModifiers) + .flatMap(_.asScala.get(opName)) + .exists(_.isExclude) + + private def hasStreamingMember(models: C2jModels, shape: Shape, alreadyChecked: Set[Shape] = Set.empty): Boolean = + if (alreadyChecked(shape)) { + false + } else { + shape.isStreaming || shape.getMembers.asScala.values.exists { member => + member.isStreaming || hasStreamingMember(models, models.serviceModel().getShape(member.getShape), alreadyChecked + shape) + } + } + + private def hasEventStreamMember(models: C2jModels, shape: Shape, alreadyChecked: Set[Shape] = Set.empty): Boolean = + if (alreadyChecked(shape)) { + false + } else { + shape.isEventStream || shape.getMembers.asScala.values.exists { member => + hasEventStreamMember(models, models.serviceModel().getShape(member.getShape), alreadyChecked + shape) + } + } + + sealed trait PaginationOverride { + def toKey: OverrideKey + } + + case class PaginationNotSupported(id: loader.ModelId, opName: String) extends PaginationOverride { + override def toKey: OverrideKey = OverrideKey(id, opName) + } + + case class SelectPaginatedListMember(id: loader.ModelId, opName: String, memberName: String) extends PaginationOverride { + override def toKey: OverrideKey = OverrideKey(id, opName) + } + + case class SelectPaginatedStringMember(id: loader.ModelId, opName: String, memberName: String) extends PaginationOverride { + override def toKey: OverrideKey = OverrideKey(id, opName) + } + } diff --git a/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/PaginationDefinition.scala b/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/PaginationDefinition.scala index a637ba5a..0bd3991a 100644 --- a/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/PaginationDefinition.scala +++ b/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/PaginationDefinition.scala @@ -2,8 +2,24 @@ package io.github.vigoo.zioaws.codegen.generator import scala.meta.Type -case class PaginationDefinition(name: String, - model: Model, - itemType: Type, - wrappedTypeRo: Type) +sealed trait PaginationDefinition +case class JavaSdkPaginationDefinition(name: String, + model: Model, + itemType: Type, + wrappedTypeRo: Type) extends PaginationDefinition + +case class ListPaginationDefinition(memberName: String, + listModel: Model, + itemModel: Model, + isSimple: Boolean) extends PaginationDefinition + +case class MapPaginationDefinition(memberName: String, + mapModel: Model, + keyModel: Model, + valueModel: Model, + isSimple: Boolean) extends PaginationDefinition + +case class StringPaginationDefinition(memberName: String, + stringModel: Model, + isSimple: Boolean) extends PaginationDefinition diff --git a/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/ServiceInterfaceGenerator.scala b/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/ServiceInterfaceGenerator.scala index 8e1fec9c..5e625a91 100644 --- a/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/ServiceInterfaceGenerator.scala +++ b/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/ServiceInterfaceGenerator.scala @@ -180,12 +180,12 @@ trait ServiceInterfaceGenerator { ZIO.succeed(ServiceMethods( ServiceMethod( interface = - q"""def $methodName(request: $requestName, body: zio.stream.ZStream[Any, AwsError, Byte]): IO[AwsError, StreamingOutputResult[$responseTypeRo]]""", + q"""def $methodName(request: $requestName, body: zio.stream.ZStream[Any, AwsError, Byte]): IO[AwsError, StreamingOutputResult[$responseTypeRo, Byte]]""", implementation = - q"""def $methodName(request: $requestName, body: zio.stream.ZStream[Any, AwsError, Byte]): IO[AwsError, StreamingOutputResult[$responseTypeRo]] = - asyncRequestInputOutputStream[$modelPkg.$requestName, $modelPkg.$responseName](api.$methodName[zio.Task[StreamingOutputResult[$modelPkg.$responseName]]])(request.buildAwsValue(), body).map(_.map($responseNameTerm.wrap))""", + q"""def $methodName(request: $requestName, body: zio.stream.ZStream[Any, AwsError, Byte]): IO[AwsError, StreamingOutputResult[$responseTypeRo, Byte]] = + asyncRequestInputOutputStream[$modelPkg.$requestName, $modelPkg.$responseName](api.$methodName[zio.Task[StreamingOutputResult[$modelPkg.$responseName, Byte]]])(request.buildAwsValue(), body).map(_.mapResponse($responseNameTerm.wrap))""", accessor = - q"""def $methodName(request: $requestName, body: zio.stream.ZStream[Any, AwsError, Byte]): ZIO[$serviceNameT, AwsError, StreamingOutputResult[$responseTypeRo]] = + q"""def $methodName(request: $requestName, body: zio.stream.ZStream[Any, AwsError, Byte]): ZIO[$serviceNameT, AwsError, StreamingOutputResult[$responseTypeRo, Byte]] = ZIO.accessM(_.get.$methodName(request, body))""" ))) } @@ -194,12 +194,12 @@ trait ServiceInterfaceGenerator { ZIO.succeed(ServiceMethods( ServiceMethod( interface = - q"""def $methodName(request: $requestName): IO[AwsError, StreamingOutputResult[$responseTypeRo]]""", + q"""def $methodName(request: $requestName): IO[AwsError, StreamingOutputResult[$responseTypeRo, Byte]]""", implementation = - q"""def $methodName(request: $requestName): IO[AwsError, StreamingOutputResult[$responseTypeRo]] = - asyncRequestOutputStream[$modelPkg.$requestName, $modelPkg.$responseName](api.$methodName[zio.Task[StreamingOutputResult[$modelPkg.$responseName]]])(request.buildAwsValue()).map(_.map($responseNameTerm.wrap))""", + q"""def $methodName(request: $requestName): IO[AwsError, StreamingOutputResult[$responseTypeRo, Byte]] = + asyncRequestOutputStream[$modelPkg.$requestName, $modelPkg.$responseName](api.$methodName[zio.Task[StreamingOutputResult[$modelPkg.$responseName, Byte]]])(request.buildAwsValue()).map(_.mapResponse($responseNameTerm.wrap))""", accessor = - q"""def $methodName(request: $requestName): ZIO[$serviceNameT, AwsError, StreamingOutputResult[$responseTypeRo]] = + q"""def $methodName(request: $requestName): ZIO[$serviceNameT, AwsError, StreamingOutputResult[$responseTypeRo, Byte]] = ZIO.accessM(_.get.$methodName(request))""" ))) } @@ -219,39 +219,178 @@ trait ServiceInterfaceGenerator { } private def generateRequestToResponse(opName: String, serviceNameT: Type.Name, methodName: Term.Name, requestName: Type.Name, responseName: Type.Name, responseNameTerm: Term.Name, responseTypeRo: Type.Select, modelPkg: Term.Ref, pagination: Option[PaginationDefinition]) = { - val raw = ServiceMethod( - interface = - q"""def $methodName(request: $requestName): IO[AwsError, $responseTypeRo]""", - implementation = - q"""def $methodName(request: $requestName): IO[AwsError, $responseTypeRo] = - asyncRequestResponse[$modelPkg.$requestName, $modelPkg.$responseName](api.$methodName)(request.buildAwsValue()).map($responseNameTerm.wrap)""", - accessor = - q"""def $methodName(request: $requestName): ZIO[$serviceNameT, AwsError, $responseTypeRo] = - ZIO.accessM(_.get.$methodName(request))""" - ) pagination match { - case Some(pagination) => + case Some(JavaSdkPaginationDefinition(paginationName, itemModel, _, wrappedItemType)) => for { paginatorPkg <- getPaginatorPkg - methodNameStream = Term.Name(methodName.value + "Stream") - wrappedItemType = pagination.wrappedTypeRo paginatorMethodName = Term.Name(methodName.value + "Paginator") publisherType = Type.Name(opName + "Publisher") - wrappedItem <- wrapSdkValue(pagination.model, Term.Name("item")) - awsItemType <- TypeMapping.toJavaType(pagination.model) + wrappedItem <- wrapSdkValue(itemModel, Term.Name("item")) + awsItemType <- TypeMapping.toJavaType(itemModel) streamedPaginator = ServiceMethod( interface = - q"""def $methodNameStream(request: $requestName): zio.stream.ZStream[Any, AwsError, $wrappedItemType]""", + q"""def $methodName(request: $requestName): zio.stream.ZStream[Any, AwsError, $wrappedItemType]""", implementation = - q"""def $methodNameStream(request: $requestName): zio.stream.ZStream[Any, AwsError, $wrappedItemType] = - asyncPaginatedRequest[$modelPkg.$requestName, $awsItemType, $paginatorPkg.$publisherType](api.$paginatorMethodName, _.${Term.Name(pagination.name.uncapitalize)}())(request.buildAwsValue()).map(item => $wrappedItem)""", + q"""def $methodName(request: $requestName): zio.stream.ZStream[Any, AwsError, $wrappedItemType] = + asyncJavaPaginatedRequest[$modelPkg.$requestName, $awsItemType, $paginatorPkg.$publisherType](api.$paginatorMethodName, _.${Term.Name(paginationName.uncapitalize)}())(request.buildAwsValue()).map(item => $wrappedItem)""", accessor = - q"""def $methodNameStream(request: $requestName): zio.stream.ZStream[$serviceNameT, AwsError, $wrappedItemType] = - ZStream.accessStream(_.get.$methodNameStream(request))""" + q"""def $methodName(request: $requestName): zio.stream.ZStream[$serviceNameT, AwsError, $wrappedItemType] = + ZStream.accessStream(_.get.$methodName(request))""" ) - } yield ServiceMethods(raw, streamedPaginator) + } yield ServiceMethods(streamedPaginator) + case Some(ListPaginationDefinition(memberName, listModel, itemModel, isSimple)) => + for { + wrappedItem <- wrapSdkValue(itemModel, Term.Name("item")) + itemTypeRo <- TypeMapping.toWrappedTypeReadOnly(itemModel) + awsItemType <- TypeMapping.toJavaType(itemModel) + responseModel <- context.get(responseName.value) + property <- propertyName(responseModel, listModel, memberName) + propertyNameTerm = Term.Name(property) + streamedPaginator = if (isSimple) { + ServiceMethod( + interface = + q"""def $methodName(request: $requestName): ZStream[Any, AwsError, $itemTypeRo]""", + implementation = + q"""def $methodName(request: $requestName): ZStream[Any, AwsError, $itemTypeRo] = + asyncSimplePaginatedRequest[$modelPkg.$requestName, $modelPkg.$responseName, $awsItemType]( + api.$methodName, + (r, token) => r.toBuilder().nextToken(token).build(), + r => Option(r.nextToken()), + r => Chunk.fromIterable(r.$propertyNameTerm().asScala) + )(request.buildAwsValue()).map(item => $wrappedItem) + """, + accessor = + q"""def $methodName(request: $requestName): ZStream[$serviceNameT, AwsError, $itemTypeRo] = + ZStream.accessStream(_.get.$methodName(request)) + """ + ) + } else { + ServiceMethod( + interface = + q"""def $methodName(request: $requestName): ZIO[Any, AwsError, StreamingOutputResult[$responseTypeRo, $itemTypeRo]]""", + implementation = + q"""def $methodName(request: $requestName): ZIO[Any, AwsError, StreamingOutputResult[$responseTypeRo, $itemTypeRo]] = + asyncPaginatedRequest[$modelPkg.$requestName, $modelPkg.$responseName, $awsItemType]( + api.$methodName, + (r, token) => r.toBuilder().nextToken(token).build(), + r => Option(r.nextToken()), + r => Chunk.fromIterable(r.$propertyNameTerm().asScala) + )(request.buildAwsValue()).map(result => result.mapResponse($responseNameTerm.wrap).mapOutput(_.map(item => $wrappedItem))) + """, + accessor = + q"""def $methodName(request: $requestName): ZIO[$serviceNameT, AwsError, StreamingOutputResult[$responseTypeRo, $itemTypeRo]] = + ZIO.accessM(_.get.$methodName(request)) + """ + ) + } + } yield ServiceMethods(streamedPaginator) + case Some(MapPaginationDefinition(memberName, mapModel, keyModel, valueModel, isSimple)) => + for { + wrappedKey <- wrapSdkValue(keyModel, Term.Name("key")) + keyTypeRo <- TypeMapping.toWrappedTypeReadOnly(keyModel) + awsKeyType <- TypeMapping.toJavaType(keyModel) + wrappedValue <- wrapSdkValue(valueModel, Term.Name("value")) + valueTypeRo <- TypeMapping.toWrappedTypeReadOnly(valueModel) + awsValueType <- TypeMapping.toJavaType(valueModel) + + responseModel <- context.get(responseName.value) + property <- propertyName(responseModel, mapModel, memberName) + propertyNameTerm = Term.Name(property) + streamedPaginator = if (isSimple) { + ServiceMethod( + interface = + q"""def $methodName(request: $requestName): ZStream[Any, AwsError, ($keyTypeRo, $valueTypeRo)]""", + implementation = + q"""def $methodName(request: $requestName): ZStream[Any, AwsError, ($keyTypeRo, $valueTypeRo)] = + asyncSimplePaginatedRequest[$modelPkg.$requestName, $modelPkg.$responseName, ($awsKeyType, $awsValueType)]( + api.$methodName, + (r, token) => r.toBuilder().nextToken(token).build(), + r => Option(r.nextToken()), + r => Chunk.fromIterable(r.$propertyNameTerm().asScala) + )(request.buildAwsValue()).map { case (key, value) => $wrappedKey -> $wrappedValue } + """, + accessor = + q"""def $methodName(request: $requestName): ZStream[$serviceNameT, AwsError, ($keyTypeRo, $valueTypeRo)] = + ZStream.accessStream(_.get.$methodName(request)) + """ + ) + } else { + ServiceMethod( + interface = + q"""def $methodName(request: $requestName): ZIO[Any, AwsError, StreamingOutputResult[$responseTypeRo, ($keyTypeRo, $valueTypeRo)]]""", + implementation = + q"""def $methodName(request: $requestName): ZIO[Any, AwsError, StreamingOutputResult[$responseTypeRo, ($keyTypeRo, $valueTypeRo)]] = + asyncPaginatedRequest[$modelPkg.$requestName, $modelPkg.$responseName, ($awsKeyType, $awsValueType)]( + api.$methodName, + (r, token) => r.toBuilder().nextToken(token).build(), + r => Option(r.nextToken()), + r => Chunk.fromIterable(r.$propertyNameTerm().asScala) + )(request.buildAwsValue()).map(result => result.mapResponse($responseNameTerm.wrap).mapOutput(_.map { case (key, value) => $wrappedKey -> $wrappedValue })) + """, + accessor = + q"""def $methodName(request: $requestName): ZIO[$serviceNameT, AwsError, StreamingOutputResult[$responseTypeRo, ($keyTypeRo, $valueTypeRo)]] = + ZIO.accessM(_.get.$methodName(request)) + """ + ) + } + } yield ServiceMethods(streamedPaginator) + case Some(StringPaginationDefinition(memberName, stringModel, isSimple)) => + for { + wrappedItem <- wrapSdkValue(stringModel, Term.Name("item")) + itemTypeRo <- TypeMapping.toWrappedTypeReadOnly(stringModel) + awsItemType <- TypeMapping.toJavaType(stringModel) + responseModel <- context.get(responseName.value) + property <- propertyName(responseModel, stringModel, memberName) + propertyNameTerm = Term.Name(property) + streamedPaginator = if (isSimple) { + ServiceMethod( + interface = + q"""def $methodName(request: $requestName): ZStream[Any, AwsError, $itemTypeRo]""", + implementation = + q"""def $methodName(request: $requestName): ZStream[Any, AwsError, $itemTypeRo] = + asyncSimplePaginatedRequest[$modelPkg.$requestName, $modelPkg.$responseName, $awsItemType]( + api.$methodName, + (r, token) => r.toBuilder().nextToken(token).build(), + r => Option(r.nextToken()), + r => Chunk(r.$propertyNameTerm()) + )(request.buildAwsValue()).map(item => $wrappedItem) + """, + accessor = + q"""def $methodName(request: $requestName): ZStream[$serviceNameT, AwsError, $itemTypeRo] = + ZStream.accessStream(_.get.$methodName(request)) + """ + ) + } else { + ServiceMethod( + interface = + q"""def $methodName(request: $requestName): ZIO[Any, AwsError, StreamingOutputResult[$responseTypeRo, $itemTypeRo]]""", + implementation = + q"""def $methodName(request: $requestName): ZIO[Any, AwsError, StreamingOutputResult[$responseTypeRo, $itemTypeRo]] = + asyncPaginatedRequest[$modelPkg.$requestName, $modelPkg.$responseName, $awsItemType]( + api.$methodName, + (r, token) => r.toBuilder().nextToken(token).build(), + r => Option(r.nextToken()), + r => Chunk(r.$propertyNameTerm()) + )(request.buildAwsValue()).map(result => result.mapResponse($responseNameTerm.wrap).mapOutput(_.map(item => $wrappedItem))) + """, + accessor = + q"""def $methodName(request: $requestName): ZIO[$serviceNameT, AwsError, StreamingOutputResult[$responseTypeRo, $itemTypeRo]] = + ZIO.accessM(_.get.$methodName(request)) + """ + ) + } + } yield ServiceMethods(streamedPaginator) case None => - ZIO.succeed(ServiceMethods(raw)) + ZIO.succeed(ServiceMethods(ServiceMethod( + interface = + q"""def $methodName(request: $requestName): IO[AwsError, $responseTypeRo]""", + implementation = + q"""def $methodName(request: $requestName): IO[AwsError, $responseTypeRo] = + asyncRequestResponse[$modelPkg.$requestName, $modelPkg.$responseName](api.$methodName)(request.buildAwsValue()).map($responseNameTerm.wrap)""", + accessor = + q"""def $methodName(request: $requestName): ZIO[$serviceNameT, AwsError, $responseTypeRo] = + ZIO.accessM(_.get.$methodName(request))""" + ))) } } @@ -315,13 +454,13 @@ trait ServiceInterfaceGenerator { serviceMethods <- generateServiceMethods() ops <- awsModel.getOperations - paginations <- ZIO.foreach(ops) { case (opName, op) => + javaSdkPaginations <- ZIO.foreach(ops) { case (opName, op) => OperationCollector.get(opName, op).map { - case RequestResponse(Some(_)) => true + case RequestResponse(Some(JavaSdkPaginationDefinition(_, _, _, _))) => true case _ => false } } - hasPaginators = paginations.exists(identity) + usesJavaSdkPaginators = javaSdkPaginations.exists(identity) serviceMethodIfaces = serviceMethods.flatMap(_.methods.map(_.interface)) serviceMethodImpls = serviceMethods.flatMap(_.methods.map(_.implementation)) @@ -330,7 +469,7 @@ trait ServiceInterfaceGenerator { imports = List( Some(q"""import io.github.vigoo.zioaws.core._"""), Some(q"""import io.github.vigoo.zioaws.core.config.AwsConfig"""), - if (hasPaginators) + if (usesJavaSdkPaginators) Some(Import(List(Importer(paginatorPackage, List(Importee.Wildcard()))))) else None, Some(id.subModuleName match { diff --git a/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/ServiceModelGenerator.scala b/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/ServiceModelGenerator.scala index 1f4b6b7d..28fbb1d4 100644 --- a/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/ServiceModelGenerator.scala +++ b/zio-aws-codegen/src/main/scala/io/github/vigoo/zioaws/codegen/generator/ServiceModelGenerator.scala @@ -57,30 +57,6 @@ trait ServiceModelGenerator { } } - private def propertyName(model: Model, fieldModel: Model, name: String): ZIO[GeneratorContext, Nothing, String] = { - getNamingStrategy.flatMap { namingStrategy => - getModels.map { models => - val shapeModifiers = Option(models.customizationConfig().getShapeModifiers).map(_.asScala).getOrElse(Map.empty[String, ShapeModifier]) - shapeModifiers.get(model.shapeName).flatMap { shapeModifier => - val modifies = Option(shapeModifier.getModify).map(_.asScala).getOrElse(List.empty) - val matchingModifiers = modifies.flatMap { modifiesMap => - modifiesMap.asScala.map { case (key, value) => (key.toLowerCase, value) }.get(name.toLowerCase) - }.toList - - matchingModifiers - .map(modifier => Option(modifier.getEmitPropertyName)) - .find(_.isDefined) - .flatten.map(_.uncapitalize) - }.getOrElse { - val getterMethod = namingStrategy.getFluentGetterMethodName(name, model.shape, fieldModel.shape) - getterMethod - .stripSuffix("AsString") - .stripSuffix("AsStrings") - } - } - } - } - private def applyEnumModifiers(model: Model, enumValueList: List[String]): ZIO[GeneratorContext, Nothing, List[String]] = { getModels.map { models => val shapeModifiers = Option(models.customizationConfig().getShapeModifiers).map(_.asScala).getOrElse(Map.empty[String, ShapeModifier]) diff --git a/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/AwsServiceBase.scala b/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/AwsServiceBase.scala index 8d6e456a..d05995ab 100644 --- a/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/AwsServiceBase.scala +++ b/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/AwsServiceBase.scala @@ -13,16 +13,85 @@ import zio.stream.ZStream.TerminationStrategy import scala.reflect.ClassTag trait AwsServiceBase { - final def asyncRequestResponse[Request, Response](impl: Request => CompletableFuture[Response])(request: Request): IO[AwsError, Response] = + final protected def asyncRequestResponse[Request, Response](impl: Request => CompletableFuture[Response])(request: Request): IO[AwsError, Response] = ZIO.fromCompletionStage(impl(request)).mapError(AwsError.fromThrowable) - final def asyncPaginatedRequest[Request, Item, Response](impl: Request => Response, selector: Response => Publisher[Item])(request: Request): ZStream[Any, AwsError, Item] = + final protected def asyncJavaPaginatedRequest[Request, Item, Response](impl: Request => Response, selector: Response => Publisher[Item])(request: Request): ZStream[Any, AwsError, Item] = ZStream.unwrap { ZIO(selector(impl(request)).toStream().mapError(AwsError.fromThrowable)).mapError(AwsError.fromThrowable) } - final def asyncRequestOutputStream[Request, Response](impl: (Request, AsyncResponseTransformer[Response, Task[StreamingOutputResult[Response]]]) => CompletableFuture[Task[StreamingOutputResult[Response]]]) - (request: Request): IO[AwsError, StreamingOutputResult[Response]] = { + final protected def asyncSimplePaginatedRequest[Request, Response, Item](impl: Request => CompletableFuture[Response], + setNextToken: (Request, String) => Request, + getNextToken: Response => Option[String], + getItems: Response => Chunk[Item]) + (request: Request): ZStream[Any, AwsError, Item] = + ZStream.unwrap { + ZIO.fromCompletionStage(impl(request)).mapError(AwsError.fromThrowable).flatMap { response => + getNextToken(response) match { + case Some(nextToken) => + val stream = ZStream { + for { + nextTokenRef <- Ref.make[Option[String]](Some(nextToken)).toManaged_ + pull = for { + token <- nextTokenRef.get + chunk <- token match { + case Some(t) => + for { + nextRequest <- ZIO.effect(setNextToken(request, t)).mapError(t => Some(GenericAwsError(t))) + rsp <- ZIO.fromCompletionStage(impl(nextRequest)).mapError(t => Some(GenericAwsError(t))) + _ <- nextTokenRef.set(getNextToken(rsp)) + } yield getItems(rsp) + case None => + IO.fail(None) + } + } yield chunk + } yield pull + } + ZIO.succeed(stream) + case None => + // No pagination + ZIO.succeed(ZStream.fromChunk(getItems(response))) + } + } + } + + final protected def asyncPaginatedRequest[Request, Response, Item](impl: Request => CompletableFuture[Response], + setNextToken: (Request, String) => Request, + getNextToken: Response => Option[String], + getItems: Response => Chunk[Item]) + (request: Request): IO[AwsError, StreamingOutputResult[Response, Item]] = { + ZIO.fromCompletionStage(impl(request)).mapError(AwsError.fromThrowable).flatMap { response => + getNextToken(response) match { + case Some(nextToken) => + val stream = ZStream { + for { + nextTokenRef <- Ref.make[Option[String]](Some(nextToken)).toManaged_ + pull = for { + token <- nextTokenRef.get + chunk <- token match { + case Some(t) => + for { + nextRequest <- ZIO.effect(setNextToken(request, t)).mapError(t => Some(GenericAwsError(t))) + rsp <- ZIO.fromCompletionStage(impl(nextRequest)).mapError(t => Some(GenericAwsError(t))) + _ <- nextTokenRef.set(getNextToken(rsp)) + } yield getItems(rsp) + case None => + IO.fail(None) + } + } yield chunk + } yield pull + } + ZIO.succeed(StreamingOutputResult(response, stream)) + case None => + // No pagination + ZIO.succeed(StreamingOutputResult(response, ZStream.fromChunk(getItems(response)))) + } + } + } + + final protected def asyncRequestOutputStream[Request, Response](impl: (Request, AsyncResponseTransformer[Response, Task[StreamingOutputResult[Response, Byte]]]) => CompletableFuture[Task[StreamingOutputResult[Response, Byte]]]) + (request: Request): IO[AwsError, StreamingOutputResult[Response, Byte]] = { for { transformer <- ZStreamAsyncResponseTransformer[Response]() streamingOutputResultTask <- ZIO.fromCompletionStage(impl(request, transformer)).mapError(AwsError.fromThrowable) @@ -30,14 +99,14 @@ trait AwsServiceBase { } yield streamingOutputResult } - final def asyncRequestInputStream[Request, Response](impl: (Request, AsyncRequestBody) => CompletableFuture[Response]) - (request: Request, body: ZStream[Any, AwsError, Byte]): IO[AwsError, Response] = + final protected def asyncRequestInputStream[Request, Response](impl: (Request, AsyncRequestBody) => CompletableFuture[Response]) + (request: Request, body: ZStream[Any, AwsError, Byte]): IO[AwsError, Response] = ZIO.runtime.flatMap { implicit runtime: Runtime[Any] => ZIO.fromCompletionStage(impl(request, new ZStreamAsyncRequestBody(body))).mapError(AwsError.fromThrowable) } - final def asyncRequestInputOutputStream[Request, Response](impl: (Request, AsyncRequestBody, AsyncResponseTransformer[Response, Task[StreamingOutputResult[Response]]]) => CompletableFuture[Task[StreamingOutputResult[Response]]]) - (request: Request, body: ZStream[Any, AwsError, Byte]): IO[AwsError, StreamingOutputResult[Response]] = { + final protected def asyncRequestInputOutputStream[Request, Response](impl: (Request, AsyncRequestBody, AsyncResponseTransformer[Response, Task[StreamingOutputResult[Response, Byte]]]) => CompletableFuture[Task[StreamingOutputResult[Response, Byte]]]) + (request: Request, body: ZStream[Any, AwsError, Byte]): IO[AwsError, StreamingOutputResult[Response, Byte]] = { ZIO.runtime.flatMap { implicit runtime: Runtime[Any] => for { transformer <- ZStreamAsyncResponseTransformer[Response]() @@ -47,7 +116,7 @@ trait AwsServiceBase { } } - final def asyncRequestEventOutputStream[ + final protected def asyncRequestEventOutputStream[ Request, Response, ResponseHandler <: EventStreamResponseHandler[Response, EventI], @@ -88,14 +157,14 @@ trait AwsServiceBase { } } - final def asyncRequestEventInputStream[Request, Response, Event](impl: (Request, Publisher[Event]) => CompletableFuture[Response]) - (request: Request, input: ZStream[Any, AwsError, Event]): IO[AwsError, Response] = + final protected def asyncRequestEventInputStream[Request, Response, Event](impl: (Request, Publisher[Event]) => CompletableFuture[Response]) + (request: Request, input: ZStream[Any, AwsError, Event]): IO[AwsError, Response] = for { publisher <- input.mapError(_.toThrowable).toPublisher response <- ZIO.fromCompletionStage(impl(request, publisher)).mapError(AwsError.fromThrowable) } yield response - final def asyncRequestEventInputOutputStream[ + final protected def asyncRequestEventInputOutputStream[ Request, Response, InEvent, diff --git a/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/StreamingOutputResult.scala b/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/StreamingOutputResult.scala index 582fca8a..bbb46de0 100644 --- a/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/StreamingOutputResult.scala +++ b/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/StreamingOutputResult.scala @@ -2,8 +2,11 @@ package io.github.vigoo.zioaws.core import zio.stream.ZStream -case class StreamingOutputResult[Response](response: Response, - output: ZStream[Any, AwsError, Byte]) { - def map[R](f: Response => R): StreamingOutputResult[R] = +case class StreamingOutputResult[Response, Item](response: Response, + output: ZStream[Any, AwsError, Item]) { + def mapResponse[R](f: Response => R): StreamingOutputResult[R, Item] = copy(response = f(response)) + + def mapOutput[I](f: ZStream[Any, AwsError, Item] => ZStream[Any, AwsError, I]): StreamingOutputResult[Response, I] = + copy(output = f(output)) } diff --git a/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/ZStreamAsyncResponseTransformer.scala b/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/ZStreamAsyncResponseTransformer.scala index e4fcf7f9..f034348f 100644 --- a/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/ZStreamAsyncResponseTransformer.scala +++ b/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/ZStreamAsyncResponseTransformer.scala @@ -12,9 +12,9 @@ class ZStreamAsyncResponseTransformer[Response](resultStreamPromise: Promise[Thr responsePromise: Promise[Throwable, Response], errorPromise: Promise[Throwable, Unit]) (implicit runtime: Runtime[Any]) - extends AsyncResponseTransformer[Response, Task[StreamingOutputResult[Response]]] { + extends AsyncResponseTransformer[Response, Task[StreamingOutputResult[Response, Byte]]] { - override def prepare(): CompletableFuture[Task[StreamingOutputResult[Response]]] = + override def prepare(): CompletableFuture[Task[StreamingOutputResult[Response, Byte]]] = CompletableFuture.completedFuture { for { response <- responsePromise.await diff --git a/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/AwsServiceBaseSpec.scala b/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/AwsServiceBaseSpec.scala index 9cbbbb39..e7db2c78 100644 --- a/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/AwsServiceBaseSpec.scala +++ b/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/AwsServiceBaseSpec.scala @@ -49,33 +49,33 @@ object AwsServiceBaseSpec extends DefaultRunnableSpec with AwsServiceBase { assertM(call.run)(fails(equalTo(GenericAwsError(SimulatedException)))) } ), - suite("asyncPaginatedRequest")( + suite("asyncJavaPaginatedRequest")( testM("success")( assertM( - runAsyncPaginatedRequest(SimulatedPublisher.correctSequence) + runAsyncJavaPaginatedRequest(SimulatedPublisher.correctSequence) )(equalTo(Chunk('h', 'e', 'l', 'l', 'o'))) ), testM("fail before subscribe")( - assertM(runAsyncPaginatedRequest( + assertM(runAsyncJavaPaginatedRequest( in => SimulatedPublisher.Error(SimulatedException) :: SimulatedPublisher.correctSequence(in) ).run)(isAwsFailure)), testM("fail during emit")( - assertM(runAsyncPaginatedRequest( + assertM(runAsyncJavaPaginatedRequest( in => SimulatedPublisher.correctSequence(in).splitAt(3) match { case (a, b) => a ::: List(SimulatedPublisher.Error(SimulatedException)) ::: b } ).run)(isAwsFailure)), testM("fail before complete")( - assertM(runAsyncPaginatedRequest( + assertM(runAsyncJavaPaginatedRequest( in => SimulatedPublisher.correctSequence(in).init ::: List(SimulatedPublisher.Error(SimulatedException), SimulatedPublisher.Complete) ).run)(isAwsFailure)), testM("fail with no complete after")( - assertM(runAsyncPaginatedRequest( + assertM(runAsyncJavaPaginatedRequest( in => SimulatedPublisher.correctSequence(in).init ::: List(SimulatedPublisher.Error(SimulatedException)) ).run)(isAwsFailure)), testM("complete before subscribe is empty result")( - assertM(runAsyncPaginatedRequest( + assertM(runAsyncJavaPaginatedRequest( in => SimulatedPublisher.Complete :: SimulatedPublisher.correctSequence(in) ).run)(equalTo(Exit.Success(Chunk.empty)))) ), @@ -88,8 +88,8 @@ object AwsServiceBaseSpec extends DefaultRunnableSpec with AwsServiceBase { case Exit.Success(value) => Some(Exit.Success(value)) case _ => None }, - hasField[Exit.Success[(StreamingOutputResult[Int], Vector[Byte])], Int]("1", _.value._1.response, equalTo(5)) && - hasField[Exit.Success[(StreamingOutputResult[Int], Vector[Byte])], Vector[Byte]]("2", _.value._2, equalTo("hello".getBytes(StandardCharsets.US_ASCII).toVector))))), + hasField[Exit.Success[(StreamingOutputResult[Int, Byte], Vector[Byte])], Int]("1", _.value._1.response, equalTo(5)) && + hasField[Exit.Success[(StreamingOutputResult[Int, Byte], Vector[Byte])], Vector[Byte]]("2", _.value._2, equalTo("hello".getBytes(StandardCharsets.US_ASCII).toVector))))), testM("future fails before prepare")( assertM(runAsyncRequestOutput(SimulatedAsyncResponseTransformer.FailureSpec(failBeforePrepare = Some(SimulatedException))))( isAwsFailure @@ -135,8 +135,8 @@ object AwsServiceBaseSpec extends DefaultRunnableSpec with AwsServiceBase { suite("asyncRequestInputOutputStream")( testM("success")( assertM(runAsyncRequestInputOutputRequest())( - hasField[(StreamingOutputResult[Int], Vector[Byte]), Int]("1", _._1.response, equalTo(5)) && - hasField[(StreamingOutputResult[Int], Vector[Byte]), Vector[Byte]]("2", _._2, equalTo("hheelllloo".getBytes(StandardCharsets.US_ASCII).toVector)))), + hasField[(StreamingOutputResult[Int, Byte], Vector[Byte]), Int]("1", _._1.response, equalTo(5)) && + hasField[(StreamingOutputResult[Int, Byte], Vector[Byte]), Vector[Byte]]("2", _._2, equalTo("hheelllloo".getBytes(StandardCharsets.US_ASCII).toVector)))), testM("failure on input stream")( assertM(runAsyncRequestInputOutputRequest(failOnInput = true).run)( isAwsFailure @@ -361,10 +361,10 @@ object AwsServiceBaseSpec extends DefaultRunnableSpec with AwsServiceBase { private def runAsyncRequestInputOutputRequest(failureSpec: SimulatedAsyncResponseTransformer.FailureSpec = SimulatedAsyncResponseTransformer.FailureSpec(), failOnInput: Boolean = false, - failOnStream: Option[Throwable] = None): ZIO[Any, AwsError, (StreamingOutputResult[Int], Vector[Byte])] = { - val fakeAwsCall: (Int, AsyncRequestBody, AsyncResponseTransformer[Int, Task[StreamingOutputResult[Int]]]) => CompletableFuture[Task[StreamingOutputResult[Int]]] = + failOnStream: Option[Throwable] = None): ZIO[Any, AwsError, (StreamingOutputResult[Int, Byte], Vector[Byte])] = { + val fakeAwsCall: (Int, AsyncRequestBody, AsyncResponseTransformer[Int, Task[StreamingOutputResult[Int, Byte]]]) => CompletableFuture[Task[StreamingOutputResult[Int, Byte]]] = (multipler, asyncBody, transformer) => - SimulatedAsyncBodyReceiver.useAsyncBody[Task[StreamingOutputResult[Int]]]( + SimulatedAsyncBodyReceiver.useAsyncBody[Task[StreamingOutputResult[Int, Byte]]]( (in, cf, buffer) => { SimulatedAsyncResponseTransformer.useAsyncResponseTransformerImpl[Int, ArrayBuffer[Byte]]( buffer, @@ -395,18 +395,18 @@ object AwsServiceBaseSpec extends DefaultRunnableSpec with AwsServiceBase { } yield (result, streamResult) } - private def runAsyncPaginatedRequest(simulation: Chunk[Char] => List[SimulatedPublisher.Action]): ZIO[Any, AwsError, Chunk[Char]] = { + private def runAsyncJavaPaginatedRequest(simulation: Chunk[Char] => List[SimulatedPublisher.Action]): ZIO[Any, AwsError, Chunk[Char]] = { val fakeAwsCall: String => Publisher[Char] = { in => SimulatedPublisher.createCharPublisher(in, simulation) } - asyncPaginatedRequest[String, Char, Publisher[Char]](fakeAwsCall, identity)("hello") + asyncJavaPaginatedRequest[String, Char, Publisher[Char]](fakeAwsCall, identity)("hello") .runCollect } private def runAsyncRequestOutput(failureSpec: SimulatedAsyncResponseTransformer.FailureSpec = SimulatedAsyncResponseTransformer.FailureSpec(), - failOnStream: Option[Throwable] = None): ZIO[Any, AwsError, Exit[AwsError, (StreamingOutputResult[Int], Vector[Byte])]] = { - val fakeAwsCall = (in: String, transformer: AsyncResponseTransformer[Int, Task[StreamingOutputResult[Int]]]) => + failOnStream: Option[Throwable] = None): ZIO[Any, AwsError, Exit[AwsError, (StreamingOutputResult[Int, Byte], Vector[Byte])]] = { + val fakeAwsCall = (in: String, transformer: AsyncResponseTransformer[Int, Task[StreamingOutputResult[Int, Byte]]]) => SimulatedAsyncResponseTransformer.useAsyncResponseTransformer[String, Int]( in, transformer, diff --git a/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/sim/SimulatedAsyncResponseTransformer.scala b/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/sim/SimulatedAsyncResponseTransformer.scala index d6c5a802..b423e21e 100644 --- a/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/sim/SimulatedAsyncResponseTransformer.scala +++ b/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/sim/SimulatedAsyncResponseTransformer.scala @@ -17,12 +17,12 @@ object SimulatedAsyncResponseTransformer { def useAsyncResponseTransformer[In, Out](in: In, - transformer: AsyncResponseTransformer[Out, Task[StreamingOutputResult[Out]]], + transformer: AsyncResponseTransformer[Out, Task[StreamingOutputResult[Out, Byte]]], toResult: In => Out, toPublisher: In => SdkPublisher[ByteBuffer], failureSpec: FailureSpec) - (implicit threadPool: ExecutorService): CompletableFuture[Task[StreamingOutputResult[Out]]] = { - val cf = new CompletableFuture[Task[StreamingOutputResult[Out]]]() + (implicit threadPool: ExecutorService): CompletableFuture[Task[StreamingOutputResult[Out, Byte]]] = { + val cf = new CompletableFuture[Task[StreamingOutputResult[Out, Byte]]]() threadPool.submit(new Runnable { override def run(): Unit = { useAsyncResponseTransformerImpl(in, transformer, toResult, toPublisher, failureSpec, cf) @@ -32,17 +32,17 @@ object SimulatedAsyncResponseTransformer { } def useAsyncResponseTransformerImpl[Out, In](in: In, - transformer: AsyncResponseTransformer[Out, Task[StreamingOutputResult[Out]]], + transformer: AsyncResponseTransformer[Out, Task[StreamingOutputResult[Out, Byte]]], toResult: In => Out, toPublisher: In => SdkPublisher[ByteBuffer], failureSpec: FailureSpec, - cf: CompletableFuture[Task[StreamingOutputResult[Out]]]): Unit = { + cf: CompletableFuture[Task[StreamingOutputResult[Out, Byte]]]): Unit = { failureSpec.failBeforePrepare match { case Some(throwable) => cf.completeExceptionally(throwable) case None => slowDown() - transformer.prepare().thenApply[Boolean] { (result: Task[StreamingOutputResult[Out]]) => + transformer.prepare().thenApply[Boolean] { (result: Task[StreamingOutputResult[Out, Byte]]) => slowDown() transformer.onResponse(toResult(in)) From 61522c7471b1893e86bf7be12928d245301dcc3c Mon Sep 17 00:00:00 2001 From: Daniel Vigovszky Date: Mon, 24 Aug 2020 10:11:36 +0200 Subject: [PATCH 3/7] Pagination tests and fixes --- .../vigoo/zioaws/core/AwsServiceBase.scala | 4 +- .../zioaws/core/AwsServiceBaseSpec.scala | 69 ++++++++++++++++++- .../zioaws/core/sim/SimulatedPagination.scala | 39 +++++++++++ 3 files changed, 109 insertions(+), 3 deletions(-) create mode 100644 zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/sim/SimulatedPagination.scala diff --git a/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/AwsServiceBase.scala b/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/AwsServiceBase.scala index d05995ab..e5f05e88 100644 --- a/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/AwsServiceBase.scala +++ b/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/AwsServiceBase.scala @@ -48,7 +48,7 @@ trait AwsServiceBase { } yield chunk } yield pull } - ZIO.succeed(stream) + ZIO.succeed(ZStream.fromChunk(getItems(response)).concat(stream)) case None => // No pagination ZIO.succeed(ZStream.fromChunk(getItems(response))) @@ -82,7 +82,7 @@ trait AwsServiceBase { } yield chunk } yield pull } - ZIO.succeed(StreamingOutputResult(response, stream)) + ZIO.succeed(StreamingOutputResult(response, ZStream.fromChunk(getItems(response)).concat(stream))) case None => // No pagination ZIO.succeed(StreamingOutputResult(response, ZStream.fromChunk(getItems(response)))) diff --git a/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/AwsServiceBaseSpec.scala b/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/AwsServiceBaseSpec.scala index e7db2c78..5d123127 100644 --- a/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/AwsServiceBaseSpec.scala +++ b/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/AwsServiceBaseSpec.scala @@ -3,7 +3,8 @@ package io.github.vigoo.zioaws.core import java.nio.charset.StandardCharsets import java.util.concurrent.{CompletableFuture, ExecutorService, Executors} -import io.github.vigoo.zioaws.core.sim.{SimulatedAsyncBodyReceiver, SimulatedAsyncResponseTransformer, SimulatedEventStreamResponseHandlerReceiver, SimulatedPublisher} +import io.github.vigoo.zioaws.core.sim.SimulatedPagination.PaginatedRequest +import io.github.vigoo.zioaws.core.sim.{SimulatedAsyncBodyReceiver, SimulatedAsyncResponseTransformer, SimulatedEventStreamResponseHandlerReceiver, SimulatedPagination, SimulatedPublisher} import org.reactivestreams.{Publisher, Subscriber, Subscription} import software.amazon.awssdk.awscore.eventstream.EventStreamResponseHandler import software.amazon.awssdk.core.async.{AsyncRequestBody, AsyncResponseTransformer} @@ -80,6 +81,52 @@ object AwsServiceBaseSpec extends DefaultRunnableSpec with AwsServiceBase { ).run)(equalTo(Exit.Success(Chunk.empty)))) ), + suite("asyncSimplePaginatedRequest")( + testM("success")( + assertM( + runAsyncSimplePaginatedRequest("hello") + )(equalTo(Chunk('h', 'e', 'l', 'l', 'o'))) + ), + testM("success in single-page case")( + assertM( + runAsyncSimplePaginatedRequest("x") + )(equalTo(Chunk('x'))) + ), + testM("fail on first page")( + assertM( + runAsyncSimplePaginatedRequest("hello", failAfter = Some(0)).run + )(isAwsFailure) + ), + testM("fail on other page")( + assertM( + runAsyncSimplePaginatedRequest("hello", failAfter = Some(3)).run + )(isAwsFailure) + ) + ), + + suite("asyncPaginatedRequest")( + testM("success")( + assertM( + runAsyncPaginatedRequest("hello") + )(equalTo(Chunk('h', 'e', 'l', 'l', 'o'))) + ), + testM("success in single-page case")( + assertM( + runAsyncPaginatedRequest("x") + )(equalTo(Chunk('x'))) + ), + testM("fail on first page")( + assertM( + runAsyncPaginatedRequest("hello", failAfter = Some(0)).run + )(isAwsFailure) + ), + testM("fail on other page")( + assertM( + runAsyncPaginatedRequest("hello", failAfter = Some(3)).run + )(isAwsFailure) + ) + ), + suite("asyncRequestOutputStream")( testM("success")( assertM(runAsyncRequestOutput())( @@ -404,6 +451,26 @@ object AwsServiceBaseSpec extends DefaultRunnableSpec with AwsServiceBase { .runCollect } + private def runAsyncSimplePaginatedRequest(test: String, failAfter: Option[Int] = None): ZIO[Any, AwsError, Chunk[Char]] = { + asyncSimplePaginatedRequest[SimulatedPagination.PaginatedRequest, SimulatedPagination.PaginatedResult, Char]( + SimulatedPagination.simplePagination(failAfter, SimulatedException), + (req, token) => req.copy(token = Some(token)), + _.next, + rsp => Chunk.fromIterable(rsp.output), + )(PaginatedRequest(test, None)).runCollect + } + + private def runAsyncPaginatedRequest(test: String, failAfter: Option[Int] = None): ZIO[Any, AwsError, Chunk[Char]] = + for { + response <- asyncPaginatedRequest[SimulatedPagination.PaginatedRequest, SimulatedPagination.PaginatedResult, Char]( + SimulatedPagination.simplePagination(failAfter, SimulatedException), + (req, token) => req.copy(token = Some(token)), + _.next, + rsp => Chunk.fromIterable(rsp.output), + )(PaginatedRequest(test, None)) + streamResult <- response.output.runCollect + } yield streamResult + private def runAsyncRequestOutput(failureSpec: SimulatedAsyncResponseTransformer.FailureSpec = SimulatedAsyncResponseTransformer.FailureSpec(), failOnStream: Option[Throwable] = None): ZIO[Any, AwsError, Exit[AwsError, (StreamingOutputResult[Int, Byte], Vector[Byte])]] = { val fakeAwsCall = (in: String, transformer: AsyncResponseTransformer[Int, Task[StreamingOutputResult[Int, Byte]]]) => diff --git a/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/sim/SimulatedPagination.scala b/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/sim/SimulatedPagination.scala new file mode 100644 index 00000000..06fa90f9 --- /dev/null +++ b/zio-aws-core/src/test/scala/io/github/vigoo/zioaws/core/sim/SimulatedPagination.scala @@ -0,0 +1,39 @@ +package io.github.vigoo.zioaws.core.sim + +import java.util.concurrent.{CompletableFuture, ExecutorService} + +object SimulatedPagination { + + case class PaginatedRequest(input: String, token: Option[String]) + case class PaginatedResult(output: List[Char], next: Option[String]) + + def simplePagination(failAfter: Option[Int], failure: Throwable)(request: PaginatedRequest)(implicit threadPool: ExecutorService): CompletableFuture[PaginatedResult] = { + val cf = new CompletableFuture[PaginatedResult]() + + threadPool.submit(new Runnable { + override def run(): Unit = { + val startIdx = request.token match { + case Some(value) => + value.toInt + case None => + 0 + } + + if (startIdx >= failAfter.getOrElse(Int.MaxValue)) { + cf.completeExceptionally(failure) + } else { + val chunk = request.input.substring(startIdx, Math.min(request.input.length, startIdx + 2)).toList + val nextIndex = startIdx + 2 + val nextToken = if (nextIndex >= request.input.length) None else Some(nextIndex.toString) + + cf.complete(PaginatedResult( + chunk, + nextToken + )) + } + } + }) + + cf + } +} From a3ae5aacb72d2974c08dd94c74bf5a8089e9cf29 Mon Sep 17 00:00:00 2001 From: Daniel Vigovszky Date: Mon, 24 Aug 2020 11:16:34 +0200 Subject: [PATCH 4/7] Updated example --- README.md | 9 +++---- build.sbt | 4 +-- examples/build.sbt | 4 +-- examples/example1/src/main/scala/Main.scala | 29 ++++++++------------- 4 files changed, 19 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 8d2a78c5..f9354f0b 100644 --- a/README.md +++ b/README.md @@ -168,10 +168,9 @@ object Main extends App { applicationName <- appDescription.applicationName _ <- console.putStrLn(s"Got application description for $applicationName") - envsResult <- elasticbeanstalk.describeEnvironments(DescribeEnvironmentsRequest(applicationName = Some(applicationName))) - envs <- envsResult.environments + envStream = elasticbeanstalk.describeEnvironments(DescribeEnvironmentsRequest(applicationName = Some(applicationName))) - _ <- ZIO.foreach(envs) { env => + _ <- envStream.run(Sink.foreach { env => env.environmentName.flatMap { environmentName => (for { environmentId <- env.environmentId @@ -184,7 +183,7 @@ object Main extends App { instanceIds <- ZIO.foreach(instances)(_.id) _ <- console.putStrLn(s"Instance IDs are ${instanceIds.mkString(", ")}") - reservationsStream <- ec2.describeInstancesStream(DescribeInstancesRequest(instanceIds = Some(instanceIds))) + reservationsStream = ec2.describeInstances(DescribeInstancesRequest(instanceIds = Some(instanceIds))) _ <- reservationsStream.run(Sink.foreach { reservation => reservation.instances.flatMap { instances => @@ -204,7 +203,7 @@ object Main extends App { console.putStrLnErr(s"Failed to get info for $environmentName: $error") } } - } + }) } yield () case None => ZIO.unit diff --git a/build.sbt b/build.sbt index 09614fb7..445abf20 100644 --- a/build.sbt +++ b/build.sbt @@ -10,8 +10,8 @@ val awsSubVersion = awsVersion.drop(awsVersion.indexOf('.') + 1) val http4sVersion = "0.21.7" val fs2Version = "2.2.2" -val majorVersion = "1" -val minorVersion = "1" +val majorVersion = "2" +val minorVersion = "0" val zioAwsVersion = s"$majorVersion.$awsSubVersion.$minorVersion" val generateAll = taskKey[Unit]("Generates all AWS client libraries") diff --git a/examples/build.sbt b/examples/build.sbt index a49300e7..86a004ce 100644 --- a/examples/build.sbt +++ b/examples/build.sbt @@ -1,9 +1,9 @@ val commonSettings = Seq( - scalaVersion := "2.13.2" + scalaVersion := "2.13.3" ) -val zioAwsVersion = "1.13.69.1" +val zioAwsVersion = "2.14.2.0" lazy val example1 = Project("example1", file("example1")).settings(commonSettings).settings( libraryDependencies ++= Seq( diff --git a/examples/example1/src/main/scala/Main.scala b/examples/example1/src/main/scala/Main.scala index 422befdb..da06dd85 100644 --- a/examples/example1/src/main/scala/Main.scala +++ b/examples/example1/src/main/scala/Main.scala @@ -1,18 +1,12 @@ -import scala.jdk.CollectionConverters._ -import zio._ -import zio.console -import zio.console._ -import zio.stream._ -import io.github.vigoo.zioaws.core -import io.github.vigoo.zioaws.http4s -import io.github.vigoo.zioaws.netty -import io.github.vigoo.zioaws.ec2 -import io.github.vigoo.zioaws.ec2.model._ +import io.github.vigoo.zioaws.core.AwsError import io.github.vigoo.zioaws.ec2.Ec2 -import io.github.vigoo.zioaws.elasticbeanstalk -import io.github.vigoo.zioaws.elasticbeanstalk.model._ +import io.github.vigoo.zioaws.ec2.model._ import io.github.vigoo.zioaws.elasticbeanstalk.ElasticBeanstalk -import io.github.vigoo.zioaws.core.AwsError +import io.github.vigoo.zioaws.elasticbeanstalk.model._ +import io.github.vigoo.zioaws.{core, ec2, elasticbeanstalk, http4s} +import zio.{console, _} +import zio.console._ +import zio.stream._ object Main extends App { val program: ZIO[Console with Ec2 with ElasticBeanstalk, AwsError, Unit] = @@ -25,10 +19,9 @@ object Main extends App { applicationName <- appDescription.applicationName _ <- console.putStrLn(s"Got application description for $applicationName") - envsResult <- elasticbeanstalk.describeEnvironments(DescribeEnvironmentsRequest(applicationName = Some(applicationName))) - envs <- envsResult.environments + envStream = elasticbeanstalk.describeEnvironments(DescribeEnvironmentsRequest(applicationName = Some(applicationName))) - _ <- ZIO.foreach(envs) { env => + _ <- envStream.run(Sink.foreach { env => env.environmentName.flatMap { environmentName => (for { environmentId <- env.environmentId @@ -41,7 +34,7 @@ object Main extends App { instanceIds <- ZIO.foreach(instances)(_.id) _ <- console.putStrLn(s"Instance IDs are ${instanceIds.mkString(", ")}") - reservationsStream <- ec2.describeInstancesStream(DescribeInstancesRequest(instanceIds = Some(instanceIds))) + reservationsStream = ec2.describeInstances(DescribeInstancesRequest(instanceIds = Some(instanceIds))) _ <- reservationsStream.run(Sink.foreach { reservation => reservation.instances.flatMap { instances => @@ -61,7 +54,7 @@ object Main extends App { console.putStrLnErr(s"Failed to get info for $environmentName: $error") } } - } + }) } yield () case None => ZIO.unit From f209d9bf3fb71c727a7aabc41d53292b8db234f4 Mon Sep 17 00:00:00 2001 From: Daniel Vigovszky Date: Mon, 24 Aug 2020 18:42:23 +0200 Subject: [PATCH 5/7] Added some integration tests using localstack --- integtests/build.sbt | 23 +++ integtests/src/test/resources/log4j2.xml | 18 +++ .../zioaws/integtests/DynamoDbTests.scala | 142 ++++++++++++++++++ 3 files changed, 183 insertions(+) create mode 100644 integtests/build.sbt create mode 100644 integtests/src/test/resources/log4j2.xml create mode 100644 integtests/src/test/scala/io/github/vigoo/zioaws/integtests/DynamoDbTests.scala diff --git a/integtests/build.sbt b/integtests/build.sbt new file mode 100644 index 00000000..910d686f --- /dev/null +++ b/integtests/build.sbt @@ -0,0 +1,23 @@ + +scalaVersion := "2.13.3" + +val zioAwsVersion = "2.14.2.1" + +libraryDependencies ++= Seq( + "io.github.vigoo" %% "zio-aws-core" % "2.14.2.2", + "io.github.vigoo" %% "zio-aws-http4s" % "2.14.2.2", + "io.github.vigoo" %% "zio-aws-netty" % zioAwsVersion, + "io.github.vigoo" %% "zio-aws-s3" % zioAwsVersion, + "io.github.vigoo" %% "zio-aws-dynamodb" % zioAwsVersion, + + "dev.zio" %% "zio" % "1.0.1", + "dev.zio" %% "zio-test" % "1.0.1", + "dev.zio" %% "zio-test-sbt" % "1.0.1", + + "org.apache.logging.log4j" % "log4j-1.2-api" % "2.13.3", + "org.apache.logging.log4j" % "log4j-core" % "2.13.3", + "org.apache.logging.log4j" % "log4j-api" % "2.13.3", + "org.apache.logging.log4j" % "log4j-slf4j-impl" % "2.13.3", +) + +testFrameworks += new TestFramework("zio.test.sbt.ZTestFramework") diff --git a/integtests/src/test/resources/log4j2.xml b/integtests/src/test/resources/log4j2.xml new file mode 100644 index 00000000..2d428121 --- /dev/null +++ b/integtests/src/test/resources/log4j2.xml @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + + + + diff --git a/integtests/src/test/scala/io/github/vigoo/zioaws/integtests/DynamoDbTests.scala b/integtests/src/test/scala/io/github/vigoo/zioaws/integtests/DynamoDbTests.scala new file mode 100644 index 00000000..d0ada8be --- /dev/null +++ b/integtests/src/test/scala/io/github/vigoo/zioaws/integtests/DynamoDbTests.scala @@ -0,0 +1,142 @@ +package io.github.vigoo.zioaws.integtests + +import java.net.URI + +import io.github.vigoo.zioaws.core._ +import io.github.vigoo.zioaws.{dynamodb, _} +import io.github.vigoo.zioaws.dynamodb.model._ +import software.amazon.awssdk.auth.credentials.{AwsBasicCredentials, StaticCredentialsProvider} +import zio._ +import zio.test.Assertion._ +import zio.test._ +import zio.test.environment.TestRandom + +object DynamoDbTests extends DefaultRunnableSpec { + + val nettyClient = netty.client() + val http4sClient = http4s.client() + val awsConfig = config.default + val dynamoDb = dynamodb.customized( + _.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("dummy", "key"))) + .endpointOverride(new URI("http://localhost:4566")) + ) + + private def testTable = { + for { + _ <- TestRandom.setSeed(scala.util.Random.nextLong()) + tableName <- generateName + env <- ZIO.environment[dynamodb.DynamoDb] + } yield ZManaged.make( + for { + tableData <- dynamodb.createTable(CreateTableRequest( + tableName = tableName, + attributeDefinitions = List( + AttributeDefinition("key", ScalarAttributeType.S) + ), + keySchema = List( + KeySchemaElement("key", KeyType.HASH) + ), + provisionedThroughput = Some(ProvisionedThroughput( + readCapacityUnits = 16L, + writeCapacityUnits = 16L + )) + )) + tableDesc <- tableData.tableDescription + } yield tableDesc + )(tableDescription => + tableDescription.tableName.flatMap { tableName => + dynamodb.deleteTable(DeleteTableRequest(tableName)) + }.provide(env) + .catchAll(error => ZIO.die(error.toThrowable)) + .unit) + } + + def tests = Seq( + testM("can create and delete a table") { + // simple request/response calls + val steps = for { + table <- testTable + _ <- table.use { _ => + ZIO.unit + } + } yield () + + assertM(steps.run)(succeeds(isUnit)) + }, + testM("scan") { + // java paginator based streaming + + val N = 100 + val steps = for { + table <- testTable + result <- table.use { tableDescription => + val put = + for { + tableName <- tableDescription.tableName + randomKey <- random.nextString(10) + randomValue <- random.nextInt + _ <- dynamodb.putItem(PutItemRequest( + tableName = tableName, + item = Map( + "key" -> AttributeValue(s = Some(randomKey)), + "value" -> AttributeValue(n = Some(randomValue.toString)) + ) + )) + } yield () + + for { + tableName <- tableDescription.tableName + _ <- put.repeatN(N - 1) + stream = dynamodb.scan(ScanRequest( + tableName = tableName, + limit = Some(10) + )) + streamResult <- stream.runCollect + } yield streamResult + } + } yield result.length + + assertM(steps)(equalTo(N)) + }, + testM("listTagsOfResource") { + // simple paginated streaming + val N = 1000 + val steps = for { + table <- testTable + result <- table.use { tableDescription => + for { + arn <- tableDescription.tableArn + _ <- dynamodb.tagResource(TagResourceRequest( + resourceArn = arn, + tags = (0 until N).map(i => dynamodb.model.Tag(s"tag$i", i.toString)).toList + )) + + tagStream = dynamodb.listTagsOfResource(ListTagsOfResourceRequest( + resourceArn = arn + )) + tags <- tagStream.runCollect + } yield tags + } + } yield result.length + + assertM(steps)(equalTo(N)) + } + ) + + private def generateName = + ZIO.foreach((0 to 8).toList) { _ => + random.nextIntBetween('a'.toInt, 'z'.toInt).map(_.toChar) + }.map(_.mkString) + + + override def spec = { + suite("DynamoDB")( + suite("with Netty")( + tests: _* + ).provideCustomLayer((nettyClient >>> awsConfig >>> dynamoDb).mapError(TestFailure.die)), + suite("with http4s")( + tests: _* + ).provideCustomLayer((http4sClient >>> awsConfig >>> dynamoDb).mapError(TestFailure.die)), + ) + } +} \ No newline at end of file From 46514457415bd64cdeb28b4791acdc9e55e2e6f8 Mon Sep 17 00:00:00 2001 From: Daniel Vigovszky Date: Mon, 24 Aug 2020 18:42:33 +0200 Subject: [PATCH 6/7] Fix issue in ZStreamAsyncRequestBody --- .../io/github/vigoo/zioaws/core/ZStreamAsyncRequestBody.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/ZStreamAsyncRequestBody.scala b/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/ZStreamAsyncRequestBody.scala index 3f45e85c..366b533a 100644 --- a/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/ZStreamAsyncRequestBody.scala +++ b/zio-aws-core/src/main/scala/io/github/vigoo/zioaws/core/ZStreamAsyncRequestBody.scala @@ -22,6 +22,7 @@ class ZStreamAsyncRequestBody(stream: ZStream[Any, AwsError, Byte])(implicit run .mapChunks(chunk => Chunk(ByteBuffer.wrap(chunk.toArray))) .run(sink) .catchAll(errorP.fail) + .forkDaemon } yield () } } From c34abb1ed44277e51c932a551914124195ba144e Mon Sep 17 00:00:00 2001 From: Daniel Vigovszky Date: Mon, 24 Aug 2020 18:52:52 +0200 Subject: [PATCH 7/7] Integration tests using S3 --- .../vigoo/zioaws/integtests/S3Tests.scala | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 integtests/src/test/scala/io/github/vigoo/zioaws/integtests/S3Tests.scala diff --git a/integtests/src/test/scala/io/github/vigoo/zioaws/integtests/S3Tests.scala b/integtests/src/test/scala/io/github/vigoo/zioaws/integtests/S3Tests.scala new file mode 100644 index 00000000..fb65c51d --- /dev/null +++ b/integtests/src/test/scala/io/github/vigoo/zioaws/integtests/S3Tests.scala @@ -0,0 +1,108 @@ +package io.github.vigoo.zioaws.integtests + +import java.net.URI + +import io.github.vigoo.zioaws.core.config +import io.github.vigoo.zioaws.s3.model._ +import io.github.vigoo.zioaws.{http4s, netty, s3} +import software.amazon.awssdk.auth.credentials.{AwsBasicCredentials, StaticCredentialsProvider} +import zio.stream.ZStream +import zio.test.Assertion._ +import zio.test.environment.TestRandom +import zio.test._ +import zio.{ZIO, ZManaged, console, random} + +object S3Tests extends DefaultRunnableSpec { + val nettyClient = netty.client() + val http4sClient = http4s.client() + val awsConfig = config.default + val s3Client = s3.customized( + _.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("dummy", "key"))) + .endpointOverride(new URI("http://localhost:4566")) + ) + + private def testBucket = { + for { + _ <- TestRandom.setSeed(scala.util.Random.nextLong()) + bucketName <- generateName + env <- ZIO.environment[s3.S3] + } yield ZManaged.make( + for { + _ <- s3.createBucket(CreateBucketRequest( + bucket = bucketName, + )) + } yield bucketName + )(bucketName => + s3.deleteBucket(DeleteBucketRequest(bucketName)) + .provide(env) + .catchAll(error => ZIO.die(error.toThrowable)) + .unit) + } + + def tests = Seq( + testM("can create and delete a bucket") { + // simple request/response calls + val steps = for { + bucket <- testBucket + _ <- bucket.use { bucketName => + ZIO.unit + } + } yield () + + assertM(steps.run)(succeeds(isUnit)) + }, + testM("can upload and download items as byte streams") { + // streaming input and streaming output calls + val steps = for { + testData <- random.nextBytes(4096) + bucket <- testBucket + key <- generateName + receivedData <- bucket.use { bucketName => + for { + _ <- console.putStrLn(s"Uploading $key to $bucketName") + _ <- s3.putObject(PutObjectRequest( + bucket = bucketName, + key = key, + ), ZStream + .fromIterable(testData) + .tap(_ => ZIO.succeed(console.putStr("."))) + .chunkN(1024)) + _ <- console.putStrLn("Downloading") + getResponse <- s3.getObject(GetObjectRequest( + bucket = bucketName, + key = key + )) + getStream = getResponse.output + result <- getStream.runCollect + + _ <- console.putStrLn("Deleting") + _ <- s3.deleteObject(DeleteObjectRequest( + bucket = bucketName, + key = key + )) + + } yield result + } + } yield testData == receivedData + + assertM(steps)(isTrue) + } + ) + + private def generateName = + ZIO.foreach((0 to 8).toList) { _ => + random.nextIntBetween('a'.toInt, 'z'.toInt).map(_.toChar) + }.map(_.mkString) + + + override def spec = { + suite("S3")( + suite("with Netty")( + tests: _* + ).provideCustomLayer((nettyClient >>> awsConfig >>> s3Client).mapError(TestFailure.die)), + suite("with http4s")( + tests: _* + ).provideCustomLayer((http4sClient >>> awsConfig >>> s3Client).mapError(TestFailure.die)), + ) + } +}