Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Jul 15, 2024
1 parent 7a02d02 commit 076df2e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.util.collection.OpenHashMap
private[feature] trait StringIndexerBase extends Params with HasHandleInvalid with HasInputCol
with HasOutputCol with HasInputCols with HasOutputCols {

@transient private[ml] var _transformDataset: Dataset[_] = _
@transient private[ml] var transformDataset: Dataset[_] = _

/**
* Param for how to handle invalid data (unseen labels or NULL values).
Expand Down Expand Up @@ -127,7 +127,7 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi
val outputFields = inputColNames.zip(outputColNames).flatMap {
case (inputColName, outputColName) =>
try {
val dtype = _transformDataset.col(inputColName).expr.dataType
val dtype = transformDataset.col(inputColName).expr.dataType
Some(
validateAndTransformField(schema, inputColName, dtype, outputColName)
)
Expand Down Expand Up @@ -246,9 +246,9 @@ class StringIndexer @Since("1.4.0") (

@Since("2.0.0")
override def fit(dataset: Dataset[_]): StringIndexerModel = {
_transformDataset = dataset
transformDataset = dataset
transformSchema(dataset.schema, logging = true)
_transformDataset = null
transformDataset = null

// In case of equal frequency when frequencyDesc/Asc, the strings are further sorted
// alphabetically.
Expand Down Expand Up @@ -425,9 +425,9 @@ class StringIndexerModel (

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
_transformDataset = dataset
transformDataset = dataset
transformSchema(dataset.schema, logging = true)
_transformDataset = null
transformDataset = null

val (inputColNames, outputColNames) = getInOutCols()
val outputColumns = new Array[Column](outputColNames.length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest {
test("StringIndexer.transformSchema)") {
val idxToStr = new StringIndexer().setInputCol("input").setOutputCol("output")
val inSchema = StructType(Seq(StructField("input", StringType)))
idxToStr._transformDataset = spark.createDataFrame(List(Row("a")).asJava, schema = inSchema)
idxToStr.transformDataset = spark.createDataFrame(List(Row("a")).asJava, schema = inSchema)
val outSchema = idxToStr.transformSchema(inSchema)
assert(outSchema("output").dataType === DoubleType)
}
Expand All @@ -112,7 +112,7 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest {
setOutputCols(Array("output", "output2"))
val inSchema = StructType(Seq(StructField("input", StringType),
StructField("input2", StringType)))
idxToStr._transformDataset = spark.createDataFrame(
idxToStr.transformDataset = spark.createDataFrame(
List(Row("a", "b")).asJava,
schema = inSchema
)
Expand All @@ -131,7 +131,7 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest {
"input1 struct<a struct<f1 string, f2 string>>, " +
"input2 struct<b1 string, b2 string>, input3 string"
)
idxToStr._transformDataset = spark.createDataFrame(
idxToStr.transformDataset = spark.createDataFrame(
List(Row(Row(Row("a", "b")), Row("c", "d"), "e")).asJava,
schema = inSchema
)
Expand Down

0 comments on commit 076df2e

Please sign in to comment.