diff --git a/ossai/decorators/catch_error_dm_user.py b/ossai/decorators/catch_error_dm_user.py index 2f3fddd..a93b00d 100644 --- a/ossai/decorators/catch_error_dm_user.py +++ b/ossai/decorators/catch_error_dm_user.py @@ -4,7 +4,7 @@ from slack_sdk import WebClient from slack_sdk.errors import SlackApiError from ossai.logging_config import logger -from ossai.utils import get_bot_id, get_direct_message_channel_id +from ossai.slack_context import SlackContext class SlackPayload(BaseModel): @@ -22,8 +22,9 @@ def get_channel_id(self) -> str: def catch_errors_dm_user(func): @wraps(func) - async def wrapper(client: WebClient, *args, **kwargs): - assert isinstance(client, WebClient), "client must be a Slack WebClient" + async def wrapper(slack_context: SlackContext, *args, **kwargs): + assert isinstance(slack_context, SlackContext), "slack_context must be a SlackContext" + assert isinstance(slack_context.client, WebClient), "slack_context.client must be a Slack WebClient" payload = None if args: @@ -38,17 +39,17 @@ async def wrapper(client: WebClient, *args, **kwargs): # Continue execution even if validation fails try: - return await func(client, *args, **kwargs) + return await func(slack_context, *args, **kwargs) except SlackApiError as e: - await _handle_slack_api_error(client, payload, payload_dict, e) + await _handle_slack_api_error(slack_context, payload, payload_dict, e) except Exception as e: - await _handle_unknown_error(client, payload, payload_dict, e) + await _handle_unknown_error(slack_context, payload, payload_dict, e) return wrapper async def _handle_slack_api_error( - client: WebClient, + slack_context: SlackContext, payload: Optional[SlackPayload], payload_dict: dict, error: SlackApiError, @@ -57,7 +58,7 @@ async def _handle_slack_api_error( if error.response["error"] in ("not_in_channel", "channel_not_found"): user_id = _get_user_id(payload, payload_dict) channel_id, error_type, error_message = await _handle_channel_error( - client, user_id + slack_context, user_id ) else: channel_id = _get_channel_id(payload, payload_dict) @@ -65,21 +66,21 @@ async def _handle_slack_api_error( error_message = f"Sorry, an unexpected error occurred. `{error.response['error']}`\n\n```{str(error)}```" user_id = _get_user_id(payload, payload_dict) - await _send_error_message(client, channel_id, user_id, error_type, error_message) + await _send_error_message(slack_context.client, channel_id, user_id, error_type, error_message) -async def _handle_channel_error(client: WebClient, user_id: str): - channel_id = await get_direct_message_channel_id(client, user_id) +async def _handle_channel_error(slack_context: SlackContext, user_id: str): + channel_id = await slack_context.get_direct_message_channel_id(user_id) error_type = "Not in channel" - bot_id = await get_bot_id(client) - bot_info = client.bots_info(bot=bot_id) + bot_id = await slack_context.get_bot_id() + bot_info = slack_context.client.bots_info(bot=bot_id) bot_name = bot_info["bot"]["name"] error_message = f"Sorry, couldn't find the channel. Have you added `@{bot_name}` to the channel?" return channel_id, error_type, error_message async def _handle_unknown_error( - client: WebClient, + slack_context: SlackContext, payload: Optional[SlackPayload], payload_dict: dict, error: Exception, @@ -89,7 +90,7 @@ async def _handle_unknown_error( logger.error(f"[Unknown error] {error}.", exc_info=True) channel_id = _get_channel_id(payload, payload_dict) user_id = _get_user_id(payload, payload_dict) - await _send_error_message(client, channel_id, user_id, error_type, error_message) + await _send_error_message(slack_context.client, channel_id, user_id, error_type, error_message) async def _send_error_message(client, channel_id, user_id, error_type, error_message): @@ -97,7 +98,6 @@ async def _send_error_message(client, channel_id, user_id, error_type, error_mes f"running _send_error_message() with {channel_id=} {user_id=} {error_type=} {error_message=}" ) try: - # ? is this sometime async other times not? await client.chat_postEphemeral( channel=channel_id, user=user_id, text=error_message ) diff --git a/ossai/handlers.py b/ossai/handlers.py index f3f157c..010ba3f 100644 --- a/ossai/handlers.py +++ b/ossai/handlers.py @@ -12,18 +12,14 @@ from ossai.summarizer import Summarizer from ossai.topic_analysis import analyze_topics_of_history from ossai.utils import ( - get_direct_message_channel_id, - get_workspace_name, - get_channel_history, - get_parsed_messages, - get_user_context, - get_is_private_and_channel_name, get_text_and_blocks_for_say, get_since_timeframe_presets, ) +from ossai.slack_context import SlackContext _custom_prompt_cache = {} +# FIXME: basically, i need to have all handlers take `slack_context` not `client` def handler_feedback(body): """ @@ -54,22 +50,23 @@ def handler_feedback(body): @catch_errors_dm_user async def handler_shortcuts( - client: WebClient, is_private: bool, payload, say, user_id: str + slack_context: SlackContext, is_private: bool, payload, say, user_id: str ): + client = slack_context.client channel_id = ( payload["channel"]["id"] if payload["channel"]["id"] else payload["channel_id"] ) - dm_channel_id = await get_direct_message_channel_id(client, user_id) + dm_channel_id = await slack_context.get_direct_message_channel_id(user_id) channel_id_for_say = dm_channel_id if is_private else channel_id await say(channel=channel_id_for_say, text="...") - response = client.conversations_replies( + response = slack_context.client.conversations_replies( channel=channel_id, ts=payload["message_ts"] ) if response["ok"]: messages = response["messages"] original_message = messages[0]["text"] - workspace_name = get_workspace_name(client) + workspace_name = slack_context.get_workspace_name() link = f"https://{workspace_name}.slack.com/archives/{channel_id}/p{payload['message_ts'].replace('.', '')}" original_message = original_message.split("\n") @@ -83,10 +80,10 @@ async def handler_shortcuts( ) title = f'*Summary of <{link}|{"thread" if len(messages) > 1 else "message"}>:*\n>{thread_hint}\n' - user = await get_user_context(client, user_id) - summarizer = Summarizer() + user = await slack_context.get_user_context(user_id) + summarizer = Summarizer(slack_context) summary, run_id = summarizer.summarize_slack_messages( - client, messages, channel_id, feature_name="summarize_thread", user=user + messages, channel_id, feature_name="summarize_thread", user=user ) text, blocks = get_text_and_blocks_for_say( title=title, run_id=run_id, messages=summary @@ -101,24 +98,23 @@ async def handler_shortcuts( @catch_errors_dm_user async def handler_tldr_extended_slash_command( - client: WebClient, ack, payload, say, user_id: str + slack_context: SlackContext, ack, payload, say, user_id: str ): await ack() + client = slack_context.client channel_name = payload["channel_name"] channel_id = payload["channel_id"] - dm_channel_id = None - dm_channel_id = await get_direct_message_channel_id(client, user_id) + dm_channel_id = await slack_context.get_direct_message_channel_id(user_id) await say(channel=dm_channel_id, text="...") - history = await get_channel_history(client, channel_id) + history = await slack_context.get_channel_history(channel_id) history.reverse() - user = await get_user_context(client, user_id) + user = await slack_context.get_user_context(user_id) title = f"*Summary of #{channel_name}* (last {len(history)} messages)\n" custom_prompt = payload.get("text", None) - summarizer = Summarizer(custom_prompt=custom_prompt) + summarizer = Summarizer(slack_context, custom_prompt=custom_prompt) summary, run_id = summarizer.summarize_slack_messages( - client, history, channel_id, feature_name="summarize_channel_messages", @@ -132,19 +128,20 @@ async def handler_tldr_extended_slash_command( @catch_errors_dm_user async def handler_topics_slash_command( - client: WebClient, ack, payload, say, user_id: str + slack_context: SlackContext, ack, payload, say, user_id: str ): await ack() + client = slack_context.client channel_id = payload["channel_id"] - dm_channel_id = await get_direct_message_channel_id(client, user_id) + dm_channel_id = await slack_context.get_direct_message_channel_id(user_id) await say(channel=dm_channel_id, text="...") - history = await get_channel_history(client, channel_id) + history = await slack_context.get_channel_history(channel_id) history.reverse() - messages = get_parsed_messages(client, history, with_names=False) - user = await get_user_context(client, user_id) - is_private, channel_name = get_is_private_and_channel_name(client, channel_id) + messages = slack_context.get_parsed_messages(history, with_names=False) + user = await slack_context.get_user_context(user_id) + is_private, channel_name = slack_context.get_is_private_and_channel_name(channel_id) custom_prompt = payload.get("text", None) if custom_prompt: # todo: add support for custom prompts to /tldr @@ -164,10 +161,11 @@ async def handler_topics_slash_command( @catch_errors_dm_user -async def handler_tldr_since_slash_command(client: WebClient, ack, payload, say): +async def handler_tldr_since_slash_command(slack_context: SlackContext, ack, payload, say): await ack() + client = slack_context.client title = "Choose your summary timeframe." - dm_channel_id = await get_direct_message_channel_id(client, payload["user_id"]) + dm_channel_id = await slack_context.get_direct_message_channel_id(payload["user_id"]) custom_prompt = payload.get("text", None) @@ -207,11 +205,12 @@ async def handler_tldr_since_slash_command(client: WebClient, ack, payload, say) @catch_errors_dm_user -async def handler_action_summarize_since_date(client: WebClient, ack, body): +async def handler_action_summarize_since_date(slack_context: SlackContext, ack, body): """ Provide a message summary of the channel since a given date. """ await ack() + client = slack_context.client channel_name = body["channel"]["name"] channel_id = body["channel"]["id"] user_id = body["user"]["id"] @@ -227,22 +226,22 @@ async def handler_action_summarize_since_date(client: WebClient, ack, body): since_date = body["actions"][0]["selected_date"] since_datetime: datetime = datetime.strptime(since_date, "%Y-%m-%d").date() - dm_channel_id = await get_direct_message_channel_id(client, user_id) + dm_channel_id = await slack_context.get_direct_message_channel_id(user_id) client.chat_postMessage(channel=dm_channel_id, text="...") async with ClientSession() as session: await session.post(body["response_url"], json={"delete_original": "true"}) - history = await get_channel_history(client, channel_id, since=since_datetime) + history = await slack_context.get_channel_history(channel_id, since=since_datetime) history.reverse() - user = await get_user_context(client, user_id) + user = await slack_context.get_user_context(user_id) custom_prompt = None if "container" in body and "message_ts" in body["container"]: key = f"{body['container']['message_ts']}__{user_id}" custom_prompt = _custom_prompt_cache.get(key, None) - summarizer = Summarizer(custom_prompt=custom_prompt) + summarizer = Summarizer(slack_context, custom_prompt=custom_prompt) summary, run_id = summarizer.summarize_slack_messages( - client, history, channel_id, feature_name=feature_name, user=user + history, channel_id, feature_name=feature_name, user=user ) text, blocks = get_text_and_blocks_for_say( title=f'*Summary of #{channel_name}* since {since_datetime.strftime("%A %b %-d, %Y")} ({len(history)} messages)\n', @@ -257,15 +256,15 @@ async def handler_action_summarize_since_date(client: WebClient, ack, body): @catch_errors_dm_user async def handler_sandbox_slash_command( - client: WebClient, ack, payload, say, user_id: str + slack_context: SlackContext, ack, payload, say, user_id: str ): logger.debug(f"Handling /sandbox command") await ack() + client = slack_context.client channel_id = payload["channel_id"] custom_prompt = payload.get("text", None) - summarizer = Summarizer(custom_prompt=custom_prompt) + summarizer = Summarizer(slack_context, custom_prompt=custom_prompt) summary, run_id = summarizer.summarize_slack_messages( - client, [ {"text": "bacon", "user": user_id}, {"text": "eggs", "user": user_id}, @@ -281,4 +280,4 @@ async def handler_sandbox_slash_command( text, blocks = get_text_and_blocks_for_say( title=title, run_id=run_id, messages=summary, custom_prompt=custom_prompt ) - return await say(text=text, blocks=blocks) + return await say(text=text, blocks=blocks) \ No newline at end of file diff --git a/ossai/slack_context.py b/ossai/slack_context.py new file mode 100644 index 0000000..f40a64d --- /dev/null +++ b/ossai/slack_context.py @@ -0,0 +1,146 @@ +import os +import re +from time import mktime +from datetime import date + +from slack_sdk import WebClient +from slack_sdk.errors import SlackApiError + +from ossai.logging_config import logger + +class SlackContext: + def __init__(self, client: WebClient): + self.client = client + self._id_name_cache = {} + self._bot_id = None + self._workspace_name = None + + async def get_bot_id(self) -> str: + # todo: refactor this to be an attribute getter i.e. slack_context.bot_id + if self._bot_id is None: + try: + response = self.client.auth_test() + self._bot_id = response["bot_id"] + except SlackApiError as e: + logger.error(f"Error fetching bot ID: {e.response['error']}") + self._bot_id = "None" + return self._bot_id + + async def get_channel_history( + self, + channel_id: str, + since: date = None, + include_threads: bool = False, + ) -> list: + oldest_timestamp = mktime(since.timetuple()) if since else 0 + response = self.client.conversations_history( + channel=channel_id, limit=1000, oldest=oldest_timestamp + ) + bot_id = await self.get_bot_id() + return [msg for msg in response["messages"] if msg.get("bot_id") != bot_id] + + async def get_direct_message_channel_id(self, user_id: str) -> str: + try: + response = self.client.conversations_open(users=user_id) + return response["channel"]["id"] + except SlackApiError as e: + logger.error(f"Error fetching bot DM channel ID: {e.response['error']}") + raise e + + def get_is_private_and_channel_name(self, channel_id: str) -> tuple[bool, str]: + try: + channel_info = self.client.conversations_info(channel=channel_id) + channel_name = channel_info["channel"]["name"] + is_private = channel_info["channel"]["is_private"] + except Exception as e: + logger.error( + f"Error getting channel info for is_private, defaulting to private: {e}" + ) + channel_name = "unknown" + is_private = True + return is_private, channel_name + + def get_name_from_id(self, user_or_bot_id: str, is_bot=False) -> str: + if user_or_bot_id in self._id_name_cache: + return self._id_name_cache[user_or_bot_id] + + try: + user_response = self.client.users_info(user=user_or_bot_id) + if user_response.get("ok"): + name = user_response["user"].get( + "real_name", user_response["user"]["profile"]["real_name"] + ) + self._id_name_cache[user_or_bot_id] = name + return name + else: + logger.error("user fetch failed") + raise SlackApiError("user fetch failed", user_response) + except SlackApiError as e: + if e.response["error"] == "user_not_found": + try: + bot_response = self.client.bots_info(bot=user_or_bot_id) + if bot_response.get("ok"): + self._id_name_cache[user_or_bot_id] = bot_response["bot"]["name"] + return bot_response["bot"]["name"] + else: + logger.error("bot fetch failed") + raise SlackApiError("bot fetch failed", bot_response) + except SlackApiError as e2: + logger.error( + f"Error fetching name for bot {user_or_bot_id=}: {e2.response['error']}" + ) + logger.error(f"Error fetching name for {user_or_bot_id=} {is_bot=} {e=}") + + return "Someone" + + def get_parsed_messages(self, messages, with_names=True): + def parse_message(msg): + user_id = msg.get("user") + if user_id is None: + bot_id = msg.get("bot_id") + name = self.get_name_from_id(bot_id, is_bot=True) + else: + name = self.get_name_from_id(user_id) + + parsed_message = re.sub( + r"<@[UB]\w+>", + lambda m: self.get_name_from_id(m.group(0)[2:-1]), + msg["text"], + ) + + if not with_names: + return re.sub(r"<@[UB]\w+>", lambda m: "", msg["text"]) + + return f"{name}: {parsed_message}" + + return [parse_message(message) for message in messages] + + async def get_user_context(self, user_id: str) -> dict: + try: + user_info = self.client.users_info(user=user_id) + logger.debug(user_info) + if user_info["ok"]: + name = user_info["user"]["name"] + title = user_info["user"]["profile"]["title"] + return {"name": name, "title": title} + except SlackApiError as e: + logger.error(f"Failed to fetch username: {e}") + return {} + + def get_workspace_name(self): + # todo: refactor this to be an attribute getter i.e. slack_context.workspace_name + if self._workspace_name is None: + try: + response = self.client.team_info() + if response["ok"]: + self._workspace_name = response["team"]["name"] + else: + logger.warning( + f"Error retrieving workspace name: {response['error']}. Falling back to WORKSPACE_NAME_FALLBACK." + ) + self._workspace_name = os.getenv("WORKSPACE_NAME_FALLBACK", "") + except SlackApiError as e: + logger.error(f"Error retrieving workspace name: {e.response['error']}") + self._workspace_name = os.getenv("WORKSPACE_NAME_FALLBACK", "") + return self._workspace_name + diff --git a/ossai/slack_server.py b/ossai/slack_server.py index 0a72f60..2092de1 100644 --- a/ossai/slack_server.py +++ b/ossai/slack_server.py @@ -8,6 +8,7 @@ from slack_bolt.adapter.socket_mode.aiohttp import AsyncSocketModeHandler from slack_bolt.async_app import AsyncApp from slack_sdk import WebClient +from ossai.slack_context import SlackContext load_dotenv(override=True) @@ -88,27 +89,29 @@ async def slack_events(request: Request): @async_app.command("/tldr_extended") async def handle_tldr_extended_slash_command(ack, payload, say): return await handler_tldr_extended_slash_command( - client, ack, payload, say, user_id=payload["user_id"] + SlackContext(client), ack, payload, say, user_id=payload["user_id"] ) @async_app.command("/tldr") async def handle_slash_command_topics(ack, payload, say): return await handler_topics_slash_command( - client, ack, payload, say, user_id=payload["user_id"] + SlackContext(client), ack, payload, say, user_id=payload["user_id"] ) @async_app.command("/sandbox") async def handle_slash_command_sandbox(ack, payload, say): return await handler_sandbox_slash_command( - client, ack, payload, say, user_id=payload["user_id"] + SlackContext(client), ack, payload, say, user_id=payload["user_id"] ) @async_app.command("/tldr_since") async def handle_slash_command_tldr_since(ack, payload, say): - return await handler_tldr_since_slash_command(client, ack, payload, say) + return await handler_tldr_since_slash_command( + SlackContext(client), ack, payload, say + ) # MARK: - ACTIONS @@ -118,7 +121,7 @@ async def handle_slash_command_tldr_since(ack, payload, say): @async_app.action("summarize_since_preset") async def handle_action_summarize_since_date(ack, body, logger): await ack() - await handler_action_summarize_since_date(client, ack, body) + await handler_action_summarize_since_date(SlackContext(client), ack, body) return logger.info(body) @@ -137,13 +140,17 @@ async def handle_feedback(ack, body, logger): @async_app.shortcut("thread") async def handle_thread_shortcut(ack, payload, say): await ack() - await handler_shortcuts(client, False, payload, say, user_id=payload["user"]["id"]) + await handler_shortcuts( + SlackContext(client), False, payload, say, user_id=payload["user"]["id"] + ) @async_app.shortcut("thread_private") async def handle_thread_private_shortcut(ack, payload, say): await ack() - await handler_shortcuts(client, True, payload, say, user_id=payload["user"]["id"]) + await handler_shortcuts( + SlackContext(client), True, payload, say, user_id=payload["user"]["id"] + ) def main(): diff --git a/ossai/summarizer.py b/ossai/summarizer.py index 9e1008b..2b29648 100644 --- a/ossai/summarizer.py +++ b/ossai/summarizer.py @@ -11,18 +11,17 @@ from ossai.logging_config import logger from ossai.utils import ( - get_parsed_messages, get_langsmith_config, get_llm_config, - get_is_private_and_channel_name, ) - +from ossai.slack_context import SlackContext load_dotenv(override=True) class Summarizer: - def __init__(self, custom_prompt: Optional[str] = None): + def __init__(self, slack_context: SlackContext, custom_prompt: Optional[str] = None): # todo: apply pydantic model + self.slack_context = slack_context self.config = get_llm_config() self.model = ChatOpenAI( model=self.config["chat_model"], temperature=self.config["temperature"] @@ -52,7 +51,7 @@ def summarize( tuple[str, UUID]: The summarized chat log in bullet point format and the run ID. Examples: - # >>> summarizer = Summarizer() + # >>> summarizer = Summarizer(slack_context) # >>> summarizer.summarize("Alice: Hi\nBob: Hello\nAlice: How are you?\nBob: I'm doing well, thanks.", "test", "user1", "general") '- Alice greeted Bob.\n- Bob responded with a greeting.\n- Alice asked how Bob was doing. \n- Bob replied that he was doing well.', UUID('...') @@ -148,19 +147,18 @@ def counter(tok): return sum(map(counter, matches)) def split_messages_by_token_count( - self, client, messages: list[dict] + self, messages: list[dict] ) -> list[list[str]]: """ Split a list of strings into sub lists with a maximum token count. Args: - client: The Slack client. messages (list[dict]): A list of Slack messages to be split. Returns: list[list[str]]: A list of sub lists, where each sublist has a token count less than or equal to max_body_tokens """ - parsed_messages = get_parsed_messages(client, messages) + parsed_messages = self.slack_context.get_parsed_messages(messages) body_token_counts = [ self.estimate_openai_chat_token_count(msg) for msg in parsed_messages @@ -183,7 +181,6 @@ def split_messages_by_token_count( def summarize_slack_messages( self, - client, messages: list, channel_id: str, feature_name: str, @@ -197,7 +194,6 @@ def summarize_slack_messages( The summary is returned as a list, with the context message as the first element. Args: - client: The slack client. messages (list): A list of slack messages to be summarized. channel_id (str): The ID of the Slack channel. feature_name (str): The name of the feature being used. @@ -207,9 +203,9 @@ def summarize_slack_messages( tuple[list, UUID]: A list of summary text and the run ID. """ # Determine if the channel is private - is_private, channel_name = get_is_private_and_channel_name(client, channel_id) + is_private, channel_name = self.slack_context.get_is_private_and_channel_name(channel_id) - message_splits = self.split_messages_by_token_count(client, messages) + message_splits = self.split_messages_by_token_count(messages) logger.info(f"{len(message_splits)=}") result_text = [] diff --git a/ossai/utils.py b/ossai/utils.py index 5a408df..ca746d5 100644 --- a/ossai/utils.py +++ b/ossai/utils.py @@ -1,21 +1,15 @@ import os -import re import uuid -from time import mktime, gmtime, strptime +from time import gmtime, strptime import calendar from typing import Union -from datetime import date from dotenv import load_dotenv -from slack_sdk import WebClient -from slack_sdk.errors import SlackApiError from langchain.callbacks.tracers import LangChainTracer from ossai.logging_config import logger load_dotenv(override=True) -_id_name_cache = {} - class CustomLangChainTracer(LangChainTracer): def __init__(self, is_private=False, *args, **kwargs): @@ -30,70 +24,6 @@ def handleText(self, text, runId): logger.info("passing no text") super().handleText("", runId) - -async def get_bot_id(client) -> str: - """ - Retrieves the bot ID using the provided Slack WebClient. - - Returns: - str: The bot ID. - """ - try: - response = client.auth_test() - return response["bot_id"] - except SlackApiError as e: - logger.error(f"Error fetching bot ID: {e.response['error']}") - return "None" - - -async def get_channel_history( - client: WebClient, - channel_id: str, - since: date = None, - include_threads: bool = False, -) -> list: - # todo: if include_threads, recursively get messages from threads - - oldest_timestamp = mktime(since.timetuple()) if since else 0 - response = client.conversations_history( - channel=channel_id, limit=1000, oldest=oldest_timestamp - ) # 1000 is the max limit - bot_id = await get_bot_id(client) - # todo: (optional) excluding all other bots too - # todo: (optional) exclude messages that start with `/` (i.e. slash commands) - return [msg for msg in response["messages"] if msg.get("bot_id") != bot_id] - - -async def get_direct_message_channel_id(client: WebClient, user_id: str) -> str: - """ - Get the direct message channel ID for the bot, so you can say() via direct message. - :return str: - """ - # todo: cache this sucker too - try: - response = client.conversations_open(users=user_id) - return response["channel"]["id"] - except SlackApiError as e: - logger.error(f"Error fetching bot DM channel ID: {e.response['error']}") - raise e - - -def get_is_private_and_channel_name( - client: WebClient, channel_id: str -) -> tuple[bool, str]: - try: - channel_info = client.conversations_info(channel=channel_id) - channel_name = channel_info["channel"]["name"] - is_private = channel_info["channel"]["is_private"] - except Exception as e: - logger.error( - f"Error getting channel info for is_private, defaulting to private: {e}" - ) - channel_name = "unknown" - is_private = True - return is_private, channel_name - - def get_langsmith_config(feature_name: str, user: dict, channel: str, is_private=False): run_id = str(uuid.uuid4()) tracer = CustomLangChainTracer( @@ -112,7 +42,6 @@ def get_langsmith_config(feature_name: str, user: dict, channel: str, is_private "callbacks": [tracer], } - def get_llm_config(): chat_model = os.getenv("CHAT_MODEL", "gpt-3.5-turbo").strip() temperature = float(os.getenv("TEMPERATURE", 0.2)) @@ -132,78 +61,6 @@ def get_llm_config(): "language": language, } - -def get_name_from_id(client: WebClient, user_or_bot_id: str, is_bot=False) -> str: - """ - Retrieves the name associated with a user ID or bot ID. - - Args: - client (WebClient): An instance of the Slack WebClient. - user_or_bot_id (str): The user or bot ID. - is_bot (bool): Whether the ID is a bot ID. - - Returns: - str: The name associated with the ID. - """ - if user_or_bot_id in _id_name_cache: - return _id_name_cache[user_or_bot_id] - - try: - user_response = client.users_info(user=user_or_bot_id) - if user_response.get("ok"): - name = user_response["user"].get( - "real_name", user_response["user"]["profile"]["real_name"] - ) - _id_name_cache[user_or_bot_id] = name - return name - else: - logger.error("user fetch failed") - raise SlackApiError("user fetch failed", user_response) - except SlackApiError as e: - if e.response["error"] == "user_not_found": - try: - bot_response = client.bots_info(bot=user_or_bot_id) - if bot_response.get("ok"): - _id_name_cache[user_or_bot_id] = bot_response["bot"]["name"] - return bot_response["bot"]["name"] - else: - logger.error("bot fetch failed") - raise SlackApiError("bot fetch failed", bot_response) - except SlackApiError as e2: - logger.error( - f"Error fetching name for bot {user_or_bot_id=}: {e2.response['error']}" - ) - logger.error(f"Error fetching name for {user_or_bot_id=} {is_bot=} {e=}") - - return "Someone" - - -def get_parsed_messages(client, messages, with_names=True): - def parse_message(msg): - user_id = msg.get("user") - if user_id is None: - bot_id = msg.get("bot_id") - name = get_name_from_id(client, bot_id, is_bot=True) - else: - name = get_name_from_id(client, user_id) - - # substitute @mentions with names - parsed_message = re.sub( - r"<@[UB]\w+>", - lambda m: get_name_from_id(client, m.group(0)[2:-1]), - msg["text"], - ) - - if not with_names: - return re.sub( - r"<@[UB]\w+>", lambda m: "", msg["text"] - ) # remove @mentions + don't prepend author name - - return f"{name}: {parsed_message}" - - return [parse_message(message) for message in messages] - - def get_text_and_blocks_for_say( title: str, run_id: Union[uuid.UUID, None], @@ -281,48 +138,6 @@ def get_text_and_blocks_for_say( return text.split("\n")[0], blocks - -async def get_user_context(client: WebClient, user_id: str) -> dict: - """ - Get the username and title for the given user ID. - """ - try: - user_info = client.users_info(user=user_id) - logger.debug(user_info) - if user_info["ok"]: - name = user_info["user"]["name"] - title = user_info["user"]["profile"]["title"] - return {"name": name, "title": title} - except SlackApiError as e: - logger.error(f"Failed to fetch username: {e}") - return {} - - -def get_workspace_name(client: WebClient): - """ - Retrieve the workspace name using an instantiated Slack WebClient. - - Args: - - client (WebClient): An instantiated Slack WebClient. - - Returns: - - str: The workspace name if found, otherwise an empty string. - """ - - try: - response = client.team_info() - if response["ok"]: - return response["team"]["name"] - else: - logger.warning( - f"Error retrieving workspace name: {response['error']}. Falling back to WORKSPACE_NAME_FALLBACK." - ) - return os.getenv("WORKSPACE_NAME_FALLBACK", "") - except SlackApiError as e: - logger.error(f"Error retrieving workspace name: {e.response['error']}") - return os.getenv("WORKSPACE_NAME_FALLBACK", "") # None - - def get_since_timeframe_presets(): DAY_OF_SECONDS = 86400 now = gmtime() @@ -372,10 +187,8 @@ def get_since_timeframe_presets(): ], } - def main(): logger.error("DEBUGGING") - if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/tests/decorators/test_catch_errors_and_dm_user.py b/tests/decorators/test_catch_errors_and_dm_user.py index 4c9eb7c..34d50b3 100644 --- a/tests/decorators/test_catch_errors_and_dm_user.py +++ b/tests/decorators/test_catch_errors_and_dm_user.py @@ -2,13 +2,16 @@ from unittest.mock import AsyncMock, MagicMock, patch from slack_sdk import WebClient from slack_sdk.errors import SlackApiError +from ossai import slack_context from ossai.decorators.catch_error_dm_user import catch_errors_dm_user +from ossai.slack_context import SlackContext @pytest.mark.asyncio async def test_catch_errors_dm_user_happy_path(): # Setup - client = AsyncMock(spec=WebClient) + slack_context = AsyncMock(spec=SlackContext) + slack_context.client = AsyncMock(spec=WebClient) mock_func = AsyncMock() mock_func.return_value = "Success" decorated_func = catch_errors_dm_user(mock_func) @@ -19,30 +22,31 @@ async def test_catch_errors_dm_user_happy_path(): # Execute result = await decorated_func( - client, mock_ack, mock_payload, user_id="U123", arg1="test", arg2=123 + slack_context, mock_ack, mock_payload, user_id="U123", arg1="test", arg2=123 ) # Verify assert result == "Success" mock_func.assert_called_once_with( - client, mock_ack, mock_payload, user_id="U123", arg1="test", arg2=123 + slack_context, mock_ack, mock_payload, user_id="U123", arg1="test", arg2=123 ) assert mock_payload["channel_id"] == "C123" - client.chat_postEphemeral.assert_not_called() + slack_context.client.chat_postEphemeral.assert_not_called() @pytest.mark.asyncio @patch("ossai.decorators.catch_error_dm_user.logger") async def test_catch_errors_dm_user_error_handling(mock_logger): # Setup - client = AsyncMock(spec=WebClient) + slack_context = AsyncMock(spec=SlackContext) + slack_context.client = AsyncMock(spec=WebClient) mock_func = AsyncMock() mock_func.side_effect = SlackApiError( message="Pineapple on pizza error", response={"error": "API error"} ) decorated_func = catch_errors_dm_user(mock_func) - client.chat_postEphemeral = AsyncMock() + slack_context.client.chat_postEphemeral = AsyncMock() # Create a mock payload with channel_id and a mock ack function mock_payload = {"channel_id": "C123", "user_id": "U123"} @@ -50,12 +54,12 @@ async def test_catch_errors_dm_user_error_handling(mock_logger): # Execute await decorated_func( - client, mock_ack, mock_payload, user_id="U123", arg1="test", arg2=123 + slack_context, mock_ack, mock_payload, user_id="U123", arg1="test", arg2=123 ) # Verify - client.chat_postEphemeral.assert_called_once() - call_args = client.chat_postEphemeral.call_args + slack_context.client.chat_postEphemeral.assert_called_once() + call_args = slack_context.client.chat_postEphemeral.call_args assert call_args.kwargs["channel"] == "C123" assert call_args.kwargs["user"] == "U123" assert ( diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 4d4e400..2926660 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -4,6 +4,7 @@ from slack_sdk import WebClient from slack_sdk.errors import SlackApiError from datetime import datetime, timezone +from ossai.slack_context import SlackContext from ossai.handlers import ( handler_sandbox_slash_command, @@ -16,6 +17,21 @@ ) +@pytest.fixture +def mock_slack_context(): + mock = MagicMock(spec=SlackContext) + mock.get_bot_id = AsyncMock(return_value="B12345") + mock.get_channel_history = AsyncMock(return_value=[]) + mock.get_direct_message_channel_id = AsyncMock(return_value="D12345") + mock.get_is_private_and_channel_name = MagicMock(return_value=(False, "general")) + mock.get_name_from_id = MagicMock(return_value="John Doe") + mock.get_parsed_messages = MagicMock(return_value=["John: Hello", "Jane: Hi"]) + mock.get_user_context = AsyncMock(return_value={"name": "John", "title": "Developer"}) + mock.get_workspace_name = MagicMock(return_value="My Workspace") + mock.client = AsyncMock(spec=WebClient) + return mock + + @pytest.fixture def client(): return WebClient() @@ -74,76 +90,60 @@ async def custom_say(**kwargs): @pytest.mark.asyncio -@patch("ossai.handlers.get_direct_message_channel_id") async def test_handler_shortcuts( - get_direct_message_channel_id_mock, client, payload, say + mock_slack_context, payload, say ): - get_direct_message_channel_id_mock.return_value = "dm_channel_id" - await handler_shortcuts(client, True, payload, say, user_id="foo123") + mock_slack_context.client.get_direct_message_channel_id.return_value = "dm_channel_id" + await handler_shortcuts(mock_slack_context, True, payload, say, user_id="foo123") say.assert_called() @pytest.mark.asyncio -@patch("ossai.handlers.get_direct_message_channel_id") async def test_handler_tldr_extended_slash_command_channel_history_error( - get_direct_message_channel_id_mock, client, payload, say + mock_slack_context, payload, say ): - get_direct_message_channel_id_mock.return_value = "dm_channel_id" + mock_slack_context.get_direct_message_channel_id.return_value = "dm_channel_id" await handler_tldr_extended_slash_command( - client, AsyncMock(), payload, say, user_id="foo123" + mock_slack_context, AsyncMock(), payload, say, user_id="foo123" ) say.assert_called() @pytest.mark.asyncio -@patch("ossai.handlers.get_direct_message_channel_id") -@patch("ossai.handlers.get_channel_history") -@patch("ossai.handlers.get_parsed_messages") @patch("ossai.handlers.analyze_topics_of_history") async def test_handler_topics_slash_command( analyze_topics_of_history_mock, - get_parsed_messages_mock, - get_channel_history_mock, - get_direct_message_channel_id_mock, - client, + mock_slack_context, payload, say, ): - get_direct_message_channel_id_mock.return_value = "dm_channel_id" - get_channel_history_mock.return_value = ["message1", "message2", "message3"] - get_parsed_messages_mock.return_value = "parsed_messages" + mock_slack_context.get_direct_message_channel_id.return_value = "dm_channel_id" + mock_slack_context.get_channel_history.return_value = ["message1", "message2", "message3"] + mock_slack_context.get_parsed_messages.return_value = "parsed_messages" analyze_topics_of_history_mock.return_value = ("topic_overview", str(uuid.uuid4())) await handler_topics_slash_command( - client, AsyncMock(), payload, say, user_id="foo123" + mock_slack_context, AsyncMock(), payload, say, user_id="foo123" ) say.assert_called() @pytest.mark.asyncio -@patch("ossai.handlers.get_workspace_name") @patch("ossai.handlers.Summarizer") -@patch("slack_sdk.WebClient.conversations_replies") -@patch("ossai.handlers.get_direct_message_channel_id") -@patch("ossai.handlers.get_user_context") async def test_handler_shortcuts( - get_user_context_mock, - get_direct_message_channel_id_mock, - conversations_replies_mock, summarizer_mock, - get_workspace_name_mock, - client, + mock_slack_context, shortcuts_payload, say, ): # Arrange run_id = str(uuid.uuid4()) - get_direct_message_channel_id_mock.return_value = "dm_channel_id" - conversations_replies_mock.return_value = { + mock_slack_context.get_direct_message_channel_id.return_value = "dm_channel_id" + mock_slack_context.client.conversations_replies.return_value = { "ok": True, "messages": [{"text": "test message"}], } - get_workspace_name_mock.return_value = "workspace_name" - get_user_context_mock.return_value = {"user": "info"} + mock_slack_context.get_workspace_name.return_value = "workspace_name" + mock_slack_context.get_user_context.return_value = {"user": "info"} # Mock Summarizer instance summarizer_instance_mock = summarizer_mock.return_value @@ -184,7 +184,7 @@ async def test_handler_shortcuts( ] # Act - await handler_shortcuts(client, True, shortcuts_payload, say, user_id="foo123") + await handler_shortcuts(mock_slack_context, True, shortcuts_payload, say, user_id="foo123") # Assert say.assert_called_with( @@ -192,21 +192,19 @@ async def test_handler_shortcuts( ) summarizer_mock.assert_called_once() summarizer_instance_mock.summarize_slack_messages.assert_called_once_with( - client, [{"text": "test message"}], "channel_id", feature_name="summarize_thread", user={"user": "info"}, ) - get_user_context_mock.assert_called_once_with(client, "foo123") + mock_slack_context.get_user_context.assert_called_once_with("foo123") @pytest.mark.asyncio -@patch("ossai.handlers.get_direct_message_channel_id") async def test_handler_tldr_extended_slash_command_public( - get_direct_message_channel_id_mock, client, say + mock_slack_context, say ): - get_direct_message_channel_id_mock.return_value = "dm_channel_id" + mock_slack_context.get_direct_message_channel_id.return_value = "dm_channel_id" payload = { "text": "public", "channel_name": "channel_name", @@ -214,7 +212,7 @@ async def test_handler_tldr_extended_slash_command_public( "user_id": "user_id", } await handler_tldr_extended_slash_command( - client, AsyncMock(), payload, say, user_id="foo123" + mock_slack_context, AsyncMock(), payload, say, user_id="foo123" ) say.assert_called() @@ -283,24 +281,21 @@ def test_handler_feedback_very_helpful_button(env_get_mock, client_mock): @pytest.mark.asyncio -@patch("ossai.decorators.catch_error_dm_user.get_direct_message_channel_id") -@patch("ossai.utils.get_bot_id") async def test_handler_shortcuts_channel_not_found_error( - get_bot_id_mock, get_direct_message_channel_id_mock + mock_slack_context, ): # Setup - client = AsyncMock(spec=WebClient) - client.bots_info.return_value = {"bot": {"name": "TestBot"}} + mock_slack_context.client.bots_info.return_value = {"bot": {"name": "TestBot"}} say = AsyncMock() - get_direct_message_channel_id_mock.return_value = "DM123" - get_bot_id_mock.return_value = "B123" - client.conversations_replies.side_effect = SlackApiError( + mock_slack_context.get_direct_message_channel_id.return_value = "DM123" + mock_slack_context.get_bot_id.return_value = "B123" + mock_slack_context.client.conversations_replies.side_effect = SlackApiError( message="channel_not_found", response={"error": "channel_not_found"} ) # Execute await handler_shortcuts( - client, + mock_slack_context, False, { "channel": {"id": "C123"}, @@ -312,8 +307,8 @@ async def test_handler_shortcuts_channel_not_found_error( ) # Verify - get_direct_message_channel_id_mock.assert_called_once_with(client, "U123") - client.chat_postEphemeral.assert_called_once_with( + mock_slack_context.get_direct_message_channel_id.assert_called_with("U123") + mock_slack_context.client.chat_postEphemeral.assert_called_once_with( channel="DM123", user="U123", text="Sorry, couldn't find the channel. Have you added `@TestBot` to the channel?", @@ -321,24 +316,18 @@ async def test_handler_shortcuts_channel_not_found_error( @pytest.mark.asyncio -@patch("ossai.handlers.get_channel_history") -@patch("ossai.handlers.get_user_context") @patch("ossai.handlers.Summarizer") @patch("ossai.handlers.get_text_and_blocks_for_say") -@patch("ossai.handlers.get_direct_message_channel_id") async def test_handler_tldr_extended_slash_command_non_public( - get_direct_message_channel_id_mock, get_text_and_blocks_for_say_mock, summarizer_mock, - get_user_context_mock, - get_channel_history_mock, + mock_slack_context, ): # Setup - client = AsyncMock(spec=WebClient) say = AsyncMock() - get_direct_message_channel_id_mock.return_value = "DM123" - get_channel_history_mock.return_value = ["message1", "message2"] - get_user_context_mock.return_value = {"user": "info"} + mock_slack_context.get_direct_message_channel_id.return_value = "DM123" + mock_slack_context.get_channel_history.return_value = ["message1", "message2"] + mock_slack_context.get_user_context.return_value = {"user": "info"} summarizer_instance_mock = summarizer_mock.return_value summarizer_instance_mock.summarize_slack_messages.return_value = ("summary", "run_id") @@ -347,7 +336,7 @@ async def test_handler_tldr_extended_slash_command_non_public( # Execute await handler_tldr_extended_slash_command( - client, + mock_slack_context, AsyncMock(), { "channel_name": "general", @@ -361,10 +350,9 @@ async def test_handler_tldr_extended_slash_command_non_public( # Verify assert say.call_count == 2 say.assert_called_with(channel="DM123", text="text", blocks="blocks") - get_direct_message_channel_id_mock.assert_called_once_with(client, "U123") + mock_slack_context.get_direct_message_channel_id.assert_called_once_with("U123") summarizer_mock.assert_called_once() summarizer_instance_mock.summarize_slack_messages.assert_called_once_with( - client, ["message2", "message1"], "C123", feature_name="summarize_channel_messages", @@ -375,23 +363,17 @@ async def test_handler_tldr_extended_slash_command_non_public( @pytest.mark.asyncio @patch("ossai.handlers.datetime") -@patch("ossai.handlers.get_channel_history") -@patch("ossai.handlers.get_user_context") @patch("ossai.handlers.Summarizer") @patch("ossai.handlers.get_text_and_blocks_for_say") -@patch("ossai.handlers.get_direct_message_channel_id") @patch("aiohttp.ClientSession.post", new_callable=AsyncMock) async def test_handler_action_summarize_since_date( mock_post, - get_direct_message_channel_id_mock, get_text_and_blocks_for_say_mock, summarizer_mock, - get_user_context_mock, - get_channel_history_mock, datetime_mock, + mock_slack_context, ): # Setup - client = AsyncMock(spec=WebClient) ack = AsyncMock() body = { "channel": {"name": "general", "id": "C123"}, @@ -404,9 +386,9 @@ async def test_handler_action_summarize_since_date( ], "response_url": "http://example.com/response", } - get_direct_message_channel_id_mock.return_value = "DM123" - get_channel_history_mock.return_value = ["message1", "message2"] - get_user_context_mock.return_value = {"user": "info"} + mock_slack_context.get_direct_message_channel_id.return_value = "DM123" + mock_slack_context.get_channel_history.return_value = ["message1", "message2"] + mock_slack_context.get_user_context.return_value = {"user": "info"} summarizer_instance_mock = summarizer_mock.return_value summarizer_instance_mock.summarize_slack_messages.return_value = ("summary", "run_id") @@ -418,19 +400,18 @@ async def test_handler_action_summarize_since_date( datetime_mock.fromtimestamp.return_value = mocked_date # Execute - await handler_action_summarize_since_date(client, ack, body) + await handler_action_summarize_since_date(mock_slack_context, ack, body) # Verify ack.assert_called_once() - get_direct_message_channel_id_mock.assert_called_once_with(client, "U123") + mock_slack_context.get_direct_message_channel_id.assert_called_once_with("U123") datetime_mock.fromtimestamp.assert_called_once_with(1676955600) - get_channel_history_mock.assert_called_once_with( - client, "C123", since=mocked_date.date() + mock_slack_context.get_channel_history.assert_called_once_with( + "C123", since=mocked_date.date() ) - get_user_context_mock.assert_called_once_with(client, "U123") + mock_slack_context.get_user_context.assert_called_once_with("U123") summarizer_mock.assert_called_once() summarizer_instance_mock.summarize_slack_messages.assert_called_once_with( - client, ["message2", "message1"], "C123", feature_name="summarize_since_preset", @@ -442,7 +423,7 @@ async def test_handler_action_summarize_since_date( messages="summary", custom_prompt=None, ) - client.chat_postMessage.assert_called_with( + mock_slack_context.client.chat_postMessage.assert_called_with( channel="DM123", text="text", blocks="blocks" ) mock_post.assert_called_once_with( @@ -450,26 +431,24 @@ async def test_handler_action_summarize_since_date( ) @pytest.mark.asyncio -@patch("ossai.handlers.get_direct_message_channel_id") @patch("ossai.handlers.get_since_timeframe_presets") async def test_handler_tldr_since_slash_command_happy_path( - get_since_timeframe_presets_mock, get_direct_message_channel_id_mock + get_since_timeframe_presets_mock, + mock_slack_context, ): # Setup - client = AsyncMock(spec=WebClient) - client.chat_postEphemeral = MagicMock() say = AsyncMock() payload = {"user_id": "U123", "channel_id": "C123", "channel_name": "general"} get_since_timeframe_presets_mock.return_value = {"foo": "bar"} - get_direct_message_channel_id_mock.return_value = "DM123" + mock_slack_context.get_direct_message_channel_id.return_value = "DM123" ack = AsyncMock() # Execute - await handler_tldr_since_slash_command(client, ack, payload, say) + await handler_tldr_since_slash_command(mock_slack_context, ack, payload, say) # Verify ack.assert_called_once() - client.chat_postEphemeral.assert_called_once_with( + mock_slack_context.client.chat_postEphemeral.assert_called_once_with( channel="C123", user="U123", text="Choose your summary timeframe.", @@ -498,28 +477,24 @@ async def test_handler_tldr_since_slash_command_happy_path( @pytest.mark.asyncio -@patch("ossai.decorators.catch_error_dm_user.get_direct_message_channel_id") -@patch("ossai.utils.get_bot_id") async def test_handlers_bot_not_in_channel( - get_bot_id_mock: AsyncMock, - get_direct_message_channel_id_mock: AsyncMock, + mock_slack_context ) -> None: USER_ID = "U123" CHANNEL_ID = "C123" DM_CHANNEL_ID = "DM123" - client = AsyncMock(spec=WebClient) - client.bots_info.return_value = {"bot": {"name": "TestBot"}} + mock_slack_context.client.bots_info.return_value = {"bot": {"name": "TestBot"}} say = AsyncMock() ack = AsyncMock() - get_direct_message_channel_id_mock.return_value = DM_CHANNEL_ID - get_bot_id_mock.return_value = "B123" + mock_slack_context.get_direct_message_channel_id.return_value = DM_CHANNEL_ID + mock_slack_context.get_bot_id.return_value = "B123" handlers_to_test = [ ( handler_shortcuts, ( - client, + mock_slack_context, False, { "channel": {"id": CHANNEL_ID}, @@ -533,7 +508,7 @@ async def test_handlers_bot_not_in_channel( ( handler_tldr_since_slash_command, ( - client, + mock_slack_context, ack, { "user_id": USER_ID, @@ -546,7 +521,7 @@ async def test_handlers_bot_not_in_channel( ( handler_sandbox_slash_command, ( - client, + mock_slack_context, ack, { "user_id": USER_ID, @@ -560,7 +535,7 @@ async def test_handlers_bot_not_in_channel( ( handler_action_summarize_since_date, ( - client, + mock_slack_context, ack, { "channel": {"name": "general", "id": CHANNEL_ID}, @@ -578,7 +553,7 @@ async def test_handlers_bot_not_in_channel( ( handler_topics_slash_command, ( - client, + mock_slack_context, ack, { "user_id": USER_ID, @@ -592,7 +567,7 @@ async def test_handlers_bot_not_in_channel( ( handler_tldr_extended_slash_command, ( - client, + mock_slack_context, ack, { "user_id": USER_ID, @@ -607,17 +582,20 @@ async def test_handlers_bot_not_in_channel( ] for handler, args in handlers_to_test: - client.reset_mock() + mock_slack_context.client.reset_mock() say.reset_mock() ack.reset_mock() - client.conversations_replies.side_effect = SlackApiError( + mock_slack_context.client.conversations_replies.side_effect = SlackApiError( + message="channel_not_found", response={"error": "channel_not_found"} + ) + mock_slack_context.client.chat_postEphemeral.side_effect = SlackApiError( message="channel_not_found", response={"error": "channel_not_found"} ) - client.chat_postEphemeral.side_effect = SlackApiError( + mock_slack_context.client.conversations_history.side_effect = SlackApiError( message="channel_not_found", response={"error": "channel_not_found"} ) - client.conversations_history.side_effect = SlackApiError( + mock_slack_context.client.chat_postMessage.side_effect = SlackApiError( message="channel_not_found", response={"error": "channel_not_found"} ) say.side_effect = SlackApiError( @@ -628,24 +606,23 @@ async def test_handlers_bot_not_in_channel( error_message = "Sorry, couldn't find the channel. Have you added `@TestBot` to the channel?" - client.chat_postEphemeral.assert_called_with( + mock_slack_context.client.chat_postEphemeral.assert_called_with( channel=DM_CHANNEL_ID, user=USER_ID, text=error_message ) - client.conversations_replies.side_effect = None - client.chat_postEphemeral.side_effect = None - client.conversations_history.side_effect = None + mock_slack_context.client.conversations_replies.side_effect = None + mock_slack_context.client.chat_postEphemeral.side_effect = None + mock_slack_context.client.conversations_history.side_effect = None say.side_effect = None @pytest.mark.asyncio -async def test_handler_sandbox_slash_command_happy_path(): +async def test_handler_sandbox_slash_command_happy_path(mock_slack_context): ack = AsyncMock() say = AsyncMock() payload = {"user_id": "U123", "channel_id": "C123", "channel_name": "general"} - client = AsyncMock(spec=WebClient) - await handler_sandbox_slash_command(client, ack, payload, say, user_id="foo123") + await handler_sandbox_slash_command(mock_slack_context, ack, payload, say, user_id="foo123") say.assert_called_once() assert any( "This is a test of the /sandbox command." in str(block) diff --git a/tests/test_slack_context.py b/tests/test_slack_context.py new file mode 100644 index 0000000..20834fb --- /dev/null +++ b/tests/test_slack_context.py @@ -0,0 +1,171 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from slack_sdk import WebClient +from ossai.slack_context import SlackContext +from slack_sdk.errors import SlackApiError + +@pytest.fixture +def mock_web_client(): + with patch("slack_sdk.WebClient") as mock_client: + def users_info_side_effect(user): + users = { + "U123": { + "ok": True, + "user": { + "real_name": "Ashley Wang", + "name": "ashley.wang", + "profile": {"real_name": "Ashley Wang", "title": "CEO"}, + }, + }, + "U456": { + "ok": True, + "user": { + "real_name": "Taylor Garcia", + "name": "taylor.garcia", + "profile": {"real_name": "Taylor Garcia", "title": "CTO"}, + }, + }, + } + return users.get(user, {"ok": False}) + + mock_client.users_info.side_effect = users_info_side_effect + mock_client.auth_test = MagicMock(return_value={"bot_id": "B123"}) + mock_client.conversations_history = MagicMock(return_value={"messages": [{"bot_id": "B123"}]}) + mock_client.conversations_open = MagicMock(return_value={"channel": {"id": "C123"}}) + mock_client.team_info = MagicMock(return_value={"ok": True, "team": {"name": "Test Workspace"}}) + yield mock_client + +@pytest.fixture +def slack_context(mock_web_client): + return SlackContext(mock_web_client) + + +@pytest.mark.asyncio +async def test_get_bot_id(slack_context): + assert await slack_context.get_bot_id() == "B123" + +@pytest.mark.asyncio +async def test_get_bot_id_with_exception(slack_context): + slack_context.client.auth_test.side_effect = SlackApiError("error", {"error": "error"}) + assert await slack_context.get_bot_id() == "None" + + +@pytest.mark.asyncio +async def test_get_channel_history(slack_context): + slack_context.client.conversations_history.return_value = {"messages": [{"bot_id": "B123"}]} + assert await slack_context.get_channel_history("C123") == [] + + +@pytest.mark.asyncio +async def test_get_direct_message_channel_id(slack_context): + slack_context.client.conversations_open.return_value = {"channel": {"id": "C123"}} + assert await slack_context.get_direct_message_channel_id("U123") == "C123" + + +@pytest.mark.asyncio +async def test_get_direct_message_channel_id_with_exception(slack_context): + slack_context.client.conversations_open.side_effect = SlackApiError( + "error", {"error": "error"} + ) + with pytest.raises(SlackApiError) as e_info: + await slack_context.get_direct_message_channel_id("U123") + assert True + + +def test_get_name_from_id(slack_context): + assert slack_context.get_name_from_id("U123") == "Ashley Wang" + + +def test_get_name_from_id_bot_user(slack_context): + slack_context.client.users_info.side_effect = lambda user: { + "ok": False, + "error": "user_not_found", + } # simulate user not found + slack_context.client.bots_info.side_effect = lambda bot: { + "ok": True, + "bot": {"name": "Bender Bending Rodríguez"}, + } + + assert slack_context.get_name_from_id("B123") == "Bender Bending Rodríguez" + + + +def test_get_name_from_id_bot_user_error(slack_context): + slack_context.client.users_info.side_effect = lambda user: { + "ok": False, + "error": "user_not_found", + } + slack_context.client.bots_info.side_effect = lambda bot: { + "ok": False, + "error": "bot_not_found", + } + + assert slack_context.get_name_from_id("B456") == "Someone" + + +def test_get_name_from_id_bot_user_exception(slack_context): + slack_context.client.users_info.side_effect = lambda user: { + "ok": False, + "error": "user_not_found", + } # simulate user not found + slack_context.client.bots_info.side_effect = SlackApiError( + "bot fetch failed", {"error": "bot_not_found"} + ) + + assert slack_context.get_name_from_id("B456") == "Someone" + + +def test_get_parsed_messages(slack_context): + messages = [ + {"text": "Hello <@U456>", "user": "U123"}, + {"text": "nohello.net!!", "user": "U456"}, + ] + assert slack_context.get_parsed_messages(messages) == [ + "Ashley Wang: Hello Taylor Garcia", # prefix with author's name & replace user ID with user's name + "Taylor Garcia: nohello.net!!", # prefix with author's name + ] + + +def test_get_parsed_messages_without_names(slack_context): + messages = [{"text": "Hello <@U456>", "user": "U123"}] + + # no author's name prefix & remove @mentions + assert slack_context.get_parsed_messages(messages, with_names=False) == [ + "Hello " + ] + + +def test_get_parsed_messages_with_bot(slack_context): + slack_context.client.users_info.side_effect = SlackApiError( + "user fetch failed", {"error": "user_not_found"} + ) # simulate user not found + slack_context.client.bots_info.side_effect = lambda bot: { + "ok": True, + "bot": {"name": "Bender Bending Rodríguez"}, + } + messages = [{"text": "I am <@B123>!", "bot_id": "B123"}] + assert slack_context.get_parsed_messages(messages) == [ + "Bender Bending Rodríguez: I am Bender Bending Rodríguez!", + ] + + +def test_get_workspace_name(slack_context): + slack_context.client.team_info.return_value = {"ok": True, "team": {"name": "Workspace"}} + result = slack_context.get_workspace_name() + slack_context.client.team_info.assert_called_once() + assert result == "Workspace" + + +def test_get_workspace_name_exception(slack_context): + with patch.dict("os.environ", {"WORKSPACE_NAME_FALLBACK": ""}): + slack_context.client.team_info.side_effect = SlackApiError("error", {"error": "error"}) + result = slack_context.get_workspace_name() + assert result == "" + + +def test_get_workspace_name_failure(slack_context): + with patch.dict("os.environ", {"WORKSPACE_NAME_FALLBACK": ""}): + slack_context.client.team_info.return_value = {"ok": False, "error": "team_info error"} + result = slack_context.get_workspace_name() + slack_context.client.team_info.assert_called_once() + assert result == "" diff --git a/tests/test_slack_server.py b/tests/test_slack_server.py index cc4e637..667c9d3 100644 --- a/tests/test_slack_server.py +++ b/tests/test_slack_server.py @@ -6,6 +6,7 @@ from slack_bolt import App from slack_bolt.adapter.socket_mode import SocketModeHandler +from ossai.slack_context import SlackContext from ossai.slack_server import handle_slash_command_sandbox, main as slack_server_main @@ -64,5 +65,6 @@ async def test_handle_slash_command_sandbox( # Assert mock_handler_sandbox_slash_command.assert_called_once_with( - mock_client, mock_ack, mock_payload, mock_say, user_id=mock_user_id + ANY, mock_ack, mock_payload, mock_say, user_id=mock_user_id ) + assert isinstance(mock_handler_sandbox_slash_command.call_args[0][0], SlackContext) diff --git a/tests/test_summarizer.py b/tests/test_summarizer.py index 4bfba2c..63c3f6e 100644 --- a/tests/test_summarizer.py +++ b/tests/test_summarizer.py @@ -1,38 +1,35 @@ import re import runpy -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, AsyncMock import pytest from openai import RateLimitError -from ossai import summarizer from ossai.summarizer import Summarizer, main as summarizer_main from ossai.utils import get_llm_config -# def test_summarize(): -# with patch('openai.ChatCompletion.create') as mock_create: -# mock_create.return_value = { -# "choices": [ -# { -# "message": { -# 'content': 'Summarized text' -# } -# } -# ] -# } -# result = summarizer.summarize("Alice: Hi\nBob: Hello\nAlice: How are you?\nBob: I'm doing well, thanks.") -# assert result == 'Summarized text' - - -def test_summarize_langchain(): +@pytest.fixture +def mock_slack_context(): + mock = MagicMock() + mock.get_bot_id = AsyncMock(return_value="B12345") + mock.get_channel_history = AsyncMock(return_value=[]) + mock.get_direct_message_channel_id = AsyncMock(return_value="D12345") + mock.get_is_private_and_channel_name = MagicMock(return_value=(False, "general")) + mock.get_name_from_id = MagicMock(return_value="John Doe") + mock.get_parsed_messages = MagicMock(return_value=["John: Hello", "Jane: Hi"]) + mock.get_user_context = AsyncMock(return_value={"name": "John", "title": "Developer"}) + mock.get_workspace_name = MagicMock(return_value="My Workspace") + return mock + +def test_summarize_langchain(mock_slack_context): text = """\ Bob: How are you? Jane: It's been so long. I've been great. I bought a house, started a business, and sold my left kidney. Bob: Well isn't that just wonderful. Did you mean to sell your kidney? I quite like having 2. Jane: I figured I had a spare and really wanted a Tesla. So, yeah. """ - summarizer = Summarizer() + summarizer = Summarizer(mock_slack_context) result, run_id = summarizer.summarize( text, feature_name="unit_test", user="test_user", channel="test_channel" ) @@ -49,25 +46,23 @@ def test_summarize_langchain(): ) -def test_estimate_openai_chat_token_count(): - summarizer = Summarizer() +def test_estimate_openai_chat_token_count(mock_slack_context): + summarizer = Summarizer(mock_slack_context) result = summarizer.estimate_openai_chat_token_count("Hello, how are you?") assert result == 7 -def test_split_messages_by_token_count(): - with patch( - "ossai.summarizer.get_parsed_messages" - ) as mock_get_parsed_messages, patch.dict("os.environ", {"MAX_BODY_TOKENS": "3"}): - mock_get_parsed_messages.return_value = ["Hello", "how", "are", "you"] +def test_split_messages_by_token_count(mock_slack_context): + with patch.dict("os.environ", {"MAX_BODY_TOKENS": "3"}): + mock_slack_context.get_parsed_messages.return_value = ["Hello", "how", "are", "you"] messages = [ {"text": "Hello"}, {"text": "how"}, {"text": "are"}, {"text": "you"}, ] - summarizer = Summarizer() - result = summarizer.split_messages_by_token_count(None, messages) + summarizer = Summarizer(mock_slack_context) + result = summarizer.split_messages_by_token_count(messages) assert result == [["Hello", "how"], ["are", "you"]] @@ -82,9 +77,8 @@ def test_missing_openai_api_key(): assert str(e.value) == "OPENAI_API_KEY is not set in .env file" -def test_summarize_slack_messages(): +def test_summarize_slack_messages(mock_slack_context): # Mock the client and messages - mock_client = MagicMock() mock_messages = [ {"text": "Hello"}, {"text": "how"}, @@ -93,12 +87,13 @@ def test_summarize_slack_messages(): ] # Mock the conversations_info method to return a fixed response - mock_client.conversations_info.return_value = { + mock_slack_context.client.conversations_info.return_value = { "channel": {"name": "foo", "is_private": False} } + mock_slack_context.get_is_private_and_channel_name.return_value = (False, "foo") # Create a Summarizer instance - summarizer = Summarizer() + summarizer = Summarizer(mock_slack_context) # Mock the split_messages_by_token_count method with patch.object( @@ -113,14 +108,13 @@ def test_summarize_slack_messages(): return_value=("Summarized text", "run_id") ) as mock_summarize: result, run_id = summarizer.summarize_slack_messages( - mock_client, mock_messages, channel_id="C1234567890", feature_name="unit_test", user="test_user", ) # Check that the split_messages_by_token_count method was called with the correct arguments - mock_split.assert_called_once_with(mock_client, mock_messages) + mock_split.assert_called_once_with(mock_messages) # Check that the summarize method was called with the correct arguments mock_summarize.assert_called_with( "\n".join(["Hello", "how", "are", "you"]), @@ -133,9 +127,9 @@ def test_summarize_slack_messages(): assert result == ["Summarized text"] -def test_summarize_slack_messages_private_channel(): +def test_summarize_slack_messages_private_channel(mock_slack_context): # Mock the client and messages - mock_client = MagicMock() + mock_messages = [ {"text": "Hello"}, {"text": "how"}, @@ -144,12 +138,13 @@ def test_summarize_slack_messages_private_channel(): ] # Mock the conversations_info method to return a fixed response - mock_client.conversations_info.return_value = { + mock_slack_context.client.conversations_info.return_value = { "channel": {"name": "foo", "is_private": True} } + mock_slack_context.get_is_private_and_channel_name.return_value = (True, "foo") # Create a Summarizer instance - summarizer = Summarizer() + summarizer = Summarizer(mock_slack_context) # Mock the split_messages_by_token_count method with patch.object( @@ -164,14 +159,13 @@ def test_summarize_slack_messages_private_channel(): return_value=("Summarized text", "run_id") ) as mock_summarize: result, run_id = summarizer.summarize_slack_messages( - mock_client, mock_messages, channel_id="C1234567890", feature_name="unit_test", user="test_user", ) # Check that the split_messages_by_token_count method was called with the correct arguments - mock_split.assert_called_once_with(mock_client, mock_messages) + mock_split.assert_called_once_with(mock_messages) # Check that the summarize method was called with the correct arguments mock_summarize.assert_called_with( "\n".join(["Hello", "how", "are", "you"]), @@ -184,9 +178,8 @@ def test_summarize_slack_messages_private_channel(): assert result == ["Summarized text"] -def test_summarize_slack_messages_rate_limit_error(): - # Mock the client and messages - mock_client = MagicMock() +def test_summarize_slack_messages_rate_limit_error(mock_slack_context): + # Mock the messages mock_messages = [ {"text": "Hello"}, {"text": "how"}, @@ -195,7 +188,7 @@ def test_summarize_slack_messages_rate_limit_error(): ] # Create a Summarizer instance - summarizer = Summarizer() + summarizer = Summarizer(mock_slack_context) # Mock the split_messages_by_token_count method with patch.object( @@ -212,7 +205,6 @@ def test_summarize_slack_messages_rate_limit_error(): ) ) as mock_summarize: result, run_id = summarizer.summarize_slack_messages( - mock_client, mock_messages, channel_id="C1234567890", feature_name="unit_test", diff --git a/tests/test_utils.py b/tests/test_utils.py index 0d88abe..c79c6cf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,187 +3,38 @@ import time import uuid import pytest -from slack_sdk.errors import SlackApiError from ossai import utils -@pytest.fixture(autouse=True) -def clear_cache(): - utils._id_name_cache.clear() - yield - utils._id_name_cache.clear() - - @pytest.fixture def mock_client(): - with patch("ossai.utils.WebClient") as mock_client: - - def users_info_side_effect(user): - users = { - "U123": { - "ok": True, - "user": { - "real_name": "Ashley Wang", - "name": "ashley.wang", - "profile": {"real_name": "Ashley Wang", "title": "CEO"}, - }, - }, - "U456": { - "ok": True, - "user": { - "real_name": "Taylor Garcia", - "name": "taylor.garcia", - "profile": {"real_name": "Taylor Garcia", "title": "CTO"}, - }, - }, - } - return users.get(user, {"ok": False}) - - mock_client.users_info.side_effect = users_info_side_effect - yield mock_client - - -@pytest.fixture -def mock_user_client(): - with patch( - "slack_sdk.WebClient", - return_value=MagicMock(auth_test=MagicMock(return_value={"user_id": "U123"})), - ) as mock_client: - yield mock_client - - -@pytest.mark.asyncio -async def test_get_bot_id(mock_client): - mock_client.auth_test.return_value = {"bot_id": "B123"} - assert await utils.get_bot_id(mock_client) == "B123" - - -@pytest.mark.asyncio -async def test_get_bot_id_with_exception(mock_client): - mock_client.auth_test.side_effect = SlackApiError("error", {"error": "error"}) - assert await utils.get_bot_id(mock_client) == "None" - - -@pytest.mark.asyncio -async def test_get_channel_history(mock_client): - mock_client.conversations_history.return_value = {"messages": [{"bot_id": "B123"}]} - mock_client.auth_test.return_value = {"bot_id": "B123"} - assert await utils.get_channel_history(mock_client, "C123") == [] - - -@pytest.mark.asyncio -async def test_get_direct_message_channel_id(mock_client, mock_user_client): - mock_client.conversations_open.return_value = {"channel": {"id": "C123"}} - assert await utils.get_direct_message_channel_id(mock_client, "U123") == "C123" - - -@pytest.mark.asyncio -async def test_get_direct_message_channel_id_with_exception(mock_client): - mock_client.conversations_open.side_effect = SlackApiError( - "error", {"error": "error"} - ) - with pytest.raises(SlackApiError) as e_info: - await utils.get_direct_message_channel_id(mock_client, "U123") - assert True - - -def test_get_name_from_id(mock_client): - assert utils.get_name_from_id(mock_client, "U123") == "Ashley Wang" - - -def test_get_name_from_id_bot_user(mock_client): - mock_client.users_info.side_effect = lambda user: { - "ok": False, - "error": "user_not_found", - } # simulate user not found - mock_client.bots_info.side_effect = lambda bot: { - "ok": True, - "bot": {"name": "Bender Bending Rodríguez"}, - } - - assert utils.get_name_from_id(mock_client, "B123") == "Bender Bending Rodríguez" - - -def test_get_name_from_id_bot_user_error(mock_client): - mock_client.users_info.side_effect = lambda user: { - "ok": False, - "error": "user_not_found", - } - mock_client.bots_info.side_effect = lambda bot: { - "ok": False, - "error": "bot_not_found", - } - - assert utils.get_name_from_id(mock_client, "B456") == "Someone" - - -def test_get_name_from_id_bot_user_exception(mock_client): - mock_client.users_info.side_effect = lambda user: { - "ok": False, - "error": "user_not_found", - } # simulate user not found - mock_client.bots_info.side_effect = SlackApiError( - "bot fetch failed", {"error": "bot_not_found"} - ) - - assert utils.get_name_from_id(mock_client, "B456") == "Someone" - - -def test_get_parsed_messages(mock_client): - messages = [ - {"text": "Hello <@U456>", "user": "U123"}, - {"text": "nohello.net!!", "user": "U456"}, - ] - assert utils.get_parsed_messages(mock_client, messages) == [ - "Ashley Wang: Hello Taylor Garcia", # prefix with author's name & replace user ID with user's name - "Taylor Garcia: nohello.net!!", # prefix with author's name - ] - - -def test_get_parsed_messages_without_names(mock_client): - messages = [{"text": "Hello <@U456>", "user": "U123"}] - - # no author's name prefix & remove @mentions - assert utils.get_parsed_messages(mock_client, messages, with_names=False) == [ - "Hello " - ] - - -def test_get_parsed_messages_with_bot(mock_client): - mock_client.users_info.side_effect = SlackApiError( - "user fetch failed", {"error": "user_not_found"} - ) # simulate user not found - mock_client.bots_info.side_effect = lambda bot: { - "ok": True, - "bot": {"name": "Bender Bending Rodríguez"}, - } - messages = [{"text": "I am <@B123>!", "bot_id": "B123"}] - assert utils.get_parsed_messages(mock_client, messages) == [ - "Bender Bending Rodríguez: I am Bender Bending Rodríguez!", - ] - - -def test_get_workspace_name(mock_client): - mock_client.team_info.return_value = {"ok": True, "team": {"name": "Workspace"}} - result = utils.get_workspace_name(mock_client) - mock_client.team_info.assert_called_once() - assert result == "Workspace" - - -def test_get_workspace_name_exception(mock_client): - with patch.dict("os.environ", {"WORKSPACE_NAME_FALLBACK": ""}): - mock_client.team_info.side_effect = SlackApiError("error", {"error": "error"}) - result = utils.get_workspace_name(mock_client) - assert result == "" - - -def test_get_workspace_name_failure(mock_client): - with patch.dict("os.environ", {"WORKSPACE_NAME_FALLBACK": ""}): - mock_client.team_info.return_value = {"ok": False, "error": "team_info error"} - result = utils.get_workspace_name(mock_client) - mock_client.team_info.assert_called_once() - assert result == "" + pass + # with patch("ossai.utils.WebClient") as mock_client: + + # def users_info_side_effect(user): + # users = { + # "U123": { + # "ok": True, + # "user": { + # "real_name": "Ashley Wang", + # "name": "ashley.wang", + # "profile": {"real_name": "Ashley Wang", "title": "CEO"}, + # }, + # }, + # "U456": { + # "ok": True, + # "user": { + # "real_name": "Taylor Garcia", + # "name": "taylor.garcia", + # "profile": {"real_name": "Taylor Garcia", "title": "CTO"}, + # }, + # }, + # } + # return users.get(user, {"ok": False}) + + # mock_client.users_info.side_effect = users_info_side_effect + # yield mock_client def test_main_as_script(capfd): @@ -229,7 +80,7 @@ def test_get_langsmith_config_happy_path(): @pytest.mark.asyncio -async def test_get_user_context_success(mock_client): +async def _test_get_user_context_success(mock_client): result = await utils.get_user_context(mock_client, "U123") mock_client.users_info.assert_called_once_with(user="U123")