From c91cc7626dcc21d9f86135b05737e913ea69f538 Mon Sep 17 00:00:00 2001 From: Noah Czaplewski Date: Tue, 17 Sep 2024 08:15:44 +0200 Subject: [PATCH] Added a simpler one-time chat api --- src/IOllamaApiClient.cs | 11 ++++++ src/Models/Chat/ChatResponse.cs | 67 +++++++++++++++++++++++++++++++++ src/OllamaApiClient.cs | 22 +++++++++++ test/ChatTests.cs | 13 +++++++ test/TestOllamaApiClient.cs | 11 ++++++ 5 files changed, 124 insertions(+) create mode 100644 src/Models/Chat/ChatResponse.cs diff --git a/src/IOllamaApiClient.cs b/src/IOllamaApiClient.cs index 5f3ccb4..f13dac6 100644 --- a/src/IOllamaApiClient.cs +++ b/src/IOllamaApiClient.cs @@ -36,6 +36,17 @@ public interface IOllamaApiClient /// To implement a fully interactive chat, you should make use of the Chat class with "new Chat(...)" /// IAsyncEnumerable Chat(ChatRequest request, [EnumeratorCancellation] CancellationToken cancellationToken = default); + + + /// + /// Sends a non-streaming request to the /api/chat endpoint and returns the response. + /// + /// The request to send to Ollama + /// The token to cancel the operation with + /// + /// A ChatResponse which contains the Message from the /api/chat endpoint. + /// + Task ChatAsync(ChatRequest request, CancellationToken cancellationToken = default); /// /// Sends a request to the /api/copy endpoint to copy a model diff --git a/src/Models/Chat/ChatResponse.cs b/src/Models/Chat/ChatResponse.cs new file mode 100644 index 0000000..310220c --- /dev/null +++ b/src/Models/Chat/ChatResponse.cs @@ -0,0 +1,67 @@ +using System; +using System.Text.Json.Serialization; + +namespace OllamaSharp.Models.Chat; + +public class ChatResponse +{ + /// + /// The model that generated the response + /// + [JsonPropertyName("model")] + public string Model { get; set; } = null!; + + /// + /// The time the response was generated + /// + [JsonPropertyName("created_at")] + public string CreatedAt { get; set; } = null!; + + /// + /// The message returned by the model + /// + [JsonPropertyName("message")] + public Message Message { get; set; } = null!; + + /// + /// Whether the response is complete + /// + [JsonPropertyName("done")] + public bool Done { get; set; } + + /// + /// Total duration to process the prompt + /// + [JsonPropertyName("total_duration")] + public int TotalDuration { get; set; } + + /// + /// Duration to load the model + /// + [JsonPropertyName("load_duration")] + public int LoadDuration { get; set; } + + /// + /// Prompt evaluation steps + /// + [JsonPropertyName("prompt_eval_count")] + public int PromptEvalCount { get; set; } + + /// + /// Prompt evaluation duration + /// + [JsonPropertyName("prompt_eval_duration")] + public int PromptEvalDuration { get; set; } + + /// + /// Evaluation duration + /// + [JsonPropertyName("eval_count")] + public int EvalCount { get; set; } + + /// + /// Evaluation duration + /// + [JsonPropertyName("eval_duration")] + public int EvalDuration { get; set; } +} \ No newline at end of file diff --git a/src/OllamaApiClient.cs b/src/OllamaApiClient.cs index 338f479..ece5beb 100644 --- a/src/OllamaApiClient.cs +++ b/src/OllamaApiClient.cs @@ -180,6 +180,22 @@ public Task Embed(EmbedRequest request, CancellationToken cancell yield return result; } + /// + public async Task ChatAsync(ChatRequest request, CancellationToken cancellationToken = default) + { + request.Stream = false; + var requestMessage = new HttpRequestMessage(HttpMethod.Post, "api/chat") + { + Content = new StringContent(JsonSerializer.Serialize(request, OutgoingJsonSerializerOptions), Encoding.UTF8, "application/json") + }; + + var completion = HttpCompletionOption.ResponseContentRead; + + using var response = await SendToOllamaAsync(requestMessage, request, completion, cancellationToken); + + return await ProcessChatResponseAsync(response, cancellationToken); + } + /// public async Task IsRunning(CancellationToken cancellationToken = default) { @@ -308,6 +324,12 @@ private async Task PostAsync(string endpoint, TR } } + private async Task ProcessChatResponseAsync(HttpResponseMessage response, CancellationToken cancellationToken) + { + var stream = await response.Content.ReadAsStringAsync(); + return JsonSerializer.Deserialize(stream, IncomingJsonSerializerOptions); + } + /// /// Sends a http request message to the Ollama API. /// diff --git a/test/ChatTests.cs b/test/ChatTests.cs index 1e26d20..dd0561f 100644 --- a/test/ChatTests.cs +++ b/test/ChatTests.cs @@ -83,6 +83,19 @@ public async Task Sends_Messages_As_User() chat.Messages.First().Role.Should().Be(ChatRole.User); chat.Messages.First().Content.Should().Be("henlo"); } + + [Test] + public async Task Sends_Message_One_Reply() + { + var expectedResponse = new ChatResponse { Message = CreateMessage(ChatRole.Assistant, "Pong!") }; + _ollama.SetExpectedChatResponse(expectedResponse); + + var answer = await _ollama.ChatAsync(new ChatRequest {Messages = new []{CreateMessage(ChatRole.User, "Ping!")}}); + + answer.Message.Content.Should().Be(expectedResponse.Message.Content); + + answer.Message.Role.Should().Be(ChatRole.Assistant); + } } public class SendAsMethod : ChatTests diff --git a/test/TestOllamaApiClient.cs b/test/TestOllamaApiClient.cs index 4fd5b8b..997b0b5 100644 --- a/test/TestOllamaApiClient.cs +++ b/test/TestOllamaApiClient.cs @@ -11,6 +11,7 @@ public class TestOllamaApiClient : IOllamaApiClient { private ChatResponseStream[] _expectedChatResponses = []; private GenerateResponseStream[] _expectedGenerateResponses = []; + private ChatResponse _expectedChatResponse; public string SelectedModel { get; set; } = string.Empty; @@ -24,6 +25,11 @@ internal void SetExpectedGenerateResponses(params GenerateResponseStream[] respo _expectedGenerateResponses = responses; } + internal void SetExpectedChatResponse(ChatResponse chatResponse) + { + _expectedChatResponse = chatResponse; + } + public async IAsyncEnumerable Chat(ChatRequest request, [EnumeratorCancellation] CancellationToken cancellationToken = default) { foreach (var response in _expectedChatResponses) @@ -33,6 +39,11 @@ internal void SetExpectedGenerateResponses(params GenerateResponseStream[] respo } } + public Task ChatAsync(ChatRequest request, CancellationToken cancellationToken = default) + { + return Task.FromResult(_expectedChatResponse); + } + public Task CopyModel(CopyModelRequest request, CancellationToken cancellationToken = default) { throw new NotImplementedException();