From 579e2b058c1b5a10bf1602925618addb023493bf Mon Sep 17 00:00:00 2001 From: Lukas Juhrich Date: Tue, 5 Sep 2023 01:49:44 +0200 Subject: [PATCH] [typing] add typing to `web.api` and dependent functions --- pycroft/lib/user.py | 6 +- pycroft/model/finance.py | 12 ++- pycroft/model/traffic.py | 6 +- stubs/flask_restful/__init__.pyi | 54 ++++++++++ stubs/flask_restful/__version__.pyi | 0 stubs/flask_restful/fields.pyi | 62 +++++++++++ stubs/flask_restful/inputs.pyi | 31 ++++++ .../representations/__init__.pyi | 0 stubs/flask_restful/representations/json.pyi | 4 + stubs/flask_restful/reqparse.pyi | 42 ++++++++ stubs/flask_restful/utils/__init__.pyi | 6 ++ stubs/flask_restful/utils/cors.pyi | 3 + stubs/flask_restful/utils/crypto.pyi | 2 + web/api/__init__.py | 4 +- web/api/v0/__init__.py | 100 +++++++++++------- web/api/v0/helpers.py | 4 +- web/blueprints/finance/__init__.py | 2 +- 17 files changed, 286 insertions(+), 52 deletions(-) create mode 100644 stubs/flask_restful/__init__.pyi create mode 100644 stubs/flask_restful/__version__.pyi create mode 100644 stubs/flask_restful/fields.pyi create mode 100644 stubs/flask_restful/inputs.pyi create mode 100644 stubs/flask_restful/representations/__init__.pyi create mode 100644 stubs/flask_restful/representations/json.pyi create mode 100644 stubs/flask_restful/reqparse.pyi create mode 100644 stubs/flask_restful/utils/__init__.pyi create mode 100644 stubs/flask_restful/utils/cors.pyi create mode 100644 stubs/flask_restful/utils/crypto.pyi diff --git a/pycroft/lib/user.py b/pycroft/lib/user.py index e8367cff8..b2d1cd22d 100644 --- a/pycroft/lib/user.py +++ b/pycroft/lib/user.py @@ -17,7 +17,7 @@ from difflib import SequenceMatcher from typing import Iterable -from sqlalchemy import func, select, Boolean, String +from sqlalchemy import func, select, Boolean, String, ColumnElement from pycroft import config, property from pycroft.helpers import user as user_helper, utc @@ -681,7 +681,9 @@ def edit_address( def traffic_history( - user_id: int, start: DateTimeTz, end: DateTimeTz + user_id: int, + start: DateTimeTz | ColumnElement[DateTimeTz], + end: DateTimeTz | ColumnElement[DateTimeTz], ) -> list[TrafficHistoryEntry]: result = session.session.execute( select("*") diff --git a/pycroft/model/finance.py b/pycroft/model/finance.py index 5d236fc89..e1bb95462 100644 --- a/pycroft/model/finance.py +++ b/pycroft/model/finance.py @@ -9,9 +9,10 @@ import datetime import typing as t from datetime import timedelta, date +from decimal import Decimal from math import fabs -from sqlalchemy import ForeignKey, event, func, select, Enum, ColumnElement +from sqlalchemy import ForeignKey, event, func, select, Enum, ColumnElement, Select from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import relationship, object_session, Mapped, mapped_column from sqlalchemy.schema import CheckConstraint, ForeignKeyConstraint, UniqueConstraint @@ -25,6 +26,7 @@ from .exc import PycroftModelException from .type_aliases import str127, str255, datetime_tz_onupdate from ..helpers import utc +from ..helpers.utc import DateTimeTz manager = ddl.DDLManager() @@ -361,14 +363,14 @@ class BankAccount(IntegerIdModel): # /backrefs @hybrid_property - def balance(self): + def _balance(self) -> Decimal: return object_session(self).execute( select(func.coalesce(func.sum(BankAccountActivity.amount), 0)) .where(BankAccountActivity.bank_account_id == self.id) ).scalar() - @balance.expression - def balance(cls): + @_balance.expression + def balance(cls) -> Select[tuple[Decimal]]: return select( [func.coalesce(func.sum(BankAccountActivity.amount), 0)] ).where( @@ -376,7 +378,7 @@ def balance(cls): ).label("balance") @hybrid_property - def last_imported_at(self): + def last_imported_at(self) -> DateTimeTz: return object_session(self).execute( select(func.max(BankAccountActivity.imported_at)) .where(BankAccountActivity.bank_account_id == self.id) diff --git a/pycroft/model/traffic.py b/pycroft/model/traffic.py index eeb473aac..bd6ff1c38 100644 --- a/pycroft/model/traffic.py +++ b/pycroft/model/traffic.py @@ -19,6 +19,7 @@ select, cast, TEXT, + ColumnElement, ) from sqlalchemy.orm import relationship, Query, Mapped, mapped_column from sqlalchemy.sql.selectable import TableValuedAlias @@ -224,7 +225,10 @@ def traffic_history_query(): def traffic_history( - user_id: int, start: utc.DateTimeTz, end: utc.DateTimeTz, name='traffic_history' + user_id: int, + start: utc.DateTimeTz | ColumnElement[utc.DateTimeTz], + end: utc.DateTimeTz | ColumnElement[utc.DateTimeTz], + name="traffic_history", ) -> TableValuedAlias: """A sqlalchemy `func` wrapper for the `evaluate_properties` PSQL function. diff --git a/stubs/flask_restful/__init__.pyi b/stubs/flask_restful/__init__.pyi new file mode 100644 index 000000000..b5ebed45e --- /dev/null +++ b/stubs/flask_restful/__init__.pyi @@ -0,0 +1,54 @@ +from _typeshed import Incomplete + +from flask import Response +from flask.views import MethodView + +def abort(http_status_code, **kwargs) -> None: ... + +class Api: + representations: Incomplete + urls: Incomplete + prefix: Incomplete + default_mediatype: Incomplete + decorators: Incomplete + catch_all_404s: Incomplete + serve_challenge_on_401: Incomplete + url_part_order: Incomplete + errors: Incomplete + blueprint_setup: Incomplete + endpoints: Incomplete + resources: Incomplete + app: Incomplete + blueprint: Incomplete + def __init__(self, app: Incomplete | None = ..., prefix: str = ..., default_mediatype: str = ..., decorators: Incomplete | None = ..., catch_all_404s: bool = ..., serve_challenge_on_401: bool = ..., url_part_order: str = ..., errors: Incomplete | None = ...) -> None: ... + def init_app(self, app) -> None: ... + def owns_endpoint(self, endpoint): ... + def error_router(self, original_handler, e): ... + def handle_error(self, e: Exception) -> Response: ... + def mediatypes_method(self): ... + def add_resource(self, resource, *urls, **kwargs) -> None: ... + def resource(self, *urls, **kwargs): ... + def output(self, resource): ... + def url_for(self, resource, **values): ... + def make_response(self, data, *args, **kwargs): ... + def mediatypes(self): ... + def representation(self, mediatype): ... + def unauthorized(self, response): ... + +class Resource(MethodView): + representations: Incomplete + method_decorators: Incomplete + def dispatch_request(self, *args, **kwargs): ... + +def marshal(data, fields, envelope: Incomplete | None = ...): ... + +class marshal_with: + fields: Incomplete + envelope: Incomplete + def __init__(self, fields, envelope: Incomplete | None = ...) -> None: ... + def __call__(self, f): ... + +class marshal_with_field: + field: Incomplete + def __init__(self, field) -> None: ... + def __call__(self, f): ... diff --git a/stubs/flask_restful/__version__.pyi b/stubs/flask_restful/__version__.pyi new file mode 100644 index 000000000..e69de29bb diff --git a/stubs/flask_restful/fields.pyi b/stubs/flask_restful/fields.pyi new file mode 100644 index 000000000..5bc2a39f7 --- /dev/null +++ b/stubs/flask_restful/fields.pyi @@ -0,0 +1,62 @@ +from _typeshed import Incomplete + +class MarshallingException(Exception): + def __init__(self, underlying_exception) -> None: ... + +class Raw: + attribute: Incomplete + default: Incomplete + def __init__(self, default: Incomplete | None = ..., attribute: Incomplete | None = ...) -> None: ... + def format(self, value): ... + def output(self, key, obj): ... + +class Nested(Raw): + nested: Incomplete + allow_null: Incomplete + def __init__(self, nested, allow_null: bool = ..., **kwargs) -> None: ... + def output(self, key, obj): ... + +class List(Raw): + container: Incomplete + def __init__(self, cls_or_instance, **kwargs) -> None: ... + def format(self, value): ... + def output(self, key, data): ... + +class String(Raw): + def format(self, value): ... + +class Integer(Raw): + def __init__(self, default: int = ..., **kwargs) -> None: ... + def format(self, value): ... + +class Boolean(Raw): + def format(self, value): ... + +class FormattedString(Raw): + src_str: Incomplete + def __init__(self, src_str) -> None: ... + def output(self, key, obj): ... + +class Url(Raw): + endpoint: Incomplete + absolute: Incomplete + scheme: Incomplete + def __init__(self, endpoint: Incomplete | None = ..., absolute: bool = ..., scheme: Incomplete | None = ..., **kwargs) -> None: ... + def output(self, key, obj): ... + +class Float(Raw): + def format(self, value): ... + +class Arbitrary(Raw): + def format(self, value): ... + +class DateTime(Raw): + dt_format: Incomplete + def __init__(self, dt_format: str = ..., **kwargs) -> None: ... + def format(self, value): ... + +class Fixed(Raw): + precision: Incomplete + def __init__(self, decimals: int = ..., **kwargs) -> None: ... + def format(self, value): ... +Price = Fixed diff --git a/stubs/flask_restful/inputs.pyi b/stubs/flask_restful/inputs.pyi new file mode 100644 index 000000000..1700eac47 --- /dev/null +++ b/stubs/flask_restful/inputs.pyi @@ -0,0 +1,31 @@ +from _typeshed import Incomplete +from calendar import timegm as timegm + +START_OF_DAY: Incomplete +END_OF_DAY: Incomplete +url_regex: Incomplete + +def url(value): ... + +class regex: + pattern: Incomplete + re: Incomplete + def __init__(self, pattern, flags: int = ...) -> None: ... + def __call__(self, value): ... + def __deepcopy__(self, memo): ... + +def iso8601interval(value, argument: str = ...): ... +def date(value): ... +def natural(value, argument: str = ...): ... +def positive(value, argument: str = ...): ... + +class int_range: + low: Incomplete + high: Incomplete + argument: Incomplete + def __init__(self, low, high, argument: str = ...) -> None: ... + def __call__(self, value): ... + +def boolean(value): ... +def datetime_from_rfc822(datetime_str): ... +def datetime_from_iso8601(datetime_str): ... diff --git a/stubs/flask_restful/representations/__init__.pyi b/stubs/flask_restful/representations/__init__.pyi new file mode 100644 index 000000000..e69de29bb diff --git a/stubs/flask_restful/representations/json.pyi b/stubs/flask_restful/representations/json.pyi new file mode 100644 index 000000000..9b660eaf5 --- /dev/null +++ b/stubs/flask_restful/representations/json.pyi @@ -0,0 +1,4 @@ +from _typeshed import Incomplete +from flask_restful.utils import PY3 as PY3 + +def output_json(data, code, headers: Incomplete | None = ...): ... diff --git a/stubs/flask_restful/reqparse.pyi b/stubs/flask_restful/reqparse.pyi new file mode 100644 index 000000000..492877e06 --- /dev/null +++ b/stubs/flask_restful/reqparse.pyi @@ -0,0 +1,42 @@ +from _typeshed import Incomplete + +class Namespace(dict): + def __getattr__(self, name): ... + def __setattr__(self, name, value) -> None: ... + +text_type: Incomplete + +class Argument: + name: Incomplete + default: Incomplete + dest: Incomplete + required: Incomplete + ignore: Incomplete + location: Incomplete + type: Incomplete + choices: Incomplete + action: Incomplete + help: Incomplete + case_sensitive: Incomplete + operators: Incomplete + store_missing: Incomplete + trim: Incomplete + nullable: Incomplete + def __init__(self, name, default: Incomplete | None = ..., dest: Incomplete | None = ..., required: bool = ..., ignore: bool = ..., type=..., location=..., choices=..., action: str = ..., help: Incomplete | None = ..., operators=..., case_sensitive: bool = ..., store_missing: bool = ..., trim: bool = ..., nullable: bool = ...) -> None: ... + def source(self, request): ... + def convert(self, value, op): ... + def handle_validation_error(self, error, bundle_errors): ... + def parse(self, request, bundle_errors: bool = ...): ... + +class RequestParser: + args: Incomplete + argument_class: Incomplete + namespace_class: Incomplete + trim: Incomplete + bundle_errors: Incomplete + def __init__(self, argument_class=..., namespace_class=..., trim: bool = ..., bundle_errors: bool = ...) -> None: ... + def add_argument(self, *args, **kwargs): ... + def parse_args(self, req: Incomplete | None = ..., strict: bool = ..., http_error_code: int = ...): ... + def copy(self): ... + def replace_argument(self, name, *args, **kwargs): ... + def remove_argument(self, name): ... diff --git a/stubs/flask_restful/utils/__init__.pyi b/stubs/flask_restful/utils/__init__.pyi new file mode 100644 index 000000000..3b8b4d99a --- /dev/null +++ b/stubs/flask_restful/utils/__init__.pyi @@ -0,0 +1,6 @@ +from _typeshed import Incomplete + +PY3: Incomplete + +def http_status_message(code): ... +def unpack(value): ... diff --git a/stubs/flask_restful/utils/cors.pyi b/stubs/flask_restful/utils/cors.pyi new file mode 100644 index 000000000..3fe55ef28 --- /dev/null +++ b/stubs/flask_restful/utils/cors.pyi @@ -0,0 +1,3 @@ +from _typeshed import Incomplete + +def crossdomain(origin: Incomplete | None = ..., methods: Incomplete | None = ..., headers: Incomplete | None = ..., expose_headers: Incomplete | None = ..., max_age: int = ..., attach_to_all: bool = ..., automatic_options: bool = ..., credentials: bool = ...): ... diff --git a/stubs/flask_restful/utils/crypto.pyi b/stubs/flask_restful/utils/crypto.pyi new file mode 100644 index 000000000..6a70664c2 --- /dev/null +++ b/stubs/flask_restful/utils/crypto.pyi @@ -0,0 +1,2 @@ +def encrypt(plaintext_data, key, seed): ... +def decrypt(encrypted_data, key, seed): ... diff --git a/web/api/__init__.py b/web/api/__init__.py index 4e1900bb7..7e22787b7 100644 --- a/web/api/__init__.py +++ b/web/api/__init__.py @@ -1,4 +1,5 @@ from flask import Blueprint, Flask +from flask.typing import ResponseReturnValue from . import v0 @@ -8,5 +9,6 @@ app_for_sphinx = Flask(__name__) app_for_sphinx.register_blueprint(bp, url_prefix="/api/v0") -def errorpage(e): + +def errorpage(e: Exception) -> ResponseReturnValue: return v0.api.handle_error(e) diff --git a/web/api/v0/__init__.py b/web/api/v0/__init__.py index 209aa4174..c2ae05aa6 100644 --- a/web/api/v0/__init__.py +++ b/web/api/v0/__init__.py @@ -1,7 +1,10 @@ -from datetime import timedelta, datetime +import typing as t +from decimal import Decimal +from datetime import timedelta, datetime, date from functools import wraps -from flask import jsonify, request, current_app +from flask import jsonify, request, current_app, Response +from flask.typing import ResponseReturnValue from flask_restful import Api, Resource as FlaskRestfulResource, abort, \ reqparse, inputs from sqlalchemy.exc import IntegrityError @@ -37,7 +40,7 @@ api = Api() -def parse_authorization_header(value): +def parse_authorization_header(value: str | None) -> str | None: if not value: return None @@ -48,9 +51,13 @@ def parse_authorization_header(value): return None -def authenticate(func): - @wraps(func) - def wrapper(*args, **kwargs): +_P = t.ParamSpec("_P") +_TF = t.Callable[_P, ResponseReturnValue] + + +def authenticate(func: _TF) -> _TF: + @t.cast(t.Callable[[_TF], _TF], wraps(func)) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> ResponseReturnValue: auth = request.headers.get('authorization') api_key = parse_authorization_header(auth) @@ -69,28 +76,28 @@ class Resource(FlaskRestfulResource): method_decorators = [authenticate] -def get_user_or_404(user_id): - user = User.get(user_id) +def get_user_or_404(user_id: int) -> User: + user = session.session.get(User, user_id) if user is None: abort(404, message=f"User {user_id} does not exist") return user -def get_authenticated_user(user_id, password): +def get_authenticated_user(user_id: int, password: str) -> User: user = get_user_or_404(user_id) if user is None or not user.check_password(password): abort(401, message="Authentication failed") return user -def get_interface_or_404(interface_id): - interface = Interface.get(interface_id) +def get_interface_or_404(interface_id: int) -> Interface: + interface = session.session.get(Interface, interface_id) if interface is None: abort(404, message=f"Interface {interface_id} does not exist") return interface -def generate_user_data(user): +def generate_user_data(user: User) -> Response: props = {prop.property_name for prop in user.current_properties} user_status = status(user) @@ -98,15 +105,26 @@ def generate_user_data(user): step = timedelta(days=1) traffic_history = func_traffic_history( user.id, + # TODO what is the emitted sql statement? + # it seems to me that this expression returns `timestamp`, and not `timestamptz` func.current_timestamp() - interval + step, - func.current_timestamp()) + func.current_timestamp(), + ) - finance_history = [{ - 'valid_on': split.transaction.valid_on, - # Invert amount, to display it from the user's point of view - 'amount': -split.amount, - 'description': Message.from_json(split.transaction.description).localize() - } for split in user.account.splits] + class _Entry(t.TypedDict): + valid_on: date + amount: int | Decimal + description: str + + finance_history: list[_Entry] = [ + { + "valid_on": split.transaction.valid_on, + # Invert amount, to display it from the user's point of view + "amount": -split.amount, + "description": Message.from_json(split.transaction.description).localize(), + } + for split in user.account.splits + ] finance_history = sorted(finance_history, key=lambda e: e['valid_on'], reverse=True) @@ -160,7 +178,7 @@ def generate_user_data(user): class UserResource(Resource): - def get(self, user_id): + def get(self, user_id: int) -> Response: user = get_user_or_404(user_id) return generate_user_data(user) @@ -169,7 +187,7 @@ def get(self, user_id): class ChangeEmailResource(Resource): - def post(self, user_id): + def post(self, user_id: int) -> ResponseReturnValue: parser = reqparse.RequestParser() parser.add_argument('password', type=str, required=True) parser.add_argument('new_email', type=str, required=True) @@ -189,7 +207,7 @@ def post(self, user_id): class ChangePasswordResource(Resource): - def post(self, user_id): + def post(self, user_id: int) -> ResponseReturnValue: parser = reqparse.RequestParser() parser.add_argument('password', dest='old_password', required=True) parser.add_argument('new_password', dest='new_password', required=True) @@ -205,7 +223,7 @@ def post(self, user_id): class FinanceHistoryResource(Resource): - def get(self, user_id): + def get(self, user_id: int) -> ResponseReturnValue: user = get_user_or_404(user_id) return jsonify([ {'valid_on': s.transaction.valid_on.isoformat(), 'amount': s.amount} @@ -218,7 +236,7 @@ def get(self, user_id): class AuthenticationResource(Resource): - def post(self): + def post(self) -> ResponseReturnValue: auth_parser = reqparse.RequestParser() auth_parser.add_argument('login', dest='login', required=True) auth_parser.add_argument('password', dest='password', required=True) @@ -235,7 +253,7 @@ def post(self): class UserByIPResource(Resource): - def get(self): + def get(self) -> ResponseReturnValue: ipv4 = request.args.get('ip', IPAddress) user = session.session.scalars( @@ -272,7 +290,7 @@ def get(self): class UserInterfaceResource(Resource): - def post(self, user_id, interface_id): + def post(self, user_id: int, interface_id: int) -> ResponseReturnValue: parser = reqparse.RequestParser() parser.add_argument('password', dest='password', required=True) parser.add_argument('mac', dest='mac', required=True) @@ -306,7 +324,7 @@ def post(self, user_id, interface_id): class ActivateNetworkAccessResource(Resource): - def post(self, user_id): + def post(self, user_id: int) -> ResponseReturnValue: parser = reqparse.RequestParser() parser.add_argument('password', dest='password', required=True) parser.add_argument('birthdate', dest='birthdate', required=True) @@ -356,7 +374,7 @@ def post(self, user_id): class TerminateMembershipResource(Resource): - def get(self, user_id): + def get(self, user_id: int) -> ResponseReturnValue: """ :param user_id: The ID of the user :return: The estimated balance of the given end_date @@ -375,7 +393,7 @@ def get(self, user_id): return jsonify(estimated_balance=estimated_balance) - def post(self, user_id): + def post(self, user_id: int) -> ResponseReturnValue: """ Terminate the membership on the given date @@ -410,7 +428,7 @@ def post(self, user_id): return "Membership termination scheduled." - def delete(self, user_id): + def delete(self, user_id: int) -> ResponseReturnValue: """ Cancel termination of a membership @@ -440,7 +458,7 @@ def delete(self, user_id): class ResetWifiPasswordResource(Resource): - def patch(self, user_id): + def patch(self, user_id: int) -> ResponseReturnValue: """ Reset the wifi password @@ -461,7 +479,7 @@ def patch(self, user_id): class RegistrationResource(Resource): - def get(self): + def get(self) -> ResponseReturnValue: """ Get the newest tenancy for the supplied user data, or an error 404 if not found. @@ -532,7 +550,7 @@ def get(self): 'room': newest_tenancy.room.level_and_number }) - def post(self): + def post(self) -> int: """ Create a member request """ @@ -554,7 +572,7 @@ def post(self): swdd_person_id = None if args.room_id is not None: - room = Room.get(args.room_id) + room = session.session.get(Room, args.room_id) if room is None: abort(404, message="Invalid room", code="invalid_room") @@ -596,8 +614,10 @@ def post(self): abort(400, message="The move-in date is invalid", code="move_in_date_invalid") else: session.session.commit() - return mr.id + raise AssertionError( + "unreachable" + ) # the `abort`s from `flask_restful` don't return `NoReturn` api.add_resource(RegistrationResource, @@ -605,12 +625,12 @@ def post(self): class EmailConfirmResource(Resource): - def get(self): + def get(self) -> ResponseReturnValue: parser = reqparse.RequestParser() parser.add_argument('user_id', required=True, type=int) args = parser.parse_args() - user = User.get(args.user_id) + user = session.session.get(User, args.user_id) if user is None: abort(404, message='User not found') @@ -621,7 +641,7 @@ def get(self): return jsonify({'success': True}) - def post(self): + def post(self) -> ResponseReturnValue: parser = reqparse.RequestParser() parser.add_argument('key', required=True, type=str) args = parser.parse_args() @@ -640,7 +660,7 @@ def post(self): class ResetPasswordResource(Resource): - def post(self): + def post(self) -> ResponseReturnValue: parser = reqparse.RequestParser() parser.add_argument('ident', required=True, type=str) parser.add_argument('email', required=True, type=str) @@ -660,7 +680,7 @@ def post(self): 'success': True } - def patch(self): + def patch(self) -> ResponseReturnValue: parser = reqparse.RequestParser() parser.add_argument('token', required=True, type=str) parser.add_argument('password', required=True, type=str) diff --git a/web/api/v0/helpers.py b/web/api/v0/helpers.py index 6e7f121ab..88b852379 100644 --- a/web/api/v0/helpers.py +++ b/web/api/v0/helpers.py @@ -1,5 +1,5 @@ -from datetime import datetime +from datetime import datetime, date -def parse_iso_date(date_str: str): +def parse_iso_date(date_str: str) -> date: return datetime.strptime(date_str, '%Y-%m-%d').date() diff --git a/web/blueprints/finance/__init__.py b/web/blueprints/finance/__init__.py index c6a52b960..daff02296 100644 --- a/web/blueprints/finance/__init__.py +++ b/web/blueprints/finance/__init__.py @@ -736,7 +736,7 @@ def balance_json(account_id: int) -> ResponseReturnValue: sum_exp: ColumnElement[int] = t.cast( Over[int], - func.sum(Split.amount).over(order_by=Transaction.valid_on), # type: ignore[no-untyped-call] + func.sum(Split.amount).over(order_by=Transaction.valid_on), ) if invert: