Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Nov 4, 2024
1 parent ab3a6a4 commit a62f50f
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 25 deletions.
4 changes: 3 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: String)
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(inputCol).nonEmpty && $(outputCol).nonEmpty) {
val size = AttributeGroup.fromStructField(schema($(inputCol))).size
val size = AttributeGroup.fromStructField(
SchemaUtils.getSchemaField(schema, $(inputCol))
).size
if (size >= 0) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
$(outputCol), size)
Expand Down
33 changes: 20 additions & 13 deletions mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)

// First we index each string column referenced by the input terms.
val indexed = terms.zipWithIndex.map { case (term, i) =>
dataset.schema(term).dataType match {
val termField = SchemaUtils.getSchemaField(dataset.schema, term)
termField.dataType match {
case _: StringType =>
val indexCol = tmpColumn("stridx")
encoderStages += new StringIndexer()
Expand All @@ -231,7 +232,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
prefixesToRewrite(indexCol + "_") = term + "_"
(term, indexCol)
case _: VectorUDT =>
val group = AttributeGroup.fromStructField(dataset.schema(term))
val group = AttributeGroup.fromStructField(termField)
val size = if (group.size < 0) {
firstRow.getAs[Vector](i).size
} else {
Expand All @@ -250,7 +251,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
// Then we handle one-hot encoding and interactions between terms.
var keepReferenceCategory = false
val encodedTerms = resolvedFormula.terms.map {
case Seq(term) if dataset.schema(term).dataType == StringType =>
case Seq(term) if SchemaUtils.getSchemaFieldType(dataset.schema, term) == StringType =>
val encodedCol = tmpColumn("onehot")
// Formula w/o intercept, one of the categories in the first category feature is
// being used as reference category, we will not drop any category for that feature.
Expand Down Expand Up @@ -292,7 +293,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
encoderStages += new ColumnPruner(tempColumns.toSet)

if ((dataset.schema.fieldNames.contains(resolvedFormula.label) &&
dataset.schema(resolvedFormula.label).dataType == StringType) || $(forceIndexLabel)) {
SchemaUtils.getSchemaFieldType(
dataset.schema, resolvedFormula.label) == StringType) || $(forceIndexLabel)) {
encoderStages += new StringIndexer()
.setInputCol(resolvedFormula.label)
.setOutputCol($(labelCol))
Expand Down Expand Up @@ -359,8 +361,8 @@ class RFormulaModel private[feature](
val withFeatures = pipelineModel.transformSchema(schema)
if (resolvedFormula.label.isEmpty || hasLabelCol(withFeatures)) {
withFeatures
} else if (schema.exists(_.name == resolvedFormula.label)) {
val nullable = schema(resolvedFormula.label).dataType match {
} else if (SchemaUtils.checkSchemaFieldExist(schema, resolvedFormula.label)) {
val nullable = SchemaUtils.getSchemaFieldType(schema, resolvedFormula.label) match {
case _: NumericType | BooleanType => false
case _ => true
}
Expand All @@ -387,8 +389,8 @@ class RFormulaModel private[feature](
val labelName = resolvedFormula.label
if (labelName.isEmpty || hasLabelCol(dataset.schema)) {
dataset.toDF()
} else if (dataset.schema.exists(_.name == labelName)) {
dataset.schema(labelName).dataType match {
} else if (SchemaUtils.checkSchemaFieldExist(dataset.schema, labelName)) {
SchemaUtils.getSchemaFieldType(dataset.schema, labelName) match {
case _: NumericType | BooleanType =>
dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
case other =>
Expand All @@ -402,10 +404,12 @@ class RFormulaModel private[feature](
}

private def checkCanTransform(schema: StructType): Unit = {
val columnNames = schema.map(_.name)
require(!columnNames.contains($(featuresCol)), "Features column already exists.")
require(
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType],
!SchemaUtils.checkSchemaFieldExist(schema, $(featuresCol)), "Features column already exists."
)
require(
!SchemaUtils.checkSchemaFieldExist(schema, $(labelCol))
|| SchemaUtils.getSchemaFieldType(schema, $(labelCol)).isInstanceOf[NumericType],
s"Label column already exists and is not of type ${NumericType.simpleString}.")
}

Expand Down Expand Up @@ -550,7 +554,9 @@ private class VectorAttributeRewriter(

override def transform(dataset: Dataset[_]): DataFrame = {
val metadata = {
val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
val group = AttributeGroup.fromStructField(
SchemaUtils.getSchemaField(dataset.schema, vectorCol)
)
val attrs = group.attributes.get.map { attr =>
if (attr.name.isDefined) {
val name = prefixesToRewrite.foldLeft(attr.name.get) { case (curName, (from, to)) =>
Expand All @@ -563,7 +569,8 @@ private class VectorAttributeRewriter(
}
new AttributeGroup(vectorCol, attrs).toMetadata()
}
val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col)
val vectorColFieldName = SchemaUtils.getSchemaField(dataset.schema, vectorCol).name
val otherCols = dataset.columns.filter(_ != vectorColFieldName).map(dataset.col)
val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata)
import org.apache.spark.util.ArrayImplicits._
dataset.select((otherCols :+ rewrittenCol).toImmutableArraySeq : _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,9 @@ private[feature] object SelectorModel {
featuresCol: String,
isNumericAttribute: Boolean): StructField = {
val selector = selectedFeatures.toSet
val origAttrGroup = AttributeGroup.fromStructField(schema(featuresCol))
val origAttrGroup = AttributeGroup.fromStructField(
SchemaUtils.getSchemaField(schema, featuresCol)
)
val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) {
origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ class IndexToString @Since("2.2.0") (@Since("1.5.0") override val uid: String)
@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
val inputDataType = SchemaUtils.getSchemaFieldType(schema, inputColName)
require(inputDataType.isInstanceOf[NumericType],
s"The input column $inputColName must be a numeric type, " +
s"but got $inputDataType.")
Expand All @@ -579,7 +579,7 @@ class IndexToString @Since("2.2.0") (@Since("1.5.0") override val uid: String)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val inputColSchema = dataset.schema($(inputCol))
val inputColSchema = SchemaUtils.getSchemaField(dataset.schema, $(inputCol))
// If the labels array is empty use column metadata
val values = if (!isDefined(labels) || $(labels).isEmpty) {
Attribute.fromStructField(inputColSchema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,9 @@ object UnivariateFeatureSelectorModel extends MLReadable[UnivariateFeatureSelect
featuresCol: String,
isNumericAttribute: Boolean): StructField = {
val selector = selectedFeatures.toSet
val origAttrGroup = AttributeGroup.fromStructField(schema(featuresCol))
val origAttrGroup = AttributeGroup.fromStructField(
SchemaUtils.getSchemaField(schema, featuresCol)
)
val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) {
origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,9 @@ class VectorIndexerModel private[ml] (
SchemaUtils.checkColumnType(schema, $(inputCol), dataType)

// If the input metadata specifies numFeatures, compare with expected numFeatures.
val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol)))
val origAttrGroup = AttributeGroup.fromStructField(
SchemaUtils.getSchemaField(schema, $(inputCol))
)
val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
Some(origAttrGroup.attributes.get.length)
} else {
Expand All @@ -466,7 +468,9 @@ class VectorIndexerModel private[ml] (
* @return Output column field. This field does not contain non-ML metadata.
*/
private def prepOutputField(schema: StructType): StructField = {
val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol)))
val origAttrGroup = AttributeGroup.fromStructField(
SchemaUtils.getSchemaField(schema, $(inputCol))
)
val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) {
// Convert original attributes to modified attributes
val origAttrs: Array[Attribute] = origAttrGroup.attributes.get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructType}
Expand Down Expand Up @@ -98,7 +98,9 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String)
val localSize = getSize
val localHandleInvalid = getHandleInvalid

val group = AttributeGroup.fromStructField(dataset.schema(localInputCol))
val group = AttributeGroup.fromStructField(
SchemaUtils.getSchemaField(dataset.schema, localInputCol)
)
val newGroup = validateSchemaAndSize(dataset.schema, group)
if (localHandleInvalid == VectorSizeHint.OPTIMISTIC_INVALID && group.size == localSize) {
dataset.toDF()
Expand Down Expand Up @@ -139,7 +141,7 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String)
val localSize = getSize
val localInputCol = getInputCol

val inputColType = schema(getInputCol).dataType
val inputColType = SchemaUtils.getSchemaFieldType(schema, getInputCol)
require(
inputColType.isInstanceOf[VectorUDT],
s"Input column, $getInputCol must be of Vector type, got $inputColType"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
override def transform(dataset: Dataset[_]): DataFrame = {
// Validity checks
transformSchema(dataset.schema)
val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
val inputAttr = AttributeGroup.fromStructField(
SchemaUtils.getSchemaField(dataset.schema, $(inputCol))
)
if ($(indices).nonEmpty) {
val size = inputAttr.size
if (size >= 0) {
Expand Down Expand Up @@ -130,7 +132,9 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") override val uid: Stri

/** Get the feature indices in order: indices, names */
private def getSelectedFeatureIndices(schema: StructType): Array[Int] = {
val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names))
val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(
SchemaUtils.getSchemaField(schema, $(inputCol)), $(names)
)
val indFeatures = $(indices)
val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length
lazy val errMsg = "VectorSlicer requires indices and names to be disjoint" +
Expand Down

0 comments on commit a62f50f

Please sign in to comment.