diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 78b1da496f254..f4e597e4ea9e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -22,7 +22,7 @@ import scala.collection.immutable import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.VectorUDT -import org.apache.spark.sql.catalyst.util.AttributeNameParser +import org.apache.spark.sql.catalyst.util.{AttributeNameParser, QuotingUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -209,6 +209,10 @@ private[spark] object SchemaUtils { checkColumnTypes(schema, colName, typeCandidates) } + def toSQLId(parts: String): String = { + AttributeNameParser.parseAttributeName(parts).map(QuotingUtils.quoteIdentifier).mkString(".") + } + /** * Get schema field. * @param schema input schema @@ -220,10 +224,9 @@ private[spark] object SchemaUtils { if (fieldOpt.isEmpty) { throw new SparkIllegalArgumentException( errorClass = "FIELD_NOT_FOUND", - messageParameters = immutable.Map( - "fieldName" -> colName, - "fields" -> schema.simpleString - ) + messageParameters = Map( + "fieldName" -> toSQLId(colName), + "fields" -> schema.fields.map(f => toSQLId(f.name)).mkString(", ")) ) } fieldOpt.get._2