From d4484b3bfcc07ce862c9cfa595d00e4b1a83145d Mon Sep 17 00:00:00 2001 From: Tharindu N Date: Mon, 21 Feb 2022 06:58:47 +0530 Subject: [PATCH] raise OperationalError if table not exists --- ReallySimpleDB/manager.py | 126 +++++++++++++++++++++----------------- tests/test_manager.py | 27 ++++++-- 2 files changed, 94 insertions(+), 59 deletions(-) diff --git a/ReallySimpleDB/manager.py b/ReallySimpleDB/manager.py index 7309d96..bc818b3 100644 --- a/ReallySimpleDB/manager.py +++ b/ReallySimpleDB/manager.py @@ -8,6 +8,9 @@ def __init__(self) -> None: self._add_columns_cmd = "" self.connection = "" + def _not_table(self, table:str): + raise sqlite3.OperationalError("no such table: {}".format(table)) + def clean(self): self._add_columns_cmd = "" @@ -15,7 +18,7 @@ def create_connection(self, database): self.connection = sqlite3.connect(database) def create_db(self, dbpath:str="", replace:bool=False): - if self.connection == "" and not len(dbpath): + if self.connection == "" and not dbpath: raise TypeError("create_db() missing 1 required positional argument: 'dbpath'") if replace: @@ -65,10 +68,10 @@ def add_columns(self, return True def create_table(self, table_name:str, database:str=""): - if self.connection == "" and not len(database): + if self.connection == "" and not database: raise TypeError("create_table() missing 1 required positional argument: 'database'") - if len(database): + if database: self.create_connection(database) if self._add_columns_cmd == "": @@ -80,10 +83,10 @@ def create_table(self, table_name:str, database:str=""): return True def all_tables(self, database:str=""): - if self.connection == "" and not len(database): + if self.connection == "" and not database: raise TypeError("all_tables() missing 1 required positional argument: 'database'") - if len(database): + if database: self.create_connection(database) cursor = self.connection.cursor() @@ -91,10 +94,10 @@ def all_tables(self, database:str=""): return [student[0] for student in cursor.execute(sql_cmd)] def is_table(self, table_name:str, database:str=""): - if self.connection == "" and not len(database): + if self.connection == "" and not database: raise TypeError("is_table() missing 1 required positional argument: 'database'") - if len(database): + if database: self.create_connection(database) if table_name in self.all_tables(database): @@ -102,29 +105,31 @@ def is_table(self, table_name:str, database:str=""): return False def delete_table(self, table:str, database:str=""): - if self.connection == "" and not len(database): + if self.connection == "" and not database: raise TypeError("delete_table() missing 1 required positional argument: 'database'") - if len(database): + if database: self.create_connection(database) if self.is_table(table_name=table): cursor = self.connection.cursor() sql_cmd = "DROP TABLE {};".format(table) cursor.execute(sql_cmd) + return True - return False + self._not_table(table=table) def get_all_column_types(self, table:str, database:str=""): - if self.connection == "" and not len(database): + if self.connection == "" and not database: raise TypeError("get_all_column_types() missing 1 required positional argument: 'database'") - if len(database): + if database: self.create_connection(database) if self.is_table(table_name=table, database=database): cursor = self.connection.cursor() + sql_cmd = "PRAGMA TABLE_INFO({});".format(table) fetch = cursor.execute(sql_cmd) @@ -133,19 +138,20 @@ def get_all_column_types(self, table:str, database:str=""): data_dict[data[1]] = data[2] return data_dict - return False + + self._not_table(table=table) def get_column_type(self, table:str, column:str, database:str=""): all_data = self.get_all_column_types(table=table, database=database) - if type(all_data) != bool and column in all_data: + if (not isinstance(all_data, bool)) and (column in all_data): return all_data[column] return False def get_columns(self, table:str, database:str=""): - if self.connection == "" and not len(database): + if self.connection == "" and not database: raise TypeError("get_columns() missing 1 required positional argument: 'database'") - if len(database): + if database: self.create_connection(database) column_types = self.get_all_column_types(table=table, database=database) @@ -157,10 +163,10 @@ def get_columns(self, table:str, database:str=""): return columns def get_primary_key(self, table:str, database:str=""): - if self.connection == "" and not len(database): + if self.connection == "" and not database: raise TypeError("get_primary_key() missing 1 required positional argument: 'database'") - if len(database): + if database: self.create_connection(database) if self.is_table(table_name=table, database=database): @@ -171,53 +177,55 @@ def get_primary_key(self, table:str, database:str=""): return fetch.fetchall()[0][1] - return False + self._not_table(table=table) def add_record(self, table:str, record, database:str=""): - if self.connection == "" and not len(database): + if self.connection == "" and not database: raise TypeError("add_record() missing 1 required positional argument: 'database'") - if len(database): + if database: self.create_connection(database) if self.is_table(table_name=table, database=database): cursor = self.connection.cursor() - tmp_all_columns = self.get_all_column_types(table=table, database=database) - all_columns = {} - for column in tmp_all_columns: - all_columns[column] = "" - - fields = [] - sql_cmd = "INSERT INTO {} VALUES(".format(table) - if isinstance(record, dict): - for field in record: - if field not in all_columns: - raise NameError("'{}' column is not in the table".format(field)) - else: + tmp_all_columns = self.get_all_column_types(table=table, database=database) + all_columns = {} + for column in tmp_all_columns: + all_columns[column] = "" + + fields = [] + sql_cmd = "INSERT INTO {} VALUES(".format(table) + if isinstance(record, dict): + for field in record: + if field not in all_columns: + raise NameError("'{}' column is not in the table".format(field)) + if DATA_TYPES[tmp_all_columns[field]] == type(record[field]): all_columns[field] = record[field] else: raise TypeError("The '{}' field requires the '{}' type but got the '{}' type".format(field, DATA_TYPES[tmp_all_columns[field]], type(record[field]))) - for field in all_columns: - fields.append(all_columns[field]) - sql_cmd+= "?," + for field in all_columns: + fields.append(all_columns[field]) + sql_cmd+= "?," - sql_cmd = sql_cmd[:-1] + ");" + sql_cmd = sql_cmd[:-1] + ");" - cursor.execute(sql_cmd, fields) - self.connection.commit() - else: - raise TypeError("'record' must be dict") + cursor.execute(sql_cmd, fields) + self.connection.commit() + else: + raise TypeError("'record' must be dict") - return True + return True + + self._not_table(table=table) def get_record(self, table:str, primary_key, database:str=""): - if self.connection == "" and not len(database): + if self.connection == "" and not database: raise TypeError("get_record() missing 1 required positional argument: 'database'") - if len(database): + if database: self.create_connection(database) if self.is_table(table_name=table, database=database): @@ -236,12 +244,14 @@ def get_record(self, table:str, primary_key, database:str=""): return {} return record - + + self._not_table(table=table) + def get_all_records(self, table:str, database:str=""): - if self.connection == "" and not len(database): - raise TypeError("get_record() missing 1 required positional argument: 'database'") + if self.connection == "" and not database: + raise TypeError("get_all_records() missing 1 required positional argument: 'database'") - if len(database): + if database: self.create_connection(database) if self.is_table(table_name=table, database=database): @@ -263,11 +273,13 @@ def get_all_records(self, table:str, database:str=""): return records + self._not_table(table=table) + def delete_record(self, table:str, primary_key, database:str=""): - if self.connection == "" and not len(database): + if self.connection == "" and not database: raise TypeError("delete_record() missing 1 required positional argument: 'database'") - if len(database): + if database: self.create_connection(database) if self.is_table(table_name=table, database=database): @@ -276,13 +288,15 @@ def delete_record(self, table:str, primary_key, database:str=""): cursor.execute(sql, (primary_key,)) self.connection.commit() - return True - + return True + + self._not_table(table=table) + def filter_records(self, table:str, values:dict, database:str=""): - if self.connection == "" and not len(database): - raise TypeError("delete_record() missing 1 required positional argument: 'database'") + if self.connection == "" and not database: + raise TypeError("filter_records() missing 1 required positional argument: 'database'") - if len(database): + if database: self.create_connection(database) if self.is_table(table_name=table, database=database): @@ -313,6 +327,8 @@ def filter_records(self, table:str, values:dict, database:str=""): return records + self._not_table(table=table) + def close_connection(self): self.connection.close() return True diff --git a/tests/test_manager.py b/tests/test_manager.py index c3dca6e..1becc88 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -1,4 +1,5 @@ import os +from sqlite3 import OperationalError from ReallySimpleDB import dbmanager _dbmanager = dbmanager() @@ -58,10 +59,16 @@ def test_is_table_4(): assert not _dbmanager.is_table(table_name="NON") def test_delete_table_1(): - assert _dbmanager.delete_table(table="EMPLOYEES") + try: + _dbmanager.delete_table(table="EMPLOYEES") + except OperationalError: + assert True def test_delete_table_2(): - assert not _dbmanager.delete_table(table="EMPLOYEES") + try: + _dbmanager.delete_table(table="EMPLOYEES") + except OperationalError: + assert True def test_get_all_column_types_1(): assert _dbmanager.get_all_column_types(table="STUDENTS") \ @@ -78,13 +85,19 @@ def test_get_column_type_2(): assert _dbmanager.get_column_type(table="STUDENTS", column="address") == False def test_get_column_type_3(): - assert _dbmanager.get_column_type(table="EMPLOYEES", column="emp_id") == False + try: + _dbmanager.get_column_type(table="EMPLOYEES", column="emp_id") + except OperationalError: + assert True def test_get_columns_1(): assert _dbmanager.get_columns(table="STUDENTS") == ["student_id", "name", "mark", "year"] def test_get_columns_2(): - assert _dbmanager.get_columns(table="EMPLOYEES") == [] + try: + _dbmanager.get_columns(table="EMPLOYEES") + except OperationalError: + assert True def test_get_primary_key_1(): assert _dbmanager.get_primary_key(table="STUDENTS") == "student_id" @@ -125,6 +138,12 @@ def test_filter_record_2(): def test_delete_record_1(): assert _dbmanager.delete_record(table="STUDENTS", primary_key="1010") +def test_delete_record_2(): + try: + _dbmanager.delete_record(table="STUDENTSS", primary_key="1010") + except OperationalError: + assert True + def test_finally(): delete_db()