From aeca0977f1a808d1970380432415544eb0c01b42 Mon Sep 17 00:00:00 2001 From: Enkidu93 Date: Tue, 7 Nov 2023 12:18:20 -0500 Subject: [PATCH] Fixes #202 --- src/Serval.Client/Client.g.cs | 30 ++++++++++++++++++ .../Protos/serval/translation/v1/engine.proto | 10 +++--- .../Contracts/TrainingCorpusConfigDto.cs | 7 +++++ .../Contracts/TrainingCorpusDto.cs | 8 +++++ .../Contracts/TranslationBuildConfigDto.cs | 1 + .../Contracts/TranslationBuildDto.cs | 1 + .../TranslationEnginesController.cs | 31 +++++++++++++++++-- src/Serval.Translation/Models/Build.cs | 1 + .../Models/TrainingCorpus.cs | 7 +++++ .../Services/EngineService.cs | 7 +++++ tests/Serval.E2ETests/ServalApiTests.cs | 9 ++++-- tests/Serval.E2ETests/ServalClientHelper.cs | 4 ++- 12 files changed, 106 insertions(+), 10 deletions(-) create mode 100644 src/Serval.Translation/Contracts/TrainingCorpusConfigDto.cs create mode 100644 src/Serval.Translation/Contracts/TrainingCorpusDto.cs create mode 100644 src/Serval.Translation/Models/TrainingCorpus.cs diff --git a/src/Serval.Client/Client.g.cs b/src/Serval.Client/Client.g.cs index e1c408ee..029b275d 100644 --- a/src/Serval.Client/Client.g.cs +++ b/src/Serval.Client/Client.g.cs @@ -4426,6 +4426,9 @@ public partial class TranslationBuild [System.ComponentModel.DataAnnotations.Required] public ResourceLink Engine { get; set; } = new ResourceLink(); + [Newtonsoft.Json.JsonProperty("trainOn", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] + public System.Collections.Generic.IList? TrainOn { get; set; } = default!; + [Newtonsoft.Json.JsonProperty("pretranslate", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] public System.Collections.Generic.IList? Pretranslate { get; set; } = default!; @@ -4457,6 +4460,18 @@ public partial class TranslationBuild } + [System.CodeDom.Compiler.GeneratedCode("NJsonSchema", "13.18.2.0 (NJsonSchema v10.8.0.0 (Newtonsoft.Json v13.0.0.0))")] + public partial class TrainingCorpus + { + [Newtonsoft.Json.JsonProperty("corpus", Required = Newtonsoft.Json.Required.Always)] + [System.ComponentModel.DataAnnotations.Required] + public ResourceLink Corpus { get; set; } = new ResourceLink(); + + [Newtonsoft.Json.JsonProperty("textIds", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] + public System.Collections.Generic.IList? TextIds { get; set; } = default!; + + } + [System.CodeDom.Compiler.GeneratedCode("NJsonSchema", "13.18.2.0 (NJsonSchema v10.8.0.0 (Newtonsoft.Json v13.0.0.0))")] public partial class PretranslateCorpus { @@ -4496,6 +4511,9 @@ public partial class TranslationBuildConfig [Newtonsoft.Json.JsonProperty("name", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] public string? Name { get; set; } = default!; + [Newtonsoft.Json.JsonProperty("trainOn", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] + public System.Collections.Generic.IList? TrainOn { get; set; } = default!; + [Newtonsoft.Json.JsonProperty("pretranslate", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] public System.Collections.Generic.IList? Pretranslate { get; set; } = default!; @@ -4504,6 +4522,18 @@ public partial class TranslationBuildConfig } + [System.CodeDom.Compiler.GeneratedCode("NJsonSchema", "13.18.2.0 (NJsonSchema v10.8.0.0 (Newtonsoft.Json v13.0.0.0))")] + public partial class TrainingCorpusConfig + { + [Newtonsoft.Json.JsonProperty("corpusId", Required = Newtonsoft.Json.Required.Always)] + [System.ComponentModel.DataAnnotations.Required(AllowEmptyStrings = true)] + public string CorpusId { get; set; } = default!; + + [Newtonsoft.Json.JsonProperty("textIds", Required = Newtonsoft.Json.Required.Default, NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore)] + public System.Collections.Generic.IList? TextIds { get; set; } = default!; + + } + [System.CodeDom.Compiler.GeneratedCode("NJsonSchema", "13.18.2.0 (NJsonSchema v10.8.0.0 (Newtonsoft.Json v13.0.0.0))")] public partial class PretranslateCorpusConfig { diff --git a/src/Serval.Grpc/Protos/serval/translation/v1/engine.proto b/src/Serval.Grpc/Protos/serval/translation/v1/engine.proto index f98478ca..89802f2a 100644 --- a/src/Serval.Grpc/Protos/serval/translation/v1/engine.proto +++ b/src/Serval.Grpc/Protos/serval/translation/v1/engine.proto @@ -126,10 +126,12 @@ message Corpus { string id = 1; string source_language = 2; string target_language = 3; - bool pretranslate_all = 4; - repeated string pretranslate_text_ids = 5; - repeated CorpusFile source_files = 6; - repeated CorpusFile target_files = 7; + bool train_on_all = 4; + bool pretranslate_all = 5; + repeated string train_on_text_ids = 6; + repeated string pretranslate_text_ids = 7; + repeated CorpusFile source_files = 8; + repeated CorpusFile target_files = 9; } message CorpusFile { diff --git a/src/Serval.Translation/Contracts/TrainingCorpusConfigDto.cs b/src/Serval.Translation/Contracts/TrainingCorpusConfigDto.cs new file mode 100644 index 00000000..a0071edc --- /dev/null +++ b/src/Serval.Translation/Contracts/TrainingCorpusConfigDto.cs @@ -0,0 +1,7 @@ +namespace Serval.Translation.Contracts; + +public class TrainingCorpusConfigDto +{ + public string CorpusId { get; set; } = default!; + public IList? TextIds { get; set; } +} diff --git a/src/Serval.Translation/Contracts/TrainingCorpusDto.cs b/src/Serval.Translation/Contracts/TrainingCorpusDto.cs new file mode 100644 index 00000000..a6c3b05d --- /dev/null +++ b/src/Serval.Translation/Contracts/TrainingCorpusDto.cs @@ -0,0 +1,8 @@ +namespace Serval.Translation.Contracts; + +public class TrainingCorpusDto +{ + public ResourceLinkDto Corpus { get; set; } = default!; + + public IList? TextIds { get; set; } +} diff --git a/src/Serval.Translation/Contracts/TranslationBuildConfigDto.cs b/src/Serval.Translation/Contracts/TranslationBuildConfigDto.cs index 51297a9e..00265820 100644 --- a/src/Serval.Translation/Contracts/TranslationBuildConfigDto.cs +++ b/src/Serval.Translation/Contracts/TranslationBuildConfigDto.cs @@ -3,6 +3,7 @@ public class TranslationBuildConfigDto { public string? Name { get; set; } + public IList? TrainOn { get; set; } public IList? Pretranslate { get; set; } /// diff --git a/src/Serval.Translation/Contracts/TranslationBuildDto.cs b/src/Serval.Translation/Contracts/TranslationBuildDto.cs index 9d9fa190..24f7cd34 100644 --- a/src/Serval.Translation/Contracts/TranslationBuildDto.cs +++ b/src/Serval.Translation/Contracts/TranslationBuildDto.cs @@ -7,6 +7,7 @@ public class TranslationBuildDto public int Revision { get; set; } public string? Name { get; set; } public ResourceLinkDto Engine { get; set; } = default!; + public IList? TrainOn { get; set; } public IList? Pretranslate { get; set; } public int Step { get; set; } public double? PercentCompleted { get; set; } diff --git a/src/Serval.Translation/Controllers/TranslationEnginesController.cs b/src/Serval.Translation/Controllers/TranslationEnginesController.cs index 28f0c510..6547ffa3 100644 --- a/src/Serval.Translation/Controllers/TranslationEnginesController.cs +++ b/src/Serval.Translation/Controllers/TranslationEnginesController.cs @@ -1,4 +1,6 @@ -namespace Serval.Translation.Controllers; +using System.Net.Sockets; + +namespace Serval.Translation.Controllers; [ApiVersion(1.0)] [Route("api/v{version:apiVersion}/translation/engines")] @@ -998,10 +1000,10 @@ private Engine Map(TranslationEngineConfigDto source) private static Build Map(Engine engine, TranslationBuildConfigDto source) { var build = new Build { EngineRef = engine.Id, Name = source.Name }; + var corpusIds = new HashSet(engine.Corpora.Select(c => c.Id)); if (source.Pretranslate != null) { var pretranslateCorpora = new List(); - var corpusIds = new HashSet(engine.Corpora.Select(c => c.Id)); foreach (PretranslateCorpusConfigDto ptcc in source.Pretranslate) { if (!corpusIds.Contains(ptcc.CorpusId)) @@ -1013,6 +1015,17 @@ private static Build Map(Engine engine, TranslationBuildConfigDto source) } build.Pretranslate = pretranslateCorpora; } + if (source.TrainOn != null) + { + var trainOnCorpora = new List(); + foreach (TrainingCorpusConfigDto tcc in source.TrainOn) + { + if (!corpusIds.Contains(tcc.CorpusId)) + throw new InvalidOperationException($"The corpus {tcc.CorpusId} is not valid."); + trainOnCorpora.Add(new TrainingCorpus { CorpusRef = tcc.CorpusId, TextIds = tcc.TextIds?.ToList() }); + } + build.TrainOn = trainOnCorpora; + } try { var jsonSerializerOptions = new JsonSerializerOptions(); @@ -1061,6 +1074,7 @@ private TranslationBuildDto Map(Build source) Id = source.EngineRef, Url = _urlService.GetUrl("GetTranslationEngine", new { id = source.EngineRef }) }, + TrainOn = source.TrainOn?.Select(s => Map(source.EngineRef, s)).ToList(), Pretranslate = source.Pretranslate?.Select(s => Map(source.EngineRef, s)).ToList(), Step = source.Step, PercentCompleted = source.PercentCompleted, @@ -1085,6 +1099,19 @@ private PretranslateCorpusDto Map(string engineId, PretranslateCorpus source) }; } + private TrainingCorpusDto Map(string engineId, TrainingCorpus source) + { + return new TrainingCorpusDto + { + Corpus = new ResourceLinkDto + { + Id = source.CorpusRef, + Url = _urlService.GetUrl("GetTranslationCorpus", new { id = engineId, corpusId = source.CorpusRef }) + }, + TextIds = source.TextIds + }; + } + private TranslationResultDto Map(TranslationResult source) { return new TranslationResultDto diff --git a/src/Serval.Translation/Models/Build.cs b/src/Serval.Translation/Models/Build.cs index fbfe2b3f..0a4f6a8c 100644 --- a/src/Serval.Translation/Models/Build.cs +++ b/src/Serval.Translation/Models/Build.cs @@ -6,6 +6,7 @@ public class Build : IEntity public int Revision { get; set; } = 1; public string? Name { get; set; } public string EngineRef { get; set; } = default!; + public IList? TrainOn { get; set; } public List? Pretranslate { get; set; } public int Step { get; set; } public double? PercentCompleted { get; set; } diff --git a/src/Serval.Translation/Models/TrainingCorpus.cs b/src/Serval.Translation/Models/TrainingCorpus.cs new file mode 100644 index 00000000..cc5dc71c --- /dev/null +++ b/src/Serval.Translation/Models/TrainingCorpus.cs @@ -0,0 +1,7 @@ +namespace Serval.Translation.Models; + +public class TrainingCorpus +{ + public string CorpusRef { get; set; } = default!; + public IList? TextIds { get; set; } +} diff --git a/src/Serval.Translation/Services/EngineService.cs b/src/Serval.Translation/Services/EngineService.cs index be39dbed..6888ac3a 100644 --- a/src/Serval.Translation/Services/EngineService.cs +++ b/src/Serval.Translation/Services/EngineService.cs @@ -191,6 +191,7 @@ public async Task StartBuildAsync(Build build, CancellationToken cancellat try { Dictionary? pretranslate = build.Pretranslate?.ToDictionary(c => c.CorpusRef); + Dictionary? trainOn = build.TrainOn?.ToDictionary(c => c.CorpusRef); var client = _grpcClientFactory.CreateClient(engine.Type); var request = new StartBuildRequest { @@ -210,6 +211,12 @@ public async Task StartBuildAsync(Build build, CancellationToken cancellat if (pretranslateCorpus.TextIds is not null) corpus.PretranslateTextIds.Add(pretranslateCorpus.TextIds); } + if (trainOn?.TryGetValue(c.Id, out TrainingCorpus? trainingCorpus) ?? false) + { + corpus.TrainOnAll = trainingCorpus.TextIds is null || trainingCorpus.TextIds.Count == 0; + if (trainingCorpus.TextIds is not null) + corpus.TrainOnTextIds.Add(trainingCorpus.TextIds); + } return corpus; }) } diff --git a/tests/Serval.E2ETests/ServalApiTests.cs b/tests/Serval.E2ETests/ServalApiTests.cs index 833c8c83..8f28dc3c 100644 --- a/tests/Serval.E2ETests/ServalApiTests.cs +++ b/tests/Serval.E2ETests/ServalApiTests.cs @@ -110,13 +110,16 @@ public async Task NmtBatch() await _helperClient!.ClearEngines(); string engineId = await _helperClient.CreateNewEngine("Nmt", "es", "en", "NMT1"); var books = new string[] { "MAT.txt", "1JN.txt", "2JN.txt" }; - await _helperClient.AddTextCorpusToEngine(engineId, books, "es", "en", false); - var cId = await _helperClient.AddTextCorpusToEngine(engineId, new string[] { "3JN.txt" }, "es", "en", true); + var cId1 = await _helperClient.AddTextCorpusToEngine(engineId, books, "es", "en", false); + _helperClient.TranslationBuildConfig.TrainOn!.Add( + new TrainingCorpusConfig { CorpusId = cId1, TextIds = new string[] { "1JN.txt" } } + ); + var cId2 = await _helperClient.AddTextCorpusToEngine(engineId, new string[] { "3JN.txt" }, "es", "en", true); await _helperClient.BuildEngine(engineId); await Task.Delay(1000); IList lTrans = await _helperClient.translationEnginesClient.GetAllPretranslationsAsync( engineId, - cId + cId2 ); Assert.IsTrue(lTrans[0].Translation.Contains("dearly beloved Gaius")); } diff --git a/tests/Serval.E2ETests/ServalClientHelper.cs b/tests/Serval.E2ETests/ServalClientHelper.cs index df1b209f..bc983641 100644 --- a/tests/Serval.E2ETests/ServalClientHelper.cs +++ b/tests/Serval.E2ETests/ServalClientHelper.cs @@ -32,7 +32,8 @@ public ServalClientHelper(string audience, string prefix = "SCE_", bool ignoreSS TranslationBuildConfig = new TranslationBuildConfig { Pretranslate = new List(), - Options = "{\"max_steps\":10}" + Options = "{\"max_steps\":10}", + TrainOn = new List() }; } @@ -86,6 +87,7 @@ public async Task ClearEngines(string name = "") } } TranslationBuildConfig.Pretranslate = new List(); + TranslationBuildConfig.TrainOn = new List(); EnginePerUser.Clear(); }