Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Optionally specify a max cost that will halt exchange once breached #225

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "goose-ai"
description = "a programming agent that runs on your machine"
version = "0.9.11"
version = "0.9.12"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
Expand Down
23 changes: 17 additions & 6 deletions src/goose/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,14 @@ def get_session_files() -> dict[str, Path]:
@click.option("--plan", type=click.Path(exists=True))
@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO")
@click.option("--tracing", is_flag=True, required=False)
@click.option("--max-cost", type=int, help="Maximum cost in cents (e.g., 100 = $1.00)", default=None)
def session_start(
name: Optional[str], profile: str, log_level: str, plan: Optional[str] = None, tracing: bool = False
name: Optional[str],
profile: str,
log_level: str,
plan: Optional[str] = None,
tracing: bool = False,
max_cost: Optional[int] = None,
) -> None:
"""Start a new goose session"""
if plan:
Expand All @@ -172,7 +178,9 @@ def session_start(
_plan = None

try:
session = Session(name=name, profile=profile, plan=_plan, log_level=log_level, tracing=tracing)
session = Session(
name=name, profile=profile, plan=_plan, log_level=log_level, tracing=tracing, max_cost=max_cost
)
session.run()
except RuntimeError as e:
print(f"[red]Error: {e}")
Expand Down Expand Up @@ -204,7 +212,8 @@ def session_planned(plan: str, log_level: str, args: Optional[dict[str, str]]) -
@click.argument("name", required=False, shell_complete=autocomplete_session_files)
@click.option("--profile")
@click.option("--log-level", type=LOG_CHOICE, default="INFO")
def session_resume(name: Optional[str], profile: str, log_level: str) -> None:
@click.option("--max-cost", type=int, help="Maximum cost in cents (e.g., 100 = $1.00)", default=None)
def session_resume(name: Optional[str], profile: str, log_level: str, max_cost: Optional[int] = None) -> None:
"""Resume an existing goose session"""
session_files = get_session_files()
if name is None:
Expand All @@ -219,7 +228,7 @@ def session_resume(name: Optional[str], profile: str, log_level: str) -> None:
print(f"Resuming session: {name}")
else:
print(f"Creating new session: {name}")
session = Session(name=name, profile=profile, log_level=log_level)
session = Session(name=name, profile=profile, log_level=log_level, max_cost=max_cost)
session.run(new_session=False)


Expand All @@ -229,12 +238,14 @@ def session_resume(name: Optional[str], profile: str, log_level: str) -> None:
@click.option("--log-level", type=LOG_CHOICE, default="INFO")
@click.option("--resume-session", is_flag=True, help="Resume the last session if available")
@click.option("--tracing", is_flag=True, required=False)
@click.option("--max-cost", type=int, help="Maximum cost in cents (e.g., 100 = $1.00)", default=None)
def run(
message_file: Optional[str],
profile: str,
log_level: str,
resume_session: bool = False,
tracing: bool = False,
max_cost: Optional[int] = None,
) -> None:
"""Run a single-pass session with a message from a markdown input file"""
if message_file:
Expand All @@ -247,9 +258,9 @@ def run(
session_files = get_session_files()
if session_files:
name = list(session_files.keys())[0]
session = Session(name=name, profile=profile, log_level=log_level, tracing=tracing)
session = Session(name=name, profile=profile, log_level=log_level, tracing=tracing, max_cost=max_cost)
else:
session = Session(profile=profile, log_level=log_level, tracing=tracing)
session = Session(profile=profile, log_level=log_level, tracing=tracing, max_cost=max_cost)
session.single_pass(initial_message=initial_message)


Expand Down
39 changes: 38 additions & 1 deletion src/goose/cli/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from goose.cli.session_notifier import SessionNotifier
from goose.profile import Profile
from goose.utils import droid, load_plugins
from goose.utils._cost_calculator import get_total_cost_message
from goose.utils._cost_calculator import calculate_cost, get_total_cost_message
from goose.utils._create_exchange import create_exchange
from goose.utils.session_file import is_empty_session, is_existing_session, log_messages, read_or_create_file

Expand Down Expand Up @@ -67,6 +67,7 @@ def __init__(
plan: Optional[dict] = None,
log_level: Optional[str] = "INFO",
tracing: bool = False,
max_cost: Optional[int] = None,
**kwargs: dict[str, any],
) -> None:
if name is None:
Expand All @@ -92,6 +93,8 @@ def __init__(
if self.tracing:
langfuse_context.configure(enabled=tracing)

self.max_cost = max_cost

self.exchange = create_exchange(profile=load_profile(profile), notifier=self.notifier)
setup_logging(log_file_directory=LOG_PATH, log_level=log_level)

Expand Down Expand Up @@ -213,6 +216,8 @@ def reply(self) -> None:
committed = [self.exchange.messages[-1]]

try:
self._check_cost_not_exceeded()

self.status_indicator.update("processing request")
response = self.exchange.generate()
self.status_indicator.update("got response, processing")
Expand All @@ -229,6 +234,9 @@ def reply(self) -> None:
message = Message(role="user", content=content)
committed.append(message)
self.exchange.add(message)

self._check_cost_not_exceeded()

self.status_indicator.update("processing tool results")
response = self.exchange.generate()
committed.append(response)
Expand All @@ -239,6 +247,11 @@ def reply(self) -> None:
# The interrupt reply modifies the message history,
# and we sync those changes to committed
self.interrupt_reply(committed)
except CostExceededError:
print(
f"[red]The session cost has exceeded the maximum allowed cost of ${self.max_cost/100:.2f}.\n"
+ "To continue, exit and resume the session with a different maximum allowed cost.[/]"
)

# we log the committed messages only once the reply completes
# this prevents messages related to uncaught errors from being recorded
Expand Down Expand Up @@ -286,6 +299,24 @@ def session_file_path(self) -> Path:
def load_session(self) -> list[Message]:
return read_or_create_file(self.session_file_path)

def _check_cost_not_exceeded(self) -> None:
if self.max_cost is not None:
total_cost = 0
for model, token_usage in self.exchange.get_token_usage().items():
cost = calculate_cost(model, token_usage)
if cost is not None:
total_cost += cost
else:
raise RuntimeError(
f"Pricing for model {model} not available. Incompatible with --max-cost parameter."
)

# Convert to integer cents for comparison
cost_cents = int(round(total_cost * 100, 0))
if cost_cents >= self.max_cost:
error = f"Session cost ${total_cost:.2f} exceeds maximum allowed cost ${self.max_cost/100:.2f}"
raise CostExceededError(error)

def _log_cost(self, start_time: datetime, end_time: datetime) -> None:
get_logger().info(get_total_cost_message(self.exchange.get_token_usage(), self.name, start_time, end_time))
print(f"[dim]you can view the cost and token usage in the log directory {LOG_PATH}[/]")
Expand Down Expand Up @@ -338,5 +369,11 @@ def _remove_empty_session(self) -> bool:
return False


class CostExceededError(Exception):
"""Raised when the cost of a session exceeds the maximum allowed cost."""

pass


if __name__ == "__main__":
session = Session()
4 changes: 2 additions & 2 deletions src/goose/utils/_cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
}


def _calculate_cost(model: str, token_usage: Usage) -> Optional[float]:
def calculate_cost(model: str, token_usage: Usage) -> Optional[float]:
model_name = model.lower()
if model_name in PRICES:
input_token_price, output_token_price = PRICES[model_name]
Expand All @@ -50,7 +50,7 @@ def get_total_cost_message(
message = ""
session_name_prefix = f"Session name: {session_name}"
for model, token_usage in token_usages.items():
cost = _calculate_cost(model, token_usage)
cost = calculate_cost(model, token_usage)
if cost is not None:
message += f"{session_name_prefix} | Cost for model {model} {str(token_usage)}: ${cost:.2f}\n"
total_cost += cost
Expand Down
10 changes: 5 additions & 5 deletions tests/cli/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ def test_session_start_command_with_session_name(mock_session):
runner = CliRunner()
runner.invoke(goose_cli, ["session", "start", "session1", "--profile", "default"])
mock_session_class.assert_called_once_with(
name="session1", profile="default", plan=None, log_level="INFO", tracing=False
name="session1", profile="default", plan=None, log_level="INFO", tracing=False, max_cost=None
)
mock_session_instance.run.assert_called_once()


def test_session_resume_command_with_session_name(mock_session):
mock_session_class, mock_session_instance = mock_session
runner = CliRunner()
runner.invoke(goose_cli, ["session", "resume", "session1", "--profile", "default"])
mock_session_class.assert_called_once_with(name="session1", profile="default", log_level="INFO")
runner.invoke(goose_cli, ["session", "resume", "session1", "--profile", "default", "--max-cost", "100"])
mock_session_class.assert_called_once_with(name="session1", profile="default", log_level="INFO", max_cost=100)
mock_session_instance.run.assert_called_once()


Expand All @@ -69,7 +69,7 @@ def test_session_resume_command_without_session_name_use_latest_session(

second_file_path = mock_session_files_path / "second.jsonl"
mock_print.assert_called_once_with(f"Resuming most recent session: second from {second_file_path}")
mock_session_class.assert_called_once_with(name="second", profile="default", log_level="INFO")
mock_session_class.assert_called_once_with(name="second", profile="default", log_level="INFO", max_cost=None)
mock_session_instance.run.assert_called_once()


Expand Down Expand Up @@ -131,7 +131,7 @@ def test_combined_group_commands(mock_session):
mock_session_class, mock_session_instance = mock_session
runner = CliRunner()
runner.invoke(cli, ["session", "resume", "session1", "--profile", "default"])
mock_session_class.assert_called_once_with(name="session1", profile="default", log_level="INFO")
mock_session_class.assert_called_once_with(name="session1", profile="default", log_level="INFO", max_cost=None)
mock_session_instance.run.assert_called_once()


Expand Down
15 changes: 15 additions & 0 deletions tests/cli/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
from exchange import Message, ToolResult, ToolUse
from exchange.providers.base import Usage
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.cli.prompt.overwrite_session_prompt import OverwriteSessionPrompt
from goose.cli.prompt.user_input import PromptAction, UserInput
Expand Down Expand Up @@ -157,6 +158,20 @@ def test_process_first_message_return_last_exchange_message(create_session_with_
assert len(session.exchange.messages) == 0


def test_reply_does_not_call_exchange_generate_when_cost_exceeded(create_session_with_mock_configs):
session = create_session_with_mock_configs({"max_cost": 200})
mock_exchange = MagicMock()
session.exchange = mock_exchange

# $2.50 * 1000000 / 1000000 = $2.50
usage = Usage(input_tokens=1000000, output_tokens=0, total_tokens=1000000)
mock_exchange.get_token_usage.return_value = {"gpt-4o": usage}

with patch("goose.cli.session.log_messages"): # Mock log_messages to avoid serialization issues with MagicMock
session.reply()
mock_exchange.generate.assert_not_called()


def test_log_log_cost(create_session_with_mock_configs):
session = create_session_with_mock_configs()
mock_logger = MagicMock()
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
from exchange.providers.base import Usage
from goose.utils._cost_calculator import _calculate_cost, get_total_cost_message
from goose.utils._cost_calculator import calculate_cost, get_total_cost_message

SESSION_NAME = "test_session"
START_TIME = datetime(2024, 10, 20, 1, 2, 3, tzinfo=timezone.utc)
Expand Down Expand Up @@ -32,7 +32,7 @@ def mock_prices():


def test_calculate_cost(mock_prices):
cost = _calculate_cost("gpt-4o", Usage(input_tokens=10000, output_tokens=600, total_tokens=10600))
cost = calculate_cost("gpt-4o", Usage(input_tokens=10000, output_tokens=600, total_tokens=10600))
assert cost == 0.059


Expand Down