diff --git a/awscmd/ecs.go b/awscmd/ecs.go index 9f95d7f..dc9afa3 100644 --- a/awscmd/ecs.go +++ b/awscmd/ecs.go @@ -9,6 +9,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ecs" + "github.com/aws/aws-sdk-go/service/sts" "github.com/wzshiming/ctc" ) @@ -31,6 +32,12 @@ func EcsDeploy(ctx context.Context, input *InputEcsDeploy, w io.Writer) (*OuputE return nil, err } svc := ecs.New(sess) + identity := sts.New(sess) + + caller, err := identity.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + if err != nil { + return nil, err + } // 1. fetch running task fmt.Fprintf(w, "Fetching service definition\n") @@ -56,49 +63,32 @@ func EcsDeploy(ctx context.Context, input *InputEcsDeploy, w io.Writer) (*OuputE taskDefinition = s.TaskDefinition } - // 2. build new task definition - fmt.Fprintf(w, "Fetching task definition\n") - taskDefinitionOut, err := svc.DescribeTaskDefinitionWithContext(ctx, &ecs.DescribeTaskDefinitionInput{ - TaskDefinition: taskDefinition, - }) - - containerDefinitions := []*ecs.ContainerDefinition{} - for _, container := range taskDefinitionOut.TaskDefinition.ContainerDefinitions { - container.Image = aws.String(input.DockerImage) - containerDefinitions = append(containerDefinitions, container) - } - - // 3. register new task revision - fmt.Fprintf(w, "Registering new task revision with new docker image\n") - registerOut, err := svc.RegisterTaskDefinitionWithContext(ctx, &ecs.RegisterTaskDefinitionInput{ - ContainerDefinitions: containerDefinitions, - Cpu: taskDefinitionOut.TaskDefinition.Cpu, - EphemeralStorage: taskDefinitionOut.TaskDefinition.EphemeralStorage, - ExecutionRoleArn: taskDefinitionOut.TaskDefinition.ExecutionRoleArn, - Family: taskDefinitionOut.TaskDefinition.Family, - InferenceAccelerators: taskDefinitionOut.TaskDefinition.InferenceAccelerators, - IpcMode: taskDefinitionOut.TaskDefinition.IpcMode, - Memory: taskDefinitionOut.TaskDefinition.Memory, - NetworkMode: taskDefinitionOut.TaskDefinition.NetworkMode, - PidMode: taskDefinitionOut.TaskDefinition.PidMode, - PlacementConstraints: taskDefinitionOut.TaskDefinition.PlacementConstraints, - ProxyConfiguration: taskDefinitionOut.TaskDefinition.ProxyConfiguration, - RequiresCompatibilities: taskDefinitionOut.TaskDefinition.RequiresCompatibilities, - RuntimePlatform: taskDefinitionOut.TaskDefinition.RuntimePlatform, - TaskRoleArn: taskDefinitionOut.TaskDefinition.TaskRoleArn, - Volumes: taskDefinitionOut.TaskDefinition.Volumes, - }) + fmt.Fprintf(w, "Registering new task [%s]\n", idOneOff(input.Service, "")) + newTask, err := ecsRegisterTaskDefinition(svc, ctx, &inputRegisterTaskDefinition{ + TaskDefinitionOrARN: *taskDefinition, + DockerImage: input.DockerImage, + }, w) if err != nil { - return nil, fmt.Errorf("Failed to register new task revision: %w", err) + return nil, err + } + for _, cmd := range input.OneOffs { + fmt.Fprintf(w, "Registering new one-off task [%s]\n", idOneOff(input.Service, cmd)) + taskDef := fmt.Sprintf("arn:aws:ecs:%s:%s:task-definition/%s", input.Region, *caller.Account, idOneOff(input.Service, cmd)) + _, err := ecsRegisterTaskDefinition(svc, ctx, &inputRegisterTaskDefinition{ + TaskDefinitionOrARN: taskDef, + DockerImage: input.DockerImage, + }, w) + if err != nil { + return nil, err + } } - arn := registerOut.TaskDefinition.TaskDefinitionArn // 4. update ecs service with new task arn fmt.Fprintf(w, "Updating service\n") updateOut, err := svc.UpdateServiceWithContext(ctx, &ecs.UpdateServiceInput{ Cluster: aws.String(input.Cluster), Service: aws.String(input.Service), - TaskDefinition: arn, + TaskDefinition: aws.String(newTask.ARN), }) if err != nil { return nil, fmt.Errorf("Failed to update service with new task revision: %w", err) @@ -230,3 +220,62 @@ type OuputEcsRunTask struct { ARN string ID string } + +func ecsRegisterTaskDefinition(svc *ecs.ECS, ctx context.Context, input *inputRegisterTaskDefinition, w io.Writer) (*outputRegisterTaskDefinition, error) { + fmt.Fprintf(w, " Fetching task definition\n") + taskDefinitionOut, err := svc.DescribeTaskDefinitionWithContext(ctx, &ecs.DescribeTaskDefinitionInput{ + TaskDefinition: aws.String(input.TaskDefinitionOrARN), + }) + + containerDefinitions := []*ecs.ContainerDefinition{} + for _, container := range taskDefinitionOut.TaskDefinition.ContainerDefinitions { + container.Image = aws.String(input.DockerImage) + containerDefinitions = append(containerDefinitions, container) + } + + // 3. register new task revision + fmt.Fprintf(w, " Registering new task revision with new docker image\n") + registerOut, err := svc.RegisterTaskDefinitionWithContext(ctx, &ecs.RegisterTaskDefinitionInput{ + ContainerDefinitions: containerDefinitions, + Cpu: taskDefinitionOut.TaskDefinition.Cpu, + EphemeralStorage: taskDefinitionOut.TaskDefinition.EphemeralStorage, + ExecutionRoleArn: taskDefinitionOut.TaskDefinition.ExecutionRoleArn, + Family: taskDefinitionOut.TaskDefinition.Family, + InferenceAccelerators: taskDefinitionOut.TaskDefinition.InferenceAccelerators, + IpcMode: taskDefinitionOut.TaskDefinition.IpcMode, + Memory: taskDefinitionOut.TaskDefinition.Memory, + NetworkMode: taskDefinitionOut.TaskDefinition.NetworkMode, + PidMode: taskDefinitionOut.TaskDefinition.PidMode, + PlacementConstraints: taskDefinitionOut.TaskDefinition.PlacementConstraints, + ProxyConfiguration: taskDefinitionOut.TaskDefinition.ProxyConfiguration, + RequiresCompatibilities: taskDefinitionOut.TaskDefinition.RequiresCompatibilities, + RuntimePlatform: taskDefinitionOut.TaskDefinition.RuntimePlatform, + TaskRoleArn: taskDefinitionOut.TaskDefinition.TaskRoleArn, + Volumes: taskDefinitionOut.TaskDefinition.Volumes, + }) + if err != nil { + return nil, fmt.Errorf("Failed to register new task revision: %w", err) + } + + return &outputRegisterTaskDefinition{ + ARN: *registerOut.TaskDefinition.TaskDefinitionArn, + }, nil +} + +type inputRegisterTaskDefinition struct { + TaskDefinitionOrARN string + DockerImage string +} + +type outputRegisterTaskDefinition struct { + ARN string +} + +// convention used in terraform modules +func idOneOff(service string, command string) string { + if command == "" { + return service + } else { + return fmt.Sprint(service, "-", command) + } +} diff --git a/main.go b/main.go index 3c8f6ec..6149651 100644 --- a/main.go +++ b/main.go @@ -152,6 +152,7 @@ func main() { &cli.StringFlag{Name: "cluster", Usage: "ECS cluster ID", Required: true}, &cli.StringFlag{Name: "service", Usage: "ECS service ID", Required: true}, &cli.StringFlag{Name: "docker-image", Usage: "Docker image to replace task definition with", Required: true}, + &cli.StringSliceFlag{Name: "one-off", Usage: "One-off commands (multiple use of flag allowed)", Required: false}, }, Action: func(c *cli.Context) error { input := &awscmd.InputEcsDeploy{ @@ -159,6 +160,7 @@ func main() { Cluster: c.String("cluster"), Service: c.String("service"), DockerImage: c.String("docker-image"), + OneOffs: c.StringSlice("one-off"), } out, err := awscmd.EcsDeploy(context.TODO(), input, c.App.Writer) if out != nil { @@ -203,14 +205,14 @@ func main() { &cli.StringFlag{Name: "region", Usage: "AWS region", Required: true}, &cli.StringFlag{Name: "cluster", Usage: "ECS cluster ID", Required: true}, &cli.StringFlag{Name: "service", Usage: "ECS service ID", Required: true}, - &cli.StringFlag{Name: "command", Usage: "One-off command to run", Required: true}, + &cli.StringFlag{Name: "one-off", Usage: "One-off command to run", Required: true}, }, Action: func(c *cli.Context) error { runTaskInput := &awscmd.InputEcsRunTask{ Region: c.String("region"), Cluster: c.String("cluster"), Service: c.String("service"), - OneOffCommand: c.String("command"), + OneOffCommand: c.String("one-off"), } out, err := awscmd.EcsRunTask(context.TODO(), runTaskInput, c.App.Writer) if err != nil {