From b185b7de0e54402eafe2c8002a2d509383dd8630 Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Sun, 17 Sep 2023 23:10:10 -0400 Subject: [PATCH] [cdd/compound/openapi/gen_routes.py] Ensure `Module` body is actually a list ; [cdd/sqlalchemy/parse.py] Correctly unwrap a `Module` when provided --- cdd/compound/openapi/gen_routes.py | 2 +- cdd/sqlalchemy/parse.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/cdd/compound/openapi/gen_routes.py b/cdd/compound/openapi/gen_routes.py index 8b2b7251..31d170e8 100644 --- a/cdd/compound/openapi/gen_routes.py +++ b/cdd/compound/openapi/gen_routes.py @@ -66,7 +66,7 @@ def gen_routes(app, model_path, model_name, crud, route): None, ) sqlalchemy_ir = cdd.sqlalchemy.parse.sqlalchemy( - Module(body=sqlalchemy_node, stmt=None, type_ignores=[]) + Module(body=[sqlalchemy_node], stmt=None, type_ignores=[]) ) primary_key = next( map( diff --git a/cdd/sqlalchemy/parse.py b/cdd/sqlalchemy/parse.py index 82557195..4d60ca91 100644 --- a/cdd/sqlalchemy/parse.py +++ b/cdd/sqlalchemy/parse.py @@ -10,7 +10,7 @@ """ import ast -from ast import AnnAssign, Assign, Call, ClassDef +from ast import AnnAssign, Assign, Call, ClassDef, Module from collections import OrderedDict from inspect import getsource @@ -19,7 +19,7 @@ from cdd.docstring.parse import docstring from cdd.shared.ast_utils import get_value from cdd.shared.defaults_utils import extract_default -from cdd.shared.pure_utils import assert_equal +from cdd.shared.pure_utils import assert_equal, rpartial from cdd.sqlalchemy.utils.parse_utils import column_call_to_param @@ -125,7 +125,11 @@ def sqlalchemy(class_def, parse_original_whitespace=False): """ if not isinstance(class_def, ClassDef): - class_def = ast.parse(getsource(class_def)).body[0] + class_def = ( + next(filter(rpartial(isinstance, ClassDef), class_def.body)) + if isinstance(class_def, Module) + else ast.parse(getsource(class_def)).body[0] + ) assert isinstance(class_def, ClassDef), "Expected `ClassDef` got `{!r}`".format( type(class_def).__name__ )