diff --git a/packages/exchange/src/exchange/observers/__init__.py b/packages/exchange/src/exchange/observers/__init__.py index deb917a0..fd70ae86 100644 --- a/packages/exchange/src/exchange/observers/__init__.py +++ b/packages/exchange/src/exchange/observers/__init__.py @@ -1,16 +1,20 @@ -from typing import Callable from functools import wraps +from typing import Callable from exchange.observers.base import ObserverManager -def observe_wrapper(*args, **kwargs) -> Callable: + +def observe_wrapper(*args, **kwargs) -> Callable: # noqa: ANN002, ANN003 """Decorator to wrap a function with all registered observer plugins, dynamically fetched.""" - def wrapper(func): + + def wrapper(func: Callable) -> Callable: @wraps(func) - def dynamic_wrapped(*func_args, **func_kwargs): + def dynamic_wrapped(*func_args, **func_kwargs) -> Callable: # noqa: ANN002, ANN003 wrapped = func for observer in ObserverManager.get_instance()._observers: wrapped = observer.observe_wrapper(*args, **kwargs)(wrapped) return wrapped(*func_args, **func_kwargs) + return dynamic_wrapped - return wrapper \ No newline at end of file + + return wrapper diff --git a/packages/exchange/src/exchange/observers/base.py b/packages/exchange/src/exchange/observers/base.py index f987b5b9..5adfb417 100644 --- a/packages/exchange/src/exchange/observers/base.py +++ b/packages/exchange/src/exchange/observers/base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, Type + class Observer(ABC): @abstractmethod @@ -7,19 +8,20 @@ def initialize(self) -> None: pass @abstractmethod - def observe_wrapper(*args, **kwargs) -> Callable: + def observe_wrapper(*args, **kwargs) -> Callable: # noqa: ANN002, ANN003 pass @abstractmethod def finalize(self) -> None: pass + class ObserverManager: _instance = None _observers: list[Observer] = [] @classmethod - def get_instance(cls): + def get_instance(cls: Type["ObserverManager"]) -> "ObserverManager": if cls._instance is None: cls._instance = cls() return cls._instance diff --git a/packages/exchange/src/exchange/observers/langfuse.py b/packages/exchange/src/exchange/observers/langfuse.py index 90119268..f87d480e 100644 --- a/packages/exchange/src/exchange/observers/langfuse.py +++ b/packages/exchange/src/exchange/observers/langfuse.py @@ -67,12 +67,12 @@ def initialize_with_disabled_tracing(self) -> None: langfuse_context.configure(enabled=False) self.tracing = False - def observe_wrapper(self, *args, **kwargs) -> Callable: + def observe_wrapper(self, *args, **kwargs) -> Callable: # noqa: ANN002, ANN003 def _wrapper(fn: Callable) -> Callable: if self.tracing and auth_check(): @wraps(fn) - def wrapped_fn(*fargs, **fkwargs) -> Callable: + def wrapped_fn(*fargs, **fkwargs) -> Callable: # noqa: ANN002, ANN003 # group all traces under the same session if fn.__name__ == "reply": langfuse_context.update_current_trace(session_id=fargs[0].name)