Skip to content

Commit

Permalink
Remove context manager transactions
Browse files Browse the repository at this point in the history
  • Loading branch information
jph00 committed Oct 22, 2024
1 parent f40c1e6 commit 19e6f31
Show file tree
Hide file tree
Showing 36 changed files with 3,869 additions and 103 deletions.
94 changes: 55 additions & 39 deletions nbs/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -431,13 +431,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"If you have any `db.query()` or `db.execute()` method calls outside a formal transaction, they are committed instantly."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you have any SQL calls outside an explicit transaction, they are committed instantly.\n",
"\n",
"To group 2 or more queries together into 1 transaction, wrap them in a BEGIN and COMMIT, executing ROLLBACK if an exception is caught: "
]
},
Expand All @@ -449,11 +444,7 @@
{
"data": {
"text/plain": [
"[(1, 'Raven', 8, 's3cret'),\n",
" (2, 'Magpie', 5, 'supersecret'),\n",
" (3, 'Crow', 12, 'verysecret'),\n",
" (4, 'Pigeon', 3, 'keptsecret'),\n",
" (5, 'Eagle', 7, 's3cr3t')]"
"{'id': 1, 'name': 'Raven', 'age': 8, 'pwd': 's3cret'}"
]
},
"execution_count": null,
Expand All @@ -462,7 +453,7 @@
}
],
"source": [
"list(db.execute('SELECT * FROM users'))"
"users.get(1)"
]
},
{
Expand All @@ -474,21 +465,26 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[(1, 'Raven', 8, 's3cret'), (2, 'Magpie', 5, 'supersecret'), (3, 'Crow', 12, 'verysecret'), (4, 'Pigeon', 3, 'keptsecret'), (5, 'Eagle', 7, 's3cr3t')]\n",
"near \"FNOOORD\": syntax error\n"
]
}
],
"source": [
"print(list(db.execute('SELECT * FROM users')))\n",
"db.execute('BEGIN')\n",
"db.begin()\n",
"try:\n",
" db.execute('DELETE FROM Users WHERE id = ?', [1])\n",
" users.delete([1])\n",
" db.execute('FNOOORD')\n",
" db.execute('COMMIT')\n",
" db.commit()\n",
"except Exception as e:\n",
" print(e)\n",
" db.execute('ROLLBACK')"
" db.rollback()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Because the transaction was rolled back, the user was not deleted:"
]
},
{
Expand All @@ -497,23 +493,38 @@
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{'id': 1, 'name': 'Raven', 'age': 8, 'pwd': 's3cret'}, {'id': 2, 'name': 'Magpie', 'age': 5, 'pwd': 'supersecret'}, {'id': 3, 'name': 'Crow', 'age': 12, 'pwd': 'verysecret'}, {'id': 4, 'name': 'Pigeon', 'age': 3, 'pwd': 'keptsecret'}, {'id': 5, 'name': 'Eagle', 'age': 7, 'pwd': 's3cr3t'}]\n"
]
"data": {
"text/plain": [
"{'id': 1, 'name': 'Raven', 'age': 8, 'pwd': 's3cret'}"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(list(db.query('SELECT * FROM users')))\n",
"db.execute('BEGIN')\n",
"users.get(1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's do it again, but without the DB error, to check the transaction is successful:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db.begin()\n",
"try:\n",
" db.execute('DELETE FROM Users WHERE id = ?', [2])\n",
" users.insert({'id': 6, 'name': 'Ravens', 'age': 8, 'pwd': 's3cret'})\n",
" db.execute('COMMIT')\n",
"except Exception as e:\n",
" print(e)\n",
" db.execute('ROLLBACK')"
" users.delete([1])\n",
" db.commit()\n",
"except Exception as e: db.rollback()"
]
},
{
Expand All @@ -525,18 +536,23 @@
"name": "stdout",
"output_type": "stream",
"text": [
"{'id': 1, 'name': 'Raven', 'age': 8, 'pwd': 's3cret'}\n",
"{'id': 3, 'name': 'Crow', 'age': 12, 'pwd': 'verysecret'}\n",
"{'id': 4, 'name': 'Pigeon', 'age': 3, 'pwd': 'keptsecret'}\n",
"{'id': 5, 'name': 'Eagle', 'age': 7, 'pwd': 's3cr3t'}\n",
"{'id': 6, 'name': 'Ravens', 'age': 8, 'pwd': 's3cret'}\n"
"Delete succeeded!\n"
]
}
],
"source": [
"for x in list(db.query('SELECT * FROM users')):\n",
" print(x)"
"try:\n",
" users.get(1)\n",
" print(\"Delete failed!\")\n",
"except: print(\"Delete succeeded!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
129 changes: 65 additions & 64 deletions sqlite_minutils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,18 @@ def table_names(self, fts5: bool = False) -> List[str]:
sql = "select name from sqlite_master where {}".format(" AND ".join(where))
return [r[0] for r in self.execute(sql).fetchall()]

def begin(self):
"Begin a transaction"
return self.execute('BEGIN')

def commit(self):
"Commit a transaction"
return self.execute('COMMIT')

def rollback(self):
"Roll back a transaction"
return self.execute('ROLLBACK')

def view_names(self) -> List[str]:
"List of string view names in this database."
return [
Expand Down Expand Up @@ -593,11 +605,10 @@ def supports_strict(self) -> bool:
"Does this database support STRICT mode?"
try:
table_name = "t{}".format(secrets.token_hex(16))
with self.conn:
self.conn.execute(
"create table {} (name text) strict".format(table_name)
)
self.conn.execute("drop table {}".format(table_name))
self.conn.execute(
"create table {} (name text) strict".format(table_name)
)
self.conn.execute("drop table {}".format(table_name))
return True
except Exception:
return False
Expand Down Expand Up @@ -632,8 +643,7 @@ def disable_wal(self):
self.execute("PRAGMA journal_mode=delete;")

def _ensure_counts_table(self):
with self.conn:
self.execute(_COUNTS_TABLE_CREATE_SQL.format(self._counts_table_name))
self.execute(_COUNTS_TABLE_CREATE_SQL.format(self._counts_table_name))

def enable_counts(self):
"""
Expand Down Expand Up @@ -667,14 +677,13 @@ def cached_counts(self, tables: Optional[Iterable[str]] = None) -> Dict[str, int
def reset_counts(self):
"Re-calculate cached counts for tables."
tables = [table for table in self.tables if table.has_counts_triggers]
with self.conn:
self._ensure_counts_table()
counts_table = self[self._counts_table_name]
counts_table.delete_where()
counts_table.insert_all(
{"table": table.name, "count": table.execute_count()}
for table in tables
)
self._ensure_counts_table()
counts_table = self[self._counts_table_name]
counts_table.delete_where()
counts_table.insert_all(
{"table": table.name, "count": table.execute_count()}
for table in tables
)

def execute_returning_dicts(
self, sql: str, params: Optional[Union[Iterable, dict]] = None
Expand Down Expand Up @@ -1623,24 +1632,23 @@ def create(
:param strict: Apply STRICT mode to table
"""
columns = {name: value for (name, value) in columns.items()}
with self.db.conn:
self.db.create_table(
self.name,
columns,
pk=pk,
foreign_keys=foreign_keys,
column_order=column_order,
not_null=not_null,
defaults=defaults,
hash_id=hash_id,
hash_id_columns=hash_id_columns,
extracts=extracts,
if_not_exists=if_not_exists,
replace=replace,
ignore=ignore,
transform=transform,
strict=strict,
)
self.db.create_table(
self.name,
columns,
pk=pk,
foreign_keys=foreign_keys,
column_order=column_order,
not_null=not_null,
defaults=defaults,
hash_id=hash_id,
hash_id_columns=hash_id_columns,
extracts=extracts,
if_not_exists=if_not_exists,
replace=replace,
ignore=ignore,
transform=transform,
strict=strict,
)
return self

def duplicate(self, new_name: str) -> "Table":
Expand All @@ -1651,12 +1659,11 @@ def duplicate(self, new_name: str) -> "Table":
"""
if not self.exists():
raise NoTable(f"Table {self.name} does not exist")
with self.db.conn:
sql = "CREATE TABLE [{new_table}] AS SELECT * FROM [{table}];".format(
new_table=new_name,
table=self.name,
)
self.db.execute(sql)
sql = "CREATE TABLE [{new_table}] AS SELECT * FROM [{table}];".format(
new_table=new_name,
table=self.name,
)
self.db.execute(sql)
return self.db[new_name]

def transform(
Expand Down Expand Up @@ -1714,12 +1721,11 @@ def transform(
try:
if pragma_foreign_keys_was_on:
self.db.execute("PRAGMA foreign_keys=0;")
with self.db.conn:
for sql in sqls:
self.db.execute(sql)
# Run the foreign_key_check before we commit
if pragma_foreign_keys_was_on:
self.db.execute("PRAGMA foreign_key_check;")
for sql in sqls:
self.db.execute(sql)
# Run the foreign_key_check before we commit
if pragma_foreign_keys_was_on:
self.db.execute("PRAGMA foreign_key_check;")
finally:
if pragma_foreign_keys_was_on:
self.db.execute("PRAGMA foreign_keys=1;")
Expand Down Expand Up @@ -2317,8 +2323,7 @@ def enable_counts(self):
table_quoted=self.db.quote(self.name),
)
)
with self.db.conn:
self.db.conn.executescript(sql)
self.db.conn.executescript(sql)
self.db.use_counts_table = True

@property
Expand Down Expand Up @@ -2460,9 +2465,8 @@ def disable_fts(self) -> "Table":
trigger_names = []
for row in self.db.execute(sql).fetchall():
trigger_names.append(row[0])
with self.db.conn:
for trigger_name in trigger_names:
self.db.execute("DROP TRIGGER IF EXISTS [{}]".format(trigger_name))
for trigger_name in trigger_names:
self.db.execute("DROP TRIGGER IF EXISTS [{}]".format(trigger_name))
return self

def rebuild_fts(self):
Expand Down Expand Up @@ -2655,8 +2659,7 @@ def delete(self, pk_values: Union[list, tuple, str, int, float]) -> "Table":
sql = "delete from [{table}] where {wheres}".format(
table=self.name, wheres=" and ".join(wheres)
)
with self.db.conn:
self.db.execute(sql, pk_values)
self.db.execute(sql, pk_values)
return self

def delete_where(
Expand Down Expand Up @@ -2725,19 +2728,18 @@ def update(
sql = "update [{table}] set {sets} where {wheres}".format(
table=self.name, sets=", ".join(sets), wheres=" and ".join(wheres)
)
with self.db.conn:
try:
try:
rowcount = self.db.execute(sql, args).rowcount
except OperationalError as e:
if alter and (" column" in e.args[0]):
# Attempt to add any missing columns, then try again
self.add_missing_columns([updates])
rowcount = self.db.execute(sql, args).rowcount
except OperationalError as e:
if alter and (" column" in e.args[0]):
# Attempt to add any missing columns, then try again
self.add_missing_columns([updates])
rowcount = self.db.execute(sql, args).rowcount
else:
raise
else:
raise

# TODO: Test this works (rolls back) - use better exception:
# assert rowcount == 1
# TODO: Test this works (rolls back) - use better exception:
# assert rowcount == 1
self.last_pk = pk_values[0] if len(pks) == 1 else pk_values
return self

Expand Down Expand Up @@ -2885,7 +2887,6 @@ def insert_chunk(
ignore,
)

# with self.db.conn:
result = None
for query, params in queries_and_params:
try:
Expand Down
Empty file added tests/__init__.py
Empty file.
Loading

0 comments on commit 19e6f31

Please sign in to comment.