Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Aug 12, 2024
1 parent c11f169 commit fb080e8
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,6 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits),
Seq(outputCols, splitsArray))

val sparkSession = SparkSession.getDefaultSession.get
val transformDataset = sparkSession.createDataFrame(
new ju.ArrayList[Row](), schema = schema
)
if (isSet(inputCols)) {
require(getInputCols.length == getOutputCols.length &&
getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " +
Expand All @@ -205,15 +201,13 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String

var transformedSchema = schema
$(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) =>
val colType = transformDataset.col(inputCol).expr.dataType
SchemaUtils.checkNumericType(colType, inputCol, "")
SchemaUtils.checkNumericType(schema, inputCol)
transformedSchema = SchemaUtils.appendColumn(transformedSchema,
prepOutputField($(splitsArray)(idx), outputCol))
}
transformedSchema
} else {
val colType = transformDataset.col($(inputCol)).expr.dataType
SchemaUtils.checkNumericType(colType, $(inputCol), "")
SchemaUtils.checkNumericType(schema, $(inputCol))
SchemaUtils.appendColumn(schema, prepOutputField($(splits), $(outputCol)))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp
s" and outputCols(${outputColNames.length}) should have the same length")
val outputFields = inputColNames.zip(outputColNames).map { case (inputCol, outputCol) =>
val inputField = SchemaUtils.getSchemaField(schema, inputCol)
SchemaUtils.checkNumericType(inputField.dataType, inputCol, "")
SchemaUtils.checkNumericType(schema, inputCol)
StructField(outputCol, inputField.dataType, inputField.nullable)
}
StructType(schema ++ outputFields)
Expand Down Expand Up @@ -179,8 +179,8 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
val quantileDataset = dataset.select(inputColumns.zipWithIndex.map {
case (colName, index) => col(colName).alias(quantileColNames(index))
}.toImmutableArraySeq: _*)
quantileDataset.select(cols.toImmutableArraySeq: _*)
.stat.approxQuantile(inputColumns, Array(0.5), $(relativeError))
quantileDataset
.stat.approxQuantile(quantileColNames, Array(0.5), $(relativeError))
.map(_.headOption.getOrElse(Double.NaN))

case Imputer.mode =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.ml.feature

import java.{util => ju}

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
Expand All @@ -29,7 +27,7 @@ import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{col, lit, udf}
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
Expand Down Expand Up @@ -90,14 +88,9 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
s"output columns ${outputColNames.length}.")

val sparkSession = SparkSession.getDefaultSession.get
val transformDataset = sparkSession.createDataFrame(
new ju.ArrayList[Row](), schema = schema
)
// Input columns must be NumericType.
inputColNames.foreach { colName =>
val dataType = transformDataset.col(colName).expr.dataType
SchemaUtils.checkNumericType(dataType, colName, "")
SchemaUtils.checkNumericType(schema, colName)
}

// Prepares output columns with proper attributes by examining input columns.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@

package org.apache.spark.ml.feature

import java.{util => ju}

import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -190,13 +188,8 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui

var outputFields = schema.fields

val sparkSession = SparkSession.getDefaultSession.get
val transformDataset = sparkSession.createDataFrame(
new ju.ArrayList[Row](), schema = schema
)
inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) =>
val colType = transformDataset.col(inputColName).expr.dataType
SchemaUtils.checkNumericType(colType, inputColName, "")
SchemaUtils.checkNumericType(schema, inputColName)
require(!schema.fieldNames.contains(outputColName),
s"Output column $outputColName already exists.")
val attr = NominalAttribute.defaultAttr.withName(outputColName)
Expand Down
21 changes: 5 additions & 16 deletions mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ private[spark] object SchemaUtils {
schema: StructType,
colName: String,
msg: String): Unit = {
val actualDataType = schema(colName).dataType
checkNumericType(actualDataType, colName, msg)
val actualDataType = getSchemaFieldType(schema, colName)
val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
require(actualDataType.isInstanceOf[NumericType],
s"Column $colName must be of type ${NumericType.simpleString} but was actually of type " +
s"${actualDataType.catalogString}.$message")
}

/**
Expand All @@ -87,20 +90,6 @@ private[spark] object SchemaUtils {
checkNumericType(schema, colName, "")
}

/**
* Check whether the given actual data type is the numeric data type.
* @param actualDataType actual data type of the column
*/
def checkNumericType(
actualDataType: DataType,
colName: String,
msg: String): Unit = {
val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
require(actualDataType.isInstanceOf[NumericType],
s"Column $colName must be of type ${NumericType.simpleString} but was actually of type " +
s"${actualDataType.catalogString}.$message")
}

/**
* Appends a new column to the input schema. This fails if the given output column already exists.
* @param schema input schema
Expand Down

0 comments on commit fb080e8

Please sign in to comment.