Skip to content

Commit

Permalink
Some chat use cases need full control over the chat history. Modified…
Browse files Browse the repository at this point in the history
… the OllamaApiClient to only return the chat response with metrics. It will not return a modified history.

The Chat class continues to maintain an internal chat history.

Enable chat to obtain metrics.
Enable more control over conversation details.
  • Loading branch information
MarkWard0110 committed Jun 2, 2024
1 parent 500658e commit 00f16f1
Show file tree
Hide file tree
Showing 6 changed files with 376 additions and 383 deletions.
7 changes: 5 additions & 2 deletions src/Chat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ public async Task<IEnumerable<Message>> SendAs(ChatRole role, string message, IE
Stream = true
};

var answer = await Client.SendChat(request, Streamer, cancellationToken);
_messages = answer.ToList();
var answer = await Client.SendChat(request, Streamer, cancellationToken);
var messages = request.Messages.ToList();
messages.Add(answer.Response);

_messages = messages;
return Messages;
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/IOllamaApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ public interface IOllamaApiClient
/// Can be used to update the user interface while the answer is still being generated.
/// </param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
/// <returns>List of the returned messages including the previous context</returns>
Task<IEnumerable<Message>> SendChat(ChatRequest chatRequest, IResponseStreamer<ChatResponseStream> streamer, CancellationToken cancellationToken = default);
/// <returns>Response</returns>
Task<ConversationResponse> SendChat(ChatRequest chatRequest, IResponseStreamer<ChatResponseStream> streamer, CancellationToken cancellationToken = default);

/// <summary>
/// Sends a request to the /api/copy endpoint to copy a model
Expand Down
110 changes: 56 additions & 54 deletions src/OllamaApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,23 +108,23 @@ public async Task<ConversationContextWithResponse> GetCompletion(GenerateComplet
var result = await GenerateCompletion(request, new ActionResponseStreamer<GenerateCompletionResponseStream>(status => builder.Append(status.Response)), cancellationToken);
return new ConversationContextWithResponse(builder.ToString(), result.Context, result.Metadata);
}

public async Task<IEnumerable<Message>> SendChat(ChatRequest chatRequest, IResponseStreamer<ChatResponseStream> streamer, CancellationToken cancellationToken = default)
{
var request = new HttpRequestMessage(HttpMethod.Post, "api/chat")
{
Content = new StringContent(JsonSerializer.Serialize(chatRequest), Encoding.UTF8, "application/json")
};

var completion = chatRequest.Stream ? HttpCompletionOption.ResponseHeadersRead : HttpCompletionOption.ResponseContentRead;

using var response = await _client.SendAsync(request, completion, cancellationToken);
response.EnsureSuccessStatusCode();

return await ProcessStreamedChatResponseAsync(chatRequest, response, streamer, cancellationToken);
}

private async Task<ConversationContext> GenerateCompletion(GenerateCompletionRequest generateRequest, IResponseStreamer<GenerateCompletionResponseStream> streamer, CancellationToken cancellationToken)

public async Task<ConversationResponse> SendChat(ChatRequest chatRequest, IResponseStreamer<ChatResponseStream> streamer, CancellationToken cancellationToken = default)
{
var request = new HttpRequestMessage(HttpMethod.Post, "api/chat")
{
Content = new StringContent(JsonSerializer.Serialize(chatRequest), Encoding.UTF8, "application/json")
};

var completion = chatRequest.Stream ? HttpCompletionOption.ResponseHeadersRead : HttpCompletionOption.ResponseContentRead;

using var response = await _client.SendAsync(request, completion, cancellationToken);
response.EnsureSuccessStatusCode();

return await ProcessStreamedChatResponseAsync(chatRequest, response, streamer, cancellationToken);
}

private async Task<ConversationContext> GenerateCompletion(GenerateCompletionRequest generateRequest, IResponseStreamer<GenerateCompletionResponseStream> streamer, CancellationToken cancellationToken)
{
var request = new HttpRequestMessage(HttpMethod.Post, "api/generate")
{
Expand Down Expand Up @@ -212,44 +212,46 @@ private static async Task<ConversationContext> ProcessStreamedCompletionResponse
}

return new ConversationContext(Array.Empty<long>());
}

private static async Task<IEnumerable<Message>> ProcessStreamedChatResponseAsync(ChatRequest chatRequest, HttpResponseMessage response, IResponseStreamer<ChatResponseStream> streamer, CancellationToken cancellationToken)
{
using var stream = await response.Content.ReadAsStreamAsync();
using var reader = new StreamReader(stream);

ChatRole? responseRole = null;
var responseContent = new StringBuilder();

while (!reader.EndOfStream && !cancellationToken.IsCancellationRequested)
{
string line = await reader.ReadLineAsync();

var streamedResponse = JsonSerializer.Deserialize<ChatResponseStream>(line);

// keep the streamed content to build the last message
// to return the list of messages
responseRole ??= streamedResponse?.Message?.Role;
responseContent.Append(streamedResponse?.Message?.Content ?? "");

streamer.Stream(streamedResponse);

if (streamedResponse?.Done ?? false)
{
var doneResponse = JsonSerializer.Deserialize<ChatDoneResponseStream>(line);
var messages = chatRequest.Messages.ToList();
messages.Add(new Message(responseRole, responseContent.ToString()));
return messages;
}
}

return Array.Empty<Message>();
}
}
}

private static async Task<ConversationResponse> ProcessStreamedChatResponseAsync(ChatRequest chatRequest, HttpResponseMessage response, IResponseStreamer<ChatResponseStream> streamer, CancellationToken cancellationToken)
{
using var stream = await response.Content.ReadAsStreamAsync();
using var reader = new StreamReader(stream);

ChatRole? responseRole = null;
var responseContent = new StringBuilder();

while (!reader.EndOfStream && !cancellationToken.IsCancellationRequested)
{
string line = await reader.ReadLineAsync();

var streamedResponse = JsonSerializer.Deserialize<ChatResponseStream>(line);

// keep the streamed content to build the last message
// to return the list of messages
responseRole ??= streamedResponse?.Message?.Role;
responseContent.Append(streamedResponse?.Message?.Content ?? "");

streamer.Stream(streamedResponse);

if (streamedResponse?.Done ?? false)
{
var doneResponse = JsonSerializer.Deserialize<ChatDoneResponseStream>(line);
var metadata = new ResponseMetadata(doneResponse.TotalDuration, doneResponse.LoadDuration, doneResponse.PromptEvalCount, doneResponse.PromptEvalDuration, doneResponse.EvalCount, doneResponse.EvalDuration);
var conversationResponse = new ConversationResponse(new Message(responseRole, responseContent.ToString()), metadata);
return conversationResponse;
}
}

return new ConversationResponse(null);
}
}

public record ResponseMetadata(long TotalDuration, long LoadDuration, int PromptEvalCount, long PromptEvalDuration, int EvalCount, long EvalDuration);
public record ConversationContext(long[] Context, ResponseMetadata Metadata = null);
public record ConversationContext(long[] Context, ResponseMetadata? Metadata = null);

public record ConversationContextWithResponse(string Response, long[] Context, ResponseMetadata Metadata = null ) : ConversationContext(Context, Metadata);
public record ConversationContextWithResponse(string Response, long[] Context, ResponseMetadata? Metadata = null ) : ConversationContext(Context, Metadata);

public record ConversationResponse(Message Response, ResponseMetadata? Metadata = null);
}
4 changes: 2 additions & 2 deletions src/OllamaApiClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ public static Chat Chat(this IOllamaApiClient client, IResponseStreamer<ChatResp
/// Can be used to update the user interface while the answer is still being generated.
/// </param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
/// <returns>List of the returned messages including the previous context</returns>
public static async Task<IEnumerable<Message>> SendChat(this IOllamaApiClient client, ChatRequest chatRequest, Action<ChatResponseStream> streamer, CancellationToken cancellationToken = default)
/// <returns>Response</returns>
public static async Task<ConversationResponse> SendChat(this IOllamaApiClient client, ChatRequest chatRequest, Action<ChatResponseStream> streamer, CancellationToken cancellationToken = default)
{
return await client.SendChat(chatRequest, new ActionResponseStreamer<ChatResponseStream>(streamer), cancellationToken);
}
Expand Down
Loading

0 comments on commit 00f16f1

Please sign in to comment.