Skip to content

Commit

Permalink
Adds more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dtiesling committed Nov 16, 2023
1 parent a414f05 commit aa9b1c3
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 50 deletions.
2 changes: 2 additions & 0 deletions flask_muck/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .views import MuckApiView
from .callback import MuckCallback
4 changes: 0 additions & 4 deletions flask_muck/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
class MuckApiValidationException(Exception):
pass


class MuckImplementationError(Exception):
pass
95 changes: 55 additions & 40 deletions flask_muck/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional, Union, Any

from flask import request, Blueprint
from flask.typing import ResponseReturnValue
from flask.views import MethodView
from marshmallow import Schema
from sqlalchemy.exc import IntegrityError
Expand All @@ -19,10 +20,10 @@
)
from webargs import fields
from webargs.flaskparser import parser
from werkzeug.exceptions import MethodNotAllowed, BadRequest, Conflict

from flask_muck.callback import CallbackType
from flask_muck.callback import MuckCallback
from flask_muck.exceptions import MuckApiValidationException
from flask_muck.types import SqlaModelType, JsonDict, ResourceId, SqlaModel
from flask_muck.utils import (
get_url_rule,
Expand Down Expand Up @@ -68,12 +69,18 @@ class MuckApiView(MethodView):
allowed_methods: set[str] = {"GET", "POST", "PUT", "PATCH", "DELETE"}
primary_key_column: str = "id"
primary_key_type: Union[type[int], type[str]] = int
filter_operator_separator: str = "__"
operator_separator: str = "__"

@property
def query(self) -> Query:
return self.session.query(self.Model)

def dispatch_request(self, **kwargs: Any) -> ResponseReturnValue:
"""Overriden to check the list of allowed_methods."""
if request.method.lower() not in [m.lower() for m in self.allowed_methods]:
raise MethodNotAllowed
return super().dispatch_request(**kwargs)

def _execute_callbacks(
self,
resource: SqlaModel,
Expand Down Expand Up @@ -111,7 +118,7 @@ def _get_clean_filter_data(self, filters: Optional[str]) -> JsonDict:
try:
return json.loads(filters)
except JSONDecodeError:
raise MuckApiValidationException(f"Filters [{filters}] is not valid json.")
raise BadRequest(f"Filters [{filters}] is not valid json.")

def _get_kwargs_from_request_payload(self) -> JsonDict:
"""Creates the correct schema based on request method and returns a sanitized dictionary of kwargs from the
Expand Down Expand Up @@ -225,7 +232,7 @@ def post(self) -> tuple[JsonDict, int]:
resource = self._create_resource(kwargs)
except IntegrityError as e:
self.session.rollback()
raise MuckApiValidationException(str(e))
raise Conflict(str(e))
self._execute_callbacks(resource, kwargs, CallbackType.pre)
self.session.commit()
self._execute_callbacks(resource, kwargs, CallbackType.post)
Expand Down Expand Up @@ -275,48 +282,49 @@ def _get_query_filters(
for column_name, value in filters.items():
# Get operator.
operator = None
if self.filter_operator_separator in column_name:
column_name, operator = column_name.split(
self.filter_operator_separator
)
if self.operator_separator in column_name:
column_name, operator = column_name.split(self.operator_separator)

# Handle nested filters.
if "." in column_name:
relationship_name, column_name = column_name.split(".")
field = getattr(self.Model, relationship_name)
field = getattr(self.Model, relationship_name, None)
if not field:
continue
raise BadRequest(
f"{column_name} is not a valid filter field. The relationship does not exist."
)
SqlaModel = field.property.mapper.class_
join_models.add(SqlaModel)
else:
SqlaModel = self.Model

if hasattr(SqlaModel, column_name):
column = getattr(SqlaModel, column_name)
if operator == "gt":
filter = column > value
elif operator == "gte":
filter = column >= value
elif operator == "lt":
filter = column < value
elif operator == "lte":
filter = column <= value
elif operator == "ne":
filter = column != value
elif operator == "in":
filter = column.in_(value)
elif operator == "not_in":
filter = column.not_in(value)
else:
filter = column == value
query_filters.append(filter)
if not (column := getattr(SqlaModel, column_name, None)):
raise BadRequest(f"{column_name} is not a valid filter field.")

if operator == "gt":
filter = column > value
elif operator == "gte":
filter = column >= value
elif operator == "lt":
filter = column < value
elif operator == "lte":
filter = column <= value
elif operator == "ne":
filter = column != value
elif operator == "in":
filter = column.in_(value)
elif operator == "not_in":
filter = column.not_in(value)
else:
filter = column == value
query_filters.append(filter)
return query_filters, join_models

def _get_query_order_by(
self, sort: str
) -> tuple[Optional[UnaryExpression], set[SqlaModelType]]:
if "__" in sort:
column_name, direction = sort.split("__")
if self.operator_separator in sort:
column_name, direction = sort.split(self.operator_separator)
else:
column_name, direction = sort, "asc"

Expand All @@ -340,7 +348,7 @@ def _get_query_order_by(
elif direction == "desc":
order_by = column.desc()
else:
raise MuckApiValidationException(
raise BadRequest(
f"Invalid sort direction: {direction}. Must asc or desc"
)
return order_by, join_models
Expand Down Expand Up @@ -368,24 +376,31 @@ def add_crud_to_blueprint(cls, blueprint: Blueprint) -> None:
"""Adds CRUD endpoints to a blueprint."""
url_rule = get_url_rule(cls, None)
api_view = cls.as_view(f"{cls.api_name}_api")

# In the special case that this API represents a ONE-TO-ONE relationship, use / for all methods.
if cls.one_to_one_api:
blueprint.add_url_rule(
url_rule,
defaults={"resource_id": None},
view_func=api_view,
methods=cls.allowed_methods,
methods={"GET", "PUT", "PATCH", "DELETE"},
)
if "POST" in cls.allowed_methods:

else:
# Create endpoint - POST on /
blueprint.add_url_rule(url_rule, view_func=api_view, methods=["POST"])
if "GET" in cls.allowed_methods:

# List endpoint - GET on /
blueprint.add_url_rule(
url_rule,
defaults={"resource_id": None},
view_func=api_view,
methods=["GET"],
)
blueprint.add_url_rule(
f"{url_rule}/<resource_id>",
view_func=api_view,
methods=cls.allowed_methods.intersection({"GET", "PUT", "PATCH", "DELETE"}),
)

# Detail, Update, Patch, Delete endpoints - GET, PUT, PATCH, DELETE on /<resource_id>
blueprint.add_url_rule(
f"{url_rule}/<resource_id>",
view_func=api_view,
methods={"GET", "PUT", "PATCH", "DELETE"},
)
66 changes: 65 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ sqlalchemy-stubs = "^0.4"
pytest = "^7.4.3"
flask-login = "^0.6.3"
flask-sqlalchemy = "^3.1.1"
coverage = "^7.3.2"

[tool.mypy]
packages = "flask_muck"
Expand Down
6 changes: 5 additions & 1 deletion tests/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from flask_sqlalchemy import SQLAlchemy
from marshmallow import fields as mf
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import DeclarativeBase, Mapped

from flask_muck.views import MuckApiView

Expand All @@ -37,19 +37,23 @@ class GuardianModel(db.Model):
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
name = db.Column(db.String, nullable=False)
age = db.Column(db.Integer, nullable=True)
children: Mapped[list["ChildModel"]] = db.relationship()


class ChildModel(db.Model):
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
name = db.Column(db.String, nullable=False)
age = db.Column(db.Integer, nullable=True)
guardian_id = db.Column(db.Integer, db.ForeignKey(GuardianModel.id))
guardian = db.relationship(GuardianModel, back_populates="children")
toys: Mapped[list["ToyModel"]] = db.relationship()


class ToyModel(db.Model):
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
name = db.Column(db.String, nullable=False)
child_id = db.Column(db.Integer, db.ForeignKey(ChildModel.id))
child = db.relationship(ChildModel, back_populates="toys")


class GuardianSchema(ma.Schema):
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,16 @@ def simpson_family(db):
@pytest.fixture
def belcher_family(db):
bob = GuardianModel(name="Bob", age=46)
db.session.add(bob)
db.session.flush()
tina = ChildModel(name="Tina", age=12, guardian_id=bob.id)
louise = ChildModel(name="Louise", age=9, guardian_id=bob.id)
gene = ChildModel(name="Gene", age=11, guardian_id=bob.id)
db.session.add_all([tina, louise, gene])
db.session.flush()
pony = ToyModel(name="Pony", child_id=tina.id)
hat = ToyModel(name="Hat", child_id=louise.id)
keyboard = ToyModel(name="Keyboard", child_id=gene.id)
db.session.add_all([bob, tina, louise, gene, pony, hat, keyboard])
db.session.add_all([pony, hat, keyboard])
db.session.flush()
return bob, tina, louise, gene, pony, hat, keyboard
Loading

0 comments on commit aa9b1c3

Please sign in to comment.