Skip to content

Commit

Permalink
Merge pull request #3 from truethari/alpha
Browse files Browse the repository at this point in the history
version 1.2
Fixed: Possible SQL injection vector through string-based query construction
  • Loading branch information
truethari authored Feb 28, 2022
2 parents 16c9ec1 + a78a296 commit 16b9401
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 76 deletions.
76 changes: 36 additions & 40 deletions ReallySimpleDB/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,37 @@
from .utils import DATA_TYPES

class ReallySimpleDB:
""" ReallySimpleDB class
"""
ReallySimpleDB class.
ReallySimpleDB objects are the ones responsible of creating DBs, connecting
with them, creating tables, adding records, geting records, among tasks. In
more cases these should be one per database.
"""

def __init__(self) -> None:
"""
create a object
"""
"""Create a object."""
self._add_columns_cmd = ""
self.connection = ""

def clean(self):
"""
cleans add_columns data
Clean add_columns data.
why? _add_columns_cmd variable is for define SQL command. when using add_column,
Why? _add_columns_cmd variable is for define SQL command. when using add_column,
it sets up a string here. but when it is finished this is not clean and the data
continues to exist. when use add_column again and again, it will be processed
along with the existing data. this should be used to prevent it.
"""
self._add_columns_cmd = ""

def create_connection(self, database):
"""opens a connection to the SQLite database file"""
"""Open a connection to the SQLite database file."""
self.connection = sqlite3.connect(database)
return True

def create_db(self, dbpath:str="", replace:bool=False):
"""creates a new database in a given path"""
"""Create a new database in a given path."""
if self.connection == "" and not dbpath:
raise TypeError("create_db() missing 1 required positional argument: 'dbpath'")

Expand All @@ -61,14 +60,13 @@ def add_columns(self,
database:str="",
table:str=""):
"""
add columns to an existing table / define columns before creating a table
Add columns to an existing table / define columns before creating a table.
if use for create new table: sqlite cannot create table without columns.
If use for create new table: sqlite cannot create table without columns.
so user must first define the columns and create a table.
important: user have to close connection here. if not, code returns error.
because it tries to add column to existing table.
"""

# checks if the user is trying to add unsupported data type
if datatype.upper() not in DATA_TYPES:
raise TypeError("datatype not supported, '{}'".format(datatype))
Expand All @@ -81,7 +79,7 @@ def add_columns(self,
# column to an existing table.
self.create_connection(database=database)
cursor = self.connection.cursor()
sql_cmd = "ALTER TABLE {} ADD COLUMN {} {}".format(table, column_name, datatype)
sql_cmd = "ALTER TABLE " + table + " ADD COLUMN " + column_name + " " + datatype
if not_null:
sql_cmd += " NOT NULL"
if primary_key:
Expand All @@ -92,7 +90,7 @@ def add_columns(self,
# if table is not defines, it means that the user is trying to add / define
# a column to a new table. so the following code add SQL syntax globally for
# use when creating new table
self._add_columns_cmd += (",{} {}".format(column_name, datatype))
self._add_columns_cmd += "," + column_name + " " + datatype

if primary_key:
self._add_columns_cmd += " PRIMARY KEY"
Expand All @@ -103,7 +101,7 @@ def add_columns(self,
return True

def create_table(self, table_name:str, database:str=""):
"""creates new table in database"""
"""Create new table in database."""
if self.connection == "" and not database:
raise TypeError("create_table() missing 1 required positional argument: 'database'")

Expand All @@ -116,13 +114,13 @@ def create_table(self, table_name:str, database:str=""):
if self._add_columns_cmd == "":
raise NotImplementedError("call 'add_columns' function before create table")

sql_cmd = "CREATE TABLE {} ({})".format(table_name, self._add_columns_cmd[1:])
sql_cmd = "CREATE TABLE " + table_name + " (" + self._add_columns_cmd[1:] + ")"

self.connection.execute(sql_cmd)
return True

def all_tables(self, database:str=""):
"""get a list of all the tables in the database"""
"""Get a list of all the tables in the database."""
if self.connection == "" and not database:
raise TypeError("all_tables() missing 1 required positional argument: 'database'")

Expand All @@ -134,7 +132,7 @@ def all_tables(self, database:str=""):
return [tables[0] for tables in cursor.execute(sql_cmd)]

def is_table(self, table_name:str, database:str=""):
"""checks if the given table is exists in the database"""
"""Check if the given table is exists in the database."""
if self.connection == "" and not database:
raise TypeError("is_table() missing 1 required positional argument: 'database'")

Expand All @@ -146,7 +144,7 @@ def is_table(self, table_name:str, database:str=""):
return False

def delete_table(self, table:str, database:str=""):
"""delete a table from the database"""
"""Delete a table from the database."""
if self.connection == "" and not database:
raise TypeError("delete_table() missing 1 required positional argument: 'database'")

Expand All @@ -155,7 +153,7 @@ def delete_table(self, table:str, database:str=""):

if self.is_table(table_name=table):
cursor = self.connection.cursor()
sql_cmd = "DROP TABLE {};".format(table)
sql_cmd = "DROP TABLE " + table + ";"
cursor.execute(sql_cmd)

return True
Expand All @@ -164,7 +162,7 @@ def delete_table(self, table:str, database:str=""):
raise sqlite3.OperationalError("no such table: {}".format(table))

def get_all_column_types(self, table:str, database:str=""):
"""get all the column names with the data types in a table"""
"""Get all the column names with the data types in a table."""
if self.connection == "" and not database:
raise TypeError(
"get_all_column_types() missing 1 required positional argument: 'database'")
Expand All @@ -175,7 +173,7 @@ def get_all_column_types(self, table:str, database:str=""):
if self.is_table(table_name=table, database=database):
cursor = self.connection.cursor()

sql_cmd = "PRAGMA TABLE_INFO({});".format(table)
sql_cmd = "PRAGMA TABLE_INFO(" + table + ");"
fetch = cursor.execute(sql_cmd)

data_dict = {}
Expand All @@ -188,7 +186,7 @@ def get_all_column_types(self, table:str, database:str=""):
raise sqlite3.OperationalError("no such table: {}".format(table))

def get_column_type(self, table:str, column:str, database:str=""):
"""get data type of a column in a table"""
"""Get data type of a column in a table."""
all_data = self.get_all_column_types(table=table, database=database)

# if columns exists in the table and given column in the table
Expand All @@ -198,7 +196,7 @@ def get_column_type(self, table:str, column:str, database:str=""):
raise sqlite3.OperationalError("no such column: {}".format(column))

def get_columns(self, table:str, database:str=""):
"""get all the column names list in a table"""
"""Get all the column names list in a table."""
if self.connection == "" and not database:
raise TypeError("get_columns() missing 1 required positional argument: 'database'")

Expand All @@ -215,7 +213,7 @@ def get_columns(self, table:str, database:str=""):
return columns

def get_primary_key(self, table:str, database:str=""):
"""find and get primary key of a table"""
"""Find and get primary key of a table."""
if self.connection == "" and not database:
raise TypeError("get_primary_key() missing 1 required positional argument: 'database'")

Expand All @@ -225,15 +223,15 @@ def get_primary_key(self, table:str, database:str=""):
if self.is_table(table_name=table, database=database):
cursor = self.connection.cursor()

sql_cmd = "SELECT * FROM pragma_table_info('{}') WHERE pk;".format(table)
fetch = cursor.execute(sql_cmd)
sql_cmd = "SELECT * FROM pragma_table_info(?) WHERE pk;"
fetch = cursor.execute(sql_cmd, (table,))
return fetch.fetchall()[0][1]

# raise OperationalError if the given table not exists
raise sqlite3.OperationalError("no such table: {}".format(table))

def add_record(self, table:str, record, database:str=""):
"""add a new record to a table"""
"""Add a new record to a table."""
if self.connection == "" and not database:
raise TypeError("add_record() missing 1 required positional argument: 'database'")

Expand All @@ -252,7 +250,7 @@ def add_record(self, table:str, record, database:str=""):
all_columns[column] = ""

fields = []
sql_cmd = "INSERT INTO {} VALUES(".format(table)
sql_cmd = "INSERT INTO " + table + " VALUES("

# if record is dict type,..
if isinstance(record, dict):
Expand Down Expand Up @@ -288,7 +286,7 @@ def add_record(self, table:str, record, database:str=""):
raise sqlite3.OperationalError("no such table: {}".format(table))

def get_record(self, table:str, primary_key, database:str=""):
"""get row data / record from a table using the primary key"""
"""Get row data / record from a table using the primary key."""
if self.connection == "" and not database:
raise TypeError("get_record() missing 1 required positional argument: 'database'")

Expand All @@ -298,8 +296,7 @@ def get_record(self, table:str, primary_key, database:str=""):
if self.is_table(table_name=table, database=database):
cursor = self.connection.cursor()

sql_cmd = "SELECT * FROM {} WHERE {}=?;".format(
table, self.get_primary_key(table=table, database=database))
sql_cmd = "SELECT * FROM " + table + " WHERE " + self.get_primary_key(table=table, database=database) + "=?;"
fetch = cursor.execute(sql_cmd, (primary_key,))

# get columns list using get_columns
Expand All @@ -321,7 +318,7 @@ def get_record(self, table:str, primary_key, database:str=""):
raise sqlite3.OperationalError("no such table: {}".format(table))

def get_all_records(self, table:str, database:str=""):
"""get all data / records of a table"""
"""Get all data / records of a table."""
if self.connection == "" and not database:
raise TypeError("get_all_records() missing 1 required positional argument: 'database'")

Expand All @@ -331,7 +328,7 @@ def get_all_records(self, table:str, database:str=""):
if self.is_table(table_name=table, database=database):
cursor = self.connection.cursor()

sql_cmd = "SELECT * FROM {}".format(table)
sql_cmd = "SELECT * FROM " + table
cursor.execute(sql_cmd)
rows = cursor.fetchall()

Expand All @@ -353,7 +350,7 @@ def get_all_records(self, table:str, database:str=""):
raise sqlite3.OperationalError("no such table: {}".format(table))

def delete_record(self, table:str, primary_key, database:str=""):
"""delete record from a table"""
"""Delete record from a table."""
if self.connection == "" and not database:
raise TypeError("delete_record() missing 1 required positional argument: 'database'")

Expand All @@ -362,8 +359,7 @@ def delete_record(self, table:str, primary_key, database:str=""):

if self.is_table(table_name=table, database=database):
cursor = self.connection.cursor()
sql = "DELETE FROM {} WHERE {}=?".format(
table, self.get_primary_key(table=table, database=database))
sql = "DELETE FROM " + table + " WHERE " + self.get_primary_key(table=table, database=database) + "=?"
cursor.execute(sql, (primary_key,))
self.connection.commit()

Expand All @@ -374,9 +370,9 @@ def delete_record(self, table:str, primary_key, database:str=""):

def filter_records(self, table:str, values:dict, database:str=""):
"""
get filtered record list from a table
Get filtered record list from a table.
this will return one or more records by checking the values.
This will return one or more records by checking the values.
"""
if self.connection == "" and not database:
raise TypeError("filter_records() missing 1 required positional argument: 'database'")
Expand All @@ -389,7 +385,7 @@ def filter_records(self, table:str, values:dict, database:str=""):

operators = [">", "<", "!", "="]

sql = "SELECT * FROM {} WHERE ".format(table)
sql = "SELECT * FROM " + table + " WHERE "

for value in values:
try:
Expand Down Expand Up @@ -427,6 +423,6 @@ def filter_records(self, table:str, values:dict, database:str=""):
raise sqlite3.OperationalError("no such table: {}".format(table))

def close_connection(self):
"""close the connection with the SQLite database file"""
"""Close the connection with the SQLite database file."""
self.connection.close()
return True
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="ReallySimpleDB",
version="1.1",
version="1.2",
description="A tool for easily manage databases with Python",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
Loading

0 comments on commit 16b9401

Please sign in to comment.