From c11f169dacaac64b511e987805b4621683e08dfb Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 12 Aug 2024 18:32:38 +0800 Subject: [PATCH] update Signed-off-by: Weichen Xu --- .../apache/spark/ml/feature/Interaction.scala | 10 +++--- .../spark/ml/feature/InteractionSuite.scala | 35 ++++++++++++++++++- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index bd2d08c0d79ed..3311231e6d830 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -70,7 +70,7 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val inputFeatures = $(inputCols).map(c => dataset.schema(c)) + val inputFeatures = $(inputCols).map(c => SchemaUtils.getSchemaField(dataset.schema, c)) val featureEncoders = getFeatureEncoders(inputFeatures.toImmutableArraySeq) val featureAttrs = getFeatureAttrs(inputFeatures.toImmutableArraySeq) @@ -102,11 +102,11 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext Vectors.sparse(size, indices.result(), values.result()).compressed } - val featureCols = inputFeatures.map { f => + val featureCols = inputFeatures.zip($(inputCols)).map { case (f, inputCol) => f.dataType match { - case DoubleType => dataset(f.name) - case _: VectorUDT => dataset(f.name) - case _: NumericType | BooleanType => dataset(f.name).cast(DoubleType) + case DoubleType => dataset(inputCol) + case _: VectorUDT => dataset(inputCol) + case _: NumericType | BooleanType => dataset(inputCol).cast(DoubleType) } } import org.apache.spark.util.ArrayImplicits._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 90038d8cc3797..d8d3128d521c7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.Row -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, struct} class InteractionSuite extends MLTest with DefaultReadWriteTest { @@ -169,4 +169,37 @@ class InteractionSuite extends MLTest with DefaultReadWriteTest { .setOutputCol("myOutputCol") testDefaultReadWrite(t) } + + test("nested input columns") { + val data = Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)) + ).toDF("a", "b", "expected") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val df = data.select( + col("a").as("a", NumericAttribute.defaultAttr.toMetadata()), + col("b").as("b", groupAttr.toMetadata()), + col("expected")) + .select(struct("a", "b").alias("nest"), col("expected")) + val trans = new Interaction().setInputCols(Array("nest.a", "nest.b")).setOutputCol("features") + + trans.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features === expected) + } + + val res = trans.transform(df) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a:b_foo"), Some(1)), + new NumericAttribute(Some("a:b_bar"), Some(2)))) + assert(attrs === expectedAttrs) + } + }