From 10213db23ef90e27d0f5cfea944d39c34a2438bb Mon Sep 17 00:00:00 2001 From: neogopher Date: Fri, 9 Feb 2024 17:33:06 +0530 Subject: [PATCH 1/2] Update hostpathmapper to work with multi namespace mode. - When multiNamespace mode is detected, the localmanager cache's listwatch will watch all namespaces instead of just the target namespace. - In mapHostPaths, only pods belonging to namespaces managed by the current vcluster will be allowed to be processed. --- .github/workflows/lint.yaml | 4 +- Dockerfile | 2 +- cmd/hostpaths/hostpaths.go | 195 +++++++++++------- go.mod | 2 +- go.sum | 18 ++ .../resources/namespaces/syncer.go | 125 +++++++++++ .../resources/namespaces/translate.go | 41 ++++ vendor/modules.txt | 1 + 8 files changed, 304 insertions(+), 84 deletions(-) create mode 100644 vendor/github.com/loft-sh/vcluster/pkg/controllers/resources/namespaces/syncer.go create mode 100644 vendor/github.com/loft-sh/vcluster/pkg/controllers/resources/namespaces/translate.go diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index c90887f2..77a8d252 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -16,11 +16,11 @@ jobs: name: lint runs-on: ubuntu-latest steps: + - uses: actions/checkout@v4 - uses: actions/setup-go@v4 with: - go-version: "1.20" + go-version-file: "go.mod" cache: false - - uses: actions/checkout@v3 - name: Run golangci-lint uses: golangci/golangci-lint-action@v3 with: diff --git a/Dockerfile b/Dockerfile index 1e88ea03..1f21362f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Build the manager binary -FROM golang:1.20 as builder +FROM golang:1.22 as builder WORKDIR /vcluster-hpm-dev ARG TARGETOS diff --git a/cmd/hostpaths/hostpaths.go b/cmd/hostpaths/hostpaths.go index a81debd3..498a234d 100644 --- a/cmd/hostpaths/hostpaths.go +++ b/cmd/hostpaths/hostpaths.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/loft-sh/vcluster/pkg/controllers/resources/namespaces" podtranslate "github.com/loft-sh/vcluster/pkg/controllers/resources/pods/translate" "github.com/loft-sh/vcluster/pkg/util/clienthelper" @@ -18,13 +19,12 @@ import ( "github.com/loft-sh/vcluster/pkg/util/translate" "github.com/pkg/errors" "github.com/spf13/cobra" - appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" kerrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/fields" + "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" @@ -148,7 +148,7 @@ func Start(ctx context.Context, options *context2.VirtualClusterOptions, init bo kubeClient, err := kubernetes.NewForConfig(virtualClusterConfig) if err != nil { - return false, errors.Wrap(err, "create kube client") + return false, fmt.Errorf("create kube client: %w", err) } _, err = kubeClient.Discovery().ServerVersion() @@ -168,13 +168,17 @@ func Start(ctx context.Context, options *context2.VirtualClusterOptions, init bo return err } - localManager, err := ctrl.NewManager(inClusterConfig, ctrl.Options{ - Scheme: scheme, - MetricsBindAddress: "0", - LeaderElection: false, - Namespace: options.TargetNamespace, - NewClient: pluginhookclient.NewPhysicalPluginClientFactory(blockingcacheclient.NewCacheClient), - }) + kubeClient, err := kubernetes.NewForConfig(inClusterConfig) + if err != nil { + return fmt.Errorf("create kube client: %w", err) + } + + err = findVclusterModeAndSetDefaultTranslation(ctx, kubeClient, options) + if err != nil { + return fmt.Errorf("find vcluster mode: %w", err) + } + + localManager, err := ctrl.NewManager(inClusterConfig, localManagerCtrlOptions(options)) if err != nil { return err } @@ -193,11 +197,6 @@ func Start(ctx context.Context, options *context2.VirtualClusterOptions, init bo startManagers(ctx, localManager, virtualClusterManager) - err = findVclusterModeAndSetDefaultTranslation(ctx, localManager, options) - if err != nil { - return err - } - if init { klog.Info("is init container mode") defer ctx.Done() @@ -216,49 +215,46 @@ func Start(ctx context.Context, options *context2.VirtualClusterOptions, init bo return nil } -func getVclusterObject(ctx context.Context, localManager manager.Manager, vclusterName, vclusterNamespace string, object client.Object) error { - err := localManager.GetClient().Get(ctx, types.NamespacedName{ - Name: vclusterName, - Namespace: vclusterNamespace, - }, object) - if err != nil { - return err - } - - return nil -} - -func getSyncerPodSpec(ctx context.Context, localManager manager.Manager, vclusterName, vclusterNamespace string) (*corev1.PodSpec, error) { +func getSyncerPodSpec(ctx context.Context, kubeClient kubernetes.Interface, vclusterName, vclusterNamespace string) (*corev1.PodSpec, error) { // try looking for the stateful set first - vclusterSts := &appsv1.StatefulSet{} - err := getVclusterObject(ctx, localManager, vclusterName, vclusterNamespace, vclusterSts) - if err != nil { + vclusterSts, err := kubeClient.AppsV1().StatefulSets(vclusterNamespace).Get(ctx, vclusterName, metav1.GetOptions{}) + if kerrors.IsNotFound(err) { + // try looking for deployment - in case of eks/k8s + vclusterDeploy, err := kubeClient.AppsV1().Deployments(vclusterNamespace).Get(ctx, vclusterName, metav1.GetOptions{}) if kerrors.IsNotFound(err) { - // try looking for deployment - in case of eks/k8s - vclusterDeploy := &appsv1.Deployment{} - err := getVclusterObject(ctx, localManager, vclusterName, vclusterNamespace, vclusterDeploy) - if err != nil { - if kerrors.IsNotFound(err) { - klog.Errorf("could not find vcluster either in statefulset or deployment: %v", err) - return nil, err - } - - klog.Errorf("error looking for vcluster deployment: %v", err) - return nil, err - } - - return &vclusterDeploy.Spec.Template.Spec, nil + klog.Errorf("could not find vcluster either in statefulset or deployment: %v", err) + return nil, err + } else if err != nil { + klog.Errorf("error looking for vcluster deployment: %v", err) + return nil, err } + return &vclusterDeploy.Spec.Template.Spec, nil + } else if err != nil { return nil, err } return &vclusterSts.Spec.Template.Spec, nil } -func findVclusterModeAndSetDefaultTranslation(ctx context.Context, localManager manager.Manager, options *context2.VirtualClusterOptions) error { - vclusterPodSpec, err := getSyncerPodSpec(ctx, localManager, options.Name, options.TargetNamespace) +func localManagerCtrlOptions(options *context2.VirtualClusterOptions) manager.Options { + controllerOptions := ctrl.Options{ + Scheme: scheme, + MetricsBindAddress: "0", + LeaderElection: false, + NewClient: pluginhookclient.NewPhysicalPluginClientFactory(blockingcacheclient.NewCacheClient), + } + + if !options.MultiNamespaceMode { + controllerOptions.Cache.Namespaces = []string{options.TargetNamespace} + } + + return controllerOptions +} + +func findVclusterModeAndSetDefaultTranslation(ctx context.Context, kubeClient kubernetes.Interface, options *context2.VirtualClusterOptions) error { + vclusterPodSpec, err := getSyncerPodSpec(ctx, kubeClient, options.Name, options.TargetNamespace) if err != nil { return err } @@ -268,6 +264,7 @@ func findVclusterModeAndSetDefaultTranslation(ctx context.Context, localManager // iterate over command args for _, arg := range container.Args { if strings.Contains(arg, MultiNamespaceMode) { + options.MultiNamespaceMode = true translate.Default = translate.NewMultiNamespaceTranslator(options.TargetNamespace) return nil } @@ -338,22 +335,12 @@ func mapHostPaths(ctx context.Context, pManager, vManager manager.Manager) { options := ctx.Value(optionsKey).(*context2.VirtualClusterOptions) wait.Forever(func() { - podList := &corev1.PodList{} - err := pManager.GetClient().List(ctx, podList, &client.ListOptions{ - Namespace: options.TargetNamespace, - FieldSelector: fields.SelectorFromSet(fields.Set{ - NodeIndexName: os.Getenv(HostpathMapperSelfNodeNameEnvVar), - }), - }) + podMappings, err := getPhysicalPodMap(ctx, options, pManager) if err != nil { - klog.Errorf("unable to list pods: %v", err) + klog.Errorf("unable to get physical pod mapping: %v", err) return } - podMappings := make(PhysicalPodMap) - - fillUpPodMapping(ctx, podList, podMappings) - vPodList := &corev1.PodList{} err = vManager.GetClient().List(ctx, vPodList, &client.ListOptions{ FieldSelector: fields.SelectorFromSet(fields.Set{ @@ -420,6 +407,75 @@ func mapHostPaths(ctx context.Context, pManager, vManager manager.Manager) { }, time.Second*5) } +func getPhysicalPodMap(ctx context.Context, options *context2.VirtualClusterOptions, pManager manager.Manager) (PhysicalPodMap, error) { + podListOptions := &client.ListOptions{ + FieldSelector: fields.SelectorFromSet(fields.Set{ + NodeIndexName: os.Getenv(HostpathMapperSelfNodeNameEnvVar), + }), + } + + if !options.MultiNamespaceMode { + podListOptions.Namespace = options.TargetNamespace + } + + podList := &corev1.PodList{} + err := pManager.GetClient().List(ctx, podList, podListOptions) + if err != nil { + return nil, fmt.Errorf("unable to list pods: %w", err) + } + + var pods []corev1.Pod + if options.MultiNamespaceMode { + // find namespaces managed by the current vcluster + nsList := &corev1.NamespaceList{} + err = pManager.GetClient().List(ctx, nsList, &client.ListOptions{ + LabelSelector: labels.SelectorFromSet(labels.Set{ + namespaces.VclusterNamespaceAnnotation: options.TargetNamespace, + }), + }) + if err != nil { + return nil, fmt.Errorf("unable to list namespaces: %w", err) + } + + vclusterNamespaces := make(map[string]struct{}, len(nsList.Items)) + for _, ns := range nsList.Items { + vclusterNamespaces[ns.Name] = struct{}{} + } + + // Limit Pods + pods = make([]corev1.Pod, 0, len(podList.Items)) + for _, pod := range podList.Items { + if _, ok := vclusterNamespaces[pod.Namespace]; ok { + pods = append(pods, pod) + } + } + } else { + pods = podList.Items + } + + podMappings := make(PhysicalPodMap, len(pods)) + for _, pPod := range pods { + lookupName := fmt.Sprintf("%s_%s_%s", pPod.Namespace, pPod.Name, pPod.UID) + + ok, err := checkIfPathExists(lookupName) + if err != nil { + klog.Errorf("error checking existence for path %s: %v", lookupName, err) + } + + if ok { + // check entry in podMapping + if _, ok := podMappings[pPod.Name]; !ok { + podMappings[pPod.Name] = &PodDetail{ + Target: lookupName, + PhysicalPod: pPod, + } + } + } + } + + return podMappings, nil +} + func cleanupOldContainerPaths(ctx context.Context, existingVPodsWithNS map[string]bool) error { options := ctx.Value(optionsKey).(*context2.VirtualClusterOptions) @@ -591,27 +647,6 @@ func getPhysicalLogFilename(ctx context.Context, physicalContainerFileName strin return fileName, nil } -func fillUpPodMapping(ctx context.Context, pPodList *corev1.PodList, podMappings PhysicalPodMap) { - for _, pPod := range pPodList.Items { - lookupName := fmt.Sprintf("%s_%s_%s", pPod.Namespace, pPod.Name, pPod.UID) - - ok, err := checkIfPathExists(lookupName) - if err != nil { - klog.Errorf("error checking existence for path %s: %v", lookupName, err) - } - - if ok { - // check entry in podMapping - if _, ok := podMappings[pPod.Name]; !ok { - podMappings[pPod.Name] = &PodDetail{ - Target: lookupName, - PhysicalPod: pPod, - } - } - } - } -} - // check if folder exists func checkIfPathExists(path string) (bool, error) { fullPath := filepath.Join(PodLogsMountPath, path) diff --git a/go.mod b/go.mod index aa29b6ae..ef99cb1e 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/loft-sh/vcluster-hostpath-mapper -go 1.20 +go 1.22.0 require ( github.com/go-openapi/loads v0.21.2 diff --git a/go.sum b/go.sum index 3f8c673e..0dcc0ba8 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,7 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef h1:46PFijGLmAjMPwCCCo7Jf0W6f9slllCkkv7vyc1yOSg= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -23,6 +24,7 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= +github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -46,12 +48,14 @@ github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbS github.com/fvbommel/sortorder v1.0.1 h1:dSnXLt4mJYH25uDDGa3biZNQsozaUWDSWeKJ0qqFfzE= github.com/fvbommel/sortorder v1.0.1/go.mod h1:uk88iVf1ovNn1iLfgUVU2F9o5eO30ui720w+kxuqRs0= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/zapr v1.2.4 h1:QHVo+6stLbfJmYGkQ7uGHUCu5hnAFAj6mDe6Ea0SeOo= +github.com/go-logr/zapr v1.2.4/go.mod h1:FyHWQIzQORZ0QVE1BtVHv3cKtNLuXsbNLtpuhNapBOA= github.com/go-openapi/analysis v0.21.4 h1:ZDFLvSNxpDaomuCueM0BlSXxpANBlFYiBvr+GXrvIHc= github.com/go-openapi/analysis v0.21.4/go.mod h1:4zQ35W4neeZTqh3ol0rv/O8JBbka9QyAgQRPp9y3pfo= github.com/go-openapi/errors v0.20.2 h1:dxy7PGTqEh94zj2E3h1cUmQQWiM1+aeCROfAr02EmK8= @@ -75,6 +79,7 @@ github.com/go-openapi/swag v0.21.1/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/e github.com/go-openapi/swag v0.22.3 h1:yMBqmnQ0gyZvEb/+KzuWZOXgllrXT4SADYbvDaXHv/g= github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -111,6 +116,7 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= +github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -137,11 +143,13 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kubernetes-csi/external-snapshotter/client/v4 v4.2.0 h1:nHHjmvjitIiyPlUHk/ofpgvBcNcawJLtf4PYHORLjAA= +github.com/kubernetes-csi/external-snapshotter/client/v4 v4.2.0/go.mod h1:YBCo4DoEeDndqvAn6eeu0vWM7QdXmHEeI9cFWplmBys= github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de h1:9TO3cAIGXtEhnIaL+V+BEER86oLrvS+kWobKpbJuye0= github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de/go.mod h1:zAbeS9B/r2mtpb6U+EI2rYA5OAXxsYw6wTamcNW+zcE= github.com/loft-sh/vcluster v0.15.2 h1:ipTthjYi0a/X7bksMHCJ532C+RTr4LHT3NeJehw/j9w= @@ -179,7 +187,9 @@ github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss= +github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0= github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU= +github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4= github.com/peterbourgon/diskv v2.0.1+incompatible h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -197,9 +207,11 @@ github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO github.com/prometheus/procfs v0.9.0 h1:wzCHvIvM5SxWqYvwgVL7yJY8Lz3PKn49KQtpgMYJfhI= github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB/chUwxUZY= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= +github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= @@ -234,8 +246,11 @@ go.starlark.net v0.0.0-20200306205701-8dd3e2ee1dd5/go.mod h1:nmDLcffg48OtT/PSW0H go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= +go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4= go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= +go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -303,6 +318,7 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= +golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -357,7 +373,9 @@ gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= +gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o= +gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= k8s.io/api v0.27.4 h1:0pCo/AN9hONazBKlNUdhQymmnfLRbSZjd5H5H3f0bSs= diff --git a/vendor/github.com/loft-sh/vcluster/pkg/controllers/resources/namespaces/syncer.go b/vendor/github.com/loft-sh/vcluster/pkg/controllers/resources/namespaces/syncer.go new file mode 100644 index 00000000..1412a4ad --- /dev/null +++ b/vendor/github.com/loft-sh/vcluster/pkg/controllers/resources/namespaces/syncer.go @@ -0,0 +1,125 @@ +package namespaces + +import ( + "fmt" + "strings" + + "github.com/loft-sh/vcluster/pkg/constants" + "github.com/loft-sh/vcluster/pkg/controllers/syncer" + synccontext "github.com/loft-sh/vcluster/pkg/controllers/syncer/context" + "github.com/loft-sh/vcluster/pkg/controllers/syncer/translator" + "github.com/loft-sh/vcluster/pkg/util/translate" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/validation" + "k8s.io/apimachinery/pkg/util/validation/field" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" +) + +// Unsafe annotations based on the docs here: +// https://kubernetes.io/docs/reference/labels-annotations-taints/ +var excludedAnnotations = []string{ + "scheduler.alpha.kubernetes.io/node-selector", + "scheduler.alpha.kubernetes.io/defaultTolerations", +} + +const ( + VclusterNameAnnotation = "vcluster.loft.sh/vcluster-name" + VclusterNamespaceAnnotation = "vcluster.loft.sh/vcluster-namespace" +) + +func New(ctx *synccontext.RegisterContext) (syncer.Object, error) { + namespaceLabels, err := parseNamespaceLabels(ctx.Options.NamespaceLabels) + if err != nil { + return nil, fmt.Errorf("invalid value of the namespace-labels flag: %v", err) + } + + namespaceLabels[VclusterNameAnnotation] = ctx.Options.Name + namespaceLabels[VclusterNamespaceAnnotation] = ctx.CurrentNamespace + + return &namespaceSyncer{ + Translator: translator.NewClusterTranslator(ctx, "namespace", &corev1.Namespace{}, NamespaceNameTranslator, excludedAnnotations...), + workloadServiceAccountName: ctx.Options.ServiceAccount, + namespaceLabels: namespaceLabels, + }, nil +} + +type namespaceSyncer struct { + translator.Translator + workloadServiceAccountName string + namespaceLabels map[string]string +} + +var _ syncer.IndicesRegisterer = &namespaceSyncer{} + +func (s *namespaceSyncer) RegisterIndices(ctx *synccontext.RegisterContext) error { + return ctx.VirtualManager.GetFieldIndexer().IndexField(ctx.Context, &corev1.Namespace{}, constants.IndexByPhysicalName, func(rawObj client.Object) []string { + return []string{NamespaceNameTranslator(rawObj.GetName(), rawObj)} + }) +} + +var _ syncer.Syncer = &namespaceSyncer{} + +func (s *namespaceSyncer) SyncDown(ctx *synccontext.SyncContext, vObj client.Object) (ctrl.Result, error) { + newNamespace := s.translate(ctx.Context, vObj.(*corev1.Namespace)) + ctx.Log.Infof("create physical namespace %s", newNamespace.Name) + err := ctx.PhysicalClient.Create(ctx.Context, newNamespace) + if err != nil { + ctx.Log.Infof("error syncing %s to physical cluster: %v", vObj.GetName(), err) + return ctrl.Result{}, err + } + + return ctrl.Result{}, s.EnsureWorkloadServiceAccount(ctx, newNamespace.Name) +} + +func (s *namespaceSyncer) Sync(ctx *synccontext.SyncContext, pObj client.Object, vObj client.Object) (ctrl.Result, error) { + updated := s.translateUpdate(ctx.Context, pObj.(*corev1.Namespace), vObj.(*corev1.Namespace)) + if updated != nil { + ctx.Log.Infof("updating physical namespace %s, because virtual namespace has changed", updated.Name) + translator.PrintChanges(pObj, updated, ctx.Log) + err := ctx.PhysicalClient.Update(ctx.Context, updated) + if err != nil { + return ctrl.Result{}, err + } + } + + return ctrl.Result{}, s.EnsureWorkloadServiceAccount(ctx, pObj.GetName()) +} + +func (s *namespaceSyncer) EnsureWorkloadServiceAccount(ctx *synccontext.SyncContext, pNamespace string) error { + if s.workloadServiceAccountName == "" { + return nil + } + + svc := &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: pNamespace, + Name: s.workloadServiceAccountName, + }, + } + _, err := controllerutil.CreateOrPatch(ctx.Context, ctx.PhysicalClient, svc, func() error { return nil }) + return err +} + +func NamespaceNameTranslator(vName string, _ client.Object) string { + return translate.Default.PhysicalNamespace(vName) +} + +func parseNamespaceLabels(labels []string) (map[string]string, error) { + out := map[string]string{} + for _, v := range labels { + parts := strings.SplitN(v, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("incorrect format, expected: key=value got: %s", v) + } + out[parts[0]] = parts[1] + } + errs := validation.ValidateLabels(out, field.NewPath("namespace-labels")) + if len(errs) != 0 { + return nil, fmt.Errorf("invalid labels: %v", errs) + } + + return out, nil +} diff --git a/vendor/github.com/loft-sh/vcluster/pkg/controllers/resources/namespaces/translate.go b/vendor/github.com/loft-sh/vcluster/pkg/controllers/resources/namespaces/translate.go new file mode 100644 index 00000000..4ab8c56c --- /dev/null +++ b/vendor/github.com/loft-sh/vcluster/pkg/controllers/resources/namespaces/translate.go @@ -0,0 +1,41 @@ +package namespaces + +import ( + "context" + + "github.com/loft-sh/vcluster/pkg/controllers/syncer/translator" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/equality" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +func (s *namespaceSyncer) translate(ctx context.Context, vObj client.Object) *corev1.Namespace { + newNamespace := s.TranslateMetadata(ctx, vObj).(*corev1.Namespace) + + // add user defined namespace labels + for k, v := range s.namespaceLabels { + newNamespace.Labels[k] = v + } + + return newNamespace +} + +func (s *namespaceSyncer) translateUpdate(ctx context.Context, pObj, vObj *corev1.Namespace) *corev1.Namespace { + var updated *corev1.Namespace + + _, updatedAnnotations, updatedLabels := s.TranslateMetadataUpdate(ctx, vObj, pObj) + // add user defined namespace labels + for k, v := range s.namespaceLabels { + updatedLabels[k] = v + } + // set the kubernetes.io/metadata.name label + updatedLabels[corev1.LabelMetadataName] = pObj.Name + // check if any labels or annotations changed + if !equality.Semantic.DeepEqual(updatedAnnotations, pObj.GetAnnotations()) || !equality.Semantic.DeepEqual(updatedLabels, pObj.GetLabels()) { + updated = translator.NewIfNil(updated, pObj) + updated.Annotations = updatedAnnotations + updated.Labels = updatedLabels + } + + return updated +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 6f6047a1..c519d007 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -156,6 +156,7 @@ github.com/liggitt/tabwriter github.com/loft-sh/vcluster/cmd/vcluster/context github.com/loft-sh/vcluster/pkg/constants github.com/loft-sh/vcluster/pkg/controllers/resources/configmaps +github.com/loft-sh/vcluster/pkg/controllers/resources/namespaces github.com/loft-sh/vcluster/pkg/controllers/resources/pods/translate github.com/loft-sh/vcluster/pkg/controllers/resources/priorityclasses github.com/loft-sh/vcluster/pkg/controllers/syncer From 335934eb6cd625ea935df7127e34d063687ed7df Mon Sep 17 00:00:00 2001 From: neogopher Date: Wed, 21 Feb 2024 19:33:36 +0530 Subject: [PATCH 2/2] Extract out the filtering function and add unit tests for the same --- .github/workflows/unit-tests.yaml | 25 ++ cmd/hostpaths/hostpaths.go | 18 +- cmd/hostpaths/hostpaths_test.go | 83 ++++ go.mod | 1 + vendor/gotest.tools/LICENSE | 13 + vendor/gotest.tools/assert/assert.go | 311 +++++++++++++ vendor/gotest.tools/assert/cmp/compare.go | 356 +++++++++++++++ vendor/gotest.tools/assert/cmp/result.go | 94 ++++ vendor/gotest.tools/assert/result.go | 106 +++++ vendor/gotest.tools/internal/difflib/LICENSE | 27 ++ .../gotest.tools/internal/difflib/difflib.go | 423 ++++++++++++++++++ vendor/gotest.tools/internal/format/diff.go | 161 +++++++ vendor/gotest.tools/internal/format/format.go | 27 ++ vendor/gotest.tools/internal/source/defers.go | 53 +++ vendor/gotest.tools/internal/source/source.go | 166 +++++++ vendor/modules.txt | 7 + 16 files changed, 1865 insertions(+), 6 deletions(-) create mode 100644 .github/workflows/unit-tests.yaml create mode 100644 cmd/hostpaths/hostpaths_test.go create mode 100644 vendor/gotest.tools/LICENSE create mode 100644 vendor/gotest.tools/assert/assert.go create mode 100644 vendor/gotest.tools/assert/cmp/compare.go create mode 100644 vendor/gotest.tools/assert/cmp/result.go create mode 100644 vendor/gotest.tools/assert/result.go create mode 100644 vendor/gotest.tools/internal/difflib/LICENSE create mode 100644 vendor/gotest.tools/internal/difflib/difflib.go create mode 100644 vendor/gotest.tools/internal/format/diff.go create mode 100644 vendor/gotest.tools/internal/format/format.go create mode 100644 vendor/gotest.tools/internal/source/defers.go create mode 100644 vendor/gotest.tools/internal/source/source.go diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml new file mode 100644 index 00000000..12373909 --- /dev/null +++ b/.github/workflows/unit-tests.yaml @@ -0,0 +1,25 @@ +name: Unit tests + +on: + workflow_dispatch: + pull_request: + branches: + - main + paths: + - "**.go" + - ".github/workflows/unit-tests.yaml" + +jobs: + unit-test: + name: Execute all tests + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v4 + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version-file: "go.mod" + cache: false + - name: Execute unit tests + run: go test -v ./... \ No newline at end of file diff --git a/cmd/hostpaths/hostpaths.go b/cmd/hostpaths/hostpaths.go index 498a234d..d24f48bc 100644 --- a/cmd/hostpaths/hostpaths.go +++ b/cmd/hostpaths/hostpaths.go @@ -443,12 +443,7 @@ func getPhysicalPodMap(ctx context.Context, options *context2.VirtualClusterOpti } // Limit Pods - pods = make([]corev1.Pod, 0, len(podList.Items)) - for _, pod := range podList.Items { - if _, ok := vclusterNamespaces[pod.Namespace]; ok { - pods = append(pods, pod) - } - } + pods = filter(ctx, podList.Items, vclusterNamespaces) } else { pods = podList.Items } @@ -476,6 +471,17 @@ func getPhysicalPodMap(ctx context.Context, options *context2.VirtualClusterOpti return podMappings, nil } +func filter(ctx context.Context, podList []corev1.Pod, vclusterNamespaces map[string]struct{}) []corev1.Pod { + pods := make([]corev1.Pod, 0, len(podList)) + for _, pod := range podList { + if _, ok := vclusterNamespaces[pod.Namespace]; ok { + pods = append(pods, pod) + } + } + + return pods +} + func cleanupOldContainerPaths(ctx context.Context, existingVPodsWithNS map[string]bool) error { options := ctx.Value(optionsKey).(*context2.VirtualClusterOptions) diff --git a/cmd/hostpaths/hostpaths_test.go b/cmd/hostpaths/hostpaths_test.go new file mode 100644 index 00000000..00ca106a --- /dev/null +++ b/cmd/hostpaths/hostpaths_test.go @@ -0,0 +1,83 @@ +package hostpaths + +import ( + "context" + "testing" + + "gotest.tools/assert" + "gotest.tools/assert/cmp" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func Test_filter(t *testing.T) { + testPodList := []corev1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod1", + Namespace: "test-ns1", + }, + Spec: corev1.PodSpec{}, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod2", + Namespace: "test-ns2", + }, + Spec: corev1.PodSpec{}, + }, + } + + testCases := []struct { + name string + podList []corev1.Pod + vclusterNamespaces map[string]struct{} + expected []corev1.Pod + }{ + { + name: "None of the pods belong to namespace(s) managed by the current vCluster", + podList: testPodList, + vclusterNamespaces: map[string]struct{}{ + "test-ns3": {}, + "test-ns4": {}, + }, + expected: []corev1.Pod{}, + }, + { + name: "Some of the pods belong to namespace(s) managed by the current vCluster", + podList: testPodList, + vclusterNamespaces: map[string]struct{}{ + "test-ns1": {}, + "test-ns4": {}, + }, + expected: []corev1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod1", + Namespace: "test-ns1", + }, + Spec: corev1.PodSpec{}, + }, + }, + }, + { + name: "All of the pods belong to namespace(s) managed by the current vCluster", + podList: testPodList, + vclusterNamespaces: map[string]struct{}{ + "test-ns1": {}, + "test-ns2": {}, + }, + expected: testPodList, + }, + } + + for _, testCase := range testCases { + actual := filter(context.Background(), testCase.podList, testCase.vclusterNamespaces) + + assert.Assert(t, + cmp.DeepEqual(actual, testCase.expected), + "Unexpected result in test case %s", + testCase.name, + ) + } +} diff --git a/go.mod b/go.mod index ef99cb1e..861978d8 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/loft-sh/vcluster v0.15.2 github.com/pkg/errors v0.9.1 github.com/spf13/cobra v1.7.0 + gotest.tools v2.2.0+incompatible k8s.io/api v0.27.4 k8s.io/apimachinery v0.27.4 k8s.io/client-go v0.27.4 diff --git a/vendor/gotest.tools/LICENSE b/vendor/gotest.tools/LICENSE new file mode 100644 index 00000000..aeaa2fac --- /dev/null +++ b/vendor/gotest.tools/LICENSE @@ -0,0 +1,13 @@ +Copyright 2018 gotest.tools authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/vendor/gotest.tools/assert/assert.go b/vendor/gotest.tools/assert/assert.go new file mode 100644 index 00000000..05d66354 --- /dev/null +++ b/vendor/gotest.tools/assert/assert.go @@ -0,0 +1,311 @@ +/*Package assert provides assertions for comparing expected values to actual +values. When an assertion fails a helpful error message is printed. + +Assert and Check + +Assert() and Check() both accept a Comparison, and fail the test when the +comparison fails. The one difference is that Assert() will end the test execution +immediately (using t.FailNow()) whereas Check() will fail the test (using t.Fail()), +return the value of the comparison, then proceed with the rest of the test case. + +Example usage + +The example below shows assert used with some common types. + + + import ( + "testing" + + "gotest.tools/assert" + is "gotest.tools/assert/cmp" + ) + + func TestEverything(t *testing.T) { + // booleans + assert.Assert(t, ok) + assert.Assert(t, !missing) + + // primitives + assert.Equal(t, count, 1) + assert.Equal(t, msg, "the message") + assert.Assert(t, total != 10) // NotEqual + + // errors + assert.NilError(t, closer.Close()) + assert.Error(t, err, "the exact error message") + assert.ErrorContains(t, err, "includes this") + assert.ErrorType(t, err, os.IsNotExist) + + // complex types + assert.DeepEqual(t, result, myStruct{Name: "title"}) + assert.Assert(t, is.Len(items, 3)) + assert.Assert(t, len(sequence) != 0) // NotEmpty + assert.Assert(t, is.Contains(mapping, "key")) + + // pointers and interface + assert.Assert(t, is.Nil(ref)) + assert.Assert(t, ref != nil) // NotNil + } + +Comparisons + +Package https://godoc.org/gotest.tools/assert/cmp provides +many common comparisons. Additional comparisons can be written to compare +values in other ways. See the example Assert (CustomComparison). + +Automated migration from testify + +gty-migrate-from-testify is a binary which can update source code which uses +testify assertions to use the assertions provided by this package. + +See http://bit.do/cmd-gty-migrate-from-testify. + + +*/ +package assert // import "gotest.tools/assert" + +import ( + "fmt" + "go/ast" + "go/token" + + gocmp "github.com/google/go-cmp/cmp" + "gotest.tools/assert/cmp" + "gotest.tools/internal/format" + "gotest.tools/internal/source" +) + +// BoolOrComparison can be a bool, or cmp.Comparison. See Assert() for usage. +type BoolOrComparison interface{} + +// TestingT is the subset of testing.T used by the assert package. +type TestingT interface { + FailNow() + Fail() + Log(args ...interface{}) +} + +type helperT interface { + Helper() +} + +const failureMessage = "assertion failed: " + +// nolint: gocyclo +func assert( + t TestingT, + failer func(), + argSelector argSelector, + comparison BoolOrComparison, + msgAndArgs ...interface{}, +) bool { + if ht, ok := t.(helperT); ok { + ht.Helper() + } + var success bool + switch check := comparison.(type) { + case bool: + if check { + return true + } + logFailureFromBool(t, msgAndArgs...) + + // Undocumented legacy comparison without Result type + case func() (success bool, message string): + success = runCompareFunc(t, check, msgAndArgs...) + + case nil: + return true + + case error: + msg := "error is not nil: " + t.Log(format.WithCustomMessage(failureMessage+msg+check.Error(), msgAndArgs...)) + + case cmp.Comparison: + success = runComparison(t, argSelector, check, msgAndArgs...) + + case func() cmp.Result: + success = runComparison(t, argSelector, check, msgAndArgs...) + + default: + t.Log(fmt.Sprintf("invalid Comparison: %v (%T)", check, check)) + } + + if success { + return true + } + failer() + return false +} + +func runCompareFunc( + t TestingT, + f func() (success bool, message string), + msgAndArgs ...interface{}, +) bool { + if ht, ok := t.(helperT); ok { + ht.Helper() + } + if success, message := f(); !success { + t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...)) + return false + } + return true +} + +func logFailureFromBool(t TestingT, msgAndArgs ...interface{}) { + if ht, ok := t.(helperT); ok { + ht.Helper() + } + const stackIndex = 3 // Assert()/Check(), assert(), formatFailureFromBool() + const comparisonArgPos = 1 + args, err := source.CallExprArgs(stackIndex) + if err != nil { + t.Log(err.Error()) + return + } + + msg, err := boolFailureMessage(args[comparisonArgPos]) + if err != nil { + t.Log(err.Error()) + msg = "expression is false" + } + + t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...)) +} + +func boolFailureMessage(expr ast.Expr) (string, error) { + if binaryExpr, ok := expr.(*ast.BinaryExpr); ok && binaryExpr.Op == token.NEQ { + x, err := source.FormatNode(binaryExpr.X) + if err != nil { + return "", err + } + y, err := source.FormatNode(binaryExpr.Y) + if err != nil { + return "", err + } + return x + " is " + y, nil + } + + if unaryExpr, ok := expr.(*ast.UnaryExpr); ok && unaryExpr.Op == token.NOT { + x, err := source.FormatNode(unaryExpr.X) + if err != nil { + return "", err + } + return x + " is true", nil + } + + formatted, err := source.FormatNode(expr) + if err != nil { + return "", err + } + return "expression is false: " + formatted, nil +} + +// Assert performs a comparison. If the comparison fails the test is marked as +// failed, a failure message is logged, and execution is stopped immediately. +// +// The comparison argument may be one of three types: bool, cmp.Comparison or +// error. +// When called with a bool the failure message will contain the literal source +// code of the expression. +// When called with a cmp.Comparison the comparison is responsible for producing +// a helpful failure message. +// When called with an error a nil value is considered success. A non-nil error +// is a failure, and Error() is used as the failure message. +func Assert(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) { + if ht, ok := t.(helperT); ok { + ht.Helper() + } + assert(t, t.FailNow, argsFromComparisonCall, comparison, msgAndArgs...) +} + +// Check performs a comparison. If the comparison fails the test is marked as +// failed, a failure message is logged, and Check returns false. Otherwise returns +// true. +// +// See Assert for details about the comparison arg and failure messages. +func Check(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) bool { + if ht, ok := t.(helperT); ok { + ht.Helper() + } + return assert(t, t.Fail, argsFromComparisonCall, comparison, msgAndArgs...) +} + +// NilError fails the test immediately if err is not nil. +// This is equivalent to Assert(t, err) +func NilError(t TestingT, err error, msgAndArgs ...interface{}) { + if ht, ok := t.(helperT); ok { + ht.Helper() + } + assert(t, t.FailNow, argsAfterT, err, msgAndArgs...) +} + +// Equal uses the == operator to assert two values are equal and fails the test +// if they are not equal. +// +// If the comparison fails Equal will use the variable names for x and y as part +// of the failure message to identify the actual and expected values. +// +// If either x or y are a multi-line string the failure message will include a +// unified diff of the two values. If the values only differ by whitespace +// the unified diff will be augmented by replacing whitespace characters with +// visible characters to identify the whitespace difference. +// +// This is equivalent to Assert(t, cmp.Equal(x, y)). +func Equal(t TestingT, x, y interface{}, msgAndArgs ...interface{}) { + if ht, ok := t.(helperT); ok { + ht.Helper() + } + assert(t, t.FailNow, argsAfterT, cmp.Equal(x, y), msgAndArgs...) +} + +// DeepEqual uses google/go-cmp (http://bit.do/go-cmp) to assert two values are +// equal and fails the test if they are not equal. +// +// Package https://godoc.org/gotest.tools/assert/opt provides some additional +// commonly used Options. +// +// This is equivalent to Assert(t, cmp.DeepEqual(x, y)). +func DeepEqual(t TestingT, x, y interface{}, opts ...gocmp.Option) { + if ht, ok := t.(helperT); ok { + ht.Helper() + } + assert(t, t.FailNow, argsAfterT, cmp.DeepEqual(x, y, opts...)) +} + +// Error fails the test if err is nil, or the error message is not the expected +// message. +// Equivalent to Assert(t, cmp.Error(err, message)). +func Error(t TestingT, err error, message string, msgAndArgs ...interface{}) { + if ht, ok := t.(helperT); ok { + ht.Helper() + } + assert(t, t.FailNow, argsAfterT, cmp.Error(err, message), msgAndArgs...) +} + +// ErrorContains fails the test if err is nil, or the error message does not +// contain the expected substring. +// Equivalent to Assert(t, cmp.ErrorContains(err, substring)). +func ErrorContains(t TestingT, err error, substring string, msgAndArgs ...interface{}) { + if ht, ok := t.(helperT); ok { + ht.Helper() + } + assert(t, t.FailNow, argsAfterT, cmp.ErrorContains(err, substring), msgAndArgs...) +} + +// ErrorType fails the test if err is nil, or err is not the expected type. +// +// Expected can be one of: +// a func(error) bool which returns true if the error is the expected type, +// an instance of (or a pointer to) a struct of the expected type, +// a pointer to an interface the error is expected to implement, +// a reflect.Type of the expected struct or interface. +// +// Equivalent to Assert(t, cmp.ErrorType(err, expected)). +func ErrorType(t TestingT, err error, expected interface{}, msgAndArgs ...interface{}) { + if ht, ok := t.(helperT); ok { + ht.Helper() + } + assert(t, t.FailNow, argsAfterT, cmp.ErrorType(err, expected), msgAndArgs...) +} diff --git a/vendor/gotest.tools/assert/cmp/compare.go b/vendor/gotest.tools/assert/cmp/compare.go new file mode 100644 index 00000000..cf48d887 --- /dev/null +++ b/vendor/gotest.tools/assert/cmp/compare.go @@ -0,0 +1,356 @@ +/*Package cmp provides Comparisons for Assert and Check*/ +package cmp // import "gotest.tools/assert/cmp" + +import ( + "fmt" + "reflect" + "regexp" + "strings" + + "github.com/google/go-cmp/cmp" + "gotest.tools/internal/format" +) + +// Comparison is a function which compares values and returns ResultSuccess if +// the actual value matches the expected value. If the values do not match the +// Result will contain a message about why it failed. +type Comparison func() Result + +// DeepEqual compares two values using google/go-cmp (http://bit.do/go-cmp) +// and succeeds if the values are equal. +// +// The comparison can be customized using comparison Options. +// Package https://godoc.org/gotest.tools/assert/opt provides some additional +// commonly used Options. +func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison { + return func() (result Result) { + defer func() { + if panicmsg, handled := handleCmpPanic(recover()); handled { + result = ResultFailure(panicmsg) + } + }() + diff := cmp.Diff(x, y, opts...) + if diff == "" { + return ResultSuccess + } + return multiLineDiffResult(diff) + } +} + +func handleCmpPanic(r interface{}) (string, bool) { + if r == nil { + return "", false + } + panicmsg, ok := r.(string) + if !ok { + panic(r) + } + switch { + case strings.HasPrefix(panicmsg, "cannot handle unexported field"): + return panicmsg, true + } + panic(r) +} + +func toResult(success bool, msg string) Result { + if success { + return ResultSuccess + } + return ResultFailure(msg) +} + +// RegexOrPattern may be either a *regexp.Regexp or a string that is a valid +// regexp pattern. +type RegexOrPattern interface{} + +// Regexp succeeds if value v matches regular expression re. +// +// Example: +// assert.Assert(t, cmp.Regexp("^[0-9a-f]{32}$", str)) +// r := regexp.MustCompile("^[0-9a-f]{32}$") +// assert.Assert(t, cmp.Regexp(r, str)) +func Regexp(re RegexOrPattern, v string) Comparison { + match := func(re *regexp.Regexp) Result { + return toResult( + re.MatchString(v), + fmt.Sprintf("value %q does not match regexp %q", v, re.String())) + } + + return func() Result { + switch regex := re.(type) { + case *regexp.Regexp: + return match(regex) + case string: + re, err := regexp.Compile(regex) + if err != nil { + return ResultFailure(err.Error()) + } + return match(re) + default: + return ResultFailure(fmt.Sprintf("invalid type %T for regex pattern", regex)) + } + } +} + +// Equal succeeds if x == y. See assert.Equal for full documentation. +func Equal(x, y interface{}) Comparison { + return func() Result { + switch { + case x == y: + return ResultSuccess + case isMultiLineStringCompare(x, y): + diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)}) + return multiLineDiffResult(diff) + } + return ResultFailureTemplate(` + {{- .Data.x}} ( + {{- with callArg 0 }}{{ formatNode . }} {{end -}} + {{- printf "%T" .Data.x -}} + ) != {{ .Data.y}} ( + {{- with callArg 1 }}{{ formatNode . }} {{end -}} + {{- printf "%T" .Data.y -}} + )`, + map[string]interface{}{"x": x, "y": y}) + } +} + +func isMultiLineStringCompare(x, y interface{}) bool { + strX, ok := x.(string) + if !ok { + return false + } + strY, ok := y.(string) + if !ok { + return false + } + return strings.Contains(strX, "\n") || strings.Contains(strY, "\n") +} + +func multiLineDiffResult(diff string) Result { + return ResultFailureTemplate(` +--- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}} ++++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}} +{{ .Data.diff }}`, + map[string]interface{}{"diff": diff}) +} + +// Len succeeds if the sequence has the expected length. +func Len(seq interface{}, expected int) Comparison { + return func() (result Result) { + defer func() { + if e := recover(); e != nil { + result = ResultFailure(fmt.Sprintf("type %T does not have a length", seq)) + } + }() + value := reflect.ValueOf(seq) + length := value.Len() + if length == expected { + return ResultSuccess + } + msg := fmt.Sprintf("expected %s (length %d) to have length %d", seq, length, expected) + return ResultFailure(msg) + } +} + +// Contains succeeds if item is in collection. Collection may be a string, map, +// slice, or array. +// +// If collection is a string, item must also be a string, and is compared using +// strings.Contains(). +// If collection is a Map, contains will succeed if item is a key in the map. +// If collection is a slice or array, item is compared to each item in the +// sequence using reflect.DeepEqual(). +func Contains(collection interface{}, item interface{}) Comparison { + return func() Result { + colValue := reflect.ValueOf(collection) + if !colValue.IsValid() { + return ResultFailure(fmt.Sprintf("nil does not contain items")) + } + msg := fmt.Sprintf("%v does not contain %v", collection, item) + + itemValue := reflect.ValueOf(item) + switch colValue.Type().Kind() { + case reflect.String: + if itemValue.Type().Kind() != reflect.String { + return ResultFailure("string may only contain strings") + } + return toResult( + strings.Contains(colValue.String(), itemValue.String()), + fmt.Sprintf("string %q does not contain %q", collection, item)) + + case reflect.Map: + if itemValue.Type() != colValue.Type().Key() { + return ResultFailure(fmt.Sprintf( + "%v can not contain a %v key", colValue.Type(), itemValue.Type())) + } + return toResult(colValue.MapIndex(itemValue).IsValid(), msg) + + case reflect.Slice, reflect.Array: + for i := 0; i < colValue.Len(); i++ { + if reflect.DeepEqual(colValue.Index(i).Interface(), item) { + return ResultSuccess + } + } + return ResultFailure(msg) + default: + return ResultFailure(fmt.Sprintf("type %T does not contain items", collection)) + } + } +} + +// Panics succeeds if f() panics. +func Panics(f func()) Comparison { + return func() (result Result) { + defer func() { + if err := recover(); err != nil { + result = ResultSuccess + } + }() + f() + return ResultFailure("did not panic") + } +} + +// Error succeeds if err is a non-nil error, and the error message equals the +// expected message. +func Error(err error, message string) Comparison { + return func() Result { + switch { + case err == nil: + return ResultFailure("expected an error, got nil") + case err.Error() != message: + return ResultFailure(fmt.Sprintf( + "expected error %q, got %s", message, formatErrorMessage(err))) + } + return ResultSuccess + } +} + +// ErrorContains succeeds if err is a non-nil error, and the error message contains +// the expected substring. +func ErrorContains(err error, substring string) Comparison { + return func() Result { + switch { + case err == nil: + return ResultFailure("expected an error, got nil") + case !strings.Contains(err.Error(), substring): + return ResultFailure(fmt.Sprintf( + "expected error to contain %q, got %s", substring, formatErrorMessage(err))) + } + return ResultSuccess + } +} + +func formatErrorMessage(err error) string { + if _, ok := err.(interface { + Cause() error + }); ok { + return fmt.Sprintf("%q\n%+v", err, err) + } + // This error was not wrapped with github.com/pkg/errors + return fmt.Sprintf("%q", err) +} + +// Nil succeeds if obj is a nil interface, pointer, or function. +// +// Use NilError() for comparing errors. Use Len(obj, 0) for comparing slices, +// maps, and channels. +func Nil(obj interface{}) Comparison { + msgFunc := func(value reflect.Value) string { + return fmt.Sprintf("%v (type %s) is not nil", reflect.Indirect(value), value.Type()) + } + return isNil(obj, msgFunc) +} + +func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison { + return func() Result { + if obj == nil { + return ResultSuccess + } + value := reflect.ValueOf(obj) + kind := value.Type().Kind() + if kind >= reflect.Chan && kind <= reflect.Slice { + if value.IsNil() { + return ResultSuccess + } + return ResultFailure(msgFunc(value)) + } + + return ResultFailure(fmt.Sprintf("%v (type %s) can not be nil", value, value.Type())) + } +} + +// ErrorType succeeds if err is not nil and is of the expected type. +// +// Expected can be one of: +// a func(error) bool which returns true if the error is the expected type, +// an instance of (or a pointer to) a struct of the expected type, +// a pointer to an interface the error is expected to implement, +// a reflect.Type of the expected struct or interface. +func ErrorType(err error, expected interface{}) Comparison { + return func() Result { + switch expectedType := expected.(type) { + case func(error) bool: + return cmpErrorTypeFunc(err, expectedType) + case reflect.Type: + if expectedType.Kind() == reflect.Interface { + return cmpErrorTypeImplementsType(err, expectedType) + } + return cmpErrorTypeEqualType(err, expectedType) + case nil: + return ResultFailure(fmt.Sprintf("invalid type for expected: nil")) + } + + expectedType := reflect.TypeOf(expected) + switch { + case expectedType.Kind() == reflect.Struct, isPtrToStruct(expectedType): + return cmpErrorTypeEqualType(err, expectedType) + case isPtrToInterface(expectedType): + return cmpErrorTypeImplementsType(err, expectedType.Elem()) + } + return ResultFailure(fmt.Sprintf("invalid type for expected: %T", expected)) + } +} + +func cmpErrorTypeFunc(err error, f func(error) bool) Result { + if f(err) { + return ResultSuccess + } + actual := "nil" + if err != nil { + actual = fmt.Sprintf("%s (%T)", err, err) + } + return ResultFailureTemplate(`error is {{ .Data.actual }} + {{- with callArg 1 }}, not {{ formatNode . }}{{end -}}`, + map[string]interface{}{"actual": actual}) +} + +func cmpErrorTypeEqualType(err error, expectedType reflect.Type) Result { + if err == nil { + return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType)) + } + errValue := reflect.ValueOf(err) + if errValue.Type() == expectedType { + return ResultSuccess + } + return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType)) +} + +func cmpErrorTypeImplementsType(err error, expectedType reflect.Type) Result { + if err == nil { + return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType)) + } + errValue := reflect.ValueOf(err) + if errValue.Type().Implements(expectedType) { + return ResultSuccess + } + return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType)) +} + +func isPtrToInterface(typ reflect.Type) bool { + return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Interface +} + +func isPtrToStruct(typ reflect.Type) bool { + return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct +} diff --git a/vendor/gotest.tools/assert/cmp/result.go b/vendor/gotest.tools/assert/cmp/result.go new file mode 100644 index 00000000..7c3c37dd --- /dev/null +++ b/vendor/gotest.tools/assert/cmp/result.go @@ -0,0 +1,94 @@ +package cmp + +import ( + "bytes" + "fmt" + "go/ast" + "text/template" + + "gotest.tools/internal/source" +) + +// Result of a Comparison. +type Result interface { + Success() bool +} + +type result struct { + success bool + message string +} + +func (r result) Success() bool { + return r.success +} + +func (r result) FailureMessage() string { + return r.message +} + +// ResultSuccess is a constant which is returned by a ComparisonWithResult to +// indicate success. +var ResultSuccess = result{success: true} + +// ResultFailure returns a failed Result with a failure message. +func ResultFailure(message string) Result { + return result{message: message} +} + +// ResultFromError returns ResultSuccess if err is nil. Otherwise ResultFailure +// is returned with the error message as the failure message. +func ResultFromError(err error) Result { + if err == nil { + return ResultSuccess + } + return ResultFailure(err.Error()) +} + +type templatedResult struct { + success bool + template string + data map[string]interface{} +} + +func (r templatedResult) Success() bool { + return r.success +} + +func (r templatedResult) FailureMessage(args []ast.Expr) string { + msg, err := renderMessage(r, args) + if err != nil { + return fmt.Sprintf("failed to render failure message: %s", err) + } + return msg +} + +// ResultFailureTemplate returns a Result with a template string and data which +// can be used to format a failure message. The template may access data from .Data, +// the comparison args with the callArg function, and the formatNode function may +// be used to format the call args. +func ResultFailureTemplate(template string, data map[string]interface{}) Result { + return templatedResult{template: template, data: data} +} + +func renderMessage(result templatedResult, args []ast.Expr) (string, error) { + tmpl := template.New("failure").Funcs(template.FuncMap{ + "formatNode": source.FormatNode, + "callArg": func(index int) ast.Expr { + if index >= len(args) { + return nil + } + return args[index] + }, + }) + var err error + tmpl, err = tmpl.Parse(result.template) + if err != nil { + return "", err + } + buf := new(bytes.Buffer) + err = tmpl.Execute(buf, map[string]interface{}{ + "Data": result.data, + }) + return buf.String(), err +} diff --git a/vendor/gotest.tools/assert/result.go b/vendor/gotest.tools/assert/result.go new file mode 100644 index 00000000..949d9396 --- /dev/null +++ b/vendor/gotest.tools/assert/result.go @@ -0,0 +1,106 @@ +package assert + +import ( + "fmt" + "go/ast" + + "gotest.tools/assert/cmp" + "gotest.tools/internal/format" + "gotest.tools/internal/source" +) + +func runComparison( + t TestingT, + argSelector argSelector, + f cmp.Comparison, + msgAndArgs ...interface{}, +) bool { + if ht, ok := t.(helperT); ok { + ht.Helper() + } + result := f() + if result.Success() { + return true + } + + var message string + switch typed := result.(type) { + case resultWithComparisonArgs: + const stackIndex = 3 // Assert/Check, assert, runComparison + args, err := source.CallExprArgs(stackIndex) + if err != nil { + t.Log(err.Error()) + } + message = typed.FailureMessage(filterPrintableExpr(argSelector(args))) + case resultBasic: + message = typed.FailureMessage() + default: + message = fmt.Sprintf("comparison returned invalid Result type: %T", result) + } + + t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...)) + return false +} + +type resultWithComparisonArgs interface { + FailureMessage(args []ast.Expr) string +} + +type resultBasic interface { + FailureMessage() string +} + +// filterPrintableExpr filters the ast.Expr slice to only include Expr that are +// easy to read when printed and contain relevant information to an assertion. +// +// Ident and SelectorExpr are included because they print nicely and the variable +// names may provide additional context to their values. +// BasicLit and CompositeLit are excluded because their source is equivalent to +// their value, which is already available. +// Other types are ignored for now, but could be added if they are relevant. +func filterPrintableExpr(args []ast.Expr) []ast.Expr { + result := make([]ast.Expr, len(args)) + for i, arg := range args { + if isShortPrintableExpr(arg) { + result[i] = arg + continue + } + + if starExpr, ok := arg.(*ast.StarExpr); ok { + result[i] = starExpr.X + continue + } + } + return result +} + +func isShortPrintableExpr(expr ast.Expr) bool { + switch expr.(type) { + case *ast.Ident, *ast.SelectorExpr, *ast.IndexExpr, *ast.SliceExpr: + return true + case *ast.BinaryExpr, *ast.UnaryExpr: + return true + default: + // CallExpr, ParenExpr, TypeAssertExpr, KeyValueExpr, StarExpr + return false + } +} + +type argSelector func([]ast.Expr) []ast.Expr + +func argsAfterT(args []ast.Expr) []ast.Expr { + if len(args) < 1 { + return nil + } + return args[1:] +} + +func argsFromComparisonCall(args []ast.Expr) []ast.Expr { + if len(args) < 1 { + return nil + } + if callExpr, ok := args[1].(*ast.CallExpr); ok { + return callExpr.Args + } + return nil +} diff --git a/vendor/gotest.tools/internal/difflib/LICENSE b/vendor/gotest.tools/internal/difflib/LICENSE new file mode 100644 index 00000000..c67dad61 --- /dev/null +++ b/vendor/gotest.tools/internal/difflib/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2013, Patrick Mezard +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + The names of its contributors may not be used to endorse or promote +products derived from this software without specific prior written +permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/gotest.tools/internal/difflib/difflib.go b/vendor/gotest.tools/internal/difflib/difflib.go new file mode 100644 index 00000000..b6f486b9 --- /dev/null +++ b/vendor/gotest.tools/internal/difflib/difflib.go @@ -0,0 +1,423 @@ +/*Package difflib is a partial port of Python difflib module. + +Original source: https://github.com/pmezard/go-difflib + +This file is trimmed to only the parts used by this repository. +*/ +package difflib // import "gotest.tools/internal/difflib" + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +// Match stores line numbers of size of match +type Match struct { + A int + B int + Size int +} + +// OpCode identifies the type of diff +type OpCode struct { + Tag byte + I1 int + I2 int + J1 int + J2 int +} + +// SequenceMatcher compares sequence of strings. The basic +// algorithm predates, and is a little fancier than, an algorithm +// published in the late 1980's by Ratcliff and Obershelp under the +// hyperbolic name "gestalt pattern matching". The basic idea is to find +// the longest contiguous matching subsequence that contains no "junk" +// elements (R-O doesn't address junk). The same idea is then applied +// recursively to the pieces of the sequences to the left and to the right +// of the matching subsequence. This does not yield minimal edit +// sequences, but does tend to yield matches that "look right" to people. +// +// SequenceMatcher tries to compute a "human-friendly diff" between two +// sequences. Unlike e.g. UNIX(tm) diff, the fundamental notion is the +// longest *contiguous* & junk-free matching subsequence. That's what +// catches peoples' eyes. The Windows(tm) windiff has another interesting +// notion, pairing up elements that appear uniquely in each sequence. +// That, and the method here, appear to yield more intuitive difference +// reports than does diff. This method appears to be the least vulnerable +// to synching up on blocks of "junk lines", though (like blank lines in +// ordinary text files, or maybe "

" lines in HTML files). That may be +// because this is the only method of the 3 that has a *concept* of +// "junk" . +// +// Timing: Basic R-O is cubic time worst case and quadratic time expected +// case. SequenceMatcher is quadratic time for the worst case and has +// expected-case behavior dependent in a complicated way on how many +// elements the sequences have in common; best case time is linear. +type SequenceMatcher struct { + a []string + b []string + b2j map[string][]int + IsJunk func(string) bool + autoJunk bool + bJunk map[string]struct{} + matchingBlocks []Match + fullBCount map[string]int + bPopular map[string]struct{} + opCodes []OpCode +} + +// NewMatcher returns a new SequenceMatcher +func NewMatcher(a, b []string) *SequenceMatcher { + m := SequenceMatcher{autoJunk: true} + m.SetSeqs(a, b) + return &m +} + +// SetSeqs sets two sequences to be compared. +func (m *SequenceMatcher) SetSeqs(a, b []string) { + m.SetSeq1(a) + m.SetSeq2(b) +} + +// SetSeq1 sets the first sequence to be compared. The second sequence to be compared is +// not changed. +// +// SequenceMatcher computes and caches detailed information about the second +// sequence, so if you want to compare one sequence S against many sequences, +// use .SetSeq2(s) once and call .SetSeq1(x) repeatedly for each of the other +// sequences. +// +// See also SetSeqs() and SetSeq2(). +func (m *SequenceMatcher) SetSeq1(a []string) { + if &a == &m.a { + return + } + m.a = a + m.matchingBlocks = nil + m.opCodes = nil +} + +// SetSeq2 sets the second sequence to be compared. The first sequence to be compared is +// not changed. +func (m *SequenceMatcher) SetSeq2(b []string) { + if &b == &m.b { + return + } + m.b = b + m.matchingBlocks = nil + m.opCodes = nil + m.fullBCount = nil + m.chainB() +} + +func (m *SequenceMatcher) chainB() { + // Populate line -> index mapping + b2j := map[string][]int{} + for i, s := range m.b { + indices := b2j[s] + indices = append(indices, i) + b2j[s] = indices + } + + // Purge junk elements + m.bJunk = map[string]struct{}{} + if m.IsJunk != nil { + junk := m.bJunk + for s := range b2j { + if m.IsJunk(s) { + junk[s] = struct{}{} + } + } + for s := range junk { + delete(b2j, s) + } + } + + // Purge remaining popular elements + popular := map[string]struct{}{} + n := len(m.b) + if m.autoJunk && n >= 200 { + ntest := n/100 + 1 + for s, indices := range b2j { + if len(indices) > ntest { + popular[s] = struct{}{} + } + } + for s := range popular { + delete(b2j, s) + } + } + m.bPopular = popular + m.b2j = b2j +} + +func (m *SequenceMatcher) isBJunk(s string) bool { + _, ok := m.bJunk[s] + return ok +} + +// Find longest matching block in a[alo:ahi] and b[blo:bhi]. +// +// If IsJunk is not defined: +// +// Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where +// alo <= i <= i+k <= ahi +// blo <= j <= j+k <= bhi +// and for all (i',j',k') meeting those conditions, +// k >= k' +// i <= i' +// and if i == i', j <= j' +// +// In other words, of all maximal matching blocks, return one that +// starts earliest in a, and of all those maximal matching blocks that +// start earliest in a, return the one that starts earliest in b. +// +// If IsJunk is defined, first the longest matching block is +// determined as above, but with the additional restriction that no +// junk element appears in the block. Then that block is extended as +// far as possible by matching (only) junk elements on both sides. So +// the resulting block never matches on junk except as identical junk +// happens to be adjacent to an "interesting" match. +// +// If no blocks match, return (alo, blo, 0). +func (m *SequenceMatcher) findLongestMatch(alo, ahi, blo, bhi int) Match { + // CAUTION: stripping common prefix or suffix would be incorrect. + // E.g., + // ab + // acab + // Longest matching block is "ab", but if common prefix is + // stripped, it's "a" (tied with "b"). UNIX(tm) diff does so + // strip, so ends up claiming that ab is changed to acab by + // inserting "ca" in the middle. That's minimal but unintuitive: + // "it's obvious" that someone inserted "ac" at the front. + // Windiff ends up at the same place as diff, but by pairing up + // the unique 'b's and then matching the first two 'a's. + besti, bestj, bestsize := alo, blo, 0 + + // find longest junk-free match + // during an iteration of the loop, j2len[j] = length of longest + // junk-free match ending with a[i-1] and b[j] + j2len := map[int]int{} + for i := alo; i != ahi; i++ { + // look at all instances of a[i] in b; note that because + // b2j has no junk keys, the loop is skipped if a[i] is junk + newj2len := map[int]int{} + for _, j := range m.b2j[m.a[i]] { + // a[i] matches b[j] + if j < blo { + continue + } + if j >= bhi { + break + } + k := j2len[j-1] + 1 + newj2len[j] = k + if k > bestsize { + besti, bestj, bestsize = i-k+1, j-k+1, k + } + } + j2len = newj2len + } + + // Extend the best by non-junk elements on each end. In particular, + // "popular" non-junk elements aren't in b2j, which greatly speeds + // the inner loop above, but also means "the best" match so far + // doesn't contain any junk *or* popular non-junk elements. + for besti > alo && bestj > blo && !m.isBJunk(m.b[bestj-1]) && + m.a[besti-1] == m.b[bestj-1] { + besti, bestj, bestsize = besti-1, bestj-1, bestsize+1 + } + for besti+bestsize < ahi && bestj+bestsize < bhi && + !m.isBJunk(m.b[bestj+bestsize]) && + m.a[besti+bestsize] == m.b[bestj+bestsize] { + bestsize += 1 + } + + // Now that we have a wholly interesting match (albeit possibly + // empty!), we may as well suck up the matching junk on each + // side of it too. Can't think of a good reason not to, and it + // saves post-processing the (possibly considerable) expense of + // figuring out what to do with it. In the case of an empty + // interesting match, this is clearly the right thing to do, + // because no other kind of match is possible in the regions. + for besti > alo && bestj > blo && m.isBJunk(m.b[bestj-1]) && + m.a[besti-1] == m.b[bestj-1] { + besti, bestj, bestsize = besti-1, bestj-1, bestsize+1 + } + for besti+bestsize < ahi && bestj+bestsize < bhi && + m.isBJunk(m.b[bestj+bestsize]) && + m.a[besti+bestsize] == m.b[bestj+bestsize] { + bestsize += 1 + } + + return Match{A: besti, B: bestj, Size: bestsize} +} + +// GetMatchingBlocks returns a list of triples describing matching subsequences. +// +// Each triple is of the form (i, j, n), and means that +// a[i:i+n] == b[j:j+n]. The triples are monotonically increasing in +// i and in j. It's also guaranteed that if (i, j, n) and (i', j', n') are +// adjacent triples in the list, and the second is not the last triple in the +// list, then i+n != i' or j+n != j'. IOW, adjacent triples never describe +// adjacent equal blocks. +// +// The last triple is a dummy, (len(a), len(b), 0), and is the only +// triple with n==0. +func (m *SequenceMatcher) GetMatchingBlocks() []Match { + if m.matchingBlocks != nil { + return m.matchingBlocks + } + + var matchBlocks func(alo, ahi, blo, bhi int, matched []Match) []Match + matchBlocks = func(alo, ahi, blo, bhi int, matched []Match) []Match { + match := m.findLongestMatch(alo, ahi, blo, bhi) + i, j, k := match.A, match.B, match.Size + if match.Size > 0 { + if alo < i && blo < j { + matched = matchBlocks(alo, i, blo, j, matched) + } + matched = append(matched, match) + if i+k < ahi && j+k < bhi { + matched = matchBlocks(i+k, ahi, j+k, bhi, matched) + } + } + return matched + } + matched := matchBlocks(0, len(m.a), 0, len(m.b), nil) + + // It's possible that we have adjacent equal blocks in the + // matching_blocks list now. + nonAdjacent := []Match{} + i1, j1, k1 := 0, 0, 0 + for _, b := range matched { + // Is this block adjacent to i1, j1, k1? + i2, j2, k2 := b.A, b.B, b.Size + if i1+k1 == i2 && j1+k1 == j2 { + // Yes, so collapse them -- this just increases the length of + // the first block by the length of the second, and the first + // block so lengthened remains the block to compare against. + k1 += k2 + } else { + // Not adjacent. Remember the first block (k1==0 means it's + // the dummy we started with), and make the second block the + // new block to compare against. + if k1 > 0 { + nonAdjacent = append(nonAdjacent, Match{i1, j1, k1}) + } + i1, j1, k1 = i2, j2, k2 + } + } + if k1 > 0 { + nonAdjacent = append(nonAdjacent, Match{i1, j1, k1}) + } + + nonAdjacent = append(nonAdjacent, Match{len(m.a), len(m.b), 0}) + m.matchingBlocks = nonAdjacent + return m.matchingBlocks +} + +// GetOpCodes returns a list of 5-tuples describing how to turn a into b. +// +// Each tuple is of the form (tag, i1, i2, j1, j2). The first tuple +// has i1 == j1 == 0, and remaining tuples have i1 == the i2 from the +// tuple preceding it, and likewise for j1 == the previous j2. +// +// The tags are characters, with these meanings: +// +// 'r' (replace): a[i1:i2] should be replaced by b[j1:j2] +// +// 'd' (delete): a[i1:i2] should be deleted, j1==j2 in this case. +// +// 'i' (insert): b[j1:j2] should be inserted at a[i1:i1], i1==i2 in this case. +// +// 'e' (equal): a[i1:i2] == b[j1:j2] +func (m *SequenceMatcher) GetOpCodes() []OpCode { + if m.opCodes != nil { + return m.opCodes + } + i, j := 0, 0 + matching := m.GetMatchingBlocks() + opCodes := make([]OpCode, 0, len(matching)) + for _, m := range matching { + // invariant: we've pumped out correct diffs to change + // a[:i] into b[:j], and the next matching block is + // a[ai:ai+size] == b[bj:bj+size]. So we need to pump + // out a diff to change a[i:ai] into b[j:bj], pump out + // the matching block, and move (i,j) beyond the match + ai, bj, size := m.A, m.B, m.Size + tag := byte(0) + if i < ai && j < bj { + tag = 'r' + } else if i < ai { + tag = 'd' + } else if j < bj { + tag = 'i' + } + if tag > 0 { + opCodes = append(opCodes, OpCode{tag, i, ai, j, bj}) + } + i, j = ai+size, bj+size + // the list of matching blocks is terminated by a + // sentinel with size 0 + if size > 0 { + opCodes = append(opCodes, OpCode{'e', ai, i, bj, j}) + } + } + m.opCodes = opCodes + return m.opCodes +} + +// GetGroupedOpCodes isolates change clusters by eliminating ranges with no changes. +// +// Return a generator of groups with up to n lines of context. +// Each group is in the same format as returned by GetOpCodes(). +func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode { + if n < 0 { + n = 3 + } + codes := m.GetOpCodes() + if len(codes) == 0 { + codes = []OpCode{{'e', 0, 1, 0, 1}} + } + // Fixup leading and trailing groups if they show no changes. + if codes[0].Tag == 'e' { + c := codes[0] + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + codes[0] = OpCode{c.Tag, max(i1, i2-n), i2, max(j1, j2-n), j2} + } + if codes[len(codes)-1].Tag == 'e' { + c := codes[len(codes)-1] + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + codes[len(codes)-1] = OpCode{c.Tag, i1, min(i2, i1+n), j1, min(j2, j1+n)} + } + nn := n + n + groups := [][]OpCode{} + group := []OpCode{} + for _, c := range codes { + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + // End the current group and start a new one whenever + // there is a large range with no changes. + if c.Tag == 'e' && i2-i1 > nn { + group = append(group, OpCode{c.Tag, i1, min(i2, i1+n), + j1, min(j2, j1+n)}) + groups = append(groups, group) + group = []OpCode{} + i1, j1 = max(i1, i2-n), max(j1, j2-n) + } + group = append(group, OpCode{c.Tag, i1, i2, j1, j2}) + } + if len(group) > 0 && !(len(group) == 1 && group[0].Tag == 'e') { + groups = append(groups, group) + } + return groups +} diff --git a/vendor/gotest.tools/internal/format/diff.go b/vendor/gotest.tools/internal/format/diff.go new file mode 100644 index 00000000..c938c97b --- /dev/null +++ b/vendor/gotest.tools/internal/format/diff.go @@ -0,0 +1,161 @@ +package format + +import ( + "bytes" + "fmt" + "strings" + "unicode" + + "gotest.tools/internal/difflib" +) + +const ( + contextLines = 2 +) + +// DiffConfig for a unified diff +type DiffConfig struct { + A string + B string + From string + To string +} + +// UnifiedDiff is a modified version of difflib.WriteUnifiedDiff with better +// support for showing the whitespace differences. +func UnifiedDiff(conf DiffConfig) string { + a := strings.SplitAfter(conf.A, "\n") + b := strings.SplitAfter(conf.B, "\n") + groups := difflib.NewMatcher(a, b).GetGroupedOpCodes(contextLines) + if len(groups) == 0 { + return "" + } + + buf := new(bytes.Buffer) + writeFormat := func(format string, args ...interface{}) { + buf.WriteString(fmt.Sprintf(format, args...)) + } + writeLine := func(prefix string, s string) { + buf.WriteString(prefix + s) + } + if hasWhitespaceDiffLines(groups, a, b) { + writeLine = visibleWhitespaceLine(writeLine) + } + formatHeader(writeFormat, conf) + for _, group := range groups { + formatRangeLine(writeFormat, group) + for _, opCode := range group { + in, out := a[opCode.I1:opCode.I2], b[opCode.J1:opCode.J2] + switch opCode.Tag { + case 'e': + formatLines(writeLine, " ", in) + case 'r': + formatLines(writeLine, "-", in) + formatLines(writeLine, "+", out) + case 'd': + formatLines(writeLine, "-", in) + case 'i': + formatLines(writeLine, "+", out) + } + } + } + return buf.String() +} + +// hasWhitespaceDiffLines returns true if any diff groups is only different +// because of whitespace characters. +func hasWhitespaceDiffLines(groups [][]difflib.OpCode, a, b []string) bool { + for _, group := range groups { + in, out := new(bytes.Buffer), new(bytes.Buffer) + for _, opCode := range group { + if opCode.Tag == 'e' { + continue + } + for _, line := range a[opCode.I1:opCode.I2] { + in.WriteString(line) + } + for _, line := range b[opCode.J1:opCode.J2] { + out.WriteString(line) + } + } + if removeWhitespace(in.String()) == removeWhitespace(out.String()) { + return true + } + } + return false +} + +func removeWhitespace(s string) string { + var result []rune + for _, r := range s { + if !unicode.IsSpace(r) { + result = append(result, r) + } + } + return string(result) +} + +func visibleWhitespaceLine(ws func(string, string)) func(string, string) { + mapToVisibleSpace := func(r rune) rune { + switch r { + case '\n': + case ' ': + return '·' + case '\t': + return '▷' + case '\v': + return '▽' + case '\r': + return '↵' + case '\f': + return '↓' + default: + if unicode.IsSpace(r) { + return '�' + } + } + return r + } + return func(prefix, s string) { + ws(prefix, strings.Map(mapToVisibleSpace, s)) + } +} + +func formatHeader(wf func(string, ...interface{}), conf DiffConfig) { + if conf.From != "" || conf.To != "" { + wf("--- %s\n", conf.From) + wf("+++ %s\n", conf.To) + } +} + +func formatRangeLine(wf func(string, ...interface{}), group []difflib.OpCode) { + first, last := group[0], group[len(group)-1] + range1 := formatRangeUnified(first.I1, last.I2) + range2 := formatRangeUnified(first.J1, last.J2) + wf("@@ -%s +%s @@\n", range1, range2) +} + +// Convert range to the "ed" format +func formatRangeUnified(start, stop int) string { + // Per the diff spec at http://www.unix.org/single_unix_specification/ + beginning := start + 1 // lines start numbering with one + length := stop - start + if length == 1 { + return fmt.Sprintf("%d", beginning) + } + if length == 0 { + beginning-- // empty ranges begin at line just before the range + } + return fmt.Sprintf("%d,%d", beginning, length) +} + +func formatLines(writeLine func(string, string), prefix string, lines []string) { + for _, line := range lines { + writeLine(prefix, line) + } + // Add a newline if the last line is missing one so that the diff displays + // properly. + if !strings.HasSuffix(lines[len(lines)-1], "\n") { + writeLine("", "\n") + } +} diff --git a/vendor/gotest.tools/internal/format/format.go b/vendor/gotest.tools/internal/format/format.go new file mode 100644 index 00000000..8f6494f9 --- /dev/null +++ b/vendor/gotest.tools/internal/format/format.go @@ -0,0 +1,27 @@ +package format // import "gotest.tools/internal/format" + +import "fmt" + +// Message accepts a msgAndArgs varargs and formats it using fmt.Sprintf +func Message(msgAndArgs ...interface{}) string { + switch len(msgAndArgs) { + case 0: + return "" + case 1: + return fmt.Sprintf("%v", msgAndArgs[0]) + default: + return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...) + } +} + +// WithCustomMessage accepts one or two messages and formats them appropriately +func WithCustomMessage(source string, msgAndArgs ...interface{}) string { + custom := Message(msgAndArgs...) + switch { + case custom == "": + return source + case source == "": + return custom + } + return fmt.Sprintf("%s: %s", source, custom) +} diff --git a/vendor/gotest.tools/internal/source/defers.go b/vendor/gotest.tools/internal/source/defers.go new file mode 100644 index 00000000..66cfafbb --- /dev/null +++ b/vendor/gotest.tools/internal/source/defers.go @@ -0,0 +1,53 @@ +package source + +import ( + "go/ast" + "go/token" + + "github.com/pkg/errors" +) + +func scanToDeferLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { + var matchedNode ast.Node + ast.Inspect(node, func(node ast.Node) bool { + switch { + case node == nil || matchedNode != nil: + return false + case fileset.Position(node.End()).Line == lineNum: + if funcLit, ok := node.(*ast.FuncLit); ok { + matchedNode = funcLit + return false + } + } + return true + }) + debug("defer line node: %s", debugFormatNode{matchedNode}) + return matchedNode +} + +func guessDefer(node ast.Node) (ast.Node, error) { + defers := collectDefers(node) + switch len(defers) { + case 0: + return nil, errors.New("failed to expression in defer") + case 1: + return defers[0].Call, nil + default: + return nil, errors.Errorf( + "ambiguous call expression: multiple (%d) defers in call block", + len(defers)) + } +} + +func collectDefers(node ast.Node) []*ast.DeferStmt { + var defers []*ast.DeferStmt + ast.Inspect(node, func(node ast.Node) bool { + if d, ok := node.(*ast.DeferStmt); ok { + defers = append(defers, d) + debug("defer: %s", debugFormatNode{d}) + return false + } + return true + }) + return defers +} diff --git a/vendor/gotest.tools/internal/source/source.go b/vendor/gotest.tools/internal/source/source.go new file mode 100644 index 00000000..8a5d0e8d --- /dev/null +++ b/vendor/gotest.tools/internal/source/source.go @@ -0,0 +1,166 @@ +package source // import "gotest.tools/internal/source" + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "os" + "runtime" + "strconv" + "strings" + + "github.com/pkg/errors" +) + +const baseStackIndex = 1 + +// FormattedCallExprArg returns the argument from an ast.CallExpr at the +// index in the call stack. The argument is formatted using FormatNode. +func FormattedCallExprArg(stackIndex int, argPos int) (string, error) { + args, err := CallExprArgs(stackIndex + 1) + if err != nil { + return "", err + } + if argPos >= len(args) { + return "", errors.New("failed to find expression") + } + return FormatNode(args[argPos]) +} + +// CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at +// the index in the call stack. +func CallExprArgs(stackIndex int) ([]ast.Expr, error) { + _, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex) + if !ok { + return nil, errors.New("failed to get call stack") + } + debug("call stack position: %s:%d", filename, lineNum) + + node, err := getNodeAtLine(filename, lineNum) + if err != nil { + return nil, err + } + debug("found node: %s", debugFormatNode{node}) + + return getCallExprArgs(node) +} + +func getNodeAtLine(filename string, lineNum int) (ast.Node, error) { + fileset := token.NewFileSet() + astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors) + if err != nil { + return nil, errors.Wrapf(err, "failed to parse source file: %s", filename) + } + + if node := scanToLine(fileset, astFile, lineNum); node != nil { + return node, nil + } + if node := scanToDeferLine(fileset, astFile, lineNum); node != nil { + node, err := guessDefer(node) + if err != nil || node != nil { + return node, err + } + } + return nil, errors.Errorf( + "failed to find an expression on line %d in %s", lineNum, filename) +} + +func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { + var matchedNode ast.Node + ast.Inspect(node, func(node ast.Node) bool { + switch { + case node == nil || matchedNode != nil: + return false + case nodePosition(fileset, node).Line == lineNum: + matchedNode = node + return false + } + return true + }) + return matchedNode +} + +// In golang 1.9 the line number changed from being the line where the statement +// ended to the line where the statement began. +func nodePosition(fileset *token.FileSet, node ast.Node) token.Position { + if goVersionBefore19 { + return fileset.Position(node.End()) + } + return fileset.Position(node.Pos()) +} + +var goVersionBefore19 = func() bool { + version := runtime.Version() + // not a release version + if !strings.HasPrefix(version, "go") { + return false + } + version = strings.TrimPrefix(version, "go") + parts := strings.Split(version, ".") + if len(parts) < 2 { + return false + } + minor, err := strconv.ParseInt(parts[1], 10, 32) + return err == nil && parts[0] == "1" && minor < 9 +}() + +func getCallExprArgs(node ast.Node) ([]ast.Expr, error) { + visitor := &callExprVisitor{} + ast.Walk(visitor, node) + if visitor.expr == nil { + return nil, errors.New("failed to find call expression") + } + debug("callExpr: %s", debugFormatNode{visitor.expr}) + return visitor.expr.Args, nil +} + +type callExprVisitor struct { + expr *ast.CallExpr +} + +func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor { + if v.expr != nil || node == nil { + return nil + } + debug("visit: %s", debugFormatNode{node}) + + switch typed := node.(type) { + case *ast.CallExpr: + v.expr = typed + return nil + case *ast.DeferStmt: + ast.Walk(v, typed.Call.Fun) + return nil + } + return v +} + +// FormatNode using go/format.Node and return the result as a string +func FormatNode(node ast.Node) (string, error) { + buf := new(bytes.Buffer) + err := format.Node(buf, token.NewFileSet(), node) + return buf.String(), err +} + +var debugEnabled = os.Getenv("GOTESTTOOLS_DEBUG") != "" + +func debug(format string, args ...interface{}) { + if debugEnabled { + fmt.Fprintf(os.Stderr, "DEBUG: "+format+"\n", args...) + } +} + +type debugFormatNode struct { + ast.Node +} + +func (n debugFormatNode) String() string { + out, err := FormatNode(n.Node) + if err != nil { + return fmt.Sprintf("failed to format %s: %s", n.Node, err) + } + return fmt.Sprintf("(%T) %s", n.Node, out) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index c519d007..a316e543 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -431,6 +431,13 @@ gopkg.in/yaml.v2 # gopkg.in/yaml.v3 v3.0.1 ## explicit gopkg.in/yaml.v3 +# gotest.tools v2.2.0+incompatible +## explicit +gotest.tools/assert +gotest.tools/assert/cmp +gotest.tools/internal/difflib +gotest.tools/internal/format +gotest.tools/internal/source # k8s.io/api v0.27.4 ## explicit; go 1.20 k8s.io/api/admission/v1