Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Jul 12, 2024
1 parent 835bb3e commit f91fedf
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{Column, DataFrame, Dataset, Encoder, Encoders, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -136,7 +137,7 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi
}

protected def extractInputDataType(schema: StructType, inputColName: String): Option[DataType] = {
val inputSplits = inputColName.split("\\.")
val inputSplits = UnresolvedAttribute.parseAttributeName(inputColName)
var dtype: Option[DataType] = Some(schema)
var i = 0
while (i < inputSplits.length && dtype.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest {
(5, 1.0)
).toDF("id", "labelIndex")

testTransformerByGlobalCheckFunc[(Int, String)](df, indexerModel, "id", "labelIndex") { rows =>
val attr = Attribute.fromStructField(rows.head.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("a", "c", "b"))
assert(rows === expected.collect().toSeq)
}
val dfOutput = indexerModel.transform(df)
val outputs = dfOutput.select("id", "labelIndex").collect().toSeq
val attr = Attribute.fromStructField(outputs.head.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("a", "c", "b"))
assert(outputs === expected.collect().toSeq)
}

test("StringIndexerUnseen") {
Expand Down

0 comments on commit f91fedf

Please sign in to comment.