Skip to content

Commit

Permalink
Add support for --profile for selecting IAM credentials (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
stshan authored and sanathkr committed Sep 20, 2017
1 parent 8624aae commit b3101e5
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 20 deletions.
14 changes: 8 additions & 6 deletions env.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ import (
This priority also applies to AWS_* system variables
*/

func getEnvironmentVariables(logicalID string, function *cloudformation.AWSServerlessFunction, overrideFile string) map[string]string {
func getEnvironmentVariables(logicalID string, function *cloudformation.AWSServerlessFunction, overrideFile string, profile string) map[string]string {

env := getEnvDefaults(function)
env := getEnvDefaults(function, profile)
osenv := getEnvFromOS()
overrides := getEnvOverrides(logicalID, overrideFile)

Expand Down Expand Up @@ -69,9 +69,9 @@ func getEnvironmentVariables(logicalID string, function *cloudformation.AWSServe

}

func getEnvDefaults(function *cloudformation.AWSServerlessFunction) map[string]string {
func getEnvDefaults(function *cloudformation.AWSServerlessFunction, profile string) map[string]string {

creds := getSessionOrDefaultCreds()
creds := getSessionOrDefaultCreds(profile)

// Variables available in Lambda execution environment for all functions (AWS_* variables)
env := map[string]string{
Expand Down Expand Up @@ -123,7 +123,7 @@ func getEnvOverrides(logicalID string, filename string) map[string]string {

}

func getSessionOrDefaultCreds() map[string]string {
func getSessionOrDefaultCreds(profile string) map[string]string {

region := "us-east-1"
key := "defaultkey"
Expand All @@ -135,8 +135,10 @@ func getSessionOrDefaultCreds() map[string]string {
"secret": secret,
}

opts := session.Options{}
opts.Profile = profile
// Obtain AWS credentials and pass them through to the container runtime via env variables
if sess, err := session.NewSession(); err == nil {
if sess, err := session.NewSessionWithOptions(opts); err == nil {
if creds, err := sess.Config.Credentials.Get(); err == nil {
if *sess.Config.Region != "" {
result["region"] = *sess.Config.Region
Expand Down
12 changes: 6 additions & 6 deletions env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ var _ = Describe("Environment Variables", func() {
It("return defaults with those defined in the template", func() {

for name, function := range functions {
variables := getEnvironmentVariables(name, &function, "")
variables := getEnvironmentVariables(name, &function, "", "")
Expect(variables).To(HaveLen(9))
Expect(variables).To(HaveKey("AWS_SAM_LOCAL"))
Expect(variables).To(HaveKey("AWS_REGION"))
Expand All @@ -46,7 +46,7 @@ var _ = Describe("Environment Variables", func() {
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Unsetenv("AWS_SESSION_TOKEN")

variables := getEnvironmentVariables(name, &function, "")
variables := getEnvironmentVariables(name, &function, "", "")
Expect(variables).To(HaveLen(9))
Expect(variables).To(HaveKey("AWS_SAM_LOCAL"))
Expect(variables).To(HaveKey("AWS_REGION"))
Expand All @@ -72,7 +72,7 @@ var _ = Describe("Environment Variables", func() {
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
os.Setenv("AWS_SESSION_TOKEN", "token")

variables := getEnvironmentVariables(name, &function, "")
variables := getEnvironmentVariables(name, &function, "", "")
Expect(variables).To(HaveLen(10))
Expect(variables).To(HaveKey("AWS_SAM_LOCAL"))
Expect(variables).To(HaveKey("AWS_REGION"))
Expand All @@ -96,19 +96,19 @@ var _ = Describe("Environment Variables", func() {

It("overides template with environment variables", func() {
for name, function := range functions {
variables := getEnvironmentVariables(name, &function, "")
variables := getEnvironmentVariables(name, &function, "", "")
Expect(variables["TABLE_NAME"]).To(Equal(""))

os.Setenv("TABLE_NAME", "ENV_TABLE")
variables = getEnvironmentVariables(name, &function, "")
variables = getEnvironmentVariables(name, &function, "", "")
Expect(variables["TABLE_NAME"]).To(Equal("ENV_TABLE"))
os.Unsetenv("TABLE_NAME")
}
})

It("overrides template and environment with customer overrides", func() {
for name, function := range functions {
variables := getEnvironmentVariables(name, &function, "test/environment-overrides.json")
variables := getEnvironmentVariables(name, &function, "test/environment-overrides.json", "")
Expect(variables["TABLE_NAME"]).To(Equal("OVERRIDE_TABLE"))
}
os.Unsetenv("TABLE_NAME")
Expand Down
2 changes: 1 addition & 1 deletion invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func invoke(c *cli.Context) {
event = string(pb)
}

stdoutTxt, stderrTxt, err := runt.Invoke(event)
stdoutTxt, stderrTxt, err := runt.Invoke(event, c.String("profile"))
if err != nil {
log.Fatalf("Could not invoke function: %s\n", err)
}
Expand Down
8 changes: 8 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ func main() {
Usage: "Optional. Specify whether SAM should skip pulling down the latest Docker image. Default is false.",
EnvVar: "SAM_SKIP_PULL_IMAGE",
},
cli.StringFlag{
Name: "profile",
Usage: "Optional. Specify which AWS credentials profile to use.",
},
},
},
cli.Command{
Expand Down Expand Up @@ -171,6 +175,10 @@ func main() {
Usage: "Optional. Specify whether SAM should skip pulling down the latest Docker image. Default is false.",
EnvVar: "SAM_SKIP_PULL_IMAGE",
},
cli.StringFlag{
Name: "profile",
Usage: "Optional. Specify which AWS credentials profile to use.",
},
},
},
cli.Command{
Expand Down
12 changes: 6 additions & 6 deletions runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ import (

// Invoker is a simple interface to help with testing runtimes
type Invoker interface {
Invoke(string) (io.Reader, io.Reader, error)
InvokeHTTP() func(http.ResponseWriter, *http.Request)
Invoke(string, string) (io.Reader, io.Reader, error)
InvokeHTTP(string) func(http.ResponseWriter, *http.Request)
CleanUp()
}

Expand Down Expand Up @@ -253,7 +253,7 @@ func (r *Runtime) getHostConfig() (*container.HostConfig, error) {
// Invoke runs a Lambda function within the runtime with the provided event
// payload and returns a pair of io.Readers for it's stdout (callback results)
// and stderr (runtime logs).
func (r *Runtime) Invoke(event string) (io.Reader, io.Reader, error) {
func (r *Runtime) Invoke(event string, profile string) (io.Reader, io.Reader, error) {

log.Printf("Invoking %s (%s)\n", r.Function.Handler, r.Name)

Expand Down Expand Up @@ -293,7 +293,7 @@ func (r *Runtime) Invoke(event string) (io.Reader, io.Reader, error) {
Cmd: []string{r.Function.Handler, event},
Env: func() []string {
result := []string{}
for k, v := range getEnvironmentVariables(r.LogicalID, &r.Function, r.EnvOverrideFile) {
for k, v := range getEnvironmentVariables(r.LogicalID, &r.Function, r.EnvOverrideFile, profile) {
result = append(result, k+"="+v)
}
return result
Expand Down Expand Up @@ -477,7 +477,7 @@ func (r *Runtime) CleanUp() {

// InvokeHTTP invokes a Lambda function, and implements the Go http.HandlerFunc interface
// so it can be connected straight into most HTTP packages/frameworks etc.
func (r *Runtime) InvokeHTTP() func(http.ResponseWriter, *http.Request) {
func (r *Runtime) InvokeHTTP(profile string) func(http.ResponseWriter, *http.Request) {

return func(w http.ResponseWriter, req *http.Request) {
var wg sync.WaitGroup
Expand All @@ -501,7 +501,7 @@ func (r *Runtime) InvokeHTTP() func(http.ResponseWriter, *http.Request) {
return
}

stdoutTxt, stderrTxt, err := r.Invoke(eventJSON)
stdoutTxt, stderrTxt, err := r.Invoke(eventJSON, profile)
if err != nil {
msg := fmt.Sprintf("Error invoking %s runtime: %s", r.Function.Runtime, err)
log.Println(msg)
Expand Down
2 changes: 1 addition & 1 deletion start.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func start(c *cli.Context) {
}

// Add this AWS::Serverless::Function to the HTTP router
if err := mux.AddFunction(&function, runt.InvokeHTTP()); err != nil {
if err := mux.AddFunction(&function, runt.InvokeHTTP(c.String("profile"))); err != nil {
if err == router.ErrNoEventsFound {
log.Printf("Ignoring %s (%s) as no API event sources are defined", name, function.Handler)
}
Expand Down

0 comments on commit b3101e5

Please sign in to comment.