Skip to content

Commit

Permalink
Merge pull request #853 from MolSSI/dbcheck
Browse files Browse the repository at this point in the history
Reduce number of SQLAlchemySocket instances created on startup
  • Loading branch information
bennybp authored Oct 24, 2024
2 parents 0557078 + 5fa330b commit 1e9bab8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 17 deletions.
22 changes: 17 additions & 5 deletions qcfractal/qcfractal/db_socket/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from functools import lru_cache
from typing import TYPE_CHECKING

from sqlalchemy import create_engine, exc, event, inspect, select, union, MetaData
from sqlalchemy import create_engine, exc, event, inspect, select, union, MetaData, Engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool

Expand Down Expand Up @@ -85,7 +85,7 @@ def checkout(dbapi_connection, connection_record, connection_proxy):
)

# Check to see if the db is up-to-date
self.check_db_revision()
self._check_db_revision(self.engine)

self.Session = sessionmaker(bind=self.engine, future=True)

Expand Down Expand Up @@ -231,9 +231,10 @@ def upgrade_database(db_config: DatabaseConfig, revision: str = "head") -> None:
alembic_cfg = SQLAlchemySocket.get_alembic_config(db_config)
command.upgrade(alembic_cfg, revision)

def check_db_revision(self):
@staticmethod
def _check_db_revision(engine: Engine):
"""
Checks to make sure the database is up-to-date
Checks to make sure the database is up-to-date, given an engine
Will raise an exception if it is not up-to-date
"""
Expand All @@ -245,7 +246,7 @@ def check_db_revision(self):
script = ScriptDirectory(script_dir)
heads = script.get_heads()

conn = self.engine.connect()
conn = engine.connect()
context = MigrationContext.configure(connection=conn)
current_rev = context.get_current_revision()

Expand All @@ -260,6 +261,17 @@ def check_db_revision(self):
finally:
conn.close()

@staticmethod
def check_db_revision(db_config: DatabaseConfig):
"""
Checks to make sure the database is up-to-date, given a configuration
Will raise an exception if it is not up-to-date
"""

engine = create_engine(db_config.sqlalchemy_url, echo=False, poolclass=NullPool, future=True)
SQLAlchemySocket._check_db_revision(engine)

@contextmanager
def session_scope(self, read_only: bool = False):
"""Provide a session as a context manager
Expand Down
30 changes: 18 additions & 12 deletions qcfractal/qcfractal/qcfractal_server_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from .postgres_harness import PostgresHarness

if TYPE_CHECKING:
from typing import Tuple
from logging import Logger


Expand Down Expand Up @@ -73,7 +72,7 @@ def dump_config(qcf_config: FractalConfig, indent: int = 0) -> str:
return s


def start_database(config: FractalConfig) -> Tuple[PostgresHarness, SQLAlchemySocket]:
def start_database(config: FractalConfig, check_revision: bool) -> PostgresHarness:
"""
Obtain a storage socket to a running postgres server
Expand All @@ -100,10 +99,11 @@ def start_database(config: FractalConfig) -> Tuple[PostgresHarness, SQLAlchemySo
if not pg_harness.can_connect():
raise RuntimeError(f"Database at {config.database.safe_uri} does not exist?")

# Start up a socket. The main thing is to see if it can connect, and also
# to check if the database needs to be upgraded
# We then no longer need the socket (everything else uses their own)
return pg_harness, SQLAlchemySocket(config)
# Check that the database is up to date
if check_revision:
SQLAlchemySocket.check_db_revision(config.database)

return pg_harness


def parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -411,7 +411,7 @@ def server_start(config):
logger = logging.getLogger(__name__)

# Ensure that the database is alive, optionally starting it
start_database(config)
start_database(config, check_revision=True)

# Set up a queue for logging. All child process will send logs
# to this queue, and a separate thread will handle them
Expand Down Expand Up @@ -510,7 +510,7 @@ def server_start_job_runner(config):

# Ensure that the database is alive. This also handles checking stuff,
# even if we don't own the db (which we shouldn't)
start_database(config)
start_database(config, check_revision=True)

# Now just run the job runner directly
job_runner = FractalJobRunner(config)
Expand Down Expand Up @@ -539,7 +539,7 @@ def server_start_api(config):

# Ensure that the database is alive. This also handles checking stuff,
# even if we don't own the db (which we shouldn't)
start_database(config)
start_database(config, check_revision=True)

# Now just run the api process
api = FractalWaitressApp(config)
Expand Down Expand Up @@ -635,7 +635,10 @@ def server_upgrade_config(config_path):

def server_user(args: argparse.Namespace, config: FractalConfig):
user_command = args.user_command
pg_harness, storage = start_database(config)

# Don't check revision here - it will be done in the SQLAlchemySocket constructor
start_database(config, check_revision=False)
storage = SQLAlchemySocket(config)

def print_user_info(u: UserInfo):
enabled = "True" if u.enabled else "False"
Expand Down Expand Up @@ -742,7 +745,10 @@ def print_user_info(u: UserInfo):

def server_role(args: argparse.Namespace, config: FractalConfig):
role_command = args.role_command
pg_harness, storage = start_database(config)

# Don't check revision here - it will be done in the SQLAlchemySocket constructor
start_database(config, check_revision=False)
storage = SQLAlchemySocket(config)

def print_role_info(r: RoleInfo):
print("-" * 80)
Expand Down Expand Up @@ -772,7 +778,7 @@ def print_role_info(r: RoleInfo):


def server_backup(args: argparse.Namespace, config: FractalConfig):
pg_harness, _ = start_database(config)
pg_harness = start_database(config, check_revision=True)

db_size = pg_harness.database_size()
pretty_size = pretty_bytes(db_size)
Expand Down

0 comments on commit 1e9bab8

Please sign in to comment.