diff --git a/api/handler/batch.go b/api/handler/batch.go index 7f90ef0d..786012e9 100644 --- a/api/handler/batch.go +++ b/api/handler/batch.go @@ -2,14 +2,12 @@ package handler import ( "github.com/gin-gonic/gin" - "github.com/juju/errors" "github.com/loopfz/gadgeto/zesty" "github.com/ovh/utask" "github.com/ovh/utask/models/task" - "github.com/ovh/utask/models/tasktemplate" + "github.com/ovh/utask/pkg/batch" "github.com/ovh/utask/pkg/metadata" - "github.com/ovh/utask/pkg/taskutils" "github.com/ovh/utask/pkg/utils" ) @@ -34,11 +32,6 @@ func CreateBatch(c *gin.Context, in *createBatchIn) (*task.Batch, error) { metadata.AddActionMetadata(c, metadata.TemplateName, in.TemplateName) - tt, err := tasktemplate.LoadFromName(dbp, in.TemplateName) - if err != nil { - return nil, err - } - if err := utils.ValidateTags(in.Tags); err != nil { return nil, err } @@ -49,45 +42,30 @@ func CreateBatch(c *gin.Context, in *createBatchIn) (*task.Batch, error) { b, err := task.CreateBatch(dbp) if err != nil { - dbp.Rollback() + _ = dbp.Rollback() return nil, err } metadata.AddActionMetadata(c, metadata.BatchID, b.PublicID) - for _, inp := range in.Inputs { - input, err := conjMap(in.CommonInput, inp) - if err != nil { - dbp.Rollback() - return nil, err - } - - _, err = taskutils.CreateTask(c, dbp, tt, in.WatcherUsernames, in.WatcherGroups, []string{}, []string{}, input, b, in.Comment, nil, in.Tags) - if err != nil { - dbp.Rollback() - return nil, err - } + _, err = batch.Populate(c, b, dbp, batch.TaskArgs{ + TemplateName: in.TemplateName, + Inputs: in.Inputs, + CommonInput: in.CommonInput, + Comment: in.Comment, + WatcherUsernames: in.WatcherUsernames, + WatcherGroups: in.WatcherGroups, + Tags: in.Tags, + }) + if err != nil { + _ = dbp.Rollback() + return nil, err } if err := dbp.Commit(); err != nil { - dbp.Rollback() + _ = dbp.Rollback() return nil, err } return b, nil } - -func conjMap(common, particular map[string]interface{}) (map[string]interface{}, error) { - conj := make(map[string]interface{}) - for key, value := range particular { - conj[key] = value - } - - for key, value := range common { - if _, ok := conj[key]; ok { - return nil, errors.NewBadRequest(nil, "Conflicting keys in input maps") - } - conj[key] = value - } - return conj, nil -} diff --git a/engine/engine.go b/engine/engine.go index cde9515e..55f23129 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -29,6 +29,7 @@ import ( "github.com/ovh/utask/pkg/jsonschema" "github.com/ovh/utask/pkg/metadata" "github.com/ovh/utask/pkg/now" + pluginbatch "github.com/ovh/utask/pkg/plugins/builtin/batch" "github.com/ovh/utask/pkg/taskutils" "github.com/ovh/utask/pkg/utils" ) @@ -524,7 +525,9 @@ forLoop: if mapStatus[status] { if status == resolution.StateWaiting && recheckWaiting { for name, s := range res.Steps { - if s.State == step.StateWaiting { + // Steps using the batch plugin shouldn't be run again when WAITING. Running them second time + // may lead to a race condition when the last task of a sub-batch tries to resume its parent + if s.State == step.StateWaiting && s.Action.Type != pluginbatch.Plugin.PluginName() { delete(executedSteps, name) } } diff --git a/engine/engine_test.go b/engine/engine_test.go index 55146dc7..ca2795b4 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -7,10 +7,12 @@ import ( "fmt" "os" "path/filepath" + "strings" "sync" "testing" "time" + "github.com/Masterminds/squirrel" "github.com/juju/errors" "github.com/loopfz/gadgeto/zesty" "github.com/maxatome/go-testdeep/td" @@ -23,6 +25,7 @@ import ( "github.com/ovh/utask/api" "github.com/ovh/utask/db" "github.com/ovh/utask/db/pgjuju" + "github.com/ovh/utask/db/sqlgenerator" "github.com/ovh/utask/engine" "github.com/ovh/utask/engine/functions" functionrunner "github.com/ovh/utask/engine/functions/runner" @@ -36,6 +39,7 @@ import ( compress "github.com/ovh/utask/pkg/compress/init" "github.com/ovh/utask/pkg/now" "github.com/ovh/utask/pkg/plugins" + pluginbatch "github.com/ovh/utask/pkg/plugins/builtin/batch" plugincallback "github.com/ovh/utask/pkg/plugins/builtin/callback" "github.com/ovh/utask/pkg/plugins/builtin/echo" "github.com/ovh/utask/pkg/plugins/builtin/script" @@ -91,6 +95,7 @@ func TestMain(m *testing.M) { step.RegisterRunner(echo.Plugin.PluginName(), echo.Plugin) step.RegisterRunner(script.Plugin.PluginName(), script.Plugin) step.RegisterRunner(pluginsubtask.Plugin.PluginName(), pluginsubtask.Plugin) + step.RegisterRunner(pluginbatch.Plugin.PluginName(), pluginbatch.Plugin) step.RegisterRunner(plugincallback.Plugin.PluginName(), plugincallback.Plugin) os.Exit(m.Run()) @@ -194,6 +199,21 @@ func templateFromYAML(dbp zesty.DBProvider, filename string) (*tasktemplate.Task return tasktemplate.LoadFromName(dbp, tmpl.Name) } +func listBatchTasks(dbp zesty.DBProvider, batchID int64) ([]string, error) { + query, params, err := sqlgenerator.PGsql. + Select("public_id"). + From("task"). + Where(squirrel.Eq{"id_batch": batchID}). + ToSql() + if err != nil { + return nil, err + } + + var taskIDs []string + _, err = dbp.DB().Select(&taskIDs, query, params...) + return taskIDs, err +} + func TestSimpleTemplate(t *testing.T) { input := map[string]interface{}{ "foo": "bar", @@ -1370,3 +1390,106 @@ func TestB64RawEncodeDecode(t *testing.T) { assert.Equal(t, "cmF3IG1lc3NhZ2U", output["a"]) assert.Equal(t, "raw message", output["b"]) } + +func TestBatch(t *testing.T) { + dbp, err := zesty.NewDBProvider(utask.DBName) + require.Nil(t, err) + + _, err = templateFromYAML(dbp, "batchedTask.yaml") + require.Nil(t, err) + + _, err = templateFromYAML(dbp, "batch.yaml") + require.Nil(t, err) + + res, err := createResolution("batch.yaml", map[string]interface{}{}, nil) + require.Nil(t, err, "failed to create resolution: %s", err) + + res, err = runResolution(res) + require.Nil(t, err) + require.NotNil(t, res) + assert.Equal(t, resolution.StateWaiting, res.State) + + for _, batchStepName := range []string{"batchJsonInputs", "batchYamlInputs"} { + batchStepMetadataRaw, ok := res.Steps[batchStepName].Metadata.(string) + assert.True(t, ok, "wrong type of metadata for step '%s'", batchStepName) + + assert.Nil(t, res.Steps[batchStepName].Output, "output nil for step '%s'", batchStepName) + + // The plugin formats Metadata in a special way that we need to revert before unmarshalling them + batchStepMetadataRaw = strings.ReplaceAll(batchStepMetadataRaw, `\"`, `"`) + var batchStepMetadata map[string]any + err := json.Unmarshal([]byte(batchStepMetadataRaw), &batchStepMetadata) + require.Nil(t, err, "metadata unmarshalling of step '%s'", batchStepName) + + batchPublicID := batchStepMetadata["batch_id"].(string) + assert.NotEqual(t, "", batchPublicID, "wrong batch ID '%s'", batchPublicID) + + b, err := task.LoadBatchFromPublicID(dbp, batchPublicID) + require.Nil(t, err) + + taskIDs, err := listBatchTasks(dbp, b.ID) + require.Nil(t, err) + assert.Len(t, taskIDs, 2) + + for i, publicID := range taskIDs { + child, err := task.LoadFromPublicID(dbp, publicID) + require.Nil(t, err) + assert.Equal(t, task.StateTODO, child.State) + + childResolution, err := resolution.Create(dbp, child, nil, "", false, nil) + require.Nil(t, err) + + childResolution, err = runResolution(childResolution) + require.Nil(t, err) + assert.Equal(t, resolution.StateDone, childResolution.State) + + for k, v := range childResolution.Steps { + assert.Equal(t, step.StateDone, v.State, "not valid state for step %s", k) + } + + child, err = task.LoadFromPublicID(dbp, child.PublicID) + require.Nil(t, err) + assert.Equal(t, task.StateDone, child.State) + + parentTaskToResume, err := taskutils.ShouldResumeParentTask(dbp, child) + require.Nil(t, err) + if i == len(taskIDs)-1 { + // Only the last child task should resume the parent + require.NotNil(t, parentTaskToResume) + assert.Equal(t, res.TaskID, parentTaskToResume.ID) + } else { + require.Nil(t, parentTaskToResume) + } + } + } + + // checking if the parent task is picked up after that the subtask is resolved. + // need to sleep a bit because the parent task is resumed asynchronously + ti := time.Second + i := time.Duration(0) + for i < ti { + res, err = resolution.LoadFromPublicID(dbp, res.PublicID) + require.Nil(t, err) + if res.State != resolution.StateWaiting { + break + } + + time.Sleep(time.Millisecond * 10) + i += time.Millisecond * 10 + } + + ti = time.Second + i = time.Duration(0) + for i < ti { + res, err = resolution.LoadFromPublicID(dbp, res.PublicID) + require.Nil(t, err) + if res.State != resolution.StateRunning { + break + } + + time.Sleep(time.Millisecond * 10) + i += time.Millisecond * 10 + + } + assert.Equal(t, resolution.StateDone, res.State) +} diff --git a/engine/templates_tests/batch.yaml b/engine/templates_tests/batch.yaml new file mode 100644 index 00000000..1cae43fb --- /dev/null +++ b/engine/templates_tests/batch.yaml @@ -0,0 +1,26 @@ +name: batchTemplate +description: Template to test the batch plugin +title_format: "[test] batch template test" + +steps: + batchJsonInputs: + description: Batching tasks JSON + action: + type: batch + configuration: + template_name: batchedtasktemplate + json_inputs: '[{"specific_string": "specific-1"}, {"specific_string": "specific-2"}]' + common_json_inputs: '{"common_string": "common"}' + sub_batch_size: 2 + batchYamlInputs: + description: Batching tasks YAML + action: + type: batch + configuration: + template_name: batchedtasktemplate + inputs: + - specific_string: specific-1 + - specific_string: specific-2 + common_inputs: + common_string: common + sub_batch_size: 2 diff --git a/engine/templates_tests/batchedTask.yaml b/engine/templates_tests/batchedTask.yaml new file mode 100644 index 00000000..d25b65d1 --- /dev/null +++ b/engine/templates_tests/batchedTask.yaml @@ -0,0 +1,23 @@ +name: batchedTaskTemplate +description: Template made to be spawned by the testing batch plugin +title_format: "[test] batched task template" + +inputs: + - name: specific_string + description: A string specific to this task + type: string + - name: common_string + description: A string common to all tasks in the same batch + type: string + +steps: + simpleStep: + description: Simple step + action: + type: echo + configuration: + output: >- + { + "specific": "{{.input.specific_string}}", + "common": "{{.input.common_string}}" + } \ No newline at end of file diff --git a/pkg/batch/batch.go b/pkg/batch/batch.go new file mode 100644 index 00000000..b878bae9 --- /dev/null +++ b/pkg/batch/batch.go @@ -0,0 +1,77 @@ +package batch + +import ( + "context" + + "github.com/juju/errors" + "github.com/loopfz/gadgeto/zesty" + + "github.com/ovh/utask/models/task" + "github.com/ovh/utask/models/tasktemplate" + "github.com/ovh/utask/pkg/taskutils" +) + +// TaskArgs holds arguments needed to create tasks in a batch +type TaskArgs struct { + TemplateName string // Mandatory + Inputs []map[string]interface{} // Mandatory + CommonInput map[string]interface{} // Optional + Comment string // Optional + WatcherUsernames []string // Optional + WatcherGroups []string // Optional + Tags map[string]string // Optional +} + +// Populate creates and adds new tasks to a given batch. +// All tasks share a common batchID which can be used as a listing filter. +// The [constants.SubtaskTagParentTaskID] tag can be set in the Tags to link the newly created tasks to another +// existing task, making it the parent of the batch. A parent task is resumed everytime a child task finishes. +func Populate(ctx context.Context, batch *task.Batch, dbp zesty.DBProvider, args TaskArgs) ([]string, error) { + tt, err := tasktemplate.LoadFromName(dbp, args.TemplateName) + if err != nil { + return nil, err + } + + taskIDs := make([]string, 0, len(args.Inputs)) + for _, inp := range args.Inputs { + input, err := mergeMaps(args.CommonInput, inp) + if err != nil { + return nil, err + } + + t, err := taskutils.CreateTask( + ctx, + dbp, + tt, + args.WatcherUsernames, + args.WatcherGroups, + []string{}, + []string{}, + input, + batch, + args.Comment, + nil, + args.Tags, + ) + if err != nil { + return nil, err + } + taskIDs = append(taskIDs, t.PublicID) + } + return taskIDs, nil +} + +func mergeMaps(common, particular map[string]interface{}) (map[string]interface{}, error) { + merged := make(map[string]interface{}, len(common)+len(particular)) + for key, value := range particular { + merged[key] = value + } + + for key, value := range common { + if _, ok := merged[key]; ok { + return nil, errors.NewBadRequest(nil, "Conflicting keys in input maps") + } + merged[key] = value + } + return merged, nil +} diff --git a/pkg/batch/batch_test.go b/pkg/batch/batch_test.go new file mode 100644 index 00000000..230ec31d --- /dev/null +++ b/pkg/batch/batch_test.go @@ -0,0 +1,97 @@ +package batch + +import ( + "context" + "encoding/json" + "testing" + + "github.com/juju/errors" + "github.com/loopfz/gadgeto/zesty" + "github.com/ovh/configstore" + "github.com/stretchr/testify/assert" + + "github.com/ovh/utask" + "github.com/ovh/utask/db" + "github.com/ovh/utask/engine/input" + "github.com/ovh/utask/engine/step" + "github.com/ovh/utask/engine/step/executor" + "github.com/ovh/utask/models/task" + "github.com/ovh/utask/models/tasktemplate" +) + +func TestPopulate(t *testing.T) { + store := configstore.DefaultStore + store.InitFromEnvironment() + + if err := db.Init(store); err != nil { + panic(err) + } + + dbp, err := zesty.NewDBProvider(utask.DBName) + if err != nil { + t.Fatal(err) + } + + tmpl, err := tasktemplate.LoadFromName(dbp, dummyTemplate.Name) + if err != nil { + if !errors.IsNotFound(err) { + t.Fatal(err) + } + tmpl = &dummyTemplate + if err := dbp.DB().Insert(tmpl); err != nil { + t.Fatal(err) + } + } + + b, err := task.CreateBatch(dbp) + if err != nil { + t.Fatal(err) + } + + batchArgs := TaskArgs{ + TemplateName: tmpl.Name, + Inputs: []map[string]any{{"id": "dummyID-1"}, {"id": "dummyID-2"}, {"id": "dummyID-3"}}, + } + + taskIDs, err := Populate(context.Background(), b, dbp, batchArgs) + if err != nil { + t.Fatal(err) + } + + // Making sure we returned as many IDs as tasks we created + assert.Len(t, taskIDs, len(batchArgs.Inputs)) + + tasks, err := task.ListTasks(dbp, task.ListFilter{Batch: b}) + if err != nil { + t.Fatal(err) + } + + // Making sure the right number of tasks was created in the batch + assert.Len(t, taskIDs, len(batchArgs.Inputs)) + + for i, childTask := range tasks { + assert.Equal(t, batchArgs.Inputs[i]["id"], childTask.Title) + } + +} + +var dummyTemplate = tasktemplate.TaskTemplate{ + Name: "dummy-template", + Description: "does nothing", + TitleFormat: "this task does nothing at all", + Inputs: []input.Input{ + { + Name: "id", + }, + }, + Steps: map[string]*step.Step{ + "step": { + Action: executor.Executor{ + Type: "echo", + Configuration: json.RawMessage(`{ + "output": {"foo":"bar"} + }`), + }, + }, + }, +} diff --git a/pkg/batchutils/batchutils.go b/pkg/batchutils/batchutils.go new file mode 100644 index 00000000..0838e40d --- /dev/null +++ b/pkg/batchutils/batchutils.go @@ -0,0 +1,28 @@ +package batchutils + +import ( + "github.com/Masterminds/squirrel" + "github.com/loopfz/gadgeto/zesty" + + "github.com/ovh/utask/db/sqlgenerator" + "github.com/ovh/utask/models/task" +) + +// FinalStates hold the states in which a task won't ever be run again +var FinalStates = []string{task.StateDone, task.StateCancelled, task.StateWontfix} + +// RunningTasks returns the amount of running tasks sharing the same given batchId. +func RunningTasks(dbp zesty.DBProvider, batchID int64) (int64, error) { + query, params, err := sqlgenerator.PGsql. + Select("count (*)"). + From("task t"). + Join("batch b on b.id = t.id_batch"). + Where(squirrel.Eq{"b.id": batchID}). + Where(squirrel.NotEq{"t.state": FinalStates}). + ToSql() + if err != nil { + return -1, err + } + + return dbp.DB().SelectInt(query, params...) +} diff --git a/pkg/batchutils/batchutils_test.go b/pkg/batchutils/batchutils_test.go new file mode 100644 index 00000000..b5ce91da --- /dev/null +++ b/pkg/batchutils/batchutils_test.go @@ -0,0 +1,125 @@ +package batchutils_test + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/juju/errors" + "github.com/loopfz/gadgeto/zesty" + "github.com/ovh/configstore" + "github.com/stretchr/testify/assert" + + "github.com/ovh/utask" + "github.com/ovh/utask/db" + "github.com/ovh/utask/engine/input" + "github.com/ovh/utask/engine/step" + "github.com/ovh/utask/engine/step/executor" + "github.com/ovh/utask/models/task" + "github.com/ovh/utask/models/tasktemplate" + "github.com/ovh/utask/pkg/batchutils" +) + +func TestRunningTasks(t *testing.T) { + store := configstore.DefaultStore + store.InitFromEnvironment() + + if err := db.Init(store); err != nil { + panic(err) + } + + dbp, err := zesty.NewDBProvider(utask.DBName) + if err != nil { + t.Fatal(err) + } + + const batchSize int = 10 + batchID, tasks := createBatch(t, batchSize, dbp) + + // Making sure that created tasks running + running, err := batchutils.RunningTasks(dbp, batchID) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, int64(len(tasks)), running) + + // Setting a final state to some tasks in the batch (one per final state) + for i, state := range batchutils.FinalStates { + tasks[i].SetState(state) + if err := tasks[i].Update(dbp, false, false); err != nil { + t.Fatal(err) + } + } + + // Making sure that tasks in final states aren't counted + running, err = batchutils.RunningTasks(dbp, batchID) + if err != nil { + t.Fatal(err) + } + expectedRunning := int64(len(tasks) - len(batchutils.FinalStates)) + assert.Equal(t, expectedRunning, running) +} + +func createBatch(t *testing.T, amount int, dbp zesty.DBProvider) (int64, []*task.Task) { + tmpl, err := tasktemplate.LoadFromName(dbp, dummyTemplate.Name) + if err != nil { + if !errors.IsNotFound(err) { + t.Fatal(err) + } + tmpl = &dummyTemplate + if err := dbp.DB().Insert(tmpl); err != nil { + t.Fatal(err) + } + } + + b, err := task.CreateBatch(dbp) + if err != nil { + t.Fatal(err) + } + + tasks := make([]*task.Task, 0, amount) + for i := 0; i < amount; i++ { + // Manually populating the batch to prevent cyclic imports + newTask, err := task.Create( + dbp, + tmpl, + "", + nil, + nil, + nil, + nil, + nil, + map[string]any{"id": fmt.Sprintf("dummyID-%d", i)}, + nil, + b, + false, + ) + if err != nil { + t.Fatal(err) + } + tasks = append(tasks, newTask) + } + + return b.ID, tasks +} + +var dummyTemplate = tasktemplate.TaskTemplate{ + Name: "dummy-template", + Description: "does nothing", + TitleFormat: "this task does nothing at all", + Inputs: []input.Input{ + { + Name: "id", + }, + }, + Steps: map[string]*step.Step{ + "step": { + Action: executor.Executor{ + Type: "echo", + Configuration: json.RawMessage(`{ + "output": {"foo":"bar"} + }`), + }, + }, + }, +} diff --git a/pkg/plugins/builtin/batch/README.md b/pkg/plugins/builtin/batch/README.md new file mode 100644 index 00000000..594dd6ab --- /dev/null +++ b/pkg/plugins/builtin/batch/README.md @@ -0,0 +1,77 @@ +# `batch` Plugin + +This plugin creates a batch of tasks based on the same template and waits for it to complete. It acts like the `subtask` combined with a `foreach`, but doesn't modify the resolution by adding new steps dynamically. As it makes less calls to the underlying database, this plugin is suited for large batches of tasks, where the `subtask` / `foreach` combination would usually struggle, escpecially by bloating the database. +Tasks belonging to the same batch share a common `BatchID` as well as tag holding their parent's ID. + +##### Remarks: +The output of child tasks is not made available in this plugin's output. This feature will come later. + +## Configuration + +| Fields | Description | +|----------------------|-------------------------------------------------------------------------------------------------------------------| +| `template_name` | the name of a task template, as accepted through µTask's API | +| `inputs` | a list of mapped key/value, as accepted on µTask's API. Each element represents the input of an individual task | +| `json_inputs` | same as `inputs`, but as a JSON string. If specified, it overrides `inputs` | +| `common_inputs` | a map of named values, as accepted on µTask's API, given to all task in the batch by combining it with each input | +| `common_json_inputs` | same as `common_inputs` but as a JSON string. If specified, it overrides `common_inputs` | +| `tags` | a map of named strings added as tags when creating child tasks | +| `sub_batch_size` | the number tasks to create and run at once. `0` for infinity (i.e.: all tasks are created at once and waited for) (default). Higher values reduce the amount of calls made to the database, but increase sensitivity to database unavailability (if a task creation fails, the whole sub batch must be created again) | +| `comment` | a string set as `comment` when creating child tasks | +| `resolver_usernames` | a string containing a JSON array of additional resolver users for child tasks | +| `resolver_groups` | a string containing a JSON array of additional resolver groups for child tasks | +| `watcher_usernames` | a string containing a JSON array of additional watcher users for child tasks | +| `watcher_groups` | a string containing a JSON array of additional watcher groups for child tasks | + +## Example + +An action of type `batch` requires the following kind of configuration: + +```yaml +action: + type: batch + configuration: + # [Required] + # A template that must already be registered on this instance of µTask + template: some-task-template + # Valid inputs, as defined by the referred template, here requiring 3 inputs: foo, otherFoo and fooCommon + inputs: + - foo: bar-1 + otherFoo: otherBar-1 + - foo: bar-2 + otherFoo: otherBar-1 + - foo: bar-3 + otherFoo: otherBar-3 + # [Optional] + common_inputs: + fooCommon: barCommon + # Some tags added to all child tasks + tags: + fooTag: value-of-foo-tag + barTag: value-of-bar-tag + # The amount of tasks to run at once + sub_batch_size: 2 + # A list of users which are authorized to resolve this specific task + resolver_usernames: '["authorizedUser"]' + resolver_groups: '["authorizedGroup"]' + watcher_usernames: '["authorizedUser"]' + watcher_groups: '["authorizedGroup"]' +``` + +## Requirements + +None. + +## Return + +### Output + +None. + +### Metadata + +| Name | Description | +|----------------------|-------------------------------------------| +| `batch_id` | The public identifier of the batch | +| `remaining_tasks` | How many tasks still need to complete | +| `tasks_started` | How many tasks were started so far | diff --git a/pkg/plugins/builtin/batch/batch.go b/pkg/plugins/builtin/batch/batch.go new file mode 100644 index 00000000..16194cfb --- /dev/null +++ b/pkg/plugins/builtin/batch/batch.go @@ -0,0 +1,368 @@ +package pluginbatch + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + jujuErrors "github.com/juju/errors" + "github.com/loopfz/gadgeto/zesty" + "github.com/sirupsen/logrus" + + "github.com/ovh/utask" + "github.com/ovh/utask/models/resolution" + "github.com/ovh/utask/models/task" + "github.com/ovh/utask/models/tasktemplate" + "github.com/ovh/utask/pkg/auth" + "github.com/ovh/utask/pkg/batch" + "github.com/ovh/utask/pkg/batchutils" + "github.com/ovh/utask/pkg/constants" + "github.com/ovh/utask/pkg/plugins/taskplugin" + "github.com/ovh/utask/pkg/templateimport" + "github.com/ovh/utask/pkg/utils" +) + +// The batch plugin spawns X new µTask tasks, given a template and inputs, and waits for them to be completed. +// Resolver usernames can be dynamically set for the task +var Plugin = taskplugin.New( + "batch", + "0.1", + exec, + taskplugin.WithConfig(validateConfigBatch, BatchConfig{}), + taskplugin.WithContextFunc(ctxBatch), +) + +// BatchConfig is the necessary configuration to spawn a new task +type BatchConfig struct { + TemplateName string `json:"template_name" binding:"required"` + CommonInputs map[string]interface{} `json:"common_inputs"` + CommonJSONInputs string `json:"common_json_inputs"` + Inputs []map[string]interface{} `json:"inputs"` + JSONInputs string `json:"json_inputs"` + Comment string `json:"comment"` + WatcherUsernames []string `json:"watcher_usernames"` + WatcherGroups []string `json:"watcher_groups"` + Tags map[string]string `json:"tags"` + ResolverUsernames string `json:"resolver_usernames"` + ResolverGroups string `json:"resolver_groups"` + // How many tasks will run concurrently. 0 for infinity (default) + SubBatchSize int `json:"sub_batch_size"` +} + +// quotedString is a string with doubly escaped quotes, so the string stays simply escaped after being processed +// as the plugin's context (see ctxBatch). +type quotedString string + +// BatchContext holds data about the parent task execution as well as the metadata of previous runs, if any. +type BatchContext struct { + ParentTaskID string `json:"parent_task_id"` + RequesterUsername string `json:"requester_username"` + RequesterGroups string `json:"requester_groups"` + // RawMetadata of the previous run. Metadata are used to communicate batch progress between runs. It's returned + // "as is" in case something goes wrong in a subsequent run, to know what the batch's progress was when the + // error occured. + RawMetadata quotedString `json:"metadata"` + // Unmarshalled version of the metadata + metadata BatchMetadata + StepName string `json:"step_name"` +} + +// BatchMetadata holds batch-progress data, communicated between each run of the plugin. +type BatchMetadata struct { + BatchID string `json:"batch_id"` + RemainingTasks int64 `json:"remaining_tasks"` + TasksStarted int64 `json:"tasks_started"` +} + +func ctxBatch(stepName string) interface{} { + return &BatchContext{ + ParentTaskID: "{{ .task.task_id }}", + RequesterUsername: "{{.task.requester_username}}", + RequesterGroups: "{{ if .task.requester_groups }}{{ .task.requester_groups }}{{ end }}", + RawMetadata: quotedString(fmt.Sprintf( + "{{ if (index .step `%s` ) }}{{ if (index .step `%s` `metadata`) }}{{ index .step `%s` `metadata` }}{{ end }}{{ end }}", + stepName, + stepName, + stepName, + )), + StepName: stepName, + } +} + +func validateConfigBatch(config any) error { + conf := config.(*BatchConfig) + + if err := utils.ValidateTags(conf.Tags); err != nil { + return err + } + + dbp, err := zesty.NewDBProvider(utask.DBName) + if err != nil { + return fmt.Errorf("can't retrieve connection to DB: %s", err) + } + + _, err = tasktemplate.LoadFromName(dbp, conf.TemplateName) + if err != nil { + if !jujuErrors.IsNotFound(err) { + return fmt.Errorf("can't load template from name: %s", err) + } + + // searching into currently imported templates + templates := templateimport.GetTemplates() + for _, template := range templates { + if template == conf.TemplateName { + return nil + } + } + + return jujuErrors.NotFoundf("batch template %q", conf.TemplateName) + } + + return nil +} + +func exec(stepName string, config any, ictx any) (any, any, error) { + var metadata BatchMetadata + var stepError error + + conf := config.(*BatchConfig) + batchCtx := ictx.(*BatchContext) + if err := parseInputs(conf, batchCtx); err != nil { + return nil, batchCtx.RawMetadata.Format(), err + } + + if conf.Tags == nil { + conf.Tags = make(map[string]string) + } + conf.Tags[constants.SubtaskTagParentTaskID] = batchCtx.ParentTaskID + + ctx := auth.WithIdentity(context.Background(), batchCtx.RequesterUsername) + requesterGroups := strings.Split(batchCtx.RequesterGroups, utask.GroupsSeparator) + ctx = auth.WithGroups(ctx, requesterGroups) + + dbp, err := zesty.NewDBProvider(utask.DBName) + if err != nil { + return nil, batchCtx.RawMetadata.Format(), err + } + + if err := dbp.Tx(); err != nil { + return nil, batchCtx.RawMetadata.Format(), err + } + + if batchCtx.metadata.BatchID == "" { + // The batch needs to be started + metadata, err = startBatch(ctx, dbp, conf, batchCtx) + if err != nil { + dbp.Rollback() + return nil, nil, err + } + + // A step returning a NotAssigned error is set to WAITING by the engine + stepError = jujuErrors.NewNotAssigned(fmt.Errorf("tasks from batch %q will start shortly", metadata.BatchID), "") + } else { + // Batch already started, we either need to start new tasks or check whether they're all done + metadata, err = runBatch(ctx, conf, batchCtx, dbp) + if err != nil { + dbp.Rollback() + return nil, batchCtx.RawMetadata.Format(), err + } + + if metadata.RemainingTasks != 0 { + // A step returning a NotAssigned error is set to WAITING by the engine + stepError = jujuErrors.NewNotAssigned(fmt.Errorf("batch %q is currently RUNNING", metadata.BatchID), "") + } else { + // The batch is done. + // We increase the resolution's maximum amount of retries to compensate for the amount of runs consumed + // by child tasks waking up the parent when they're done. + err := increaseRunMax(dbp, batchCtx.ParentTaskID, batchCtx.StepName) + if err != nil { + return nil, batchCtx.RawMetadata.Format(), err + } + } + } + + formattedMetadata, err := formatOutput(metadata) + if err != nil { + dbp.Rollback() + return nil, batchCtx.RawMetadata.Format(), err + } + + if err := dbp.Commit(); err != nil { + dbp.Rollback() + return nil, batchCtx.RawMetadata.Format(), err + } + return nil, formattedMetadata, stepError +} + +// startBatch creates a batch of tasks as described in the given batchArgs. +func startBatch( + ctx context.Context, + dbp zesty.DBProvider, + conf *BatchConfig, + batchCtx *BatchContext, +) (BatchMetadata, error) { + b, err := task.CreateBatch(dbp) + if err != nil { + return BatchMetadata{}, err + } + + taskIDs, err := populateBatch(ctx, b, dbp, conf, batchCtx) + if err != nil { + return BatchMetadata{}, err + } + + return BatchMetadata{ + BatchID: b.PublicID, + RemainingTasks: int64(len(conf.Inputs)), + TasksStarted: int64(len(taskIDs)), + }, nil +} + +// populateBatch spawns new tasks in the batch and returns their public identifier. +func populateBatch( + ctx context.Context, + b *task.Batch, + dbp zesty.DBProvider, + conf *BatchConfig, + batchCtx *BatchContext, +) ([]string, error) { + tasksStarted := batchCtx.metadata.TasksStarted + running, err := batchutils.RunningTasks(dbp, b.ID) + if err != nil { + return []string{}, err + } + + // Computing how many tasks to start + remaining := int64(len(conf.Inputs)) - tasksStarted + toStart := int64(conf.SubBatchSize) - running // How many tasks can be started + if remaining < toStart { + toStart = remaining // There's less tasks to start remaining than the amount of available running slots + } + + args := batch.TaskArgs{ + TemplateName: conf.TemplateName, + CommonInput: conf.CommonInputs, + Inputs: conf.Inputs[tasksStarted : tasksStarted+toStart], + Comment: conf.Comment, + WatcherGroups: conf.WatcherGroups, + WatcherUsernames: conf.WatcherUsernames, + Tags: conf.Tags, + } + + taskIDs, err := batch.Populate(ctx, b, dbp, args) + if err != nil { + return []string{}, err + } + + return taskIDs, nil +} + +// runBatch runs a batch, spawning new tasks if needed and checking whether they're all done. +func runBatch( + ctx context.Context, + conf *BatchConfig, + batchCtx *BatchContext, + dbp zesty.DBProvider, +) (BatchMetadata, error) { + metadata := batchCtx.metadata + + b, err := task.LoadBatchFromPublicID(dbp, metadata.BatchID) + if err != nil { + if jujuErrors.IsNotFound(err) { + // The batch has been collected (deleted in DB) because no remaining task referenced it. There's + // nothing more to do. + return metadata, nil + } + return metadata, err + } + + if metadata.TasksStarted < int64(len(conf.Inputs)) { + // New tasks still need to be added to the batch + + taskIDs, err := populateBatch(ctx, b, dbp, conf, batchCtx) + if err != nil { + return metadata, err + } + + started := int64(len(taskIDs)) + metadata.TasksStarted += started + metadata.RemainingTasks -= started // Starting X tasks means that X tasks became DONE + return metadata, nil + } + // else, all tasks are started, we need to wait for the last ones to become DONE + + running, err := batchutils.RunningTasks(dbp, b.ID) + if err != nil { + return metadata, err + } + metadata.RemainingTasks = running + return metadata, nil +} + +// increaseRunMax increases the maximum amount of runs of the resolution matching the given parentTaskID by the run +// count of the given batchStepName. +// Since child tasks wake their parent up when they're done, the resolution's RunCount gets incremented everytime. We +// compensate this by increasing the RunMax property once the batch is done. +func increaseRunMax(dbp zesty.DBProvider, parentTaskID string, batchStepName string) error { + t, err := task.LoadFromPublicID(dbp, parentTaskID) + if err != nil { + return err + } + + if t.Resolution != nil { + return fmt.Errorf("resolution not found for step '%s' of task '%s'", batchStepName, parentTaskID) + } + + res, err := resolution.LoadLockedFromPublicID(dbp, *t.Resolution) + if err != nil { + return err + } + + step, ok := res.Steps[batchStepName] + if !ok { + return fmt.Errorf("step '%s' not found in resolution", batchStepName) + } + + res.ExtendRunMax(step.TryCount) + return res.Update(dbp) +} + +// parseInputs parses the step's inputs as well as metadata from the previous run (if it exists). +func parseInputs(conf *BatchConfig, batchCtx *BatchContext) error { + if batchCtx.RawMetadata != "" { + // Metadata from a previous run is available + if err := json.Unmarshal([]byte(batchCtx.RawMetadata), &batchCtx.metadata); err != nil { + return jujuErrors.NewBadRequest(err, "metadata unmarshalling failure") + } + } + + if conf.CommonJSONInputs != "" { + if err := json.Unmarshal([]byte(conf.CommonJSONInputs), &conf.CommonInputs); err != nil { + return jujuErrors.NewBadRequest(err, "JSON common input unmarshalling failure") + } + } + + if conf.JSONInputs != "" { + if err := json.Unmarshal([]byte(conf.JSONInputs), &conf.Inputs); err != nil { + return jujuErrors.NewBadRequest(err, "JSON inputs unmarshalling failure") + } + } + return nil +} + +// Format formats the utaskString to make sure it's parsable by subsequent runs of the plugin (i.e.: escaping +// double quotes). +func (rm quotedString) Format() string { + return strings.ReplaceAll(string(rm), `"`, `\"`) +} + +// formatOutput formats an output (plugin output or metadata) as a uTask-friendly output. +func formatOutput(result any) (string, error) { + marshalled, err := json.Marshal(result) + if err != nil { + logrus.WithError(err).Error("Couldn't marshal batch metadata") + return "", err + } + return quotedString(marshalled).Format(), nil +} diff --git a/pkg/plugins/builtin/builtin.go b/pkg/plugins/builtin/builtin.go index 0585967f..dbe2ae1f 100644 --- a/pkg/plugins/builtin/builtin.go +++ b/pkg/plugins/builtin/builtin.go @@ -4,6 +4,7 @@ import ( "github.com/ovh/utask/engine/step" "github.com/ovh/utask/pkg/plugins" pluginapiovh "github.com/ovh/utask/pkg/plugins/builtin/apiovh" + pluginbatch "github.com/ovh/utask/pkg/plugins/builtin/batch" plugincallback "github.com/ovh/utask/pkg/plugins/builtin/callback" pluginecho "github.com/ovh/utask/pkg/plugins/builtin/echo" pluginemail "github.com/ovh/utask/pkg/plugins/builtin/email" @@ -43,6 +44,7 @@ func Register() error { pluginscript.Plugin, plugintag.Plugin, plugincallback.Plugin, + pluginbatch.Plugin, } { if err := step.RegisterRunner(p.PluginName(), p); err != nil { return err diff --git a/pkg/taskutils/taskutils.go b/pkg/taskutils/taskutils.go index 57d20e3d..6540762c 100644 --- a/pkg/taskutils/taskutils.go +++ b/pkg/taskutils/taskutils.go @@ -11,6 +11,7 @@ import ( "github.com/ovh/utask/models/task" "github.com/ovh/utask/models/tasktemplate" "github.com/ovh/utask/pkg/auth" + "github.com/ovh/utask/pkg/batchutils" "github.com/ovh/utask/pkg/constants" ) @@ -88,6 +89,24 @@ func ShouldResumeParentTask(dbp zesty.DBProvider, t *task.Task) (*task.Task, err return nil, nil } + if t.BatchID != nil { + // The task belongs to a batch. If all sibling tasks are done, the parent can be awaken. + + // Note on race conditions: + // When two sibling tasks complete, they either complete at the very same time or with a delay. In the former + // case, two attempts to resume the parent may be triggered, but a DB lock already prevents the parent + // from being run twice at the same time. In the later case, no race condition exists since one finished before + // the other. + running, err := batchutils.RunningTasks(dbp, *t.BatchID) + if err != nil { + return nil, err + } + if running != 0 { + // Some sibling tasks are still running, no need to resume the parent yet + return nil, nil + } + } + parentTask, err := task.LoadFromPublicID(dbp, parentTaskID) if err != nil { return nil, err