Skip to content

Commit

Permalink
#patch Adding support for SdkBindingData[scala.Option[_]] (#308)
Browse files Browse the repository at this point in the history
* Adding support for SdkBindingData[scala.Option[_]]

Signed-off-by: Jonathan Schuchart <[email protected]>

* Fixing handling of product element names in scala 2.12

Signed-off-by: Jonathan Schuchart <[email protected]>

---------

Signed-off-by: Jonathan Schuchart <[email protected]>
  • Loading branch information
jschuchart-spot authored Aug 2, 2024
1 parent 806d894 commit e647d77
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ class SdkScalaTypeTest {
datetime: SdkBindingData[Instant],
duration: SdkBindingData[Duration],
blob: SdkBindingData[Blob],
generic: SdkBindingData[ScalarNested]
generic: SdkBindingData[ScalarNested],
none: SdkBindingData[Option[String]],
some: SdkBindingData[Option[String]]
)

case class CollectionInput(
Expand All @@ -105,7 +107,8 @@ class SdkScalaTypeTest {
booleans: SdkBindingData[List[Boolean]],
datetimes: SdkBindingData[List[Instant]],
durations: SdkBindingData[List[Duration]],
generics: SdkBindingData[List[ScalarNested]]
generics: SdkBindingData[List[ScalarNested]],
options: SdkBindingData[List[Option[String]]]
)

case class MapInput(
Expand All @@ -115,7 +118,8 @@ class SdkScalaTypeTest {
booleanMap: SdkBindingData[Map[String, Boolean]],
datetimeMap: SdkBindingData[Map[String, Instant]],
durationMap: SdkBindingData[Map[String, Duration]],
genericMap: SdkBindingData[Map[String, ScalarNested]]
genericMap: SdkBindingData[Map[String, ScalarNested]],
optionMap: SdkBindingData[Map[String, Option[String]]]
)

case class ComplexInput(
Expand Down Expand Up @@ -196,7 +200,9 @@ class SdkScalaTypeTest {
.literalType(LiteralType.ofBlobType(BlobType.DEFAULT))
.description("")
.build(),
"generic" -> createVar(SimpleType.STRUCT)
"generic" -> createVar(SimpleType.STRUCT),
"none" -> createVar(SimpleType.STRUCT),
"some" -> createVar(SimpleType.STRUCT)
).asJava

val output = SdkScalaType[ScalarInput].getVariableMap
Expand Down Expand Up @@ -274,6 +280,16 @@ class SdkScalaTypeTest {
).asJava
)
)
),
"none" -> Literal.ofScalar(
Scalar.ofGeneric(
Struct.of(Map.empty[String, Struct.Value].asJava)
)
),
"some" -> Literal.ofScalar(
Scalar.ofGeneric(
Struct.of(Map("value" -> Struct.Value.ofStringValue("hello")).asJava)
)
)
).asJava

Expand All @@ -295,6 +311,14 @@ class SdkScalaTypeTest {
List(ScalarNestedNested("foo", Some("bar"))),
Map("foo" -> ScalarNestedNested("foo", Some("bar")))
)
),
none = SdkBindingDataFactory.of(
SdkLiteralTypes.generics[Option[String]](),
Option(null)
),
some = SdkBindingDataFactory.of(
SdkLiteralTypes.generics[Option[String]](),
Option("hello")
)
)

Expand Down Expand Up @@ -323,7 +347,11 @@ class SdkScalaTypeTest {
List(ScalarNestedNested("foo", Some("bar"))),
Map("foo" -> ScalarNestedNested("foo", Some("bar")))
)
)
),
none =
SdkBindingDataFactory.of(SdkLiteralTypes.generics(), Option(null)),
some =
SdkBindingDataFactory.of(SdkLiteralTypes.generics(), Option("hello"))
)

val expected = Map(
Expand Down Expand Up @@ -399,6 +427,23 @@ class SdkScalaTypeTest {
).asJava
)
)
),
"none" -> Literal.ofScalar(
Scalar.ofGeneric(
Struct.of(
Map(__TYPE -> Struct.Value.ofStringValue("scala.None$")).asJava
)
)
),
"some" -> Literal.ofScalar(
Scalar.ofGeneric(
Struct.of(
Map(
"value" -> Struct.Value.ofStringValue("hello"),
__TYPE -> Struct.Value.ofStringValue("scala.Some")
).asJava
)
)
)
).asJava

Expand All @@ -416,7 +461,8 @@ class SdkScalaTypeTest {
"booleans" -> createCollectionVar(SimpleType.BOOLEAN),
"datetimes" -> createCollectionVar(SimpleType.DATETIME),
"durations" -> createCollectionVar(SimpleType.DURATION),
"generics" -> createCollectionVar(SimpleType.STRUCT)
"generics" -> createCollectionVar(SimpleType.STRUCT),
"options" -> createCollectionVar(SimpleType.STRUCT)
).asJava

val output = SdkScalaType[CollectionInput].getVariableMap
Expand All @@ -443,6 +489,14 @@ class SdkScalaTypeTest {
List(ScalarNestedNested("foo", Some("bar"))),
Map("foo" -> ScalarNestedNested("foo", Some("bar")))
)
),
none = SdkBindingDataFactory.of(
SdkLiteralTypes.generics[Option[String]](),
Option(null)
),
some = SdkBindingDataFactory.of(
SdkLiteralTypes.generics[Option[String]](),
Option("hello")
)
)

Expand All @@ -465,6 +519,14 @@ class SdkScalaTypeTest {
List(ScalarNestedNested("foo", Some("bar"))),
Map("foo" -> ScalarNestedNested("foo", Some("bar")))
)
),
"none" -> SdkBindingDataFactory.of(
SdkLiteralTypes.generics[Option[String]](),
Option(null)
),
"some" -> SdkBindingDataFactory.of(
SdkLiteralTypes.generics[Option[String]](),
Option("hello")
)
).asJava

Expand Down Expand Up @@ -531,6 +593,10 @@ class SdkScalaTypeTest {
Map("foo2" -> ScalarNestedNested("foo2", Some("bar2")))
)
)
),
options = SdkBindingDataFactory.of(
SdkLiteralTypes.generics[Option[String]](),
List(Option("hello"), Option(null))
)
)

Expand All @@ -550,7 +616,8 @@ class SdkScalaTypeTest {
"booleanMap" -> createMapVar(SimpleType.BOOLEAN),
"datetimeMap" -> createMapVar(SimpleType.DATETIME),
"durationMap" -> createMapVar(SimpleType.DURATION),
"genericMap" -> createMapVar(SimpleType.STRUCT)
"genericMap" -> createMapVar(SimpleType.STRUCT),
"optionMap" -> createMapVar(SimpleType.STRUCT)
).asJava

val output = SdkScalaType[MapInput].getVariableMap
Expand Down Expand Up @@ -598,6 +665,10 @@ class SdkScalaTypeTest {
Map("foo2" -> ScalarNestedNested("foo2", Some("bar2")))
)
)
),
optionMap = SdkBindingDataFactory.of(
SdkLiteralTypes.generics[Option[String]](),
Map("none" -> Option(null), "some" -> Option("hello"))
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import scala.reflect.api.{Mirror, TypeCreator, Universe}
import scala.reflect.runtime.universe
import scala.reflect.{ClassTag, classTag}
import scala.reflect.runtime.universe.{
ClassSymbol,
NoPrefix,
Symbol,
Type,
Expand Down Expand Up @@ -72,7 +73,7 @@ object SdkLiteralTypes {
blobs(BlobType.DEFAULT).asInstanceOf[SdkLiteralType[T]]
case t if t =:= typeOf[Binary] =>
binary().asInstanceOf[SdkLiteralType[T]]
case t if t <:< typeOf[Product] && !(t =:= typeOf[Option[_]]) =>
case t if t <:< typeOf[Product] =>
generics().asInstanceOf[SdkLiteralType[T]]

case t if t =:= typeOf[List[Long]] =>
Expand Down Expand Up @@ -391,24 +392,37 @@ object SdkLiteralTypes {
)
}

val clazz = typeOf[S].typeSymbol.asClass
val classMirror = mirror.reflectClass(clazz)
val constructor = typeOf[S].decl(termNames.CONSTRUCTOR).asMethod
val constructorMirror = classMirror.reflectConstructor(constructor)

val constructorArgs =
constructor.paramLists.flatten.map((param: Symbol) => {
val paramName = param.name.toString
val value = map.getOrElse(
paramName,
throw new IllegalArgumentException(
s"Map is missing required parameter named $paramName"
def instantiateViaConstructor(cls: ClassSymbol): S = {
val classMirror = mirror.reflectClass(cls)
val constructor = typeOf[S].decl(termNames.CONSTRUCTOR).asMethod
val constructorMirror = classMirror.reflectConstructor(constructor)

val constructorArgs =
constructor.paramLists.flatten.map((param: Symbol) => {
val paramName = param.name.toString
val value = map.getOrElse(
paramName,
throw new IllegalArgumentException(
s"Map is missing required parameter named $paramName"
)
)
)
valueToParamValue(value, param.typeSignature.dealias)
})
valueToParamValue(value, param.typeSignature.dealias)
})

constructorMirror(constructorArgs: _*).asInstanceOf[S]
}

val clazz = typeOf[S].typeSymbol.asClass
// special handling of scala.Option as it is a Product, but can't be instantiated like common
// case classes
if (clazz.name.toString == "Option")
map
.get("value")
.map(valueToParamValue(_, typeOf[S].typeArgs.head))
.asInstanceOf[S]
else
instantiateViaConstructor(clazz)

constructorMirror(constructorArgs: _*).asInstanceOf[S]
}

def structValueToAny(value: Struct.Value): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,8 @@ object SdkScalaType {
implicit def durationLiteralType: SdkScalaLiteralType[Duration] =
DelegateLiteralType(SdkLiteralTypes.durations())

// more specific matching to fail the usage of SdkBindingData[Option[_]]
implicit def optionLiteralType: SdkScalaLiteralType[Option[_]] = ???

// fixme: using Product is just an approximation for case class because Product
// is also super class of, for example, Option and Tuple
// is also super class of, for example, Either or Try
implicit def productLiteralType[T <: Product: TypeTag: ClassTag]
: SdkScalaLiteralType[T] =
DelegateLiteralType(SdkLiteralTypes.generics())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ package object flytekitscala {
} catch {
case _: Throwable =>
// fall back to java's way, less reliable and with limitations
product.getClass.getDeclaredFields.map(_.getName).toList
val methodNames = product.getClass.getDeclaredMethods.map(_.getName)
product.getClass.getDeclaredFields
.map(_.getName)
.filter(methodNames.contains)
.toList
}
}
}

0 comments on commit e647d77

Please sign in to comment.