From 28e2b78b2e4f5633bfebbf7f13edeae3b5f8ca15 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Wed, 27 Nov 2024 08:10:07 -0800 Subject: [PATCH 001/133] Fix search dropdown (#3269) * validate dropdown * validate * update organization * move to utils --- web/src/components/Dropdown.tsx | 151 ++++++++++++++++---------------- web/src/components/Modal.tsx | 16 +++- web/src/lib/dropdown.ts | 49 +++++++++++ 3 files changed, 139 insertions(+), 77 deletions(-) create mode 100644 web/src/lib/dropdown.ts diff --git a/web/src/components/Dropdown.tsx b/web/src/components/Dropdown.tsx index e822632b4eb..2c5c4719efb 100644 --- a/web/src/components/Dropdown.tsx +++ b/web/src/components/Dropdown.tsx @@ -10,6 +10,8 @@ import { import { ChevronDownIcon } from "./icons/icons"; import { FiCheck, FiChevronDown } from "react-icons/fi"; import { Popover } from "./popover/Popover"; +import { createPortal } from "react-dom"; +import { useDropdownPosition } from "@/lib/dropdown"; export interface Option { name: string; @@ -60,6 +62,7 @@ export function SearchMultiSelectDropdown({ const [isOpen, setIsOpen] = useState(false); const [searchTerm, setSearchTerm] = useState(""); const dropdownRef = useRef(null); + const dropdownMenuRef = useRef(null); const handleSelect = (option: StringOrNumberOption) => { onSelect(option); @@ -75,7 +78,9 @@ export function SearchMultiSelectDropdown({ const handleClickOutside = (event: MouseEvent) => { if ( dropdownRef.current && - !dropdownRef.current.contains(event.target as Node) + !dropdownRef.current.contains(event.target as Node) && + dropdownMenuRef.current && + !dropdownMenuRef.current.contains(event.target as Node) ) { setIsOpen(false); } @@ -87,105 +92,103 @@ export function SearchMultiSelectDropdown({ }; }, []); + useDropdownPosition({ isOpen, dropdownRef, dropdownMenuRef }); + return ( -
+
) => { - if (!searchTerm) { + setSearchTerm(e.target.value); + if (e.target.value) { setIsOpen(true); - } - if (!e.target.value) { + } else { setIsOpen(false); } - setSearchTerm(e.target.value); }} onFocus={() => setIsOpen(true)} className={`inline-flex - justify-between - w-full - px-4 - py-2 - text-sm - bg-background - border - border-border - rounded-md - shadow-sm - `} - onClick={(e) => e.stopPropagation()} + justify-between + w-full + px-4 + py-2 + text-sm + bg-background + border + border-border + rounded-md + shadow-sm + `} />
- {isOpen && ( -
+ {isOpen && + createPortal(
- {filteredOptions.length ? ( - filteredOptions.map((option, index) => - itemComponent ? ( -
{ - setIsOpen(false); - handleSelect(option); - }} - > - {itemComponent({ option })} -
- ) : ( - +
+ {filteredOptions.length ? ( + filteredOptions.map((option, index) => + itemComponent ? ( +
{ + handleSelect(option); + }} + > + {itemComponent({ option })} +
+ ) : ( + + ) ) - ) - ) : ( - - )} -
-
- )} + ) : ( + + )} +
+
, + document.body + )}
); } diff --git a/web/src/components/Modal.tsx b/web/src/components/Modal.tsx index 4582ed8a558..05886975088 100644 --- a/web/src/components/Modal.tsx +++ b/web/src/components/Modal.tsx @@ -66,11 +66,21 @@ export function Modal({ e.stopPropagation(); } }} - className={`bg-background text-emphasis rounded shadow-2xl - transform transition-all duration-300 ease-in-out + className={` + bg-background + text-emphasis + rounded + shadow-2xl + transform + transition-all + duration-300 + ease-in-out + relative + overflow-visible ${width ?? "w-11/12 max-w-4xl"} ${noPadding ? "" : "p-10"} - ${className || ""}`} + ${className || ""} + `} > {onOutsideClick && !hideCloseButton && (
diff --git a/web/src/lib/dropdown.ts b/web/src/lib/dropdown.ts new file mode 100644 index 00000000000..b4fcf42d68e --- /dev/null +++ b/web/src/lib/dropdown.ts @@ -0,0 +1,49 @@ +import { RefObject, useCallback, useEffect } from "react"; + +interface DropdownPositionProps { + isOpen: boolean; + dropdownRef: RefObject; + dropdownMenuRef: RefObject; +} + +// This hook manages the positioning of a dropdown menu relative to its trigger element. +// It ensures the menu is positioned correctly, adjusting for viewport boundaries and scroll position. +// Also adds event listeners for window resize and scroll to update the position dynamically. +export const useDropdownPosition = ({ + isOpen, + dropdownRef, + dropdownMenuRef, +}: DropdownPositionProps) => { + const updateMenuPosition = useCallback(() => { + if (isOpen && dropdownRef.current && dropdownMenuRef.current) { + const rect = dropdownRef.current.getBoundingClientRect(); + const menuRect = dropdownMenuRef.current.getBoundingClientRect(); + const viewportHeight = window.innerHeight; + + let top = rect.bottom + window.scrollY; + + if (top + menuRect.height > viewportHeight) { + top = rect.top + window.scrollY - menuRect.height; + } + + dropdownMenuRef.current.style.position = "absolute"; + dropdownMenuRef.current.style.top = `${top}px`; + dropdownMenuRef.current.style.left = `${rect.left + window.scrollX}px`; + dropdownMenuRef.current.style.width = `${rect.width}px`; + dropdownMenuRef.current.style.zIndex = "10000"; + } + }, [isOpen, dropdownRef, dropdownMenuRef]); + + useEffect(() => { + updateMenuPosition(); + window.addEventListener("resize", updateMenuPosition); + window.addEventListener("scroll", updateMenuPosition); + + return () => { + window.removeEventListener("resize", updateMenuPosition); + window.removeEventListener("scroll", updateMenuPosition); + }; + }, [isOpen, updateMenuPosition]); + + return updateMenuPosition; +}; From 07dfde2209374bb90ee22b6074ab51c0c23235d7 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 27 Nov 2024 10:25:38 -0800 Subject: [PATCH 002/133] add continue in danswer button to slack bot responses (#3239) * all done except routing * fixed initial changes * added backend endpoint for duplicating a chat session from Slack * got chat duplication routing done * got login routing working * improved answer handling * finished all checks * finished all! * made sure it works with google oauth * dont remove that lol * fixed weird thing * bad comments --- ...1b118_add_web_ui_option_to_slack_config.py | 35 +++ backend/danswer/danswerbot/slack/blocks.py | 231 +++++++++++++++--- backend/danswer/danswerbot/slack/constants.py | 1 + .../slack/handlers/handle_buttons.py | 4 +- .../slack/handlers/handle_message.py | 4 +- .../slack/handlers/handle_regular_answer.py | 68 +----- backend/danswer/danswerbot/slack/utils.py | 13 +- backend/danswer/db/chat.py | 106 ++++++++ backend/danswer/db/models.py | 1 + backend/danswer/db/persona.py | 25 ++ backend/danswer/main.py | 3 +- .../one_shot_answer/answer_question.py | 14 +- backend/danswer/one_shot_answer/qa_utils.py | 28 +++ backend/danswer/server/manage/models.py | 1 + backend/danswer/server/manage/slack_bot.py | 4 + .../server/query_and_chat/chat_backend.py | 34 +++ .../[bot-id]/SlackChannelConfigsTable.tsx | 22 +- .../SlackChannelConfigCreationForm.tsx | 14 +- web/src/app/admin/bots/[bot-id]/lib.ts | 2 + web/src/app/admin/bots/[bot-id]/page.tsx | 1 - web/src/app/auth/login/EmailPasswordForm.tsx | 4 +- web/src/app/auth/login/page.tsx | 21 +- web/src/app/auth/signup/page.tsx | 18 +- web/src/app/chat/ChatPage.tsx | 46 +++- web/src/app/chat/lib.tsx | 4 +- web/src/lib/types.ts | 1 + web/src/lib/userSS.ts | 31 ++- 27 files changed, 590 insertions(+), 146 deletions(-) create mode 100644 backend/alembic/versions/93560ba1b118_add_web_ui_option_to_slack_config.py diff --git a/backend/alembic/versions/93560ba1b118_add_web_ui_option_to_slack_config.py b/backend/alembic/versions/93560ba1b118_add_web_ui_option_to_slack_config.py new file mode 100644 index 00000000000..ab084aee314 --- /dev/null +++ b/backend/alembic/versions/93560ba1b118_add_web_ui_option_to_slack_config.py @@ -0,0 +1,35 @@ +"""add web ui option to slack config + +Revision ID: 93560ba1b118 +Revises: 6d562f86c78b +Create Date: 2024-11-24 06:36:17.490612 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "93560ba1b118" +down_revision = "6d562f86c78b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Add show_continue_in_web_ui with default False to all existing channel_configs + op.execute( + """ + UPDATE slack_channel_config + SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb + WHERE NOT channel_config ? 'show_continue_in_web_ui' + """ + ) + + +def downgrade() -> None: + # Remove show_continue_in_web_ui from all channel_configs + op.execute( + """ + UPDATE slack_channel_config + SET channel_config = channel_config - 'show_continue_in_web_ui' + """ + ) diff --git a/backend/danswer/danswerbot/slack/blocks.py b/backend/danswer/danswerbot/slack/blocks.py index 1f689157452..a5e6868fd37 100644 --- a/backend/danswer/danswerbot/slack/blocks.py +++ b/backend/danswer/danswerbot/slack/blocks.py @@ -18,20 +18,30 @@ from danswer.chat.models import DanswerQuote from danswer.configs.app_configs import DISABLE_GENERATIVE_AI +from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.constants import DocumentSource from danswer.configs.constants import SearchFeedbackType from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY from danswer.context.search.models import SavedSearchDoc +from danswer.danswerbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID +from danswer.danswerbot.slack.formatting import format_slack_message from danswer.danswerbot.slack.icons import source_to_github_img_link +from danswer.danswerbot.slack.models import SlackMessageInfo +from danswer.danswerbot.slack.utils import build_continue_in_web_ui_id from danswer.danswerbot.slack.utils import build_feedback_id from danswer.danswerbot.slack.utils import remove_slack_text_interactions from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack +from danswer.db.chat import get_chat_session_by_message_id +from danswer.db.engine import get_session_with_tenant +from danswer.db.models import ChannelConfig +from danswer.db.models import Persona +from danswer.one_shot_answer.models import OneShotQAResponse from danswer.utils.text_processing import decode_escapes from danswer.utils.text_processing import replace_whitespaces_w_space @@ -101,12 +111,12 @@ def _split_text(text: str, limit: int = 3000) -> list[str]: return chunks -def clean_markdown_link_text(text: str) -> str: +def _clean_markdown_link_text(text: str) -> str: # Remove any newlines within the text return text.replace("\n", " ").strip() -def build_qa_feedback_block( +def _build_qa_feedback_block( message_id: int, feedback_reminder_id: str | None = None ) -> Block: return ActionsBlock( @@ -115,7 +125,6 @@ def build_qa_feedback_block( ButtonElement( action_id=LIKE_BLOCK_ACTION_ID, text="👍 Helpful", - style="primary", value=feedback_reminder_id, ), ButtonElement( @@ -155,7 +164,7 @@ def get_document_feedback_blocks() -> Block: ) -def build_doc_feedback_block( +def _build_doc_feedback_block( message_id: int, document_id: str, document_rank: int, @@ -182,7 +191,7 @@ def get_restate_blocks( ] -def build_documents_blocks( +def _build_documents_blocks( documents: list[SavedSearchDoc], message_id: int | None, num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY, @@ -223,7 +232,7 @@ def build_documents_blocks( feedback: ButtonElement | dict = {} if message_id is not None: - feedback = build_doc_feedback_block( + feedback = _build_doc_feedback_block( message_id=message_id, document_id=d.document_id, document_rank=rank, @@ -241,7 +250,7 @@ def build_documents_blocks( return section_blocks -def build_sources_blocks( +def _build_sources_blocks( cited_documents: list[tuple[int, SavedSearchDoc]], num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY, ) -> list[Block]: @@ -286,7 +295,7 @@ def build_sources_blocks( + ([days_ago_str] if days_ago_str else []) ) - document_title = clean_markdown_link_text(doc_sem_id) + document_title = _clean_markdown_link_text(doc_sem_id) img_link = source_to_github_img_link(d.source_type) section_blocks.append( @@ -317,7 +326,50 @@ def build_sources_blocks( return section_blocks -def build_quotes_block( +def _priority_ordered_documents_blocks( + answer: OneShotQAResponse, +) -> list[Block]: + docs_response = answer.docs if answer.docs else None + top_docs = docs_response.top_documents if docs_response else [] + llm_doc_inds = answer.llm_selected_doc_indices or [] + llm_docs = [top_docs[i] for i in llm_doc_inds] + remaining_docs = [ + doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds + ] + priority_ordered_docs = llm_docs + remaining_docs + if not priority_ordered_docs: + return [] + + document_blocks = _build_documents_blocks( + documents=priority_ordered_docs, + message_id=answer.chat_message_id, + ) + if document_blocks: + document_blocks = [DividerBlock()] + document_blocks + return document_blocks + + +def _build_citations_blocks( + answer: OneShotQAResponse, +) -> list[Block]: + docs_response = answer.docs if answer.docs else None + top_docs = docs_response.top_documents if docs_response else [] + citations = answer.citations or [] + cited_docs = [] + for citation in citations: + matching_doc = next( + (d for d in top_docs if d.document_id == citation.document_id), + None, + ) + if matching_doc: + cited_docs.append((citation.citation_num, matching_doc)) + + cited_docs.sort() + citations_block = _build_sources_blocks(cited_documents=cited_docs) + return citations_block + + +def _build_quotes_block( quotes: list[DanswerQuote], ) -> list[Block]: quote_lines: list[str] = [] @@ -359,58 +411,70 @@ def build_quotes_block( return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))] -def build_qa_response_blocks( - message_id: int | None, - answer: str | None, - quotes: list[DanswerQuote] | None, - source_filters: list[DocumentSource] | None, - time_cutoff: datetime | None, - favor_recent: bool, +def _build_qa_response_blocks( + answer: OneShotQAResponse, skip_quotes: bool = False, process_message_for_citations: bool = False, - skip_ai_feedback: bool = False, - feedback_reminder_id: str | None = None, ) -> list[Block]: + retrieval_info = answer.docs + if not retrieval_info: + # This should not happen, even with no docs retrieved, there is still info returned + raise RuntimeError("Failed to retrieve docs, cannot answer question.") + + formatted_answer = format_slack_message(answer.answer) if answer.answer else None + quotes = answer.quotes.quotes if answer.quotes else None + if DISABLE_GENERATIVE_AI: return [] quotes_blocks: list[Block] = [] filter_block: Block | None = None - if time_cutoff or favor_recent or source_filters: + if ( + retrieval_info.applied_time_cutoff + or retrieval_info.recency_bias_multiplier > 1 + or retrieval_info.applied_source_filters + ): filter_text = "Filters: " - if source_filters: - sources_str = ", ".join([s.value for s in source_filters]) + if retrieval_info.applied_source_filters: + sources_str = ", ".join( + [s.value for s in retrieval_info.applied_source_filters] + ) filter_text += f"`Sources in [{sources_str}]`" - if time_cutoff or favor_recent: + if ( + retrieval_info.applied_time_cutoff + or retrieval_info.recency_bias_multiplier > 1 + ): filter_text += " and " - if time_cutoff is not None: - time_str = time_cutoff.strftime("%b %d, %Y") + if retrieval_info.applied_time_cutoff is not None: + time_str = retrieval_info.applied_time_cutoff.strftime("%b %d, %Y") filter_text += f"`Docs Updated >= {time_str}` " - if favor_recent: - if time_cutoff is not None: + if retrieval_info.recency_bias_multiplier > 1: + if retrieval_info.applied_time_cutoff is not None: filter_text += "+ " filter_text += "`Prioritize Recently Updated Docs`" filter_block = SectionBlock(text=f"_{filter_text}_") - if not answer: + if not formatted_answer: answer_blocks = [ SectionBlock( text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓" ) ] else: - answer_processed = decode_escapes(remove_slack_text_interactions(answer)) + answer_processed = decode_escapes( + remove_slack_text_interactions(formatted_answer) + ) if process_message_for_citations: answer_processed = _process_citations_for_slack(answer_processed) answer_blocks = [ SectionBlock(text=text) for text in _split_text(answer_processed) ] if quotes: - quotes_blocks = build_quotes_block(quotes) + quotes_blocks = _build_quotes_block(quotes) - # if no quotes OR `build_quotes_block()` did not give back any blocks + # if no quotes OR `_build_quotes_block()` did not give back any blocks if not quotes_blocks: quotes_blocks = [ SectionBlock( @@ -425,20 +489,37 @@ def build_qa_response_blocks( response_blocks.extend(answer_blocks) - if message_id is not None and not skip_ai_feedback: - response_blocks.append( - build_qa_feedback_block( - message_id=message_id, feedback_reminder_id=feedback_reminder_id - ) - ) - if not skip_quotes: response_blocks.extend(quotes_blocks) return response_blocks -def build_follow_up_block(message_id: int | None) -> ActionsBlock: +def _build_continue_in_web_ui_block( + tenant_id: str | None, + message_id: int | None, +) -> Block: + if message_id is None: + raise ValueError("No message id provided to build continue in web ui block") + with get_session_with_tenant(tenant_id) as db_session: + chat_session = get_chat_session_by_message_id( + db_session=db_session, + message_id=message_id, + ) + return ActionsBlock( + block_id=build_continue_in_web_ui_id(message_id), + elements=[ + ButtonElement( + action_id=CONTINUE_IN_WEB_UI_ACTION_ID, + text="Continue Chat in Danswer!", + style="primary", + url=f"{WEB_DOMAIN}/chat?slackChatId={chat_session.id}", + ), + ], + ) + + +def _build_follow_up_block(message_id: int | None) -> ActionsBlock: return ActionsBlock( block_id=build_feedback_id(message_id) if message_id is not None else None, elements=[ @@ -483,3 +564,77 @@ def build_follow_up_resolved_blocks( ] ) return [text_block, button_block] + + +def build_slack_response_blocks( + tenant_id: str | None, + message_info: SlackMessageInfo, + answer: OneShotQAResponse, + persona: Persona | None, + channel_conf: ChannelConfig | None, + use_citations: bool, + feedback_reminder_id: str | None, + skip_ai_feedback: bool = False, +) -> list[Block]: + """ + This function is a top level function that builds all the blocks for the Slack response. + It also handles combining all the blocks together. + """ + # If called with the DanswerBot slash command, the question is lost so we have to reshow it + restate_question_block = get_restate_blocks( + message_info.thread_messages[-1].message, message_info.is_bot_msg + ) + + answer_blocks = _build_qa_response_blocks( + answer=answer, + skip_quotes=persona is not None or use_citations, + process_message_for_citations=use_citations, + ) + + web_follow_up_block = [] + if channel_conf and channel_conf.get("show_continue_in_web_ui"): + web_follow_up_block.append( + _build_continue_in_web_ui_block( + tenant_id=tenant_id, + message_id=answer.chat_message_id, + ) + ) + + follow_up_block = [] + if channel_conf and channel_conf.get("follow_up_tags") is not None: + follow_up_block.append( + _build_follow_up_block(message_id=answer.chat_message_id) + ) + + ai_feedback_block = [] + if answer.chat_message_id is not None and not skip_ai_feedback: + ai_feedback_block.append( + _build_qa_feedback_block( + message_id=answer.chat_message_id, + feedback_reminder_id=feedback_reminder_id, + ) + ) + + citations_blocks = [] + document_blocks = [] + if use_citations: + # if citations are enabled, only show cited documents + citations_blocks = _build_citations_blocks(answer) + else: + document_blocks = _priority_ordered_documents_blocks(answer) + + citations_divider = [DividerBlock()] if citations_blocks else [] + buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else [] + + all_blocks = ( + restate_question_block + + answer_blocks + + ai_feedback_block + + citations_divider + + citations_blocks + + document_blocks + + buttons_divider + + web_follow_up_block + + follow_up_block + ) + return all_blocks diff --git a/backend/danswer/danswerbot/slack/constants.py b/backend/danswer/danswerbot/slack/constants.py index cf2b38032c3..6a5b3ed43ed 100644 --- a/backend/danswer/danswerbot/slack/constants.py +++ b/backend/danswer/danswerbot/slack/constants.py @@ -2,6 +2,7 @@ LIKE_BLOCK_ACTION_ID = "feedback-like" DISLIKE_BLOCK_ACTION_ID = "feedback-dislike" +CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui" FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button" IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button" FOLLOWUP_BUTTON_ACTION_ID = "followup-button" diff --git a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py index ec423979941..9335b96874f 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py @@ -28,7 +28,7 @@ from danswer.danswerbot.slack.utils import build_feedback_id from danswer.danswerbot.slack.utils import decompose_action_id from danswer.danswerbot.slack.utils import fetch_group_ids_from_names -from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails +from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails from danswer.danswerbot.slack.utils import get_channel_name_from_id from danswer.danswerbot.slack.utils import get_feedback_visibility from danswer.danswerbot.slack.utils import read_slack_thread @@ -267,7 +267,7 @@ def handle_followup_button( tag_names = slack_channel_config.channel_config.get("follow_up_tags") remaining = None if tag_names: - tag_ids, remaining = fetch_user_ids_from_emails( + tag_ids, remaining = fetch_slack_user_ids_from_emails( tag_names, client.web_client ) if remaining: diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index 6bec83def4b..1f19d0a70a6 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -13,7 +13,7 @@ handle_standard_answers, ) from danswer.danswerbot.slack.models import SlackMessageInfo -from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails +from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups from danswer.danswerbot.slack.utils import respond_in_thread from danswer.danswerbot.slack.utils import slack_usage_report @@ -184,7 +184,7 @@ def handle_message( send_to: list[str] | None = None missing_users: list[str] | None = None if respond_member_group_list: - send_to, missing_ids = fetch_user_ids_from_emails( + send_to, missing_ids = fetch_slack_user_ids_from_emails( respond_member_group_list, client ) diff --git a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py index 3d5f013dca8..926fd858243 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py @@ -7,7 +7,6 @@ from retry import retry from slack_sdk import WebClient -from slack_sdk.models.blocks import DividerBlock from slack_sdk.models.blocks import SectionBlock from danswer.configs.app_configs import DISABLE_GENERATIVE_AI @@ -25,12 +24,7 @@ from danswer.context.search.models import BaseFilters from danswer.context.search.models import RerankingDetails from danswer.context.search.models import RetrievalDetails -from danswer.danswerbot.slack.blocks import build_documents_blocks -from danswer.danswerbot.slack.blocks import build_follow_up_block -from danswer.danswerbot.slack.blocks import build_qa_response_blocks -from danswer.danswerbot.slack.blocks import build_sources_blocks -from danswer.danswerbot.slack.blocks import get_restate_blocks -from danswer.danswerbot.slack.formatting import format_slack_message +from danswer.danswerbot.slack.blocks import build_slack_response_blocks from danswer.danswerbot.slack.handlers.utils import send_team_member_message from danswer.danswerbot.slack.models import SlackMessageInfo from danswer.danswerbot.slack.utils import respond_in_thread @@ -411,62 +405,16 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non ) return True - # If called with the DanswerBot slash command, the question is lost so we have to reshow it - restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg) - formatted_answer = format_slack_message(answer.answer) if answer.answer else None - - answer_blocks = build_qa_response_blocks( - message_id=answer.chat_message_id, - answer=formatted_answer, - quotes=answer.quotes.quotes if answer.quotes else None, - source_filters=retrieval_info.applied_source_filters, - time_cutoff=retrieval_info.applied_time_cutoff, - favor_recent=retrieval_info.recency_bias_multiplier > 1, - # currently Personas don't support quotes - # if citations are enabled, also don't use quotes - skip_quotes=persona is not None or use_citations, - process_message_for_citations=use_citations, + all_blocks = build_slack_response_blocks( + tenant_id=tenant_id, + message_info=message_info, + answer=answer, + persona=persona, + channel_conf=channel_conf, + use_citations=use_citations, feedback_reminder_id=feedback_reminder_id, ) - # Get the chunks fed to the LLM only, then fill with other docs - llm_doc_inds = answer.llm_selected_doc_indices or [] - llm_docs = [top_docs[i] for i in llm_doc_inds] - remaining_docs = [ - doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds - ] - priority_ordered_docs = llm_docs + remaining_docs - - document_blocks = [] - citations_block = [] - # if citations are enabled, only show cited documents - if use_citations: - citations = answer.citations or [] - cited_docs = [] - for citation in citations: - matching_doc = next( - (d for d in top_docs if d.document_id == citation.document_id), - None, - ) - if matching_doc: - cited_docs.append((citation.citation_num, matching_doc)) - - cited_docs.sort() - citations_block = build_sources_blocks(cited_documents=cited_docs) - elif priority_ordered_docs: - document_blocks = build_documents_blocks( - documents=priority_ordered_docs, - message_id=answer.chat_message_id, - ) - document_blocks = [DividerBlock()] + document_blocks - - all_blocks = ( - restate_question_block + answer_blocks + citations_block + document_blocks - ) - - if channel_conf and channel_conf.get("follow_up_tags") is not None: - all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id)) - try: respond_in_thread( client=client, diff --git a/backend/danswer/danswerbot/slack/utils.py b/backend/danswer/danswerbot/slack/utils.py index e19ce8b688c..cf6f1e1bfc8 100644 --- a/backend/danswer/danswerbot/slack/utils.py +++ b/backend/danswer/danswerbot/slack/utils.py @@ -3,9 +3,9 @@ import re import string import time +import uuid from typing import Any from typing import cast -from typing import Optional from retry import retry from slack_sdk import WebClient @@ -216,6 +216,13 @@ def build_feedback_id( return unique_prefix + ID_SEPARATOR + feedback_id +def build_continue_in_web_ui_id( + message_id: int, +) -> str: + unique_prefix = str(uuid.uuid4())[:10] + return unique_prefix + ID_SEPARATOR + str(message_id) + + def decompose_action_id(feedback_id: str) -> tuple[int, str | None, int | None]: """Decompose into query_id, document_id, document_rank, see above function""" try: @@ -313,7 +320,7 @@ def get_channel_name_from_id( raise e -def fetch_user_ids_from_emails( +def fetch_slack_user_ids_from_emails( user_emails: list[str], client: WebClient ) -> tuple[list[str], list[str]]: user_ids: list[str] = [] @@ -522,7 +529,7 @@ def refill(self) -> None: self.last_reset_time = time.time() def notify( - self, client: WebClient, channel: str, position: int, thread_ts: Optional[str] + self, client: WebClient, channel: str, position: int, thread_ts: str | None ) -> None: respond_in_thread( client=client, diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index a76fcccdd8d..73d0a886f45 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -3,6 +3,7 @@ from datetime import timedelta from uuid import UUID +from fastapi import HTTPException from sqlalchemy import delete from sqlalchemy import desc from sqlalchemy import func @@ -30,6 +31,7 @@ from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import ToolCall from danswer.db.models import User +from danswer.db.persona import get_best_persona_id_for_user from danswer.db.pg_file_store import delete_lobj_by_name from danswer.file_store.models import FileDescriptor from danswer.llm.override_models import LLMOverride @@ -250,6 +252,50 @@ def create_chat_session( return chat_session +def duplicate_chat_session_for_user_from_slack( + db_session: Session, + user: User | None, + chat_session_id: UUID, +) -> ChatSession: + """ + This takes a chat session id for a session in Slack and: + - Creates a new chat session in the DB + - Tries to copy the persona from the original chat session + (if it is available to the user clicking the button) + - Sets the user to the given user (if provided) + """ + chat_session = get_chat_session_by_id( + chat_session_id=chat_session_id, + user_id=None, # Ignore user permissions for this + db_session=db_session, + ) + if not chat_session: + raise HTTPException(status_code=400, detail="Invalid Chat Session ID provided") + + # This enforces permissions and sets a default + new_persona_id = get_best_persona_id_for_user( + db_session=db_session, + user=user, + persona_id=chat_session.persona_id, + ) + + return create_chat_session( + db_session=db_session, + user_id=user.id if user else None, + persona_id=new_persona_id, + # Set this to empty string so the frontend will force a rename + description="", + llm_override=chat_session.llm_override, + prompt_override=chat_session.prompt_override, + # Chat sessions from Slack should put people in the chat UI, not the search + one_shot=False, + # Chat is in UI now so this is false + danswerbot_flow=False, + # Maybe we want this in the future to track if it was created from Slack + slack_thread_id=None, + ) + + def update_chat_session( db_session: Session, user_id: UUID | None, @@ -336,6 +382,28 @@ def get_chat_message( return chat_message +def get_chat_session_by_message_id( + db_session: Session, + message_id: int, +) -> ChatSession: + """ + Should only be used for Slack + Get the chat session associated with a specific message ID + Note: this ignores permission checks. + """ + stmt = select(ChatMessage).where(ChatMessage.id == message_id) + + result = db_session.execute(stmt) + chat_message = result.scalar_one_or_none() + + if chat_message is None: + raise ValueError( + f"Unable to find chat session associated with message ID: {message_id}" + ) + + return chat_message.chat_session + + def get_chat_messages_by_sessions( chat_session_ids: list[UUID], user_id: UUID | None, @@ -355,6 +423,44 @@ def get_chat_messages_by_sessions( return db_session.execute(stmt).scalars().all() +def add_chats_to_session_from_slack_thread( + db_session: Session, + slack_chat_session_id: UUID, + new_chat_session_id: UUID, +) -> None: + new_root_message = get_or_create_root_message( + chat_session_id=new_chat_session_id, + db_session=db_session, + ) + + for chat_message in get_chat_messages_by_sessions( + chat_session_ids=[slack_chat_session_id], + user_id=None, # Ignore user permissions for this + db_session=db_session, + skip_permission_check=True, + ): + if chat_message.message_type == MessageType.SYSTEM: + continue + # Duplicate the message + new_root_message = create_new_chat_message( + db_session=db_session, + chat_session_id=new_chat_session_id, + parent_message=new_root_message, + message=chat_message.message, + files=chat_message.files, + rephrased_query=chat_message.rephrased_query, + error=chat_message.error, + citations=chat_message.citations, + reference_docs=chat_message.search_docs, + tool_call=chat_message.tool_call, + prompt_id=chat_message.prompt_id, + token_count=chat_message.token_count, + message_type=chat_message.message_type, + alternate_assistant_id=chat_message.alternate_assistant_id, + overridden_model=chat_message.overridden_model, + ) + + def get_search_docs_for_chat_message( chat_message_id: int, db_session: Session ) -> list[SearchDoc]: diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 76e70c2d2d9..4e1970a7bd2 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -1480,6 +1480,7 @@ class ChannelConfig(TypedDict): # If None then no follow up # If empty list, follow up with no tags follow_up_tags: NotRequired[list[str]] + show_continue_in_web_ui: NotRequired[bool] # defaults to False class SlackBotResponseType(str, PyEnum): diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index 98a50d50e9d..b71df22181e 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -113,6 +113,31 @@ def fetch_persona_by_id( return persona +def get_best_persona_id_for_user( + db_session: Session, user: User | None, persona_id: int | None = None +) -> int | None: + if persona_id is not None: + stmt = select(Persona).where(Persona.id == persona_id).distinct() + stmt = _add_user_filters( + stmt=stmt, + user=user, + # We don't want to filter by editable here, we just want to see if the + # persona is usable by the user + get_editable=False, + ) + persona = db_session.scalars(stmt).one_or_none() + if persona: + return persona.id + + # If the persona is not found, or the slack bot is using doc sets instead of personas, + # we need to find the best persona for the user + # This is the persona with the highest display priority that the user has access to + stmt = select(Persona).order_by(Persona.display_priority.desc()).distinct() + stmt = _add_user_filters(stmt=stmt, user=user, get_editable=True) + persona = db_session.scalars(stmt).one_or_none() + return persona.id if persona else None + + def _get_persona_by_name( persona_name: str, user: User | None, db_session: Session ) -> Persona | None: diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 3fd7072bb9a..a8fe531f7d5 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -26,6 +26,7 @@ from danswer.auth.schemas import UserUpdate from danswer.auth.users import auth_backend from danswer.auth.users import BasicAuthenticationError +from danswer.auth.users import create_danswer_oauth_router from danswer.auth.users import fastapi_users from danswer.configs.app_configs import APP_API_PREFIX from danswer.configs.app_configs import APP_HOST @@ -323,7 +324,7 @@ def get_application() -> FastAPI: oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET) include_router_with_global_prefix_prepended( application, - fastapi_users.get_oauth_router( + create_danswer_oauth_router( oauth_client, auth_backend, USER_AUTH_SECRET, diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 9f8ce99231b..826673acb0d 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -47,6 +47,7 @@ from danswer.one_shot_answer.models import OneShotQAResponse from danswer.one_shot_answer.models import QueryRephrase from danswer.one_shot_answer.qa_utils import combine_message_thread +from danswer.one_shot_answer.qa_utils import slackify_message_thread from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail @@ -194,13 +195,22 @@ def stream_answer_objects( ) prompt = persona.prompts[0] + user_message_str = query_msg.message + # For this endpoint, we only save one user message to the chat session + # However, for slackbot, we want to include the history of the entire thread + if danswerbot_flow: + # Right now, we only support bringing over citations and search docs + # from the last message in the thread, not the entire thread + # in the future, we may want to retrieve the entire thread + user_message_str = slackify_message_thread(query_req.messages) + # Create the first User query message new_user_message = create_new_chat_message( chat_session_id=chat_session.id, parent_message=root_message, prompt_id=query_req.prompt_id, - message=query_msg.message, - token_count=len(llm_tokenizer.encode(query_msg.message)), + message=user_message_str, + token_count=len(llm_tokenizer.encode(user_message_str)), message_type=MessageType.USER, db_session=db_session, commit=True, diff --git a/backend/danswer/one_shot_answer/qa_utils.py b/backend/danswer/one_shot_answer/qa_utils.py index 6fbad99eff1..8770a3b1413 100644 --- a/backend/danswer/one_shot_answer/qa_utils.py +++ b/backend/danswer/one_shot_answer/qa_utils.py @@ -51,3 +51,31 @@ def combine_message_thread( total_token_count += message_token_count return "\n\n".join(message_strs) + + +def slackify_message(message: ThreadMessage) -> str: + if message.role != MessageType.USER: + return message.message + + return f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}" + + +def slackify_message_thread(messages: list[ThreadMessage]) -> str: + if not messages: + return "" + + message_strs: list[str] = [] + for message in messages: + if message.role == MessageType.USER: + message_text = ( + f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}" + ) + elif message.role == MessageType.ASSISTANT: + message_text = f"DanswerBot said in Slack:\n{message.message}" + else: + message_text = ( + f"{message.role.value.upper()} said in Slack:\n{message.message}" + ) + message_strs.append(message_text) + + return "\n\n".join(message_strs) diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index 74a3a774e21..9c2960741f3 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -156,6 +156,7 @@ class SlackChannelConfigCreationRequest(BaseModel): channel_name: str respond_tag_only: bool = False respond_to_bots: bool = False + show_continue_in_web_ui: bool = False enable_auto_filters: bool = False # If no team members, assume respond in the channel to everyone respond_member_group_list: list[str] = Field(default_factory=list) diff --git a/backend/danswer/server/manage/slack_bot.py b/backend/danswer/server/manage/slack_bot.py index 036f2fca0dd..60a7edaaed0 100644 --- a/backend/danswer/server/manage/slack_bot.py +++ b/backend/danswer/server/manage/slack_bot.py @@ -80,6 +80,10 @@ def _form_channel_config( if follow_up_tags is not None: channel_config["follow_up_tags"] = follow_up_tags + channel_config[ + "show_continue_in_web_ui" + ] = slack_channel_config_creation_request.show_continue_in_web_ui + channel_config[ "respond_to_bots" ] = slack_channel_config_creation_request.respond_to_bots diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index c4728336c86..954728c32a3 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -27,9 +27,11 @@ from danswer.configs.constants import FileOrigin from danswer.configs.constants import MessageType from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS +from danswer.db.chat import add_chats_to_session_from_slack_thread from danswer.db.chat import create_chat_session from danswer.db.chat import create_new_chat_message from danswer.db.chat import delete_chat_session +from danswer.db.chat import duplicate_chat_session_for_user_from_slack from danswer.db.chat import get_chat_message from danswer.db.chat import get_chat_messages_by_session from danswer.db.chat import get_chat_session_by_id @@ -532,6 +534,38 @@ def seed_chat( ) +class SeedChatFromSlackRequest(BaseModel): + chat_session_id: UUID + + +class SeedChatFromSlackResponse(BaseModel): + redirect_url: str + + +@router.post("/seed-chat-session-from-slack") +def seed_chat_from_slack( + chat_seed_request: SeedChatFromSlackRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> SeedChatFromSlackResponse: + slack_chat_session_id = chat_seed_request.chat_session_id + new_chat_session = duplicate_chat_session_for_user_from_slack( + db_session=db_session, + user=user, + chat_session_id=slack_chat_session_id, + ) + + add_chats_to_session_from_slack_thread( + db_session=db_session, + slack_chat_session_id=slack_chat_session_id, + new_chat_session_id=new_chat_session.id, + ) + + return SeedChatFromSlackResponse( + redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}" + ) + + """File upload""" diff --git a/web/src/app/admin/bots/[bot-id]/SlackChannelConfigsTable.tsx b/web/src/app/admin/bots/[bot-id]/SlackChannelConfigsTable.tsx index 632e41aa375..1f99b7ca214 100644 --- a/web/src/app/admin/bots/[bot-id]/SlackChannelConfigsTable.tsx +++ b/web/src/app/admin/bots/[bot-id]/SlackChannelConfigsTable.tsx @@ -60,21 +60,24 @@ export function SlackChannelConfigsTable({ .slice(numToDisplay * (page - 1), numToDisplay * page) .map((slackChannelConfig) => { return ( - + { + window.location.href = `/admin/bots/${slackBotId}/channels/${slackChannelConfig.id}`; + }} + >
- +
- +
{"#" + slackChannelConfig.channel_config.channel_name}
- + e.stopPropagation()}> {slackChannelConfig.persona && !isPersonaASlackBotPersona(slackChannelConfig.persona) ? ( - + e.stopPropagation()}>
{ + onClick={async (e) => { + e.stopPropagation(); const response = await deleteSlackChannelConfig( slackChannelConfig.id ); diff --git a/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx b/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx index 9a8caad2ad5..5b51e8cf61d 100644 --- a/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx +++ b/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx @@ -81,6 +81,11 @@ export const SlackChannelConfigCreationForm = ({ respond_to_bots: existingSlackChannelConfig?.channel_config?.respond_to_bots || false, + show_continue_in_web_ui: + // If we're updating, we want to keep the existing value + // Otherwise, we want to default to true + existingSlackChannelConfig?.channel_config + ?.show_continue_in_web_ui ?? !isUpdate, enable_auto_filters: existingSlackChannelConfig?.enable_auto_filters || false, respond_member_group_list: @@ -119,6 +124,7 @@ export const SlackChannelConfigCreationForm = ({ questionmark_prefilter_enabled: Yup.boolean().required(), respond_tag_only: Yup.boolean().required(), respond_to_bots: Yup.boolean().required(), + show_continue_in_web_ui: Yup.boolean().required(), enable_auto_filters: Yup.boolean().required(), respond_member_group_list: Yup.array().of(Yup.string()).required(), still_need_help_enabled: Yup.boolean().required(), @@ -270,7 +276,13 @@ export const SlackChannelConfigCreationForm = ({ {showAdvancedOptions && (
-
+ +
{ const searchParams = await props.searchParams; const autoRedirectDisabled = searchParams?.disableAutoRedirect === "true"; + const nextUrl = Array.isArray(searchParams?.next) + ? searchParams?.next[0] + : searchParams?.next || null; // catch cases where the backend is completely unreachable here // without try / catch, will just raise an exception and the page @@ -37,10 +40,6 @@ const Page = async (props: { console.log(`Some fetch failed for the login page - ${e}`); } - const nextUrl = Array.isArray(searchParams?.next) - ? searchParams?.next[0] - : searchParams?.next || null; - // simply take the user to the home page if Auth is disabled if (authTypeMetadata?.authType === "disabled") { return redirect("/"); @@ -100,12 +99,15 @@ const Page = async (props: { or
- +
Don't have an account?{" "} - + Create an account @@ -120,11 +122,14 @@ const Page = async (props: {
- +
Don't have an account?{" "} - + Create an account diff --git a/web/src/app/auth/signup/page.tsx b/web/src/app/auth/signup/page.tsx index 223faff331d..94a7d1967bb 100644 --- a/web/src/app/auth/signup/page.tsx +++ b/web/src/app/auth/signup/page.tsx @@ -15,7 +15,14 @@ import AuthFlowContainer from "@/components/auth/AuthFlowContainer"; import ReferralSourceSelector from "./ReferralSourceSelector"; import { Separator } from "@/components/ui/separator"; -const Page = async () => { +const Page = async (props: { + searchParams?: Promise<{ [key: string]: string | string[] | undefined }>; +}) => { + const searchParams = await props.searchParams; + const nextUrl = Array.isArray(searchParams?.next) + ? searchParams?.next[0] + : searchParams?.next || null; + // catch cases where the backend is completely unreachable here // without try / catch, will just raise an exception and the page // will not render @@ -86,12 +93,19 @@ const Page = async () => {
Already have an account?{" "} - + Log In diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 0bb3ebfa965..94f336ba885 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -161,6 +161,8 @@ export function ChatPage({ const { user, isAdmin, isLoadingUser, refreshUser } = useUser(); + const slackChatId = searchParams.get("slackChatId"); + const existingChatIdRaw = searchParams.get("chatId"); const [sendOnLoad, setSendOnLoad] = useState( searchParams.get(SEARCH_PARAM_NAMES.SEND_ON_LOAD) @@ -403,6 +405,7 @@ export function ChatPage({ } return; } + setIsReady(true); const shouldScrollToBottom = visibleRange.get(existingChatSessionId) === undefined || visibleRange.get(existingChatSessionId)?.end == 0; @@ -468,9 +471,12 @@ export function ChatPage({ }); // force re-name if the chat session doesn't have one if (!chatSession.description) { - await nameChatSession(existingChatSessionId, seededMessage); + await nameChatSession(existingChatSessionId); refreshChatSessions(); } + } else if (newMessageHistory.length === 2 && !chatSession.description) { + await nameChatSession(existingChatSessionId); + refreshChatSessions(); } } @@ -1428,7 +1434,7 @@ export function ChatPage({ if (!searchParamBasedChatSessionName) { await new Promise((resolve) => setTimeout(resolve, 200)); - await nameChatSession(currChatSessionId, currMessage); + await nameChatSession(currChatSessionId); refreshChatSessions(); } @@ -1810,6 +1816,42 @@ export function ChatPage({ }; } + useEffect(() => { + const handleSlackChatRedirect = async () => { + if (!slackChatId) return; + + // Set isReady to false before starting retrieval to display loading text + setIsReady(false); + + try { + const response = await fetch("/api/chat/seed-chat-session-from-slack", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + chat_session_id: slackChatId, + }), + }); + + if (!response.ok) { + throw new Error("Failed to seed chat from Slack"); + } + + const data = await response.json(); + router.push(data.redirect_url); + } catch (error) { + console.error("Error seeding chat from Slack:", error); + setPopup({ + message: "Failed to load chat from Slack", + type: "error", + }); + } + }; + + handleSlackChatRedirect(); + }, [searchParams, router]); + return ( <> diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index a64c605a095..00529776407 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -203,7 +203,7 @@ export async function* sendMessage({ yield* handleSSEStream(response); } -export async function nameChatSession(chatSessionId: string, message: string) { +export async function nameChatSession(chatSessionId: string) { const response = await fetch("/api/chat/rename-chat-session", { method: "PUT", headers: { @@ -212,7 +212,6 @@ export async function nameChatSession(chatSessionId: string, message: string) { body: JSON.stringify({ chat_session_id: chatSessionId, name: null, - first_message: message, }), }); return response; @@ -263,7 +262,6 @@ export async function renameChatSession( body: JSON.stringify({ chat_session_id: chatSessionId, name: newName, - first_message: null, }), }); return response; diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 8ea6047dd1a..7fe1402c5f7 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -208,6 +208,7 @@ export interface ChannelConfig { channel_name: string; respond_tag_only?: boolean; respond_to_bots?: boolean; + show_continue_in_web_ui?: boolean; respond_member_group_list?: string[]; answer_filters?: AnswerFilterOption[]; follow_up_tags?: string[]; diff --git a/web/src/lib/userSS.ts b/web/src/lib/userSS.ts index 906f23fa8b2..b0c9609391f 100644 --- a/web/src/lib/userSS.ts +++ b/web/src/lib/userSS.ts @@ -62,12 +62,17 @@ const getOIDCAuthUrlSS = async (nextUrl: string | null): Promise => { return data.authorization_url; }; -const getGoogleOAuthUrlSS = async (): Promise => { - const res = await fetch(buildUrl(`/auth/oauth/authorize`), { - headers: { - cookie: processCookies(await cookies()), - }, - }); +const getGoogleOAuthUrlSS = async (nextUrl: string | null): Promise => { + const res = await fetch( + buildUrl( + `/auth/oauth/authorize${nextUrl ? `?next=${encodeURIComponent(nextUrl)}` : ""}` + ), + { + headers: { + cookie: processCookies(await cookies()), + }, + } + ); if (!res.ok) { throw new Error("Failed to fetch data"); } @@ -76,8 +81,12 @@ const getGoogleOAuthUrlSS = async (): Promise => { return data.authorization_url; }; -const getSAMLAuthUrlSS = async (): Promise => { - const res = await fetch(buildUrl("/auth/saml/authorize")); +const getSAMLAuthUrlSS = async (nextUrl: string | null): Promise => { + const res = await fetch( + buildUrl( + `/auth/saml/authorize${nextUrl ? `?next=${encodeURIComponent(nextUrl)}` : ""}` + ) + ); if (!res.ok) { throw new Error("Failed to fetch data"); } @@ -97,13 +106,13 @@ export const getAuthUrlSS = async ( case "basic": return ""; case "google_oauth": { - return await getGoogleOAuthUrlSS(); + return await getGoogleOAuthUrlSS(nextUrl); } case "cloud": { - return await getGoogleOAuthUrlSS(); + return await getGoogleOAuthUrlSS(nextUrl); } case "saml": { - return await getSAMLAuthUrlSS(); + return await getSAMLAuthUrlSS(nextUrl); } case "oidc": { return await getOIDCAuthUrlSS(nextUrl); From 9c0cc94f15532624d1b843c4595cd0845f582344 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Wed, 27 Nov 2024 11:11:58 -0800 Subject: [PATCH 003/133] refresh router -> refresh assistants (#3271) --- web/src/app/admin/assistants/AssistantEditor.tsx | 2 +- web/src/app/admin/assistants/PersonaTable.tsx | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index 650c4d199a8..f2fb85d1b79 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -405,7 +405,7 @@ export function AssistantEditor({ message: `"${assistant.name}" has been added to your list.`, type: "success", }); - router.refresh(); + await refreshAssistants(); } else { setPopup({ message: `"${assistant.name}" could not be added to your list.`, diff --git a/web/src/app/admin/assistants/PersonaTable.tsx b/web/src/app/admin/assistants/PersonaTable.tsx index c5dcfc2690e..e451a519b49 100644 --- a/web/src/app/admin/assistants/PersonaTable.tsx +++ b/web/src/app/admin/assistants/PersonaTable.tsx @@ -90,7 +90,7 @@ export function PersonasTable() { message: `Failed to update persona order - ${await response.text()}`, }); setFinalPersonas(assistants); - router.refresh(); + await refreshAssistants(); return; } @@ -151,7 +151,7 @@ export function PersonasTable() { persona.is_visible ); if (response.ok) { - router.refresh(); + await refreshAssistants(); } else { setPopup({ type: "error", @@ -183,7 +183,7 @@ export function PersonasTable() { onClick={async () => { const response = await deletePersona(persona.id); if (response.ok) { - router.refresh(); + await refreshAssistants(); } else { alert( `Failed to delete persona - ${await response.text()}` From 09d3e47c03b758a8c93dc922432b1c11d7d42a56 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 27 Nov 2024 12:04:15 -0800 Subject: [PATCH 004/133] Perm sync behavior change (#3262) * Change external permissions behavior * fixed behavior * added error handling * LLM the goat * comment * simplify * fixed * done * limits increased * added a ton of logging * uhhhh --- .../tasks/doc_permission_syncing/tasks.py | 6 +- .../tasks/external_group_syncing/tasks.py | 2 +- .../connectors/confluence/connector.py | 9 +- .../connectors/confluence/onyx_confluence.py | 2 +- .../danswer/db/connector_credential_pair.py | 16 ++- backend/ee/danswer/db/user_group.py | 12 +- .../confluence/doc_sync.py | 14 +- .../confluence/group_sync.py | 5 +- .../external_permissions/sync_params.py | 6 +- .../slack/test_permission_sync.py | 122 ++++++++++++++++++ 10 files changed, 170 insertions(+), 24 deletions(-) diff --git a/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py index 6a5761a7428..eef14e980ca 100644 --- a/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py +++ b/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py @@ -241,9 +241,11 @@ def connector_permission_sync_generator_task( doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type) if doc_sync_func is None: - raise ValueError(f"No doc sync func found for {source_type}") + raise ValueError( + f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}" + ) - logger.info(f"Syncing docs for {source_type}") + logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}") payload = RedisConnectorPermissionSyncData( started=datetime.now(timezone.utc), diff --git a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py index 61ceae4e463..8381656ee17 100644 --- a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py @@ -49,7 +49,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool: if cc_pair.access_type != AccessType.SYNC: return False - # skip pruning if not active + # skip external group sync if not active if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE: return False diff --git a/backend/danswer/connectors/confluence/connector.py b/backend/danswer/connectors/confluence/connector.py index 0e09a4aed61..e30c85922ce 100644 --- a/backend/danswer/connectors/confluence/connector.py +++ b/backend/danswer/connectors/confluence/connector.py @@ -51,7 +51,7 @@ "restrictions.read.restrictions.group", ] -_SLIM_DOC_BATCH_SIZE = 1000 +_SLIM_DOC_BATCH_SIZE = 5000 class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): @@ -301,5 +301,8 @@ def retrieve_all_slim_documents( perm_sync_data=perm_sync_data, ) ) - yield doc_metadata_list - doc_metadata_list = [] + if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE: + yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE] + doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:] + + yield doc_metadata_list diff --git a/backend/danswer/connectors/confluence/onyx_confluence.py b/backend/danswer/connectors/confluence/onyx_confluence.py index 8b4ec81ef8b..e1542109c42 100644 --- a/backend/danswer/connectors/confluence/onyx_confluence.py +++ b/backend/danswer/connectors/confluence/onyx_confluence.py @@ -120,7 +120,7 @@ def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: return cast(F, wrapped_call) -_DEFAULT_PAGINATION_LIMIT = 100 +_DEFAULT_PAGINATION_LIMIT = 1000 class OnyxConfluence(Confluence): diff --git a/backend/danswer/db/connector_credential_pair.py b/backend/danswer/db/connector_credential_pair.py index 2cc96f6fa63..26730d1178f 100644 --- a/backend/danswer/db/connector_credential_pair.py +++ b/backend/danswer/db/connector_credential_pair.py @@ -324,8 +324,11 @@ def associate_default_cc_pair(db_session: Session) -> None: def _relate_groups_to_cc_pair__no_commit( db_session: Session, cc_pair_id: int, - user_group_ids: list[int], + user_group_ids: list[int] | None = None, ) -> None: + if not user_group_ids: + return + for group_id in user_group_ids: db_session.add( UserGroup__ConnectorCredentialPair( @@ -402,12 +405,11 @@ def add_credential_to_connector( db_session.flush() # make sure the association has an id db_session.refresh(association) - if groups and access_type != AccessType.SYNC: - _relate_groups_to_cc_pair__no_commit( - db_session=db_session, - cc_pair_id=association.id, - user_group_ids=groups, - ) + _relate_groups_to_cc_pair__no_commit( + db_session=db_session, + cc_pair_id=association.id, + user_group_ids=groups, + ) db_session.commit() diff --git a/backend/ee/danswer/db/user_group.py b/backend/ee/danswer/db/user_group.py index ba9e3440497..187f7c7b901 100644 --- a/backend/ee/danswer/db/user_group.py +++ b/backend/ee/danswer/db/user_group.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id +from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Credential__UserGroup @@ -298,6 +299,11 @@ def fetch_user_groups_for_documents( db_session: Session, document_ids: list[str], ) -> Sequence[tuple[str, list[str]]]: + """ + Fetches all user groups that have access to the given documents. + + NOTE: this doesn't include groups if the cc_pair is access type SYNC + """ stmt = ( select(Document.id, func.array_agg(UserGroup.name)) .join( @@ -306,7 +312,11 @@ def fetch_user_groups_for_documents( ) .join( ConnectorCredentialPair, - ConnectorCredentialPair.id == UserGroup__ConnectorCredentialPair.cc_pair_id, + and_( + ConnectorCredentialPair.id + == UserGroup__ConnectorCredentialPair.cc_pair_id, + ConnectorCredentialPair.access_type != AccessType.SYNC, + ), ) .join( DocumentByConnectorCredentialPair, diff --git a/backend/ee/danswer/external_permissions/confluence/doc_sync.py b/backend/ee/danswer/external_permissions/confluence/doc_sync.py index be6b2f76151..d83da900d2c 100644 --- a/backend/ee/danswer/external_permissions/confluence/doc_sync.py +++ b/backend/ee/danswer/external_permissions/confluence/doc_sync.py @@ -97,6 +97,7 @@ def _get_space_permissions( confluence_client: OnyxConfluence, is_cloud: bool, ) -> dict[str, ExternalAccess]: + logger.debug("Getting space permissions") # Gets all the spaces in the Confluence instance all_space_keys = [] start = 0 @@ -113,6 +114,7 @@ def _get_space_permissions( start += len(spaces_batch.get("results", [])) # Gets the permissions for each space + logger.debug(f"Got {len(all_space_keys)} spaces from confluence") space_permissions_by_space_key: dict[str, ExternalAccess] = {} for space_key in all_space_keys: if is_cloud: @@ -242,6 +244,7 @@ def _fetch_all_page_restrictions_for_space( logger.warning(f"No permissions found for document {slim_doc.id}") + logger.debug("Finished fetching all page restrictions for space") return document_restrictions @@ -254,27 +257,28 @@ def confluence_doc_sync( it in postgres so that when it gets created later, the permissions are already populated """ + logger.debug("Starting confluence doc sync") confluence_connector = ConfluenceConnector( **cc_pair.connector.connector_specific_config ) confluence_connector.load_credentials(cc_pair.credential.credential_json) - if confluence_connector.confluence_client is None: - raise ValueError("Failed to load credentials") - confluence_client = confluence_connector.confluence_client is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False) space_permissions_by_space_key = _get_space_permissions( - confluence_client=confluence_client, + confluence_client=confluence_connector.confluence_client, is_cloud=is_cloud, ) slim_docs = [] + logger.debug("Fetching all slim documents from confluence") for doc_batch in confluence_connector.retrieve_all_slim_documents(): + logger.debug(f"Got {len(doc_batch)} slim documents from confluence") slim_docs.extend(doc_batch) + logger.debug("Fetching all page restrictions for space") return _fetch_all_page_restrictions_for_space( - confluence_client=confluence_client, + confluence_client=confluence_connector.confluence_client, slim_docs=slim_docs, space_permissions_by_space_key=space_permissions_by_space_key, ) diff --git a/backend/ee/danswer/external_permissions/confluence/group_sync.py b/backend/ee/danswer/external_permissions/confluence/group_sync.py index f2f53e589b1..8f3f3e43fc6 100644 --- a/backend/ee/danswer/external_permissions/confluence/group_sync.py +++ b/backend/ee/danswer/external_permissions/confluence/group_sync.py @@ -14,7 +14,10 @@ def _build_group_member_email_map( ) -> dict[str, set[str]]: group_member_emails: dict[str, set[str]] = {} for user_result in confluence_client.paginated_cql_user_retrieval(): - user = user_result["user"] + user = user_result.get("user", {}) + if not user: + logger.warning(f"user result missing user field: {user_result}") + continue email = user.get("email") if not email: # This field is only present in Confluence Server diff --git a/backend/ee/danswer/external_permissions/sync_params.py b/backend/ee/danswer/external_permissions/sync_params.py index c00090d748d..43c8a78122c 100644 --- a/backend/ee/danswer/external_permissions/sync_params.py +++ b/backend/ee/danswer/external_permissions/sync_params.py @@ -57,9 +57,9 @@ # If nothing is specified here, we run the doc_sync every time the celery beat runs EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = { - # Polling is not supported so we fetch all group permissions every 60 seconds - DocumentSource.GOOGLE_DRIVE: 60, - DocumentSource.CONFLUENCE: 60, + # Polling is not supported so we fetch all group permissions every 5 minutes + DocumentSource.GOOGLE_DRIVE: 5 * 60, + DocumentSource.CONFLUENCE: 5 * 60, } diff --git a/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py b/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py index 6c0c5908cd1..3c37332547d 100644 --- a/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py +++ b/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py @@ -14,6 +14,7 @@ ) from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager from tests.integration.common_utils.test_models import DATestCCPair from tests.integration.common_utils.test_models import DATestConnector from tests.integration.common_utils.test_models import DATestCredential @@ -215,3 +216,124 @@ def test_slack_permission_sync( # Ensure test_user_1 can only see messages from the public channel assert public_message in danswer_doc_message_strings assert private_message not in danswer_doc_message_strings + + +def test_slack_group_permission_sync( + reset: None, + vespa_client: vespa_fixture, + slack_test_setup: tuple[dict[str, Any], dict[str, Any]], +) -> None: + """ + This test ensures that permission sync overrides danswer group access. + """ + public_channel, private_channel = slack_test_setup + + # Creating an admin user (first user created is automatically an admin) + admin_user: DATestUser = UserManager.create( + email="admin@onyx-test.com", + ) + + # Creating a non-admin user + test_user_1: DATestUser = UserManager.create( + email="test_user_1@onyx-test.com", + ) + + # Create a user group and adding the non-admin user to it + user_group = UserGroupManager.create( + name="test_group", + user_ids=[test_user_1.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group], + user_performing_action=admin_user, + ) + + slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"]) + email_id_map = SlackManager.build_slack_user_email_id_map(slack_client) + admin_user_id = email_id_map[admin_user.email] + + LLMProviderManager.create(user_performing_action=admin_user) + + # Add only admin to the private channel + SlackManager.set_channel_members( + slack_client=slack_client, + admin_user_id=admin_user_id, + channel=private_channel, + user_ids=[admin_user_id], + ) + + before = datetime.now(timezone.utc) + credential = CredentialManager.create( + source=DocumentSource.SLACK, + credential_json={ + "slack_bot_token": os.environ["SLACK_BOT_TOKEN"], + }, + user_performing_action=admin_user, + ) + + # Create connector with sync access and assign it to the user group + connector = ConnectorManager.create( + name="Slack", + input_type=InputType.POLL, + source=DocumentSource.SLACK, + connector_specific_config={ + "workspace": "onyx-test-workspace", + "channels": [private_channel["name"]], + }, + access_type=AccessType.SYNC, + groups=[user_group.id], + user_performing_action=admin_user, + ) + + cc_pair = CCPairManager.create( + credential_id=credential.id, + connector_id=connector.id, + access_type=AccessType.SYNC, + user_performing_action=admin_user, + groups=[user_group.id], + ) + + # Add a test message to the private channel + private_message = "This is a secret message: 987654" + SlackManager.add_message_to_channel( + slack_client=slack_client, + channel=private_channel, + message=private_message, + ) + + # Run indexing + CCPairManager.run_once(cc_pair, admin_user) + CCPairManager.wait_for_indexing( + cc_pair=cc_pair, + after=before, + user_performing_action=admin_user, + ) + + # Run permission sync + CCPairManager.sync( + cc_pair=cc_pair, + user_performing_action=admin_user, + ) + CCPairManager.wait_for_sync( + cc_pair=cc_pair, + after=before, + number_of_updated_docs=1, + user_performing_action=admin_user, + ) + + # Verify admin can see the message + admin_docs = DocumentSearchManager.search_documents( + query="secret message", + user_performing_action=admin_user, + ) + assert private_message in admin_docs + + # Verify test_user_1 cannot see the message despite being in the group + # (Slack permissions should take precedence) + user_1_docs = DocumentSearchManager.search_documents( + query="secret message", + user_performing_action=test_user_1, + ) + assert private_message not in user_1_docs From 634a0b9398e746273275cdaf49b0c8b69fa5095c Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Wed, 27 Nov 2024 12:58:21 -0800 Subject: [PATCH 005/133] no stack by default (#3278) --- .../performance/usage/QueryPerformanceChart.tsx | 1 + web/src/components/ui/areaChart.tsx | 16 +++------------- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/web/src/app/ee/admin/performance/usage/QueryPerformanceChart.tsx b/web/src/app/ee/admin/performance/usage/QueryPerformanceChart.tsx index ffbf4e8c93e..f9ed3f7986d 100644 --- a/web/src/app/ee/admin/performance/usage/QueryPerformanceChart.tsx +++ b/web/src/app/ee/admin/performance/usage/QueryPerformanceChart.tsx @@ -62,6 +62,7 @@ export function QueryPerformanceChart({ chart = ( { const queryAnalyticsForDate = dateToQueryAnalytics.get(dateStr); const userAnalyticsForDate = dateToUserAnalytics.get(dateStr); diff --git a/web/src/components/ui/areaChart.tsx b/web/src/components/ui/areaChart.tsx index 71f593b4f7b..c0baeae93c4 100644 --- a/web/src/components/ui/areaChart.tsx +++ b/web/src/components/ui/areaChart.tsx @@ -24,18 +24,12 @@ interface AreaChartProps { categories?: string[]; index?: string; colors?: string[]; - startEndOnly?: boolean; showXAxis?: boolean; showYAxis?: boolean; yAxisWidth?: number; showAnimation?: boolean; showTooltip?: boolean; - showLegend?: boolean; showGridLines?: boolean; - showGradient?: boolean; - autoMinValue?: boolean; - minValue?: number; - maxValue?: number; connectNulls?: boolean; allowDecimals?: boolean; className?: string; @@ -43,6 +37,7 @@ interface AreaChartProps { description?: string; xAxisFormatter?: (value: any) => string; yAxisFormatter?: (value: any) => string; + stacked?: boolean; } export function AreaChartDisplay({ @@ -50,18 +45,12 @@ export function AreaChartDisplay({ categories = [], index, colors = ["indigo", "fuchsia"], - startEndOnly = false, showXAxis = true, showYAxis = true, yAxisWidth = 56, showAnimation = true, showTooltip = true, - showLegend = false, showGridLines = true, - showGradient = true, - autoMinValue = false, - minValue, - maxValue, connectNulls = false, allowDecimals = true, className, @@ -69,6 +58,7 @@ export function AreaChartDisplay({ description, xAxisFormatter = (dateStr: string) => dateStr, yAxisFormatter = (number: number) => number.toString(), + stacked = false, }: AreaChartProps) { return ( @@ -113,7 +103,7 @@ export function AreaChartDisplay({ key={category} type="monotone" dataKey={category} - stackId="1" + stackId={stacked ? "1" : category} stroke={colors[ind % colors.length]} fill={colors[ind % colors.length]} fillOpacity={0.3} From ac448956e94a8e233813970e83cd124d63fe80a4 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Wed, 27 Nov 2024 14:22:15 -0800 Subject: [PATCH 006/133] Add handling for rate limiting (#3280) --- .../tasks/external_group_syncing/tasks.py | 10 ++-- .../natural_language_processing/exceptions.py | 4 ++ .../search_nlp_models.py | 53 +++++++++++++------ backend/model_server/encoders.py | 44 +++++++-------- .../tests/daily/embedding/test_embeddings.py | 40 ++++++++++++++ 5 files changed, 110 insertions(+), 41 deletions(-) create mode 100644 backend/danswer/natural_language_processing/exceptions.py diff --git a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py index 8381656ee17..c3f0f6c6f15 100644 --- a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py @@ -195,7 +195,7 @@ def connector_external_group_sync_generator_task( tenant_id: str | None, ) -> None: """ - Permission sync task that handles document permission syncing for a given connector credential pair + Permission sync task that handles external group syncing for a given connector credential pair This task assumes that the task has already been properly fenced """ @@ -228,9 +228,13 @@ def connector_external_group_sync_generator_task( ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type) if ext_group_sync_func is None: - raise ValueError(f"No external group sync func found for {source_type}") + raise ValueError( + f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}" + ) - logger.info(f"Syncing docs for {source_type}") + logger.info( + f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}" + ) external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair) diff --git a/backend/danswer/natural_language_processing/exceptions.py b/backend/danswer/natural_language_processing/exceptions.py new file mode 100644 index 00000000000..5ca112f64ea --- /dev/null +++ b/backend/danswer/natural_language_processing/exceptions.py @@ -0,0 +1,4 @@ +class ModelServerRateLimitError(Exception): + """ + Exception raised for rate limiting errors from the model server. + """ diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index ee80292de63..9fed0d489e7 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -6,6 +6,9 @@ import requests from httpx import HTTPError +from requests import JSONDecodeError +from requests import RequestException +from requests import Response from retry import retry from danswer.configs.app_configs import LARGE_CHUNK_RATIO @@ -16,6 +19,9 @@ from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.models import SearchSettings from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface +from danswer.natural_language_processing.exceptions import ( + ModelServerRateLimitError, +) from danswer.natural_language_processing.utils import get_tokenizer from danswer.natural_language_processing.utils import tokenizer_trim_content from danswer.utils.logger import setup_logger @@ -99,28 +105,43 @@ def __init__( self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed" def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse: - def _make_request() -> EmbedResponse: + def _make_request() -> Response: response = requests.post( self.embed_server_endpoint, json=embed_request.model_dump() ) - try: - response.raise_for_status() - except requests.HTTPError as e: - try: - error_detail = response.json().get("detail", str(e)) - except Exception: - error_detail = response.text - raise HTTPError(f"HTTP error occurred: {error_detail}") from e - except requests.RequestException as e: - raise HTTPError(f"Request failed: {str(e)}") from e + # signify that this is a rate limit error + if response.status_code == 429: + raise ModelServerRateLimitError(response.text) - return EmbedResponse(**response.json()) + response.raise_for_status() + return response + + final_make_request_func = _make_request - # only perform retries for the non-realtime embedding of passages (e.g. for indexing) + # if the text type is a passage, add some default + # retries + handling for rate limiting if embed_request.text_type == EmbedTextType.PASSAGE: - return retry(tries=3, delay=5)(_make_request)() - else: - return _make_request() + final_make_request_func = retry( + tries=3, + delay=5, + exceptions=(RequestException, ValueError, JSONDecodeError), + )(final_make_request_func) + # use 10 second delay as per Azure suggestion + final_make_request_func = retry( + tries=10, delay=10, exceptions=ModelServerRateLimitError + )(final_make_request_func) + + try: + response = final_make_request_func() + return EmbedResponse(**response.json()) + except requests.HTTPError as e: + try: + error_detail = response.json().get("detail", str(e)) + except Exception: + error_detail = response.text + raise HTTPError(f"HTTP error occurred: {error_detail}") from e + except requests.RequestException as e: + raise HTTPError(f"Request failed: {str(e)}") from e def _batch_encode_texts( self, diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 003953cb29a..c72be9e4ac3 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -11,6 +11,7 @@ from fastapi import HTTPException from google.oauth2 import service_account # type: ignore from litellm import embedding +from litellm.exceptions import RateLimitError from retry import retry from sentence_transformers import CrossEncoder # type: ignore from sentence_transformers import SentenceTransformer # type: ignore @@ -205,28 +206,22 @@ def embed( model_name: str | None = None, deployment_name: str | None = None, ) -> list[Embedding]: - try: - if self.provider == EmbeddingProvider.OPENAI: - return self._embed_openai(texts, model_name) - elif self.provider == EmbeddingProvider.AZURE: - return self._embed_azure(texts, f"azure/{deployment_name}") - elif self.provider == EmbeddingProvider.LITELLM: - return self._embed_litellm_proxy(texts, model_name) - - embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) - if self.provider == EmbeddingProvider.COHERE: - return self._embed_cohere(texts, model_name, embedding_type) - elif self.provider == EmbeddingProvider.VOYAGE: - return self._embed_voyage(texts, model_name, embedding_type) - elif self.provider == EmbeddingProvider.GOOGLE: - return self._embed_vertex(texts, model_name, embedding_type) - else: - raise ValueError(f"Unsupported provider: {self.provider}") - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error embedding text with {self.provider}: {str(e)}", - ) + if self.provider == EmbeddingProvider.OPENAI: + return self._embed_openai(texts, model_name) + elif self.provider == EmbeddingProvider.AZURE: + return self._embed_azure(texts, f"azure/{deployment_name}") + elif self.provider == EmbeddingProvider.LITELLM: + return self._embed_litellm_proxy(texts, model_name) + + embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) + if self.provider == EmbeddingProvider.COHERE: + return self._embed_cohere(texts, model_name, embedding_type) + elif self.provider == EmbeddingProvider.VOYAGE: + return self._embed_voyage(texts, model_name, embedding_type) + elif self.provider == EmbeddingProvider.GOOGLE: + return self._embed_vertex(texts, model_name, embedding_type) + else: + raise ValueError(f"Unsupported provider: {self.provider}") @staticmethod def create( @@ -430,6 +425,11 @@ async def process_embed_request( prefix=prefix, ) return EmbedResponse(embeddings=embeddings) + except RateLimitError as e: + raise HTTPException( + status_code=429, + detail=str(e), + ) except Exception as e: exception_detail = f"Error during embedding process:\n{str(e)}" logger.exception(exception_detail) diff --git a/backend/tests/daily/embedding/test_embeddings.py b/backend/tests/daily/embedding/test_embeddings.py index 10a1dd850f6..7182510214f 100644 --- a/backend/tests/daily/embedding/test_embeddings.py +++ b/backend/tests/daily/embedding/test_embeddings.py @@ -7,6 +7,7 @@ from shared_configs.model_server_models import EmbeddingProvider VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"] +VALID_LONG_SAMPLE = ["hi " * 999] # openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't # seem to be true TOO_LONG_SAMPLE = ["a"] * 2500 @@ -99,3 +100,42 @@ def local_nomic_embedding_model() -> EmbeddingModel: def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None: _run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768) _run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768) + + +@pytest.fixture +def azure_embedding_model() -> EmbeddingModel: + return EmbeddingModel( + server_host="localhost", + server_port=9000, + model_name="text-embedding-3-large", + normalize=True, + query_prefix=None, + passage_prefix=None, + api_key=os.getenv("AZURE_API_KEY"), + provider_type=EmbeddingProvider.AZURE, + api_url=os.getenv("AZURE_API_URL"), + ) + + +# NOTE (chris): this test doesn't work, and I do not know why +# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel): +# """NOTE: this test relies on a very low rate limit for the Azure API + +# this test only being run once in a 1 minute window""" +# # VALID_LONG_SAMPLE is 999 tokens, so the second call should run into rate +# # limits assuming the limit is 1000 tokens per minute +# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY) +# assert len(result) == 1 +# assert len(result[0]) == 1536 + +# # this should fail +# with pytest.raises(ModelServerRateLimitError): +# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY) +# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY) +# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY) + +# # this should succeed, since passage requests retry up to 10 times +# start = time.time() +# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.PASSAGE) +# assert len(result) == 1 +# assert len(result[0]) == 1536 +# assert time.time() - start > 30 # make sure we waited, even though we hit rate limits From eb8708f7708b2081efaa25db16e0386ffbb17c14 Mon Sep 17 00:00:00 2001 From: "Richard Kuo (Danswer)" Date: Tue, 26 Nov 2024 22:03:53 -0800 Subject: [PATCH 007/133] the word "error" might be throwing off sentry --- backend/danswer/natural_language_processing/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/danswer/natural_language_processing/utils.py b/backend/danswer/natural_language_processing/utils.py index b46757728f9..35f5629e06f 100644 --- a/backend/danswer/natural_language_processing/utils.py +++ b/backend/danswer/natural_language_processing/utils.py @@ -131,7 +131,7 @@ def _try_initialize_tokenizer( return tokenizer except Exception as hf_error: logger.warning( - f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}" + f"Failed to initialize HuggingFaceTokenizer for {model_name}: {hf_error}" ) # If both initializations fail, return None From 212353ed4a8a69e2b2ea43451f8235cb3c84ca4e Mon Sep 17 00:00:00 2001 From: Matthew Holland Date: Mon, 25 Nov 2024 19:29:54 -0800 Subject: [PATCH 008/133] Fixed default feedback options --- web/src/app/chat/modal/FeedbackModal.tsx | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/web/src/app/chat/modal/FeedbackModal.tsx b/web/src/app/chat/modal/FeedbackModal.tsx index 39c3253b76a..886a761acb2 100644 --- a/web/src/app/chat/modal/FeedbackModal.tsx +++ b/web/src/app/chat/modal/FeedbackModal.tsx @@ -5,15 +5,19 @@ import { FeedbackType } from "../types"; import { Modal } from "@/components/Modal"; import { FilledLikeIcon } from "@/components/icons/icons"; -const predefinedPositiveFeedbackOptions = - process.env.NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS?.split(",") || - []; -const predefinedNegativeFeedbackOptions = - process.env.NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS?.split(",") || [ - "Retrieved documents were not relevant", - "AI misread the documents", - "Cited source had incorrect information", - ]; +const predefinedPositiveFeedbackOptions = process.env + .NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS + ? process.env.NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS.split(",") + : []; + +const predefinedNegativeFeedbackOptions = process.env + .NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS + ? process.env.NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS.split(",") + : [ + "Retrieved documents were not relevant", + "AI misread the documents", + "Cited source had incorrect information", + ]; interface FeedbackModalProps { feedbackType: FeedbackType; From 36941ae663a646d36d10bec8cc5825dd58c504fd Mon Sep 17 00:00:00 2001 From: Subash-Mohan Date: Sat, 23 Nov 2024 16:05:22 +0530 Subject: [PATCH 009/133] fix: Cannot configure API keys #3191 --- backend/danswer/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/danswer/main.py b/backend/danswer/main.py index a8fe531f7d5..e6e71e5b80c 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -91,6 +91,7 @@ from danswer.server.token_rate_limits.api import ( router as token_rate_limit_settings_router, ) +from danswer.server.api_key.api import router as api_key_router from danswer.setup import setup_danswer from danswer.setup import setup_multitenant_danswer from danswer.utils.logger import setup_logger @@ -281,6 +282,7 @@ def get_application() -> FastAPI: application, get_full_openai_assistants_api_router() ) include_router_with_global_prefix_prepended(application, long_term_logs_router) + include_router_with_global_prefix_prepended(application, api_key_router) if AUTH_TYPE == AuthType.DISABLED: # Server logs this during auth setup verification step From fd84b7a768455d6cc8fc0e8f4a9f6d6e63ab81a7 Mon Sep 17 00:00:00 2001 From: Weves Date: Wed, 27 Nov 2024 16:30:25 -0800 Subject: [PATCH 010/133] Remove duplicate API key router --- backend/ee/danswer/main.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/backend/ee/danswer/main.py b/backend/ee/danswer/main.py index 96655af2acd..198f945b8da 100644 --- a/backend/ee/danswer/main.py +++ b/backend/ee/danswer/main.py @@ -13,7 +13,6 @@ from danswer.configs.constants import AuthType from danswer.main import get_application as get_application_base from danswer.main import include_router_with_global_prefix_prepended -from danswer.server.api_key.api import router as api_key_router from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import global_version from ee.danswer.configs.app_configs import OPENID_CONFIG_URL @@ -116,8 +115,6 @@ def get_application() -> FastAPI: # Analytics endpoints include_router_with_global_prefix_prepended(application, analytics_router) include_router_with_global_prefix_prepended(application, query_history_router) - # Api key management - include_router_with_global_prefix_prepended(application, api_key_router) # EE only backend APIs include_router_with_global_prefix_prepended(application, query_router) include_router_with_global_prefix_prepended(application, chat_router) From 5be7d27285bb246a17731e39a4589a40ac001cb2 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Wed, 27 Nov 2024 17:34:34 -0800 Subject: [PATCH 011/133] use indexing flag in db for manually triggering indexing (#3264) * use indexing flag in db for manually trigger indexing * add comment. * only try to release the lock if we actually succeeded with the lock * ensure we don't trigger manual indexing on anything but the primary search settings * comment usage of primary search settings * run check for indexing immediately after indexing triggers are set * reorder fix --- ...78b8217_add_indexing_trigger_to_cc_pair.py | 30 +++++++++ .../background/celery/tasks/indexing/tasks.py | 54 +++++++++++++--- .../background/celery/versioned_apps/beat.py | 4 +- .../celery/versioned_apps/primary.py | 4 +- backend/danswer/db/connector.py | 23 +++++++ backend/danswer/db/enums.py | 5 ++ backend/danswer/db/models.py | 6 +- backend/danswer/main.py | 2 +- backend/danswer/server/documents/connector.py | 64 ++++++++----------- 9 files changed, 139 insertions(+), 53 deletions(-) create mode 100644 backend/alembic/versions/abe7378b8217_add_indexing_trigger_to_cc_pair.py diff --git a/backend/alembic/versions/abe7378b8217_add_indexing_trigger_to_cc_pair.py b/backend/alembic/versions/abe7378b8217_add_indexing_trigger_to_cc_pair.py new file mode 100644 index 00000000000..cc947eef0e0 --- /dev/null +++ b/backend/alembic/versions/abe7378b8217_add_indexing_trigger_to_cc_pair.py @@ -0,0 +1,30 @@ +"""add indexing trigger to cc_pair + +Revision ID: abe7378b8217 +Revises: 6d562f86c78b +Create Date: 2024-11-26 19:09:53.481171 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "abe7378b8217" +down_revision = "93560ba1b118" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "connector_credential_pair", + sa.Column( + "indexing_trigger", + sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False), + nullable=True, + ), + ) + + +def downgrade() -> None: + op.drop_column("connector_credential_pair", "indexing_trigger") diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py index 73b2b20a4e0..9ebab40d0af 100644 --- a/backend/danswer/background/celery/tasks/indexing/tasks.py +++ b/backend/danswer/background/celery/tasks/indexing/tasks.py @@ -25,11 +25,13 @@ from danswer.configs.constants import DanswerCeleryQueues from danswer.configs.constants import DanswerRedisLocks from danswer.configs.constants import DocumentSource +from danswer.db.connector import mark_ccpair_with_indexing_trigger from danswer.db.connector_credential_pair import fetch_connector_credential_pairs from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id from danswer.db.engine import get_db_current_time from danswer.db.engine import get_session_with_tenant from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.db.enums import IndexingMode from danswer.db.enums import IndexingStatus from danswer.db.enums import IndexModelStatus from danswer.db.index_attempt import create_index_attempt @@ -159,7 +161,7 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[ ) def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: tasks_created = 0 - + locked = False r = get_redis_client(tenant_id=tenant_id) lock_beat: RedisLock = r.lock( @@ -172,6 +174,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: if not lock_beat.acquire(blocking=False): return None + locked = True + # check for search settings swap with get_session_with_tenant(tenant_id=tenant_id) as db_session: old_search_settings = check_index_swap(db_session=db_session) @@ -231,22 +235,46 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: last_attempt = get_last_attempt_for_cc_pair( cc_pair.id, search_settings_instance.id, db_session ) + + search_settings_primary = False + if search_settings_instance.id == primary_search_settings.id: + search_settings_primary = True + if not _should_index( cc_pair=cc_pair, last_index=last_attempt, search_settings_instance=search_settings_instance, + search_settings_primary=search_settings_primary, secondary_index_building=len(search_settings) > 1, db_session=db_session, ): continue + reindex = False + if search_settings_instance.id == primary_search_settings.id: + # the indexing trigger is only checked and cleared with the primary search settings + if cc_pair.indexing_trigger is not None: + if cc_pair.indexing_trigger == IndexingMode.REINDEX: + reindex = True + + task_logger.info( + f"Connector indexing manual trigger detected: " + f"cc_pair={cc_pair.id} " + f"search_settings={search_settings_instance.id} " + f"indexing_mode={cc_pair.indexing_trigger}" + ) + + mark_ccpair_with_indexing_trigger( + cc_pair.id, None, db_session + ) + # using a task queue and only allowing one task per cc_pair/search_setting # prevents us from starving out certain attempts attempt_id = try_creating_indexing_task( self.app, cc_pair, search_settings_instance, - False, + reindex, db_session, r, tenant_id, @@ -281,7 +309,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: mark_attempt_failed( attempt.id, db_session, failure_reason=failure_reason ) - except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." @@ -289,13 +316,14 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: except Exception: task_logger.exception(f"Unexpected exception: tenant={tenant_id}") finally: - if lock_beat.owned(): - lock_beat.release() - else: - task_logger.error( - "check_for_indexing - Lock not owned on completion: " - f"tenant={tenant_id}" - ) + if locked: + if lock_beat.owned(): + lock_beat.release() + else: + task_logger.error( + "check_for_indexing - Lock not owned on completion: " + f"tenant={tenant_id}" + ) return tasks_created @@ -304,6 +332,7 @@ def _should_index( cc_pair: ConnectorCredentialPair, last_index: IndexAttempt | None, search_settings_instance: SearchSettings, + search_settings_primary: bool, secondary_index_building: bool, db_session: Session, ) -> bool: @@ -368,6 +397,11 @@ def _should_index( ): return False + if search_settings_primary: + if cc_pair.indexing_trigger is not None: + # if a manual indexing trigger is on the cc pair, honor it for primary search settings + return True + # if no attempt has ever occurred, we should index regardless of refresh_freq if not last_index: return True diff --git a/backend/danswer/background/celery/versioned_apps/beat.py b/backend/danswer/background/celery/versioned_apps/beat.py index af407f93c64..64bc1112ed3 100644 --- a/backend/danswer/background/celery/versioned_apps/beat.py +++ b/backend/danswer/background/celery/versioned_apps/beat.py @@ -1,6 +1,8 @@ """Factory stub for running celery worker / celery beat.""" +from celery import Celery + from danswer.background.celery.apps.beat import celery_app from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable set_is_ee_based_on_env_variable() -app = celery_app +app: Celery = celery_app diff --git a/backend/danswer/background/celery/versioned_apps/primary.py b/backend/danswer/background/celery/versioned_apps/primary.py index 2d97caa3da5..f07a63b2e1a 100644 --- a/backend/danswer/background/celery/versioned_apps/primary.py +++ b/backend/danswer/background/celery/versioned_apps/primary.py @@ -1,8 +1,10 @@ """Factory stub for running celery worker / celery beat.""" +from celery import Celery + from danswer.utils.variable_functionality import fetch_versioned_implementation from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable set_is_ee_based_on_env_variable() -app = fetch_versioned_implementation( +app: Celery = fetch_versioned_implementation( "danswer.background.celery.apps.primary", "celery_app" ) diff --git a/backend/danswer/db/connector.py b/backend/danswer/db/connector.py index 767a722eec4..1bcfe75e4c1 100644 --- a/backend/danswer/db/connector.py +++ b/backend/danswer/db/connector.py @@ -12,6 +12,7 @@ from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ from danswer.configs.constants import DocumentSource from danswer.connectors.models import InputType +from danswer.db.enums import IndexingMode from danswer.db.models import Connector from danswer.db.models import ConnectorCredentialPair from danswer.db.models import IndexAttempt @@ -311,3 +312,25 @@ def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int) # If this changes, we need to update this function. cc_pair.last_time_external_group_sync = datetime.now(timezone.utc) db_session.commit() + + +def mark_ccpair_with_indexing_trigger( + cc_pair_id: int, indexing_mode: IndexingMode | None, db_session: Session +) -> None: + """indexing_mode sets a field which will be picked up by a background task + to trigger indexing. Set to None to disable the trigger.""" + try: + cc_pair = db_session.execute( + select(ConnectorCredentialPair) + .where(ConnectorCredentialPair.id == cc_pair_id) + .with_for_update() + ).scalar_one() + + if cc_pair is None: + raise ValueError(f"No cc_pair with ID: {cc_pair_id}") + + cc_pair.indexing_trigger = indexing_mode + db_session.commit() + except Exception: + db_session.rollback() + raise diff --git a/backend/danswer/db/enums.py b/backend/danswer/db/enums.py index b1905d4e785..0ccb1470ca7 100644 --- a/backend/danswer/db/enums.py +++ b/backend/danswer/db/enums.py @@ -19,6 +19,11 @@ def is_terminal(self) -> bool: return self in terminal_states +class IndexingMode(str, PyEnum): + UPDATE = "update" + REINDEX = "reindex" + + # these may differ in the future, which is why we're okay with this duplication class DeletionStatus(str, PyEnum): NOT_STARTED = "not_started" diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 4e1970a7bd2..2163aa5aaf4 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -42,7 +42,7 @@ from danswer.configs.constants import DocumentSource from danswer.configs.constants import FileOrigin from danswer.configs.constants import MessageType -from danswer.db.enums import AccessType +from danswer.db.enums import AccessType, IndexingMode from danswer.configs.constants import NotificationType from danswer.configs.constants import SearchFeedbackType from danswer.configs.constants import TokenRateLimitScope @@ -438,6 +438,10 @@ class ConnectorCredentialPair(Base): total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0) + indexing_trigger: Mapped[IndexingMode | None] = mapped_column( + Enum(IndexingMode, native_enum=False), nullable=True + ) + connector: Mapped["Connector"] = relationship( "Connector", back_populates="credentials" ) diff --git a/backend/danswer/main.py b/backend/danswer/main.py index e6e71e5b80c..a9094399702 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -45,6 +45,7 @@ from danswer.configs.constants import POSTGRES_WEB_APP_NAME from danswer.db.engine import SqlEngine from danswer.db.engine import warm_up_connections +from danswer.server.api_key.api import router as api_key_router from danswer.server.auth_check import check_router_auth from danswer.server.danswer_api.ingestion import router as danswer_api_router from danswer.server.documents.cc_pair import router as cc_pair_router @@ -91,7 +92,6 @@ from danswer.server.token_rate_limits.api import ( router as token_rate_limit_settings_router, ) -from danswer.server.api_key.api import router as api_key_router from danswer.setup import setup_danswer from danswer.setup import setup_multitenant_danswer from danswer.utils.logger import setup_logger diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 9b9da834e05..cdeb4ed16c6 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -17,9 +17,9 @@ from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot -from danswer.background.celery.tasks.indexing.tasks import try_creating_indexing_task from danswer.background.celery.versioned_apps.primary import app as primary_app from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES +from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DocumentSource from danswer.configs.constants import FileOrigin from danswer.connectors.google_utils.google_auth import ( @@ -59,6 +59,7 @@ from danswer.db.connector import fetch_connector_by_id from danswer.db.connector import fetch_connectors from danswer.db.connector import get_connector_credential_ids +from danswer.db.connector import mark_ccpair_with_indexing_trigger from danswer.db.connector import update_connector from danswer.db.connector_credential_pair import add_credential_to_connector from danswer.db.connector_credential_pair import get_cc_pair_groups_for_ids @@ -74,6 +75,7 @@ from danswer.db.engine import get_current_tenant_id from danswer.db.engine import get_session from danswer.db.enums import AccessType +from danswer.db.enums import IndexingMode from danswer.db.index_attempt import get_index_attempts_for_cc_pair from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id from danswer.db.index_attempt import get_latest_index_attempts @@ -86,7 +88,6 @@ from danswer.file_store.file_store import get_default_file_store from danswer.key_value_store.interface import KvKeyNotFoundError from danswer.redis.redis_connector import RedisConnector -from danswer.redis.redis_pool import get_redis_client from danswer.server.documents.models import AuthStatus from danswer.server.documents.models import AuthUrl from danswer.server.documents.models import ConnectorCredentialPairIdentifier @@ -792,12 +793,10 @@ def connector_run_once( _: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), tenant_id: str = Depends(get_current_tenant_id), -) -> StatusResponse[list[int]]: +) -> StatusResponse[int]: """Used to trigger indexing on a set of cc_pairs associated with a single connector.""" - r = get_redis_client(tenant_id=tenant_id) - connector_id = run_info.connector_id specified_credential_ids = run_info.credential_ids @@ -843,54 +842,41 @@ def connector_run_once( ) ] - search_settings = get_current_search_settings(db_session) - connector_credential_pairs = [ get_connector_credential_pair(connector_id, credential_id, db_session) for credential_id in credential_ids if credential_id not in skipped_credentials ] - index_attempt_ids = [] + num_triggers = 0 for cc_pair in connector_credential_pairs: if cc_pair is not None: - attempt_id = try_creating_indexing_task( - primary_app, - cc_pair, - search_settings, - run_info.from_beginning, - db_session, - r, - tenant_id, + indexing_mode = IndexingMode.UPDATE + if run_info.from_beginning: + indexing_mode = IndexingMode.REINDEX + + mark_ccpair_with_indexing_trigger(cc_pair.id, indexing_mode, db_session) + num_triggers += 1 + + logger.info( + f"connector_run_once - marking cc_pair with indexing trigger: " + f"connector={run_info.connector_id} " + f"cc_pair={cc_pair.id} " + f"indexing_trigger={indexing_mode}" ) - if attempt_id: - logger.info( - f"connector_run_once - try_creating_indexing_task succeeded: " - f"connector={run_info.connector_id} " - f"cc_pair={cc_pair.id} " - f"attempt={attempt_id} " - ) - index_attempt_ids.append(attempt_id) - else: - logger.info( - f"connector_run_once - try_creating_indexing_task failed: " - f"connector={run_info.connector_id} " - f"cc_pair={cc_pair.id}" - ) - if not index_attempt_ids: - msg = "No new indexing attempts created, indexing jobs are queued or running." - logger.info(msg) - raise HTTPException( - status_code=400, - detail=msg, - ) + # run the beat task to pick up the triggers immediately + primary_app.send_task( + "check_for_indexing", + priority=DanswerCeleryPriority.HIGH, + kwargs={"tenant_id": tenant_id}, + ) - msg = f"Successfully created {len(index_attempt_ids)} index attempts. {index_attempt_ids}" + msg = f"Marked {num_triggers} index attempts with indexing triggers." return StatusResponse( success=True, message=msg, - data=index_attempt_ids, + data=num_triggers, ) From 7f1e4a02bf85cff719b236f155dca3757d953da9 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Wed, 27 Nov 2024 21:32:45 -0800 Subject: [PATCH 012/133] Feature/kill indexing (#3213) * checkpoint * add celery termination of the task * rename to RedisConnectorPermissionSyncPayload, add RedisLock to more places, add get_active_search_settings * rename payload * pretty sure these weren't named correctly * testing in progress * cleanup * remove space * merge fix * three dots animation on Pausing * improve messaging when connector is stopped or killed and animate buttons --------- Co-authored-by: Richard Kuo --- .../celery/tasks/connector_deletion/tasks.py | 6 +- .../tasks/doc_permission_syncing/tasks.py | 19 ++-- .../tasks/external_group_syncing/tasks.py | 24 ++-- .../background/celery/tasks/indexing/tasks.py | 59 +++++++--- .../background/celery/tasks/vespa/tasks.py | 26 +++-- .../background/indexing/run_indexing.py | 62 +++++++---- backend/danswer/db/search_settings.py | 19 ++++ backend/danswer/redis/redis_connector.py | 41 +++++++ .../redis/redis_connector_doc_perm_sync.py | 15 ++- .../redis/redis_connector_ext_group_sync.py | 30 ++++- .../danswer/redis/redis_connector_index.py | 15 +++ backend/danswer/server/documents/cc_pair.py | 103 +++++++++++++++++- .../common_utils/managers/cc_pair.py | 80 +++++++++++++- .../slack/test_permission_sync.py | 6 +- .../connector_job_tests/slack/test_prune.py | 4 +- .../connector/test_connector_creation.py | 49 ++++++++- .../integration/tests/pruning/test_pruning.py | 2 +- .../configuration/search/UpgradingPage.tsx | 4 +- .../[ccPairId]/ModifyStatusButtonCluster.tsx | 98 +++++++++++------ .../connector/[ccPairId]/ReIndexButton.tsx | 2 +- .../app/admin/connector/[ccPairId]/page.tsx | 1 + 21 files changed, 539 insertions(+), 126 deletions(-) diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index 9413dd97854..caae8be301b 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -5,7 +5,6 @@ from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded -from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session @@ -37,7 +36,7 @@ class TaskDependencyError(RuntimeError): def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None: r = get_redis_client(tenant_id=tenant_id) - lock_beat = r.lock( + lock_beat: RedisLock = r.lock( DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK, timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) @@ -60,7 +59,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N redis_connector = RedisConnector(tenant_id, cc_pair_id) try: try_generate_document_cc_pair_cleanup_tasks( - self.app, cc_pair_id, db_session, r, lock_beat, tenant_id + self.app, cc_pair_id, db_session, lock_beat, tenant_id ) except TaskDependencyError as e: # this means we wanted to start deleting but dependent tasks were running @@ -86,7 +85,6 @@ def try_generate_document_cc_pair_cleanup_tasks( app: Celery, cc_pair_id: int, db_session: Session, - r: Redis, lock_beat: RedisLock, tenant_id: str | None, ) -> int | None: diff --git a/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py index eef14e980ca..babf9b69b6f 100644 --- a/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py +++ b/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py @@ -8,6 +8,7 @@ from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis +from redis.lock import Lock as RedisLock from danswer.access.models import DocExternalAccess from danswer.background.celery.apps.app_base import task_logger @@ -27,7 +28,7 @@ from danswer.db.users import batch_add_ext_perm_user_if_not_exists from danswer.redis.redis_connector import RedisConnector from danswer.redis.redis_connector_doc_perm_sync import ( - RedisConnectorPermissionSyncData, + RedisConnectorPermissionSyncPayload, ) from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import doc_permission_sync_ctx @@ -138,7 +139,7 @@ def try_creating_permissions_sync_task( LOCK_TIMEOUT = 30 - lock = r.lock( + lock: RedisLock = r.lock( DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks", timeout=LOCK_TIMEOUT, ) @@ -162,7 +163,7 @@ def try_creating_permissions_sync_task( custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}" - app.send_task( + result = app.send_task( "connector_permission_sync_generator_task", kwargs=dict( cc_pair_id=cc_pair_id, @@ -174,8 +175,8 @@ def try_creating_permissions_sync_task( ) # set a basic fence to start - payload = RedisConnectorPermissionSyncData( - started=None, + payload = RedisConnectorPermissionSyncPayload( + started=None, celery_task_id=result.id ) redis_connector.permissions.set_fence(payload) @@ -247,9 +248,11 @@ def connector_permission_sync_generator_task( logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}") - payload = RedisConnectorPermissionSyncData( - started=datetime.now(timezone.utc), - ) + payload = redis_connector.permissions.payload + if not payload: + raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}") + + payload.started = datetime.now(timezone.utc) redis_connector.permissions.set_fence(payload) document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair) diff --git a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py index c3f0f6c6f15..d80b2b518ee 100644 --- a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py @@ -8,6 +8,7 @@ from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis +from redis.lock import Lock as RedisLock from danswer.background.celery.apps.app_base import task_logger from danswer.configs.app_configs import JOB_TIMEOUT @@ -24,6 +25,9 @@ from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.models import ConnectorCredentialPair from danswer.redis.redis_connector import RedisConnector +from danswer.redis.redis_connector_ext_group_sync import ( + RedisConnectorExternalGroupSyncPayload, +) from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import setup_logger from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs @@ -107,7 +111,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None: cc_pair_ids_to_sync.append(cc_pair.id) for cc_pair_id in cc_pair_ids_to_sync: - tasks_created = try_creating_permissions_sync_task( + tasks_created = try_creating_external_group_sync_task( self.app, cc_pair_id, r, tenant_id ) if not tasks_created: @@ -125,7 +129,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None: lock_beat.release() -def try_creating_permissions_sync_task( +def try_creating_external_group_sync_task( app: Celery, cc_pair_id: int, r: Redis, @@ -156,7 +160,7 @@ def try_creating_permissions_sync_task( custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}" - _ = app.send_task( + result = app.send_task( "connector_external_group_sync_generator_task", kwargs=dict( cc_pair_id=cc_pair_id, @@ -166,8 +170,13 @@ def try_creating_permissions_sync_task( task_id=custom_task_id, priority=DanswerCeleryPriority.HIGH, ) - # set a basic fence to start - redis_connector.external_group_sync.set_fence(True) + + payload = RedisConnectorExternalGroupSyncPayload( + started=datetime.now(timezone.utc), + celery_task_id=result.id, + ) + + redis_connector.external_group_sync.set_fence(payload) except Exception: task_logger.exception( @@ -203,7 +212,7 @@ def connector_external_group_sync_generator_task( r = get_redis_client(tenant_id=tenant_id) - lock = r.lock( + lock: RedisLock = r.lock( DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX + f"_{redis_connector.id}", timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT, @@ -253,7 +262,6 @@ def connector_external_group_sync_generator_task( ) mark_cc_pair_as_external_group_synced(db_session, cc_pair.id) - except Exception as e: task_logger.exception( f"Failed to run external group sync: cc_pair={cc_pair_id}" @@ -264,6 +272,6 @@ def connector_external_group_sync_generator_task( raise e finally: # we always want to clear the fence after the task is done or failed so it doesn't get stuck - redis_connector.external_group_sync.set_fence(False) + redis_connector.external_group_sync.set_fence(None) if lock.owned(): lock.release() diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py index 9ebab40d0af..4525b1e9425 100644 --- a/backend/danswer/background/celery/tasks/indexing/tasks.py +++ b/backend/danswer/background/celery/tasks/indexing/tasks.py @@ -39,12 +39,13 @@ from danswer.db.index_attempt import get_all_index_attempts_by_status from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import get_last_attempt_for_cc_pair +from danswer.db.index_attempt import mark_attempt_canceled from danswer.db.index_attempt import mark_attempt_failed from danswer.db.models import ConnectorCredentialPair from danswer.db.models import IndexAttempt from danswer.db.models import SearchSettings +from danswer.db.search_settings import get_active_search_settings from danswer.db.search_settings import get_current_search_settings -from danswer.db.search_settings import get_secondary_search_settings from danswer.db.swap_index import check_index_swap from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface from danswer.natural_language_processing.search_nlp_models import EmbeddingModel @@ -209,17 +210,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: redis_connector = RedisConnector(tenant_id, cc_pair_id) with get_session_with_tenant(tenant_id) as db_session: - # Get the primary search settings - primary_search_settings = get_current_search_settings(db_session) - search_settings = [primary_search_settings] - - # Check for secondary search settings - secondary_search_settings = get_secondary_search_settings(db_session) - if secondary_search_settings is not None: - # If secondary settings exist, add them to the list - search_settings.append(secondary_search_settings) - - for search_settings_instance in search_settings: + search_settings_list: list[SearchSettings] = get_active_search_settings( + db_session + ) + for search_settings_instance in search_settings_list: redis_connector_index = redis_connector.new_index( search_settings_instance.id ) @@ -237,7 +231,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: ) search_settings_primary = False - if search_settings_instance.id == primary_search_settings.id: + if search_settings_instance.id == search_settings_list[0].id: search_settings_primary = True if not _should_index( @@ -245,13 +239,13 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: last_index=last_attempt, search_settings_instance=search_settings_instance, search_settings_primary=search_settings_primary, - secondary_index_building=len(search_settings) > 1, + secondary_index_building=len(search_settings_list) > 1, db_session=db_session, ): continue reindex = False - if search_settings_instance.id == primary_search_settings.id: + if search_settings_instance.id == search_settings_list[0].id: # the indexing trigger is only checked and cleared with the primary search settings if cc_pair.indexing_trigger is not None: if cc_pair.indexing_trigger == IndexingMode.REINDEX: @@ -284,7 +278,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: f"Connector indexing queued: " f"index_attempt={attempt_id} " f"cc_pair={cc_pair.id} " - f"search_settings={search_settings_instance.id} " + f"search_settings={search_settings_instance.id}" ) tasks_created += 1 @@ -529,8 +523,11 @@ def try_creating_indexing_task( return index_attempt_id -@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True) +@shared_task( + name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True +) def connector_indexing_proxy_task( + self: Task, index_attempt_id: int, cc_pair_id: int, search_settings_id: int, @@ -543,6 +540,10 @@ def connector_indexing_proxy_task( f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id}" ) + + if not self.request.id: + task_logger.error("self.request.id is None!") + client = SimpleJobClient() job = client.submit( @@ -571,8 +572,30 @@ def connector_indexing_proxy_task( f"search_settings={search_settings_id}" ) + redis_connector = RedisConnector(tenant_id, cc_pair_id) + redis_connector_index = redis_connector.new_index(search_settings_id) + while True: - sleep(10) + sleep(5) + + if self.request.id and redis_connector_index.terminating(self.request.id): + task_logger.warning( + "Indexing proxy - termination signal detected: " + f"attempt={index_attempt_id} " + f"tenant={tenant_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id}" + ) + + with get_session_with_tenant(tenant_id) as db_session: + mark_attempt_canceled( + index_attempt_id, + db_session, + "Connector termination signal detected", + ) + + job.cancel() + break # do nothing for ongoing jobs that haven't been stopped if not job.done(): diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index ec7f52bc03c..f491ff27b23 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -46,6 +46,7 @@ from danswer.db.document_set import get_document_set_by_id from danswer.db.document_set import mark_document_set_as_synced from danswer.db.engine import get_session_with_tenant +from danswer.db.enums import IndexingStatus from danswer.db.index_attempt import delete_index_attempts from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import mark_attempt_failed @@ -58,7 +59,7 @@ from danswer.redis.redis_connector_delete import RedisConnectorDelete from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync from danswer.redis.redis_connector_doc_perm_sync import ( - RedisConnectorPermissionSyncData, + RedisConnectorPermissionSyncPayload, ) from danswer.redis.redis_connector_index import RedisConnectorIndex from danswer.redis.redis_connector_prune import RedisConnectorPrune @@ -588,7 +589,7 @@ def monitor_ccpair_permissions_taskset( if remaining > 0: return - payload: RedisConnectorPermissionSyncData | None = ( + payload: RedisConnectorPermissionSyncPayload | None = ( redis_connector.permissions.payload ) start_time: datetime | None = payload.started if payload else None @@ -596,9 +597,7 @@ def monitor_ccpair_permissions_taskset( mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time) task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}") - redis_connector.permissions.taskset_clear() - redis_connector.permissions.generator_clear() - redis_connector.permissions.set_fence(None) + redis_connector.permissions.reset() def monitor_ccpair_indexing_taskset( @@ -678,11 +677,15 @@ def monitor_ccpair_indexing_taskset( index_attempt = get_index_attempt(db_session, payload.index_attempt_id) if index_attempt: - mark_attempt_failed( - index_attempt_id=payload.index_attempt_id, - db_session=db_session, - failure_reason=msg, - ) + if ( + index_attempt.status != IndexingStatus.CANCELED + and index_attempt.status != IndexingStatus.FAILED + ): + mark_attempt_failed( + index_attempt_id=payload.index_attempt_id, + db_session=db_session, + failure_reason=msg, + ) redis_connector_index.reset() return @@ -692,6 +695,7 @@ def monitor_ccpair_indexing_taskset( task_logger.info( f"Connector indexing finished: cc_pair={cc_pair_id} " f"search_settings={search_settings_id} " + f"progress={progress} " f"status={status_enum.name} " f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}" ) @@ -724,7 +728,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: # print current queue lengths r_celery = self.app.broker_connection().channel().client # type: ignore - n_celery = celery_get_queue_length("celery", r) + n_celery = celery_get_queue_length("celery", r_celery) n_indexing = celery_get_queue_length( DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery ) diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 699e4682caa..40ed778f033 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -19,6 +19,7 @@ from danswer.db.connector_credential_pair import update_connector_credential_pair from danswer.db.engine import get_session_with_tenant from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.db.index_attempt import mark_attempt_canceled from danswer.db.index_attempt import mark_attempt_failed from danswer.db.index_attempt import mark_attempt_partially_succeeded from danswer.db.index_attempt import mark_attempt_succeeded @@ -87,6 +88,10 @@ def _get_connector_runner( ) +class ConnectorStopSignal(Exception): + """A custom exception used to signal a stop in processing.""" + + def _run_indexing( db_session: Session, index_attempt: IndexAttempt, @@ -208,9 +213,7 @@ def _run_indexing( # contents still need to be initially pulled. if callback: if callback.should_stop(): - raise RuntimeError( - "_run_indexing: Connector stop signal detected" - ) + raise ConnectorStopSignal("Connector stop signal detected") # TODO: should we move this into the above callback instead? db_session.refresh(db_cc_pair) @@ -304,26 +307,16 @@ def _run_indexing( ) except Exception as e: logger.exception( - f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds" + f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds" ) - # Only mark the attempt as a complete failure if this is the first indexing window. - # Otherwise, some progress was made - the next run will not start from the beginning. - # In this case, it is not accurate to mark it as a failure. When the next run begins, - # if that fails immediately, it will be marked as a failure. - # - # NOTE: if the connector is manually disabled, we should mark it as a failure regardless - # to give better clarity in the UI, as the next run will never happen. - if ( - ind == 0 - or not db_cc_pair.status.is_active() - or index_attempt.status != IndexingStatus.IN_PROGRESS - ): - mark_attempt_failed( + + if isinstance(e, ConnectorStopSignal): + mark_attempt_canceled( index_attempt.id, db_session, - failure_reason=str(e), - full_exception_trace=traceback.format_exc(), + reason=str(e), ) + if is_primary: update_connector_credential_pair( db_session=db_session, @@ -335,6 +328,37 @@ def _run_indexing( if INDEXING_TRACER_INTERVAL > 0: tracer.stop() raise e + else: + # Only mark the attempt as a complete failure if this is the first indexing window. + # Otherwise, some progress was made - the next run will not start from the beginning. + # In this case, it is not accurate to mark it as a failure. When the next run begins, + # if that fails immediately, it will be marked as a failure. + # + # NOTE: if the connector is manually disabled, we should mark it as a failure regardless + # to give better clarity in the UI, as the next run will never happen. + if ( + ind == 0 + or not db_cc_pair.status.is_active() + or index_attempt.status != IndexingStatus.IN_PROGRESS + ): + mark_attempt_failed( + index_attempt.id, + db_session, + failure_reason=str(e), + full_exception_trace=traceback.format_exc(), + ) + + if is_primary: + update_connector_credential_pair( + db_session=db_session, + connector_id=db_connector.id, + credential_id=db_credential.id, + net_docs=net_doc_change, + ) + + if INDEXING_TRACER_INTERVAL > 0: + tracer.stop() + raise e # break => similar to success case. As mentioned above, if the next run fails for the same # reason it will then be marked as a failure diff --git a/backend/danswer/db/search_settings.py b/backend/danswer/db/search_settings.py index 4f437eaae53..1134b326a76 100644 --- a/backend/danswer/db/search_settings.py +++ b/backend/danswer/db/search_settings.py @@ -143,6 +143,25 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None: return latest_settings +def get_active_search_settings(db_session: Session) -> list[SearchSettings]: + """Returns active search settings. The first entry will always be the current search + settings. If there are new search settings that are being migrated to, those will be + the second entry.""" + search_settings_list: list[SearchSettings] = [] + + # Get the primary search settings + primary_search_settings = get_current_search_settings(db_session) + search_settings_list.append(primary_search_settings) + + # Check for secondary search settings + secondary_search_settings = get_secondary_search_settings(db_session) + if secondary_search_settings is not None: + # If secondary settings exist, add them to the list + search_settings_list.append(secondary_search_settings) + + return search_settings_list + + def get_all_search_settings(db_session: Session) -> list[SearchSettings]: query = select(SearchSettings).order_by(SearchSettings.id.desc()) result = db_session.execute(query) diff --git a/backend/danswer/redis/redis_connector.py b/backend/danswer/redis/redis_connector.py index 8b52a2fd811..8d82fc11943 100644 --- a/backend/danswer/redis/redis_connector.py +++ b/backend/danswer/redis/redis_connector.py @@ -1,5 +1,8 @@ +import time + import redis +from danswer.db.models import SearchSettings from danswer.redis.redis_connector_delete import RedisConnectorDelete from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync @@ -31,6 +34,44 @@ def new_index(self, search_settings_id: int) -> RedisConnectorIndex: self.tenant_id, self.id, search_settings_id, self.redis ) + def wait_for_indexing_termination( + self, + search_settings_list: list[SearchSettings], + timeout: float = 15.0, + ) -> bool: + """ + Returns True if all indexing for the given redis connector is finished within the given timeout. + Returns False if the timeout is exceeded + + This check does not guarantee that current indexings being terminated + won't get restarted midflight + """ + + finished = False + + start = time.monotonic() + + while True: + still_indexing = False + for search_settings in search_settings_list: + redis_connector_index = self.new_index(search_settings.id) + if redis_connector_index.fenced: + still_indexing = True + break + + if not still_indexing: + finished = True + break + + now = time.monotonic() + if now - start > timeout: + break + + time.sleep(1) + continue + + return finished + @staticmethod def get_id_from_fence_key(key: str) -> str | None: """ diff --git a/backend/danswer/redis/redis_connector_doc_perm_sync.py b/backend/danswer/redis/redis_connector_doc_perm_sync.py index d9c3cd814ff..7b3748fcc2d 100644 --- a/backend/danswer/redis/redis_connector_doc_perm_sync.py +++ b/backend/danswer/redis/redis_connector_doc_perm_sync.py @@ -14,8 +14,9 @@ from danswer.configs.constants import DanswerCeleryQueues -class RedisConnectorPermissionSyncData(BaseModel): +class RedisConnectorPermissionSyncPayload(BaseModel): started: datetime | None + celery_task_id: str | None class RedisConnectorPermissionSync: @@ -78,14 +79,14 @@ def fenced(self) -> bool: return False @property - def payload(self) -> RedisConnectorPermissionSyncData | None: + def payload(self) -> RedisConnectorPermissionSyncPayload | None: # read related data and evaluate/print task progress fence_bytes = cast(bytes, self.redis.get(self.fence_key)) if fence_bytes is None: return None fence_str = fence_bytes.decode("utf-8") - payload = RedisConnectorPermissionSyncData.model_validate_json( + payload = RedisConnectorPermissionSyncPayload.model_validate_json( cast(str, fence_str) ) @@ -93,7 +94,7 @@ def payload(self) -> RedisConnectorPermissionSyncData | None: def set_fence( self, - payload: RedisConnectorPermissionSyncData | None, + payload: RedisConnectorPermissionSyncPayload | None, ) -> None: if not payload: self.redis.delete(self.fence_key) @@ -162,6 +163,12 @@ def generate_tasks( return len(async_results) + def reset(self) -> None: + self.redis.delete(self.generator_progress_key) + self.redis.delete(self.generator_complete_key) + self.redis.delete(self.taskset_key) + self.redis.delete(self.fence_key) + @staticmethod def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None: taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}" diff --git a/backend/danswer/redis/redis_connector_ext_group_sync.py b/backend/danswer/redis/redis_connector_ext_group_sync.py index 631845648c3..bbe539c3954 100644 --- a/backend/danswer/redis/redis_connector_ext_group_sync.py +++ b/backend/danswer/redis/redis_connector_ext_group_sync.py @@ -1,11 +1,18 @@ +from datetime import datetime from typing import cast import redis from celery import Celery +from pydantic import BaseModel from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session +class RedisConnectorExternalGroupSyncPayload(BaseModel): + started: datetime | None + celery_task_id: str | None + + class RedisConnectorExternalGroupSync: """Manages interactions with redis for external group syncing tasks. Should only be accessed through RedisConnector.""" @@ -68,12 +75,29 @@ def fenced(self) -> bool: return False - def set_fence(self, value: bool) -> None: - if not value: + @property + def payload(self) -> RedisConnectorExternalGroupSyncPayload | None: + # read related data and evaluate/print task progress + fence_bytes = cast(bytes, self.redis.get(self.fence_key)) + if fence_bytes is None: + return None + + fence_str = fence_bytes.decode("utf-8") + payload = RedisConnectorExternalGroupSyncPayload.model_validate_json( + cast(str, fence_str) + ) + + return payload + + def set_fence( + self, + payload: RedisConnectorExternalGroupSyncPayload | None, + ) -> None: + if not payload: self.redis.delete(self.fence_key) return - self.redis.set(self.fence_key, 0) + self.redis.set(self.fence_key, payload.model_dump_json()) @property def generator_complete(self) -> int | None: diff --git a/backend/danswer/redis/redis_connector_index.py b/backend/danswer/redis/redis_connector_index.py index 10fd3667fda..40b194af03e 100644 --- a/backend/danswer/redis/redis_connector_index.py +++ b/backend/danswer/redis/redis_connector_index.py @@ -29,6 +29,8 @@ class RedisConnectorIndex: GENERATOR_LOCK_PREFIX = "da_lock:indexing" + TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate + def __init__( self, tenant_id: str | None, @@ -51,6 +53,7 @@ def __init__( self.generator_lock_key = ( f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}" ) + self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}" @classmethod def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str: @@ -92,6 +95,18 @@ def set_fence( self.redis.set(self.fence_key, payload.model_dump_json()) + def terminating(self, celery_task_id: str) -> bool: + if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"): + return True + + return False + + def set_terminate(self, celery_task_id: str) -> None: + """This sets a signal. It does not block!""" + # We shouldn't need very long to terminate the spawned task. + # 10 minute TTL is good. + self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600) + def set_generator_complete(self, payload: int | None) -> None: if not payload: self.redis.delete(self.generator_complete_key) diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 55808ebcee7..46bdb2078c4 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -6,6 +6,7 @@ from fastapi import Depends from fastapi import HTTPException from fastapi import Query +from fastapi.responses import JSONResponse from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -37,7 +38,9 @@ from danswer.db.index_attempt import count_index_attempts_for_connector from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id +from danswer.db.models import SearchSettings from danswer.db.models import User +from danswer.db.search_settings import get_active_search_settings from danswer.db.search_settings import get_current_search_settings from danswer.redis.redis_connector import RedisConnector from danswer.redis.redis_pool import get_redis_client @@ -158,7 +161,19 @@ def update_cc_pair_status( status_update_request: CCStatusUpdateRequest, user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), -) -> None: + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + """This method may wait up to 30 seconds if pausing the connector due to the need to + terminate tasks in progress. Tasks are not guaranteed to terminate within the + timeout. + + Returns HTTPStatus.OK if everything finished. + Returns HTTPStatus.ACCEPTED if the connector is being paused, but background tasks + did not finish within the timeout. + """ + WAIT_TIMEOUT = 15.0 + still_terminating = False + cc_pair = get_connector_credential_pair_from_id( cc_pair_id=cc_pair_id, db_session=db_session, @@ -173,10 +188,76 @@ def update_cc_pair_status( ) if status_update_request.status == ConnectorCredentialPairStatus.PAUSED: - cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session) + search_settings_list: list[SearchSettings] = get_active_search_settings( + db_session + ) + cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session) cancel_indexing_attempts_past_model(db_session) + redis_connector = RedisConnector(tenant_id, cc_pair_id) + + try: + redis_connector.stop.set_fence(True) + while True: + logger.debug( + f"Wait for indexing soft termination starting: cc_pair={cc_pair_id}" + ) + wait_succeeded = redis_connector.wait_for_indexing_termination( + search_settings_list, WAIT_TIMEOUT + ) + if wait_succeeded: + logger.debug( + f"Wait for indexing soft termination succeeded: cc_pair={cc_pair_id}" + ) + break + + logger.debug( + "Wait for indexing soft termination timed out. " + f"Moving to hard termination: cc_pair={cc_pair_id} timeout={WAIT_TIMEOUT:.2f}" + ) + + for search_settings in search_settings_list: + redis_connector_index = redis_connector.new_index( + search_settings.id + ) + if not redis_connector_index.fenced: + continue + + index_payload = redis_connector_index.payload + if not index_payload: + continue + + if not index_payload.celery_task_id: + continue + + # Revoke the task to prevent it from running + primary_app.control.revoke(index_payload.celery_task_id) + + # If it is running, then signaling for termination will get the + # watchdog thread to kill the spawned task + redis_connector_index.set_terminate(index_payload.celery_task_id) + + logger.debug( + f"Wait for indexing hard termination starting: cc_pair={cc_pair_id}" + ) + wait_succeeded = redis_connector.wait_for_indexing_termination( + search_settings_list, WAIT_TIMEOUT + ) + if wait_succeeded: + logger.debug( + f"Wait for indexing hard termination succeeded: cc_pair={cc_pair_id}" + ) + break + + logger.debug( + f"Wait for indexing hard termination timed out: cc_pair={cc_pair_id}" + ) + still_terminating = True + break + finally: + redis_connector.stop.set_fence(False) + update_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, @@ -185,6 +266,18 @@ def update_cc_pair_status( db_session.commit() + if still_terminating: + return JSONResponse( + status_code=HTTPStatus.ACCEPTED, + content={ + "message": "Request accepted, background task termination still in progress" + }, + ) + + return JSONResponse( + status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)} + ) + @router.put("/admin/cc-pair/{cc_pair_id}/name") def update_cc_pair_name( @@ -267,9 +360,9 @@ def prune_cc_pair( ) logger.info( - f"Pruning cc_pair: cc_pair_id={cc_pair_id} " - f"connector_id={cc_pair.connector_id} " - f"credential_id={cc_pair.credential_id} " + f"Pruning cc_pair: cc_pair={cc_pair_id} " + f"connector={cc_pair.connector_id} " + f"credential={cc_pair.credential_id} " f"{cc_pair.connector.name} connector." ) tasks_created = try_creating_prune_generator_task( diff --git a/backend/tests/integration/common_utils/managers/cc_pair.py b/backend/tests/integration/common_utils/managers/cc_pair.py index b37822d3496..d32e100563b 100644 --- a/backend/tests/integration/common_utils/managers/cc_pair.py +++ b/backend/tests/integration/common_utils/managers/cc_pair.py @@ -240,7 +240,85 @@ def run_once( result.raise_for_status() @staticmethod - def wait_for_indexing( + def wait_for_indexing_inactive( + cc_pair: DATestCCPair, + timeout: float = MAX_DELAY, + user_performing_action: DATestUser | None = None, + ) -> None: + """wait for the number of docs to be indexed on the connector. + This is used to test pausing a connector in the middle of indexing and + terminating that indexing.""" + print(f"Indexing wait for inactive starting: cc_pair={cc_pair.id}") + start = time.monotonic() + while True: + fetched_cc_pairs = CCPairManager.get_indexing_statuses( + user_performing_action + ) + for fetched_cc_pair in fetched_cc_pairs: + if fetched_cc_pair.cc_pair_id != cc_pair.id: + continue + + if fetched_cc_pair.in_progress: + continue + + print(f"Indexing is inactive: cc_pair={cc_pair.id}") + return + + elapsed = time.monotonic() - start + if elapsed > timeout: + raise TimeoutError( + f"Indexing wait for inactive timed out: cc_pair={cc_pair.id} timeout={timeout}s" + ) + + print( + f"Indexing wait for inactive still waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s" + ) + time.sleep(5) + + @staticmethod + def wait_for_indexing_in_progress( + cc_pair: DATestCCPair, + timeout: float = MAX_DELAY, + num_docs: int = 16, + user_performing_action: DATestUser | None = None, + ) -> None: + """wait for the number of docs to be indexed on the connector. + This is used to test pausing a connector in the middle of indexing and + terminating that indexing.""" + start = time.monotonic() + while True: + fetched_cc_pairs = CCPairManager.get_indexing_statuses( + user_performing_action + ) + for fetched_cc_pair in fetched_cc_pairs: + if fetched_cc_pair.cc_pair_id != cc_pair.id: + continue + + if not fetched_cc_pair.in_progress: + continue + + if fetched_cc_pair.docs_indexed >= num_docs: + print( + "Indexed at least the requested number of docs: " + f"cc_pair={cc_pair.id} " + f"docs_indexed={fetched_cc_pair.docs_indexed} " + f"num_docs={num_docs}" + ) + return + + elapsed = time.monotonic() - start + if elapsed > timeout: + raise TimeoutError( + f"Indexing in progress wait timed out: cc_pair={cc_pair.id} timeout={timeout}s" + ) + + print( + f"Indexing in progress waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s" + ) + time.sleep(5) + + @staticmethod + def wait_for_indexing_completion( cc_pair: DATestCCPair, after: datetime, timeout: float = MAX_DELAY, diff --git a/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py b/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py index 3c37332547d..8045501ce27 100644 --- a/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py +++ b/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py @@ -78,7 +78,7 @@ def test_slack_permission_sync( access_type=AccessType.SYNC, user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user, @@ -113,7 +113,7 @@ def test_slack_permission_sync( # Run indexing before = datetime.now(timezone.utc) CCPairManager.run_once(cc_pair, admin_user) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user, @@ -305,7 +305,7 @@ def test_slack_group_permission_sync( # Run indexing CCPairManager.run_once(cc_pair, admin_user) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user, diff --git a/backend/tests/integration/connector_job_tests/slack/test_prune.py b/backend/tests/integration/connector_job_tests/slack/test_prune.py index 2dfc3d0750f..b2decb6584b 100644 --- a/backend/tests/integration/connector_job_tests/slack/test_prune.py +++ b/backend/tests/integration/connector_job_tests/slack/test_prune.py @@ -74,7 +74,7 @@ def test_slack_prune( access_type=AccessType.SYNC, user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user, @@ -113,7 +113,7 @@ def test_slack_prune( # Run indexing before = datetime.now(timezone.utc) CCPairManager.run_once(cc_pair, admin_user) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user, diff --git a/backend/tests/integration/tests/connector/test_connector_creation.py b/backend/tests/integration/tests/connector/test_connector_creation.py index acfafe9436d..61085c5a5d2 100644 --- a/backend/tests/integration/tests/connector/test_connector_creation.py +++ b/backend/tests/integration/tests/connector/test_connector_creation.py @@ -58,7 +58,7 @@ def test_overlapping_connector_creation(reset: None) -> None: user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair_1, now, timeout=120, user_performing_action=admin_user ) @@ -71,7 +71,7 @@ def test_overlapping_connector_creation(reset: None) -> None: user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair_2, now, timeout=120, user_performing_action=admin_user ) @@ -82,3 +82,48 @@ def test_overlapping_connector_creation(reset: None) -> None: assert info_2 assert info_1.num_docs_indexed == info_2.num_docs_indexed + + +def test_connector_pause_while_indexing(reset: None) -> None: + """Tests that we can pause a connector while indexing is in progress and that + tasks end early or abort as a result. + + TODO: This does not specifically test for soft or hard termination code paths. + Design specific tests for those use cases. + """ + admin_user: DATestUser = UserManager.create(name="admin_user") + + config = { + "wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"], + "space": "", + "is_cloud": True, + "page_id": "", + } + + credential = { + "confluence_username": os.environ["CONFLUENCE_USER_NAME"], + "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"], + } + + # store the time before we create the connector so that we know after + # when the indexing should have started + datetime.now(timezone.utc) + + # create connector + cc_pair_1 = CCPairManager.create_from_scratch( + source=DocumentSource.CONFLUENCE, + connector_specific_config=config, + credential_json=credential, + user_performing_action=admin_user, + ) + + CCPairManager.wait_for_indexing_in_progress( + cc_pair_1, timeout=60, num_docs=16, user_performing_action=admin_user + ) + + CCPairManager.pause_cc_pair(cc_pair_1, user_performing_action=admin_user) + + CCPairManager.wait_for_indexing_inactive( + cc_pair_1, timeout=60, user_performing_action=admin_user + ) + return diff --git a/backend/tests/integration/tests/pruning/test_pruning.py b/backend/tests/integration/tests/pruning/test_pruning.py index 9d9a41c7069..beb1e8efbe9 100644 --- a/backend/tests/integration/tests/pruning/test_pruning.py +++ b/backend/tests/integration/tests/pruning/test_pruning.py @@ -135,7 +135,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None: user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair_1, now, timeout=60, user_performing_action=admin_user ) diff --git a/web/src/app/admin/configuration/search/UpgradingPage.tsx b/web/src/app/admin/configuration/search/UpgradingPage.tsx index ecd7f87316e..98653c4aa42 100644 --- a/web/src/app/admin/configuration/search/UpgradingPage.tsx +++ b/web/src/app/admin/configuration/search/UpgradingPage.tsx @@ -161,7 +161,7 @@ export default function UpgradingPage({ reindexingProgress={sortedReindexingProgress} /> ) : ( - + )} ) : ( @@ -171,7 +171,7 @@ export default function UpgradingPage({

You're currently switching embedding models, but there - are no connectors to re-index. This means the transition will + are no connectors to reindex. This means the transition will be quick and seamless!

diff --git a/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx b/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx index 71d26a8eb47..b5b4e7ecbf2 100644 --- a/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx +++ b/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx @@ -6,6 +6,8 @@ import { usePopup } from "@/components/admin/connectors/Popup"; import { mutate } from "swr"; import { buildCCPairInfoUrl } from "./lib"; import { setCCPairStatus } from "@/lib/ccPair"; +import { useState } from "react"; +import { LoadingAnimation } from "@/components/Loading"; export function ModifyStatusButtonCluster({ ccPair, @@ -13,44 +15,72 @@ export function ModifyStatusButtonCluster({ ccPair: CCPairFullInfo; }) { const { popup, setPopup } = usePopup(); + const [isUpdating, setIsUpdating] = useState(false); + + const handleStatusChange = async ( + newStatus: ConnectorCredentialPairStatus + ) => { + if (isUpdating) return; // Prevent double-clicks or multiple requests + setIsUpdating(true); + + try { + // Call the backend to update the status + await setCCPairStatus(ccPair.id, newStatus, setPopup); + + // Use mutate to revalidate the status on the backend + await mutate(buildCCPairInfoUrl(ccPair.id)); + } catch (error) { + console.error("Failed to update status", error); + } finally { + // Reset local updating state and button text after mutation + setIsUpdating(false); + } + }; + + // Compute the button text based on current state and backend status + const buttonText = + ccPair.status === ConnectorCredentialPairStatus.PAUSED + ? "Re-Enable" + : "Pause"; + + const tooltip = + ccPair.status === ConnectorCredentialPairStatus.PAUSED + ? "Click to start indexing again!" + : "When paused, the connector's documents will still be visible. However, no new documents will be indexed."; return ( <> {popup} - {ccPair.status === ConnectorCredentialPairStatus.PAUSED ? ( - - ) : ( - - )} + ); } diff --git a/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx b/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx index af0e2a8f4aa..962339e9fe8 100644 --- a/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx +++ b/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx @@ -121,7 +121,7 @@ export function ReIndexButton({ {popup}

{!documentSet.is_up_to_date && ( - +
Cannot update while syncing! Wait for the sync to finish, then diff --git a/web/src/app/admin/settings/SettingsForm.tsx b/web/src/app/admin/settings/SettingsForm.tsx index f00e4d978b5..5e2eb00335a 100644 --- a/web/src/app/admin/settings/SettingsForm.tsx +++ b/web/src/app/admin/settings/SettingsForm.tsx @@ -175,29 +175,6 @@ export function SettingsForm() { { fieldName, newValue: checked }, ]; - // If we're disabling a page, check if we need to update the default page - if ( - !checked && - (fieldName === "search_page_enabled" || fieldName === "chat_page_enabled") - ) { - const otherPageField = - fieldName === "search_page_enabled" - ? "chat_page_enabled" - : "search_page_enabled"; - const otherPageEnabled = settings && settings[otherPageField]; - - if ( - otherPageEnabled && - settings?.default_page === - (fieldName === "search_page_enabled" ? "search" : "chat") - ) { - updates.push({ - fieldName: "default_page", - newValue: fieldName === "search_page_enabled" ? "chat" : "search", - }); - } - } - updateSettingField(updates); } @@ -218,42 +195,17 @@ export function SettingsForm() { return (
{popup} - Page Visibility + Workspace Settings - handleToggleSettingsField("search_page_enabled", e.target.checked) + handleToggleSettingsField("auto_scroll", e.target.checked) } /> - - handleToggleSettingsField("chat_page_enabled", e.target.checked) - } - /> - - { - value && - updateSettingField([ - { fieldName: "default_page", newValue: value }, - ]); - }} - /> - {isEnterpriseEnabled && ( <> Chat Settings diff --git a/web/src/app/admin/settings/interfaces.ts b/web/src/app/admin/settings/interfaces.ts index 38959fc8cd2..32ce1d01067 100644 --- a/web/src/app/admin/settings/interfaces.ts +++ b/web/src/app/admin/settings/interfaces.ts @@ -5,14 +5,12 @@ export enum GatingType { } export interface Settings { - chat_page_enabled: boolean; - search_page_enabled: boolean; - default_page: "search" | "chat"; maximum_chat_retention_days: number | null; notifications: Notification[]; needs_reindexing: boolean; gpu_enabled: boolean; product_gating: GatingType; + auto_scroll: boolean; } export enum NotificationType { @@ -54,6 +52,7 @@ export interface EnterpriseSettings { custom_popup_header: string | null; custom_popup_content: string | null; enable_consent_screen: boolean | null; + auto_scroll: boolean; } export interface CombinedSettings { diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 94f336ba885..634dc0624b8 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -8,7 +8,6 @@ import { ChatFileType, ChatSession, ChatSessionSharedStatus, - DocumentsResponse, FileDescriptor, FileChatDisplay, Message, @@ -60,7 +59,7 @@ import { useDocumentSelection } from "./useDocumentSelection"; import { LlmOverride, useFilters, useLlmOverride } from "@/lib/hooks"; import { computeAvailableFilters } from "@/lib/filters"; import { ChatState, FeedbackType, RegenerationState } from "./types"; -import { DocumentSidebar } from "./documentSidebar/DocumentSidebar"; +import { ChatFilters } from "./documentSidebar/ChatFilters"; import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoader"; import { FeedbackModal } from "./modal/FeedbackModal"; import { ShareChatSessionModal } from "./modal/ShareChatSessionModal"; @@ -71,6 +70,7 @@ import { StarterMessages } from "../../components/assistants/StarterMessage"; import { AnswerPiecePacket, DanswerDocument, + FinalContextDocs, StreamStopInfo, StreamStopReason, } from "@/lib/search/interfaces"; @@ -105,14 +105,9 @@ import BlurBackground from "./shared_chat_search/BlurBackground"; import { NoAssistantModal } from "@/components/modals/NoAssistantModal"; import { useAssistants } from "@/components/context/AssistantsContext"; import { Separator } from "@/components/ui/separator"; -import { - Card, - CardContent, - CardDescription, - CardHeader, -} from "@/components/ui/card"; -import { AssistantIcon } from "@/components/assistants/AssistantIcon"; import AssistantBanner from "../../components/assistants/AssistantBanner"; +import AssistantSelector from "@/components/chat_search/AssistantSelector"; +import { Modal } from "@/components/Modal"; const TEMP_USER_MESSAGE_ID = -1; const TEMP_ASSISTANT_MESSAGE_ID = -2; @@ -132,8 +127,9 @@ export function ChatPage({ const { chatSessions, - availableSources, - availableDocumentSets, + ccPairs, + tags, + documentSets, llmProviders, folders, openedFolders, @@ -142,6 +138,36 @@ export function ChatPage({ shouldShowWelcomeModal, refreshChatSessions, } = useChatContext(); + function useScreenSize() { + const [screenSize, setScreenSize] = useState({ + width: typeof window !== "undefined" ? window.innerWidth : 0, + height: typeof window !== "undefined" ? window.innerHeight : 0, + }); + + useEffect(() => { + const handleResize = () => { + setScreenSize({ + width: window.innerWidth, + height: window.innerHeight, + }); + }; + + window.addEventListener("resize", handleResize); + return () => window.removeEventListener("resize", handleResize); + }, []); + + return screenSize; + } + + const { height: screenHeight } = useScreenSize(); + + const getContainerHeight = () => { + if (autoScrollEnabled) return undefined; + + if (screenHeight < 600) return "20vh"; + if (screenHeight < 1200) return "30vh"; + return "40vh"; + }; // handle redirect if chat page is disabled // NOTE: this must be done here, in a client component since @@ -149,9 +175,11 @@ export function ChatPage({ // available in server-side components const settings = useContext(SettingsContext); const enterpriseSettings = settings?.enterpriseSettings; - if (settings?.settings?.chat_page_enabled === false) { - router.push("/search"); - } + + const [documentSidebarToggled, setDocumentSidebarToggled] = useState(false); + const [filtersToggled, setFiltersToggled] = useState(false); + + const [userSettingsToggled, setUserSettingsToggled] = useState(false); const { assistants: availableAssistants, finalAssistants } = useAssistants(); @@ -159,16 +187,13 @@ export function ChatPage({ !shouldShowWelcomeModal ); - const { user, isAdmin, isLoadingUser, refreshUser } = useUser(); - + const { user, isAdmin, isLoadingUser } = useUser(); const slackChatId = searchParams.get("slackChatId"); - const existingChatIdRaw = searchParams.get("chatId"); const [sendOnLoad, setSendOnLoad] = useState( searchParams.get(SEARCH_PARAM_NAMES.SEND_ON_LOAD) ); - const currentPersonaId = searchParams.get(SEARCH_PARAM_NAMES.PERSONA_ID); const modelVersionFromSearchParams = searchParams.get( SEARCH_PARAM_NAMES.STRUCTURED_MODEL ); @@ -261,7 +286,7 @@ export function ChatPage({ refreshRecentAssistants, } = useAssistants(); - const liveAssistant = + const liveAssistant: Persona | undefined = alternativeAssistant || selectedAssistant || recentAssistants[0] || @@ -269,8 +294,20 @@ export function ChatPage({ availableAssistants[0]; const noAssistants = liveAssistant == null || liveAssistant == undefined; + + const availableSources = ccPairs.map((ccPair) => ccPair.source); + const [finalAvailableSources, finalAvailableDocumentSets] = + computeAvailableFilters({ + selectedPersona: availableAssistants.find( + (assistant) => assistant.id === liveAssistant?.id + ), + availableSources: availableSources, + availableDocumentSets: documentSets, + }); + // always set the model override for the chat session, when an assistant, llm provider, or user preference exists useEffect(() => { + if (noAssistants) return; const personaDefault = getLLMProviderOverrideForPersona( liveAssistant, llmProviders @@ -357,9 +394,7 @@ export function ChatPage({ textAreaRef.current?.focus(); // only clear things if we're going from one chat session to another - const isChatSessionSwitch = - chatSessionIdRef.current !== null && - existingChatSessionId !== priorChatSessionId; + const isChatSessionSwitch = existingChatSessionId !== priorChatSessionId; if (isChatSessionSwitch) { // de-select documents clearSelectedDocuments(); @@ -449,9 +484,9 @@ export function ChatPage({ } if (shouldScrollToBottom) { - if (!hasPerformedInitialScroll) { + if (!hasPerformedInitialScroll && autoScrollEnabled) { clientScrollToBottom(); - } else if (isChatSessionSwitch) { + } else if (isChatSessionSwitch && autoScrollEnabled) { clientScrollToBottom(true); } } @@ -759,7 +794,7 @@ export function ChatPage({ useEffect(() => { async function fetchMaxTokens() { const response = await fetch( - `/api/chat/max-selected-document-tokens?persona_id=${liveAssistant.id}` + `/api/chat/max-selected-document-tokens?persona_id=${liveAssistant?.id}` ); if (response.ok) { const maxTokens = (await response.json()).max_tokens as number; @@ -833,11 +868,13 @@ export function ChatPage({ 0 )}px`; - scrollableDivRef?.current.scrollBy({ - left: 0, - top: Math.max(heightDifference, 0), - behavior: "smooth", - }); + if (autoScrollEnabled) { + scrollableDivRef?.current.scrollBy({ + left: 0, + top: Math.max(heightDifference, 0), + behavior: "smooth", + }); + } } previousHeight.current = newHeight; } @@ -884,6 +921,7 @@ export function ChatPage({ endDivRef.current.scrollIntoView({ behavior: fast ? "auto" : "smooth", }); + setHasPerformedInitialScroll(true); } }, 50); @@ -1035,7 +1073,9 @@ export function ChatPage({ } setAlternativeGeneratingAssistant(alternativeAssistantOverride); + clientScrollToBottom(); + let currChatSessionId: string; const isNewSession = chatSessionIdRef.current === null; const searchParamBasedChatSessionName = @@ -1281,8 +1321,8 @@ export function ChatPage({ if (Object.hasOwn(packet, "answer_piece")) { answer += (packet as AnswerPiecePacket).answer_piece; - } else if (Object.hasOwn(packet, "top_documents")) { - documents = (packet as DocumentsResponse).top_documents; + } else if (Object.hasOwn(packet, "final_context_docs")) { + documents = (packet as FinalContextDocs).final_context_docs; retrievalType = RetrievalType.Search; if (documents && documents.length > 0) { // point to the latest message (we don't know the messageId yet, which is why @@ -1379,8 +1419,7 @@ export function ChatPage({ type: error ? "error" : "assistant", retrievalType, query: finalMessage?.rephrased_query || query, - documents: - finalMessage?.context_docs?.top_documents || documents, + documents: documents, citations: finalMessage?.citations || {}, files: finalMessage?.files || aiMessageImages || [], toolCall: finalMessage?.tool_call || toolCall, @@ -1599,6 +1638,11 @@ export function ChatPage({ mobile: settings?.isMobile, }); + const autoScrollEnabled = + user?.preferences?.auto_scroll == null + ? settings?.enterpriseSettings?.auto_scroll || false + : user?.preferences?.auto_scroll!; + useScrollonStream({ chatState: currentSessionChatState, scrollableDivRef, @@ -1607,6 +1651,7 @@ export function ChatPage({ debounceNumber, waitForScrollRef, mobile: settings?.isMobile, + enableAutoScroll: autoScrollEnabled, }); // Virtualization + Scrolling related effects and functions @@ -1756,6 +1801,13 @@ export function ChatPage({ liveAssistant ); }); + + useEffect(() => { + if (!retrievalEnabled) { + setDocumentSidebarToggled(false); + } + }, [retrievalEnabled]); + const [stackTraceModalContent, setStackTraceModalContent] = useState< string | null >(null); @@ -1764,7 +1816,41 @@ export function ChatPage({ const [settingsToggled, setSettingsToggled] = useState(false); const currentPersona = alternativeAssistant || liveAssistant; + useEffect(() => { + const handleSlackChatRedirect = async () => { + if (!slackChatId) return; + + // Set isReady to false before starting retrieval to display loading text + setIsReady(false); + + try { + const response = await fetch("/api/chat/seed-chat-session-from-slack", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + chat_session_id: slackChatId, + }), + }); + + if (!response.ok) { + throw new Error("Failed to seed chat from Slack"); + } + + const data = await response.json(); + router.push(data.redirect_url); + } catch (error) { + console.error("Error seeding chat from Slack:", error); + setPopup({ + message: "Failed to load chat from Slack", + type: "error", + }); + } + }; + handleSlackChatRedirect(); + }, [searchParams, router]); useEffect(() => { const handleKeyDown = (event: KeyboardEvent) => { if (event.metaKey || event.ctrlKey) { @@ -1795,9 +1881,30 @@ export function ChatPage({ setSharedChatSession(chatSession); }; const [documentSelection, setDocumentSelection] = useState(false); - const toggleDocumentSelectionAspects = () => { - setDocumentSelection((documentSelection) => !documentSelection); - setShowDocSidebar(false); + // const toggleDocumentSelectionAspects = () => { + // setDocumentSelection((documentSelection) => !documentSelection); + // setShowDocSidebar(false); + // }; + + const toggleDocumentSidebar = () => { + if (!documentSidebarToggled) { + setFiltersToggled(false); + setDocumentSidebarToggled(true); + } else if (!filtersToggled) { + setDocumentSidebarToggled(false); + } else { + setFiltersToggled(false); + } + }; + const toggleFilters = () => { + if (!documentSidebarToggled) { + setFiltersToggled(true); + setDocumentSidebarToggled(true); + } else if (filtersToggled) { + setDocumentSidebarToggled(false); + } else { + setFiltersToggled(true); + } }; interface RegenerationRequest { @@ -1815,54 +1922,23 @@ export function ChatPage({ }); }; } - - useEffect(() => { - const handleSlackChatRedirect = async () => { - if (!slackChatId) return; - - // Set isReady to false before starting retrieval to display loading text - setIsReady(false); - - try { - const response = await fetch("/api/chat/seed-chat-session-from-slack", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - chat_session_id: slackChatId, - }), - }); - - if (!response.ok) { - throw new Error("Failed to seed chat from Slack"); - } - - const data = await response.json(); - router.push(data.redirect_url); - } catch (error) { - console.error("Error seeding chat from Slack:", error); - setPopup({ - message: "Failed to load chat from Slack", - type: "error", - }); - } - }; - - handleSlackChatRedirect(); - }, [searchParams, router]); + if (noAssistants) + return ( + <> + + + + ); return ( <> - {showApiKeyModal && !shouldShowWelcomeModal ? ( + {showApiKeyModal && !shouldShowWelcomeModal && ( setShowApiKeyModal(false)} setPopup={setPopup} /> - ) : ( - noAssistants && )} {/* ChatPopup is a custom popup that displays a admin-specified message on initial user visit. @@ -1886,16 +1962,46 @@ export function ChatPage({ /> )} - {settingsToggled && ( + {(settingsToggled || userSettingsToggled) && ( setSettingsToggled(false)} + onClose={() => { + setUserSettingsToggled(false); + setSettingsToggled(false); + }} /> )} + {retrievalEnabled && documentSidebarToggled && settings?.isMobile && ( +
+ + { + setDocumentSidebarToggled(false); + }} + selectedMessage={aiMessage} + selectedDocuments={selectedDocuments} + toggleDocumentSelection={toggleDocumentSelection} + clearSelectedDocuments={clearSelectedDocuments} + selectedDocumentTokens={selectedDocumentTokens} + maxTokens={maxTokens} + initialWidth={400} + isOpen={true} + /> + +
+ )} + {deletingChatSession && (
+ {!settings?.isMobile && retrievalEnabled && ( +
+ setDocumentSidebarToggled(false)} + selectedMessage={aiMessage} + selectedDocuments={selectedDocuments} + toggleDocumentSelection={toggleDocumentSelection} + clearSelectedDocuments={clearSelectedDocuments} + selectedDocumentTokens={selectedDocumentTokens} + maxTokens={maxTokens} + initialWidth={400} + isOpen={documentSidebarToggled} + /> +
+ )} -
+
{liveAssistant && ( setUserSettingsToggled(true)} + liveAssistant={liveAssistant} + onAssistantChange={onAssistantChange} sidebarToggled={toggledSidebar} reset={() => setMessage("")} page="chat" @@ -2018,6 +2171,8 @@ export function ChatPage({ } toggleSidebar={toggleSidebar} currentChatSession={selectedChatSession} + documentSidebarToggled={documentSidebarToggled} + llmOverrideManager={llmOverrideManager} /> )} @@ -2039,7 +2194,7 @@ export function ChatPage({ duration-300 ease-in-out h-full - ${toggledSidebar ? "w-[250px]" : "w-[0px]"} + ${toggledSidebar ? "w-[200px]" : "w-[0px]"} `} >
)} @@ -2049,9 +2204,55 @@ export function ChatPage({ {...getRootProps()} >
+ {liveAssistant && onAssistantChange && ( +
+ {!settings?.isMobile && ( +
+ )} + + + {!settings?.isMobile && ( +
+ )} +
+ )} + {/* ChatBanner is a custom banner that displays a admin-specified message at the top of the chat page. Oly used in the EE version of the app. */} @@ -2059,7 +2260,7 @@ export function ChatPage({ !isFetchingChatMessages && currentSessionChatState == "input" && !loadingError && ( -
+
{ + if ( + !documentSidebarToggled || + (documentSidebarToggled && + selectedMessageForDocDisplay === + message.messageId) + ) { + toggleDocumentSidebar(); + } + setSelectedMessageForDocDisplay( + message.messageId + ); + }} docs={message.documents} currentPersona={liveAssistant} alternativeAssistant={ @@ -2268,7 +2488,6 @@ export function ChatPage({ } messageId={message.messageId} content={message.message} - // content={message.message} files={message.files} query={ messageHistory[i]?.query || undefined @@ -2454,6 +2673,15 @@ export function ChatPage({ />
)} + {messageHistory.length > 0 && ( +
+ )} {/* Some padding at the bottom so the search bar has space at the bottom to not cover the last message*/}
@@ -2477,6 +2705,15 @@ export function ChatPage({
)} { + clearSelectedDocuments(); + }} + removeFilters={() => { + filterManager.setSelectedSources([]); + filterManager.setSelectedTags([]); + filterManager.setSelectedDocumentSets([]); + setDocumentSidebarToggled(false); + }} showConfigureAPIKey={() => setShowApiKeyModal(true) } @@ -2499,6 +2736,9 @@ export function ChatPage({ llmOverrideManager={llmOverrideManager} files={currentMessageFiles} setFiles={setCurrentMessageFiles} + toggleFilters={ + retrievalEnabled ? toggleFilters : undefined + } handleFileUpload={handleImageUpload} textAreaRef={textAreaRef} chatSessionId={chatSessionIdRef.current!} @@ -2529,6 +2769,23 @@ export function ChatPage({
+ {!settings?.isMobile && ( +
+ )}
)} @@ -2537,7 +2794,11 @@ export function ChatPage({
@@ -2548,20 +2809,8 @@ export function ChatPage({
+ {/* Right Sidebar - DocumentSidebar */}
- setDocumentSelection(false)} - selectedMessage={aiMessage} - selectedDocuments={selectedDocuments} - toggleDocumentSelection={toggleDocumentSelection} - clearSelectedDocuments={clearSelectedDocuments} - selectedDocumentTokens={selectedDocumentTokens} - maxTokens={maxTokens} - isLoading={isFetchingChatMessages} - isOpen={documentSelection} - /> ); } diff --git a/web/src/app/chat/documentSidebar/ChatDocumentDisplay.tsx b/web/src/app/chat/documentSidebar/ChatDocumentDisplay.tsx index 85ac429c497..5f61e4b9db8 100644 --- a/web/src/app/chat/documentSidebar/ChatDocumentDisplay.tsx +++ b/web/src/app/chat/documentSidebar/ChatDocumentDisplay.tsx @@ -1,133 +1,117 @@ -import { HoverPopup } from "@/components/HoverPopup"; import { SourceIcon } from "@/components/SourceIcon"; -import { PopupSpec } from "@/components/admin/connectors/Popup"; import { DanswerDocument } from "@/lib/search/interfaces"; -import { FiInfo, FiRadio } from "react-icons/fi"; +import { FiTag } from "react-icons/fi"; import { DocumentSelector } from "./DocumentSelector"; -import { - DocumentMetadataBlock, - buildDocumentSummaryDisplay, -} from "@/components/search/DocumentDisplay"; -import { InternetSearchIcon } from "@/components/InternetSearchIcon"; +import { buildDocumentSummaryDisplay } from "@/components/search/DocumentDisplay"; +import { DocumentUpdatedAtBadge } from "@/components/search/DocumentUpdatedAtBadge"; +import { MetadataBadge } from "@/components/MetadataBadge"; +import { WebResultIcon } from "@/components/WebResultIcon"; interface DocumentDisplayProps { document: DanswerDocument; - queryEventId: number | null; - isAIPick: boolean; + modal?: boolean; isSelected: boolean; handleSelect: (documentId: string) => void; - setPopup: (popupSpec: PopupSpec | null) => void; tokenLimitReached: boolean; } +export function DocumentMetadataBlock({ + modal, + document, +}: { + modal?: boolean; + document: DanswerDocument; +}) { + const MAX_METADATA_ITEMS = 3; + const metadataEntries = Object.entries(document.metadata); + + return ( +
+ {document.updated_at && ( + + )} + + {metadataEntries.length > 0 && ( + <> +
+
+ {metadataEntries + .slice(0, MAX_METADATA_ITEMS) + .map(([key, value], index) => ( + + ))} + {metadataEntries.length > MAX_METADATA_ITEMS && ( + ... + )} +
+ + )} +
+ ); +} + export function ChatDocumentDisplay({ document, - queryEventId, - isAIPick, + modal, isSelected, handleSelect, - setPopup, tokenLimitReached, }: DocumentDisplayProps) { const isInternet = document.is_internet; - // Consider reintroducing null scored docs in the future if (document.score === null) { return null; } return ( -
-
+
+
- {isInternet ? ( - - ) : ( - - )} -

- {document.semantic_identifier || document.document_id} -

-
- {document.score !== null && ( -
- {isAIPick && ( -
- } - popupContent={ -
-
-
- -
-
The AI liked this doc!
-
-
- } - direction="bottom" - style="dark" - /> -
+
+ {document.is_internet || document.source_type === "web" ? ( + + ) : ( + )} -
- {Math.abs(document.score).toFixed(2)} +
+ {(document.semantic_identifier || document.document_id).length > + (modal ? 30 : 40) + ? `${(document.semantic_identifier || document.document_id) + .slice(0, modal ? 30 : 40) + .trim()}...` + : document.semantic_identifier || document.document_id}
- )} - - {!isInternet && ( - handleSelect(document.document_id)} - isDisabled={tokenLimitReached && !isSelected} - /> - )} -
-
-
- -
-
-

- {buildDocumentSummaryDisplay(document.match_highlights, document.blurb)} - test -

-
- {/* - // TODO: find a way to include this - {queryEventId && ( - - )} */} + +
+ {buildDocumentSummaryDisplay( + document.match_highlights, + document.blurb + )} +
+
+ {!isInternet && ( + handleSelect(document.document_id)} + isDisabled={tokenLimitReached && !isSelected} + /> + )} +
+
); diff --git a/web/src/app/chat/documentSidebar/ChatFilters.tsx b/web/src/app/chat/documentSidebar/ChatFilters.tsx new file mode 100644 index 00000000000..616595abfc1 --- /dev/null +++ b/web/src/app/chat/documentSidebar/ChatFilters.tsx @@ -0,0 +1,186 @@ +import { DanswerDocument } from "@/lib/search/interfaces"; +import { ChatDocumentDisplay } from "./ChatDocumentDisplay"; +import { usePopup } from "@/components/admin/connectors/Popup"; +import { removeDuplicateDocs } from "@/lib/documentUtils"; +import { Message } from "../interfaces"; +import { ForwardedRef, forwardRef, useEffect, useState } from "react"; +import { FilterManager } from "@/lib/hooks"; +import { CCPairBasicInfo, DocumentSet, Tag } from "@/lib/types"; +import { SourceSelector } from "../shared_chat_search/SearchFilters"; +import { XIcon } from "@/components/icons/icons"; + +interface ChatFiltersProps { + filterManager: FilterManager; + closeSidebar: () => void; + selectedMessage: Message | null; + selectedDocuments: DanswerDocument[] | null; + toggleDocumentSelection: (document: DanswerDocument) => void; + clearSelectedDocuments: () => void; + selectedDocumentTokens: number; + maxTokens: number; + initialWidth: number; + isOpen: boolean; + modal: boolean; + ccPairs: CCPairBasicInfo[]; + tags: Tag[]; + documentSets: DocumentSet[]; + showFilters: boolean; +} + +export const ChatFilters = forwardRef( + ( + { + closeSidebar, + modal, + selectedMessage, + selectedDocuments, + filterManager, + toggleDocumentSelection, + clearSelectedDocuments, + selectedDocumentTokens, + maxTokens, + initialWidth, + isOpen, + ccPairs, + tags, + documentSets, + showFilters, + }, + ref: ForwardedRef + ) => { + const { popup, setPopup } = usePopup(); + const [delayedSelectedDocumentCount, setDelayedSelectedDocumentCount] = + useState(0); + + useEffect(() => { + const timer = setTimeout( + () => { + setDelayedSelectedDocumentCount(selectedDocuments?.length || 0); + }, + selectedDocuments?.length == 0 ? 1000 : 0 + ); + + return () => clearTimeout(timer); + }, [selectedDocuments]); + + const selectedDocumentIds = + selectedDocuments?.map((document) => document.document_id) || []; + + const currentDocuments = selectedMessage?.documents || null; + const dedupedDocuments = removeDuplicateDocs(currentDocuments || []); + + const tokenLimitReached = selectedDocumentTokens > maxTokens - 75; + + const hasSelectedDocuments = selectedDocumentIds.length > 0; + + return ( +
{ + if (e.target === e.currentTarget) { + closeSidebar(); + } + }} + > +
+
+ {popup} +
+

+ {showFilters ? "Filters" : "Sources"} +

+ +
+
+
+ {showFilters ? ( + ccPair.source)} + availableTags={tags} + /> + ) : ( + <> + {dedupedDocuments.length > 0 ? ( + dedupedDocuments.map((document, ind) => ( +
+ { + toggleDocumentSelection( + dedupedDocuments.find( + (doc) => doc.document_id === documentId + )! + ); + }} + tokenLimitReached={tokenLimitReached} + /> +
+ )) + ) : ( +
+ )} + + )} +
+
+ {!showFilters && ( +
+ +
+ )} +
+
+ ); + } +); + +ChatFilters.displayName = "ChatFilters"; diff --git a/web/src/app/chat/documentSidebar/DocumentSidebar.tsx b/web/src/app/chat/documentSidebar/DocumentSidebar.tsx deleted file mode 100644 index 021c2398157..00000000000 --- a/web/src/app/chat/documentSidebar/DocumentSidebar.tsx +++ /dev/null @@ -1,168 +0,0 @@ -import { DanswerDocument } from "@/lib/search/interfaces"; -import Text from "@/components/ui/text"; -import { ChatDocumentDisplay } from "./ChatDocumentDisplay"; -import { usePopup } from "@/components/admin/connectors/Popup"; -import { removeDuplicateDocs } from "@/lib/documentUtils"; -import { Message } from "../interfaces"; -import { ForwardedRef, forwardRef } from "react"; -import { Separator } from "@/components/ui/separator"; - -interface DocumentSidebarProps { - closeSidebar: () => void; - selectedMessage: Message | null; - selectedDocuments: DanswerDocument[] | null; - toggleDocumentSelection: (document: DanswerDocument) => void; - clearSelectedDocuments: () => void; - selectedDocumentTokens: number; - maxTokens: number; - isLoading: boolean; - initialWidth: number; - isOpen: boolean; -} - -export const DocumentSidebar = forwardRef( - ( - { - closeSidebar, - selectedMessage, - selectedDocuments, - toggleDocumentSelection, - clearSelectedDocuments, - selectedDocumentTokens, - maxTokens, - isLoading, - initialWidth, - isOpen, - }, - ref: ForwardedRef - ) => { - const { popup, setPopup } = usePopup(); - - const selectedDocumentIds = - selectedDocuments?.map((document) => document.document_id) || []; - - const currentDocuments = selectedMessage?.documents || null; - const dedupedDocuments = removeDuplicateDocs(currentDocuments || []); - - // NOTE: do not allow selection if less than 75 tokens are left - // this is to prevent the case where they are able to select the doc - // but it basically is unused since it's truncated right at the very - // start of the document (since title + metadata + misc overhead) takes up - // space - const tokenLimitReached = selectedDocumentTokens > maxTokens - 75; - - return ( -
{ - if (e.target === e.currentTarget) { - closeSidebar(); - } - }} - > -
-
- {popup} -
- {dedupedDocuments.length} Document - {dedupedDocuments.length > 1 ? "s" : ""} -

- Select to add to continuous context - - Learn more - -

-
- - - - {currentDocuments ? ( -
- {dedupedDocuments.length > 0 ? ( - dedupedDocuments.map((document, ind) => ( -
- { - toggleDocumentSelection( - dedupedDocuments.find( - (document) => document.document_id === documentId - )! - ); - }} - tokenLimitReached={tokenLimitReached} - /> -
- )) - ) : ( -
- No documents found for the query. -
- )} -
- ) : ( - !isLoading && ( -
- - When you run ask a question, the retrieved documents will - show up here! - -
- ) - )} -
- -
-
- - - -
-
-
- ); - } -); - -DocumentSidebar.displayName = "DocumentSidebar"; diff --git a/web/src/app/chat/input/ChatInputBar.tsx b/web/src/app/chat/input/ChatInputBar.tsx index 9dd3d5274c4..5d786a1784a 100644 --- a/web/src/app/chat/input/ChatInputBar.tsx +++ b/web/src/app/chat/input/ChatInputBar.tsx @@ -1,13 +1,9 @@ import React, { useContext, useEffect, useRef, useState } from "react"; -import { FiPlusCircle, FiPlus, FiInfo, FiX } from "react-icons/fi"; +import { FiPlusCircle, FiPlus, FiInfo, FiX, FiSearch } from "react-icons/fi"; import { ChatInputOption } from "./ChatInputOption"; import { Persona } from "@/app/admin/assistants/interfaces"; import { InputPrompt } from "@/app/admin/prompt-library/interfaces"; -import { - FilterManager, - getDisplayNameForModel, - LlmOverrideManager, -} from "@/lib/hooks"; +import { FilterManager, LlmOverrideManager } from "@/lib/hooks"; import { SelectedFilterDisplay } from "./SelectedFilterDisplay"; import { useChatContext } from "@/components/context/ChatContext"; import { getFinalLLM } from "@/lib/llm/utils"; @@ -18,15 +14,10 @@ import { } from "../files/InputBarPreview"; import { AssistantsIconSkeleton, - CpuIconSkeleton, FileIcon, SendIcon, StopGeneratingIcon, } from "@/components/icons/icons"; -import { IconType } from "react-icons"; -import Popup from "../../../components/popup/Popup"; -import { LlmTab } from "../modal/configuration/LlmTab"; -import { AssistantsTab } from "../modal/configuration/AssistantsTab"; import { DanswerDocument } from "@/lib/search/interfaces"; import { AssistantIcon } from "@/components/assistants/AssistantIcon"; import { @@ -40,10 +31,18 @@ import { SettingsContext } from "@/components/settings/SettingsProvider"; import { ChatState } from "../types"; import UnconfiguredProviderText from "@/components/chat_search/UnconfiguredProviderText"; import { useAssistants } from "@/components/context/AssistantsContext"; +import AnimatedToggle from "@/components/search/SearchBar"; +import { Popup } from "@/components/admin/connectors/Popup"; +import { AssistantsTab } from "../modal/configuration/AssistantsTab"; +import { IconType } from "react-icons"; +import { LlmTab } from "../modal/configuration/LlmTab"; +import { XIcon } from "lucide-react"; const MAX_INPUT_HEIGHT = 200; export function ChatInputBar({ + removeFilters, + removeDocs, openModelSettings, showDocs, showConfigureAPIKey, @@ -68,7 +67,10 @@ export function ChatInputBar({ alternativeAssistant, chatSessionId, inputPrompts, + toggleFilters, }: { + removeFilters: () => void; + removeDocs: () => void; showConfigureAPIKey: () => void; openModelSettings: () => void; chatState: ChatState; @@ -90,6 +92,7 @@ export function ChatInputBar({ handleFileUpload: (files: File[]) => void; textAreaRef: React.RefObject; chatSessionId?: string; + toggleFilters?: () => void; }) { useEffect(() => { const textarea = textAreaRef.current; @@ -370,9 +373,9 @@ export function ChatInputBar({
)} -
+ {/*
-
+
*/} @@ -429,16 +432,21 @@ export function ChatInputBar({ )} {(selectedDocuments.length > 0 || files.length > 0) && (
-
+
{selectedDocuments.length > 0 && ( )} {files.map((file) => ( @@ -529,72 +537,6 @@ export function ChatInputBar({ suppressContentEditableWarning={true} />
- ( - { - setSelectedAssistant(assistant); - close(); - }} - /> - )} - flexPriority="shrink" - position="top" - mobilePosition="top-right" - > - - - ( - - )} - position="top" - > - - - + {toggleFilters && ( + + )}
diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index 00529776407..b5264ba1c54 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -2,6 +2,7 @@ import { AnswerPiecePacket, DanswerDocument, Filters, + FinalContextDocs, StreamStopInfo, } from "@/lib/search/interfaces"; import { handleSSEStream } from "@/lib/search/streamingUtils"; @@ -102,6 +103,7 @@ export type PacketType = | ToolCallMetadata | BackendMessage | AnswerPiecePacket + | FinalContextDocs | DocumentsResponse | FileChatDisplay | StreamingError @@ -147,7 +149,6 @@ export async function* sendMessage({ }): AsyncGenerator { const documentsAreSelected = selectedDocumentIds && selectedDocumentIds.length > 0; - const body = JSON.stringify({ alternate_assistant_id: alternateAssistantId, chat_session_id: chatSessionId, @@ -639,6 +640,7 @@ export async function useScrollonStream({ endDivRef, debounceNumber, mobile, + enableAutoScroll, }: { chatState: ChatState; scrollableDivRef: RefObject; @@ -647,6 +649,7 @@ export async function useScrollonStream({ endDivRef: RefObject; debounceNumber: number; mobile?: boolean; + enableAutoScroll?: boolean; }) { const mobileDistance = 900; // distance that should "engage" the scroll const desktopDistance = 500; // distance that should "engage" the scroll @@ -659,6 +662,10 @@ export async function useScrollonStream({ const previousScroll = useRef(0); useEffect(() => { + if (!enableAutoScroll) { + return; + } + if (chatState != "input" && scrollableDivRef && scrollableDivRef.current) { const newHeight: number = scrollableDivRef.current?.scrollTop!; const heightDifference = newHeight - previousScroll.current; @@ -716,7 +723,7 @@ export async function useScrollonStream({ // scroll on end of stream if within distance useEffect(() => { - if (scrollableDivRef?.current && chatState == "input") { + if (scrollableDivRef?.current && chatState == "input" && enableAutoScroll) { if (scrollDist.current < distance - 50) { scrollableDivRef?.current?.scrollBy({ left: 0, diff --git a/web/src/app/chat/message/MemoizedTextComponents.tsx b/web/src/app/chat/message/MemoizedTextComponents.tsx index 9ab0e28e3ca..7c8144e8ced 100644 --- a/web/src/app/chat/message/MemoizedTextComponents.tsx +++ b/web/src/app/chat/message/MemoizedTextComponents.tsx @@ -1,8 +1,50 @@ import { Citation } from "@/components/search/results/Citation"; +import { WebResultIcon } from "@/components/WebResultIcon"; +import { LoadedDanswerDocument } from "@/lib/search/interfaces"; +import { getSourceMetadata } from "@/lib/sources"; +import { ValidSources } from "@/lib/types"; import React, { memo } from "react"; +import isEqual from "lodash/isEqual"; + +export const MemoizedAnchor = memo(({ docs, children }: any) => { + console.log(children); + const value = children?.toString(); + if (value?.startsWith("[") && value?.endsWith("]")) { + const match = value.match(/\[(\d+)\]/); + if (match) { + const index = parseInt(match[1], 10) - 1; + const associatedDoc = docs && docs[index]; + + const url = associatedDoc?.link + ? new URL(associatedDoc.link).origin + "/favicon.ico" + : ""; + + const getIcon = (sourceType: ValidSources, link: string) => { + return getSourceMetadata(sourceType).icon({ size: 18 }); + }; + + const icon = + associatedDoc?.source_type === "web" ? ( + + ) : ( + getIcon( + associatedDoc?.source_type || "web", + associatedDoc?.link || "" + ) + ); + + return ( + + {children} + + ); + } + } + return {children}; +}); export const MemoizedLink = memo((props: any) => { - const { node, ...rest } = props; + const { node, document, ...rest } = props; const value = rest.children; if (value?.toString().startsWith("*")) { @@ -10,7 +52,16 @@ export const MemoizedLink = memo((props: any) => {
); } else if (value?.toString().startsWith("[")) { - return {rest.children}; + return ( + + {rest.children} + + ); } else { return ( { } }); -export const MemoizedParagraph = memo(({ ...props }: any) => { - return

; -}); +export const MemoizedParagraph = memo( + function MemoizedParagraph({ children }: any) { + return

{children}

; + }, + (prevProps, nextProps) => { + const areEqual = isEqual(prevProps.children, nextProps.children); + return areEqual; + } +); +MemoizedAnchor.displayName = "MemoizedAnchor"; MemoizedLink.displayName = "MemoizedLink"; MemoizedParagraph.displayName = "MemoizedParagraph"; diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index cc4f9c9cac8..0aa9ba82683 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -8,14 +8,22 @@ import { FiGlobe, } from "react-icons/fi"; import { FeedbackType } from "../types"; -import React, { useContext, useEffect, useMemo, useRef, useState } from "react"; +import React, { + memo, + useCallback, + useContext, + useEffect, + useMemo, + useRef, + useState, +} from "react"; import ReactMarkdown from "react-markdown"; import { DanswerDocument, FilteredDanswerDocument, } from "@/lib/search/interfaces"; import { SearchSummary } from "./SearchSummary"; -import { SourceIcon } from "@/components/SourceIcon"; + import { SkippedSearch } from "./SkippedSearch"; import remarkGfm from "remark-gfm"; import { CopyButton } from "@/components/CopyButton"; @@ -36,8 +44,6 @@ import "prismjs/themes/prism-tomorrow.css"; import "./custom-code-styles.css"; import { Persona } from "@/app/admin/assistants/interfaces"; import { AssistantIcon } from "@/components/assistants/AssistantIcon"; -import { Citation } from "@/components/search/results/Citation"; -import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay"; import { LikeFeedback, DislikeFeedback } from "@/components/icons/icons"; import { @@ -52,16 +58,18 @@ import { TooltipTrigger, } from "@/components/ui/tooltip"; import { useMouseTracking } from "./hooks"; -import { InternetSearchIcon } from "@/components/InternetSearchIcon"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import GeneratingImageDisplay from "../tools/GeneratingImageDisplay"; import RegenerateOption from "../RegenerateOption"; import { LlmOverride } from "@/lib/hooks"; import { ContinueGenerating } from "./ContinueMessage"; -import { MemoizedLink, MemoizedParagraph } from "./MemoizedTextComponents"; +import { MemoizedAnchor, MemoizedParagraph } from "./MemoizedTextComponents"; import { extractCodeText } from "./codeUtils"; import ToolResult from "../../../components/tools/ToolResult"; import CsvContent from "../../../components/tools/CSVContent"; +import SourceCard, { + SeeMoreBlock, +} from "@/components/chat_search/sources/SourceCard"; const TOOLS_WITH_CUSTOM_HANDLING = [ SEARCH_TOOL_NAME, @@ -155,6 +163,7 @@ function FileDisplay({ export const AIMessage = ({ regenerate, overriddenModel, + selectedMessageForDocDisplay, continueGenerating, shared, isActive, @@ -162,6 +171,7 @@ export const AIMessage = ({ alternativeAssistant, docs, messageId, + documentSelectionToggled, content, files, selectedDocuments, @@ -178,7 +188,10 @@ export const AIMessage = ({ currentPersona, otherMessagesCanSwitchTo, onMessageSelection, + index, }: { + index?: number; + selectedMessageForDocDisplay?: number | null; shared?: boolean; isActive?: boolean; continueGenerating?: () => void; @@ -191,6 +204,7 @@ export const AIMessage = ({ currentPersona: Persona; messageId: number | null; content: string | JSX.Element; + documentSelectionToggled?: boolean; files?: FileDescriptor[]; query?: string; citedDocuments?: [string, DanswerDocument][] | null; @@ -287,18 +301,31 @@ export const AIMessage = ({ }); } + const paragraphCallback = useCallback( + (props: any) => {props.children}, + [] + ); + + const anchorCallback = useCallback( + (props: any) => ( + {props.children} + ), + [docs] + ); + const currentMessageInd = messageId ? otherMessagesCanSwitchTo?.indexOf(messageId) : undefined; + const uniqueSources: ValidSources[] = Array.from( new Set((docs || []).map((doc) => doc.source_type)) ).slice(0, 3); const markdownComponents = useMemo( () => ({ - a: MemoizedLink, - p: MemoizedParagraph, - code: ({ node, className, children, ...props }: any) => { + a: anchorCallback, + p: paragraphCallback, + code: ({ node, className, children }: any) => { const codeText = extractCodeText( node, finalContent as string, @@ -312,7 +339,7 @@ export const AIMessage = ({ ); }, }), - [finalContent] + [anchorCallback, paragraphCallback, finalContent] ); const renderedMarkdown = useMemo(() => { @@ -333,12 +360,11 @@ export const AIMessage = ({ onMessageSelection && otherMessagesCanSwitchTo && otherMessagesCanSwitchTo.length > 1; - return (
)} + {docs && docs.length > 0 && ( +
+
+
+ {!settings?.isMobile && + docs.length > 0 && + docs + .slice(0, 2) + .map((doc, ind) => ( + + ))} + +
+
+
+ )} + {content || files ? ( <> @@ -438,81 +490,6 @@ export const AIMessage = ({ ) : isComplete ? null : ( <> )} - {isComplete && docs && docs.length > 0 && ( -
-
-
- {!settings?.isMobile && - filteredDocs.length > 0 && - filteredDocs.slice(0, 2).map((doc, ind) => ( - - ))} -
{ - if (messageId) { - onMessageSelection?.(messageId); - } - toggleDocumentSelection?.(); - }} - key={-1} - className="cursor-pointer w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 py-2 border-b" - > -
-

See context

-
- {uniqueSources.map((sourceType, ind) => { - return ( -
- -
- ); - })} -
-
-
- See more -
-
-
-
-
- )}
{handleFeedback && diff --git a/web/src/app/chat/message/SearchSummary.tsx b/web/src/app/chat/message/SearchSummary.tsx index f86212fd290..7349ec6ca35 100644 --- a/web/src/app/chat/message/SearchSummary.tsx +++ b/web/src/app/chat/message/SearchSummary.tsx @@ -41,6 +41,7 @@ export function ShowHideDocsButton({ } export function SearchSummary({ + index, query, hasDocs, finished, @@ -48,6 +49,7 @@ export function SearchSummary({ handleShowRetrieved, handleSearchQueryEdit, }: { + index: number; finished: boolean; query: string; hasDocs: boolean; @@ -98,7 +100,14 @@ export function SearchSummary({ !text-sm !line-clamp-1 !break-all px-0.5`} ref={searchingForRef} > - {finished ? "Searched" : "Searching"} for: {finalQuery} + {finished ? "Searched" : "Searching"} for:{" "} + + {index === 1 + ? finalQuery.length > 50 + ? `${finalQuery.slice(0, 50)}...` + : finalQuery + : finalQuery} +
); diff --git a/web/src/app/chat/modal/FeedbackModal.tsx b/web/src/app/chat/modal/FeedbackModal.tsx index 886a761acb2..e050dcc62af 100644 --- a/web/src/app/chat/modal/FeedbackModal.tsx +++ b/web/src/app/chat/modal/FeedbackModal.tsx @@ -53,7 +53,7 @@ export const FeedbackModal = ({ : predefinedNegativeFeedbackOptions; return ( - + <>

diff --git a/web/src/app/chat/modal/SetDefaultModelModal.tsx b/web/src/app/chat/modal/SetDefaultModelModal.tsx index 27696c46916..22e7f60adfe 100644 --- a/web/src/app/chat/modal/SetDefaultModelModal.tsx +++ b/web/src/app/chat/modal/SetDefaultModelModal.tsx @@ -1,4 +1,4 @@ -import { Dispatch, SetStateAction, useEffect, useRef } from "react"; +import { Dispatch, SetStateAction, useContext, useEffect, useRef } from "react"; import { Modal } from "@/components/Modal"; import Text from "@/components/ui/text"; import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks"; @@ -9,6 +9,10 @@ import { setUserDefaultModel } from "@/lib/users/UserSettings"; import { useRouter } from "next/navigation"; import { PopupSpec } from "@/components/admin/connectors/Popup"; import { useUser } from "@/components/user/UserProvider"; +import { Separator } from "@/components/ui/separator"; +import { Switch } from "@/components/ui/switch"; +import { Label } from "@/components/admin/connectors/Field"; +import { SettingsContext } from "@/components/settings/SettingsProvider"; export function SetDefaultModelModal({ setPopup, @@ -23,7 +27,7 @@ export function SetDefaultModelModal({ onClose: () => void; defaultModel: string | null; }) { - const { refreshUser } = useUser(); + const { refreshUser, user, updateUserAutoScroll } = useUser(); const containerRef = useRef(null); const messageRef = useRef(null); @@ -121,16 +125,41 @@ export function SetDefaultModelModal({ const defaultProvider = llmProviders.find( (llmProvider) => llmProvider.is_default_provider ); + const settings = useContext(SettingsContext); + const autoScroll = settings?.enterpriseSettings?.auto_scroll; + + const checked = + user?.preferences?.auto_scroll === null + ? autoScroll + : user?.preferences?.auto_scroll; return ( <>

- Set Default Model + User settings

+
+
+ { + updateUserAutoScroll(checked); + }} + /> + +
+
+ + + +

+ Default model for assistants +

+ Choose a Large Language Model (LLM) to serve as the default for assistants that don't have a default model assigned. diff --git a/web/src/app/chat/page.tsx b/web/src/app/chat/page.tsx index 7894ce651fc..274f362d4f4 100644 --- a/web/src/app/chat/page.tsx +++ b/web/src/app/chat/page.tsx @@ -32,6 +32,7 @@ export default async function Page(props: { defaultAssistantId, shouldShowWelcomeModal, userInputPrompts, + ccPairs, } = data; return ( @@ -44,6 +45,9 @@ export default async function Page(props: { value={{ chatSessions, availableSources, + ccPairs, + documentSets, + tags, availableDocumentSets: documentSets, availableTags: tags, llmProviders, diff --git a/web/src/app/chat/sessionSidebar/HistorySidebar.tsx b/web/src/app/chat/sessionSidebar/HistorySidebar.tsx index d2ebf1c07e2..88002ee005a 100644 --- a/web/src/app/chat/sessionSidebar/HistorySidebar.tsx +++ b/web/src/app/chat/sessionSidebar/HistorySidebar.tsx @@ -113,7 +113,7 @@ export const HistorySidebar = forwardRef( {page == "chat" && (
( +
{children}
+); + +export interface SourceSelectorProps { + timeRange: DateRangePickerValue | null; + setTimeRange: React.Dispatch< + React.SetStateAction + >; + showDocSidebar?: boolean; + selectedSources: SourceMetadata[]; + setSelectedSources: React.Dispatch>; + selectedDocumentSets: string[]; + setSelectedDocumentSets: React.Dispatch>; + selectedTags: Tag[]; + setSelectedTags: React.Dispatch>; + availableDocumentSets: DocumentSet[]; + existingSources: ValidSources[]; + availableTags: Tag[]; + toggleFilters: () => void; + filtersUntoggled: boolean; + tagsOnLeft: boolean; +} + +export function SourceSelector({ + timeRange, + setTimeRange, + selectedSources, + setSelectedSources, + selectedDocumentSets, + setSelectedDocumentSets, + selectedTags, + setSelectedTags, + availableDocumentSets, + existingSources, + availableTags, + showDocSidebar, + toggleFilters, + filtersUntoggled, + tagsOnLeft, +}: SourceSelectorProps) { + const handleSelect = (source: SourceMetadata) => { + setSelectedSources((prev: SourceMetadata[]) => { + if ( + prev.map((source) => source.internalName).includes(source.internalName) + ) { + return prev.filter((s) => s.internalName !== source.internalName); + } else { + return [...prev, source]; + } + }); + }; + + const handleDocumentSetSelect = (documentSetName: string) => { + setSelectedDocumentSets((prev: string[]) => { + if (prev.includes(documentSetName)) { + return prev.filter((s) => s !== documentSetName); + } else { + return [...prev, documentSetName]; + } + }); + }; + + let allSourcesSelected = selectedSources.length > 0; + + const toggleAllSources = () => { + if (allSourcesSelected) { + setSelectedSources([]); + } else { + const allSources = listSourceMetadata().filter((source) => + existingSources.includes(source.internalName) + ); + setSelectedSources(allSources); + } + }; + + return ( +
+ + {!filtersUntoggled && ( + <> + + + +
+
+ Time Range + {true && ( + + )} +
+

+ {getTimeAgoString(timeRange?.from!) || "Select a time range"} +

+
+
+ + { + const initialDate = daterange?.from || new Date(); + const endDate = daterange?.to || new Date(); + setTimeRange({ + from: initialDate, + to: endDate, + selectValue: timeRange?.selectValue || "", + }); + }} + className="rounded-md " + /> + +
+ + {availableTags.length > 0 && ( + <> +
+ Tags +
+ + + )} + + {existingSources.length > 0 && ( +
+
+
+

Sources

+ +
+
+
+ {listSourceMetadata() + .filter((source) => + existingSources.includes(source.internalName) + ) + .map((source) => ( +
source.internalName) + .includes(source.internalName) + ? "bg-hover" + : "hover:bg-hover-light") + } + onClick={() => handleSelect(source)} + > + + + {source.displayName} + +
+ ))} +
+
+ )} + + {availableDocumentSets.length > 0 && ( + <> +
+ Knowledge Sets +
+
+ {availableDocumentSets.map((documentSet) => ( +
+
handleDocumentSetSelect(documentSet.name)} + > + + +
+ } + popupContent={ +
+
Description
+
+ {documentSet.description} +
+
+ } + classNameModifications="-ml-2" + /> + {documentSet.name} +
+
+ ))} +
+ + )} + + )} +
+ ); +} + +export function SelectedBubble({ + children, + onClick, +}: { + children: string | JSX.Element; + onClick: () => void; +}) { + return ( +
+ {children} + +
+ ); +} + +export function HorizontalFilters({ + timeRange, + setTimeRange, + selectedSources, + setSelectedSources, + selectedDocumentSets, + setSelectedDocumentSets, + availableDocumentSets, + existingSources, +}: SourceSelectorProps) { + const handleSourceSelect = (source: SourceMetadata) => { + setSelectedSources((prev: SourceMetadata[]) => { + const prevSourceNames = prev.map((source) => source.internalName); + if (prevSourceNames.includes(source.internalName)) { + return prev.filter((s) => s.internalName !== source.internalName); + } else { + return [...prev, source]; + } + }); + }; + + const handleDocumentSetSelect = (documentSetName: string) => { + setSelectedDocumentSets((prev: string[]) => { + if (prev.includes(documentSetName)) { + return prev.filter((s) => s !== documentSetName); + } else { + return [...prev, documentSetName]; + } + }); + }; + + const allSources = listSourceMetadata(); + const availableSources = allSources.filter((source) => + existingSources.includes(source.internalName) + ); + + return ( +
+
+
+ +
+ + { + return { + key: source.displayName, + display: ( + <> + + {source.displayName} + + ), + }; + })} + selected={selectedSources.map((source) => source.displayName)} + handleSelect={(option) => + handleSourceSelect( + allSources.find((source) => source.displayName === option.key)! + ) + } + icon={ +
+ +
+ } + defaultDisplay="All Sources" + /> + + { + return { + key: documentSet.name, + display: ( + <> +
+ +
+ {documentSet.name} + + ), + }; + })} + selected={selectedDocumentSets} + handleSelect={(option) => handleDocumentSetSelect(option.key)} + icon={ +
+ +
+ } + defaultDisplay="All Document Sets" + /> +
+ +
+
+ {timeRange && timeRange.selectValue && ( + setTimeRange(null)}> +
{timeRange.selectValue}
+
+ )} + {existingSources.length > 0 && + selectedSources.map((source) => ( + handleSourceSelect(source)} + > + <> + + {source.displayName} + + + ))} + {selectedDocumentSets.length > 0 && + selectedDocumentSets.map((documentSetName) => ( + handleDocumentSetSelect(documentSetName)} + > + <> +
+ +
+ {documentSetName} + +
+ ))} +
+
+
+ ); +} + +export function HorizontalSourceSelector({ + timeRange, + setTimeRange, + selectedSources, + setSelectedSources, + selectedDocumentSets, + setSelectedDocumentSets, + selectedTags, + setSelectedTags, + availableDocumentSets, + existingSources, + availableTags, +}: SourceSelectorProps) { + const handleSourceSelect = (source: SourceMetadata) => { + setSelectedSources((prev: SourceMetadata[]) => { + if (prev.map((s) => s.internalName).includes(source.internalName)) { + return prev.filter((s) => s.internalName !== source.internalName); + } else { + return [...prev, source]; + } + }); + }; + + const handleDocumentSetSelect = (documentSetName: string) => { + setSelectedDocumentSets((prev: string[]) => { + if (prev.includes(documentSetName)) { + return prev.filter((s) => s !== documentSetName); + } else { + return [...prev, documentSetName]; + } + }); + }; + + const handleTagSelect = (tag: Tag) => { + setSelectedTags((prev: Tag[]) => { + if ( + prev.some( + (t) => t.tag_key === tag.tag_key && t.tag_value === tag.tag_value + ) + ) { + return prev.filter( + (t) => !(t.tag_key === tag.tag_key && t.tag_value === tag.tag_value) + ); + } else { + return [...prev, tag]; + } + }); + }; + + const resetSources = () => { + setSelectedSources([]); + }; + const resetDocuments = () => { + setSelectedDocumentSets([]); + }; + + const resetTags = () => { + setSelectedTags([]); + }; + + return ( +
+ + +
+ + + {timeRange?.from ? getTimeAgoString(timeRange.from) : "Since"} +
+
+ + { + const initialDate = daterange?.from || new Date(); + const endDate = daterange?.to || new Date(); + setTimeRange({ + from: initialDate, + to: endDate, + selectValue: timeRange?.selectValue || "", + }); + }} + className="rounded-md" + /> + +
+ + {existingSources.length > 0 && ( + existingSources.includes(source.internalName)) + .map((source) => ({ + key: source.internalName, + display: ( + <> + + {source.displayName} + + ), + }))} + selected={selectedSources.map((source) => source.internalName)} + handleSelect={(option) => + handleSourceSelect( + listSourceMetadata().find((s) => s.internalName === option.key)! + ) + } + icon={} + defaultDisplay="Sources" + dropdownColor="bg-background-search-filter-dropdown" + width="w-fit ellipsis truncate" + resetValues={resetSources} + dropdownWidth="w-40" + optionClassName="truncate w-full break-all ellipsis" + /> + )} + + {availableDocumentSets.length > 0 && ( + ({ + key: documentSet.name, + display: <>{documentSet.name}, + }))} + selected={selectedDocumentSets} + handleSelect={(option) => handleDocumentSetSelect(option.key)} + icon={} + defaultDisplay="Sets" + resetValues={resetDocuments} + width="w-fit max-w-24 text-ellipsis truncate" + dropdownColor="bg-background-search-filter-dropdown" + dropdownWidth="max-w-36 w-fit" + optionClassName="truncate w-full break-all" + /> + )} + + {availableTags.length > 0 && ( + ({ + key: `${tag.tag_key}=${tag.tag_value}`, + display: ( + + {tag.tag_key} + = + {tag.tag_value} + + ), + }))} + selected={selectedTags.map( + (tag) => `${tag.tag_key}=${tag.tag_value}` + )} + handleSelect={(option) => { + const [tag_key, tag_value] = option.key.split("="); + const selectedTag = availableTags.find( + (tag) => tag.tag_key === tag_key && tag.tag_value === tag_value + ); + if (selectedTag) { + handleTagSelect(selectedTag); + } + }} + icon={} + defaultDisplay="Tags" + resetValues={resetTags} + dropdownColor="bg-background-search-filter-dropdown" + width="w-fit max-w-24 ellipsis truncate" + dropdownWidth="max-w-80 w-fit" + optionClassName="truncate w-full break-all ellipsis" + /> + )} +
+ ); +} diff --git a/web/src/app/chat/shared_chat_search/FixedLogo.tsx b/web/src/app/chat/shared_chat_search/FixedLogo.tsx index 5bf48c27794..c7e7a7d2afc 100644 --- a/web/src/app/chat/shared_chat_search/FixedLogo.tsx +++ b/web/src/app/chat/shared_chat_search/FixedLogo.tsx @@ -21,9 +21,7 @@ export default function FixedLogo({ return ( <>
@@ -49,7 +47,7 @@ export default function FixedLogo({
- + {/* */}
); diff --git a/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx b/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx index e8c377dc57f..8a58c639136 100644 --- a/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx +++ b/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx @@ -1,90 +1,7 @@ "use client"; -import React, { ReactNode, useContext, useEffect, useState } from "react"; -import { usePathname, useRouter } from "next/navigation"; -import { ChatIcon, SearchIcon } from "@/components/icons/icons"; -import { SettingsContext } from "@/components/settings/SettingsProvider"; -import KeyboardSymbol from "@/lib/browserUtilities"; - -const ToggleSwitch = () => { - const commandSymbol = KeyboardSymbol(); - const pathname = usePathname(); - const router = useRouter(); - const settings = useContext(SettingsContext); - - const [activeTab, setActiveTab] = useState(() => { - return pathname == "/search" ? "search" : "chat"; - }); - - const [isInitialLoad, setIsInitialLoad] = useState(true); - - useEffect(() => { - const newTab = pathname === "/search" ? "search" : "chat"; - setActiveTab(newTab); - localStorage.setItem("activeTab", newTab); - setIsInitialLoad(false); - }, [pathname]); - - const handleTabChange = (tab: string) => { - setActiveTab(tab); - localStorage.setItem("activeTab", tab); - if (settings?.isMobile && window) { - window.location.href = tab; - } else { - router.push(tab === "search" ? "/search" : "/chat"); - } - }; - - return ( -
-
- - -
- ); -}; +import React, { ReactNode, useEffect, useState } from "react"; +import { useRouter } from "next/navigation"; export default function FunctionalWrapper({ initiallyToggled, @@ -128,12 +45,6 @@ export default function FunctionalWrapper({ window.removeEventListener("keydown", handleKeyDown); }; }, [router]); - const combinedSettings = useContext(SettingsContext); - const settings = combinedSettings?.settings; - const chatBannerPresent = - combinedSettings?.enterpriseSettings?.custom_header_content; - const twoLines = - combinedSettings?.enterpriseSettings?.two_lines_for_chat_header; const [toggledSidebar, setToggledSidebar] = useState(initiallyToggled); @@ -145,24 +56,7 @@ export default function FunctionalWrapper({ return ( <> - {(!settings || - (settings.search_page_enabled && settings.chat_page_enabled)) && ( -
-
-
- -
-
- )} - + {" "}
{content(toggledSidebar, toggle)}
diff --git a/web/src/app/chat/shared_chat_search/SearchFilters.tsx b/web/src/app/chat/shared_chat_search/SearchFilters.tsx new file mode 100644 index 00000000000..46ceda9a71c --- /dev/null +++ b/web/src/app/chat/shared_chat_search/SearchFilters.tsx @@ -0,0 +1,294 @@ +import { DocumentSet, Tag, ValidSources } from "@/lib/types"; +import { SourceMetadata } from "@/lib/search/interfaces"; +import { InfoIcon, defaultTailwindCSS } from "@/components/icons/icons"; +import { HoverPopup } from "@/components/HoverPopup"; +import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector"; +import { SourceIcon } from "@/components/SourceIcon"; +import { Checkbox } from "@/components/ui/checkbox"; +import { TagFilter } from "@/components/search/filtering/TagFilter"; +import { CardContent } from "@/components/ui/card"; +import { useEffect } from "react"; +import { useState } from "react"; +import { listSourceMetadata } from "@/lib/sources"; +import { Calendar } from "@/components/ui/calendar"; +import { getDateRangeString } from "@/lib/dateUtils"; +import { Button } from "@/components/ui/button"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; +import { ToolTipDetails } from "@/components/admin/connectors/Field"; + +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/ui/tooltip"; +import { TooltipProvider } from "@radix-ui/react-tooltip"; + +const SectionTitle = ({ + children, + modal, +}: { + children: string; + modal?: boolean; +}) => ( +
+

{children}

+
+); + +export interface SourceSelectorProps { + timeRange: DateRangePickerValue | null; + setTimeRange: React.Dispatch< + React.SetStateAction + >; + showDocSidebar?: boolean; + selectedSources: SourceMetadata[]; + setSelectedSources: React.Dispatch>; + selectedDocumentSets: string[]; + setSelectedDocumentSets: React.Dispatch>; + selectedTags: Tag[]; + setSelectedTags: React.Dispatch>; + availableDocumentSets: DocumentSet[]; + existingSources: ValidSources[]; + availableTags: Tag[]; + filtersUntoggled: boolean; + modal?: boolean; + tagsOnLeft: boolean; +} + +export function SourceSelector({ + timeRange, + filtersUntoggled, + setTimeRange, + selectedSources, + setSelectedSources, + selectedDocumentSets, + setSelectedDocumentSets, + selectedTags, + setSelectedTags, + availableDocumentSets, + existingSources, + modal, + availableTags, +}: SourceSelectorProps) { + const handleSelect = (source: SourceMetadata) => { + setSelectedSources((prev: SourceMetadata[]) => { + if ( + prev.map((source) => source.internalName).includes(source.internalName) + ) { + return prev.filter((s) => s.internalName !== source.internalName); + } else { + return [...prev, source]; + } + }); + }; + + const handleDocumentSetSelect = (documentSetName: string) => { + setSelectedDocumentSets((prev: string[]) => { + if (prev.includes(documentSetName)) { + return prev.filter((s) => s !== documentSetName); + } else { + return [...prev, documentSetName]; + } + }); + }; + + let allSourcesSelected = selectedSources.length > 0; + + const toggleAllSources = () => { + if (allSourcesSelected) { + setSelectedSources([]); + } else { + const allSources = listSourceMetadata().filter((source) => + existingSources.includes(source.internalName) + ); + setSelectedSources(allSources); + } + }; + + const [isCalendarOpen, setIsCalendarOpen] = useState(false); + + useEffect(() => { + const handleClickOutside = (event: MouseEvent) => { + const calendar = document.querySelector(".rdp"); + if (calendar && !calendar.contains(event.target as Node)) { + setIsCalendarOpen(false); + } + }; + + document.addEventListener("mousedown", handleClickOutside); + return () => { + document.removeEventListener("mousedown", handleClickOutside); + }; + }, []); + + return ( +
+ {!filtersUntoggled && ( + +
+
+

Time Range

+ {timeRange && ( + + )} +
+ + + + + + { + const today = new Date(); + const initialDate = daterange?.from + ? new Date( + Math.min(daterange.from.getTime(), today.getTime()) + ) + : today; + const endDate = daterange?.to + ? new Date( + Math.min(daterange.to.getTime(), today.getTime()) + ) + : today; + setTimeRange({ + from: initialDate, + to: endDate, + selectValue: timeRange?.selectValue || "", + }); + }} + className="rounded-md" + /> + + +
+ + {availableTags.length > 0 && ( +
+ Tags + +
+ )} + + {existingSources.length > 0 && ( +
+ Sources + +
+ {existingSources.length > 1 && ( +
+ + + +
+ )} + {listSourceMetadata() + .filter((source) => + existingSources.includes(source.internalName) + ) + .map((source) => ( +
handleSelect(source)} + > + s.internalName) + .includes(source.internalName)} + /> + + {source.displayName} +
+ ))} +
+
+ )} + + {availableDocumentSets.length > 0 && ( +
+ Knowledge Sets +
+ {availableDocumentSets.map((documentSet) => ( +
handleDocumentSetSelect(documentSet.name)} + > + + + + + + + +
+
Description
+
+ {documentSet.description} +
+
+
+
+
+ {documentSet.name} +
+ ))} +
+
+ )} +
+ )} +
+ ); +} diff --git a/web/src/app/ee/admin/whitelabeling/WhitelabelingForm.tsx b/web/src/app/ee/admin/whitelabeling/WhitelabelingForm.tsx index 475c689441a..cd977d44c1c 100644 --- a/web/src/app/ee/admin/whitelabeling/WhitelabelingForm.tsx +++ b/web/src/app/ee/admin/whitelabeling/WhitelabelingForm.tsx @@ -55,6 +55,7 @@ export function WhitelabelingForm() {
( - - )} - /> - ); -} diff --git a/web/src/app/search/page.tsx b/web/src/app/search/page.tsx deleted file mode 100644 index 3572d7cbe40..00000000000 --- a/web/src/app/search/page.tsx +++ /dev/null @@ -1,213 +0,0 @@ -import { - AuthTypeMetadata, - getAuthTypeMetadataSS, - getCurrentUserSS, -} from "@/lib/userSS"; -import { redirect } from "next/navigation"; -import { HealthCheckBanner } from "@/components/health/healthcheck"; -import { fetchSS } from "@/lib/utilsSS"; -import { CCPairBasicInfo, DocumentSet, Tag, User } from "@/lib/types"; -import { cookies } from "next/headers"; -import { SearchType } from "@/lib/search/interfaces"; -import { Persona } from "../admin/assistants/interfaces"; -import { unstable_noStore as noStore } from "next/cache"; -import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh"; -import { personaComparator } from "../admin/assistants/lib"; -import { FullEmbeddingModelResponse } from "@/components/embedding/interfaces"; -import { ChatPopup } from "../chat/ChatPopup"; -import { - FetchAssistantsResponse, - fetchAssistantsSS, -} from "@/lib/assistants/fetchAssistantsSS"; -import { ChatSession } from "../chat/interfaces"; -import { SIDEBAR_TOGGLED_COOKIE_NAME } from "@/components/resizable/constants"; -import { - AGENTIC_SEARCH_TYPE_COOKIE_NAME, - NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN, - DISABLE_LLM_DOC_RELEVANCE, -} from "@/lib/constants"; -import WrappedSearch from "./WrappedSearch"; -import { SearchProvider } from "@/components/context/SearchContext"; -import { fetchLLMProvidersSS } from "@/lib/llm/fetchLLMs"; -import { LLMProviderDescriptor } from "../admin/configuration/llm/interfaces"; -import { headers } from "next/headers"; -import { - hasCompletedWelcomeFlowSS, - WelcomeModal, -} from "@/components/initialSetup/welcome/WelcomeModalWrapper"; - -export default async function Home(props: { - searchParams: Promise<{ [key: string]: string | string[] | undefined }>; -}) { - const searchParams = await props.searchParams; - // Disable caching so we always get the up to date connector / document set / persona info - // importantly, this prevents users from adding a connector, going back to the main page, - // and then getting hit with a "No Connectors" popup - noStore(); - const requestCookies = await cookies(); - const tasks = [ - getAuthTypeMetadataSS(), - getCurrentUserSS(), - fetchSS("/manage/indexing-status"), - fetchSS("/manage/document-set"), - fetchAssistantsSS(), - fetchSS("/query/valid-tags"), - fetchSS("/query/user-searches"), - fetchLLMProvidersSS(), - ]; - - // catch cases where the backend is completely unreachable here - // without try / catch, will just raise an exception and the page - // will not render - let results: ( - | User - | Response - | AuthTypeMetadata - | FullEmbeddingModelResponse - | FetchAssistantsResponse - | LLMProviderDescriptor[] - | null - )[] = [null, null, null, null, null, null, null, null]; - try { - results = await Promise.all(tasks); - } catch (e) { - console.log(`Some fetch failed for the main search page - ${e}`); - } - const authTypeMetadata = results[0] as AuthTypeMetadata | null; - const user = results[1] as User | null; - const ccPairsResponse = results[2] as Response | null; - const documentSetsResponse = results[3] as Response | null; - const [initialAssistantsList, assistantsFetchError] = - results[4] as FetchAssistantsResponse; - const tagsResponse = results[5] as Response | null; - const queryResponse = results[6] as Response | null; - const llmProviders = (results[7] || []) as LLMProviderDescriptor[]; - - const authDisabled = authTypeMetadata?.authType === "disabled"; - - if (!authDisabled && !user) { - const headersList = await headers(); - const fullUrl = headersList.get("x-url") || "/search"; - const searchParamsString = new URLSearchParams( - searchParams as unknown as Record - ).toString(); - const redirectUrl = searchParamsString - ? `${fullUrl}?${searchParamsString}` - : fullUrl; - return redirect(`/auth/login?next=${encodeURIComponent(redirectUrl)}`); - } - - if (user && !user.is_verified && authTypeMetadata?.requiresVerification) { - return redirect("/auth/waiting-on-verification"); - } - - let ccPairs: CCPairBasicInfo[] = []; - if (ccPairsResponse?.ok) { - ccPairs = await ccPairsResponse.json(); - } else { - console.log(`Failed to fetch connectors - ${ccPairsResponse?.status}`); - } - - let documentSets: DocumentSet[] = []; - if (documentSetsResponse?.ok) { - documentSets = await documentSetsResponse.json(); - } else { - console.log( - `Failed to fetch document sets - ${documentSetsResponse?.status}` - ); - } - - let querySessions: ChatSession[] = []; - if (queryResponse?.ok) { - querySessions = (await queryResponse.json()).sessions; - } else { - console.log(`Failed to fetch chat sessions - ${queryResponse?.text()}`); - } - - let assistants: Persona[] = initialAssistantsList; - if (assistantsFetchError) { - console.log(`Failed to fetch assistants - ${assistantsFetchError}`); - } else { - // remove those marked as hidden by an admin - assistants = assistants.filter((assistant) => assistant.is_visible); - // hide personas with no retrieval - assistants = assistants.filter((assistant) => assistant.num_chunks !== 0); - // sort them in priority order - assistants.sort(personaComparator); - } - - let tags: Tag[] = []; - if (tagsResponse?.ok) { - tags = (await tagsResponse.json()).tags; - } else { - console.log(`Failed to fetch tags - ${tagsResponse?.status}`); - } - - // needs to be done in a non-client side component due to nextjs - const storedSearchType = requestCookies.get("searchType")?.value as - | string - | undefined; - const searchTypeDefault: SearchType = - storedSearchType !== undefined && - SearchType.hasOwnProperty(storedSearchType) - ? (storedSearchType as SearchType) - : SearchType.SEMANTIC; // default to semantic - - const hasAnyConnectors = ccPairs.length > 0; - - const shouldShowWelcomeModal = - !llmProviders.length && - !hasCompletedWelcomeFlowSS(requestCookies) && - !hasAnyConnectors && - (!user || user.role === "admin"); - - const shouldDisplayNoSourcesModal = - (!user || user.role === "admin") && - ccPairs.length === 0 && - !shouldShowWelcomeModal; - - const sidebarToggled = requestCookies.get(SIDEBAR_TOGGLED_COOKIE_NAME); - const agenticSearchToggle = requestCookies.get( - AGENTIC_SEARCH_TYPE_COOKIE_NAME - ); - - const toggleSidebar = sidebarToggled - ? sidebarToggled.value.toLocaleLowerCase() == "true" || false - : NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN; - - const agenticSearchEnabled = agenticSearchToggle - ? agenticSearchToggle.value.toLocaleLowerCase() == "true" || false - : false; - - return ( - <> - - - {shouldShowWelcomeModal && ( - - )} - {/* ChatPopup is a custom popup that displays a admin-specified message on initial user visit. - Only used in the EE version of the app. */} - - - - - - ); -} diff --git a/web/src/components/InternetSearchIcon.tsx b/web/src/components/InternetSearchIcon.tsx deleted file mode 100644 index e21218da9c5..00000000000 --- a/web/src/components/InternetSearchIcon.tsx +++ /dev/null @@ -1,9 +0,0 @@ -export function InternetSearchIcon({ url }: { url: string }) { - return ( - favicon - ); -} diff --git a/web/src/components/MetadataBadge.tsx b/web/src/components/MetadataBadge.tsx index cfd94d0a879..f06429a92b1 100644 --- a/web/src/components/MetadataBadge.tsx +++ b/web/src/components/MetadataBadge.tsx @@ -1,9 +1,11 @@ export function MetadataBadge({ icon, value, + flexNone, }: { icon?: React.FC<{ size?: number; className?: string }>; value: string | JSX.Element; + flexNone?: boolean; }) { return (
- {icon && icon({ size: 12, className: "mr-0.5 my-auto" })} + {icon && + icon({ + size: 12, + className: flexNone ? "flex-none" : "mr-0.5 my-auto", + })}
{value}
); diff --git a/web/src/components/Modal.tsx b/web/src/components/Modal.tsx index 05886975088..7175d46cbb1 100644 --- a/web/src/components/Modal.tsx +++ b/web/src/components/Modal.tsx @@ -1,11 +1,11 @@ "use client"; import { Separator } from "@/components/ui/separator"; -import { FiX } from "react-icons/fi"; import { IconProps, XIcon } from "./icons/icons"; import { useRef } from "react"; import { isEventWithinRef } from "@/lib/contains"; import ReactDOM from "react-dom"; import { useEffect, useState } from "react"; +import { cn } from "@/lib/utils"; interface ModalProps { icon?: ({ size, className }: IconProps) => JSX.Element; @@ -18,6 +18,8 @@ interface ModalProps { hideDividerForTitle?: boolean; hideCloseButton?: boolean; noPadding?: boolean; + height?: string; + noScroll?: boolean; } export function Modal({ @@ -28,9 +30,11 @@ export function Modal({ width, titleSize, hideDividerForTitle, + height, noPadding, icon, hideCloseButton, + noScroll, }: ModalProps) { const modalRef = useRef(null); const [isMounted, setIsMounted] = useState(false); @@ -56,8 +60,10 @@ export function Modal({ const modalContent = (
)} - -
+
{title && ( <>
@@ -110,7 +115,14 @@ export function Modal({ {!hideDividerForTitle && } )} -
{children}
+
+ {children} +
diff --git a/web/src/components/SearchResultIcon.tsx b/web/src/components/SearchResultIcon.tsx new file mode 100644 index 00000000000..28aee05783f --- /dev/null +++ b/web/src/components/SearchResultIcon.tsx @@ -0,0 +1,65 @@ +import { useState, useEffect } from "react"; +import faviconFetch from "favicon-fetch"; +import { SourceIcon } from "./SourceIcon"; + +const CACHE_DURATION = 24 * 60 * 60 * 1000; + +export async function getFaviconUrl(url: string): Promise { + const getCachedFavicon = () => { + const cachedData = localStorage.getItem(`favicon_${url}`); + if (cachedData) { + const { favicon, timestamp } = JSON.parse(cachedData); + if (Date.now() - timestamp < CACHE_DURATION) { + return favicon; + } + } + return null; + }; + + const cachedFavicon = getCachedFavicon(); + if (cachedFavicon) { + return cachedFavicon; + } + + const newFaviconUrl = await faviconFetch({ uri: url }); + if (newFaviconUrl) { + localStorage.setItem( + `favicon_${url}`, + JSON.stringify({ favicon: newFaviconUrl, timestamp: Date.now() }) + ); + return newFaviconUrl; + } + + return null; +} + +export function SearchResultIcon({ url }: { url: string }) { + const [faviconUrl, setFaviconUrl] = useState(null); + + useEffect(() => { + getFaviconUrl(url).then((favicon) => { + if (favicon) { + setFaviconUrl(favicon); + } + }); + }, [url]); + + if (!faviconUrl) { + return ; + } + + return ( +
+ favicon { + e.currentTarget.onerror = null; + }} + /> +
+ ); +} diff --git a/web/src/components/UserDropdown.tsx b/web/src/components/UserDropdown.tsx index 8a1503410c4..cc43980b8c2 100644 --- a/web/src/components/UserDropdown.tsx +++ b/web/src/components/UserDropdown.tsx @@ -9,7 +9,7 @@ import { checkUserIsNoAuthUser, logout } from "@/lib/user"; import { Popover } from "./popover/Popover"; import { LOGOUT_DISABLED } from "@/lib/constants"; import { SettingsContext } from "./settings/SettingsProvider"; -import { BellIcon, LightSettingsIcon } from "./icons/icons"; +import { BellIcon, LightSettingsIcon, UserIcon } from "./icons/icons"; import { pageType } from "@/app/chat/sessionSidebar/types"; import { NavigationItem, Notification } from "@/app/admin/settings/interfaces"; import DynamicFaIcon, { preloadIcons } from "./icons/DynamicFaIcon"; @@ -56,7 +56,13 @@ const DropdownOption: React.FC = ({ } }; -export function UserDropdown({ page }: { page?: pageType }) { +export function UserDropdown({ + page, + toggleUserSettings, +}: { + page?: pageType; + toggleUserSettings?: () => void; +}) { const { user, isCurator } = useUser(); const [userInfoVisible, setUserInfoVisible] = useState(false); const userInfoRef = useRef(null); @@ -238,6 +244,13 @@ export function UserDropdown({ page }: { page?: pageType }) { ) )} + {toggleUserSettings && ( + } + label="User Settings" + /> + )} { setUserInfoVisible(true); diff --git a/web/src/components/WebResultIcon.tsx b/web/src/components/WebResultIcon.tsx new file mode 100644 index 00000000000..27e5e91ee4e --- /dev/null +++ b/web/src/components/WebResultIcon.tsx @@ -0,0 +1,16 @@ +import { SourceIcon } from "./SourceIcon"; + +export function WebResultIcon({ url }: { url: string }) { + const hostname = new URL(url).hostname; + return hostname == "https://docs.danswer.dev" ? ( + favicon + ) : ( + + ); +} diff --git a/web/src/components/admin/connectors/AdminSidebar.tsx b/web/src/components/admin/connectors/AdminSidebar.tsx index 26f0694a6a1..be1ee3de933 100644 --- a/web/src/components/admin/connectors/AdminSidebar.tsx +++ b/web/src/components/admin/connectors/AdminSidebar.tsx @@ -40,14 +40,7 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {