diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ff5f841d..14d1b826 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -52,17 +52,23 @@ just test > [!NOTE] > This integration is experimental and we don't currently have integration tests for it. -Developers can use locally hosted Langfuse tracing by applying the custom `observe_wrapper` decorator defined in `packages/exchange/src/langfuse_wrapper.py` to functions for automatic integration with Langfuse. +Developers can use locally hosted Langfuse tracing by applying the custom `observe_wrapper` decorator defined in `packages/exchange/src/exchange/observers` to functions for automatic integration with Langfuse, and potentially other observability providers in the future. +- Add an `observers` array to your profile containing `langfuse`. - Run `just langfuse-server` to start your local Langfuse server. It requires Docker. - Go to http://localhost:3000 and log in with the default email/password output by the shell script (values can also be found in the `.env.langfuse.local` file). - Run Goose with the --tracing flag enabled i.e., `goose session start --tracing` - View your traces at http://localhost:3000 -To extend tracing to additional functions, import `from exchange.langfuse_wrapper import observe_wrapper` and use the `observe_wrapper()` decorator on functions you wish to enable tracing for. `observe_wrapper` functions the same way as Langfuse's observe decorator. +`To extend tracing to additional functions, import `from exchange.observers import observe_wrapper` and use the `observe_wrapper()` decorator on functions you wish to enable tracing for. `observe_wrapper` functions the same way as Langfuse's observe decorator. Read more about Langfuse's decorator-based tracing [here](https://langfuse.com/docs/sdk/python/decorators). +### Other observability plugins + +In case locally hosted Langfuse doesn't fit your needs, you can alternatively use other `observer` telemetry plugins to ingest data with the same interface as the Langfuse integration. +To do so, extend `packages/exchange/src/exchange/observers/base.py:Observer` and include the new plugin's path as an entrypoint in `exchange`'s `pyproject.toml`. + ## Exchange The lower level generation behind goose is powered by the [`exchange`][ai-exchange] package, also in this repo. diff --git a/README.md b/README.md index d0322907..553e5d55 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,7 @@ Read more about local Langfuse deployments [here](https://langfuse.com/docs/depl #### Exchange and Goose integration -Import `from exchange.langfuse_wrapper import observe_wrapper` and use the `observe_wrapper()` decorator on functions you wish to enable tracing for. `observe_wrapper` functions the same way as Langfuse's observe decorator. +Import `from exchange.observers import observe_wrapper`, include `langfuse` in the `observers` list of your profile, and use the `observe_wrapper()` decorator on functions you wish to enable tracing for. `observe_wrapper` functions the same way as Langfuse's observe decorator. Read more about Langfuse's decorator-based tracing [here](https://langfuse.com/docs/sdk/python/decorators). diff --git a/packages/exchange/pyproject.toml b/packages/exchange/pyproject.toml index e7c48f81..539be4e5 100644 --- a/packages/exchange/pyproject.toml +++ b/packages/exchange/pyproject.toml @@ -42,6 +42,9 @@ passive = "exchange.moderators.passive:PassiveModerator" truncate = "exchange.moderators.truncate:ContextTruncate" summarize = "exchange.moderators.summarizer:ContextSummarizer" +[project.entry-points."exchange.observer"] +langfuse = "exchange.observers.langfuse:LangfuseObserver" + [project.entry-points."metadata.plugins"] ai-exchange = "exchange:module_name" diff --git a/packages/exchange/src/exchange/exchange.py b/packages/exchange/src/exchange/exchange.py index 1974463d..599b6eb4 100644 --- a/packages/exchange/src/exchange/exchange.py +++ b/packages/exchange/src/exchange/exchange.py @@ -8,10 +8,10 @@ from exchange.checkpoint import Checkpoint, CheckpointData from exchange.content import Text, ToolResult, ToolUse -from exchange.langfuse_wrapper import observe_wrapper from exchange.message import Message from exchange.moderators import Moderator from exchange.moderators.truncate import ContextTruncate +from exchange.observers import observe_wrapper from exchange.providers import Provider, Usage from exchange.token_usage_collector import _token_usage_collector from exchange.tool import Tool diff --git a/packages/exchange/src/exchange/langfuse_wrapper.py b/packages/exchange/src/exchange/langfuse_wrapper.py deleted file mode 100644 index 9788bf46..00000000 --- a/packages/exchange/src/exchange/langfuse_wrapper.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -Langfuse Integration Module - -This module provides integration with Langfuse, a tool for monitoring and tracing LLM applications. - -Usage: - Import this module to enable Langfuse integration. - It automatically checks for Langfuse credentials in the .env.langfuse file and for a running Langfuse server. - If these are found, it will set up the necessary client and context for tracing. - -Note: - Run setup_langfuse.sh which automates the steps for running local Langfuse. -""" - -import os -from typing import Callable -from langfuse.decorators import langfuse_context -import sys -from io import StringIO -from functools import cache, wraps - -## These are the default configurations for local Langfuse server -## Please refer to .env.langfuse.local file for local langfuse server setup configurations -DEFAULT_LOCAL_LANGFUSE_HOST = "http://localhost:3000" -DEFAULT_LOCAL_LANGFUSE_PUBLIC_KEY = "publickey-local" -DEFAULT_LOCAL_LANGFUSE_SECRET_KEY = "secretkey-local" - - -@cache -def auth_check() -> bool: - # Temporarily redirect stdout and stderr to suppress print statements from Langfuse - temp_stderr = StringIO() - sys.stderr = temp_stderr - - # Set environment variables if not specified - os.environ.setdefault("LANGFUSE_PUBLIC_KEY", DEFAULT_LOCAL_LANGFUSE_PUBLIC_KEY) - os.environ.setdefault("LANGFUSE_SECRET_KEY", DEFAULT_LOCAL_LANGFUSE_SECRET_KEY) - os.environ.setdefault("LANGFUSE_HOST", DEFAULT_LOCAL_LANGFUSE_HOST) - - auth_val = langfuse_context.auth_check() - - # Restore stderr - sys.stderr = sys.__stderr__ - return auth_val - - -def observe_wrapper(*args, **kwargs) -> Callable: # noqa - """ - A decorator that wraps a function with Langfuse context observation if credentials are available. - - If Langfuse credentials were found, the function will be wrapped with Langfuse's observe method. - Otherwise, the function will be returned as-is. - - Args: - *args: Positional arguments to pass to langfuse_context.observe. - **kwargs: Keyword arguments to pass to langfuse_context.observe. - - Returns: - Callable: The wrapped function if credentials are available, otherwise the original function. - """ - - def _wrapper(fn: Callable) -> Callable: - if auth_check(): - - @wraps(fn) - def wrapped_fn(*fargs, **fkwargs): # noqa - return langfuse_context.observe(*args, **kwargs)(fn)(*fargs, **fkwargs) - - return wrapped_fn - else: - return fn - - return _wrapper diff --git a/packages/exchange/src/exchange/observers/__init__.py b/packages/exchange/src/exchange/observers/__init__.py new file mode 100644 index 00000000..fd70ae86 --- /dev/null +++ b/packages/exchange/src/exchange/observers/__init__.py @@ -0,0 +1,20 @@ +from functools import wraps +from typing import Callable + +from exchange.observers.base import ObserverManager + + +def observe_wrapper(*args, **kwargs) -> Callable: # noqa: ANN002, ANN003 + """Decorator to wrap a function with all registered observer plugins, dynamically fetched.""" + + def wrapper(func: Callable) -> Callable: + @wraps(func) + def dynamic_wrapped(*func_args, **func_kwargs) -> Callable: # noqa: ANN002, ANN003 + wrapped = func + for observer in ObserverManager.get_instance()._observers: + wrapped = observer.observe_wrapper(*args, **kwargs)(wrapped) + return wrapped(*func_args, **func_kwargs) + + return dynamic_wrapped + + return wrapper diff --git a/packages/exchange/src/exchange/observers/base.py b/packages/exchange/src/exchange/observers/base.py new file mode 100644 index 00000000..5adfb417 --- /dev/null +++ b/packages/exchange/src/exchange/observers/base.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from typing import Callable, Type + + +class Observer(ABC): + @abstractmethod + def initialize(self) -> None: + pass + + @abstractmethod + def observe_wrapper(*args, **kwargs) -> Callable: # noqa: ANN002, ANN003 + pass + + @abstractmethod + def finalize(self) -> None: + pass + + +class ObserverManager: + _instance = None + _observers: list[Observer] = [] + + @classmethod + def get_instance(cls: Type["ObserverManager"]) -> "ObserverManager": + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def initialize(self, tracing: bool, observers: list[Observer]) -> None: + from exchange.observers.langfuse import LangfuseObserver + + self._observers = observers + for observer in self._observers: + # LangfuseObserver has special behavior when tracing is _dis_abled. + # Consider refactoring to make this less special-casey if that's common. + if isinstance(observer, LangfuseObserver) and not tracing: + observer.initialize_with_disabled_tracing() + elif tracing: + observer.initialize() + + def finalize(self) -> None: + for observer in self._observers: + observer.finalize() diff --git a/packages/exchange/src/exchange/observers/langfuse.py b/packages/exchange/src/exchange/observers/langfuse.py new file mode 100644 index 00000000..028a5636 --- /dev/null +++ b/packages/exchange/src/exchange/observers/langfuse.py @@ -0,0 +1,98 @@ +""" +Langfuse Observer + +This observer provides integration with Langfuse, a tool for monitoring and tracing LLM applications. + +Usage: + Include "langfuse" in your profile's list of observers to enable Langfuse integration. + It automatically checks for Langfuse credentials in the .env.langfuse file and for a running Langfuse server. + If these are found, it will set up the necessary client and context for tracing. + +Note: + Run setup_langfuse.sh which automates the steps for running local Langfuse. +""" + +import logging +import os +import sys +from functools import cache, wraps +from io import StringIO +from typing import Callable + +from langfuse.decorators import langfuse_context + +from exchange.observers.base import Observer + +## These are the default configurations for local Langfuse server +## Please refer to .env.langfuse.local file for local langfuse server setup configurations +DEFAULT_LOCAL_LANGFUSE_HOST = "http://localhost:3000" +DEFAULT_LOCAL_LANGFUSE_PUBLIC_KEY = "publickey-local" +DEFAULT_LOCAL_LANGFUSE_SECRET_KEY = "secretkey-local" + + +@cache +def auth_check() -> bool: + # Temporarily redirect stdout and stderr to suppress print statements from Langfuse + temp_stderr = StringIO() + sys.stderr = temp_stderr + + # Set environment variables if not specified + os.environ.setdefault("LANGFUSE_PUBLIC_KEY", DEFAULT_LOCAL_LANGFUSE_PUBLIC_KEY) + os.environ.setdefault("LANGFUSE_SECRET_KEY", DEFAULT_LOCAL_LANGFUSE_SECRET_KEY) + os.environ.setdefault("LANGFUSE_HOST", DEFAULT_LOCAL_LANGFUSE_HOST) + + auth_val = langfuse_context.auth_check() + + # Restore stderr + sys.stderr = sys.__stderr__ + return auth_val + + +class LangfuseObserver(Observer): + def initialize(self) -> None: + langfuse_auth = auth_check() + if langfuse_auth: + print("Local Langfuse initialized. View your traces at http://localhost:3000") + else: + raise RuntimeError( + "You passed --tracing, but a Langfuse object was not found in the current context. " + "Please initialize the local Langfuse server and restart Goose." + ) + + langfuse_context.configure(enabled=True) + self.tracing = True + + def initialize_with_disabled_tracing(self) -> None: + logging.getLogger("langfuse").setLevel(logging.ERROR) + langfuse_context.configure(enabled=False) + self.tracing = False + + def session_id_wrapper(self, func: Callable, session_id) -> Callable: + @wraps(func) # This will preserve the metadata of 'func' + def wrapper(*args, **kwargs): + langfuse_context.update_current_trace(session_id=session_id) + return func(*args, **kwargs) + return wrapper + + def observe_wrapper(self, *args, **kwargs) -> Callable: # noqa: ANN002, ANN003 + def _wrapper(fn: Callable) -> Callable: + if self.tracing and auth_check(): + + @wraps(fn) + def wrapped_fn(*fargs, **fkwargs) -> Callable: # noqa: ANN002, ANN003 + # group all traces under the same session + if "session_id" in kwargs: + session_id_function = kwargs.pop("session_id") + session_id_value = session_id_function(fargs[0]) + modified_fn = self.session_id_wrapper(fn, session_id_value) + return langfuse_context.observe(*args, **kwargs)(modified_fn)(*fargs, **fkwargs) + else: + return langfuse_context.observe(*args, **kwargs)(fn)(*fargs, **fkwargs) + return wrapped_fn + else: + return fn + + return _wrapper + + def finalize(self) -> None: + langfuse_context.flush() diff --git a/packages/exchange/src/exchange/providers/anthropic.py b/packages/exchange/src/exchange/providers/anthropic.py index 9f4b72d7..59474c72 100644 --- a/packages/exchange/src/exchange/providers/anthropic.py +++ b/packages/exchange/src/exchange/providers/anthropic.py @@ -7,7 +7,7 @@ from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt from exchange.providers.utils import retry_if_status, raise_for_status -from exchange.langfuse_wrapper import observe_wrapper +from exchange.observers import observe_wrapper ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages" diff --git a/packages/exchange/src/exchange/providers/bedrock.py b/packages/exchange/src/exchange/providers/bedrock.py index cdc0c29c..f766c8d1 100644 --- a/packages/exchange/src/exchange/providers/bedrock.py +++ b/packages/exchange/src/exchange/providers/bedrock.py @@ -15,7 +15,7 @@ from tenacity import retry, wait_fixed, stop_after_attempt from exchange.providers.utils import raise_for_status, retry_if_status from exchange.tool import Tool -from exchange.langfuse_wrapper import observe_wrapper +from exchange.observers import observe_wrapper SERVICE = "bedrock-runtime" UTC = timezone.utc diff --git a/packages/exchange/src/exchange/providers/databricks.py b/packages/exchange/src/exchange/providers/databricks.py index 517ccee6..09564000 100644 --- a/packages/exchange/src/exchange/providers/databricks.py +++ b/packages/exchange/src/exchange/providers/databricks.py @@ -11,7 +11,7 @@ tools_to_openai_spec, ) from exchange.tool import Tool -from exchange.langfuse_wrapper import observe_wrapper +from exchange.observers import observe_wrapper retry_procedure = retry( wait=wait_fixed(2), diff --git a/packages/exchange/src/exchange/providers/google.py b/packages/exchange/src/exchange/providers/google.py index bfb1faf0..ad57396d 100644 --- a/packages/exchange/src/exchange/providers/google.py +++ b/packages/exchange/src/exchange/providers/google.py @@ -7,7 +7,7 @@ from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt from exchange.providers.utils import raise_for_status, retry_if_status, encode_image -from exchange.langfuse_wrapper import observe_wrapper +from exchange.observers import observe_wrapper GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta" diff --git a/packages/exchange/src/exchange/providers/groq.py b/packages/exchange/src/exchange/providers/groq.py index 0f6472f8..1ca61810 100644 --- a/packages/exchange/src/exchange/providers/groq.py +++ b/packages/exchange/src/exchange/providers/groq.py @@ -1,6 +1,6 @@ import os -from exchange.langfuse_wrapper import observe_wrapper +from exchange.observers import observe_wrapper import httpx from exchange.message import Message diff --git a/packages/exchange/src/exchange/providers/openai.py b/packages/exchange/src/exchange/providers/openai.py index 8701e542..61a92ffc 100644 --- a/packages/exchange/src/exchange/providers/openai.py +++ b/packages/exchange/src/exchange/providers/openai.py @@ -14,7 +14,7 @@ from exchange.tool import Tool from tenacity import retry, wait_fixed, stop_after_attempt from exchange.providers.utils import retry_if_status -from exchange.langfuse_wrapper import observe_wrapper +from exchange.observers import observe_wrapper OPENAI_HOST = "https://api.openai.com/" diff --git a/packages/exchange/tests/test_langfuse_wrapper.py b/packages/exchange/tests/test_langfuse_wrapper.py deleted file mode 100644 index 9ac304c5..00000000 --- a/packages/exchange/tests/test_langfuse_wrapper.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -from exchange.langfuse_wrapper import observe_wrapper - - -@pytest.fixture -def mock_langfuse_context(): - with patch("exchange.langfuse_wrapper.langfuse_context") as mock: - yield mock - - -@patch("exchange.langfuse_wrapper.auth_check") -def test_function_is_wrapped(mock_auth_check, mock_langfuse_context): - mock_observe = MagicMock(side_effect=lambda *args, **kwargs: lambda fn: fn) - mock_auth_check.return_value = True - mock_langfuse_context.observe = mock_observe - - def original_function(x: int, y: int) -> int: - return x + y - - # test function before we decorate it with - # @observe_wrapper("arg1", kwarg1="kwarg1") - assert not hasattr(original_function, "__wrapped__") - - # ensure we args get passed along (e.g. @observe(capture_input=False, capture_output=False)) - decorated_function = observe_wrapper("arg1", kwarg1="kwarg1")(original_function) - assert hasattr(decorated_function, "__wrapped__") - assert decorated_function.__wrapped__ is original_function, "Function is not properly wrapped" - - assert decorated_function(2, 3) == 5 - mock_observe.assert_called_once() - mock_observe.assert_called_with("arg1", kwarg1="kwarg1") - - -@patch("exchange.langfuse_wrapper.auth_check") -def test_function_is_not_wrapped(mock_auth_check, mock_langfuse_context): - mock_observe = MagicMock(return_value=lambda f: f) - mock_auth_check.return_value = False - mock_langfuse_context.observe = mock_observe - - @observe_wrapper("arg1", kwarg1="kwarg1") - def hello() -> str: - return "Hello" - - assert not hasattr(hello, "__wrapped__") - assert hello() == "Hello" - - mock_observe.assert_not_called() diff --git a/packages/exchange/tests/test_observer.py b/packages/exchange/tests/test_observer.py new file mode 100644 index 00000000..b0bbb915 --- /dev/null +++ b/packages/exchange/tests/test_observer.py @@ -0,0 +1,61 @@ +from exchange.observers import ObserverManager, observe_wrapper +from exchange.observers.base import Observer + + +class MockObserver(Observer): + def __init__(self): + self.initialized = False + self.args = None + self.kwargs = None + self.finalized = False + + def initialize(self): + pass + + def observe_wrapper(self, *args, **kwargs): + def wrapper(func): + self.args = args + self.kwargs = kwargs + return func + + return wrapper + + def finalize(self): + pass + + +def test_wrapper_is_invoked(): + manager = ObserverManager.get_instance() + mock_observer = MockObserver() + manager.initialize(True, [mock_observer]) + + @observe_wrapper("arg0", arg1="arg2") + def wrapped(x: int, y: int) -> int: + return x + y + + # code in decorator hasn't run yet + assert mock_observer.args is None + assert mock_observer.kwargs is None + + ret_val = wrapped(2, 3) + assert ret_val == 5 + + # decorator has been run since `wrapped` was called + assert mock_observer.args == ("arg0",) + assert mock_observer.kwargs == {"arg1": "arg2"} + + +def test_multiple_wrappers(): + manager = ObserverManager.get_instance() + mock_observer_1 = MockObserver() + mock_observer_2 = MockObserver() + manager.initialize(True, [mock_observer_1, mock_observer_2]) + + @observe_wrapper("arg0") + def wrapped(x: int, y: int) -> int: + return x + y + + wrapped(2, 3) + + assert mock_observer_1.args == ("arg0",) + assert mock_observer_2.args == ("arg0",) diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index e4173d9a..f4cc0ff5 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -1,12 +1,10 @@ -from datetime import datetime -import logging import traceback +from datetime import datetime from pathlib import Path from typing import Optional from exchange import Message, Text, ToolResult, ToolUse -from exchange.langfuse_wrapper import auth_check, observe_wrapper -from langfuse.decorators import langfuse_context +from exchange.observers import ObserverManager, observe_wrapper from rich import print from rich.markdown import Markdown from rich.panel import Panel @@ -79,23 +77,17 @@ def __init__( self.notifier = SessionNotifier(self.status_indicator) self.has_plan = plan is not None self.tracing = tracing - if not tracing: - logging.getLogger("langfuse").setLevel(logging.ERROR) - else: - langfuse_auth = auth_check() - if langfuse_auth: - print("Local Langfuse initialized. View your traces at http://localhost:3000") - else: - raise RuntimeError( - "You passed --tracing, but a Langfuse object was not found in the current context. " - "Please initialize the local Langfuse server and restart Goose." - ) - if self.tracing: - langfuse_context.configure(enabled=tracing) self.exchange = create_exchange(profile=load_profile(profile), notifier=self.notifier) setup_logging(log_file_directory=LOG_PATH, log_level=log_level) + all_observers = load_plugins(group="exchange.observer") + profile_observer_names = load_profile(profile).observers + observers_to_init = [all_observers[o.name]() for o in profile_observer_names if o.name in all_observers] + + self.observer_manager = ObserverManager.get_instance() + self.observer_manager.initialize(tracing=tracing, observers=observers_to_init) + self.exchange.messages.extend(self._get_initial_messages()) if len(self.exchange.messages) == 0 and plan: @@ -103,6 +95,10 @@ def __init__( self.prompt_session = GoosePromptSession() + def __del__(self) -> None: + if hasattr(self, "observer_manager"): + self.observer_manager.finalize() + def _get_initial_messages(self) -> list[Message]: messages = self.load_session() @@ -211,12 +207,9 @@ def run(self, new_session: bool = True) -> None: time_end = datetime.now() self._log_cost(start_time=time_start, end_time=time_end) - @observe_wrapper() + @observe_wrapper(session_id=lambda instance: instance.name) def reply(self) -> None: """Reply to the last user message, calling tools as needed""" - # group all traces under the same session - langfuse_context.update_current_trace(session_id=self.name) - # These are the *raw* messages, before the moderator rewrites things committed = [self.exchange.messages[-1]] diff --git a/src/goose/profile.py b/src/goose/profile.py index e6759ccb..343278ea 100644 --- a/src/goose/profile.py +++ b/src/goose/profile.py @@ -13,6 +13,13 @@ class ToolkitSpec: requires: Mapping[str, str] = field(factory=dict) +@define +class ObserverSpec: + """Configuration for an Observer (telemetry plugin)""" + + name: str + + @define class Profile: """The configuration for a run of goose""" @@ -22,6 +29,7 @@ class Profile: accelerator: str moderator: str toolkits: list[ToolkitSpec] = field(factory=list, converter=ensure_list(ToolkitSpec)) + observers: list[ObserverSpec] = field(factory=list, converter=ensure_list(ObserverSpec)) @toolkits.validator def check_toolkit_requirements(self, _: type["ToolkitSpec"], toolkits: list[ToolkitSpec]) -> None: @@ -40,8 +48,13 @@ def to_dict(self) -> dict[str, any]: return asdict(self) def profile_info(self) -> str: - tookit_names = [toolkit.name for toolkit in self.toolkits] - return f"provider:{self.provider}, processor:{self.processor} toolkits: {', '.join(tookit_names)}" + toolkit_names = [toolkit.name for toolkit in self.toolkits] + observer_names = [observer.name for observer in self.observers] + return ( + f"provider:{self.provider}, processor:{self.processor} " + f"toolkits: {', '.join(toolkit_names)} " + f"observers: {', '.join(observer_names)}" + ) def default_profile(provider: str, processor: str, accelerator: str, **kwargs: dict[str, any]) -> Profile: @@ -55,4 +68,5 @@ def default_profile(provider: str, processor: str, accelerator: str, **kwargs: d accelerator=accelerator, moderator="synopsis", toolkits=[ToolkitSpec("synopsis")], + observers=[ObserverSpec("langfuse")], ) diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index 6a4db544..d6cfb758 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -1,10 +1,11 @@ -from datetime import datetime import os +from datetime import datetime from typing import Union from unittest.mock import MagicMock, mock_open, patch import pytest from exchange import Message, ToolResult, ToolUse +from exchange.observers import ObserverManager from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.prompt.overwrite_session_prompt import OverwriteSessionPrompt from goose.cli.prompt.user_input import PromptAction, UserInput @@ -260,3 +261,22 @@ def check_overwrite_behavior(choice: str, expected_messages: list[Message]) -> N choice="r", expected_messages=[Message.user(text="duck duck"), Message.user(text="goose")], ) + + +def test_observer_plugin_called(create_session_with_mock_configs): + observer_mock = MagicMock() + observe_wrapper_mock = MagicMock() + observer_mock.observe_wrapper = observe_wrapper_mock + + observer_manager_mock = MagicMock(spec=ObserverManager) + observer_manager_mock._observers = [observer_mock] + + with patch("exchange.observers.ObserverManager.get_instance", return_value=observer_manager_mock), patch( + "exchange.Exchange.generate", return_value=Message.assistant("test response") + ): + session = create_session_with_mock_configs({"name": SESSION_NAME}) + + session.exchange.messages.append(Message.user("hi")) + session.reply() + + observe_wrapper_mock.assert_called_once() diff --git a/tests/test_profile.py b/tests/test_profile.py index 3c022f74..b063937c 100644 --- a/tests/test_profile.py +++ b/tests/test_profile.py @@ -1,4 +1,4 @@ -from goose.profile import ToolkitSpec +from goose.profile import ToolkitSpec, ObserverSpec def test_profile_info(profile_factory): @@ -7,6 +7,10 @@ def test_profile_info(profile_factory): "provider": "provider", "processor": "processor", "toolkits": [ToolkitSpec("developer"), ToolkitSpec("github")], + "observers": [ObserverSpec(name="test.plugin")], } ) - assert profile.profile_info() == "provider:provider, processor:processor toolkits: developer, github" + assert ( + profile.profile_info() + == "provider:provider, processor:processor toolkits: developer, github observers: test.plugin" + )