Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Oct 11, 2024
1 parent b0a373b commit ef0f032
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.immutable
import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.sql.catalyst.util.AttributeNameParser
import org.apache.spark.sql.catalyst.util.{AttributeNameParser, QuotingUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -209,6 +209,10 @@ private[spark] object SchemaUtils {
checkColumnTypes(schema, colName, typeCandidates)
}

def toSQLId(parts: String): String = {
AttributeNameParser.parseAttributeName(parts).map(QuotingUtils.quoteIdentifier).mkString(".")
}

/**
* Get schema field.
* @param schema input schema
Expand All @@ -220,10 +224,9 @@ private[spark] object SchemaUtils {
if (fieldOpt.isEmpty) {
throw new SparkIllegalArgumentException(
errorClass = "FIELD_NOT_FOUND",
messageParameters = immutable.Map(
"fieldName" -> colName,
"fields" -> schema.simpleString
)
messageParameters = Map(
"fieldName" -> toSQLId(colName),
"fields" -> schema.fields.map(f => toSQLId(f.name)).mkString(", "))
)
}
fieldOpt.get._2
Expand Down

0 comments on commit ef0f032

Please sign in to comment.