From fa88646301693276c6ab61cf5fba6507436a4afa Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Tue, 23 Jul 2024 09:59:05 -0700 Subject: [PATCH] [.Net] Add a constructor which takes ChatCompletionOptions for OpenAIChatAgent (#3170) * accept ChatCompletionOptions in constrcutor * fix comment --- .../AutoGen.OpenAI/Agent/OpenAIChatAgent.cs | 96 ++++++++++++++----- .../OpenAIChatAgentTest.cs | 64 ++++++++++--- 2 files changed, 123 insertions(+), 37 deletions(-) diff --git a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs index b192cde1024b..4608a416feda 100644 --- a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using AutoGen.OpenAI.Extension; @@ -32,13 +33,8 @@ namespace AutoGen.OpenAI; public class OpenAIChatAgent : IStreamingAgent { private readonly OpenAIClient openAIClient; - private readonly string modelName; - private readonly float _temperature; - private readonly int _maxTokens = 1024; - private readonly IEnumerable? _functions; - private readonly string _systemMessage; - private readonly ChatCompletionsResponseFormat? _responseFormat; - private readonly int? _seed; + private readonly ChatCompletionsOptions options; + private readonly string systemMessage; /// /// Create a new instance of . @@ -62,16 +58,36 @@ public OpenAIChatAgent( int? seed = null, ChatCompletionsResponseFormat? responseFormat = null, IEnumerable? functions = null) + : this( + openAIClient: openAIClient, + name: name, + options: CreateChatCompletionOptions(modelName, temperature, maxTokens, seed, responseFormat, functions), + systemMessage: systemMessage) { + } + + /// + /// Create a new instance of . + /// + /// openai client + /// agent name + /// system message + /// chat completion option. The option can't contain messages + public OpenAIChatAgent( + OpenAIClient openAIClient, + string name, + ChatCompletionsOptions options, + string systemMessage = "You are a helpful AI assistant") + { + if (options.Messages is { Count: > 0 }) + { + throw new ArgumentException("Messages should not be provided in options"); + } + this.openAIClient = openAIClient; - this.modelName = modelName; this.Name = name; - _temperature = temperature; - _maxTokens = maxTokens; - _functions = functions; - _systemMessage = systemMessage; - _responseFormat = responseFormat; - _seed = seed; + this.options = options; + this.systemMessage = systemMessage; } public string Name { get; } @@ -116,22 +132,25 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions // add system message if there's no system message in messages if (!oaiMessages.Any(m => m is ChatRequestSystemMessage)) { - oaiMessages = new[] { new ChatRequestSystemMessage(_systemMessage) }.Concat(oaiMessages); + oaiMessages = new[] { new ChatRequestSystemMessage(systemMessage) }.Concat(oaiMessages); } - var settings = new ChatCompletionsOptions(this.modelName, oaiMessages) + // clone the options by serializing and deserializing + var json = JsonSerializer.Serialize(this.options); + var settings = JsonSerializer.Deserialize(json) ?? throw new InvalidOperationException("Failed to clone options"); + + foreach (var m in oaiMessages) { - MaxTokens = options?.MaxToken ?? _maxTokens, - Temperature = options?.Temperature ?? _temperature, - ResponseFormat = _responseFormat, - Seed = _seed, - }; + settings.Messages.Add(m); + } + + settings.Temperature = options?.Temperature ?? settings.Temperature; + settings.MaxTokens = options?.MaxToken ?? settings.MaxTokens; - var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToOpenAIFunctionDefinition()); - var functions = openAIFunctionDefinitions ?? _functions; - if (functions is not null && functions.Count() > 0) + var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToOpenAIFunctionDefinition()).ToList(); + if (openAIFunctionDefinitions is { Count: > 0 }) { - foreach (var f in functions) + foreach (var f in openAIFunctionDefinitions) { settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); } @@ -147,4 +166,31 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions return settings; } + + private static ChatCompletionsOptions CreateChatCompletionOptions( + string modelName, + float temperature = 0.7f, + int maxTokens = 1024, + int? seed = null, + ChatCompletionsResponseFormat? responseFormat = null, + IEnumerable? functions = null) + { + var options = new ChatCompletionsOptions(modelName, []) + { + Temperature = temperature, + MaxTokens = maxTokens, + Seed = seed, + ResponseFormat = responseFormat, + }; + + if (functions is not null) + { + foreach (var f in functions) + { + options.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); + } + } + + return options; + } } diff --git a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs index 8ff66f5c86bf..85f898547b00 100644 --- a/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs +++ b/dotnet/test/AutoGen.OpenAI.Tests/OpenAIChatAgentTest.cs @@ -28,10 +28,8 @@ public async Task GetWeatherAsync(string location) [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] public async Task BasicConversationTestAsync() { - var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); - var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); - var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + var openaiClient = CreateOpenAIClientFromAzureOpenAI(); var openAIChatAgent = new OpenAIChatAgent( openAIClient: openaiClient, name: "assistant", @@ -60,10 +58,8 @@ public async Task BasicConversationTestAsync() [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] public async Task OpenAIChatMessageContentConnectorTestAsync() { - var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); - var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); - var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + var openaiClient = CreateOpenAIClientFromAzureOpenAI(); var openAIChatAgent = new OpenAIChatAgent( openAIClient: openaiClient, name: "assistant", @@ -107,10 +103,8 @@ public async Task OpenAIChatMessageContentConnectorTestAsync() [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] public async Task OpenAIChatAgentToolCallTestAsync() { - var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); - var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); - var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + var openaiClient = CreateOpenAIClientFromAzureOpenAI(); var openAIChatAgent = new OpenAIChatAgent( openAIClient: openaiClient, name: "assistant", @@ -176,10 +170,8 @@ public async Task OpenAIChatAgentToolCallTestAsync() [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] public async Task OpenAIChatAgentToolCallInvokingTestAsync() { - var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); - var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); - var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + var openaiClient = CreateOpenAIClientFromAzureOpenAI(); var openAIChatAgent = new OpenAIChatAgent( openAIClient: openaiClient, name: "assistant", @@ -236,4 +228,52 @@ public async Task OpenAIChatAgentToolCallInvokingTestAsync() } } } + + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] + public async Task ItCreateOpenAIChatAgentWithChatCompletionOptionAsync() + { + var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); + var openaiClient = CreateOpenAIClientFromAzureOpenAI(); + var options = new ChatCompletionsOptions(deployName, []) + { + Temperature = 0.7f, + MaxTokens = 1, + }; + + var openAIChatAgent = new OpenAIChatAgent( + openAIClient: openaiClient, + name: "assistant", + options: options) + .RegisterMessageConnector(); + + var respond = await openAIChatAgent.SendAsync("hello"); + respond.GetContent()?.Should().NotBeNullOrEmpty(); + } + + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")] + public async Task ItThrowExceptionWhenChatCompletionOptionContainsMessages() + { + var deployName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable."); + var openaiClient = CreateOpenAIClientFromAzureOpenAI(); + var options = new ChatCompletionsOptions(deployName, [new ChatRequestUserMessage("hi")]) + { + Temperature = 0.7f, + MaxTokens = 1, + }; + + var action = () => new OpenAIChatAgent( + openAIClient: openaiClient, + name: "assistant", + options: options) + .RegisterMessageConnector(); + + action.Should().ThrowExactly().WithMessage("Messages should not be provided in options"); + } + + private OpenAIClient CreateOpenAIClientFromAzureOpenAI() + { + var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); + var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); + return new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + } }