diff --git a/pkg/annotations.go b/pkg/annotations/constants.go similarity index 50% rename from pkg/annotations.go rename to pkg/annotations/constants.go index 51c39a9df..6f8092145 100644 --- a/pkg/annotations.go +++ b/pkg/annotations/constants.go @@ -1,18 +1,18 @@ /* - Copyright 2010 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Copyright 2010 Amazon.com, Inc. or its affiliates. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"). - You may not use this file except in compliance with the License. - A copy of the License is located at +Licensed under the Apache License, Version 2.0 (the "License"). +You may not use this file except in compliance with the License. +A copy of the License is located at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - or in the "license" file accompanying this file. This file is distributed - on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - express or implied. See the License for the specific language governing - permissions and limitations under the License. +or in the "license" file accompanying this file. This file is distributed +on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +express or implied. See the License for the specific language governing +permissions and limitations under the License. */ -package pkg +package annotations const ( // The audience annotation diff --git a/pkg/annotations/parser.go b/pkg/annotations/parser.go new file mode 100644 index 000000000..e68f9e9dc --- /dev/null +++ b/pkg/annotations/parser.go @@ -0,0 +1,94 @@ +/* +Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"). +You may not use this file except in compliance with the License. +A copy of the License is located at + + http://www.apache.org/licenses/LICENSE-2.0 + +or in the "license" file accompanying this file. This file is distributed +on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +express or implied. See the License for the specific language governing +permissions and limitations under the License. +*/ + +package annotations + +import ( + "encoding/csv" + "strconv" + "strings" + + "github.com/aws/amazon-eks-pod-identity-webhook/pkg" + corev1 "k8s.io/api/core/v1" + "k8s.io/klog/v2" +) + +type PodAnnotations struct { + tokenExpiration *int64 + containersToSkip map[string]bool +} + +func (a *PodAnnotations) GetContainersToSkip() map[string]bool { + return a.containersToSkip +} + +func (a *PodAnnotations) GetTokenExpiration(fallback int64) int64 { + if a.tokenExpiration == nil { + return fallback + } else { + return *a.tokenExpiration + } +} + +// parsePodAnnotations parses the pod annotations that can influence mutation: +// - tokenExpiration. Overrides the given service account annotation/flag-level +// setting. +// - containersToSkip. A Pod specific setting since certain containers within a +// specific pod might need to be opted-out of mutation +func ParsePodAnnotations(pod *corev1.Pod, annotationDomain string) *PodAnnotations { + return &PodAnnotations{ + tokenExpiration: parseTokenExpiration(annotationDomain, pod), + containersToSkip: parseContainersToSkip(annotationDomain, pod), + } +} + +// parseContainersToSkip returns the containers of a pod to skip mutating +func parseContainersToSkip(annotationDomain string, pod *corev1.Pod) map[string]bool { + skippedNames := map[string]bool{} + skipContainersKey := annotationDomain + "/" + SkipContainersAnnotation + + value, ok := pod.Annotations[skipContainersKey] + if !ok { + return nil + } + r := csv.NewReader(strings.NewReader(value)) + // error means we don't skip any + podNames, err := r.Read() + if err != nil { + klog.Infof("Could not parse skip containers annotation on pod %s/%s: %v", pod.Namespace, pod.Name, err) + return nil + } + for _, name := range podNames { + skippedNames[name] = true + } + return skippedNames +} + +func parseTokenExpiration(annotationDomain string, pod *corev1.Pod) *int64 { + expirationKey := annotationDomain + "/" + TokenExpirationAnnotation + expirationStr, ok := pod.Annotations[expirationKey] + if !ok { + return nil + } + + expiration, err := strconv.ParseInt(expirationStr, 10, 64) + if err != nil { + klog.V(4).Infof("Found invalid value for token expiration on the pod annotation: %s, falling back to the default: %v", expirationStr, err) + return nil + } + + val := pkg.ValidateMinTokenExpiration(expiration) + return &val +} diff --git a/pkg/annotations/parser_test.go b/pkg/annotations/parser_test.go new file mode 100644 index 000000000..1bd50a555 --- /dev/null +++ b/pkg/annotations/parser_test.go @@ -0,0 +1,99 @@ +/* +Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"). +You may not use this file except in compliance with the License. +A copy of the License is located at + + http://www.apache.org/licenses/LICENSE-2.0 + +or in the "license" file accompanying this file. This file is distributed +on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +express or implied. See the License for the specific language governing +permissions and limitations under the License. +*/ + +package annotations + +import ( + "testing" + + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + "sigs.k8s.io/yaml" +) + +func TestParsePodAnnotations(t *testing.T) { + podNoAnnotations := ` + apiVersion: v1 + kind: Pod + metadata: + name: balajilovesoreos` + testcases := []struct { + name string + pod string + + expectedContainersToSkip map[string]bool + + fallbackExpiration int64 + expectedExpiration int64 + }{ + { + name: "sidecar-containers", + pod: ` + apiVersion: v1 + kind: Pod + metadata: + name: balajilovesoreos + annotations: + testing.eks.amazonaws.com/skip-containers: "sidecar,foo" + `, + expectedContainersToSkip: map[string]bool{"sidecar": true, "foo": true}, + }, + { + name: "token-expiration", + pod: ` + apiVersion: v1 + kind: Pod + metadata: + name: balajilovesoreos + annotations: + testing.eks.amazonaws.com/token-expiration: "1234" + `, + fallbackExpiration: 4567, + expectedExpiration: 1234, + }, + { + name: "token-expiration fallback", + pod: podNoAnnotations, + fallbackExpiration: 4567, + expectedExpiration: 4567, + }, + { + name: "token-expiration round up to min value", + pod: ` + apiVersion: v1 + kind: Pod + metadata: + name: balajilovesoreos + annotations: + testing.eks.amazonaws.com/token-expiration: "0" + `, + fallbackExpiration: 4567, + expectedExpiration: 600, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + var pod *corev1.Pod + + err := yaml.Unmarshal([]byte(tc.pod), &pod) + assert.NoError(t, err) + + actual := ParsePodAnnotations(pod, "testing.eks.amazonaws.com") + assert.Equal(t, tc.expectedContainersToSkip, actual.GetContainersToSkip()) + assert.Equal(t, tc.expectedExpiration, actual.GetTokenExpiration(tc.fallbackExpiration)) + }) + } +} diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 55b5885f7..bae7b87c3 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -24,6 +24,7 @@ import ( "sync" "github.com/aws/amazon-eks-pod-identity-webhook/pkg" + "github.com/aws/amazon-eks-pod-identity-webhook/pkg/annotations" "github.com/prometheus/client_golang/prometheus" v1 "k8s.io/api/core/v1" utilruntime "k8s.io/apimachinery/pkg/util/runtime" @@ -212,7 +213,7 @@ func (c *serviceAccountCache) ToJSON() string { func (c *serviceAccountCache) addSA(sa *v1.ServiceAccount) { entry := &Entry{} - arn, ok := sa.Annotations[c.annotationPrefix+"/"+pkg.RoleARNAnnotation] + arn, ok := sa.Annotations[c.annotationPrefix+"/"+annotations.RoleARNAnnotation] if ok { if !strings.Contains(arn, "arn:") && c.composeRoleArn.Enabled { arn = fmt.Sprintf("arn:%s:iam::%s:role/%s", c.composeRoleArn.Partition, c.composeRoleArn.AccountID, arn) @@ -228,12 +229,12 @@ func (c *serviceAccountCache) addSA(sa *v1.ServiceAccount) { } entry.Audience = c.defaultAudience - if audience, ok := sa.Annotations[c.annotationPrefix+"/"+pkg.AudienceAnnotation]; ok { + if audience, ok := sa.Annotations[c.annotationPrefix+"/"+annotations.AudienceAnnotation]; ok { entry.Audience = audience } entry.UseRegionalSTS = c.defaultRegionalSTS - if useRegionalStr, ok := sa.Annotations[c.annotationPrefix+"/"+pkg.UseRegionalSTSAnnotation]; ok { + if useRegionalStr, ok := sa.Annotations[c.annotationPrefix+"/"+annotations.UseRegionalSTSAnnotation]; ok { useRegional, err := strconv.ParseBool(useRegionalStr) if err != nil { klog.V(4).Infof("Ignoring service account %s/%s invalid value for disable-regional-sts annotation", sa.Namespace, sa.Name) @@ -243,7 +244,7 @@ func (c *serviceAccountCache) addSA(sa *v1.ServiceAccount) { } entry.TokenExpiration = c.defaultTokenExpiration - if tokenExpirationStr, ok := sa.Annotations[c.annotationPrefix+"/"+pkg.TokenExpirationAnnotation]; ok { + if tokenExpirationStr, ok := sa.Annotations[c.annotationPrefix+"/"+annotations.TokenExpirationAnnotation]; ok { if tokenExpiration, err := strconv.ParseInt(tokenExpirationStr, 10, 64); err != nil { klog.V(4).Infof("Found invalid value for token expiration, using %d seconds as default: %v", entry.TokenExpiration, err) } else { diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 31f852654..31f3f9c95 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -16,16 +16,15 @@ package handler import ( - "encoding/csv" "encoding/json" "fmt" "io/ioutil" "net/http" "path/filepath" - "strconv" "strings" "time" + "github.com/aws/amazon-eks-pod-identity-webhook/pkg/annotations" "github.com/aws/amazon-eks-pod-identity-webhook/pkg/containercredentials" "github.com/aws/amazon-eks-pod-identity-webhook/pkg" @@ -143,25 +142,6 @@ func logContext(podName, podGenerateName, serviceAccountName, namespace string) "Namespace=%s", name, serviceAccountName, namespace) } -// getContainersToSkip returns the containers of a pod to skip mutating -func getContainersToSkip(annotationDomain string, pod *corev1.Pod) map[string]bool { - skippedNames := map[string]bool{} - skipContainersKey := annotationDomain + "/" + pkg.SkipContainersAnnotation - if value, ok := pod.Annotations[skipContainersKey]; ok { - r := csv.NewReader(strings.NewReader(value)) - // error means we don't skip any - podNames, err := r.Read() - if err != nil { - klog.Infof("Could not parse skip containers annotation on pod %s/%s: %v", pod.Namespace, pod.Name, err) - return skippedNames - } - for _, name := range podNames { - skippedNames[name] = true - } - } - return skippedNames -} - func (m *Modifier) addEnvToContainer(container *corev1.Container, tokenFilePath string, patchConfig *podPatchConfig) bool { var ( webIdentityKeysDefined bool @@ -277,29 +257,6 @@ func (m *Modifier) addEnvToContainer(container *corev1.Container, tokenFilePath return changed } -// parsePodAnnotations parses the pod annotations that can influence mutation: -// - tokenExpiration. Overrides the given service account annotation/flag-level -// setting. -// - containersToSkip. A Pod specific setting since certain containers within a -// specific pod might need to be opted-out of mutation -func (m *Modifier) parsePodAnnotations(pod *corev1.Pod, serviceAccountTokenExpiration int64) (int64, map[string]bool) { - // override serviceaccount annotation/flag token expiration with pod - // annotation if present - tokenExpiration := serviceAccountTokenExpiration - expirationKey := m.AnnotationDomain + "/" + pkg.TokenExpirationAnnotation - if expirationStr, ok := pod.Annotations[expirationKey]; ok { - if expiration, err := strconv.ParseInt(expirationStr, 10, 64); err != nil { - klog.V(4).Infof("Found invalid value for token expiration, using %d seconds as default: %v", serviceAccountTokenExpiration, err) - } else { - tokenExpiration = pkg.ValidateMinTokenExpiration(expiration) - } - } - - containersToSkip := getContainersToSkip(m.AnnotationDomain, pod) - - return tokenExpiration, containersToSkip -} - // getPodSpecPatch gets the patch operation to be applied to the given Pod func (m *Modifier) getPodSpecPatch(pod *corev1.Pod, patchConfig *podPatchConfig) ([]patchOperation, bool) { tokenFilePath := filepath.Join(patchConfig.MountPath, patchConfig.TokenPath) @@ -413,15 +370,16 @@ func (m *Modifier) getPodSpecPatch(pod *corev1.Pod, patchConfig *podPatchConfig) func (m *Modifier) buildPodPatchConfig(pod *corev1.Pod) *podPatchConfig { // Container credentials method takes precedence containerCredentialsPatchConfig := m.ContainerCredentialsConfig.Get(pod.Namespace, pod.Spec.ServiceAccountName) + + podAnnotations := annotations.ParsePodAnnotations(pod, m.AnnotationDomain) if containerCredentialsPatchConfig != nil { - regionalSTS, tokenExpiration := m.Cache.GetCommonConfigurations(pod.Spec.ServiceAccountName, pod.Namespace) - tokenExpiration, containersToSkip := m.parsePodAnnotations(pod, tokenExpiration) + regionalSTS, tokenExpirationFallback := m.Cache.GetCommonConfigurations(pod.Spec.ServiceAccountName, pod.Namespace) webhookPodCount.WithLabelValues("container_credentials").Inc() return &podPatchConfig{ - ContainersToSkip: containersToSkip, - TokenExpiration: tokenExpiration, + ContainersToSkip: podAnnotations.GetContainersToSkip(), + TokenExpiration: podAnnotations.GetTokenExpiration(tokenExpirationFallback), UseRegionalSTS: regionalSTS, Audience: containerCredentialsPatchConfig.Audience, MountPath: containerCredentialsPatchConfig.MountPath, @@ -452,13 +410,12 @@ func (m *Modifier) buildPodPatchConfig(pod *corev1.Pod) *podPatchConfig { } klog.V(5).Infof("Value of roleArn after after cache retrieval for service account %s: %s", request.CacheKey(), response.RoleARN) if response.RoleARN != "" { - tokenExpiration, containersToSkip := m.parsePodAnnotations(pod, response.TokenExpiration) webhookPodCount.WithLabelValues("sts_web_identity").Inc() return &podPatchConfig{ - ContainersToSkip: containersToSkip, - TokenExpiration: tokenExpiration, + ContainersToSkip: podAnnotations.GetContainersToSkip(), + TokenExpiration: podAnnotations.GetTokenExpiration(response.TokenExpiration), UseRegionalSTS: response.UseRegionalSTS, Audience: response.Audience, MountPath: m.MountPath, diff --git a/pkg/handler/handler_pod_test.go b/pkg/handler/handler_pod_test.go index 2296b2661..2c0d6c6c6 100644 --- a/pkg/handler/handler_pod_test.go +++ b/pkg/handler/handler_pod_test.go @@ -19,14 +19,15 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/aws/amazon-eks-pod-identity-webhook/pkg/cache" - "github.com/aws/amazon-eks-pod-identity-webhook/pkg/containercredentials" "os" "path/filepath" "strconv" "strings" "testing" + "github.com/aws/amazon-eks-pod-identity-webhook/pkg/cache" + "github.com/aws/amazon-eks-pod-identity-webhook/pkg/containercredentials" + corev1 "k8s.io/api/core/v1" "sigs.k8s.io/yaml" ) @@ -140,42 +141,42 @@ func TestUpdatePodSpec(t *testing.T) { if info.IsDir() { return nil } - if strings.HasSuffix(info.Name(), ".yaml") || strings.HasSuffix(info.Name(), ".yml") { - pod, err := parseFile(filepath.Join("./", path)) + if !strings.HasSuffix(info.Name(), ".yaml") && !strings.HasSuffix(info.Name(), ".yml") { + return nil + } + pod, err := parseFile(filepath.Join("./", path)) + if err != nil { + t.Errorf("Error while parsing file %s: %v", info.Name(), err) + return err + } + if skipStr, ok := pod.Annotations[skipAnnotation]; ok { + skip, _ := strconv.ParseBool(skipStr) + if skip { + return nil + } + } + + pod.Namespace = "default" + pod.Spec.ServiceAccountName = "default" + + t.Run(fmt.Sprintf("Pod %s in file %s", pod.Name, path), func(t *testing.T) { + modifier := buildModifierFromPod(pod) + patchConfig := modifier.buildPodPatchConfig(pod) + patch, _ := modifier.getPodSpecPatch(pod, patchConfig) + patchBytes, err := json.Marshal(patch) if err != nil { - t.Errorf("Error while parsing file %s: %v", info.Name(), err) - return err + t.Errorf("Unexpected error: %v", err) } - if skipStr, ok := pod.Annotations[skipAnnotation]; ok { - skip, _ := strconv.ParseBool(skipStr) - if skip { - return nil - } + expectedPatchStr, ok := pod.Annotations[expectedPatchAnnotation] + if !ok && (len(patchBytes) == 0 || patchBytes == nil) { + return } - pod.Namespace = "default" - pod.Spec.ServiceAccountName = "default" - - t.Run(fmt.Sprintf("Pod %s in file %s", pod.Name, path), func(t *testing.T) { - modifier := buildModifierFromPod(pod) - patchConfig := modifier.buildPodPatchConfig(pod) - patch, _ := modifier.getPodSpecPatch(pod, patchConfig) - patchBytes, err := json.Marshal(patch) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - expectedPatchStr, ok := pod.Annotations[expectedPatchAnnotation] - if !ok && (len(patchBytes) == 0 || patchBytes == nil) { - return - } - - if bytes.Compare(patchBytes, []byte(expectedPatchStr)) != 0 { - t.Errorf("Expected patch didn't match: \nGot\n\t%v\nWanted:\n\t%v\n", string(patchBytes), expectedPatchStr) - } - - }) - return nil - } + if bytes.Compare(patchBytes, []byte(expectedPatchStr)) != 0 { + t.Errorf("Expected patch didn't match: \nGot\n\t%v\nWanted:\n\t%v\n", string(patchBytes), expectedPatchStr) + } + + }) return nil }) if err != nil {