From ef76b33080d9aac0c1024ca7acb79320994ce155 Mon Sep 17 00:00:00 2001 From: Casper van der Wel Date: Tue, 12 Sep 2023 14:49:38 +0200 Subject: [PATCH] Add SQLDatabase.truncate_tables (#14) --- clean_python/sql/sql_provider.py | 5 +++++ integration_tests/test_sql_database.py | 14 +++++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/clean_python/sql/sql_provider.py b/clean_python/sql/sql_provider.py index 19adde6..4df16bc 100644 --- a/clean_python/sql/sql_provider.py +++ b/clean_python/sql/sql_provider.py @@ -6,6 +6,7 @@ from typing import Dict from typing import List from typing import Optional +from typing import Sequence from sqlalchemy import text from sqlalchemy.exc import DBAPIError @@ -83,6 +84,10 @@ async def create_extension(self, name: str) -> None: async def drop_database(self, name: str) -> None: await self._execute_autocommit(text(f"DROP DATABASE IF EXISTS {name}")) + async def truncate_tables(self, names: Sequence[str]) -> None: + quoted = [f'"{x}"' for x in names] + await self._execute_autocommit(text(f"TRUNCATE TABLE {', '.join(quoted)}")) + class SQLTransaction(SQLProvider): def __init__(self, connection: AsyncConnection): diff --git a/integration_tests/test_sql_database.py b/integration_tests/test_sql_database.py index e5cd0ed..307f206 100644 --- a/integration_tests/test_sql_database.py +++ b/integration_tests/test_sql_database.py @@ -59,10 +59,10 @@ async def database(postgres_url): @pytest.fixture -async def database_with_cleanup(database): - await database.execute(text("DELETE FROM test_model WHERE TRUE RETURNING id")) +async def database_with_cleanup(database: SQLDatabase): + await database.truncate_tables(["test_model"]) yield database - await database.execute(text("DELETE FROM test_model WHERE TRUE RETURNING id")) + await database.truncate_tables(["test_model"]) @pytest.fixture @@ -367,3 +367,11 @@ async def test_count(filters, expected, sql_gateway, obj_in_db, obj2_in_db): async def test_exists(filters, expected, sql_gateway, obj_in_db, obj2_in_db): actual = await sql_gateway.exists(filters) assert actual == expected + + +async def test_truncate(database: SQLDatabase, obj): + gateway = TstSQLGateway(database) + await gateway.add(obj) + assert await gateway.exists([]) + await database.truncate_tables(["test_model"]) + assert not await gateway.exists([])