Skip to content

Commit

Permalink
Add SQLDatabase.truncate_tables (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervdw authored Sep 12, 2023
1 parent 73adcd7 commit ef76b33
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
5 changes: 5 additions & 0 deletions clean_python/sql/sql_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence

from sqlalchemy import text
from sqlalchemy.exc import DBAPIError
Expand Down Expand Up @@ -83,6 +84,10 @@ async def create_extension(self, name: str) -> None:
async def drop_database(self, name: str) -> None:
await self._execute_autocommit(text(f"DROP DATABASE IF EXISTS {name}"))

async def truncate_tables(self, names: Sequence[str]) -> None:
quoted = [f'"{x}"' for x in names]
await self._execute_autocommit(text(f"TRUNCATE TABLE {', '.join(quoted)}"))


class SQLTransaction(SQLProvider):
def __init__(self, connection: AsyncConnection):
Expand Down
14 changes: 11 additions & 3 deletions integration_tests/test_sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ async def database(postgres_url):


@pytest.fixture
async def database_with_cleanup(database):
await database.execute(text("DELETE FROM test_model WHERE TRUE RETURNING id"))
async def database_with_cleanup(database: SQLDatabase):
await database.truncate_tables(["test_model"])
yield database
await database.execute(text("DELETE FROM test_model WHERE TRUE RETURNING id"))
await database.truncate_tables(["test_model"])


@pytest.fixture
Expand Down Expand Up @@ -367,3 +367,11 @@ async def test_count(filters, expected, sql_gateway, obj_in_db, obj2_in_db):
async def test_exists(filters, expected, sql_gateway, obj_in_db, obj2_in_db):
actual = await sql_gateway.exists(filters)
assert actual == expected


async def test_truncate(database: SQLDatabase, obj):
gateway = TstSQLGateway(database)
await gateway.add(obj)
assert await gateway.exists([])
await database.truncate_tables(["test_model"])
assert not await gateway.exists([])

0 comments on commit ef76b33

Please sign in to comment.