From 79340a1fc6c5e095e8d80f372cf9860b2b4d665b Mon Sep 17 00:00:00 2001 From: Wilielmus <88447902+WilliamKarolDiCioccio@users.noreply.github.com> Date: Mon, 26 Aug 2024 14:28:45 +0200 Subject: [PATCH] Refactor chat provider to use message streaming --- app/lib/backend/providers/chat.dart | 133 ++++++++++++---------------- 1 file changed, 57 insertions(+), 76 deletions(-) diff --git a/app/lib/backend/providers/chat.dart b/app/lib/backend/providers/chat.dart index b971ba3..a3a7316 100644 --- a/app/lib/backend/providers/chat.dart +++ b/app/lib/backend/providers/chat.dart @@ -308,24 +308,41 @@ class ChatProvider extends ChangeNotifier { /// If the session is not selected, the function returns the newly created [ChatModelMessageWrapper] without adding it to the memory or the database. /// /// Returns the newly created [ChatModelMessageWrapper]. - ChatModelMessageWrapper addModelMessage(String message, String senderName) { + StreamSubscription<String> addModelMessage( + Stream<String> messageStream, + String senderName, + ) { + final StringBuffer messageBuffer = StringBuffer(); + final DateTime timestamp = DateTime.now(); + final String messageId = const Uuid().v4(); + final chatMessage = ChatModelMessageWrapper( - message, - DateTime.now(), - const Uuid().v4(), + '', + timestamp, + messageId, senderName, ); - if (!isSessionSelected) return chatMessage; - _session!.messages.add(chatMessage); - _session!.memory.chatHistory.addAIChatMessage(message); - ChatSessionsDatabase.updateSession(_session!); + final StreamSubscription<String> subscription = messageStream.listen( + (message) { + messageBuffer.write(message); - notifyListeners(); + _session!.messages.last.text = messageBuffer.toString(); + }, + onDone: () { + if (isSessionSelected) { + _session!.memory.chatHistory.addAIChatMessage( + messageBuffer.toString(), + ); + ChatSessionsDatabase.updateSession(_session!); + notifyListeners(); + } + }, + ); - return _session!.messages.last as ChatModelMessageWrapper; + return subscription; } /// Adds a chat message of type user to the current session and to the model's memory and updates the session in the database. @@ -465,6 +482,34 @@ class ChatProvider extends ChangeNotifier { return prompt; } + /// Processes the chat chain with the given prompt and streams the response as a string. + Stream<String> _processChain(ChatMessage prompt) async* { + final chain = await _buildChain(); + + await for (final response in chain.stream([prompt])) { + final result = response as ChatResult; + + _computePerformanceStatistics(result); + + yield response.outputAsString; + + notifyListeners(); + + // If the session is aborted, remove the last message from memory and break the loop + + if (_session!.status == ChatSessionStatus.aborting) { + _session!.status = ChatSessionStatus.idle; + _session!.memory.chatHistory.removeLast(); + + _computePerformanceStatistics(result); + + notifyListeners(); + + break; + } + } + } + /// Sends a message to the chat model and processes the response. /// /// The function first checks if a session is selected and creates a new one if not. @@ -504,43 +549,9 @@ class ChatProvider extends ChangeNotifier { addUserMessage(text, imageBytes); - final chain = await _buildChain(); - final prompt = _buildPrompt(text, imageBytes: imageBytes); - addModelMessage('', _modelName); - - await for (final response in chain.stream([prompt])) { - ChatResult result = response as ChatResult; - - // If the session is aborted, remove the last message from memory and break the loop - - if (_session!.status == ChatSessionStatus.aborting) { - _session!.status = ChatSessionStatus.idle; - _session!.memory.chatHistory.removeLast(); - - _computePerformanceStatistics(result); - - notifyListeners(); - - break; - } - - final lastMessage = _session!.messages.last; - lastMessage.text += result.outputAsString; - - _computePerformanceStatistics(result); - - notifyListeners(); - } - - // Save the generated message, remove and add it back to force a memory update - - final generatedText = _session!.messages.last.text; - - removeLastMessage(); - - addModelMessage(generatedText, _modelName); + addModelMessage(_processChain(prompt), _modelName); _session!.status = ChatSessionStatus.idle; @@ -638,42 +649,12 @@ class ChatProvider extends ChangeNotifier { notifyListeners(); - final chain = await _buildChain(); - final prompt = _buildPrompt( userMessage.text, imageBytes: userMessage.imageBytes, ); - addModelMessage('', _modelName); - - await for (final response in chain.stream([prompt])) { - ChatResult result = response as ChatResult; - - if (_session!.status == ChatSessionStatus.aborting) { - _session!.status = ChatSessionStatus.idle; - _session!.memory.chatHistory.removeLast(); - - _computePerformanceStatistics(result); - - notifyListeners(); - - break; - } - - final lastMessage = _session!.messages.last; - lastMessage.text += result.outputAsString; - - _computePerformanceStatistics(result); - - notifyListeners(); - } - - final generatedText = _session!.messages.last.text; - - removeLastMessage(); - - addModelMessage(generatedText, _modelName); + addModelMessage(_processChain(prompt), _modelName); _session!.status = ChatSessionStatus.idle;