diff --git a/pycroft/model/ddl.py b/pycroft/model/ddl.py index ae9b34690..ba1b57265 100644 --- a/pycroft/model/ddl.py +++ b/pycroft/model/ddl.py @@ -6,19 +6,23 @@ ~~~~~~~~~~~~~~~~~ """ import inspect +import typing as t from collections import OrderedDict from collections.abc import Iterable from functools import partial, cached_property -from sqlalchemy import event as sqla_event, schema, table +from sqlalchemy import event as sqla_event, schema, table, Constraint, Table from sqlalchemy.dialects import postgresql +from sqlalchemy.engine import Dialect from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql import ClauseElement, Selectable +from sqlalchemy.sql import ClauseElement, Selectable, ColumnCollection +from sqlalchemy.sql.compiler import Compiled +from sqlalchemy.sql.selectable import SelectBase from pycroft.model.session import with_transaction, session -def _join_tokens(*tokens): +def _join_tokens(*tokens: str | None) -> str: """ Join all elements that are not None :param tokens: @@ -27,10 +31,13 @@ def _join_tokens(*tokens): return ' '.join(token for token in tokens if token is not None) -def compile_if_clause(compiler, clause): +def compile_if_clause(compiler: Compiled, clause: t.Any) -> t.Any: if isinstance(clause, ClauseElement): - return str(clause.compile(compile_kwargs={'literal_binds': True}, - dialect=compiler.dialect)) + return str( + clause.compile( + compile_kwargs={"literal_binds": True}, dialect=compiler.dialect + ) + ) return compiler.sql_compiler.process(clause, literal_binds=True) return clause @@ -39,7 +46,14 @@ class DropConstraint(schema.DropConstraint): """ Extends SQLALchemy's DropConstraint with support for IF EXISTS """ - def __init__(self, element, if_exists=False, cascade=False, **kw): + + def __init__( + self, + element: Constraint, + if_exists: bool = False, + cascade: bool = False, + **kw: t.Any, + ): super().__init__(element, cascade, **kw) self.element = element self.if_exists = if_exists @@ -47,11 +61,11 @@ def __init__(self, element, if_exists=False, cascade=False, **kw): # noinspection PyUnusedLocal -@compiles(DropConstraint, 'postgresql') -def visit_drop_constraint(drop_constraint, compiler, **kw): +@compiles(DropConstraint, "postgresql") +def visit_drop_constraint(drop_constraint: DropConstraint, compiler: Compiled, **kw): constraint = drop_constraint.element - opt_if_exists = 'IF EXISTS' if drop_constraint.if_exists else None - opt_drop_behavior = 'CASCADE' if drop_constraint.cascade else None + opt_if_exists = "IF EXISTS" if drop_constraint.if_exists else None + opt_drop_behavior = "CASCADE" if drop_constraint.cascade else None table_name = compiler.preparer.format_table(constraint.table) constraint_name = compiler.preparer.quote(constraint.name) return _join_tokens( @@ -60,24 +74,33 @@ def visit_drop_constraint(drop_constraint, compiler, **kw): class Function(schema.DDLElement): - on = 'postgresql' - - def __init__(self, name, arguments, rtype, definition: str | Selectable, - volatility='volatile', - strict=False, leakproof=False, language='sql', quote_tag=''): + on = "postgresql" + + def __init__( + self, + name: str, + arguments: t.Iterable[str], + rtype: str, + definition: str | Selectable, + volatility: t.Literal["volatile", "stable", "immutable"] = "volatile", + strict: bool = False, + leakproof: bool = False, + language: str = "sql", + quote_tag: str = "", + ): """ Represents PostgreSQL function - :param str name: Name of the function (excluding arguments). - :param list arguments: Arguments of the function. A function + :param name: Name of the function (excluding arguments). + :param arguments: Arguments of the function. A function identifier of ``new_function(integer, integer)`` would result in ``arguments=['integer', 'integer']``. - :param str rtype: Return type + :param rtype: Return type :param definition: Definition - :param str volatility: Either 'volatile', 'stable', or + :param volatility: Either 'volatile', 'stable', or 'immutable' - :param bool strict: Function should be declared STRICT - :param bool leakproof: Function should be declared LEAKPROOF + :param strict: Function should be declared STRICT + :param leakproof: Function should be declared LEAKPROOF :param str language: Language the function is defined in :param str quote_tag: Dollar quote tag to enclose the function definition @@ -96,26 +119,29 @@ def __init__(self, name, arguments, rtype, definition: str | Selectable, self.quote_tag = quote_tag @cached_property - def definition(self): + def definition(self) -> str: if isinstance(self._definition, str): return inspect.cleandoc(self._definition) if isinstance(self._definition, Selectable): - return str(self._definition.compile( - dialect=postgresql.dialect(), - compile_kwargs={'literal_binds': True} - )) + return str( + self._definition.compile( + dialect=t.cast(type[Dialect], postgresql.dialect)(), + compile_kwargs={"literal_binds": True}, + ) + ) - raise ValueError(f"definition must be str or Selectable, not {type(self._definition)}") + raise ValueError( + f"definition must be str or Selectable, not {type(self._definition)}" + ) - def build_quoted_identifier(self, quoter): + def build_quoted_identifier(self, quoter: t.Callable[[str], str]) -> str: """Compile the function identifier from name and arguments. :param quoter: A callable that quotes the function name :returns: The compiled string, like ``"my_function_name"(integer, account_type)`` - :rtype: str """ return "{name}({args})".format( name=quoter(self.name), @@ -123,14 +149,14 @@ def build_quoted_identifier(self, quoter): ) - class CreateFunction(schema.DDLElement): """ Represents a CREATE FUNCTION DDL statement """ - on = 'postgresql' - def __init__(self, func, or_replace=False): + on = "postgresql" + + def __init__(self, func: Function, or_replace: bool = False): self.function = func self.or_replace = or_replace @@ -139,17 +165,20 @@ class DropFunction(schema.DDLElement): """ Represents a DROP FUNCTION DDL statement """ - on = 'postgresql' - def __init__(self, func, if_exists=False, cascade=False): + on = "postgresql" + + def __init__(self, func: Function, if_exists: bool = False, cascade: bool = False): self.function = func self.if_exists = if_exists self.cascade = cascade # noinspection PyUnusedLocal -@compiles(CreateFunction, 'postgresql') -def visit_create_function(element, compiler, **kw): +@compiles(CreateFunction, "postgresql") +def visit_create_function( + element: CreateFunction, compiler: Compiled, **kw: t.Any +) -> str: """ Compile a CREATE FUNCTION DDL statement for PostgreSQL """ @@ -170,8 +199,8 @@ def visit_create_function(element, compiler, **kw): # noinspection PyUnusedLocal -@compiles(DropFunction, 'postgresql') -def visit_drop_function(element, compiler, **kw): +@compiles(DropFunction, "postgresql") +def visit_drop_function(element: DropFunction, compiler: Compiled, **kw: t.Any) -> str: """ Compile a DROP FUNCTION DDL statement for PostgreSQL """ @@ -183,17 +212,26 @@ def visit_drop_function(element, compiler, **kw): class Rule(schema.DDLElement): - on = 'postgresql' - - def __init__(self, name, table, event, command_or_commands, - condition=None, do_instead=False): + on = "postgresql" + + def __init__( + self, + name: str, + table: Table, + event: str, + command_or_commands: str | t.Sequence[str], + condition: str | None = None, + do_instead: bool = False, + ) -> None: self.name = name self.table = table self.event = event self.condition = condition self.do_instead = do_instead - if (isinstance(command_or_commands, Iterable) and - not isinstance(command_or_commands, str)): + self.commands: tuple[str, ...] + if isinstance(command_or_commands, Iterable) and not isinstance( + command_or_commands, str + ): self.commands = tuple(command_or_commands) else: self.commands = (command_or_commands,) @@ -203,9 +241,10 @@ class CreateRule(schema.DDLElement): """ Represents a CREATE RULE DDL statement """ - on = 'postgresql' - def __init__(self, rule, or_replace=False): + on = "postgresql" + + def __init__(self, rule: Rule, or_replace: bool = False) -> None: self.rule = rule self.or_replace = or_replace @@ -214,9 +253,12 @@ class DropRule(schema.DDLElement): """ Represents a DROP RULE DDL statement """ - on = 'postgresql' - def __init__(self, rule, if_exists=False, cascade=False): + on = "postgresql" + + def __init__( + self, rule: Rule, if_exists: bool = False, cascade: bool = False + ) -> None: """ :param rule: :param if_exists: @@ -228,8 +270,8 @@ def __init__(self, rule, if_exists=False, cascade=False): # noinspection PyUnusedLocal -@compiles(CreateRule, 'postgresql') -def visit_create_rule(element, compiler, **kw): +@compiles(CreateRule, "postgresql") +def visit_create_rule(element: CreateRule, compiler: Compiled, **kw: t.Any) -> str: """ Compile a CREATE RULE DDL statement for PostgreSQL. """ @@ -252,8 +294,8 @@ def visit_create_rule(element, compiler, **kw): # noinspection PyUnusedLocal -@compiles(DropRule, 'postgresql') -def visit_drop_rule(element, compiler, **kw): +@compiles(DropRule, "postgresql") +def visit_drop_rule(element: DropRule, compiler: Compiled, **kw: t.Any) -> str: """ Compile a DROP RULE DDL statement for PostgreSQL """ @@ -267,15 +309,23 @@ def visit_drop_rule(element, compiler, **kw): opt_drop_behavior) +# TODO add type hints class Trigger(schema.DDLElement): - def __init__(self, name, table, events, function_call, when="AFTER"): + def __init__( + self, + name: str, + table: Table, + events: t.Sequence[str], + function_call: str, + when: t.Literal["BEFORE", "AFTER", "INSTEAD OF"] = "AFTER", + ) -> None: """Construct a trigger - :param str name: Name of the trigger + :param name: Name of the trigger :param table: Table the trigger is for - :param iterable[str] events: list of events (INSERT, UPDATE, DELETE) - :param str function_call: call of the trigger function - :param str when: Mode of execution. Must be one of ``BEFORE``, ``AFTER``, ``INSTEAD OF`` + :param events: list of events (INSERT, UPDATE, DELETE) + :param function_call: call of the trigger function + :param when: Mode of execution """ self.name = name self.table = table @@ -287,7 +337,13 @@ def __init__(self, name, table, events, function_call, when="AFTER"): class ConstraintTrigger(Trigger): - def __init__(self, *args, deferrable=False, initially_deferred=False, **kwargs): + def __init__( + self, + *args: t.Any, + deferrable: bool = False, + initially_deferred: bool = False, + **kwargs: t.Any, + ) -> None: """Construct a Constraint Trigger :param deferrable: Constraint can be deferred @@ -302,9 +358,9 @@ def __init__(self, *args, deferrable=False, initially_deferred=False, **kwargs): class CreateTrigger(schema.DDLElement): - on = 'postgresql' + on = "postgresql" - def __init__(self, trigger): + def __init__(self, trigger: Trigger) -> None: self.trigger = trigger @@ -312,9 +368,10 @@ class CreateConstraintTrigger(schema.DDLElement): """ Represents a CREATE CONSTRAINT TRIGGER DDL statement """ - on = 'postgresql' - def __init__(self, constraint_trigger): + on = "postgresql" + + def __init__(self, constraint_trigger: ConstraintTrigger) -> None: self.constraint_trigger = constraint_trigger @@ -322,17 +379,22 @@ class DropTrigger(schema.DDLElement): """ Represents a DROP TRIGGER DDL statement. """ - on = 'postgresql' - def __init__(self, trigger, if_exists=False, cascade=False): + on = "postgresql" + + def __init__( + self, trigger: Trigger, if_exists: bool = False, cascade: bool = False + ) -> None: self.trigger = trigger self.if_exists = if_exists self.cascade = cascade # noinspection PyUnusedLocal -@compiles(CreateConstraintTrigger, 'postgresql') -def create_add_constraint_trigger(element, compiler, **kw): +@compiles(CreateConstraintTrigger, "postgresql") +def create_add_constraint_trigger( + element: CreateConstraintTrigger, compiler: Compiled, **kw: t.Any +) -> str: """ Compile a CREATE CONSTRAINT TRIGGER DDL statement for PostgreSQL """ @@ -350,8 +412,8 @@ def create_add_constraint_trigger(element, compiler, **kw): # noinspection PyUnusedLocal -@compiles(CreateTrigger, 'postgresql') -def create_add_trigger(element, compiler, **kw): +@compiles(CreateTrigger, "postgresql") +def create_add_trigger(element: CreateTrigger, compiler: Compiled, **kw: t.Any) -> str: """ Compile a CREATE CONSTRAINT TRIGGER DDL statement for PostgreSQL """ @@ -365,8 +427,8 @@ def create_add_trigger(element, compiler, **kw): # noinspection PyUnusedLocal -@compiles(DropTrigger, 'postgresql') -def visit_drop_trigger(element, compiler, **kw): +@compiles(DropTrigger, "postgresql") +def visit_drop_trigger(element: DropTrigger, compiler: Compiled, **kw: t.Any) -> str: """ Compile a DROP TRIGGER DDL statement for PostgreSQL """ @@ -381,12 +443,16 @@ def visit_drop_trigger(element, compiler, **kw): class View(schema.DDLElement): - def __init__(self, name, query, - column_names=None, - temporary=False, - view_options=None, - check_option=None, - materialized=False): + def __init__( + self, + name: str, + query: SelectBase, + column_names: t.Sequence[str] = None, + temporary: bool = False, + view_options: t.Mapping[str, t.Any] = None, + check_option: t.Literal["local", "cascaded"] | None = None, + materialized: bool = False, + ) -> None: """DDL Element representing a VIEW :param name: The name of the view @@ -428,12 +494,11 @@ def _init_table_columns(self): "The given column_names must coincide with the implicit columns of the query:" f" {my_column_names!r} != {query_column_names!r}" ) - for c in self.query.selected_columns: + for c in t.cast(ColumnCollection, self.query.selected_columns): # _make_proxy doesn't attach the column to the selectable (`self.table`) anymore # since sqla commit:aceefb508ccd0911f52ff0e50324b3fefeaa3f16 (before 1.4.0) key, col = c._make_proxy(self.table) - self.table._columns.add(col, key=key) - + self.table._columns.add(col, key=key) # type: ignore @with_transaction def refresh(self, concurrently=False): @@ -442,29 +507,35 @@ def refresh(self, concurrently=False): if not self.materialized: raise ValueError("Cannot refresh a non-materialized view") - _con = 'CONCURRENTLY ' if concurrently else '' - session.execute('REFRESH MATERIALIZED VIEW ' + _con + self.name) + _con = "CONCURRENTLY " if concurrently else "" + session.execute("REFRESH MATERIALIZED VIEW " + _con + self.name) # type: ignore class CreateView(schema.DDLElement): - def __init__(self, view, or_replace=False, if_not_exists=False): + def __init__( + self, view: View, or_replace: bool = False, if_not_exists: bool = False + ) -> None: self.view = view self.or_replace = or_replace self.if_not_exists = if_not_exists class DropView(schema.DDLElement): - def __init__(self, view, if_exists=False, cascade=False): + def __init__( + self, view: View, if_exists: bool = False, cascade: bool = False + ) -> None: self.view = view self.if_exists = if_exists self.cascade = cascade # noinspection PyUnusedLocal -@compiles(CreateView, 'postgresql') -def visit_create_view(element: CreateView, compiler, **kw): +@compiles(CreateView, "postgresql") +def visit_create_view(element: CreateView, compiler: Compiled, **kw: t.Any) -> str: view = element.view - opt_or_replace = "OR REPLACE" if element.or_replace and not view.materialized else None + opt_or_replace = ( + "OR REPLACE" if element.or_replace and not view.materialized else None + ) opt_temporary = "TEMPORARY" if view.temporary else None if view.column_names is not None: quoted_column_names = map(compiler.preparer.quote, view.column_names) @@ -496,8 +567,8 @@ def visit_create_view(element: CreateView, compiler, **kw): # noinspection PyUnusedLocal -@compiles(DropView, 'postgresql') -def visit_drop_view(element, compiler, **kw): +@compiles(DropView, "postgresql") +def visit_drop_view(element: DropView, compiler: Compiled, **kw: t.Any) -> str: view = element.view opt_if_exists = "IF EXISTS" if element.if_exists else None opt_drop_behavior = "CASCADE" if element.cascade else None @@ -527,39 +598,89 @@ class DDLManager: """ def __init__(self): - self.objects = [] - - def add(self, target, create_ddl, drop_ddl, dialect=None): + self.objects: list[tuple[object, schema.DDLElement, schema.DDLElement]] = [] + + def add( + self, + target: Table, + create_ddl: schema.DDLElement, + drop_ddl: schema.DDLElement, + dialect: str | None = None, + ): if dialect: - create_ddl = create_ddl.execute_if(dialect=dialect) - drop_ddl = drop_ddl.execute_if(dialect=dialect) + create_ddl = t.cast( + schema.DDLElement, create_ddl.execute_if(dialect=dialect) + ) + drop_ddl = t.cast(schema.DDLElement, drop_ddl.execute_if(dialect=dialect)) self.objects.append((target, create_ddl, drop_ddl)) - def add_constraint(self, table, constraint, dialect=None): - self.add(table, schema.AddConstraint(constraint), - DropConstraint(constraint, if_exists=True), dialect=dialect) + def add_constraint( + self, table: Table, constraint: Constraint, dialect: str | None = None + ) -> None: + self.add( + table, + schema.AddConstraint(constraint), + DropConstraint(constraint, if_exists=True), + dialect=dialect, + ) - def add_function(self, table, func, dialect=None): - self.add(table, CreateFunction(func, or_replace=True), - DropFunction(func, if_exists=True), dialect=dialect) + def add_function( + self, table: Table, func: Function, dialect: str | None = None + ) -> None: + self.add( + table, + CreateFunction(func, or_replace=True), + DropFunction(func, if_exists=True), + dialect=dialect, + ) - def add_rule(self, table, rule, dialect=None): - self.add(table, CreateRule(rule, or_replace=True), - DropRule(rule, if_exists=True), dialect=dialect) + def add_rule(self, table: Table, rule: Rule, dialect: str | None = None) -> None: + self.add( + table, + CreateRule(rule, or_replace=True), + DropRule(rule, if_exists=True), + dialect=dialect, + ) - def add_trigger(self, table, trigger, dialect=None): - self.add(table, CreateTrigger(trigger), - DropTrigger(trigger, if_exists=True), dialect=dialect) + def add_trigger( + self, table: Table, trigger: Trigger, dialect: str | None = None + ) -> None: + self.add( + table, + CreateTrigger(trigger), + DropTrigger(trigger, if_exists=True), + dialect=dialect, + ) - def add_constraint_trigger(self, table, constraint_trigger, dialect=None): - self.add(table, CreateConstraintTrigger(constraint_trigger), - DropTrigger(constraint_trigger, if_exists=True), dialect=dialect) + def add_constraint_trigger( + self, + table: Table, + constraint_trigger: ConstraintTrigger, + dialect: str | None = None, + ) -> None: + self.add( + table, + CreateConstraintTrigger(constraint_trigger), + DropTrigger(constraint_trigger, if_exists=True), + dialect=dialect, + ) - def add_view(self, table, view, dialect=None, or_replace=True, if_not_exists=True): - self.add(table, CreateView(view, or_replace=or_replace, if_not_exists=if_not_exists), - DropView(view, if_exists=True), dialect=dialect) + def add_view( + self, + table: Table, + view: View, + dialect: str | None = None, + or_replace: bool = True, + if_not_exists: bool = True, + ) -> None: + self.add( + table, + CreateView(view, or_replace=or_replace, if_not_exists=if_not_exists), + DropView(view, if_exists=True), + dialect=dialect, + ) - def register(self): + def register(self) -> None: for target, create_ddl, _drop_ddl in self.objects: sqla_event.listen(target, 'after_create', create_ddl) for target, _create_ddl, drop_ddl in reversed(self.objects):