Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preping connection to external database. Silently restore broken conn… #145

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions octoprint_filamentmanager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,6 @@ def notify(pid, channel, payload):
except Exception as e:
self._logger.error("Failed to set temperature offsets: {message}".format(message=str(e)))

def on_shutdown(self):
if self.filamentManager is not None:
self.filamentManager.close()

def on_data_modified(self, data, action):
if action.lower() == "update":
# if either profiles, spools or selections are updated
Expand Down
9 changes: 5 additions & 4 deletions octoprint_filamentmanager/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,11 @@ def test_database_connection(self):
return make_response("Configuration does not contain mandatory '{}' field".format(key), 400)

try:
connection = self.filamentManager.connect(config["uri"],
database=config["name"],
username=config["user"],
password=config["password"])
db = self.filamentManager.get_database(config["uri"],
database=config["name"],
username=config["user"],
password=config["password"])
connection = db.connect()
except Exception as e:
return make_response("Failed to connect to the database with the given configuration", 400)
else:
Expand Down
127 changes: 65 additions & 62 deletions octoprint_filamentmanager/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ class FilamentManager(object):

def __init__(self, config):
self.notify = None
self.conn = self.connect(config.get("uri", ""),
database=config.get("name", ""),
username=config.get("user", ""),
password=config.get("password", ""))
self.db = self.get_database(config.get("uri", ""),
database=config.get("name", ""),
username=config.get("user", ""),
password=config.get("password", ""))

# QUESTION thread local connection (pool) vs sharing a serialized connection, pro/cons?
# from sqlalchemy.orm import sessionmaker, scoped_session
Expand All @@ -41,12 +41,12 @@ def __init__(self, config):

if self.engine_dialect_is(self.DIALECT_SQLITE):
# Enable foreign key constraints
self.conn.execute(text("PRAGMA foreign_keys = ON").execution_options(autocommit=True))
self.db.execute(text("PRAGMA foreign_keys = ON").execution_options(autocommit=True))
elif self.engine_dialect_is(self.DIALECT_POSTGRESQL):
# Create listener thread
self.notify = PGNotify(self.conn.engine.url)
self.notify = PGNotify(self.db.url)

def connect(self, uri, database="", username="", password=""):
def get_database(self, uri, database="", username="", password=""):
uri_parts = urisplit(uri)

if uri_parts.scheme == self.DIALECT_SQLITE:
Expand All @@ -58,17 +58,14 @@ def connect(self, uri, database="", username="", password=""):
database=database,
username=username,
password=password)
engine = create_engine(uri)
engine = create_engine(uri, pool_pre_ping=True)
else:
raise ValueError("Engine '{engine}' not supported".format(engine=uri_parts.scheme))

return engine.connect()

def close(self):
self.conn.close()
return engine

def engine_dialect_is(self, dialect):
return self.conn.engine.dialect.name == dialect if self.conn is not None else False
return self.db.dialect.name == dialect if self.db is not None else False

def initialize(self):
metadata = MetaData()
Expand Down Expand Up @@ -108,11 +105,11 @@ def initialize(self):

if self.engine_dialect_is(self.DIALECT_POSTGRESQL):
def should_create_function(name):
row = self.conn.execute("select proname from pg_proc where proname = '%s'" % name).scalar()
row = self.db.execute("select proname from pg_proc where proname = '%s'" % name).scalar()
return not bool(row)

def should_create_trigger(name):
row = self.conn.execute("select tgname from pg_trigger where tgname = '%s'" % name).scalar()
row = self.db.execute("select tgname from pg_trigger where tgname = '%s'" % name).scalar()
return not bool(row)

trigger_function = DDL("""
Expand Down Expand Up @@ -155,65 +152,69 @@ def should_create_trigger(name):
""".format(name=name, table=table, action=action))
event.listen(metadata, "after_create", trigger)

metadata.create_all(self.conn, checkfirst=True)
metadata.create_all(self.db, checkfirst=True)

def execute_script(self, script):
with self.lock, self.conn.begin():
conn = self.db.connect()
with self.lock, conn.begin():
for stmt in script.split(";"):
self.conn.execute(text(stmt))
conn.execute(text(stmt))
conn.close()

# versioning

def get_schema_version(self):
with self.lock, self.conn.begin():
return self.conn.execute(select([func.max(self.versioning.c.schema_id)])).scalar()
with self.lock:
return self.db.execute(select([func.max(self.versioning.c.schema_id)])).scalar()

def set_schema_version(self, version):
with self.lock, self.conn.begin():
self.conn.execute(insert(self.versioning).values((version,)))
self.conn.execute(delete(self.versioning).where(self.versioning.c.schema_id < version))
conn = self.db.connect()
with self.lock, conn.begin():
conn.execute(insert(self.versioning).values((version,)))
conn.execute(delete(self.versioning).where(self.versioning.c.schema_id < version))
conn.close()

# profiles

def get_all_profiles(self):
with self.lock, self.conn.begin():
with self.lock:
stmt = select([self.profiles]).order_by(self.profiles.c.material, self.profiles.c.vendor)
result = self.conn.execute(stmt)
result = self.db.execute(stmt)
return self._result_to_dict(result)

def get_profiles_lastmodified(self):
with self.lock, self.conn.begin():
with self.lock:
stmt = select([self.modifications.c.changed_at]).where(self.modifications.c.table_name == "profiles")
return self.conn.execute(stmt).scalar()
return self.db.execute(stmt).scalar()

def get_profile(self, identifier):
with self.lock, self.conn.begin():
with self.lock:
stmt = select([self.profiles]).where(self.profiles.c.id == identifier)\
.order_by(self.profiles.c.material, self.profiles.c.vendor)
result = self.conn.execute(stmt)
result = self.db.execute(stmt)
return self._result_to_dict(result, one=True)

def create_profile(self, data):
with self.lock, self.conn.begin():
with self.lock:
stmt = insert(self.profiles)\
.values(vendor=data["vendor"], material=data["material"], density=data["density"],
diameter=data["diameter"])
result = self.conn.execute(stmt)
data["id"] = result.lastrowid
result = self.db.execute(stmt)
data["id"] = result.inserted_primary_key[0]
return data

def update_profile(self, identifier, data):
with self.lock, self.conn.begin():
with self.lock:
stmt = update(self.profiles).where(self.profiles.c.id == identifier)\
.values(vendor=data["vendor"], material=data["material"], density=data["density"],
diameter=data["diameter"])
self.conn.execute(stmt)
self.db.execute(stmt)
return data

def delete_profile(self, identifier):
with self.lock, self.conn.begin():
with self.lock:
stmt = delete(self.profiles).where(self.profiles.c.id == identifier)
self.conn.execute(stmt)
self.db.execute(stmt)

# spools

Expand All @@ -228,48 +229,48 @@ def _build_spool_dict(self, row, column_names):
return spool

def get_all_spools(self):
with self.lock, self.conn.begin():
with self.lock:
j = self.spools.join(self.profiles, self.spools.c.profile_id == self.profiles.c.id)
stmt = select([self.spools, self.profiles]).select_from(j).order_by(self.spools.c.name)
result = self.conn.execute(stmt)
result = self.db.execute(stmt)
return [self._build_spool_dict(row, row.keys()) for row in result.fetchall()]

def get_spools_lastmodified(self):
with self.lock, self.conn.begin():
with self.lock:
stmt = select([func.max(self.modifications.c.changed_at)])\
.where(self.modifications.c.table_name.in_(["spools", "profiles"]))
return self.conn.execute(stmt).scalar()
return self.db.execute(stmt).scalar()

def get_spool(self, identifier):
with self.lock, self.conn.begin():
with self.lock:
j = self.spools.join(self.profiles, self.spools.c.profile_id == self.profiles.c.id)
stmt = select([self.spools, self.profiles]).select_from(j)\
.where(self.spools.c.id == identifier).order_by(self.spools.c.name)
result = self.conn.execute(stmt)
result = self.db.execute(stmt)
row = result.fetchone()
return self._build_spool_dict(row, row.keys()) if row is not None else None

def create_spool(self, data):
with self.lock, self.conn.begin():
with self.lock:
stmt = insert(self.spools)\
.values(name=data["name"], cost=data["cost"], weight=data["weight"], used=data["used"],
temp_offset=data["temp_offset"], profile_id=data["profile"]["id"])
result = self.conn.execute(stmt)
data["id"] = result.lastrowid
result = self.db.execute(stmt)
data["id"] = result.inserted_primary_key[0]
return data

def update_spool(self, identifier, data):
with self.lock, self.conn.begin():
with self.lock:
stmt = update(self.spools).where(self.spools.c.id == identifier)\
.values(name=data["name"], cost=data["cost"], weight=data["weight"], used=data["used"],
temp_offset=data["temp_offset"], profile_id=data["profile"]["id"])
self.conn.execute(stmt)
self.db.execute(stmt)
return data

def delete_spool(self, identifier):
with self.lock, self.conn.begin():
with self.lock:
stmt = delete(self.spools).where(self.spools.c.id == identifier)
self.conn.execute(stmt)
self.db.execute(stmt)

# selections

Expand All @@ -287,26 +288,26 @@ def _build_selection_dict(self, row, column_names):
return sel

def get_all_selections(self, client_id):
with self.lock, self.conn.begin():
with self.lock:
j1 = self.selections.join(self.spools, self.selections.c.spool_id == self.spools.c.id)
j2 = j1.join(self.profiles, self.spools.c.profile_id == self.profiles.c.id)
stmt = select([self.selections, self.spools, self.profiles]).select_from(j2)\
.where(self.selections.c.client_id == client_id).order_by(self.selections.c.tool)
result = self.conn.execute(stmt)
result = self.db.execute(stmt)
return [self._build_selection_dict(row, row.keys()) for row in result.fetchall()]

def get_selection(self, identifier, client_id):
with self.lock, self.conn.begin():
with self.lock:
j1 = self.selections.join(self.spools, self.selections.c.spool_id == self.spools.c.id)
j2 = j1.join(self.profiles, self.spools.c.profile_id == self.profiles.c.id)
stmt = select([self.selections, self.spools, self.profiles]).select_from(j2)\
.where((self.selections.c.tool == identifier) & (self.selections.c.client_id == client_id))
result = self.conn.execute(stmt)
result = self.db.execute(stmt)
row = result.fetchone()
return self._build_selection_dict(row, row.keys()) if row is not None else dict(tool=identifier, spool=None)

def update_selection(self, identifier, client_id, data):
with self.lock, self.conn.begin():
with self.lock:
values = dict()
if self.engine_dialect_is(self.DIALECT_SQLITE):
stmt = insert(self.selections).prefix_with("OR REPLACE")\
Expand All @@ -315,13 +316,13 @@ def update_selection(self, identifier, client_id, data):
stmt = pg_insert(self.selections)\
.values(tool=identifier, client_id=client_id, spool_id=data["spool"]["id"])\
.on_conflict_do_update(constraint="selections_pkey", set_=dict(spool_id=data["spool"]["id"]))
self.conn.execute(stmt)
self.db.execute(stmt)
return self.get_selection(identifier, client_id)

def export_data(self, dirpath):
def to_csv(table):
with self.lock, self.conn.begin():
result = self.conn.execute(select([table]))
with self.lock:
result = self.db.execute(select([table]))
filepath = os.path.join(dirpath, table.name + ".csv")
with io.open(filepath, mode="w", encoding="utf-8") as csv_file:
csv_writer = csv.writer(csv_file)
Expand All @@ -338,27 +339,29 @@ def from_csv(table):
with io.open(filepath, mode="r", encoding="utf-8") as csv_file:
csv_reader = csv.reader(csv_file)
header = next(csv_reader)
with self.lock, self.conn.begin():
conn = self.db.connect()
with self.lock, conn.begin():
for row in csv_reader:
values = dict(zip(header, row))

if self.engine_dialect_is(self.DIALECT_SQLITE):
identifier = values[table.c.id.name]
# try to update entry
stmt = update(table).values(values).where(table.c.id == identifier)
if self.conn.execute(stmt).rowcount == 0:
if conn.execute(stmt).rowcount == 0:
# identifier doesn't match any => insert new entry
stmt = insert(table).values(values)
self.conn.execute(stmt)
conn.execute(stmt)
elif self.engine_dialect_is(self.DIALECT_POSTGRESQL):
stmt = pg_insert(table).values(values)\
.on_conflict_do_update(index_elements=[table.c.id], set_=values)
self.conn.execute(stmt)
conn.execute(stmt)

if self.engine_dialect_is(self.DIALECT_POSTGRESQL):
# update sequence
sql = "SELECT setval('{table}_id_seq', max(id)) FROM {table}".format(table=table.name)
self.conn.execute(text(sql))
conn.execute(text(sql))
conn.close()

tables = [self.profiles, self.spools]
for t in tables:
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
plugin_identifier = "filamentmanager"
plugin_package = "octoprint_filamentmanager"
plugin_name = "OctoPrint-FilamentManager"
plugin_version = "0.5.3"
plugin_version = "0.5.4"
plugin_description = "Manage your spools and keep track of remaining filament on them"
plugin_author = "Sven Lohrmann"
plugin_author_email = "[email protected]"
plugin_url = "https://github.com/malnvenshorn/OctoPrint-FilamentManager"
plugin_license = "AGPLv3"
plugin_requires = ["backports.csv>=1.0.5,<1.1",
"uritools>=2.1,<2.2",
"SQLAlchemy>=1.1.15,<1.2"]
"SQLAlchemy>=1.2"]
plugin_additional_data = []
plugin_additional_packages = []
plugin_ignored_packages = []
Expand Down