diff --git a/docs/examples/contrib/jwt/using_jwt_auth.py b/docs/examples/contrib/jwt/using_jwt_auth.py index 891f0b7878..b15883cd92 100644 --- a/docs/examples/contrib/jwt/using_jwt_auth.py +++ b/docs/examples/contrib/jwt/using_jwt_auth.py @@ -46,12 +46,7 @@ async def retrieve_user_handler(token: Token, connection: "ASGIConnection[Any, A @post("/login") async def login_handler(data: User) -> Response[User]: MOCK_DB[str(data.id)] = data - response = jwt_auth.login(identifier=str(data.id), response_body=data) - - # you can do whatever you want to update the response instance here - # e.g. response.set_cookie(...) - - return response + return jwt_auth.login(identifier=str(data.id), response_body=data) # We also have some other routes, for example: diff --git a/docs/examples/contrib/jwt/using_jwt_cookie_auth.py b/docs/examples/contrib/jwt/using_jwt_cookie_auth.py index 1e37e81561..41c6e1a594 100644 --- a/docs/examples/contrib/jwt/using_jwt_cookie_auth.py +++ b/docs/examples/contrib/jwt/using_jwt_cookie_auth.py @@ -48,12 +48,7 @@ async def retrieve_user_handler(token: "Token", connection: "ASGIConnection[Any, @post("/login") async def login_handler(data: "User") -> "Response[User]": MOCK_DB[str(data.id)] = data - response = jwt_cookie_auth.login(identifier=str(data.id), response_body=data) - - # you can do whatever you want to update the response instance here - # e.g. response.set_cookie(...) - - return response + return jwt_cookie_auth.login(identifier=str(data.id), response_body=data) # We also have some other routes, for example: diff --git a/docs/examples/contrib/jwt/using_oauth2_password_bearer.py b/docs/examples/contrib/jwt/using_oauth2_password_bearer.py index 7916eb4ca1..d48a7ce999 100644 --- a/docs/examples/contrib/jwt/using_oauth2_password_bearer.py +++ b/docs/examples/contrib/jwt/using_oauth2_password_bearer.py @@ -48,26 +48,14 @@ async def retrieve_user_handler(token: "Token", connection: "ASGIConnection[Any, @post("/login") async def login_handler(request: "Request[Any, Any, Any]", data: "User") -> "Response[OAuth2Login]": MOCK_DB[str(data.id)] = data - # if we do not define a response body, the login process will return a standard OAuth2 login response. Note the `Response[OAuth2Login]` return type. - response = oauth2_auth.login(identifier=str(data.id)) - - # you can do whatever you want to update the response instance here - # e.g. response.set_cookie(...) - - return response + return oauth2_auth.login(identifier=str(data.id)) @post("/login_custom") async def login_custom_response_handler(data: "User") -> "Response[User]": MOCK_DB[str(data.id)] = data - # If you'd like to define a custom response body, use the `response_body` parameter. Note the `Response[User]` return type. - response = oauth2_auth.login(identifier=str(data.id), response_body=data) - - # you can do whatever you want to update the response instance here - # e.g. response.set_cookie(...) - - return response + return oauth2_auth.login(identifier=str(data.id), response_body=data) # We also have some other routes, for example: diff --git a/docs/examples/parameters/header_and_cookie_parameters.py b/docs/examples/parameters/header_and_cookie_parameters.py index c1d1393796..97dc05dfa0 100644 --- a/docs/examples/parameters/header_and_cookie_parameters.py +++ b/docs/examples/parameters/header_and_cookie_parameters.py @@ -26,7 +26,7 @@ async def get_user( token: str = Parameter(header="X-API-KEY"), cookie: str = Parameter(cookie="my-cookie-param"), ) -> User: - if not (token == VALID_TOKEN and cookie == VALID_COOKIE_VALUE): + if token != VALID_TOKEN or cookie != VALID_COOKIE_VALUE: raise NotAuthorizedException return User.parse_obj(USER_DB[user_id]) diff --git a/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_async.py b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_async.py index 6d1eb54658..2b739fe562 100644 --- a/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_async.py +++ b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_async.py @@ -55,10 +55,10 @@ async def get_company(company_id: int, async_session: AsyncSession) -> Company: If a company with that ID does not exist, return a 404 response """ result = await async_session.scalars(select(Company).where(Company.id == company_id)) - company: Optional[Company] = result.one_or_none() - if not company: + if company := result.one_or_none(): + return company + else: raise HTTPException(detail=f"Company with ID {company_id} not found", status_code=HTTP_404_NOT_FOUND) - return company app = Starlite( diff --git a/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships.py b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships.py index 82e37f603e..27b675e329 100644 --- a/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships.py +++ b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships.py @@ -46,10 +46,12 @@ def get_user(user_id: int, db_session: Session) -> User: If a user with that ID does not exist, return a 404 response """ - user: Optional[User] = db_session.scalars(select(User).where(User.id == user_id)).one_or_none() - if not user: + if user := db_session.scalars( + select(User).where(User.id == user_id) + ).one_or_none(): + return user + else: raise HTTPException(detail=f"User with ID {user} not found", status_code=HTTP_404_NOT_FOUND) - return user app = Starlite( diff --git a/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships_to_many.py b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships_to_many.py index fea058bfd4..9a2cb81460 100644 --- a/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships_to_many.py +++ b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships_to_many.py @@ -64,10 +64,12 @@ def get_user(user_id: int, db_session: Session) -> UserModel: If a user with that ID does not exist, return a 404 response """ - user: Optional[User] = db_session.scalars(select(User).where(User.id == user_id)).one_or_none() - if not user: + if user := db_session.scalars( + select(User).where(User.id == user_id) + ).one_or_none(): + return UserModel.from_orm(user) + else: raise HTTPException(detail=f"User with ID {user} not found", status_code=HTTP_404_NOT_FOUND) - return UserModel.from_orm(user) app = Starlite( diff --git a/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_sync.py b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_sync.py index 87654f2f9a..51cdcd96d9 100644 --- a/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_sync.py +++ b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_sync.py @@ -50,10 +50,12 @@ def get_company(company_id: str, db_session: Session) -> Company: If a company with that ID does not exist, return a 404 response """ - company: Optional[Company] = db_session.scalars(select(Company).where(Company.id == company_id)).one_or_none() - if not company: + if company := db_session.scalars( + select(Company).where(Company.id == company_id) + ).one_or_none(): + return company + else: raise HTTPException(detail=f"Company with ID {company_id} not found", status_code=HTTP_404_NOT_FOUND) - return company app = Starlite( diff --git a/docs/examples/security/using_session_auth.py b/docs/examples/security/using_session_auth.py index 846a927d0a..0066f4461e 100644 --- a/docs/examples/security/using_session_auth.py +++ b/docs/examples/security/using_session_auth.py @@ -56,13 +56,7 @@ class UserLoginPayload(BaseModel): async def retrieve_user_handler( session: Dict[str, Any], connection: "ASGIConnection[Any, Any, Any, Any]" ) -> Optional[User]: - # we retrieve the user instance based on session data - - user_id = session.get("user_id") - if user_id: - return MOCK_DB.get(user_id) - - return None + return MOCK_DB.get(user_id) if (user_id := session.get("user_id")) else None @post("/login") diff --git a/starlite/_asgi/routing_trie/mapping.py b/starlite/_asgi/routing_trie/mapping.py index c0f75ba128..ef9f9f67af 100644 --- a/starlite/_asgi/routing_trie/mapping.py +++ b/starlite/_asgi/routing_trie/mapping.py @@ -85,10 +85,11 @@ def add_route_to_trie( """ current_node = root_node - is_mount = hasattr(route, "route_handler") and getattr(route.route_handler, "is_mount", False) # pyright: ignore has_path_parameters = bool(route.path_parameters) - if is_mount: # pyright: ignore + if is_mount := hasattr(route, "route_handler") and getattr( + route.route_handler, "is_mount", False + ): current_node = add_mount_route( current_node=current_node, mount_routes=mount_routes, diff --git a/starlite/_asgi/routing_trie/traversal.py b/starlite/_asgi/routing_trie/traversal.py index acad6300c0..ae1e5c234c 100644 --- a/starlite/_asgi/routing_trie/traversal.py +++ b/starlite/_asgi/routing_trie/traversal.py @@ -143,7 +143,9 @@ def parse_path_to_route( remaining_path = path[match.end() :] # since we allow regular handlers under static paths, we must validate that the request does not match # any such handler. - if not mount_node.children or not any(sub_route in path for sub_route in mount_node.children): # type: ignore + if not mount_node.children or all( + sub_route not in path for sub_route in mount_node.children + ): # type: ignore asgi_app, handler = parse_node_handlers(node=mount_node, method=method) remaining_path = remaining_path or "/" if not mount_node.is_static: diff --git a/starlite/_asgi/routing_trie/validate.py b/starlite/_asgi/routing_trie/validate.py index 1d5bd00699..4712633cbc 100644 --- a/starlite/_asgi/routing_trie/validate.py +++ b/starlite/_asgi/routing_trie/validate.py @@ -31,8 +31,7 @@ def validate_node(node: RouteTrieNode) -> None: node.is_mount and node.children and any( - v - for v in chain.from_iterable( + chain.from_iterable( list(child.path_parameters.values()) if isinstance(child.path_parameters, dict) else child.path_parameters diff --git a/starlite/_kwargs/kwargs_model.py b/starlite/_kwargs/kwargs_model.py index 25c1805ee7..a25e9f0360 100644 --- a/starlite/_kwargs/kwargs_model.py +++ b/starlite/_kwargs/kwargs_model.py @@ -453,8 +453,11 @@ def _validate_raw_kwargs( f"Make sure to use distinct keys for your dependencies, path parameters and aliased parameters." ) - used_reserved_kwargs = {*parameter_names, *path_parameters, *dependency_keys}.intersection(RESERVED_KWARGS) - if used_reserved_kwargs: + if used_reserved_kwargs := { + *parameter_names, + *path_parameters, + *dependency_keys, + }.intersection(RESERVED_KWARGS): raise ImproperlyConfiguredException( f"Reserved kwargs ({', '.join(RESERVED_KWARGS)}) cannot be used for dependencies and parameter arguments. " f"The following kwargs have been used: {', '.join(used_reserved_kwargs)}" diff --git a/starlite/_kwargs/parameter_definition.py b/starlite/_kwargs/parameter_definition.py index e6ca919ab6..8427948ad9 100644 --- a/starlite/_kwargs/parameter_definition.py +++ b/starlite/_kwargs/parameter_definition.py @@ -60,7 +60,9 @@ def create_parameter_definition( field_alias=field_alias, default_value=default_value, is_required=signature_field.is_required - and (default_value is None and not (signature_field.is_optional or signature_field.is_any)), + and default_value is None + and not signature_field.is_optional + and not signature_field.is_any, is_sequence=signature_field.is_non_string_sequence, ) diff --git a/starlite/_multipart.py b/starlite/_multipart.py index 55c93c44f2..7414564acc 100644 --- a/starlite/_multipart.py +++ b/starlite/_multipart.py @@ -123,7 +123,7 @@ def parse_multipart_form(body: bytes, boundary: bytes, multipart_form_part_limit line_index = line_end_index + 2 colon_index = form_line.index(":") current_idx = colon_index + 2 - form_header_field = form_line[0:colon_index].lower() + form_header_field = form_line[:colon_index].lower() form_header_value, form_parameters = parse_content_header(form_line[current_idx:]) if form_header_field == "content-disposition": diff --git a/starlite/_openapi/schema_generation/constrained_fields.py b/starlite/_openapi/schema_generation/constrained_fields.py index ddcf9f1d69..07d367d962 100644 --- a/starlite/_openapi/schema_generation/constrained_fields.py +++ b/starlite/_openapi/schema_generation/constrained_fields.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: from starlite.plugins import OpenAPISchemaPluginProtocol -if TYPE_CHECKING: from starlite._signature.models import SignatureField try: @@ -158,10 +157,7 @@ def create_collection_constrained_field_schema( create_schema(field=sub_field, generate_examples=False, plugins=plugins, schemas=schemas) for sub_field in children ] - if len(items) > 1: - schema.items = Schema(one_of=items) - else: - schema.items = items[0] + schema.items = Schema(one_of=items) if len(items) > 1 else items[0] else: from starlite._signature.models import SignatureField diff --git a/starlite/_openapi/typescript_converter/schema_parsing.py b/starlite/_openapi/typescript_converter/schema_parsing.py index 4793c9598d..558822849a 100644 --- a/starlite/_openapi/typescript_converter/schema_parsing.py +++ b/starlite/_openapi/typescript_converter/schema_parsing.py @@ -50,12 +50,14 @@ def normalize_typescript_namespace(value: str, allow_quoted: bool) -> str: Returns: A normalized value """ - if not allow_quoted and not (value[0].isalpha() or value[0] in {"_", "$"}): + if ( + not allow_quoted + and not value[0].isalpha() + and value[0] not in {"_", "$"} + ): raise ValueError(f"invalid typescript namespace {value}") if allow_quoted: - if allowed_key_re.fullmatch(value): - return value - return f'"{value}"' + return value if allowed_key_re.fullmatch(value) else f'"{value}"' return invalid_namespace_re.sub("", value) diff --git a/starlite/_openapi/typescript_converter/types.py b/starlite/_openapi/typescript_converter/types.py index bf565598fc..ff265d4da9 100644 --- a/starlite/_openapi/typescript_converter/types.py +++ b/starlite/_openapi/typescript_converter/types.py @@ -24,15 +24,12 @@ def _as_string(value: Any) -> str: if isinstance(value, str): - return '"' + value + '"' + return f'"{value}"' if isinstance(value, bool): return "true" if value else "false" - if value is None: - return "null" - - return str(value) + return "null" if value is None else str(value) class TypeScriptElement(ABC): diff --git a/starlite/_parsers.py b/starlite/_parsers.py index a90bd83c05..4f34acc5d0 100644 --- a/starlite/_parsers.py +++ b/starlite/_parsers.py @@ -48,10 +48,14 @@ def parse_cookie_string(cookie_string: str) -> dict[str, str]: Returns: A string keyed dictionary of values """ - output: dict[str, str] = {} cookies = [cookie.split("=", 1) if "=" in cookie else ("", cookie) for cookie in cookie_string.split(";")] - for k, v in filter(lambda x: x[0] or x[1], ((k.strip(), v.strip()) for k, v in cookies)): - output[k] = unquote(unquote_cookie(v)) + output: dict[str, str] = { + k: unquote(unquote_cookie(v)) + for k, v in filter( + lambda x: x[0] or x[1], + ((k.strip(), v.strip()) for k, v in cookies), + ) + } return output diff --git a/starlite/_signature/models.py b/starlite/_signature/models.py index ca3964d009..80417219d6 100644 --- a/starlite/_signature/models.py +++ b/starlite/_signature/models.py @@ -139,7 +139,11 @@ def is_required(self) -> bool: if isinstance(self.kwarg_model, ParameterKwarg) and self.kwarg_model.required is not None: return self.kwarg_model.required - return not (self.is_optional or self.is_any) and (self.is_empty or self.default_value is None) + return ( + not self.is_optional + and not self.is_any + and (self.is_empty or self.default_value is None) + ) @property def is_literal(self) -> bool: diff --git a/starlite/_signature/parsing.py b/starlite/_signature/parsing.py index b3c91dcf0a..cf428ba102 100644 --- a/starlite/_signature/parsing.py +++ b/starlite/_signature/parsing.py @@ -166,9 +166,7 @@ def parse_fn_signature( ) if isinstance(parameter.default, DependencyKwarg) and parameter.name not in dependency_name_set: - if not parameter.optional and ( - isinstance(parameter.default, DependencyKwarg) and parameter.default.default is Empty - ): + if not parameter.optional and parameter.default.default is Empty: raise ImproperlyConfiguredException( f"Explicit dependency '{parameter.name}' for '{fn_name}' has no default value, " f"or provided dependency." diff --git a/starlite/cli/_utils.py b/starlite/cli/_utils.py index 1c745f9168..cd27d2f7a5 100644 --- a/starlite/cli/_utils.py +++ b/starlite/cli/_utils.py @@ -291,12 +291,10 @@ def _autodiscover_app(cwd: Path) -> LoadedApp: def _format_is_enabled(value: Any) -> str: """Return a coloured string `"Enabled" if ``value`` is truthy, else "Disabled".""" - if value: - return "[green]Enabled[/]" - return "[red]Disabled[/]" + return "[green]Enabled[/]" if value else "[red]Disabled[/]" -def show_app_info(app: Starlite) -> None: # pragma: no cover +def show_app_info(app: Starlite) -> None: # pragma: no cover """Display basic information about the application and its configuration.""" table = Table(show_header=False) @@ -324,12 +322,10 @@ def show_app_info(app: Starlite) -> None: # pragma: no cover if app.static_files_config: static_files_configs = app.static_files_config - static_files_info = [] - for static_files in static_files_configs: - static_files_info.append( - f"path=[yellow]{static_files.path}[/] dirs=[yellow]{', '.join(map(str, static_files.directories))}[/] " - f"html_mode={_format_is_enabled(static_files.html_mode)}", - ) + static_files_info = [ + f"path=[yellow]{static_files.path}[/] dirs=[yellow]{', '.join(map(str, static_files.directories))}[/] html_mode={_format_is_enabled(static_files.html_mode)}" + for static_files in static_files_configs + ] table.add_row("Static files", "\n".join(static_files_info)) if app.serialization_plugins: diff --git a/starlite/connection/base.py b/starlite/connection/base.py index 062bf2fa32..3bee6a7e4a 100644 --- a/starlite/connection/base.py +++ b/starlite/connection/base.py @@ -184,9 +184,7 @@ def cookies(self) -> dict[str, str]: """ if self._cookies is Empty: cookies: dict[str, str] = {} - cookie_header = self.headers.get("cookie") - - if cookie_header: + if cookie_header := self.headers.get("cookie"): cookies = parse_cookie_string(cookie_header) self._cookies = self.scope["_cookies"] = cookies # type: ignore[typeddict-unknown-key] diff --git a/starlite/connection/request.py b/starlite/connection/request.py index 4eafe7054a..0cf89a01d6 100644 --- a/starlite/connection/request.py +++ b/starlite/connection/request.py @@ -115,23 +115,22 @@ async def stream(self) -> AsyncGenerator[bytes, None]: RuntimeError: if the stream is already consumed """ if self._body is Empty: - if self.is_connected: - while event := await self.receive(): - if event["type"] == "http.request": - if event["body"]: - yield event["body"] + if not self.is_connected: + raise InternalServerException("stream consumed") + while event := await self.receive(): + if event["type"] == "http.request": + if event["body"]: + yield event["body"] - if not event.get("more_body", False): - break + if not event.get("more_body", False): + break - if event["type"] == "http.disconnect": - raise InternalServerException("client disconnected prematurely") + if event["type"] == "http.disconnect": + raise InternalServerException("client disconnected prematurely") - self.is_connected = False - yield b"" + self.is_connected = False + yield b"" - else: - raise InternalServerException("stream consumed") else: yield self._body yield b"" @@ -155,22 +154,22 @@ async def form(self) -> FormMultiDict: Returns: A FormMultiDict instance """ - if self._form is Empty: - content_type, options = self.content_type - if content_type == RequestEncodingType.MULTI_PART: - self._form = self.scope["_form"] = form_values = parse_multipart_form( # type: ignore[typeddict-unknown-key] - body=await self.body(), - boundary=options.get("boundary", "").encode(), - multipart_form_part_limit=self.app.multipart_form_part_limit, - ) - return FormMultiDict(form_values) - if content_type == RequestEncodingType.URL_ENCODED: - self._form = self.scope["_form"] = form_values = parse_url_encoded_form_data( # type: ignore[typeddict-unknown-key] - await self.body(), - ) - return FormMultiDict(form_values) - return FormMultiDict() - return FormMultiDict(self._form) + if self._form is not Empty: + return FormMultiDict(self._form) + content_type, options = self.content_type + if content_type == RequestEncodingType.MULTI_PART: + self._form = self.scope["_form"] = form_values = parse_multipart_form( # type: ignore[typeddict-unknown-key] + body=await self.body(), + boundary=options.get("boundary", "").encode(), + multipart_form_part_limit=self.app.multipart_form_part_limit, + ) + return FormMultiDict(form_values) + if content_type == RequestEncodingType.URL_ENCODED: + self._form = self.scope["_form"] = form_values = parse_url_encoded_form_data( # type: ignore[typeddict-unknown-key] + await self.body(), + ) + return FormMultiDict(form_values) + return FormMultiDict() async def send_push_promise(self, path: str) -> None: """Send a push promise. @@ -187,6 +186,8 @@ async def send_push_promise(self, path: str) -> None: if "http.response.push" in extensions: raw_headers = [] for name in SERVER_PUSH_HEADERS: - for value in self.headers.getall(name, []): - raw_headers.append((name.encode("latin-1"), value.encode("latin-1"))) + raw_headers.extend( + (name.encode("latin-1"), value.encode("latin-1")) + for value in self.headers.getall(name, []) + ) await self.send({"type": "http.response.push", "path": path, "headers": raw_headers}) diff --git a/starlite/contrib/htmx/_utils.py b/starlite/contrib/htmx/_utils.py index 5203ee0d3c..42663ac389 100644 --- a/starlite/contrib/htmx/_utils.py +++ b/starlite/contrib/htmx/_utils.py @@ -62,7 +62,6 @@ class HTMXHeaders(str, Enum): def get_trigger_event_headers(trigger_event: TriggerEventType) -> dict[str, Any]: """Return headers for trigger event response.""" - params = trigger_event["params"] or {} after_params: dict[EventAfterType, str] = { "receive": HTMXHeaders.TRIGGER_EVENT.value, "settle": HTMXHeaders.TRIGGER_AFTER_SETTLE.value, @@ -70,6 +69,7 @@ def get_trigger_event_headers(trigger_event: TriggerEventType) -> dict[str, Any] } if trigger_header := after_params.get(trigger_event["after"]): + params = trigger_event["params"] or {} return {trigger_header: encode_json({trigger_event["name"]: params}).decode()} raise ImproperlyConfiguredException( @@ -100,9 +100,7 @@ def get_replace_url_header(url: PushUrlType) -> dict[str, Any]: def get_refresh_header(refresh: bool) -> dict[str, Any]: """Return headers for client refresh response.""" - value = "" - if refresh: - value = "true" + value = "true" if refresh else "" return {HTMXHeaders.REFRESH.value: value} @@ -118,13 +116,10 @@ def get_retarget_header(target: str) -> dict[str, Any]: def get_location_headers(location: LocationType) -> dict[str, Any]: """Return headers for redirect without page-reload response.""" - spec: dict[str, Any] = {} - for key, value in location.items(): - if value: - spec[key] = value - if not spec: + if spec := {key: value for key, value in location.items() if value}: + return {HTMXHeaders.LOCATION.value: encode_json(spec).decode()} + else: raise ValueError("redirect_to is required parameter.") - return {HTMXHeaders.LOCATION.value: encode_json(spec).decode()} def get_headers(hx_headers: HtmxHeaderType) -> dict[str, Any]: @@ -148,9 +143,8 @@ def get_headers(hx_headers: HtmxHeaderType) -> dict[str, Any]: value: Any for key, value in hx_headers.items(): if key in ["redirect", "refresh", "location", "replace_url"]: - response = htmx_headers_dict[key](value) - return response + return htmx_headers_dict[key](value) if value is not None: response = htmx_headers_dict[key](value) - header.update(response) + header |= response return header diff --git a/starlite/contrib/htmx/request.py b/starlite/contrib/htmx/request.py index e5b67c808a..e77f17dd48 100644 --- a/starlite/contrib/htmx/request.py +++ b/starlite/contrib/htmx/request.py @@ -29,7 +29,7 @@ def _get_header_value(self, name: str) -> str | None: Check for uri encoded header and unquotes it in readable format. """ value = self.request.headers.get(name) or None - if value and self.request.headers.get(name + "-URI-AutoEncoded") == "true": + if value and self.request.headers.get(f"{name}-URI-AutoEncoded") == "true": return unquote(value) return value diff --git a/starlite/contrib/repository/testing/generic_mock_repository.py b/starlite/contrib/repository/testing/generic_mock_repository.py index b62665aa58..1b49388d16 100644 --- a/starlite/contrib/repository/testing/generic_mock_repository.py +++ b/starlite/contrib/repository/testing/generic_mock_repository.py @@ -73,8 +73,8 @@ def _update_audit_attributes(self, data: ModelT, now: datetime | None = None, do now = now or self._now() if self._model_has_updated: data.updated = now # type:ignore[attr-defined] - if self._model_has_updated and do_created: - data.created = now # type:ignore[attr-defined] + if do_created: + data.created = now # type:ignore[attr-defined] return data async def add(self, data: ModelT) -> ModelT: diff --git a/starlite/contrib/sqlalchemy/repository.py b/starlite/contrib/sqlalchemy/repository.py index f28901ae01..14868406ed 100644 --- a/starlite/contrib/sqlalchemy/repository.py +++ b/starlite/contrib/sqlalchemy/repository.py @@ -177,7 +177,7 @@ async def exists(self, **kwargs: Any) -> bool: """ existing = await self.count(**kwargs) - return bool(existing > 0) + return existing > 0 async def get(self, item_id: Any, **kwargs: Any) -> ModelT: """Get instance identified by `item_id`. diff --git a/starlite/contrib/sqlalchemy_1/plugin.py b/starlite/contrib/sqlalchemy_1/plugin.py index fa1a8e6d86..bf980108fc 100644 --- a/starlite/contrib/sqlalchemy_1/plugin.py +++ b/starlite/contrib/sqlalchemy_1/plugin.py @@ -129,9 +129,7 @@ def handle_numeric_type(column_type: sqlalchemy_type.Numeric) -> "Type": Returns: An appropriate numerical type """ - if column_type.asdecimal: - return Decimal - return float + return Decimal if column_type.asdecimal else float def handle_list_type(self, column_type: sqlalchemy_type.ARRAY) -> Any: """Handle the SQLAlchemy Array type. diff --git a/starlite/data_extractors.py b/starlite/data_extractors.py index 7d47664987..88638d3da3 100644 --- a/starlite/data_extractors.py +++ b/starlite/data_extractors.py @@ -263,19 +263,19 @@ async def extract_body(self, request: "Request[Any, Any, Any]") -> Any: Returns: Either the parsed request body or the raw byte-string. """ - if request.method != HttpMethod.GET: - if not self.parse_body: - return await request.body() - request_encoding_type = request.content_type[0] - if request_encoding_type == RequestEncodingType.JSON: - return await request.json() - form_data = await request.form() - if request_encoding_type == RequestEncodingType.URL_ENCODED: - return dict(form_data) - return { - key: repr(value) if isinstance(value, UploadFile) else value for key, value in form_data.multi_items() - } - return None + if request.method == HttpMethod.GET: + return None + if not self.parse_body: + return await request.body() + request_encoding_type = request.content_type[0] + if request_encoding_type == RequestEncodingType.JSON: + return await request.json() + form_data = await request.form() + if request_encoding_type == RequestEncodingType.URL_ENCODED: + return dict(form_data) + return { + key: repr(value) if isinstance(value, UploadFile) else value for key, value in form_data.multi_items() + } class ExtractedResponseData(TypedDict, total=False): @@ -401,10 +401,14 @@ def extract_cookies(self, messages: tuple[HTTPResponseStartEvent, HTTPResponseBo Returns: The Response's cookies dict. """ - cookie_string = ";".join( - [x[1].decode("latin-1") for x in filter(lambda x: x[0].lower() == b"set-cookie", messages[0]["headers"])] - ) - if cookie_string: + if cookie_string := ";".join( + [ + x[1].decode("latin-1") + for x in filter( + lambda x: x[0].lower() == b"set-cookie", messages[0]["headers"] + ) + ] + ): parsed_cookies = parse_cookie_string(cookie_string) return _obfuscate(parsed_cookies, self.obfuscate_cookies) if self.obfuscate_cookies else parsed_cookies return {} diff --git a/starlite/datastructures/cookie.py b/starlite/datastructures/cookie.py index f17b011f55..da811a045c 100644 --- a/starlite/datastructures/cookie.py +++ b/starlite/datastructures/cookie.py @@ -59,7 +59,7 @@ def simple_cookie(self) -> SimpleCookie: continue if value is not None: updated_key = key - if key == "max_age": + if updated_key == "max_age": updated_key = "max-age" namespace[updated_key] = value diff --git a/starlite/datastructures/headers.py b/starlite/datastructures/headers.py index f6ed037d4d..3d05846fbb 100644 --- a/starlite/datastructures/headers.py +++ b/starlite/datastructures/headers.py @@ -209,13 +209,12 @@ def __setitem__(self, key: str, value: str) -> None: """Set a header in the scope, overwriting duplicates.""" name_encoded = key.lower().encode("latin-1") value_encoded = value.encode("latin-1") - indices = self._find_indices(key) - if not indices: - self.headers.append((name_encoded, value_encoded)) - else: + if indices := self._find_indices(key): for i in indices[1:]: del self.headers[i] self.headers[indices[0]] = (name_encoded, value_encoded) + else: + self.headers.append((name_encoded, value_encoded)) def __delitem__(self, key: str) -> None: """Delete all headers matching ``name``""" @@ -307,12 +306,15 @@ class CacheControlHeader(Header): def _get_header_value(self) -> str: """Get the header value as string.""" - cc_items = [] - for key, value in self.dict( - exclude_unset=True, exclude_none=True, by_alias=True, exclude={"documentation_only"} - ).items(): - cc_items.append(key if isinstance(value, bool) else f"{key}={value}") - + cc_items = [ + key if isinstance(value, bool) else f"{key}={value}" + for key, value in self.dict( + exclude_unset=True, + exclude_none=True, + by_alias=True, + exclude={"documentation_only"}, + ).items() + ] return ", ".join(cc_items) @classmethod @@ -361,9 +363,7 @@ class ETag(Header): def _get_header_value(self) -> str: value = f'"{self.value}"' - if self.weak: - return f"W/{value}" - return value + return f"W/{value}" if self.weak else value @classmethod def from_header(cls, header_value: str) -> "ETag": diff --git a/starlite/datastructures/upload_file.py b/starlite/datastructures/upload_file.py index 96f368a48e..87b267b345 100644 --- a/starlite/datastructures/upload_file.py +++ b/starlite/datastructures/upload_file.py @@ -119,6 +119,8 @@ def __modify_schema__(cls, field_schema: dict[str, Any], field: ModelField | Non None """ if field: - field_schema.update( - {"type": OpenAPIType.STRING.value, "contentMediaType": "application/octet-stream", "format": "binary"} - ) + field_schema |= { + "type": OpenAPIType.STRING.value, + "contentMediaType": "application/octet-stream", + "format": "binary", + } diff --git a/starlite/datastructures/url.py b/starlite/datastructures/url.py index 10f397ddc3..936860082c 100644 --- a/starlite/datastructures/url.py +++ b/starlite/datastructures/url.py @@ -173,13 +173,14 @@ def from_scope(cls, scope: Scope) -> URL: path = scope.get("root_path", "") + scope["path"] query_string = scope.get("query_string", b"") - # # we use iteration here because it's faster, and headers might not yet be cached - # # in the scope - host = "" - for header_name, header_value in scope.get("headers", []): - if header_name == b"host": - host = header_value.decode("latin-1") - break + host = next( + ( + header_value.decode("latin-1") + for header_name, header_value in scope.get("headers", []) + if header_name == b"host" + ), + "", + ) if server and not host: host, port = server default_port = _DEFAULT_SCHEME_PORTS[scheme] diff --git a/starlite/dto.py b/starlite/dto.py index be6e97cc16..cb2e49d662 100644 --- a/starlite/dto.py +++ b/starlite/dto.py @@ -47,13 +47,10 @@ def get_field_type(model_field: ModelField) -> Any: Type of field. """ outer_type = model_field.outer_type_ - inner_type = model_field.type_ if "ForwardRef" not in repr(outer_type): return outer_type - if model_field.shape == SHAPE_SINGLETON: - return inner_type - # This might be too simplistic - return List[inner_type] # type: ignore + inner_type = model_field.type_ + return inner_type if model_field.shape == SHAPE_SINGLETON else List[inner_type] T = TypeVar("T") diff --git a/starlite/handlers/base.py b/starlite/handlers/base.py index 8a480f13cd..12557c572f 100644 --- a/starlite/handlers/base.py +++ b/starlite/handlers/base.py @@ -121,10 +121,10 @@ def handler_name(self) -> str: Returns: Name of the handler function """ - fn = getattr(self, "fn", None) - if not fn: + if fn := getattr(self, "fn", None): + return get_name(unwrap_partial(self.fn.value)) + else: raise ImproperlyConfiguredException("cannot access handler name before setting the handler function") - return get_name(unwrap_partial(self.fn.value)) @property def dependency_name_set(self) -> set[str]: @@ -160,7 +160,7 @@ def resolve_type_encoders(self) -> TypeEncodersMap: for layer in self.ownership_layers: if type_encoders := getattr(layer, "type_encoders", None): - self._resolved_type_encoders.update(type_encoders) + self._resolved_type_encoders |= type_encoders return cast("TypeEncodersMap", self._resolved_type_encoders) def resolve_layered_parameters(self) -> dict[str, SignatureField]: @@ -169,7 +169,7 @@ def resolve_layered_parameters(self) -> dict[str, SignatureField]: parameter_kwargs: dict[str, ParameterKwarg] = {} for layer in self.ownership_layers: - parameter_kwargs.update(getattr(layer, "parameters", {}) or {}) + parameter_kwargs |= (getattr(layer, "parameters", {}) or {}) self._resolved_layered_parameters = { key: SignatureField.create( @@ -224,7 +224,7 @@ def resolve_exception_handlers(self) -> ExceptionHandlersMap: """ resolved_exception_handlers: dict[int | type[Exception], ExceptionHandler] = {} for layer in self.ownership_layers: - resolved_exception_handlers.update(layer.exception_handlers or {}) + resolved_exception_handlers |= (layer.exception_handlers or {}) return resolved_exception_handlers def resolve_opts(self) -> None: @@ -236,7 +236,7 @@ def resolve_opts(self) -> None: opt: dict[str, Any] = {} for layer in self.ownership_layers: - opt.update(layer.opt or {}) + opt |= (layer.opt or {}) self.opt = opt @@ -249,7 +249,7 @@ def resolve_signature_namespace(self) -> dict[str, Any]: if self._resolved_layered_parameters is Empty: ns: dict[str, Any] = {} for layer in self.ownership_layers: - ns.update(layer.signature_namespace) + ns |= layer.signature_namespace self._resolved_signature_namespace = ns return cast("dict[str, Any]", self._resolved_signature_namespace) diff --git a/starlite/handlers/http_handlers/_utils.py b/starlite/handlers/http_handlers/_utils.py index 89665646e6..c60abb0ad9 100644 --- a/starlite/handlers/http_handlers/_utils.py +++ b/starlite/handlers/http_handlers/_utils.py @@ -93,10 +93,7 @@ async def create_response(data: Any) -> "ASGIApp": ) response.raw_headers = raw_headers - if after_request: - return await after_request(response) # type: ignore - - return response + return await after_request(response) if after_request else response async def handler(data: Any, plugins: list["SerializationPluginProtocol"], **kwargs: Any) -> "ASGIApp": if isawaitable(data): @@ -111,7 +108,11 @@ async def handler(data: Any, plugins: list["SerializationPluginProtocol"], **kwa dto_type(**datum) if isinstance(datum, dict) else dto_type.from_model_instance(datum) for datum in data ] - elif plugins and not (is_dto_annotation or is_dto_iterable_annotation): + elif ( + plugins + and not is_dto_annotation + and not is_dto_iterable_annotation + ): data = await normalize_response_data(data=data, plugins=plugins) return await create_response(data=data) diff --git a/starlite/handlers/http_handlers/base.py b/starlite/handlers/http_handlers/base.py index d91bbbf813..c2a2a33d32 100644 --- a/starlite/handlers/http_handlers/base.py +++ b/starlite/handlers/http_handlers/base.py @@ -294,10 +294,14 @@ def resolve_response_class(self) -> type[Response]: Returns: The default :class:`Response <.response.Response>` class for the route handler. """ - for layer in list(reversed(self.ownership_layers)): - if layer.response_class is not None: - return layer.response_class - return Response + return next( + ( + layer.response_class + for layer in list(reversed(self.ownership_layers)) + if layer.response_class is not None + ), + Response, + ) def resolve_response_headers(self) -> frozenset[ResponseHeader]: """Return all header parameters in the scope of the handler function. @@ -312,14 +316,14 @@ def resolve_response_headers(self) -> frozenset[ResponseHeader]: if isinstance(layer_response_headers, Mapping): # this can't happen unless you manually set response_headers on an instance, which would result in a # type-checking error on everything but the controller. We cover this case nevertheless - resolved_response_headers.update( - {name: ResponseHeader(name=name, value=value) for name, value in layer_response_headers.items()} - ) + resolved_response_headers |= { + name: ResponseHeader(name=name, value=value) + for name, value in layer_response_headers.items() + } else: resolved_response_headers.update({h.name: h for h in layer_response_headers}) for extra_header in ("cache_control", "etag"): - header_model: Header | None = getattr(layer, extra_header, None) - if header_model: + if header_model := getattr(layer, extra_header, None): resolved_response_headers[header_model.HEADER_NAME] = ResponseHeader( name=header_model.HEADER_NAME, value=header_model.to_header(), diff --git a/starlite/middleware/_utils.py b/starlite/middleware/_utils.py index 2782cbc6b2..3562ad3fd0 100644 --- a/starlite/middleware/_utils.py +++ b/starlite/middleware/_utils.py @@ -57,8 +57,11 @@ def should_bypass_middleware( if exclude_opt_key and scope["route_handler"].opt.get(exclude_opt_key): return True - if exclude_path_pattern and exclude_path_pattern.findall( - scope["path"] if not getattr(scope.get("route_handler", {}), "is_mount", False) else scope["raw_path"].decode() - ): - return True - return False + return bool( + exclude_path_pattern + and exclude_path_pattern.findall( + scope["path"] + if not getattr(scope.get("route_handler", {}), "is_mount", False) + else scope["raw_path"].decode() + ) + ) diff --git a/starlite/middleware/allowed_hosts.py b/starlite/middleware/allowed_hosts.py index fb6676571f..4ccfe66400 100644 --- a/starlite/middleware/allowed_hosts.py +++ b/starlite/middleware/allowed_hosts.py @@ -42,10 +42,11 @@ def __init__(self, app: ASGIApp, config: AllowedHostsConfig): self.allowed_hosts_regex = re.compile("|".join(sorted(allowed_hosts))) # pyright: ignore if config.www_redirect: - redirect_domains: set[str] = { - host.replace("www.", "") for host in config.allowed_hosts if host.startswith("www.") - } - if redirect_domains: + if redirect_domains := { + host.replace("www.", "") + for host in config.allowed_hosts + if host.startswith("www.") + }: self.redirect_domains = re.compile("|".join(sorted(redirect_domains))) # pyright: ignore async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: @@ -64,16 +65,16 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No return headers = MutableScopeHeaders(scope=scope) - host = headers.get("host", headers.get("x-forwarded-host", "")).split(":")[0] - - if host: + if host := headers.get("host", headers.get("x-forwarded-host", "")).split( + ":" + )[0]: if self.allowed_hosts_regex.fullmatch(host): await self.app(scope, receive, send) return if self.redirect_domains is not None and self.redirect_domains.fullmatch(host): url = URL.from_scope(scope) - redirect_url = url.with_replacements(netloc="www." + url.netloc) + redirect_url = url.with_replacements(netloc=f"www.{url.netloc}") await RedirectResponse(url=str(redirect_url))(scope, receive, send) return diff --git a/starlite/middleware/cors.py b/starlite/middleware/cors.py index 97b761e74e..b786e039ec 100644 --- a/starlite/middleware/cors.py +++ b/starlite/middleware/cors.py @@ -41,12 +41,10 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No None """ headers = Headers.from_scope(scope=scope) - origin = headers.get("origin") - - if not origin: - await self.app(scope, receive, send) - else: + if origin := headers.get("origin"): await self.app(scope, receive, self.send_wrapper(send=send, origin=origin, has_cookie="cookie" in headers)) + else: + await self.app(scope, receive, send) def send_wrapper(self, send: Send, origin: str, has_cookie: bool) -> Send: """Wrap ``send`` to ensure that state is not disconnected. diff --git a/starlite/middleware/csrf.py b/starlite/middleware/csrf.py index 1bd4619967..2bf7dc65b7 100644 --- a/starlite/middleware/csrf.py +++ b/starlite/middleware/csrf.py @@ -172,10 +172,7 @@ def _decode_csrf_token(self, token: str) -> str | None: token_secret = token[:CSRF_SECRET_LENGTH] existing_hash = token[CSRF_SECRET_LENGTH:] expected_hash = generate_csrf_hash(token=token_secret, secret=self.config.secret) - if not compare_digest(existing_hash, expected_hash): - return None - - return token_secret + return token_secret if compare_digest(existing_hash, expected_hash) else None def _csrf_tokens_match(self, request_csrf_token: str | None, cookie_csrf_token: str | None) -> bool: """Take the CSRF tokens from the request and the cookie and verify both are valid and identical.""" diff --git a/starlite/middleware/exceptions/_debug_response.py b/starlite/middleware/exceptions/_debug_response.py index 4cc98cf022..e927c79680 100644 --- a/starlite/middleware/exceptions/_debug_response.py +++ b/starlite/middleware/exceptions/_debug_response.py @@ -95,10 +95,10 @@ def create_frame_html(frame: FrameInfo, collapsed: bool) -> str: """ frame_tpl = (tpl_dir / "frame.html").read_text() - code_lines: list[str] = [] - for idx, line in enumerate(frame.code_context or []): - code_lines.append(create_line_html(line, frame.lineno, frame.index or 0, idx)) - + code_lines: list[str] = [ + create_line_html(line, frame.lineno, frame.index or 0, idx) + for idx, line in enumerate(frame.code_context or []) + ] data = { "file": escape(frame.filename), "line": frame.lineno, @@ -120,10 +120,10 @@ def create_exception_html(exc: BaseException, line_limit: int) -> str: A string containing HTML representation of the execution frames related to the exception. """ frames = getinnerframes(exc.__traceback__, line_limit) if exc.__traceback__ else [] - result = [] - for idx, frame in enumerate(reversed(frames)): - result.append(create_frame_html(frame=frame, collapsed=idx > 0)) - + result = [ + create_frame_html(frame=frame, collapsed=idx > 0) + for idx, frame in enumerate(reversed(frames)) + ] return "".join(result) diff --git a/starlite/middleware/exceptions/middleware.py b/starlite/middleware/exceptions/middleware.py index 0c1afea2e4..f49bc49195 100644 --- a/starlite/middleware/exceptions/middleware.py +++ b/starlite/middleware/exceptions/middleware.py @@ -57,12 +57,17 @@ def get_exception_handler(exception_handlers: ExceptionHandlersMap, exc: Excepti status_code: int | None = getattr(exc, "status_code", None) if status_code and (exception_handler := exception_handlers.get(status_code)): return exception_handler - for cls in getmro(type(exc)): - if cls in exception_handlers: - return exception_handlers[cast("Type[Exception]", cls)] - if not hasattr(exc, "status_code") and HTTP_500_INTERNAL_SERVER_ERROR in exception_handlers: - return exception_handlers[HTTP_500_INTERNAL_SERVER_ERROR] - return None + return next( + ( + exception_handlers[cast("Type[Exception]", cls)] + for cls in getmro(type(exc)) + if cls in exception_handlers + ), + exception_handlers[HTTP_500_INTERNAL_SERVER_ERROR] + if not hasattr(exc, "status_code") + and HTTP_500_INTERNAL_SERVER_ERROR in exception_handlers + else None, + ) @dataclass diff --git a/starlite/middleware/logging.py b/starlite/middleware/logging.py index d5fbf48d50..24b656a3fe 100644 --- a/starlite/middleware/logging.py +++ b/starlite/middleware/logging.py @@ -161,9 +161,7 @@ def log_message(self, values: dict[str, Any]) -> None: def _serialize_value(self, serializer: Serializer | None, value: Any) -> Any: if not self.is_struct_logger and isinstance(value, (dict, list, tuple, set)): value = encode_json(value, serializer) - if isinstance(value, bytes): - return value.decode("utf-8") - return value + return value.decode("utf-8") if isinstance(value, bytes) else value async def extract_request_data(self, request: "Request") -> dict[str, Any]: """Create a dictionary of values for the message. diff --git a/starlite/middleware/session/client_side.py b/starlite/middleware/session/client_side.py index b755cdc6a8..426b2883a9 100644 --- a/starlite/middleware/session/client_side.py +++ b/starlite/middleware/session/client_side.py @@ -188,8 +188,7 @@ async def load_from_connection(self, connection: "ASGIConnection") -> dict[str, Returns: The session data """ - cookie_keys = self.get_cookie_keys(connection) - if cookie_keys: + if cookie_keys := self.get_cookie_keys(connection): data = [connection.cookies[key].encode("utf-8") for key in cookie_keys] # If these exceptions occur, the session must remain empty so do nothing. with contextlib.suppress(InvalidTag, binascii.Error): diff --git a/starlite/middleware/session/server_side.py b/starlite/middleware/session/server_side.py index 711fd23bb5..6db4b13ba4 100644 --- a/starlite/middleware/session/server_side.py +++ b/starlite/middleware/session/server_side.py @@ -139,8 +139,7 @@ async def load_from_connection(self, connection: "ASGIConnection") -> dict[str, Returns: The current session data """ - session_id = connection.cookies.get(self.config.key) - if session_id: + if session_id := connection.cookies.get(self.config.key): store = self.config.get_store_from_app(connection.scope["app"]) data = await self.get(session_id, store=store) if data is not None: diff --git a/starlite/openapi/controller.py b/starlite/openapi/controller.py index 04f4e1b68c..5d8001e28f 100644 --- a/starlite/openapi/controller.py +++ b/starlite/openapi/controller.py @@ -125,10 +125,7 @@ def should_serve_endpoint(self, request: "Request") -> bool: if request_path == root_path and config.root_schema_site in config.enabled_endpoints: return True - if request_path & config.enabled_endpoints: - return True - - return False + return bool(request_path & config.enabled_endpoints) @property def favicon(self) -> str: diff --git a/starlite/openapi/spec/base.py b/starlite/openapi/spec/base.py index 62a534a304..a25a9710b8 100644 --- a/starlite/openapi/spec/base.py +++ b/starlite/openapi/spec/base.py @@ -15,9 +15,7 @@ def _normalize_key(key: str) -> str: if "_" in key: components = key.split("_") return components[0] + "".join(component.title() for component in components[1:]) - if key == "ref": - return "$ref" - return key + return "$ref" if key == "ref" else key def _normalize_value(value: Any) -> Any: @@ -29,9 +27,7 @@ def _normalize_value(value: Any) -> Any: return {_normalize_value(k): _normalize_value(v) for k, v in value.items() if v is not None} if isinstance(value, list): return [_normalize_value(v) for v in value] - if isinstance(value, Enum): - return value.value - return value + return value.value if isinstance(value, Enum) else value @dataclass diff --git a/starlite/openapi/spec/schema.py b/starlite/openapi/spec/schema.py index 34f3d45f1f..7d5d31ba0e 100644 --- a/starlite/openapi/spec/schema.py +++ b/starlite/openapi/spec/schema.py @@ -34,9 +34,7 @@ def _recursive_hash(value: Hashable | Sequence | Mapping | DataclassProtocol | t return hash_value if is_non_string_sequence(value): return sum(_recursive_hash(v) for v in value) - if isinstance(value, Hashable): - return hash(value) - return 0 + return hash(value) if isinstance(value, Hashable) else 0 @dataclass diff --git a/starlite/partial.py b/starlite/partial.py index 4b4072eeeb..fc38d60870 100644 --- a/starlite/partial.py +++ b/starlite/partial.py @@ -87,15 +87,14 @@ def _create_partial_pydantic_model(cls, item: Type[BaseModel]) -> None: Args: item: A pydantic model class. """ - field_definitions: Dict[str, Tuple[Any, None]] = {} - for field_name, field_type in get_type_hints(item).items(): - if is_classvar(field_type): - continue - if not isinstance(field_type, GenericAlias) or NoneType not in field_type.__args__: - field_definitions[field_name] = (Optional[field_type], None) - else: - field_definitions[field_name] = (field_type, None) - + field_definitions: Dict[str, Tuple[Any, None]] = { + field_name: (Optional[field_type], None) + if not isinstance(field_type, GenericAlias) + or NoneType not in field_type.__args__ + else (field_type, None) + for field_name, field_type in get_type_hints(item).items() + if not is_classvar(field_type) + } cls._models[item] = create_model(cls._create_partial_type_name(item), __base__=item, **field_definitions) # type: ignore @classmethod diff --git a/starlite/response/base.py b/starlite/response/base.py index 8349f22eb5..c991875f6a 100644 --- a/starlite/response/base.py +++ b/starlite/response/base.py @@ -93,10 +93,11 @@ def __init__( ) self.is_head_response = is_head_response self.media_type = get_enum_string_value(media_type) - self.status_allows_body = not ( - status_code in {HTTP_204_NO_CONTENT, HTTP_304_NOT_MODIFIED} or status_code < HTTP_200_OK - ) self.status_code = status_code + self.status_allows_body = ( + status_code not in {HTTP_204_NO_CONTENT, HTTP_304_NOT_MODIFIED} + and status_code >= HTTP_200_OK + ) self._enc_hook = self.get_serializer(type_encoders) if not self.status_allows_body or is_head_response: @@ -119,8 +120,7 @@ def __init__( def get_serializer(cls, type_encoders: TypeEncodersMap | None = None) -> Serializer: """Get the serializer for this response class.""" - type_encoders = {**(cls.type_encoders or {}), **(type_encoders or {})} - if type_encoders: + if type_encoders := {**(cls.type_encoders or {}), **(type_encoders or {})}: return partial(default_serializer, type_encoders={**DEFAULT_TYPE_ENCODERS, **type_encoders}) return default_serializer @@ -241,11 +241,7 @@ def render(self, content: Any) -> bytes: """ try: if self.media_type.startswith("text/"): - if not content: - return b"" - - return content.encode(self.encoding) # type: ignore - + return b"" if not content else content.encode(self.encoding) if self.media_type == MediaType.MESSAGEPACK: return encode_msgpack(content, self._enc_hook) @@ -261,9 +257,7 @@ def content_length(self) -> int: The content length of the body (e.g. for use in a ``Content-Length`` header). If the response does not have a body, this value is ``None`` """ - if self.status_allows_body: - return len(self.body) - return 0 + return len(self.body) if self.status_allows_body else 0 def encode_headers(self) -> list[tuple[bytes, bytes]]: """Encode the response headers as a list of byte tuples. diff --git a/starlite/response/file.py b/starlite/response/file.py index 9c4669096c..1187b933d3 100644 --- a/starlite/response/file.py +++ b/starlite/response/file.py @@ -174,9 +174,7 @@ def content_length(self) -> int: Returns the value of :attr:`stat_result.st_size ` to populate the ``Content-Length`` header. """ - if isinstance(self.file_info, dict): - return self.file_info["size"] - return 0 + return self.file_info["size"] if isinstance(self.file_info, dict) else 0 async def send_body(self, send: "Send", receive: "Receive") -> None: """Emit a stream of events correlating with the response body. diff --git a/starlite/response/template.py b/starlite/response/template.py index 20fb401d6e..873101c97a 100644 --- a/starlite/response/template.py +++ b/starlite/response/template.py @@ -53,7 +53,7 @@ def __init__( if media_type == MediaType.JSON: # we assume this is the default suffixes = PurePath(template_name).suffixes for suffix in suffixes: - if _type := guess_type("name" + suffix)[0]: + if _type := guess_type(f"name{suffix}")[0]: media_type = _type break else: diff --git a/starlite/routes/base.py b/starlite/routes/base.py index 8a8a82d6a7..1e74a892db 100644 --- a/starlite/routes/base.py +++ b/starlite/routes/base.py @@ -161,8 +161,7 @@ def _parse_path(cls, path: str) -> tuple[str, str, list[str | PathParameterDefin components = [component for component in path.split("/") if component] for component in components: - param_match = param_match_regex.fullmatch(component) - if param_match: + if param_match := param_match_regex.fullmatch(component): param = param_match.group(1) cls._validate_path_parameter(param) param_name, param_type = (p.strip() for p in param.split(":")) diff --git a/starlite/routes/http.py b/starlite/routes/http.py index 581644c42a..e621d8a8d2 100644 --- a/starlite/routes/http.py +++ b/starlite/routes/http.py @@ -200,16 +200,15 @@ async def _get_response_data( if cleanup_group: async with cleanup_group: - if route_handler.has_sync_callable: - data = route_handler.fn.value(**parsed_kwargs) - else: - data = await route_handler.fn.value(**parsed_kwargs) - + data = ( + route_handler.fn.value(**parsed_kwargs) + if route_handler.has_sync_callable + else await route_handler.fn.value(**parsed_kwargs) + ) + elif route_handler.has_sync_callable: + data = route_handler.fn.value(**parsed_kwargs) else: - if route_handler.has_sync_callable: - data = route_handler.fn.value(**parsed_kwargs) - else: - data = await route_handler.fn.value(**parsed_kwargs) + data = await route_handler.fn.value(**parsed_kwargs) return data, cleanup_group diff --git a/starlite/stores/file.py b/starlite/stores/file.py index f681a6bbdf..d91adad575 100644 --- a/starlite/stores/file.py +++ b/starlite/stores/file.py @@ -56,7 +56,9 @@ async def _load_from_path(path: Path) -> StorageObject | None: def _write_sync(self, target_file: Path, storage_obj: StorageObject) -> None: try: - tmp_file_fd, tmp_file_name = mkstemp(dir=self.path, prefix=target_file.name + ".tmp") + tmp_file_fd, tmp_file_name = mkstemp( + dir=self.path, prefix=f"{target_file.name}.tmp" + ) renamed = False try: try: diff --git a/starlite/stores/redis.py b/starlite/stores/redis.py index 5fdfb00bfe..883d3de2f8 100644 --- a/starlite/stores/redis.py +++ b/starlite/stores/redis.py @@ -173,6 +173,4 @@ async def expires_in(self, key: str) -> int | None: expiry time was set, return ``None``. """ ttl = await self._redis.ttl(self._make_key(key)) - if ttl == -2: - return None - return ttl + return None if ttl == -2 else ttl diff --git a/starlite/utils/deprecation.py b/starlite/utils/deprecation.py index 472739696f..b9b8ab3b35 100644 --- a/starlite/utils/deprecation.py +++ b/starlite/utils/deprecation.py @@ -50,9 +50,12 @@ def warn_deprecation( else: parts.append(f"{access_type} deprecated {kind} {deprecated_name!r}") - parts.append(f"Deprecated in starlite {version}") - parts.append(f"This {kind} will be removed in {removal_in or 'the next major version'}") - + parts.extend( + ( + f"Deprecated in starlite {version}", + f"This {kind} will be removed in {removal_in or 'the next major version'}", + ) + ) if alternative: parts.append(f"Use {alternative!r} instead") diff --git a/starlite/utils/path.py b/starlite/utils/path.py index 69b03ae4ef..76b43affc9 100644 --- a/starlite/utils/path.py +++ b/starlite/utils/path.py @@ -19,7 +19,7 @@ def normalize_path(path: str) -> str: Path string """ path = path.strip("/") - path = "/" + path + path = f"/{path}" return multi_slash_pattern.sub("/", path) diff --git a/starlite/utils/predicates.py b/starlite/utils/predicates.py index 925d018303..b3b4044c24 100644 --- a/starlite/utils/predicates.py +++ b/starlite/utils/predicates.py @@ -252,7 +252,7 @@ def is_pydantic_model_class(annotation: Any) -> "TypeGuard[Type[BaseModel]]": # return False # pragma: no cover -def is_pydantic_model_instance(annotation: Any) -> "TypeGuard[BaseModel]": # pyright: ignore +def is_pydantic_model_instance(annotation: Any) -> "TypeGuard[BaseModel]": # pyright: ignore """Given a type annotation determine if the annotation is an instance of pydantic's BaseModel. Args: @@ -261,6 +261,4 @@ def is_pydantic_model_instance(annotation: Any) -> "TypeGuard[BaseModel]": # py Returns: A typeguard determining whether the type is :data:`BaseModel pydantic.BaseModel>`. """ - if BaseModel is not Empty: # type: ignore[comparison-overlap] - return isinstance(annotation, BaseModel) - return False # pragma: no cover + return isinstance(annotation, BaseModel) if BaseModel is not Empty else False diff --git a/starlite/utils/sequence.py b/starlite/utils/sequence.py index c536221964..80b5f05b29 100644 --- a/starlite/utils/sequence.py +++ b/starlite/utils/sequence.py @@ -13,10 +13,9 @@ def find_index(target_list: list[T], predicate: Callable[[T], bool]) -> int: List elements can be dicts or classes """ - for i, element in enumerate(target_list): - if predicate(element): - return i - return -1 + return next( + (i for i, element in enumerate(target_list) if predicate(element)), -1 + ) def unique(value: Iterable[T]) -> list[T]: diff --git a/tests/cli/test_core_commands.py b/tests/cli/test_core_commands.py index 7655d615b3..0b7724f927 100644 --- a/tests/cli/test_core_commands.py +++ b/tests/cli/test_core_commands.py @@ -47,7 +47,7 @@ def test_run_command( args = ["run"] if custom_app_file: - args[0:0] = ["--app", f"{custom_app_file.stem}:app"] + args[:0] = ["--app", f"{custom_app_file.stem}:app"] if reload: if set_in_env: @@ -73,14 +73,13 @@ def test_run_command( else: host = "127.0.0.1" - if web_concurrency is not None: - if set_in_env: - monkeypatch.setenv("WEB_CONCURRENCY", str(web_concurrency)) - else: - args.extend(["--web-concurrency", str(web_concurrency)]) - else: + if web_concurrency is None: web_concurrency = 1 + elif set_in_env: + monkeypatch.setenv("WEB_CONCURRENCY", str(web_concurrency)) + else: + args.extend(["--web-concurrency", str(web_concurrency)]) path = create_app_file(custom_app_file or "app.py") result = runner.invoke(cli_command, args) diff --git a/tests/connection/websocket/test_websocket.py b/tests/connection/websocket/test_websocket.py index 8d90cc0904..b83e4aacb3 100644 --- a/tests/connection/websocket/test_websocket.py +++ b/tests/connection/websocket/test_websocket.py @@ -158,7 +158,7 @@ async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: websocket = WebSocket[Any, Any, Any](scope, receive=receive, send=send) await websocket.accept() data = await websocket.receive_text() - await websocket.send_text("Message was: " + data) + await websocket.send_text(f"Message was: {data}") await websocket.close() with TestClient(app).websocket_connect("/") as websocket: diff --git a/tests/contrib/htmx/test_htmx_request.py b/tests/contrib/htmx/test_htmx_request.py index 6c5c99154e..1118bd4126 100644 --- a/tests/contrib/htmx/test_htmx_request.py +++ b/tests/contrib/htmx/test_htmx_request.py @@ -99,7 +99,7 @@ def handler(request: HTMXRequest) -> Response: "/", headers={ HTMXHeaders.CURRENT_URL.value: "https%3A%2F%2Fexample.com%2F%3F", - HTMXHeaders.CURRENT_URL.value + "-URI-AutoEncoded": "true", + f"{HTMXHeaders.CURRENT_URL.value}-URI-AutoEncoded": "true", }, ) assert response.text == '"https://example.com/?"' @@ -283,7 +283,7 @@ def handler(request: HTMXRequest) -> Response: "/", headers={ HTMXHeaders.TRIGGERING_EVENT.value: "%7B%22target%22%3A%20null%7D", - HTMXHeaders.TRIGGERING_EVENT.value + "-uri-autoencoded": "true", + f"{HTMXHeaders.TRIGGERING_EVENT.value}-uri-autoencoded": "true", }, ) assert response.text == '{"target":null}' diff --git a/tests/contrib/sqlalchemy/repository/test_sqlalchemy_aiosqlite.py b/tests/contrib/sqlalchemy/repository/test_sqlalchemy_aiosqlite.py index 8fe2a90f93..cb58afe255 100644 --- a/tests/contrib/sqlalchemy/repository/test_sqlalchemy_aiosqlite.py +++ b/tests/contrib/sqlalchemy/repository/test_sqlalchemy_aiosqlite.py @@ -291,13 +291,12 @@ async def test_repo_delete_many_method(author_repo: AuthorRepository) -> None: Args: author_repo (AuthorRepository): The author mock repository """ - data_to_insert = [] - for chunk in range(0, 1000): - data_to_insert.append( - Author( - name="author name %d" % chunk, - ) + data_to_insert = [ + Author( + name="author name %d" % chunk, ) + for chunk in range(0, 1000) + ] _ = await author_repo.add_many(data_to_insert) all_objs = await author_repo.list() ids_to_delete = [existing_obj.id for existing_obj in all_objs] diff --git a/tests/contrib/sqlalchemy/repository/test_sqlalchemy_asyncpg.py b/tests/contrib/sqlalchemy/repository/test_sqlalchemy_asyncpg.py index 70d79be534..b3658655d7 100644 --- a/tests/contrib/sqlalchemy/repository/test_sqlalchemy_asyncpg.py +++ b/tests/contrib/sqlalchemy/repository/test_sqlalchemy_asyncpg.py @@ -65,7 +65,7 @@ async def wait_until_responsive(check: Callable[..., Awaitable], timeout: float, """ ref = timeit.default_timer() now = ref - while (now - ref) < timeout: + while now - now < timeout: if await check(**kwargs): return await asyncio.sleep(pause) @@ -360,13 +360,12 @@ async def test_repo_delete_many_method(author_repo: AuthorRepository) -> None: Args: author_repo (AuthorRepository): The author mock repository """ - data_to_insert = [] - for chunk in range(0, 1000): - data_to_insert.append( - Author( - name="author name %d" % chunk, - ) + data_to_insert = [ + Author( + name="author name %d" % chunk, ) + for chunk in range(0, 1000) + ] _ = await author_repo.add_many(data_to_insert) all_objs = await author_repo.list() ids_to_delete = [existing_obj.id for existing_obj in all_objs] diff --git a/tests/datastructures/test_state.py b/tests/datastructures/test_state.py index c99ca88df2..31d32c080d 100644 --- a/tests/datastructures/test_state.py +++ b/tests/datastructures/test_state.py @@ -13,7 +13,7 @@ def test_state_immutable_mapping(state_class: Type[ImmutableState]) -> None: assert len(state) == 3 assert "first" in state assert state["first"] == 1 - assert [(k, v) for k, v in state.items()] == [("first", 1), ("second", 2), ("third", 3)] + assert list(state.items()) == [("first", 1), ("second", 2), ("third", 3)] assert state assert isinstance(state.mutable_copy(), State) del state_dict["first"] diff --git a/tests/dependency_injection/test_http_handler_dependency_injection.py b/tests/dependency_injection/test_http_handler_dependency_injection.py index d50516cd1e..02dd6b7e5b 100644 --- a/tests/dependency_injection/test_http_handler_dependency_injection.py +++ b/tests/dependency_injection/test_http_handler_dependency_injection.py @@ -56,7 +56,7 @@ class FirstController(Controller): def test_method(self, first: int, second: dict, third: bool) -> None: assert isinstance(first, int) assert isinstance(second, dict) - assert third is False + assert not third def test_controller_dependency_injection() -> None: @@ -73,15 +73,15 @@ def test_controller_dependency_injection() -> None: def test_function_dependency_injection() -> None: @get( - path=test_path + "/{path_param:str}", - dependencies={ - "first": Provide(local_method_first_dependency), - "third": Provide(local_method_second_dependency), - }, - ) + path=test_path + "/{path_param:str}", + dependencies={ + "first": Provide(local_method_first_dependency), + "third": Provide(local_method_second_dependency), + }, + ) def test_function(first: int, second: bool, third: str) -> None: assert isinstance(first, int) - assert second is False + assert not second assert isinstance(third, str) with create_test_client( diff --git a/tests/dependency_injection/test_websocket_handler_dependency_injection.py b/tests/dependency_injection/test_websocket_handler_dependency_injection.py index 31e2edd981..1db305a278 100644 --- a/tests/dependency_injection/test_websocket_handler_dependency_injection.py +++ b/tests/dependency_injection/test_websocket_handler_dependency_injection.py @@ -60,7 +60,7 @@ async def test_method(self, socket: WebSocket, first: int, second: dict, third: assert socket assert isinstance(first, int) assert isinstance(second, dict) - assert third is False + assert not third await socket.close() @@ -78,19 +78,19 @@ def test_controller_dependency_injection() -> None: def test_function_dependency_injection() -> None: @websocket( - path=test_path + "/{path_param:str}", - dependencies={ - "first": Provide(local_method_first_dependency), - "third": Provide(local_method_second_dependency), - }, - ) + path=test_path + "/{path_param:str}", + dependencies={ + "first": Provide(local_method_first_dependency), + "third": Provide(local_method_second_dependency), + }, + ) async def test_function(socket: WebSocket, first: int, second: bool, third: str) -> None: await socket.accept() assert socket msg = await socket.receive_json() assert msg assert isinstance(first, int) - assert second is False + assert not second assert isinstance(third, str) await socket.close() diff --git a/tests/dto_factory/test_dto_factory_integration.py b/tests/dto_factory/test_dto_factory_integration.py index 0b653a84ad..e3fe08e948 100644 --- a/tests/dto_factory/test_dto_factory_integration.py +++ b/tests/dto_factory/test_dto_factory_integration.py @@ -58,7 +58,7 @@ def test_dto_factory(model: Any, exclude: list, field_mapping: dict, field_defin ) assert issubclass(dto, BaseModel) assert dto.__name__ == "MyDTO" - assert not any(excluded_key in dto.__fields__ for excluded_key in exclude) + assert all(excluded_key not in dto.__fields__ for excluded_key in exclude) assert all(remapped_key in dto.__fields__ for remapped_key in field_mapping.values()) special = dto.__fields__["special"] assert not special.allow_none diff --git a/tests/dto_factory/test_dto_factory_model_conversion.py b/tests/dto_factory/test_dto_factory_model_conversion.py index 33acec71c4..4d5d9fd9ce 100644 --- a/tests/dto_factory/test_dto_factory_model_conversion.py +++ b/tests/dto_factory/test_dto_factory_model_conversion.py @@ -42,11 +42,11 @@ class DTOModelFactory(ModelFactory[MyDTO]): # type: ignore for key in dto_instance.__fields__: # type: ignore if key not in MyDTO.dto_field_mapping: attribute_value = _get_attribute_value(model_instance, key) - assert attribute_value == dto_instance.__getattribute__(key) # type: ignore else: original_key = MyDTO.dto_field_mapping[key] attribute_value = _get_attribute_value(model_instance, original_key) - assert attribute_value == dto_instance.__getattribute__(key) # type: ignore + + assert attribute_value == dto_instance.__getattribute__(key) # type: ignore @pytest.mark.skipif(sys.version_info < (3, 9), reason="dataclasses behave differently in lower versions") diff --git a/tests/kwargs/test_generator_dependencies.py b/tests/kwargs/test_generator_dependencies.py index d29f148b64..8839771981 100644 --- a/tests/kwargs/test_generator_dependencies.py +++ b/tests/kwargs/test_generator_dependencies.py @@ -200,7 +200,7 @@ def test_generator_dependency_nested_error_during_cleanup( async def other_dependency(generator_dep: str) -> AsyncGenerator[str, None]: try: - yield generator_dep + ", world" + yield f"{generator_dep}, world" finally: cleanup_mock_no_raise() diff --git a/tests/kwargs/test_path_params.py b/tests/kwargs/test_path_params.py index 8d3e1f9297..8d1903305d 100644 --- a/tests/kwargs/test_path_params.py +++ b/tests/kwargs/test_path_params.py @@ -151,7 +151,7 @@ def handler(test: param_type_class) -> None: assert test == value with create_test_client(handler) as client: - response = client.get("/some/test/path/" + str(value)) + response = client.get(f"/some/test/path/{str(value)}") assert response.status_code == HTTP_200_OK diff --git a/tests/middleware/test_csrf_middleware.py b/tests/middleware/test_csrf_middleware.py index ceaadc413d..a1fa23e145 100644 --- a/tests/middleware/test_csrf_middleware.py +++ b/tests/middleware/test_csrf_middleware.py @@ -87,15 +87,15 @@ def test_unsafe_method_fails_without_csrf_header(method: str) -> None: def test_invalid_csrf_token() -> None: with create_test_client( - route_handlers=[get_handler, post_handler], csrf_config=CSRFConfig(secret="secret") - ) as client: + route_handlers=[get_handler, post_handler], csrf_config=CSRFConfig(secret="secret") + ) as client: response = client.get("/") assert response.status_code == HTTP_200_OK csrf_token: Optional[str] = response.cookies.get("csrftoken") assert csrf_token is not None - response = client.post("/", headers={"x-csrftoken": csrf_token + "invalid"}) + response = client.post("/", headers={"x-csrftoken": f"{csrf_token}invalid"}) assert response.status_code == HTTP_403_FORBIDDEN assert response.json() == {"detail": "CSRF token verification failed", "status_code": 403} @@ -173,14 +173,14 @@ def form_handler(data: dict = Body(media_type=RequestEncodingType.URL_ENCODED)) return data with create_test_client( - route_handlers=[handler, form_handler], - template_config=TemplateConfig( - directory=template_dir, - engine=engine, - ), - csrf_config=CSRFConfig(secret=str(urandom(10))), - ) as client: - url = str(client.base_url) + "/" + route_handlers=[handler, form_handler], + template_config=TemplateConfig( + directory=template_dir, + engine=engine, + ), + csrf_config=CSRFConfig(secret=str(urandom(10))), + ) as client: + url = f"{str(client.base_url)}/" Path(template_dir / "abc.html").write_text( f'
{template}
' ) diff --git a/tests/template/test_template.py b/tests/template/test_template.py index 6d1080c195..e12d8433f5 100644 --- a/tests/template/test_template.py +++ b/tests/template/test_template.py @@ -92,7 +92,7 @@ def index() -> Template: ) @pytest.mark.skipif(sys.platform == "win32", reason="mimetypes.guess_types is unreliable on windows") def test_media_type_inferred(extension: str, expected_type: MediaType, template_dir: Path) -> None: - tpl_name = "hello" + extension + tpl_name = f"hello{extension}" (template_dir / tpl_name).write_text("hello") @get("/") diff --git a/tests/test_controller.py b/tests/test_controller.py index 5569676918..0653ad95d7 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -80,7 +80,7 @@ async def ws(self, socket: WebSocket) -> None: client = create_test_client(route_handlers=MyController) - with client.websocket_connect(test_path + "/socket") as ws: + with client.websocket_connect(f"{test_path}/socket") as ws: ws.send_json({"data": "123"}) data = ws.receive_json() assert data diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 2a62fa6e4b..7e73b3d269 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -29,9 +29,7 @@ def __init__(self, name: str) -> None: self.name = name def __eq__(self, __o: object) -> bool: - if isinstance(__o, type(self)): - return __o.name == self.name - return False + return __o.name == self.name if isinstance(__o, type(self)) else False class APydanticModel(BaseModel): diff --git a/tests/test_response_caching.py b/tests/test_response_caching.py index 5f09acdaef..fbef40cd35 100644 --- a/tests/test_response_caching.py +++ b/tests/test_response_caching.py @@ -91,7 +91,7 @@ async def handler() -> str: @pytest.mark.parametrize("sync_to_thread", (True, False)) async def test_custom_cache_key(sync_to_thread: bool, anyio_backend: str, mock: MagicMock) -> None: def custom_cache_key_builder(request: Request) -> str: - return request.url.path + ":::cached" + return f"{request.url.path}:::cached" @get("/cached", sync_to_thread=sync_to_thread, cache=True, cache_key_builder=custom_cache_key_builder) async def handler() -> str: diff --git a/tests/test_stores.py b/tests/test_stores.py index 93bd3a0f2c..da068ac90c 100644 --- a/tests/test_stores.py +++ b/tests/test_stores.py @@ -122,7 +122,7 @@ async def test_delete_all(store: Store) -> None: await store.delete_all() - assert not any([await store.get(key) for key in keys]) + assert not any(await store.get(key) for key in keys) async def test_expires_in(store: Store) -> None: @@ -132,7 +132,7 @@ async def test_expires_in(store: Store) -> None: assert await store.expires_in("foo") == -1 await store.set("foo", "bar", expires_in=10) - assert math.ceil(await store.expires_in("foo") / 10) * 10 == 10 # type: ignore[operator] + assert math.ceil(await store.expires_in("foo") / 10) == 1 @patch("starlite.stores.redis.Redis") @@ -173,7 +173,7 @@ async def test_redis_delete_all(redis_store: RedisStore) -> None: await redis_store.delete_all() - assert not any([await redis_store.get(key) for key in keys]) + assert not any(await redis_store.get(key) for key in keys) assert await redis_store._redis.get("test_key") == b"test_value" # check it doesn't delete other values @@ -221,7 +221,7 @@ def test_file_with_namespace(file_store: FileStore) -> None: @pytest.mark.parametrize("invalid_char", string.punctuation) def test_file_with_namespace_invalid_namespace_char(file_store: FileStore, invalid_char: str) -> None: with pytest.raises(ValueError): - file_store.with_namespace("foo" + invalid_char) + file_store.with_namespace(f"foo{invalid_char}") @pytest.fixture(params=["redis_store", "file_store"]) @@ -275,8 +275,8 @@ async def test_memory_delete_expired(store_fixture: str, request: FixtureRequest await anyio.sleep(0.002) await store.delete_expired() - assert not any([await store.exists(key) for key in expect_expired]) - assert all([await store.exists(key) for key in expect_not_expired]) + assert not any(await store.exists(key) for key in expect_expired) + assert all(await store.exists(key) for key in expect_not_expired) def test_registry_get(memory_store: MemoryStore) -> None: diff --git a/tests/testing/test_testing.py b/tests/testing/test_testing.py index 6f30fcb8d4..ea5e9aa00b 100644 --- a/tests/testing/test_testing.py +++ b/tests/testing/test_testing.py @@ -25,7 +25,7 @@ def test_request_factory_no_cookie_header() -> None: headers: Dict[str, str] = {} RequestFactory._create_cookie_header(headers) - assert headers == {} + assert not headers def test_request_factory_str_cookie_header() -> None: diff --git a/tools/sphinx_ext/__init__.py b/tools/sphinx_ext/__init__.py index b3450b8213..694a80ad27 100644 --- a/tools/sphinx_ext/__init__.py +++ b/tools/sphinx_ext/__init__.py @@ -10,7 +10,7 @@ def setup(app: Sphinx) -> dict[str, bool]: ext_config = {} - ext_config.update(run_examples.setup(app)) + ext_config |= run_examples.setup(app) ext_config.update(missing_references.setup(app)) ext_config.update(changelog.setup(app)) diff --git a/tools/sphinx_ext/missing_references.py b/tools/sphinx_ext/missing_references.py index b035f17095..dd43dfbdb9 100644 --- a/tools/sphinx_ext/missing_references.py +++ b/tools/sphinx_ext/missing_references.py @@ -54,9 +54,9 @@ def on_warn_missing_reference(app: Sphinx, domain: str, node: Node) -> bool | No attributes = node.attributes # type: ignore[attr-defined] target = attributes["reftarget"] - reference_target_source_obj = attributes.get("py:class", attributes.get("py:meth", attributes.get("py:func"))) - - if reference_target_source_obj: + if reference_target_source_obj := attributes.get( + "py:class", attributes.get("py:meth", attributes.get("py:func")) + ): global_names = get_module_global_imports(attributes["py:module"], reference_target_source_obj) if target in global_names: diff --git a/tools/sphinx_ext/run_examples.py b/tools/sphinx_ext/run_examples.py index ed3b24b806..4c3a771fda 100644 --- a/tools/sphinx_ext/run_examples.py +++ b/tools/sphinx_ext/run_examples.py @@ -117,7 +117,7 @@ def exec_examples(app_file: Path, run_configs: list[list[str]]) -> str: logger.error(f"Example: {app_file}:{args} yielded no results") continue - result = "\n".join(line for line in ("> " + (" ".join(clean_args)), *stdout)) + result = "\n".join(("> " + (" ".join(clean_args)), *stdout)) results.append(result) return "\n".join(results) @@ -141,7 +141,7 @@ def run(self) -> list[Node]: tmp_file = self.env.tmp_examples_path / str(file.relative_to(docs_dir)).replace("/", "_") - self.arguments[0] = "/" + str(tmp_file.relative_to(docs_dir)) + self.arguments[0] = f"/{str(tmp_file.relative_to(docs_dir))}" tmp_file.write_text(clean_content) nodes = super().run()