From a62f50f906b3de24289e8df4d80f07ca98359d65 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 4 Nov 2024 20:01:02 +0800 Subject: [PATCH] update Signed-off-by: Weichen Xu --- .../org/apache/spark/ml/feature/DCT.scala | 4 ++- .../apache/spark/ml/feature/RFormula.scala | 33 +++++++++++-------- .../apache/spark/ml/feature/Selector.scala | 4 ++- .../spark/ml/feature/StringIndexer.scala | 4 +-- .../feature/UnivariateFeatureSelector.scala | 4 ++- .../spark/ml/feature/VectorIndexer.scala | 8 +++-- .../spark/ml/feature/VectorSizeHint.scala | 8 +++-- .../spark/ml/feature/VectorSlicer.scala | 8 +++-- 8 files changed, 48 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index d057e5a62e507..9a8bfb195666b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -79,7 +79,9 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: String) override def transformSchema(schema: StructType): StructType = { var outputSchema = super.transformSchema(schema) if ($(inputCol).nonEmpty && $(outputCol).nonEmpty) { - val size = AttributeGroup.fromStructField(schema($(inputCol))).size + val size = AttributeGroup.fromStructField( + SchemaUtils.getSchemaField(schema, $(inputCol)) + ).size if (size >= 0) { outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, $(outputCol), size) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 77bd18423ef1b..221d70c18d5aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -220,7 +220,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) // First we index each string column referenced by the input terms. val indexed = terms.zipWithIndex.map { case (term, i) => - dataset.schema(term).dataType match { + val termField = SchemaUtils.getSchemaField(dataset.schema, term) + termField.dataType match { case _: StringType => val indexCol = tmpColumn("stridx") encoderStages += new StringIndexer() @@ -231,7 +232,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) case _: VectorUDT => - val group = AttributeGroup.fromStructField(dataset.schema(term)) + val group = AttributeGroup.fromStructField(termField) val size = if (group.size < 0) { firstRow.getAs[Vector](i).size } else { @@ -250,7 +251,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) // Then we handle one-hot encoding and interactions between terms. var keepReferenceCategory = false val encodedTerms = resolvedFormula.terms.map { - case Seq(term) if dataset.schema(term).dataType == StringType => + case Seq(term) if SchemaUtils.getSchemaFieldType(dataset.schema, term) == StringType => val encodedCol = tmpColumn("onehot") // Formula w/o intercept, one of the categories in the first category feature is // being used as reference category, we will not drop any category for that feature. @@ -292,7 +293,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) encoderStages += new ColumnPruner(tempColumns.toSet) if ((dataset.schema.fieldNames.contains(resolvedFormula.label) && - dataset.schema(resolvedFormula.label).dataType == StringType) || $(forceIndexLabel)) { + SchemaUtils.getSchemaFieldType( + dataset.schema, resolvedFormula.label) == StringType) || $(forceIndexLabel)) { encoderStages += new StringIndexer() .setInputCol(resolvedFormula.label) .setOutputCol($(labelCol)) @@ -359,8 +361,8 @@ class RFormulaModel private[feature]( val withFeatures = pipelineModel.transformSchema(schema) if (resolvedFormula.label.isEmpty || hasLabelCol(withFeatures)) { withFeatures - } else if (schema.exists(_.name == resolvedFormula.label)) { - val nullable = schema(resolvedFormula.label).dataType match { + } else if (SchemaUtils.checkSchemaFieldExist(schema, resolvedFormula.label)) { + val nullable = SchemaUtils.getSchemaFieldType(schema, resolvedFormula.label) match { case _: NumericType | BooleanType => false case _ => true } @@ -387,8 +389,8 @@ class RFormulaModel private[feature]( val labelName = resolvedFormula.label if (labelName.isEmpty || hasLabelCol(dataset.schema)) { dataset.toDF() - } else if (dataset.schema.exists(_.name == labelName)) { - dataset.schema(labelName).dataType match { + } else if (SchemaUtils.checkSchemaFieldExist(dataset.schema, labelName)) { + SchemaUtils.getSchemaFieldType(dataset.schema, labelName) match { case _: NumericType | BooleanType => dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType)) case other => @@ -402,10 +404,12 @@ class RFormulaModel private[feature]( } private def checkCanTransform(schema: StructType): Unit = { - val columnNames = schema.map(_.name) - require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( - !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType], + !SchemaUtils.checkSchemaFieldExist(schema, $(featuresCol)), "Features column already exists." + ) + require( + !SchemaUtils.checkSchemaFieldExist(schema, $(labelCol)) + || SchemaUtils.getSchemaFieldType(schema, $(labelCol)).isInstanceOf[NumericType], s"Label column already exists and is not of type ${NumericType.simpleString}.") } @@ -550,7 +554,9 @@ private class VectorAttributeRewriter( override def transform(dataset: Dataset[_]): DataFrame = { val metadata = { - val group = AttributeGroup.fromStructField(dataset.schema(vectorCol)) + val group = AttributeGroup.fromStructField( + SchemaUtils.getSchemaField(dataset.schema, vectorCol) + ) val attrs = group.attributes.get.map { attr => if (attr.name.isDefined) { val name = prefixesToRewrite.foldLeft(attr.name.get) { case (curName, (from, to)) => @@ -563,7 +569,8 @@ private class VectorAttributeRewriter( } new AttributeGroup(vectorCol, attrs).toMetadata() } - val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col) + val vectorColFieldName = SchemaUtils.getSchemaField(dataset.schema, vectorCol).name + val otherCols = dataset.columns.filter(_ != vectorColFieldName).map(dataset.col) val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata) import org.apache.spark.util.ArrayImplicits._ dataset.select((otherCols :+ rewrittenCol).toImmutableArraySeq : _*) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala index 8ff880b7b8aaf..dde1068c5b924 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala @@ -341,7 +341,9 @@ private[feature] object SelectorModel { featuresCol: String, isNumericAttribute: Boolean): StructField = { val selector = selectedFeatures.toSet - val origAttrGroup = AttributeGroup.fromStructField(schema(featuresCol)) + val origAttrGroup = AttributeGroup.fromStructField( + SchemaUtils.getSchemaField(schema, featuresCol) + ) val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 20b03edf23c4a..1acffa471e9a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -564,7 +564,7 @@ class IndexToString @Since("2.2.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) - val inputDataType = schema(inputColName).dataType + val inputDataType = SchemaUtils.getSchemaFieldType(schema, inputColName) require(inputDataType.isInstanceOf[NumericType], s"The input column $inputColName must be a numeric type, " + s"but got $inputDataType.") @@ -579,7 +579,7 @@ class IndexToString @Since("2.2.0") (@Since("1.5.0") override val uid: String) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val inputColSchema = dataset.schema($(inputCol)) + val inputColSchema = SchemaUtils.getSchemaField(dataset.schema, $(inputCol)) // If the labels array is empty use column metadata val values = if (!isDefined(labels) || $(labels).isEmpty) { Attribute.fromStructField(inputColSchema) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala index 9c2033c28430e..ea1a8c6438c8d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala @@ -410,7 +410,9 @@ object UnivariateFeatureSelectorModel extends MLReadable[UnivariateFeatureSelect featuresCol: String, isNumericAttribute: Boolean): StructField = { val selector = selectedFeatures.toSet - val origAttrGroup = AttributeGroup.fromStructField(schema(featuresCol)) + val origAttrGroup = AttributeGroup.fromStructField( + SchemaUtils.getSchemaField(schema, featuresCol) + ) val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index ff89dee68ea38..b2323d2b706f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -445,7 +445,9 @@ class VectorIndexerModel private[ml] ( SchemaUtils.checkColumnType(schema, $(inputCol), dataType) // If the input metadata specifies numFeatures, compare with expected numFeatures. - val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol))) + val origAttrGroup = AttributeGroup.fromStructField( + SchemaUtils.getSchemaField(schema, $(inputCol)) + ) val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) { Some(origAttrGroup.attributes.get.length) } else { @@ -466,7 +468,9 @@ class VectorIndexerModel private[ml] ( * @return Output column field. This field does not contain non-ML metadata. */ private def prepOutputField(schema: StructType): StructField = { - val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol))) + val origAttrGroup = AttributeGroup.fromStructField( + SchemaUtils.getSchemaField(schema, $(inputCol)) + ) val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { // Convert original attributes to modified attributes val origAttrs: Array[Attribute] = origAttrGroup.attributes.get diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala index 5c96d07e0ca94..4abb607733e35 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.VectorUDT import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StringType, StructType} @@ -98,7 +98,9 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String) val localSize = getSize val localHandleInvalid = getHandleInvalid - val group = AttributeGroup.fromStructField(dataset.schema(localInputCol)) + val group = AttributeGroup.fromStructField( + SchemaUtils.getSchemaField(dataset.schema, localInputCol) + ) val newGroup = validateSchemaAndSize(dataset.schema, group) if (localHandleInvalid == VectorSizeHint.OPTIMISTIC_INVALID && group.size == localSize) { dataset.toDF() @@ -139,7 +141,7 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String) val localSize = getSize val localInputCol = getInputCol - val inputColType = schema(getInputCol).dataType + val inputColType = SchemaUtils.getSchemaFieldType(schema, getInputCol) require( inputColType.isInstanceOf[VectorUDT], s"Input column, $getInputCol must be of Vector type, got $inputColType" diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 5687ba878634a..58a44a41f0e84 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -99,7 +99,9 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") override val uid: Stri override def transform(dataset: Dataset[_]): DataFrame = { // Validity checks transformSchema(dataset.schema) - val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol))) + val inputAttr = AttributeGroup.fromStructField( + SchemaUtils.getSchemaField(dataset.schema, $(inputCol)) + ) if ($(indices).nonEmpty) { val size = inputAttr.size if (size >= 0) { @@ -130,7 +132,9 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") override val uid: Stri /** Get the feature indices in order: indices, names */ private def getSelectedFeatureIndices(schema: StructType): Array[Int] = { - val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names)) + val nameFeatures = MetadataUtils.getFeatureIndicesFromNames( + SchemaUtils.getSchemaField(schema, $(inputCol)), $(names) + ) val indFeatures = $(indices) val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length lazy val errMsg = "VectorSlicer requires indices and names to be disjoint" +