From 9420122d0fea19d03e9899664b43df2cebccc15e Mon Sep 17 00:00:00 2001 From: Gabriel Ciuloaica Date: Tue, 22 Aug 2023 14:07:33 +0300 Subject: [PATCH] fixed enum support in avro codec (#578) * fixed enum support in avro codec * fixed tests * linted * fixed README * fixed scal 3x compilation * fixed enum support in avro codec * fixed tests * linted * fixed scal 3x compilation * fix: fixed README * fix: fixed README --------- Co-authored-by: Daniel Vigovszky --- README.md | 10 +- .../scala/zio/schema/codec/AvroCodec.scala | 114 +++++++++++++----- .../zio/schema/codec/AvroSchemaCodec.scala | 4 +- .../zio/schema/codec/AvroCodecSpec.scala | 14 ++- .../schema/codec/AvroSchemaCodecSpec.scala | 2 +- 5 files changed, 101 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 243558c74..2e33a70fe 100644 --- a/README.md +++ b/README.md @@ -30,13 +30,13 @@ _ZIO Schema_ is used by a growing number of ZIO libraries, including _ZIO Flow_, In order to use this library, we need to add the following lines in our `build.sbt` file: ```scala -libraryDependencies += "dev.zio" %% "zio-schema" % "0.4.12" -libraryDependencies += "dev.zio" %% "zio-schema-bson" % "0.4.12" -libraryDependencies += "dev.zio" %% "zio-schema-json" % "0.4.12" -libraryDependencies += "dev.zio" %% "zio-schema-protobuf" % "0.4.12" +libraryDependencies += "dev.zio" %% "zio-schema" % "0.4.13" +libraryDependencies += "dev.zio" %% "zio-schema-bson" % "0.4.13" +libraryDependencies += "dev.zio" %% "zio-schema-json" % "0.4.13" +libraryDependencies += "dev.zio" %% "zio-schema-protobuf" % "0.4.13" // Required for automatic generic derivation of schemas -libraryDependencies += "dev.zio" %% "zio-schema-derivation" % "0.4.12", +libraryDependencies += "dev.zio" %% "zio-schema-derivation" % "0.4.13", libraryDependencies += "org.scala-lang" % "scala-reflect" % scalaVersion.value % "provided" ``` diff --git a/zio-schema-avro/shared/src/main/scala/zio/schema/codec/AvroCodec.scala b/zio-schema-avro/shared/src/main/scala/zio/schema/codec/AvroCodec.scala index 97a499e7d..6caec830c 100644 --- a/zio-schema-avro/shared/src/main/scala/zio/schema/codec/AvroCodec.scala +++ b/zio-schema-avro/shared/src/main/scala/zio/schema/codec/AvroCodec.scala @@ -209,16 +209,31 @@ object AvroCodec { private def decodeCaseClass1[A, Z](raw: Any, schema: Schema.CaseClass1[A, Z]) = decodeValue(raw, schema.field.schema).map(schema.defaultConstruct) - private def decodeEnum[Z](raw: Any, cases: Schema.Case[Z, _]*): Either[DecodeError, Any] = { - val generic = raw.asInstanceOf[GenericData.Record] - val enumCaseName = generic.getSchema.getFullName - val enumCaseValue = generic.get("value") + private def decodeEnum[Z](raw: Any, cases: Schema.Case[Z, _]*): Either[DecodeError, Any] = + raw match { + case enums: GenericData.EnumSymbol => + decodeGenericEnum(enums.toString, None, cases: _*) + case gr: GenericData.Record => + val enumCaseName = gr.getSchema.getFullName + if (gr.hasField("value")) { + val enumCaseValue = gr.get("value") + decodeGenericEnum[Z](enumCaseName, Some(enumCaseValue), cases: _*) + } else { + decodeGenericEnum[Z](enumCaseName, None, cases: _*) + } + case _ => Left(DecodeError.MalformedFieldWithPath(Chunk.single("Error"), s"Unknown enum: $raw")) + } + + private def decodeGenericEnum[Z]( + enumCaseName: String, + enumCaseValue: Option[AnyRef], + cases: Schema.Case[Z, _]* + ): Either[DecodeError, Any] = cases .find(_.id == enumCaseName) - .map(s => decodeValue(enumCaseValue, s.schema)) + .map(s => decodeValue(enumCaseValue.getOrElse(s), s.schema)) .toRight(DecodeError.MalformedFieldWithPath(Chunk.single("Error"), s"Unknown enum value: $enumCaseName")) .flatMap(identity) - } private def decodeRecord[A](value: A, schema: Schema.Record[_]) = { val record = value.asInstanceOf[GenericRecord] @@ -454,41 +469,41 @@ object AvroCodec { else decodeValue(value, schema).map(Some(_)) private def encodeValue[A](a: A, schema: Schema[A]): Any = schema match { - case Schema.Enum1(_, c1, _) => encodeEnum(a, c1) - case Schema.Enum2(_, c1, c2, _) => encodeEnum(a, c1, c2) - case Schema.Enum3(_, c1, c2, c3, _) => encodeEnum(a, c1, c2, c3) - case Schema.Enum4(_, c1, c2, c3, c4, _) => encodeEnum(a, c1, c2, c3, c4) - case Schema.Enum5(_, c1, c2, c3, c4, c5, _) => encodeEnum(a, c1, c2, c3, c4, c5) - case Schema.Enum6(_, c1, c2, c3, c4, c5, c6, _) => encodeEnum(a, c1, c2, c3, c4, c5, c6) + case Schema.Enum1(_, c1, _) => encodeEnum(schema, a, c1) + case Schema.Enum2(_, c1, c2, _) => encodeEnum(schema, a, c1, c2) + case Schema.Enum3(_, c1, c2, c3, _) => encodeEnum(schema, a, c1, c2, c3) + case Schema.Enum4(_, c1, c2, c3, c4, _) => encodeEnum(schema, a, c1, c2, c3, c4) + case Schema.Enum5(_, c1, c2, c3, c4, c5, _) => encodeEnum(schema, a, c1, c2, c3, c4, c5) + case Schema.Enum6(_, c1, c2, c3, c4, c5, c6, _) => encodeEnum(schema, a, c1, c2, c3, c4, c5, c6) case Schema.Enum7(_, c1, c2, c3, c4, c5, c6, c7, _) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7) + encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7) case Schema.Enum8(_, c1, c2, c3, c4, c5, c6, c7, c8, _) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8) + encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8) case Schema.Enum9(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, _) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9) + encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9) case Schema.Enum10(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, _) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10) + encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10) case Schema.Enum11(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, _) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11) + encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11) case Schema.Enum12(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, _) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12) + encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12) case Schema.Enum13(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, _) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13) + encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13) case Schema.Enum14(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, _) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14) + encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14) case Schema.Enum15(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, _) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15) + encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15) case Schema.Enum16(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, _) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16) + encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16) case Schema.Enum17(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, _) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17) + encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17) case Schema.Enum18(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, _) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18) + encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18) case Schema.Enum19(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, _) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19) + encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19) case Schema .Enum20(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, _) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20) + encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20) case Schema.Enum21( _, c1, @@ -514,7 +529,31 @@ object AvroCodec { c21, _ ) => - encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, c21) + encodeEnum( + schema, + a, + c1, + c2, + c3, + c4, + c5, + c6, + c7, + c8, + c9, + c10, + c11, + c12, + c13, + c14, + c15, + c16, + c17, + c18, + c19, + c20, + c21 + ) case Schema.Enum22( _, c1, @@ -542,6 +581,7 @@ object AvroCodec { _ ) => encodeEnum( + schema, a, c1, c2, @@ -580,9 +620,10 @@ object AvroCodec { case Schema.Optional(schema, _) => encodeOption(schema, a) case Schema.Tuple2(left, right, _) => encodeTuple2(left.asInstanceOf[Schema[Any]], right.asInstanceOf[Schema[Any]], a) - case Schema.Either(left, right, _) => encodeEither(left, right, a) - case Schema.Lazy(schema0) => encodeValue(a, schema0()) - case Schema.CaseClass0(_, _, _) => encodePrimitive((), StandardType.UnitType) + case Schema.Either(left, right, _) => encodeEither(left, right, a) + case Schema.Lazy(schema0) => encodeValue(a, schema0()) + case Schema.CaseClass0(_, _, _) => + encodeCaseClass(schema, a, Seq.empty: _*) //encodePrimitive((), StandardType.UnitType) case Schema.CaseClass1(_, f, _, _) => encodeCaseClass(schema, a, f) case Schema.CaseClass2(_, f0, f1, _, _) => encodeCaseClass(schema, a, f0, f1) case Schema.CaseClass3(_, f0, f1, f2, _, _) => encodeCaseClass(schema, a, f0, f1, f2) @@ -926,11 +967,20 @@ object AvroCodec { record } - private def encodeEnum[Z](value: Z, cases: Schema.Case[Z, _]*): Any = { + private def encodeEnum[Z](schemaRaw: Schema[Z], value: Z, cases: Schema.Case[Z, _]*): Any = { + val schema = AvroSchemaCodec + .encodeToApacheAvro(schemaRaw) + .getOrElse(throw new Exception("Avro schema could not be generated for Enum.")) val fieldIndex = cases.indexWhere(c => c.deconstructOption(value).isDefined) if (fieldIndex >= 0) { val subtypeCase = cases(fieldIndex) - encodeValue(subtypeCase.deconstruct(value), subtypeCase.schema.asInstanceOf[Schema[Any]]) + if (schema.getType == SchemaAvro.Type.ENUM) { + GenericData.get.createEnum(schema.getEnumSymbols.get(fieldIndex), schema) + } else { + + encodeValue(subtypeCase.deconstruct(value), subtypeCase.schema.asInstanceOf[Schema[Any]]) + + } } else { throw new Exception("Could not find matching case for enum value.") } diff --git a/zio-schema-avro/shared/src/main/scala/zio/schema/codec/AvroSchemaCodec.scala b/zio-schema-avro/shared/src/main/scala/zio/schema/codec/AvroSchemaCodec.scala index 6630e6177..bb64aeac7 100644 --- a/zio-schema-avro/shared/src/main/scala/zio/schema/codec/AvroSchemaCodec.scala +++ b/zio-schema-avro/shared/src/main/scala/zio/schema/codec/AvroSchemaCodec.scala @@ -517,8 +517,8 @@ object AvroSchemaCodec extends AvroSchemaCodec { } def hasAvroEnumAnnotation(annotations: Chunk[Any]): Boolean = annotations.exists { - case AvroAnnotations.avroEnum => true - case _ => false + case AvroAnnotations.avroEnum() => true + case _ => false } def wrapAvro(schemaAvro: SchemaAvro, name: String, marker: AvroPropMarker): SchemaAvro = { diff --git a/zio-schema-avro/shared/src/test/scala-2/zio/schema/codec/AvroCodecSpec.scala b/zio-schema-avro/shared/src/test/scala-2/zio/schema/codec/AvroCodecSpec.scala index 5790916c8..d0796b28b 100644 --- a/zio-schema-avro/shared/src/test/scala-2/zio/schema/codec/AvroCodecSpec.scala +++ b/zio-schema-avro/shared/src/test/scala-2/zio/schema/codec/AvroCodecSpec.scala @@ -21,9 +21,9 @@ import java.time.{ import java.util.UUID import zio._ +import zio.schema.codec.AvroAnnotations.avroEnum import zio.schema.{ DeriveSchema, Schema } import zio.stream.ZStream -import zio.test.TestAspect.failing import zio.test._ object AvroCodecSpec extends ZIOSpecDefault { @@ -106,10 +106,12 @@ object AvroCodecSpec extends ZIOSpecDefault { case class BooleanValue(value: Boolean) extends OneOf + case object NullValue extends OneOf + implicit val schemaOneOf: Schema[OneOf] = DeriveSchema.gen[OneOf] } - sealed trait Enums + @avroEnum() sealed trait Enums object Enums { case object A extends Enums @@ -649,12 +651,18 @@ object AvroCodecSpec extends ZIOSpecDefault { val result = codec.decode(bytes) assertTrue(result == Right(OneOf.BooleanValue(true))) }, + test("Decode Enum3 - case object") { + val codec = AvroCodec.schemaBasedBinaryCodec[OneOf] + val bytes = codec.encode(OneOf.NullValue) + val result = codec.decode(bytes) + assertTrue(result == Right(OneOf.NullValue)) + }, test("Decode Enum5") { val codec = AvroCodec.schemaBasedBinaryCodec[Enums] val bytes = codec.encode(Enums.A) val result = codec.decode(bytes) assertTrue(result == Right(Enums.A)) - } @@ failing, // TODO: the case object from a sealed trait are not properly encoded and decoded. + }, test("Decode Person") { val codec = AvroCodec.schemaBasedBinaryCodec[Person] val bytes = codec.encode(Person("John", 42)) diff --git a/zio-schema-avro/shared/src/test/scala-2/zio/schema/codec/AvroSchemaCodecSpec.scala b/zio-schema-avro/shared/src/test/scala-2/zio/schema/codec/AvroSchemaCodecSpec.scala index 31060202f..25895f44d 100644 --- a/zio-schema-avro/shared/src/test/scala-2/zio/schema/codec/AvroSchemaCodecSpec.scala +++ b/zio-schema-avro/shared/src/test/scala-2/zio/schema/codec/AvroSchemaCodecSpec.scala @@ -60,7 +60,7 @@ object AvroSchemaCodecSpec extends ZIOSpecDefault { }, test("encodes sealed trait objects only as enum when avroEnum annotation is present") { - val schema = DeriveSchema.gen[SpecTestData.CaseObjectsOnlyAdt].annotate(AvroAnnotations.avroEnum) + val schema = DeriveSchema.gen[SpecTestData.CaseObjectsOnlyAdt].annotate(AvroAnnotations.avroEnum()) val result = AvroSchemaCodec.encode(schema) val expected = """{"type":"enum","name":"MyEnum","symbols":["A","B","MyC"]}"""