diff --git a/ossai/handlers.py b/ossai/handlers.py index 06a94ac..f24e0c4 100644 --- a/ossai/handlers.py +++ b/ossai/handlers.py @@ -22,6 +22,7 @@ get_since_timeframe_presets, ) +_custom_prompt_cache = {} def handler_feedback(body): """ @@ -102,7 +103,6 @@ async def handler_tldr_extended_slash_command( client: WebClient, ack, payload, say, user_id: str ): await ack() - text = payload.get("text", None) channel_name = payload["channel_name"] channel_id = payload["channel_id"] dm_channel_id = None @@ -110,14 +110,12 @@ async def handler_tldr_extended_slash_command( dm_channel_id = await get_direct_message_channel_id(client, user_id) await say(channel=dm_channel_id, text="...") - if text: - return await say("ERROR: custom prompt support coming soon!") - history = await get_channel_history(client, channel_id) history.reverse() user = await get_user_context(client, user_id) title = f"*Summary of #{channel_name}* (last {len(history)} messages)\n" - summarizer = Summarizer() + custom_prompt = payload.get("text", None) + summarizer = Summarizer(custom_prompt=custom_prompt) summary, run_id = summarizer.summarize_slack_messages( client, history, @@ -126,7 +124,7 @@ async def handler_tldr_extended_slash_command( user=user, ) text, blocks = get_text_and_blocks_for_say( - title=title, run_id=run_id, messages=summary + title=title, run_id=run_id, messages=summary, custom_prompt=custom_prompt ) return await say(channel=dm_channel_id, text=text, blocks=blocks) @@ -136,9 +134,6 @@ async def handler_topics_slash_command( client: WebClient, ack, payload, say, user_id: str ): await ack() - text = payload.get("text", None) - if text: - return await say("ERROR: custom prompt support coming soon!") channel_id = payload["channel_id"] dm_channel_id = await get_direct_message_channel_id(client, user_id) await say(channel=dm_channel_id, text="...") @@ -149,6 +144,11 @@ async def handler_topics_slash_command( messages = get_parsed_messages(client, history, with_names=False) user = await get_user_context(client, user_id) is_private, channel_name = get_is_private_and_channel_name(client, channel_id) + custom_prompt = payload.get("text", None) + if custom_prompt: + # todo: add support for custom prompts to /tldr + await say(channel=dm_channel_id, text="Sorry, this command doesn't support custom prompts yet so I'm processing your request without it.") + topic_overview, run_id = await analyze_topics_of_history( channel_name, messages, user=user, is_private=is_private ) @@ -164,11 +164,10 @@ async def handler_tldr_since_slash_command(client: WebClient, ack, payload, say) await ack() title = "Choose your summary timeframe." dm_channel_id = await get_direct_message_channel_id(client, payload["user_id"]) - text = payload.get("text", None) - if text: - return await say("ERROR: custom prompt support coming soon!") + + custom_prompt = payload.get("text", None) - client.chat_postEphemeral( + result = client.chat_postEphemeral( channel=payload["channel_id"], user=payload["user_id"], text=title, @@ -184,17 +183,23 @@ async def handler_tldr_since_slash_command(client: WebClient, ack, payload, say) "text": "Select a date", "emoji": True, }, - "action_id": "summarize_since", + "action_id": f"summarize_since", }, ], } ], ) + # get `custom_prompt` into handler_action_summarize_since_date() + key = f"{result['message_ts']}__{payload['user_id']}" + _custom_prompt_cache[key] = custom_prompt + logger.debug(f"Storing `custom_prompt` at {key}: {custom_prompt}") + await say( channel=dm_channel_id, text=f'In #{payload["channel_name"]}, choose a date or timeframe to get your summary', ) + return @catch_errors_dm_user @@ -227,7 +232,11 @@ async def handler_action_summarize_since_date(client: WebClient, ack, body): history = await get_channel_history(client, channel_id, since=since_datetime) history.reverse() user = await get_user_context(client, user_id) - summarizer = Summarizer() + custom_prompt = None + if 'container' in body and 'message_ts' in body['container']: + key = f"{body['container']['message_ts']}__{user_id}" + custom_prompt = _custom_prompt_cache.get(key, None) + summarizer = Summarizer(custom_prompt=custom_prompt) summary, run_id = summarizer.summarize_slack_messages( client, history, channel_id, feature_name=feature_name, user=user ) @@ -235,6 +244,7 @@ async def handler_action_summarize_since_date(client: WebClient, ack, body): title=f'*Summary of #{channel_name}* since {since_datetime.strftime("%A %b %-d, %Y")} ({len(history)} messages)\n', run_id=run_id, messages=summary, + custom_prompt=custom_prompt, ) # todo: somehow add date/preset choice to langsmith metadata # feature_name: str -> feature: str || Tuple[str, List(Tuple[str, str])] @@ -245,17 +255,26 @@ async def handler_action_summarize_since_date(client: WebClient, ack, body): async def handler_sandbox_slash_command( client: WebClient, ack, payload, say, user_id: str ): - text = payload.get("text", None) - if text: - return await say("ERROR: custom prompt support coming soon!") logger.debug(f"Handling /sandbox command") await ack() - run_id = str(uuid.uuid4()) - run_id = None - text = """-- Better error handling coming soon! Useful summary of content goes here -- (no run id)""" - lines = text.strip().split("\n") + channel_id = payload["channel_id"] + custom_prompt = payload.get("text", None) + summarizer = Summarizer(custom_prompt=custom_prompt) + summary, run_id = summarizer.summarize_slack_messages( + client, + [ + {"text": "bacon", "user": user_id}, + {"text": "eggs", "user": user_id}, + {"text": "spam", "user": user_id}, + {"text": "orange juice", "user": user_id}, + {"text": "coffee", "user": user_id}, + ], + channel_id=channel_id, + feature_name="sandbox", + user=user_id, + ) title = "This is a test of the /sandbox command." text, blocks = get_text_and_blocks_for_say( - title=title, run_id=run_id, messages=lines + title=title, run_id=run_id, messages=summary, custom_prompt=custom_prompt ) return await say(text=text, blocks=blocks) diff --git a/ossai/summarizer.py b/ossai/summarizer.py index 030ce16..4d9af10 100644 --- a/ossai/summarizer.py +++ b/ossai/summarizer.py @@ -20,10 +20,12 @@ class Summarizer: - def __init__(self): + def __init__(self, custom_prompt: str | None = None): + # todo: apply pydantic model self.config = get_llm_config() self.model = ChatOpenAI(model=self.config["chat_model"], temperature=self.config["temperature"]) self.parser = StrOutputParser() + self.custom_prompt = custom_prompt def summarize( self, @@ -62,18 +64,21 @@ def summarize( So, The assistant needs to speak in {language}. """ - human_msg = """\ + base_human_msg = """\ Please summarize the following chat log to a flat markdown formatted bullet list. Do not write a line by line summary. Instead, summarize the overall conversation. Do not include greeting/salutation/polite expressions in summary. Make the summary easy to read while maintaining a conversational tone and retaining meaning. Write in conversational English. + {custom_instructions} {text} """ + # todo: guard against prompt injection + prompt_template = ChatPromptTemplate.from_messages( - [("system", system_msg), ("user", human_msg)] + [("system", system_msg), ("user", base_human_msg)] ) chain = prompt_template | self.model | self.parser @@ -87,7 +92,12 @@ def summarize( ) logger.info(f"{langsmith_config=}") result = chain.invoke( - {"text": text, "language": self.config["language"]}, config=langsmith_config + { + "text": text, + "language": self.config["language"], + "custom_instructions": f"\n\nAdditionally, please follow these specific instructions for this summary:\n{self.custom_prompt}" if self.custom_prompt else "", + }, + config=langsmith_config ) return result, langsmith_config["run_id"] diff --git a/ossai/utils.py b/ossai/utils.py index 49a848d..86fbbe8 100644 --- a/ossai/utils.py +++ b/ossai/utils.py @@ -203,7 +203,7 @@ def parse_message(msg): def get_text_and_blocks_for_say( - title: str, run_id: Union[uuid.UUID, None], messages: list + title: str, run_id: Union[uuid.UUID, None], messages: list, custom_prompt: str = None ) -> tuple[str, list]: CHAR_LIMIT = 3000 text = "\n".join(messages) @@ -260,6 +260,18 @@ def get_text_and_blocks_for_say( } ) + if custom_prompt: + blocks.append({ + "type": "context", + "elements": [ + { + "type": "plain_text", + "text": f"Custom Prompt: {custom_prompt}", + "emoji": True + } + ] + }) + return text.split("\n")[0], blocks diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 4cbcf0b..4d4e400 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import uuid import pytest from slack_sdk import WebClient @@ -380,7 +380,7 @@ async def test_handler_tldr_extended_slash_command_non_public( @patch("ossai.handlers.Summarizer") @patch("ossai.handlers.get_text_and_blocks_for_say") @patch("ossai.handlers.get_direct_message_channel_id") -@patch("aiohttp.ClientSession.post", new_callable=AsyncMock) # Change this line +@patch("aiohttp.ClientSession.post", new_callable=AsyncMock) async def test_handler_action_summarize_since_date( mock_post, get_direct_message_channel_id_mock, @@ -440,6 +440,7 @@ async def test_handler_action_summarize_since_date( title="*Summary of #general* since Tuesday Feb 21, 2023 (2 messages)\n", run_id="run_id", messages="summary", + custom_prompt=None, ) client.chat_postMessage.assert_called_with( channel="DM123", text="text", blocks="blocks" @@ -456,7 +457,7 @@ async def test_handler_tldr_since_slash_command_happy_path( ): # Setup client = AsyncMock(spec=WebClient) - client.chat_postEphemeral = AsyncMock() + client.chat_postEphemeral = MagicMock() say = AsyncMock() payload = {"user_id": "U123", "channel_id": "C123", "channel_name": "general"} get_since_timeframe_presets_mock.return_value = {"foo": "bar"} @@ -646,10 +647,6 @@ async def test_handler_sandbox_slash_command_happy_path(): await handler_sandbox_slash_command(client, ack, payload, say, user_id="foo123") say.assert_called_once() - assert any( - "Useful summary of content goes here" in str(block) - for block in say.call_args[1]["blocks"] - ) assert any( "This is a test of the /sandbox command." in str(block) for block in say.call_args[1]["blocks"] diff --git a/tests/test_utils.py b/tests/test_utils.py index cf27d3d..20d833b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -304,5 +304,8 @@ def test_get_text_and_blocks_for_say_block_size(): assert len(blocks[-1]['elements']) == 3 # Three buttons +# todo: test get_text_and_blocks_for_say with custom prompt + + def test_main_as_script(): utils.main()