Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Aug 13, 2024
1 parent 063decb commit 06a66f3
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 53 deletions.
16 changes: 4 additions & 12 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

package org.apache.spark.ml.feature

import java.util.ArrayList

import scala.collection.mutable.ArrayBuilder

import org.apache.spark.SparkException
import org.apache.spark.{SparkException, SparkIllegalArgumentException}
import org.apache.spark.annotation.Since
import org.apache.spark.internal.{LogKeys, MDC}
import org.apache.spark.ml.Transformer
Expand Down Expand Up @@ -119,12 +117,8 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
(Seq($(inputCol)), Seq($(outputCol)), Seq($(threshold)))
}

val sparkSession = SparkSession.getDefaultSession.get
val transformDataset = sparkSession.createDataFrame(
new ArrayList[Row](), schema = dataset.schema
)
val mappedOutputCols = inputColNames.zip(tds).map { case (colName, td) =>
transformDataset.col(colName).expr.dataType match {
dataset.col(colName).expr.dataType match {
case DoubleType =>
when(!col(colName).isNaN && col(colName) > td, lit(1.0))
.otherwise(lit(0.0))
Expand Down Expand Up @@ -203,16 +197,14 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
}

var outputFields = schema.fields
val sparkSession = SparkSession.getDefaultSession.get
val transformDataset = sparkSession.createDataFrame(new ArrayList[Row](), schema = schema)
inputColNames.zip(outputColNames).foreach { case (inputColName, outputColName) =>
require(!schema.fieldNames.contains(outputColName),
s"Output column $outputColName already exists.")

val inputType = try {
transformDataset.col(inputColName).expr.dataType
SchemaUtils.getSchemaFieldType(schema, inputColName)
} catch {
case _: AnalysisException =>
case e: SparkIllegalArgumentException if e.getErrorClass == "FIELD_NOT_FOUND" =>
throw new SparkException(s"Input column $inputColName does not exist.")
case e: Exception =>
throw e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String

var transformedSchema = schema
$(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) =>
SchemaUtils.checkNumericType(schema, inputCol)
SchemaUtils.checkNumericType(transformedSchema, inputCol)
transformedSchema = SchemaUtils.appendColumn(transformedSchema,
prepOutputField($(splitsArray)(idx), outputCol))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

package org.apache.spark.ml.feature

import java.util.{ArrayList, Locale}
import java.util.Locale

import org.apache.spark.annotation.Since
import org.apache.spark.internal.{LogKeys, MDC}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
Expand Down Expand Up @@ -194,14 +194,10 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String

val (inputColNames, outputColNames) = getInOutCols()

val sparkSession = SparkSession.getDefaultSession.get
val transformDataset = sparkSession.createDataFrame(
new ArrayList[Row](), schema = schema
)
val newCols = inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
require(!schema.fieldNames.contains(outputColName),
s"Output Column $outputColName already exists.")
val inputType = transformDataset.col(inputColName).expr.dataType
val inputType = SchemaUtils.getSchemaFieldType(schema, inputColName)
require(DataTypeUtils.sameType(inputType, ArrayType(StringType)), "Input type must be " +
s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.")
StructField(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,17 @@

package org.apache.spark.ml.feature

import java.util.ArrayList

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.{SparkException, SparkIllegalArgumentException}
import org.apache.spark.annotation.Since
import org.apache.spark.internal.{LogKeys, MDC}
import org.apache.spark.ml.{Estimator, Model, Transformer}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, Encoder, Encoders, Row, SparkSession}
import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, Encoder, Encoders, Row}
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -124,17 +122,15 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi
require(outputColNames.distinct.length == outputColNames.length,
s"Output columns should not be duplicate.")

val sparkSession = SparkSession.getDefaultSession.get
val transformDataset = sparkSession.createDataFrame(new ArrayList[Row](), schema = schema)
val outputFields = inputColNames.zip(outputColNames).flatMap {
case (inputColName, outputColName) =>
try {
val dtype = transformDataset.col(inputColName).expr.dataType
val dtype = SchemaUtils.getSchemaFieldType(schema, inputColName)
Some(
validateAndTransformField(schema, inputColName, dtype, outputColName)
)
} catch {
case _: AnalysisException =>
case e: SparkIllegalArgumentException if e.getErrorClass == "FIELD_NOT_FOUND" =>
if (skipNonExistsCol) {
None
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.ml.feature

import java.util.{ArrayList, NoSuchElementException}
import java.util.NoSuchElementException

import scala.collection.mutable

Expand All @@ -29,7 +29,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -160,12 +160,8 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
override def transformSchema(schema: StructType): StructType = {
val inputColNames = $(inputCols)
val outputColName = $(outputCol)
val sparkSession = SparkSession.getDefaultSession.get
val transformDataset = sparkSession.createDataFrame(
new ArrayList[Row](), schema = schema
)
val incorrectColumns = inputColNames.flatMap { name =>
transformDataset.col(name).expr.dataType match {
SchemaUtils.getSchemaFieldType(schema, name) match {
case _: NumericType | BooleanType => None
case t if t.isInstanceOf[VectorUDT] => None
case other => Some(s"Data type ${other.catalogString} of column $name is not supported.")
Expand Down
14 changes: 2 additions & 12 deletions mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,12 @@ private[spark] object SchemaUtils {
def checkNumericType(
schema: StructType,
colName: String,
msg: String): Unit = {
msg: String = ""): Unit = {
val actualDataType = getSchemaFieldType(schema, colName)
val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
require(actualDataType.isInstanceOf[NumericType],
s"Column $colName must be of type ${NumericType.simpleString} but was actually of type " +
s"${actualDataType.catalogString}.$message")
}

/**
* Check whether the given schema contains a column of the numeric data type.
* @param colName column name
*/
def checkNumericType(
schema: StructType,
colName: String): Unit = {
checkNumericType(schema, colName, "")
s"${actualDataType.catalogString}.$message")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,11 @@ class BucketizerSuite extends MLTest with DefaultReadWriteTest {
.collect()

resultForSingleCol.zip(resultForMultiCols).foreach {
case (rowForSingle, rowForMultiCols) =>
assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) &&
rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) &&
rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2) &&
rowForSingle.getDouble(3) == rowForMultiCols.getDouble(3))
case (rowForSingle, rowForMultiCols) =>
assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) &&
rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) &&
rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2) &&
rowForSingle.getDouble(3) == rowForMultiCols.getDouble(3))
}
}

Expand All @@ -420,7 +420,7 @@ class BucketizerSuite extends MLTest with DefaultReadWriteTest {
("outputCols", Array("result1", "result2")))
}

test("Bucket nested input column") {
test("Bucketizer nested input column") {
// Check a set of valid feature values.
val splits = Array(-0.5, 0.0, 0.5)
val validData = Array(-0.5, -0.3, 0.0, 0.2)
Expand Down

0 comments on commit 06a66f3

Please sign in to comment.