Skip to content

Commit

Permalink
raise OperationalError if table not exists
Browse files Browse the repository at this point in the history
  • Loading branch information
truethari committed Feb 21, 2022
1 parent 35c0869 commit d4484b3
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 59 deletions.
126 changes: 71 additions & 55 deletions ReallySimpleDB/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@ def __init__(self) -> None:
self._add_columns_cmd = ""
self.connection = ""

def _not_table(self, table:str):
raise sqlite3.OperationalError("no such table: {}".format(table))

def clean(self):
self._add_columns_cmd = ""

def create_connection(self, database):
self.connection = sqlite3.connect(database)

def create_db(self, dbpath:str="", replace:bool=False):
if self.connection == "" and not len(dbpath):
if self.connection == "" and not dbpath:
raise TypeError("create_db() missing 1 required positional argument: 'dbpath'")

if replace:
Expand Down Expand Up @@ -65,10 +68,10 @@ def add_columns(self,
return True

def create_table(self, table_name:str, database:str=""):
if self.connection == "" and not len(database):
if self.connection == "" and not database:
raise TypeError("create_table() missing 1 required positional argument: 'database'")

if len(database):
if database:
self.create_connection(database)

if self._add_columns_cmd == "":
Expand All @@ -80,51 +83,53 @@ def create_table(self, table_name:str, database:str=""):
return True

def all_tables(self, database:str=""):
if self.connection == "" and not len(database):
if self.connection == "" and not database:
raise TypeError("all_tables() missing 1 required positional argument: 'database'")

if len(database):
if database:
self.create_connection(database)

cursor = self.connection.cursor()
sql_cmd = "SELECT name FROM sqlite_master WHERE type='table';"
return [student[0] for student in cursor.execute(sql_cmd)]

def is_table(self, table_name:str, database:str=""):
if self.connection == "" and not len(database):
if self.connection == "" and not database:
raise TypeError("is_table() missing 1 required positional argument: 'database'")

if len(database):
if database:
self.create_connection(database)

if table_name in self.all_tables(database):
return True
return False

def delete_table(self, table:str, database:str=""):
if self.connection == "" and not len(database):
if self.connection == "" and not database:
raise TypeError("delete_table() missing 1 required positional argument: 'database'")

if len(database):
if database:
self.create_connection(database)

if self.is_table(table_name=table):
cursor = self.connection.cursor()
sql_cmd = "DROP TABLE {};".format(table)
cursor.execute(sql_cmd)

return True

return False
self._not_table(table=table)

def get_all_column_types(self, table:str, database:str=""):
if self.connection == "" and not len(database):
if self.connection == "" and not database:
raise TypeError("get_all_column_types() missing 1 required positional argument: 'database'")

if len(database):
if database:
self.create_connection(database)

if self.is_table(table_name=table, database=database):
cursor = self.connection.cursor()

sql_cmd = "PRAGMA TABLE_INFO({});".format(table)
fetch = cursor.execute(sql_cmd)

Expand All @@ -133,19 +138,20 @@ def get_all_column_types(self, table:str, database:str=""):
data_dict[data[1]] = data[2]

return data_dict
return False

self._not_table(table=table)

def get_column_type(self, table:str, column:str, database:str=""):
all_data = self.get_all_column_types(table=table, database=database)
if type(all_data) != bool and column in all_data:
if (not isinstance(all_data, bool)) and (column in all_data):
return all_data[column]
return False

def get_columns(self, table:str, database:str=""):
if self.connection == "" and not len(database):
if self.connection == "" and not database:
raise TypeError("get_columns() missing 1 required positional argument: 'database'")

if len(database):
if database:
self.create_connection(database)

column_types = self.get_all_column_types(table=table, database=database)
Expand All @@ -157,10 +163,10 @@ def get_columns(self, table:str, database:str=""):
return columns

def get_primary_key(self, table:str, database:str=""):
if self.connection == "" and not len(database):
if self.connection == "" and not database:
raise TypeError("get_primary_key() missing 1 required positional argument: 'database'")

if len(database):
if database:
self.create_connection(database)

if self.is_table(table_name=table, database=database):
Expand All @@ -171,53 +177,55 @@ def get_primary_key(self, table:str, database:str=""):

return fetch.fetchall()[0][1]

return False
self._not_table(table=table)

def add_record(self, table:str, record, database:str=""):
if self.connection == "" and not len(database):
if self.connection == "" and not database:
raise TypeError("add_record() missing 1 required positional argument: 'database'")

if len(database):
if database:
self.create_connection(database)

if self.is_table(table_name=table, database=database):
cursor = self.connection.cursor()

tmp_all_columns = self.get_all_column_types(table=table, database=database)
all_columns = {}
for column in tmp_all_columns:
all_columns[column] = ""

fields = []
sql_cmd = "INSERT INTO {} VALUES(".format(table)
if isinstance(record, dict):
for field in record:
if field not in all_columns:
raise NameError("'{}' column is not in the table".format(field))
else:
tmp_all_columns = self.get_all_column_types(table=table, database=database)
all_columns = {}
for column in tmp_all_columns:
all_columns[column] = ""

fields = []
sql_cmd = "INSERT INTO {} VALUES(".format(table)
if isinstance(record, dict):
for field in record:
if field not in all_columns:
raise NameError("'{}' column is not in the table".format(field))

if DATA_TYPES[tmp_all_columns[field]] == type(record[field]):
all_columns[field] = record[field]
else:
raise TypeError("The '{}' field requires the '{}' type but got the '{}' type".format(field, DATA_TYPES[tmp_all_columns[field]], type(record[field])))

for field in all_columns:
fields.append(all_columns[field])
sql_cmd+= "?,"
for field in all_columns:
fields.append(all_columns[field])
sql_cmd+= "?,"

sql_cmd = sql_cmd[:-1] + ");"
sql_cmd = sql_cmd[:-1] + ");"

cursor.execute(sql_cmd, fields)
self.connection.commit()
else:
raise TypeError("'record' must be dict")
cursor.execute(sql_cmd, fields)
self.connection.commit()
else:
raise TypeError("'record' must be dict")

return True
return True

self._not_table(table=table)

def get_record(self, table:str, primary_key, database:str=""):
if self.connection == "" and not len(database):
if self.connection == "" and not database:
raise TypeError("get_record() missing 1 required positional argument: 'database'")

if len(database):
if database:
self.create_connection(database)

if self.is_table(table_name=table, database=database):
Expand All @@ -236,12 +244,14 @@ def get_record(self, table:str, primary_key, database:str=""):
return {}

return record


self._not_table(table=table)

def get_all_records(self, table:str, database:str=""):
if self.connection == "" and not len(database):
raise TypeError("get_record() missing 1 required positional argument: 'database'")
if self.connection == "" and not database:
raise TypeError("get_all_records() missing 1 required positional argument: 'database'")

if len(database):
if database:
self.create_connection(database)

if self.is_table(table_name=table, database=database):
Expand All @@ -263,11 +273,13 @@ def get_all_records(self, table:str, database:str=""):

return records

self._not_table(table=table)

def delete_record(self, table:str, primary_key, database:str=""):
if self.connection == "" and not len(database):
if self.connection == "" and not database:
raise TypeError("delete_record() missing 1 required positional argument: 'database'")

if len(database):
if database:
self.create_connection(database)

if self.is_table(table_name=table, database=database):
Expand All @@ -276,13 +288,15 @@ def delete_record(self, table:str, primary_key, database:str=""):
cursor.execute(sql, (primary_key,))
self.connection.commit()

return True

return True

self._not_table(table=table)

def filter_records(self, table:str, values:dict, database:str=""):
if self.connection == "" and not len(database):
raise TypeError("delete_record() missing 1 required positional argument: 'database'")
if self.connection == "" and not database:
raise TypeError("filter_records() missing 1 required positional argument: 'database'")

if len(database):
if database:
self.create_connection(database)

if self.is_table(table_name=table, database=database):
Expand Down Expand Up @@ -313,6 +327,8 @@ def filter_records(self, table:str, values:dict, database:str=""):

return records

self._not_table(table=table)

def close_connection(self):
self.connection.close()
return True
27 changes: 23 additions & 4 deletions tests/test_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from sqlite3 import OperationalError
from ReallySimpleDB import dbmanager

_dbmanager = dbmanager()
Expand Down Expand Up @@ -58,10 +59,16 @@ def test_is_table_4():
assert not _dbmanager.is_table(table_name="NON")

def test_delete_table_1():
assert _dbmanager.delete_table(table="EMPLOYEES")
try:
_dbmanager.delete_table(table="EMPLOYEES")
except OperationalError:
assert True

def test_delete_table_2():
assert not _dbmanager.delete_table(table="EMPLOYEES")
try:
_dbmanager.delete_table(table="EMPLOYEES")
except OperationalError:
assert True

def test_get_all_column_types_1():
assert _dbmanager.get_all_column_types(table="STUDENTS") \
Expand All @@ -78,13 +85,19 @@ def test_get_column_type_2():
assert _dbmanager.get_column_type(table="STUDENTS", column="address") == False

def test_get_column_type_3():
assert _dbmanager.get_column_type(table="EMPLOYEES", column="emp_id") == False
try:
_dbmanager.get_column_type(table="EMPLOYEES", column="emp_id")
except OperationalError:
assert True

def test_get_columns_1():
assert _dbmanager.get_columns(table="STUDENTS") == ["student_id", "name", "mark", "year"]

def test_get_columns_2():
assert _dbmanager.get_columns(table="EMPLOYEES") == []
try:
_dbmanager.get_columns(table="EMPLOYEES")
except OperationalError:
assert True

def test_get_primary_key_1():
assert _dbmanager.get_primary_key(table="STUDENTS") == "student_id"
Expand Down Expand Up @@ -125,6 +138,12 @@ def test_filter_record_2():
def test_delete_record_1():
assert _dbmanager.delete_record(table="STUDENTS", primary_key="1010")

def test_delete_record_2():
try:
_dbmanager.delete_record(table="STUDENTSS", primary_key="1010")
except OperationalError:
assert True

def test_finally():
delete_db()

Expand Down

0 comments on commit d4484b3

Please sign in to comment.