From a2049119320088801fe3d7f3bd0cc46b7d6ac6a3 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 10 Jul 2024 23:37:30 +0800 Subject: [PATCH] init Signed-off-by: Weichen Xu --- .../spark/ml/feature/StringIndexer.scala | 28 +++++++++++-- .../spark/ml/feature/StringIndexerSuite.scala | 40 ++++++++++++++++++- 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 60dc4d0240716..9e33c38cdaf4f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -103,8 +103,8 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi private def validateAndTransformField( schema: StructType, inputColName: String, + inputDataType: DataType, outputColName: String): StructField = { - val inputDataType = schema(inputColName).dataType require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], s"The input column $inputColName must be either string type or numeric type, " + s"but got $inputDataType.") @@ -122,11 +122,31 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi require(outputColNames.distinct.length == outputColNames.length, s"Output columns should not be duplicate.") + def extractInputDataType(inputColName: String): Option[DataType] = { + val inputSplits = inputColName.split("\\.") + var dtype: Option[DataType] = Some(schema) + var i = 0 + while (i < inputSplits.length && dtype.isDefined) { + val s = inputSplits(i) + dtype = if (dtype.get.isInstanceOf[StructType]) { + val struct = dtype.get.asInstanceOf[StructType] + if (struct.fieldNames.contains(s)) { + Some(struct(s).dataType) + } else None + } else None + i += 1 + } + + dtype + } + val outputFields = inputColNames.zip(outputColNames).flatMap { case (inputColName, outputColName) => - schema.fieldNames.contains(inputColName) match { - case true => Some(validateAndTransformField(schema, inputColName, outputColName)) - case false if skipNonExistsCol => None + extractInputDataType(inputColName) match { + case Some(dtype) => Some( + validateAndTransformField(schema, inputColName, dtype, outputColName) + ) + case None if skipNonExistsCol => None case _ => throw new SparkException(s"Input column $inputColName does not exist.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 99f12eab7d690..0e64516d3aa49 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -21,7 +21,8 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.catalyst.parser.DataTypeParser +import org.apache.spark.sql.functions.{col, struct} import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} class StringIndexerSuite extends MLTest with DefaultReadWriteTest { @@ -113,6 +114,43 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest { assert(outSchema("output2").dataType === DoubleType) } + test("StringIndexer.transformSchema nested col") { + val outputCols = Array("output", "output2", "output3", "output4", "output5") + val idxToStr = new StringIndexer().setInputCols( + Array("input1.a.f1", "input.a.f2", "input2.b1", "input2.b2", "input3") + ).setOutputCols(outputCols) + + val inSchema = DataTypeParser.parseTableSchema( + "input1 struct>, " + + "input2 struct, input3 string" + ) + val outSchema = idxToStr.transformSchema(inSchema) + + for (outputCol <- outputCols) { + assert(outSchema(outputCol).dataType === DoubleType) + } + } + + test("StringIndexer nested input cols") { + val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + val df = data.toDF("id", "label") + .select(col("id"), struct(col("label")).alias("c1")) + val indexer = new StringIndexer() + .setInputCol("c1.label") + .setOutputCol("labelIndex") + val indexerModel = indexer.fit(df) + MLTestingUtils.checkCopyAndUids(indexer, indexerModel) + // a -> 0, b -> 2, c -> 1 + val expected = Seq( + (0, 0.0), + (1, 2.0), + (2, 1.0), + (3, 0.0), + (4, 0.0), + (5, 1.0) + ).toDF("id", "labelIndex") + } + test("StringIndexerUnseen") { val data = Seq((0, "a"), (1, "b"), (4, "b")) val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"))