Skip to content

Commit

Permalink
[protobuf] Add map support (#988)
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones authored Jul 23, 2024
1 parent 4918b13 commit 5cbdfe2
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 54 deletions.
58 changes: 29 additions & 29 deletions docs/mapping.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,37 @@

| Scala | Avro | BigQuery | Bigtable<sup>7</sup> | Datastore | Parquet | Protobuf | TensorFlow |
|-----------------------------------|------------------------------|------------------------|---------------------------------|-----------------------|-----------------------------------|-------------------------|---------------------|
| `Unit` | `NULL` | x | x | `Null` | x | x | x |
| `Boolean` | `BOOLEAN` | `BOOL` | `Byte` | `Boolean` | `BOOLEAN` | `Boolean` | `INT64`<sup>3</sup> |
| `Char` | `INT`<sup>3</sup> | `INT64`<sup>3</sup2> | `Char` | `Integer`<sup>3</sup> | `INT32`<sup>3</sup> | `Int`<sup>3</sup> | `INT64`<sup>3</sup> |
| `Byte` | `INT`<sup>3</sup> | `INT64`<sup>3</sup2> | `Byte` | `Integer`<sup>3</sup> | `INT32`<sup>9</sup> | `Int`<sup>3</sup> | `INT64`<sup>3</sup> |
| `Short` | `INT`<sup>3</sup> | `INT64`<sup>3</sup2> | `Short` | `Integer`<sup>3</sup> | `INT32`<sup>9</sup> | `Int`<sup>3</sup> | `INT64`<sup>3</sup> |
| `Int` | `INT` | `INT64`<sup>3</sup2> | `Int` | `Integer`<sup>3</sup> | `INT32`<sup>9</sup> | `Int` | `INT64`<sup>3</sup> |
| `Long` | `LONG` | `INT64` | `Long` | `Integer` | `INT64`<sup>9</sup> | `Long` | `INT64` |
| `Float` | `FLOAT` | `FLOAT64`<sup>3</sup2> | `Float` | `Double`<sup>3</sup> | `FLOAT` | `Float` | `FLOAT` |
| `Double` | `DOUBLE` | `FLOAT64` | `Double` | `Double` | `DOUBLE` | `Double` | `FLOAT`<sup>3</sup> |
| `CharSequence` | `STRING` | x | x | x | x | x | x |
| `String` | `STRING` | `STRING` | `String` | `String` | `BINARY` | `String` | `BYTES`<sup>3</sup> |
| `Array[Byte]` | `BYTES` | `BYTES` | `ByteString` | `Blob` | `BINARY` | `ByteString` | `BYTES` |
| `Unit` | `null` | x | x | `Null` | x | x | x |
| `Boolean` | `boolean` | `BOOL` | `Byte` | `Boolean` | `BOOLEAN` | `Boolean` | `INT64`<sup>3</sup> |
| `Char` | `int`<sup>3</sup> | `INT64`<sup>3</sup2> | `Char` | `Integer`<sup>3</sup> | `INT32`<sup>3</sup> | `Int`<sup>3</sup> | `INT64`<sup>3</sup> |
| `Byte` | `int`<sup>3</sup> | `INT64`<sup>3</sup2> | `Byte` | `Integer`<sup>3</sup> | `INT32`<sup>9</sup> | `Int`<sup>3</sup> | `INT64`<sup>3</sup> |
| `Short` | `int`<sup>3</sup> | `INT64`<sup>3</sup2> | `Short` | `Integer`<sup>3</sup> | `INT32`<sup>9</sup> | `Int`<sup>3</sup> | `INT64`<sup>3</sup> |
| `Int` | `int` | `INT64`<sup>3</sup2> | `Int` | `Integer`<sup>3</sup> | `INT32`<sup>9</sup> | `Int` | `INT64`<sup>3</sup> |
| `Long` | `long` | `INT64` | `Long` | `Integer` | `INT64`<sup>9</sup> | `Long` | `INT64` |
| `Float` | `float` | `FLOAT64`<sup>3</sup2> | `Float` | `Double`<sup>3</sup> | `FLOAT` | `Float` | `FLOAT` |
| `Double` | `double` | `FLOAT64` | `Double` | `Double` | `DOUBLE` | `Double` | `FLOAT`<sup>3</sup> |
| `CharSequence` | `string` | x | x | x | x | x | x |
| `String` | `string` | `STRING` | `String` | `String` | `BINARY` | `String` | `BYTES`<sup>3</sup> |
| `Array[Byte]` | `bytes` | `BYTES` | `ByteString` | `Blob` | `BINARY` | `ByteString` | `BYTES` |
| `ByteString` | x | x | `ByteString` | `Blob` | x | `ByteString` | `BYTES` |
| `ByteBuffer` | `BYTES` | x | x | | x | x | x |
| Enum<sup>1</sup> | `ENUM` | `STRING`<sup>3</sup2> | `String` | `String`<sup>3</sup> | `BINARY`/`ENUM`<sup>9</sup> | Enum | `BYTES`<sup>3</sup> |
| `ByteBuffer` | `bytes` | x | x | | x | x | x |
| Enum<sup>1</sup> | `enum` | `STRING`<sup>3</sup2> | `String` | `String`<sup>3</sup> | `BINARY`/`ENUM`<sup>9</sup> | Enum | `BYTES`<sup>3</sup> |
| `BigInt` | x | x | `BigInt` | x | x | x | x |
| `BigDecimal` | `BYTES`<sup>4</sup> | `NUMERIC`<sup>6</sup2> | `Int` scale + unscaled `BigInt` | x | `LOGICAL[DECIMAL]`<sup>9,14</sup> | x | x |
| `Option[T]` | `UNION[NULL, T]`<sup>5</sup> | `NULLABLE` | Empty as `None` | Absent as `None` | `OPTIONAL` | `optional`<sup>10</sup> | Size <= 1 |
| `Iterable[T]`<sup>2</sup> | `ARRAY` | `REPEATED` | x | `Array` | `REPEATED`<sup>13</sup> | `repeated` | Size >= 0 |
| Nested | `RECORD` | `STRUCT` | Flat<sup>8</sup> | `Entity` | Group | `Message` | Flat<sup>8</sup> |
| `Map[CharSequence, T]` | `MAP[STRING, T]` | x | x | x | x | x | |
| `Map[String, T]` | `MAP[STRING, T]` | x | x | x | x | x | x |
| `java.time.Instant` | `LONG`<sup>11</sup> | `TIMESTAMP` | x | `Timestamp` | `LOGICAL[TIMESTAMP]`<sup>9</sup> | x | x |
| `java.time.LocalDateTime` | `LONG`<sup>11</sup> | `DATETIME` | x | x | `LOGICAL[TIMESTAMP]`<sup>9</sup> | x | x |
| `BigDecimal` | `bytes`<sup>4</sup> | `NUMERIC`<sup>6</sup2> | `Int` scale + unscaled `BigInt` | x | `LOGICAL[DECIMAL]`<sup>9,14</sup> | x | x |
| `Option[T]` | `union[null, T]`<sup>5</sup> | `NULLABLE` | Empty as `None` | Absent as `None` | `OPTIONAL` | `optional`<sup>10</sup> | Size <= 1 |
| `Iterable[T]`<sup>2</sup> | `array[T]` | `REPEATED` | x | `Array` | `REPEATED`<sup>13</sup> | `repeated` | Size >= 0 |
| Nested | `record` | `STRUCT` | Flat<sup>8</sup> | `Entity` | Group | `Message` | Flat<sup>8</sup> |
| `Map[K, V]` | `map[V]`<sup>15</sup> | x | x | x | x | `map<K, V>` | x |
| `java.time.Instant` | `long`<sup>11</sup> | `TIMESTAMP` | x | `Timestamp` | `LOGICAL[TIMESTAMP]`<sup>9</sup> | x | x |
| `java.time.LocalDateTime` | `long`<sup>11</sup> | `DATETIME` | x | x | `LOGICAL[TIMESTAMP]`<sup>9</sup> | x | x |
| `java.time.OffsetTime` | x | x | x | x | `LOGICAL[TIME]`<sup>9</sup> | x | x |
| `java.time.LocalTime` | `LONG`<sup>11</sup> | `TIME` | x | x | `LOGICAL[TIME]`<sup>9</sup> | x | x |
| `java.time.LocalDate` | `INT`<sup>11</sup> | `DATE` | x | x | `LOGICAL[DATE]`<sup>9</sup> | x | x |
| `org.joda.time.LocalDate` | `INT`<sup>11</sup> | x | x | x | x | x | x |
| `org.joda.time.DateTime` | `INT`<sup>11</sup> | x | x | x | x | x | x |
| `org.joda.time.LocalTime` | `INT`<sup>11</sup> | x | x | x | x | x | x |
| `java.util.UUID` | `STRING`<sup>4</sup> | x | ByteString (16 bytes) | x | `FIXED[16]` | x | x |
| `(Long, Long, Long)`<sup>12</sup> | `FIXED[12]` | x | x | x | x | x | x |
| `java.time.LocalTime` | `long`<sup>11</sup> | `TIME` | x | x | `LOGICAL[TIME]`<sup>9</sup> | x | x |
| `java.time.LocalDate` | `int`<sup>11</sup> | `DATE` | x | x | `LOGICAL[DATE]`<sup>9</sup> | x | x |
| `org.joda.time.LocalDate` | `int`<sup>11</sup> | x | x | x | x | x | x |
| `org.joda.time.DateTime` | `int`<sup>11</sup> | x | x | x | x | x | x |
| `org.joda.time.LocalTime` | `int`<sup>11</sup> | x | x | x | x | x | x |
| `java.util.UUID` | `string`<sup>4</sup> | x | ByteString (16 bytes) | x | `FIXED[16]` | x | x |
| `(Long, Long, Long)`<sup>12</sup> | `fixed[12]` | x | x | x | x | x | x |

1. Those wrapped in`UnsafeEnum` are encoded as strings,
see [enums.md](https://github.com/spotify/magnolify/blob/master/docs/enums.md) for more
Expand All @@ -59,3 +58,4 @@
format: `required group $FIELDNAME (LIST) { repeated $FIELDTYPE array ($FIELDSCHEMA); }`.
14. Parquet's Decimal logical format supports multiple representations, and are not implicitly scoped by default. Import
one of: `magnolify.parquet.ParquetField.{decimal32, decimal64, decimalFixed, decimalBinary}`.
15. Map key type in avro is fixed to string. Scala Map key type must be either `String` or `CharSequence`.
66 changes: 50 additions & 16 deletions protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package magnolify.protobuf
import java.lang.reflect.Method
import java.util as ju
import com.google.protobuf.Descriptors.{Descriptor, EnumValueDescriptor, FieldDescriptor}
import com.google.protobuf.{ByteString, Message, ProtocolMessageEnum}
import com.google.protobuf.{ByteString, MapEntry, Message, ProtocolMessageEnum}
import magnolia1.*
import magnolify.shared.*
import magnolify.shims.FactoryCompat
Expand Down Expand Up @@ -54,17 +54,13 @@ object ProtobufType {
r.checkDefaults(descriptor)(cm)
}

@transient private var _newBuilder: Method = _
private def newBuilder: Message.Builder = {
if (_newBuilder == null) {
_newBuilder = ct.runtimeClass.getMethod("newBuilder")
}
@transient private lazy val _newBuilder: Method = ct.runtimeClass.getMethod("newBuilder")
private def newBuilder(): Message.Builder =
_newBuilder.invoke(null).asInstanceOf[Message.Builder]
}

private val caseMapper: CaseMapper = cm
override def from(v: MsgT): T = r.from(v)(caseMapper)
override def to(v: T): MsgT = r.to(v, newBuilder)(caseMapper).asInstanceOf[MsgT]
override def to(v: T): MsgT = r.to(v, newBuilder())(caseMapper).asInstanceOf[MsgT]
}
case _ =>
throw new IllegalArgumentException(s"ProtobufType can only be created from Record. Got $f")
Expand Down Expand Up @@ -130,6 +126,10 @@ object ProtobufField {
}
)

private def newFieldBuilder(b: Message.Builder)(f: FieldDescriptor): Message.Builder =
if (f.getType != FieldDescriptor.Type.MESSAGE) null
else b.newBuilderForField(f)

override def checkDefaults(descriptor: Descriptor)(cm: CaseMapper): Unit = {
val fields = getFields(descriptor)(cm)
caseClass.parameters.foreach { p =>
Expand Down Expand Up @@ -169,17 +169,11 @@ object ProtobufField {

override def to(v: T, bu: Message.Builder)(cm: CaseMapper): Message = {
val fields = getFields(bu.getDescriptorForType)(cm)

caseClass.parameters
.foldLeft(bu) { (b, p) =>
val field = fields(p.index)
val value = if (field.getType == FieldDescriptor.Type.MESSAGE) {
// nested records
p.typeclass.to(p.dereference(v), b.newBuilderForField(field))(cm)
} else {
// non-nested
p.typeclass.to(p.dereference(v), null)(cm)
}
val builder = newFieldBuilder(bu)(field)
val value = p.typeclass.to(p.dereference(v), builder)(cm)
if (value == null) b else b.setField(field, value)
}
.build()
Expand Down Expand Up @@ -284,4 +278,44 @@ object ProtobufField {
override def to(v: C[T], b: Message.Builder)(cm: CaseMapper): ju.List[f.ToT] =
if (v.isEmpty) null else v.iterator.map(f.to(_, b)(cm)).toList.asJava
}

implicit def pfMap[K, V](implicit
kf: ProtobufField[K],
vf: ProtobufField[V]
): ProtobufField[Map[K, V]] =
new Aux[Map[K, V], ju.List[MapEntry[kf.FromT, vf.FromT]], ju.List[MapEntry[kf.ToT, vf.ToT]]] {

override val default: Option[Map[K, V]] = Some(Map.empty)

override def from(v: ju.List[MapEntry[kf.FromT, vf.FromT]])(cm: CaseMapper): Map[K, V] = {
val b = Map.newBuilder[K, V]
if (v != null) {
b ++= v.asScala.map(me => kf.from(me.getKey)(cm) -> vf.from(me.getValue)(cm))
}
b.result()
}

private def newFieldBuilder(b: Message.Builder)(f: FieldDescriptor): Message.Builder =
if (f.getType != FieldDescriptor.Type.MESSAGE) null
else b.newBuilderForField(f)

override def to(v: Map[K, V], b: Message.Builder)(
cm: CaseMapper
): ju.List[MapEntry[kf.ToT, vf.ToT]] = {
if (v.isEmpty) {
null
} else {
val keyField = b.getDescriptorForType.findFieldByName("key")
val valueField = b.getDescriptorForType.findFieldByName("value")
v.map { case (k, v) =>
b
.setField(keyField, kf.to(k, newFieldBuilder(b)(keyField))(cm))
.setField(valueField, vf.to(v, newFieldBuilder(b)(valueField))(cm))
.build()
.asInstanceOf[MapEntry[kf.ToT, vf.ToT]]
}.toList
.asJava
}
}
}
}
9 changes: 7 additions & 2 deletions protobuf/src/test/protobuf/Proto2.proto
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,24 @@ message NestedP2 {
repeated RequiredP2 l = 6;
}

message CollectionP2 {
message CollectionsP2 {
repeated int32 a = 1;
repeated int32 l = 2;
repeated int32 v = 3;
repeated int32 s = 4;
}

message MoreCollectionP2 {
message MoreCollectionsP2 {
repeated int32 i = 1;
repeated int32 s = 2;
repeated int32 is = 3;
}

message MapsP2 {
map<string, int32> mp = 1;
map<string, NestedP2> mn = 2;
}

message EnumsP2 {
enum JavaEnums {
RED = 0;
Expand Down
Loading

0 comments on commit 5cbdfe2

Please sign in to comment.