From 0f8e3c52059d6baa42f185bca266fb8f98db1a59 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 12 May 2023 11:20:19 -0400 Subject: [PATCH] feat(format): introduce AdbcDriver110 Fixes #317. --- adbc.h | 24 +- c/driver_manager/adbc_driver_manager.cc | 269 +++++++++++-------- c/driver_manager/adbc_driver_manager_test.cc | 30 +++ 3 files changed, 213 insertions(+), 110 deletions(-) diff --git a/adbc.h b/adbc.h index 154e881255..27450faf57 100644 --- a/adbc.h +++ b/adbc.h @@ -279,6 +279,12 @@ struct ADBC_EXPORT AdbcError { /// point to an AdbcDriver. #define ADBC_VERSION_1_0_0 1000000 +/// \brief ADBC revision 1.1.0. +/// +/// When passed to an AdbcDriverInitFunc(), the driver parameter must +/// point to an AdbcDriver110. +#define ADBC_VERSION_1_1_0 1001000 + /// \brief Canonical option value for enabling an option. /// /// For use as the value in SetOption calls. @@ -479,7 +485,7 @@ struct ADBC_EXPORT AdbcDatabase { void* private_data; /// \brief The associated driver (used by the driver manager to help /// track state). - struct AdbcDriver* private_driver; + void* private_driver; }; /// @} @@ -502,7 +508,7 @@ struct ADBC_EXPORT AdbcConnection { void* private_data; /// \brief The associated driver (used by the driver manager to help /// track state). - struct AdbcDriver* private_driver; + void* private_driver; }; /// @} @@ -541,7 +547,7 @@ struct ADBC_EXPORT AdbcStatement { /// \brief The associated driver (used by the driver manager to help /// track state). - struct AdbcDriver* private_driver; + void* private_driver; }; /// \defgroup adbc-statement-partition Partitioned Results @@ -595,7 +601,7 @@ struct AdbcPartitions { /// driver and the driver manager. /// @{ -/// \brief An instance of an initialized database driver. +/// \brief An instance of an initialized database driver (API 1.0.0). /// /// This provides a common interface for vendor-specific driver /// initialization routines. Drivers should populate this struct, and @@ -669,6 +675,16 @@ struct ADBC_EXPORT AdbcDriver { size_t, struct AdbcError*); }; +/// \brief An instance of an initialized database driver (API 1.1.0). +/// +/// This provides a common interface for vendor-specific driver +/// initialization routines. Drivers should populate this struct, and +/// applications can call ADBC functions through this struct, without +/// worrying about multiple definitions of the same symbol. +struct ADBC_EXPORT AdbcDriver110 { + struct AdbcDriver base; +}; + /// @} /// \addtogroup adbc-database diff --git a/c/driver_manager/adbc_driver_manager.cc b/c/driver_manager/adbc_driver_manager.cc index c63560a40e..afe44a908a 100644 --- a/c/driver_manager/adbc_driver_manager.cc +++ b/c/driver_manager/adbc_driver_manager.cc @@ -243,7 +243,8 @@ AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, const char* value, struct AdbcError* error) { if (database->private_driver) { - return database->private_driver->DatabaseSetOption(database, key, value, error); + return static_cast(database->private_driver) + ->base.DatabaseSetOption(database, key, value, error); } TempDatabase* args = reinterpret_cast(database->private_data); @@ -282,48 +283,52 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* return ADBC_STATUS_INVALID_ARGUMENT; } - database->private_driver = new AdbcDriver; - std::memset(database->private_driver, 0, sizeof(AdbcDriver)); + database->private_driver = new AdbcDriver110; + std::memset(database->private_driver, 0, sizeof(AdbcDriver110)); + + auto* driver110 = static_cast(database->private_driver); + struct AdbcDriver* driver100 = &driver110->base; + AdbcStatusCode status; // So we don't confuse a driver into thinking it's initialized already database->private_data = nullptr; if (args->init_func) { - status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_0_0, - database->private_driver, error); + status = + AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_0_0, driver100, error); } else { status = AdbcLoadDriver(args->driver.c_str(), args->entrypoint.c_str(), - ADBC_VERSION_1_0_0, database->private_driver, error); + ADBC_VERSION_1_0_0, driver100, error); } if (status != ADBC_STATUS_OK) { // Restore private_data so it will be released by AdbcDatabaseRelease database->private_data = args; - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); + if (driver100->release) { + driver100->release(driver100, error); } - delete database->private_driver; + delete static_cast(database->private_driver); database->private_driver = nullptr; return status; } - status = database->private_driver->DatabaseNew(database, error); + status = driver100->DatabaseNew(database, error); if (status != ADBC_STATUS_OK) { - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); + if (driver100->release) { + driver100->release(driver100, error); } - delete database->private_driver; + delete static_cast(database->private_driver); database->private_driver = nullptr; return status; } for (const auto& option : args->options) { - status = database->private_driver->DatabaseSetOption(database, option.first.c_str(), - option.second.c_str(), error); + status = driver100->DatabaseSetOption(database, option.first.c_str(), + option.second.c_str(), error); if (status != ADBC_STATUS_OK) { delete args; // Release the database - std::ignore = database->private_driver->DatabaseRelease(database, error); - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); + std::ignore = driver100->DatabaseRelease(database, error); + if (driver100->release) { + driver100->release(driver100, error); } - delete database->private_driver; + delete static_cast(database->private_driver); database->private_driver = nullptr; // Should be redundant, but ensure that AdbcDatabaseRelease // below doesn't think that it contains a TempDatabase @@ -332,7 +337,7 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* } } delete args; - return database->private_driver->DatabaseInit(database, error); + return driver100->DatabaseInit(database, error); } AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, @@ -346,11 +351,13 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, } return ADBC_STATUS_INVALID_STATE; } - auto status = database->private_driver->DatabaseRelease(database, error); - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); + auto* driver110 = static_cast(database->private_driver); + struct AdbcDriver* driver100 = &driver110->base; + auto status = driver100->DatabaseRelease(database, error); + if (driver100->release) { + driver100->release(driver100, error); } - delete database->private_driver; + delete static_cast(database->private_driver); database->private_data = nullptr; database->private_driver = nullptr; return status; @@ -361,7 +368,8 @@ AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionCommit(connection, error); + return static_cast(connection->private_driver) + ->base.ConnectionCommit(connection, error); } AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, @@ -371,8 +379,8 @@ AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetInfo(connection, info_codes, - info_codes_length, out, error); + return static_cast(connection->private_driver) + ->base.ConnectionGetInfo(connection, info_codes, info_codes_length, out, error); } AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, @@ -384,9 +392,9 @@ AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int d if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetObjects( - connection, depth, catalog, db_schema, table_name, table_types, column_name, stream, - error); + return static_cast(connection->private_driver) + ->base.ConnectionGetObjects(connection, depth, catalog, db_schema, table_name, + table_types, column_name, stream, error); } AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, @@ -397,8 +405,9 @@ AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetTableSchema( - connection, catalog, db_schema, table_name, schema, error); + return static_cast(connection->private_driver) + ->base.ConnectionGetTableSchema(connection, catalog, db_schema, table_name, schema, + error); } AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, @@ -407,7 +416,8 @@ AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionGetTableTypes(connection, stream, error); + return static_cast(connection->private_driver) + ->base.ConnectionGetTableTypes(connection, stream, error); } AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, @@ -425,16 +435,19 @@ AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, std::unordered_map options = std::move(args->options); delete args; - auto status = database->private_driver->ConnectionNew(connection, error); + auto status = static_cast(database->private_driver) + ->base.ConnectionNew(connection, error); if (status != ADBC_STATUS_OK) return status; connection->private_driver = database->private_driver; for (const auto& option : options) { - status = database->private_driver->ConnectionSetOption( - connection, option.first.c_str(), option.second.c_str(), error); + status = static_cast(database->private_driver) + ->base.ConnectionSetOption(connection, option.first.c_str(), + option.second.c_str(), error); if (status != ADBC_STATUS_OK) return status; } - return connection->private_driver->ConnectionInit(connection, database, error); + return static_cast(connection->private_driver) + ->base.ConnectionInit(connection, database, error); } AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, @@ -455,8 +468,9 @@ AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionReadPartition( - connection, serialized_partition, serialized_length, out, error); + return static_cast(connection->private_driver) + ->base.ConnectionReadPartition(connection, serialized_partition, serialized_length, + out, error); } AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, @@ -470,7 +484,8 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, } return ADBC_STATUS_INVALID_STATE; } - auto status = connection->private_driver->ConnectionRelease(connection, error); + auto status = static_cast(connection->private_driver) + ->base.ConnectionRelease(connection, error); connection->private_driver = nullptr; return status; } @@ -480,7 +495,8 @@ AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return connection->private_driver->ConnectionRollback(connection, error); + return static_cast(connection->private_driver) + ->base.ConnectionRollback(connection, error); } AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, @@ -495,7 +511,8 @@ AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const args->options[key] = value; return ADBC_STATUS_OK; } - return connection->private_driver->ConnectionSetOption(connection, key, value, error); + return static_cast(connection->private_driver) + ->base.ConnectionSetOption(connection, key, value, error); } AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, @@ -504,7 +521,8 @@ AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return statement->private_driver->StatementBind(statement, values, schema, error); + return static_cast(statement->private_driver) + ->base.StatementBind(statement, values, schema, error); } AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, @@ -513,7 +531,8 @@ AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return statement->private_driver->StatementBindStream(statement, stream, error); + return static_cast(statement->private_driver) + ->base.StatementBindStream(statement, stream, error); } // XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema' @@ -525,8 +544,9 @@ AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return statement->private_driver->StatementExecutePartitions( - statement, schema, partitions, rows_affected, error); + return static_cast(statement->private_driver) + ->base.StatementExecutePartitions(statement, schema, partitions, rows_affected, + error); } AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, @@ -536,8 +556,8 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return statement->private_driver->StatementExecuteQuery(statement, out, rows_affected, - error); + return static_cast(statement->private_driver) + ->base.StatementExecuteQuery(statement, out, rows_affected, error); } AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, @@ -546,7 +566,8 @@ AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return statement->private_driver->StatementGetParameterSchema(statement, schema, error); + return static_cast(statement->private_driver) + ->base.StatementGetParameterSchema(statement, schema, error); } AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, @@ -555,7 +576,8 @@ AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, if (!connection->private_driver) { return ADBC_STATUS_INVALID_STATE; } - auto status = connection->private_driver->StatementNew(connection, statement, error); + auto status = static_cast(connection->private_driver) + ->base.StatementNew(connection, statement, error); statement->private_driver = connection->private_driver; return status; } @@ -565,7 +587,8 @@ AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return statement->private_driver->StatementPrepare(statement, error); + return static_cast(statement->private_driver) + ->base.StatementPrepare(statement, error); } AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, @@ -573,7 +596,8 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - auto status = statement->private_driver->StatementRelease(statement, error); + auto status = static_cast(statement->private_driver) + ->base.StatementRelease(statement, error); statement->private_driver = nullptr; return status; } @@ -583,7 +607,8 @@ AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const cha if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return statement->private_driver->StatementSetOption(statement, key, value, error); + return static_cast(statement->private_driver) + ->base.StatementSetOption(statement, key, value, error); } AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, @@ -591,7 +616,8 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return statement->private_driver->StatementSetSqlQuery(statement, query, error); + return static_cast(statement->private_driver) + ->base.StatementSetSqlQuery(statement, query, error); } AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, @@ -600,8 +626,8 @@ AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, if (!statement->private_driver) { return ADBC_STATUS_INVALID_STATE; } - return statement->private_driver->StatementSetSubstraitPlan(statement, plan, length, - error); + return static_cast(statement->private_driver) + ->base.StatementSetSubstraitPlan(statement, plan, length, error); } const char* AdbcStatusCodeMessage(AdbcStatusCode code) { @@ -640,13 +666,11 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint, AdbcDriverInitFunc init_func; std::string error_message; - if (version != ADBC_VERSION_1_0_0) { - SetError(error, "Only ADBC 1.0.0 is supported"); + if (version != ADBC_VERSION_1_0_0 && version != ADBC_VERSION_1_1_0) { + SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); return ADBC_STATUS_NOT_IMPLEMENTED; } - auto* driver = reinterpret_cast(raw_driver); - if (!entrypoint) { // Default entrypoint (see adbc.h) entrypoint = "AdbcDriverInit"; @@ -730,8 +754,6 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint, } if (!handle) { SetError(error, error_message); - // AdbcDatabaseInit tries to call this if set - driver->release = nullptr; return ADBC_STATUS_INTERNAL; } @@ -748,29 +770,46 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint, #endif // defined(_WIN32) - AdbcStatusCode status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); - if (status == ADBC_STATUS_OK) { - ManagerDriverState* state = new ManagerDriverState; - state->driver_release = driver->release; + AdbcStatusCode status = ADBC_STATUS_OK; + if (version == ADBC_VERSION_1_0_0) { + auto* driver = reinterpret_cast(raw_driver); + status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); + if (status == ADBC_STATUS_OK) { + ManagerDriverState* state = new ManagerDriverState; + state->driver_release = driver->release; #if defined(_WIN32) - state->handle = handle; + state->handle = handle; #endif // defined(_WIN32) - driver->release = &ReleaseDriver; - driver->private_manager = state; - } else { -#if defined(_WIN32) - if (!FreeLibrary(handle)) { - std::string message = "FreeLibrary() failed: "; - GetWinError(&message); - SetError(error, message); + driver->release = &ReleaseDriver; + driver->private_manager = state; } + } else if (version == ADBC_VERSION_1_1_0) { + auto* driver = reinterpret_cast(raw_driver); + status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); + if (status == ADBC_STATUS_OK) { + ManagerDriverState* state = new ManagerDriverState; + state->driver_release = driver->base.release; +#if defined(_WIN32) + state->handle = handle; #endif // defined(_WIN32) + driver->base.release = &ReleaseDriver; + driver->base.private_manager = state; + } + } else { + SetError(error, "ADBC version not supported"); + status = ADBC_STATUS_NOT_IMPLEMENTED; + } + +#if defined(_WIN32) + if (status != ADBC_STATUS_OK && !FreeLibrary(handle)) { + std::string message = "FreeLibrary() failed: "; + GetWinError(&message); + SetError(error, message); } +#endif // defined(_WIN32) return status; } -AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, - void* raw_driver, struct AdbcError* error) { #define FILL_DEFAULT(DRIVER, STUB) \ if (!DRIVER->STUB) { \ DRIVER->STUB = &STUB; \ @@ -781,44 +820,62 @@ AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int vers return ADBC_STATUS_INTERNAL; \ } +AdbcStatusCode PolyfillDriver100(AdbcDriver* driver, AdbcError* error) { + CHECK_REQUIRED(driver, DatabaseNew); + CHECK_REQUIRED(driver, DatabaseInit); + CHECK_REQUIRED(driver, DatabaseRelease); + FILL_DEFAULT(driver, DatabaseSetOption); + + CHECK_REQUIRED(driver, ConnectionNew); + CHECK_REQUIRED(driver, ConnectionInit); + CHECK_REQUIRED(driver, ConnectionRelease); + FILL_DEFAULT(driver, ConnectionCommit); + FILL_DEFAULT(driver, ConnectionGetInfo); + FILL_DEFAULT(driver, ConnectionGetObjects); + FILL_DEFAULT(driver, ConnectionGetTableSchema); + FILL_DEFAULT(driver, ConnectionGetTableTypes); + FILL_DEFAULT(driver, ConnectionReadPartition); + FILL_DEFAULT(driver, ConnectionRollback); + FILL_DEFAULT(driver, ConnectionSetOption); + + FILL_DEFAULT(driver, StatementExecutePartitions); + CHECK_REQUIRED(driver, StatementExecuteQuery); + CHECK_REQUIRED(driver, StatementNew); + CHECK_REQUIRED(driver, StatementRelease); + FILL_DEFAULT(driver, StatementBind); + FILL_DEFAULT(driver, StatementGetParameterSchema); + FILL_DEFAULT(driver, StatementPrepare); + FILL_DEFAULT(driver, StatementSetOption); + FILL_DEFAULT(driver, StatementSetSqlQuery); + FILL_DEFAULT(driver, StatementSetSubstraitPlan); + return ADBC_STATUS_OK; +} + +AdbcStatusCode PolyfillDriver110(AdbcDriver110* driver, AdbcError* error) { + // No new functions yet + return PolyfillDriver100(&driver->base, error); +} + +AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, + void* raw_driver, struct AdbcError* error) { auto result = init_func(version, raw_driver, error); + if (version == ADBC_VERSION_1_1_0 && result == ADBC_STATUS_NOT_IMPLEMENTED) { + auto* driver = reinterpret_cast(raw_driver); + result = init_func(ADBC_VERSION_1_0_0, &driver->base, error); + } if (result != ADBC_STATUS_OK) { return result; } if (version == ADBC_VERSION_1_0_0) { auto* driver = reinterpret_cast(raw_driver); - CHECK_REQUIRED(driver, DatabaseNew); - CHECK_REQUIRED(driver, DatabaseInit); - CHECK_REQUIRED(driver, DatabaseRelease); - FILL_DEFAULT(driver, DatabaseSetOption); - - CHECK_REQUIRED(driver, ConnectionNew); - CHECK_REQUIRED(driver, ConnectionInit); - CHECK_REQUIRED(driver, ConnectionRelease); - FILL_DEFAULT(driver, ConnectionCommit); - FILL_DEFAULT(driver, ConnectionGetInfo); - FILL_DEFAULT(driver, ConnectionGetObjects); - FILL_DEFAULT(driver, ConnectionGetTableSchema); - FILL_DEFAULT(driver, ConnectionGetTableTypes); - FILL_DEFAULT(driver, ConnectionReadPartition); - FILL_DEFAULT(driver, ConnectionRollback); - FILL_DEFAULT(driver, ConnectionSetOption); - - FILL_DEFAULT(driver, StatementExecutePartitions); - CHECK_REQUIRED(driver, StatementExecuteQuery); - CHECK_REQUIRED(driver, StatementNew); - CHECK_REQUIRED(driver, StatementRelease); - FILL_DEFAULT(driver, StatementBind); - FILL_DEFAULT(driver, StatementGetParameterSchema); - FILL_DEFAULT(driver, StatementPrepare); - FILL_DEFAULT(driver, StatementSetOption); - FILL_DEFAULT(driver, StatementSetSqlQuery); - FILL_DEFAULT(driver, StatementSetSubstraitPlan); + return PolyfillDriver100(driver, error); + } else if (version == ADBC_VERSION_1_1_0) { + auto* driver = reinterpret_cast(raw_driver); + return PolyfillDriver110(driver, error); } - return ADBC_STATUS_OK; +} #undef FILL_DEFAULT #undef CHECK_REQUIRED -} diff --git a/c/driver_manager/adbc_driver_manager_test.cc b/c/driver_manager/adbc_driver_manager_test.cc index 99fa477bfa..4882234364 100644 --- a/c/driver_manager/adbc_driver_manager_test.cc +++ b/c/driver_manager/adbc_driver_manager_test.cc @@ -157,6 +157,36 @@ TEST_F(DriverManager, MultiDriverTest) { error->release(&error.value); } +class AdbcVersion : public ::testing::Test { + public: + void SetUp() override { + std::memset(&driver, 0, sizeof(driver)); + std::memset(&error, 0, sizeof(error)); + } + + void TearDown() override { + if (error.release) { + error.release(&error); + } + + if (driver.base.release) { + ASSERT_THAT(driver.base.release(&driver.base, &error), IsOkStatus(&error)); + ASSERT_EQ(driver.base.private_data, nullptr); + ASSERT_EQ(driver.base.private_manager, nullptr); + } + } + + protected: + struct AdbcDriver110 driver = {}; + struct AdbcError error = {}; +}; + +TEST_F(DriverManager, AdbcVersionNotSupported) { + ASSERT_THAT( + AdbcLoadDriver("adbc_driver_sqlite", nullptr, ADBC_VERSION_1_1_0, &driver, &error), + IsOkStatus(&error)); +} + class SqliteQuirks : public adbc_validation::DriverQuirks { public: AdbcStatusCode SetupDatabase(struct AdbcDatabase* database,