Skip to content

Commit

Permalink
Refactor chat provider to use message streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamKarolDiCioccio committed Aug 26, 2024
1 parent 2912452 commit 79340a1
Showing 1 changed file with 57 additions and 76 deletions.
133 changes: 57 additions & 76 deletions app/lib/backend/providers/chat.dart
Original file line number Diff line number Diff line change
Expand Up @@ -308,24 +308,41 @@ class ChatProvider extends ChangeNotifier {
/// If the session is not selected, the function returns the newly created [ChatModelMessageWrapper] without adding it to the memory or the database.
///
/// Returns the newly created [ChatModelMessageWrapper].
ChatModelMessageWrapper addModelMessage(String message, String senderName) {
StreamSubscription<String> addModelMessage(
Stream<String> messageStream,
String senderName,
) {
final StringBuffer messageBuffer = StringBuffer();
final DateTime timestamp = DateTime.now();
final String messageId = const Uuid().v4();

final chatMessage = ChatModelMessageWrapper(
message,
DateTime.now(),
const Uuid().v4(),
'',
timestamp,
messageId,
senderName,
);

if (!isSessionSelected) return chatMessage;

_session!.messages.add(chatMessage);
_session!.memory.chatHistory.addAIChatMessage(message);

ChatSessionsDatabase.updateSession(_session!);
final StreamSubscription<String> subscription = messageStream.listen(
(message) {
messageBuffer.write(message);

notifyListeners();
_session!.messages.last.text = messageBuffer.toString();
},
onDone: () {
if (isSessionSelected) {
_session!.memory.chatHistory.addAIChatMessage(
messageBuffer.toString(),
);
ChatSessionsDatabase.updateSession(_session!);
notifyListeners();
}
},
);

return _session!.messages.last as ChatModelMessageWrapper;
return subscription;
}

/// Adds a chat message of type user to the current session and to the model's memory and updates the session in the database.
Expand Down Expand Up @@ -465,6 +482,34 @@ class ChatProvider extends ChangeNotifier {
return prompt;
}

/// Processes the chat chain with the given prompt and streams the response as a string.
Stream<String> _processChain(ChatMessage prompt) async* {
final chain = await _buildChain();

await for (final response in chain.stream([prompt])) {
final result = response as ChatResult;

_computePerformanceStatistics(result);

yield response.outputAsString;

notifyListeners();

// If the session is aborted, remove the last message from memory and break the loop

if (_session!.status == ChatSessionStatus.aborting) {
_session!.status = ChatSessionStatus.idle;
_session!.memory.chatHistory.removeLast();

_computePerformanceStatistics(result);

notifyListeners();

break;
}
}
}

/// Sends a message to the chat model and processes the response.
///
/// The function first checks if a session is selected and creates a new one if not.
Expand Down Expand Up @@ -504,43 +549,9 @@ class ChatProvider extends ChangeNotifier {

addUserMessage(text, imageBytes);

final chain = await _buildChain();

final prompt = _buildPrompt(text, imageBytes: imageBytes);

addModelMessage('', _modelName);

await for (final response in chain.stream([prompt])) {
ChatResult result = response as ChatResult;

// If the session is aborted, remove the last message from memory and break the loop

if (_session!.status == ChatSessionStatus.aborting) {
_session!.status = ChatSessionStatus.idle;
_session!.memory.chatHistory.removeLast();

_computePerformanceStatistics(result);

notifyListeners();

break;
}

final lastMessage = _session!.messages.last;
lastMessage.text += result.outputAsString;

_computePerformanceStatistics(result);

notifyListeners();
}

// Save the generated message, remove and add it back to force a memory update

final generatedText = _session!.messages.last.text;

removeLastMessage();

addModelMessage(generatedText, _modelName);
addModelMessage(_processChain(prompt), _modelName);

_session!.status = ChatSessionStatus.idle;

Expand Down Expand Up @@ -638,42 +649,12 @@ class ChatProvider extends ChangeNotifier {

notifyListeners();

final chain = await _buildChain();

final prompt = _buildPrompt(
userMessage.text,
imageBytes: userMessage.imageBytes,
);

addModelMessage('', _modelName);

await for (final response in chain.stream([prompt])) {
ChatResult result = response as ChatResult;

if (_session!.status == ChatSessionStatus.aborting) {
_session!.status = ChatSessionStatus.idle;
_session!.memory.chatHistory.removeLast();

_computePerformanceStatistics(result);

notifyListeners();

break;
}

final lastMessage = _session!.messages.last;
lastMessage.text += result.outputAsString;

_computePerformanceStatistics(result);

notifyListeners();
}

final generatedText = _session!.messages.last.text;

removeLastMessage();

addModelMessage(generatedText, _modelName);
addModelMessage(_processChain(prompt), _modelName);

_session!.status = ChatSessionStatus.idle;

Expand Down

0 comments on commit 79340a1

Please sign in to comment.