diff --git a/ml_metadata/metadata_store/mysql_metadata_source.cc b/ml_metadata/metadata_store/mysql_metadata_source.cc index 9daf02500..ba6dfb12c 100644 --- a/ml_metadata/metadata_store/mysql_metadata_source.cc +++ b/ml_metadata/metadata_store/mysql_metadata_source.cc @@ -402,7 +402,7 @@ std::string MySqlMetadataSource::EscapeString(absl::string_view value) const { CHECK(mysql_real_escape_string(db_, buffer, value.data(), value.length()) != -1UL) << "NO_BACKSLASH_ESCAPES SQL mode should not be enabled."; - std::string result = absl::StrCat("'", buffer, "'"); + std::string result(buffer); delete[] buffer; return result; } diff --git a/ml_metadata/metadata_store/postgresql_metadata_source.cc b/ml_metadata/metadata_store/postgresql_metadata_source.cc index 40b79180c..966527b96 100644 --- a/ml_metadata/metadata_store/postgresql_metadata_source.cc +++ b/ml_metadata/metadata_store/postgresql_metadata_source.cc @@ -315,8 +315,12 @@ std::string PostgreSQLMetadataSource::EscapeString( char* escaped_str = PQescapeLiteral(conn_, value.data(), value.size()); std::string result{escaped_str}; + // PQescapeLiteral will wrap the escaped string in '', which is redundant to + // the existing MLMD syntax. Therefore stripping the outer '' from the escaped + // string. + std::string substring = result.substr(1, std::strlen(result.data()) - 2); PQfreemem(escaped_str); - return result; + return substring; } std::string PostgreSQLMetadataSource::EncodeBytes( diff --git a/ml_metadata/metadata_store/postgresql_query_executor.cc b/ml_metadata/metadata_store/postgresql_query_executor.cc index ab0db89d4..501a9251b 100644 --- a/ml_metadata/metadata_store/postgresql_query_executor.cc +++ b/ml_metadata/metadata_store/postgresql_query_executor.cc @@ -374,10 +374,10 @@ absl::Status PostgreSQLQueryExecutor::DowngradeMetadataSource( return absl::OkStatus(); } std::string PostgreSQLQueryExecutor::Bind(const char* value) { - return metadata_source_->EscapeString(value); + return absl::StrCat("'", metadata_source_->EscapeString(value), "'"); } std::string PostgreSQLQueryExecutor::Bind(absl::string_view value) { - return metadata_source_->EscapeString(value); + return absl::StrCat("'", metadata_source_->EscapeString(value), "'"); } std::string PostgreSQLQueryExecutor::Bind(int value) { return std::to_string(value); @@ -390,10 +390,10 @@ std::string PostgreSQLQueryExecutor::Bind(double value) { } std::string PostgreSQLQueryExecutor::Bind(const google::protobuf::Any& value) { return absl::StrCat( - "decode(", + "decode('", metadata_source_->EscapeString( metadata_source_->EncodeBytes(value.SerializeAsString())), - ", 'base64')"); + "', 'base64')"); } std::string PostgreSQLQueryExecutor::Bind(bool value) { return value ? "TRUE" : "FALSE"; diff --git a/ml_metadata/metadata_store/query_config_executor.cc b/ml_metadata/metadata_store/query_config_executor.cc index 094a7abd4..d369a511d 100644 --- a/ml_metadata/metadata_store/query_config_executor.cc +++ b/ml_metadata/metadata_store/query_config_executor.cc @@ -283,11 +283,11 @@ absl::Status QueryConfigExecutor::DowngradeMetadataSource( } std::string QueryConfigExecutor::Bind(const char* value) { - return metadata_source_->EscapeString(value); + return absl::StrCat("'", metadata_source_->EscapeString(value), "'"); } std::string QueryConfigExecutor::Bind(absl::string_view value) { - return metadata_source_->EscapeString(value); + return absl::StrCat("'", metadata_source_->EscapeString(value), "'"); } std::string QueryConfigExecutor::Bind(int value) {