Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Aug 12, 2024
1 parent 556c92c commit c11f169
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val inputFeatures = $(inputCols).map(c => dataset.schema(c))
val inputFeatures = $(inputCols).map(c => SchemaUtils.getSchemaField(dataset.schema, c))
val featureEncoders = getFeatureEncoders(inputFeatures.toImmutableArraySeq)
val featureAttrs = getFeatureAttrs(inputFeatures.toImmutableArraySeq)

Expand Down Expand Up @@ -102,11 +102,11 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext
Vectors.sparse(size, indices.result(), values.result()).compressed
}

val featureCols = inputFeatures.map { f =>
val featureCols = inputFeatures.zip($(inputCols)).map { case (f, inputCol) =>
f.dataType match {
case DoubleType => dataset(f.name)
case _: VectorUDT => dataset(f.name)
case _: NumericType | BooleanType => dataset(f.name).cast(DoubleType)
case DoubleType => dataset(inputCol)
case _: VectorUDT => dataset(inputCol)
case _: NumericType | BooleanType => dataset(inputCol).cast(DoubleType)
}
}
import org.apache.spark.util.ArrayImplicits._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.{col, struct}

class InteractionSuite extends MLTest with DefaultReadWriteTest {

Expand Down Expand Up @@ -169,4 +169,37 @@ class InteractionSuite extends MLTest with DefaultReadWriteTest {
.setOutputCol("myOutputCol")
testDefaultReadWrite(t)
}

test("nested input columns") {
val data = Seq(
(2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)),
(1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))
).toDF("a", "b", "expected")
val groupAttr = new AttributeGroup(
"b",
Array[Attribute](
NumericAttribute.defaultAttr.withName("foo"),
NumericAttribute.defaultAttr.withName("bar")))
val df = data.select(
col("a").as("a", NumericAttribute.defaultAttr.toMetadata()),
col("b").as("b", groupAttr.toMetadata()),
col("expected"))
.select(struct("a", "b").alias("nest"), col("expected"))
val trans = new Interaction().setInputCols(Array("nest.a", "nest.b")).setOutputCol("features")

trans.transform(df).select("features", "expected").collect().foreach {
case Row(features: Vector, expected: Vector) =>
assert(features === expected)
}

val res = trans.transform(df)
val attrs = AttributeGroup.fromStructField(res.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
Array[Attribute](
new NumericAttribute(Some("a:b_foo"), Some(1)),
new NumericAttribute(Some("a:b_bar"), Some(2))))
assert(attrs === expectedAttrs)
}

}

0 comments on commit c11f169

Please sign in to comment.