diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 63a3a67e90cba..48872a564202a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -28,6 +28,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{Column, DataFrame, Dataset, Encoder, Encoders, Row} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{If, Literal} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions._ @@ -136,7 +137,7 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi } protected def extractInputDataType(schema: StructType, inputColName: String): Option[DataType] = { - val inputSplits = inputColName.split("\\.") + val inputSplits = UnresolvedAttribute.parseAttributeName(inputColName) var dtype: Option[DataType] = Some(schema) var i = 0 while (i < inputSplits.length && dtype.isDefined) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index fc3d2d349ab06..8f3750959d2be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -150,12 +150,12 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest { (5, 1.0) ).toDF("id", "labelIndex") - testTransformerByGlobalCheckFunc[(Int, String)](df, indexerModel, "id", "labelIndex") { rows => - val attr = Attribute.fromStructField(rows.head.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("a", "c", "b")) - assert(rows === expected.collect().toSeq) - } + val dfOutput = indexerModel.transform(df) + val outputs = dfOutput.select("id", "labelIndex").collect().toSeq + val attr = Attribute.fromStructField(outputs.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("a", "c", "b")) + assert(outputs === expected.collect().toSeq) } test("StringIndexerUnseen") {