Skip to content

Commit

Permalink
feat(c/driver/postgresql): Implement Foreign Key information for GetO…
Browse files Browse the repository at this point in the history
…bjects (#757)
  • Loading branch information
WillAyd authored Jun 12, 2023
1 parent 53752b3 commit eabc5b7
Show file tree
Hide file tree
Showing 2 changed files with 295 additions and 21 deletions.
129 changes: 108 additions & 21 deletions c/driver/postgresql/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ class PqGetObjectsHelper {

constraint_column_usages_col_ = table_constraints_items_->children[3];
constraint_column_usage_items_ = constraint_column_usages_col_->children[0];
fk_catalog_col_ = constraint_column_usage_items_->children[0];
fk_db_schema_col_ = constraint_column_usage_items_->children[1];
fk_table_col_ = constraint_column_usage_items_->children[2];
fk_column_name_col_ = constraint_column_usage_items_->children[3];

RAISE_ADBC(AppendCatalogs());
RAISE_ADBC(FinishArrowArray());
Expand Down Expand Up @@ -480,31 +484,89 @@ class PqGetObjectsHelper {

AdbcStatusCode AppendConstraints(std::string schema_name, std::string table_name) {
struct StringBuilder query = {0};
if (StringBuilderInit(&query, /*initial_size*/ 512)) {
if (StringBuilderInit(&query, /*initial_size*/ 4096)) {
return ADBC_STATUS_INTERNAL;
}

std::vector<std::string> params = {schema_name, table_name};
const char* stmt =
"SELECT con.conname, CASE con.contype WHEN 'c' THEN 'CHECK' WHEN 'u' THEN "
"'UNIQUE' WHEN 'p' THEN 'PRIMARY KEY' WHEN 'f' THEN 'FOREIGN KEY' "
"END AS contype, ARRAY(SELECT attr.attname) AS colnames, con.confkey "
"FROM pg_catalog.pg_constraint AS con "
"CROSS JOIN UNNEST(conkey) AS conkeys "
"INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = con.conrelid "
"INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace "
"INNER JOIN pg_catalog.pg_attribute AS attr ON attr.attnum = conkeys "
"AND cls.oid = attr.attrelid "
"WHERE con.contype IN ('c', 'u', 'p', 'f') AND nsp.nspname LIKE $1 "
"AND cls.relname LIKE $2";
"WITH fk_unnest AS ( "
" SELECT "
" con.conname, "
" 'FOREIGN KEY' AS contype, "
" conrelid, "
" UNNEST(con.conkey) AS conkey, "
" confrelid, "
" UNNEST(con.confkey) AS confkey "
" FROM pg_catalog.pg_constraint AS con "
" INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = conrelid "
" INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace "
" WHERE con.contype = 'f' AND nsp.nspname LIKE $1 "
" AND cls.relname LIKE $2 "
"), "
"fk_names AS ( "
" SELECT "
" fk_unnest.conname, "
" fk_unnest.contype, "
" attr.attname, "
" fnsp.nspname AS fschema, "
" fcls.relname AS ftable, "
" fattr.attname AS fattname "
" FROM fk_unnest "
" INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = fk_unnest.conrelid "
" INNER JOIN pg_catalog.pg_class AS fcls ON fcls.oid = fk_unnest.confrelid "
" INNER JOIN pg_catalog.pg_namespace AS fnsp ON fnsp.oid = fcls.relnamespace"
" INNER JOIN pg_catalog.pg_attribute AS attr ON attr.attnum = "
"fk_unnest.conkey "
" AND attr.attrelid = fk_unnest.conrelid "
" LEFT JOIN pg_catalog.pg_attribute AS fattr ON fattr.attnum = "
"fk_unnest.confkey "
" AND fattr.attrelid = fk_unnest.confrelid "
"), "
"fkeys AS ( "
" SELECT "
" conname, "
" contype, "
" ARRAY_AGG(attname) AS colnames, "
" fschema, "
" ftable, "
" ARRAY_AGG(fattname) AS fcolnames "
" FROM fk_names "
" GROUP BY "
" conname, "
" contype, "
" fschema, "
" ftable "
"), "
"other_constraints AS ( "
" SELECT con.conname, CASE con.contype WHEN 'c' THEN 'CHECK' WHEN 'u' THEN "
" 'UNIQUE' WHEN 'p' THEN 'PRIMARY KEY' END AS contype, "
" ARRAY_AGG(attr.attname) AS colnames "
" FROM pg_catalog.pg_constraint AS con "
" CROSS JOIN UNNEST(conkey) AS conkeys "
" INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = con.conrelid "
" INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace "
" INNER JOIN pg_catalog.pg_attribute AS attr ON attr.attnum = conkeys "
" AND cls.oid = attr.attrelid "
" WHERE con.contype IN ('c', 'u', 'p') AND nsp.nspname LIKE $1 "
" AND cls.relname LIKE $2 "
" GROUP BY conname, contype "
") "
"SELECT "
" conname, contype, colnames, fschema, ftable, fcolnames "
"FROM fkeys "
"UNION ALL "
"SELECT "
" conname, contype, colnames, NULL, NULL, NULL "
"FROM other_constraints";

if (StringBuilderAppend(&query, "%s", stmt)) {
StringBuilderReset(&query);
return ADBC_STATUS_INTERNAL;
}

if (column_name_ != NULL) {
if (StringBuilderAppend(&query, "%s", " AND con.conname LIKE $3")) {
if (StringBuilderAppend(&query, "%s", " WHERE conname LIKE $3")) {
StringBuilderReset(&query);
return ADBC_STATUS_INTERNAL;
}
Expand Down Expand Up @@ -541,16 +603,37 @@ class PqGetObjectsHelper {
}
CHECK_NA(INTERNAL, ArrowArrayFinishElement(constraint_column_names_col_), error_);

if (row[3].is_null) {
CHECK_NA(INTERNAL, ArrowArrayAppendNull(constraint_column_usage_items_, 1),
error_);
} else {
// TODO: some kind of for loop here over each usage
// need to unpack binary data from libpq
return ADBC_STATUS_NOT_IMPLEMENTED;
if (!strcmp(constraint_type, "FOREIGN KEY")) {
assert(!row[3].is_null);
assert(!row[4].is_null);
assert(!row[5].is_null);

const char* constraint_ftable_schema = row[3].data;
const char* constraint_ftable_name = row[4].data;
auto constraint_fcolumn_names = PqTextArrayToVector(std::string(row[5].data));
for (const auto& constraint_fcolumn_name : constraint_fcolumn_names) {
CHECK_NA(
INTERNAL,
ArrowArrayAppendString(fk_catalog_col_, ArrowCharView(current_db_.c_str())),
error_);
CHECK_NA(INTERNAL,
ArrowArrayAppendString(fk_db_schema_col_,
ArrowCharView(constraint_ftable_schema)),
error_);
CHECK_NA(INTERNAL,
ArrowArrayAppendString(fk_table_col_,
ArrowCharView(constraint_ftable_name)),
error_);
CHECK_NA(INTERNAL,
ArrowArrayAppendString(fk_column_name_col_,
ArrowCharView(constraint_fcolumn_name.c_str())),
error_);

CHECK_NA(INTERNAL, ArrowArrayFinishElement(constraint_column_usage_items_),
error_);
}
}
CHECK_NA(INTERNAL, ArrowArrayFinishElement(constraint_column_usages_col_), error_);

CHECK_NA(INTERNAL, ArrowArrayFinishElement(table_constraints_items_), error_);
}

Expand Down Expand Up @@ -598,6 +681,10 @@ class PqGetObjectsHelper {
struct ArrowArray* constraint_column_name_col_;
struct ArrowArray* constraint_column_usages_col_;
struct ArrowArray* constraint_column_usage_items_;
struct ArrowArray* fk_catalog_col_;
struct ArrowArray* fk_db_schema_col_;
struct ArrowArray* fk_table_col_;
struct ArrowArray* fk_column_name_col_;
};

} // namespace
Expand Down
187 changes: 187 additions & 0 deletions c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,193 @@ TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsPrimaryKey) {
ASSERT_TRUE(seen_primary_key) << "could not find primary key for adbc_pkey_test";
}

TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsForeignKey) {
ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error));

if (!quirks()->supports_get_objects()) {
GTEST_SKIP();
}

ASSERT_THAT(quirks()->DropTable(&connection, "adbc_fkey_test", &error),
IsOkStatus(&error));
ASSERT_THAT(quirks()->DropTable(&connection, "adbc_fkey_test_base", &error),
IsOkStatus(&error));

struct AdbcStatement statement;
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
{
ASSERT_THAT(
AdbcStatementSetSqlQuery(&statement,
"CREATE TABLE adbc_fkey_test_base (id1 INT, id2 INT, "
"PRIMARY KEY (id1, id2))",
&error),
IsOkStatus(&error));
adbc_validation::StreamReader reader;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_EQ(reader.rows_affected, 0);
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(reader.array->release, nullptr);
}

ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
{
ASSERT_THAT(AdbcStatementSetSqlQuery(
&statement,
"CREATE TABLE adbc_fkey_test (fid1 INT, fid2 INT, "
"FOREIGN KEY (fid1, fid2) REFERENCES adbc_fkey_test_base(id1, id2))",
&error),
IsOkStatus(&error));
adbc_validation::StreamReader reader;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_EQ(reader.rows_affected, 0);
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(reader.array->release, nullptr);
}

adbc_validation::StreamReader reader;
ASSERT_THAT(
AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_ALL, nullptr, nullptr,
nullptr, nullptr, nullptr, &reader.stream.value, &error),
IsOkStatus(&error));
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_NE(nullptr, reader.array->release);
ASSERT_GT(reader.array->length, 0);

bool seen_fid1 = false;
bool seen_fid2 = false;

struct ArrowArrayView* catalog_db_schemas_list = reader.array_view->children[1];
struct ArrowArrayView* catalog_db_schemas_items = catalog_db_schemas_list->children[0];
struct ArrowArrayView* db_schema_name_col = catalog_db_schemas_items->children[0];
struct ArrowArrayView* db_schema_tables_col = catalog_db_schemas_items->children[1];

struct ArrowArrayView* schema_table_items = db_schema_tables_col->children[0];
struct ArrowArrayView* table_name_col = schema_table_items->children[0];
// struct ArrowArrayView* table_columns_col = schema_table_items->children[2];
struct ArrowArrayView* table_constraints_col = schema_table_items->children[3];

// struct ArrowArrayView* table_columns_items = table_columns_col->children[0];
// struct ArrowArrayView* column_name_col = table_columns_items->children[0];

struct ArrowArrayView* table_constraints_items = table_constraints_col->children[0];
struct ArrowArrayView* constraint_name_col = table_constraints_items->children[0];
struct ArrowArrayView* constraint_type_col = table_constraints_items->children[1];
// struct ArrowArrayView* constraint_column_names_col =
// table_constraints_items->children[2];

struct ArrowArrayView* constraint_column_usages_col =
table_constraints_items->children[3];
struct ArrowArrayView* constraint_column_usage_items =
constraint_column_usages_col->children[0];
struct ArrowArrayView* fk_catalog_col = constraint_column_usage_items->children[0];
struct ArrowArrayView* fk_db_schema_col = constraint_column_usage_items->children[1];
struct ArrowArrayView* fk_table_col = constraint_column_usage_items->children[2];
struct ArrowArrayView* fk_column_name_col = constraint_column_usage_items->children[3];

do {
for (int64_t catalog_idx = 0; catalog_idx < reader.array->length; catalog_idx++) {
ArrowStringView db_name =
ArrowArrayViewGetStringUnsafe(reader.array_view->children[0], catalog_idx);
auto db_str = std::string(db_name.data, db_name.size_bytes);

auto schema_list_start =
ArrowArrayViewListChildOffset(catalog_db_schemas_list, catalog_idx);
auto schema_list_end =
ArrowArrayViewListChildOffset(catalog_db_schemas_list, catalog_idx + 1);

if (db_str == "postgres") {
for (auto db_schemas_index = schema_list_start;
db_schemas_index < schema_list_end; db_schemas_index++) {
ArrowStringView schema_name =
ArrowArrayViewGetStringUnsafe(db_schema_name_col, db_schemas_index);
auto schema_str = std::string(schema_name.data, schema_name.size_bytes);
if (schema_str == "public") {
for (auto tables_index = ArrowArrayViewListChildOffset(db_schema_tables_col,
db_schemas_index);
tables_index < ArrowArrayViewListChildOffset(db_schema_tables_col,
db_schemas_index + 1);
tables_index++) {
ArrowStringView table_name =
ArrowArrayViewGetStringUnsafe(table_name_col, tables_index);
auto table_str = std::string(table_name.data, table_name.size_bytes);
if (table_str == "adbc_fkey_test") {
for (auto constraints_index = ArrowArrayViewListChildOffset(
table_constraints_col, tables_index);
constraints_index < ArrowArrayViewListChildOffset(
table_constraints_col, tables_index + 1);
constraints_index++) {
ArrowStringView constraint_name = ArrowArrayViewGetStringUnsafe(
constraint_name_col, constraints_index);
auto constraint_name_str =
std::string(constraint_name.data, constraint_name.size_bytes);
ArrowStringView constraint_type = ArrowArrayViewGetStringUnsafe(
constraint_type_col, constraints_index);
auto constraint_type_str =
std::string(constraint_type.data, constraint_type.size_bytes);

if (constraint_type_str == "FOREIGN KEY") {
for (auto usage_index = ArrowArrayViewListChildOffset(
constraint_column_usages_col, constraints_index);
usage_index <
ArrowArrayViewListChildOffset(constraint_column_usages_col,
constraints_index + 1);
usage_index++) {
ArrowStringView fk_catalog_name =
ArrowArrayViewGetStringUnsafe(fk_catalog_col, usage_index);
auto fk_catalog_name_str =
std::string(fk_catalog_name.data, fk_catalog_name.size_bytes);
ArrowStringView fk_schema_name =
ArrowArrayViewGetStringUnsafe(fk_db_schema_col, usage_index);
auto fk_schema_name_str =
std::string(fk_schema_name.data, fk_schema_name.size_bytes);
ArrowStringView fk_table_name =
ArrowArrayViewGetStringUnsafe(fk_table_col, usage_index);
auto fk_table_name_str =
std::string(fk_table_name.data, fk_table_name.size_bytes);
ArrowStringView fk_column_name =
ArrowArrayViewGetStringUnsafe(fk_column_name_col, usage_index);
auto fk_column_name_str =
std::string(fk_column_name.data, fk_column_name.size_bytes);
if ((fk_catalog_name_str == "postgres") &&
(fk_schema_name_str == "public") &&
(fk_table_name_str == "adbc_fkey_test_base")) {
// TODO: the current implementation makes it so the length of
// constraint_column_names is not the same as the length of
// constraint_column_usage as the latter applies only to foreign
// keys; should these be the same? If so can pairwise iterate
// and check source column -> foreign table column mapping
if (fk_column_name_str == "id1") {
seen_fid1 = true;
} else if (fk_column_name_str == "id2") {
seen_fid2 = true;
}
}
}
}
}
}
}
}
}
}
}
ASSERT_NO_FATAL_FAILURE(reader.Next());
} while (reader.array->release);

ASSERT_TRUE(seen_fid1)
<< "could not find foreign key relationship for 'fid1' on adbc_fkey_test";
ASSERT_TRUE(seen_fid2)
<< "could not find foreign key relationship for 'fid2' on adbc_fkey_test";
}

TEST_F(PostgresConnectionTest, MetadataGetTableSchemaInjection) {
if (!quirks()->supports_bulk_ingest()) {
GTEST_SKIP();
Expand Down

0 comments on commit eabc5b7

Please sign in to comment.