diff --git a/src/backend/CosmosTodoItemAITrigger.cs b/src/backend/CosmosTodoItemAITrigger.cs new file mode 100644 index 0000000..22c4d44 --- /dev/null +++ b/src/backend/CosmosTodoItemAITrigger.cs @@ -0,0 +1,108 @@ +using System.Collections.Generic; +using Microsoft.Azure.WebJobs; +using backend.Models; +using System.Threading.Tasks; +using Microsoft.Azure.WebJobs.Extensions.DurableTask; +using Azure.AI.OpenAI; +using Azure; +using System; +using System.Linq; + +namespace backend +{ + public class CosmosTodoItemAITrigger + { + private readonly OpenAIClient _openAIClient; + + public CosmosTodoItemAITrigger(OpenAIClient openAIClient) + { + _openAIClient = openAIClient; + } + + [FunctionName("CosmosTodoItemAITrigger")] + public async Task Run( + [CosmosDBTrigger( + databaseName: "%CosmosDatabaseName%", + containerName: "TodoItem", + Connection = "CosmosConnectionOptions", + LeaseContainerName = "Leases", + StartFromBeginning = true, + LeaseContainerPrefix = "ai", + CreateLeaseContainerIfNotExists = false)] IReadOnlyList input, + [CosmosDB( + databaseName: "%CosmosDatabaseName%", + containerName: "TodoItem", + Connection = "CosmosConnectionOptions")] + IAsyncCollector output, + [DurableClient] IDurableEntityClient durableEntityClient) + { + var exceptions = new List(); + + if (input != null && input.Count > 0) + { + foreach (var item in input) + { + try + { + var piratized = await Piratize(output, item); + + if (piratized) + { + await UpdateDbDescription(output, item); + } + } + catch (Exception e) + { + // We need to keep processing the rest of the batch - capture this exception and continue. + // Also, consider capturing details of the message that failed processing so it can be processed again later. + exceptions.Add(e); + } + } + } + + // Once processing of the batch is complete, if any messages in the batch failed processing throw an exception so that there is a record of the failure. + + if (exceptions.Count > 1) + throw new AggregateException(exceptions); + + if (exceptions.Count == 1) + throw exceptions.Single(); + } + + private async Task Piratize(IAsyncCollector output, TodoItem item) + { + if (!string.IsNullOrWhiteSpace(item.Description) && item.Description.StartsWith("[ASSISTANT]")) + { + return false; + } + + var input = string.IsNullOrWhiteSpace(item.Description) ? item.Name : item.Description; + + var chatCompletionsOptions = new ChatCompletionsOptions() + { + DeploymentName = "gpt-35-turbo-16k", + Messages = + { + new ChatRequestSystemMessage("You are a helpful assistant. You will talk like a pirate. Rephrase, fix grammer, and create a todo task."), + + new ChatRequestUserMessage(input), + } + }; + + Response response = await _openAIClient.GetChatCompletionsAsync(chatCompletionsOptions); + ChatResponseMessage responseMessage = response.Value.Choices[0].Message; + + item.Description = $"[ASSISTANT]: {responseMessage.Content}"; + + return true; + } + + private async Task UpdateDbDescription(IAsyncCollector output, TodoItem item) + { + item.UpdatedDate = DateTimeOffset.UtcNow.DateTime; + await output.AddAsync(item); + + return true; + } + } +} diff --git a/src/backend/Startup.cs b/src/backend/Startup.cs index b6229ca..6d700e7 100644 --- a/src/backend/Startup.cs +++ b/src/backend/Startup.cs @@ -1,6 +1,8 @@ using System; using System.Net.Http; using System.Threading.Tasks; +using Azure; +using Azure.AI.OpenAI; using Azure.Identity; using Azure.Security.KeyVault.Secrets; using Microsoft.Azure.Cosmos; @@ -44,6 +46,24 @@ public override void Configure(IFunctionsHostBuilder builder) return new CosmosClientBuilder(accountEndpoint, defaultAzureCredentials).Build(); }) + .AddSingleton(serviceProvider => + { + var azureOpenAIEndpoint = Environment.GetEnvironmentVariable("AzureOpenAIEndpoint"); + Uri azureOpenAIResourceUri = new(azureOpenAIEndpoint); + + var keyVaultUri = new Uri(Environment.GetEnvironmentVariable("KeyVaultEndpoint")); + var httpClient = serviceProvider.GetService(); + var defaultAzureCredentials = serviceProvider.GetService(); + + // Retrieve apiKey from KeyVault + var secretClient = new SecretClient(keyVaultUri, defaultAzureCredentials); + var azureResponseKeyVaultSecret = new Lazy>>(async () => await secretClient.GetSecretAsync("AZURE-OPEN-AI-KEY")); + var openAIApiKey = azureResponseKeyVaultSecret.Value.Result.Value.Value; + + AzureKeyCredential azureOpenAIApiKey = new(openAIApiKey); + + return new OpenAIClient(azureOpenAIResourceUri, azureOpenAIApiKey); + }) .AddHealthChecks(); } } diff --git a/src/backend/Todo.Backend.csproj b/src/backend/Todo.Backend.csproj index 869bde1..9fc19d6 100644 --- a/src/backend/Todo.Backend.csproj +++ b/src/backend/Todo.Backend.csproj @@ -4,6 +4,7 @@ v4 + diff --git a/src/backend/local.settings.json b/src/backend/local.settings.json index 5b374a7..640e00a 100644 --- a/src/backend/local.settings.json +++ b/src/backend/local.settings.json @@ -8,6 +8,7 @@ "EventHubRequestsName": "", "EventHubRequestsConnectionOptions__fullyQualifiedNamespace": "", "EventHubRequestsConsumerGroup": "", - "KeyVaultEndpoint": "" + "KeyVaultEndpoint": "", + "AzureOpenAIEndpoint": "" } } \ No newline at end of file