From a7aba5357280fffc701c97713e57c7b019702f8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Wei=C3=9F?= Date: Thu, 12 Oct 2023 09:45:15 +0200 Subject: [PATCH] make slack connector more fault tolerant ... by turning it into a LOAD_STATE connector (while still technically being a polling connector). It can be interrupted at any time and will continue where it left off, which is especially important for the first import of large Slack workspaces. --- backend/danswer/configs/app_configs.py | 6 ++ backend/danswer/connectors/slack/connector.py | 75 +++++++++++++++++-- 2 files changed, 73 insertions(+), 8 deletions(-) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 4845c69cc04..111ccd059f8 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -138,6 +138,12 @@ DIRECTORY_CONNECTOR_MAX_BATCHES = int(os.environ.get("DIRECTORY_CONNECTOR_MAX_BATCHES", "5")) except TypeError as e: DIRECTORY_CONNECTOR_MAX_BATCHES = 5 + +try: + SLACK_CONNECTOR_MAX_BATCHES = int(os.environ.get("SLACK_CONNECTOR_MAX_BATCHES", "5")) +except TypeError as e: + SLACK_CONNECTOR_MAX_BATCHES = 5 + # TODO these should be available for frontend configuration, via advanced options expandable WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get( "WEB_CONNECTOR_IGNORED_CLASSES", "sidebar,footer" diff --git a/backend/danswer/connectors/slack/connector.py b/backend/danswer/connectors/slack/connector.py index 42d026c2988..35f47a34c76 100644 --- a/backend/danswer/connectors/slack/connector.py +++ b/backend/danswer/connectors/slack/connector.py @@ -11,6 +11,7 @@ from slack_sdk.web import SlackResponse from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.app_configs import SLACK_CONNECTOR_MAX_BATCHES from danswer.configs.constants import DocumentSource from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import LoadConnector @@ -24,6 +25,9 @@ from danswer.connectors.slack.utils import make_slack_api_call_paginated from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.connectors.slack.utils import UserIdReplacer +from danswer.dynamic_configs import get_dynamic_config_store +from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.dynamic_configs.interface import JSON_ro from danswer.utils.logger import setup_logger logger = setup_logger() @@ -34,6 +38,8 @@ # list of messages in a thread ThreadType = list[MessageType] +LOAD_STATE_KEY = "slack_connector_state" +MAX_BATCHES = SLACK_CONNECTOR_MAX_BATCHES def _make_paginated_slack_api_call( call: Callable[..., SlackResponse], **kwargs: Any @@ -198,9 +204,8 @@ def _filter_channels( def get_all_docs( client: WebClient, workspace: str, + state: dict[str, dict[str, Any]], channels: list[str] | None = None, - oldest: str | None = None, - latest: str | None = None, msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter, ) -> Generator[Document, None, None]: """Get all documents in the workspace, channel by channel""" @@ -211,13 +216,43 @@ def get_all_docs( for channel in filtered_channels: channel_docs = 0 + + if channel["id"] in state: + channel_state = state[channel["id"]] + else: + channel_state = { + "name": channel["name"], + "oldest": None, + "latest": None, + "initial": True, + } + state[channel["id"]] = channel_state + + initial = channel_state["initial"] + + # If we're doing an initial import, we go backwards until we have + # imported all messages from the channel. Afterwards, we only pull + # messages that are newer than the latest message we've seen. + if initial: + oldest = None + latest = channel_state["oldest"] + logger.info(f'Running initial import of channel #{channel["name"]}: oldest={oldest}, latest={latest}') + else: + oldest = channel_state["latest"] + latest = None + logger.info(f'Running incremental import of channel #{channel["name"]}: oldest={oldest}, latest={latest}') + channel_message_batches = get_channel_messages( client=client, channel=channel, oldest=oldest, latest=latest ) + latest_ts = None seen_thread_ts: set[str] = set() for message_batch in channel_message_batches: for message in message_batch: + if latest_ts is None: + latest_ts = message["ts"] + filtered_thread: ThreadType | None = None thread_ts = message.get("thread_ts") if thread_ts: @@ -237,17 +272,28 @@ def get_all_docs( if filtered_thread: channel_docs += 1 - yield thread_to_doc( + doc = thread_to_doc( workspace=workspace, channel=channel, thread=filtered_thread, user_id_replacer=user_id_replacer, ) + text_length = sum(map(lambda sec: len(sec.text), doc.sections)) + if text_length != 0: + yield doc + if initial: + channel_state["oldest"] = message["ts"] + if channel_state["latest"] is None: + channel_state["latest"] = message["ts"] + else: + channel_state["latest"] = latest_ts logger.info( f"Pulled {channel_docs} documents from slack channel {channel['name']}" ) + channel_state["initial"] = False + class SlackLoadConnector(LoadConnector): def __init__( @@ -374,21 +420,34 @@ def poll_source( if self.client is None: raise ConnectorMissingCredentialError("Slack") + try: + state = cast(dict, get_dynamic_config_store().load(LOAD_STATE_KEY)) + except ConfigNotFoundError: + state = {} + + if "channels" not in state: + state["channels"] = {} + channels_state = state["channels"] + documents: list[Document] = [] + num_batches = 0 for document in get_all_docs( client=self.client, workspace=self.workspace, channels=self.channels, - # NOTE: need to impute to `None` instead of using 0.0, since Slack will - # throw an error if we use 0.0 on an account without infinite data - # retention - oldest=str(start) if start else None, - latest=str(end), + state=channels_state, ): documents.append(document) if len(documents) >= self.batch_size: + logger.info(f"Yielding batch {num_batches + 1} of {self.batch_size} documents") yield documents documents = [] + num_batches += 1 + if num_batches >= MAX_BATCHES: + logger.info(f"Reached max batches of {MAX_BATCHES}, stopping") + break if documents: yield documents + + get_dynamic_config_store().store(LOAD_STATE_KEY, cast(JSON_ro, state))