Skip to content

Commit

Permalink
fix: resolve bug with bot DM routing
Browse files Browse the repository at this point in the history
  • Loading branch information
meetbryce committed May 30, 2024
1 parent 8174830 commit 9ca294f
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 22 deletions.
1 change: 0 additions & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ jobs:
DEBUG: True
TEMPERATURE: 0.35
SLACK_BOT_TOKEN: "xoxb-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
SLACK_USER_TOKEN: "xoxp-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
SLACK_SIGNING_SECRET: "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
SLACK_BOT_USER_ID: "XXXXXXXXXXX"
OPEN_AI_TOKEN: "sk-XXXXXXXXXXXXXXXXXXX"
Expand Down
1 change: 0 additions & 1 deletion example.env
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
DEBUG=True
TEMPERATURE=0.35
SLACK_BOT_TOKEN="xoxb-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
SLACK_USER_TOKEN="xoxp-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
SLACK_SIGNING_SECRET="XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
SLACK_BOT_USER_ID="XXXXXXXXXXX"
OPEN_AI_TOKEN="sk-XXXXXXXXXXXXXXXXXXX"
Expand Down
12 changes: 6 additions & 6 deletions hackathon_2023/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
get_parsed_messages


async def handler_shortcuts(client: WebClient, is_private: bool, payload, say):
async def handler_shortcuts(client: WebClient, is_private: bool, payload, say, user_id: str):
channel_id = payload['channel']['id'] if payload['channel']['id'] else payload['channel_id']
dm_channel_id = await get_direct_message_channel_id(client)
dm_channel_id = await get_direct_message_channel_id(client, user_id)
channel_id_for_say = dm_channel_id if is_private else channel_id
await say(channel=channel_id_for_say, text='...')

Expand Down Expand Up @@ -43,7 +43,7 @@ async def handler_shortcuts(client: WebClient, is_private: bool, payload, say):
return await say(channel=dm_channel_id, text=f"Encountered an error: {e.response['error']}")


async def handler_tldr_slash_command(client: WebClient, ack, payload, say):
async def handler_tldr_slash_command(client: WebClient, ack, payload, say, user_id: str):
await ack() # fixme: this seemingly does nothing
text = payload.get("text", None)
channel_name = payload["channel_name"]
Expand All @@ -53,7 +53,7 @@ async def handler_tldr_slash_command(client: WebClient, ack, payload, say):
if text == 'public':
await say('...') # hack to get the bot to not show an error message but works fine
else:
dm_channel_id = await get_direct_message_channel_id(client)
dm_channel_id = await get_direct_message_channel_id(client, user_id)
await say(channel=dm_channel_id, text='...') # hack to get the bot to not show an error message but works fine

if text and text != 'public':
Expand All @@ -74,10 +74,10 @@ async def handler_tldr_slash_command(client: WebClient, ack, payload, say):
return await say(channel=dm_channel_id, text=f"Encountered an error: {e.response['error']}")


async def handler_topics_slash_command(client: WebClient, ack, payload, say):
async def handler_topics_slash_command(client: WebClient, ack, payload, say, user_id: str):
# START boilerplate
await ack()
dm_channel_id = await get_direct_message_channel_id(client)
dm_channel_id = await get_direct_message_channel_id(client, user_id)
await say(channel=dm_channel_id, text='...')

history = await get_channel_history(client, payload["channel_id"])
Expand Down
8 changes: 4 additions & 4 deletions hackathon_2023/slack_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,24 @@ async def shutdown_event():

@async_app.command("/tldr")
async def handle_tldr_slash_command(ack, payload, say):
return await handler_tldr_slash_command(client, ack, payload, say)
return await handler_tldr_slash_command(client, ack, payload, say, user_id=payload['user_id'])


@async_app.command("/tldr_topics")
async def temp__handle_slash_command_topics(ack, payload, say):
return await handler_topics_slash_command(client, ack, payload, say)
return await handler_topics_slash_command(client, ack, payload, say, user_id=payload['user_id'])


@async_app.shortcut("thread")
async def handle_thread_shortcut(ack, payload, say):
await ack()
await handler_shortcuts(client, False, payload, say)
await handler_shortcuts(client, False, payload, say, user_id=payload['user']['id'])


@async_app.shortcut("thread_private")
async def handle_thread_private_shortcut(ack, payload, say):
await ack()
await handler_shortcuts(client, True, payload, say)
await handler_shortcuts(client, True, payload, say, user_id=payload['user']['id'])


@app.event("message")
Expand Down
5 changes: 2 additions & 3 deletions hackathon_2023/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,14 @@ async def get_channel_history(client: WebClient, channel_id: str) -> list:
return [msg for msg in response["messages"] if msg.get("bot_id") != bot_id]


async def get_direct_message_channel_id(client: WebClient) -> str:
async def get_direct_message_channel_id(client: WebClient, user_id: str) -> str:
"""
Get the direct message channel ID for the bot, so you can say() via direct message.
:return str:
"""
# todo: cache this sucker too
try:
user_client = WebClient(token=os.environ["SLACK_USER_TOKEN"])
response = client.conversations_open(users=user_client.auth_test()['user_id'])
response = client.conversations_open(users=user_id)
return response["channel"]["id"]
except SlackApiError as e:
print(f"Error fetching bot DM channel ID: {e.response['error']}")
Expand Down
10 changes: 5 additions & 5 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def test_handler_shortcuts(
get_direct_message_channel_id_mock, client, payload, say
):
get_direct_message_channel_id_mock.return_value = "dm_channel_id"
await handler_shortcuts(client, True, payload, say)
await handler_shortcuts(client, True, payload, say, user_id="foo123")
say.assert_called()


Expand All @@ -47,7 +47,7 @@ async def test_handler_tldr_slash_command_channel_history_error(
get_direct_message_channel_id_mock, client, payload, say
):
get_direct_message_channel_id_mock.return_value = "dm_channel_id"
await handler_tldr_slash_command(client, AsyncMock(), payload, say)
await handler_tldr_slash_command(client, AsyncMock(), payload, say, user_id="foo123")
say.assert_called()


Expand All @@ -69,7 +69,7 @@ async def test_handler_topics_slash_command(
get_channel_history_mock.return_value = ["message1", "message2", "message3"]
get_parsed_messages_mock.return_value = "parsed_messages"
analyze_topics_of_history_mock.return_value = "topic_overview"
await handler_topics_slash_command(client, AsyncMock(), payload, say)
await handler_topics_slash_command(client, AsyncMock(), payload, say, user_id="foo123")
say.assert_called()


Expand Down Expand Up @@ -97,7 +97,7 @@ async def test_handler_shortcuts(
summarize_slack_messages_mock.return_value = ["summary"]

# Act
await handler_shortcuts(client, True, payload, say)
await handler_shortcuts(client, True, payload, say, user_id="foo123")

# Assert
say.assert_called_with(channel="dm_channel_id", text="\n".join(["summary"]))
Expand All @@ -114,5 +114,5 @@ async def test_handler_tldr_slash_command_public(
"channel_name": "channel_name",
"channel_id": "channel_id",
}
await handler_tldr_slash_command(client, AsyncMock(), payload, say)
await handler_tldr_slash_command(client, AsyncMock(), payload, say, user_id="foo123")
say.assert_called()
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ async def test_get_channel_history(mock_client):
@pytest.mark.asyncio
async def test_get_direct_message_channel_id(mock_client, mock_user_client):
mock_client.conversations_open.return_value = {"channel": {"id": "C123"}}
assert await utils.get_direct_message_channel_id(mock_client) == "C123"
assert await utils.get_direct_message_channel_id(mock_client, "U123") == "C123"


@pytest.mark.asyncio
async def test_get_direct_message_channel_id_with_exception(mock_client):
mock_client.conversations_open.side_effect = SlackApiError("error", {"error": "error"})
with pytest.raises(SlackApiError) as e_info:
await utils.get_direct_message_channel_id(mock_client)
await utils.get_direct_message_channel_id(mock_client, "U123")
assert True


Expand Down

0 comments on commit 9ca294f

Please sign in to comment.