Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a simpler one-time chat api #83

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/IOllamaApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ public interface IOllamaApiClient
/// To implement a fully interactive chat, you should make use of the Chat class with "new Chat(...)"
/// </remarks>
IAsyncEnumerable<ChatResponseStream?> Chat(ChatRequest request, [EnumeratorCancellation] CancellationToken cancellationToken = default);


/// <summary>
/// Sends a non-streaming request to the /api/chat endpoint and returns the response.
/// </summary>
/// <param name="request">The request to send to Ollama</param>
/// <param name="cancellationToken">The token to cancel the operation with</param>
/// <returns>
/// A ChatResponse which contains the Message from the /api/chat endpoint.
/// </returns>
Task<ChatResponse> ChatAsync(ChatRequest request, CancellationToken cancellationToken = default);

/// <summary>
/// Sends a request to the /api/copy endpoint to copy a model
Expand Down
67 changes: 67 additions & 0 deletions src/Models/Chat/ChatResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using System;
using System.Text.Json.Serialization;

namespace OllamaSharp.Models.Chat;

public class ChatResponse
{
/// <summary>
/// The model that generated the response
/// </summary>
[JsonPropertyName("model")]
public string Model { get; set; } = null!;

/// <summary>
/// The time the response was generated
/// </summary>
[JsonPropertyName("created_at")]
public string CreatedAt { get; set; } = null!;

/// <summary>
/// The message returned by the model
/// </summary>
[JsonPropertyName("message")]
public Message Message { get; set; } = null!;

/// <summary>
/// Whether the response is complete
/// </summary>
[JsonPropertyName("done")]
public bool Done { get; set; }

/// <summary>
/// Total duration to process the prompt
/// </summary>
[JsonPropertyName("total_duration")]
public int TotalDuration { get; set; }

/// <summary>
/// Duration to load the model
/// </summary>
[JsonPropertyName("load_duration")]
public int LoadDuration { get; set; }

/// <summary>
/// Prompt evaluation steps
/// </summary>
[JsonPropertyName("prompt_eval_count")]
public int PromptEvalCount { get; set; }

/// <summary>
/// Prompt evaluation duration
/// </summary>
[JsonPropertyName("prompt_eval_duration")]
public int PromptEvalDuration { get; set; }

/// <summary>
/// Evaluation duration
/// </summary>
[JsonPropertyName("eval_count")]
public int EvalCount { get; set; }

/// <summary>
/// Evaluation duration
/// </summary>
[JsonPropertyName("eval_duration")]
public int EvalDuration { get; set; }
}
22 changes: 22 additions & 0 deletions src/OllamaApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,22 @@ public Task<EmbedResponse> Embed(EmbedRequest request, CancellationToken cancell
yield return result;
}

/// <inheritdoc />
public async Task<ChatResponse> ChatAsync(ChatRequest request, CancellationToken cancellationToken = default)
{
request.Stream = false;
var requestMessage = new HttpRequestMessage(HttpMethod.Post, "api/chat")
{
Content = new StringContent(JsonSerializer.Serialize(request, OutgoingJsonSerializerOptions), Encoding.UTF8, "application/json")
};

var completion = HttpCompletionOption.ResponseContentRead;

using var response = await SendToOllamaAsync(requestMessage, request, completion, cancellationToken);

return await ProcessChatResponseAsync(response, cancellationToken);
}

/// <inheritdoc />
public async Task<bool> IsRunning(CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -308,6 +324,12 @@ private async Task<TResponse> PostAsync<TRequest, TResponse>(string endpoint, TR
}
}

private async Task<ChatResponse?> ProcessChatResponseAsync(HttpResponseMessage response, CancellationToken cancellationToken)
{
var stream = await response.Content.ReadAsStringAsync();
return JsonSerializer.Deserialize<ChatResponse>(stream, IncomingJsonSerializerOptions);
}

/// <summary>
/// Sends a http request message to the Ollama API.
/// </summary>
Expand Down
13 changes: 13 additions & 0 deletions test/ChatTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,19 @@ public async Task Sends_Messages_As_User()
chat.Messages.First().Role.Should().Be(ChatRole.User);
chat.Messages.First().Content.Should().Be("henlo");
}

[Test]
public async Task Sends_Message_One_Reply()
{
var expectedResponse = new ChatResponse { Message = CreateMessage(ChatRole.Assistant, "Pong!") };
_ollama.SetExpectedChatResponse(expectedResponse);

var answer = await _ollama.ChatAsync(new ChatRequest {Messages = new []{CreateMessage(ChatRole.User, "Ping!")}});

answer.Message.Content.Should().Be(expectedResponse.Message.Content);

answer.Message.Role.Should().Be(ChatRole.Assistant);
}
}

public class SendAsMethod : ChatTests
Expand Down
11 changes: 11 additions & 0 deletions test/TestOllamaApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public class TestOllamaApiClient : IOllamaApiClient
{
private ChatResponseStream[] _expectedChatResponses = [];
private GenerateResponseStream[] _expectedGenerateResponses = [];
private ChatResponse _expectedChatResponse;

public string SelectedModel { get; set; } = string.Empty;

Expand All @@ -24,6 +25,11 @@ internal void SetExpectedGenerateResponses(params GenerateResponseStream[] respo
_expectedGenerateResponses = responses;
}

internal void SetExpectedChatResponse(ChatResponse chatResponse)
{
_expectedChatResponse = chatResponse;
}

public async IAsyncEnumerable<ChatResponseStream?> Chat(ChatRequest request, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var response in _expectedChatResponses)
Expand All @@ -33,6 +39,11 @@ internal void SetExpectedGenerateResponses(params GenerateResponseStream[] respo
}
}

public Task<ChatResponse> ChatAsync(ChatRequest request, CancellationToken cancellationToken = default)
{
return Task.FromResult(_expectedChatResponse);
}

public Task CopyModel(CopyModelRequest request, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
Expand Down
Loading