From 49c9d64624d23999ff0f71fd62ab929239e9ba30 Mon Sep 17 00:00:00 2001 From: HavenDV Date: Sat, 24 Aug 2024 18:09:35 +0400 Subject: [PATCH] feat: Updated to support easy image generation for FluxPro. --- src/helpers/FixOpenApiSpec/Program.cs | 118 +++++++++++++++++- .../Properties/launchSettings.json | 8 ++ .../Generated/JsonSerializerContextTypes.g.cs | 90 ++++++++----- .../Replicate.Models.PredictionResponse.g.cs | 95 ++++++++++++++ ...licate.Models.PredictionResponseInput.g.cs | 59 +++++++++ ...cate.Models.PredictionResponseMetrics.g.cs | 35 ++++++ ...plicate.Models.PredictionResponseUrls.g.cs | 29 +++++ ....ReplicateApi.ModelsPredictionsCreate.g.cs | 36 +++++- ...Replicate.ReplicateApi.PredictionsGet.g.cs | 32 ++++- .../Replicate/PredictionResponseExtensions.cs | 48 +++++++ src/libs/Replicate/openapi.yaml | 88 ++++++++++++- .../Tests.CreatePrediction.cs | 28 ++++- 12 files changed, 613 insertions(+), 53 deletions(-) create mode 100644 src/helpers/FixOpenApiSpec/Properties/launchSettings.json create mode 100644 src/libs/Replicate/Generated/Replicate.Models.PredictionResponse.g.cs create mode 100644 src/libs/Replicate/Generated/Replicate.Models.PredictionResponseInput.g.cs create mode 100644 src/libs/Replicate/Generated/Replicate.Models.PredictionResponseMetrics.g.cs create mode 100644 src/libs/Replicate/Generated/Replicate.Models.PredictionResponseUrls.g.cs create mode 100644 src/libs/Replicate/PredictionResponseExtensions.cs diff --git a/src/helpers/FixOpenApiSpec/Program.cs b/src/helpers/FixOpenApiSpec/Program.cs index ad9a6ea..8e6875b 100644 --- a/src/helpers/FixOpenApiSpec/Program.cs +++ b/src/helpers/FixOpenApiSpec/Program.cs @@ -1,19 +1,90 @@ +using System.Diagnostics.CodeAnalysis; using Microsoft.OpenApi; -using Microsoft.OpenApi.Any; -using Microsoft.OpenApi.Extensions; +using Microsoft.OpenApi.Extensions;using Microsoft.OpenApi.Models; using Microsoft.OpenApi.Readers; var path = args[0]; var text = await File.ReadAllTextAsync(path); text = text - .Replace("\"openapi\":\"3.1.0\"", "\"openapi\":\"3.0.1\"") + .Replace("\"openapi\":\"3.1.0\"", "\"openapi\":\"3.0.1\"") ; var openApiDocument = new OpenApiStringReader().Read(text, out var diagnostics); -//openApiDocument.Components.Schemas["CreateEmbeddingRequest"]!.Properties["dimensions"].Nullable = true; - +openApiDocument.Components.Schemas["prediction_response"] = FromJson( + /* language=json */ + """ + { + "completed_at": "2024-08-24T11:52:04.150854Z", + "created_at": "2024-08-24T11:51:46.577000Z", + "data_removed": false, + "error": null, + "id": "0ppyrp3aj5rge0chggxb4szz48", + "input": { + "seed": 321972, + "steps": 25, + "prompt": "a female, slavian, young adult, fit body, wavy acid orange hair, wearing open swimsuit, sea in the background.", + "guidance": 3.5, + "interval": 3, + "aspect_ratio": "16:9", + "safety_tolerance": 5 + }, + "logs": "Using seed: 321972\nRunning prediction... \nGenerating image...", + "metrics": { + "image_count": 1, + "predict_time": 17.565933287, + "total_time": 17.573854 + }, + "output": "https://replicate.delivery/czjl/UVvZ7pAzOk7zLlZhKB2nUx9veCCVSDk4VlfwJ7KxaDmkt3VTA/output.webp", + "started_at": "2024-08-24T11:51:46.584921Z", + "status": "succeeded", + "urls": { + "get": "https://api.replicate.com/v1/predictions/0ppyrp3aj5rge0chggxb4szz48", + "cancel": "https://api.replicate.com/v1/predictions/0ppyrp3aj5rge0chggxb4szz48/cancel" + }, + "version": "d848511ad960c3a099e2a5b04f783ebf8e8a44c625b8fa7d8f03b72ebee1c954" + } + """); +openApiDocument.Paths["/models/{model_owner}/{model_name}/predictions"] + .Operations[OperationType.Post].Responses["200"] = new OpenApiResponse +{ + Description = "Successful", + Content = new Dictionary + { + ["application/json"] = new() + { + Schema = new OpenApiSchema + { + Reference = new OpenApiReference + { + Type = ReferenceType.Schema, + Id = "prediction_response", + }, + }, + }, + }, +}; +openApiDocument.Paths["/predictions/{prediction_id}"] + .Operations[OperationType.Get].Responses["200"] = new OpenApiResponse +{ + Description = "Successful", + Content = new Dictionary + { + ["application/json"] = new() + { + Schema = new OpenApiSchema + { + Reference = new OpenApiReference + { + Type = ReferenceType.Schema, + Id = "prediction_response", + }, + }, + }, + }, +}; + text = openApiDocument.SerializeAsYaml(OpenApiSpecVersion.OpenApi3_0); _ = new OpenApiStringReader().Read(text, out diagnostics); @@ -23,8 +94,43 @@ { Console.WriteLine(error.Message); } + // Return Exit code 1 Environment.Exit(1); } -await File.WriteAllTextAsync(path, text); \ No newline at end of file +await File.WriteAllTextAsync(path, text); + +static OpenApiSchema FromJson([StringSyntax(StringSyntaxAttribute.Json)] string json) +{ + var schema = new OpenApiSchema(); + + var element = System.Text.Json.JsonDocument.Parse(json).RootElement; + schema.Type = element.ValueKind switch + { + System.Text.Json.JsonValueKind.Array => "array", + System.Text.Json.JsonValueKind.False => "boolean", + System.Text.Json.JsonValueKind.True => "boolean", + System.Text.Json.JsonValueKind.Number => "number", + System.Text.Json.JsonValueKind.String => "string", + System.Text.Json.JsonValueKind.Object => "object", + System.Text.Json.JsonValueKind.Null => "null", + _ => throw new NotSupportedException(), + }; + schema.Format = element.ValueKind switch + { + System.Text.Json.JsonValueKind.Number => element.TryGetInt64(out var _) ? "int64" : "number", + System.Text.Json.JsonValueKind.String => "string", + _ => null, + }; + schema.Properties = new Dictionary(); + if (element.ValueKind == System.Text.Json.JsonValueKind.Object) + { + foreach (var property in element.EnumerateObject()) + { + schema.Properties[property.Name] = FromJson(property.Value.GetRawText()); + } + } + + return schema; +} \ No newline at end of file diff --git a/src/helpers/FixOpenApiSpec/Properties/launchSettings.json b/src/helpers/FixOpenApiSpec/Properties/launchSettings.json new file mode 100644 index 0000000..eda77e1 --- /dev/null +++ b/src/helpers/FixOpenApiSpec/Properties/launchSettings.json @@ -0,0 +1,8 @@ +{ + "profiles": { + "Localizer": { + "commandName": "Project", + "commandLineArgs": "../../../../../../src/libs/Replicate/openapi.yaml" + } + } +} \ No newline at end of file diff --git a/src/libs/Replicate/Generated/JsonSerializerContextTypes.g.cs b/src/libs/Replicate/Generated/JsonSerializerContextTypes.g.cs index b12d695..ba45761 100644 --- a/src/libs/Replicate/Generated/JsonSerializerContextTypes.g.cs +++ b/src/libs/Replicate/Generated/JsonSerializerContextTypes.g.cs @@ -69,134 +69,158 @@ public sealed partial class JsonSerializerContextTypes /// /// /// - public global::Replicate.DeploymentsCreateRequest? Type14 { get; set; } + public global::Replicate.PredictionResponse? Type14 { get; set; } /// /// /// - public int? Type15 { get; set; } + public object? Type15 { get; set; } /// /// /// - public global::Replicate.DeploymentsUpdateRequest? Type16 { get; set; } + public global::Replicate.PredictionResponseInput? Type16 { get; set; } /// /// /// - public global::Replicate.ModelsCreateRequest? Type17 { get; set; } + public double? Type17 { get; set; } /// /// /// - public global::Replicate.ModelsCreateRequestVisibility? Type18 { get; set; } + public global::Replicate.PredictionResponseMetrics? Type18 { get; set; } /// /// /// - public global::Replicate.AccountGetResponse? Type19 { get; set; } + public global::Replicate.PredictionResponseUrls? Type19 { get; set; } /// /// /// - public global::Replicate.AccountGetResponseType? Type20 { get; set; } + public global::Replicate.DeploymentsCreateRequest? Type20 { get; set; } /// /// /// - public global::Replicate.DeploymentsListResponse? Type21 { get; set; } + public int? Type21 { get; set; } /// /// /// - public global::System.Collections.Generic.IList? Type22 { get; set; } + public global::Replicate.DeploymentsUpdateRequest? Type22 { get; set; } /// /// /// - public global::Replicate.DeploymentsListResponseResult? Type23 { get; set; } + public global::Replicate.ModelsCreateRequest? Type23 { get; set; } /// /// /// - public global::Replicate.DeploymentsListResponseResultCurrentRelease? Type24 { get; set; } + public global::Replicate.ModelsCreateRequestVisibility? Type24 { get; set; } /// /// /// - public global::Replicate.DeploymentsListResponseResultCurrentReleaseConfiguration? Type25 { get; set; } + public global::Replicate.AccountGetResponse? Type25 { get; set; } /// /// /// - public global::System.DateTime? Type26 { get; set; } + public global::Replicate.AccountGetResponseType? Type26 { get; set; } /// /// /// - public global::Replicate.DeploymentsListResponseResultCurrentReleaseCreatedBy? Type27 { get; set; } + public global::Replicate.DeploymentsListResponse? Type27 { get; set; } /// /// /// - public global::Replicate.DeploymentsListResponseResultCurrentReleaseCreatedByType? Type28 { get; set; } + public global::System.Collections.Generic.IList? Type28 { get; set; } /// /// /// - public global::Replicate.DeploymentsCreateResponse? Type29 { get; set; } + public global::Replicate.DeploymentsListResponseResult? Type29 { get; set; } /// /// /// - public global::Replicate.DeploymentsCreateResponseCurrentRelease? Type30 { get; set; } + public global::Replicate.DeploymentsListResponseResultCurrentRelease? Type30 { get; set; } /// /// /// - public global::Replicate.DeploymentsCreateResponseCurrentReleaseConfiguration? Type31 { get; set; } + public global::Replicate.DeploymentsListResponseResultCurrentReleaseConfiguration? Type31 { get; set; } /// /// /// - public global::Replicate.DeploymentsCreateResponseCurrentReleaseCreatedBy? Type32 { get; set; } + public global::System.DateTime? Type32 { get; set; } /// /// /// - public global::Replicate.DeploymentsCreateResponseCurrentReleaseCreatedByType? Type33 { get; set; } + public global::Replicate.DeploymentsListResponseResultCurrentReleaseCreatedBy? Type33 { get; set; } /// /// /// - public global::Replicate.DeploymentsGetResponse? Type34 { get; set; } + public global::Replicate.DeploymentsListResponseResultCurrentReleaseCreatedByType? Type34 { get; set; } /// /// /// - public global::Replicate.DeploymentsGetResponseCurrentRelease? Type35 { get; set; } + public global::Replicate.DeploymentsCreateResponse? Type35 { get; set; } /// /// /// - public global::Replicate.DeploymentsGetResponseCurrentReleaseConfiguration? Type36 { get; set; } + public global::Replicate.DeploymentsCreateResponseCurrentRelease? Type36 { get; set; } /// /// /// - public global::Replicate.DeploymentsGetResponseCurrentReleaseCreatedBy? Type37 { get; set; } + public global::Replicate.DeploymentsCreateResponseCurrentReleaseConfiguration? Type37 { get; set; } /// /// /// - public global::Replicate.DeploymentsGetResponseCurrentReleaseCreatedByType? Type38 { get; set; } + public global::Replicate.DeploymentsCreateResponseCurrentReleaseCreatedBy? Type38 { get; set; } /// /// /// - public global::Replicate.DeploymentsUpdateResponse? Type39 { get; set; } + public global::Replicate.DeploymentsCreateResponseCurrentReleaseCreatedByType? Type39 { get; set; } /// /// /// - public global::Replicate.DeploymentsUpdateResponseCurrentRelease? Type40 { get; set; } + public global::Replicate.DeploymentsGetResponse? Type40 { get; set; } /// /// /// - public global::Replicate.DeploymentsUpdateResponseCurrentReleaseConfiguration? Type41 { get; set; } + public global::Replicate.DeploymentsGetResponseCurrentRelease? Type41 { get; set; } /// /// /// - public global::Replicate.DeploymentsUpdateResponseCurrentReleaseCreatedBy? Type42 { get; set; } + public global::Replicate.DeploymentsGetResponseCurrentReleaseConfiguration? Type42 { get; set; } /// /// /// - public global::Replicate.DeploymentsUpdateResponseCurrentReleaseCreatedByType? Type43 { get; set; } + public global::Replicate.DeploymentsGetResponseCurrentReleaseCreatedBy? Type43 { get; set; } /// /// /// - public global::System.Collections.Generic.IList? Type44 { get; set; } + public global::Replicate.DeploymentsGetResponseCurrentReleaseCreatedByType? Type44 { get; set; } /// /// /// - public global::Replicate.HardwareListResponseItem? Type45 { get; set; } + public global::Replicate.DeploymentsUpdateResponse? Type45 { get; set; } /// /// /// - public global::Replicate.WebhooksDefaultSecretGetResponse? Type46 { get; set; } + public global::Replicate.DeploymentsUpdateResponseCurrentRelease? Type46 { get; set; } + /// + /// + /// + public global::Replicate.DeploymentsUpdateResponseCurrentReleaseConfiguration? Type47 { get; set; } + /// + /// + /// + public global::Replicate.DeploymentsUpdateResponseCurrentReleaseCreatedBy? Type48 { get; set; } + /// + /// + /// + public global::Replicate.DeploymentsUpdateResponseCurrentReleaseCreatedByType? Type49 { get; set; } + /// + /// + /// + public global::System.Collections.Generic.IList? Type50 { get; set; } + /// + /// + /// + public global::Replicate.HardwareListResponseItem? Type51 { get; set; } + /// + /// + /// + public global::Replicate.WebhooksDefaultSecretGetResponse? Type52 { get; set; } } } \ No newline at end of file diff --git a/src/libs/Replicate/Generated/Replicate.Models.PredictionResponse.g.cs b/src/libs/Replicate/Generated/Replicate.Models.PredictionResponse.g.cs new file mode 100644 index 0000000..f4edfc2 --- /dev/null +++ b/src/libs/Replicate/Generated/Replicate.Models.PredictionResponse.g.cs @@ -0,0 +1,95 @@ + +#nullable enable + +namespace Replicate +{ + /// + /// + /// + public sealed partial class PredictionResponse + { + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("completed_at")] + public string? CompletedAt { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("created_at")] + public string? CreatedAt { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("data_removed")] + public bool DataRemoved { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("error")] + public object? Error { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("id")] + public string? Id { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("input")] + public global::Replicate.PredictionResponseInput? Input { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("logs")] + public string? Logs { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("metrics")] + public global::Replicate.PredictionResponseMetrics? Metrics { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("output")] + public string? Output { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("started_at")] + public string? StartedAt { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("status")] + public string? Status { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("urls")] + public global::Replicate.PredictionResponseUrls? Urls { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("version")] + public string? Version { get; set; } + + /// + /// Additional properties that are not explicitly defined in the schema + /// + [global::System.Text.Json.Serialization.JsonExtensionData] + public global::System.Collections.Generic.IDictionary AdditionalProperties { get; set; } = new global::System.Collections.Generic.Dictionary(); + } +} \ No newline at end of file diff --git a/src/libs/Replicate/Generated/Replicate.Models.PredictionResponseInput.g.cs b/src/libs/Replicate/Generated/Replicate.Models.PredictionResponseInput.g.cs new file mode 100644 index 0000000..47e38b7 --- /dev/null +++ b/src/libs/Replicate/Generated/Replicate.Models.PredictionResponseInput.g.cs @@ -0,0 +1,59 @@ + +#nullable enable + +namespace Replicate +{ + /// + /// + /// + public sealed partial class PredictionResponseInput + { + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("seed")] + public double Seed { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("steps")] + public double Steps { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("prompt")] + public string? Prompt { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("guidance")] + public double Guidance { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("interval")] + public double Interval { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("aspect_ratio")] + public string? AspectRatio { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("safety_tolerance")] + public double SafetyTolerance { get; set; } + + /// + /// Additional properties that are not explicitly defined in the schema + /// + [global::System.Text.Json.Serialization.JsonExtensionData] + public global::System.Collections.Generic.IDictionary AdditionalProperties { get; set; } = new global::System.Collections.Generic.Dictionary(); + } +} \ No newline at end of file diff --git a/src/libs/Replicate/Generated/Replicate.Models.PredictionResponseMetrics.g.cs b/src/libs/Replicate/Generated/Replicate.Models.PredictionResponseMetrics.g.cs new file mode 100644 index 0000000..1316f5b --- /dev/null +++ b/src/libs/Replicate/Generated/Replicate.Models.PredictionResponseMetrics.g.cs @@ -0,0 +1,35 @@ + +#nullable enable + +namespace Replicate +{ + /// + /// + /// + public sealed partial class PredictionResponseMetrics + { + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("image_count")] + public double ImageCount { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("predict_time")] + public double PredictTime { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("total_time")] + public double TotalTime { get; set; } + + /// + /// Additional properties that are not explicitly defined in the schema + /// + [global::System.Text.Json.Serialization.JsonExtensionData] + public global::System.Collections.Generic.IDictionary AdditionalProperties { get; set; } = new global::System.Collections.Generic.Dictionary(); + } +} \ No newline at end of file diff --git a/src/libs/Replicate/Generated/Replicate.Models.PredictionResponseUrls.g.cs b/src/libs/Replicate/Generated/Replicate.Models.PredictionResponseUrls.g.cs new file mode 100644 index 0000000..dc945d6 --- /dev/null +++ b/src/libs/Replicate/Generated/Replicate.Models.PredictionResponseUrls.g.cs @@ -0,0 +1,29 @@ + +#nullable enable + +namespace Replicate +{ + /// + /// + /// + public sealed partial class PredictionResponseUrls + { + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("get")] + public string? Get { get; set; } + + /// + /// + /// + [global::System.Text.Json.Serialization.JsonPropertyName("cancel")] + public string? Cancel { get; set; } + + /// + /// Additional properties that are not explicitly defined in the schema + /// + [global::System.Text.Json.Serialization.JsonExtensionData] + public global::System.Collections.Generic.IDictionary AdditionalProperties { get; set; } = new global::System.Collections.Generic.Dictionary(); + } +} \ No newline at end of file diff --git a/src/libs/Replicate/Generated/Replicate.ReplicateApi.ModelsPredictionsCreate.g.cs b/src/libs/Replicate/Generated/Replicate.ReplicateApi.ModelsPredictionsCreate.g.cs index 4af29a9..80c22ae 100644 --- a/src/libs/Replicate/Generated/Replicate.ReplicateApi.ModelsPredictionsCreate.g.cs +++ b/src/libs/Replicate/Generated/Replicate.ReplicateApi.ModelsPredictionsCreate.g.cs @@ -20,6 +20,11 @@ partial void ProcessModelsPredictionsCreateResponse( global::System.Net.Http.HttpClient httpClient, global::System.Net.Http.HttpResponseMessage httpResponseMessage); + partial void ProcessModelsPredictionsCreateResponseContent( + global::System.Net.Http.HttpClient httpClient, + global::System.Net.Http.HttpResponseMessage httpResponseMessage, + ref string content); + /// /// Create a prediction using an official model
/// Start a new prediction for an official model using the inputs you provide.
@@ -67,7 +72,7 @@ partial void ProcessModelsPredictionsCreateResponse( /// /// The token to cancel the operation with /// - public async global::System.Threading.Tasks.Task ModelsPredictionsCreateAsync( + public async global::System.Threading.Tasks.Task ModelsPredictionsCreateAsync( string modelOwner, string modelName, global::Replicate.PredictionRequest request, @@ -114,7 +119,30 @@ partial void ProcessModelsPredictionsCreateResponse( ProcessModelsPredictionsCreateResponse( httpClient: _httpClient, httpResponseMessage: response); - response.EnsureSuccessStatusCode(); + + var __content = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + + ProcessResponseContent( + client: _httpClient, + response: response, + content: ref __content); + ProcessModelsPredictionsCreateResponseContent( + httpClient: _httpClient, + httpResponseMessage: response, + content: ref __content); + + try + { + response.EnsureSuccessStatusCode(); + } + catch (global::System.Net.Http.HttpRequestException ex) + { + throw new global::System.InvalidOperationException(__content, ex); + } + + return + global::System.Text.Json.JsonSerializer.Deserialize(__content, global::Replicate.SourceGenerationContext.Default.PredictionResponse) ?? + throw new global::System.InvalidOperationException($"Response deserialization failed for \"{__content}\" "); } /// @@ -201,7 +229,7 @@ partial void ProcessModelsPredictionsCreateResponse( /// /// The token to cancel the operation with /// - public async global::System.Threading.Tasks.Task ModelsPredictionsCreateAsync( + public async global::System.Threading.Tasks.Task ModelsPredictionsCreateAsync( string modelOwner, string modelName, global::Replicate.PredictionRequestInput input, @@ -218,7 +246,7 @@ partial void ProcessModelsPredictionsCreateResponse( WebhookEventsFilter = webhookEventsFilter, }; - await ModelsPredictionsCreateAsync( + return await ModelsPredictionsCreateAsync( modelOwner: modelOwner, modelName: modelName, request: request, diff --git a/src/libs/Replicate/Generated/Replicate.ReplicateApi.PredictionsGet.g.cs b/src/libs/Replicate/Generated/Replicate.ReplicateApi.PredictionsGet.g.cs index 72bad17..b85e7d7 100644 --- a/src/libs/Replicate/Generated/Replicate.ReplicateApi.PredictionsGet.g.cs +++ b/src/libs/Replicate/Generated/Replicate.ReplicateApi.PredictionsGet.g.cs @@ -16,6 +16,11 @@ partial void ProcessPredictionsGetResponse( global::System.Net.Http.HttpClient httpClient, global::System.Net.Http.HttpResponseMessage httpResponseMessage); + partial void ProcessPredictionsGetResponseContent( + global::System.Net.Http.HttpClient httpClient, + global::System.Net.Http.HttpResponseMessage httpResponseMessage, + ref string content); + /// /// Get a prediction
/// Get the current state of a prediction.
@@ -67,7 +72,7 @@ partial void ProcessPredictionsGetResponse( /// /// The token to cancel the operation with /// - public async global::System.Threading.Tasks.Task PredictionsGetAsync( + public async global::System.Threading.Tasks.Task PredictionsGetAsync( string predictionId, global::System.Threading.CancellationToken cancellationToken = default) { @@ -100,7 +105,30 @@ partial void ProcessPredictionsGetResponse( ProcessPredictionsGetResponse( httpClient: _httpClient, httpResponseMessage: response); - response.EnsureSuccessStatusCode(); + + var __content = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + + ProcessResponseContent( + client: _httpClient, + response: response, + content: ref __content); + ProcessPredictionsGetResponseContent( + httpClient: _httpClient, + httpResponseMessage: response, + content: ref __content); + + try + { + response.EnsureSuccessStatusCode(); + } + catch (global::System.Net.Http.HttpRequestException ex) + { + throw new global::System.InvalidOperationException(__content, ex); + } + + return + global::System.Text.Json.JsonSerializer.Deserialize(__content, global::Replicate.SourceGenerationContext.Default.PredictionResponse) ?? + throw new global::System.InvalidOperationException($"Response deserialization failed for \"{__content}\" "); } } } \ No newline at end of file diff --git a/src/libs/Replicate/PredictionResponseExtensions.cs b/src/libs/Replicate/PredictionResponseExtensions.cs new file mode 100644 index 0000000..f6d8d70 --- /dev/null +++ b/src/libs/Replicate/PredictionResponseExtensions.cs @@ -0,0 +1,48 @@ +namespace Replicate; + +/// +/// +/// +public static class PredictionResponseExtensions +{ + /// + /// + /// + /// + /// + /// + public static bool IsSuccessful(this PredictionResponse response) + { + response = response ?? throw new ArgumentNullException(nameof(response)); + + return response.Status == "succeeded"; + } + + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static async Task WaitUntilSuccessfulAsync( + this PredictionResponse response, + ReplicateApi api, + CancellationToken cancellationToken = default) + { + response = response ?? throw new ArgumentNullException(nameof(response)); + api = api ?? throw new ArgumentNullException(nameof(api)); + var id = response.Id ?? throw new ArgumentException(nameof(response.Id)); + + while (!response.IsSuccessful()) + { + await Task.Delay(1000, cancellationToken).ConfigureAwait(false); + + response = await api.PredictionsGetAsync(id, cancellationToken).ConfigureAwait(false); + } + + return response; + } +} \ No newline at end of file diff --git a/src/libs/Replicate/openapi.yaml b/src/libs/Replicate/openapi.yaml index 5e46d31..eadb7a6 100644 --- a/src/libs/Replicate/openapi.yaml +++ b/src/libs/Replicate/openapi.yaml @@ -611,7 +611,11 @@ paths: $ref: '#/components/schemas/prediction_request' responses: '200': - description: Success + description: Successful + content: + application/json: + schema: + $ref: '#/components/schemas/prediction_response' '/models/{model_owner}/{model_name}/versions': get: summary: List model versions @@ -752,7 +756,11 @@ paths: type: string responses: '200': - description: Success + description: Successful + content: + application/json: + schema: + $ref: '#/components/schemas/prediction_response' '/predictions/{prediction_id}/cancel': post: summary: Cancel a prediction @@ -901,6 +909,82 @@ components: type: string description: "By default, we will send requests to your webhook URL whenever there are new outputs or the prediction has finished. You can change which events trigger webhook requests by specifying `webhook_events_filter` in the prediction request:\n\n- `start`: immediately on prediction start\n- `output`: each time a prediction generates an output (note that predictions can generate multiple outputs)\n- `logs`: each time log output is generated by a prediction\n- `completed`: when the prediction reaches a terminal state (succeeded/canceled/failed)\n\nFor example, if you only wanted requests to be sent at the start and end of the prediction, you would provide:\n\n```json\n{\n \"version\": \"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa\",\n \"input\": {\n \"text\": \"Alice\"\n },\n \"webhook\": \"https://example.com/my-webhook\",\n \"webhook_events_filter\": [\"start\", \"completed\"]\n}\n```\n\nRequests for event types `output` and `logs` will be sent at most once every 500ms. If you request `start` and `completed` webhooks, then they'll always be sent regardless of throttling.\n" additionalProperties: false + prediction_response: + type: object + properties: + completed_at: + type: string + format: string + created_at: + type: string + format: string + data_removed: + type: boolean + error: + type: 'null' + id: + type: string + format: string + input: + type: object + properties: + seed: + type: number + format: int64 + steps: + type: number + format: int64 + prompt: + type: string + format: string + guidance: + type: number + format: number + interval: + type: number + format: int64 + aspect_ratio: + type: string + format: string + safety_tolerance: + type: number + format: int64 + logs: + type: string + format: string + metrics: + type: object + properties: + image_count: + type: number + format: int64 + predict_time: + type: number + format: number + total_time: + type: number + format: number + output: + type: string + format: string + started_at: + type: string + format: string + status: + type: string + format: string + urls: + type: object + properties: + get: + type: string + format: string + cancel: + type: string + format: string + version: + type: string + format: string securitySchemes: bearerAuth: type: http diff --git a/src/tests/Replicate.IntegrationTests/Tests.CreatePrediction.cs b/src/tests/Replicate.IntegrationTests/Tests.CreatePrediction.cs index 76b3199..ff926eb 100755 --- a/src/tests/Replicate.IntegrationTests/Tests.CreatePrediction.cs +++ b/src/tests/Replicate.IntegrationTests/Tests.CreatePrediction.cs @@ -3,19 +3,35 @@ namespace Replicate.IntegrationTests; public partial class Tests { [TestMethod] - public async Task CreatePrediction() + public async Task CreatePredictionForFluxPro() { using var api = GetAuthorizedApi(); - await api.PredictionsCreateAsync( - input: new VersionPredictionRequestInput + var response = await api.ModelsPredictionsCreateAsync( + input: new PredictionRequestInput { AdditionalProperties = new Dictionary { - ["prompt"] = "I forgot how to kill a process in Linux, can you help?", - ["assistant"] = "Sure! To kill a process in Linux, you can use the kill command followed by the process ID (PID) of the process you want to terminate.", + ["seed"] = Random.Shared.Next(0, 1000000), + ["steps"] = 25, + ["prompt"] = "a female, european, young adult, fit body, wavy acid orange hair, wearing open swimsuit, sea in the background.", + ["guidance"] = 3.5, + ["interval"] = 3, + ["aspect_ratio"] = "16:9", + ["safety_tolerance"] = 5, }, }, - version: "b063023ee937f28e922982abdbf97b041ffe34ad3b35a53d33e1d74bb19b36c4"); + modelOwner: "black-forest-labs", + modelName: "flux-pro", + stream: false, + webhook: "https://hook.eu2.make.com/h6tawmpxsmxb7ut4edfmje4g3xw8y8rf", + webhookEventsFilter: null); + response.Should().NotBeNull(); + response.Id.Should().NotBeNull(); + + var endResponse = await response.WaitUntilSuccessfulAsync(api); + + Console.WriteLine($@"Seed: {endResponse.Input?.Seed}. +Image available at:\n{endResponse.Output}"); } } \ No newline at end of file