diff --git a/qcfractal/qcfractal/db_socket/socket.py b/qcfractal/qcfractal/db_socket/socket.py index 48cec7034..3635439b0 100644 --- a/qcfractal/qcfractal/db_socket/socket.py +++ b/qcfractal/qcfractal/db_socket/socket.py @@ -12,7 +12,7 @@ from functools import lru_cache from typing import TYPE_CHECKING -from sqlalchemy import create_engine, exc, event, inspect, select, union, MetaData +from sqlalchemy import create_engine, exc, event, inspect, select, union, MetaData, Engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool @@ -85,7 +85,7 @@ def checkout(dbapi_connection, connection_record, connection_proxy): ) # Check to see if the db is up-to-date - self.check_db_revision() + self._check_db_revision(self.engine) self.Session = sessionmaker(bind=self.engine, future=True) @@ -231,9 +231,10 @@ def upgrade_database(db_config: DatabaseConfig, revision: str = "head") -> None: alembic_cfg = SQLAlchemySocket.get_alembic_config(db_config) command.upgrade(alembic_cfg, revision) - def check_db_revision(self): + @staticmethod + def _check_db_revision(engine: Engine): """ - Checks to make sure the database is up-to-date + Checks to make sure the database is up-to-date, given an engine Will raise an exception if it is not up-to-date """ @@ -245,7 +246,7 @@ def check_db_revision(self): script = ScriptDirectory(script_dir) heads = script.get_heads() - conn = self.engine.connect() + conn = engine.connect() context = MigrationContext.configure(connection=conn) current_rev = context.get_current_revision() @@ -260,6 +261,17 @@ def check_db_revision(self): finally: conn.close() + @staticmethod + def check_db_revision(db_config: DatabaseConfig): + """ + Checks to make sure the database is up-to-date, given a configuration + + Will raise an exception if it is not up-to-date + """ + + engine = create_engine(db_config.sqlalchemy_url, echo=False, poolclass=NullPool, future=True) + SQLAlchemySocket._check_db_revision(engine) + @contextmanager def session_scope(self, read_only: bool = False): """Provide a session as a context manager diff --git a/qcfractal/qcfractal/qcfractal_server_cli.py b/qcfractal/qcfractal/qcfractal_server_cli.py index 944a5bb22..54c162c5e 100644 --- a/qcfractal/qcfractal/qcfractal_server_cli.py +++ b/qcfractal/qcfractal/qcfractal_server_cli.py @@ -30,7 +30,6 @@ from .postgres_harness import PostgresHarness if TYPE_CHECKING: - from typing import Tuple from logging import Logger @@ -73,7 +72,7 @@ def dump_config(qcf_config: FractalConfig, indent: int = 0) -> str: return s -def start_database(config: FractalConfig) -> Tuple[PostgresHarness, SQLAlchemySocket]: +def start_database(config: FractalConfig, check_revision: bool) -> PostgresHarness: """ Obtain a storage socket to a running postgres server @@ -100,10 +99,11 @@ def start_database(config: FractalConfig) -> Tuple[PostgresHarness, SQLAlchemySo if not pg_harness.can_connect(): raise RuntimeError(f"Database at {config.database.safe_uri} does not exist?") - # Start up a socket. The main thing is to see if it can connect, and also - # to check if the database needs to be upgraded - # We then no longer need the socket (everything else uses their own) - return pg_harness, SQLAlchemySocket(config) + # Check that the database is up to date + if check_revision: + SQLAlchemySocket.check_db_revision(config.database) + + return pg_harness def parse_args() -> argparse.Namespace: @@ -411,7 +411,7 @@ def server_start(config): logger = logging.getLogger(__name__) # Ensure that the database is alive, optionally starting it - start_database(config) + start_database(config, check_revision=True) # Set up a queue for logging. All child process will send logs # to this queue, and a separate thread will handle them @@ -510,7 +510,7 @@ def server_start_job_runner(config): # Ensure that the database is alive. This also handles checking stuff, # even if we don't own the db (which we shouldn't) - start_database(config) + start_database(config, check_revision=True) # Now just run the job runner directly job_runner = FractalJobRunner(config) @@ -539,7 +539,7 @@ def server_start_api(config): # Ensure that the database is alive. This also handles checking stuff, # even if we don't own the db (which we shouldn't) - start_database(config) + start_database(config, check_revision=True) # Now just run the api process api = FractalWaitressApp(config) @@ -635,7 +635,10 @@ def server_upgrade_config(config_path): def server_user(args: argparse.Namespace, config: FractalConfig): user_command = args.user_command - pg_harness, storage = start_database(config) + + # Don't check revision here - it will be done in the SQLAlchemySocket constructor + start_database(config, check_revision=False) + storage = SQLAlchemySocket(config) def print_user_info(u: UserInfo): enabled = "True" if u.enabled else "False" @@ -742,7 +745,10 @@ def print_user_info(u: UserInfo): def server_role(args: argparse.Namespace, config: FractalConfig): role_command = args.role_command - pg_harness, storage = start_database(config) + + # Don't check revision here - it will be done in the SQLAlchemySocket constructor + start_database(config, check_revision=False) + storage = SQLAlchemySocket(config) def print_role_info(r: RoleInfo): print("-" * 80) @@ -772,7 +778,7 @@ def print_role_info(r: RoleInfo): def server_backup(args: argparse.Namespace, config: FractalConfig): - pg_harness, _ = start_database(config) + pg_harness = start_database(config, check_revision=True) db_size = pg_harness.database_size() pretty_size = pretty_bytes(db_size)