Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Jul 10, 2024
1 parent fdbacdf commit a204911
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi
private def validateAndTransformField(
schema: StructType,
inputColName: String,
inputDataType: DataType,
outputColName: String): StructField = {
val inputDataType = schema(inputColName).dataType
require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
s"The input column $inputColName must be either string type or numeric type, " +
s"but got $inputDataType.")
Expand All @@ -122,11 +122,31 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi
require(outputColNames.distinct.length == outputColNames.length,
s"Output columns should not be duplicate.")

def extractInputDataType(inputColName: String): Option[DataType] = {
val inputSplits = inputColName.split("\\.")
var dtype: Option[DataType] = Some(schema)
var i = 0
while (i < inputSplits.length && dtype.isDefined) {
val s = inputSplits(i)
dtype = if (dtype.get.isInstanceOf[StructType]) {
val struct = dtype.get.asInstanceOf[StructType]
if (struct.fieldNames.contains(s)) {
Some(struct(s).dataType)
} else None
} else None
i += 1
}

dtype
}

val outputFields = inputColNames.zip(outputColNames).flatMap {
case (inputColName, outputColName) =>
schema.fieldNames.contains(inputColName) match {
case true => Some(validateAndTransformField(schema, inputColName, outputColName))
case false if skipNonExistsCol => None
extractInputDataType(inputColName) match {
case Some(dtype) => Some(
validateAndTransformField(schema, inputColName, dtype, outputColName)
)
case None if skipNonExistsCol => None
case _ => throw new SparkException(s"Input column $inputColName does not exist.")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.catalyst.parser.DataTypeParser
import org.apache.spark.sql.functions.{col, struct}
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}

class StringIndexerSuite extends MLTest with DefaultReadWriteTest {
Expand Down Expand Up @@ -113,6 +114,43 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest {
assert(outSchema("output2").dataType === DoubleType)
}

test("StringIndexer.transformSchema nested col") {
val outputCols = Array("output", "output2", "output3", "output4", "output5")
val idxToStr = new StringIndexer().setInputCols(
Array("input1.a.f1", "input.a.f2", "input2.b1", "input2.b2", "input3")
).setOutputCols(outputCols)

val inSchema = DataTypeParser.parseTableSchema(
"input1 struct<a struct<f1 string, f2 string>>, " +
"input2 struct<b1 string, b2 string>, input3 string"
)
val outSchema = idxToStr.transformSchema(inSchema)

for (outputCol <- outputCols) {
assert(outSchema(outputCol).dataType === DoubleType)
}
}

test("StringIndexer nested input cols") {
val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
val df = data.toDF("id", "label")
.select(col("id"), struct(col("label")).alias("c1"))
val indexer = new StringIndexer()
.setInputCol("c1.label")
.setOutputCol("labelIndex")
val indexerModel = indexer.fit(df)
MLTestingUtils.checkCopyAndUids(indexer, indexerModel)
// a -> 0, b -> 2, c -> 1
val expected = Seq(
(0, 0.0),
(1, 2.0),
(2, 1.0),
(3, 0.0),
(4, 0.0),
(5, 1.0)
).toDF("id", "labelIndex")
}

test("StringIndexerUnseen") {
val data = Seq((0, "a"), (1, "b"), (4, "b"))
val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"))
Expand Down

0 comments on commit a204911

Please sign in to comment.