Skip to content

Commit

Permalink
Merge pull request #36 from betalgo/feature/edits
Browse files Browse the repository at this point in the history
support for Edit API and clean up and bug fixes
  • Loading branch information
kayhantolga authored Nov 30, 2022
2 parents ae650f8 + 48fdf39 commit 3de77df
Show file tree
Hide file tree
Showing 26 changed files with 271 additions and 55 deletions.
3 changes: 2 additions & 1 deletion OpenAI.Playground/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
var sdk = serviceProvider.GetRequiredService<IOpenAIService>();

//await ModelTestHelper.FetchModelsTest(sdk);
await EditTestHelper.RunSimpleEditCreateTest(sdk);
//await ImageTestHelper.RunSimpleCreateImageTest(sdk);
//await ImageTestHelper.RunSimpleCreateImageEditTest(sdk);
await ImageTestHelper.RunSimpleCreateImageVariationTest(sdk);
//await ImageTestHelper.RunSimpleCreateImageVariationTest(sdk);
//await ModerationTestHelper.CreateModerationTest(sdk);
//await CompletionTestHelper.RunSimpleCompletionTest(sdk);
//await EmbeddingTestHelper.RunSimpleEmbeddingTest(sdk);
Expand Down
43 changes: 43 additions & 0 deletions OpenAI.Playground/TestHelpers/EditTestHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
using OpenAI.GPT3.Interfaces;
using OpenAI.GPT3.ObjectModels;
using OpenAI.GPT3.ObjectModels.RequestModels;

namespace OpenAI.Playground.TestHelpers
{
internal static class EditTestHelper
{
public static async Task RunSimpleEditCreateTest(IOpenAIService sdk)
{
ConsoleExtensions.WriteLine("Edit Create Testing is starting:", ConsoleColor.Cyan);

try
{
ConsoleExtensions.WriteLine("Edit Create Test:", ConsoleColor.DarkCyan);
var completionResult = await sdk.Edit.CreateEdit(new EditCreateRequest()
{
Input = "What day of the wek is it?",
Instruction = "Fix the spelling mistakes"
}, Models.TextEditDavinciV1);

if (completionResult.Successful)
{
Console.WriteLine(completionResult.Choices.FirstOrDefault());
}
else
{
if (completionResult.Error == null)
{
throw new Exception("Unknown Error");
}

Console.WriteLine($"{completionResult.Error.Code}: {completionResult.Error.Message}");
}
}
catch (Exception e)
{
Console.WriteLine(e);
throw;
}
}
}
}
14 changes: 7 additions & 7 deletions OpenAI.Playground/TestHelpers/ImageTestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,19 @@ public static async Task RunSimpleCreateImageTest(IOpenAIService sdk)
Console.WriteLine(e);
throw;
}
}
}

public static async Task RunSimpleCreateImageEditTest(IOpenAIService sdk)
{
ConsoleExtensions.WriteLine("Image Edit Create Testing is starting:", ConsoleColor.Cyan);
const string maskFileName = "image_edit_mask.png";
const string originalFileName = "image_edit_original.png";

// Images should be in png format with ARGB. I got help from this website to generate sample mask
// https://www.online-image-editor.com/
var maskFile = await File.ReadAllBytesAsync($"SampleData/{maskFileName}");
var originalFile = await File.ReadAllBytesAsync($"SampleData/{originalFileName}");

try
{
ConsoleExtensions.WriteLine("Image Edit Create Test:", ConsoleColor.DarkCyan);
Expand Down Expand Up @@ -93,13 +93,13 @@ public static async Task RunSimpleCreateImageEditTest(IOpenAIService sdk)
}
}

public static async Task RunSimpleCreateImageVariationTest(IOpenAIService sdk)
public static async Task RunSimpleCreateImageVariationTest(IOpenAIService sdk)
{
ConsoleExtensions.WriteLine("Image Variation Create Testing is starting:", ConsoleColor.Cyan);
const string originalFileName = "image_edit_original.png";

var originalFile = await File.ReadAllBytesAsync($"SampleData/{originalFileName}");

try
{
ConsoleExtensions.WriteLine("Image Variation Create Test:", ConsoleColor.DarkCyan);
Expand Down
30 changes: 30 additions & 0 deletions OpenAI.SDK/Interfaces/IEditService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using OpenAI.GPT3.ObjectModels;
using OpenAI.GPT3.ObjectModels.RequestModels;
using OpenAI.GPT3.ObjectModels.ResponseModels;

namespace OpenAI.GPT3.Interfaces;

/// <summary>
/// Given a prompt and an instruction, the model will return an edited version of the prompt.
/// </summary>
public interface IEditService
{
/// <summary>
/// Creates a new edit for the provided input, instruction, and parameters
/// </summary>
/// <param name="editCreate"></param>
/// <param name="engineId">The ID of the engine to use for this request</param>
/// <returns></returns>
Task<EditCreateResponse> CreateEdit(EditCreateRequest editCreate, string? engineId = null);

/// <summary>
/// Creates a new edit for the provided input, instruction, and parameters
/// </summary>
/// <param name="editCreate"></param>
/// <param name="engineId">The ID of the engine to use for this request</param>
/// <returns></returns>
Task<EditCreateResponse> Edit(EditCreateRequest editCreate, Models.Model engineId)
{
return CreateEdit(editCreate, engineId.EnumToString());
}
}
9 changes: 7 additions & 2 deletions OpenAI.SDK/Interfaces/IOpenAIService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ public interface IOpenAIService
/// <summary>
/// Given a prompt and/or an input image, the model will generate a new image.
/// </summary>
public IImageService Image { get; }

public IImageService Image { get; }

/// <summary>
/// Creates a new edit for the provided input, instruction, and parameters
/// </summary>
public IEditService Edit { get; }


/// <summary>
/// Set default engine
Expand Down
22 changes: 22 additions & 0 deletions OpenAI.SDK/Managers/OpenAIEdits.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using OpenAI.GPT3.Extensions;
using OpenAI.GPT3.Interfaces;
using OpenAI.GPT3.ObjectModels.RequestModels;
using OpenAI.GPT3.ObjectModels.ResponseModels;

namespace OpenAI.GPT3.Managers;

public partial class OpenAIService : IEditService
{
public async Task<EditCreateResponse> CreateEdit(EditCreateRequest editCreate, string? engineId = null)
{
if (editCreate.Model != null && engineId != null)
{
throw new ArgumentException("You cannot specify both a model and an engineId");
}
else if (editCreate.Model == null && engineId != null)
{
editCreate.Model = ProcessEngineId(engineId);
}
return await _httpClient.PostAndReadAsAsync<EditCreateResponse>(_endpointProvider.EditCreate(), editCreate);
}
}
16 changes: 8 additions & 8 deletions OpenAI.SDK/Managers/OpenAIImage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public async Task<ImageCreateResponse> CreateImage(ImageCreateRequest imageCreat
}

/// <summary>
/// Creates an edited or extended image given an original image and a prompt.
/// Creates an edited or extended image given an original image and a prompt.
/// </summary>
/// <param name="imageEditCreateRequest"></param>
/// <returns></returns>
Expand All @@ -27,9 +27,9 @@ public async Task<ImageCreateResponse> CreateImageEdit(ImageEditCreateRequest im
var multipartContent = new MultipartFormDataContent();
if (imageEditCreateRequest.User != null) multipartContent.Add(new StringContent(imageEditCreateRequest.User), "user");
if (imageEditCreateRequest.ResponseFormat != null) multipartContent.Add(new StringContent(imageEditCreateRequest.ResponseFormat), "response_format");
if (imageEditCreateRequest.Size!= null) multipartContent.Add(new StringContent(imageEditCreateRequest.Size), "size");
if (imageEditCreateRequest.N!= null) multipartContent.Add(new StringContent(imageEditCreateRequest.N.ToString()!), "n");
if (imageEditCreateRequest.Size != null) multipartContent.Add(new StringContent(imageEditCreateRequest.Size), "size");
if (imageEditCreateRequest.N != null) multipartContent.Add(new StringContent(imageEditCreateRequest.N.ToString()!), "n");

multipartContent.Add(new StringContent(imageEditCreateRequest.Prompt), "prompt");
multipartContent.Add(new ByteArrayContent(imageEditCreateRequest.Image), "image", imageEditCreateRequest.ImageName);
multipartContent.Add(new ByteArrayContent(imageEditCreateRequest.Mask), "mask", imageEditCreateRequest.MaskName);
Expand All @@ -38,7 +38,7 @@ public async Task<ImageCreateResponse> CreateImageEdit(ImageEditCreateRequest im
}

/// <summary>
/// Creates a variation of a given image.
/// Creates a variation of a given image.
/// </summary>
/// <param name="imageEditCreateRequest"></param>
/// <returns></returns>
Expand All @@ -47,9 +47,9 @@ public async Task<ImageCreateResponse> CreateImageVariation(ImageVariationCreate
var multipartContent = new MultipartFormDataContent();
if (imageEditCreateRequest.User != null) multipartContent.Add(new StringContent(imageEditCreateRequest.User), "user");
if (imageEditCreateRequest.ResponseFormat != null) multipartContent.Add(new StringContent(imageEditCreateRequest.ResponseFormat), "response_format");
if (imageEditCreateRequest.Size!= null) multipartContent.Add(new StringContent(imageEditCreateRequest.Size), "size");
if (imageEditCreateRequest.N!= null) multipartContent.Add(new StringContent(imageEditCreateRequest.N.ToString()!), "n");
if (imageEditCreateRequest.Size != null) multipartContent.Add(new StringContent(imageEditCreateRequest.Size), "size");
if (imageEditCreateRequest.N != null) multipartContent.Add(new StringContent(imageEditCreateRequest.N.ToString()!), "n");

multipartContent.Add(new ByteArrayContent(imageEditCreateRequest.Image), "image", imageEditCreateRequest.ImageName);

return await _httpClient.PostFileAndReadAsAsync<ImageCreateResponse>(_endpointProvider.ImageVariationCreate(), multipartContent);
Expand Down
1 change: 1 addition & 0 deletions OpenAI.SDK/Managers/OpenAIService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public OpenAIService(OpenAiOptions settings, HttpClient? httpClient = null)
public IFineTuneService FineTunes => this;
public IModerationService Moderation => this;
public IImageService Image => this;
public IEditService Edit => this;

public void SetDefaultEngineId(string engineId)
{
Expand Down
17 changes: 16 additions & 1 deletion OpenAI.SDK/ObjectModels/Models.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public enum Model
TextDavinciV1,

TextDavinciV2,
TextDavinciV3,

CurieInstructBeta,
DavinciInstructBeta,
Expand All @@ -45,6 +46,9 @@ public enum Model
TextSearchCurieQueryV1,
TextSearchDavinciQueryV1,

TextEditDavinciV1,
CodeEditDavinciV1,

CodeSearchAdaCodeV1,
CodeSearchBabbageCodeV1,

Expand All @@ -67,7 +71,9 @@ public enum Subject
TextSearchQuery,
CodeSearchCode,
CodeSearchText,
Code
Code,
CodeEdit,
Edit
}

public static string Ada => "ada";
Expand All @@ -80,6 +86,7 @@ public enum Subject

public static string TextDavinciV1 => ModelNameBuilder(BaseEngine.Davinci, Subject.Text, "001");
public static string TextDavinciV2 => ModelNameBuilder(BaseEngine.Davinci, Subject.Text, "002");
public static string TextDavinciV3 => ModelNameBuilder(BaseEngine.Davinci, Subject.Text, "003");
public static string TextAdaV1 => ModelNameBuilder(BaseEngine.Ada, Subject.Text, "001");
public static string TextBabbageV1 => ModelNameBuilder(BaseEngine.Babbage, Subject.Text, "001");
public static string TextCurieV1 => ModelNameBuilder(BaseEngine.Curie, Subject.Text, "001");
Expand All @@ -104,6 +111,9 @@ public enum Subject
public static string TextSearchCurieQueryV1 => ModelNameBuilder(BaseEngine.Curie, Subject.TextSearchQuery, "001");
public static string TextSearchDavinciQueryV1 => ModelNameBuilder(BaseEngine.Davinci, Subject.TextSearchQuery, "001");

public static string TextEditDavinciV1 => ModelNameBuilder(BaseEngine.Davinci, Subject.Edit, "001");
public static string CodeEditDavinciV1 => ModelNameBuilder(BaseEngine.Davinci, Subject.CodeEdit, "001");

public static string CodeSearchAdaCodeV1 => ModelNameBuilder(BaseEngine.Ada, Subject.CodeSearchCode, "001");
public static string CodeSearchBabbageCodeV1 => ModelNameBuilder(BaseEngine.Babbage, Subject.CodeSearchCode, "001");
public static string CodeSearchAdaTextV1 => ModelNameBuilder(BaseEngine.Ada, Subject.CodeSearchText, "001");
Expand Down Expand Up @@ -147,6 +157,7 @@ public static string EnumToString(this Model engine)
Model.DavinciInstructBeta => DavinciInstructBeta,
Model.TextDavinciV1 => TextDavinciV1,
Model.TextDavinciV2 => TextDavinciV2,
Model.TextDavinciV3 => TextDavinciV3,
Model.TextAdaV1 => TextAdaV1,
Model.TextBabbageV1 => TextBabbageV1,
Model.TextCurieV1 => TextCurieV1,
Expand All @@ -170,6 +181,8 @@ public static string EnumToString(this Model engine)
Model.CodeSearchBabbageCodeV1 => CodeSearchBabbageCodeV1,
Model.CodeSearchAdaTextV1 => CodeSearchAdaTextV1,
Model.CodeSearchBabbageTextV1 => CodeSearchBabbageTextV1,
Model.TextEditDavinciV1 => TextEditDavinciV1,
Model.CodeEditDavinciV1 => CodeEditDavinciV1,
_ => throw new ArgumentOutOfRangeException(nameof(engine), engine, null)
};
}
Expand Down Expand Up @@ -201,6 +214,8 @@ public static string EnumToString(this Subject subject, string baseEngine)
Subject.CodeSearchCode => "code-search-{0}-code",
Subject.CodeSearchText => "code-search-{0}-text",
Subject.Code => "code-{0}",
Subject.CodeEdit => "code-{0}-edit",
Subject.Edit => "text-{0}-edit",
_ => throw new ArgumentOutOfRangeException(nameof(subject), subject, null)
}, baseEngine);
}
Expand Down
12 changes: 6 additions & 6 deletions OpenAI.SDK/ObjectModels/RequestModels/CompletionCreateRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,6 @@ public record CompletionCreateRequest : IModelValidate, IOpenAiModels.ITemperatu
[JsonPropertyName("logit_bias")]
public object? LogitBias { get; set; }

/// <summary>
/// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
/// </summary>
[JsonPropertyName("user")]
public string? User { get; set; }

/// <summary>
/// Include the log probabilities on the logprobs most likely tokens, as well the chosen tokens. For example, if
/// logprobs is 5, the API will return a list of the 5 most likely tokens. The API will always return the logprob of
Expand Down Expand Up @@ -153,5 +147,11 @@ public IEnumerable<ValidationResult> Validate()
/// <see cref="https://beta.openai.com/docs/api-reference/completions/create#completions/create-temperature" />
[JsonPropertyName("temperature")]
public float? Temperature { get; set; }

/// <summary>
/// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
/// </summary>
[JsonPropertyName("user")]
public string? User { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace OpenAI.GPT3.ObjectModels.RequestModels
{
public class CreateModerationRequest : IOpenAiModels.IModel
public record CreateModerationRequest : IOpenAiModels.IModel
{
/// <summary>
/// The input text to classify
Expand Down
64 changes: 64 additions & 0 deletions OpenAI.SDK/ObjectModels/RequestModels/EditCreateRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;
using OpenAI.GPT3.Interfaces;
using OpenAI.GPT3.ObjectModels.SharedModels;

namespace OpenAI.GPT3.ObjectModels.RequestModels
{
//TODO add model validation
//TODO check what is string or array for prompt,..
/// <summary>
/// Create Edit Request Model
/// </summary>
public record EditCreateRequest : IModelValidate, IOpenAiModels.ITemperature, IOpenAiModels.IModel
{
/// <summary>
/// The input text to use as a starting point for the edit.
/// </summary>
[JsonPropertyName("input")]
public string? Input { get; set; }

/// <summary>
/// The instruction that tells the model how to edit the prompt.
/// </summary>
[JsonPropertyName("instruction")]
public string Instruction { get; set; }

/// <summary>
/// Defaults to 1
/// How many completions to generate for each prompt.
/// Note: Because this parameter generates many completions, it can quickly consume your token quota.Use carefully and
/// ensure that you have reasonable settings for max_tokens and stop.
/// </summary>
[JsonPropertyName("n")]
public int? N { get; set; }

/// <summary>
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are
/// considered.
/// We generally recommend altering this or temperature but not both.
/// </summary>
[JsonPropertyName("top_p")]
public float? TopP { get; set; }


[JsonPropertyName("model")] public string? Model { get; set; }

public IEnumerable<ValidationResult> Validate()
{
throw new NotImplementedException();
}

/// <summary>
/// What
/// <a href="https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277">sampling temperature</a>
/// to use. Higher values means the model will take more risks. Try 0.9 for more creative
/// applications, and 0 (argmax sampling) for ones with a well-defined answer.
/// We generally recommend altering this or top_p but not both.
/// </summary>
/// <see cref="https://beta.openai.com/docs/api-reference/completions/create#completions/create-temperature" />
[JsonPropertyName("temperature")]
public float? Temperature { get; set; }
}
}
Loading

0 comments on commit 3de77df

Please sign in to comment.