Skip to content

Commit

Permalink
Fix mypy and lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
chriskuehl committed Nov 12, 2024
1 parent c12d4c4 commit ebfef59
Show file tree
Hide file tree
Showing 35 changed files with 313 additions and 256 deletions.
78 changes: 38 additions & 40 deletions baseplate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import logging
import os
import random
from collections.abc import Iterator
from contextlib import contextmanager
from types import TracebackType
from typing import Any, Callable, NamedTuple, Optional
from typing import Any, Callable, NamedTuple, Optional, Tuple

import gevent.monkey
from pkg_resources import DistributionNotFound, get_distribution
Expand All @@ -24,7 +26,7 @@
class BaseplateObserver:
"""Interface for an observer that watches Baseplate."""

def on_server_span_created(self, context: "RequestContext", server_span: "ServerSpan") -> None:
def on_server_span_created(self, context: RequestContext, server_span: ServerSpan) -> None:
"""Do something when a server span is created.
:py:class:`Baseplate` calls this when a new request begins.
Expand All @@ -37,7 +39,7 @@ def on_server_span_created(self, context: "RequestContext", server_span: "Server
raise NotImplementedError


_ExcInfo = tuple[Optional[type[BaseException]], Optional[BaseException], Optional[TracebackType]]
_ExcInfo = Tuple[Optional[type[BaseException]], Optional[BaseException], Optional[TracebackType]]


class SpanObserver:
Expand All @@ -55,15 +57,15 @@ def on_incr_tag(self, key: str, delta: float) -> None:
def on_log(self, name: str, payload: Any) -> None:
"""Do something when a log entry is added to the span."""

def on_finish(self, exc_info: Optional[_ExcInfo]) -> None:
def on_finish(self, exc_info: _ExcInfo | None) -> None:
"""Do something when the observed span is finished.
:param exc_info: If the span ended because of an exception, the
exception info. Otherwise, :py:data:`None`.
"""

def on_child_span_created(self, span: "Span") -> None:
def on_child_span_created(self, span: Span) -> None:
"""Do something when a child span is created.
:py:class:`SpanObserver` objects call this when a new child span is
Expand Down Expand Up @@ -91,19 +93,19 @@ class TraceInfo(NamedTuple):
trace_id: str

#: The ID of the parent span, or None if this is the root span.
parent_id: Optional[str]
parent_id: str | None

#: The ID of the current span. Should be unique within a trace.
span_id: str

#: True if this trace was selected for sampling. Will be propagated to child spans.
sampled: Optional[bool]
sampled: bool | None

#: A bit field of extra flags about this trace.
flags: Optional[int]
flags: int | None

@classmethod
def new(cls) -> "TraceInfo":
def new(cls) -> TraceInfo:
"""Generate IDs for a new initial server span.
This span has no parent and has a random ID. It cannot be correlated
Expand All @@ -117,11 +119,11 @@ def new(cls) -> "TraceInfo":
def from_upstream(
cls,
trace_id: str,
parent_id: Optional[str],
parent_id: str | None,
span_id: str,
sampled: Optional[bool],
flags: Optional[int],
) -> "TraceInfo":
sampled: bool | None,
flags: int | None,
) -> TraceInfo:
"""Build a TraceInfo from individual headers.
:param trace_id: The ID of the trace.
Expand Down Expand Up @@ -169,9 +171,9 @@ class RequestContext:
def __init__(
self,
context_config: dict[str, Any],
prefix: Optional[str] = None,
span: Optional["Span"] = None,
wrapped: Optional["RequestContext"] = None,
prefix: str | None = None,
span: Span | None = None,
wrapped: RequestContext | None = None,
):
self.__context_config = context_config
self.__prefix = prefix
Expand Down Expand Up @@ -216,7 +218,7 @@ def __getattr__(self, name: str) -> Any:
def __setattr__(self, name: str, value: Any) -> None:
super().__setattr__(name, value)

def clone(self) -> "RequestContext":
def clone(self) -> RequestContext:
return RequestContext(
context_config=self.__context_config,
prefix=self.__prefix,
Expand All @@ -241,7 +243,7 @@ class Baseplate:
"""

def __init__(self, app_config: Optional[config.RawConfig] = None) -> None:
def __init__(self, app_config: config.RawConfig | None = None) -> None:
"""Initialize the core observability framework.
:param app_config: The raw configuration dictionary for your
Expand All @@ -266,7 +268,7 @@ def __init__(self, app_config: Optional[config.RawConfig] = None) -> None:
"""
self.observers: list[BaseplateObserver] = []
self._metrics_client: Optional[metrics.Client] = None
self._metrics_client: metrics.Client | None = None
self._context_config: dict[str, Any] = {}
self._app_config = app_config or {}

Expand Down Expand Up @@ -432,8 +434,8 @@ def make_context_object(self) -> RequestContext:
return RequestContext(self._context_config)

def make_server_span(
self, context: RequestContext, name: str, trace_info: Optional[TraceInfo] = None
) -> "ServerSpan":
self, context: RequestContext, name: str, trace_info: TraceInfo | None = None
) -> ServerSpan:
"""Return a server span representing the request we are handling.
In a server, a server span represents the time spent on a single
Expand Down Expand Up @@ -500,7 +502,7 @@ def server_context(self, name: str) -> Iterator[RequestContext]:
yield context

def get_runtime_metric_reporters(self) -> dict[str, Callable[[Any], None]]:
specs: list[tuple[Optional[str], dict[str, Any]]] = [(None, self._context_config)]
specs: list[tuple[str | None, dict[str, Any]]] = [(None, self._context_config)]
result = {}
while specs:
prefix, spec = specs.pop(0)
Expand All @@ -523,13 +525,13 @@ class Span:
def __init__(
self,
trace_id: str,
parent_id: Optional[str],
parent_id: str | None,
span_id: str,
sampled: Optional[bool],
flags: Optional[int],
sampled: bool | None,
flags: int | None,
name: str,
context: RequestContext,
baseplate: Optional[Baseplate] = None,
baseplate: Baseplate | None = None,
):
self.trace_id = trace_id
self.parent_id = parent_id
Expand All @@ -539,7 +541,7 @@ def __init__(
self.name = name
self.context = context
self.baseplate = baseplate
self.component_name: Optional[str] = None
self.component_name: str | None = None
self.observers: list[SpanObserver] = []

def register(self, observer: SpanObserver) -> None:
Expand Down Expand Up @@ -592,7 +594,7 @@ def incr_tag(self, key: str, delta: float = 1) -> None:
for observer in self.observers:
observer.on_incr_tag(key, delta)

def log(self, name: str, payload: Optional[Any] = None) -> None:
def log(self, name: str, payload: Any | None = None) -> None:
"""Add a log entry to the span.
Log entries are timestamped events recording notable moments in the
Expand All @@ -606,7 +608,7 @@ def log(self, name: str, payload: Optional[Any] = None) -> None:
for observer in self.observers:
observer.on_log(name, payload)

def finish(self, exc_info: Optional[_ExcInfo] = None) -> None:
def finish(self, exc_info: _ExcInfo | None = None) -> None:
"""Record the end of the span.
:param exc_info: If the span ended because of an exception, this is
Expand All @@ -624,28 +626,26 @@ def finish(self, exc_info: Optional[_ExcInfo] = None) -> None:
self.context = None # type: ignore
self.observers.clear()

def __enter__(self) -> "Span":
def __enter__(self) -> Span:
self.start()
return self

def __exit__(
self,
exc_type: Optional[type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
exc_type: type[BaseException] | None,
value: BaseException | None,
traceback: TracebackType | None,
) -> None:
if exc_type is not None:
self.finish(exc_info=(exc_type, value, traceback))
else:
self.finish()

def make_child(
self, name: str, local: bool = False, component_name: Optional[str] = None
) -> "Span":
def make_child(self, name: str, local: bool = False, component_name: str | None = None) -> Span:
"""Return a child Span whose parent is this Span."""
raise NotImplementedError

def with_tags(self, tags: dict[str, Any]) -> "Span":
def with_tags(self, tags: dict[str, Any]) -> Span:
"""Declare a set of tags to be added to a span before starting it in the context manager.
Can be used as follow:
Expand All @@ -668,9 +668,7 @@ def __init__(self) -> None:


class LocalSpan(Span):
def make_child(
self, name: str, local: bool = False, component_name: Optional[str] = None
) -> "Span":
def make_child(self, name: str, local: bool = False, component_name: str | None = None) -> Span:
"""Return a child Span whose parent is this Span.
The child span can either be a local span representing an in-request
Expand Down
38 changes: 20 additions & 18 deletions baseplate/clients/cassandra.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
import time
from collections.abc import Mapping, Sequence
Expand All @@ -6,8 +8,8 @@
TYPE_CHECKING,
Any,
Callable,
List,
NamedTuple,
Optional,
Union,
)

Expand Down Expand Up @@ -67,9 +69,9 @@ class CassandraPrometheusLabels(NamedTuple):

def cluster_from_config(
app_config: config.RawConfig,
secrets: Optional[SecretsStore] = None,
secrets: SecretsStore | None = None,
prefix: str = "cassandra.",
execution_profiles: Optional[dict[str, ExecutionProfile]] = None,
execution_profiles: dict[str, ExecutionProfile] | None = None,
**kwargs: Any,
) -> Cluster:
"""Make a Cluster from a configuration dictionary.
Expand Down Expand Up @@ -138,7 +140,7 @@ def __init__(self, keyspace: str, client_name: str = "", **kwargs: Any):
self.kwargs = kwargs
self.client_name = client_name

def parse(self, key_path: str, raw_config: config.RawConfig) -> "CassandraContextFactory":
def parse(self, key_path: str, raw_config: config.RawConfig) -> CassandraContextFactory:
cluster = cluster_from_config(raw_config, prefix=f"{key_path}.", **self.kwargs)
session = cluster.connect(keyspace=self.keyspace)

Expand Down Expand Up @@ -166,15 +168,15 @@ class CassandraContextFactory(ContextFactory):
def __init__(
self,
session: Session,
prometheus_client_name: Optional[str] = None,
prometheus_cluster_name: Optional[str] = None,
prometheus_client_name: str | None = None,
prometheus_cluster_name: str | None = None,
):
self.session = session
self.prepared_statements: dict[str, PreparedStatement] = {}
self.prometheus_client_name = prometheus_client_name
self.prometheus_cluster_name = prometheus_cluster_name

def make_object_for_context(self, name: str, span: Span) -> "CassandraSessionAdapter":
def make_object_for_context(self, name: str, span: Span) -> CassandraSessionAdapter:
return CassandraSessionAdapter(
name,
span,
Expand All @@ -201,7 +203,7 @@ def __init__(self, keyspace: str, **kwargs: Any):
self.keyspace = keyspace
self.kwargs = kwargs

def parse(self, key_path: str, raw_config: config.RawConfig) -> "CQLMapperContextFactory":
def parse(self, key_path: str, raw_config: config.RawConfig) -> CQLMapperContextFactory:
cluster = cluster_from_config(raw_config, prefix=f"{key_path}.", **self.kwargs)
session = cluster.connect(keyspace=self.keyspace)
return CQLMapperContextFactory(session)
Expand All @@ -221,7 +223,7 @@ class CQLMapperContextFactory(CassandraContextFactory):
"""

def make_object_for_context(self, name: str, span: Span) -> "cqlmapper.connection.Connection":
def make_object_for_context(self, name: str, span: Span) -> cqlmapper.connection.Connection:
# Import inline so you can still use the regular Cassandra integration
# without installing cqlmapper
# pylint: disable=redefined-outer-name
Expand Down Expand Up @@ -317,7 +319,7 @@ def _on_execute_failed(exc: BaseException, args: CassandraCallbackArgs, event: E
event.set()


RowFactory = Callable[[list[str], list[tuple]], Any]
RowFactory = Callable[[List[str], List[tuple]], Any]
Query = Union[str, SimpleStatement, PreparedStatement, BoundStatement]
Parameters = Union[Sequence[Any], Mapping[str, Any]]

Expand All @@ -329,8 +331,8 @@ def __init__(
server_span: Span,
session: Session,
prepared_statements: dict[str, PreparedStatement],
prometheus_client_name: Optional[str] = None,
prometheus_cluster_name: Optional[str] = None,
prometheus_client_name: str | None = None,
prometheus_cluster_name: str | None = None,
):
self.context_name = context_name
self.server_span = server_span
Expand All @@ -345,9 +347,9 @@ def __getattr__(self, name: str) -> Any:
def execute(
self,
query: Query,
parameters: Optional[Parameters] = None,
timeout: Union[float, object] = _NOT_SET,
query_name: Optional[str] = None,
parameters: Parameters | None = None,
timeout: float | object = _NOT_SET,
query_name: str | None = None,
**kwargs: Any,
) -> Any:
return self.execute_async(
Expand All @@ -357,9 +359,9 @@ def execute(
def execute_async(
self,
query: Query,
parameters: Optional[Parameters] = None,
timeout: Union[float, object] = _NOT_SET,
query_name: Optional[str] = None,
parameters: Parameters | None = None,
timeout: float | object = _NOT_SET,
query_name: str | None = None,
**kwargs: Any,
) -> ResponseFuture:
prom_labels = CassandraPrometheusLabels(
Expand Down
Loading

0 comments on commit ebfef59

Please sign in to comment.