From 076df2efb1ba1286269b9769c6002598033bea71 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 15 Jul 2024 12:28:20 +0800 Subject: [PATCH] rename Signed-off-by: Weichen Xu --- .../org/apache/spark/ml/feature/StringIndexer.scala | 12 ++++++------ .../apache/spark/ml/feature/StringIndexerSuite.scala | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 9db3d44fb9abc..281f50a4773e0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.collection.OpenHashMap private[feature] trait StringIndexerBase extends Params with HasHandleInvalid with HasInputCol with HasOutputCol with HasInputCols with HasOutputCols { - @transient private[ml] var _transformDataset: Dataset[_] = _ + @transient private[ml] var transformDataset: Dataset[_] = _ /** * Param for how to handle invalid data (unseen labels or NULL values). @@ -127,7 +127,7 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi val outputFields = inputColNames.zip(outputColNames).flatMap { case (inputColName, outputColName) => try { - val dtype = _transformDataset.col(inputColName).expr.dataType + val dtype = transformDataset.col(inputColName).expr.dataType Some( validateAndTransformField(schema, inputColName, dtype, outputColName) ) @@ -246,9 +246,9 @@ class StringIndexer @Since("1.4.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { - _transformDataset = dataset + transformDataset = dataset transformSchema(dataset.schema, logging = true) - _transformDataset = null + transformDataset = null // In case of equal frequency when frequencyDesc/Asc, the strings are further sorted // alphabetically. @@ -425,9 +425,9 @@ class StringIndexerModel ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - _transformDataset = dataset + transformDataset = dataset transformSchema(dataset.schema, logging = true) - _transformDataset = null + transformDataset = null val (inputColNames, outputColNames) = getInOutCols() val outputColumns = new Array[Column](outputColNames.length) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 71ffe37e58305..47ff7e8a917a3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -102,7 +102,7 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest { test("StringIndexer.transformSchema)") { val idxToStr = new StringIndexer().setInputCol("input").setOutputCol("output") val inSchema = StructType(Seq(StructField("input", StringType))) - idxToStr._transformDataset = spark.createDataFrame(List(Row("a")).asJava, schema = inSchema) + idxToStr.transformDataset = spark.createDataFrame(List(Row("a")).asJava, schema = inSchema) val outSchema = idxToStr.transformSchema(inSchema) assert(outSchema("output").dataType === DoubleType) } @@ -112,7 +112,7 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest { setOutputCols(Array("output", "output2")) val inSchema = StructType(Seq(StructField("input", StringType), StructField("input2", StringType))) - idxToStr._transformDataset = spark.createDataFrame( + idxToStr.transformDataset = spark.createDataFrame( List(Row("a", "b")).asJava, schema = inSchema ) @@ -131,7 +131,7 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest { "input1 struct>, " + "input2 struct, input3 string" ) - idxToStr._transformDataset = spark.createDataFrame( + idxToStr.transformDataset = spark.createDataFrame( List(Row(Row(Row("a", "b")), Row("c", "d"), "e")).asJava, schema = inSchema )