Skip to content

Commit

Permalink
[typing] fix remaining web admonitions of mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasjuhrich committed Sep 5, 2023
1 parent 579e2b0 commit f89cf8f
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 88 deletions.
6 changes: 3 additions & 3 deletions pycroft/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from datetime import timezone, tzinfo

import psycopg2.extensions
from sqlalchemy import create_engine as sqa_create_engine
from sqlalchemy import create_engine as sqa_create_engine, Connection
from sqlalchemy.future import Engine

from . import _all
Expand Down Expand Up @@ -61,7 +61,7 @@ def create_engine(connection_string, **kwargs) -> Engine:
return sqa_create_engine(connection_string, **kwargs)


def create_db_model(bind):
def create_db_model(bind: Connection) -> None:
"""Create all models in the database.
"""
# skip objects marked with "is_view"
Expand All @@ -70,7 +70,7 @@ def create_db_model(bind):
base.ModelBase.metadata.create_all(bind, tables=tables)


def drop_db_model(bind):
def drop_db_model(bind: Connection) -> None:
"""Drop all models from the database.
"""
# skip objects marked with "is_view"
Expand Down
5 changes: 5 additions & 0 deletions pycroft/model/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import overload, TypeVar, Callable, Any, TYPE_CHECKING

from sqlalchemy.orm import scoped_session
from sqlalchemy.sql.functions import AnsiFunction
from werkzeug.local import LocalProxy
import wrapt

Expand Down Expand Up @@ -77,3 +78,7 @@ def with_transaction(wrapped, instance, args, kwargs):

def utcnow() -> DateTimeTz:
return session.query(func.current_timestamp()).scalar()


def current_timestamp() -> AnsiFunction[DateTimeTz]:
return t.cast(AnsiFunction[DateTimeTz], func.current_timestamp())
5 changes: 3 additions & 2 deletions pycroft/model/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,15 @@ def validate_passwd_hash(self, _, value):
"not correct!"
return value

def check_password(self, plaintext_password):
def check_password(self, plaintext_password: str) -> bool:
"""verify a given plaintext password against the users passwd hash.
"""
return verify_password(plaintext_password, self.passwd_hash)

@property
def password(self):
# actually `NoReturn`, but mismatch to `setter` confuses mypy
def password(self) -> str:
"""Store a hash of a given plaintext passwd for the user."""
raise RuntimeError("Password can not be read, only set")

Expand Down
2 changes: 1 addition & 1 deletion stubs/flask_restful/reqparse.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class RequestParser:
trim: Incomplete
bundle_errors: Incomplete
def __init__(self, argument_class=..., namespace_class=..., trim: bool = ..., bundle_errors: bool = ...) -> None: ...
def add_argument(self, *args, **kwargs): ...
def add_argument(self, *args, **kwargs) -> RequestParser: ...
def parse_args(self, req: Incomplete | None = ..., strict: bool = ..., http_error_code: int = ...): ...
def copy(self): ...
def replace_argument(self, name, *args, **kwargs): ...
Expand Down
9 changes: 4 additions & 5 deletions web/api/v0/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flask_restful import Api, Resource as FlaskRestfulResource, abort, \
reqparse, inputs
from sqlalchemy.exc import IntegrityError
from sqlalchemy import select, func
from sqlalchemy import select
from sqlalchemy.orm import joinedload, selectinload

from pycroft.helpers import utc
Expand All @@ -33,6 +33,7 @@
from pycroft.model.facilities import Room
from pycroft.model.finance import Account, Split
from pycroft.model.host import IP, Interface, Host
from pycroft.model.session import current_timestamp
from pycroft.model.types import IPAddress, InvalidMACAddressException
from pycroft.model.user import User, IllegalEmailError, IllegalLoginError
from web.api.v0.helpers import parse_iso_date
Expand Down Expand Up @@ -105,10 +106,8 @@ def generate_user_data(user: User) -> Response:
step = timedelta(days=1)
traffic_history = func_traffic_history(
user.id,
# TODO what is the emitted sql statement?
# it seems to me that this expression returns `timestamp`, and not `timestamptz`
func.current_timestamp() - interval + step,
func.current_timestamp(),
current_timestamp() - interval + step,
current_timestamp(),
)

class _Entry(t.TypedDict):
Expand Down
2 changes: 1 addition & 1 deletion web/blueprints/finance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ def balance_json(account_id: int) -> ResponseReturnValue:

sum_exp: ColumnElement[int] = t.cast(
Over[int],
func.sum(Split.amount).over(order_by=Transaction.valid_on),
func.sum(Split.amount).over(order_by=Transaction.valid_on), # type: ignore[no-untyped-call]
)

if invert:
Expand Down
1 change: 0 additions & 1 deletion web/blueprints/user/tables.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import typing as t
import typing

from flask import url_for
from pydantic import BaseModel
Expand Down
13 changes: 8 additions & 5 deletions web/commands.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2021. The Pycroft Authors. See the AUTHORS file.
# This file is part of the Pycroft project and licensed under the terms of
# the Apache License, Version 2.0. See the LICENSE file for details
import typing as t
import os

import click
Expand All @@ -10,17 +11,19 @@
from pycroft.model import create_engine, drop_db_model


def register_commands(app: Flask):
def register_commands(app: Flask) -> None:
"""Register custom commands executable via `flask $command_name`."""

@app.cli.command('create-model', help="Create the database model.")
def create_model():
cli = t.cast(click.Group, app.cli)

@cli.command("create-model", help="Create the database model.")
def create_model() -> None:
engine = create_engine(os.getenv('PYCROFT_DB_URI'))
with engine.begin() as connection:
create_db_model(bind=connection)

@app.cli.command('drop-model', help="Drop the database model.")
def drop_model():
@cli.command("drop-model", help="Drop the database model.")
def drop_model() -> None:
engine = create_engine(os.getenv('PYCROFT_DB_URI'))
click.confirm(f'This will drop the whole database schema associated to {engine!r}.'
' Are you absolutely sure?', abort=True)
Expand Down
7 changes: 5 additions & 2 deletions web/form/widgets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import typing as t


import wtforms_widgets
from wtforms import ValidationError, Form

Expand All @@ -8,10 +11,10 @@
class UserIDField(wtforms_widgets.fields.core.StringField):
"""A User-ID Field """

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)

def __call__(self, **kwargs) -> None:
def __call__(self, **kwargs: t.Any) -> t.Any:
return super().__call__(
**kwargs
)
Expand Down
84 changes: 48 additions & 36 deletions web/template_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,31 @@
:copyright: (c) 2012 by AG DSN.
"""
import typing as t
import pathlib
from cmath import log
from datetime import datetime, date
from decimal import Decimal
from itertools import chain
from re import sub

import flask_babel
from flask import current_app, json, url_for
from flask import current_app, json, url_for, Flask
from jinja2 import pass_context
from jinja2.runtime import Context

from pycroft.helpers.i18n import localized, gettext
from pycroft.helpers.utc import ensure_tz
from pycroft.model import session

_filter_registry = {}


def template_filter(name):
def decorator(fn):
_TF = t.TypeVar("_TF", bound=t.Callable[..., t.Any])


def template_filter(name: str) -> t.Callable[[_TF], _TF]:
def decorator(fn: _TF) -> _TF:
_filter_registry[name] = fn
return fn
return decorator
Expand All @@ -48,7 +55,7 @@ class AssetNotFound(Exception):
# noinspection PyUnusedLocal
@template_filter("require")
@pass_context
def require(ctx: Context, asset: str, **kwargs) -> str:
def require(ctx: Context, asset: str, **kwargs: t.Any) -> str:
"""
Build an URL for an asset generated by webpack.
Expand All @@ -74,7 +81,7 @@ def require(ctx: Context, asset: str, **kwargs) -> str:
with path.open() as f:
asset_map = json.load(f)

def has_changed():
def has_changed() -> bool:
try:
return path.stat().st_mtime != mtime
except OSError:
Expand All @@ -86,47 +93,45 @@ def has_changed():
filename = asset_map[asset]
except KeyError:
raise AssetNotFound(f"Asset {asset} not found") from None
kwargs['filename'] = filename
return url_for('static', **kwargs)
kwargs["filename"] = filename
return url_for("static", **kwargs)


@template_filter("pretty_category")
def pretty_category_filter(category):
def pretty_category_filter(category: str) -> str:
"""Make pretty category names for flash messages, etc
"""
return _category_map.get(category, "Hinweis")


@template_filter("date")
def date_filter(dt, format=None):
def date_filter(dt: datetime | date | None, format: str | None = None) -> str:
"""Format date or datetime objects using Flask-Babel
:param datetime|date|None dt: a datetime object or None
:param str format: format as understood by Flask-Babel's format_datetime
:rtype: unicode
:param dt: a datetime object or None
:param format: format as understood by Flask-Babel's format_datetime
"""
if dt is None:
return "k/A"
return flask_babel.format_date(dt, format)
return t.cast(str, flask_babel.format_date(dt, format))


@template_filter("datetime")
def datetime_filter(dt, format=None):
def datetime_filter(dt: datetime | None, format: str | None = None) -> str:
"""Format datetime objects using Flask-Babel
:param datetime|None dt: a datetime object or None
:param str format: format as understood by Flask-Babel's format_datetime
:rtype: unicode
:param dt: a datetime object or None
:param format: format as understood by Flask-Babel's format_datetime
"""
if dt is None:
return "k/A"
if isinstance(dt, str):
return dt
return flask_babel.format_datetime(dt, format)
return t.cast(str, flask_babel.format_datetime(dt, format))


@template_filter("timesince")
def timesince_filter(dt, default="just now"):
def timesince_filter(dt: datetime | None, default: str = "just now") -> str:
"""
Returns string representing "time since" e.g.
3 days ago, 5 hours ago etc.
Expand All @@ -138,7 +143,7 @@ def timesince_filter(dt, default="just now"):
return "k/A"

now = session.utcnow()
diff = now - dt
diff = now - ensure_tz(dt)

periods = (
(diff.days / 365, "Jahr", "Jahre"),
Expand All @@ -149,16 +154,18 @@ def timesince_filter(dt, default="just now"):
(diff.seconds / 60, "Minute", "Minuten"),
(diff.seconds, "Sekunde", "Sekunden"),
)

for period, singular, plural in periods:

if period:
return f"vor {period:d} {singular if period == 1 else plural}"

return default
return next(
(
f"vor {period:d} {singular if period == 1 else plural}"
for period, singular, plural in periods
),
default,
)


def prefix_unit_filter(value, unit, factor, prefixes):
def prefix_unit_filter(
value: float | Decimal, unit: str, factor: int, prefixes: t.Iterable[str]
) -> str:
units = list(chain(unit, (p + unit for p in prefixes)))
if value > 0:
n = min(int(log(value, factor).real), len(units)-1)
Expand All @@ -169,17 +176,17 @@ def prefix_unit_filter(value, unit, factor, prefixes):


@template_filter("byte_size")
def byte_size_filter(value):
def byte_size_filter(value: float | Decimal) -> str:
return prefix_unit_filter(value, 'B', 1024, ['Ki', 'Mi', 'Gi', 'Ti'])


@template_filter("money")
def money_filter(amount):
def money_filter(amount: float | Decimal) -> str:
return (f"{amount:.2f}\u202f€").replace('.', ',')


@template_filter("icon")
def icon_filter(icon_class: str):
def icon_filter(icon_class: str) -> str:
if len(tokens := icon_class.split(maxsplit=1)) == 2:
prefix, icon = tokens
else:
Expand All @@ -189,7 +196,7 @@ def icon_filter(icon_class: str):


@template_filter("account_type")
def account_type_filter(account_type):
def account_type_filter(account_type: str) -> str:
types = {
"USER_ASSET": gettext("User account (asset)"),
"BANK_ASSET": gettext("Bank account (asset)"),
Expand All @@ -204,9 +211,14 @@ def account_type_filter(account_type):


@template_filter("transaction_type")
def transaction_type_filter(credit_debit_type):
def replacer(types):
return types and tuple(sub(r'[A-Z]+_(?=ASSET)', r'', t) for t in types)
def transaction_type_filter(credit_debit_type: tuple[str, str]) -> str:
def remove_prefix(account_type_name: str) -> str:
return sub(r"[A-Z]+_(?=ASSET)", r"", account_type_name)

def replacer(types: tuple[str, str]) -> tuple[str, str] | None:
if not types:
return None
return (remove_prefix(types[0]), remove_prefix(types[1]))

types = {
("ASSET", "LIABILITY"): gettext("Balance sheet extension"),
Expand Down Expand Up @@ -246,6 +258,6 @@ def host_traffic_filter(host):
"""


def register_filters(app):
def register_filters(app: Flask) -> None:
for name in _filter_registry:
app.jinja_env.filters[name] = _filter_registry[name]
Loading

0 comments on commit f89cf8f

Please sign in to comment.