Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ec2 provider #587

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Adding network wait for ssh
jonhenrik13 committed Jan 23, 2019
commit 424eb0a3c3bda64775ced438d2a81fb9ed9b470c
90 changes: 81 additions & 9 deletions backend/ec2.go
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ package backend
import (
"fmt"
"io"
"net"
"net/url"
"strconv"
"strings"
@@ -58,6 +59,8 @@ func init() {
"UPLOAD_RETRIES": fmt.Sprintf("number of times to attempt to upload script before erroring (default %d)", defaultEC2UploadRetries),
"UPLOAD_RETRY_SLEEP": fmt.Sprintf("sleep interval between script upload attempts (default %v)", defaultEC2UploadRetrySleep),
"SECURITY_GROUPS": "Security groups to assign",
"PUBLIC_IP": "boot job instances with a public ip, disable this for NAT (default true)",
"PUBLIC_IP_CONNECT": "connect to the public ip of the instance instead of the internal, only takes effect if PUBLIC_IP is true (default true)",
}, newEC2Provider)
}

@@ -74,6 +77,9 @@ type ec2Provider struct {
uploadRetries uint64
uploadRetrySleep time.Duration
sshDialTimeout time.Duration
publicIP bool
publicIPConnect bool
subnetID string
}

func newEC2Provider(cfg *config.ProviderConfig) (Provider, error) {
@@ -139,6 +145,11 @@ func newEC2Provider(cfg *config.ProviderConfig) (Provider, error) {
instanceType = cfg.Get("INSTANCE_TYPE")
}

subnetID := ""
if cfg.IsSet("SUBNET_ID") {
subnetID = cfg.Get("SUBNET_ID")
}

defaultImage := defaultEC2Image
if cfg.IsSet("DEFAULT_IMAGE") {
defaultImage = cfg.Get("DEFAULT_IMAGE")
@@ -180,6 +191,16 @@ func newEC2Provider(cfg *config.ProviderConfig) (Provider, error) {
uploadRetrySleep = si
}

publicIP := true
if cfg.IsSet("PUBLIC_IP") {
publicIP = asBool(cfg.Get("PUBLIC_IP"))
}

publicIPConnect := true
if cfg.IsSet("PUBLIC_IP_CONNECT") {
publicIPConnect = asBool(cfg.Get("PUBLIC_IP_CONNECT"))
}

return &ec2Provider{
cfg: cfg,
sshDialTimeout: sshDialTimeout,
@@ -193,6 +214,9 @@ func newEC2Provider(cfg *config.ProviderConfig) (Provider, error) {
diskSize: diskSize,
uploadRetries: uploadRetries,
uploadRetrySleep: uploadRetrySleep,
publicIP: publicIP,
publicIPConnect: publicIPConnect,
subnetID: subnetID,
}, nil
}

@@ -273,12 +297,12 @@ func (p *ec2Provider) Start(ctx gocontext.Context, startAttributes *StartAttribu
}

runOpts := &ec2.RunInstancesInput{
ImageId: aws.String(imageID),
InstanceType: aws.String(p.instanceType),
MaxCount: aws.Int64(1),
MinCount: aws.Int64(1),
KeyName: keyResp.KeyName,
SecurityGroupIds: securityGroups,
ImageId: aws.String(imageID),
InstanceType: aws.String(p.instanceType),
MaxCount: aws.Int64(1),
MinCount: aws.Int64(1),
KeyName: keyResp.KeyName,

CreditSpecification: &ec2.CreditSpecificationRequest{
CpuCredits: aws.String("unlimited"), // TODO:
},
@@ -296,6 +320,22 @@ func (p *ec2Provider) Start(ctx gocontext.Context, startAttributes *StartAttribu
},
},
}

if p.subnetID != "" && p.publicIP {
runOpts.NetworkInterfaces = []*ec2.InstanceNetworkInterfaceSpecification{
{
DeviceIndex: aws.Int64(0),
AssociatePublicIpAddress: &p.publicIP,
SubnetId: aws.String(p.subnetID),
Groups: securityGroups,
DeleteOnTermination: aws.Bool(true),
},
}
} else {
runOpts.SubnetId = aws.String(p.subnetID)
runOpts.SecurityGroupIds = securityGroups
}

reservation, err := svc.RunInstances(runOpts)

if err != nil {
@@ -317,7 +357,10 @@ func (p *ec2Provider) Start(ctx gocontext.Context, startAttributes *StartAttribu
}
instance = instances.Reservations[0].Instances[0]
if instances != nil {
address := *instance.PublicDnsName
address := *instance.PrivateDnsName
if p.publicIPConnect {
address = *instance.PublicDnsName
}
if address != "" {
break
}
@@ -394,6 +437,28 @@ func (i *ec2Instance) UploadScript(ctx gocontext.Context, script []byte) error {
//return i.uploadScriptAttempt(ctx, script)
}

func (i *ec2Instance) waitForSSH(port, timeout int) error {

host := *i.instance.PrivateIpAddress
if i.provider.publicIPConnect {
host = *i.instance.PublicIpAddress
}

iter := 0
for {
_, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), 1*time.Second)
if err == nil {
break
}
iter = iter + 1
if iter > timeout {
return err
}
time.Sleep(500 * time.Millisecond)
}
return nil
}

func (i *ec2Instance) uploadScriptAttempt(ctx gocontext.Context, script []byte) error {
return i.uploadScriptSCP(ctx, script)
}
@@ -416,7 +481,11 @@ func (i *ec2Instance) uploadScriptSCP(ctx gocontext.Context, script []byte) erro
}

func (i *ec2Instance) sshConnection(ctx gocontext.Context) (ssh.Connection, error) {
return i.sshDialer.Dial(fmt.Sprintf("%s:22", *i.instance.PublicDnsName), defaultEC2SSHUserName, i.provider.sshDialTimeout)
ip := *i.instance.PrivateIpAddress
if i.provider.publicIPConnect {
ip = *i.instance.PublicIpAddress
}
return i.sshDialer.Dial(fmt.Sprintf("%s:22", ip), defaultEC2SSHUserName, i.provider.sshDialTimeout)
}

func (i *ec2Instance) RunScript(ctx gocontext.Context, output io.Writer) (*RunResult, error) {
@@ -483,7 +552,10 @@ func (i *ec2Instance) Warmed() bool {
}

func (i *ec2Instance) ID() string {
return *i.instance.PublicDnsName
if i.provider.publicIP {
return *i.instance.PublicDnsName
}
return *i.instance.PrivateDnsName
}

func (i *ec2Instance) ImageName() string {