Skip to content

Commit

Permalink
refactor!: move utils that use slack WebClient into dedicated SlackCo…
Browse files Browse the repository at this point in the history
…ntext class

BREAKING CHANGE: SlackContext is now required as a first argument for all handlers
  • Loading branch information
meetbryce committed Oct 12, 2024
1 parent 66a365c commit c0343be
Show file tree
Hide file tree
Showing 12 changed files with 563 additions and 605 deletions.
32 changes: 16 additions & 16 deletions ossai/decorators/catch_error_dm_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from ossai.logging_config import logger
from ossai.utils import get_bot_id, get_direct_message_channel_id
from ossai.slack_context import SlackContext


class SlackPayload(BaseModel):
Expand All @@ -22,8 +22,9 @@ def get_channel_id(self) -> str:

def catch_errors_dm_user(func):
@wraps(func)
async def wrapper(client: WebClient, *args, **kwargs):
assert isinstance(client, WebClient), "client must be a Slack WebClient"
async def wrapper(slack_context: SlackContext, *args, **kwargs):
assert isinstance(slack_context, SlackContext), "slack_context must be a SlackContext"
assert isinstance(slack_context.client, WebClient), "slack_context.client must be a Slack WebClient"

payload = None
if args:
Expand All @@ -38,17 +39,17 @@ async def wrapper(client: WebClient, *args, **kwargs):
# Continue execution even if validation fails

try:
return await func(client, *args, **kwargs)
return await func(slack_context, *args, **kwargs)
except SlackApiError as e:
await _handle_slack_api_error(client, payload, payload_dict, e)
await _handle_slack_api_error(slack_context, payload, payload_dict, e)
except Exception as e:
await _handle_unknown_error(client, payload, payload_dict, e)
await _handle_unknown_error(slack_context, payload, payload_dict, e)

return wrapper


async def _handle_slack_api_error(
client: WebClient,
slack_context: SlackContext,
payload: Optional[SlackPayload],
payload_dict: dict,
error: SlackApiError,
Expand All @@ -57,29 +58,29 @@ async def _handle_slack_api_error(
if error.response["error"] in ("not_in_channel", "channel_not_found"):
user_id = _get_user_id(payload, payload_dict)
channel_id, error_type, error_message = await _handle_channel_error(
client, user_id
slack_context, user_id
)
else:
channel_id = _get_channel_id(payload, payload_dict)
error_type = "Slack API"
error_message = f"Sorry, an unexpected error occurred. `{error.response['error']}`\n\n```{str(error)}```"

user_id = _get_user_id(payload, payload_dict)
await _send_error_message(client, channel_id, user_id, error_type, error_message)
await _send_error_message(slack_context.client, channel_id, user_id, error_type, error_message)


async def _handle_channel_error(client: WebClient, user_id: str):
channel_id = await get_direct_message_channel_id(client, user_id)
async def _handle_channel_error(slack_context: SlackContext, user_id: str):
channel_id = await slack_context.get_direct_message_channel_id(user_id)
error_type = "Not in channel"
bot_id = await get_bot_id(client)
bot_info = client.bots_info(bot=bot_id)
bot_id = await slack_context.get_bot_id()
bot_info = slack_context.client.bots_info(bot=bot_id)
bot_name = bot_info["bot"]["name"]
error_message = f"Sorry, couldn't find the channel. Have you added `@{bot_name}` to the channel?"
return channel_id, error_type, error_message


async def _handle_unknown_error(
client: WebClient,
slack_context: SlackContext,
payload: Optional[SlackPayload],
payload_dict: dict,
error: Exception,
Expand All @@ -89,15 +90,14 @@ async def _handle_unknown_error(
logger.error(f"[Unknown error] {error}.", exc_info=True)
channel_id = _get_channel_id(payload, payload_dict)
user_id = _get_user_id(payload, payload_dict)
await _send_error_message(client, channel_id, user_id, error_type, error_message)
await _send_error_message(slack_context.client, channel_id, user_id, error_type, error_message)


async def _send_error_message(client, channel_id, user_id, error_type, error_message):
logger.debug(
f"running _send_error_message() with {channel_id=} {user_id=} {error_type=} {error_message=}"
)
try:
# ? is this sometime async other times not?
await client.chat_postEphemeral(
channel=channel_id, user=user_id, text=error_message
)
Expand Down
75 changes: 37 additions & 38 deletions ossai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,14 @@
from ossai.summarizer import Summarizer
from ossai.topic_analysis import analyze_topics_of_history
from ossai.utils import (
get_direct_message_channel_id,
get_workspace_name,
get_channel_history,
get_parsed_messages,
get_user_context,
get_is_private_and_channel_name,
get_text_and_blocks_for_say,
get_since_timeframe_presets,
)
from ossai.slack_context import SlackContext

_custom_prompt_cache = {}

# FIXME: basically, i need to have all handlers take `slack_context` not `client`

def handler_feedback(body):
"""
Expand Down Expand Up @@ -54,22 +50,23 @@ def handler_feedback(body):

@catch_errors_dm_user
async def handler_shortcuts(
client: WebClient, is_private: bool, payload, say, user_id: str
slack_context: SlackContext, is_private: bool, payload, say, user_id: str
):
client = slack_context.client
channel_id = (
payload["channel"]["id"] if payload["channel"]["id"] else payload["channel_id"]
)
dm_channel_id = await get_direct_message_channel_id(client, user_id)
dm_channel_id = await slack_context.get_direct_message_channel_id(user_id)
channel_id_for_say = dm_channel_id if is_private else channel_id
await say(channel=channel_id_for_say, text="...")

response = client.conversations_replies(
response = slack_context.client.conversations_replies(
channel=channel_id, ts=payload["message_ts"]
)
if response["ok"]:
messages = response["messages"]
original_message = messages[0]["text"]
workspace_name = get_workspace_name(client)
workspace_name = slack_context.get_workspace_name()
link = f"https://{workspace_name}.slack.com/archives/{channel_id}/p{payload['message_ts'].replace('.', '')}"

original_message = original_message.split("\n")
Expand All @@ -83,10 +80,10 @@ async def handler_shortcuts(
)

title = f'*Summary of <{link}|{"thread" if len(messages) > 1 else "message"}>:*\n>{thread_hint}\n'
user = await get_user_context(client, user_id)
summarizer = Summarizer()
user = await slack_context.get_user_context(user_id)
summarizer = Summarizer(slack_context)
summary, run_id = summarizer.summarize_slack_messages(
client, messages, channel_id, feature_name="summarize_thread", user=user
messages, channel_id, feature_name="summarize_thread", user=user
)
text, blocks = get_text_and_blocks_for_say(
title=title, run_id=run_id, messages=summary
Expand All @@ -101,24 +98,23 @@ async def handler_shortcuts(

@catch_errors_dm_user
async def handler_tldr_extended_slash_command(
client: WebClient, ack, payload, say, user_id: str
slack_context: SlackContext, ack, payload, say, user_id: str
):
await ack()
client = slack_context.client
channel_name = payload["channel_name"]
channel_id = payload["channel_id"]
dm_channel_id = None

dm_channel_id = await get_direct_message_channel_id(client, user_id)
dm_channel_id = await slack_context.get_direct_message_channel_id(user_id)
await say(channel=dm_channel_id, text="...")

history = await get_channel_history(client, channel_id)
history = await slack_context.get_channel_history(channel_id)
history.reverse()
user = await get_user_context(client, user_id)
user = await slack_context.get_user_context(user_id)
title = f"*Summary of #{channel_name}* (last {len(history)} messages)\n"
custom_prompt = payload.get("text", None)
summarizer = Summarizer(custom_prompt=custom_prompt)
summarizer = Summarizer(slack_context, custom_prompt=custom_prompt)
summary, run_id = summarizer.summarize_slack_messages(
client,
history,
channel_id,
feature_name="summarize_channel_messages",
Expand All @@ -132,19 +128,20 @@ async def handler_tldr_extended_slash_command(

@catch_errors_dm_user
async def handler_topics_slash_command(
client: WebClient, ack, payload, say, user_id: str
slack_context: SlackContext, ack, payload, say, user_id: str
):
await ack()
client = slack_context.client
channel_id = payload["channel_id"]
dm_channel_id = await get_direct_message_channel_id(client, user_id)
dm_channel_id = await slack_context.get_direct_message_channel_id(user_id)
await say(channel=dm_channel_id, text="...")

history = await get_channel_history(client, channel_id)
history = await slack_context.get_channel_history(channel_id)
history.reverse()

messages = get_parsed_messages(client, history, with_names=False)
user = await get_user_context(client, user_id)
is_private, channel_name = get_is_private_and_channel_name(client, channel_id)
messages = slack_context.get_parsed_messages(history, with_names=False)
user = await slack_context.get_user_context(user_id)
is_private, channel_name = slack_context.get_is_private_and_channel_name(channel_id)
custom_prompt = payload.get("text", None)
if custom_prompt:
# todo: add support for custom prompts to /tldr
Expand All @@ -164,10 +161,11 @@ async def handler_topics_slash_command(


@catch_errors_dm_user
async def handler_tldr_since_slash_command(client: WebClient, ack, payload, say):
async def handler_tldr_since_slash_command(slack_context: SlackContext, ack, payload, say):
await ack()
client = slack_context.client
title = "Choose your summary timeframe."
dm_channel_id = await get_direct_message_channel_id(client, payload["user_id"])
dm_channel_id = await slack_context.get_direct_message_channel_id(payload["user_id"])

custom_prompt = payload.get("text", None)

Expand Down Expand Up @@ -207,11 +205,12 @@ async def handler_tldr_since_slash_command(client: WebClient, ack, payload, say)


@catch_errors_dm_user
async def handler_action_summarize_since_date(client: WebClient, ack, body):
async def handler_action_summarize_since_date(slack_context: SlackContext, ack, body):
"""
Provide a message summary of the channel since a given date.
"""
await ack()
client = slack_context.client
channel_name = body["channel"]["name"]
channel_id = body["channel"]["id"]
user_id = body["user"]["id"]
Expand All @@ -227,22 +226,22 @@ async def handler_action_summarize_since_date(client: WebClient, ack, body):
since_date = body["actions"][0]["selected_date"]
since_datetime: datetime = datetime.strptime(since_date, "%Y-%m-%d").date()

dm_channel_id = await get_direct_message_channel_id(client, user_id)
dm_channel_id = await slack_context.get_direct_message_channel_id(user_id)
client.chat_postMessage(channel=dm_channel_id, text="...")

async with ClientSession() as session:
await session.post(body["response_url"], json={"delete_original": "true"})

history = await get_channel_history(client, channel_id, since=since_datetime)
history = await slack_context.get_channel_history(channel_id, since=since_datetime)
history.reverse()
user = await get_user_context(client, user_id)
user = await slack_context.get_user_context(user_id)
custom_prompt = None
if "container" in body and "message_ts" in body["container"]:
key = f"{body['container']['message_ts']}__{user_id}"
custom_prompt = _custom_prompt_cache.get(key, None)
summarizer = Summarizer(custom_prompt=custom_prompt)
summarizer = Summarizer(slack_context, custom_prompt=custom_prompt)
summary, run_id = summarizer.summarize_slack_messages(
client, history, channel_id, feature_name=feature_name, user=user
history, channel_id, feature_name=feature_name, user=user
)
text, blocks = get_text_and_blocks_for_say(
title=f'*Summary of #{channel_name}* since {since_datetime.strftime("%A %b %-d, %Y")} ({len(history)} messages)\n',
Expand All @@ -257,15 +256,15 @@ async def handler_action_summarize_since_date(client: WebClient, ack, body):

@catch_errors_dm_user
async def handler_sandbox_slash_command(
client: WebClient, ack, payload, say, user_id: str
slack_context: SlackContext, ack, payload, say, user_id: str
):
logger.debug(f"Handling /sandbox command")
await ack()
client = slack_context.client
channel_id = payload["channel_id"]
custom_prompt = payload.get("text", None)
summarizer = Summarizer(custom_prompt=custom_prompt)
summarizer = Summarizer(slack_context, custom_prompt=custom_prompt)
summary, run_id = summarizer.summarize_slack_messages(
client,
[
{"text": "bacon", "user": user_id},
{"text": "eggs", "user": user_id},
Expand All @@ -281,4 +280,4 @@ async def handler_sandbox_slash_command(
text, blocks = get_text_and_blocks_for_say(
title=title, run_id=run_id, messages=summary, custom_prompt=custom_prompt
)
return await say(text=text, blocks=blocks)
return await say(text=text, blocks=blocks)
Loading

0 comments on commit c0343be

Please sign in to comment.