From f9da3978823f38f4cfd4146ab3a2d4d3b7f93a9f Mon Sep 17 00:00:00 2001 From: Pete Miller Date: Tue, 10 Dec 2024 16:47:21 -0800 Subject: [PATCH] Uplift AI Chat Full Page Storage feature MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - [AIChat conversation data storage by petemill · Pull Request \#25876 · brave/brave-core](https://github.com/brave/brave-core/pull/25876) - [AI Chat conversations should persist browser restart · Issue \#42800 · brave/brave-browser](https://github.com/brave/brave-browser/issues/42800) `QA/Yes` - [\[AIChat\]: Add URL based routing for different chats by fallaciousreasoning · Pull Request \#26050 · brave/brave-core](https://github.com/brave/brave-core/pull/26050) - [\[AI Chat\]: URLs for all conversations · Issue \#42055 · brave/brave-browser](https://github.com/brave/brave-browser/issues/42055) `QA/Yes` - test plan: url and header variations specified - [AI Chat fullpage UI notices and polish by petemill · Pull Request \#26678 · brave/brave-core](https://github.com/brave/brave-core/pull/26678) - [AI Chat conversation list should prompt to enable history storage if it's disabled · Issue \#42576 · brave/brave-browser](https://github.com/brave/brave-browser/issues/42576) `QA/Yes` - test plan: specified - [AI Chat notice for conversation storage · Issue \#42360 · brave/brave-browser](https://github.com/brave/brave-browser/issues/42360) `QA/Yes` - test plan: specified - [AI Chat FullPage shouldn't show in SidePanel mode when restored at startup · Issue \#42413 · brave/brave-browser](https://github.com/brave/brave-browser/issues/42413) `QA/Yes` - test plan: specified - [AI Chat becomes a trusted WebUI with an UntrustedWebUI frame for LLM-generated responses by petemill · Pull Request \#26855 · brave/brave-core](https://github.com/brave/brave-core/pull/26855) - [Change AI Chat url to chrome://leo-ai · Issue \#42817 · brave/brave-browser](https://github.com/brave/brave-browser/issues/42817) `QA/Yes` - [AI Chat conversation entries should be isolated in an untrusted frame · Issue \#42818 · brave/brave-browser](https://github.com/brave/brave-browser/issues/42818) `QA/No` - [\[AI Chat\]: Update copy button label by fallaciousreasoning · Pull Request \#26422 · brave/brave-core](https://github.com/brave/brave-core/pull/26422) - [Change copy button name to text if text in code block · Issue \#42117 · brave/brave-browser](https://github.com/brave/brave-browser/issues/42117) `QA/Yes` - [\[AI Chat\]: Conversation starter pack, for unassociated content by fallaciousreasoning · Pull Request \#26379 · brave/brave-core](https://github.com/brave/brave-core/pull/26379) - [\[AI Chat\]: Add support for static conversation starters · Issue \#42106 · brave/brave-browser](https://github.com/brave/brave-browser/issues/42106) `QA/Yes` - test plan: needs improvement - [Css tweaks and icon change for sidebar by aguscruiz · Pull Request \#26450 · brave/brave-core](https://github.com/brave/brave-core/pull/26450) - [\[Leo full page\] - Tweak to expand/collapse sidebar icons · Issue \#42068 · brave/brave-browser](https://github.com/brave/brave-browser/issues/42068) `QA/Yes` - test plan: specified (check icons are correct) - [\[AI Chat\]: Update styling on suggestions by fallaciousreasoning · Pull Request \#26565 · brave/brave-core](https://github.com/brave/brave-core/pull/26565) - [\[AI Chat\]: Update suggestions style to match new design · Issue \#42107 · brave/brave-browser](https://github.com/brave/brave-browser/issues/42107) `QA/Yes` - test plan: specified (check suggestions have new style) - [\[AI Chat\]: Don't show starter suggestions on non-empty chats by fallaciousreasoning · Pull Request \#26677 · brave/brave-core](https://github.com/brave/brave-core/pull/26677) - [AI Chat static conversation starters shouldn't show if conversation has chat history · Issue \#42412 · brave/brave-browser](https://github.com/brave/brave-browser/issues/42412) `QA/Yes` - test plan: specified (verify conversation starters only show when applicable) - [fix hit area for leo conversations list by aguscruiz · Pull Request \#26793 · brave/brave-core](https://github.com/brave/brave-core/pull/26793) - [Leo full page - Conversation list - Make the whole item clickable instead of excluding the padding · Issue \#42552 · brave/brave-browser](https://github.com/brave/brave-browser/issues/42552) `QA/No` - [\[Nala / @brave/leo\] update dependency by petemill · Pull Request \#26767 · brave/brave-core](https://github.com/brave/brave-core/pull/26767) - https://github.com/brave/brave-browser/issues/42545 `QA/No` --- app/brave_settings_strings.grdp | 26 +- browser/ai_chat/BUILD.gn | 3 + browser/ai_chat/ai_chat_service_factory.cc | 3 +- browser/ai_chat/ai_chat_throttle_unittest.cc | 36 +- browser/ai_chat/ai_chat_urls.cc | 29 + browser/ai_chat/ai_chat_urls.h | 29 + .../ai_chat/android/ai_chat_utils_android.cc | 2 +- browser/brave_content_browser_client.cc | 5 + .../brave_browsing_data_remover_delegate.cc | 32 +- .../brave_browsing_data_remover_delegate.h | 1 - browser/browsing_data/sources.gni | 2 + .../api/settings_private/brave_prefs_util.cc | 2 + ...rave_clear_browsing_data_on_exit_page.html | 4 +- .../brave_clear_browsing_data_on_exit_page.ts | 2 +- .../brave_leo_assistant_page.html | 11 +- .../brave_leo_assistant_page.ts | 37 +- .../clear_browsing_data_dialog.ts | 4 +- .../settings/brave_overrides/settings_menu.ts | 2 +- browser/resources/settings/brave_routes.ts | 2 +- browser/resources/settings/sources.gni | 1 + browser/ui/BUILD.gn | 4 +- browser/ui/brave_pages.cc | 2 +- .../side_panel/brave_side_panel_utils.cc | 4 +- browser/ui/webui/ai_chat/ai_chat_ui.cc | 122 +- browser/ui/webui/ai_chat/ai_chat_ui.h | 15 +- .../webui/ai_chat/ai_chat_ui_page_handler.cc | 23 +- .../webui/ai_chat/ai_chat_ui_page_handler.h | 7 +- .../ai_chat_untrusted_conversation_ui.cc | 215 ++++ .../ai_chat_untrusted_conversation_ui.h | 53 + .../brave_settings_leo_assistant_handler.cc | 22 +- ...ave_settings_localized_strings_provider.cc | 11 +- .../chrome_autocomplete_provider_client.cc | 61 +- .../webui/chrome_untrusted_web_ui_configs.cc | 4 +- .../browser/ui/webui/chrome_web_ui_configs.cc | 6 + components/ai_chat/content/browser/DEPS | 2 + .../browser/ai_chat_brave_search_throttle.cc | 14 +- .../browser/ai_chat_brave_search_throttle.h | 10 + .../content/browser/ai_chat_tab_helper.cc | 36 +- .../content/browser/ai_chat_tab_helper.h | 24 + .../content/browser/ai_chat_throttle.cc | 36 +- .../content/browser/ai_chat_throttle.h | 4 + .../content/browser/model_service_factory.cc | 1 + .../content/browser/model_service_factory.h | 5 + .../content/browser/page_content_fetcher.cc | 32 +- .../content/browser/page_content_fetcher.h | 2 + .../ai_chat/content/browser/pdf_utils.cc | 10 + components/ai_chat/core/browser/BUILD.gn | 15 + components/ai_chat/core/browser/DEPS | 8 + .../browser/ai_chat_credential_manager.cc | 21 +- .../core/browser/ai_chat_credential_manager.h | 7 + .../ai_chat_credential_manager_unittest.cc | 15 +- .../ai_chat/core/browser/ai_chat_database.cc | 1051 +++++++++++++++++ .../ai_chat/core/browser/ai_chat_database.h | 121 ++ .../core/browser/ai_chat_database_unittest.cc | 442 +++++++ .../core/browser/ai_chat_feedback_api.cc | 6 +- .../core/browser/ai_chat_feedback_api.h | 5 + .../ai_chat/core/browser/ai_chat_metrics.cc | 26 +- .../ai_chat/core/browser/ai_chat_metrics.h | 5 + .../core/browser/ai_chat_metrics_unittest.cc | 10 +- .../ai_chat/core/browser/ai_chat_service.cc | 731 ++++++++++-- .../ai_chat/core/browser/ai_chat_service.h | 115 +- .../core/browser/ai_chat_service_unittest.cc | 460 ++++++-- .../browser/associated_archive_content.cc | 19 + .../core/browser/associated_archive_content.h | 11 + .../core/browser/associated_content_driver.cc | 47 +- .../core/browser/associated_content_driver.h | 32 +- .../associated_content_driver_unittest.cc | 10 +- components/ai_chat/core/browser/constants.cc | 20 +- components/ai_chat/core/browser/constants.h | 13 +- .../core/browser/conversation_handler.cc | 551 +++++++-- .../core/browser/conversation_handler.h | 112 +- .../browser/conversation_handler_unittest.cc | 392 +++--- .../browser/engine/conversation_api_client.cc | 18 +- .../browser/engine/conversation_api_client.h | 14 + .../conversation_api_client_unittest.cc | 25 +- .../core/browser/engine/engine_consumer.cc | 4 + .../core/browser/engine/engine_consumer.h | 5 + .../browser/engine/engine_consumer_claude.cc | 16 +- .../browser/engine/engine_consumer_claude.h | 10 + .../engine/engine_consumer_claude_unittest.cc | 27 +- .../engine_consumer_conversation_api.cc | 11 +- .../engine/engine_consumer_conversation_api.h | 8 + ...gine_consumer_conversation_api_unittest.cc | 53 +- .../browser/engine/engine_consumer_llama.cc | 20 +- .../browser/engine/engine_consumer_llama.h | 10 + .../engine/engine_consumer_llama_unittest.cc | 28 +- .../browser/engine/engine_consumer_oai.cc | 12 +- .../core/browser/engine/engine_consumer_oai.h | 5 + .../engine/engine_consumer_oai_unittest.cc | 33 +- .../browser/engine/mock_engine_consumer.h | 5 + .../engine/mock_remote_completion_client.cc | 5 +- .../core/browser/engine/oai_api_client.cc | 15 +- .../core/browser/engine/oai_api_client.h | 10 + .../browser/engine/oai_api_client_unittest.cc | 11 +- .../engine/remote_completion_client.cc | 15 +- .../browser/engine/remote_completion_client.h | 11 + .../ai_chat/core/browser/engine/test_utils.cc | 21 +- .../core/browser/local_models_updater.cc | 11 + .../core/browser/local_models_updater.h | 10 + .../browser/local_models_updater_unittest.cc | 4 +- .../mock_conversation_handler_observer.cc | 18 + .../mock_conversation_handler_observer.h | 66 ++ .../ai_chat/core/browser/model_service.cc | 18 + .../ai_chat/core/browser/model_service.h | 12 +- .../core/browser/model_service_unittest.cc | 7 +- .../ai_chat/core/browser/model_validator.cc | 4 +- .../ai_chat/core/browser/model_validator.h | 14 +- .../core/browser/model_validator_unittest.cc | 7 +- components/ai_chat/core/browser/test_utils.cc | 228 ++++ components/ai_chat/core/browser/test_utils.h | 50 + .../ai_chat/core/browser/text_embedder.cc | 21 +- .../ai_chat/core/browser/text_embedder.h | 13 + .../core/browser/text_embedder_unittest.cc | 8 + components/ai_chat/core/browser/utils.cc | 24 +- components/ai_chat/core/browser/utils.h | 6 + .../ai_chat/core/browser/utils_unittest.cc | 3 + components/ai_chat/core/common/features.cc | 3 + components/ai_chat/core/common/features.h | 2 + components/ai_chat/core/common/mojom/BUILD.gn | 1 + .../ai_chat/core/common/mojom/ai_chat.mojom | 137 ++- .../core/common/mojom/untrusted_frame.mojom | 31 + components/ai_chat/core/common/pref_names.cc | 7 +- components/ai_chat/core/common/pref_names.h | 3 + .../core/common/pref_names_unittest.cc | 2 - components/ai_chat/core/common/utils.cc | 3 + .../ai_chat/core/common/utils_unittest.cc | 3 + components/ai_chat/renderer/DEPS | 2 + .../renderer/ai_chat_resource_sniffer.cc | 20 +- .../renderer/ai_chat_resource_sniffer.h | 10 + ...chat_resource_sniffer_throttle_unittest.cc | 32 +- .../renderer/page_content_extractor.cc | 28 +- .../ai_chat/renderer/page_content_extractor.h | 14 + .../ai_chat/renderer/page_text_distilling.cc | 20 +- .../ai_chat/renderer/page_text_distilling.h | 1 + components/ai_chat/renderer/yt_util.cc | 4 + components/ai_chat/renderer/yt_util.h | 2 + .../ai_chat/renderer/yt_util_unittest.cc | 1 - .../ai_chat/resources/{page => }/BUILD.gn | 19 +- .../resources/ai_chat_ui_resources.grdp | 9 + components/ai_chat/resources/common/api.ts | 36 + .../components/action_type_label/index.tsx | 40 +- .../action_type_label/style.module.scss | 0 components/ai_chat/resources/common/mojom.ts | 7 + .../ai_chat/resources/common/useAPIState.ts | 39 + .../ai_chat/resources/page/ai_chat_ui.html | 7 +- .../resources/page/ai_chat_ui_resources.grdp | 5 - .../ai_chat/resources/page/api/index.ts | 167 ++- components/ai_chat/resources/page/chat_ui.tsx | 161 +-- .../components/alerts/error_connection.tsx | 1 - .../alerts/error_conversation_end.tsx | 11 +- .../components/alerts/error_rate_limit.tsx | 1 - .../alerts/long_conversation_info.tsx | 6 +- .../alerts/warning_premium_disconnected.tsx | 1 - .../context_menu_assistant/index.tsx | 205 ---- .../components/conversations_list/index.tsx | 33 +- .../conversations_list/style.module.scss | 35 +- .../components/feature_button_menu/index.tsx | 9 +- .../page/components/feedback_form/index.tsx | 18 +- .../page/components/full_page/index.tsx | 20 +- .../components/full_page/style.module.scss | 11 +- .../page/components/header/index.tsx | 36 +- .../page/components/header/style.module.scss | 3 +- .../page/components/input_box/index.tsx | 10 +- .../components/input_box/style.module.scss | 7 +- .../page/components/loading/index.tsx | 16 + .../components/loading/loading.module.scss | 24 + .../resources/page/components/main/index.tsx | 219 ++-- .../page/components/main/style.module.scss | 68 +- .../page/components/model_intro/index.tsx | 9 +- .../notices/conversation_storage.svg | 31 + .../notices/notice_conversation_storage.tsx | 73 ++ .../components/notices/notices.module.scss | 90 ++ .../components/suggested_question/index.tsx | 66 ++ .../suggested_question/style.module.scss | 56 + .../components/tools_button_menu/index.tsx | 9 +- .../ai_chat/resources/page/model_utils.ts | 4 +- .../page/state/active_chat_context.tsx | 122 ++ .../resources/page/state/ai_chat_context.tsx | 164 +-- .../page/state/conversation_context.tsx | 128 +- .../resources/page/state/useSendFeedback.ts | 121 ++ .../page/stories/components_panel.tsx | 283 +++-- .../story_utils/ConversationEntries.tsx | 24 + .../page/stories/{ => story_utils}/actions.ts | 12 +- .../page/stories/{ => story_utils}/locale.ts | 14 +- components/ai_chat/resources/page/styles.css | 7 + .../components/assistant_response/index.tsx | 13 +- .../assistant_response/style.module.scss | 0 .../components/code_block/index.tsx | 8 +- .../components/code_block/style.module.scss | 0 .../context_menu_assistant/index.tsx | 102 ++ .../context_menu_assistant/style.module.scss | 0 .../components/conversation_entries/index.tsx | 175 +-- .../conversation_entries/style.module.scss | 30 - .../components/copy_button/index.tsx | 0 .../components/copy_button/style.module.scss | 0 .../components/edit_button/index.tsx | 0 .../components/edit_button/style.module.scss | 0 .../components/edit_indicator/index.tsx | 0 .../edit_indicator/style.module.scss | 0 .../components/edit_input/index.tsx | 0 .../components/edit_input/style.module.scss | 0 .../components/markdown_renderer/index.tsx | 0 .../markdown_renderer/style.module.scss | 0 .../page_context_message}/long_page_info.tsx | 23 +- .../page_context_message/style.module.scss | 14 + .../components/quote/index.tsx | 0 .../components/quote/style.module.scss | 0 .../components/svg/caret.tsx | 0 .../untrusted_conversation_frame/styles.css | 20 + .../tsconfig.json | 9 + .../untrusted_conversation_context.tsx | 49 + .../untrusted_conversation_frame.html | 27 + .../untrusted_conversation_frame.tsx | 29 + .../untrusted_conversation_frame_api.ts | 89 ++ .../panel/components/error-panel/index.tsx | 1 - .../panel/components/main-panel/index.tsx | 2 - .../purchase-failed-panel/index.tsx | 1 - .../desktop/wallet-banner/index.tsx | 1 - .../common/tx_warnings.tsx | 1 - .../extension/edit-gas/edit-gas.tsx | 3 +- .../common/speed_up_alert.tsx | 1 - .../add-custom-token-form/add-nft-form.tsx | 1 - .../nft/components/nft-details/nft-screen.tsx | 1 - components/common/useRoute.ts | 90 ++ components/constants/webui_url_constants.h | 9 +- .../browser/resources/playerEventSink.ts | 1 - components/resources/BUILD.gn | 4 +- components/resources/ai_chat_prompts.grdp | 55 + components/resources/ai_chat_ui_strings.grdp | 21 + .../resources/brave_components_resources.grd | 2 +- .../AIChat/Components/AIChatView.swift | 9 +- .../Messages/AIChatResponseMessageView.swift | 1 + .../AIChat/ModelView/AIChatViewModel.swift | 12 +- ios/browser/api/ai_chat/BUILD.gn | 12 +- ios/browser/api/ai_chat/ai_chat.h | 9 +- ios/browser/api/ai_chat/ai_chat.mm | 19 +- ios/browser/api/ai_chat/ai_chat_delegate.h | 1 + .../api/ai_chat/ai_chat_service_factory.mm | 4 +- ios/browser/api/ai_chat/conversation_client.h | 23 +- .../api/ai_chat/conversation_client.mm | 18 +- ios/browser/api/ai_chat/headers.gni | 1 + package-lock.json | 15 +- package.json | 2 +- ui/webui/resources/BUILD.gn | 2 + 244 files changed, 7768 insertions(+), 2024 deletions(-) create mode 100644 browser/ai_chat/ai_chat_urls.cc create mode 100644 browser/ai_chat/ai_chat_urls.h create mode 100644 browser/ui/webui/ai_chat/ai_chat_untrusted_conversation_ui.cc create mode 100644 browser/ui/webui/ai_chat/ai_chat_untrusted_conversation_ui.h create mode 100644 components/ai_chat/core/browser/ai_chat_database.cc create mode 100644 components/ai_chat/core/browser/ai_chat_database.h create mode 100644 components/ai_chat/core/browser/ai_chat_database_unittest.cc create mode 100644 components/ai_chat/core/browser/mock_conversation_handler_observer.cc create mode 100644 components/ai_chat/core/browser/mock_conversation_handler_observer.h create mode 100644 components/ai_chat/core/browser/test_utils.cc create mode 100644 components/ai_chat/core/browser/test_utils.h create mode 100644 components/ai_chat/core/common/mojom/untrusted_frame.mojom rename components/ai_chat/resources/{page => }/BUILD.gn (67%) create mode 100644 components/ai_chat/resources/ai_chat_ui_resources.grdp create mode 100644 components/ai_chat/resources/common/api.ts rename components/ai_chat/resources/{page => common}/components/action_type_label/index.tsx (81%) rename components/ai_chat/resources/{page => common}/components/action_type_label/style.module.scss (100%) create mode 100644 components/ai_chat/resources/common/mojom.ts create mode 100644 components/ai_chat/resources/common/useAPIState.ts delete mode 100644 components/ai_chat/resources/page/ai_chat_ui_resources.grdp delete mode 100644 components/ai_chat/resources/page/components/context_menu_assistant/index.tsx create mode 100644 components/ai_chat/resources/page/components/loading/index.tsx create mode 100644 components/ai_chat/resources/page/components/loading/loading.module.scss create mode 100644 components/ai_chat/resources/page/components/notices/conversation_storage.svg create mode 100644 components/ai_chat/resources/page/components/notices/notice_conversation_storage.tsx create mode 100644 components/ai_chat/resources/page/components/notices/notices.module.scss create mode 100644 components/ai_chat/resources/page/components/suggested_question/index.tsx create mode 100644 components/ai_chat/resources/page/components/suggested_question/style.module.scss create mode 100644 components/ai_chat/resources/page/state/active_chat_context.tsx create mode 100644 components/ai_chat/resources/page/state/useSendFeedback.ts create mode 100644 components/ai_chat/resources/page/stories/story_utils/ConversationEntries.tsx rename components/ai_chat/resources/page/stories/{ => story_utils}/actions.ts (72%) rename components/ai_chat/resources/page/stories/{ => story_utils}/locale.ts (84%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/assistant_response/index.tsx (89%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/assistant_response/style.module.scss (100%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/code_block/index.tsx (95%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/code_block/style.module.scss (100%) create mode 100644 components/ai_chat/resources/untrusted_conversation_frame/components/context_menu_assistant/index.tsx rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/context_menu_assistant/style.module.scss (100%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/conversation_entries/index.tsx (52%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/conversation_entries/style.module.scss (79%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/copy_button/index.tsx (100%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/copy_button/style.module.scss (100%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/edit_button/index.tsx (100%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/edit_button/style.module.scss (100%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/edit_indicator/index.tsx (100%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/edit_indicator/style.module.scss (100%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/edit_input/index.tsx (100%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/edit_input/style.module.scss (100%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/markdown_renderer/index.tsx (100%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/markdown_renderer/style.module.scss (100%) rename components/ai_chat/resources/{page/components/alerts => untrusted_conversation_frame/components/page_context_message}/long_page_info.tsx (58%) create mode 100644 components/ai_chat/resources/untrusted_conversation_frame/components/page_context_message/style.module.scss rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/quote/index.tsx (100%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/quote/style.module.scss (100%) rename components/ai_chat/resources/{page => untrusted_conversation_frame}/components/svg/caret.tsx (100%) create mode 100644 components/ai_chat/resources/untrusted_conversation_frame/styles.css create mode 100644 components/ai_chat/resources/untrusted_conversation_frame/tsconfig.json create mode 100644 components/ai_chat/resources/untrusted_conversation_frame/untrusted_conversation_context.tsx create mode 100644 components/ai_chat/resources/untrusted_conversation_frame/untrusted_conversation_frame.html create mode 100644 components/ai_chat/resources/untrusted_conversation_frame/untrusted_conversation_frame.tsx create mode 100644 components/ai_chat/resources/untrusted_conversation_frame/untrusted_conversation_frame_api.ts create mode 100644 components/common/useRoute.ts diff --git a/app/brave_settings_strings.grdp b/app/brave_settings_strings.grdp index 0f203fc81152..0b2e115fdae0 100644 --- a/app/brave_settings_strings.grdp +++ b/app/brave_settings_strings.grdp @@ -96,7 +96,7 @@ Position - Expand vertical tabs independently per window + Expand vertical tabs independently per window Show scrollbar @@ -1332,23 +1332,29 @@ Show suggested prompts in the conversation - - Clear Leo data + + Store my conversation history + + + Disabling conversation storage will permanently erase all previously stored conversations. This action can't be undone. + + + Delete all Leo AI conversation data - Resetting the Leo assistant will require you to opt-in to use Brave -Leo in the future and will also clear your chat history. Clearing your + Resetting Leo AI will require you to opt-in to use Brave +Leo AI in the future and will also clear your chat history. Clearing your chat history will delete all your previous conversations with Brave -Leo. This action can't be undone. +Leo AI. This action can't be undone. Adjust autocomplete suggestions - - Leo + + Leo AI - - Data and chat history + + Chat history Default model for new conversations diff --git a/browser/ai_chat/BUILD.gn b/browser/ai_chat/BUILD.gn index d6cfffee0c94..f9af82632195 100644 --- a/browser/ai_chat/BUILD.gn +++ b/browser/ai_chat/BUILD.gn @@ -11,6 +11,8 @@ static_library("ai_chat") { "ai_chat_service_factory.h", "ai_chat_settings_helper.cc", "ai_chat_settings_helper.h", + "ai_chat_urls.cc", + "ai_chat_urls.h", "ai_chat_utils.cc", "ai_chat_utils.h", ] @@ -24,6 +26,7 @@ static_library("ai_chat") { "//brave/components/ai_chat/core/browser", "//brave/components/ai_chat/core/common", "//brave/components/ai_chat/core/common/mojom", + "//brave/components/constants", "//brave/components/resources:strings_grit", "//brave/net/base:utils", "//chrome/browser:browser_process", diff --git a/browser/ai_chat/ai_chat_service_factory.cc b/browser/ai_chat/ai_chat_service_factory.cc index 03f7829b4d09..c7a904132aac 100644 --- a/browser/ai_chat/ai_chat_service_factory.cc +++ b/browser/ai_chat/ai_chat_service_factory.cc @@ -77,9 +77,10 @@ AIChatServiceFactory::BuildServiceInstanceForBrowserContext( (g_brave_browser_process->process_misc_metrics()) ? g_brave_browser_process->process_misc_metrics()->ai_chat_metrics() : nullptr, + g_browser_process->os_crypt_async(), context->GetDefaultStoragePartition() ->GetURLLoaderFactoryForBrowserProcess(), - version_info::GetChannelString(chrome::GetChannel())); + version_info::GetChannelString(chrome::GetChannel()), context->GetPath()); } } // namespace ai_chat diff --git a/browser/ai_chat/ai_chat_throttle_unittest.cc b/browser/ai_chat/ai_chat_throttle_unittest.cc index aa7dbe8e1412..4cf514656885 100644 --- a/browser/ai_chat/ai_chat_throttle_unittest.cc +++ b/browser/ai_chat/ai_chat_throttle_unittest.cc @@ -3,11 +3,13 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at https://mozilla.org/MPL/2.0/. */ +#include "brave/components/ai_chat/content/browser/ai_chat_throttle.h" + #include #include "base/test/scoped_feature_list.h" -#include "brave/components/ai_chat/content/browser/ai_chat_throttle.h" #include "brave/components/ai_chat/core/common/features.h" +#include "brave/components/constants/webui_url_constants.h" #include "chrome/test/base/testing_browser_process.h" #include "chrome/test/base/testing_profile.h" #include "chrome/test/base/testing_profile_manager.h" @@ -20,7 +22,9 @@ namespace ai_chat { namespace { + constexpr char kTestProfileName[] = "TestProfile"; + } // namespace class AiChatThrottleUnitTest : public testing::Test, @@ -67,14 +71,13 @@ INSTANTIATE_TEST_SUITE_P( AiChatThrottleUnitTest, ::testing::Bool(), [](const testing::TestParamInfo& info) { - return base::StringPrintf("History%s", - info.param ? "Enabled" : "Disabled"); + return base::StrCat({"History", info.param ? "Enabled" : "Disabled"}); }); TEST_P(AiChatThrottleUnitTest, CancelNavigationFromTab) { content::MockNavigationHandle test_handle(web_contents()); - test_handle.set_url(GURL("chrome-untrusted://chat")); + test_handle.set_url(GURL(kAIChatUIURL)); #if BUILDFLAG(IS_ANDROID) ui::PageTransition transition = ui::PageTransitionFromInt( @@ -99,10 +102,33 @@ TEST_P(AiChatThrottleUnitTest, CancelNavigationFromTab) { } } +TEST_P(AiChatThrottleUnitTest, CancelNavigationToFrame) { + content::MockNavigationHandle test_handle(web_contents()); + + test_handle.set_url(GURL(kAIChatUntrustedConversationUIURL)); + +#if BUILDFLAG(IS_ANDROID) + ui::PageTransition transition = ui::PageTransitionFromInt( + ui::PageTransition::PAGE_TRANSITION_FROM_ADDRESS_BAR); +#else + ui::PageTransition transition = ui::PageTransitionFromInt( + ui::PageTransition::PAGE_TRANSITION_FROM_ADDRESS_BAR | + ui::PageTransition::PAGE_TRANSITION_TYPED); +#endif + + test_handle.set_page_transition(transition); + + std::unique_ptr throttle = + AiChatThrottle::MaybeCreateThrottleFor(&test_handle); + + EXPECT_EQ(content::NavigationThrottle::CANCEL_AND_IGNORE, + throttle->WillStartRequest().action()); +} + TEST_P(AiChatThrottleUnitTest, AllowNavigationFromPanel) { content::MockNavigationHandle test_handle(web_contents()); - test_handle.set_url(GURL("chrome-untrusted://chat")); + test_handle.set_url(GURL(kAIChatUIURL)); #if BUILDFLAG(IS_ANDROID) ui::PageTransition transition = diff --git a/browser/ai_chat/ai_chat_urls.cc b/browser/ai_chat/ai_chat_urls.cc new file mode 100644 index 000000000000..ed9b4e66399c --- /dev/null +++ b/browser/ai_chat/ai_chat_urls.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/browser/ai_chat/ai_chat_urls.h" + +#include + +#include "base/strings/strcat.h" +#include "base/strings/string_util.h" +#include "brave/components/constants/webui_url_constants.h" +#include "url/gurl.h" + +namespace ai_chat { + +GURL TabAssociatedConversationUrl() { + return GURL(base::StrCat({kAIChatUIURL, "tab"})); +} + +GURL ConversationUrl(std::string_view conversation_uuid) { + return GURL(base::StrCat({kAIChatUIURL, conversation_uuid})); +} + +std::string_view ConversationUUIDFromURL(const GURL& url) { + return base::TrimString(url.path_piece(), "/", base::TrimPositions::TRIM_ALL); +} + +} // namespace ai_chat diff --git a/browser/ai_chat/ai_chat_urls.h b/browser/ai_chat/ai_chat_urls.h new file mode 100644 index 000000000000..1cf4c970e34e --- /dev/null +++ b/browser/ai_chat/ai_chat_urls.h @@ -0,0 +1,29 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_BROWSER_AI_CHAT_AI_CHAT_URLS_H_ +#define BRAVE_BROWSER_AI_CHAT_AI_CHAT_URLS_H_ + +#include + +#include "url/gurl.h" + +namespace ai_chat { + +// UI that will open a conversation associated with the active Tab in the same +// browser window. The conversation will change when that Tab navigates. +GURL TabAssociatedConversationUrl(); + +// UI that will open to a specific conversation. The conversation will not +// change upon any navigation. +GURL ConversationUrl(std::string_view conversation_uuid); + +// Extracts the conversation UUID from a conversation URL or a conversation +// entries iframe +std::string_view ConversationUUIDFromURL(const GURL& url); + +} // namespace ai_chat + +#endif // BRAVE_BROWSER_AI_CHAT_AI_CHAT_URLS_H_ diff --git a/browser/ai_chat/android/ai_chat_utils_android.cc b/browser/ai_chat/android/ai_chat_utils_android.cc index 03ded3c5b2f0..de44b68399f1 100644 --- a/browser/ai_chat/android/ai_chat_utils_android.cc +++ b/browser/ai_chat/android/ai_chat_utils_android.cc @@ -46,7 +46,7 @@ static void JNI_BraveLeoUtils_OpenLeoQuery( // Send the query conversation->MaybeUnlinkAssociatedContent(); mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, base::android::ConvertJavaStringToUTF8(query), std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false); diff --git a/browser/brave_content_browser_client.cc b/browser/brave_content_browser_client.cc index a415b4bd5311..eeccb4ae68ba 100644 --- a/browser/brave_content_browser_client.cc +++ b/browser/brave_content_browser_client.cc @@ -38,6 +38,7 @@ #include "brave/browser/skus/skus_service_factory.h" #include "brave/browser/ui/brave_ui_features.h" #include "brave/browser/ui/webui/ai_chat/ai_chat_ui.h" +#include "brave/browser/ui/webui/ai_chat/ai_chat_untrusted_conversation_ui.h" #include "brave/browser/ui/webui/brave_rewards/rewards_page_ui.h" #include "brave/browser/ui/webui/skus_internals_ui.h" #include "brave/browser/url_sanitizer/url_sanitizer_service_factory.h" @@ -49,6 +50,7 @@ #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/ai_chat/core/common/mojom/page_content_extractor.mojom.h" #include "brave/components/ai_chat/core/common/mojom/settings_helper.mojom.h" +#include "brave/components/ai_chat/core/common/mojom/untrusted_frame.mojom.h" #include "brave/components/ai_rewriter/common/buildflags/buildflags.h" #include "brave/components/body_sniffer/body_sniffer_throttle.h" #include "brave/components/brave_federated/features.h" @@ -620,6 +622,9 @@ void BraveContentBrowserClient::RegisterWebUIInterfaceBrokers( registry.ForWebUI() .Add() .Add(); + registry.ForWebUI() + .Add() + .Add(); } #if BUILDFLAG(ENABLE_AI_REWRITER) diff --git a/browser/browsing_data/brave_browsing_data_remover_delegate.cc b/browser/browsing_data/brave_browsing_data_remover_delegate.cc index d96808085fe2..10ee2fcbec27 100644 --- a/browser/browsing_data/brave_browsing_data_remover_delegate.cc +++ b/browser/browsing_data/brave_browsing_data_remover_delegate.cc @@ -8,7 +8,9 @@ #include #include +#include "brave/browser/ai_chat/ai_chat_service_factory.h" #include "brave/browser/brave_news/brave_news_controller_factory.h" +#include "brave/components/ai_chat/core/browser/ai_chat_service.h" #include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/features.h" #include "brave/components/brave_news/browser/brave_news_controller.h" @@ -19,7 +21,9 @@ #include "chrome/browser/content_settings/host_content_settings_map_factory.h" #include "chrome/browser/profiles/profile.h" #include "chrome/common/buildflags.h" +#include "components/browsing_data/content/browsing_data_helper.h" #include "components/content_settings/core/browser/host_content_settings_map.h" +#include "content/public/browser/browsing_data_remover.h" BraveBrowsingDataRemoverDelegate::BraveBrowsingDataRemoverDelegate( content::BrowserContext* browser_context) @@ -48,19 +52,31 @@ void BraveBrowsingDataRemoverDelegate::RemoveEmbedderData( ClearShieldsSettings(delete_begin, delete_end); } - // Brave News feed cache if (remove_mask & chrome_browsing_data_remover::DATA_TYPE_HISTORY) { + // Brave News feed cache if (auto* brave_news_controller = brave_news::BraveNewsControllerFactory::GetForBrowserContext( profile_)) { brave_news_controller->ClearHistory(); } + // AI Chat history but only associated content, not neccessary if we + // are also deleting entire AI Chat history. + if (!(remove_mask & + chrome_browsing_data_remover::DATA_TYPE_BRAVE_LEO_HISTORY)) { + ai_chat::AIChatService* ai_chat_service = + ai_chat::AIChatServiceFactory::GetForBrowserContext(profile_); + if (ai_chat_service) { + ai_chat_service->DeleteAssociatedWebContent(delete_begin, delete_end); + } + } } - if (remove_mask & chrome_browsing_data_remover::DATA_TYPE_BRAVE_LEO_HISTORY && - ai_chat::IsAIChatEnabled(profile_->GetPrefs()) && - ai_chat::features::IsAIChatHistoryEnabled()) { - ClearAiChatHistory(delete_begin, delete_end); + if (remove_mask & chrome_browsing_data_remover::DATA_TYPE_BRAVE_LEO_HISTORY) { + ai_chat::AIChatService* ai_chat_service = + ai_chat::AIChatServiceFactory::GetForBrowserContext(profile_); + if (ai_chat_service) { + ai_chat_service->DeleteConversations(delete_begin, delete_end); + } } } @@ -92,9 +108,3 @@ void BraveBrowsingDataRemoverDelegate::ClearShieldsSettings( } } } - -void BraveBrowsingDataRemoverDelegate::ClearAiChatHistory(base::Time begin_time, - base::Time end_time) { - // Handler for the Brave Leo History clearing. - // It is prepared for future implementation. -} diff --git a/browser/browsing_data/brave_browsing_data_remover_delegate.h b/browser/browsing_data/brave_browsing_data_remover_delegate.h index 9b2b383eba92..595dd7b71120 100644 --- a/browser/browsing_data/brave_browsing_data_remover_delegate.h +++ b/browser/browsing_data/brave_browsing_data_remover_delegate.h @@ -43,7 +43,6 @@ class BraveBrowsingDataRemoverDelegate base::OnceCallback callback) override; void ClearShieldsSettings(base::Time begin_time, base::Time end_time); - void ClearAiChatHistory(base::Time begin_time, base::Time end_time); raw_ptr profile_ = nullptr; base::WeakPtrFactory weak_ptr_factory_{ diff --git a/browser/browsing_data/sources.gni b/browser/browsing_data/sources.gni index fb8c2d5f35ef..21506e893c3a 100644 --- a/browser/browsing_data/sources.gni +++ b/browser/browsing_data/sources.gni @@ -14,6 +14,8 @@ brave_browser_browsing_data_sources = [ brave_browser_browsing_data_deps = [ "//base", + "//brave/browser/ai_chat", + "//brave/components/ai_chat/core/browser", "//chrome/browser:browser_process", "//chrome/browser/browsing_data:constants", "//chrome/browser/profiles:profile", diff --git a/browser/extensions/api/settings_private/brave_prefs_util.cc b/browser/extensions/api/settings_private/brave_prefs_util.cc index 0b7f57212267..bb19a15548fd 100644 --- a/browser/extensions/api/settings_private/brave_prefs_util.cc +++ b/browser/extensions/api/settings_private/brave_prefs_util.cc @@ -242,6 +242,8 @@ const PrefsUtil::TypedPrefMap& BravePrefsUtil::GetAllowlistedKeys() { settings_api::PrefType::kBoolean; // Leo Assistant pref + (*s_brave_allowlist)[ai_chat::prefs::kStorageEnabled] = + settings_api::PrefType::kBoolean; (*s_brave_allowlist)[ai_chat::prefs::kBraveChatAutocompleteProviderEnabled] = settings_api::PrefType::kBoolean; (*s_brave_allowlist)[ai_chat::prefs::kBraveAIChatContextMenuEnabled] = diff --git a/browser/resources/settings/brave_clear_browsing_data_dialog/brave_clear_browsing_data_on_exit_page.html b/browser/resources/settings/brave_clear_browsing_data_dialog/brave_clear_browsing_data_on_exit_page.html index fac819e1c81f..0e686ce59b29 100644 --- a/browser/resources/settings/brave_clear_browsing_data_dialog/brave_clear_browsing_data_on_exit_page.html +++ b/browser/resources/settings/brave_clear_browsing_data_dialog/brave_clear_browsing_data_on_exit_page.html @@ -30,8 +30,8 @@ diff --git a/browser/resources/settings/brave_clear_browsing_data_dialog/brave_clear_browsing_data_on_exit_page.ts b/browser/resources/settings/brave_clear_browsing_data_dialog/brave_clear_browsing_data_on_exit_page.ts index dc427c1ad1d7..59880cb95495 100644 --- a/browser/resources/settings/brave_clear_browsing_data_dialog/brave_clear_browsing_data_on_exit_page.ts +++ b/browser/resources/settings/brave_clear_browsing_data_dialog/brave_clear_browsing_data_on_exit_page.ts @@ -43,7 +43,7 @@ Polymer({ }, }, - isLeoAssistantAndHistoryAllowed_: { + isAIChatAssistantAndHistoryAllowed_: { type: Boolean, value: function() { return loadTimeData.getBoolean('isLeoAssistantAllowed') diff --git a/browser/resources/settings/brave_leo_assistant_page/brave_leo_assistant_page.html b/browser/resources/settings/brave_leo_assistant_page/brave_leo_assistant_page.html index 8aa45c5d8fd2..e61dbc03e965 100644 --- a/browser/resources/settings/brave_leo_assistant_page/brave_leo_assistant_page.html +++ b/browser/resources/settings/brave_leo_assistant_page/brave_leo_assistant_page.html @@ -115,8 +115,15 @@ -
+ diff --git a/browser/resources/settings/brave_leo_assistant_page/brave_leo_assistant_page.ts b/browser/resources/settings/brave_leo_assistant_page/brave_leo_assistant_page.ts index ad211efbae64..236662a75dbd 100644 --- a/browser/resources/settings/brave_leo_assistant_page/brave_leo_assistant_page.ts +++ b/browser/resources/settings/brave_leo_assistant_page/brave_leo_assistant_page.ts @@ -4,16 +4,19 @@ // You can obtain one at https://mozilla.org/MPL/2.0/. import '//resources/cr_elements/md_select.css.js' -import {PolymerElement} from 'chrome://resources/polymer/v3_0/polymer/polymer_bundled.min.js'; -import {WebUiListenerMixin} from 'chrome://resources/cr_elements/web_ui_listener_mixin.js'; -import {PrefsMixin} from '/shared/settings/prefs/prefs_mixin.js'; -import {I18nMixin} from 'chrome://resources/cr_elements/i18n_mixin.js'; +import 'chrome://resources/brave/leo.bundle.js' +import {assert} from 'chrome://resources/js/assert.js'; +import {I18nMixin} from 'chrome://resources/cr_elements/i18n_mixin.js' +import {PolymerElement} from 'chrome://resources/polymer/v3_0/polymer/polymer_bundled.min.js' +import {WebUiListenerMixin} from 'chrome://resources/cr_elements/web_ui_listener_mixin.js' +import {PrefsMixin} from '/shared/settings/prefs/prefs_mixin.js' +import {SettingsToggleButtonElement} from '../controls/settings_toggle_button.js' +import {Router} from '../router.js' +import {loadTimeData} from '../i18n_setup.js' +import {routes} from '../route.js'; import {getTemplate} from './brave_leo_assistant_page.html.js' import {BraveLeoAssistantBrowserProxy, BraveLeoAssistantBrowserProxyImpl, PremiumStatus, ModelWithSubtitle, PremiumInfo, ModelAccess, Model} from './brave_leo_assistant_browser_proxy.js' -import 'chrome://resources/brave/leo.bundle.js' -import { Router } from '../router.js'; -import {routes} from '../route.js'; const BraveLeoAssistantPageBase = WebUiListenerMixin(I18nMixin(PrefsMixin(PolymerElement))) @@ -33,6 +36,10 @@ class BraveLeoAssistantPageElement extends BraveLeoAssistantPageBase { static get properties() { return { + prefs: { + type: Object, + notify: true, + }, leoAssistantShowOnToolbarPref_: { type: Boolean, value: false, @@ -51,6 +58,10 @@ class BraveLeoAssistantPageElement extends BraveLeoAssistantPageBase { } private isPremiumUser_: boolean + + isHistoryFeatureEnabled_: boolean = + loadTimeData.getBoolean('isLeoAssistantHistoryAllowed') + leoAssistantShowOnToolbarPref_: boolean defaultModelKeyPrefValue_: string models_: ModelWithSubtitle[] @@ -180,6 +191,18 @@ class BraveLeoAssistantPageElement extends BraveLeoAssistantPageBase { openManageAccountPage_() { window.open(this.manageUrl_, "_self", "noopener noreferrer") } + + private onStorageEnabledChange_(event: Event) { + const target = event.target + assert(target instanceof SettingsToggleButtonElement); + // Confirm that the user knows conversation history will be permanently + // deleted. + if (!target?.checked) { + if (!confirm(this.i18n('braveLeoAssistantHistoryPreferenceConfirm'))) { + target.checked = !target.checked + } + } + } } customElements.define( diff --git a/browser/resources/settings/brave_overrides/clear_browsing_data_dialog.ts b/browser/resources/settings/brave_overrides/clear_browsing_data_dialog.ts index 8d298bc3ceb8..63352e217a00 100644 --- a/browser/resources/settings/brave_overrides/clear_browsing_data_dialog.ts +++ b/browser/resources/settings/brave_overrides/clear_browsing_data_dialog.ts @@ -133,8 +133,8 @@ RegisterPolymerTemplateModifications({ `) diff --git a/browser/resources/settings/brave_overrides/settings_menu.ts b/browser/resources/settings/brave_overrides/settings_menu.ts index 215e787ca80b..d2e3f5d37c25 100644 --- a/browser/resources/settings/brave_overrides/settings_menu.ts +++ b/browser/resources/settings/brave_overrides/settings_menu.ts @@ -266,7 +266,7 @@ RegisterPolymerTemplateModifications({ // Add leo item const leoAssistantEl = createMenuElement( loadTimeData.getString('leoAssistant'), - '/leo-assistant', + '/leo-ai', 'product-brave-leo', 'leoAssistant', ) diff --git a/browser/resources/settings/brave_routes.ts b/browser/resources/settings/brave_routes.ts index bce4d58943d9..93d1dd16f96e 100644 --- a/browser/resources/settings/brave_routes.ts +++ b/browser/resources/settings/brave_routes.ts @@ -53,7 +53,7 @@ export default function addBraveRoutes(r: Partial) { if (pageVisibility.leoAssistant) { r.BRAVE_LEO_ASSISTANT = - r.BASIC.createSection('/leo-assistant', 'leoAssistant') + r.BASIC.createSection('/leo-ai', 'leoAssistant') } if (pageVisibility.content) { r.BRAVE_CONTENT = r.BASIC.createSection('/braveContent', 'content') diff --git a/browser/resources/settings/sources.gni b/browser/resources/settings/sources.gni index 5c3449fbefd0..9f143ef87d0f 100644 --- a/browser/resources/settings/sources.gni +++ b/browser/resources/settings/sources.gni @@ -156,6 +156,7 @@ brave_settings_ts_extra_deps = brave_settings_mojo_files = [ "$root_gen_dir/brave/components/ai_chat/core/common/mojom/settings_helper.mojom-webui.ts", "$root_gen_dir/brave/components/ai_chat/core/common/mojom/ai_chat.mojom-webui.ts", + "$root_gen_dir/brave/components/ai_chat/core/common/mojom/untrusted_frame.mojom-webui.ts", ] brave_settings_mojo_files_deps = diff --git a/browser/ui/BUILD.gn b/browser/ui/BUILD.gn index 89cec2e2ac15..4ebcc8f0943b 100644 --- a/browser/ui/BUILD.gn +++ b/browser/ui/BUILD.gn @@ -66,6 +66,8 @@ source_set("ui") { "webui/ai_chat/ai_chat_ui.h", "webui/ai_chat/ai_chat_ui_page_handler.cc", "webui/ai_chat/ai_chat_ui_page_handler.h", + "webui/ai_chat/ai_chat_untrusted_conversation_ui.cc", + "webui/ai_chat/ai_chat_untrusted_conversation_ui.h", "webui/brave_adblock_internals_ui.cc", "webui/brave_adblock_internals_ui.h", "webui/brave_adblock_ui.cc", @@ -763,7 +765,7 @@ source_set("ui") { "//brave/components/ai_chat/core/browser", "//brave/components/ai_chat/core/common", "//brave/components/ai_chat/core/common/mojom", - "//brave/components/ai_chat/resources/page:generated_resources", + "//brave/components/ai_chat/resources", "//brave/components/ai_rewriter/common/buildflags", "//brave/components/brave_adaptive_captcha", "//brave/components/brave_adblock_ui:generated_resources", diff --git a/browser/ui/brave_pages.cc b/browser/ui/brave_pages.cc index f5e8351be9fb..279169832115 100644 --- a/browser/ui/brave_pages.cc +++ b/browser/ui/brave_pages.cc @@ -50,7 +50,7 @@ void ShowFullpageChat(Browser* browser) { if (!ai_chat::features::IsAIChatHistoryEnabled()) { return; } - ShowSingletonTabOverwritingNTP(browser, GURL(kChatUIURL)); + ShowSingletonTabOverwritingNTP(browser, GURL(kAIChatUIURL)); } void ShowWebcompatReporter(Browser* browser) { diff --git a/browser/ui/views/side_panel/brave_side_panel_utils.cc b/browser/ui/views/side_panel/brave_side_panel_utils.cc index 9057767d2250..9be0e0c85032 100644 --- a/browser/ui/views/side_panel/brave_side_panel_utils.cc +++ b/browser/ui/views/side_panel/brave_side_panel_utils.cc @@ -3,9 +3,9 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at https://mozilla.org/MPL/2.0/. */ +#include "brave/browser/ai_chat/ai_chat_urls.h" #include "brave/browser/ui/webui/ai_chat/ai_chat_ui.h" #include "brave/components/ai_chat/core/browser/utils.h" -#include "brave/components/constants/webui_url_constants.h" #include "chrome/browser/profiles/profile.h" #include "chrome/browser/ui/views/side_panel/side_panel_registry.h" #include "chrome/browser/ui/views/side_panel/side_panel_web_ui_view.h" @@ -28,7 +28,7 @@ std::unique_ptr CreateAIChatSidePanelWebView( auto web_view = std::make_unique>( scope, base::RepeatingClosure(), base::RepeatingClosure(), std::make_unique>( - GURL(kChatUIURL), profile.get(), + ai_chat::TabAssociatedConversationUrl(), profile.get(), IDS_SIDEBAR_CHAT_SUMMARIZER_ITEM_TITLE, /*esc_closes_ui=*/false)); web_view->ShowUI(); diff --git a/browser/ui/webui/ai_chat/ai_chat_ui.cc b/browser/ui/webui/ai_chat/ai_chat_ui.cc index a6e48ebe79ac..0c3d40894c54 100644 --- a/browser/ui/webui/ai_chat/ai_chat_ui.cc +++ b/browser/ui/webui/ai_chat/ai_chat_ui.cc @@ -17,17 +17,22 @@ #include "brave/components/ai_chat/core/common/features.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/ai_chat/core/common/pref_names.h" -#include "brave/components/ai_chat/resources/page/grit/ai_chat_ui_generated_map.h" +#include "brave/components/ai_chat/resources/grit/ai_chat_ui_generated_map.h" #include "brave/components/constants/webui_url_constants.h" #include "brave/components/l10n/common/localization_util.h" #include "chrome/browser/profiles/profile.h" +#include "chrome/browser/ui/tabs/tab_model.h" +#include "chrome/browser/ui/webui/favicon_source.h" #include "chrome/browser/ui/webui/webui_util.h" +#include "components/favicon_base/favicon_url_parser.h" #include "components/grit/brave_components_resources.h" #include "components/prefs/pref_service.h" #include "components/user_prefs/user_prefs.h" #include "content/public/browser/web_contents.h" +#include "content/public/browser/web_ui_controller.h" #include "content/public/browser/web_ui_data_source.h" #include "content/public/common/url_constants.h" +#include "ui/webui/mojo_web_ui_controller.h" #if !BUILDFLAG(IS_ANDROID) #include "chrome/browser/ui/browser.h" @@ -57,65 +62,51 @@ content::WebContents* GetActiveWebContents(content::BrowserContext* context) { #endif AIChatUI::AIChatUI(content::WebUI* web_ui) - : ui::UntrustedWebUIController(web_ui), - profile_(Profile::FromWebUI(web_ui)) { + : ui::MojoWebUIController(web_ui), profile_(Profile::FromWebUI(web_ui)) { DCHECK(profile_); DCHECK(profile_->IsRegularProfile()); // Create a URLDataSource and add resources. - content::WebUIDataSource* untrusted_source = - content::WebUIDataSource::CreateAndAdd( - web_ui->GetWebContents()->GetBrowserContext(), kChatUIURL); + content::WebUIDataSource* source = content::WebUIDataSource::CreateAndAdd( + web_ui->GetWebContents()->GetBrowserContext(), kAIChatUIHost); - webui::SetupWebUIDataSource( - untrusted_source, - UNSAFE_TODO(base::make_span(kAiChatUiGenerated, kAiChatUiGeneratedSize)), - IDR_CHAT_UI_HTML); + webui::SetupWebUIDataSource(source, kAiChatUiGenerated, IDR_AI_CHAT_UI_HTML); - untrusted_source->AddResourcePath("styles.css", IDR_CHAT_UI_CSS); + source->AddResourcePath("styles.css", IDR_AI_CHAT_UI_CSS); for (const auto& str : ai_chat::GetLocalizedStrings()) { - untrusted_source->AddString( - str.name, brave_l10n::GetLocalizedResourceUTF16String(str.id)); + source->AddString(str.name, + brave_l10n::GetLocalizedResourceUTF16String(str.id)); } - base::Time last_accepted_disclaimer = - profile_->GetOriginalProfile()->GetPrefs()->GetTime( - ai_chat::prefs::kLastAcceptedDisclaimer); + constexpr bool kIsMobile = BUILDFLAG(IS_ANDROID) || BUILDFLAG(IS_IOS); + source->AddBoolean("isMobile", kIsMobile); + source->AddBoolean("isHistoryEnabled", + ai_chat::features::IsAIChatHistoryEnabled()); - untrusted_source->AddBoolean("hasAcceptedAgreement", - !last_accepted_disclaimer.is_null()); - -#if BUILDFLAG(IS_ANDROID) || BUILDFLAG(IS_IOS) - constexpr bool kIsMobile = true; -#else - constexpr bool kIsMobile = false; -#endif - - untrusted_source->AddBoolean("isMobile", kIsMobile); - untrusted_source->AddBoolean("isHistoryEnabled", - ai_chat::features::IsAIChatHistoryEnabled()); - - untrusted_source->AddBoolean( - "hasUserDismissedPremiumPrompt", - profile_->GetOriginalProfile()->GetPrefs()->GetBoolean( - ai_chat::prefs::kUserDismissedPremiumPrompt)); - - untrusted_source->OverrideContentSecurityPolicy( + web_ui->AddRequestableScheme(content::kChromeUIUntrustedScheme); + source->OverrideContentSecurityPolicy( network::mojom::CSPDirectiveName::ScriptSrc, - "script-src 'self' chrome-untrusted://resources;"); - untrusted_source->OverrideContentSecurityPolicy( + "script-src 'self' chrome://resources;"); + source->OverrideContentSecurityPolicy( network::mojom::CSPDirectiveName::StyleSrc, - "style-src 'self' 'unsafe-inline' chrome-untrusted://resources;"); - untrusted_source->OverrideContentSecurityPolicy( + "style-src 'self' 'unsafe-inline' chrome://resources;"); + source->OverrideContentSecurityPolicy( network::mojom::CSPDirectiveName::ImgSrc, - "img-src 'self' blob: chrome-untrusted://resources;"); - untrusted_source->OverrideContentSecurityPolicy( + "img-src 'self' blob: chrome://resources chrome://favicon2;"); + source->OverrideContentSecurityPolicy( network::mojom::CSPDirectiveName::FontSrc, - "font-src 'self' data: chrome-untrusted://resources;"); + "font-src 'self' chrome://resources;"); + source->OverrideContentSecurityPolicy( + network::mojom::CSPDirectiveName::ChildSrc, + base::StringPrintf("child-src %s;", kAIChatUntrustedConversationUIURL)); - untrusted_source->OverrideContentSecurityPolicy( + source->OverrideContentSecurityPolicy( network::mojom::CSPDirectiveName::TrustedTypes, "trusted-types default;"); + + content::URLDataSource::Add( + profile_, std::make_unique( + profile_, chrome::FaviconUrlFormat::kFavicon2)); } AIChatUI::~AIChatUI() = default; @@ -128,21 +119,26 @@ void AIChatUI::BindInterface( if (embedder_) { embedder_->ShowUI(); } - + // Get the WebContents which SidePanel mode should be associated with content::WebContents* web_contents = nullptr; #if !BUILDFLAG(IS_ANDROID) Browser* browser = ai_chat::GetBrowserForWebContents(web_ui()->GetWebContents()); - if (!browser) { - return; + if (browser) { + TabStripModel* tab_strip_model = browser->tab_strip_model(); + if (tab_strip_model) { + // If this WebUI is a main tab, we never want to be associated with + // the active tab + if (tab_strip_model->GetIndexOfWebContents(web_ui()->GetWebContents()) == + TabStripModel::kNoTab) { + web_contents = tab_strip_model->GetActiveWebContents(); + } + } } - - TabStripModel* tab_strip_model = browser->tab_strip_model(); - DCHECK(tab_strip_model); - web_contents = tab_strip_model->GetActiveWebContents(); #else web_contents = GetActiveWebContents(profile_); #endif + // Don't associate with the WebUI's WebContents if (web_contents == web_ui()->GetWebContents()) { web_contents = nullptr; } @@ -156,28 +152,34 @@ void AIChatUI::BindInterface( std::move(receiver)); } -bool UntrustedChatUIConfig::IsWebUIEnabled( - content::BrowserContext* browser_context) { +void AIChatUI::BindInterface( + mojo::PendingReceiver + parent_ui_frame_receiver) { + CHECK(page_handler_); + page_handler_->BindParentUIFrameFromChildFrame( + std::move(parent_ui_frame_receiver)); +} + +bool AIChatUIConfig::IsWebUIEnabled(content::BrowserContext* browser_context) { return ai_chat::IsAIChatEnabled( user_prefs::UserPrefs::Get(browser_context)) && Profile::FromBrowserContext(browser_context)->IsRegularProfile(); } #if BUILDFLAG(IS_ANDROID) -std::unique_ptr -UntrustedChatUIConfig::CreateWebUIController(content::WebUI* web_ui, - const GURL& url) { +std::unique_ptr AIChatUIConfig::CreateWebUIController( + content::WebUI* web_ui, + const GURL& url) { return std::make_unique(web_ui); } #endif // #if BUILDFLAG(IS_ANDROID) #if !BUILDFLAG(IS_ANDROID) -UntrustedChatUIConfig::UntrustedChatUIConfig() - : DefaultTopChromeWebUIConfig(content::kChromeUIUntrustedScheme, - kChatUIHost) {} +AIChatUIConfig::AIChatUIConfig() + : DefaultTopChromeWebUIConfig(content::kChromeUIScheme, kAIChatUIHost) {} #else -UntrustedChatUIConfig::UntrustedChatUIConfig() - : WebUIConfig(content::kChromeUIUntrustedScheme, kChatUIHost) {} +AIChatUIConfig::AIChatUIConfig() + : WebUIConfig(content::kChromeUIScheme, kAIChatUIHost) {} #endif // #if !BUILDFLAG(IS_ANDROID) WEB_UI_CONTROLLER_TYPE_IMPL(AIChatUI) diff --git a/browser/ui/webui/ai_chat/ai_chat_ui.h b/browser/ui/webui/ai_chat/ai_chat_ui.h index 3dfab0af575b..67c26ed7403a 100644 --- a/browser/ui/webui/ai_chat/ai_chat_ui.h +++ b/browser/ui/webui/ai_chat/ai_chat_ui.h @@ -9,6 +9,7 @@ #include #include +#include "brave/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "chrome/browser/ui/webui/top_chrome/top_chrome_web_ui_controller.h" #include "content/public/browser/web_ui_controller.h" @@ -27,7 +28,7 @@ class BrowserContext; class Profile; -class AIChatUI : public ui::UntrustedWebUIController { +class AIChatUI : public ui::MojoWebUIController { public: explicit AIChatUI(content::WebUI* web_ui); AIChatUI(const AIChatUI&) = delete; @@ -37,6 +38,8 @@ class AIChatUI : public ui::UntrustedWebUIController { void BindInterface( mojo::PendingReceiver receiver); void BindInterface(mojo::PendingReceiver receiver); + void BindInterface(mojo::PendingReceiver + parent_ui_frame_receiver); // Set by WebUIContentsWrapperT. TopChromeWebUIController provides default // implementation for this but we don't use it. @@ -48,7 +51,7 @@ class AIChatUI : public ui::UntrustedWebUIController { static constexpr std::string GetWebUIName() { return "AIChatPanel"; } private: - std::unique_ptr page_handler_; + std::unique_ptr page_handler_; base::WeakPtr embedder_; raw_ptr profile_ = nullptr; @@ -57,13 +60,13 @@ class AIChatUI : public ui::UntrustedWebUIController { }; #if !BUILDFLAG(IS_ANDROID) -class UntrustedChatUIConfig : public DefaultTopChromeWebUIConfig { +class AIChatUIConfig : public DefaultTopChromeWebUIConfig { #else -class UntrustedChatUIConfig : public content::WebUIConfig { +class AIChatUIConfig : public content::WebUIConfig { #endif // #if !BUILDFLAG(IS_ANDROID) public: - UntrustedChatUIConfig(); - ~UntrustedChatUIConfig() override = default; + AIChatUIConfig(); + ~AIChatUIConfig() override = default; bool IsWebUIEnabled(content::BrowserContext* browser_context) override; diff --git a/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc b/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc index 0d059503168e..9a9585d93109 100644 --- a/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc +++ b/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc @@ -11,6 +11,7 @@ #include #include "brave/browser/ai_chat/ai_chat_service_factory.h" +#include "brave/browser/ai_chat/ai_chat_urls.h" #include "brave/browser/ui/side_panel/ai_chat/ai_chat_side_panel_utils.h" #include "brave/components/ai_chat/core/browser/ai_chat_service.h" #include "brave/components/ai_chat/core/browser/constants.h" @@ -38,8 +39,7 @@ namespace { constexpr uint32_t kDesiredFaviconSizePixels = 32; constexpr char kURLRefreshPremiumSession[] = "https://account.brave.com/?intent=recover&product=leo"; -constexpr char kURLLearnMoreBraveSearchLeo[] = - "https://support.brave.com/hc/en-us/categories/20990938292237-Brave-Leo"; + #if !BUILDFLAG(IS_ANDROID) constexpr char kURLGoPremium[] = "https://account.brave.com/account/?intent=checkout&product=leo"; @@ -96,7 +96,7 @@ void AIChatUIPageHandler::OpenAIChatSettings() { (active_chat_tab_helper_) ? active_chat_tab_helper_->web_contents() : owner_web_contents_.get(); #if !BUILDFLAG(IS_ANDROID) - const GURL url("brave://settings/leo-assistant"); + const GURL url("brave://settings/leo-ai"); if (auto* browser = chrome::FindBrowserWithTab(contents_to_navigate)) { ShowSingletonTab(browser, url); } else { @@ -116,7 +116,7 @@ void AIChatUIPageHandler::OpenConversationFullPage( CHECK(active_chat_tab_helper_); active_chat_tab_helper_->web_contents()->OpenURL( { - GURL(kChatUIURL).Resolve(conversation_uuid), + ConversationUrl(conversation_uuid), content::Referrer(), WindowOpenDisposition::NEW_FOREGROUND_TAB, ui::PAGE_TRANSITION_TYPED, @@ -172,10 +172,6 @@ void AIChatUIPageHandler::ManagePremium() { #endif } -void AIChatUIPageHandler::OpenLearnMoreAboutBraveSearchWithLeo() { - OpenURL(GURL(kURLLearnMoreBraveSearchLeo)); -} - void AIChatUIPageHandler::OpenModelSupportUrl() { OpenURL(GURL(kLeoModelSupportUrl)); } @@ -205,10 +201,10 @@ void AIChatUIPageHandler::CloseUI() { #endif } -void AIChatUIPageHandler::SetChatUI( - mojo::PendingRemote chat_ui) { +void AIChatUIPageHandler::SetChatUI(mojo::PendingRemote chat_ui, + SetChatUICallback callback) { chat_ui_.Bind(std::move(chat_ui)); - chat_ui_->SetInitialData(active_chat_tab_helper_ == nullptr); + std::move(callback).Run(active_chat_tab_helper_ == nullptr); } void AIChatUIPageHandler::BindRelatedConversation( @@ -264,6 +260,11 @@ void AIChatUIPageHandler::GetFaviconImageData( weak_ptr_factory_.GetWeakPtr(), std::move(callback))); } +void AIChatUIPageHandler::BindParentUIFrameFromChildFrame( + mojo::PendingReceiver receiver) { + chat_ui_->OnChildFrameBound(std::move(receiver)); +} + void AIChatUIPageHandler::GetFaviconImageDataForAssociatedContent( GetFaviconImageDataCallback callback, mojom::SiteInfoPtr content_info, diff --git a/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h b/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h index 7fc3c89cbe93..81e4f6dfc4ed 100644 --- a/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h +++ b/browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h @@ -47,14 +47,14 @@ class AIChatUIPageHandler : public mojom::AIChatUIHandler, void OpenAIChatSettings() override; void OpenConversationFullPage(const std::string& conversation_uuid) override; void OpenURL(const GURL& url) override; - void OpenLearnMoreAboutBraveSearchWithLeo() override; void OpenModelSupportUrl() override; void GoPremium() override; void RefreshPremiumSession() override; void ManagePremium() override; void HandleVoiceRecognition(const std::string& conversation_uuid) override; void CloseUI() override; - void SetChatUI(mojo::PendingRemote chat_ui) override; + void SetChatUI(mojo::PendingRemote chat_ui, + SetChatUICallback callback) override; void BindRelatedConversation( mojo::PendingReceiver receiver, mojo::PendingRemote conversation_ui_handler) @@ -66,6 +66,9 @@ class AIChatUIPageHandler : public mojom::AIChatUIHandler, void GetFaviconImageData(const std::string& conversation_id, GetFaviconImageDataCallback callback) override; + void BindParentUIFrameFromChildFrame( + mojo::PendingReceiver receiver); + private: class ChatContextObserver : public content::WebContentsObserver { public: diff --git a/browser/ui/webui/ai_chat/ai_chat_untrusted_conversation_ui.cc b/browser/ui/webui/ai_chat/ai_chat_untrusted_conversation_ui.cc new file mode 100644 index 000000000000..adaa05832dd5 --- /dev/null +++ b/browser/ui/webui/ai_chat/ai_chat_untrusted_conversation_ui.cc @@ -0,0 +1,215 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/browser/ui/webui/ai_chat/ai_chat_untrusted_conversation_ui.h" + +#include +#include + +#include "base/strings/escape.h" +#include "brave/browser/ai_chat/ai_chat_service_factory.h" +#include "brave/browser/ai_chat/ai_chat_urls.h" +#include "brave/browser/ui/side_panel/ai_chat/ai_chat_side_panel_utils.h" +#include "brave/browser/ui/webui/ai_chat/ai_chat_ui.h" +#include "brave/components/ai_chat/core/browser/ai_chat_service.h" +#include "brave/components/ai_chat/core/browser/constants.h" +#include "brave/components/ai_chat/core/browser/conversation_handler.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "brave/components/ai_chat/core/common/mojom/untrusted_frame.mojom.h" +#include "brave/components/ai_chat/resources/grit/ai_chat_ui_generated_map.h" +#include "brave/components/constants/webui_url_constants.h" +#include "brave/components/l10n/common/localization_util.h" +#include "chrome/browser/ui/webui/webui_util.h" +#include "components/grit/brave_components_resources.h" +#include "content/public/browser/render_frame_host.h" +#include "content/public/browser/web_contents.h" +#include "content/public/browser/web_ui.h" +#include "content/public/browser/web_ui_data_source.h" +#include "content/public/common/url_constants.h" + +#if BUILDFLAG(IS_ANDROID) +#include "brave/browser/ui/android/ai_chat/brave_leo_settings_launcher_helper.h" +#else +#include "chrome/browser/ui/browser.h" +#endif + +namespace { +constexpr char kURLLearnMoreBraveSearchLeo[] = + "https://support.brave.com/hc/en-us/categories/20990938292237-Brave-Leo"; + +// Implments the interface to calls from the UI to the browser +class UIHandler : public ai_chat::mojom::UntrustedUIHandler { + public: + UIHandler(content::WebUI* web_ui, + mojo::PendingReceiver receiver) + : web_ui_(web_ui), receiver_(this, std::move(receiver)) {} + UIHandler(const UIHandler&) = delete; + UIHandler& operator=(const UIHandler&) = delete; + + ~UIHandler() override = default; + + // ai_chat::mojom::UntrustedConversationUIHandler + void OpenLearnMoreAboutBraveSearchWithLeo() override { + if (!web_ui_->GetRenderFrameHost()->HasTransientUserActivation()) { + return; + } + OpenURL(GURL(kURLLearnMoreBraveSearchLeo)); + } + + void OpenSearchURL(const std::string& search_query) override { + if (!web_ui_->GetRenderFrameHost()->HasTransientUserActivation()) { + return; + } + OpenURL(GURL("https://search.brave.com/search?q=" + + base::EscapeQueryParamValue(search_query, true))); + } + + void BindParentPage(mojo::PendingReceiver + parent_ui_frame_receiver) override { + // Route the receiver to the parent frame + auto* rfh = web_ui_->GetWebContents()->GetPrimaryMainFrame(); + if (!rfh) { + return; + } + + // We should not be embedded on a non-WebUI page + CHECK(rfh->GetWebUI()); + + AIChatUI* ai_chat_ui_controller = + rfh->GetWebUI()->GetController()->GetAs(); + // We should not be embedded on any non AIChatUI page + CHECK(ai_chat_ui_controller); + + ai_chat_ui_controller->BindInterface(std::move(parent_ui_frame_receiver)); + } + + private: + void OpenURL(GURL url) { + if (!url.SchemeIs(url::kHttpsScheme)) { + return; + } + +#if !BUILDFLAG(IS_ANDROID) + Browser* browser = + ai_chat::GetBrowserForWebContents(web_ui_->GetWebContents()); + browser->OpenURL( + {url, content::Referrer(), WindowOpenDisposition::NEW_FOREGROUND_TAB, + ui::PAGE_TRANSITION_LINK, false}, + /*navigation_handle_callback=*/{}); +#else + // We handle open link different on Android as we need to close the chat + // window because it's always full screen + ai_chat::OpenURL(url.spec()); +#endif + } + + raw_ptr web_ui_ = nullptr; + mojo::Receiver receiver_; +}; + +} // namespace + +bool AIChatUntrustedConversationUIConfig::IsWebUIEnabled( + content::BrowserContext* browser_context) { + // Only enabled if we have a valid service + return (ai_chat::AIChatServiceFactory::GetForBrowserContext( + browser_context) != nullptr); +} + +std::unique_ptr +AIChatUntrustedConversationUIConfig::CreateWebUIController( + content::WebUI* web_ui, + const GURL& url) { + return std::make_unique(web_ui); +} + +AIChatUntrustedConversationUIConfig::AIChatUntrustedConversationUIConfig() + : WebUIConfig(content::kChromeUIUntrustedScheme, + kAIChatUntrustedConversationUIHost) {} + +AIChatUntrustedConversationUIConfig::~AIChatUntrustedConversationUIConfig() = + default; + +AIChatUntrustedConversationUI::AIChatUntrustedConversationUI( + content::WebUI* web_ui) + : ui::MojoWebUIController(web_ui) { + // Create a URLDataSource and add resources. + content::WebUIDataSource* source = content::WebUIDataSource::CreateAndAdd( + web_ui->GetWebContents()->GetBrowserContext(), + kAIChatUntrustedConversationUIURL); + webui::SetupWebUIDataSource(source, kAiChatUiGenerated, + IDR_AI_CHAT_UNTRUSTED_CONVERSATION_UI_HTML); + + for (const auto& str : ai_chat::GetLocalizedStrings()) { + source->AddString(str.name, + brave_l10n::GetLocalizedResourceUTF16String(str.id)); + } + + constexpr bool kIsMobile = BUILDFLAG(IS_ANDROID) || BUILDFLAG(IS_IOS); + source->AddBoolean("isMobile", kIsMobile); + + source->OverrideContentSecurityPolicy( + network::mojom::CSPDirectiveName::ScriptSrc, + "script-src 'self' chrome-untrusted://resources;"); + source->OverrideContentSecurityPolicy( + network::mojom::CSPDirectiveName::StyleSrc, + "style-src 'self' 'unsafe-inline' chrome-untrusted://resources;"); + source->OverrideContentSecurityPolicy( + network::mojom::CSPDirectiveName::ImgSrc, + "img-src 'self' blob: chrome-untrusted://resources;"); + source->OverrideContentSecurityPolicy( + network::mojom::CSPDirectiveName::FontSrc, + "font-src 'self' chrome-untrusted://resources;"); + source->OverrideContentSecurityPolicy( + network::mojom::CSPDirectiveName::FrameAncestors, + base::StringPrintf("frame-ancestors %s;", kAIChatUIURL)); + source->OverrideContentSecurityPolicy( + network::mojom::CSPDirectiveName::TrustedTypes, "trusted-types default;"); +} + +AIChatUntrustedConversationUI::~AIChatUntrustedConversationUI() = default; + +void AIChatUntrustedConversationUI::BindInterface( + mojo::PendingReceiver receiver) { + ui_handler_ = std::make_unique(web_ui(), std::move(receiver)); +} + +void AIChatUntrustedConversationUI::BindInterface( + mojo::PendingReceiver + receiver) { + // Get conversation from URL + std::string_view conversation_uuid = ai_chat::ConversationUUIDFromURL( + web_ui()->GetRenderFrameHost()->GetLastCommittedURL()); + DVLOG(2) << "Binding conversation frame for conversation uuid:" + << conversation_uuid; + if (conversation_uuid.empty()) { + return; + } + + ai_chat::AIChatService* service = + ai_chat::AIChatServiceFactory::GetForBrowserContext( + web_ui()->GetWebContents()->GetBrowserContext()); + + if (!service) { + return; + } + + service->GetConversation( + conversation_uuid, + base::BindOnce( + [](mojo::PendingReceiver + receiver, + ai_chat::ConversationHandler* conversation_handler) { + if (!conversation_handler) { + DVLOG(0) << "Failed to get conversation handler for conversation " + "entries frame"; + return; + } + conversation_handler->Bind(std::move(receiver)); + }, + std::move(receiver))); +} + +WEB_UI_CONTROLLER_TYPE_IMPL(AIChatUntrustedConversationUI) diff --git a/browser/ui/webui/ai_chat/ai_chat_untrusted_conversation_ui.h b/browser/ui/webui/ai_chat/ai_chat_untrusted_conversation_ui.h new file mode 100644 index 000000000000..b633d4968375 --- /dev/null +++ b/browser/ui/webui/ai_chat/ai_chat_untrusted_conversation_ui.h @@ -0,0 +1,53 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_BROWSER_UI_WEBUI_AI_CHAT_AI_CHAT_UNTRUSTED_CONVERSATION_UI_H_ +#define BRAVE_BROWSER_UI_WEBUI_AI_CHAT_AI_CHAT_UNTRUSTED_CONVERSATION_UI_H_ + +#include + +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +#include "brave/components/ai_chat/core/common/mojom/untrusted_frame.mojom-forward.h" +#include "content/public/browser/webui_config.h" +#include "ui/webui/mojo_web_ui_controller.h" + +// Determines in which context the untrusted conversation UI should be enabled +class AIChatUntrustedConversationUIConfig : public content::WebUIConfig { + public: + AIChatUntrustedConversationUIConfig(); + ~AIChatUntrustedConversationUIConfig() override; + + // content::WebUIConfig: + bool IsWebUIEnabled(content::BrowserContext* browser_context) override; + std::unique_ptr CreateWebUIController( + content::WebUI* web_ui, + const GURL& url) override; +}; + +// This Untrusted WebUI hosts the UI to display conversation entries, including +// ones generated by an LLM. It should not be granted more permissions than +// required to display the conversation entries. Anything requiring access +// to browser features should be done in the trusted UI. +class AIChatUntrustedConversationUI : public ui::MojoWebUIController { + public: + explicit AIChatUntrustedConversationUI(content::WebUI* web_ui); + AIChatUntrustedConversationUI(const AIChatUntrustedConversationUI&) = delete; + AIChatUntrustedConversationUI& operator=( + const AIChatUntrustedConversationUI&) = delete; + ~AIChatUntrustedConversationUI() override; + + void BindInterface( + mojo::PendingReceiver receiver); + void BindInterface( + mojo::PendingReceiver + receiver); + + private: + std::unique_ptr ui_handler_; + + WEB_UI_CONTROLLER_TYPE_DECL(); +}; + +#endif // BRAVE_BROWSER_UI_WEBUI_AI_CHAT_AI_CHAT_UNTRUSTED_CONVERSATION_UI_H_ diff --git a/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc b/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc index 0697b5a369a7..318f8b3c8d13 100644 --- a/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc +++ b/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc @@ -10,10 +10,13 @@ #include #include "base/containers/contains.h" +#include "brave/browser/ai_chat/ai_chat_service_factory.h" #include "brave/browser/brave_browser_process.h" #include "brave/browser/misc_metrics/process_misc_metrics.h" #include "brave/browser/ui/sidebar/sidebar_service_factory.h" #include "brave/components/ai_chat/core/browser/ai_chat_metrics.h" +#include "brave/components/ai_chat/core/browser/ai_chat_service.h" +#include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/pref_names.h" #include "brave/components/sidebar/browser/sidebar_item.h" #include "brave/components/sidebar/browser/sidebar_service.h" @@ -141,13 +144,20 @@ void BraveLeoAssistantHandler::HandleGetLeoIconVisibility( void BraveLeoAssistantHandler::HandleResetLeoData( const base::Value::List& args) { - auto* service = sidebar::SidebarServiceFactory::GetForProfile(profile_); + auto* sidebar_service = + sidebar::SidebarServiceFactory::GetForProfile(profile_); + + ShowLeoAssistantIconVisibleIfNot(sidebar_service); - ShowLeoAssistantIconVisibleIfNot(service); - profile_->GetPrefs()->ClearPref(ai_chat::prefs::kLastAcceptedDisclaimer); - g_brave_browser_process->process_misc_metrics() - ->ai_chat_metrics() - ->RecordReset(); + ai_chat::AIChatService* service = + ai_chat::AIChatServiceFactory::GetForBrowserContext(profile_); + if (!service) { + return; + } + service->DeleteConversations(); + if (profile_) { + ai_chat::SetUserOptedIn(profile_->GetPrefs(), false); + } AllowJavascript(); } diff --git a/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc b/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc index e3c2fdf5d762..f78b6e99cbdf 100644 --- a/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc +++ b/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc @@ -422,16 +422,19 @@ void BraveAddCommonStrings(content::WebUIDataSource* html_source, IDS_SETTINGS_LEO_ASSISTANT_SHOW_IN_CONTEXT_MENU_LABEL}, {"braveLeoAssistantShowInContextMenuDesc", IDS_SETTINGS_LEO_ASSISTANT_SHOW_IN_CONTEXT_MENU_DESC}, + {"braveLeoAssistantHistoryPreferenceLabel", + IDS_SETTINGS_LEO_ASSISTANT_HISTORY_PREFERENCE_LABEL}, + {"braveLeoAssistantHistoryPreferenceConfirm", + IDS_SETTINGS_LEO_ASSISTANT_HISTORY_PREFERENCE_CONFIRM}, {"braveLeoAssistantResetAndClearDataLabel", IDS_SETTINGS_LEO_ASSISTANT_RESET_AND_CLEAR_DATA_LABEL}, {"braveLeoAssistantResetAndClearDataConfirmationText", IDS_SETTINGS_LEO_ASSISTANT_RESET_AND_CLEAR_DATA_CONFIRMATION_LABEL}, {"braveLeoAssistantAutocompleteLink", IDS_SETTINGS_LEO_ASSISTANT_AUTOCOMPLETE_LINK}, - {"leoClearHistoryData", - IDS_SETTINGS_LEO_ASSISTANT_CLEAR_HISTORY_DATA_LABEL}, - {"leoClearHistoryDataSubLabel", - IDS_SETTINGS_LEO_ASSISTANT_CLEAR_HISTORY_DATA_SUBLABEL}, + {"aiChatClearHistoryData", IDS_SETTINGS_AI_CHAT_CLEAR_HISTORY_DATA_LABEL}, + {"aiChatClearHistoryDataSubLabel", + IDS_SETTINGS_AI_CHAT_CLEAR_HISTORY_DATA_SUBLABEL}, {"braveLeoPremiumLabelNonPremium", IDS_CHAT_UI_MODEL_PREMIUM_LABEL_NON_PREMIUM}, {"braveLeoAssistantModelSelectionLabel", diff --git a/chromium_src/chrome/browser/autocomplete/chrome_autocomplete_provider_client.cc b/chromium_src/chrome/browser/autocomplete/chrome_autocomplete_provider_client.cc index daffe4578adf..a1ba9fb5acec 100644 --- a/chromium_src/chrome/browser/autocomplete/chrome_autocomplete_provider_client.cc +++ b/chromium_src/chrome/browser/autocomplete/chrome_autocomplete_provider_client.cc @@ -6,13 +6,19 @@ #include "src/chrome/browser/autocomplete/chrome_autocomplete_provider_client.cc" #include "brave/browser/ai_chat/ai_chat_service_factory.h" +#include "brave/browser/ai_chat/ai_chat_urls.h" #include "brave/components/ai_chat/content/browser/ai_chat_tab_helper.h" #include "brave/components/ai_chat/core/browser/ai_chat_metrics.h" #include "brave/components/ai_chat/core/browser/ai_chat_service.h" +#include "brave/components/ai_chat/core/browser/conversation_handler.h" +#include "brave/components/ai_chat/core/common/features.h" #include "brave/components/ai_chat/core/common/pref_names.h" #include "brave/components/commander/common/buildflags/buildflags.h" #include "build/build_config.h" #include "chrome/browser/profiles/profile.h" +#include "content/public/browser/web_contents.h" +#include "ui/base/page_transition_types.h" +#include "ui/base/window_open_disposition.h" #if !BUILDFLAG(IS_ANDROID) #include "brave/browser/brave_browser_process.h" @@ -57,30 +63,47 @@ void ChromeAutocompleteProviderClient::OpenLeo(const std::u16string& query) { return; } - auto* chat_tab_helper = ai_chat::AIChatTabHelper::FromWebContents( - browser->tab_strip_model()->GetActiveWebContents()); - DCHECK(chat_tab_helper); - - auto* conversation_handler = - ai_chat_service->GetOrCreateConversationHandlerForContent( - chat_tab_helper->GetContentId(), chat_tab_helper->GetWeakPtr()); - CHECK(conversation_handler); - - // Before trying to activate the panel, unlink page content if needed. - // This needs to be called before activating the panel to check against the - // current state. - conversation_handler->MaybeUnlinkAssociatedContent(); + ai_chat::ConversationHandler* conversation_handler; + + if (ai_chat_service->IsAIChatHistoryEnabled() && + ai_chat::features::kOmniboxOpensFullPage.Get()) { + conversation_handler = ai_chat_service->CreateConversation(); + browser->OpenURL({ai_chat::ConversationUrl( + conversation_handler->get_conversation_uuid()), + content::Referrer(), WindowOpenDisposition::CURRENT_TAB, + ui::PageTransition::PAGE_TRANSITION_GENERATED, false}, + {}); + } else { + auto* chat_tab_helper = ai_chat::AIChatTabHelper::FromWebContents( + browser->tab_strip_model()->GetActiveWebContents()); + DCHECK(chat_tab_helper); + conversation_handler = + ai_chat_service->GetOrCreateConversationHandlerForContent( + chat_tab_helper->GetContentId(), chat_tab_helper->GetWeakPtr()); + if (!conversation_handler) { + return; + } + + // Before trying to activate the panel, unlink page content if needed. + // This needs to be called before activating the panel to check against the + // current state. + conversation_handler->MaybeUnlinkAssociatedContent(); + + // Activate the panel. + auto* sidebar_controller = + static_cast(browser)->sidebar_controller(); + sidebar_controller->ActivatePanelItem( + sidebar::SidebarItem::BuiltInItemType::kChatUI); + } - // Activate the panel. - auto* sidebar_controller = - static_cast(browser)->sidebar_controller(); - sidebar_controller->ActivatePanelItem( - sidebar::SidebarItem::BuiltInItemType::kChatUI); + if (!conversation_handler) { + return; + } // Send the query to the AIChat's backend. ai_chat::mojom::ConversationTurnPtr turn = ai_chat::mojom::ConversationTurn::New( - ai_chat::mojom::CharacterType::HUMAN, + std::nullopt, ai_chat::mojom::CharacterType::HUMAN, ai_chat::mojom::ActionType::QUERY, ai_chat::mojom::ConversationTurnVisibility::VISIBLE, base::UTF16ToUTF8(query) /* text */, std::nullopt /* selected_text */, diff --git a/chromium_src/chrome/browser/ui/webui/chrome_untrusted_web_ui_configs.cc b/chromium_src/chrome/browser/ui/webui/chrome_untrusted_web_ui_configs.cc index 6c2a2a6824be..b0b5a6593d8b 100644 --- a/chromium_src/chrome/browser/ui/webui/chrome_untrusted_web_ui_configs.cc +++ b/chromium_src/chrome/browser/ui/webui/chrome_untrusted_web_ui_configs.cc @@ -6,7 +6,7 @@ #include "chrome/browser/ui/webui/chrome_untrusted_web_ui_configs.h" #include "base/feature_list.h" -#include "brave/browser/ui/webui/ai_chat/ai_chat_ui.h" +#include "brave/browser/ui/webui/ai_chat/ai_chat_untrusted_conversation_ui.h" #include "brave/browser/ui/webui/brave_wallet/ledger/ledger_ui.h" #include "brave/browser/ui/webui/brave_wallet/line_chart/line_chart_ui.h" #include "brave/browser/ui/webui/brave_wallet/market/market_ui.h" @@ -68,6 +68,6 @@ void RegisterChromeUntrustedWebUIConfigs() { if (ai_chat::features::IsAIChatEnabled()) { content::WebUIConfigMap::GetInstance().AddUntrustedWebUIConfig( - std::make_unique()); + std::make_unique()); } } diff --git a/chromium_src/chrome/browser/ui/webui/chrome_web_ui_configs.cc b/chromium_src/chrome/browser/ui/webui/chrome_web_ui_configs.cc index d590d2ed34e7..91e43134b001 100644 --- a/chromium_src/chrome/browser/ui/webui/chrome_web_ui_configs.cc +++ b/chromium_src/chrome/browser/ui/webui/chrome_web_ui_configs.cc @@ -5,6 +5,8 @@ #include "chrome/browser/ui/webui/chrome_web_ui_configs.h" +#include "brave/browser/ui/webui/ai_chat/ai_chat_ui.h" +#include "brave/components/ai_chat/core/common/features.h" #include "content/public/browser/webui_config_map.h" #define RegisterChromeWebUIConfigs RegisterChromeWebUIConfigs_ChromiumImpl @@ -79,4 +81,8 @@ void RegisterChromeWebUIConfigs() { #endif // !BUILDFLAG(IS_ANDROID) map.AddWebUIConfig(std::make_unique()); map.AddWebUIConfig(std::make_unique()); + + if (ai_chat::features::IsAIChatEnabled()) { + map.AddWebUIConfig(std::make_unique()); + } } diff --git a/components/ai_chat/content/browser/DEPS b/components/ai_chat/content/browser/DEPS index 45f85fce9932..2a8c66b98131 100644 --- a/components/ai_chat/content/browser/DEPS +++ b/components/ai_chat/content/browser/DEPS @@ -9,4 +9,6 @@ include_rules = [ "+services/data_decoder/public", "+services/network/public", "+services/service_manager/public", + "+third_party/blink/public/common", + "+third_party/blink/public/mojom/permissions" ] diff --git a/components/ai_chat/content/browser/ai_chat_brave_search_throttle.cc b/components/ai_chat/content/browser/ai_chat_brave_search_throttle.cc index 921344850b89..8dae895a1951 100644 --- a/components/ai_chat/content/browser/ai_chat_brave_search_throttle.cc +++ b/components/ai_chat/content/browser/ai_chat_brave_search_throttle.cc @@ -5,14 +5,14 @@ #include "brave/components/ai_chat/content/browser/ai_chat_brave_search_throttle.h" -#include +#include +#include #include #include "base/check.h" #include "base/functional/bind.h" #include "brave/components/ai_chat/content/browser/ai_chat_tab_helper.h" -#include "brave/components/ai_chat/content/browser/page_content_fetcher.h" -#include "brave/components/ai_chat/core/browser/conversation_handler.h" +#include "brave/components/ai_chat/core/browser/ai_chat_service.h" #include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/features.h" #include "brave/components/ai_chat/core/common/utils.h" @@ -20,7 +20,15 @@ #include "content/public/browser/navigation_handle.h" #include "content/public/browser/permission_controller.h" #include "content/public/browser/permission_request_description.h" +#include "content/public/browser/permission_result.h" #include "content/public/browser/web_contents.h" +#include "third_party/blink/public/common/permissions/permission_utils.h" +#include "third_party/blink/public/mojom/permissions/permission_status.mojom-shared.h" +#include "url/gurl.h" + +namespace content { +class RenderFrameHost; +} // namespace content namespace ai_chat { diff --git a/components/ai_chat/content/browser/ai_chat_brave_search_throttle.h b/components/ai_chat/content/browser/ai_chat_brave_search_throttle.h index d9183e2d5aac..ddae1d4882e6 100644 --- a/components/ai_chat/content/browser/ai_chat_brave_search_throttle.h +++ b/components/ai_chat/content/browser/ai_chat_brave_search_throttle.h @@ -7,21 +7,31 @@ #define BRAVE_COMPONENTS_AI_CHAT_CONTENT_BROWSER_AI_CHAT_BRAVE_SEARCH_THROTTLE_H_ #include +#include #include +#include "base/functional/callback.h" #include "base/memory/raw_ptr.h" #include "base/memory/weak_ptr.h" #include "brave/components/ai_chat/core/browser/ai_chat_service.h" #include "content/public/browser/navigation_throttle.h" #include "content/public/browser/permission_result.h" +namespace blink { +namespace mojom { +enum class PermissionStatus : int32_t; +} // namespace mojom +} // namespace blink + namespace content { class WebContents; +class NavigationHandle; } class PrefService; namespace ai_chat { +class AIChatService; // A network throttle which intercepts Brave Search requests. // Currently the only use case is to intercept requests to open Leo AI chat, so diff --git a/components/ai_chat/content/browser/ai_chat_tab_helper.cc b/components/ai_chat/content/browser/ai_chat_tab_helper.cc index 12a143243567..d79992c0e213 100644 --- a/components/ai_chat/content/browser/ai_chat_tab_helper.cc +++ b/components/ai_chat/content/browser/ai_chat_tab_helper.cc @@ -5,16 +5,32 @@ #include "brave/components/ai_chat/content/browser/ai_chat_tab_helper.h" +#include #include +#include +#include #include +#include +#include #include -#include +#include +#include +#include +#include "base/check.h" +#include "base/containers/contains.h" #include "base/containers/fixed_flat_set.h" #include "base/functional/bind.h" +#include "base/functional/callback_forward.h" +#include "base/location.h" +#include "base/logging.h" +#include "base/memory/scoped_refptr.h" #include "base/memory/weak_ptr.h" -#include "base/ranges/algorithm.h" +#include "base/numerics/clamped_math.h" #include "base/strings/string_util.h" +#include "base/strings/utf_ostream_operators.h" +#include "base/task/sequenced_task_runner.h" +#include "base/time/time.h" #include "brave/components/ai_chat/content/browser/page_content_fetcher.h" #include "brave/components/ai_chat/content/browser/pdf_utils.h" #include "brave/components/ai_chat/core/browser/associated_content_driver.h" @@ -28,16 +44,29 @@ #include "content/public/browser/navigation_details.h" #include "content/public/browser/navigation_entry.h" #include "content/public/browser/permission_controller.h" -#include "content/public/browser/permission_request_description.h" #include "content/public/browser/permission_result.h" +#include "content/public/browser/render_frame_host.h" #include "content/public/browser/scoped_accessibility_mode.h" #include "content/public/browser/storage_partition.h" #include "content/public/browser/web_contents.h" +#include "mojo/public/cpp/bindings/pending_associated_receiver.h" #include "pdf/buildflags.h" +#include "third_party/blink/public/common/permissions/permission_utils.h" +#include "third_party/blink/public/mojom/permissions/permission_status.mojom-shared.h" +#include "ui/accessibility/ax_enums.mojom-shared.h" #include "ui/accessibility/ax_mode.h" +#include "ui/accessibility/ax_node_data.h" +#include "ui/accessibility/ax_tree_update.h" #include "ui/accessibility/ax_updates_and_events.h" #include "ui/base/l10n/l10n_util.h" +namespace favicon { +class FaviconDriver; +} // namespace favicon +namespace gfx { +class Image; +} // namespace gfx + namespace ai_chat { AIChatTabHelper::PDFA11yInfoLoadObserver::PDFA11yInfoLoadObserver( @@ -178,6 +207,7 @@ void AIChatTabHelper::TitleWasSet(content::NavigationEntry* entry) { << " title=" << entry->GetTitle(); MaybeSameDocumentIsNewPage(); previous_page_title_ = GetPageTitle(); + OnTitleChanged(); } void AIChatTabHelper::InnerWebContentsAttached( diff --git a/components/ai_chat/content/browser/ai_chat_tab_helper.h b/components/ai_chat/content/browser/ai_chat_tab_helper.h index 46a6147421e4..6caed8266a9e 100644 --- a/components/ai_chat/content/browser/ai_chat_tab_helper.h +++ b/components/ai_chat/content/browser/ai_chat_tab_helper.h @@ -6,12 +6,16 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CONTENT_BROWSER_AI_CHAT_TAB_HELPER_H_ #define BRAVE_COMPONENTS_AI_CHAT_CONTENT_BROWSER_AI_CHAT_TAB_HELPER_H_ +#include #include #include +#include #include +#include "base/functional/callback.h" #include "base/functional/callback_forward.h" #include "base/memory/raw_ptr.h" +#include "base/memory/weak_ptr.h" #include "brave/components/ai_chat/core/browser/associated_content_driver.h" #include "brave/components/ai_chat/core/browser/conversation_handler.h" #include "brave/components/ai_chat/core/common/mojom/page_content_extractor.mojom.h" @@ -23,12 +27,32 @@ #include "content/public/browser/web_contents_user_data.h" #include "mojo/public/cpp/bindings/associated_receiver.h" #include "mojo/public/cpp/bindings/pending_associated_receiver.h" +#include "url/gurl.h" + +namespace favicon { +class FaviconDriver; +} // namespace favicon +namespace gfx { +class Image; +} // namespace gfx +namespace mojo { +template +class PendingAssociatedReceiver; +} // namespace mojo +namespace ui { +struct AXUpdatesAndEvents; +} // namespace ui namespace content { class ScopedAccessibilityMode; +class NavigationEntry; +class RenderFrameHost; +class WebContents; +struct LoadCommittedDetails; } class AIChatUIBrowserTest; + namespace ai_chat { class AIChatMetrics; diff --git a/components/ai_chat/content/browser/ai_chat_throttle.cc b/components/ai_chat/content/browser/ai_chat_throttle.cc index 2667b9264068..d0cef44790f5 100644 --- a/components/ai_chat/content/browser/ai_chat_throttle.cc +++ b/components/ai_chat/content/browser/ai_chat_throttle.cc @@ -5,43 +5,57 @@ #include "brave/components/ai_chat/content/browser/ai_chat_throttle.h" -#include +#include +#include #include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/features.h" #include "brave/components/constants/webui_url_constants.h" +#include "build/build_config.h" #include "components/user_prefs/user_prefs.h" #include "content/public/browser/browser_context.h" #include "content/public/browser/navigation_handle.h" #include "content/public/browser/web_contents.h" #include "content/public/common/url_constants.h" +#include "ui/base/page_transition_types.h" +#include "url/gurl.h" namespace ai_chat { // static std::unique_ptr AiChatThrottle::MaybeCreateThrottleFor( content::NavigationHandle* navigation_handle) { - // The AI Chat WebUI won't be enabled if the feature is disabled + // The throttle's only purpose is to deny navigation in a Tab. + + // The AI Chat WebUI won't be enabled if the feature or policy is disabled + // (this is not checking a user preference). if (!ai_chat::IsAIChatEnabled(user_prefs::UserPrefs::Get( navigation_handle->GetWebContents()->GetBrowserContext()))) { return nullptr; } - // We don't need this throttle if the full-page feature is enabled via proxy - // of the AIChatHistory feature flag. - if (features::IsAIChatHistoryEnabled()) { + const GURL& url = navigation_handle->GetURL(); + + bool is_main_page_url = url.SchemeIs(content::kChromeUIScheme) && + url.host_piece() == kAIChatUIHost; + + // We allow main page navigation only if the full-page feature is enabled + // via the AIChatHistory feature flag. + if (is_main_page_url && features::IsAIChatHistoryEnabled()) { return nullptr; } - // We need this throttle to work only for chrome-untrusted://chat page - if (!navigation_handle->GetURL().SchemeIs( - content::kChromeUIUntrustedScheme) || - navigation_handle->GetURL().host_piece() != kChatUIHost) { + bool is_ai_chat_frame = + url.SchemeIs(content::kChromeUIUntrustedScheme) && + url.host_piece() == kAIChatUntrustedConversationUIHost; + + // We need this throttle to work only for AI Chat related URLs + if (!is_main_page_url && !is_ai_chat_frame) { return nullptr; } - // Purpose of this throttle is to forbid loading of chrome-untrusted://chat - // in tab. + // Purpose of this throttle is to forbid loading of chrome://leo-ai and + // related urls in tab. // Parameters check is made different for Android and Desktop because // there are different flags: // --------+---------------------------------+------------------------------ diff --git a/components/ai_chat/content/browser/ai_chat_throttle.h b/components/ai_chat/content/browser/ai_chat_throttle.h index 113270fd2e05..c07d94b5cad3 100644 --- a/components/ai_chat/content/browser/ai_chat_throttle.h +++ b/components/ai_chat/content/browser/ai_chat_throttle.h @@ -10,6 +10,10 @@ #include "content/public/browser/navigation_throttle.h" +namespace content { +class NavigationHandle; +} // namespace content + namespace ai_chat { // Prevents navigation to certain AI Chat URLs diff --git a/components/ai_chat/content/browser/model_service_factory.cc b/components/ai_chat/content/browser/model_service_factory.cc index 84067513b890..57006ccc46c2 100644 --- a/components/ai_chat/content/browser/model_service_factory.cc +++ b/components/ai_chat/content/browser/model_service_factory.cc @@ -5,6 +5,7 @@ #include "brave/components/ai_chat/content/browser/model_service_factory.h" +#include "base/check.h" #include "base/no_destructor.h" #include "brave/components/ai_chat/core/browser/model_service.h" #include "brave/components/ai_chat/core/common/features.h" diff --git a/components/ai_chat/content/browser/model_service_factory.h b/components/ai_chat/content/browser/model_service_factory.h index 34ae0b19d389..04f41805ae8d 100644 --- a/components/ai_chat/content/browser/model_service_factory.h +++ b/components/ai_chat/content/browser/model_service_factory.h @@ -9,6 +9,11 @@ #include #include "components/keyed_service/content/browser_context_keyed_service_factory.h" +#include "components/keyed_service/core/keyed_service.h" + +namespace content { +class BrowserContext; +} // namespace content namespace base { diff --git a/components/ai_chat/content/browser/page_content_fetcher.cc b/components/ai_chat/content/browser/page_content_fetcher.cc index 7ceb61f9a8c0..63b3fd5493f2 100644 --- a/components/ai_chat/content/browser/page_content_fetcher.cc +++ b/components/ai_chat/content/browser/page_content_fetcher.cc @@ -5,37 +5,61 @@ #include "brave/components/ai_chat/content/browser/page_content_fetcher.h" +#include +#include #include -#include +#include +#include #include #include +#include #include #include +#include "base/check.h" +#include "base/containers/checked_iterators.h" #include "base/containers/contains.h" #include "base/containers/fixed_flat_set.h" #include "base/functional/bind.h" +#include "base/functional/callback.h" +#include "base/logging.h" +#include "base/memory/weak_ptr.h" #include "base/strings/string_split.h" +#include "base/types/expected.h" +#include "base/values.h" #include "brave/components/ai_chat/content/browser/ai_chat_tab_helper.h" #include "brave/components/ai_chat/content/browser/pdf_utils.h" -#include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/mojom/page_content_extractor.mojom.h" #include "brave/components/text_recognition/common/buildflags/buildflags.h" #include "content/public/browser/browser_context.h" -#include "content/public/browser/render_process_host.h" -#include "content/public/browser/render_widget_host_view.h" +#include "content/public/browser/render_frame_host.h" #include "content/public/browser/storage_partition.h" #include "content/public/browser/web_contents.h" #include "mojo/public/cpp/bindings/remote.h" +#include "mojo/public/cpp/bindings/struct_ptr.h" #include "net/base/load_flags.h" +#include "net/base/net_errors.h" +#include "net/cookies/site_for_cookies.h" #include "net/http/http_request_headers.h" +#include "net/http/http_response_headers.h" +#include "net/traffic_annotation/network_traffic_annotation.h" #include "services/data_decoder/public/cpp/data_decoder.h" #include "services/data_decoder/public/cpp/safe_xml_parser.h" +#include "services/data_decoder/public/mojom/xml_parser.mojom.h" #include "services/network/public/cpp/resource_request.h" +#include "services/network/public/cpp/shared_url_loader_factory.h" #include "services/network/public/cpp/simple_url_loader.h" +#include "services/network/public/mojom/fetch_api.mojom-shared.h" #include "services/network/public/mojom/url_response_head.mojom.h" #include "services/service_manager/public/cpp/interface_provider.h" #include "url/gurl.h" +#include "url/origin.h" +#include "url/url_constants.h" + +#if BUILDFLAG(ENABLE_TEXT_RECOGNITION) +#include "brave/components/ai_chat/core/browser/utils.h" +#include "content/public/browser/render_widget_host_view.h" +#endif // BUILDFLAG(ENABLE_TEXT_RECOGNITION) namespace ai_chat { diff --git a/components/ai_chat/content/browser/page_content_fetcher.h b/components/ai_chat/content/browser/page_content_fetcher.h index 1825a328e5d8..31c855ec8e44 100644 --- a/components/ai_chat/content/browser/page_content_fetcher.h +++ b/components/ai_chat/content/browser/page_content_fetcher.h @@ -7,8 +7,10 @@ #define BRAVE_COMPONENTS_AI_CHAT_CONTENT_BROWSER_PAGE_CONTENT_FETCHER_H_ #include +#include #include "base/functional/callback_forward.h" +#include "base/memory/raw_ptr.h" #include "base/memory/scoped_refptr.h" #include "brave/components/ai_chat/content/browser/ai_chat_tab_helper.h" #include "brave/components/ai_chat/core/common/mojom/page_content_extractor.mojom.h" diff --git a/components/ai_chat/content/browser/pdf_utils.cc b/components/ai_chat/content/browser/pdf_utils.cc index 94c28018416e..89f5b748639f 100644 --- a/components/ai_chat/content/browser/pdf_utils.cc +++ b/components/ai_chat/content/browser/pdf_utils.cc @@ -5,13 +5,23 @@ #include "brave/components/ai_chat/content/browser/pdf_utils.h" +#include +#include +#include +#include + +#include "base/functional/function_ref.h" +#include "base/memory/raw_ptr.h" #include "base/strings/strcat.h" #include "components/strings/grit/components_strings.h" +#include "content/public/browser/render_frame_host.h" #include "content/public/browser/render_process_host.h" #include "content/public/browser/web_contents.h" #include "pdf/buildflags.h" #include "services/strings/grit/services_strings.h" +#include "ui/accessibility/ax_enums.mojom-shared.h" #include "ui/accessibility/ax_node.h" +#include "ui/accessibility/ax_tree_id.h" #include "ui/accessibility/ax_tree_manager.h" #include "ui/base/l10n/l10n_util.h" diff --git a/components/ai_chat/core/browser/BUILD.gn b/components/ai_chat/core/browser/BUILD.gn index a57ac802f6a3..26657a9a258a 100644 --- a/components/ai_chat/core/browser/BUILD.gn +++ b/components/ai_chat/core/browser/BUILD.gn @@ -10,6 +10,8 @@ static_library("browser") { sources = [ "ai_chat_credential_manager.cc", "ai_chat_credential_manager.h", + "ai_chat_database.cc", + "ai_chat_database.h", "ai_chat_feedback_api.cc", "ai_chat_feedback_api.h", "ai_chat_metrics.cc", @@ -85,13 +87,17 @@ static_library("browser") { "//components/component_updater", "//components/component_updater:component_updater_paths", "//components/keyed_service/core", + "//components/os_crypt/async/browser", + "//components/os_crypt/async/common", "//components/os_crypt/sync:os_crypt", "//components/prefs", + "//components/update_client", "//components/user_prefs", "//net/traffic_annotation", "//services/data_decoder/public/cpp", "//services/network/public/cpp", "//services/service_manager/public/cpp", + "//sql", "//third_party/abseil-cpp:absl", "//third_party/re2", "//ui/base", @@ -124,6 +130,7 @@ if (!is_ios) { testonly = true sources = [ "ai_chat_credential_manager_unittest.cc", + "ai_chat_database_unittest.cc", "ai_chat_metrics_unittest.cc", "ai_chat_service_unittest.cc", "associated_content_driver_unittest.cc", @@ -159,6 +166,7 @@ if (!is_ios) { "//brave/components/skus/common:mojom", "//components/component_updater:component_updater_paths", "//components/component_updater:test_support", + "//components/os_crypt/async/browser:test_support", "//components/os_crypt/sync:test_support", "//components/prefs:test_support", "//components/sync_preferences:test_support", @@ -167,6 +175,7 @@ if (!is_ios) { "//services/data_decoder/public/cpp:test_support", "//services/network:test_support", "//services/network/public/cpp:cpp", + "//sql:test_support", "//testing/gtest:gtest", ] } @@ -179,12 +188,18 @@ source_set("test_support") { "engine/mock_engine_consumer.h", "engine/mock_remote_completion_client.cc", "engine/mock_remote_completion_client.h", + "mock_conversation_handler_observer.cc", + "mock_conversation_handler_observer.h", + "test_utils.cc", + "test_utils.h", ] deps = [ "//brave/components/ai_chat/core/browser", + "//brave/components/ai_chat/core/common/mojom", "//services/network/public/cpp", "//testing/gmock", + "//testing/gtest", ] } diff --git a/components/ai_chat/core/browser/DEPS b/components/ai_chat/core/browser/DEPS index 108a66fbb918..c80cb5a16d41 100644 --- a/components/ai_chat/core/browser/DEPS +++ b/components/ai_chat/core/browser/DEPS @@ -1,7 +1,15 @@ include_rules = [ + "+cc/port", + "+cc/task/core", + "+cc/task/text", + "+components/os_crypt/async", "+services/data_decoder/public", "+services/network/public", "+services/network/test", + "+sql", + "+tensorflow_lite_support/cc/task/core/proto", + "+tensorflow_lite_support/cc/task/processor/proto", + "+tensorflow_lite_support/cc/task/text/proto", "+third_party/skia/include", "+third_party/re2/src/re2", "+third_party/tflite", diff --git a/components/ai_chat/core/browser/ai_chat_credential_manager.cc b/components/ai_chat/core/browser/ai_chat_credential_manager.cc index 8f1d776267e9..27fa7d86fcf6 100644 --- a/components/ai_chat/core/browser/ai_chat_credential_manager.cc +++ b/components/ai_chat/core/browser/ai_chat_credential_manager.cc @@ -5,28 +5,41 @@ #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" +#include #include +#include +#include +#include #include #include -#include "base/base64.h" -#include "base/i18n/time_formatting.h" +#include "base/check.h" +#include "base/functional/bind.h" #include "base/json/json_reader.h" -#include "base/json/json_writer.h" #include "base/json/values_util.h" +#include "base/numerics/clamped_math.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" #include "base/time/time.h" +#include "base/value_iterators.h" +#include "base/values.h" #include "brave/brave_domains/service_domains.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "brave/components/ai_chat/core/common/pref_names.h" #include "components/prefs/pref_service.h" #include "components/prefs/scoped_user_pref_update.h" +#include "mojo/public/cpp/bindings/pending_remote.h" +#include "mojo/public/cpp/bindings/struct_ptr.h" #include "net/cookies/cookie_inclusion_status.h" #include "net/cookies/cookie_util.h" #include "net/cookies/parsed_cookie.h" +#include "url/url_canon.h" #include "url/url_util.h" +#if BUILDFLAG(IS_ANDROID) +#include "base/base64.h" +#include "base/json/json_writer.h" +#endif // BUILDFLAG(IS_ANDROID) + namespace { constexpr char kLeoSkuHostnamePart[] = "leo"; diff --git a/components/ai_chat/core/browser/ai_chat_credential_manager.h b/components/ai_chat/core/browser/ai_chat_credential_manager.h index 0cad013b1ba9..ec132b32248f 100644 --- a/components/ai_chat/core/browser/ai_chat_credential_manager.h +++ b/components/ai_chat/core/browser/ai_chat_credential_manager.h @@ -10,14 +10,21 @@ #include #include "base/functional/callback.h" +#include "base/memory/raw_ptr.h" #include "base/memory/weak_ptr.h" #include "base/time/time.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/skus/common/skus_sdk.mojom.h" +#include "build/build_config.h" #include "mojo/public/cpp/bindings/pending_remote.h" +#include "mojo/public/cpp/bindings/remote.h" #include "mojo/public/cpp/bindings/remote_set.h" class PrefService; +namespace mojo { +template +class PendingRemote; +} // namespace mojo namespace ai_chat { diff --git a/components/ai_chat/core/browser/ai_chat_credential_manager_unittest.cc b/components/ai_chat/core/browser/ai_chat_credential_manager_unittest.cc index cfb02fa730b6..001dcb0064e9 100644 --- a/components/ai_chat/core/browser/ai_chat_credential_manager_unittest.cc +++ b/components/ai_chat/core/browser/ai_chat_credential_manager_unittest.cc @@ -7,28 +7,37 @@ #include #include +#include #include #include -#include "base/i18n/time_formatting.h" +#include "base/functional/bind.h" #include "base/json/values_util.h" +#include "base/memory/scoped_refptr.h" +#include "base/numerics/clamped_math.h" +#include "base/run_loop.h" +#include "base/strings/string_util.h" #include "base/strings/stringprintf.h" #include "base/test/bind.h" #include "base/test/scoped_feature_list.h" #include "base/time/time.h" +#include "base/values.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/ai_chat/core/common/pref_names.h" #include "brave/components/skus/browser/pref_names.h" -#include "brave/components/skus/browser/skus_context_impl.h" #include "brave/components/skus/browser/skus_service_impl.h" #include "brave/components/skus/browser/skus_utils.h" #include "brave/components/skus/common/features.h" -#include "brave/components/skus/common/skus_sdk.mojom.h" +#include "components/prefs/pref_service.h" #include "components/prefs/testing_pref_service.h" #include "content/public/test/browser_task_environment.h" +#include "mojo/public/cpp/bindings/pending_remote.h" +#include "mojo/public/cpp/bindings/struct_ptr.h" +#include "services/network/public/cpp/shared_url_loader_factory.h" #include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h" #include "services/network/test/test_url_loader_factory.h" #include "testing/gtest/include/gtest/gtest.h" +#include "third_party/abseil-cpp/absl/strings/str_format.h" namespace { diff --git a/components/ai_chat/core/browser/ai_chat_database.cc b/components/ai_chat/core/browser/ai_chat_database.cc new file mode 100644 index 000000000000..941588cb8efa --- /dev/null +++ b/components/ai_chat/core/browser/ai_chat_database.cc @@ -0,0 +1,1051 @@ +/* Copyright (c) 2023 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/ai_chat/core/browser/ai_chat_database.h" + +#include +#include +#include +#include + +#include "base/check.h" +#include "base/strings/string_split.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "components/os_crypt/async/common/encryptor.h" +#include "sql/init_status.h" +#include "sql/meta_table.h" +#include "sql/statement.h" +#include "sql/transaction.h" + +namespace { + +// These database versions should roll together unless we develop migrations. +// Lowest version we support migrations from - existing database will be deleted +// if lower. +constexpr int kLowestSupportedDatabaseVersion = 1; +// Current version of the database. Increase if breaking changes are made. +constexpr int kCurrentDatabaseVersion = 1; + +constexpr char kSearchQueriesSeparator[] = "|||"; + +std::optional GetOptionalString(sql::Statement& statement, + int index) { + if (statement.GetColumnType(index) == sql::ColumnType::kNull) { + return std::nullopt; + } + return std::make_optional(statement.ColumnString(index)); +} + +void BindOptionalString(sql::Statement& statement, + int index, + const std::optional& value) { + if (value.has_value() && !value.value().empty()) { + statement.BindString(index, value.value()); + } else { + statement.BindNull(index); + } +} + +} // namespace + +namespace ai_chat { + +AIChatDatabase::AIChatDatabase(const base::FilePath& db_file_path, + os_crypt_async::Encryptor encryptor) + : db_file_path_(db_file_path), + db_({.page_size = 4096, .cache_size = 1000}), + encryptor_(std::move(encryptor)) {} + +AIChatDatabase::~AIChatDatabase() = default; + +bool AIChatDatabase::LazyInit(bool re_init) { + if (!db_init_status_.has_value() || re_init) { + db_init_status_ = InitInternal(); + } + + return *db_init_status_ == sql::InitStatus::INIT_OK; +} + +sql::InitStatus AIChatDatabase::InitInternal() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (!GetDB().is_open() && !GetDB().Open(db_file_path_)) { + return sql::InitStatus::INIT_FAILURE; + } + + if (sql::MetaTable::RazeIfIncompatible( + &GetDB(), kLowestSupportedDatabaseVersion, kCurrentDatabaseVersion) == + sql::RazeIfIncompatibleResult::kFailed) { + return sql::InitStatus::INIT_FAILURE; + } + + sql::Transaction transaction(&GetDB()); + if (!transaction.Begin()) { + return sql::InitStatus::INIT_FAILURE; + } + + sql::MetaTable meta_table; + if (!meta_table.Init(&GetDB(), kCurrentDatabaseVersion, + /*compatible_version=*/kCurrentDatabaseVersion)) { + DVLOG(0) << "Failed to init meta table"; + return sql::InitStatus::INIT_FAILURE; + } + + if (meta_table.GetCompatibleVersionNumber() > kCurrentDatabaseVersion) { + LOG(ERROR) << "AIChat database version is too new."; + return sql::InitStatus::INIT_TOO_NEW; + } + + if (!CreateSchema()) { + DVLOG(0) << "Failure to create tables"; + return sql::InitStatus::INIT_FAILURE; + } + + if (!transaction.Commit()) { + return sql::InitStatus::INIT_FAILURE; + } + + return sql::InitStatus::INIT_OK; +} + +std::vector AIChatDatabase::GetAllConversations() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (!LazyInit()) { + return {}; + } + // All conversation metadata, associated content and most + // and most recent entry date. 1 row for each associated content. + static constexpr char kQuery[] = + "SELECT conversation.uuid, conversation.title, conversation.model_key," + " last_activity_date.date," + " associated_content.uuid, associated_content.title," + " associated_content.url, associated_content.content_type," + " associated_content.content_used_percentage," + " associated_content.is_content_refined" + " FROM conversation" + " LEFT JOIN associated_content" + " ON conversation.uuid = associated_content.conversation_uuid" + " LEFT JOIN (" + " SELECT conversation_entry.date AS date, " + " conversation_entry.conversation_uuid AS conversation_uuid " + " FROM conversation_entry" + " GROUP BY conversation_entry.conversation_uuid" + " ORDER BY conversation_entry.date desc) " + " AS last_activity_date" + " ON last_activity_date.conversation_uuid = conversation.uuid" + " ORDER BY conversation.uuid ASC"; + sql::Statement statement(GetDB().GetCachedStatement(SQL_FROM_HERE, kQuery)); + CHECK(statement.is_valid()); + + std::vector conversation_list; + // This/last row's conversation + mojom::ConversationPtr conversation; + + while (statement.Step()) { + DVLOG(1) << __func__ << " got a result"; + std::string uuid = statement.ColumnString(0); + if (conversation) { + if (conversation->uuid == uuid) { + // TODO(petemill): Support multiple associated content + continue; + } else { + conversation_list.emplace_back(std::move(conversation)); + } + } + auto index = 1; + conversation = mojom::Conversation::New(); + conversation->uuid = uuid; + conversation->title = + DecryptOptionalColumnToString(statement, index++).value_or(""); + conversation->model_key = GetOptionalString(statement, index++); + conversation->updated_time = statement.ColumnTime(index++); + conversation->has_content = true; + + conversation->associated_content = mojom::SiteInfo::New(); + + if (statement.GetColumnType(index) != sql::ColumnType::kNull) { + DVLOG(1) << __func__ << " got associated content"; + + conversation->associated_content->uuid = statement.ColumnString(index++); + conversation->associated_content->title = + DecryptOptionalColumnToString(statement, index++); + auto url_raw = DecryptOptionalColumnToString(statement, index++); + if (url_raw.has_value()) { + conversation->associated_content->url = GURL(url_raw.value()); + } + conversation->associated_content->content_type = + static_cast(statement.ColumnInt(index++)); + conversation->associated_content->content_used_percentage = + statement.ColumnInt(index++); + conversation->associated_content->is_content_refined = + statement.ColumnBool(index++); + conversation->associated_content->is_content_association_possible = true; + } else { + conversation->associated_content->is_content_association_possible = false; + } + } + + // Final row's conversation + if (conversation) { + conversation_list.emplace_back(std::move(conversation)); + } + + return conversation_list; +} + +mojom::ConversationArchivePtr AIChatDatabase::GetConversationData( + std::string_view conversation_uuid) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (!LazyInit()) { + return nullptr; + } + + return mojom::ConversationArchive::New( + GetConversationEntries(conversation_uuid), + GetArchiveContentsForConversation(conversation_uuid)); +} + +std::vector AIChatDatabase::GetConversationEntries( + std::string_view conversation_uuid) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + static constexpr char kEntriesQuery[] = + "SELECT uuid, date, entry_text, character_type, editing_entry_uuid, " + "action_type, selected_text" + " FROM conversation_entry" + " WHERE conversation_uuid=?" + " ORDER BY date ASC"; + sql::Statement statement( + GetDB().GetCachedStatement(SQL_FROM_HERE, kEntriesQuery)); + CHECK(statement.is_valid()); + + statement.BindString(0, conversation_uuid); + + DVLOG(4) << __func__ << " for " << conversation_uuid; + + std::vector history; + // Map of editing entry id to the edit entry + std::map> edits; + + while (statement.Step()) { + // basic metadata + std::string entry_uuid = statement.ColumnString(0); + DVLOG(4) << "Found entry row for conversation " << conversation_uuid + << " with id " << entry_uuid; + auto date = statement.ColumnTime(1); + auto text = DecryptOptionalColumnToString(statement, 2).value_or(""); + auto character_type = + static_cast(statement.ColumnInt(3)); + auto editing_entry_id = GetOptionalString(statement, 4); + auto action_type = static_cast(statement.ColumnInt(5)); + auto selected_text = DecryptOptionalColumnToString(statement, 6); + + auto entry = mojom::ConversationTurn::New( + entry_uuid, character_type, action_type, + mojom::ConversationTurnVisibility::VISIBLE, text, selected_text, + std::nullopt, date, std::nullopt, false); + + // events + struct Event { + int event_order; + mojom::ConversationEntryEventPtr event; + }; + std::vector events; + + // Completion events + { + sql::Statement event_statement( + GetDB().GetCachedStatement(SQL_FROM_HERE, + "SELECT event_order, text" + " FROM conversation_entry_event_completion" + " WHERE conversation_entry_uuid=?" + " ORDER BY event_order ASC")); + event_statement.BindString(0, entry_uuid); + + while (event_statement.Step()) { + int event_order = event_statement.ColumnInt(0); + std::string completion = DecryptColumnToString(event_statement, 1); + events.emplace_back(Event{ + event_order, mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New(completion))}); + } + } + + // Search Query events + { + sql::Statement event_statement(GetDB().GetUniqueStatement( + "SELECT event_order, queries" + " FROM conversation_entry_event_search_queries" + " WHERE conversation_entry_uuid=?" + " ORDER BY event_order ASC")); + event_statement.BindString(0, entry_uuid); + + while (event_statement.Step()) { + int event_order = event_statement.ColumnInt(0); + auto queries_data = DecryptColumnToString(event_statement, 1); + std::vector queries = + base::SplitString(queries_data, kSearchQueriesSeparator, + base::WhitespaceHandling::TRIM_WHITESPACE, + base::SplitResult::SPLIT_WANT_NONEMPTY); + events.emplace_back(Event{ + event_order, mojom::ConversationEntryEvent::NewSearchQueriesEvent( + mojom::SearchQueriesEvent::New(queries))}); + } + } + + // insert events in order + if (!events.empty()) { + base::ranges::sort(events, [](const Event& a, const Event& b) { + return a.event_order < b.event_order; + }); + entry->events = std::vector{}; + for (auto& event : events) { + entry->events->emplace_back(std::move(event.event)); + } + } + + // root entry or edited entry + if (editing_entry_id.has_value()) { + DVLOG(4) << "Collected edit entry for " << editing_entry_id.value() + << " with id " << entry_uuid; + edits[editing_entry_id.value()].emplace_back(std::move(entry)); + } else { + DVLOG(4) << "Collected entry for " << entry_uuid; + history.emplace_back(std::move(entry)); + } + } + + // Reconstruct edits + for (auto& entry : history) { + CHECK(entry->uuid.has_value()); + auto id = entry->uuid.value(); + if (edits.count(id)) { + entry->edits = std::vector{}; + for (auto& edit : edits[id]) { + entry->edits->emplace_back(std::move(edit)); + } + } + } + + return history; +} + +std::vector +AIChatDatabase::GetArchiveContentsForConversation( + std::string_view conversation_uuid) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + static constexpr char kQuery[] = + "SELECT uuid, last_contents" + " FROM associated_content" + " WHERE conversation_uuid=?" + " AND last_contents IS NOT NULL" + " ORDER BY uuid ASC"; + sql::Statement statement(GetDB().GetCachedStatement(SQL_FROM_HERE, kQuery)); + CHECK(statement.is_valid()); + statement.BindString(0, conversation_uuid); + std::vector archive_contents; + // We only support a single entry until ConversationHandler supports multiple + // associated contents. + if (statement.Step()) { + auto content = mojom::ContentArchive::New( + statement.ColumnString(0), DecryptColumnToString(statement, 1)); + archive_contents.emplace_back(std::move(content)); + } + return archive_contents; +} + +bool AIChatDatabase::AddConversation(mojom::ConversationPtr conversation, + std::optional contents, + mojom::ConversationTurnPtr first_entry) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + CHECK(!conversation->uuid.empty()); + CHECK(first_entry); + if (!LazyInit()) { + return false; + } + + sql::Transaction transaction(&GetDB()); + CHECK(GetDB().is_open()); + if (!transaction.Begin()) { + DVLOG(0) << "Transaction cannot begin"; + return false; + } + + static constexpr char kInsertConversationQuery[] = + "INSERT INTO conversation(uuid, title, model_key) " + "VALUES(?, ?, ?)"; + sql::Statement statement( + GetDB().GetUniqueStatement(kInsertConversationQuery)); + CHECK(statement.is_valid()); + + statement.BindString(0, conversation->uuid); + + BindAndEncryptOptionalString(statement, 1, conversation->title); + BindOptionalString(statement, 2, conversation->model_key); + + if (!statement.Run()) { + DVLOG(0) << "Failed to execute 'conversation' insert statement: " + << db_.GetErrorMessage(); + return false; + } + + if (conversation->associated_content->is_content_association_possible) { + DVLOG(2) << "Adding associated content for conversation " + << conversation->uuid << " with url " + << conversation->associated_content->url->spec(); + if (!AddOrUpdateAssociatedContent( + conversation->uuid, std::move(conversation->associated_content), + contents)) { + return false; + } + } + + if (!AddConversationEntry(conversation->uuid, std::move(first_entry))) { + return false; + } + + if (!transaction.Commit()) { + DVLOG(0) << "Transaction commit failed with reason: " + << db_.GetErrorMessage(); + return false; + } + + return true; +} + +bool AIChatDatabase::AddOrUpdateAssociatedContent( + std::string_view conversation_uuid, + mojom::SiteInfoPtr associated_content, + std::optional contents) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (!LazyInit()) { + return false; + } + + // TODO(petemill): handle multiple associated content per conversation + CHECK(!conversation_uuid.empty()); + CHECK(associated_content->uuid.has_value()); + + // Check if we already have persisted this content + static constexpr char kSelectExistingAssociatedContentId[] = + "SELECT uuid FROM associated_content WHERE conversation_uuid=?" + " AND uuid=?"; + sql::Statement select_statement(GetDB().GetCachedStatement( + SQL_FROM_HERE, kSelectExistingAssociatedContentId)); + CHECK(select_statement.is_valid()); + select_statement.BindString(0, conversation_uuid); + select_statement.BindString(1, associated_content->uuid.value()); + + sql::Statement statement; + if (select_statement.Step()) { + DVLOG(4) << "Updating associated content for conversation " + << conversation_uuid << " with id " + << associated_content->uuid.value(); + static constexpr char kUpdateAssociatedContentQuery[] = + "UPDATE associated_content" + " SET title = ?," + " url = ?," + " content_type = ?," + " last_contents = ?," + " content_used_percentage = ?," + " is_content_refined = ?" + " WHERE uuid=? and conversation_uuid=?"; + statement.Assign(GetDB().GetUniqueStatement(kUpdateAssociatedContentQuery)); + } else { + DVLOG(4) << "Inserting associated content for conversation " + << conversation_uuid; + static constexpr char kInsertAssociatedContentQuery[] = + "INSERT INTO associated_content(title, url," + " content_type, last_contents, content_used_percentage," + " is_content_refined, uuid, conversation_uuid)" + " VALUES(?, ?, ?, ?, ?, ?, ?, ?) "; + statement.Assign(GetDB().GetUniqueStatement(kInsertAssociatedContentQuery)); + } + CHECK(statement.is_valid()); + int index = 0; + BindAndEncryptOptionalString(statement, index++, associated_content->title); + BindAndEncryptOptionalString(statement, index++, + associated_content->url->spec()); + statement.BindInt(index++, + base::to_underlying(associated_content->content_type)); + BindAndEncryptOptionalString(statement, index++, contents); + statement.BindInt(index++, associated_content->content_used_percentage); + statement.BindBool(index++, associated_content->is_content_refined); + statement.BindString(index++, associated_content->uuid.value()); + statement.BindString(index, conversation_uuid); + + if (!statement.Run()) { + DVLOG(0) + << "Failed to execute 'associated_content' insert or update statement: " + << db_.GetErrorMessage(); + return false; + } + + return true; +} + +bool AIChatDatabase::AddConversationEntry( + std::string_view conversation_uuid, + mojom::ConversationTurnPtr entry, + std::optional model_key, + std::optional editing_id) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + CHECK(!conversation_uuid.empty()); + CHECK(entry->uuid.has_value() && !entry->uuid->empty()); + if (!LazyInit()) { + return false; + } + + // Verify the conversation exists and get existing model key. We don't + // want to add orphan conversation entries when the conversation doesn't + // exist. + static constexpr char kGetConversationIdQuery[] = + "SELECT model_key FROM conversation WHERE uuid=?"; + sql::Statement get_conversation_model_statement( + GetDB().GetCachedStatement(SQL_FROM_HERE, kGetConversationIdQuery)); + CHECK(get_conversation_model_statement.is_valid()); + get_conversation_model_statement.BindString(0, conversation_uuid); + if (!get_conversation_model_statement.Step()) { + DVLOG(0) << "ID not found in 'conversation' table"; + return false; + } + auto existing_model_key = + GetOptionalString(get_conversation_model_statement, 0); + + sql::Transaction transaction(&GetDB()); + CHECK(GetDB().is_open()); + if (!transaction.Begin()) { + DVLOG(0) << "Transaction cannot begin"; + return false; + } + + bool has_valid_new_model_key = !model_key.value_or("").empty(); + bool should_update_model = ( + // Clear existing + (!has_valid_new_model_key && existing_model_key.has_value()) || + // Change or add existing + (has_valid_new_model_key && + (existing_model_key.value_or("") != model_key.value()))); + if (should_update_model) { + // Update model key if neccessary + static constexpr char kUpdateModelKeyQuery[] = + "UPDATE conversation SET model_key=? WHERE uuid=?"; + sql::Statement update_model_key_statement( + GetDB().GetCachedStatement(SQL_FROM_HERE, kUpdateModelKeyQuery)); + update_model_key_statement.BindString(1, conversation_uuid); + if (has_valid_new_model_key) { + update_model_key_statement.BindString(0, model_key.value()); + } else { + update_model_key_statement.BindNull(0); + } + update_model_key_statement.Run(); + } + + sql::Statement insert_conversation_entry_statement; + + if (editing_id.has_value()) { + static constexpr char kInsertEditingConversationEntryQuery[] = + "INSERT INTO conversation_entry(editing_entry_uuid, uuid," + " conversation_uuid, date, entry_text," + " character_type, action_type, selected_text)" + " VALUES(?, ?, ?, ?, ?, ?, ?, ?)"; + insert_conversation_entry_statement.Assign( + GetDB().GetUniqueStatement(kInsertEditingConversationEntryQuery)); + } else { + static constexpr char kInsertConversationEntryQuery[] = + "INSERT INTO conversation_entry(uuid, conversation_uuid, date," + " entry_text, character_type, action_type, selected_text)" + " VALUES(?, ?, ?, ?, ?, ?, ?)"; + insert_conversation_entry_statement.Assign( + GetDB().GetUniqueStatement(kInsertConversationEntryQuery)); + } + CHECK(insert_conversation_entry_statement.is_valid()); + + int index = 0; + if (editing_id.has_value()) { + insert_conversation_entry_statement.BindString(index++, editing_id.value()); + } + insert_conversation_entry_statement.BindString(index++, entry->uuid.value()); + insert_conversation_entry_statement.BindString(index++, conversation_uuid); + insert_conversation_entry_statement.BindTime(index++, entry->created_time); + BindAndEncryptOptionalString(insert_conversation_entry_statement, index++, + entry->text); + insert_conversation_entry_statement.BindInt( + index++, base::to_underlying(entry->character_type)); + insert_conversation_entry_statement.BindInt( + index++, base::to_underlying(entry->action_type)); + BindAndEncryptOptionalString(insert_conversation_entry_statement, index++, + entry->selected_text); + + if (!insert_conversation_entry_statement.Run()) { + DVLOG(0) << "Failed to execute 'conversation_entry' insert statement: " + << db_.GetErrorMessage(); + return false; + } + + if (entry->events.has_value()) { + for (size_t i = 0; i < entry->events->size(); i++) { + const mojom::ConversationEntryEventPtr& event = entry->events->at(i); + switch (event->which()) { + case mojom::ConversationEntryEvent::Tag::kCompletionEvent: { + sql::Statement event_statement(GetDB().GetCachedStatement( + SQL_FROM_HERE, + "INSERT INTO conversation_entry_event_completion" + " (event_order, text, conversation_entry_uuid)" + " VALUES(?, ?, ?)")); + CHECK(event_statement.is_valid()); + event_statement.BindInt(0, static_cast(i)); + if (!BindAndEncryptString( + event_statement, 1, + event->get_completion_event()->completion)) { + return false; + } + event_statement.BindString(2, entry->uuid.value()); + event_statement.Run(); + break; + } + case mojom::ConversationEntryEvent::Tag::kSearchQueriesEvent: { + sql::Statement event_statement(GetDB().GetCachedStatement( + SQL_FROM_HERE, + "INSERT INTO conversation_entry_event_search_queries" + " (event_order, queries, conversation_entry_uuid)" + " VALUES(?, ?, ?)")); + CHECK(event_statement.is_valid()); + + std::string queries_data = base::JoinString( + event->get_search_queries_event()->search_queries, + kSearchQueriesSeparator); + + event_statement.BindInt(0, static_cast(i)); + if (!BindAndEncryptString(event_statement, 1, queries_data)) { + return false; + } + event_statement.BindString(2, entry->uuid.value()); + event_statement.Run(); + break; + } + default: { + break; + } + } + } + } + + if (entry->edits.has_value()) { + for (auto& edit : entry->edits.value()) { + if (!AddConversationEntry(conversation_uuid, std::move(edit), model_key, + entry->uuid.value())) { + return false; + } + } + } + + if (!transaction.Commit()) { + DVLOG(0) << "Transaction commit failed with reason: " + << db_.GetErrorMessage(); + return false; + } + + return true; +} + +bool AIChatDatabase::UpdateConversationTitle(std::string_view conversation_uuid, + std::string_view title) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (!LazyInit()) { + return false; + } + + static constexpr char kUpdateConversationTitleQuery[] = + "UPDATE conversation SET title=? WHERE uuid=?"; + sql::Statement statement( + GetDB().GetCachedStatement(SQL_FROM_HERE, kUpdateConversationTitleQuery)); + CHECK(statement.is_valid()); + + if (!BindAndEncryptString(statement, 0, title)) { + return false; + } + statement.BindString(1, conversation_uuid); + + return statement.Run(); +} + +bool AIChatDatabase::DeleteConversation(std::string_view conversation_uuid) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (!LazyInit()) { + return false; + } + + sql::Transaction transaction(&db_); + if (!transaction.Begin()) { + DVLOG(0) << "Transaction cannot begin\n"; + return false; + } + + // Delete all conversation entries + static constexpr char kSelectConversationEntryQuery[] = + "SELECT uuid FROM conversation_entry WHERE conversation_uuid=?"; + sql::Statement select_conversation_entry_statement( + GetDB().GetUniqueStatement(kSelectConversationEntryQuery)); + CHECK(select_conversation_entry_statement.is_valid()); + select_conversation_entry_statement.BindString(0, conversation_uuid); + + // Delete all conversation entry events + while (select_conversation_entry_statement.Step()) { + std::string conversation_entry_uuid = + select_conversation_entry_statement.ColumnString(0); + static constexpr char kDeleteCompletionEventQuery[] = + "DELETE FROM conversation_entry_event_completion" + " WHERE conversation_entry_uuid=?"; + sql::Statement delete_completion_event_statement( + GetDB().GetUniqueStatement(kDeleteCompletionEventQuery)); + CHECK(delete_completion_event_statement.is_valid()); + delete_completion_event_statement.BindString(0, conversation_entry_uuid); + if (!delete_completion_event_statement.Run()) { + return false; + } + + static constexpr char kDeleteSearchQueriesEventQuery[] = + "DELETE FROM conversation_entry_event_search_queries " + " WHERE conversation_entry_uuid=?"; + sql::Statement delete_queries_event_statement( + GetDB().GetUniqueStatement(kDeleteSearchQueriesEventQuery)); + CHECK(delete_queries_event_statement.is_valid()); + delete_queries_event_statement.BindString(0, conversation_entry_uuid); + if (!delete_queries_event_statement.Run()) { + return false; + } + + static constexpr char kDeleteEntryQuery[] = + "DELETE FROM conversation_entry WHERE uuid=?"; + sql::Statement delete_conversation_entry_statement( + GetDB().GetUniqueStatement(kDeleteEntryQuery)); + CHECK(delete_conversation_entry_statement.is_valid()); + delete_conversation_entry_statement.BindString(0, conversation_entry_uuid); + if (!delete_conversation_entry_statement.Run()) { + return false; + } + } + + // Delete the conversation metadata + static constexpr char kDeleteAssociatedContentQuery[] = + "DELETE FROM associated_content WHERE conversation_uuid=?"; + sql::Statement delete_associated_content_statement( + GetDB().GetUniqueStatement(kDeleteAssociatedContentQuery)); + CHECK(delete_associated_content_statement.is_valid()); + delete_associated_content_statement.BindString(0, conversation_uuid); + if (!delete_associated_content_statement.Run()) { + return false; + } + + static constexpr char kDeleteConversationQuery[] = + "DELETE FROM conversation WHERE uuid=?"; + sql::Statement delete_conversation_statement( + GetDB().GetUniqueStatement(kDeleteConversationQuery)); + CHECK(delete_conversation_statement.is_valid()); + delete_conversation_statement.BindString(0, conversation_uuid); + if (!delete_conversation_statement.Run()) { + return false; + } + + if (!transaction.Commit()) { + DVLOG(0) << "Transaction commit failed with reason: " + << db_.GetErrorMessage(); + return false; + } + return true; +} + +bool AIChatDatabase::DeleteConversationEntry( + std::string_view conversation_entry_uuid) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (!LazyInit()) { + return false; + } + + sql::Transaction transaction(&db_); + CHECK(!conversation_entry_uuid.empty()); + if (!transaction.Begin()) { + DVLOG(0) << "Transaction cannot begin\n"; + return false; + } + // Delete from conversation_entry_event_completion + { + sql::Statement delete_statement(GetDB().GetUniqueStatement( + "DELETE FROM conversation_entry_event_completion WHERE " + "conversation_entry_uuid=?")); + delete_statement.BindString(0, conversation_entry_uuid); + if (!delete_statement.Run()) { + DLOG(ERROR) + << "Failed to delete from conversation_entry_event_completion " + "for id: " + << conversation_entry_uuid; + return false; + } + } + + // Delete from conversation_entry_event_search_queries + { + static constexpr char kQuery[] = + "DELETE FROM conversation_entry_event_search_queries WHERE " + "conversation_entry_uuid=?"; + sql::Statement delete_statement(GetDB().GetUniqueStatement(kQuery)); + CHECK(delete_statement.is_valid()); + delete_statement.BindString(0, conversation_entry_uuid); + if (!delete_statement.Run()) { + DLOG(ERROR) << "Failed to delete from " + "conversation_entry_event_search_queries for conversation " + "entry uuid: " + << conversation_entry_uuid; + return false; + } + } + + // Delete edits + { + static constexpr char kQuery[] = + "DELETE FROM conversation_entry WHERE editing_entry_uuid = ?"; + sql::Statement delete_statement(GetDB().GetUniqueStatement(kQuery)); + CHECK(delete_statement.is_valid()); + delete_statement.BindString(0, conversation_entry_uuid); + if (!delete_statement.Run()) { + DLOG(ERROR) << "Failed to delete from conversation_entry for " + "conversation entry uuid: " + << conversation_entry_uuid; + return false; + } + } + + // Delete from conversation_entry + { + static constexpr char kQuery[] = + "DELETE FROM conversation_entry WHERE uuid=?"; + sql::Statement delete_statement(GetDB().GetUniqueStatement(kQuery)); + CHECK(delete_statement.is_valid()); + delete_statement.BindString(0, conversation_entry_uuid); + if (!delete_statement.Run()) { + LOG(ERROR) << "Failed to delete from conversation_entry for id: " + << conversation_entry_uuid; + return false; + } + } + + if (!transaction.Commit()) { + DVLOG(0) << "Transaction commit failed with reason: " + << db_.GetErrorMessage(); + return false; + } + return true; +} + +bool AIChatDatabase::DeleteAllData() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + // Ignore init failure when deletion. We only need the database to be open. + LazyInit(); + + if (!GetDB().is_open()) { + return false; + } + + // Delete everything + if (!GetDB().Raze()) { + return false; + } + + // Re-init the database + return LazyInit(true); +} + +bool AIChatDatabase::DeleteAssociatedWebContent( + std::optional begin_time, + std::optional end_time) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (!LazyInit()) { + return false; + } + DVLOG(4) << "Deleting associated web content for time range " + << begin_time.value_or(base::Time()) << " to " + << end_time.value_or(base::Time::Max()); + + // Set any associated content url, title and content to NULL where + // conversation had any entry between begin_time and end_time. + static constexpr char kQuery[] = + "UPDATE associated_content" + " SET url=NULL, title=NULL, last_contents=NULL" + " WHERE conversation_uuid IN (" + " SELECT conversation_uuid" + " FROM conversation_entry" + " WHERE date >= ? AND date <= ?)"; + sql::Statement statement(GetDB().GetUniqueStatement(kQuery)); + CHECK(statement.is_valid()); + statement.BindTime(0, begin_time.value_or(base::Time())); + statement.BindTime(1, end_time.value_or(base::Time::Max())); + if (!statement.Run()) { + DVLOG(0) << "Failed to execute 'associated_content' update statement for " + "DeleteAssociatedWebContent: " + << db_.GetErrorMessage(); + return false; + } + return true; +} + +sql::Database& AIChatDatabase::GetDB() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + return db_; +} + +std::string AIChatDatabase::DecryptColumnToString(sql::Statement& statement, + int index) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + auto decrypted_value = encryptor_.DecryptData(statement.ColumnBlob(index)); + if (!decrypted_value) { + DVLOG(0) << "Failed to decrypt value"; + return ""; + } + return *decrypted_value; +} + +std::optional AIChatDatabase::DecryptOptionalColumnToString( + sql::Statement& statement, + int index) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + // Don't allow non-BLOB types + if (statement.GetColumnType(index) != sql::ColumnType::kBlob) { + return std::nullopt; + } + auto decrypted_value = encryptor_.DecryptData(statement.ColumnBlob(index)); + if (!decrypted_value) { + DVLOG(0) << "Failed to decrypt value"; + return std::nullopt; + } + return *decrypted_value; +} + +void AIChatDatabase::BindAndEncryptOptionalString( + sql::Statement& statement, + int index, + std::optional value) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (value.has_value() && !value.value().empty()) { + auto encrypted_value = encryptor_.EncryptString(std::string(value.value())); + if (!encrypted_value) { + DVLOG(0) << "Failed to encrypt value"; + statement.BindNull(index); + return; + } + statement.BindBlob(index, *encrypted_value); + } else { + statement.BindNull(index); + } +} + +bool AIChatDatabase::BindAndEncryptString(sql::Statement& statement, + int index, + std::string_view value) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + auto encrypted_value = encryptor_.EncryptString(std::string(value)); + if (!encrypted_value) { + DVLOG(0) << "Failed to encrypt value"; + return false; + } + statement.BindBlob(index, *encrypted_value); + return true; +} + +bool AIChatDatabase::CreateSchema() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + static constexpr char kCreateConversationTableQuery[] = + "CREATE TABLE IF NOT EXISTS conversation(" + "uuid TEXT PRIMARY KEY NOT NULL," + // Encrypted conversation title string + "title BLOB," + "model_key TEXT)"; + CHECK(GetDB().IsSQLValid(kCreateConversationTableQuery)); + if (!GetDB().Execute(kCreateConversationTableQuery)) { + return false; + } + + // AssociatedContent is 1:many with Conversation for future-proofing when + // we support multiple associated contents per conversation. + static constexpr char kCreateAssociatedContentTableQuery[] = + "CREATE TABLE IF NOT EXISTS associated_content(" + "uuid TEXT PRIMARY KEY NOT NULL," + "conversation_uuid TEXT NOT NULL," + // Encrypted associated content title string + "title BLOB," + // Encrypted url string + "url BLOB," + // Stores SiteInfo.IsVideo. Future-proofed for multiple content types + // 0 for regular content + // 1 for video. + "content_type INTEGER NOT NULL," + // Encrypted string value of the content, so that conversations can be + // continued. + "last_contents BLOB," + // Don't need REAL for content_used_percentage since + // we're never using decimal values. + // UI expects 0 - 100 values. + "content_used_percentage INTEGER NOT NULL," + "is_content_refined INTEGER NOT NULL)"; + CHECK(GetDB().IsSQLValid(kCreateAssociatedContentTableQuery)); + if (!GetDB().Execute(kCreateAssociatedContentTableQuery)) { + return false; + } + + // AKA ConversationTurn in mojom + static constexpr char kCreateConversationEntryTableQuery[] = + "CREATE TABLE IF NOT EXISTS conversation_entry(" + "uuid TEXT PRIMARY KEY NOT NULL," + "conversation_uuid STRING NOT NULL," + "date INTEGER NOT NULL," + // Encrypted text string + // TODO(petemill): move to event only + "entry_text BLOB," + "character_type INTEGER NOT NULL," + // editing_entry points to the ConversationEntry row that is being edited. + // Edits can be sorted by date. + "editing_entry_uuid TEXT," + "action_type INTEGER," + // Encrypted selected text + "selected_text BLOB)"; + // TODO(petemill): Forking can be achieved by associating each + // ConversationEntry with a parent ConversationEntry. + // TODO(petemill): Store a model name with each entry to know when + // a model was changed for a conversation, or for forking-by-model features. + CHECK(GetDB().IsSQLValid(kCreateConversationEntryTableQuery)); + if (!GetDB().Execute(kCreateConversationEntryTableQuery)) { + return false; + } + + static constexpr char kCreateConversationEntryTextTableQuery[] = + "CREATE TABLE IF NOT EXISTS conversation_entry_event_completion(" + "conversation_entry_uuid INTEGER NOT NULL," + "event_order INTEGER NOT NULL," + // encrypted event text string + "text BLOB NOT NULL," + "PRIMARY KEY(conversation_entry_uuid, event_order)" + ")"; + CHECK(GetDB().IsSQLValid(kCreateConversationEntryTextTableQuery)); + if (!GetDB().Execute(kCreateConversationEntryTextTableQuery)) { + return false; + } + + static constexpr char kCreateSearchQueriesTableQuery[] = + "CREATE TABLE IF NOT EXISTS conversation_entry_event_search_queries(" + "conversation_entry_uuid INTEGER NOT NULL," + "event_order INTEGER NOT NULL," + // encrypted delimited search query strings + "queries BLOB NOT NULL," + "PRIMARY KEY(conversation_entry_uuid, event_order)" + ")"; + CHECK(GetDB().IsSQLValid(kCreateSearchQueriesTableQuery)); + if (!GetDB().Execute(kCreateSearchQueriesTableQuery)) { + return false; + } + + return true; +} + +} // namespace ai_chat diff --git a/components/ai_chat/core/browser/ai_chat_database.h b/components/ai_chat/core/browser/ai_chat_database.h new file mode 100644 index 000000000000..5a8b681094bf --- /dev/null +++ b/components/ai_chat/core/browser/ai_chat_database.h @@ -0,0 +1,121 @@ +/* Copyright (c) 2023 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_DATABASE_H_ +#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_DATABASE_H_ + +#include +#include +#include + +#include "base/sequence_checker.h" +#include "base/thread_annotations.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +#include "components/os_crypt/async/common/encryptor.h" +#include "sql/database.h" +#include "sql/init_status.h" + +namespace ai_chat { + +// Persists AI Chat conversations and associated content. Conversations are +// mainly formed of their conversation entries. Edits to conversation entries +// should be handled with removal and re-adding so that other classes can make +// decisions about how it affects the rest of history. +// All data should be stored encrypted. +class AIChatDatabase { + public: + AIChatDatabase(const base::FilePath& db_file_path, + os_crypt_async::Encryptor encryptor); + AIChatDatabase(const AIChatDatabase&) = delete; + AIChatDatabase& operator=(const AIChatDatabase&) = delete; + ~AIChatDatabase(); + + // Gets lightweight metadata for all conversations. No high-memory-consuming + // data is returned. + std::vector GetAllConversations(); + + // Gets all data needed to rehydrate a conversation + mojom::ConversationArchivePtr GetConversationData( + std::string_view conversation_uuid); + + // Returns new ID for the provided entry and any provided associated content + bool AddConversation(mojom::ConversationPtr conversation, + std::optional contents, + mojom::ConversationTurnPtr first_entry); + + // Update any properties of associated content metadata or full-text content + bool AddOrUpdateAssociatedContent(std::string_view conversation_uuid, + mojom::SiteInfoPtr associated_content, + std::optional content); + + // Adds a new conversation entry to the conversation with the provided UUID + bool AddConversationEntry( + std::string_view conversation_uuid, + mojom::ConversationTurnPtr entry, + std::optional model_key = std::nullopt, + std::optional editing_id = std::nullopt); + + // Updates the title of the conversation with the provided UUID + bool UpdateConversationTitle(std::string_view conversation_uuid, + std::string_view title); + + // Deletes the conversation with the provided UUID + bool DeleteConversation(std::string_view conversation_uuid); + + // Deletes the conversation entry with the provided ID and all associated + // edits and events. + bool DeleteConversationEntry(std::string_view conversation_entry_uuid); + + // Drops all data and tables in the database, and re-creates empty tables + bool DeleteAllData(); + + bool DeleteAssociatedWebContent(std::optional begin_time, + std::optional end_time); + + private: + friend class AIChatDatabaseTest; + + sql::Database& GetDB(); + + // Initializes the database if it hasn't been initialized yet. If |re_init| + // is true, it will forget previous intiialization state and attempt to + // re-initialize the database (e.g. after a table deletion). + bool LazyInit(bool re_init = false); + sql::InitStatus InitInternal(); + + std::vector GetConversationEntries( + std::string_view conversation_id); + std::vector GetArchiveContentsForConversation( + std::string_view conversation_uuid); + + std::string DecryptColumnToString(sql::Statement& statement, int index); + std::optional DecryptOptionalColumnToString( + sql::Statement& statement, + int index); + void BindAndEncryptOptionalString(sql::Statement& statement, + int index, + std::optional value); + bool BindAndEncryptString(sql::Statement& statement, + int index, + std::string_view value); + + bool CreateSchema(); + + // The directory storing the database. + const base::FilePath db_file_path_; + + // The underlying SQL database + sql::Database db_ GUARDED_BY_CONTEXT(sequence_checker_); + os_crypt_async::Encryptor encryptor_ GUARDED_BY_CONTEXT(sequence_checker_); + // The initialization status of the database. It's not set if never attempted. + std::optional db_init_status_ = std::nullopt; + + // Verifies that all operations happen on the same sequence. + SEQUENCE_CHECKER(sequence_checker_); +}; + +} // namespace ai_chat + +#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_DATABASE_H_ diff --git a/components/ai_chat/core/browser/ai_chat_database_unittest.cc b/components/ai_chat/core/browser/ai_chat_database_unittest.cc new file mode 100644 index 000000000000..b14343299231 --- /dev/null +++ b/components/ai_chat/core/browser/ai_chat_database_unittest.cc @@ -0,0 +1,442 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/ai_chat/core/browser/ai_chat_database.h" + +#include + +#include +#include +#include +#include +#include + +#include "base/containers/flat_tree.h" +#include "base/files/file_path.h" +#include "base/files/file_util.h" +#include "base/files/scoped_temp_dir.h" +#include "base/path_service.h" +#include "base/run_loop.h" +#include "base/strings/strcat.h" +#include "base/strings/stringprintf.h" +#include "base/test/bind.h" +#include "base/test/task_environment.h" +#include "base/time/time.h" +#include "base/uuid.h" +#include "brave/components/ai_chat/core/browser/test_utils.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "components/os_crypt/async/browser/test_utils.h" +#include "sql/init_status.h" +#include "sql/test/test_helpers.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace ai_chat { +class AIChatDatabaseTest : public testing::Test, + public testing::WithParamInterface { + public: + AIChatDatabaseTest() = default; + + void SetUp() override { + CHECK(temp_directory_.CreateUniqueTempDir()); + os_crypt_ = os_crypt_async::GetTestOSCryptAsyncForTesting( + /*is_sync_for_unittests=*/true); + + // Create database when os_crypt is ready + base::RunLoop run_loop; + encryptor_ready_subscription_ = + os_crypt_->GetInstance(base::BindLambdaForTesting( + [&](os_crypt_async::Encryptor encryptor, bool success) { + ASSERT_TRUE(success); + db_ = std::make_unique(db_file_path(), + std::move(encryptor)); + run_loop.Quit(); + })); + run_loop.Run(); + + if (GetParam()) { + db_->DeleteAllData(); + } + } + + void TearDown() override { + // Verify that the db was init successfully and not using default return + // values. + EXPECT_TRUE(IsInitOk()); + + db_.reset(); + CHECK(temp_directory_.Delete()); + } + + bool IsInitOk() { + return (db_->db_init_status_.has_value() && + db_->db_init_status_.value() == sql::InitStatus::INIT_OK); + } + + base::FilePath db_file_path() { + return temp_directory_.GetPath().AppendASCII("ai_chat"); + } + + protected: + base::test::TaskEnvironment task_environment_{ + base::test::TaskEnvironment::TimeSource::MOCK_TIME}; + base::ScopedTempDir temp_directory_; + std::unique_ptr os_crypt_; + base::CallbackListSubscription encryptor_ready_subscription_; + std::unique_ptr db_; + base::FilePath path_; +}; + +INSTANTIATE_TEST_SUITE_P( + , + AIChatDatabaseTest, + // Run all tests with the initial schema and the schema created + // after calling DeleteAllData, to verify the schemas are the same + // and no tables are missing or different. + ::testing::Bool(), + [](const testing::TestParamInfo& info) { + return base::StringPrintf("DropTablesFirst_%s", + info.param ? "Yes" : "No"); + }); + +// Functions tested: +// - AddConversation +// - GetAllConversations +// - GetConversationData +// - AddConversationEntry +// - DeleteConversationEntry +// - DeleteConversation +TEST_P(AIChatDatabaseTest, AddAndGetConversationAndEntries) { + auto now = base::Time::Now(); + + // Do this for both associated content and non-associated content + for (bool has_content : {true, false}) { + SCOPED_TRACE(testing::Message() << (has_content ? "With" : "Without") + << " associated content"); + const std::string uuid = has_content ? "first" : "second"; + const std::string content_uuid = "content"; + // Create the conversation metadata which gets persisted + // when the first entry is asked to be persisted. + // Put an incorrect time value to show that the time from the + // mojom::Conversation is not persisted and instead is taken from the most + // recent entry. + const GURL page_url = GURL("https://example.com/page"); + const std::string expected_contents = "Page contents"; + mojom::SiteInfoPtr associated_content = + has_content + ? mojom::SiteInfo::New( + content_uuid, mojom::ContentType::PageContent, "page title", + page_url.host(), page_url, 62, true, true) + : mojom::SiteInfo::New( + std::nullopt, mojom::ContentType::PageContent, std::nullopt, + std::nullopt, std::nullopt, 0, false, false); + const mojom::ConversationPtr metadata = + mojom::Conversation::New(uuid, "title", now - base::Hours(2), true, + std::nullopt, std::move(associated_content)); + + // Persist the first entry (and get the response ready) + auto history = CreateSampleChatHistory(1u); + + EXPECT_TRUE(db_->AddConversation( + metadata->Clone(), + has_content ? std::make_optional(expected_contents) : std::nullopt, + history[0]->Clone())); + + // Test getting the conversation metadata + std::vector conversations = + db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), has_content ? 1u : 2u); + auto& conversation = has_content ? conversations[0] : conversations[1]; + ExpectConversationEquals(FROM_HERE, conversation, metadata); + EXPECT_EQ(conversation->updated_time, history.front()->created_time); + + // Persist the response entry + EXPECT_TRUE(db_->AddConversationEntry(uuid, history[1]->Clone())); + + // Test getting the conversation entries + mojom::ConversationArchivePtr result = db_->GetConversationData(uuid); + ExpectConversationHistoryEquals(FROM_HERE, result->entries, history); + EXPECT_EQ(result->associated_content.size(), has_content ? 1u : 0u); + if (has_content) { + EXPECT_EQ(result->associated_content[0]->content_uuid, content_uuid); + EXPECT_EQ(result->associated_content[0]->content, expected_contents); + } + + // Add another pair of entries + auto next_history = CreateSampleChatHistory(1u, 1); + // Change the model this time + std::string new_model_key = "model-2"; + EXPECT_TRUE(db_->AddConversationEntry(uuid, next_history[0]->Clone(), + new_model_key)); + EXPECT_TRUE(db_->AddConversationEntry(uuid, next_history[1]->Clone())); + + // Verify all entries are returned + mojom::ConversationArchivePtr result_2 = db_->GetConversationData(uuid); + for (auto& entry : next_history) { + history.push_back(std::move(entry)); + } + ExpectConversationHistoryEquals(FROM_HERE, result_2->entries, history); + + // Verify metadata now has new model key + conversations = db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), has_content ? 1u : 2u); + ExpectConversationEquals( + FROM_HERE, has_content ? conversations[0] : conversations[1], metadata); + + // Edits (delete, re-add and check edit re-construction) + + // Delete the last response + EXPECT_TRUE( + db_->DeleteConversationEntry(result_2->entries.back()->uuid.value())); + + // Verify the last entry is gone + history.pop_back(); + mojom::ConversationArchivePtr result_3 = db_->GetConversationData(uuid); + ExpectConversationHistoryEquals(FROM_HERE, result_3->entries, history); + + // Add an edit to the last query + { + auto& last_query = result_3->entries.back(); + last_query->edits = std::vector{}; + last_query->edits->emplace_back(mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), + mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + mojom::ConversationTurnVisibility::VISIBLE, "edited query 1", + std::nullopt, std::nullopt, base::Time::Now() + base::Minutes(121), + std::nullopt, false)); + EXPECT_TRUE(db_->DeleteConversationEntry(last_query->uuid.value())); + EXPECT_TRUE(db_->AddConversationEntry(uuid, last_query->Clone())); + } + // Verify the edit is persisted + mojom::ConversationArchivePtr result_4 = db_->GetConversationData(uuid); + ExpectConversationHistoryEquals(FROM_HERE, result_4->entries, + result_3->entries); + + // Add another edit to test multiple edits for the same turn + { + auto& last_query = result_4->entries.back(); + last_query->edits->emplace_back(mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), + mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + mojom::ConversationTurnVisibility::VISIBLE, "edited query 2", + std::nullopt, std::nullopt, base::Time::Now() + base::Minutes(122), + std::nullopt, false)); + EXPECT_TRUE(db_->DeleteConversationEntry(last_query->uuid.value())); + EXPECT_TRUE(db_->AddConversationEntry(uuid, last_query->Clone())); + } + // Verify multiple edits are persisted + mojom::ConversationArchivePtr result_5 = db_->GetConversationData(uuid); + ExpectConversationHistoryEquals(FROM_HERE, result_5->entries, + result_4->entries); + } + + // Test deleting conversation (after loop so that we can test conversation + // entry selection with multiple conversations in the database). + EXPECT_TRUE(db_->DeleteConversation("second")); + // Verify no data for deleted conversation + mojom::ConversationArchivePtr conversation_data = + db_->GetConversationData("second"); + EXPECT_EQ(conversation_data->entries.size(), 0u); + EXPECT_EQ(conversation_data->associated_content.size(), 0u); + // Verify deleted conversation metadata not returned + std::vector conversations = + db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), 1u); + EXPECT_EQ(conversations[0]->uuid, "first"); + // Verify there's still data for other conversations + mojom::ConversationArchivePtr conversation_data_2 = + db_->GetConversationData("first"); + EXPECT_GT(conversation_data_2->entries.size(), 0u); + EXPECT_EQ(conversation_data_2->associated_content.size(), 1u); + // Delete last conversation + EXPECT_TRUE(db_->DeleteConversation("first")); + conversations = db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), 0u); +} + +TEST_P(AIChatDatabaseTest, UpdateConversationTitle) { + const std::vector initial_titles = {"first title", ""}; + for (const auto& initial_title : initial_titles) { + const std::string uuid = + base::StrCat({"for_conversation_title_", initial_title}); + const std::string updated_title = "updated title"; + mojom::ConversationPtr metadata = mojom::Conversation::New( + uuid, initial_title, base::Time::Now(), true, std::nullopt, + mojom::SiteInfo::New(std::nullopt, mojom::ContentType::PageContent, + std::nullopt, std::nullopt, std::nullopt, 0, false, + false)); + + // Persist the first entry (and get the response ready) + const auto history = CreateSampleChatHistory(1u); + + EXPECT_TRUE(db_->AddConversation(metadata->Clone(), std::nullopt, + history[0]->Clone())); + + // Verify initial title + std::vector conversations = + db_->GetAllConversations(); + // get this conversation + auto* conversation = GetConversation(FROM_HERE, conversations, uuid); + EXPECT_EQ(conversation->title, initial_title); + + // Update title + EXPECT_TRUE(db_->UpdateConversationTitle(uuid, updated_title)); + // Verify + conversations = db_->GetAllConversations(); + conversation = GetConversation(FROM_HERE, conversations, uuid); + EXPECT_EQ(conversation->title, updated_title); + } +} + +TEST_P(AIChatDatabaseTest, AddOrUpdateAssociatedContent) { + const std::string uuid = "for_associated_content"; + const std::string content_uuid = "content_uuid"; + const GURL page_url = GURL("https://example.com/page"); + mojom::ConversationPtr metadata = mojom::Conversation::New( + uuid, "title", base::Time::Now() - base::Hours(2), true, std::nullopt, + mojom::SiteInfo::New(content_uuid, mojom::ContentType::PageContent, + "page title", page_url.host(), page_url, 62, true, + true)); + + auto history = CreateSampleChatHistory(1u); + + std::string expected_contents = "First contents"; + EXPECT_TRUE(db_->AddConversation(metadata->Clone(), + std::make_optional(expected_contents), + history[0]->Clone())); + + // Verify data is persisted + mojom::ConversationArchivePtr result = db_->GetConversationData(uuid); + EXPECT_EQ(result->associated_content.size(), 1u); + EXPECT_EQ(result->associated_content[0]->content_uuid, content_uuid); + EXPECT_EQ(result->associated_content[0]->content, expected_contents); + auto conversations = db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), 1u); + ExpectConversationEquals(FROM_HERE, conversations[0], metadata); + + // Change data and call AddOrUpdateAssociatedContent + expected_contents = "Second contents"; + metadata->associated_content->content_used_percentage = 50; + metadata->associated_content->is_content_refined = false; + EXPECT_TRUE(db_->AddOrUpdateAssociatedContent( + uuid, metadata->associated_content->Clone(), + std::make_optional(expected_contents))); + // Verify data is changed + result = db_->GetConversationData(uuid); + EXPECT_EQ(result->associated_content.size(), 1u); + EXPECT_EQ(result->associated_content[0]->content_uuid, + metadata->associated_content->uuid.value()); + EXPECT_EQ(result->associated_content[0]->content, expected_contents); + conversations = db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), 1u); + ExpectConversationEquals(FROM_HERE, conversations[0], metadata); +} + +TEST_P(AIChatDatabaseTest, DeleteAllData) { + const std::string uuid = "first"; + const GURL page_url = GURL("https://example.com/page"); + mojom::ConversationPtr metadata = mojom::Conversation::New( + uuid, "title", base::Time::Now() - base::Hours(2), true, std::nullopt, + mojom::SiteInfo::New(std::nullopt, mojom::ContentType::PageContent, + std::nullopt, std::nullopt, std::nullopt, 0, false, + false)); + + auto history = CreateSampleChatHistory(1u); + + EXPECT_TRUE(db_->AddConversation(metadata->Clone(), std::nullopt, + history[0]->Clone())); + + // Verify data is persisted + { + mojom::ConversationArchivePtr result = db_->GetConversationData(uuid); + + ExpectConversationEntryEquals(FROM_HERE, result->entries[0], history[0]); + auto conversations = db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), 1u); + ExpectConversationEquals(FROM_HERE, conversations[0], metadata); + } + + // Delete all data + db_->DeleteAllData(); + + // Verify no data + { + auto conversations = db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), 0u); + mojom::ConversationArchivePtr result = db_->GetConversationData(uuid); + EXPECT_EQ(result->entries.size(), 0u); + } +} + +TEST_P(AIChatDatabaseTest, DeleteAssociatedWebContent) { + GURL page_url = GURL("https://example.com/page"); + std::string expected_contents = "First contents"; + + // The times in the Conversation are irrelevant, only the times of the entries + // are persisted. + mojom::ConversationPtr metadata_first = mojom::Conversation::New( + "first", "title", base::Time::Now() - base::Hours(2), true, std::nullopt, + mojom::SiteInfo::New("first-content", mojom::ContentType::PageContent, + "page title", page_url.host(), page_url, 62, true, + true)); + mojom::ConversationPtr metadata_second = mojom::Conversation::New( + "second", "title", base::Time::Now() - base::Hours(1), true, "model-2", + mojom::SiteInfo::New("second-content", mojom::ContentType::PageContent, + "page title", page_url.host(), page_url, 62, true, + true)); + + auto history_first = CreateSampleChatHistory(1u, -2); + auto history_second = CreateSampleChatHistory(1u, -1); + + EXPECT_TRUE(db_->AddConversation(metadata_first->Clone(), + std::make_optional(expected_contents), + history_first[0]->Clone())); + + EXPECT_TRUE(db_->AddConversation(metadata_second->Clone(), + std::make_optional(expected_contents), + history_second[0]->Clone())); + + // Verify data is persisted + auto conversations = db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), 2u); + ExpectConversationEquals(FROM_HERE, conversations[0], metadata_first); + ExpectConversationEquals(FROM_HERE, conversations[1], metadata_second); + + mojom::ConversationArchivePtr archive_result = + db_->GetConversationData("first"); + EXPECT_EQ(archive_result->associated_content.size(), 1u); + EXPECT_EQ(archive_result->associated_content[0]->content_uuid, + "first-content"); + EXPECT_EQ(archive_result->associated_content[0]->content, expected_contents); + archive_result = db_->GetConversationData("second"); + EXPECT_EQ(archive_result->associated_content.size(), 1u); + EXPECT_EQ(archive_result->associated_content[0]->content_uuid, + "second-content"); + EXPECT_EQ(archive_result->associated_content[0]->content, expected_contents); + + // Delete associated content to only consider the second conversation + EXPECT_TRUE(db_->DeleteAssociatedWebContent( + base::Time::Now() + base::Minutes(-61), std::nullopt)); + + // Verify only url, title and content was deleted and only from the second + // conversation + conversations = db_->GetAllConversations(); + EXPECT_EQ(conversations.size(), 2u); + ExpectConversationEquals(FROM_HERE, conversations[0], metadata_first); + metadata_second->associated_content->url = std::nullopt; + metadata_second->associated_content->title = std::nullopt; + ExpectConversationEquals(FROM_HERE, conversations[1], metadata_second); + + archive_result = db_->GetConversationData("second"); + EXPECT_TRUE(archive_result->associated_content.empty()); + archive_result = db_->GetConversationData("first"); + EXPECT_EQ(archive_result->associated_content.size(), 1u); + EXPECT_EQ(archive_result->associated_content[0]->content_uuid, + "first-content"); + EXPECT_EQ(archive_result->associated_content[0]->content, expected_contents); +} + +} // namespace ai_chat diff --git a/components/ai_chat/core/browser/ai_chat_feedback_api.cc b/components/ai_chat/core/browser/ai_chat_feedback_api.cc index aeb113bcc0d9..25945ae64e1c 100644 --- a/components/ai_chat/core/browser/ai_chat_feedback_api.cc +++ b/components/ai_chat/core/browser/ai_chat_feedback_api.cc @@ -5,19 +5,23 @@ #include "brave/components/ai_chat/core/browser/ai_chat_feedback_api.h" +#include #include +#include "base/containers/checked_iterators.h" #include "base/containers/flat_map.h" #include "base/json/json_writer.h" +#include "base/memory/scoped_refptr.h" #include "base/no_destructor.h" #include "base/strings/strcat.h" +#include "base/time/time.h" #include "base/values.h" #include "brave/brave_domains/service_domains.h" #include "brave/components/brave_stats/browser/brave_stats_updater_util.h" #include "brave/components/l10n/common/locale_util.h" +#include "mojo/public/cpp/bindings/struct_ptr.h" #include "net/traffic_annotation/network_traffic_annotation.h" #include "services/network/public/cpp/shared_url_loader_factory.h" -#include "services/network/public/cpp/simple_url_loader.h" #include "url/gurl.h" #include "url/url_constants.h" diff --git a/components/ai_chat/core/browser/ai_chat_feedback_api.h b/components/ai_chat/core/browser/ai_chat_feedback_api.h index 883f55952f76..f2ac6f798a05 100644 --- a/components/ai_chat/core/browser/ai_chat_feedback_api.h +++ b/components/ai_chat/core/browser/ai_chat_feedback_api.h @@ -6,11 +6,16 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_FEEDBACK_API_H_ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_FEEDBACK_API_H_ +#include #include +#include "base/containers/span.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/api_request_helper/api_request_helper.h" +template +class scoped_refptr; + namespace network { class SharedURLLoaderFactory; } // namespace network diff --git a/components/ai_chat/core/browser/ai_chat_metrics.cc b/components/ai_chat/core/browser/ai_chat_metrics.cc index c4ec714c9a9b..08c9f1547084 100644 --- a/components/ai_chat/core/browser/ai_chat_metrics.cc +++ b/components/ai_chat/core/browser/ai_chat_metrics.cc @@ -5,14 +5,28 @@ #include "brave/components/ai_chat/core/browser/ai_chat_metrics.h" +#include + #include +#include +#include +#include +#include +#include #include +#include +#include #include +#include "base/check.h" #include "base/containers/fixed_flat_map.h" -#include "base/metrics/histogram_functions.h" +#include "base/functional/bind.h" +#include "base/location.h" +#include "base/metrics/histogram_base.h" +#include "base/metrics/histogram_functions_internal_overloads.h" #include "base/metrics/histogram_macros.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" +#include "base/numerics/clamped_math.h" +#include "base/time/time.h" #include "brave/components/ai_chat/core/common/pref_names.h" #include "brave/components/p3a_utils/bucket.h" #include "brave/components/p3a_utils/feature_usage.h" @@ -28,15 +42,15 @@ using sidebar::features::SidebarDefaultMode; constexpr base::TimeDelta kReportInterval = base::Hours(24); constexpr base::TimeDelta kReportDebounceDelay = base::Seconds(3); -const int kChatCountBuckets[] = {1, 5, 10, 20, 50}; -const int kAvgPromptCountBuckets[] = {2, 5, 10, 20}; +constexpr int kChatCountBuckets[] = {1, 5, 10, 20, 50}; +constexpr int kAvgPromptCountBuckets[] = {2, 5, 10, 20}; constexpr base::TimeDelta kPremiumCheckInterval = base::Days(1); #if !BUILDFLAG(IS_ANDROID) && !BUILDFLAG(IS_IOS) // Value -1 is added to buckets to add padding for the "less than 1% option" -const int kOmniboxOpenBuckets[] = {-1, 0, 3, 5, 10, 25}; -const int kContextMenuUsageBuckets[] = {0, 1, 2, 5, 10, 20, 50}; +constexpr int kOmniboxOpenBuckets[] = {-1, 0, 3, 5, 10, 25}; +constexpr int kContextMenuUsageBuckets[] = {0, 1, 2, 5, 10, 20, 50}; constexpr char kSummarizeActionKey[] = "summarize"; constexpr char kExplainActionKey[] = "explain"; diff --git a/components/ai_chat/core/browser/ai_chat_metrics.h b/components/ai_chat/core/browser/ai_chat_metrics.h index 20bd0d11cda4..096851b35057 100644 --- a/components/ai_chat/core/browser/ai_chat_metrics.h +++ b/components/ai_chat/core/browser/ai_chat_metrics.h @@ -10,10 +10,15 @@ #include #include "base/containers/flat_map.h" +#include "base/functional/callback.h" +#include "base/memory/raw_ptr.h" #include "base/memory/weak_ptr.h" +#include "base/timer/timer.h" #include "base/timer/wall_clock_timer.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "brave/components/time_period_storage/time_period_storage.h" #include "brave/components/time_period_storage/weekly_storage.h" +#include "build/build_config.h" class PrefRegistrySimple; class PrefService; diff --git a/components/ai_chat/core/browser/ai_chat_metrics_unittest.cc b/components/ai_chat/core/browser/ai_chat_metrics_unittest.cc index 759f71d7c3da..2935e978cefb 100644 --- a/components/ai_chat/core/browser/ai_chat_metrics_unittest.cc +++ b/components/ai_chat/core/browser/ai_chat_metrics_unittest.cc @@ -5,15 +5,19 @@ #include "brave/components/ai_chat/core/browser/ai_chat_metrics.h" +#include + #include -#include +#include +#include +#include #include -#include "base/memory/raw_ptr.h" +#include "base/numerics/clamped_math.h" #include "base/test/bind.h" #include "base/test/metrics/histogram_tester.h" +#include "base/test/task_environment.h" #include "base/time/time.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "components/prefs/testing_pref_service.h" #include "content/public/test/browser_task_environment.h" #include "testing/gtest/include/gtest/gtest.h" diff --git a/components/ai_chat/core/browser/ai_chat_service.cc b/components/ai_chat/core/browser/ai_chat_service.cc index 4f5c5954b8fa..b0d0884d9006 100644 --- a/components/ai_chat/core/browser/ai_chat_service.cc +++ b/components/ai_chat/core/browser/ai_chat_service.cc @@ -5,31 +5,87 @@ #include "brave/components/ai_chat/core/browser/ai_chat_service.h" +#include +#include #include -#include +#include #include +#include +#include #include #include +#include "base/check.h" +#include "base/containers/adapters.h" +#include "base/containers/contains.h" +#include "base/containers/fixed_flat_set.h" +#include "base/functional/bind.h" +#include "base/functional/callback_helpers.h" +#include "base/logging.h" +#include "base/notreached.h" +#include "base/numerics/clamped_math.h" +#include "base/ranges/algorithm.h" +#include "base/task/task_traits.h" +#include "base/task/thread_pool.h" #include "base/time/time.h" +#include "base/types/strong_alias.h" #include "base/uuid.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" +#include "brave/components/ai_chat/core/browser/ai_chat_database.h" +#include "brave/components/ai_chat/core/browser/ai_chat_metrics.h" #include "brave/components/ai_chat/core/browser/constants.h" #include "brave/components/ai_chat/core/browser/conversation_handler.h" #include "brave/components/ai_chat/core/browser/model_service.h" #include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/features.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/ai_chat/core/common/pref_names.h" +#include "build/build_config.h" +#include "components/os_crypt/async/browser/os_crypt_async.h" #include "components/prefs/pref_service.h" +#include "mojo/public/cpp/bindings/pending_receiver.h" +#include "mojo/public/cpp/bindings/struct_ptr.h" +#include "services/network/public/cpp/shared_url_loader_factory.h" +#include "url/gurl.h" +#include "url/url_constants.h" namespace ai_chat { namespace { +constexpr base::FilePath::StringPieceType kDBFileName = + FILE_PATH_LITERAL("AIChat"); + constexpr auto kAllowedSchemes = base::MakeFixedFlatSet( {url::kHttpsScheme, url::kHttpScheme, url::kFileScheme, url::kDataScheme}); +std::vector FilterVisibleConversations( + std::map& conversations_map) { + std::vector conversations; + for (const auto& kv : conversations_map) { + auto& conversation = kv.second; + // Conversations are only visible if they have content + if (!conversation->has_content) { + continue; + } + conversations.push_back(conversation.get()); + } + base::ranges::sort(conversations, std::greater<>(), + &mojom::Conversation::updated_time); + return conversations; +} + +bool IsConversationUpdatedTimeWithinRange( + std::optional begin_time, + std::optional end_time, + mojom::ConversationPtr& conversation) { + return ((!begin_time.has_value() || begin_time->is_null() || + conversation->updated_time >= begin_time) && + (!end_time.has_value() || end_time->is_null() || end_time->is_max() || + conversation->updated_time <= end_time)); +} + } // namespace AIChatService::AIChatService( @@ -37,22 +93,36 @@ AIChatService::AIChatService( std::unique_ptr ai_chat_credential_manager, PrefService* profile_prefs, AIChatMetrics* ai_chat_metrics, + os_crypt_async::OSCryptAsync* os_crypt_async, scoped_refptr url_loader_factory, - std::string_view channel_string) + std::string_view channel_string, + base::FilePath profile_path) : model_service_(model_service), profile_prefs_(profile_prefs), ai_chat_metrics_(ai_chat_metrics), + os_crypt_async_(os_crypt_async), url_loader_factory_(url_loader_factory), feedback_api_( std::make_unique(url_loader_factory_, std::string(channel_string))), - credential_manager_(std::move(ai_chat_credential_manager)) { + credential_manager_(std::move(ai_chat_credential_manager)), + profile_path_(profile_path) { DCHECK(profile_prefs_); pref_change_registrar_.Init(profile_prefs_); pref_change_registrar_.Add( prefs::kLastAcceptedDisclaimer, base::BindRepeating(&AIChatService::OnUserOptedIn, weak_ptr_factory_.GetWeakPtr())); + pref_change_registrar_.Add( + prefs::kStorageEnabled, + base::BindRepeating(&AIChatService::MaybeInitStorage, + weak_ptr_factory_.GetWeakPtr())); + pref_change_registrar_.Add( + prefs::kUserDismissedPremiumPrompt, + base::BindRepeating(&AIChatService::OnStateChanged, + weak_ptr_factory_.GetWeakPtr())); + + MaybeInitStorage(); } AIChatService::~AIChatService() = default; @@ -70,6 +140,10 @@ void AIChatService::Bind(mojo::PendingReceiver receiver) { void AIChatService::Shutdown() { // Disconnect remotes receivers_.ClearWithReason(0, "Shutting down"); + weak_ptr_factory_.InvalidateWeakPtrs(); + if (ai_chat_db_) { + ai_chat_db_.Reset(); + } } ConversationHandler* AIChatService::CreateConversation() { @@ -78,7 +152,10 @@ ConversationHandler* AIChatService::CreateConversation() { // Create the conversation metadata { mojom::ConversationPtr conversation = mojom::Conversation::New( - conversation_uuid, "", base::Time::Now(), false); + conversation_uuid, "", base::Time::Now(), false, std::nullopt, + mojom::SiteInfo::New(base::Uuid::GenerateRandomV4().AsLowercaseString(), + mojom::ContentType::PageContent, std::nullopt, + std::nullopt, std::nullopt, 0, false, false)); conversations_.insert_or_assign(conversation_uuid, std::move(conversation)); } mojom::Conversation* conversation = @@ -107,14 +184,74 @@ ConversationHandler* AIChatService::CreateConversation() { } ConversationHandler* AIChatService::GetConversation( - const std::string& conversation_uuid) { - auto conversation_handler_it = conversation_handlers_.find(conversation_uuid); + std::string_view conversation_uuid) { + auto conversation_handler_it = + conversation_handlers_.find(conversation_uuid.data()); if (conversation_handler_it == conversation_handlers_.end()) { return nullptr; } return conversation_handler_it->second.get(); } +void AIChatService::GetConversation( + std::string_view conversation_uuid, + base::OnceCallback callback) { + if (ConversationHandler* cached_conversation = + GetConversation(conversation_uuid)) { + DVLOG(4) << __func__ << " found cached conversation for " + << conversation_uuid; + std::move(callback).Run(cached_conversation); + return; + } + // Load from database + if (!ai_chat_db_) { + std::move(callback).Run(nullptr); + return; + } + LoadConversationsLazy(base::BindOnce( + [](base::WeakPtr instance, std::string conversation_uuid, + base::OnceCallback callback, + ConversationMap& conversations) { + auto conversation_it = conversations.find(conversation_uuid); + if (conversation_it == conversations.end()) { + std::move(callback).Run(nullptr); + return; + } + mojom::ConversationPtr& metadata = conversation_it->second; + // Get archive content and conversation entries + instance->ai_chat_db_.AsyncCall(&AIChatDatabase::GetConversationData) + .WithArgs(metadata->uuid) + .Then(base::BindOnce(&AIChatService::OnConversationDataReceived, + std::move(instance), metadata->uuid, + std::move(callback))); + }, + weak_ptr_factory_.GetWeakPtr(), std::string(conversation_uuid), + std::move(callback))); +} + +void AIChatService::OnConversationDataReceived( + std::string conversation_uuid, + base::OnceCallback callback, + mojom::ConversationArchivePtr data) { + DVLOG(4) << __func__ << " for " << conversation_uuid + << " with data: " << data->entries.size() << " entries and " + << data->associated_content.size() << " contents"; + auto conversation_it = conversations_.find(conversation_uuid); + if (conversation_it == conversations_.end()) { + std::move(callback).Run(nullptr); + return; + } + mojom::Conversation* conversation = conversation_it->second.get(); + std::unique_ptr conversation_handler = + std::make_unique( + conversation, this, model_service_, credential_manager_.get(), + feedback_api_.get(), url_loader_factory_, std::move(data)); + conversation_observations_.AddObservation(conversation_handler.get()); + conversation_handlers_.insert_or_assign(conversation_uuid, + std::move(conversation_handler)); + std::move(callback).Run(GetConversation(conversation_uuid)); +} + ConversationHandler* AIChatService::GetOrCreateConversationHandlerForContent( int associated_content_id, base::WeakPtr @@ -152,6 +289,249 @@ ConversationHandler* AIChatService::CreateConversationHandlerForContent( return conversation; } +void AIChatService::DeleteConversations(std::optional begin_time, + std::optional end_time) { + if (!begin_time.has_value() && !end_time.has_value()) { + // Delete all conversations + // Delete in-memory data + conversation_observations_.RemoveAllObservations(); + conversation_handlers_.clear(); + conversations_.clear(); + content_conversations_.clear(); + + // Delete database data + if (ai_chat_db_) { + ai_chat_db_.AsyncCall(base::IgnoreResult(&AIChatDatabase::DeleteAllData)); + ReloadConversations(); + } + if (ai_chat_metrics_ != nullptr) { + ai_chat_metrics_->RecordReset(); + } + OnConversationListChanged(); + return; + } + + // Get all keys from conversations_ + std::vector conversation_keys; + + for (auto& [uuid, conversation] : conversations_) { + if (IsConversationUpdatedTimeWithinRange(begin_time, end_time, + conversation)) { + conversation_keys.push_back(uuid); + } + } + + for (const auto& uuid : conversation_keys) { + DeleteConversation(uuid); + } + if (!conversation_keys.empty()) { + OnConversationListChanged(); + } +} + +void AIChatService::DeleteAssociatedWebContent( + std::optional begin_time, + std::optional end_time, + base::OnceCallback callback) { + if (!ai_chat_db_) { + std::move(callback).Run(true); + return; + } + + ai_chat_db_.AsyncCall(&AIChatDatabase::DeleteAssociatedWebContent) + .WithArgs(begin_time, end_time) + .Then(std::move(callback)); + + // Update local data + ReloadConversations(); +} + +void AIChatService::MaybeInitStorage() { + if (IsAIChatHistoryEnabled()) { + if (!ai_chat_db_) { + DVLOG(0) << "Initializing OS Crypt Async"; + encryptor_ready_subscription_ = os_crypt_async_->GetInstance( + base::BindOnce(&AIChatService::OnOsCryptAsyncReady, + weak_ptr_factory_.GetWeakPtr())); + // Don't init DB until oscrypt is ready - we don't want to use the DB + // if we can't use encryption. + } + } else { + // Delete all stored data from database + if (ai_chat_db_) { + DVLOG(0) << "Unloading AI Chat database due to pref change"; + base::SequenceBound ai_chat_db = std::move(ai_chat_db_); + ai_chat_db.AsyncCall(&AIChatDatabase::DeleteAllData) + .Then(base::BindOnce(&AIChatService::OnDataDeletedForDisabledStorage, + weak_ptr_factory_.GetWeakPtr())); + } + } + OnStateChanged(); +} + +void AIChatService::OnOsCryptAsyncReady(os_crypt_async::Encryptor encryptor, + bool success) { + CHECK(features::IsAIChatHistoryEnabled()); + if (!success) { + LOG(ERROR) << "Failed to initialize AIChat DB due to OSCrypt failure"; + return; + } + // Pref might have changed since we started this process + if (!profile_prefs_->GetBoolean(prefs::kStorageEnabled)) { + return; + } + ai_chat_db_ = base::SequenceBound( + base::ThreadPool::CreateSequencedTaskRunner( + {base::MayBlock(), base::WithBaseSyncPrimitives(), + base::TaskPriority::BEST_EFFORT, + base::TaskShutdownBehavior::BLOCK_SHUTDOWN}), + profile_path_.Append(kDBFileName), std::move(encryptor)); +} + +void AIChatService::OnDataDeletedForDisabledStorage(bool success) { + // Remove any conversations from in-memory that aren't connected to UI. + // This is done now, in the callback from DeleteAllData, in case there + // was any in-progress operations that would have resulted in adding data + // back to conversations_ whilst waiting for DeleteAllData to complete. + std::vector all_conversation_handlers; + for (auto& [_, conversation_handler] : conversation_handlers_) { + all_conversation_handlers.push_back(conversation_handler.get()); + } + for (auto* conversation_handler : all_conversation_handlers) { + MaybeUnloadConversation(conversation_handler); + } + // Remove any conversation metadata that isn't connected to a still-alive + // handler. + for (auto it = conversations_.begin(); it != conversations_.end();) { + if (!conversation_handlers_.contains(it->first)) { + it = conversations_.erase(it); + } else { + ++it; + } + } + OnConversationListChanged(); + // Re-check the preference since it could have been re-enabled + // whilst the database operation was in progress. If so, we can re-use + // the same database instance (post data deletion). + if (!IsAIChatHistoryEnabled()) { + // If there is a LoadConversationsLazy in progress, it will get cancelled + // on destruction of ai_chat_db_ so call the callbacks. + if (on_conversations_loaded_callbacks_.has_value() && + !on_conversations_loaded_callbacks_->empty()) { + for (auto& callback : on_conversations_loaded_callbacks_.value()) { + std::move(callback).Run(conversations_); + } + } + + ai_chat_db_.Reset(); + cancel_conversation_load_callback_ = base::NullCallback(); + on_conversations_loaded_callbacks_ = std::nullopt; + } +} + +void AIChatService::LoadConversationsLazy(ConversationMapCallback callback) { + // Send immediately if we have finished loading from storage + if (!ai_chat_db_ || (on_conversations_loaded_callbacks_.has_value() && + on_conversations_loaded_callbacks_->empty())) { + std::move(callback).Run(conversations_); + return; + } + if (on_conversations_loaded_callbacks_.has_value()) { + on_conversations_loaded_callbacks_->push_back(std::move(callback)); + return; + } + + on_conversations_loaded_callbacks_ = std::vector(); + on_conversations_loaded_callbacks_->push_back(std::move(callback)); + ai_chat_db_.AsyncCall(&AIChatDatabase::GetAllConversations) + .Then(base::BindOnce(&AIChatService::OnLoadConversationsLazyData, + weak_ptr_factory_.GetWeakPtr())); +} + +void AIChatService::OnLoadConversationsLazyData( + std::vector conversations) { + if (!cancel_conversation_load_callback_.is_null()) { + std::move(cancel_conversation_load_callback_).Run(); + cancel_conversation_load_callback_ = base::NullCallback(); + return; + } + DVLOG(1) << "Loaded " << conversations.size() << " conversations."; + for (auto& conversation : conversations) { + std::string uuid = conversation->uuid; + DVLOG(2) << "Loaded conversation " << conversation->uuid + << " with details: " << "\n has content: " + << conversation->has_content + << "\n last updated: " << conversation->updated_time + << "\n title: " << conversation->title; + // It's ok to overwrite existing metadata - some operations may modify + // the database data and we want to keep the in-memory data synchronised. + auto existing_conversation_it = conversations_.find(uuid); + if (existing_conversation_it != conversations_.end()) { + auto& existing_conversation = existing_conversation_it->second; + existing_conversation->title = conversation->title; + existing_conversation->updated_time = conversation->updated_time; + existing_conversation->has_content = conversation->has_content; + existing_conversation->model_key = conversation->model_key; + existing_conversation->associated_content = + std::move(conversation->associated_content); + } else { + conversations_.emplace(uuid, std::move(conversation)); + } + auto handler_it = conversation_handlers_.find(uuid); + if (handler_it != conversation_handlers_.end()) { + // Notify the handler that metadata is possibly changed + ConversationHandler* handler = handler_it->second.get(); + handler->OnConversationMetadataUpdated(); + // If a reload was asked for, then we should also update the deeper + // conversation data from the database, since the reload was likely due + // to underlying data changing. + ai_chat_db_.AsyncCall(&AIChatDatabase::GetConversationData) + .WithArgs(uuid) + .Then(base::BindOnce( + [](base::WeakPtr handler, + mojom::ConversationArchivePtr updated_data) { + if (!handler) { + return; + } + handler->OnArchiveContentUpdated(std::move(updated_data)); + }, + handler->GetWeakPtr())); + } + } + if (on_conversations_loaded_callbacks_.has_value()) { + for (auto& callback : on_conversations_loaded_callbacks_.value()) { + std::move(callback).Run(conversations_); + } + on_conversations_loaded_callbacks_->clear(); + } + OnConversationListChanged(); +} + +void AIChatService::ReloadConversations(bool from_cancel) { + // If in the middle of a conversation load, then make sure data is ignored, + // and ask again when current load is complete. + if (!from_cancel && on_conversations_loaded_callbacks_.has_value() && + !on_conversations_loaded_callbacks_->empty()) { + cancel_conversation_load_callback_ = + base::BindOnce(&AIChatService::ReloadConversations, + weak_ptr_factory_.GetWeakPtr(), true); + return; + } + + // Collect any previous callbacks and force conversations to load again + std::vector previous_callbacks; + if (on_conversations_loaded_callbacks_.has_value()) { + on_conversations_loaded_callbacks_->swap(previous_callbacks); + } + on_conversations_loaded_callbacks_ = std::nullopt; + LoadConversationsLazy(base::DoNothing()); + + // Re-queue any previous callbacks + for (auto& callback : previous_callbacks) { + LoadConversationsLazy(std::move(callback)); + } +} + void AIChatService::MaybeAssociateContentWithConversation( ConversationHandler* conversation, int associated_content_id, @@ -173,6 +553,18 @@ void AIChatService::MarkAgreementAccepted() { SetUserOptedIn(profile_prefs_, true); } +void AIChatService::EnableStoragePref() { + profile_prefs_->SetBoolean(prefs::kStorageEnabled, true); +} + +void AIChatService::DismissStorageNotice() { + profile_prefs_->SetBoolean(prefs::kUserDismissedStorageNotice, true); +} + +void AIChatService::DismissPremiumPrompt() { + profile_prefs_->SetBoolean(prefs::kUserDismissedPremiumPrompt, true); +} + void AIChatService::GetActionMenuList(GetActionMenuListCallback callback) { std::move(callback).Run(ai_chat::GetActionMenuList()); } @@ -183,49 +575,12 @@ void AIChatService::GetPremiumStatus(GetPremiumStatusCallback callback) { weak_ptr_factory_.GetWeakPtr(), std::move(callback))); } -void AIChatService::GetCanShowPremiumPrompt( - GetCanShowPremiumPromptCallback callback) { - bool has_user_dismissed_prompt = - profile_prefs_->GetBoolean(prefs::kUserDismissedPremiumPrompt); - - if (has_user_dismissed_prompt) { - std::move(callback).Run(false); - return; - } - - base::Time last_accepted_disclaimer = - profile_prefs_->GetTime(prefs::kLastAcceptedDisclaimer); - - // Can't show if we haven't accepted disclaimer yet - if (last_accepted_disclaimer.is_null()) { - std::move(callback).Run(false); - return; - } - - base::Time time_1_day_ago = base::Time::Now() - base::Days(1); - bool is_more_than_24h_since_last_seen = - last_accepted_disclaimer < time_1_day_ago; - - if (is_more_than_24h_since_last_seen) { - std::move(callback).Run(true); - return; - } - - std::move(callback).Run(false); -} - -void AIChatService::DismissPremiumPrompt() { - profile_prefs_->SetBoolean(prefs::kUserDismissedPremiumPrompt, true); -} - void AIChatService::DeleteConversation(const std::string& id) { - ConversationHandler* conversation_handler = - conversation_handlers_.at(id).get(); - if (!conversation_handler) { - return; + auto handler_it = conversation_handlers_.find(id); + if (handler_it != conversation_handlers_.end()) { + conversation_observations_.RemoveObservation(handler_it->second.get()); + conversation_handlers_.erase(id); } - conversation_observations_.RemoveObservation(conversation_handler); - conversation_handlers_.erase(id); conversations_.erase(id); DVLOG(1) << "Erased conversation due to deletion request (" << id << "). Now have " << conversations_.size() @@ -233,6 +588,12 @@ void AIChatService::DeleteConversation(const std::string& id) { << conversation_handlers_.size() << " ConversationHandler instances."; OnConversationListChanged(); + // Update database + if (ai_chat_db_) { + ai_chat_db_ + .AsyncCall(base::IgnoreResult(&AIChatDatabase::DeleteConversation)) + .WithArgs(id); + } } void AIChatService::RenameConversation(const std::string& id, @@ -261,79 +622,186 @@ void AIChatService::OnPremiumStatusReceived(GetPremiumStatusCallback callback, last_premium_status_ = status; if (ai_chat::HasUserOptedIn(profile_prefs_) && ai_chat_metrics_ != nullptr) { - ai_chat_metrics_->OnPremiumStatusUpdated(false, status, std::move(info)); + ai_chat_metrics_->OnPremiumStatusUpdated(false, status, info.Clone()); } model_service_->OnPremiumStatus(status); std::move(callback).Run(status, std::move(info)); } -void AIChatService::MaybeEraseConversation( +void AIChatService::MaybeUnloadConversation( ConversationHandler* conversation_handler) { - // Don't unload if there is active UI for the conversation - if (conversation_handler->IsAnyClientConnected()) { - return; - } - - bool has_history = conversation_handler->HasAnyHistory(); - - // We can keep a conversation with history in memory until there is no active - // content. - // TODO(petemill): With the history feature enabled, we should unload (if - // there is no request in progress). However, we can only do this when - // GetOrCreateConversationHandlerForContent allows a callback so that it - // can provide an answer after loading the conversation content from storage. - if (conversation_handler->IsAssociatedContentAlive() && has_history) { - return; - } - - // AIChatHistory feature doesn't yet have persistant storage, so keep - // handlers and data around if it's enabled. - if (!features::IsAIChatHistoryEnabled() || !has_history) { - // Can erase because no active UI and no history, so it's - // not a real / persistable conversation + if (!conversation_handler->IsAnyClientConnected() && + !conversation_handler->IsRequestInProgress()) { + // Can erase handler because no active UI + bool has_history = conversation_handler->HasAnyHistory(); auto uuid = conversation_handler->get_conversation_uuid(); conversation_observations_.RemoveObservation(conversation_handler); conversation_handlers_.erase(uuid); - conversations_.erase(uuid); - std::erase_if(content_conversations_, - [&uuid](const auto& kv) { return kv.second == uuid; }); - DVLOG(1) << "Erased conversation (" << uuid << "). Now have " + DVLOG(1) << "Unloaded conversation (" << uuid << ") from memory. Now have " << conversations_.size() << " Conversation metadata items and " << conversation_handlers_.size() << " ConversationHandler instances."; - OnConversationListChanged(); + if (!IsAIChatHistoryEnabled() || !has_history) { + // Can erase because no active UI and no history, so it's + // not a real / persistable conversation + conversations_.erase(uuid); + std::erase_if(content_conversations_, + [&uuid](const auto& kv) { return kv.second == uuid; }); + DVLOG(1) << "Erased conversation (" << uuid << "). Now have " + << conversations_.size() << " Conversation metadata items and " + << conversation_handlers_.size() + << " ConversationHandler instances."; + OnConversationListChanged(); + } + } else { + DVLOG(4) << "Not unloading conversation (" + << conversation_handler->get_conversation_uuid() + << ") from memory. Has active clients: " + << (conversation_handler->IsAnyClientConnected() ? "yes" : "no") + << " Request is in progress: " + << (conversation_handler->IsRequestInProgress() ? "yes" : "no"); } } -void AIChatService::OnConversationEntriesChanged( +mojom::ServiceStatePtr AIChatService::BuildState() { + bool has_user_dismissed_storage_notice = + profile_prefs_->GetBoolean(prefs::kUserDismissedStorageNotice); + base::Time last_accepted_disclaimer = + profile_prefs_->GetTime(ai_chat::prefs::kLastAcceptedDisclaimer); + + bool is_user_opted_in = !last_accepted_disclaimer.is_null(); + + // Premium prompt is only shown conditionally (e.g. the user hasn't dismissed + // it and it's been some time since the user started using the feature). + bool can_show_premium_prompt = + !profile_prefs_->GetBoolean(prefs::kUserDismissedPremiumPrompt) && + !last_accepted_disclaimer.is_null() && + last_accepted_disclaimer < base::Time::Now() - base::Days(1); + + bool is_storage_enabled = profile_prefs_->GetBoolean(prefs::kStorageEnabled); + + mojom::ServiceStatePtr state = mojom::ServiceState::New(); + state->has_accepted_agreement = is_user_opted_in; + state->is_storage_pref_enabled = is_storage_enabled; + state->is_storage_notice_dismissed = has_user_dismissed_storage_notice; + state->can_show_premium_prompt = can_show_premium_prompt; + return state; +} + +void AIChatService::OnStateChanged() { + mojom::ServiceStatePtr state = BuildState(); + for (auto& remote : observer_remotes_) { + remote->OnStateChanged(state.Clone()); + } +} + +bool AIChatService::IsAIChatHistoryEnabled() { + return (features::IsAIChatHistoryEnabled() && + profile_prefs_->GetBoolean(prefs::kStorageEnabled)); +} + +void AIChatService::OnRequestInProgressChanged(ConversationHandler* handler, + bool in_progress) { + // We don't unload a conversation if it has a request in progress, so check + // again when that changes. + if (!in_progress) { + MaybeUnloadConversation(handler); + } +} + +void AIChatService::OnConversationEntryAdded( ConversationHandler* handler, - std::vector entries) { + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value) { auto conversation_it = conversations_.find(handler->get_conversation_uuid()); CHECK(conversation_it != conversations_.end()); - auto& conversation = conversation_it->second; - if (!entries.empty()) { - // This conversation is visible once the first response begins - conversation->has_content = true; - if (ai_chat_metrics_ != nullptr) { - // Each time the user starts a conversation - if (entries.size() == 1) { - ai_chat_metrics_->RecordNewChat(); - } - // Each time the user submits an entry - if (entries.back()->character_type == mojom::CharacterType::HUMAN) { - ai_chat_metrics_->RecordNewPrompt(); - } + mojom::ConversationPtr& conversation = conversation_it->second; + + if (!conversation->has_content) { + HandleFirstEntry(handler, entry, associated_content_value, conversation); + } else { + HandleNewEntry(handler, entry, associated_content_value, conversation); + } + + conversation->has_content = true; + conversation->updated_time = entry->created_time; + OnConversationListChanged(); +} + +void AIChatService::HandleFirstEntry( + ConversationHandler* handler, + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value, + mojom::ConversationPtr& conversation) { + DVLOG(1) << __func__ << " Conversation " << conversation->uuid + << " being persisted for first time."; + CHECK(entry->uuid.has_value()); + CHECK(conversation->associated_content->uuid.has_value()); + // We can persist the conversation metadata for the first time as well as the + // entry. + if (ai_chat_db_) { + ai_chat_db_.AsyncCall(base::IgnoreResult(&AIChatDatabase::AddConversation)) + .WithArgs(conversation->Clone(), + std::optional(associated_content_value), + entry->Clone()); + } + // Record metrics + if (ai_chat_metrics_ != nullptr) { + if (handler->GetConversationHistory().size() == 1) { + ai_chat_metrics_->RecordNewChat(); } - OnConversationListChanged(); - // TODO(petemill): Persist the entries, but consider receiving finer grained - // entry update events. + } +} + +void AIChatService::HandleNewEntry( + ConversationHandler* handler, + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value, + mojom::ConversationPtr& conversation) { + CHECK(entry->uuid.has_value()); + DVLOG(1) << __func__ << " Conversation " << conversation->uuid + << " persisting new entry. Count of entries: " + << handler->GetConversationHistory().size(); + + // Persist the new entry and update the associated content data, if present + if (ai_chat_db_) { + ai_chat_db_ + .AsyncCall(base::IgnoreResult(&AIChatDatabase::AddConversationEntry)) + .WithArgs(handler->get_conversation_uuid(), entry.Clone(), + conversation->model_key, std::nullopt); + + if (associated_content_value.has_value() && + conversation->associated_content->is_content_association_possible) { + ai_chat_db_ + .AsyncCall( + base::IgnoreResult(&AIChatDatabase::AddOrUpdateAssociatedContent)) + .WithArgs(conversation->uuid, + conversation->associated_content->Clone(), + std::optional(associated_content_value)); + } + } + + // Record metrics + if (ai_chat_metrics_ != nullptr && + entry->character_type == mojom::CharacterType::HUMAN) { + ai_chat_metrics_->RecordNewPrompt(); + } +} + +void AIChatService::OnConversationEntryRemoved(ConversationHandler* handler, + std::string entry_uuid) { + // Persist the removal + if (ai_chat_db_) { + ai_chat_db_ + .AsyncCall(base::IgnoreResult(&AIChatDatabase::DeleteConversationEntry)) + .WithArgs(entry_uuid); } } void AIChatService::OnClientConnectionChanged(ConversationHandler* handler) { DVLOG(4) << "Client connection changed for conversation " << handler->get_conversation_uuid(); - MaybeEraseConversation(handler); + MaybeUnloadConversation(handler); } void AIChatService::OnConversationTitleChanged(ConversationHandler* handler, @@ -342,39 +810,57 @@ void AIChatService::OnConversationTitleChanged(ConversationHandler* handler, CHECK(conversation_it != conversations_.end()); auto& conversation = conversation_it->second; conversation->title = title; + OnConversationListChanged(); -} -void AIChatService::OnAssociatedContentDestroyed(ConversationHandler* handler, - int content_id) { - content_conversations_.erase(content_id); - MaybeEraseConversation(handler); + // Persist the change + if (ai_chat_db_) { + ai_chat_db_ + .AsyncCall(base::IgnoreResult(&AIChatDatabase::UpdateConversationTitle)) + .WithArgs(handler->get_conversation_uuid(), std::move(title)); + } } void AIChatService::GetVisibleConversations( GetVisibleConversationsCallback callback) { - std::vector conversations; - for (const auto& conversation : FilterVisibleConversations()) { - conversations.push_back(conversation->Clone()); - } - std::move(callback).Run(std::move(conversations)); + LoadConversationsLazy(base::BindOnce( + [](GetVisibleConversationsCallback callback, + ConversationMap& conversations_map) { + std::vector conversations; + for (const auto& conversation : + FilterVisibleConversations(conversations_map)) { + conversations.push_back(conversation->Clone()); + } + std::move(callback).Run(std::move(conversations)); + }, + std::move(callback))); } void AIChatService::BindConversation( const std::string& uuid, mojo::PendingReceiver receiver, mojo::PendingRemote conversation_ui_handler) { - ConversationHandler* conversation = GetConversation(uuid); - if (!conversation) { - return; - } - CHECK(conversation) << "Asked to bind a conversation which doesn't exist"; - conversation->Bind(std::move(receiver), std::move(conversation_ui_handler)); + GetConversation( + std::move(uuid), + base::BindOnce( + [](mojo::PendingReceiver receiver, + mojo::PendingRemote conversation_ui_handler, + ConversationHandler* handler) { + if (!handler) { + DVLOG(0) << "Failed to get conversation for binding"; + return; + } + handler->Bind(std::move(receiver), + std::move(conversation_ui_handler)); + }, + std::move(receiver), std::move(conversation_ui_handler))); } void AIChatService::BindObserver( - mojo::PendingRemote observer) { + mojo::PendingRemote observer, + BindObserverCallback callback) { observer_remotes_.Add(std::move(observer)); + std::move(callback).Run(BuildState()); } bool AIChatService::HasUserOptedIn() { @@ -396,6 +882,7 @@ size_t AIChatService::GetInMemoryConversationCountForTesting() { } void AIChatService::OnUserOptedIn() { + OnStateChanged(); bool is_opted_in = HasUserOptedIn(); if (!is_opted_in) { return; @@ -403,31 +890,13 @@ void AIChatService::OnUserOptedIn() { for (auto& kv : conversation_handlers_) { kv.second->OnUserOptedIn(); } - for (auto& remote : observer_remotes_) { - remote->OnAgreementAccepted(); - } if (ai_chat_metrics_ != nullptr) { ai_chat_metrics_->RecordEnabled(true, true, {}); } } -std::vector AIChatService::FilterVisibleConversations() { - std::vector conversations; - for (const auto& kv : conversations_) { - auto& conversation = kv.second; - // Conversations are only visible if they have content - if (!conversation->has_content) { - continue; - } - conversations.push_back(conversation.get()); - } - base::ranges::sort(conversations, std::greater<>(), - &mojom::Conversation::created_time); - return conversations; -} - void AIChatService::OnConversationListChanged() { - auto conversations = FilterVisibleConversations(); + auto conversations = FilterVisibleConversations(conversations_); for (auto& remote : observer_remotes_) { std::vector client_conversations; for (const auto& conversation : conversations) { diff --git a/components/ai_chat/core/browser/ai_chat_service.h b/components/ai_chat/core/browser/ai_chat_service.h index 1653046f988e..7fadde1ed709 100644 --- a/components/ai_chat/core/browser/ai_chat_service.h +++ b/components/ai_chat/core/browser/ai_chat_service.h @@ -6,31 +6,54 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_SERVICE_H_ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_AI_CHAT_SERVICE_H_ +#include + #include #include +#include #include +#include #include +#include "base/callback_list.h" +#include "base/functional/callback.h" +#include "base/functional/callback_helpers.h" +#include "base/memory/raw_ptr.h" #include "base/memory/scoped_refptr.h" #include "base/memory/weak_ptr.h" #include "base/scoped_multi_source_observation.h" +#include "base/task/sequenced_task_runner.h" +#include "base/threading/sequence_bound.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" +#include "brave/components/ai_chat/core/browser/ai_chat_database.h" #include "brave/components/ai_chat/core/browser/ai_chat_feedback_api.h" #include "brave/components/ai_chat/core/browser/ai_chat_metrics.h" #include "brave/components/ai_chat/core/browser/conversation_handler.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "brave/components/skus/common/skus_sdk.mojom.h" #include "components/keyed_service/core/keyed_service.h" #include "components/prefs/pref_change_registrar.h" #include "mojo/public/cpp/bindings/pending_receiver.h" +#include "mojo/public/cpp/bindings/pending_remote.h" #include "mojo/public/cpp/bindings/receiver_set.h" +#include "mojo/public/cpp/bindings/remote_set.h" #include "services/network/public/cpp/shared_url_loader_factory.h" +namespace os_crypt_async { +class Encryptor; +class OSCryptAsync; +} // namespace os_crypt_async + +class PrefService; +namespace network { +class SharedURLLoaderFactory; +} // namespace network + namespace ai_chat { class ModelService; +class AIChatMetrics; // Main entry point for creating and consuming AI Chat conversations class AIChatService : public KeyedService, @@ -45,8 +68,10 @@ class AIChatService : public KeyedService, std::unique_ptr ai_chat_credential_manager, PrefService* profile_prefs, AIChatMetrics* ai_chat_metrics, + os_crypt_async::OSCryptAsync* os_crypt_async, scoped_refptr url_loader_factory, - std::string_view channel_string); + std::string_view channel_string, + base::FilePath profile_path); ~AIChatService() override; AIChatService(const AIChatService&) = delete; @@ -59,19 +84,24 @@ class AIChatService : public KeyedService, void Shutdown() override; // ConversationHandler::Observer - void OnConversationEntriesChanged( + void OnRequestInProgressChanged(ConversationHandler* handler, + bool in_progress) override; + void OnConversationEntryAdded( ConversationHandler* handler, - std::vector entries) override; + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value) override; + void OnConversationEntryRemoved(ConversationHandler* handler, + std::string entry_uuid) override; void OnClientConnectionChanged(ConversationHandler* handler) override; void OnConversationTitleChanged(ConversationHandler* handler, std::string title) override; - void OnAssociatedContentDestroyed(ConversationHandler* handler, - int content_id) override; // Adds new conversation and returns the handler ConversationHandler* CreateConversation(); - ConversationHandler* GetConversation(const std::string& uuid); + ConversationHandler* GetConversation(std::string_view uuid); + void GetConversation(std::string_view conversation_uuid, + base::OnceCallback); // Creates and owns a ConversationHandler if one hasn't been made for the // associated_content_id yet. |associated_content_id| should not be stored. It @@ -89,6 +119,16 @@ class AIChatService : public KeyedService, base::WeakPtr associated_content); + // Removes all in-memory and persisted data for all conversations + void DeleteConversations(std::optional begin_time = std::nullopt, + std::optional end_time = std::nullopt); + + // Remove only web-content data from conversations + void DeleteAssociatedWebContent( + std::optional begin_time = std::nullopt, + std::optional end_time = std::nullopt, + base::OnceCallback callback = base::DoNothing()); + void OpenConversationWithStagedEntries( base::WeakPtr associated_content, @@ -96,13 +136,13 @@ class AIChatService : public KeyedService, // mojom::Service void MarkAgreementAccepted() override; + void EnableStoragePref() override; + void DismissStorageNotice() override; + void DismissPremiumPrompt() override; void GetVisibleConversations( GetVisibleConversationsCallback callback) override; void GetActionMenuList(GetActionMenuListCallback callback) override; void GetPremiumStatus(GetPremiumStatusCallback callback) override; - void GetCanShowPremiumPrompt( - GetCanShowPremiumPromptCallback callback) override; - void DismissPremiumPrompt() override; void DeleteConversation(const std::string& id) override; void RenameConversation(const std::string& id, const std::string& new_name) override; @@ -112,11 +152,15 @@ class AIChatService : public KeyedService, mojo::PendingReceiver receiver, mojo::PendingRemote conversation_ui_handler) override; - void BindObserver(mojo::PendingRemote ui) override; + void BindObserver(mojo::PendingRemote ui, + BindObserverCallback callback) override; bool HasUserOptedIn(); bool IsPremiumStatus(); + // Whether the feature and user preference for history storage is enabled + bool IsAIChatHistoryEnabled(); + std::unique_ptr GetDefaultAIEngine(); AIChatCredentialManager* GetCredentialManagerForTesting() { @@ -127,33 +171,72 @@ class AIChatService : public KeyedService, size_t GetInMemoryConversationCountForTesting(); private: + // Key is uuid + using ConversationMap = std::map; + using ConversationMapCallback = base::OnceCallback; + + void MaybeInitStorage(); + // Called when the database encryptor is ready. + void OnOsCryptAsyncReady(os_crypt_async::Encryptor encryptor, bool success); + void LoadConversationsLazy(ConversationMapCallback callback); + void OnLoadConversationsLazyData( + std::vector conversations); + void ReloadConversations(bool from_cancel = false); + void OnConversationDataReceived( + std::string conversation_uuid, + base::OnceCallback callback, + mojom::ConversationArchivePtr data); + void MaybeAssociateContentWithConversation( ConversationHandler* conversation, int associated_content_id, base::WeakPtr associated_content); + void MaybeUnloadConversation(ConversationHandler* conversation); + void HandleFirstEntry( + ConversationHandler* handler, + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value, + mojom::ConversationPtr& conversation); + void HandleNewEntry(ConversationHandler* handler, + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value, + mojom::ConversationPtr& conversation); + void OnUserOptedIn(); void OnSkusServiceReceived( SkusServiceGetter getter, mojo::PendingRemote service); - std::vector FilterVisibleConversations(); void OnConversationListChanged(); void OnPremiumStatusReceived(GetPremiumStatusCallback callback, mojom::PremiumStatus status, mojom::PremiumInfoPtr info); - void MaybeEraseConversation(ConversationHandler* conversation); + void OnDataDeletedForDisabledStorage(bool success); + mojom::ServiceStatePtr BuildState(); + void OnStateChanged(); raw_ptr model_service_; raw_ptr profile_prefs_; raw_ptr ai_chat_metrics_; + raw_ptr os_crypt_async_; scoped_refptr url_loader_factory_; PrefChangeRegistrar pref_change_registrar_; std::unique_ptr feedback_api_; std::unique_ptr credential_manager_; - // All conversation metadata. Mainly just titles and uuids. Key is uuid - std::map conversations_; + base::FilePath profile_path_; + + // Storage for conversations + base::SequenceBound ai_chat_db_; + + // nullopt if haven't started fetching, empty if done fetching + std::optional> + on_conversations_loaded_callbacks_; + base::OnceClosure cancel_conversation_load_callback_ = base::NullCallback(); + + // All conversation metadata. Mainly just titles and uuids. + ConversationMap conversations_; // Only keep ConversationHandlers around that are being // actively used. Any metadata that needs to stay in-memory @@ -179,6 +262,8 @@ class AIChatService : public KeyedService, // subscription status changes. So we cache it and fetch latest fairly // often (whenever UI is focused). mojom::PremiumStatus last_premium_status_ = mojom::PremiumStatus::Unknown; + // Maintains the subscription for `OSCryptAsync` and cancels upon destruction. + base::CallbackListSubscription encryptor_ready_subscription_; base::WeakPtrFactory weak_ptr_factory_{this}; }; diff --git a/components/ai_chat/core/browser/ai_chat_service_unittest.cc b/components/ai_chat/core/browser/ai_chat_service_unittest.cc index 55a872698c77..b1e3cc300ea4 100644 --- a/components/ai_chat/core/browser/ai_chat_service_unittest.cc +++ b/components/ai_chat/core/browser/ai_chat_service_unittest.cc @@ -14,8 +14,10 @@ #include #include +#include "base/files/scoped_temp_dir.h" #include "base/functional/bind.h" #include "base/functional/callback.h" +#include "base/functional/callback_helpers.h" #include "base/functional/overloaded.h" #include "base/memory/scoped_refptr.h" #include "base/run_loop.h" @@ -31,10 +33,14 @@ #include "base/time/time.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/browser/conversation_handler.h" +#include "brave/components/ai_chat/core/browser/mock_conversation_handler_observer.h" +#include "brave/components/ai_chat/core/browser/test_utils.h" #include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/features.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/ai_chat/core/common/pref_names.h" +#include "components/os_crypt/async/browser/os_crypt_async.h" +#include "components/os_crypt/async/browser/test_utils.h" #include "components/sync_preferences/testing_pref_service_syncable.h" #include "mojo/public/cpp/bindings/receiver.h" #include "services/data_decoder/public/cpp/test_support/in_process_data_decoder.h" @@ -55,40 +61,6 @@ namespace ai_chat { namespace { -std::vector CreateSampleHistory() { - auto created_time1 = base::Time::Now(); - std::vector history; - history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, - mojom::ConversationTurnVisibility::VISIBLE, "prompt1", std::nullopt, - std::nullopt, created_time1, std::nullopt, false)); - history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "answer1", std::nullopt, - std::nullopt, base::Time::Now(), std::nullopt, false)); - return history; -} - -class MockConversationHandlerObserver : public ConversationHandler::Observer { - public: - MockConversationHandlerObserver() = default; - ~MockConversationHandlerObserver() override = default; - - void Observe(ConversationHandler* conversation) { - conversation_observations_.AddObservation(conversation); - } - - MOCK_METHOD(void, - OnClientConnectionChanged, - (ConversationHandler*), - (override)); - - private: - base::ScopedMultiSourceObservation - conversation_observations_{this}; -}; - class MockAIChatCredentialManager : public AIChatCredentialManager { public: using AIChatCredentialManager::AIChatCredentialManager; @@ -101,8 +73,8 @@ class MockAIChatCredentialManager : public AIChatCredentialManager { class MockServiceClient : public mojom::ServiceObserver { public: explicit MockServiceClient(AIChatService* service) { - service->BindObserver( - service_observer_receiver_.BindNewPipeAndPassRemote()); + service->BindObserver(service_observer_receiver_.BindNewPipeAndPassRemote(), + base::DoNothing()); service->Bind(service_remote_.BindNewPipeAndPassReceiver()); } @@ -120,7 +92,7 @@ class MockServiceClient : public mojom::ServiceObserver { (std::vector), (override)); - MOCK_METHOD(void, OnAgreementAccepted, (), (override)); + MOCK_METHOD(void, OnStateChanged, (mojom::ServiceStatePtr), (override)); private: mojo::Receiver service_observer_receiver_{this}; @@ -132,6 +104,7 @@ class MockConversationHandlerClient : public mojom::ConversationUI { explicit MockConversationHandlerClient(ConversationHandler* driver) { driver->Bind(conversation_handler_remote_.BindNewPipeAndPassReceiver(), conversation_ui_receiver_.BindNewPipeAndPassRemote()); + conversation_handler_ = driver; } ~MockConversationHandlerClient() override = default; @@ -141,6 +114,9 @@ class MockConversationHandlerClient : public mojom::ConversationUI { conversation_ui_receiver_.reset(); } + ConversationHandler* GetConversationHandler() { + return conversation_handler_; + } MOCK_METHOD(void, OnConversationHistoryUpdate, (), (override)); @@ -171,6 +147,7 @@ class MockConversationHandlerClient : public mojom::ConversationUI { private: mojo::Receiver conversation_ui_receiver_{this}; mojo::Remote conversation_handler_remote_; + raw_ptr conversation_handler_; }; class MockAssociatedContent @@ -183,15 +160,21 @@ class MockAssociatedContent void SetContentId(int id) { content_id_ = id; } + void GetContent( + ConversationHandler::GetPageContentCallback callback) override { + std::move(callback).Run(GetTextContent(), GetCachedIsVideo(), ""); + } + + std::string_view GetCachedTextContent() override { + cached_text_content_ = GetTextContent(); + return cached_text_content_; + } + MOCK_METHOD(GURL, GetURL, (), (const, override)); MOCK_METHOD(std::u16string, GetTitle, (), (const, override)); - MOCK_METHOD(std::string_view, GetCachedTextContent, (), (override)); MOCK_METHOD(bool, GetCachedIsVideo, (), (override)); - MOCK_METHOD(void, - GetContent, - (ConversationHandler::GetPageContentCallback), - (override)); + MOCK_METHOD(std::string, GetTextContent, (), ()); MOCK_METHOD(void, GetStagedEntriesFromContent, @@ -210,12 +193,7 @@ class MockAssociatedContent void DisassociateWithConversations(std::string archived_text_content, bool archived_is_video) { - std::vector> related_conversations; for (auto& conversation : related_conversations_) { - related_conversations.push_back(conversation->GetWeakPtr()); - } - - for (auto& conversation : related_conversations) { if (conversation) { conversation->OnAssociatedContentDestroyed(archived_text_content, archived_is_video); @@ -231,6 +209,7 @@ class MockAssociatedContent base::WeakPtrFactory weak_ptr_factory_{this}; int content_id_ = 0; + std::string cached_text_content_; std::set> related_conversations_; }; @@ -245,13 +224,39 @@ class AIChatServiceUnitTest : public testing::Test, } void SetUp() override { + CHECK(temp_directory_.CreateUniqueTempDir()); + DVLOG(0) << "Temp directory: " << temp_directory_.GetPath().value(); prefs::RegisterProfilePrefs(prefs_.registry()); prefs::RegisterLocalStatePrefs(local_state_.registry()); ModelService::RegisterProfilePrefs(prefs_.registry()); + os_crypt_ = os_crypt_async::GetTestOSCryptAsyncForTesting( + /*is_sync_for_unittests=*/true); + shared_url_loader_factory_ = base::MakeRefCounted( &url_loader_factory_); + + model_service_ = std::make_unique(&prefs_); + + CreateService(); + + if (is_opted_in_) { + EmulateUserOptedIn(); + } else { + EmulateUserOptedOut(); + } + } + + void TearDown() override { + ai_chat_service_.reset(); + // Allow handles on the db to be released, otherwise for very quick + // tests, we get crashes on temp_directory_.Delete(). + task_environment_.RunUntilIdle(); + CHECK(temp_directory_.Delete()); + } + + void CreateService() { std::unique_ptr credential_manager = std::make_unique(base::NullCallback(), &local_state_); @@ -263,20 +268,19 @@ class AIChatServiceUnitTest : public testing::Test, std::move(premium_info)); }); - model_service_ = std::make_unique(&prefs_); - ai_chat_service_ = std::make_unique( model_service_.get(), std::move(credential_manager), &prefs_, nullptr, - shared_url_loader_factory_, ""); + os_crypt_.get(), shared_url_loader_factory_, "", + temp_directory_.GetPath()); client_ = std::make_unique>(ai_chat_service_.get()); + } - if (is_opted_in_) { - EmulateUserOptedIn(); - } else { - EmulateUserOptedOut(); - } + void ResetService() { + ai_chat_service_.reset(); + task_environment_.RunUntilIdle(); + CreateService(); } void ExpectVisibleConversationsSize(base::Location location, size_t size) { @@ -327,8 +331,6 @@ class AIChatServiceUnitTest : public testing::Test, void EmulateUserOptedOut() { ::ai_chat::SetUserOptedIn(&prefs_, false); } - void TearDown() override { ai_chat_service_.reset(); } - protected: base::test::TaskEnvironment task_environment_; std::unique_ptr ai_chat_service_; @@ -336,6 +338,7 @@ class AIChatServiceUnitTest : public testing::Test, std::unique_ptr> client_; sync_preferences::TestingPrefServiceSyncable prefs_; sync_preferences::TestingPrefServiceSyncable local_state_; + std::unique_ptr os_crypt_; network::TestURLLoaderFactory url_loader_factory_; scoped_refptr shared_url_loader_factory_; data_decoder::test::InProcessDataDecoder in_process_data_decoder_; @@ -343,6 +346,7 @@ class AIChatServiceUnitTest : public testing::Test, private: base::test::ScopedFeatureList scoped_feature_list_; + base::ScopedTempDir temp_directory_; }; INSTANTIATE_TEST_SUITE_P( @@ -394,10 +398,10 @@ TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithMessages) { .Times(testing::AtLeast(1)); ConversationHandler* conversation_handler1 = CreateConversation(); - conversation_handler1->SetChatHistoryForTesting(CreateSampleHistory()); + conversation_handler1->SetChatHistoryForTesting(CreateSampleChatHistory(1u)); ConversationHandler* conversation_handler2 = CreateConversation(); - conversation_handler2->SetChatHistoryForTesting(CreateSampleHistory()); + conversation_handler2->SetChatHistoryForTesting(CreateSampleChatHistory(1u)); ExpectVisibleConversationsSize(FROM_HERE, 2u); @@ -407,17 +411,16 @@ TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithMessages) { // Connect a client then disconnect auto client1 = CreateConversationClient(conversation_handler1); DisconnectConversationClient(client1.get()); - // Only 1 should be deleted, or none if we're preserving history - EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), - IsAIChatHistoryEnabled() ? 2u : 1u); + // Only 1 should be deleted, whether we preserve history or not (is preserved + // in the database). + EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u); ExpectVisibleConversationsSize(FROM_HERE, IsAIChatHistoryEnabled() ? 2u : 1u); // Connect a client then disconnect auto client2 = CreateConversationClient(conversation_handler2); DisconnectConversationClient(client2.get()); - EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), - IsAIChatHistoryEnabled() ? 2u : 0u); + EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u); ExpectVisibleConversationsSize(FROM_HERE, IsAIChatHistoryEnabled() ? 2u : 0u); @@ -425,55 +428,6 @@ TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithMessages) { task_environment_.RunUntilIdle(); } -TEST_P(AIChatServiceUnitTest, ConversationLifecycle_WithContent) { - NiceMock associated_content{}; - ON_CALL(associated_content, GetURL()) - .WillByDefault(testing::Return(GURL("https://example.com"))); - associated_content.SetContentId(1); - ConversationHandler* conversation_with_content_no_messages = - ai_chat_service_->GetOrCreateConversationHandlerForContent( - associated_content.GetContentId(), associated_content.GetWeakPtr()); - EXPECT_TRUE(conversation_with_content_no_messages); - // Asking again for same content ID gets same conversation - EXPECT_EQ( - conversation_with_content_no_messages, - ai_chat_service_->GetOrCreateConversationHandlerForContent( - associated_content.GetContentId(), associated_content.GetWeakPtr())); - // Shouldn't be visible without messages - ExpectVisibleConversationsSize(FROM_HERE, 0u); - EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u); - // Disconnecting the client should unload the handler and delete the - // conversation. - auto client1 = - CreateConversationClient(conversation_with_content_no_messages); - DisconnectConversationClient(client1.get()); - EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u); - ExpectVisibleConversationsSize(FROM_HERE, 0u); - - // Create a new conversation for same content, with messages this time - ConversationHandler* conversation_with_content = - ai_chat_service_->GetOrCreateConversationHandlerForContent( - associated_content.GetContentId(), associated_content.GetWeakPtr()); - conversation_with_content->SetChatHistoryForTesting(CreateSampleHistory()); - ExpectVisibleConversationsSize(FROM_HERE, 1u); - EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u); - auto client2 = CreateConversationClient(conversation_with_content); - DisconnectConversationClient(client2.get()); - // Disconnecting all clients should keep the handler in memory until - // the content is destroyed. - EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u); - ExpectVisibleConversationsSize(FROM_HERE, 1u); - associated_content.DisassociateWithConversations("", false); - - if (IsAIChatHistoryEnabled()) { - EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 1u); - ExpectVisibleConversationsSize(FROM_HERE, 1u); - } else { - EXPECT_EQ(ai_chat_service_->GetInMemoryConversationCountForTesting(), 0u); - ExpectVisibleConversationsSize(FROM_HERE, 0u); - } -} - TEST_P(AIChatServiceUnitTest, GetOrCreateConversationHandlerForContent) { ConversationHandler* conversation_without_content = CreateConversation(); @@ -509,24 +463,27 @@ TEST_P(AIChatServiceUnitTest, GetOrCreateConversationHandlerForContent) { // Creating a second conversation with the same associated content should // make the second conversation the default for that content, but leave // the first still associated with the content. - ConversationHandler* conversation2 = + ConversationHandler* conversation_with_content2 = ai_chat_service_->CreateConversationHandlerForContent( associated_content.GetContentId(), associated_content.GetWeakPtr()); - EXPECT_NE(conversation_with_content, conversation2); + EXPECT_NE(conversation_with_content, conversation_with_content2); EXPECT_NE(conversation_with_content->get_conversation_uuid(), - conversation2->get_conversation_uuid()); - EXPECT_EQ(conversation2->GetAssociatedContentDelegateForTesting(), - &associated_content); - EXPECT_EQ(conversation_with_content->GetAssociatedContentDelegateForTesting(), - conversation2->GetAssociatedContentDelegateForTesting()); + conversation_with_content2->get_conversation_uuid()); + EXPECT_EQ( + conversation_with_content2->GetAssociatedContentDelegateForTesting(), + &associated_content); + EXPECT_EQ( + conversation_with_content->GetAssociatedContentDelegateForTesting(), + conversation_with_content2->GetAssociatedContentDelegateForTesting()); // Check the second conversation is the default for that content ID EXPECT_EQ( ai_chat_service_->GetOrCreateConversationHandlerForContent( associated_content.GetContentId(), associated_content.GetWeakPtr()), - conversation2); + conversation_with_content2); // Let the conversation be deleted - std::string conversation2_uuid = conversation2->get_conversation_uuid(); - auto client1 = CreateConversationClient(conversation2); + std::string conversation2_uuid = + conversation_with_content2->get_conversation_uuid(); + auto client1 = CreateConversationClient(conversation_with_content2); DisconnectConversationClient(client1.get()); ConversationHandler* conversation_with_content3 = ai_chat_service_->GetOrCreateConversationHandlerForContent( @@ -563,6 +520,92 @@ TEST_P(AIChatServiceUnitTest, run_loop.Run(); } +TEST_P(AIChatServiceUnitTest, GetConversation_AfterRestart) { + auto history = CreateSampleChatHistory(1u); + std::string uuid; + { + ConversationHandler* conversation_handler = CreateConversation(); + uuid = conversation_handler->get_conversation_uuid(); + auto client = CreateConversationClient(conversation_handler); + conversation_handler->SetChatHistoryForTesting(CloneHistory(history)); + ExpectVisibleConversationsSize(FROM_HERE, 1); + DisconnectConversationClient(client.get()); + } + ExpectVisibleConversationsSize(FROM_HERE, IsAIChatHistoryEnabled() ? 1 : 0); + + // Allow entries to finish being persisted before restarting service + task_environment_.RunUntilIdle(); + DVLOG(0) << "Restarting service"; + ResetService(); + + if (IsAIChatHistoryEnabled()) { + EXPECT_CALL(*client_, OnConversationListChanged(testing::SizeIs(1))) + .Times(testing::AtLeast(1)); + } else { + EXPECT_CALL(*client_, OnConversationListChanged).Times(0); + } + // Can get conversation data + if (IsAIChatHistoryEnabled()) { + base::RunLoop run_loop; + ai_chat_service_->GetConversation( + uuid, base::BindLambdaForTesting( + [&](ConversationHandler* conversation_handler) { + EXPECT_TRUE(conversation_handler); + ExpectConversationHistoryEquals( + FROM_HERE, + conversation_handler->GetConversationHistory(), + history); + run_loop.Quit(); + })); + run_loop.Run(); + } +} + +TEST_P(AIChatServiceUnitTest, MaybeInitStorage_DisableStoragePref) { + // This test is only relevant when history feature is enabled initially + if (!IsAIChatHistoryEnabled()) { + return; + } + // Create history, verify it's persisted, then disable storage and verify + // no history is returned, even in-memory (unless a client is connected). + ConversationHandler* conversation_handler1 = CreateConversation(); + auto client1 = CreateConversationClient(conversation_handler1); + conversation_handler1->SetChatHistoryForTesting(CreateSampleChatHistory(1u)); + + ConversationHandler* conversation_handler2 = CreateConversation(); + auto client2 = CreateConversationClient(conversation_handler2); + conversation_handler2->SetChatHistoryForTesting(CreateSampleChatHistory(1u)); + + ConversationHandler* conversation_handler3 = CreateConversation(); + auto client3 = CreateConversationClient(conversation_handler3); + conversation_handler3->SetChatHistoryForTesting(CreateSampleChatHistory(1u)); + + DisconnectConversationClient(client2.get()); + ExpectVisibleConversationsSize(FROM_HERE, 3); + + // Disable storage + prefs_.SetBoolean(prefs::kStorageEnabled, false); + // Wait for OnConversationListChanged which indicates data has been removed + task_environment_.RunUntilIdle(); + + // Conversation with no client was erased from memory + ExpectVisibleConversationsSize(FROM_HERE, 2); + + // Disconnecting conversations should erase them fom memory + DisconnectConversationClient(client1.get()); + DisconnectConversationClient(client3.get()); + ExpectVisibleConversationsSize(FROM_HERE, 0); + + // Restart service and verify still doesn't load from storage + ResetService(); + ExpectVisibleConversationsSize(FROM_HERE, 0); + + // Re-enable storage preference + prefs_.SetBoolean(prefs::kStorageEnabled, true); + // Conversations are no longer in persistant storage + ExpectVisibleConversationsSize(FROM_HERE, 0); +} + TEST_P(AIChatServiceUnitTest, OpenConversationWithStagedEntries_NoPermission) { NiceMock associated_content{}; ConversationHandler* conversation = @@ -620,4 +663,175 @@ TEST_P(AIChatServiceUnitTest, OpenConversationWithStagedEntries) { testing::Mock::VerifyAndClearExpectations(&associated_content); } +TEST_P(AIChatServiceUnitTest, DeleteConversations) { + // Create conversations, call DeleteConversations and verify all conversations + // are deleted, whether a client is connected or not. + ConversationHandler* conversation_handler1 = CreateConversation(); + auto client1 = CreateConversationClient(conversation_handler1); + conversation_handler1->SetChatHistoryForTesting(CreateSampleChatHistory(1u)); + + ConversationHandler* conversation_handler2 = CreateConversation(); + auto client2 = CreateConversationClient(conversation_handler2); + conversation_handler2->SetChatHistoryForTesting(CreateSampleChatHistory(1u)); + + ConversationHandler* conversation_handler3 = CreateConversation(); + auto client3 = CreateConversationClient(conversation_handler3); + conversation_handler3->SetChatHistoryForTesting(CreateSampleChatHistory(1u)); + + ExpectVisibleConversationsSize(FROM_HERE, 3); + + ai_chat_service_->DeleteConversations(); + + ExpectVisibleConversationsSize(FROM_HERE, 0); + + // Verify deleted from database + ResetService(); + ExpectVisibleConversationsSize(FROM_HERE, 0); +} + +TEST_P(AIChatServiceUnitTest, DeleteConversations_TimeRange) { + // Create conversations, call DeleteConversations and verify all conversations + // are deleted, whether a client is connected or not. + ConversationHandler* conversation_handler1 = CreateConversation(); + auto client1 = CreateConversationClient(conversation_handler1); + // This conversation 3 hours in the past + conversation_handler1->SetChatHistoryForTesting( + CreateSampleChatHistory(1u, -3)); + + ConversationHandler* conversation_handler2 = CreateConversation(); + auto client2 = CreateConversationClient(conversation_handler2); + // This conversation 2 hours in the past + conversation_handler2->SetChatHistoryForTesting( + CreateSampleChatHistory(1u, -2)); + + ConversationHandler* conversation_handler3 = CreateConversation(); + auto client3 = CreateConversationClient(conversation_handler3); + // This conversation 1 hour in the past + conversation_handler3->SetChatHistoryForTesting( + CreateSampleChatHistory(1u, -1)); + + ExpectVisibleConversationsSize(FROM_HERE, 3); + + ai_chat_service_->DeleteConversations(base::Time::Now() - base::Minutes(245), + base::Time::Now() - base::Minutes(110)); + + ExpectVisibleConversationsSize(FROM_HERE, 1); + + // Verify deleted from database + ResetService(); + ExpectVisibleConversationsSize(FROM_HERE, IsAIChatHistoryEnabled() ? 1 : 0); +} + +TEST_P(AIChatServiceUnitTest, DeleteAssociatedWebContent) { + // Only valid when history is enabled + if (!IsAIChatHistoryEnabled()) { + return; + } + + const GURL content_url("https://example.com"); + const std::u16string page_title = u"page title"; + const std::string page_content = "page content"; + + struct Data { + NiceMock associated_content; + raw_ptr conversation_handler; + std::unique_ptr client; + }; + std::array data; + + // First conversation and its content should stay alive and still report + // actual content info even though it falls in the deletion time range. + // Second conversation should have its content archived and should report + // empty content info since it falls in the deletion time range. + // Third conversation should have its content archived but should report + // actual content info since it does not fall in the deletion time range. + + for (int i = 0; i < 3; i++) { + ON_CALL(data[i].associated_content, GetURL()) + .WillByDefault(testing::Return(content_url)); + ON_CALL(data[i].associated_content, GetTitle()) + .WillByDefault(testing::Return(page_title)); + ON_CALL(data[i].associated_content, GetTextContent) + .WillByDefault(testing::Return(page_content)); + data[i].associated_content.SetContentId(i); + + data[i].conversation_handler = + ai_chat_service_->GetOrCreateConversationHandlerForContent( + data[i].associated_content.GetContentId(), + data[i].associated_content.GetWeakPtr()); + EXPECT_TRUE(data[i].conversation_handler); + data[i].client = CreateConversationClient(data[i].conversation_handler); + data[i].conversation_handler->SetChatHistoryForTesting( + CreateSampleChatHistory(1u, -3 + i)); + + // Verify associated are initially correct + base::RunLoop run_loop; + data[i].conversation_handler->GetAssociatedContentInfo( + base::BindLambdaForTesting([&](mojom::SiteInfoPtr site_info, + bool should_send_page_contents) { + SCOPED_TRACE(testing::Message() << "data index: " << i); + EXPECT_TRUE(site_info->is_content_association_possible); + EXPECT_TRUE(site_info->url.has_value()); + EXPECT_EQ(site_info->url.value(), content_url); + EXPECT_EQ(site_info->title.value(), base::UTF16ToUTF8(page_title)); + run_loop.Quit(); + })); + run_loop.Run(); + } + + // Archive content for conversations 2 and 3 + data[1].associated_content.DisassociateWithConversations(page_content, false); + data[2].associated_content.DisassociateWithConversations(page_content, false); + + // Delete associated content from conversations between 1 hours ago and 3 + // hours ago. + base::RunLoop deletion_run_loop; + ai_chat_service_->DeleteAssociatedWebContent( + base::Time::Now() - base::Minutes(182), + base::Time::Now() - base::Minutes(70), + base::BindLambdaForTesting([&](bool success) { + EXPECT_TRUE(success); + deletion_run_loop.Quit(); + })); + deletion_run_loop.Run(); + + ExpectVisibleConversationsSize(FROM_HERE, 3); + + task_environment_.RunUntilIdle(); + + for (int i = 0; i < 3; i++) { + base::RunLoop run_loop; + data[i].conversation_handler->GetAssociatedContentInfo( + base::BindLambdaForTesting([&](mojom::SiteInfoPtr site_info, + bool should_send_page_contents) { + SCOPED_TRACE(testing::Message() << "data index: " << i); + EXPECT_TRUE(site_info->is_content_association_possible); + EXPECT_TRUE(site_info->url.has_value()); + EXPECT_TRUE(site_info->title.has_value()); + if (i == 1) { + EXPECT_TRUE(site_info->url->is_empty()); + EXPECT_TRUE(site_info->title->empty()); + } else { + EXPECT_EQ(site_info->url.value(), content_url); + EXPECT_EQ(site_info->title.value(), base::UTF16ToUTF8(page_title)); + } + run_loop.Quit(); + })); + run_loop.Run(); + + base::RunLoop run_loop_2; + data[i].conversation_handler->GeneratePageContent( + base::BindLambdaForTesting([&](std::string content, bool is_video, + std::string invalidation_token) { + if (i == 1) { + EXPECT_TRUE(content.empty()); + } else { + EXPECT_EQ(content, page_content); + } + run_loop_2.Quit(); + })); + run_loop_2.Run(); + } +} + } // namespace ai_chat diff --git a/components/ai_chat/core/browser/associated_archive_content.cc b/components/ai_chat/core/browser/associated_archive_content.cc index a48f34caba3d..1545bb5d16ae 100644 --- a/components/ai_chat/core/browser/associated_archive_content.cc +++ b/components/ai_chat/core/browser/associated_archive_content.cc @@ -5,9 +5,16 @@ #include "brave/components/ai_chat/core/browser/associated_archive_content.h" +#include +#include +#include +#include #include +#include "base/functional/callback.h" +#include "base/logging.h" #include "base/memory/weak_ptr.h" +#include "base/strings/utf_ostream_operators.h" #include "brave/components/ai_chat/core/browser/conversation_handler.h" namespace ai_chat { @@ -26,6 +33,18 @@ AssociatedArchiveContent::AssociatedArchiveContent(GURL url, AssociatedArchiveContent::~AssociatedArchiveContent() = default; +void AssociatedArchiveContent::SetMetadata(GURL url, + std::u16string title, + bool is_video) { + url_ = url; + title_ = title; + is_video_ = is_video; +} + +void AssociatedArchiveContent::SetContent(std::string text_content) { + text_content_ = text_content; +} + int AssociatedArchiveContent::GetContentId() const { return -1; } diff --git a/components/ai_chat/core/browser/associated_archive_content.h b/components/ai_chat/core/browser/associated_archive_content.h index bec7599abdf5..6af9cc9a3a1d 100644 --- a/components/ai_chat/core/browser/associated_archive_content.h +++ b/components/ai_chat/core/browser/associated_archive_content.h @@ -7,6 +7,7 @@ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ASSOCIATED_ARCHIVE_CONTENT_H_ #include +#include #include "base/memory/weak_ptr.h" #include "brave/components/ai_chat/core/browser/conversation_handler.h" @@ -21,6 +22,11 @@ namespace ai_chat { // content. // Similarly, if a conversation is loaded from storage, and the conversation // was associated with content, this class is used to represent that content. +// +// If this class is used to represent archive content that can be shared by +// multiple conversations, consider changing owner to the AIChatService and +// having it subclass AssociatedContentDriver for related conversation +// management. class AssociatedArchiveContent : public ConversationHandler::AssociatedContentDelegate { public: @@ -32,6 +38,11 @@ class AssociatedArchiveContent AssociatedArchiveContent(const AssociatedArchiveContent&) = delete; AssociatedArchiveContent& operator=(const AssociatedArchiveContent&) = delete; + // Occassionally even an archive is updated, such as when content is deleted + // for privacy reasons. + void SetMetadata(GURL url, std::u16string title, bool is_video); + void SetContent(std::string text_content); + int GetContentId() const override; GURL GetURL() const override; std::u16string GetTitle() const override; diff --git a/components/ai_chat/core/browser/associated_content_driver.cc b/components/ai_chat/core/browser/associated_content_driver.cc index c27cc4b09636..53448c3a7eec 100644 --- a/components/ai_chat/core/browser/associated_content_driver.cc +++ b/components/ai_chat/core/browser/associated_content_driver.cc @@ -5,27 +5,33 @@ #include "brave/components/ai_chat/core/browser/associated_content_driver.h" +#include #include +#include #include #include +#include #include #include -#include "base/containers/contains.h" -#include "base/containers/fixed_flat_set.h" +#include "base/check.h" +#include "base/containers/flat_map.h" #include "base/functional/bind.h" +#include "base/location.h" +#include "base/logging.h" #include "base/memory/weak_ptr.h" #include "base/one_shot_event.h" -#include "base/ranges/algorithm.h" -#include "base/strings/string_util.h" +#include "base/strings/strcat.h" #include "brave/brave_domains/service_domains.h" #include "brave/components/ai_chat/core/browser/brave_search_responses.h" #include "brave/components/ai_chat/core/browser/conversation_handler.h" #include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/constants.h" +#include "brave/components/api_request_helper/api_request_helper.h" #include "net/base/url_util.h" #include "net/traffic_annotation/network_traffic_annotation.h" #include "services/network/public/cpp/shared_url_loader_factory.h" +#include "url/url_constants.h" namespace ai_chat { @@ -61,7 +67,12 @@ AssociatedContentDriver::AssociatedContentDriver( : url_loader_factory_(url_loader_factory) {} AssociatedContentDriver::~AssociatedContentDriver() { - DisassociateWithConversations(); + for (auto& conversation : associated_conversations_) { + if (conversation) { + conversation->OnAssociatedContentDestroyed(cached_text_content_, + is_video_); + } + } } void AssociatedContentDriver::AddRelatedConversation( @@ -267,11 +278,17 @@ void AssociatedContentDriver::OnFaviconImageDataChanged() { } } -void AssociatedContentDriver::OnNewPage(int64_t navigation_id) { - // This instance will now be used for different content so existing - // conversations need to be disassociated. - DisassociateWithConversations(); +void AssociatedContentDriver::OnTitleChanged() { + for (auto& conversation : associated_conversations_) { + conversation->OnAssociatedContentTitleChanged(); + } +} +void AssociatedContentDriver::OnNewPage(int64_t navigation_id) { + // Tell the associated_conversations_ that we're breaking up + for (auto& conversation : associated_conversations_) { + conversation->OnAssociatedContentDestroyed(cached_text_content_, is_video_); + } // Tell the observer how to find the next conversation for (auto& observer : observers_) { observer.OnAssociatedContentNavigated(navigation_id); @@ -287,16 +304,4 @@ void AssociatedContentDriver::OnNewPage(int64_t navigation_id) { ConversationHandler::AssociatedContentDelegate::OnNewPage(navigation_id); } -void AssociatedContentDriver::DisassociateWithConversations() { - // Iterator might be invalidated by destruction, so copy the items - std::vector conversations{ - associated_conversations_.begin(), associated_conversations_.end()}; - for (auto& conversation : conversations) { - if (conversation) { - conversation->OnAssociatedContentDestroyed(cached_text_content_, - is_video_); - } - } -} - } // namespace ai_chat diff --git a/components/ai_chat/core/browser/associated_content_driver.h b/components/ai_chat/core/browser/associated_content_driver.h index a36343623a94..01426499b770 100644 --- a/components/ai_chat/core/browser/associated_content_driver.h +++ b/components/ai_chat/core/browser/associated_content_driver.h @@ -6,20 +6,43 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ASSOCIATED_CONTENT_DRIVER_H_ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ASSOCIATED_CONTENT_DRIVER_H_ +#include #include +#include #include #include #include #include +#include "base/functional/callback.h" #include "base/gtest_prod_util.h" #include "base/memory/raw_ptr.h" +#include "base/memory/scoped_refptr.h" +#include "base/memory/weak_ptr.h" +#include "base/observer_list.h" +#include "base/observer_list_types.h" #include "base/one_shot_event.h" #include "brave/components/ai_chat/core/browser/conversation_handler.h" #include "brave/components/ai_chat/core/browser/model_service.h" +#include "brave/components/ai_chat/core/browser/types.h" #include "brave/components/api_request_helper/api_request_helper.h" +#include "url/gurl.h" FORWARD_DECLARE_TEST(AIChatUIBrowserTest, PrintPreviewFallback); +class AIChatUIBrowserTest; +namespace api_request_helper { +class APIRequestHelper; +class APIRequestResult; +} // namespace api_request_helper +namespace base { +class Value; +} // namespace base +namespace network { +class SharedURLLoaderFactory; +} // namespace network +namespace base { +class OneShotEvent; +} // namespace base namespace ai_chat { @@ -87,6 +110,9 @@ class AssociatedContentDriver // Implementer should call this when the favicon for the content changes void OnFaviconImageDataChanged(); + // Implementer should call this when the title is updated + void OnTitleChanged(); + // Implementer should call this when the content is updated in a way that // will not be detected by the on-demand techniques used by GetPageContent. // For example for sites where GetPageContent does not read the live DOM but @@ -122,12 +148,6 @@ class AssociatedContentDriver ConversationHandler::GetStagedEntriesCallback callback, int64_t navigation_id, api_request_helper::APIRequestResult result); - - // Let all conversations using this content know that the content - // has been destroyed or changed to represent different content (e.g. a - // navigation). - void DisassociateWithConversations(); - static std::optional> ParseSearchQuerySummaryResponse(const base::Value& value); diff --git a/components/ai_chat/core/browser/associated_content_driver_unittest.cc b/components/ai_chat/core/browser/associated_content_driver_unittest.cc index e364fd7b8e68..183515766e3e 100644 --- a/components/ai_chat/core/browser/associated_content_driver_unittest.cc +++ b/components/ai_chat/core/browser/associated_content_driver_unittest.cc @@ -5,22 +5,26 @@ #include "brave/components/ai_chat/core/browser/associated_content_driver.h" -#include #include +#include +#include +#include #include +#include "base/functional/bind.h" +#include "base/location.h" +#include "base/run_loop.h" #include "base/task/sequenced_task_runner.h" -#include "base/task/single_thread_task_runner.h" #include "base/test/bind.h" #include "base/test/gmock_callback_support.h" #include "base/test/mock_callback.h" #include "base/test/task_environment.h" #include "base/test/values_test_util.h" -#include "base/threading/thread.h" #include "brave/components/ai_chat/core/browser/conversation_handler.h" #include "brave/components/ai_chat/core/browser/types.h" #include "brave/components/ai_chat/core/common/mojom/page_content_extractor.mojom.h" #include "services/data_decoder/public/cpp/test_support/in_process_data_decoder.h" +#include "services/network/public/cpp/resource_request.h" #include "services/network/public/cpp/shared_url_loader_factory.h" #include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h" #include "services/network/test/test_url_loader_factory.h" diff --git a/components/ai_chat/core/browser/constants.cc b/components/ai_chat/core/browser/constants.cc index 056f6ffd9d5c..29c2b3e10c44 100644 --- a/components/ai_chat/core/browser/constants.cc +++ b/components/ai_chat/core/browser/constants.cc @@ -6,9 +6,14 @@ #include "brave/components/ai_chat/core/browser/constants.h" #include +#include +#include #include -#include "base/strings/strcat.h" +#include "base/containers/flat_tree.h" +#include "base/strings/string_util.h" +#include "components/grit/brave_components_strings.h" +#include "mojo/public/cpp/bindings/struct_ptr.h" #include "ui/base/l10n/l10n_util.h" namespace ai_chat { @@ -120,6 +125,17 @@ base::span GetLocalizedStrings() { {"searchInProgress", IDS_CHAT_UI_SEARCH_IN_PROGRESS}, {"searchQueries", IDS_CHAT_UI_SEARCH_QUERIES}, {"learnMore", IDS_CHAT_UI_LEARN_MORE}, + {"closeNotice", IDS_CHAT_UI_CLOSE_NOTICE}, + {"noticeConversationHistoryBody", + IDS_CHAT_UI_NOTICE_CONVERSATION_HISTORY_BODY}, + {"noticeConversationHistoryEmpty", + IDS_CHAT_UI_NOTICE_CONVERSATION_HISTORY_EMPTY}, + {"noticeConversationHistoryTitleDisabledPref", + IDS_CHAT_UI_NOTICE_CONVERSATION_HISTORY_TITLE_DISABLED_PREF}, + {"noticeConversationHistoryDisabledPref", + IDS_CHAT_UI_NOTICE_CONVERSATION_HISTORY_DISABLED_PREF}, + {"noticeConversationHistoryDisabledPrefButton", + IDS_CHAT_UI_NOTICE_CONVERSATION_HISTORY_DISABLED_PREF_BUTTON}, {"leoSettingsTooltipLabel", IDS_CHAT_UI_LEO_SETTINGS_TOOLTIP_LABEL}, {"summarizePageButtonLabel", IDS_CHAT_UI_SUMMARIZE_PAGE}, {"welcomeGuideTitle", IDS_CHAT_UI_WELCOME_GUIDE_TITLE}, @@ -163,6 +179,8 @@ base::span GetLocalizedStrings() { {"useMicButtonLabel", IDS_AI_CHAT_USE_MICROPHONE_BUTTON_LABEL}, {"menuTitleCustomModels", IDS_AI_CHAT_MENU_TITLE_CUSTOM_MODELS}, {"startConversationLabel", IDS_AI_CHAT_START_CONVERSATION_LABEL}, + {"goBackToActiveConversationButton", + IDS_AI_CHAT_GO_BACK_TO_ACTIVE_CONVERSATION_BUTTON}, {"conversationListUntitled", IDS_AI_CHAT_CONVERSATION_LIST_UNTITLED}}); return kLocalizedStrings; diff --git a/components/ai_chat/core/browser/constants.h b/components/ai_chat/core/browser/constants.h index 4296b65d611d..698343cb172c 100644 --- a/components/ai_chat/core/browser/constants.h +++ b/components/ai_chat/core/browser/constants.h @@ -6,10 +6,15 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_CONSTANTS_H_ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_CONSTANTS_H_ +#include + +#include #include +#include #include #include "base/containers/fixed_flat_set.h" +#include "base/containers/span.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "components/grit/brave_components_strings.h" #include "ui/base/webui/web_ui_util.h" @@ -35,10 +40,10 @@ inline constexpr char kLeoModelSupportUrl[] = // chars per token. When no context size has been provided, we will default to a // conservative 4k tokens based on common models like Phi 3 Mini and Llama 2 // (both have 4k token context limits). -constexpr size_t kDefaultCharsPerToken = 4; -constexpr float kMaxContentLengthThreshold = 0.6f; -constexpr size_t kReservedTokensForPrompt = 300; -constexpr size_t kReservedTokensForMaxNewTokens = 400; +inline constexpr size_t kDefaultCharsPerToken = 4; +inline constexpr float kMaxContentLengthThreshold = 0.6f; +inline constexpr size_t kReservedTokensForPrompt = 300; +inline constexpr size_t kReservedTokensForMaxNewTokens = 400; } // namespace ai_chat diff --git a/components/ai_chat/core/browser/conversation_handler.cc b/components/ai_chat/core/browser/conversation_handler.cc index 9e721ec775b1..28d24503d261 100644 --- a/components/ai_chat/core/browser/conversation_handler.cc +++ b/components/ai_chat/core/browser/conversation_handler.cc @@ -5,35 +5,68 @@ #include "brave/components/ai_chat/core/browser/conversation_handler.h" +#include + #include +#include +#include +#include #include -#include +#include +#include +#include #include +#include "base/check.h" +#include "base/containers/flat_tree.h" +#include "base/containers/span.h" +#include "base/debug/crash_logging.h" +#include "base/debug/dump_without_crashing.h" #include "base/files/file_path.h" +#include "base/functional/bind.h" +#include "base/logging.h" #include "base/memory/weak_ptr.h" +#include "base/metrics/field_trial_params.h" +#include "base/notreached.h" +#include "base/numerics/safe_math.h" +#include "base/rand_util.h" +#include "base/ranges/algorithm.h" +#include "base/strings/strcat.h" +#include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" #include "base/task/sequenced_task_runner.h" +#include "base/task/task_traits.h" #include "base/task/thread_pool.h" +#include "base/time/time.h" #include "base/types/expected.h" -#include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" +#include "base/types/strong_alias.h" +#include "base/uuid.h" +#include "base/values.h" #include "brave/components/ai_chat/core/browser/ai_chat_feedback_api.h" #include "brave/components/ai_chat/core/browser/ai_chat_service.h" #include "brave/components/ai_chat/core/browser/associated_archive_content.h" -#include "brave/components/ai_chat/core/browser/constants.h" #include "brave/components/ai_chat/core/browser/local_models_updater.h" #include "brave/components/ai_chat/core/browser/model_service.h" +#include "brave/components/ai_chat/core/browser/model_validator.h" #include "brave/components/ai_chat/core/browser/types.h" #include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/features.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "brave/components/api_request_helper/api_request_helper.h" +#include "components/grit/brave_components_strings.h" #include "mojo/public/cpp/bindings/pending_receiver.h" +#include "mojo/public/cpp/bindings/pending_remote.h" +#include "mojo/public/cpp/bindings/remote.h" +#include "mojo/public/cpp/bindings/struct_ptr.h" #include "services/network/public/cpp/shared_url_loader_factory.h" #include "ui/base/l10n/l10n_util.h" +#define STARTER_PROMPT(TYPE) \ + l10n_util::GetStringUTF8(IDS_AI_CHAT_STATIC_STARTER_TITLE_##TYPE), \ + l10n_util::GetStringUTF8(IDS_AI_CHAT_STATIC_STARTER_PROMPT_##TYPE) + namespace ai_chat { +class AIChatCredentialManager; namespace { @@ -43,6 +76,8 @@ using ai_chat::mojom::ConversationTurn; using AssociatedContentDelegate = ConversationHandler::AssociatedContentDelegate; +constexpr size_t kDefaultSuggestionsCount = 4; + } // namespace AssociatedContentDelegate::AssociatedContentDelegate() @@ -131,13 +166,39 @@ void AssociatedContentDelegate::OnTextEmbedderInitialized(bool initialized) { pending_top_similarity_requests_.clear(); } +ConversationHandler::Suggestion::Suggestion(std::string title) + : title(std::move(title)) {} +ConversationHandler::Suggestion::Suggestion(std::string title, + std::string prompt) + : title(std::move(title)), prompt(std::move(prompt)) {} +ConversationHandler::Suggestion::Suggestion(Suggestion&&) = default; +ConversationHandler::Suggestion& ConversationHandler::Suggestion::operator=( + Suggestion&&) = default; +ConversationHandler::Suggestion::~Suggestion() = default; + ConversationHandler::ConversationHandler( - const mojom::Conversation* conversation, + mojom::Conversation* conversation, AIChatService* ai_chat_service, ModelService* model_service, AIChatCredentialManager* credential_manager, AIChatFeedbackAPI* feedback_api, scoped_refptr url_loader_factory) + : ConversationHandler(conversation, + ai_chat_service, + model_service, + credential_manager, + feedback_api, + url_loader_factory, + std::nullopt) {} + +ConversationHandler::ConversationHandler( + mojom::Conversation* conversation, + AIChatService* ai_chat_service, + ModelService* model_service, + AIChatCredentialManager* credential_manager, + AIChatFeedbackAPI* feedback_api, + scoped_refptr url_loader_factory, + std::optional initial_state) : metadata_(conversation), ai_chat_service_(ai_chat_service), model_service_(model_service), @@ -152,8 +213,31 @@ ConversationHandler::ConversationHandler( &ConversationHandler::OnConversationUIConnectionChanged, weak_ptr_factory_.GetWeakPtr())); models_observer_.Observe(model_service_.get()); - // TODO(petemill): differ based on premium status, if different - ChangeModel(model_service->GetDefaultModelKey()); + + ChangeModel(conversation->model_key.value_or("").empty() + ? model_service->GetDefaultModelKey() + : conversation->model_key.value()); + + if (initial_state.has_value() && !initial_state.value()->entries.empty()) { + // We only support single associated content for now + mojom::ConversationArchivePtr conversation_data = + std::move(initial_state.value()); + if (!conversation_data->associated_content.empty()) { + CHECK(metadata_->associated_content->uuid.has_value()); + CHECK_EQ(conversation_data->associated_content[0]->content_uuid, + metadata_->associated_content->uuid.value()); + bool is_video = (metadata_->associated_content->content_type == + mojom::ContentType::VideoTranscript); + SetArchiveContent(conversation_data->associated_content[0]->content, + is_video); + } + DVLOG(1) << "Restoring associated content for conversation " + << metadata_->uuid << " with " + << conversation_data->entries.size(); + chat_history_ = std::move(conversation_data->entries); + } + + MaybeSeedOrClearSuggestions(); } ConversationHandler::~ConversationHandler() { @@ -184,6 +268,52 @@ void ConversationHandler::Bind( Bind(std::move(conversation_ui_handler)); } +void ConversationHandler::Bind( + mojo::PendingReceiver receiver) { + untrusted_receivers_.Add(this, std::move(receiver)); +} + +void ConversationHandler::BindUntrustedConversationUI( + mojo::PendingRemote + untrusted_conversation_ui_handler, + BindUntrustedConversationUICallback callback) { + untrusted_conversation_ui_handlers_.Add( + std::move(untrusted_conversation_ui_handler)); + std::move(callback).Run(GetStateForConversationEntries()); +} + +void ConversationHandler::OnConversationMetadataUpdated() { + // Pass the updated data to archive content + if (archive_content_) { + archive_content_->SetMetadata( + metadata_->associated_content->url.value_or(GURL()), + base::UTF8ToUTF16(metadata_->associated_content->title.value_or("")), + metadata_->associated_content->content_type == + mojom::ContentType::VideoTranscript); + } + // Notify UI. If we have live content then the metadata will be updated + // again from that live data. + OnAssociatedContentInfoChanged(); +} + +void ConversationHandler::OnArchiveContentUpdated( + mojom::ConversationArchivePtr conversation_data) { + // We don't need to update text content if it's not archive since live + // content owns the text content and is re-fetched on demand. + if (archive_content_) { + // Only supports a single associated content for now + std::string text_content; + if (!conversation_data->associated_content.empty() && + conversation_data->associated_content[0]->content_uuid == + metadata_->associated_content->uuid) { + text_content = conversation_data->associated_content[0]->content; + } else { + text_content = ""; + } + archive_content_->SetContent(std::move(text_content)); + } +} + bool ConversationHandler::IsAnyClientConnected() { return !receivers_.empty() || !conversation_ui_handlers_.empty(); } @@ -199,8 +329,8 @@ bool ConversationHandler::HasAnyHistory() { }); } -bool ConversationHandler::IsAssociatedContentAlive() { - return associated_content_delegate_ && !archive_content_; +bool ConversationHandler::IsRequestInProgress() { + return is_request_in_progress_; } void ConversationHandler::OnConversationDeleted() { @@ -236,6 +366,13 @@ void ConversationHandler::InitEngine() { // no longer exists). model_key_ = model->key; + // Update Conversation metadata's model key + if (model_key_ != model_service_->GetDefaultModelKey()) { + metadata_->model_key = model_key_; + } else { + metadata_->model_key = std::nullopt; + } + engine_ = model_service_->GetEngineForModel(model_key_, url_loader_factory_, credential_manager_); @@ -244,9 +381,7 @@ void ConversationHandler::InitEngine() { if (is_request_in_progress_) { // Pending requests have been deleted along with the model engine is_request_in_progress_ = false; - for (auto& client : conversation_ui_handlers_) { - client->OnAPIRequestInProgress(is_request_in_progress_); - } + OnAPIRequestInProgressChanged(); } // When the model changes, the content truncation might be different, @@ -260,30 +395,35 @@ void ConversationHandler::InitEngine() { void ConversationHandler::OnAssociatedContentDestroyed( std::string last_text_content, bool is_video) { - // The associated content delegate is already or about to be destroyed. - auto content_id = associated_content_delegate_ - ? associated_content_delegate_->GetContentId() - : -1; - DisassociateContentDelegate(); + // The associated content delegate is destroyed, so we should not try to + // fetch. It may be populated later, e.g. through back navigation. + // If this conversation is allowed to be associated with content, we can keep + // using our current cached content. + associated_content_delegate_ = nullptr; if (!chat_history_.empty() && should_send_page_contents_ && - associated_content_info_ && associated_content_info_->url.has_value()) { + metadata_->associated_content && + metadata_->associated_content->is_content_association_possible) { // Get the latest version of article text and // associated_content_info_ if this chat has history and was connected to - // the associated conversation, then construct a "content archive" - // implementation of AssociatedContentDelegate with a duplicate of the - // article text. - auto archive_content = std::make_unique( - associated_content_info_->url.value_or(GURL()), last_text_content, - base::UTF8ToUTF16(associated_content_info_->title.value_or("")), - is_video); - associated_content_delegate_ = archive_content->GetWeakPtr(); - archive_content_ = std::move(archive_content); + // the associated conversation, then store the content so the conversation + // can continue. + SetArchiveContent(std::move(last_text_content), is_video); } OnAssociatedContentInfoChanged(); - // Notify observers - for (auto& observer : observers_) { - observer.OnAssociatedContentDestroyed(this, content_id); - } +} + +void ConversationHandler::SetArchiveContent(std::string text_content, + bool is_video) { + // Construct a "content archive" implementation of AssociatedContentDelegate + // with a duplicate of the article text. + auto archive_content = std::make_unique( + metadata_->associated_content->url.value_or(GURL()), + std::move(text_content), + base::UTF8ToUTF16(metadata_->associated_content->title.value_or("")), + is_video); + associated_content_delegate_ = archive_content->GetWeakPtr(); + archive_content_ = std::move(archive_content); + should_send_page_contents_ = true; } void ConversationHandler::SetAssociatedContentDelegate( @@ -360,18 +500,22 @@ void ConversationHandler::GetState(GetStateCallback callback) { BuildAssociatedContentInfo(); + std::vector suggestions; + std::ranges::transform(suggestions_, std::back_inserter(suggestions), + [](const auto& s) { return s.title; }); mojom::ConversationStatePtr state = mojom::ConversationState::New( metadata_->uuid, is_request_in_progress_, std::move(models_copy), - model_key, suggestions_, suggestion_generation_status_, - associated_content_info_->Clone(), should_send_page_contents_, + model_key, std::move(suggestions), suggestion_generation_status_, + metadata_->associated_content->Clone(), should_send_page_contents_, current_error_); std::move(callback).Run(std::move(state)); } void ConversationHandler::RateMessage(bool is_liked, - uint32_t turn_id, + const std::string& turn_uuid, RateMessageCallback callback) { + DVLOG(2) << __func__ << ": " << is_liked << ", " << turn_uuid; auto& model = GetCurrentModel(); // We only allow Leo models to be rated. @@ -379,33 +523,39 @@ void ConversationHandler::RateMessage(bool is_liked, const std::vector& history = chat_history_; - auto on_complete = base::BindOnce( - [](RateMessageCallback callback, APIRequestResult result) { - if (result.Is2XXResponseCode() && result.value_body().is_dict()) { - std::string id = *result.value_body().GetDict().FindString("id"); - std::move(callback).Run(id); - return; - } - std::move(callback).Run(std::nullopt); - }, - std::move(callback)); - - // TODO(petemill): Something more robust than relying on message index, - // and probably a message uuid. - uint32_t current_turn_id = turn_id + 1; - - if (current_turn_id <= history.size()) { - base::span history_slice = - base::make_span(history).first(current_turn_id); - - feedback_api_->SendRating( - is_liked, ai_chat_service_->IsPremiumStatus(), history_slice, - model.options->get_leo_model_options()->name, std::move(on_complete)); + auto entry_it = + base::ranges::find(history, turn_uuid, &mojom::ConversationTurn::uuid); + if (entry_it == history.end()) { + std::move(callback).Run(std::nullopt); return; } - std::move(callback).Run(std::nullopt); + const size_t count = std::distance(history.begin(), entry_it) + 1; + + base::span history_slice = + base::span(history).first(count); + + feedback_api_->SendRating( + is_liked, ai_chat_service_->IsPremiumStatus(), history_slice, + model.options->get_leo_model_options()->name, + base::BindOnce( + [](RateMessageCallback callback, APIRequestResult result) { + if (result.Is2XXResponseCode() && result.value_body().is_dict()) { + const std::string* id_result = + result.value_body().GetDict().FindString("id"); + if (id_result) { + std::move(callback).Run(*id_result); + } else { + DLOG(ERROR) << "Failed to get rating ID"; + std::move(callback).Run(std::nullopt); + } + return; + } + DLOG(ERROR) << "Failed to send rating: " << result.response_code(); + std::move(callback).Run(std::nullopt); + }, + std::move(callback))); } void ConversationHandler::SendFeedback(const std::string& category, @@ -413,13 +563,15 @@ void ConversationHandler::SendFeedback(const std::string& category, const std::string& rating_id, bool send_hostname, SendFeedbackCallback callback) { + DVLOG(2) << __func__ << ": " << rating_id << ", " << send_hostname << ", " + << category << ", " << feedback; auto on_complete = base::BindOnce( [](SendFeedbackCallback callback, APIRequestResult result) { if (result.Is2XXResponseCode()) { std::move(callback).Run(true); return; } - + DLOG(ERROR) << "Failed to send feedback: " << result.response_code(); std::move(callback).Run(false); }, std::move(callback)); @@ -474,7 +626,7 @@ void ConversationHandler::SubmitHumanConversationEntry( << "than a single human conversation turn at a time."; mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New( - CharacterType::HUMAN, mojom::ActionType::UNSPECIFIED, + std::nullopt, CharacterType::HUMAN, mojom::ActionType::UNSPECIFIED, mojom::ConversationTurnVisibility::VISIBLE, input, std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false); SubmitHumanConversationEntry(std::move(turn)); @@ -513,13 +665,7 @@ void ConversationHandler::SubmitHumanConversationEntry( DCHECK(latest_turn->character_type == mojom::CharacterType::HUMAN); is_request_in_progress_ = true; OnAPIRequestInProgressChanged(); - // If it's a suggested question, remove it - auto found_question_iter = - base::ranges::find(suggestions_, latest_turn->text); - if (found_question_iter != suggestions_.end()) { - suggestions_.erase(found_question_iter); - OnSuggestedQuestionsChanged(); - } + // Directly modify Entry's text to remove engine-breaking substrings if (!has_edits) { // Edits are already sanitized. engine_->SanitizeInput(latest_turn->text); @@ -533,7 +679,15 @@ void ConversationHandler::SubmitHumanConversationEntry( // callers of SubmitHumanConversationEntry mojo API currently don't have // action_type specified. std::string question_part = latest_turn->text; - if (latest_turn->action_type == mojom::ActionType::UNSPECIFIED) { + // If it's a suggested question, remove it + auto found_question_iter = + base::ranges::find(suggestions_, latest_turn->text, &Suggestion::title); + if (found_question_iter != suggestions_.end()) { + question_part = + found_question_iter->prompt.value_or(found_question_iter->title); + suggestions_.erase(found_question_iter); + OnSuggestedQuestionsChanged(); + } else if (latest_turn->action_type == mojom::ActionType::UNSPECIFIED) { if (latest_turn->text == l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_PAGE)) { latest_turn->action_type = mojom::ActionType::SUMMARIZE_PAGE; @@ -618,6 +772,7 @@ void ConversationHandler::ModifyConversation(uint32_t turn_index, } auto edited_turn = mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), turn->character_type, turn->action_type, turn->visibility, trimmed_input, std::nullopt /* selected_text */, std::move(events), base::Time::Now(), std::nullopt /* edits */, false); @@ -630,7 +785,9 @@ void ConversationHandler::ModifyConversation(uint32_t turn_index, } turn->edits->emplace_back(std::move(edited_turn)); - OnHistoryUpdate(); + OnConversationEntryRemoved(turn->uuid); + OnConversationEntryAdded(turn); + return; } @@ -649,18 +806,26 @@ void ConversationHandler::ModifyConversation(uint32_t turn_index, // editable human turns in our current implementation, just use std::nullopt // here directly to be more explicit and avoid confusion. auto edited_turn = mojom::ConversationTurn::New( - turn->character_type, turn->action_type, turn->visibility, - sanitized_input, std::nullopt /* selected_text */, - std::nullopt /* events */, base::Time::Now(), std::nullopt /* edits */, - false); + base::Uuid::GenerateRandomV4().AsLowercaseString(), turn->character_type, + turn->action_type, turn->visibility, sanitized_input, + std::nullopt /* selected_text */, std::nullopt /* events */, + base::Time::Now(), std::nullopt /* edits */, false); if (!turn->edits) { turn->edits.emplace(); } + // Erase all turns after the edited turn and notify observers + std::vector> erased_turn_ids; + base::ranges::transform( + chat_history_.begin() + turn_index, chat_history_.end(), + std::back_inserter(erased_turn_ids), + [](mojom::ConversationTurnPtr& turn) { return turn->uuid; }); turn->edits->emplace_back(std::move(edited_turn)); - auto new_turn = std::move(chat_history_.at(turn_index)); chat_history_.erase(chat_history_.begin() + turn_index, chat_history_.end()); - OnHistoryUpdate(); + + for (auto& uuid : erased_turn_ids) { + OnConversationEntryRemoved(uuid); + } SubmitHumanConversationEntry(std::move(new_turn)); } @@ -672,16 +837,24 @@ void ConversationHandler::SubmitSummarizationRequest() { << "This conversation request should send page contents"; mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New( - CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_PAGE, + std::nullopt, CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_PAGE, mojom::ConversationTurnVisibility::VISIBLE, l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_PAGE), std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false); SubmitHumanConversationEntry(std::move(turn)); } -void ConversationHandler::GetSuggestedQuestions( - GetSuggestedQuestionsCallback callback) { - std::move(callback).Run(suggestions_, suggestion_generation_status_); +std::vector ConversationHandler::GetSuggestedQuestionsForTest() { + std::vector suggestions; + base::ranges::transform(suggestions_, std::back_inserter(suggestions), + [](const auto& s) { return s.title; }); + return suggestions; +} + +void ConversationHandler::SetSuggestedQuestionForTest(std::string title, + std::string prompt) { + suggestions_.clear(); + suggestions_.emplace_back(title, prompt); } void ConversationHandler::GenerateQuestions() { @@ -757,7 +930,7 @@ void ConversationHandler::DisassociateContentDelegate() { void ConversationHandler::GetAssociatedContentInfo( GetAssociatedContentInfoCallback callback) { BuildAssociatedContentInfo(); - std::move(callback).Run(associated_content_info_->Clone(), + std::move(callback).Run(metadata_->associated_content->Clone(), should_send_page_contents_); } @@ -806,7 +979,7 @@ void ConversationHandler::ClearErrorAndGetFailedMessage( mojom::ConversationTurnPtr turn = std::move(chat_history_.back()); chat_history_.pop_back(); - OnHistoryUpdate(); + OnConversationEntryRemoved(turn->uuid); std::move(callback).Run(std::move(turn)); } @@ -822,7 +995,7 @@ void ConversationHandler::SubmitSelectedTextWithQuestion( const std::string& question, mojom::ActionType action_type) { mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New( - CharacterType::HUMAN, action_type, + std::nullopt, CharacterType::HUMAN, action_type, mojom::ConversationTurnVisibility::VISIBLE, question, selected_text, std::nullopt, base::Time::Now(), std::nullopt, false); @@ -861,7 +1034,7 @@ void ConversationHandler::AddSubmitSelectedTextError( } const std::string& question = GetActionTypeQuestion(action_type); mojom::ConversationTurnPtr turn = mojom::ConversationTurn::New( - CharacterType::HUMAN, action_type, + std::nullopt, CharacterType::HUMAN, action_type, mojom::ConversationTurnVisibility::VISIBLE, question, selected_text, std::nullopt, base::Time::Now(), std::nullopt, false); AddToConversationHistory(std::move(turn)); @@ -875,6 +1048,10 @@ void ConversationHandler::OnFaviconImageDataChanged() { } } +void ConversationHandler::OnAssociatedContentTitleChanged() { + OnAssociatedContentInfoChanged(); +} + void ConversationHandler::OnUserOptedIn() { MaybePopPendingRequests(); MaybeFetchOrClearContentStagedConversation(); @@ -886,10 +1063,13 @@ void ConversationHandler::AddToConversationHistory( return; } + if (!turn->uuid.has_value()) { + turn->uuid = base::Uuid::GenerateRandomV4().AsLowercaseString(); + } + chat_history_.push_back(std::move(turn)); - OnHistoryUpdate(); - OnConversationEntriesChanged(); + OnConversationEntryAdded(chat_history_.back()); } void ConversationHandler::PerformAssistantGeneration( @@ -957,6 +1137,7 @@ void ConversationHandler::UpdateOrCreateLastAssistantEntry( if (chat_history_.empty() || chat_history_.back()->character_type != CharacterType::ASSISTANT) { mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, std::vector{}, base::Time::Now(), @@ -1020,14 +1201,42 @@ void ConversationHandler::MaybeSeedOrClearSuggestions() { const bool is_page_associated = IsContentAssociationPossible() && should_send_page_contents_; - if (!is_page_associated && !suggestions_.empty()) { + if (!is_page_associated) { suggestions_.clear(); suggestion_generation_status_ = mojom::SuggestionGenerationStatus::None; + if (!chat_history_.empty()) { + return; + } + + suggestions_.emplace_back(STARTER_PROMPT(MEMO)); + suggestions_.emplace_back(STARTER_PROMPT(INTERVIEW)); + suggestions_.emplace_back(STARTER_PROMPT(STUDY_PLAN)); + suggestions_.emplace_back(STARTER_PROMPT(PROJECT_TIMELINE)); + suggestions_.emplace_back(STARTER_PROMPT(MARKETING_STRATEGY)); + suggestions_.emplace_back(STARTER_PROMPT(PRESENTATION_OUTLINE)); + suggestions_.emplace_back(STARTER_PROMPT(BRAINSTORM)); + suggestions_.emplace_back(STARTER_PROMPT(PROFESSIONAL_EMAIL)); + suggestions_.emplace_back(STARTER_PROMPT(BUSINESS_PROPOSAL)); + + // We don't have an external list of all the available suggestions, so we + // generate all of them and remove random ones until we have the required + // number and then shuffle the result. + while (suggestions_.size() > kDefaultSuggestionsCount) { + auto remove_at = base::RandInt(0, suggestions_.size() - 1); + suggestions_.erase(suggestions_.begin() + remove_at); + } + base::RandomShuffle(suggestions_.begin(), suggestions_.end()); OnSuggestedQuestionsChanged(); return; } - if (is_page_associated && suggestions_.empty() && + // This means we have the default suggestions + if (suggestion_generation_status_ == + mojom::SuggestionGenerationStatus::None) { + suggestions_.clear(); + } + + if (suggestions_.empty() && suggestion_generation_status_ != mojom::SuggestionGenerationStatus::IsGenerating && suggestion_generation_status_ != @@ -1094,22 +1303,26 @@ void ConversationHandler::OnGetStagedEntriesFromContent( return turn->from_brave_search_SERP; }); - // Add the query & summary pairs to the conversation history and call - // OnHistoryUpdate to update UI. + // Add the query & summary pairs to the conversation history and notify + // observers. for (const auto& entry : *entries) { chat_history_.push_back(mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, entry.query, std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, true)); + OnConversationEntryAdded(chat_history_.back()); + std::vector events; events.push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New(entry.summary))); chat_history_.push_back(mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, entry.summary, std::nullopt, std::move(events), base::Time::Now(), std::nullopt, true)); + OnConversationEntryAdded(chat_history_.back()); } - OnHistoryUpdate(); } void ConversationHandler::GeneratePageContent(GetPageContentCallback callback) { @@ -1124,21 +1337,35 @@ void ConversationHandler::GeneratePageContent(GetPageContentCallback callback) { DCHECK(ai_chat_service_->HasUserOptedIn()) << "UI shouldn't allow operations before user has accepted agreement"; + // Keep hold of the current content so we can check if it changed + std::string current_content = + std::string(associated_content_delegate_->GetCachedTextContent()); associated_content_delegate_->GetContent( base::BindOnce(&ConversationHandler::OnGeneratePageContentComplete, - weak_ptr_factory_.GetWeakPtr(), std::move(callback))); + weak_ptr_factory_.GetWeakPtr(), std::move(callback), + std::move(current_content))); } void ConversationHandler::OnGeneratePageContentComplete( GetPageContentCallback callback, + std::string previous_content, std::string contents_text, bool is_video, std::string invalidation_token) { engine_->SanitizeInput(contents_text); + // Keep is_content_different_ as true if it's the initial state + is_content_different_ = + is_content_different_ || contents_text != previous_content; + + metadata_->associated_content->content_type = + is_video ? mojom::ContentType::VideoTranscript + : mojom::ContentType::PageContent; + std::move(callback).Run(contents_text, is_video, invalidation_token); - // Content-used percentage might have changed + // Content-used percentage and is_video might have changed in addition to + // content_type. OnAssociatedContentInfoChanged(); } @@ -1193,11 +1420,13 @@ void ConversationHandler::OnEngineCompletionComplete( UpdateOrCreateLastAssistantEntry( mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New(*result))); - OnConversationEntriesChanged(); + OnConversationEntryAdded(chat_history_.back()); } else { auto& last_entry = chat_history_.back(); if (last_entry->character_type != mojom::CharacterType::ASSISTANT) { SetAPIError(mojom::APIError::ConnectionIssue); + } else { + OnConversationEntryAdded(chat_history_.back()); } } MaybePopPendingRequests(); @@ -1224,18 +1453,20 @@ void ConversationHandler::OnEngineCompletionComplete( void ConversationHandler::OnSuggestedQuestionsResponse( EngineConsumer::SuggestedQuestionResult result) { if (result.has_value()) { - suggestions_.insert(suggestions_.end(), result->begin(), result->end()); + std::ranges::transform(result.value(), std::back_inserter(suggestions_), + [](const auto& s) { return Suggestion(s); }); suggestion_generation_status_ = mojom::SuggestionGenerationStatus::HasGenerated; + DVLOG(2) << "Got questions:" << base::JoinString(result.value(), "\n"); } else { // TODO(nullhook): Set a specialized error state generated questions suggestion_generation_status_ = mojom::SuggestionGenerationStatus::CanGenerate; + DVLOG(2) << "Got no questions"; } // Notify observers OnSuggestedQuestionsChanged(); - DVLOG(2) << "Got questions:" << base::JoinString(suggestions_, "\n"); } void ConversationHandler::OnModelListUpdated() { @@ -1280,6 +1511,7 @@ void ConversationHandler::OnModelDataChanged() { [](auto& model) { return model.Clone(); }); client->OnModelDataChanged(model_key_, std::move(models_copy)); } + OnStateForConversationEntriesChanged(); } void ConversationHandler::OnHistoryUpdate() { @@ -1288,20 +1520,57 @@ void ConversationHandler::OnHistoryUpdate() { for (auto& client : conversation_ui_handlers_) { client->OnConversationHistoryUpdate(); } - OnConversationEntriesChanged(); + for (auto& client : untrusted_conversation_ui_handlers_) { + client->OnConversationHistoryUpdate(); + } } -void ConversationHandler::OnConversationEntriesChanged() { +void ConversationHandler::OnConversationEntryRemoved( + std::optional entry_uuid) { + OnHistoryUpdate(); + if (!entry_uuid.has_value()) { + return; + } for (auto& observer : observers_) { - // TODO(petemill): only tell observers about complete turns. This is - // expensive to do for every event generated by in-progress turns, - // and consumers likely only need complete ones (e.g. database save). - std::vector history; - for (const auto& turn : chat_history_) { - history.emplace_back(turn->Clone()); + observer.OnConversationEntryRemoved(this, entry_uuid.value()); + } +} + +void ConversationHandler::OnConversationEntryAdded( + mojom::ConversationTurnPtr& entry) { + // Only notify about staged entries once we have the first staged entry + if (entry->from_brave_search_SERP) { + OnHistoryUpdate(); + return; + } + std::optional associated_content_value; + if (is_content_different_ && associated_content_delegate_) { + associated_content_value = + associated_content_delegate_->GetCachedTextContent(); + is_content_different_ = false; + } + // If this is the first entry that isn't staged, notify about all previous + // staged entries + if (!entry->from_brave_search_SERP && + base::ranges::all_of(chat_history_, + [&entry](mojom::ConversationTurnPtr& history_entry) { + return history_entry == entry || + history_entry->from_brave_search_SERP; + })) { + // Notify every item in chat history + for (auto& observer : observers_) { + for (auto& history_entry : chat_history_) { + observer.OnConversationEntryAdded(this, history_entry, + associated_content_value); + } } - observer.OnConversationEntriesChanged(this, std::move(history)); + OnHistoryUpdate(); + return; + } + for (auto& observer : observers_) { + observer.OnConversationEntryAdded(this, entry, associated_content_value); } + OnHistoryUpdate(); } int ConversationHandler::GetContentUsedPercentage() { @@ -1330,32 +1599,56 @@ bool ConversationHandler::IsContentAssociationPossible() { } void ConversationHandler::BuildAssociatedContentInfo() { - // Save in class instance so that we have a cache for when live - // AssociatedContentDelegate disconnects. Only modify in this function. - associated_content_info_ = mojom::SiteInfo::New(); + // Only modify associated content metadata here if (associated_content_delegate_) { - associated_content_info_->title = + metadata_->associated_content->title = base::UTF16ToUTF8(associated_content_delegate_->GetTitle()); const GURL url = associated_content_delegate_->GetURL(); - if (url.SchemeIsHTTPOrHTTPS()) { - associated_content_info_->hostname = url.host(); - associated_content_info_->url = url; - } - associated_content_info_->content_used_percentage = + metadata_->associated_content->hostname = url.host(); + metadata_->associated_content->url = url; + metadata_->associated_content->content_used_percentage = GetContentUsedPercentage(); - associated_content_info_->is_content_refined = is_content_refined_; - associated_content_info_->is_content_association_possible = true; + metadata_->associated_content->is_content_refined = is_content_refined_; + metadata_->associated_content->is_content_association_possible = true; } else { - associated_content_info_->is_content_association_possible = false; + metadata_->associated_content->title = std::nullopt; + metadata_->associated_content->hostname = std::nullopt; + metadata_->associated_content->url = std::nullopt; + metadata_->associated_content->is_content_association_possible = false; } } +mojom::ConversationEntriesStatePtr +ConversationHandler::GetStateForConversationEntries() { + auto& model = GetCurrentModel(); + bool is_leo_model = model.options->is_leo_model_options(); + + mojom::ConversationEntriesStatePtr entries_state = + mojom::ConversationEntriesState::New(); + entries_state->is_generating = IsRequestInProgress(); + entries_state->is_content_refined = is_content_refined_; + entries_state->is_leo_model = is_leo_model; + entries_state->content_used_percentage = + metadata_->associated_content->is_content_association_possible + ? std::make_optional( + metadata_->associated_content->content_used_percentage) + : std::nullopt; + // Can't submit if not a premium user and the model is premium-only + entries_state->can_submit_user_entries = + !IsRequestInProgress() && + (ai_chat_service_->IsPremiumStatus() || !is_leo_model || + model.options->get_leo_model_options()->access != + mojom::ModelAccess::PREMIUM); + return entries_state; +} + void ConversationHandler::OnAssociatedContentInfoChanged() { BuildAssociatedContentInfo(); for (auto& client : conversation_ui_handlers_) { - client->OnAssociatedContentInfoChanged(associated_content_info_->Clone(), - should_send_page_contents_); + client->OnAssociatedContentInfoChanged( + metadata_->associated_content->Clone(), should_send_page_contents_); } + OnStateForConversationEntriesChanged(); } void ConversationHandler::OnClientConnectionChanged() { @@ -1387,19 +1680,39 @@ void ConversationHandler::OnAssociatedContentFaviconImageDataChanged() { for (auto& client : conversation_ui_handlers_) { client->OnFaviconImageDataChanged(); } + for (auto& client : untrusted_conversation_ui_handlers_) { + client->OnFaviconImageDataChanged(); + } } void ConversationHandler::OnSuggestedQuestionsChanged() { + std::vector suggestions; + std::ranges::transform(suggestions_, std::back_inserter(suggestions), + [](const auto& s) { return s.title; }); + for (auto& client : conversation_ui_handlers_) { - client->OnSuggestedQuestionsChanged(suggestions_, + client->OnSuggestedQuestionsChanged(suggestions, suggestion_generation_status_); } } void ConversationHandler::OnAPIRequestInProgressChanged() { + OnStateForConversationEntriesChanged(); for (auto& client : conversation_ui_handlers_) { client->OnAPIRequestInProgress(is_request_in_progress_); } + for (auto& observer : observers_) { + observer.OnRequestInProgressChanged(this, is_request_in_progress_); + } +} + +void ConversationHandler::OnStateForConversationEntriesChanged() { + auto entries_state = GetStateForConversationEntries(); + for (auto& client : untrusted_conversation_ui_handlers_) { + client->OnEntriesUIStateChanged(entries_state->Clone()); + } } } // namespace ai_chat + +#undef STARTER_PROMPT diff --git a/components/ai_chat/core/browser/conversation_handler.h b/components/ai_chat/core/browser/conversation_handler.h index e57e6d0a4244..3cfe6ec46c74 100644 --- a/components/ai_chat/core/browser/conversation_handler.h +++ b/components/ai_chat/core/browser/conversation_handler.h @@ -6,21 +6,35 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_CONVERSATION_HANDLER_H_ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_CONVERSATION_HANDLER_H_ +#include +#include #include +#include +#include #include +#include #include #include #include +#include "base/functional/callback.h" #include "base/functional/callback_forward.h" +#include "base/gtest_prod_util.h" +#include "base/memory/raw_ptr.h" +#include "base/memory/scoped_refptr.h" #include "base/memory/weak_ptr.h" +#include "base/observer_list.h" +#include "base/observer_list_types.h" #include "base/scoped_observation.h" +#include "base/task/sequenced_task_runner.h" +#include "base/types/expected.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" #include "brave/components/ai_chat/core/browser/model_service.h" #include "brave/components/ai_chat/core/browser/text_embedder.h" #include "brave/components/ai_chat/core/browser/types.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "mojo/public/cpp/bindings/pending_receiver.h" #include "mojo/public/cpp/bindings/receiver_set.h" @@ -28,6 +42,12 @@ #include "url/gurl.h" class AIChatUIBrowserTest; +namespace mojo { +template +class PendingRemote; +template +class PendingReceiver; +} // namespace mojo namespace network { class SharedURLLoaderFactory; @@ -38,11 +58,13 @@ namespace ai_chat { class AIChatFeedbackAPI; class AIChatService; class AssociatedArchiveContent; +class AIChatCredentialManager; // Performs all conversation-related operations, responsible for sending // messages to the conversation engine, handling the responses, and owning // the in-memory conversation history. class ConversationHandler : public mojom::ConversationHandler, + public mojom::UntrustedConversationHandler, public ModelService::Observer { public: // |invalidation_token| is an optional parameter that will be passed back on @@ -131,9 +153,16 @@ class ConversationHandler : public mojom::ConversationHandler, ~Observer() override {} // Called when the conversation history changess - virtual void OnConversationEntriesChanged( + virtual void OnRequestInProgressChanged(ConversationHandler* handler, + bool in_progress) {} + virtual void OnConversationEntryAdded( ConversationHandler* handler, - std::vector entries) {} + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value) {} + virtual void OnConversationEntryRemoved(ConversationHandler* handler, + std::string turn_uuid) {} + virtual void OnConversationEntryUpdated(ConversationHandler* handler, + mojom::ConversationTurnPtr entry) {} // Called when a mojo client connects or disconnects virtual void OnClientConnectionChanged(ConversationHandler* handler) {} @@ -142,18 +171,25 @@ class ConversationHandler : public mojom::ConversationHandler, virtual void OnSelectedLanguageChanged( ConversationHandler* handler, const std::string& selected_language) {} - virtual void OnAssociatedContentDestroyed(ConversationHandler* handler, - int content_id) {} }; ConversationHandler( - const mojom::Conversation* conversation, + mojom::Conversation* conversation, AIChatService* ai_chat_service, ModelService* model_service, AIChatCredentialManager* credential_manager, AIChatFeedbackAPI* feedback_api, scoped_refptr url_loader_factory); + ConversationHandler( + mojom::Conversation* conversation, + AIChatService* ai_chat_service, + ModelService* model_service, + AIChatCredentialManager* credential_manager, + AIChatFeedbackAPI* feedback_api, + scoped_refptr url_loader_factory, + std::optional initial_state); + ~ConversationHandler() override; ConversationHandler(const ConversationHandler&) = delete; ConversationHandler& operator=(const ConversationHandler&) = delete; @@ -161,16 +197,23 @@ class ConversationHandler : public mojom::ConversationHandler, void Bind(mojo::PendingRemote conversation_ui_handler); void Bind(mojo::PendingReceiver receiver, mojo::PendingRemote conversation_ui_handler); + void Bind( + mojo::PendingReceiver receiver); + void BindUntrustedConversationUI( + mojo::PendingRemote + untrusted_conversation_ui_handler, + BindUntrustedConversationUICallback callback) override; void AddObserver(Observer* observer); void RemoveObserver(Observer* observer); + // Called when the provided Conversation data is updated + void OnConversationMetadataUpdated(); + void OnArchiveContentUpdated(mojom::ConversationArchivePtr conversation_data); + bool IsAnyClientConnected(); bool HasAnyHistory(); - void OnConversationDeleted(); - - // Returns true if the conversation has associated content that is non-archive - bool IsAssociatedContentAlive(); + bool IsRequestInProgress(); // Called when the associated content is destroyed or navigated away. If // it's a navigation, the AssociatedContentDelegate will set itself to a new @@ -189,7 +232,7 @@ class ConversationHandler : public mojom::ConversationHandler, void GetState(GetStateCallback callback) override; void GetConversationHistory(GetConversationHistoryCallback callback) override; void RateMessage(bool is_liked, - uint32_t turn_id, + const std::string& turn_uuid, RateMessageCallback callback) override; void SendFeedback(const std::string& category, const std::string& feedback, @@ -208,7 +251,8 @@ class ConversationHandler : public mojom::ConversationHandler, void ModifyConversation(uint32_t turn_index, const std::string& new_text) override; void SubmitSummarizationRequest() override; - void GetSuggestedQuestions(GetSuggestedQuestionsCallback callback) override; + std::vector GetSuggestedQuestionsForTest(); + void SetSuggestedQuestionForTest(std::string title, std::string prompt); void GenerateQuestions() override; void GetAssociatedContentInfo( GetAssociatedContentInfoCallback callback) override; @@ -228,6 +272,7 @@ class ConversationHandler : public mojom::ConversationHandler, void AddSubmitSelectedTextError(const std::string& selected_text, mojom::ActionType action_type, mojom::APIError error); + void OnAssociatedContentTitleChanged(); void OnFaviconImageDataChanged(); void OnUserOptedIn(); @@ -249,7 +294,9 @@ class ConversationHandler : public mojom::ConversationHandler, void SetChatHistoryForTesting( std::vector history) { chat_history_ = std::move(history); - OnHistoryUpdate(); + for (auto& entry : chat_history_) { + OnConversationEntryAdded(entry); + } } AssociatedContentDelegate* GetAssociatedContentDelegateForTesting() { @@ -269,6 +316,7 @@ class ConversationHandler : public mojom::ConversationHandler, private: friend class ::AIChatUIBrowserTest; + FRIEND_TEST_ALL_PREFIXES(AIChatServiceUnitTest, DeleteAssociatedWebContent); FRIEND_TEST_ALL_PREFIXES(ConversationHandlerUnitTest, UpdateOrCreateLastAssistantEntry_Delta); FRIEND_TEST_ALL_PREFIXES(ConversationHandlerUnitTest, @@ -285,12 +333,23 @@ class ConversationHandler : public mojom::ConversationHandler, FRIEND_TEST_ALL_PREFIXES(PageContentRefineTest, LocalModelsUpdater); FRIEND_TEST_ALL_PREFIXES(PageContentRefineTest, TextEmbedder); FRIEND_TEST_ALL_PREFIXES(PageContentRefineTest, TextEmbedderInitialized); - FRIEND_TEST_ALL_PREFIXES(PageContentRefineTest, LeoLocalModelsUpdater); - FRIEND_TEST_ALL_PREFIXES(PageContentRefineTest, TextEmbedder); - FRIEND_TEST_ALL_PREFIXES(PageContentRefineTest, TextEmbedderInitialized); + + struct Suggestion { + std::string title; + std::optional prompt; + + explicit Suggestion(std::string title); + Suggestion(std::string title, std::string prompt); + Suggestion(const Suggestion&) = delete; + Suggestion& operator=(const Suggestion&) = delete; + Suggestion(Suggestion&&); + Suggestion& operator=(Suggestion&&); + ~Suggestion(); + }; void InitEngine(); void BuildAssociatedContentInfo(); + mojom::ConversationEntriesStatePtr GetStateForConversationEntries(); bool IsContentAssociationPossible(); int GetContentUsedPercentage(); void AddToConversationHistory(mojom::ConversationTurnPtr turn); @@ -315,11 +374,11 @@ class ConversationHandler : public mojom::ConversationHandler, const std::optional>& entries); void GeneratePageContent(GetPageContentCallback callback); - void SetPageContent(std::string contents_text, - bool is_video, - std::string invalidation_token); + + void SetArchiveContent(std::string text_content, bool is_video); void OnGeneratePageContentComplete(GetPageContentCallback callback, + std::string previous_content, std::string contents_text, bool is_video, std::string invalidation_token); @@ -336,16 +395,19 @@ class ConversationHandler : public mojom::ConversationHandler, EngineConsumer::SuggestedQuestionResult result); void OnModelDataChanged(); + void OnConversationDeleted(); void OnHistoryUpdate(); + void OnConversationEntryAdded(mojom::ConversationTurnPtr& entry); + void OnConversationEntryRemoved(std::optional turn_id); void OnSuggestedQuestionsChanged(); void OnAssociatedContentInfoChanged(); - void OnConversationEntriesChanged(); void OnClientConnectionChanged(); void OnConversationTitleChanged(std::string title); void OnConversationUIConnectionChanged(mojo::RemoteSetElementId id); void OnSelectedLanguageChanged(const std::string& selected_language); void OnAssociatedContentFaviconImageDataChanged(); void OnAPIRequestInProgressChanged(); + void OnStateForConversationEntriesChanged(); base::WeakPtr associated_content_delegate_; std::unique_ptr archive_content_; @@ -355,12 +417,11 @@ class ConversationHandler : public mojom::ConversationHandler, std::vector chat_history_; mojom::ConversationTurnPtr pending_conversation_entry_; // Any previously-generated suggested questions - std::vector suggestions_; + std::vector suggestions_; std::string selected_language_; // Is a conversation engine request in progress (does not include // non-conversation engine requests. bool is_request_in_progress_ = false; - mojom::SiteInfoPtr associated_content_info_ = nullptr; // TODO(petemill): Tracking whether the UI is open // for a conversation might not be neccessary anymore as there @@ -385,6 +446,9 @@ class ConversationHandler : public mojom::ConversationHandler, // different pages. bool should_send_page_contents_ = false; bool is_content_refined_ = false; + // When this is true, the most recent content retrieval was different to the + // previous one. + bool is_content_different_ = true; bool is_print_preview_fallback_requested_ = false; @@ -392,7 +456,7 @@ class ConversationHandler : public mojom::ConversationHandler, mojom::APIError current_error_ = mojom::APIError::None; // Data store UUID for conversation - raw_ptr metadata_; + raw_ptr metadata_; raw_ptr ai_chat_service_; raw_ptr model_service_; raw_ptr credential_manager_; @@ -404,8 +468,10 @@ class ConversationHandler : public mojom::ConversationHandler, base::ObserverList observers_; mojo::ReceiverSet receivers_; - // TODO(petemill): Rename to ConversationUIHandler + mojo::ReceiverSet untrusted_receivers_; mojo::RemoteSet conversation_ui_handlers_; + mojo::RemoteSet + untrusted_conversation_ui_handlers_; base::WeakPtrFactory weak_ptr_factory_{this}; }; diff --git a/components/ai_chat/core/browser/conversation_handler_unittest.cc b/components/ai_chat/core/browser/conversation_handler_unittest.cc index 4f621e362128..d90d053db228 100644 --- a/components/ai_chat/core/browser/conversation_handler_unittest.cc +++ b/components/ai_chat/core/browser/conversation_handler_unittest.cc @@ -12,11 +12,13 @@ #include #include +#include "base/files/scoped_temp_dir.h" #include "base/functional/bind.h" #include "base/functional/callback.h" #include "base/functional/overloaded.h" #include "base/memory/scoped_refptr.h" #include "base/ranges/algorithm.h" +#include "base/run_loop.h" #include "base/scoped_observation.h" #include "base/task/sequenced_task_runner.h" #include "base/task/thread_pool.h" @@ -30,6 +32,8 @@ #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/browser/ai_chat_service.h" #include "brave/components/ai_chat/core/browser/engine/mock_engine_consumer.h" +#include "brave/components/ai_chat/core/browser/mock_conversation_handler_observer.h" +#include "brave/components/ai_chat/core/browser/test_utils.h" #include "brave/components/ai_chat/core/browser/text_embedder.h" #include "brave/components/ai_chat/core/browser/types.h" #include "brave/components/ai_chat/core/browser/utils.h" @@ -38,6 +42,8 @@ #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/ai_chat/core/common/pref_names.h" #include "components/grit/brave_components_strings.h" +#include "components/os_crypt/async/browser/os_crypt_async.h" +#include "components/os_crypt/async/browser/test_utils.h" #include "components/sync_preferences/testing_pref_service_syncable.h" #include "mojo/public/cpp/bindings/receiver.h" #include "services/data_decoder/public/cpp/test_support/in_process_data_decoder.h" @@ -60,17 +66,6 @@ namespace ai_chat { namespace { -bool CompareConversationTurn(const mojom::ConversationTurnPtr& a, - const mojom::ConversationTurnPtr& b) { - if (!a || !b) { - return a == b; // Both should be null or neither - } - return a->action_type == b->action_type && - a->character_type == b->character_type && - a->selected_text == b->selected_text && a->text == b->text && - a->visibility == b->visibility; -} - class MockAIChatCredentialManager : public AIChatCredentialManager { public: using AIChatCredentialManager::AIChatCredentialManager; @@ -183,10 +178,14 @@ class MockTextEmbedder : public TextEmbedder { class ConversationHandlerUnitTest : public testing::Test { public: void SetUp() override { + ASSERT_TRUE(temp_directory_.CreateUniqueTempDir()); prefs::RegisterProfilePrefs(prefs_.registry()); prefs::RegisterLocalStatePrefs(local_state_.registry()); ModelService::RegisterProfilePrefs(prefs_.registry()); + os_crypt_ = os_crypt_async::GetTestOSCryptAsyncForTesting( + /*is_sync_for_unittests=*/true); + shared_url_loader_factory_ = base::MakeRefCounted( &url_loader_factory_); @@ -205,10 +204,15 @@ class ConversationHandlerUnitTest : public testing::Test { ai_chat_service_ = std::make_unique( model_service_.get(), std::move(credential_manager), &prefs_, nullptr, - shared_url_loader_factory_, ""); + os_crypt_.get(), shared_url_loader_factory_, "", + temp_directory_.GetPath()); + mojom::SiteInfoPtr non_content = mojom::SiteInfo::New( + std::nullopt, mojom::ContentType::PageContent, std::nullopt, + std::nullopt, std::nullopt, 0, false, false); conversation_ = - mojom::Conversation::New("uuid", "title", base::Time::Now(), false); + mojom::Conversation::New("uuid", "title", base::Time::Now(), false, + std::nullopt, std::move(non_content)); conversation_handler_ = std::make_unique( conversation_.get(), ai_chat_service_.get(), model_service_.get(), @@ -282,6 +286,7 @@ class ConversationHandlerUnitTest : public testing::Test { } auto entry = mojom::ConversationTurn::New( + std::nullopt, is_human ? mojom::CharacterType::HUMAN : mojom::CharacterType::ASSISTANT, is_human ? mojom::ActionType::QUERY : mojom::ActionType::RESPONSE, @@ -300,6 +305,7 @@ class ConversationHandlerUnitTest : public testing::Test { std::unique_ptr model_service_; sync_preferences::TestingPrefServiceSyncable prefs_; sync_preferences::TestingPrefServiceSyncable local_state_; + std::unique_ptr os_crypt_; network::TestURLLoaderFactory url_loader_factory_; scoped_refptr shared_url_loader_factory_; data_decoder::test::InProcessDataDecoder in_process_data_decoder_; @@ -308,6 +314,9 @@ class ConversationHandlerUnitTest : public testing::Test { std::unique_ptr> associated_content_; bool is_opted_in_ = true; bool has_associated_content_ = true; + + private: + base::ScopedTempDir temp_directory_; }; class ConversationHandlerUnitTest_OptedOut @@ -355,7 +364,7 @@ TEST_F(ConversationHandlerUnitTest, GetState) { testing::ElementsAre(l10n_util::GetStringUTF8( IDS_CHAT_UI_SUMMARIZE_PAGE))); } else { - EXPECT_TRUE(state->suggested_questions.empty()); + EXPECT_EQ(4u, state->suggested_questions.size()); } EXPECT_EQ(state->suggestion_status, should_send_content @@ -378,6 +387,7 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText) { std::string selected_text = "I have spoken."; std::string expected_turn_text = l10n_util::GetStringUTF8(IDS_AI_CHAT_QUESTION_SUMMARIZE_SELECTED_TEXT); + const std::string expected_response = "This is the way."; // Expect the ConversationHandler to call the engine with the selected text // and the action's expanded text. @@ -389,7 +399,7 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText) { .WillOnce(::testing::DoAll( base::test::RunOnceCallback<5>( mojom::ConversationEntryEvent::NewCompletionEvent( - mojom::CompletionEvent::New("This is the way."))), + mojom::CompletionEvent::New(expected_response))), base::test::RunOnceCallback<6>(base::ok("")))); EXPECT_FALSE(conversation_handler_->HasAnyHistory()); @@ -410,8 +420,8 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText) { NiceMock client(conversation_handler_.get()); EXPECT_CALL(client, OnAPIRequestInProgress(true)).Times(1); - // Human and AI entries. - EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(2); + // Human, AI entries and content event for AI response. + EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(3); // Fired from OnEngineCompletionComplete. EXPECT_CALL(client, OnAPIRequestInProgress(false)).Times(1); // Ensure everything is sanitized @@ -426,7 +436,7 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText) { .Times(1); conversation_handler_->SubmitSelectedText( - "I have spoken.", mojom::ActionType::SUMMARIZE_SELECTED_TEXT); + selected_text, mojom::ActionType::SUMMARIZE_SELECTED_TEXT); task_environment_.RunUntilIdle(); testing::Mock::VerifyAndClearExpectations(&client); @@ -440,28 +450,27 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText) { // once conversation history is committed. EXPECT_FALSE(site_info->is_content_association_possible); })); - conversation_handler_->GetSuggestedQuestions( - base::BindLambdaForTesting([&](const std::vector& questions, - mojom::SuggestionGenerationStatus status) { - EXPECT_TRUE(questions.empty()); - })); + EXPECT_TRUE(conversation_handler_->GetSuggestedQuestionsForTest().empty()); EXPECT_TRUE(conversation_handler_->HasAnyHistory()); const auto& history = conversation_handler_->GetConversationHistory(); std::vector expected_history; + expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT, - mojom::ConversationTurnVisibility::VISIBLE, - l10n_util::GetStringUTF8(IDS_AI_CHAT_QUESTION_SUMMARIZE_SELECTED_TEXT), - "I have spoken.", std::nullopt, base::Time::Now(), std::nullopt, false)); + std::nullopt, mojom::CharacterType::HUMAN, + mojom::ActionType::SUMMARIZE_SELECTED_TEXT, + mojom::ConversationTurnVisibility::VISIBLE, expected_turn_text, + selected_text, std::nullopt, base::Time::Now(), std::nullopt, false)); + + std::vector response_events; + response_events.push_back(mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New(expected_response))); expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "This is the way.", - std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); - EXPECT_EQ(history.size(), expected_history.size()); - for (size_t i = 0; i < history.size(); i++) { - EXPECT_TRUE(CompareConversationTurn(history[i], expected_history[i])); - } + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + expected_response, std::nullopt, std::move(response_events), + base::Time::Now(), std::nullopt, false)); + ExpectConversationHistoryEquals(FROM_HERE, history, expected_history, false); } TEST_F(ConversationHandlerUnitTest, SubmitSelectedText_WithAssociatedContent) { @@ -475,6 +484,7 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText_WithAssociatedContent) { std::string selected_text = "I have spoken again."; std::string expected_turn_text = l10n_util::GetStringUTF8(IDS_AI_CHAT_QUESTION_SUMMARIZE_SELECTED_TEXT); + std::string expected_response = "This is the way."; EXPECT_CALL(*engine, GenerateAssistantResponse( false, StrEq(page_content), LastTurnHasSelectedText(selected_text), @@ -483,7 +493,7 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText_WithAssociatedContent) { .WillOnce(::testing::DoAll( base::test::RunOnceCallback<5>( mojom::ConversationEntryEvent::NewCompletionEvent( - mojom::CompletionEvent::New("This is the way."))), + mojom::CompletionEvent::New(expected_response))), base::test::RunOnceCallback<6>(base::ok("")))); ON_CALL(*associated_content_, GetURL) @@ -500,8 +510,8 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText_WithAssociatedContent) { NiceMock client(conversation_handler_.get()); EXPECT_CALL(client, OnAPIRequestInProgress(true)).Times(1); - // Human and AI entries. - EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(2); + // Human and AI entries, and content event for AI response. + EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(3); // Fired from OnEngineCompletionComplete. EXPECT_CALL(client, OnAPIRequestInProgress(false)).Times(1); // Ensure everything is sanitized @@ -512,7 +522,7 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText_WithAssociatedContent) { EXPECT_CALL(*engine, GenerateQuestionSuggestions).Times(0); conversation_handler_->SubmitSelectedText( - "I have spoken again.", mojom::ActionType::SUMMARIZE_SELECTED_TEXT); + selected_text, mojom::ActionType::SUMMARIZE_SELECTED_TEXT); task_environment_.RunUntilIdle(); testing::Mock::VerifyAndClearExpectations(&client); @@ -527,28 +537,28 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText_WithAssociatedContent) { // Should not be any LLM-generated suggested questions yet because they // weren't asked for - conversation_handler_->GetSuggestedQuestions( - base::BindLambdaForTesting([&](const std::vector& questions, - mojom::SuggestionGenerationStatus status) { - EXPECT_EQ(1u, questions.size()); - EXPECT_EQ(questions[0], "Summarize this page"); - })); + const auto questions = conversation_handler_->GetSuggestedQuestionsForTest(); + EXPECT_EQ(1u, questions.size()); + EXPECT_EQ(questions[0], "Summarize this page"); const auto& history2 = conversation_handler_->GetConversationHistory(); std::vector expected_history2; expected_history2.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT, + std::nullopt, mojom::CharacterType::HUMAN, + mojom::ActionType::SUMMARIZE_SELECTED_TEXT, mojom::ConversationTurnVisibility::VISIBLE, expected_turn_text, - "I have spoken again.", std::nullopt, base::Time::Now(), std::nullopt, - false)); + selected_text, std::nullopt, base::Time::Now(), std::nullopt, false)); + + std::vector response_events; + response_events.push_back(mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New(expected_response))); expected_history2.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "This is the way.", - std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); - EXPECT_EQ(history2.size(), expected_history2.size()); - for (size_t i = 0; i < history2.size(); i++) { - EXPECT_TRUE(CompareConversationTurn(history2[i], expected_history2[i])); - } + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + expected_response, std::nullopt, std::move(response_events), + base::Time::Now(), std::nullopt, false)); + ExpectConversationHistoryEquals(FROM_HERE, history2, expected_history2, + false); } TEST_F(ConversationHandlerUnitTest, UpdateOrCreateLastAssistantEntry_Delta) { @@ -795,111 +805,119 @@ TEST_F(ConversationHandlerUnitTest, ModifyConversation) { MockEngineConsumer* engine = static_cast( conversation_handler_->GetEngineForTesting()); - // Setup history for testing. - auto created_time1 = base::Time::Now(); - std::vector history; - history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, - mojom::ConversationTurnVisibility::VISIBLE, "prompt1", std::nullopt, - std::nullopt, created_time1, std::nullopt, false)); - history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "answer1", std::nullopt, - std::nullopt, base::Time::Now(), std::nullopt, false)); - conversation_handler_->SetChatHistoryForTesting(std::move(history)); - + // Setup history for testing. Items have IDs so we can test removal + // notifications to an observer. + std::vector history = CreateSampleChatHistory(1); + EXPECT_FALSE(history[0]->edits); + conversation_handler_->SetChatHistoryForTesting(CloneHistory(history)); + mojom::ConversationEntryEventPtr expected_new_completion_event = + mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New("new answer")); // Modify an entry for the first time. EXPECT_CALL(*engine, GenerateAssistantResponse( false, StrEq(""), LastTurnHasText("prompt2"), StrEq("prompt2"), StrEq(""), _, _)) // Mock the response from the engine - .WillOnce(::testing::DoAll( - base::test::RunOnceCallback<5>( - mojom::ConversationEntryEvent::NewCompletionEvent( - mojom::CompletionEvent::New("new answer"))), - base::test::RunOnceCallback<6>(base::ok("")))); + .WillOnce(::testing::DoAll(base::test::RunOnceCallback<5>( + expected_new_completion_event->Clone()), + base::test::RunOnceCallback<6>(base::ok("")))); + testing::NiceMock observer; + // Verify both entries are removed + EXPECT_CALL(observer, OnConversationEntryRemoved(conversation_handler_.get(), + history[0]->uuid.value())) + .Times(1); + EXPECT_CALL(observer, OnConversationEntryRemoved(conversation_handler_.get(), + history[1]->uuid.value())) + .Times(1); + // Verify edited entry is added as well as the new response + EXPECT_CALL(observer, + OnConversationEntryAdded(conversation_handler_.get(), _, _)) + .Times(2); + observer.Observe(conversation_handler_.get()); + + // Make a first edit conversation_handler_->ModifyConversation(0, "prompt2"); + testing::Mock::VerifyAndClearExpectations(&observer); + + // Create the entries events in the way we're expecting to look + // post-modification. + auto first_edit_expected_history = CloneHistory(history); + auto first_edit = history[0]->Clone(); + first_edit->uuid = "ignore_me"; + first_edit->selected_text = std::nullopt; + first_edit->text = "prompt2"; + + first_edit_expected_history[0]->edits.emplace(); + first_edit_expected_history[0]->edits->push_back(first_edit->Clone()); + + first_edit_expected_history[1]->text = "new answer"; + first_edit_expected_history[1]->events.emplace(); + first_edit_expected_history[1]->events->emplace_back( + expected_new_completion_event->Clone()); + + // Verify the first entry still has original details const auto& conversation_history = conversation_handler_->GetConversationHistory(); - ASSERT_EQ(conversation_history.size(), 2u); - EXPECT_EQ(conversation_history[0]->text, "prompt1"); - EXPECT_EQ(conversation_history[0]->created_time, created_time1); - - ASSERT_TRUE(conversation_history[0]->edits); - ASSERT_EQ(conversation_history[0]->edits->size(), 1u); - EXPECT_EQ(conversation_history[0]->edits->at(0)->text, "prompt2"); - EXPECT_NE(conversation_history[0]->edits->at(0)->created_time, created_time1); - EXPECT_FALSE(conversation_history[0]->edits->at(0)->edits); - EXPECT_EQ(conversation_history[1]->text, "new answer"); + ExpectConversationHistoryEquals(FROM_HERE, conversation_history, + first_edit_expected_history, false); + // Create time shouldn't be changed + EXPECT_EQ(conversation_history[0]->created_time, history[0]->created_time); auto created_time2 = conversation_history[0]->edits->at(0)->created_time; + // New edit should have a different created time + EXPECT_NE(created_time2, history[0]->created_time); // Modify the same entry again. EXPECT_CALL(*engine, GenerateAssistantResponse( false, StrEq(""), LastTurnHasText("prompt3"), StrEq("prompt3"), StrEq(""), _, _)) // Mock the response from the engine - .WillOnce(::testing::DoAll( - base::test::RunOnceCallback<5>( - mojom::ConversationEntryEvent::NewCompletionEvent( - mojom::CompletionEvent::New("new answer"))), - base::test::RunOnceCallback<6>(base::ok("")))); + .WillOnce(::testing::DoAll(base::test::RunOnceCallback<5>( + expected_new_completion_event->Clone()), + base::test::RunOnceCallback<6>(base::ok("")))); conversation_handler_->ModifyConversation(0, "prompt3"); - ASSERT_EQ(conversation_history.size(), 2u); - EXPECT_EQ(conversation_history[0]->text, "prompt1"); - EXPECT_EQ(conversation_history[0]->created_time, created_time1); - ASSERT_TRUE(conversation_history[0]->edits); - ASSERT_EQ(conversation_history[0]->edits->size(), 2u); - EXPECT_EQ(conversation_history[0]->edits->at(0)->text, "prompt2"); - EXPECT_EQ(conversation_history[0]->edits->at(0)->created_time, created_time2); - EXPECT_FALSE(conversation_history[0]->edits->at(0)->edits); + auto second_edit_expected_history = CloneHistory(first_edit_expected_history); + auto second_edit = first_edit->Clone(); + second_edit->text = "prompt3"; + second_edit_expected_history[0]->edits->emplace_back(second_edit->Clone()); - EXPECT_EQ(conversation_history[0]->edits->at(1)->text, "prompt3"); - EXPECT_NE(conversation_history[0]->edits->at(1)->created_time, created_time1); + ExpectConversationHistoryEquals(FROM_HERE, conversation_history, + second_edit_expected_history, false); + // Create time shouldn't be changed + EXPECT_EQ(conversation_history[0]->created_time, history[0]->created_time); + // New edit should have a different create time + EXPECT_EQ(conversation_history[0]->edits->at(0)->created_time, created_time2); + EXPECT_NE(conversation_history[0]->edits->at(1)->created_time, + conversation_history[0]->created_time); EXPECT_NE(conversation_history[0]->edits->at(1)->created_time, created_time2); - EXPECT_FALSE(conversation_history[0]->edits->at(1)->edits); - - EXPECT_EQ(conversation_history[1]->text, "new answer"); - // Modify server response should have text and completion event updated in + // Modifying server response should have text and completion event updated in // the entry of edits. // Engine should not be called for an assistant edit EXPECT_CALL(*engine, GenerateAssistantResponse(_, _, _, _, _, _, _)).Times(0); conversation_handler_->ModifyConversation(1, " answer2 "); - ASSERT_EQ(conversation_history.size(), 2u); - EXPECT_EQ(conversation_history[1]->text, "new answer"); - ASSERT_TRUE(conversation_history[1]->edits); - ASSERT_EQ(conversation_history[1]->edits->size(), 1u); - EXPECT_EQ(conversation_history[1]->edits->at(0)->text, "answer2"); + auto third_edit_expected_history = CloneHistory(second_edit_expected_history); + + auto response_edit = third_edit_expected_history[1]->Clone(); + response_edit->uuid = "ignore_me"; + response_edit->text = "answer2"; // trimmed + response_edit->events->at(0) = + mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New("answer2")); + + third_edit_expected_history[1]->edits.emplace(); + third_edit_expected_history[1]->edits->emplace_back(response_edit->Clone()); + + ExpectConversationHistoryEquals(FROM_HERE, conversation_history, + third_edit_expected_history, false); + + // Edit time should be set differently EXPECT_NE(conversation_history[1]->edits->at(0)->created_time, conversation_history[1]->created_time); - - ASSERT_TRUE(conversation_history[1]->events); - ASSERT_EQ(conversation_history[1]->events->size(), 1u); - // Verify the original is left unchanged - ASSERT_TRUE(conversation_history[1]->events->at(0)->is_completion_event()); - EXPECT_EQ(conversation_history[1] - ->events->at(0) - ->get_completion_event() - ->completion, - "new answer"); - - ASSERT_TRUE(conversation_history[1]->edits->at(0)->events); - ASSERT_EQ(conversation_history[1]->edits->at(0)->events->size(), 1u); - ASSERT_TRUE(conversation_history[1] - ->edits->at(0) - ->events->at(0) - ->is_completion_event()); - EXPECT_EQ(conversation_history[1] - ->edits->at(0) - ->events->at(0) - ->get_completion_event() - ->completion, - "answer2"); } TEST_F(ConversationHandlerUnitTest, @@ -907,11 +925,20 @@ TEST_F(ConversationHandlerUnitTest, // Fetch with result should update the conversation history and call // OnConversationHistoryUpdate on observers. SetAssociatedContentStagedEntries(/*empty=*/false); + + // Shouldn't get any notification of real entries added + NiceMock observer; + observer.Observe(conversation_handler_.get()); + EXPECT_CALL(observer, OnConversationEntryAdded).Times(0); + // Client connecting will trigger content staging EXPECT_CALL(*associated_content_, GetStagedEntriesFromContent).Times(1); NiceMock client(conversation_handler_.get()); - EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(1); EXPECT_TRUE(conversation_handler_->IsAnyClientConnected()); + + // History update notification once for each entry + EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(2); + conversation_handler_->GetAssociatedContentInfo(base::BindLambdaForTesting( [&](mojom::SiteInfoPtr site_info, bool should_send_page_contents) { EXPECT_TRUE(should_send_page_contents); @@ -919,25 +946,28 @@ TEST_F(ConversationHandlerUnitTest, task_environment_.RunUntilIdle(); testing::Mock::VerifyAndClearExpectations(associated_content_.get()); + testing::Mock::VerifyAndClearExpectations(&observer); testing::Mock::VerifyAndClearExpectations(&client); auto& history = conversation_handler_->GetConversationHistory(); std::vector expected_history; expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "query", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, true)); std::vector events; events.push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New("summary"))); expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "summary", std::nullopt, - std::move(events), base::Time::Now(), std::nullopt, true)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "summary", std::nullopt, std::move(events), base::Time::Now(), + std::nullopt, true)); ASSERT_EQ(history.size(), expected_history.size()); for (size_t i = 0; i < history.size(); i++) { expected_history[i]->created_time = history[i]->created_time; - EXPECT_EQ(history[i], expected_history[i]); + ExpectConversationEntryEquals(FROM_HERE, history[i], expected_history[i], + false); } // HasAnyHistory should still return false since all entries are staged EXPECT_FALSE(conversation_handler_->HasAnyHistory()); @@ -964,9 +994,12 @@ TEST_F(ConversationHandlerUnitTest, // OnConversationHistoryUpdate on observers. SetAssociatedContentStagedEntries(/*empty=*/false, /*multi=*/true); // Client connecting will trigger content staging + testing::NiceMock observer; + observer.Observe(conversation_handler_.get()); + EXPECT_CALL(observer, OnConversationEntryAdded).Times(0); EXPECT_CALL(*associated_content_, GetStagedEntriesFromContent).Times(1); NiceMock client(conversation_handler_.get()); - EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(1); + EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(4); EXPECT_TRUE(conversation_handler_->IsAnyClientConnected()); conversation_handler_->GetAssociatedContentInfo(base::BindLambdaForTesting( [&](mojom::SiteInfoPtr site_info, bool should_send_page_contents) { @@ -975,56 +1008,72 @@ TEST_F(ConversationHandlerUnitTest, task_environment_.RunUntilIdle(); testing::Mock::VerifyAndClearExpectations(associated_content_.get()); + testing::Mock::VerifyAndClearExpectations(&observer); testing::Mock::VerifyAndClearExpectations(&client); auto& history = conversation_handler_->GetConversationHistory(); std::vector expected_history; expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "query", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, true)); std::vector events; events.push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New("summary"))); expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "summary", std::nullopt, - std::move(events), base::Time::Now(), std::nullopt, true)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "summary", std::nullopt, std::move(events), base::Time::Now(), + std::nullopt, true)); expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "query2", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, true)); std::vector events2; events2.push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New("summary2"))); expected_history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "summary2", std::nullopt, - std::move(events2), base::Time::Now(), std::nullopt, true)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "summary2", std::nullopt, std::move(events2), base::Time::Now(), + std::nullopt, true)); ASSERT_EQ(history.size(), expected_history.size()); for (size_t i = 0; i < history.size(); i++) { expected_history[i]->created_time = history[i]->created_time; - EXPECT_EQ(history[i], expected_history[i]); + ExpectConversationEntryEquals(FROM_HERE, history[i], expected_history[i], + false); } // HasAnyHistory should still return false since all entries are staged EXPECT_FALSE(conversation_handler_->HasAnyHistory()); - // Verify turning off content association clears the conversation history. - EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(1); - // Shouldn't ask for staged entries if user doesn't want to be associated - // with content. This verifies that even with existing staged entries, - // MaybeFetchOrClearContentStagedConversation will always early return. - EXPECT_CALL(*associated_content_, GetStagedEntriesFromContent).Times(0); + // Verify adding an actual conversation entry causes all entries to be + // notified and HasAnyHistory to return true. + // Modify an entry for the first time. + MockEngineConsumer* engine = static_cast( + conversation_handler_->GetEngineForTesting()); + EXPECT_CALL(*associated_content_, GetContent) + .WillOnce(base::test::RunOnceCallback<0>("page content", false, "")); + EXPECT_CALL(*engine, GenerateAssistantResponse) + // Mock the response from the engine + .WillOnce(::testing::DoAll( + base::test::RunOnceCallback<5>( + mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New("new answer"))), + base::test::RunOnceCallback<6>(base::ok("")))); - conversation_handler_->SetShouldSendPageContents(false); + EXPECT_CALL(observer, OnConversationEntryAdded).Times(6); + EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(3); + + conversation_handler_->SubmitHumanConversationEntry("query3"); task_environment_.RunUntilIdle(); testing::Mock::VerifyAndClearExpectations(&client); + testing::Mock::VerifyAndClearExpectations(&observer); testing::Mock::VerifyAndClearExpectations(associated_content_.get()); - EXPECT_TRUE(conversation_handler_->GetConversationHistory().empty()); + EXPECT_TRUE(conversation_handler_->HasAnyHistory()); } TEST_F(ConversationHandlerUnitTest, @@ -1055,9 +1104,9 @@ TEST_F( // MaybeFetchOrClearContentStagedConversation should clear old staged entries // and fetch new ones. EXPECT_CALL(*associated_content_, GetStagedEntriesFromContent).Times(1); - // One from SetupHistory and one from removing old entries and adding + // 4 from SetupHistory and 4 from adding // new entries in OnGetStagedEntriesFromContent. - EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(2); + EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(8); // Fill history with staged and non-staged entries. SetupHistory({{"old query" /* text */, true /*from_brave_search_SERP */}, @@ -1112,7 +1161,7 @@ TEST_F(ConversationHandlerUnitTest, OnGetStagedEntriesFromContent) { NiceMock client(conversation_handler_.get()); ASSERT_TRUE(conversation_handler_->IsAnyClientConnected()); - EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(2); + EXPECT_CALL(client, OnConversationHistoryUpdate()).Times(8); // Fill history with staged and non-staged entries. SetupHistory({{"q1" /* text */, true /*from_brave_search_SERP */}, {"s1", "true"}, @@ -1317,6 +1366,37 @@ TEST_F(ConversationHandlerUnitTest_NoAssociatedContent, GenerateQuestions) { testing::Mock::VerifyAndClearExpectations(engine); } +TEST_F(ConversationHandlerUnitTest_NoAssociatedContent, + GeneratesQuestionsByDefault) { + EXPECT_EQ(4u, conversation_handler_->GetSuggestedQuestionsForTest().size()); +} + +TEST_F(ConversationHandlerUnitTest_NoAssociatedContent, + SelectingDefaultQuestionSendsPrompt) { + conversation_handler_->SetSuggestedQuestionForTest("the thing", + "do the thing!"); + auto suggestions = conversation_handler_->GetSuggestedQuestionsForTest(); + EXPECT_EQ(1u, suggestions.size()); + + // Mock engine response + MockEngineConsumer* engine = static_cast( + conversation_handler_->GetEngineForTesting()); + + base::RunLoop loop; + // The prompt should be submitted to the engine, not the title. + EXPECT_CALL(*engine, + GenerateAssistantResponse(false, StrEq(""), _, "do the thing!", + StrEq(""), _, _)) + .WillOnce(testing::InvokeWithoutArgs(&loop, &base::RunLoop::Quit)); + + conversation_handler_->SubmitHumanConversationEntry("the thing"); + loop.Run(); + testing::Mock::VerifyAndClearExpectations(engine); + + // Suggestion should be removed + EXPECT_EQ(0u, conversation_handler_->GetSuggestedQuestionsForTest().size()); +} + TEST_F(ConversationHandlerUnitTest, SelectedLanguage) { MockEngineConsumer* engine = static_cast( conversation_handler_->GetEngineForTesting()); diff --git a/components/ai_chat/core/browser/engine/conversation_api_client.cc b/components/ai_chat/core/browser/engine/conversation_api_client.cc index babb0bb00698..39bb31e05995 100644 --- a/components/ai_chat/core/browser/engine/conversation_api_client.cc +++ b/components/ai_chat/core/browser/engine/conversation_api_client.cc @@ -5,36 +5,48 @@ #include "brave/components/ai_chat/core/browser/engine/conversation_api_client.h" -#include +#include #include -#include +#include #include +#include +#include #include #include +#include "base/check.h" #include "base/command_line.h" +#include "base/containers/checked_iterators.h" +#include "base/containers/flat_map.h" #include "base/functional/bind.h" #include "base/json/json_writer.h" +#include "base/logging.h" #include "base/memory/raw_ptr.h" +#include "base/memory/scoped_refptr.h" #include "base/memory/weak_ptr.h" +#include "base/metrics/field_trial_params.h" #include "base/no_destructor.h" #include "base/notreached.h" +#include "base/numerics/clamped_math.h" #include "base/strings/strcat.h" #include "base/strings/string_util.h" #include "base/types/expected.h" +#include "base/values.h" #include "brave/brave_domains/service_domains.h" +#include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/common/buildflags/buildflags.h" #include "brave/components/ai_chat/core/common/features.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/brave_service_keys/brave_service_key_utils.h" #include "brave/components/constants/brave_services_key.h" #include "brave/components/l10n/common/locale_util.h" +#include "mojo/public/cpp/bindings/struct_ptr.h" #include "net/http/http_request_headers.h" #include "net/http/http_status_code.h" #include "net/traffic_annotation/network_traffic_annotation.h" #include "services/network/public/cpp/shared_url_loader_factory.h" -#include "services/network/public/cpp/simple_url_loader.h" #include "url/gurl.h" +#include "url/url_constants.h" namespace ai_chat { diff --git a/components/ai_chat/core/browser/engine/conversation_api_client.h b/components/ai_chat/core/browser/engine/conversation_api_client.h index 9fe7f545a9f2..aff93fb18927 100644 --- a/components/ai_chat/core/browser/engine/conversation_api_client.h +++ b/components/ai_chat/core/browser/engine/conversation_api_client.h @@ -7,13 +7,25 @@ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_CONVERSATION_API_CLIENT_H_ #include +#include #include #include #include +#include "base/functional/callback.h" +#include "base/memory/raw_ptr.h" +#include "base/memory/scoped_refptr.h" +#include "base/memory/weak_ptr.h" +#include "base/types/expected.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" +#include "brave/components/ai_chat/core/browser/engine/remote_completion_client.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +#include "brave/components/api_request_helper/api_request_helper.h" + +namespace base { +class Value; +} // namespace base namespace api_request_helper { class APIRequestResult; @@ -24,6 +36,8 @@ class SharedURLLoaderFactory; } // namespace network namespace ai_chat { +class AIChatCredentialManager; +struct CredentialCacheEntry; // Performs remote request to the remote HTTP Brave Conversation API. class ConversationAPIClient { diff --git a/components/ai_chat/core/browser/engine/conversation_api_client_unittest.cc b/components/ai_chat/core/browser/engine/conversation_api_client_unittest.cc index ef112991b430..9f91122d238a 100644 --- a/components/ai_chat/core/browser/engine/conversation_api_client_unittest.cc +++ b/components/ai_chat/core/browser/engine/conversation_api_client_unittest.cc @@ -5,23 +5,26 @@ #include "brave/components/ai_chat/core/browser/engine/conversation_api_client.h" -#include +#include #include #include +#include +#include #include #include +#include "base/containers/flat_map.h" #include "base/functional/bind.h" #include "base/functional/callback_helpers.h" #include "base/json/json_reader.h" #include "base/json/json_writer.h" +#include "base/memory/scoped_refptr.h" +#include "base/numerics/clamped_math.h" #include "base/run_loop.h" -#include "base/strings/strcat.h" -#include "base/strings/string_util.h" -#include "base/test/bind.h" #include "base/test/task_environment.h" #include "base/time/time.h" #include "base/types/expected.h" +#include "base/values.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" @@ -29,6 +32,7 @@ #include "brave/components/api_request_helper/api_request_helper.h" #include "brave/components/l10n/common/test/scoped_default_locale.h" #include "components/prefs/testing_pref_service.h" +#include "mojo/public/cpp/bindings/struct_ptr.h" #include "net/base/net_errors.h" #include "net/http/http_request_headers.h" #include "net/traffic_annotation/network_traffic_annotation.h" @@ -36,6 +40,19 @@ #include "services/network/public/cpp/shared_url_loader_factory.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" +#include "url/gurl.h" +#include "url/url_constants.h" + +class PrefService; +namespace mojo { +template +class PendingRemote; +} // namespace mojo +namespace skus { +namespace mojom { +class SkusService; +} // namespace mojom +} // namespace skus using ConversationHistory = std::vector; using ::testing::_; diff --git a/components/ai_chat/core/browser/engine/engine_consumer.cc b/components/ai_chat/core/browser/engine/engine_consumer.cc index 42a4dba03b2d..7c02ab881aee 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer.cc @@ -5,6 +5,10 @@ #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" +#include + +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" + namespace ai_chat { EngineConsumer::EngineConsumer() = default; diff --git a/components/ai_chat/core/browser/engine/engine_consumer.h b/components/ai_chat/core/browser/engine/engine_consumer.h index 1acfc739ac3f..874cd3cfa0db 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer.h +++ b/components/ai_chat/core/browser/engine/engine_consumer.h @@ -6,17 +6,22 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_ENGINE_CONSUMER_H_ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_ENGINE_CONSUMER_H_ +#include #include #include #include #include +#include "base/functional/callback.h" #include "base/functional/callback_forward.h" #include "base/types/expected.h" #include "brave/components/ai_chat/core/browser/engine/remote_completion_client.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" namespace ai_chat { +namespace mojom { +class ModelOptions; +} // namespace mojom // Abstract class for using AI completion engines to generate various specific // styles of completion. The engines could be local (invoked directly via a diff --git a/components/ai_chat/core/browser/engine/engine_consumer_claude.cc b/components/ai_chat/core/browser/engine/engine_consumer_claude.cc index 6688e9a10d17..10b60d3a0ab7 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_claude.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_claude.cc @@ -5,35 +5,43 @@ #include "brave/components/ai_chat/core/browser/engine/engine_consumer_claude.h" -#include +#include +#include #include +#include #include #include -#include +#include #include +#include "base/check.h" #include "base/containers/fixed_flat_set.h" #include "base/containers/flat_set.h" +#include "base/containers/span.h" #include "base/functional/bind.h" +#include "base/functional/callback.h" #include "base/i18n/time_formatting.h" -#include "base/memory/raw_ptr.h" +#include "base/logging.h" +#include "base/memory/scoped_refptr.h" #include "base/memory/weak_ptr.h" +#include "base/numerics/clamped_math.h" #include "base/strings/pattern.h" #include "base/strings/strcat.h" #include "base/strings/string_split.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" +#include "base/time/time.h" #include "base/types/expected.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" #include "brave/components/ai_chat/core/browser/engine/remote_completion_client.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "components/grit/brave_components_strings.h" -#include "net/http/http_status_code.h" #include "services/network/public/cpp/shared_url_loader_factory.h" #include "ui/base/l10n/l10n_util.h" namespace ai_chat { +class AIChatCredentialManager; namespace { diff --git a/components/ai_chat/core/browser/engine/engine_consumer_claude.h b/components/ai_chat/core/browser/engine/engine_consumer_claude.h index a70b6bd3678a..2fd4b2d98385 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_claude.h +++ b/components/ai_chat/core/browser/engine/engine_consumer_claude.h @@ -10,10 +10,15 @@ #include #include +#include "base/memory/weak_ptr.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" +#include "brave/components/ai_chat/core/browser/engine/remote_completion_client.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +template +class scoped_refptr; + namespace api_request_helper { class APIRequestResult; } // namespace api_request_helper @@ -23,6 +28,11 @@ class SharedURLLoaderFactory; } // namespace network namespace ai_chat { +class AIChatCredentialManager; +namespace mojom { +class LeoModelOptions; +class ModelOptions; +} // namespace mojom using api_request_helper::APIRequestResult; diff --git a/components/ai_chat/core/browser/engine/engine_consumer_claude_unittest.cc b/components/ai_chat/core/browser/engine/engine_consumer_claude_unittest.cc index 1e64160fa896..d08e3f8fc2ee 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_claude_unittest.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_claude_unittest.cc @@ -5,24 +5,27 @@ #include "brave/components/ai_chat/core/browser/engine/engine_consumer_claude.h" -#include #include #include -#include +#include +#include #include +#include "base/functional/bind.h" +#include "base/functional/callback.h" #include "base/functional/callback_helpers.h" +#include "base/numerics/clamped_math.h" #include "base/run_loop.h" #include "base/strings/string_util.h" #include "base/test/bind.h" #include "base/test/task_environment.h" #include "base/time/time.h" +#include "base/types/expected.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" #include "brave/components/ai_chat/core/browser/engine/mock_remote_completion_client.h" #include "brave/components/ai_chat/core/browser/engine/test_utils.h" #include "brave/components/ai_chat/core/browser/model_service.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "services/network/public/cpp/shared_url_loader_factory.h" #include "testing/gmock/include/gmock/gmock.h" @@ -71,14 +74,16 @@ class EngineConsumerClaudeUnitTest : public testing::Test { TEST_F(EngineConsumerClaudeUnitTest, TestGenerateAssistantResponse) { std::vector history; history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT, + std::nullopt, mojom::CharacterType::HUMAN, + mojom::ActionType::SUMMARIZE_SELECTED_TEXT, mojom::ConversationTurnVisibility::VISIBLE, "Which show is this catchphrase from?", "I have spoken.", std::nullopt, base::Time::Now(), std::nullopt, false)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "The Mandalorian.", - std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "The Mandalorian.", std::nullopt, std::nullopt, base::Time::Now(), + std::nullopt, false)); auto* mock_remote_completion_client = GetMockRemoteCompletionClient(); std::string prompt_before_time_and_date = "\n\nHuman: Here is the text of a web page in tags:\n\nThis " @@ -247,10 +252,10 @@ TEST_F(EngineConsumerClaudeUnitTest, GenerateAssistantResponseEarlyReturn) { testing::Mock::VerifyAndClearExpectations(mock_remote_completion_client); mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, - std::vector{}, base::Time::Now(), - std::nullopt, false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "", std::nullopt, std::vector{}, + base::Time::Now(), std::nullopt, false); entry->events->push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New("Me"))); history.push_back(std::move(entry)); diff --git a/components/ai_chat/core/browser/engine/engine_consumer_conversation_api.cc b/components/ai_chat/core/browser/engine/engine_consumer_conversation_api.cc index 9b1f1f5fee98..6ffd87394fa0 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_conversation_api.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_conversation_api.cc @@ -5,16 +5,21 @@ #include "brave/components/ai_chat/core/browser/engine/engine_consumer_conversation_api.h" +#include #include -#include +#include +#include #include +#include "base/check.h" #include "base/functional/bind.h" +#include "base/functional/callback.h" #include "base/functional/callback_helpers.h" +#include "base/memory/scoped_refptr.h" #include "base/memory/weak_ptr.h" -#include "base/notreached.h" +#include "base/numerics/clamped_math.h" #include "base/strings/string_split.h" -#include "base/strings/string_util.h" +#include "base/time/time.h" #include "base/types/expected.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "services/network/public/cpp/shared_url_loader_factory.h" diff --git a/components/ai_chat/core/browser/engine/engine_consumer_conversation_api.h b/components/ai_chat/core/browser/engine/engine_consumer_conversation_api.h index 941299984f55..4f07c3806cea 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_conversation_api.h +++ b/components/ai_chat/core/browser/engine/engine_consumer_conversation_api.h @@ -11,10 +11,14 @@ #include #include +#include "base/memory/weak_ptr.h" #include "brave/components/ai_chat/core/browser/engine/conversation_api_client.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +template +class scoped_refptr; + namespace network { class SharedURLLoaderFactory; } // namespace network @@ -22,6 +26,10 @@ class SharedURLLoaderFactory; namespace ai_chat { class AIChatCredentialManager; +namespace mojom { +class LeoModelOptions; +class ModelOptions; +} // namespace mojom // An AI Chat engine consumer that uses the remote HTTP Brave Conversation API. // Converts between AI Chat's Conversation actions and data model diff --git a/components/ai_chat/core/browser/engine/engine_consumer_conversation_api_unittest.cc b/components/ai_chat/core/browser/engine/engine_consumer_conversation_api_unittest.cc index 6f73c5255d49..0871357a6d0f 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_conversation_api_unittest.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_conversation_api_unittest.cc @@ -5,25 +5,26 @@ #include "brave/components/ai_chat/core/browser/engine/engine_consumer_conversation_api.h" -#include #include #include -#include +#include +#include #include +#include "base/functional/callback.h" #include "base/functional/callback_helpers.h" #include "base/json/json_reader.h" #include "base/json/json_writer.h" +#include "base/memory/scoped_refptr.h" +#include "base/numerics/clamped_math.h" #include "base/run_loop.h" -#include "base/strings/strcat.h" -#include "base/strings/string_util.h" #include "base/test/bind.h" #include "base/test/task_environment.h" #include "base/time/time.h" +#include "base/types/expected.h" +#include "base/values.h" #include "brave/components/ai_chat/core/browser/engine/conversation_api_client.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "services/network/public/cpp/shared_url_loader_factory.h" #include "testing/gmock/include/gmock/gmock.h" @@ -34,7 +35,7 @@ using ::testing::_; namespace ai_chat { namespace { -const int kTestingMaxAssociatedContentLength = 100; +constexpr int kTestingMaxAssociatedContentLength = 100; } using ConversationEvent = ConversationAPIClient::ConversationEvent; @@ -214,16 +215,17 @@ TEST_F(EngineConsumerConversationAPIUnitTest, // selected text but with page association. EngineConsumer::ConversationHistory history; history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "Which show is this catchphrase from?", "I have spoken.", std::nullopt, base::Time::Now(), std::nullopt, false)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "The Mandalorian.", - std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "The Mandalorian.", std::nullopt, std::nullopt, base::Time::Now(), + std::nullopt, false)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::RESPONSE, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, "Is it related to a broader series?", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); @@ -293,7 +295,7 @@ TEST_F(EngineConsumerConversationAPIUnitTest, GenerateEvents_ModifyReply) { // Tests events building from history with modified agent reply. EngineConsumer::ConversationHistory history; history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "Which show is 'This is the way' from?", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); @@ -314,18 +316,19 @@ TEST_F(EngineConsumerConversationAPIUnitTest, GenerateEvents_ModifyReply) { modified_events.push_back(modified_completion_event.Clone()); auto edit = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "The Mandalorian.", - std::nullopt, std::move(modified_events), base::Time::Now(), std::nullopt, - false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "The Mandalorian.", std::nullopt, std::move(modified_events), + base::Time::Now(), std::nullopt, false); std::vector edits; edits.push_back(std::move(edit)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "Mandalorian.", std::nullopt, - std::move(events), base::Time::Now(), std::move(edits), false)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "Mandalorian.", std::nullopt, std::move(events), base::Time::Now(), + std::move(edits), false)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "Is it related to a broader series?", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); @@ -380,10 +383,10 @@ TEST_F(EngineConsumerConversationAPIUnitTest, GenerateEvents_EarlyReturn) { testing::Mock::VerifyAndClearExpectations(mock_api_client); mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, - std::vector{}, base::Time::Now(), - std::nullopt, false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "", std::nullopt, std::vector{}, + base::Time::Now(), std::nullopt, false); entry->events->push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New("Me"))); history.push_back(std::move(entry)); diff --git a/components/ai_chat/core/browser/engine/engine_consumer_llama.cc b/components/ai_chat/core/browser/engine/engine_consumer_llama.cc index 5660d1e57031..f8e611918efd 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_llama.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_llama.cc @@ -5,31 +5,43 @@ #include "brave/components/ai_chat/core/browser/engine/engine_consumer_llama.h" -#include +#include + +#include +#include #include +#include #include #include -#include +#include #include +#include "base/check.h" #include "base/containers/fixed_flat_set.h" #include "base/containers/flat_set.h" #include "base/functional/bind.h" +#include "base/functional/callback.h" #include "base/i18n/time_formatting.h" -#include "base/memory/raw_ptr.h" +#include "base/logging.h" +#include "base/memory/scoped_refptr.h" #include "base/memory/weak_ptr.h" #include "base/strings/strcat.h" #include "base/strings/string_split.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" #include "base/time/time.h" +#include "base/types/expected.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" #include "brave/components/ai_chat/core/browser/engine/remote_completion_client.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" #include "components/grit/brave_components_strings.h" #include "services/network/public/cpp/shared_url_loader_factory.h" #include "ui/base/l10n/l10n_util.h" +namespace ai_chat { +class AIChatCredentialManager; +} // namespace ai_chat + namespace { using ai_chat::mojom::ConversationTurnPtr; diff --git a/components/ai_chat/core/browser/engine/engine_consumer_llama.h b/components/ai_chat/core/browser/engine/engine_consumer_llama.h index fac6abdad35b..3ebfb81d4fd6 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_llama.h +++ b/components/ai_chat/core/browser/engine/engine_consumer_llama.h @@ -10,10 +10,15 @@ #include #include +#include "base/memory/weak_ptr.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" +#include "brave/components/ai_chat/core/browser/engine/remote_completion_client.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +template +class scoped_refptr; + namespace api_request_helper { class APIRequestResult; } // namespace api_request_helper @@ -23,6 +28,11 @@ class SharedURLLoaderFactory; } // namespace network namespace ai_chat { +class AIChatCredentialManager; +namespace mojom { +class LeoModelOptions; +class ModelOptions; +} // namespace mojom using api_request_helper::APIRequestResult; diff --git a/components/ai_chat/core/browser/engine/engine_consumer_llama_unittest.cc b/components/ai_chat/core/browser/engine/engine_consumer_llama_unittest.cc index f853582f295e..f6be05002b8b 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_llama_unittest.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_llama_unittest.cc @@ -5,18 +5,21 @@ #include "brave/components/ai_chat/core/browser/engine/engine_consumer_llama.h" -#include #include #include -#include +#include +#include #include +#include "base/functional/bind.h" +#include "base/functional/callback.h" #include "base/functional/callback_helpers.h" +#include "base/numerics/clamped_math.h" #include "base/run_loop.h" -#include "base/strings/string_util.h" #include "base/test/bind.h" #include "base/test/task_environment.h" #include "base/time/time.h" +#include "base/types/expected.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" #include "brave/components/ai_chat/core/browser/engine/mock_remote_completion_client.h" #include "brave/components/ai_chat/core/browser/engine/test_utils.h" @@ -26,6 +29,7 @@ #include "services/network/public/cpp/shared_url_loader_factory.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" + using ::testing::_; using ::testing::Sequence; @@ -69,14 +73,16 @@ class EngineConsumerLlamaUnitTest : public testing::Test { TEST_F(EngineConsumerLlamaUnitTest, TestGenerateAssistantResponse) { EngineConsumer::ConversationHistory history; history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT, + std::nullopt, mojom::CharacterType::HUMAN, + mojom::ActionType::SUMMARIZE_SELECTED_TEXT, mojom::ConversationTurnVisibility::VISIBLE, "Which show is this catchphrase from?", "This is the way.", std::nullopt, base::Time::Now(), std::nullopt, false)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "The Mandalorian.", - std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "The Mandalorian.", std::nullopt, std::nullopt, base::Time::Now(), + std::nullopt, false)); auto* mock_remote_completion_client = static_cast(engine_->GetAPIForTesting()); std::string prompt_before_time_and_date = @@ -252,10 +258,10 @@ TEST_F(EngineConsumerLlamaUnitTest, GenerateAssistantResponseEarlyReturn) { testing::Mock::VerifyAndClearExpectations(mock_remote_completion_client); mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, - std::vector{}, base::Time::Now(), - std::nullopt, false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "", std::nullopt, std::vector{}, + base::Time::Now(), std::nullopt, false); entry->events->push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New("Me"))); history.push_back(std::move(entry)); diff --git a/components/ai_chat/core/browser/engine/engine_consumer_oai.cc b/components/ai_chat/core/browser/engine/engine_consumer_oai.cc index 317b2a92bde1..c22690550eeb 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_oai.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_oai.cc @@ -5,26 +5,26 @@ #include "brave/components/ai_chat/core/browser/engine/engine_consumer_oai.h" -#include #include #include #include -#include +#include #include -#include "base/containers/fixed_flat_set.h" #include "base/functional/bind.h" +#include "base/functional/callback.h" +#include "base/functional/callback_helpers.h" #include "base/i18n/time_formatting.h" +#include "base/memory/scoped_refptr.h" #include "base/memory/weak_ptr.h" #include "base/strings/strcat.h" #include "base/strings/string_tokenizer.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" +#include "base/time/time.h" #include "base/types/expected.h" -#include "brave/components/ai_chat/core/browser/constants.h" +#include "base/values.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" -#include "brave/components/ai_chat/core/browser/engine/remote_completion_client.h" -#include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "components/grit/brave_components_strings.h" diff --git a/components/ai_chat/core/browser/engine/engine_consumer_oai.h b/components/ai_chat/core/browser/engine/engine_consumer_oai.h index 1b305817c5c9..8586fd045996 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_oai.h +++ b/components/ai_chat/core/browser/engine/engine_consumer_oai.h @@ -10,10 +10,15 @@ #include #include +#include "base/memory/weak_ptr.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" #include "brave/components/ai_chat/core/browser/engine/oai_api_client.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" + +template +class scoped_refptr; namespace api_request_helper { class APIRequestResult; diff --git a/components/ai_chat/core/browser/engine/engine_consumer_oai_unittest.cc b/components/ai_chat/core/browser/engine/engine_consumer_oai_unittest.cc index c2f3c2be5399..6d044412cc51 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_oai_unittest.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_oai_unittest.cc @@ -5,26 +5,34 @@ #include "brave/components/ai_chat/core/browser/engine/engine_consumer_oai.h" -#include +#include #include -#include +#include +#include #include +#include "base/containers/checked_iterators.h" +#include "base/functional/callback.h" #include "base/functional/callback_helpers.h" #include "base/i18n/time_formatting.h" +#include "base/memory/scoped_refptr.h" +#include "base/numerics/clamped_math.h" #include "base/run_loop.h" #include "base/strings/string_util.h" +#include "base/strings/utf_string_conversions.h" #include "base/test/bind.h" #include "base/test/task_environment.h" #include "base/time/time.h" +#include "base/values.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" #include "brave/components/ai_chat/core/browser/engine/test_utils.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" #include "components/grit/brave_components_strings.h" #include "services/network/public/cpp/shared_url_loader_factory.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" #include "ui/base/l10n/l10n_util.h" +#include "url/gurl.h" using ::testing::_; using ::testing::Sequence; @@ -238,6 +246,7 @@ TEST_F(EngineConsumerOAIUnitTest, // Push a single user turn into the history. history.push_back(mojom::ConversationTurn::New( + std::nullopt, mojom::CharacterType::HUMAN, // Author is the user mojom::ActionType::UNSPECIFIED, // No specific action mojom::ConversationTurnVisibility::VISIBLE, // Visible to the user @@ -298,14 +307,16 @@ TEST_F(EngineConsumerOAIUnitTest, std::string expected_system_message = "This is a custom system prompt."; history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::SUMMARIZE_SELECTED_TEXT, + std::nullopt, mojom::CharacterType::HUMAN, + mojom::ActionType::SUMMARIZE_SELECTED_TEXT, mojom::ConversationTurnVisibility::VISIBLE, human_input, selected_text, std::nullopt, base::Time::Now(), std::nullopt, false)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, assistant_input, std::nullopt, - std::nullopt, base::Time::Now(), std::nullopt, false)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + assistant_input, std::nullopt, std::nullopt, base::Time::Now(), + std::nullopt, false)); auto* client = GetClient(); auto run_loop = std::make_unique(); @@ -410,10 +421,10 @@ TEST_F(EngineConsumerOAIUnitTest, GenerateAssistantResponseEarlyReturn) { testing::Mock::VerifyAndClearExpectations(client); mojom::ConversationTurnPtr entry = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, - std::vector{}, base::Time::Now(), - std::nullopt, false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "", std::nullopt, std::vector{}, + base::Time::Now(), std::nullopt, false); entry->events->push_back(mojom::ConversationEntryEvent::NewCompletionEvent( mojom::CompletionEvent::New("Me"))); history.push_back(std::move(entry)); diff --git a/components/ai_chat/core/browser/engine/mock_engine_consumer.h b/components/ai_chat/core/browser/engine/mock_engine_consumer.h index 245d92711f4a..747bb90def33 100644 --- a/components/ai_chat/core/browser/engine/mock_engine_consumer.h +++ b/components/ai_chat/core/browser/engine/mock_engine_consumer.h @@ -6,10 +6,15 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_MOCK_ENGINE_CONSUMER_H_ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_MOCK_ENGINE_CONSUMER_H_ +#include + #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" #include "testing/gmock/include/gmock/gmock.h" namespace ai_chat { +namespace mojom { +class ModelOptions; +} // namespace mojom class MockEngineConsumer : public EngineConsumer { public: diff --git a/components/ai_chat/core/browser/engine/mock_remote_completion_client.cc b/components/ai_chat/core/browser/engine/mock_remote_completion_client.cc index 0726193d04f2..bffe902c5cd4 100644 --- a/components/ai_chat/core/browser/engine/mock_remote_completion_client.cc +++ b/components/ai_chat/core/browser/engine/mock_remote_completion_client.cc @@ -4,7 +4,10 @@ * You can obtain one at https://mozilla.org/MPL/2.0/. */ #include "brave/components/ai_chat/core/browser/engine/mock_remote_completion_client.h" -#include + +#include + +#include "base/memory/scoped_refptr.h" #include "services/network/public/cpp/shared_url_loader_factory.h" namespace ai_chat { diff --git a/components/ai_chat/core/browser/engine/oai_api_client.cc b/components/ai_chat/core/browser/engine/oai_api_client.cc index 6faabe568fbb..143e68327f50 100644 --- a/components/ai_chat/core/browser/engine/oai_api_client.cc +++ b/components/ai_chat/core/browser/engine/oai_api_client.cc @@ -5,20 +5,28 @@ #include "brave/components/ai_chat/core/browser/engine/oai_api_client.h" +#include +#include +#include +#include +#include +#include + +#include "base/containers/flat_map.h" #include "base/functional/bind.h" #include "base/json/json_writer.h" +#include "base/logging.h" +#include "base/memory/scoped_refptr.h" #include "base/memory/weak_ptr.h" +#include "base/metrics/field_trial_params.h" #include "base/strings/strcat.h" -#include "base/strings/string_util.h" #include "base/types/expected.h" #include "brave/components/ai_chat/core/common/features.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/constants/brave_services_key.h" #include "net/http/http_request_headers.h" -#include "net/http/http_status_code.h" #include "net/traffic_annotation/network_traffic_annotation.h" #include "services/network/public/cpp/shared_url_loader_factory.h" -#include "services/network/public/cpp/simple_url_loader.h" #include "url/gurl.h" namespace ai_chat { @@ -171,6 +179,7 @@ void OAIAPIClient::OnQueryDataReceived( if (choices->front().is_dict()) { const base::Value::Dict* delta = choices->front().GetDict().FindDict("delta"); + const std::string* content = delta->FindString("content"); if (content) { diff --git a/components/ai_chat/core/browser/engine/oai_api_client.h b/components/ai_chat/core/browser/engine/oai_api_client.h index 0a55ec671842..def1b5d726ad 100644 --- a/components/ai_chat/core/browser/engine/oai_api_client.h +++ b/components/ai_chat/core/browser/engine/oai_api_client.h @@ -10,8 +10,15 @@ #include #include +#include "base/functional/callback.h" +#include "base/memory/scoped_refptr.h" +#include "base/memory/weak_ptr.h" +#include "base/types/expected.h" +#include "base/values.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" +#include "brave/components/ai_chat/core/browser/engine/remote_completion_client.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +#include "brave/components/api_request_helper/api_request_helper.h" namespace api_request_helper { class APIRequestResult; @@ -22,6 +29,9 @@ class SharedURLLoaderFactory; } // namespace network namespace ai_chat { +namespace mojom { +class CustomModelOptions; +} // namespace mojom // Performs remote request to the OAI format APIs. class OAIAPIClient { diff --git a/components/ai_chat/core/browser/engine/oai_api_client_unittest.cc b/components/ai_chat/core/browser/engine/oai_api_client_unittest.cc index 8e873d83edce..7c5db050bcf9 100644 --- a/components/ai_chat/core/browser/engine/oai_api_client_unittest.cc +++ b/components/ai_chat/core/browser/engine/oai_api_client_unittest.cc @@ -5,23 +5,25 @@ #include "brave/components/ai_chat/core/browser/engine/oai_api_client.h" -#include +#include #include #include -#include +#include +#include #include +#include "base/containers/flat_map.h" #include "base/functional/bind.h" -#include "base/functional/callback_helpers.h" #include "base/json/json_reader.h" #include "base/json/json_writer.h" +#include "base/memory/scoped_refptr.h" #include "base/run_loop.h" -#include "base/strings/string_util.h" #include "base/test/task_environment.h" #include "base/types/expected.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/api_request_helper/api_request_helper.h" +#include "mojo/public/cpp/bindings/struct_ptr.h" #include "net/base/net_errors.h" #include "net/http/http_request_headers.h" #include "net/traffic_annotation/network_traffic_annotation.h" @@ -29,6 +31,7 @@ #include "services/network/public/cpp/shared_url_loader_factory.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" +#include "url/gurl.h" using ConversationHistory = std::vector; using ::testing::_; diff --git a/components/ai_chat/core/browser/engine/remote_completion_client.cc b/components/ai_chat/core/browser/engine/remote_completion_client.cc index 1161c9e6ca5d..7191c7360e64 100644 --- a/components/ai_chat/core/browser/engine/remote_completion_client.cc +++ b/components/ai_chat/core/browser/engine/remote_completion_client.cc @@ -7,19 +7,26 @@ #include +#include #include +#include #include +#include #include +#include "base/check.h" #include "base/containers/flat_set.h" #include "base/functional/bind.h" -#include "base/functional/callback_helpers.h" #include "base/json/json_writer.h" -#include "base/no_destructor.h" +#include "base/logging.h" +#include "base/memory/scoped_refptr.h" +#include "base/metrics/field_trial_params.h" +#include "base/numerics/clamped_math.h" #include "base/strings/strcat.h" +#include "base/strings/string_util.h" #include "base/values.h" #include "brave/brave_domains/service_domains.h" -#include "brave/components/ai_chat/core/browser/constants.h" +#include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/common/buildflags/buildflags.h" #include "brave/components/ai_chat/core/common/features.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" @@ -28,8 +35,8 @@ #include "net/http/http_status_code.h" #include "net/traffic_annotation/network_traffic_annotation.h" #include "services/network/public/cpp/shared_url_loader_factory.h" -#include "services/network/public/cpp/simple_url_loader.h" #include "url/gurl.h" +#include "url/url_constants.h" namespace ai_chat { namespace { diff --git a/components/ai_chat/core/browser/engine/remote_completion_client.h b/components/ai_chat/core/browser/engine/remote_completion_client.h index 9f150423fc6a..876babd910f7 100644 --- a/components/ai_chat/core/browser/engine/remote_completion_client.h +++ b/components/ai_chat/core/browser/engine/remote_completion_client.h @@ -13,18 +13,29 @@ #include #include "base/containers/flat_set.h" +#include "base/functional/callback.h" #include "base/functional/callback_forward.h" +#include "base/functional/callback_helpers.h" +#include "base/memory/raw_ptr.h" #include "base/memory/weak_ptr.h" #include "base/types/expected.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/api_request_helper/api_request_helper.h" +namespace base { +class Value; +} // namespace base +template +class scoped_refptr; + namespace network { class SharedURLLoaderFactory; } // namespace network namespace ai_chat { +class AIChatCredentialManager; +struct CredentialCacheEntry; using api_request_helper::APIRequestResult; diff --git a/components/ai_chat/core/browser/engine/test_utils.cc b/components/ai_chat/core/browser/engine/test_utils.cc index a1ace67692b0..58c15e6b0c26 100644 --- a/components/ai_chat/core/browser/engine/test_utils.cc +++ b/components/ai_chat/core/browser/engine/test_utils.cc @@ -8,15 +8,15 @@ #include #include +#include "base/numerics/clamped_math.h" #include "base/time/time.h" -#include "testing/gtest/include/gtest/gtest.h" namespace ai_chat { std::vector GetHistoryWithModifiedReply() { std::vector history; history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "Which show is 'This is the way' from?", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); @@ -34,18 +34,19 @@ std::vector GetHistoryWithModifiedReply() { mojom::CompletionEvent::New("The Mandalorian"))); auto edit = mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "The Mandalorian.", - std::nullopt, std::move(modified_events), base::Time::Now(), std::nullopt, - false); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "The Mandalorian.", std::nullopt, std::move(modified_events), + base::Time::Now(), std::nullopt, false); std::vector edits; edits.push_back(std::move(edit)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, - mojom::ConversationTurnVisibility::VISIBLE, "Mandalorian.", std::nullopt, - std::move(events), base::Time::Now(), std::move(edits), false)); + std::nullopt, mojom::CharacterType::ASSISTANT, + mojom::ActionType::RESPONSE, mojom::ConversationTurnVisibility::VISIBLE, + "Mandalorian.", std::nullopt, std::move(events), base::Time::Now(), + std::move(edits), false)); history.push_back(mojom::ConversationTurn::New( - mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + std::nullopt, mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, mojom::ConversationTurnVisibility::VISIBLE, "Is it related to a broader series?", std::nullopt, std::nullopt, base::Time::Now(), std::nullopt, false)); diff --git a/components/ai_chat/core/browser/local_models_updater.cc b/components/ai_chat/core/browser/local_models_updater.cc index 32695fb3cb3f..3addd8616297 100644 --- a/components/ai_chat/core/browser/local_models_updater.cc +++ b/components/ai_chat/core/browser/local_models_updater.cc @@ -5,19 +5,30 @@ #include "brave/components/ai_chat/core/browser/local_models_updater.h" +#include #include +#include +#include #include #include "base/check_is_test.h" +#include "base/compiler_specific.h" #include "base/files/file_path.h" #include "base/files/file_util.h" +#include "base/functional/bind.h" +#include "base/memory/scoped_refptr.h" #include "base/no_destructor.h" #include "base/path_service.h" #include "brave/components/ai_chat/core/common/features.h" #include "brave/components/brave_component_updater/browser/brave_on_demand_updater.h" #include "components/component_updater/component_updater_paths.h" +#include "components/update_client/update_client_errors.h" #include "crypto/sha2.h" +namespace base { +class Version; +} // namespace base + namespace ai_chat { namespace { diff --git a/components/ai_chat/core/browser/local_models_updater.h b/components/ai_chat/core/browser/local_models_updater.h index 90a1298bb957..c281e8231e2f 100644 --- a/components/ai_chat/core/browser/local_models_updater.h +++ b/components/ai_chat/core/browser/local_models_updater.h @@ -6,11 +6,21 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_LOCAL_MODELS_UPDATER_H_ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_LOCAL_MODELS_UPDATER_H_ +#include #include #include +#include "base/files/file_path.h" #include "base/no_destructor.h" +#include "base/values.h" #include "components/component_updater/component_installer.h" +#include "components/update_client/update_client.h" + +namespace base { +class Version; +template +class NoDestructor; +} // namespace base namespace component_updater { class ComponentUpdateService; diff --git a/components/ai_chat/core/browser/local_models_updater_unittest.cc b/components/ai_chat/core/browser/local_models_updater_unittest.cc index 182035e95c6d..a915db0c6265 100644 --- a/components/ai_chat/core/browser/local_models_updater_unittest.cc +++ b/components/ai_chat/core/browser/local_models_updater_unittest.cc @@ -6,8 +6,10 @@ #include "brave/components/ai_chat/core/browser/local_models_updater.h" #include -#include +#include +#include +#include "base/feature_list.h" #include "base/files/file_path.h" #include "base/files/file_util.h" #include "base/path_service.h" diff --git a/components/ai_chat/core/browser/mock_conversation_handler_observer.cc b/components/ai_chat/core/browser/mock_conversation_handler_observer.cc new file mode 100644 index 000000000000..2b0577906eea --- /dev/null +++ b/components/ai_chat/core/browser/mock_conversation_handler_observer.cc @@ -0,0 +1,18 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/ai_chat/core/browser/mock_conversation_handler_observer.h" + +namespace ai_chat { + +MockConversationHandlerObserver::MockConversationHandlerObserver() = default; +MockConversationHandlerObserver::~MockConversationHandlerObserver() = default; + +void MockConversationHandlerObserver::Observe( + ConversationHandler* conversation) { + conversation_observations_.AddObservation(conversation); +} + +} // namespace ai_chat diff --git a/components/ai_chat/core/browser/mock_conversation_handler_observer.h b/components/ai_chat/core/browser/mock_conversation_handler_observer.h new file mode 100644 index 000000000000..b944d74382f3 --- /dev/null +++ b/components/ai_chat/core/browser/mock_conversation_handler_observer.h @@ -0,0 +1,66 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_MOCK_CONVERSATION_HANDLER_OBSERVER_H_ +#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_MOCK_CONVERSATION_HANDLER_OBSERVER_H_ + +#include +#include + +#include "base/scoped_multi_source_observation.h" +#include "brave/components/ai_chat/core/browser/conversation_handler.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "testing/gmock/include/gmock/gmock.h" + +namespace ai_chat { + +class MockConversationHandlerObserver : public ConversationHandler::Observer { + public: + MockConversationHandlerObserver(); + ~MockConversationHandlerObserver() override; + + void Observe(ConversationHandler* conversation); + + MOCK_METHOD(void, + OnRequestInProgressChanged, + (ConversationHandler * handler, bool in_progress), + (override)); + + MOCK_METHOD(void, + OnConversationEntryAdded, + (ConversationHandler * handler, + mojom::ConversationTurnPtr& entry, + std::optional associated_content_value), + (override)); + + MOCK_METHOD(void, + OnConversationEntryRemoved, + (ConversationHandler * handler, std::string turn_uuid), + (override)); + + MOCK_METHOD(void, + OnConversationEntryUpdated, + (ConversationHandler * handler, mojom::ConversationTurnPtr entry), + (override)); + + MOCK_METHOD(void, + OnClientConnectionChanged, + (ConversationHandler*), + (override)); + + MOCK_METHOD(void, + OnConversationTitleChanged, + (ConversationHandler * handler, std::string title), + (override)); + + private: + base::ScopedMultiSourceObservation + conversation_observations_{this}; +}; + +} // namespace ai_chat + +#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_MOCK_CONVERSATION_HANDLER_OBSERVER_H_ diff --git a/components/ai_chat/core/browser/model_service.cc b/components/ai_chat/core/browser/model_service.cc index 26ebdfbeb2ab..580e10c083b4 100644 --- a/components/ai_chat/core/browser/model_service.cc +++ b/components/ai_chat/core/browser/model_service.cc @@ -6,16 +6,30 @@ #include "brave/components/ai_chat/core/browser/model_service.h" #include +#include +#include #include +#include +#include +#include +#include #include #include "base/base64.h" +#include "base/check.h" +#include "base/containers/checked_iterators.h" #include "base/containers/contains.h" +#include "base/logging.h" +#include "base/memory/scoped_refptr.h" +#include "base/metrics/field_trial_params.h" #include "base/no_destructor.h" +#include "base/numerics/safe_math.h" #include "base/strings/strcat.h" +#include "base/strings/string_util.h" #include "base/uuid.h" #include "base/values.h" #include "brave/components/ai_chat/core/browser/constants.h" +#include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer_claude.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer_conversation_api.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer_llama.h" @@ -26,11 +40,15 @@ #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/ai_chat/core/common/pref_names.h" +#include "cc/task/core/task_utils.h" #include "components/os_crypt/sync/os_crypt.h" +#include "components/prefs/pref_registry_simple.h" #include "components/prefs/pref_service.h" #include "services/network/public/cpp/shared_url_loader_factory.h" +#include "url/gurl.h" namespace ai_chat { +class AIChatCredentialManager; namespace { constexpr char kDefaultModelKey[] = "brave.ai_chat.default_model_key"; diff --git a/components/ai_chat/core/browser/model_service.h b/components/ai_chat/core/browser/model_service.h index f791c9e123ea..d5637476b266 100644 --- a/components/ai_chat/core/browser/model_service.h +++ b/components/ai_chat/core/browser/model_service.h @@ -6,24 +6,34 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_MODEL_SERVICE_H_ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_MODEL_SERVICE_H_ +#include + +#include #include #include #include #include +#include "base/memory/raw_ptr.h" +#include "base/memory/scoped_refptr.h" +#include "base/memory/weak_ptr.h" #include "base/observer_list.h" +#include "base/observer_list_types.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "components/keyed_service/core/keyed_service.h" #include "components/prefs/pref_registry_simple.h" +class PrefRegistrySimple; +class PrefService; + namespace network { class SharedURLLoaderFactory; } namespace ai_chat { class EngineConsumer; +class AIChatCredentialManager; class ModelService : public KeyedService { public: diff --git a/components/ai_chat/core/browser/model_service_unittest.cc b/components/ai_chat/core/browser/model_service_unittest.cc index 09ced41f80d0..4e8f6384478b 100644 --- a/components/ai_chat/core/browser/model_service_unittest.cc +++ b/components/ai_chat/core/browser/model_service_unittest.cc @@ -5,21 +5,24 @@ #include "brave/components/ai_chat/core/browser/model_service.h" -#include #include +#include #include +#include "base/metrics/field_trial_params.h" +#include "base/numerics/safe_math.h" #include "base/scoped_observation.h" #include "base/test/scoped_feature_list.h" #include "brave/components/ai_chat/core/browser/constants.h" #include "brave/components/ai_chat/core/browser/model_validator.h" #include "brave/components/ai_chat/core/common/features.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "brave/components/ai_chat/core/common/pref_names.h" #include "components/os_crypt/sync/os_crypt_mocker.h" #include "components/prefs/testing_pref_service.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" +#include "url/gurl.h" namespace ai_chat { diff --git a/components/ai_chat/core/browser/model_validator.cc b/components/ai_chat/core/browser/model_validator.cc index 69bfac162c41..d4829968f481 100644 --- a/components/ai_chat/core/browser/model_validator.cc +++ b/components/ai_chat/core/browser/model_validator.cc @@ -5,10 +5,12 @@ #include "brave/components/ai_chat/core/browser/model_validator.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +#include "base/numerics/safe_math.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/net/base/url_util.h" +class GURL; + namespace ai_chat { // Static diff --git a/components/ai_chat/core/browser/model_validator.h b/components/ai_chat/core/browser/model_validator.h index 06d3a29a10aa..1e8be56e48d3 100644 --- a/components/ai_chat/core/browser/model_validator.h +++ b/components/ai_chat/core/browser/model_validator.h @@ -6,6 +6,9 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_MODEL_VALIDATOR_H_ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_MODEL_VALIDATOR_H_ +#include + +#include #include #include "brave/components/ai_chat/core/browser/constants.h" @@ -13,14 +16,19 @@ #include "brave/net/base/url_util.h" #include "url/url_constants.h" +class GURL; + namespace ai_chat { +namespace mojom { +class CustomModelOptions; +} // namespace mojom // The declared context size needs to be large enough to accommodate expected // reserves (i.e., prompt tokens and max new tokens) -constexpr size_t kMinCustomModelContextSize = +inline constexpr size_t kMinCustomModelContextSize = kReservedTokensForMaxNewTokens + kReservedTokensForPrompt; -constexpr size_t kMaxCustomModelContextSize = 2'000'000; -constexpr size_t kDefaultCustomModelContextSize = 4000; +inline constexpr size_t kMaxCustomModelContextSize = 2'000'000; +inline constexpr size_t kDefaultCustomModelContextSize = 4000; enum class ModelValidationResult { kSuccess, diff --git a/components/ai_chat/core/browser/model_validator_unittest.cc b/components/ai_chat/core/browser/model_validator_unittest.cc index 259457d2d88b..62ff294aff1f 100644 --- a/components/ai_chat/core/browser/model_validator_unittest.cc +++ b/components/ai_chat/core/browser/model_validator_unittest.cc @@ -5,13 +5,16 @@ #include "brave/components/ai_chat/core/browser/model_validator.h" -#include #include #include +#include #include +#include "base/numerics/checked_math.h" +#include "base/strings/string_number_conversions.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" -#include "testing/gmock/include/gmock/gmock.h" +#include "mojo/public/cpp/bindings/struct_ptr.h" #include "testing/gtest/include/gtest/gtest.h" #include "url/gurl.h" diff --git a/components/ai_chat/core/browser/test_utils.cc b/components/ai_chat/core/browser/test_utils.cc new file mode 100644 index 000000000000..e37d712597ef --- /dev/null +++ b/components/ai_chat/core/browser/test_utils.cc @@ -0,0 +1,228 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/ai_chat/core/browser/test_utils.h" + +#include +#include + +#include "base/strings/strcat.h" +#include "base/time/time.h" +#include "base/uuid.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace ai_chat { + +namespace { + +std::string MessageConversationEntryEvents( + const mojom::ConversationTurnPtr& entry) { + std::string message = "Entry has the following events:"; + if (!entry->events.has_value()) { + message = base::StrCat({message, "\nNo events"}); + return message; + } + for (const auto& event : entry->events.value()) { + switch (event->which()) { + case mojom::ConversationEntryEvent::Tag::kCompletionEvent: { + message = base::StrCat({message, "\n - completion: ", + event->get_completion_event()->completion}); + break; + } + case mojom::ConversationEntryEvent::Tag::kSearchQueriesEvent: { + message = base::StrCat({message, "\n - search event"}); + break; + } + case mojom::ConversationEntryEvent::Tag::kConversationTitleEvent: { + message = base::StrCat({message, "\n - title: ", + event->get_conversation_title_event()->title}); + break; + } + case mojom::ConversationEntryEvent::Tag::kPageContentRefineEvent: { + message = base::StrCat({message, "\n - content refine event"}); + break; + } + default: + message = base::StrCat({message, "\n - unknown event"}); + } + } + return message; +} + +} // namespace + +void ExpectConversationEquals(base::Location location, + const mojom::ConversationPtr& a, + const mojom::ConversationPtr& b) { + SCOPED_TRACE(testing::Message() << location.ToString()); + if (!a || !b) { + EXPECT_EQ(a, b); // Both should be null or neither + return; + } + EXPECT_EQ(a->uuid, b->uuid); + EXPECT_EQ(a->title, b->title); + EXPECT_EQ(a->has_content, b->has_content); + + // associated content + ExpectAssociatedContentEquals(FROM_HERE, a->associated_content, + b->associated_content); +} + +void ExpectAssociatedContentEquals(base::Location location, + const mojom::SiteInfoPtr& a, + const mojom::SiteInfoPtr& b) { + SCOPED_TRACE(testing::Message() << location.ToString()); + if (!a || !b) { + EXPECT_EQ(a, b); // Both should be null or neither + return; + } + EXPECT_EQ(a->uuid, b->uuid); + EXPECT_EQ(a->title, b->title); + EXPECT_EQ(a->url, b->url); + EXPECT_EQ(a->content_type, b->content_type); + EXPECT_EQ(a->content_used_percentage, b->content_used_percentage); + EXPECT_EQ(a->is_content_refined, b->is_content_refined); + EXPECT_EQ(a->is_content_association_possible, + b->is_content_association_possible); +} + +void ExpectConversationHistoryEquals( + base::Location location, + const std::vector& a, + const std::vector& b, + bool compare_uuid) { + SCOPED_TRACE(testing::Message() << location.ToString()); + EXPECT_EQ(a.size(), b.size()); + for (auto i = 0u; i < a.size(); i++) { + SCOPED_TRACE(testing::Message() << "Comparing entries at index " << i); + ExpectConversationEntryEquals(FROM_HERE, a.at(i), b.at(i), compare_uuid); + } +} + +void ExpectConversationEntryEquals(base::Location location, + const mojom::ConversationTurnPtr& a, + const mojom::ConversationTurnPtr& b, + bool compare_uuid) { + SCOPED_TRACE(testing::Message() << location.ToString()); + if (!a || !b) { + EXPECT_EQ(a, b); // Both should be null or neither + return; + } + + if (compare_uuid) { + EXPECT_EQ(a->uuid.value_or("default"), b->uuid.value_or("default")); + } + + EXPECT_EQ(a->action_type, b->action_type); + EXPECT_EQ(a->character_type, b->character_type); + EXPECT_EQ(a->selected_text, b->selected_text); + EXPECT_EQ(a->text, b->text); + EXPECT_EQ(a->visibility, b->visibility); + + // compare events + EXPECT_EQ(a->events.has_value(), b->events.has_value()); + if (a->events.has_value()) { + EXPECT_EQ(a->events->size(), b->events->size()) + << "\nEvents for a. " << MessageConversationEntryEvents(a) + << "\nEvents for b. " << MessageConversationEntryEvents(b); + for (auto i = 0u; i < a->events->size(); i++) { + SCOPED_TRACE(testing::Message() << "Comparing events at index " << i); + auto& a_event = a->events->at(i); + auto& b_event = b->events->at(i); + EXPECT_EQ(a_event->which(), b_event->which()); + switch (a_event->which()) { + case mojom::ConversationEntryEvent::Tag::kCompletionEvent: { + EXPECT_EQ(a_event->get_completion_event()->completion, + b_event->get_completion_event()->completion); + break; + } + case mojom::ConversationEntryEvent::Tag::kSearchQueriesEvent: { + EXPECT_EQ(a_event->get_search_queries_event()->search_queries, + b_event->get_search_queries_event()->search_queries); + break; + } + default: + NOTREACHED() + << "Unexpected event type for comparison. Only know about " + "event types which are not discarded."; + } + } + } + + // compare edits + EXPECT_EQ(a->edits.has_value(), b->edits.has_value()); + if (a->edits.has_value()) { + EXPECT_EQ(a->edits->size(), b->edits->size()); + for (auto i = 0u; i < a->edits->size(); i++) { + SCOPED_TRACE(testing::Message() << "Comparing edits at index " << i); + auto& a_edit = a->edits->at(i); + auto& b_edit = b->edits->at(i); + ExpectConversationEntryEquals(FROM_HERE, a_edit, b_edit, compare_uuid); + } + } +} + +mojom::Conversation* GetConversation( + base::Location location, + const std::vector& conversations, + std::string uuid) { + SCOPED_TRACE(testing::Message() << location.ToString()); + auto it = std::find_if(conversations.begin(), conversations.end(), + [&uuid](const mojom::ConversationPtr& conversation) { + return conversation->uuid == uuid; + }); + EXPECT_NE(it, conversations.end()); + return it->get(); +} + +std::vector CreateSampleChatHistory( + size_t num_query_pairs, + int32_t future_hours) { + std::vector history; + base::Time now = base::Time::Now(); + for (size_t i = 0; i < num_query_pairs; i++) { + // query + history.push_back(mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), + mojom::CharacterType::HUMAN, mojom::ActionType::QUERY, + mojom::ConversationTurnVisibility::VISIBLE, + base::StrCat({"query", base::NumberToString(i)}), std::nullopt, + std::nullopt, now + base::Seconds(i * 60) + base::Hours(future_hours), + std::nullopt, false)); + // response + std::vector events; + events.emplace_back(mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New(base::StrCat( + {"This is a generated response ", base::NumberToString(i)})))); + events.emplace_back(mojom::ConversationEntryEvent::NewCompletionEvent( + mojom::CompletionEvent::New(base::StrCat( + {"and this is more response", base::NumberToString(i)})))); + events.emplace_back(mojom::ConversationEntryEvent::NewSearchQueriesEvent( + mojom::SearchQueriesEvent::New(std::vector{ + base::StrCat({"Something to search for", base::NumberToString(i)}), + base::StrCat({"Another search query", base::NumberToString(i)})}))); + history.push_back(mojom::ConversationTurn::New( + base::Uuid::GenerateRandomV4().AsLowercaseString(), + mojom::CharacterType::ASSISTANT, mojom::ActionType::RESPONSE, + mojom::ConversationTurnVisibility::VISIBLE, "", std::nullopt, + std::move(events), + now + base::Seconds((i * 60) + 30) + base::Hours(future_hours), + std::nullopt, false)); + } + return history; +} + +std::vector CloneHistory( + std::vector& history) { + std::vector cloned_history; + for (const auto& turn : history) { + cloned_history.push_back(turn->Clone()); + } + return cloned_history; +} + +} // namespace ai_chat diff --git a/components/ai_chat/core/browser/test_utils.h b/components/ai_chat/core/browser/test_utils.h new file mode 100644 index 000000000000..a56716508e91 --- /dev/null +++ b/components/ai_chat/core/browser/test_utils.h @@ -0,0 +1,50 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_TEST_UTILS_H_ +#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_TEST_UTILS_H_ + +#include +#include + +#include "base/location.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" + +namespace ai_chat { + +void ExpectConversationEquals(base::Location location, + const mojom::ConversationPtr& a, + const mojom::ConversationPtr& b); + +void ExpectAssociatedContentEquals(base::Location location, + const mojom::SiteInfoPtr& a, + const mojom::SiteInfoPtr& b); + +void ExpectConversationEntryEquals(base::Location location, + const mojom::ConversationTurnPtr& a, + const mojom::ConversationTurnPtr& b, + bool compare_uuid = true); + +void ExpectConversationHistoryEquals( + base::Location location, + const std::vector& a, + const std::vector& b, + bool compare_uuid = true); + +mojom::Conversation* GetConversation( + base::Location location, + const std::vector& conversations, + std::string uuid); + +std::vector CreateSampleChatHistory( + size_t num_query_pairs, + int32_t future_hours = 0); + +std::vector CloneHistory( + std::vector& history); + +} // namespace ai_chat + +#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_TEST_UTILS_H_ diff --git a/components/ai_chat/core/browser/text_embedder.cc b/components/ai_chat/core/browser/text_embedder.cc index 04f5c40b1f0a..74bae07ee4af 100644 --- a/components/ai_chat/core/browser/text_embedder.cc +++ b/components/ai_chat/core/browser/text_embedder.cc @@ -6,18 +6,31 @@ #include "brave/components/ai_chat/core/browser/text_embedder.h" #include +#include +#include +#include +#include +#include +#include +#include "base/check.h" #include "base/check_is_test.h" +#include "base/containers/span.h" #include "base/files/file_util.h" +#include "base/functional/bind.h" #include "base/hash/hash.h" +#include "base/location.h" #include "base/logging.h" -#include "base/memory/ptr_util.h" -#include "base/metrics/histogram_functions.h" +#include "base/metrics/histogram_functions_internal_overloads.h" #include "base/strings/strcat.h" #include "base/strings/string_split.h" #include "base/task/bind_post_task.h" #include "base/timer/elapsed_timer.h" -#include "third_party/tflite/src/tensorflow/lite/core/api/op_resolver.h" +#include "cc/port/statusor.h" +#include "cc/task/text/text_embedder.h" +#include "tensorflow_lite_support/cc/task/core/proto/base_options.pb.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file.pb.h" +#include "tensorflow_lite_support/cc/task/text/proto/text_embedder_options.pb.h" #include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/utils/text_op_resolver.h" using TFLiteTextEmbedder = tflite::task::text::TextEmbedder; @@ -141,7 +154,7 @@ void TextEmbedder::GetTopSimilarityWithPromptTilContextLimitInternal( base::unexpected("TextEmbedder is not initialized.")); return; } - auto text_hash = base::FastHash(base::as_bytes(base::make_span(text))); + auto text_hash = base::FastHash(base::as_byte_span(text)); if (text_hash != text_hash_) { text_hash_ = text_hash; segments_ = SplitSegments(text); diff --git a/components/ai_chat/core/browser/text_embedder.h b/components/ai_chat/core/browser/text_embedder.h index 854dba200418..62afa22569d6 100644 --- a/components/ai_chat/core/browser/text_embedder.h +++ b/components/ai_chat/core/browser/text_embedder.h @@ -6,6 +6,9 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_TEXT_EMBEDDER_H_ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_TEXT_EMBEDDER_H_ +#include + +#include #include #include #include @@ -18,8 +21,18 @@ #include "base/synchronization/lock.h" #include "base/task/sequenced_task_runner.h" #include "base/types/expected.h" +#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h" +#include "third_party/abseil-cpp/absl/status/status.h" #include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/text_embedder.h" +namespace tflite { +namespace task { +namespace text { +class TextEmbedder; +} // namespace text +} // namespace task +} // namespace tflite + namespace base { class SequencedTaskRunner; } // namespace base diff --git a/components/ai_chat/core/browser/text_embedder_unittest.cc b/components/ai_chat/core/browser/text_embedder_unittest.cc index 35d2c9e301d2..f85bb278589c 100644 --- a/components/ai_chat/core/browser/text_embedder_unittest.cc +++ b/components/ai_chat/core/browser/text_embedder_unittest.cc @@ -11,17 +11,25 @@ #include "brave/components/ai_chat/core/browser/text_embedder.h" +#include + +#include +#include +#include + #include "base/files/file_path.h" #include "base/location.h" #include "base/path_service.h" #include "base/run_loop.h" #include "base/strings/string_number_conversions.h" #include "base/task/sequenced_task_runner.h" +#include "base/task/task_traits.h" #include "base/task/thread_pool.h" #include "base/test/bind.h" #include "base/test/task_environment.h" #include "brave/components/ai_chat/core/browser/local_models_updater.h" #include "brave/components/constants/brave_paths.h" +#include "build/build_config.h" #include "testing/gtest/include/gtest/gtest.h" namespace ai_chat { diff --git a/components/ai_chat/core/browser/utils.cc b/components/ai_chat/core/browser/utils.cc index 1b1191aa6fe8..074f334732ef 100644 --- a/components/ai_chat/core/browser/utils.cc +++ b/components/ai_chat/core/browser/utils.cc @@ -6,29 +6,39 @@ #include "brave/components/ai_chat/core/browser/utils.h" #include +#include #include #include -#include "base/containers/fixed_flat_set.h" +#include "base/check.h" +#include "base/containers/flat_map.h" #include "base/functional/bind.h" #include "base/no_destructor.h" #include "base/strings/string_util.h" -#include "base/task/bind_post_task.h" -#include "base/task/thread_pool.h" #include "base/time/time.h" #include "brave/brave_domains/service_domains.h" -#include "brave/components/ai_chat/core/browser/constants.h" #include "brave/components/ai_chat/core/common/constants.h" #include "brave/components/ai_chat/core/common/features.h" -#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" #include "brave/components/ai_chat/core/common/pref_names.h" -#include "brave/components/l10n/common/locale_util.h" +#include "components/grit/brave_components_strings.h" #include "components/prefs/pref_service.h" -#include "components/user_prefs/user_prefs.h" +#include "mojo/public/cpp/bindings/struct_ptr.h" #include "third_party/re2/src/re2/re2.h" #include "ui/base/l10n/l10n_util.h" +#include "url/gurl.h" #include "url/url_constants.h" +#if BUILDFLAG(IS_MAC) || BUILDFLAG(IS_WIN) +#include "base/task/task_traits.h" +#include "base/task/thread_pool.h" +#endif // BUILDFLAG(IS_MAC) || BUILDFLAG(IS_WIN) + +#if BUILDFLAG(IS_WIN) +#include "base/task/bind_post_task.h" +#include "brave/components/l10n/common/locale_util.h" +#endif // BUILDFLAG(IS_WIN) + #if BUILDFLAG(ENABLE_TEXT_RECOGNITION) #include "brave/components/text_recognition/browser/text_recognition.h" #endif diff --git a/components/ai_chat/core/browser/utils.h b/components/ai_chat/core/browser/utils.h index 87c99dc5779d..aa3e8de5c93c 100644 --- a/components/ai_chat/core/browser/utils.h +++ b/components/ai_chat/core/browser/utils.h @@ -6,6 +6,8 @@ #ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_UTILS_H_ #define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_UTILS_H_ +#include + #include "base/functional/callback_forward.h" #include "brave/components/ai_chat/core/browser/conversation_handler.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" @@ -15,8 +17,12 @@ #include "url/gurl.h" class PrefService; +class GURL; namespace ai_chat { +namespace mojom { +enum class ActionType : int32_t; +} // namespace mojom // Check both policy and feature flag to determine if AI Chat is enabled. bool IsAIChatEnabled(PrefService* prefs); diff --git a/components/ai_chat/core/browser/utils_unittest.cc b/components/ai_chat/core/browser/utils_unittest.cc index 8460551e6b37..4533cc9003d1 100644 --- a/components/ai_chat/core/browser/utils_unittest.cc +++ b/components/ai_chat/core/browser/utils_unittest.cc @@ -5,6 +5,9 @@ #include "brave/components/ai_chat/core/browser/utils.h" +#include +#include + #include "testing/gtest/include/gtest/gtest.h" #include "url/gurl.h" diff --git a/components/ai_chat/core/common/features.cc b/components/ai_chat/core/common/features.cc index 60b211f5d112..f839410236aa 100644 --- a/components/ai_chat/core/common/features.cc +++ b/components/ai_chat/core/common/features.cc @@ -9,6 +9,7 @@ #include "base/feature_list.h" #include "base/metrics/field_trial_params.h" +#include "build/build_config.h" namespace ai_chat::features { @@ -24,6 +25,8 @@ const base::FeatureParam kFreemiumAvailable(&kAIChat, "is_freemium_available", true); const base::FeatureParam kAIChatSSE{&kAIChat, "ai_chat_sse", true}; +const base::FeatureParam kOmniboxOpensFullPage{ + &kAIChat, "omnibox_opens_full_page", true}; const base::FeatureParam kConversationAPIEnabled{ &kAIChat, "conversation_api", true}; const base::FeatureParam kAITemperature{&kAIChat, "temperature", 0.2}; diff --git a/components/ai_chat/core/common/features.h b/components/ai_chat/core/common/features.h index ca454658d29b..a368185d3845 100644 --- a/components/ai_chat/core/common/features.h +++ b/components/ai_chat/core/common/features.h @@ -28,6 +28,8 @@ extern const base::FeatureParam kFreemiumAvailable; COMPONENT_EXPORT(AI_CHAT_COMMON) extern const base::FeatureParam kAIChatSSE; COMPONENT_EXPORT(AI_CHAT_COMMON) +extern const base::FeatureParam kOmniboxOpensFullPage; +COMPONENT_EXPORT(AI_CHAT_COMMON) extern const base::FeatureParam kConversationAPIEnabled; COMPONENT_EXPORT(AI_CHAT_COMMON) extern const base::FeatureParam kAITemperature; diff --git a/components/ai_chat/core/common/mojom/BUILD.gn b/components/ai_chat/core/common/mojom/BUILD.gn index b3e5bd6535d4..646ac80617bc 100644 --- a/components/ai_chat/core/common/mojom/BUILD.gn +++ b/components/ai_chat/core/common/mojom/BUILD.gn @@ -17,6 +17,7 @@ mojom_component("mojom") { "ai_chat.mojom", "page_content_extractor.mojom", "settings_helper.mojom", + "untrusted_frame.mojom", ] deps = [ diff --git a/components/ai_chat/core/common/mojom/ai_chat.mojom b/components/ai_chat/core/common/mojom/ai_chat.mojom index 29fd611aa4f6..d5b88bcf98e8 100644 --- a/components/ai_chat/core/common/mojom/ai_chat.mojom +++ b/components/ai_chat/core/common/mojom/ai_chat.mojom @@ -5,6 +5,7 @@ module ai_chat.mojom; +import "brave/components/ai_chat/core/common/mojom/untrusted_frame.mojom"; import "mojo/public/mojom/base/time.mojom"; import "url/mojom/url.mojom"; @@ -69,19 +70,59 @@ enum SuggestionGenerationStatus { HasGenerated, }; +// Type of content that is extracted +enum ContentType { + PageContent, + VideoTranscript, +}; + +// A piece of content associated with a conversation. +// TODO(petemill): Rename to AssociatedContent and have hostname/url/title be +// part of a detail property (which would have different structure possibilities +// depending on content type. struct SiteInfo { - // Title is present if the site has associated page content + string? uuid; + ContentType content_type; + // Web page specific fields, if available and allowed string? title; - // Indicates if the URL scheme is of a type that allows for content association. - bool is_content_association_possible; - // Current tab's url and hostname, if http(s) string? hostname; url.mojom.Url? url; + // Percentage of the content that has been utilized by remote LLM (0-100) int32 content_used_percentage; // Indicates content has been refined through extracting most relevant parts // from long page content bool is_content_refined; + + // Indicates if the URL scheme is of a type that allows for + // content association. + // TODO(petemill): SiteInfo should just not exist if content-associated is + // not possible. + bool is_content_association_possible; +}; + +struct Conversation { + string uuid; + // Set by the LLM or the user + string title; + // Time used for ordering purposes + mojo_base.mojom.Time updated_time; + // If there are entries and the conversation should be selectable + bool has_content; + // Model key, if different than default + string? model_key; + + SiteInfo associated_content; +}; + +struct ContentArchive { + string content_uuid; + string content; +}; + +struct ConversationArchive { + array entries; + array associated_content; }; enum ActionType { @@ -154,6 +195,9 @@ union ConversationEntryEvent { // The selected_text attribute contains what user selects in the page when // calling from the context menu. struct ConversationTurn { + // Populated if owned by a conversation + string? uuid; + CharacterType character_type; ActionType action_type; ConversationTurnVisibility visibility; @@ -253,19 +297,21 @@ struct ActionGroup { array entries; }; -struct Conversation { - string uuid; - // Set by the LLM or the user - string title; - // Time used for ordering purposes - mojo_base.mojom.Time created_time; - // If there are entries and the conversation should be selectable - bool has_content; +// This does not cover more specific data that the Service owns, such as the +// conversation list, but does cover status of preferences and notices. +struct ServiceState { + bool has_accepted_agreement; + bool is_storage_pref_enabled; + bool is_storage_notice_dismissed; + bool can_show_premium_prompt; }; interface Service { - // User opts-in to the feature at a profile level + // Profile-level acknowledgements MarkAgreementAccepted(); + EnableStoragePref(); + DismissStorageNotice(); + DismissPremiumPrompt(); // Get metadata for non-archived conversations GetVisibleConversations() => (array conversations); @@ -276,16 +322,11 @@ interface Service { // Current status of subscription GetPremiumStatus() => (PremiumStatus status, PremiumInfo? info); - // Premium prompt is only shown conditionally (e.g. the user hasn't dismissed - // it and it's been some time since the user started using the feature). - GetCanShowPremiumPrompt() => (bool can_show); - DismissPremiumPrompt(); - DeleteConversation(string id); RenameConversation(string id, string new_name); - // Send events to the UI - BindObserver(pending_remote ui); + // Bind ability to send events to the UI and receive current state + BindObserver(pending_remote ui) => (ServiceState state); // Bind specified Conversation for 2-way communication BindConversation( @@ -296,7 +337,7 @@ interface Service { interface ServiceObserver { OnConversationListChanged(array conversations); - OnAgreementAccepted(); + OnStateChanged(ServiceState state); }; // Browser-side handler for general AI Chat UI functions, implemented @@ -312,7 +353,7 @@ interface AIChatUIHandler { OpenConversationFullPage(string conversation_uuid); OpenURL(url.mojom.Url url); - OpenLearnMoreAboutBraveSearchWithLeo(); + OpenModelSupportUrl(); GoPremium(); RefreshPremiumSession(); @@ -322,7 +363,9 @@ interface AIChatUIHandler { // This might be a no-op if the UI isn't closeable CloseUI(); - SetChatUI(pending_remote chat_ui); + // Provide a reference of the UI to the UI handler and get some + // initial constant state + SetChatUI(pending_remote chat_ui) => (bool is_standalone); // Bind 2-way communication to the conversation related to the open page in // the current browser window. No binding will occur if this isn't a @@ -342,13 +385,12 @@ interface AIChatUIHandler { array? favicon_image_data); }; -// UI-side handler for whole AI Chat UI +// UI-side handler for messages from the browser WebUI interface ChatUI { - // Initial Data - SetInitialData(bool is_standalone); // Notifies that the default conversation for the // panel has changed. e.g. Tab navigation with AIChat open in sidebar. OnNewDefaultConversation(); + OnChildFrameBound(pending_receiver receiver); }; struct ConversationState { @@ -365,6 +407,22 @@ struct ConversationState { // `OnConversationHistoryUpdate` is more intelligent (see TOOD in definition). }; +// State required to show the conversations entries UI block +struct ConversationEntriesState { + // Whether an answer generation is in progress + bool is_generating; + // Whether the current model is a built-in Leo model + bool is_leo_model; + // How much of the content has been used by the AI engine, percentage,or null + // if no content is associated. + uint32? content_used_percentage; + // Whether the content has been refined + bool is_content_refined; + // Whether the UI should represent that the user cannot submit new messages + // or edits to the conversation. + bool can_submit_user_entries; +}; + // Browser-side handler for a Conversation interface ConversationHandler { GetState() => (ConversationState conversation_state); @@ -383,10 +441,6 @@ interface ConversationHandler { // Get all visible history entries, including in-progress responses GetConversationHistory() => (array conversation_history); - // List of all suggested questions - GetSuggestedQuestions() => ( - array questions, SuggestionGenerationStatus suggestion_status); - // The browser should generate some questions and fire an event when they // are ready. GenerateQuestions(); @@ -414,7 +468,7 @@ interface ConversationHandler { // Send a user-rating for a chat // message. |turn_id| is the index of the message in the // specified conversation. - RateMessage(bool is_liked, uint32 turn_id) + RateMessage(bool is_liked, string turn_uuid) => (string? rating_id); SendFeedback( string category, @@ -422,9 +476,28 @@ interface ConversationHandler { string rating_id, bool send_hostname) => (bool is_success); }; -interface ConversationEntriesHandler { +// Browser-side handler for a Conversation's UI responsible for displaying +// untrusted content (e.g. content generated by the AI engine). +interface UntrustedConversationHandler { + BindUntrustedConversationUI( + pending_remote untrusted_ui) + => (ConversationEntriesState conversation_entries_state); + // Get all visible history entries, including in-progress responses GetConversationHistory() => (array conversation_history); + + ModifyConversation(uint32 turn_index, string new_text); +}; + +// Untrusted-UI-side handler for a Conversation, responsible for displaying +// content generated by the AI engine. +interface UntrustedConversationUI { + // TODO(petemill): Provide single entry that's been updated so that we don't + // need to fetch (and clone) all conversation entries each time text is added + // to the most recent entry. + OnConversationHistoryUpdate(); + OnEntriesUIStateChanged(ConversationEntriesState state); + OnFaviconImageDataChanged(); }; interface ConversationUI { diff --git a/components/ai_chat/core/common/mojom/untrusted_frame.mojom b/components/ai_chat/core/common/mojom/untrusted_frame.mojom new file mode 100644 index 000000000000..264d5e2c3c2c --- /dev/null +++ b/components/ai_chat/core/common/mojom/untrusted_frame.mojom @@ -0,0 +1,31 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +module ai_chat.mojom; + +// Interfaces for communication between the untrusted content frame and both +// the Browser and the parent trusted UI frame. + +// Trusted WebUI-side handler for messages from the untrusted child frame +interface ParentUIFrame { + //