diff --git a/api/handler/v1/runtime.go b/api/handler/v1/runtime.go index 14123a943c..67d8c1785c 100644 --- a/api/handler/v1/runtime.go +++ b/api/handler/v1/runtime.go @@ -94,13 +94,13 @@ func (sv *RuntimeServiceServer) DeployJobSpecification(req *pb.DeployJobSpecific startTime := time.Now() projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(respStream.Context(), req.GetProjectName()) if err != nil { return status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.GetNamespace()) + namespaceSpec, err := namespaceRepo.GetByName(respStream.Context(), req.GetNamespace()) if err != nil { return status.Errorf(codes.NotFound, "%s: namespace %s not found", err.Error(), req.GetNamespace()) } @@ -112,7 +112,7 @@ func (sv *RuntimeServiceServer) DeployJobSpecification(req *pb.DeployJobSpecific return status.Errorf(codes.Internal, "%s: cannot adapt job %s", err.Error(), reqJob.GetName()) } - err = sv.jobSvc.Create(namespaceSpec, adaptJob) + err = sv.jobSvc.Create(respStream.Context(), namespaceSpec, adaptJob) if err != nil { return status.Errorf(codes.Internal, "%s: failed to save %s", err.Error(), adaptJob.Name) } @@ -128,7 +128,7 @@ func (sv *RuntimeServiceServer) DeployJobSpecification(req *pb.DeployJobSpecific }) // delete specs not sent for deployment from internal repository - if err := sv.jobSvc.KeepOnly(namespaceSpec, jobsToKeep, observers); err != nil { + if err := sv.jobSvc.KeepOnly(respStream.Context(), namespaceSpec, jobsToKeep, observers); err != nil { return status.Errorf(codes.Internal, "failed to delete jobs: \n%s", err.Error()) } @@ -142,18 +142,18 @@ func (sv *RuntimeServiceServer) DeployJobSpecification(req *pb.DeployJobSpecific func (sv *RuntimeServiceServer) ListJobSpecification(ctx context.Context, req *pb.ListJobSpecificationRequest) (*pb.ListJobSpecificationResponse, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.GetNamespace()) + namespaceSpec, err := namespaceRepo.GetByName(ctx, req.GetNamespace()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: namespace %s not found", err.Error(), req.GetNamespace()) } - jobSpecs, err := sv.jobSvc.GetAll(namespaceSpec) + jobSpecs, err := sv.jobSvc.GetAll(ctx, namespaceSpec) if err != nil { return nil, status.Errorf(codes.Internal, "%s: failed to retrieve jobs for project %s", err.Error(), req.GetProjectName()) } @@ -177,13 +177,13 @@ func (sv *RuntimeServiceServer) DumpJobSpecification(ctx context.Context, req *p func (sv *RuntimeServiceServer) CheckJobSpecification(ctx context.Context, req *pb.CheckJobSpecificationRequest) (*pb.CheckJobSpecificationResponse, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.GetNamespace()) + namespaceSpec, err := namespaceRepo.GetByName(ctx, req.GetNamespace()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: namespace %s not found", err.Error(), req.GetNamespace()) } @@ -202,13 +202,13 @@ func (sv *RuntimeServiceServer) CheckJobSpecification(ctx context.Context, req * func (sv *RuntimeServiceServer) CheckJobSpecifications(req *pb.CheckJobSpecificationsRequest, respStream pb.RuntimeService_CheckJobSpecificationsServer) error { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(respStream.Context(), req.GetProjectName()) if err != nil { return status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.GetNamespace()) + namespaceSpec, err := namespaceRepo.GetByName(respStream.Context(), req.GetNamespace()) if err != nil { return status.Errorf(codes.NotFound, "%s: namespace %s not found", err.Error(), req.GetNamespace()) } @@ -240,12 +240,12 @@ func (sv *RuntimeServiceServer) RegisterProject(ctx context.Context, req *pb.Reg projectRepo := sv.projectRepoFactory.New() projectSpec := sv.adapter.FromProjectProto(req.GetProject()) - if err := projectRepo.Save(projectSpec); err != nil { + if err := projectRepo.Save(ctx, projectSpec); err != nil { return nil, status.Errorf(codes.Internal, "%s: failed to save project %s", err.Error(), req.GetProject().GetName()) } if req.GetNamespace() != nil { - savedProjectSpec, err := projectRepo.GetByName(projectSpec.Name) + savedProjectSpec, err := projectRepo.GetByName(ctx, projectSpec.Name) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: failed to find project %s", err.Error(), req.GetProject().GetName()) @@ -253,7 +253,7 @@ func (sv *RuntimeServiceServer) RegisterProject(ctx context.Context, req *pb.Reg namespaceRepo := sv.namespaceRepoFactory.New(savedProjectSpec) namespaceSpec := sv.adapter.FromNamespaceProto(req.GetNamespace()) - if err = namespaceRepo.Save(namespaceSpec); err != nil { + if err = namespaceRepo.Save(ctx, namespaceSpec); err != nil { return nil, status.Errorf(codes.Internal, "%s: failed to save project %s with namespace %s", err.Error(), req.GetProject().GetName(), req.GetNamespace().GetName()) } @@ -267,14 +267,14 @@ func (sv *RuntimeServiceServer) RegisterProject(ctx context.Context, req *pb.Reg func (sv *RuntimeServiceServer) RegisterProjectNamespace(ctx context.Context, req *pb.RegisterProjectNamespaceRequest) (*pb.RegisterProjectNamespaceResponse, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceSpec := sv.adapter.FromNamespaceProto(req.GetNamespace()) namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - if err = namespaceRepo.Save(namespaceSpec); err != nil { + if err = namespaceRepo.Save(ctx, namespaceSpec); err != nil { return nil, status.Errorf(codes.Internal, "%s: failed to save namespace %s for project %s", err.Error(), namespaceSpec.Name, projSpec.Name) } @@ -286,13 +286,13 @@ func (sv *RuntimeServiceServer) RegisterProjectNamespace(ctx context.Context, re func (sv *RuntimeServiceServer) CreateJobSpecification(ctx context.Context, req *pb.CreateJobSpecificationRequest) (*pb.CreateJobSpecificationResponse, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.GetNamespace()) + namespaceSpec, err := namespaceRepo.GetByName(ctx, req.GetNamespace()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: namespace %s not found. Is it registered?", err.Error(), req.GetNamespace()) } @@ -307,7 +307,7 @@ func (sv *RuntimeServiceServer) CreateJobSpecification(ctx context.Context, req return nil, status.Errorf(codes.Internal, "spec validation failed\n%s", err.Error()) } - err = sv.jobSvc.Create(namespaceSpec, jobSpec) + err = sv.jobSvc.Create(ctx, namespaceSpec, jobSpec) if err != nil { return nil, status.Errorf(codes.Internal, "%s: failed to save job %s", err.Error(), jobSpec.Name) } @@ -324,18 +324,18 @@ func (sv *RuntimeServiceServer) CreateJobSpecification(ctx context.Context, req func (sv *RuntimeServiceServer) ReadJobSpecification(ctx context.Context, req *pb.ReadJobSpecificationRequest) (*pb.ReadJobSpecificationResponse, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.GetNamespace()) + namespaceSpec, err := namespaceRepo.GetByName(ctx, req.GetNamespace()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: namespace %s not found. Is it registered?", err.Error(), req.GetNamespace()) } - jobSpec, err := sv.jobSvc.GetByName(req.GetJobName(), namespaceSpec) + jobSpec, err := sv.jobSvc.GetByName(ctx, req.GetJobName(), namespaceSpec) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: error while finding the job %s", err.Error(), req.GetJobName()) } @@ -352,18 +352,18 @@ func (sv *RuntimeServiceServer) ReadJobSpecification(ctx context.Context, req *p func (sv *RuntimeServiceServer) DeleteJobSpecification(ctx context.Context, req *pb.DeleteJobSpecificationRequest) (*pb.DeleteJobSpecificationResponse, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.GetNamespace()) + namespaceSpec, err := namespaceRepo.GetByName(ctx, req.GetNamespace()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: namespace %s not found. Is it registered?", err.Error(), req.GetNamespace()) } - jobSpecToDelete, err := sv.jobSvc.GetByName(req.GetJobName(), namespaceSpec) + jobSpecToDelete, err := sv.jobSvc.GetByName(ctx, req.GetJobName(), namespaceSpec) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: job %s does not exist", err.Error(), req.GetJobName()) } @@ -380,7 +380,7 @@ func (sv *RuntimeServiceServer) DeleteJobSpecification(ctx context.Context, req func (sv *RuntimeServiceServer) ListProjects(ctx context.Context, req *pb.ListProjectsRequest) (*pb.ListProjectsResponse, error) { projectRepo := sv.projectRepoFactory.New() - projects, err := projectRepo.GetAll() + projects, err := projectRepo.GetAll(ctx) if err != nil { return nil, status.Errorf(codes.NotFound, "failed to retrieve saved projects: \n%s", err.Error()) } @@ -397,13 +397,13 @@ func (sv *RuntimeServiceServer) ListProjects(ctx context.Context, req *pb.ListPr func (sv *RuntimeServiceServer) ListProjectNamespaces(ctx context.Context, req *pb.ListProjectNamespacesRequest) (*pb.ListProjectNamespacesResponse, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpecs, err := namespaceRepo.GetAll() + namespaceSpecs, err := namespaceRepo.GetAll(ctx) if err != nil { return nil, status.Errorf(codes.Internal, "error while fetching namespaces: \n%s", err.Error()) } @@ -420,7 +420,7 @@ func (sv *RuntimeServiceServer) ListProjectNamespaces(ctx context.Context, req * func (sv *RuntimeServiceServer) RegisterInstance(ctx context.Context, req *pb.RegisterInstanceRequest) (*pb.RegisterInstanceResponse, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } @@ -435,12 +435,12 @@ func (sv *RuntimeServiceServer) RegisterInstance(ctx context.Context, req *pb.Re if req.JobrunId == "" { var jobSpec models.JobSpec // a scheduled trigger instance, extract job run id if already present or create a new run - jobSpec, namespaceSpec, err = sv.jobSvc.GetByNameForProject(req.GetJobName(), projSpec) + jobSpec, namespaceSpec, err = sv.jobSvc.GetByNameForProject(ctx, req.GetJobName(), projSpec) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: job %s not found", err.Error(), req.GetJobName()) } - jobRun, err = sv.runSvc.GetScheduledRun(namespaceSpec, jobSpec, req.GetScheduledAt().AsTime()) + jobRun, err = sv.runSvc.GetScheduledRun(ctx, namespaceSpec, jobSpec, req.GetScheduledAt().AsTime()) if err != nil { return nil, status.Errorf(codes.Internal, "%s: failed to initialize scheduled run of job %s", err.Error(), req.GetJobName()) } @@ -450,17 +450,17 @@ func (sv *RuntimeServiceServer) RegisterInstance(ctx context.Context, req *pb.Re if err != nil { return nil, status.Errorf(codes.InvalidArgument, "%s: failed to parse uuid of job %s", err.Error(), req.JobrunId) } - jobRun, namespaceSpec, err = sv.runSvc.GetByID(jobRunID) + jobRun, namespaceSpec, err = sv.runSvc.GetByID(ctx, jobRunID) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: failed to find scheduled run of job %s", err.Error(), req.JobrunId) } } - instance, err := sv.runSvc.Register(namespaceSpec, jobRun, instanceType, req.GetInstanceName()) + instance, err := sv.runSvc.Register(ctx, namespaceSpec, jobRun, instanceType, req.GetInstanceName()) if err != nil { return nil, status.Errorf(codes.Internal, "%s: failed to register instance of jobrun %s", err.Error(), jobRun) } - envMap, fileMap, err := sv.runSvc.Compile(namespaceSpec, jobRun, instance) + envMap, fileMap, err := sv.runSvc.Compile(ctx, namespaceSpec, jobRun, instance) if err != nil { return nil, status.Errorf(codes.Internal, "%s: failed to compile instance of job %s", err.Error(), req.GetJobName()) } @@ -487,12 +487,12 @@ func (sv *RuntimeServiceServer) RegisterInstance(ctx context.Context, req *pb.Re func (sv *RuntimeServiceServer) JobStatus(ctx context.Context, req *pb.JobStatusRequest) (*pb.JobStatusResponse, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } - _, _, err = sv.jobSvc.GetByNameForProject(req.GetJobName(), projSpec) + _, _, err = sv.jobSvc.GetByNameForProject(ctx, req.GetJobName(), projSpec) if err != nil { return nil, status.Errorf(codes.NotFound, "%s\nfailed to find the job %s for project %s", err.Error(), req.GetJobName(), req.GetProjectName()) @@ -519,18 +519,18 @@ func (sv *RuntimeServiceServer) JobStatus(ctx context.Context, req *pb.JobStatus func (sv *RuntimeServiceServer) RegisterJobEvent(ctx context.Context, req *pb.RegisterJobEventRequest) (*pb.RegisterJobEventResponse, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.GetNamespace()) + namespaceSpec, err := namespaceRepo.GetByName(ctx, req.GetNamespace()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: namespace %s not found", err.Error(), req.GetNamespace()) } - jobSpec, err := sv.jobSvc.GetByName(req.GetJobName(), namespaceSpec) + jobSpec, err := sv.jobSvc.GetByName(ctx, req.GetJobName(), namespaceSpec) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: failed to find the job %s for namespace %s", err.Error(), req.GetJobName(), req.GetNamespace()) @@ -590,13 +590,13 @@ func (sv *RuntimeServiceServer) RegisterSecret(ctx context.Context, req *pb.Regi } projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } secretRepo := sv.secretRepoFactory.New(projSpec) - if err := secretRepo.Save(models.ProjectSecretItem{ + if err := secretRepo.Save(ctx, models.ProjectSecretItem{ Name: req.GetSecretName(), Value: string(base64Decoded), }); err != nil { @@ -610,13 +610,13 @@ func (sv *RuntimeServiceServer) RegisterSecret(ctx context.Context, req *pb.Regi func (sv *RuntimeServiceServer) CreateResource(ctx context.Context, req *pb.CreateResourceRequest) (*pb.CreateResourceResponse, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.GetNamespace()) + namespaceSpec, err := namespaceRepo.GetByName(ctx, req.GetNamespace()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: namespace %s not found", err.Error(), req.GetNamespace()) } @@ -636,13 +636,13 @@ func (sv *RuntimeServiceServer) CreateResource(ctx context.Context, req *pb.Crea func (sv *RuntimeServiceServer) UpdateResource(ctx context.Context, req *pb.UpdateResourceRequest) (*pb.UpdateResourceResponse, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.GetNamespace()) + namespaceSpec, err := namespaceRepo.GetByName(ctx, req.GetNamespace()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: namespace %s not found", err.Error(), req.GetNamespace()) } @@ -662,13 +662,13 @@ func (sv *RuntimeServiceServer) UpdateResource(ctx context.Context, req *pb.Upda func (sv *RuntimeServiceServer) ReadResource(ctx context.Context, req *pb.ReadResourceRequest) (*pb.ReadResourceResponse, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.GetNamespace()) + namespaceSpec, err := namespaceRepo.GetByName(ctx, req.GetNamespace()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: namespace %s not found", err.Error(), req.GetNamespace()) } @@ -693,13 +693,13 @@ func (sv *RuntimeServiceServer) DeployResourceSpecification(req *pb.DeployResour startTime := time.Now() projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(respStream.Context(), req.GetProjectName()) if err != nil { return status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.GetNamespace()) + namespaceSpec, err := namespaceRepo.GetByName(respStream.Context(), req.GetNamespace()) if err != nil { return status.Errorf(codes.NotFound, "%s: namespace %s not found", err.Error(), req.GetNamespace()) } @@ -730,18 +730,18 @@ func (sv *RuntimeServiceServer) DeployResourceSpecification(req *pb.DeployResour func (sv *RuntimeServiceServer) ListResourceSpecification(ctx context.Context, req *pb.ListResourceSpecificationRequest) (*pb.ListResourceSpecificationResponse, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(req.GetProjectName()) + projSpec, err := projectRepo.GetByName(ctx, req.GetProjectName()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.GetProjectName()) } namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.GetNamespace()) + namespaceSpec, err := namespaceRepo.GetByName(ctx, req.GetNamespace()) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: namespace %s not found", err.Error(), req.GetNamespace()) } - resourceSpecs, err := sv.resourceSvc.GetAll(namespaceSpec, req.DatastoreName) + resourceSpecs, err := sv.resourceSvc.GetAll(ctx, namespaceSpec, req.DatastoreName) if err != nil { return nil, status.Errorf(codes.Internal, "%s: failed to retrieve jobs for project %s", err.Error(), req.GetProjectName()) } @@ -760,7 +760,7 @@ func (sv *RuntimeServiceServer) ListResourceSpecification(ctx context.Context, r } func (sv *RuntimeServiceServer) ReplayDryRun(ctx context.Context, req *pb.ReplayDryRunRequest) (*pb.ReplayDryRunResponse, error) { - replayRequest, err := sv.parseReplayRequest(req.ProjectName, req.Namespace, req.JobName, req.StartDate, req.EndDate, false) + replayRequest, err := sv.parseReplayRequest(ctx, req.ProjectName, req.Namespace, req.JobName, req.StartDate, req.EndDate, false) if err != nil { return nil, err } @@ -781,7 +781,7 @@ func (sv *RuntimeServiceServer) ReplayDryRun(ctx context.Context, req *pb.Replay } func (sv *RuntimeServiceServer) Replay(ctx context.Context, req *pb.ReplayRequest) (*pb.ReplayResponse, error) { - replayWorkerRequest, err := sv.parseReplayRequest(req.ProjectName, req.Namespace, req.JobName, req.StartDate, req.EndDate, req.Force) + replayWorkerRequest, err := sv.parseReplayRequest(ctx, req.ProjectName, req.Namespace, req.JobName, req.StartDate, req.EndDate, req.Force) if err != nil { return nil, err } @@ -802,7 +802,7 @@ func (sv *RuntimeServiceServer) Replay(ctx context.Context, req *pb.ReplayReques } func (sv *RuntimeServiceServer) GetReplayStatus(ctx context.Context, req *pb.GetReplayStatusRequest) (*pb.GetReplayStatusResponse, error) { - replayRequest, err := sv.parseReplayStatusRequest(req) + replayRequest, err := sv.parseReplayStatusRequest(ctx, req) if err != nil { return nil, err } @@ -823,8 +823,8 @@ func (sv *RuntimeServiceServer) GetReplayStatus(ctx context.Context, req *pb.Get }, nil } -func (sv *RuntimeServiceServer) parseReplayStatusRequest(req *pb.GetReplayStatusRequest) (models.ReplayRequest, error) { - projSpec, err := sv.getProjectSpec(req.ProjectName) +func (sv *RuntimeServiceServer) parseReplayStatusRequest(ctx context.Context, req *pb.GetReplayStatusRequest) (models.ReplayRequest, error) { + projSpec, err := sv.getProjectSpec(ctx, req.ProjectName) if err != nil { return models.ReplayRequest{}, err } @@ -842,12 +842,12 @@ func (sv *RuntimeServiceServer) parseReplayStatusRequest(req *pb.GetReplayStatus } func (sv *RuntimeServiceServer) ListReplays(ctx context.Context, req *pb.ListReplaysRequest) (*pb.ListReplaysResponse, error) { - projSpec, err := sv.getProjectSpec(req.ProjectName) + projSpec, err := sv.getProjectSpec(ctx, req.ProjectName) if err != nil { return nil, err } - replays, err := sv.jobSvc.GetReplayList(projSpec.ID) + replays, err := sv.jobSvc.GetReplayList(ctx, projSpec.ID) if err != nil { return nil, status.Errorf(codes.Internal, "error while getting replay list: %v", err) } @@ -869,14 +869,14 @@ func (sv *RuntimeServiceServer) ListReplays(ctx context.Context, req *pb.ListRep }, nil } -func (sv *RuntimeServiceServer) parseReplayRequest(projectName string, namespace string, jobName string, startDate string, +func (sv *RuntimeServiceServer) parseReplayRequest(ctx context.Context, projectName string, namespace string, jobName string, startDate string, endDate string, forceFlag bool) (models.ReplayRequest, error) { - projSpec, err := sv.getProjectSpec(projectName) + projSpec, err := sv.getProjectSpec(ctx, projectName) if err != nil { return models.ReplayRequest{}, err } - jobSpec, err := sv.getJobSpec(projSpec, namespace, jobName) + jobSpec, err := sv.getJobSpec(ctx, projSpec, namespace, jobName) if err != nil { return models.ReplayRequest{}, err } @@ -905,23 +905,23 @@ func (sv *RuntimeServiceServer) parseReplayRequest(projectName string, namespace return replayRequest, nil } -func (sv *RuntimeServiceServer) getProjectSpec(projectName string) (models.ProjectSpec, error) { +func (sv *RuntimeServiceServer) getProjectSpec(ctx context.Context, projectName string) (models.ProjectSpec, error) { projectRepo := sv.projectRepoFactory.New() - projSpec, err := projectRepo.GetByName(projectName) + projSpec, err := projectRepo.GetByName(ctx, projectName) if err != nil { return models.ProjectSpec{}, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), projectName) } return projSpec, nil } -func (sv *RuntimeServiceServer) getJobSpec(projSpec models.ProjectSpec, namespace string, jobName string) (models.JobSpec, error) { +func (sv *RuntimeServiceServer) getJobSpec(ctx context.Context, projSpec models.ProjectSpec, namespace string, jobName string) (models.JobSpec, error) { namespaceRepo := sv.namespaceRepoFactory.New(projSpec) - namespaceSpec, err := namespaceRepo.GetByName(namespace) + namespaceSpec, err := namespaceRepo.GetByName(ctx, namespace) if err != nil { return models.JobSpec{}, status.Errorf(codes.NotFound, "%s: namespace %s not found", err.Error(), namespace) } - jobSpec, err := sv.jobSvc.GetByName(jobName, namespaceSpec) + jobSpec, err := sv.jobSvc.GetByName(ctx, jobName, namespaceSpec) if err != nil { return models.JobSpec{}, status.Errorf(codes.NotFound, "%s: failed to find the job %s for namespace %s", err.Error(), jobName, namespace) @@ -930,13 +930,13 @@ func (sv *RuntimeServiceServer) getJobSpec(projSpec models.ProjectSpec, namespac } func (sv *RuntimeServiceServer) BackupDryRun(ctx context.Context, req *pb.BackupDryRunRequest) (*pb.BackupDryRunResponse, error) { - projectSpec, err := sv.getProjectSpec(req.ProjectName) + projectSpec, err := sv.getProjectSpec(ctx, req.ProjectName) if err != nil { return nil, err } namespaceRepo := sv.namespaceRepoFactory.New(projectSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.Namespace) + namespaceSpec, err := namespaceRepo.GetByName(ctx, req.Namespace) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: namespace %s not found", err.Error(), req.Namespace) } @@ -947,7 +947,7 @@ func (sv *RuntimeServiceServer) BackupDryRun(ctx context.Context, req *pb.Backup } var jobSpecs []models.JobSpec - jobSpec, err := sv.jobSvc.GetByDestination(projectSpec, resourceSpec.URN) + jobSpec, err := sv.jobSvc.GetByDestination(ctx, projectSpec, resourceSpec.URN) if err != nil { return nil, status.Errorf(codes.Internal, "error while getting job: %v", err) } @@ -981,13 +981,13 @@ func (sv *RuntimeServiceServer) BackupDryRun(ctx context.Context, req *pb.Backup } func (sv *RuntimeServiceServer) Backup(ctx context.Context, req *pb.BackupRequest) (*pb.BackupResponse, error) { - projectSpec, err := sv.getProjectSpec(req.ProjectName) + projectSpec, err := sv.getProjectSpec(ctx, req.ProjectName) if err != nil { return nil, err } namespaceRepo := sv.namespaceRepoFactory.New(projectSpec) - namespaceSpec, err := namespaceRepo.GetByName(req.Namespace) + namespaceSpec, err := namespaceRepo.GetByName(ctx, req.Namespace) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: namespace %s not found", err.Error(), req.Namespace) } @@ -998,7 +998,7 @@ func (sv *RuntimeServiceServer) Backup(ctx context.Context, req *pb.BackupReques } var jobSpecs []models.JobSpec - jobSpec, err := sv.jobSvc.GetByDestination(projectSpec, resourceSpec.URN) + jobSpec, err := sv.jobSvc.GetByDestination(ctx, projectSpec, resourceSpec.URN) if err != nil { return nil, status.Errorf(codes.Internal, "error while getting job: %v", err) } @@ -1032,12 +1032,12 @@ func (sv *RuntimeServiceServer) Backup(ctx context.Context, req *pb.BackupReques } func (sv *RuntimeServiceServer) ListBackups(ctx context.Context, req *pb.ListBackupsRequest) (*pb.ListBackupsResponse, error) { - projectSpec, err := sv.getProjectSpec(req.ProjectName) + projectSpec, err := sv.getProjectSpec(ctx, req.ProjectName) if err != nil { return nil, err } - results, err := sv.resourceSvc.ListBackupResources(projectSpec, req.DatastoreName) + results, err := sv.resourceSvc.ListBackupResources(ctx, projectSpec, req.DatastoreName) if err != nil { return nil, status.Errorf(codes.Internal, "error while getting backup list: %v", err) } @@ -1058,12 +1058,12 @@ func (sv *RuntimeServiceServer) ListBackups(ctx context.Context, req *pb.ListBac func (sv *RuntimeServiceServer) RunJob(ctx context.Context, req *pb.RunJobRequest) (*pb.RunJobResponse, error) { // create job run in db - projSpec, err := sv.projectRepoFactory.New().GetByName(req.ProjectName) + projSpec, err := sv.projectRepoFactory.New().GetByName(ctx, req.ProjectName) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: project %s not found", err.Error(), req.ProjectName) } - namespaceSpec, err := sv.namespaceRepoFactory.New(projSpec).GetByName(req.Namespace) + namespaceSpec, err := sv.namespaceRepoFactory.New(projSpec).GetByName(ctx, req.Namespace) if err != nil { return nil, status.Errorf(codes.NotFound, "%s: namespace %s not found", err.Error(), req.Namespace) } diff --git a/api/handler/v1/runtime_test.go b/api/handler/v1/runtime_test.go index 587e8f013c..827c481a0e 100644 --- a/api/handler/v1/runtime_test.go +++ b/api/handler/v1/runtime_test.go @@ -21,7 +21,6 @@ import ( "github.com/odpf/optimus/core/tree" - "github.com/golang/protobuf/ptypes" "github.com/google/uuid" v1 "github.com/odpf/optimus/api/handler/v1" pb "github.com/odpf/optimus/api/proto/odpf/optimus" @@ -150,7 +149,7 @@ func TestRuntimeServiceServer(t *testing.T) { } t.Run("should register a new job instance with run for scheduled triggers", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -158,13 +157,13 @@ func TestRuntimeServiceServer(t *testing.T) { defer projectRepoFactory.AssertExpectations(t) jobService := new(mock.JobService) - jobService.On("GetByNameForProject", jobName, projectSpec).Return(jobSpec, namespaceSpec, nil) + jobService.On("GetByNameForProject", ctx, jobName, projectSpec).Return(jobSpec, namespaceSpec, nil) defer jobService.AssertExpectations(t) instanceService := new(mock.RunService) - instanceService.On("GetScheduledRun", namespaceSpec, jobSpec, scheduledAt).Return(jobRun, nil) - instanceService.On("Register", namespaceSpec, jobRun, instanceSpec.Type, instanceSpec.Name).Return(instanceSpec, nil) - instanceService.On("Compile", namespaceSpec, jobRun, instanceSpec).Return( + instanceService.On("GetScheduledRun", ctx, namespaceSpec, jobSpec, scheduledAt).Return(jobRun, nil) + instanceService.On("Register", ctx, namespaceSpec, jobRun, instanceSpec.Type, instanceSpec.Name).Return(instanceSpec, nil) + instanceService.On("Compile", ctx, namespaceSpec, jobRun, instanceSpec).Return( map[string]string{ run.ConfigKeyExecutionTime: mockedTimeNow.Format(models.InstanceScheduledAtTimeLayout), run.ConfigKeyDstart: jobSpec.Task.Window.GetStart(scheduledAt).Format(models.InstanceScheduledAtTimeLayout), @@ -221,7 +220,7 @@ func TestRuntimeServiceServer(t *testing.T) { }) t.Run("should find the existing job run if manually triggered", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -229,9 +228,9 @@ func TestRuntimeServiceServer(t *testing.T) { defer projectRepoFactory.AssertExpectations(t) instanceService := new(mock.RunService) - instanceService.On("GetByID", jobRun.ID).Return(jobRun, namespaceSpec, nil) - instanceService.On("Register", namespaceSpec, jobRun, instanceSpec.Type, instanceSpec.Name).Return(instanceSpec, nil) - instanceService.On("Compile", namespaceSpec, jobRun, instanceSpec).Return( + instanceService.On("GetByID", ctx, jobRun.ID).Return(jobRun, namespaceSpec, nil) + instanceService.On("Register", ctx, namespaceSpec, jobRun, instanceSpec.Type, instanceSpec.Name).Return(instanceSpec, nil) + instanceService.On("Compile", ctx, namespaceSpec, jobRun, instanceSpec).Return( map[string]string{ run.ConfigKeyExecutionTime: mockedTimeNow.Format(models.InstanceScheduledAtTimeLayout), run.ConfigKeyDstart: jobSpec.Task.Window.GetStart(scheduledAt).Format(models.InstanceScheduledAtTimeLayout), @@ -303,7 +302,7 @@ func TestRuntimeServiceServer(t *testing.T) { adapter := v1.NewAdapter(nil, nil) projectRepository := new(mock.ProjectRepository) - projectRepository.On("Save", projectSpec).Return(errors.New("a random error")) + projectRepository.On("Save", ctx, projectSpec).Return(errors.New("a random error")) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -344,7 +343,7 @@ func TestRuntimeServiceServer(t *testing.T) { adapter := v1.NewAdapter(nil, nil) projectRepository := new(mock.ProjectRepository) - projectRepository.On("Save", projectSpec).Return(nil) + projectRepository.On("Save", ctx, projectSpec).Return(nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -394,8 +393,8 @@ func TestRuntimeServiceServer(t *testing.T) { adapter := v1.NewAdapter(nil, nil) projectRepository := new(mock.ProjectRepository) - projectRepository.On("Save", projectSpec).Return(nil) - projectRepository.On("GetByName", projectSpec.Name).Return(projectSpec, nil) + projectRepository.On("Save", ctx, projectSpec).Return(nil) + projectRepository.On("GetByName", ctx, projectSpec.Name).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -406,7 +405,7 @@ func TestRuntimeServiceServer(t *testing.T) { defer jobSvc.AssertExpectations(t) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("Save", namespaceSpec).Return(nil) + namespaceRepository.On("Save", ctx, namespaceSpec).Return(nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -459,7 +458,7 @@ func TestRuntimeServiceServer(t *testing.T) { adapter := v1.NewAdapter(nil, nil) projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectSpec.Name).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectSpec.Name).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -470,7 +469,7 @@ func TestRuntimeServiceServer(t *testing.T) { defer jobSvc.AssertExpectations(t) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("Save", namespaceSpec).Return(nil) + namespaceRepository.On("Save", ctx, namespaceSpec).Return(nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -519,7 +518,7 @@ func TestRuntimeServiceServer(t *testing.T) { adapter := v1.NewAdapter(nil, nil) projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectSpec.Name).Return(projectSpec, errors.New("project does not exist")) + projectRepository.On("GetByName", ctx, projectSpec.Name).Return(projectSpec, errors.New("project does not exist")) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -619,7 +618,7 @@ func TestRuntimeServiceServer(t *testing.T) { } projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectSpec.Name).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectSpec.Name).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -627,13 +626,13 @@ func TestRuntimeServiceServer(t *testing.T) { defer projectRepoFactory.AssertExpectations(t) jobSvc := new(mock.JobService) - jobSvc.On("Create", jobSpec, namespaceSpec).Return(nil) + jobSvc.On("Create", ctx, jobSpec, namespaceSpec).Return(nil) jobSvc.On("Check", ctx, namespaceSpec, []models.JobSpec{jobSpec}, mock2.Anything).Return(nil) jobSvc.On("Sync", mock2.Anything, namespaceSpec, mock2.Anything).Return(nil) defer jobSvc.AssertExpectations(t) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -682,7 +681,7 @@ func TestRuntimeServiceServer(t *testing.T) { adapter := v1.NewAdapter(nil, nil) projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectSpec.Name).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectSpec.Name).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -695,7 +694,7 @@ func TestRuntimeServiceServer(t *testing.T) { } projectSecretRepository := new(mock.ProjectSecretRepository) - projectSecretRepository.On("Save", sec).Return(nil) + projectSecretRepository.On("Save", ctx, sec).Return(nil) defer projectSecretRepository.AssertExpectations(t) projectSecretRepoFactory := new(mock.ProjectSecretRepoFactory) @@ -741,7 +740,7 @@ func TestRuntimeServiceServer(t *testing.T) { adapter := v1.NewAdapter(nil, nil) projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectSpec.Name).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectSpec.Name).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -754,7 +753,7 @@ func TestRuntimeServiceServer(t *testing.T) { } projectSecretRepository := new(mock.ProjectSecretRepository) - projectSecretRepository.On("Save", sec).Return(errors.New("random error")) + projectSecretRepository.On("Save", ctx, sec).Return(errors.New("random error")) defer projectSecretRepository.AssertExpectations(t) projectSecretRepoFactory := new(mock.ProjectSecretRepoFactory) @@ -844,7 +843,7 @@ func TestRuntimeServiceServer(t *testing.T) { } projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -864,7 +863,7 @@ func TestRuntimeServiceServer(t *testing.T) { adapter := v1.NewAdapter(pluginRepo, nil) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -878,8 +877,8 @@ func TestRuntimeServiceServer(t *testing.T) { defer projectJobSpecRepoFactory.AssertExpectations(t) jobService := new(mock.JobService) - jobService.On("Create", mock2.Anything, namespaceSpec).Return(nil) - jobService.On("KeepOnly", namespaceSpec, mock2.Anything, mock2.Anything).Return(nil) + jobService.On("Create", ctx, mock2.Anything, namespaceSpec).Return(nil) + jobService.On("KeepOnly", ctx, namespaceSpec, mock2.Anything, mock2.Anything).Return(nil) jobService.On("Sync", mock2.Anything, namespaceSpec, mock2.Anything).Return(nil) defer jobService.AssertExpectations(t) @@ -968,7 +967,7 @@ func TestRuntimeServiceServer(t *testing.T) { } projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -980,7 +979,7 @@ func TestRuntimeServiceServer(t *testing.T) { adapter := v1.NewAdapter(allTasksRepo, nil) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -988,7 +987,7 @@ func TestRuntimeServiceServer(t *testing.T) { defer namespaceRepoFact.AssertExpectations(t) jobService := new(mock.JobService) - jobService.On("GetByName", jobSpecs[0].Name, namespaceSpec).Return(jobSpecs[0], nil) + jobService.On("GetByName", ctx, jobSpecs[0].Name, namespaceSpec).Return(jobSpecs[0], nil) defer jobService.AssertExpectations(t) runtimeServiceServer := v1.NewRuntimeServiceServer( @@ -1037,7 +1036,7 @@ func TestRuntimeServiceServer(t *testing.T) { } projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -1047,7 +1046,7 @@ func TestRuntimeServiceServer(t *testing.T) { adapter := v1.NewAdapter(nil, nil) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetAll").Return([]models.NamespaceSpec{namespaceSpec}, nil) + namespaceRepository.On("GetAll", ctx).Return([]models.NamespaceSpec{namespaceSpec}, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -1132,7 +1131,7 @@ func TestRuntimeServiceServer(t *testing.T) { } projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -1146,7 +1145,7 @@ func TestRuntimeServiceServer(t *testing.T) { adapter := v1.NewAdapter(pluginRepo, nil) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -1156,7 +1155,7 @@ func TestRuntimeServiceServer(t *testing.T) { jobSpec := jobSpecs[0] jobService := new(mock.JobService) - jobService.On("GetByName", jobSpecs[0].Name, namespaceSpec).Return(jobSpecs[0], nil) + jobService.On("GetByName", ctx, jobSpecs[0].Name, namespaceSpec).Return(jobSpecs[0], nil) jobService.On("Delete", mock2.Anything, namespaceSpec, jobSpec).Return(nil) defer jobService.AssertExpectations(t) @@ -1201,7 +1200,7 @@ func TestRuntimeServiceServer(t *testing.T) { } projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectSpec.Name).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectSpec.Name).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -1211,7 +1210,7 @@ func TestRuntimeServiceServer(t *testing.T) { adapter := v1.NewAdapter(nil, nil) jobService := new(mock.JobService) - jobService.On("GetByNameForProject", jobSpec.Name, projectSpec).Return(jobSpec, namespaceSpec, nil) + jobService.On("GetByNameForProject", ctx, jobSpec.Name, projectSpec).Return(jobSpec, namespaceSpec, nil) defer jobService.AssertExpectations(t) jobStatuses := []models.JobStatus{ @@ -1225,7 +1224,7 @@ func TestRuntimeServiceServer(t *testing.T) { }, } scheduler := new(mock.Scheduler) - scheduler.On("GetJobStatus", context.Background(), projectSpec, jobSpec.Name).Return(jobStatuses, nil) + scheduler.On("GetJobStatus", ctx, projectSpec, jobSpec.Name).Return(jobStatuses, nil) defer scheduler.AssertExpectations(t) runtimeServiceServer := v1.NewRuntimeServiceServer( @@ -1286,7 +1285,7 @@ func TestRuntimeServiceServer(t *testing.T) { } projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectSpec.Name).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectSpec.Name).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -1296,7 +1295,7 @@ func TestRuntimeServiceServer(t *testing.T) { adapter := v1.NewAdapter(nil, nil) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -1304,7 +1303,7 @@ func TestRuntimeServiceServer(t *testing.T) { defer namespaceRepoFact.AssertExpectations(t) jobService := new(mock.JobService) - jobService.On("GetByName", jobSpecs[0].Name, namespaceSpec).Return(jobSpecs[0], nil) + jobService.On("GetByName", ctx, jobSpecs[0].Name, namespaceSpec).Return(jobSpecs[0], nil) defer jobService.AssertExpectations(t) eventValues, _ := structpb.NewStruct( @@ -1313,7 +1312,7 @@ func TestRuntimeServiceServer(t *testing.T) { }, ) eventSvc := new(mock.EventService) - eventSvc.On("Register", context.Background(), namespaceSpec, jobSpecs[0], models.JobEvent{ + eventSvc.On("Register", ctx, namespaceSpec, jobSpecs[0], models.JobEvent{ Type: models.JobEventTypeFailure, Value: eventValues.GetFields(), }).Return(nil) @@ -1372,8 +1371,8 @@ func TestRuntimeServiceServer(t *testing.T) { resp, err := runtimeServiceServer.GetWindow(context.Background(), &req) assert.Nil(t, err) - assert.Equal(t, "2020-11-11T00:00:00Z", ptypes.TimestampString(resp.GetStart())) - assert.Equal(t, "2020-11-12T00:00:00Z", ptypes.TimestampString(resp.GetEnd())) + assert.Equal(t, "2020-11-11T00:00:00Z", resp.GetStart().AsTime().Format(time.RFC3339)) + assert.Equal(t, "2020-11-12T00:00:00Z", resp.GetEnd().AsTime().Format(time.RFC3339)) }) t.Run("should return error if any of the required fields in request is missing", func(t *testing.T) { Version := "1.0.1" @@ -1463,7 +1462,7 @@ func TestRuntimeServiceServer(t *testing.T) { } projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -1475,7 +1474,7 @@ func TestRuntimeServiceServer(t *testing.T) { defer resourceSvc.AssertExpectations(t) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -1562,7 +1561,7 @@ func TestRuntimeServiceServer(t *testing.T) { } projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -1574,7 +1573,7 @@ func TestRuntimeServiceServer(t *testing.T) { defer resourceSvc.AssertExpectations(t) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -1655,12 +1654,12 @@ func TestRuntimeServiceServer(t *testing.T) { dagNode.Runs.Add(time.Date(2020, 11, 28, 2, 0, 0, 0, time.UTC)) jobService := new(mock.JobService) - jobService.On("GetByName", jobName, namespaceSpec).Return(jobSpec, nil) - jobService.On("ReplayDryRun", replayWorkerRequest).Return(dagNode, nil) + jobService.On("GetByName", ctx, jobName, namespaceSpec).Return(jobSpec, nil) + jobService.On("ReplayDryRun", ctx, replayWorkerRequest).Return(dagNode, nil) defer jobService.AssertExpectations(t) projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -1668,7 +1667,7 @@ func TestRuntimeServiceServer(t *testing.T) { defer projectRepoFactory.AssertExpectations(t) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -1709,11 +1708,11 @@ func TestRuntimeServiceServer(t *testing.T) { endDate := time.Date(2020, 11, 24, 0, 0, 0, 0, time.UTC) jobService := new(mock.JobService) - jobService.On("GetByName", jobName, namespaceSpec).Return(jobSpec, nil) + jobService.On("GetByName", ctx, jobName, namespaceSpec).Return(jobSpec, nil) defer jobService.AssertExpectations(t) projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -1721,7 +1720,7 @@ func TestRuntimeServiceServer(t *testing.T) { defer projectRepoFactory.AssertExpectations(t) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -1765,12 +1764,12 @@ func TestRuntimeServiceServer(t *testing.T) { dagNode := tree.NewTreeNode(jobSpec) jobService := new(mock.JobService) - jobService.On("GetByName", jobName, namespaceSpec).Return(jobSpec, nil) - jobService.On("ReplayDryRun", replayWorkerRequest).Return(dagNode, errors.New("populating jobs spec failed")) + jobService.On("GetByName", ctx, jobName, namespaceSpec).Return(jobSpec, nil) + jobService.On("ReplayDryRun", ctx, replayWorkerRequest).Return(dagNode, errors.New("populating jobs spec failed")) defer jobService.AssertExpectations(t) projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -1778,7 +1777,7 @@ func TestRuntimeServiceServer(t *testing.T) { defer projectRepoFactory.AssertExpectations(t) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -1862,7 +1861,7 @@ func TestRuntimeServiceServer(t *testing.T) { randomUUID := "random-uuid" projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -1870,12 +1869,12 @@ func TestRuntimeServiceServer(t *testing.T) { defer projectRepoFactory.AssertExpectations(t) jobService := new(mock.JobService) - jobService.On("GetByName", jobName, namespaceSpec).Return(jobSpec, nil) - jobService.On("Replay", context.TODO(), replayWorkerRequest).Return(randomUUID, nil) + jobService.On("GetByName", ctx, jobName, namespaceSpec).Return(jobSpec, nil) + jobService.On("Replay", ctx, replayWorkerRequest).Return(randomUUID, nil) defer jobService.AssertExpectations(t) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -1909,7 +1908,7 @@ func TestRuntimeServiceServer(t *testing.T) { }) t.Run("should failed when replay request is invalid", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -1917,7 +1916,7 @@ func TestRuntimeServiceServer(t *testing.T) { defer projectRepoFactory.AssertExpectations(t) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(models.NamespaceSpec{}, errors.New("Namespace not found")) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(models.NamespaceSpec{}, errors.New("Namespace not found")) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -1960,7 +1959,7 @@ func TestRuntimeServiceServer(t *testing.T) { errMessage := "internal error" projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -1968,12 +1967,12 @@ func TestRuntimeServiceServer(t *testing.T) { defer projectRepoFactory.AssertExpectations(t) jobService := new(mock.JobService) - jobService.On("GetByName", jobName, namespaceSpec).Return(jobSpec, nil) - jobService.On("Replay", context.TODO(), replayWorkerRequest).Return(emptyUUID, errors.New(errMessage)) + jobService.On("GetByName", ctx, jobName, namespaceSpec).Return(jobSpec, nil) + jobService.On("Replay", ctx, replayWorkerRequest).Return(emptyUUID, errors.New(errMessage)) defer jobService.AssertExpectations(t) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -2010,7 +2009,7 @@ func TestRuntimeServiceServer(t *testing.T) { t.Run("should failed when project is not found", func(t *testing.T) { errMessage := "project not found" projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(models.ProjectSpec{}, errors.New(errMessage)) + projectRepository.On("GetByName", ctx, projectName).Return(models.ProjectSpec{}, errors.New(errMessage)) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -2048,7 +2047,7 @@ func TestRuntimeServiceServer(t *testing.T) { errMessage := "job not found in namespace" projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -2056,11 +2055,11 @@ func TestRuntimeServiceServer(t *testing.T) { defer projectRepoFactory.AssertExpectations(t) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) jobService := new(mock.JobService) - jobService.On("GetByName", jobName, namespaceSpec).Return(models.JobSpec{}, errors.New(errMessage)) + jobService.On("GetByName", ctx, jobName, namespaceSpec).Return(models.JobSpec{}, errors.New(errMessage)) defer jobService.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -2103,7 +2102,7 @@ func TestRuntimeServiceServer(t *testing.T) { emptyUUID := "" projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -2111,12 +2110,12 @@ func TestRuntimeServiceServer(t *testing.T) { defer projectRepoFactory.AssertExpectations(t) jobService := new(mock.JobService) - jobService.On("GetByName", jobName, namespaceSpec).Return(jobSpec, nil) - jobService.On("Replay", context.TODO(), replayWorkerRequest).Return(emptyUUID, job.ErrConflictedJobRun) + jobService.On("GetByName", ctx, jobName, namespaceSpec).Return(jobSpec, nil) + jobService.On("Replay", ctx, replayWorkerRequest).Return(emptyUUID, job.ErrConflictedJobRun) defer jobService.AssertExpectations(t) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -2160,7 +2159,7 @@ func TestRuntimeServiceServer(t *testing.T) { emptyUUID := "" projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -2168,12 +2167,12 @@ func TestRuntimeServiceServer(t *testing.T) { defer projectRepoFactory.AssertExpectations(t) jobService := new(mock.JobService) - jobService.On("GetByName", jobName, namespaceSpec).Return(jobSpec, nil) - jobService.On("Replay", context.TODO(), replayWorkerRequest).Return(emptyUUID, job.ErrRequestQueueFull) + jobService.On("GetByName", ctx, jobName, namespaceSpec).Return(jobSpec, nil) + jobService.On("Replay", ctx, replayWorkerRequest).Return(emptyUUID, job.ErrRequestQueueFull) defer jobService.AssertExpectations(t) namespaceRepository := new(mock.NamespaceRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) defer namespaceRepository.AssertExpectations(t) namespaceRepoFact := new(mock.NamespaceRepoFactory) @@ -2264,7 +2263,7 @@ func TestRuntimeServiceServer(t *testing.T) { } projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -2306,7 +2305,7 @@ func TestRuntimeServiceServer(t *testing.T) { }) t.Run("should failed when unable to get status of a replay", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -2415,7 +2414,7 @@ func TestRuntimeServiceServer(t *testing.T) { } projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -2424,7 +2423,7 @@ func TestRuntimeServiceServer(t *testing.T) { jobService := new(mock.JobService) defer jobService.AssertExpectations(t) - jobService.On("GetReplayList", projectSpec.ID).Return(replaySpecs, nil) + jobService.On("GetReplayList", ctx, projectSpec.ID).Return(replaySpecs, nil) adapter := v1.NewAdapter(nil, nil) @@ -2452,7 +2451,7 @@ func TestRuntimeServiceServer(t *testing.T) { }) t.Run("should failed when unable to get status of a replay", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -2462,7 +2461,7 @@ func TestRuntimeServiceServer(t *testing.T) { errMessage := "internal error" jobService := new(mock.JobService) defer jobService.AssertExpectations(t) - jobService.On("GetReplayList", projectSpec.ID).Return([]models.ReplaySpec{}, errors.New(errMessage)) + jobService.On("GetReplayList", ctx, projectSpec.ID).Return([]models.ReplaySpec{}, errors.New(errMessage)) adapter := v1.NewAdapter(nil, nil) @@ -2483,7 +2482,7 @@ func TestRuntimeServiceServer(t *testing.T) { replayRequestPb := pb.ListReplaysRequest{ ProjectName: projectName, } - replayListResponse, err := runtimeServiceServer.ListReplays(context.Background(), &replayRequestPb) + replayListResponse, err := runtimeServiceServer.ListReplays(ctx, &replayRequestPb) assert.Contains(t, err.Error(), errMessage) assert.Nil(t, replayListResponse) @@ -2531,7 +2530,7 @@ func TestRuntimeServiceServer(t *testing.T) { projectRepository := new(mock.ProjectRepository) defer projectRepository.AssertExpectations(t) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) projectRepoFactory := new(mock.ProjectRepoFactory) defer projectRepoFactory.AssertExpectations(t) @@ -2539,7 +2538,7 @@ func TestRuntimeServiceServer(t *testing.T) { namespaceRepository := new(mock.NamespaceRepository) defer namespaceRepository.AssertExpectations(t) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) namespaceRepoFact := new(mock.NamespaceRepoFactory) defer namespaceRepoFact.AssertExpectations(t) @@ -2627,7 +2626,7 @@ func TestRuntimeServiceServer(t *testing.T) { projectRepository := new(mock.ProjectRepository) defer projectRepository.AssertExpectations(t) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) projectRepoFactory := new(mock.ProjectRepoFactory) defer projectRepoFactory.AssertExpectations(t) @@ -2635,7 +2634,7 @@ func TestRuntimeServiceServer(t *testing.T) { namespaceRepository := new(mock.NamespaceRepository) defer namespaceRepository.AssertExpectations(t) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) namespaceRepoFact := new(mock.NamespaceRepoFactory) defer namespaceRepoFact.AssertExpectations(t) @@ -2696,7 +2695,7 @@ func TestRuntimeServiceServer(t *testing.T) { projectRepository := new(mock.ProjectRepository) defer projectRepository.AssertExpectations(t) errorMsg := "unable to fetch project" - projectRepository.On("GetByName", projectName).Return(models.ProjectSpec{}, errors.New(errorMsg)) + projectRepository.On("GetByName", ctx, projectName).Return(models.ProjectSpec{}, errors.New(errorMsg)) projectRepoFactory := new(mock.ProjectRepoFactory) defer projectRepoFactory.AssertExpectations(t) @@ -2738,7 +2737,7 @@ func TestRuntimeServiceServer(t *testing.T) { t.Run("should return error when namespace is not found", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) defer projectRepository.AssertExpectations(t) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) projectRepoFactory := new(mock.ProjectRepoFactory) defer projectRepoFactory.AssertExpectations(t) @@ -2747,7 +2746,7 @@ func TestRuntimeServiceServer(t *testing.T) { namespaceRepository := new(mock.NamespaceRepository) defer namespaceRepository.AssertExpectations(t) errorMsg := "unable to get namespace" - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(models.NamespaceSpec{}, errors.New(errorMsg)) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(models.NamespaceSpec{}, errors.New(errorMsg)) namespaceRepoFact := new(mock.NamespaceRepoFactory) defer namespaceRepoFact.AssertExpectations(t) @@ -2782,7 +2781,7 @@ func TestRuntimeServiceServer(t *testing.T) { t.Run("should return error when unable to read resource", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) defer projectRepository.AssertExpectations(t) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) projectRepoFactory := new(mock.ProjectRepoFactory) defer projectRepoFactory.AssertExpectations(t) @@ -2790,7 +2789,7 @@ func TestRuntimeServiceServer(t *testing.T) { namespaceRepository := new(mock.NamespaceRepository) defer namespaceRepository.AssertExpectations(t) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) namespaceRepoFact := new(mock.NamespaceRepoFactory) defer namespaceRepoFact.AssertExpectations(t) @@ -2835,7 +2834,7 @@ func TestRuntimeServiceServer(t *testing.T) { t.Run("should return error when unable to get jobSpec", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) defer projectRepository.AssertExpectations(t) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) projectRepoFactory := new(mock.ProjectRepoFactory) defer projectRepoFactory.AssertExpectations(t) @@ -2843,7 +2842,7 @@ func TestRuntimeServiceServer(t *testing.T) { namespaceRepository := new(mock.NamespaceRepository) defer namespaceRepository.AssertExpectations(t) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) namespaceRepoFact := new(mock.NamespaceRepoFactory) defer namespaceRepoFact.AssertExpectations(t) @@ -2912,7 +2911,7 @@ func TestRuntimeServiceServer(t *testing.T) { projectRepository := new(mock.ProjectRepository) defer projectRepository.AssertExpectations(t) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) projectRepoFactory := new(mock.ProjectRepoFactory) defer projectRepoFactory.AssertExpectations(t) @@ -2920,7 +2919,7 @@ func TestRuntimeServiceServer(t *testing.T) { namespaceRepository := new(mock.NamespaceRepository) defer namespaceRepository.AssertExpectations(t) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) namespaceRepoFact := new(mock.NamespaceRepoFactory) defer namespaceRepoFact.AssertExpectations(t) @@ -2990,7 +2989,7 @@ func TestRuntimeServiceServer(t *testing.T) { projectRepository := new(mock.ProjectRepository) defer projectRepository.AssertExpectations(t) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) projectRepoFactory := new(mock.ProjectRepoFactory) defer projectRepoFactory.AssertExpectations(t) @@ -2998,7 +2997,7 @@ func TestRuntimeServiceServer(t *testing.T) { namespaceRepository := new(mock.NamespaceRepository) defer namespaceRepository.AssertExpectations(t) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) namespaceRepoFact := new(mock.NamespaceRepoFactory) defer namespaceRepoFact.AssertExpectations(t) @@ -3139,10 +3138,10 @@ func TestRuntimeServiceServer(t *testing.T) { Urn: []string{backupUrn}, } - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) projectRepoFactory.On("New").Return(projectRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) namespaceRepoFact.On("New", projectSpec).Return(namespaceRepository) resourceSvc.On("ReadResource", context.Background(), namespaceSpec, models.DestinationTypeBigquery.String(), resourceName).Return(resourceSpec, nil) @@ -3244,10 +3243,10 @@ func TestRuntimeServiceServer(t *testing.T) { } backupResults := []string{backupUrn, backupDownstream1Urn, backupDownstream2Urn} - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) projectRepoFactory.On("New").Return(projectRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) namespaceRepoFact.On("New", projectSpec).Return(namespaceRepository) resourceSvc.On("ReadResource", context.Background(), namespaceSpec, models.DestinationTypeBigquery.String(), resourceName).Return(resourceSpec, nil) @@ -3299,7 +3298,7 @@ func TestRuntimeServiceServer(t *testing.T) { projectRepoFactory.On("New").Return(projectRepository) errorMsg := "unable to fetch project" - projectRepository.On("GetByName", projectName).Return(models.ProjectSpec{}, errors.New(errorMsg)) + projectRepository.On("GetByName", ctx, projectName).Return(models.ProjectSpec{}, errors.New(errorMsg)) runtimeServiceServer := v1.NewRuntimeServiceServer( log, @@ -3339,10 +3338,10 @@ func TestRuntimeServiceServer(t *testing.T) { Namespace: namespaceSpec.Name, } - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) projectRepoFactory.On("New").Return(projectRepository) errorMsg := "unable to get namespace" - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(models.NamespaceSpec{}, errors.New(errorMsg)) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(models.NamespaceSpec{}, errors.New(errorMsg)) namespaceRepoFact.On("New", projectSpec).Return(namespaceRepository) runtimeServiceServer := v1.NewRuntimeServiceServer( @@ -3391,9 +3390,9 @@ func TestRuntimeServiceServer(t *testing.T) { Namespace: namespaceSpec.Name, } - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) projectRepoFactory.On("New").Return(projectRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) namespaceRepoFact.On("New", projectSpec).Return(namespaceRepository) errorMsg := "unable to read resource" resourceSvc.On("ReadResource", context.Background(), namespaceSpec, @@ -3447,9 +3446,9 @@ func TestRuntimeServiceServer(t *testing.T) { Namespace: namespaceSpec.Name, } - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) projectRepoFactory.On("New").Return(projectRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) namespaceRepoFact.On("New", projectSpec).Return(namespaceRepository) resourceSvc.On("ReadResource", context.Background(), namespaceSpec, models.DestinationTypeBigquery.String(), resourceName).Return(resourceSpec, nil) errorMsg := "unable to get jobspec" @@ -3523,9 +3522,9 @@ func TestRuntimeServiceServer(t *testing.T) { Namespace: namespaceSpec.Name, } - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) projectRepoFactory.On("New").Return(projectRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) namespaceRepoFact.On("New", projectSpec).Return(namespaceRepository) resourceSvc.On("ReadResource", context.Background(), namespaceSpec, models.DestinationTypeBigquery.String(), resourceName).Return(resourceSpec, nil) jobService.On("GetByDestination", projectSpec, resourceUrn).Return(jobSpec, nil) @@ -3613,9 +3612,9 @@ func TestRuntimeServiceServer(t *testing.T) { }, } - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) projectRepoFactory.On("New").Return(projectRepository) - namespaceRepository.On("GetByName", namespaceSpec.Name).Return(namespaceSpec, nil) + namespaceRepository.On("GetByName", ctx, namespaceSpec.Name).Return(namespaceSpec, nil) namespaceRepoFact.On("New", projectSpec).Return(namespaceRepository) resourceSvc.On("ReadResource", context.Background(), namespaceSpec, models.DestinationTypeBigquery.String(), resourceName).Return(resourceSpec, nil) jobService.On("GetByDestination", projectSpec, resourceUrn).Return(jobSpec, nil) @@ -3709,8 +3708,8 @@ func TestRuntimeServiceServer(t *testing.T) { } projectRepoFactory.On("New").Return(projectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) - resourceSvc.On("ListBackupResources", projectSpec, datastoreName).Return(backupSpecs, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) + resourceSvc.On("ListBackupResources", ctx, projectSpec, datastoreName).Return(backupSpecs, nil) runtimeServiceServer := v1.NewRuntimeServiceServer( log, @@ -3742,7 +3741,7 @@ func TestRuntimeServiceServer(t *testing.T) { projectRepoFactory.On("New").Return(projectRepository) errorMsg := "unable to get project spec" - projectRepository.On("GetByName", projectName).Return(models.ProjectSpec{}, + projectRepository.On("GetByName", ctx, projectName).Return(models.ProjectSpec{}, errors.New(errorMsg)) runtimeServiceServer := v1.NewRuntimeServiceServer( @@ -3774,9 +3773,9 @@ func TestRuntimeServiceServer(t *testing.T) { defer resourceSvc.AssertExpectations(t) projectRepoFactory.On("New").Return(projectRepository) - projectRepository.On("GetByName", projectName).Return(projectSpec, nil) + projectRepository.On("GetByName", ctx, projectName).Return(projectSpec, nil) errorMsg := "unable to get list of backups" - resourceSvc.On("ListBackupResources", projectSpec, datastoreName).Return([]models.BackupSpec{}, errors.New(errorMsg)) + resourceSvc.On("ListBackupResources", ctx, projectSpec, datastoreName).Return([]models.BackupSpec{}, errors.New(errorMsg)) runtimeServiceServer := v1.NewRuntimeServiceServer( log, diff --git a/cmd/create.go b/cmd/create.go index 662f892af1..ec712674c6 100644 --- a/cmd/create.go +++ b/cmd/create.go @@ -555,7 +555,7 @@ func createResourceSubCommand(l log.Logger, datastoreSpecFs map[string]afero.Fs, func IsResourceNameUnique(repository store.ResourceSpecRepository) survey.Validator { return func(val interface{}) error { if str, ok := val.(string); ok { - if _, err := repository.GetByName(str); err == nil { + if _, err := repository.GetByName(context.Background(), str); err == nil { return fmt.Errorf("resource with the provided name already exists") } else if err != models.ErrNoSuchSpec && err != models.ErrNoResources { return err diff --git a/cmd/deploy.go b/cmd/deploy.go index 1bd0615088..3e67c421d4 100644 --- a/cmd/deploy.go +++ b/cmd/deploy.go @@ -113,7 +113,7 @@ func postDeploymentRequest(l log.Logger, projectName string, namespace string, j return fmt.Errorf("unsupported datastore: %s\n", storeName) } resourceSpecRepo := local.NewResourceSpecRepository(repoFS, ds) - resourceSpecs, err := resourceSpecRepo.GetAll() + resourceSpecs, err := resourceSpecRepo.GetAll(context.Background()) if err == models.ErrNoResources { l.Info(coloredNotice("no resource specifications found")) continue diff --git a/cmd/server/server.go b/cmd/server/server.go index c337650aeb..5c37a60c56 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -16,7 +16,6 @@ import ( grpctags "github.com/grpc-ecosystem/go-grpc-middleware/tags" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/hashicorp/go-multierror" - "github.com/jinzhu/gorm" v1 "github.com/odpf/optimus/api/handler/v1" v1handler "github.com/odpf/optimus/api/handler/v1" pb "github.com/odpf/optimus/api/proto/odpf/optimus" @@ -54,6 +53,7 @@ import ( "golang.org/x/oauth2/google" "google.golang.org/grpc" "google.golang.org/grpc/reflection" + "gorm.io/gorm" ) var ( @@ -338,7 +338,7 @@ func Initialize(l log.Logger, conf config.Provider) error { hash: appHash, } if !conf.GetScheduler().SkipInit { - registeredProjects, err := projectRepoFac.New().GetAll() + registeredProjects, err := projectRepoFac.New().GetAll(context.Background()) if err != nil { return errors.Wrap(err, "projectRepoFactory.GetAll()") } @@ -598,6 +598,14 @@ func Initialize(l log.Logger, conf config.Provider) error { clusterPlanner.Close() clusterServer.Shutdown() + sqlConn, err := dbConn.DB() + if err != nil { + terminalError = multierror.Append(terminalError, errors.Wrap(err, "dbConn.DB")) + } + if err := sqlConn.Close(); err != nil { + terminalError = multierror.Append(terminalError, errors.Wrap(err, "sqlConn.Close")) + } + l.Info("bye") return terminalError } diff --git a/datastore/service.go b/datastore/service.go index 4e36bb708c..28277e64ec 100644 --- a/datastore/service.go +++ b/datastore/service.go @@ -41,12 +41,12 @@ type Service struct { uuidProvider utils.UUIDProvider } -func (srv Service) GetAll(namespace models.NamespaceSpec, datastoreName string) ([]models.ResourceSpec, error) { +func (srv Service) GetAll(ctx context.Context, namespace models.NamespaceSpec, datastoreName string) ([]models.ResourceSpec, error) { ds, err := srv.dsRepo.GetByName(datastoreName) if err != nil { return nil, err } - return srv.resourceRepoFactory.New(namespace, ds).GetAll() + return srv.resourceRepoFactory.New(namespace, ds).GetAll(ctx) } func (srv Service) CreateResource(ctx context.Context, namespace models.NamespaceSpec, resourceSpecs []models.ResourceSpec, obs progress.Observer) error { @@ -55,7 +55,7 @@ func (srv Service) CreateResource(ctx context.Context, namespace models.Namespac currentSpec := resourceSpec repo := srv.resourceRepoFactory.New(namespace, currentSpec.Datastore) runner.Add(func() (interface{}, error) { - if err := repo.Save(currentSpec); err != nil { + if err := repo.Save(ctx, currentSpec); err != nil { return nil, err } @@ -86,7 +86,7 @@ func (srv Service) UpdateResource(ctx context.Context, namespace models.Namespac currentSpec := resourceSpec repo := srv.resourceRepoFactory.New(namespace, currentSpec.Datastore) runner.Add(func() (interface{}, error) { - if err := repo.Save(currentSpec); err != nil { + if err := repo.Save(ctx, currentSpec); err != nil { return nil, err } @@ -117,7 +117,7 @@ func (srv Service) ReadResource(ctx context.Context, namespace models.NamespaceS return models.ResourceSpec{}, err } repo := srv.resourceRepoFactory.New(namespace, ds) - dbSpec, err := repo.GetByName(name) + dbSpec, err := repo.GetByName(ctx, name) if err != nil { return models.ResourceSpec{}, err } @@ -138,7 +138,7 @@ func (srv Service) DeleteResource(ctx context.Context, namespace models.Namespac return err } repo := srv.resourceRepoFactory.New(namespace, ds) - resourceSpec, err := repo.GetByName(name) + resourceSpec, err := repo.GetByName(ctx, name) if err != nil { return err } @@ -151,7 +151,7 @@ func (srv Service) DeleteResource(ctx context.Context, namespace models.Namespac return err } - return repo.Delete(name) + return repo.Delete(ctx, name) } func generateResourceDestination(ctx context.Context, jobSpec models.JobSpec) (*models.GenerateDestinationResponse, error) { @@ -161,9 +161,9 @@ func generateResourceDestination(ctx context.Context, jobSpec models.JobSpec) (* }) } -func (srv Service) getResourceSpec(datastorer models.Datastorer, namespace models.NamespaceSpec, destinationURN string) (models.ResourceSpec, error) { +func (srv Service) getResourceSpec(ctx context.Context, datastorer models.Datastorer, namespace models.NamespaceSpec, destinationURN string) (models.ResourceSpec, error) { repo := srv.resourceRepoFactory.New(namespace, datastorer) - return repo.GetByURN(destinationURN) + return repo.GetByURN(ctx, destinationURN) } func (srv Service) BackupResourceDryRun(ctx context.Context, backupRequest models.BackupRequest, jobSpecs []models.JobSpec) ([]string, error) { @@ -179,7 +179,7 @@ func (srv Service) BackupResourceDryRun(ctx context.Context, backupRequest model return nil, err } - resourceSpec, err := srv.getResourceSpec(datastorer, backupRequest.Namespace, destination.URN()) + resourceSpec, err := srv.getResourceSpec(ctx, datastorer, backupRequest.Namespace, destination.URN()) if err != nil { if err == store.ErrResourceNotFound { continue @@ -223,7 +223,7 @@ func (srv Service) BackupResource(ctx context.Context, backupRequest models.Back return nil, err } - resourceSpec, err := srv.getResourceSpec(datastorer, backupRequest.Namespace, destination.URN()) + resourceSpec, err := srv.getResourceSpec(ctx, datastorer, backupRequest.Namespace, destination.URN()) if err != nil { if err == store.ErrResourceNotFound { continue @@ -258,21 +258,21 @@ func (srv Service) BackupResource(ctx context.Context, backupRequest models.Back //save the backup backupRepo := srv.backupRepoFactory.New(backupRequest.Project, backupSpec.Resource.Datastore) - if err := backupRepo.Save(backupSpec); err != nil { + if err := backupRepo.Save(ctx, backupSpec); err != nil { return nil, err } return backupResult, nil } -func (srv Service) ListBackupResources(projectSpec models.ProjectSpec, datastoreName string) ([]models.BackupSpec, error) { +func (srv Service) ListBackupResources(ctx context.Context, projectSpec models.ProjectSpec, datastoreName string) ([]models.BackupSpec, error) { datastorer, err := srv.dsRepo.GetByName(datastoreName) if err != nil { return []models.BackupSpec{}, err } backupRepo := srv.backupRepoFactory.New(projectSpec, datastorer) - backupSpecs, err := backupRepo.GetAll() + backupSpecs, err := backupRepo.GetAll(ctx) if err != nil { if err == store.ErrResourceNotFound { return []models.BackupSpec{}, nil diff --git a/datastore/service_test.go b/datastore/service_test.go index e46be09f4c..133fa81f3c 100644 --- a/datastore/service_test.go +++ b/datastore/service_test.go @@ -26,6 +26,7 @@ func TestService(t *testing.T) { "bucket": "gs://some_folder", }, } + ctx := context.Background() namespaceSpec := models.NamespaceSpec{ ID: uuid.Must(uuid.NewRandom()), @@ -50,7 +51,7 @@ func TestService(t *testing.T) { } resourceRepo := new(mock.ResourceSpecRepository) - resourceRepo.On("GetAll").Return([]models.ResourceSpec{resourceSpec1}, nil) + resourceRepo.On("GetAll", ctx).Return([]models.ResourceSpec{resourceSpec1}, nil) defer resourceRepo.AssertExpectations(t) resourceRepoFac := new(mock.ResourceSpecRepoFactory) @@ -61,7 +62,7 @@ func TestService(t *testing.T) { defer projectResourceRepoFac.AssertExpectations(t) service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - res, err := service.GetAll(namespaceSpec, "bq") + res, err := service.GetAll(ctx, namespaceSpec, "bq") assert.Nil(t, err) assert.Equal(t, []models.ResourceSpec{resourceSpec1}, res) }) @@ -87,18 +88,18 @@ func TestService(t *testing.T) { Type: models.ResourceTypeDataset, Datastore: datastorer, } - datastorer.On("CreateResource", context.TODO(), models.CreateResourceRequest{ + datastorer.On("CreateResource", ctx, models.CreateResourceRequest{ Project: projectSpec, Resource: resourceSpec1, }).Return(nil) - datastorer.On("CreateResource", context.TODO(), models.CreateResourceRequest{ + datastorer.On("CreateResource", ctx, models.CreateResourceRequest{ Project: projectSpec, Resource: resourceSpec2, }).Return(nil) resourceRepo := new(mock.ResourceSpecRepository) - resourceRepo.On("Save", resourceSpec1).Return(nil) - resourceRepo.On("Save", resourceSpec2).Return(nil) + resourceRepo.On("Save", ctx, resourceSpec1).Return(nil) + resourceRepo.On("Save", ctx, resourceSpec2).Return(nil) defer resourceRepo.AssertExpectations(t) resourceRepoFac := new(mock.ResourceSpecRepoFactory) @@ -109,7 +110,7 @@ func TestService(t *testing.T) { defer projectResourceRepoFac.AssertExpectations(t) service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - err := service.CreateResource(context.TODO(), namespaceSpec, []models.ResourceSpec{resourceSpec1, resourceSpec2}, nil) + err := service.CreateResource(ctx, namespaceSpec, []models.ResourceSpec{resourceSpec1, resourceSpec2}, nil) assert.Nil(t, err) }) t.Run("should not call create in datastore if failed to save in repository", func(t *testing.T) { @@ -131,14 +132,14 @@ func TestService(t *testing.T) { Type: models.ResourceTypeDataset, Datastore: datastorer, } - datastorer.On("CreateResource", context.TODO(), models.CreateResourceRequest{ + datastorer.On("CreateResource", ctx, models.CreateResourceRequest{ Project: projectSpec, Resource: resourceSpec2, }).Return(nil) resourceRepo := new(mock.ResourceSpecRepository) - resourceRepo.On("Save", resourceSpec1).Return(errors.New("cant save, too busy")) - resourceRepo.On("Save", resourceSpec2).Return(nil) + resourceRepo.On("Save", ctx, resourceSpec1).Return(errors.New("cant save, too busy")) + resourceRepo.On("Save", ctx, resourceSpec2).Return(nil) defer resourceRepo.AssertExpectations(t) resourceRepoFac := new(mock.ResourceSpecRepoFactory) @@ -149,7 +150,7 @@ func TestService(t *testing.T) { defer projectResourceRepoFac.AssertExpectations(t) service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - err := service.CreateResource(context.TODO(), namespaceSpec, []models.ResourceSpec{resourceSpec1, resourceSpec2}, nil) + err := service.CreateResource(ctx, namespaceSpec, []models.ResourceSpec{resourceSpec1, resourceSpec2}, nil) assert.NotNil(t, err) }) }) @@ -173,18 +174,18 @@ func TestService(t *testing.T) { Type: models.ResourceTypeDataset, Datastore: datastorer, } - datastorer.On("UpdateResource", context.TODO(), models.UpdateResourceRequest{ + datastorer.On("UpdateResource", ctx, models.UpdateResourceRequest{ Project: projectSpec, Resource: resourceSpec1, }).Return(nil) - datastorer.On("UpdateResource", context.TODO(), models.UpdateResourceRequest{ + datastorer.On("UpdateResource", ctx, models.UpdateResourceRequest{ Project: projectSpec, Resource: resourceSpec2, }).Return(nil) resourceRepo := new(mock.ResourceSpecRepository) - resourceRepo.On("Save", resourceSpec1).Return(nil) - resourceRepo.On("Save", resourceSpec2).Return(nil) + resourceRepo.On("Save", ctx, resourceSpec1).Return(nil) + resourceRepo.On("Save", ctx, resourceSpec2).Return(nil) defer resourceRepo.AssertExpectations(t) resourceRepoFac := new(mock.ResourceSpecRepoFactory) @@ -195,7 +196,7 @@ func TestService(t *testing.T) { defer projectResourceRepoFac.AssertExpectations(t) service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - err := service.UpdateResource(context.TODO(), namespaceSpec, []models.ResourceSpec{resourceSpec1, resourceSpec2}, nil) + err := service.UpdateResource(ctx, namespaceSpec, []models.ResourceSpec{resourceSpec1, resourceSpec2}, nil) assert.Nil(t, err) }) t.Run("should not call update in datastore if failed to save in repository", func(t *testing.T) { @@ -217,14 +218,14 @@ func TestService(t *testing.T) { Type: models.ResourceTypeDataset, Datastore: datastorer, } - datastorer.On("UpdateResource", context.TODO(), models.UpdateResourceRequest{ + datastorer.On("UpdateResource", ctx, models.UpdateResourceRequest{ Project: projectSpec, Resource: resourceSpec2, }).Return(nil) resourceRepo := new(mock.ResourceSpecRepository) - resourceRepo.On("Save", resourceSpec1).Return(errors.New("cant save, too busy")) - resourceRepo.On("Save", resourceSpec2).Return(nil) + resourceRepo.On("Save", ctx, resourceSpec1).Return(errors.New("cant save, too busy")) + resourceRepo.On("Save", ctx, resourceSpec2).Return(nil) defer resourceRepo.AssertExpectations(t) resourceRepoFac := new(mock.ResourceSpecRepoFactory) @@ -235,7 +236,7 @@ func TestService(t *testing.T) { defer projectResourceRepoFac.AssertExpectations(t) service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - err := service.UpdateResource(context.TODO(), namespaceSpec, []models.ResourceSpec{resourceSpec1, resourceSpec2}, nil) + err := service.UpdateResource(ctx, namespaceSpec, []models.ResourceSpec{resourceSpec1, resourceSpec2}, nil) assert.NotNil(t, err) }) }) @@ -254,13 +255,13 @@ func TestService(t *testing.T) { Type: models.ResourceTypeDataset, Datastore: datastorer, } - datastorer.On("ReadResource", context.TODO(), models.ReadResourceRequest{ + datastorer.On("ReadResource", ctx, models.ReadResourceRequest{ Project: projectSpec, Resource: resourceSpec1, }).Return(models.ReadResourceResponse{Resource: resourceSpec1}, nil) resourceRepo := new(mock.ResourceSpecRepository) - resourceRepo.On("GetByName", resourceSpec1.Name).Return(resourceSpec1, nil) + resourceRepo.On("GetByName", ctx, resourceSpec1.Name).Return(resourceSpec1, nil) defer resourceRepo.AssertExpectations(t) resourceRepoFac := new(mock.ResourceSpecRepoFactory) @@ -271,7 +272,7 @@ func TestService(t *testing.T) { defer projectResourceRepoFac.AssertExpectations(t) service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - resp, err := service.ReadResource(context.TODO(), namespaceSpec, "bq", resourceSpec1.Name) + resp, err := service.ReadResource(ctx, namespaceSpec, "bq", resourceSpec1.Name) assert.Nil(t, err) assert.Equal(t, resourceSpec1, resp) }) @@ -291,7 +292,7 @@ func TestService(t *testing.T) { } resourceRepo := new(mock.ResourceSpecRepository) - resourceRepo.On("GetByName", resourceSpec1.Name).Return(resourceSpec1, errors.New("not found")) + resourceRepo.On("GetByName", ctx, resourceSpec1.Name).Return(resourceSpec1, errors.New("not found")) defer resourceRepo.AssertExpectations(t) resourceRepoFac := new(mock.ResourceSpecRepoFactory) @@ -302,7 +303,7 @@ func TestService(t *testing.T) { defer projectResourceRepoFac.AssertExpectations(t) service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - _, err := service.ReadResource(context.TODO(), namespaceSpec, "bq", resourceSpec1.Name) + _, err := service.ReadResource(ctx, namespaceSpec, "bq", resourceSpec1.Name) assert.NotNil(t, err) }) }) @@ -321,14 +322,14 @@ func TestService(t *testing.T) { Type: models.ResourceTypeDataset, Datastore: datastorer, } - datastorer.On("DeleteResource", context.TODO(), models.DeleteResourceRequest{ + datastorer.On("DeleteResource", ctx, models.DeleteResourceRequest{ Project: projectSpec, Resource: resourceSpec1, }).Return(nil) resourceRepo := new(mock.ResourceSpecRepository) - resourceRepo.On("GetByName", resourceSpec1.Name).Return(resourceSpec1, nil) - resourceRepo.On("Delete", resourceSpec1.Name).Return(nil) + resourceRepo.On("GetByName", ctx, resourceSpec1.Name).Return(resourceSpec1, nil) + resourceRepo.On("Delete", ctx, resourceSpec1.Name).Return(nil) defer resourceRepo.AssertExpectations(t) resourceRepoFac := new(mock.ResourceSpecRepoFactory) @@ -339,7 +340,7 @@ func TestService(t *testing.T) { defer projectResourceRepoFac.AssertExpectations(t) service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - err := service.DeleteResource(context.TODO(), namespaceSpec, "bq", resourceSpec1.Name) + err := service.DeleteResource(ctx, namespaceSpec, "bq", resourceSpec1.Name) assert.Nil(t, err) }) t.Run("should not call delete in datastore if failed to delete from repository", func(t *testing.T) { @@ -356,13 +357,13 @@ func TestService(t *testing.T) { Type: models.ResourceTypeDataset, Datastore: datastorer, } - datastorer.On("DeleteResource", context.TODO(), models.DeleteResourceRequest{ + datastorer.On("DeleteResource", ctx, models.DeleteResourceRequest{ Project: projectSpec, Resource: resourceSpec1, }).Return(errors.New("failed to delete")) resourceRepo := new(mock.ResourceSpecRepository) - resourceRepo.On("GetByName", resourceSpec1.Name).Return(resourceSpec1, nil) + resourceRepo.On("GetByName", ctx, resourceSpec1.Name).Return(resourceSpec1, nil) defer resourceRepo.AssertExpectations(t) resourceRepoFac := new(mock.ResourceSpecRepoFactory) @@ -373,7 +374,7 @@ func TestService(t *testing.T) { defer projectResourceRepoFac.AssertExpectations(t) service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - err := service.DeleteResource(context.TODO(), namespaceSpec, "bq", resourceSpec1.Name) + err := service.DeleteResource(ctx, namespaceSpec, "bq", resourceSpec1.Name) assert.NotNil(t, err) }) }) @@ -452,14 +453,14 @@ func TestService(t *testing.T) { BackupSpec: backupReq, } - depMod.On("GenerateDestination", context.TODO(), unitData).Return(destination, nil) + depMod.On("GenerateDestination", ctx, unitData).Return(destination, nil) dsRepo.On("GetByName", models.DestinationTypeBigquery.String()).Return(datastorer, nil) resourceRepoFac.On("New", namespaceSpec, datastorer).Return(resourceRepo) - resourceRepo.On("GetByURN", destination.URN()).Return(resourceSpec, nil) - datastorer.On("BackupResource", context.TODO(), backupResourceReq).Return(models.BackupResourceResponse{}, nil) + resourceRepo.On("GetByURN", ctx, destination.URN()).Return(resourceSpec, nil) + datastorer.On("BackupResource", ctx, backupResourceReq).Return(models.BackupResourceResponse{}, nil) service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - resp, err := service.BackupResourceDryRun(context.TODO(), backupReq, []models.JobSpec{jobSpec}) + resp, err := service.BackupResourceDryRun(ctx, backupReq, []models.JobSpec{jobSpec}) assert.Nil(t, err) assert.Equal(t, []string{destination.Destination}, resp) }) @@ -548,18 +549,18 @@ func TestService(t *testing.T) { dsRepo.On("GetByName", models.DestinationTypeBigquery.String()).Return(datastorer, nil) - depMod.On("GenerateDestination", context.TODO(), unitRoot).Return(destinationRoot, nil).Once() - resourceRepo.On("GetByURN", destinationRoot.URN()).Return(resourceRoot, nil).Once() - datastorer.On("BackupResource", context.TODO(), backupResourceReqRoot).Return(models.BackupResourceResponse{}, nil).Once() + depMod.On("GenerateDestination", ctx, unitRoot).Return(destinationRoot, nil).Once() + resourceRepo.On("GetByURN", ctx, destinationRoot.URN()).Return(resourceRoot, nil).Once() + datastorer.On("BackupResource", ctx, backupResourceReqRoot).Return(models.BackupResourceResponse{}, nil).Once() - depMod.On("GenerateDestination", context.TODO(), unitDownstream).Return(destinationDownstream, nil).Once() - resourceRepo.On("GetByURN", destinationDownstream.URN()).Return(resourceDownstream, nil).Once() - datastorer.On("BackupResource", context.TODO(), backupResourceReqDownstream).Return(models.BackupResourceResponse{}, nil).Once() + depMod.On("GenerateDestination", ctx, unitDownstream).Return(destinationDownstream, nil).Once() + resourceRepo.On("GetByURN", ctx, destinationDownstream.URN()).Return(resourceDownstream, nil).Once() + datastorer.On("BackupResource", ctx, backupResourceReqDownstream).Return(models.BackupResourceResponse{}, nil).Once() resourceRepoFac.On("New", namespaceSpec, datastorer).Return(resourceRepo) service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - resp, err := service.BackupResourceDryRun(context.TODO(), backupReq, []models.JobSpec{jobRoot, jobDownstream}) + resp, err := service.BackupResourceDryRun(ctx, backupReq, []models.JobSpec{jobRoot, jobDownstream}) assert.Nil(t, err) assert.Equal(t, []string{destinationRoot.Destination, destinationDownstream.Destination}, resp) @@ -593,10 +594,10 @@ func TestService(t *testing.T) { } errorMsg := "unable to generate destination" - depMod.On("GenerateDestination", context.TODO(), unitData).Return(&models.GenerateDestinationResponse{}, errors.New(errorMsg)) + depMod.On("GenerateDestination", ctx, unitData).Return(&models.GenerateDestinationResponse{}, errors.New(errorMsg)) service := datastore.NewService(nil, dsRepo, nil, nil) - resp, err := service.BackupResourceDryRun(context.TODO(), backupReq, []models.JobSpec{jobSpec}) + resp, err := service.BackupResourceDryRun(ctx, backupReq, []models.JobSpec{jobSpec}) assert.Contains(t, err.Error(), errorMsg) assert.Nil(t, resp) @@ -632,13 +633,13 @@ func TestService(t *testing.T) { DryRun: true, } - depMod.On("GenerateDestination", context.TODO(), unitData).Return(destination, nil) + depMod.On("GenerateDestination", ctx, unitData).Return(destination, nil) errorMsg := "unable to get datastorer" dsRepo.On("GetByName", destination.Type.String()).Return(datastorer, errors.New(errorMsg)) service := datastore.NewService(nil, dsRepo, nil, nil) - resp, err := service.BackupResourceDryRun(context.TODO(), backupReq, []models.JobSpec{jobSpec}) + resp, err := service.BackupResourceDryRun(ctx, backupReq, []models.JobSpec{jobSpec}) assert.Contains(t, err.Error(), errorMsg) assert.Nil(t, resp) @@ -691,16 +692,16 @@ func TestService(t *testing.T) { BackupSpec: backupReq, } - depMod.On("GenerateDestination", context.TODO(), unitData).Return(destination, nil) + depMod.On("GenerateDestination", ctx, unitData).Return(destination, nil) dsRepo.On("GetByName", models.DestinationTypeBigquery.String()).Return(datastorer, nil) - resourceRepo.On("GetByURN", destination.URN()).Return(resourceSpec, nil) + resourceRepo.On("GetByURN", ctx, destination.URN()).Return(resourceSpec, nil) resourceRepoFac.On("New", namespaceSpec, datastorer).Return(resourceRepo) errorMsg := "unable to do backup dry run" - datastorer.On("BackupResource", context.TODO(), backupResourceReq).Return(models.BackupResourceResponse{}, errors.New(errorMsg)) + datastorer.On("BackupResource", ctx, backupResourceReq).Return(models.BackupResourceResponse{}, errors.New(errorMsg)) service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - resp, err := service.BackupResourceDryRun(context.TODO(), backupReq, []models.JobSpec{jobSpec}) + resp, err := service.BackupResourceDryRun(ctx, backupReq, []models.JobSpec{jobSpec}) assert.Equal(t, errorMsg, err.Error()) assert.Nil(t, resp) @@ -742,15 +743,15 @@ func TestService(t *testing.T) { IgnoreDownstream: false, } - depMod.On("GenerateDestination", context.TODO(), unitData).Return(destination, nil) + depMod.On("GenerateDestination", ctx, unitData).Return(destination, nil) dsRepo.On("GetByName", models.DestinationTypeBigquery.String()).Return(datastorer, nil) resourceRepoFac.On("New", namespaceSpec, datastorer).Return(resourceRepo) errorMsg := "unable to get resource" - resourceRepo.On("GetByURN", destination.URN()).Return(models.ResourceSpec{}, errors.New(errorMsg)) + resourceRepo.On("GetByURN", ctx, destination.URN()).Return(models.ResourceSpec{}, errors.New(errorMsg)) service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - resp, err := service.BackupResourceDryRun(context.TODO(), backupReq, []models.JobSpec{jobSpec}) + resp, err := service.BackupResourceDryRun(ctx, backupReq, []models.JobSpec{jobSpec}) assert.Equal(t, errorMsg, err.Error()) assert.Nil(t, resp) @@ -822,17 +823,17 @@ func TestService(t *testing.T) { Assets: models.PluginAssets{}.FromJobSpec(jobDownstream.Assets), } - depMod.On("GenerateDestination", context.TODO(), unitRoot).Return(destinationRoot, nil).Once() + depMod.On("GenerateDestination", ctx, unitRoot).Return(destinationRoot, nil).Once() dsRepo.On("GetByName", models.DestinationTypeBigquery.String()).Return(datastorer, nil) resourceRepoFac.On("New", namespaceSpec, datastorer).Return(resourceRepo) - resourceRepo.On("GetByURN", destinationRoot.URN()).Return(resourceRoot, nil).Once() - datastorer.On("BackupResource", context.TODO(), backupResourceReqRoot).Return(models.BackupResourceResponse{}, nil).Once() + resourceRepo.On("GetByURN", ctx, destinationRoot.URN()).Return(resourceRoot, nil).Once() + datastorer.On("BackupResource", ctx, backupResourceReqRoot).Return(models.BackupResourceResponse{}, nil).Once() errorMsg := "unable to generate destination" - depMod.On("GenerateDestination", context.TODO(), unitDownstream).Return(&models.GenerateDestinationResponse{}, errors.New(errorMsg)).Once() + depMod.On("GenerateDestination", ctx, unitDownstream).Return(&models.GenerateDestinationResponse{}, errors.New(errorMsg)).Once() service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - resp, err := service.BackupResourceDryRun(context.TODO(), backupReq, []models.JobSpec{jobRoot, jobDownstream}) + resp, err := service.BackupResourceDryRun(ctx, backupReq, []models.JobSpec{jobRoot, jobDownstream}) assert.Equal(t, errorMsg, err.Error()) assert.Nil(t, resp) @@ -911,16 +912,16 @@ func TestService(t *testing.T) { Type: models.DestinationTypeBigquery, } - depMod.On("GenerateDestination", context.TODO(), unitRoot).Return(destinationRoot, nil).Once() + depMod.On("GenerateDestination", ctx, unitRoot).Return(destinationRoot, nil).Once() resourceRepoFac.On("New", namespaceSpec, datastorer).Return(resourceRepo) - resourceRepo.On("GetByURN", destinationRoot.URN()).Return(resourceRoot, nil).Once() - datastorer.On("BackupResource", context.TODO(), backupResourceReqRoot).Return(models.BackupResourceResponse{}, nil).Once() + resourceRepo.On("GetByURN", ctx, destinationRoot.URN()).Return(resourceRoot, nil).Once() + datastorer.On("BackupResource", ctx, backupResourceReqRoot).Return(models.BackupResourceResponse{}, nil).Once() - depMod.On("GenerateDestination", context.TODO(), unitDownstream).Return(destinationDownstream, nil).Once() - resourceRepo.On("GetByURN", destinationDownstream.URN()).Return(models.ResourceSpec{}, store.ErrResourceNotFound).Once() + depMod.On("GenerateDestination", ctx, unitDownstream).Return(destinationDownstream, nil).Once() + resourceRepo.On("GetByURN", ctx, destinationDownstream.URN()).Return(models.ResourceSpec{}, store.ErrResourceNotFound).Once() service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - resp, err := service.BackupResourceDryRun(context.TODO(), backupReq, []models.JobSpec{jobRoot, jobDownstream}) + resp, err := service.BackupResourceDryRun(ctx, backupReq, []models.JobSpec{jobRoot, jobDownstream}) assert.Nil(t, err) assert.Equal(t, []string{destinationRoot.Destination}, resp) @@ -1007,18 +1008,18 @@ func TestService(t *testing.T) { BackupSpec: backupReq, } - depMod.On("GenerateDestination", context.TODO(), unitRoot).Return(destinationRoot, nil).Once() + depMod.On("GenerateDestination", ctx, unitRoot).Return(destinationRoot, nil).Once() resourceRepoFac.On("New", namespaceSpec, datastorer).Return(resourceRepo) - resourceRepo.On("GetByURN", destinationRoot.URN()).Return(resourceRoot, nil).Once() + resourceRepo.On("GetByURN", ctx, destinationRoot.URN()).Return(resourceRoot, nil).Once() dsRepo.On("GetByName", models.DestinationTypeBigquery.String()).Return(datastorer, nil) - datastorer.On("BackupResource", context.TODO(), backupResourceReqRoot).Return(models.BackupResourceResponse{}, nil).Once() + datastorer.On("BackupResource", ctx, backupResourceReqRoot).Return(models.BackupResourceResponse{}, nil).Once() - depMod.On("GenerateDestination", context.TODO(), unitDownstream).Return(destinationDownstream, nil).Once() - resourceRepo.On("GetByURN", destinationDownstream.URN()).Return(resourceDownstream, nil).Once() - datastorer.On("BackupResource", context.TODO(), backupResourceReqDownstream).Return(models.BackupResourceResponse{}, models.ErrUnsupportedResource).Once() + depMod.On("GenerateDestination", ctx, unitDownstream).Return(destinationDownstream, nil).Once() + resourceRepo.On("GetByURN", ctx, destinationDownstream.URN()).Return(resourceDownstream, nil).Once() + datastorer.On("BackupResource", ctx, backupResourceReqDownstream).Return(models.BackupResourceResponse{}, models.ErrUnsupportedResource).Once() service := datastore.NewService(resourceRepoFac, dsRepo, nil, nil) - resp, err := service.BackupResourceDryRun(context.TODO(), backupReq, []models.JobSpec{jobRoot, jobDownstream}) + resp, err := service.BackupResourceDryRun(ctx, backupReq, []models.JobSpec{jobRoot, jobDownstream}) assert.Nil(t, err) assert.Equal(t, []string{destinationRoot.Destination}, resp) @@ -1122,18 +1123,18 @@ func TestService(t *testing.T) { Description: "", } - depMod.On("GenerateDestination", context.TODO(), unitData).Return(destination, nil) + depMod.On("GenerateDestination", ctx, unitData).Return(destination, nil) dsRepo.On("GetByName", models.DestinationTypeBigquery.String()).Return(datastorer, nil) resourceRepoFac.On("New", namespaceSpec, datastorer).Return(resourceRepo) - resourceRepo.On("GetByURN", destination.URN()).Return(resourceSpec, nil) - datastorer.On("BackupResource", context.TODO(), backupResourceReq). + resourceRepo.On("GetByURN", ctx, destination.URN()).Return(resourceSpec, nil) + datastorer.On("BackupResource", ctx, backupResourceReq). Return(models.BackupResourceResponse{ResultURN: resultURN, ResultSpec: resultSpec}, nil) uuidProvider.On("NewUUID").Return(backupUUID, nil) backupRepoFac.On("New", projectSpec, datastorer).Return(backupRepo) - backupRepo.On("Save", backupSpec).Return(nil) + backupRepo.On("Save", ctx, backupSpec).Return(nil) service := datastore.NewService(resourceRepoFac, dsRepo, uuidProvider, backupRepoFac) - resp, err := service.BackupResource(context.TODO(), backupReq, []models.JobSpec{jobSpec}) + resp, err := service.BackupResource(ctx, backupReq, []models.JobSpec{jobSpec}) assert.Nil(t, err) assert.Equal(t, []string{resultURN}, resp) }) @@ -1261,24 +1262,24 @@ func TestService(t *testing.T) { dsRepo.On("GetByName", models.DestinationTypeBigquery.String()).Return(datastorer, nil) - depMod.On("GenerateDestination", context.TODO(), unitRoot).Return(destinationRoot, nil).Once() - resourceRepo.On("GetByURN", destinationRoot.URN()).Return(resourceRoot, nil).Once() - datastorer.On("BackupResource", context.TODO(), backupResourceReqRoot). + depMod.On("GenerateDestination", ctx, unitRoot).Return(destinationRoot, nil).Once() + resourceRepo.On("GetByURN", ctx, destinationRoot.URN()).Return(resourceRoot, nil).Once() + datastorer.On("BackupResource", ctx, backupResourceReqRoot). Return(models.BackupResourceResponse{ResultURN: resultURNRoot, ResultSpec: resultSpecRoot}, nil).Once() - depMod.On("GenerateDestination", context.TODO(), unitDownstream).Return(destinationDownstream, nil).Once() - resourceRepo.On("GetByURN", destinationDownstream.URN()).Return(resourceDownstream, nil).Once() - datastorer.On("BackupResource", context.TODO(), backupResourceReqDownstream). + depMod.On("GenerateDestination", ctx, unitDownstream).Return(destinationDownstream, nil).Once() + resourceRepo.On("GetByURN", ctx, destinationDownstream.URN()).Return(resourceDownstream, nil).Once() + datastorer.On("BackupResource", ctx, backupResourceReqDownstream). Return(models.BackupResourceResponse{ResultURN: resultURNDownstream, ResultSpec: resultSpecDownstream}, nil).Once() resourceRepoFac.On("New", namespaceSpec, datastorer).Return(resourceRepo) uuidProvider.On("NewUUID").Return(backupUUID, nil) backupRepoFac.On("New", projectSpec, datastorer).Return(backupRepo) - backupRepo.On("Save", backupSpec).Return(nil) + backupRepo.On("Save", ctx, backupSpec).Return(nil) service := datastore.NewService(resourceRepoFac, dsRepo, uuidProvider, backupRepoFac) - resp, err := service.BackupResource(context.TODO(), backupReq, []models.JobSpec{jobRoot, jobDownstream}) + resp, err := service.BackupResource(ctx, backupReq, []models.JobSpec{jobRoot, jobDownstream}) assert.Nil(t, err) assert.Equal(t, []string{resultURNRoot, resultURNDownstream}, resp) @@ -1317,10 +1318,10 @@ func TestService(t *testing.T) { uuidProvider.On("NewUUID").Return(backupUUID, nil) errorMsg := "unable to generate destination" - depMod.On("GenerateDestination", context.TODO(), unitData).Return(&models.GenerateDestinationResponse{}, errors.New(errorMsg)) + depMod.On("GenerateDestination", ctx, unitData).Return(&models.GenerateDestinationResponse{}, errors.New(errorMsg)) service := datastore.NewService(nil, dsRepo, uuidProvider, nil) - resp, err := service.BackupResource(context.TODO(), backupReq, []models.JobSpec{jobSpec}) + resp, err := service.BackupResource(ctx, backupReq, []models.JobSpec{jobSpec}) assert.Contains(t, err.Error(), errorMsg) assert.Nil(t, resp) @@ -1360,13 +1361,13 @@ func TestService(t *testing.T) { } uuidProvider.On("NewUUID").Return(backupUUID, nil) - depMod.On("GenerateDestination", context.TODO(), unitData).Return(destination, nil) + depMod.On("GenerateDestination", ctx, unitData).Return(destination, nil) errorMsg := "unable to get datastorer" dsRepo.On("GetByName", destination.Type.String()).Return(datastorer, errors.New(errorMsg)) service := datastore.NewService(nil, dsRepo, uuidProvider, nil) - resp, err := service.BackupResource(context.TODO(), backupReq, []models.JobSpec{jobSpec}) + resp, err := service.BackupResource(ctx, backupReq, []models.JobSpec{jobSpec}) assert.Contains(t, err.Error(), errorMsg) assert.Nil(t, resp) @@ -1412,15 +1413,15 @@ func TestService(t *testing.T) { } uuidProvider.On("NewUUID").Return(backupUUID, nil) - depMod.On("GenerateDestination", context.TODO(), unitData).Return(destination, nil) + depMod.On("GenerateDestination", ctx, unitData).Return(destination, nil) dsRepo.On("GetByName", models.DestinationTypeBigquery.String()).Return(datastorer, nil) resourceRepoFac.On("New", namespaceSpec, datastorer).Return(resourceRepo) errorMsg := "unable to get resource" - resourceRepo.On("GetByURN", destination.URN()).Return(models.ResourceSpec{}, errors.New(errorMsg)) + resourceRepo.On("GetByURN", ctx, destination.URN()).Return(models.ResourceSpec{}, errors.New(errorMsg)) service := datastore.NewService(resourceRepoFac, dsRepo, uuidProvider, nil) - resp, err := service.BackupResource(context.TODO(), backupReq, []models.JobSpec{jobSpec}) + resp, err := service.BackupResource(ctx, backupReq, []models.JobSpec{jobSpec}) assert.Equal(t, errorMsg, err.Error()) assert.Nil(t, resp) @@ -1478,16 +1479,16 @@ func TestService(t *testing.T) { } uuidProvider.On("NewUUID").Return(backupUUID, nil) - depMod.On("GenerateDestination", context.TODO(), unitData).Return(destination, nil) + depMod.On("GenerateDestination", ctx, unitData).Return(destination, nil) dsRepo.On("GetByName", models.DestinationTypeBigquery.String()).Return(datastorer, nil) - resourceRepo.On("GetByURN", destination.URN()).Return(resourceSpec, nil) + resourceRepo.On("GetByURN", ctx, destination.URN()).Return(resourceSpec, nil) resourceRepoFac.On("New", namespaceSpec, datastorer).Return(resourceRepo) errorMsg := "unable to do backup" - datastorer.On("BackupResource", context.TODO(), backupResourceReq).Return(models.BackupResourceResponse{}, errors.New(errorMsg)) + datastorer.On("BackupResource", ctx, backupResourceReq).Return(models.BackupResourceResponse{}, errors.New(errorMsg)) service := datastore.NewService(resourceRepoFac, dsRepo, uuidProvider, nil) - resp, err := service.BackupResource(context.TODO(), backupReq, []models.JobSpec{jobSpec}) + resp, err := service.BackupResource(ctx, backupReq, []models.JobSpec{jobSpec}) assert.Equal(t, errorMsg, err.Error()) assert.Nil(t, resp) @@ -1564,17 +1565,17 @@ func TestService(t *testing.T) { } uuidProvider.On("NewUUID").Return(backupUUID, nil) - depMod.On("GenerateDestination", context.TODO(), unitRoot).Return(destinationRoot, nil).Once() + depMod.On("GenerateDestination", ctx, unitRoot).Return(destinationRoot, nil).Once() dsRepo.On("GetByName", models.DestinationTypeBigquery.String()).Return(datastorer, nil) resourceRepoFac.On("New", namespaceSpec, datastorer).Return(resourceRepo) - resourceRepo.On("GetByURN", destinationRoot.URN()).Return(resourceRoot, nil).Once() - datastorer.On("BackupResource", context.TODO(), backupResourceReqRoot).Return(models.BackupResourceResponse{}, nil).Once() + resourceRepo.On("GetByURN", ctx, destinationRoot.URN()).Return(resourceRoot, nil).Once() + datastorer.On("BackupResource", ctx, backupResourceReqRoot).Return(models.BackupResourceResponse{}, nil).Once() errorMsg := "unable to generate destination" - depMod.On("GenerateDestination", context.TODO(), unitDownstream).Return(&models.GenerateDestinationResponse{}, errors.New(errorMsg)).Once() + depMod.On("GenerateDestination", ctx, unitDownstream).Return(&models.GenerateDestinationResponse{}, errors.New(errorMsg)).Once() service := datastore.NewService(resourceRepoFac, dsRepo, uuidProvider, nil) - resp, err := service.BackupResource(context.TODO(), backupReq, []models.JobSpec{jobRoot, jobDownstream}) + resp, err := service.BackupResource(ctx, backupReq, []models.JobSpec{jobRoot, jobDownstream}) assert.Equal(t, errorMsg, err.Error()) assert.Nil(t, resp) @@ -1681,21 +1682,21 @@ func TestService(t *testing.T) { } uuidProvider.On("NewUUID").Return(backupUUID, nil) - depMod.On("GenerateDestination", context.TODO(), unitRoot).Return(destinationRoot, nil).Once() + depMod.On("GenerateDestination", ctx, unitRoot).Return(destinationRoot, nil).Once() resourceRepoFac.On("New", namespaceSpec, datastorer).Return(resourceRepo) - resourceRepo.On("GetByURN", destinationRoot.URN()).Return(resourceRoot, nil).Once() + resourceRepo.On("GetByURN", ctx, destinationRoot.URN()).Return(resourceRoot, nil).Once() dsRepo.On("GetByName", models.DestinationTypeBigquery.String()).Return(datastorer, nil) - datastorer.On("BackupResource", context.TODO(), backupResourceReqRoot). + datastorer.On("BackupResource", ctx, backupResourceReqRoot). Return(models.BackupResourceResponse{ResultURN: resultURNRoot, ResultSpec: resultSpecRoot}, nil).Once() - depMod.On("GenerateDestination", context.TODO(), unitDownstream).Return(destinationDownstream, nil).Once() - resourceRepo.On("GetByURN", destinationDownstream.URN()).Return(models.ResourceSpec{}, store.ErrResourceNotFound).Once() + depMod.On("GenerateDestination", ctx, unitDownstream).Return(destinationDownstream, nil).Once() + resourceRepo.On("GetByURN", ctx, destinationDownstream.URN()).Return(models.ResourceSpec{}, store.ErrResourceNotFound).Once() backupRepoFac.On("New", projectSpec, datastorer).Return(backupRepo) - backupRepo.On("Save", backupSpec).Return(nil) + backupRepo.On("Save", ctx, backupSpec).Return(nil) service := datastore.NewService(resourceRepoFac, dsRepo, uuidProvider, backupRepoFac) - resp, err := service.BackupResource(context.TODO(), backupReq, []models.JobSpec{jobRoot, jobDownstream}) + resp, err := service.BackupResource(ctx, backupReq, []models.JobSpec{jobRoot, jobDownstream}) assert.Nil(t, err) assert.Equal(t, []string{resultURNRoot}, resp) @@ -1811,22 +1812,22 @@ func TestService(t *testing.T) { } uuidProvider.On("NewUUID").Return(backupUUID, nil) - depMod.On("GenerateDestination", context.TODO(), unitRoot).Return(destinationRoot, nil).Once() + depMod.On("GenerateDestination", ctx, unitRoot).Return(destinationRoot, nil).Once() resourceRepoFac.On("New", namespaceSpec, datastorer).Return(resourceRepo) - resourceRepo.On("GetByURN", destinationRoot.URN()).Return(resourceRoot, nil).Once() + resourceRepo.On("GetByURN", ctx, destinationRoot.URN()).Return(resourceRoot, nil).Once() dsRepo.On("GetByName", models.DestinationTypeBigquery.String()).Return(datastorer, nil) - datastorer.On("BackupResource", context.TODO(), backupResourceReqRoot). + datastorer.On("BackupResource", ctx, backupResourceReqRoot). Return(models.BackupResourceResponse{ResultURN: resultURNRoot, ResultSpec: resultSpecRoot}, nil).Once() - depMod.On("GenerateDestination", context.TODO(), unitDownstream).Return(destinationDownstream, nil).Once() - resourceRepo.On("GetByURN", destinationDownstream.URN()).Return(resourceDownstream, nil).Once() - datastorer.On("BackupResource", context.TODO(), backupResourceReqDownstream).Return(models.BackupResourceResponse{}, models.ErrUnsupportedResource).Once() + depMod.On("GenerateDestination", ctx, unitDownstream).Return(destinationDownstream, nil).Once() + resourceRepo.On("GetByURN", ctx, destinationDownstream.URN()).Return(resourceDownstream, nil).Once() + datastorer.On("BackupResource", ctx, backupResourceReqDownstream).Return(models.BackupResourceResponse{}, models.ErrUnsupportedResource).Once() backupRepoFac.On("New", projectSpec, datastorer).Return(backupRepo) - backupRepo.On("Save", backupSpec).Return(nil) + backupRepo.On("Save", ctx, backupSpec).Return(nil) service := datastore.NewService(resourceRepoFac, dsRepo, uuidProvider, backupRepoFac) - resp, err := service.BackupResource(context.TODO(), backupReq, []models.JobSpec{jobRoot, jobDownstream}) + resp, err := service.BackupResource(ctx, backupReq, []models.JobSpec{jobRoot, jobDownstream}) assert.Nil(t, err) assert.Equal(t, []string{resultURNRoot}, resp) @@ -1863,10 +1864,10 @@ func TestService(t *testing.T) { dsRepo.On("GetByName", datastoreName).Return(datastorer, nil) backupRepoFac.On("New", projectSpec, datastorer).Return(backupRepo) - backupRepo.On("GetAll").Return(backupSpecs, nil) + backupRepo.On("GetAll", ctx).Return(backupSpecs, nil) service := datastore.NewService(nil, dsRepo, nil, backupRepoFac) - resp, err := service.ListBackupResources(projectSpec, datastoreName) + resp, err := service.ListBackupResources(ctx, projectSpec, datastoreName) assert.Nil(t, err) assert.Equal(t, []models.BackupSpec{backupSpecs[0], backupSpecs[1]}, resp) @@ -1882,7 +1883,7 @@ func TestService(t *testing.T) { dsRepo.On("GetByName", datastoreName).Return(datastorer, errors.New(errorMsg)) service := datastore.NewService(nil, dsRepo, nil, nil) - resp, err := service.ListBackupResources(projectSpec, datastoreName) + resp, err := service.ListBackupResources(ctx, projectSpec, datastoreName) assert.Equal(t, errorMsg, err.Error()) assert.Equal(t, []models.BackupSpec{}, resp) @@ -1904,10 +1905,10 @@ func TestService(t *testing.T) { backupRepoFac.On("New", projectSpec, datastorer).Return(backupRepo) errorMsg := "unable to get backups" - backupRepo.On("GetAll").Return([]models.BackupSpec{}, errors.New(errorMsg)) + backupRepo.On("GetAll", ctx).Return([]models.BackupSpec{}, errors.New(errorMsg)) service := datastore.NewService(nil, dsRepo, nil, backupRepoFac) - resp, err := service.ListBackupResources(projectSpec, datastoreName) + resp, err := service.ListBackupResources(ctx, projectSpec, datastoreName) assert.Equal(t, errorMsg, err.Error()) assert.Equal(t, []models.BackupSpec{}, resp) @@ -1927,10 +1928,10 @@ func TestService(t *testing.T) { dsRepo.On("GetByName", datastoreName).Return(datastorer, nil) backupRepoFac.On("New", projectSpec, datastorer).Return(backupRepo) - backupRepo.On("GetAll").Return([]models.BackupSpec{}, store.ErrResourceNotFound) + backupRepo.On("GetAll", ctx).Return([]models.BackupSpec{}, store.ErrResourceNotFound) service := datastore.NewService(nil, dsRepo, nil, backupRepoFac) - resp, err := service.ListBackupResources(projectSpec, datastoreName) + resp, err := service.ListBackupResources(ctx, projectSpec, datastoreName) assert.Nil(t, err) assert.Equal(t, []models.BackupSpec{}, resp) @@ -1950,10 +1951,10 @@ func TestService(t *testing.T) { dsRepo.On("GetByName", datastoreName).Return(datastorer, nil) backupRepoFac.On("New", projectSpec, datastorer).Return(backupRepo) - backupRepo.On("GetAll").Return([]models.BackupSpec{backupSpecs[2]}, nil) + backupRepo.On("GetAll", ctx).Return([]models.BackupSpec{backupSpecs[2]}, nil) service := datastore.NewService(nil, dsRepo, nil, backupRepoFac) - resp, err := service.ListBackupResources(projectSpec, datastoreName) + resp, err := service.ListBackupResources(ctx, projectSpec, datastoreName) assert.Nil(t, err) assert.Equal(t, 0, len(resp)) diff --git a/ext/scheduler/prime/planner.go b/ext/scheduler/prime/planner.go index c442840b9e..6f96a5c6dd 100644 --- a/ext/scheduler/prime/planner.go +++ b/ext/scheduler/prime/planner.go @@ -84,7 +84,7 @@ func (p *Planner) leaderJobAllocation(ctx context.Context) { continue } - allocNodeID, allocRunIDs, err := p.getJobAllocations() + allocNodeID, allocRunIDs, err := p.getJobAllocations(ctx) if err != nil { p.errChan <- err return @@ -116,7 +116,7 @@ func (p *Planner) leaderJobAllocation(ctx context.Context) { // once the command is committed to raft log, we need to update the job state // from pending to accepted for _, runID := range allocRunIDs { - if err := p.jobRunRepoFac.New().UpdateStatus(runID, models.RunStateAccepted); err != nil { + if err := p.jobRunRepoFac.New().UpdateStatus(ctx, runID, models.RunStateAccepted); err != nil { p.errChan <- err return } @@ -140,8 +140,8 @@ func (p *Planner) leaderJobAllocation(ctx context.Context) { // and we will not scale down the cluster once its scaled up. This is a // temporary approach and ideally we should timeout jobs which are assigned // to jobs which went down and move them back to the pending state list. -func (p *Planner) getJobAllocations() (mostCapNodeID string, runIDs []uuid.UUID, err error) { - pendingJobRuns, err := p.jobRunRepoFac.New().GetByTrigger(models.TriggerManual, models.RunStatePending) +func (p *Planner) getJobAllocations(ctx context.Context) (mostCapNodeID string, runIDs []uuid.UUID, err error) { + pendingJobRuns, err := p.jobRunRepoFac.New().GetByTrigger(ctx, models.TriggerManual, models.RunStatePending) if err != nil { return } @@ -199,7 +199,7 @@ func (p *Planner) leaderJobReconcile(ctx context.Context) { runRepo := p.jobRunRepoFac.New() // check for non assignment, non terminating states - waitingJobs, err := runRepo.GetByTrigger(models.TriggerManual, models.RunStateAccepted, models.RunStateRunning) + waitingJobs, err := runRepo.GetByTrigger(ctx, models.TriggerManual, models.RunStateAccepted, models.RunStateRunning) if err != nil { p.errChan <- err continue @@ -223,7 +223,7 @@ func (p *Planner) leaderJobReconcile(ctx context.Context) { } if allocatedNode == "" { // move it back to assignment - if err := runRepo.Clear(currentRun.ID); err != nil { + if err := runRepo.Clear(ctx, currentRun.ID); err != nil { p.errChan <- err } p.l.Debug("cleared orphaned run for reassignment", "run id", currentRun.ID) @@ -269,7 +269,7 @@ func (p *Planner) leaderJobReconcile(ctx context.Context) { p.errChan <- err continue } - if err := runRepo.UpdateStatus(currentRun.ID, finalState); err != nil { + if err := runRepo.UpdateStatus(ctx, currentRun.ID, finalState); err != nil { p.errChan <- err continue } @@ -307,7 +307,7 @@ func (p *Planner) peerJobExecution(ctx context.Context) { p.errChan <- err return } - jobRun, namespaceSpec, err := runRepo.GetByID(runUUID) + jobRun, namespaceSpec, err := runRepo.GetByID(ctx, runUUID) if err != nil { p.errChan <- err return @@ -359,7 +359,7 @@ func (p *Planner) executeRun(ctx context.Context, namespace models.NamespaceSpec p.l.Warn("found a zombie instance", "job name", jobRun.Spec.Name, "task name", jobRun.Spec.Task.Unit.Info().Name) // cancel task and move back state to accepted - if err := instanceRepo.UpdateStatus(instance.ID, models.RunStateAccepted); err != nil { + if err := instanceRepo.UpdateStatus(ctx, instance.ID, models.RunStateAccepted); err != nil { return err } @@ -382,7 +382,7 @@ func (p *Planner) executeRun(ctx context.Context, namespace models.NamespaceSpec Type: models.InstanceTypeTask, Status: models.RunStateAccepted, } - if err := p.jobRunRepoFac.New().AddInstance(namespace, jobRun, newInstance); err != nil { + if err := p.jobRunRepoFac.New().AddInstance(ctx, namespace, jobRun, newInstance); err != nil { return err } @@ -396,7 +396,7 @@ func (p *Planner) executeRun(ctx context.Context, namespace models.NamespaceSpec if err != nil { return err } - if err := p.instanceRepoFac.New().UpdateStatus(instanceID, models.RunStateRunning); err != nil { + if err := p.instanceRepoFac.New().UpdateStatus(ctx, instanceID, models.RunStateRunning); err != nil { return err } @@ -410,13 +410,13 @@ func (p *Planner) executeRun(ctx context.Context, namespace models.NamespaceSpec p.l.Warn("job finished with non zero code", "code", finishCode, "job name", jobRun.Spec.Name) // mark instance failed - if err := instanceRepo.UpdateStatus(newInstance.ID, models.RunStateFailed); err != nil { + if err := instanceRepo.UpdateStatus(ctx, newInstance.ID, models.RunStateFailed); err != nil { return err } } // mark instance success - if err := instanceRepo.UpdateStatus(newInstance.ID, models.RunStateSuccess); err != nil { + if err := instanceRepo.UpdateStatus(ctx, newInstance.ID, models.RunStateSuccess); err != nil { return err } p.l.Info("finished executing job spec", "job name", jobRun.Spec.Name) diff --git a/ext/scheduler/prime/scheduler.go b/ext/scheduler/prime/scheduler.go index 6878dfd779..9338f17bbb 100644 --- a/ext/scheduler/prime/scheduler.go +++ b/ext/scheduler/prime/scheduler.go @@ -48,7 +48,7 @@ func (s *Scheduler) DeployJobs(ctx context.Context, namespace models.NamespaceSp repo := s.jobRunRepoFac.New() for _, runs := range jobRuns { - if err := repo.Save(namespace, runs); err != nil { + if err := repo.Save(ctx, namespace, runs); err != nil { return err } } diff --git a/go.mod b/go.mod index 31b4f6ae9d..210da3a727 100644 --- a/go.mod +++ b/go.mod @@ -27,10 +27,8 @@ require ( github.com/hashicorp/serf v0.8.2 github.com/huandu/xstrings v1.3.2 // indirect github.com/jhump/protoreflect v1.9.1-0.20210817181203-db1a327a393e // indirect - github.com/jinzhu/gorm v1.9.16 github.com/knadh/koanf v1.1.0 github.com/kushsharma/parallel v0.2.1 - github.com/lib/pq v1.10.2 github.com/mattn/go-sqlite3 v2.0.1+incompatible // indirect github.com/odpf/salt v0.0.0-20210919015538-3fd8ab22acea github.com/olekukonko/tablewriter v0.0.5 @@ -56,6 +54,8 @@ require ( gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b gorm.io/datatypes v1.0.0 + gorm.io/driver/postgres v1.0.5 + gorm.io/gorm v1.21.16 ) go 1.16 diff --git a/go.sum b/go.sum index 60d38269b8..5d02871841 100644 --- a/go.sum +++ b/go.sum @@ -122,7 +122,6 @@ github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5/go.mod h1:tT github.com/Netflix/go-expect v0.0.0-20180615182759-c93bf25de8e8 h1:xzYJEypr/85nBpB11F9br+3HUrpgb+fcm5iADzXXYEw= github.com/Netflix/go-expect v0.0.0-20180615182759-c93bf25de8e8/go.mod h1:oX5x61PbNXchhh0oikYAH+4Pcfw5LKv21+Jnpr6r6Pc= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= -github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38/go.mod h1:r7bzyVFMNntcxPZXK3/+KdruV1H5KSlyVY0gc+NgInI= github.com/alecthomas/chroma v0.8.2/go.mod h1:sko8vR34/90zvl5QdcUdvzL3J8NKjAUx9va9jPuFNoM= github.com/alecthomas/colour v0.0.0-20160524082231-60882d9e2721/go.mod h1:QO9JBoKquHd+jz9nshCh40fOfO+JzsoXy8qTHF68zU0= @@ -130,7 +129,6 @@ github.com/alecthomas/kong v0.2.4/go.mod h1:kQOmtJgV+Lb4aj+I2LEn40cbtawdWJ9Y8QLq github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ= github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129 h1:MzBOUgng9orim59UnfUTLRjMpd09C5uEVQ6RPGeCaVI= github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129/go.mod h1:rFgpPQZYZ8vdbc+48xibu8ALc3yeyd64IhHS+PU6Yyg= -github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/arrow/go/arrow v0.0.0-20200601151325-b2287a20f230/go.mod h1:QNYViu/X0HXDHw7m3KXzWSVXIbfUvJqBFe6Gj8/pYA0= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= @@ -192,6 +190,7 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/cockroachdb/cockroach-go v0.0.0-20190925194419-606b3d062051/go.mod h1:XGLbWH/ujMcbPbhZq52Nv6UrCghb1yGn//133kEsvDk= github.com/containerd/containerd v1.4.0/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA= @@ -209,7 +208,6 @@ github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964/go.mod h1:Xd9 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/denisenkom/go-mssqldb v0.0.0-20200620013148-b91950f658ec/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/denisenkom/go-mssqldb v0.9.0 h1:RSohk2RsiZqLZ0zCjtfn3S4Gp4exhpBWHyQ7D0yGjAk= @@ -244,8 +242,6 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y= -github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= @@ -286,6 +282,7 @@ github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6Wezm github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gocql/gocql v0.0.0-20190301043612-f6df8288f9b4/go.mod h1:4Fw1eo5iaEhDUs8XyuhSVCVy52Jq3L+/3GJgYkwc+/0= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= @@ -511,6 +508,7 @@ github.com/jackc/pgconn v1.7.0 h1:pwjzcYyfmz/HQOQlENvG1OcDqauTGaqlVahq934F0/U= github.com/jackc/pgconn v1.7.0/go.mod h1:sF/lPpNEMEOp+IYhyQGdAvrG20gWf6A1tKlr0v7JMeA= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= @@ -551,13 +549,11 @@ github.com/jeremywohl/flatten v1.0.1/go.mod h1:4AmD/VxjWcI5SRB0n6szE2A6s2fsNHDLO github.com/jhump/protoreflect v1.6.0/go.mod h1:eaTn3RZAmMBcV0fifFvlm6VHNz3wSkYyXYWUh7ymB74= github.com/jhump/protoreflect v1.9.1-0.20210817181203-db1a327a393e h1:Yb4fEGk+GtBSNuvy5rs0ZJt/jtopc/z9azQaj3xbies= github.com/jhump/protoreflect v1.9.1-0.20210817181203-db1a327a393e/go.mod h1:7GcYQDdMU/O/BBrl/cX6PNHpXh6cenjd8pneu5yW7Tg= -github.com/jinzhu/gorm v1.9.16 h1:+IyIjPEABKRpsu/F8OvDPy9fyQlgsg2luMV2ZIH5i5o= -github.com/jinzhu/gorm v1.9.16/go.mod h1:G3LB3wezTOWM2ITLzPxEXgSkOXAntiLHS7UdBefADcs= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/jinzhu/now v1.1.1 h1:g39TucaRWyV3dwDO++eEc6qf8TVIQ/Da48WmqjZ3i7E= github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI= +github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= @@ -605,7 +601,6 @@ github.com/kushsharma/parallel v0.2.1/go.mod h1:6JCy2+DRCUfZ0VFBUg6HG8IdDTDKuVL0 github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= @@ -637,7 +632,6 @@ github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4 github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= github.com/mattn/go-sqlite3 v1.14.3/go.mod h1:WVKg1VTActs4Qso6iwGbiFih2UIHo0ENGwNd0Lj+XmI= github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw= github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= @@ -858,7 +852,6 @@ golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200414173820-0848c9571904/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -1373,8 +1366,9 @@ gorm.io/driver/sqlserver v1.0.5/go.mod h1:WI/bfZ+s9TigYXe3hb3XjNaUP0TqmTdXl11pEC gorm.io/gorm v1.20.1/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= gorm.io/gorm v1.20.2/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= gorm.io/gorm v1.20.4/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= -gorm.io/gorm v1.20.5 h1:g3tpSF9kggASzReK+Z3dYei1IJODLqNUbOjSuCczY8g= gorm.io/gorm v1.20.5/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= +gorm.io/gorm v1.21.16 h1:YBIQLtP5PLfZQz59qfrq7xbrK7KWQ+JsXXCH/THlMqs= +gorm.io/gorm v1.21.16/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/job/dependency_resolver.go b/job/dependency_resolver.go index a6390ea856..44dfcc6ae9 100644 --- a/job/dependency_resolver.go +++ b/job/dependency_resolver.go @@ -38,7 +38,7 @@ func (r *dependencyResolver) Resolve(ctx context.Context, projectSpec models.Pro } // resolve statically defined dependencies - jobSpec, err = r.resolveStaticDependencies(jobSpec, projectSpec, projectJobSpecRepo) + jobSpec, err = r.resolveStaticDependencies(ctx, jobSpec, projectSpec, projectJobSpecRepo) if err != nil { return models.JobSpec{}, err } @@ -70,7 +70,7 @@ func (r *dependencyResolver) resolveInferredDependencies(ctx context.Context, jo // get job spec of these destinations and append to current jobSpec for _, depDestination := range jobDependencies { - depSpec, depProj, err := projectJobSpecRepo.GetByDestination(depDestination) + depSpec, depProj, err := projectJobSpecRepo.GetByDestination(ctx, depDestination) if err != nil { if err == store.ErrResourceNotFound { // should not fail for unknown dependency @@ -98,7 +98,7 @@ func (r *dependencyResolver) getJobSpecDependencyType(dependency models.JobSpecD // update named (explicit/static) dependencies if unresolved with its spec model // this can normally happen when reading specs from a store[local/postgres] -func (r *dependencyResolver) resolveStaticDependencies(jobSpec models.JobSpec, projectSpec models.ProjectSpec, +func (r *dependencyResolver) resolveStaticDependencies(ctx context.Context, jobSpec models.JobSpec, projectSpec models.ProjectSpec, projectJobSpecRepo store.ProjectJobSpecRepository) (models.JobSpec, error) { // update static dependencies if unresolved with its spec model for depName, depSpec := range jobSpec.Dependencies { @@ -106,7 +106,7 @@ func (r *dependencyResolver) resolveStaticDependencies(jobSpec models.JobSpec, p switch depSpec.Type { case models.JobSpecDependencyTypeIntra: { - job, _, err := projectJobSpecRepo.GetByName(depName) + job, _, err := projectJobSpecRepo.GetByName(ctx, depName) if err != nil { return models.JobSpec{}, errors.Wrapf(err, "%s for job %s", ErrUnknownLocalDependency, depName) } @@ -123,7 +123,7 @@ func (r *dependencyResolver) resolveStaticDependencies(jobSpec models.JobSpec, p } projectName := depParts[0] jobName := depParts[1] - job, proj, err := projectJobSpecRepo.GetByNameForProject(projectName, jobName) + job, proj, err := projectJobSpecRepo.GetByNameForProject(ctx, projectName, jobName) if err != nil { return models.JobSpec{}, errors.Wrapf(err, "%s for job %s", ErrUnknownCrossProjectDependency, depName) } diff --git a/job/dependency_resolver_test.go b/job/dependency_resolver_test.go index 11eb858cdf..f5683817d6 100644 --- a/job/dependency_resolver_test.go +++ b/job/dependency_resolver_test.go @@ -92,7 +92,7 @@ func TestDependencyResolver(t *testing.T) { } jobSpecRepository := new(mock.ProjectJobSpecRepository) - jobSpecRepository.On("GetByDestination", "project.dataset.table2_destination").Return(jobSpec2, projectSpec, nil) + jobSpecRepository.On("GetByDestination", ctx, "project.dataset.table2_destination").Return(jobSpec2, projectSpec, nil) defer jobSpecRepository.AssertExpectations(t) projectJobSpecRepoFactory := new(mock.ProjectJobSpecRepoFactory) @@ -109,8 +109,8 @@ func TestDependencyResolver(t *testing.T) { } // task dependencies - execUnit1.On("GenerateDependencies", context.TODO(), unitData).Return(&models.GenerateDependenciesResponse{Dependencies: []string{"project.dataset.table2_destination"}}, nil) - execUnit1.On("GenerateDependencies", context.TODO(), unitData2).Return(&models.GenerateDependenciesResponse{}, nil) + execUnit1.On("GenerateDependencies", ctx, unitData).Return(&models.GenerateDependenciesResponse{Dependencies: []string{"project.dataset.table2_destination"}}, nil) + execUnit1.On("GenerateDependencies", ctx, unitData2).Return(&models.GenerateDependenciesResponse{}, nil) // hook dependency hookUnit1.On("PluginInfo").Return(&models.PluginInfoResponse{ @@ -196,7 +196,7 @@ func TestDependencyResolver(t *testing.T) { } jobSpecRepository := new(mock.ProjectJobSpecRepository) - jobSpecRepository.On("GetByDestination", "project.dataset.table2_destination").Return(jobSpec2, projectSpec, nil) + jobSpecRepository.On("GetByDestination", ctx, "project.dataset.table2_destination").Return(jobSpec2, projectSpec, nil) defer jobSpecRepository.AssertExpectations(t) projectJobSpecRepoFactory := new(mock.ProjectJobSpecRepoFactory) @@ -212,10 +212,10 @@ func TestDependencyResolver(t *testing.T) { Project: projectSpec, } - execUnit.On("GenerateDependencies", context.TODO(), unitData).Return(&models.GenerateDependenciesResponse{ + execUnit.On("GenerateDependencies", ctx, unitData).Return(&models.GenerateDependenciesResponse{ Dependencies: []string{"project.dataset.table2_destination"}, }, nil) - execUnit.On("GenerateDependencies", context.TODO(), unitData2).Return(&models.GenerateDependenciesResponse{}, nil) + execUnit.On("GenerateDependencies", ctx, unitData2).Return(&models.GenerateDependenciesResponse{}, nil) resolver := job.NewDependencyResolver(projectJobSpecRepoFactory) resolvedJobSpec1, err := resolver.Resolve(ctx, projectSpec, jobSpec1, nil) @@ -274,7 +274,7 @@ func TestDependencyResolver(t *testing.T) { } jobSpecRepository := new(mock.ProjectJobSpecRepository) - jobSpecRepository.On("GetByDestination", "project.dataset.table2_destination").Return(jobSpec2, projectSpec, errors.New("random error")) + jobSpecRepository.On("GetByDestination", ctx, "project.dataset.table2_destination").Return(jobSpec2, projectSpec, errors.New("random error")) defer jobSpecRepository.AssertExpectations(t) projectJobSpecRepoFactory := new(mock.ProjectJobSpecRepoFactory) @@ -359,7 +359,7 @@ func TestDependencyResolver(t *testing.T) { } jobSpecRepository := new(mock.ProjectJobSpecRepository) - jobSpecRepository.On("GetByDestination", "project.dataset.table3_destination").Return(nil, nil, errors.New("spec not found")) + jobSpecRepository.On("GetByDestination", ctx, "project.dataset.table3_destination").Return(nil, nil, errors.New("spec not found")) defer jobSpecRepository.AssertExpectations(t) projectJobSpecRepoFactory := new(mock.ProjectJobSpecRepoFactory) @@ -421,15 +421,15 @@ func TestDependencyResolver(t *testing.T) { } jobSpecRepository := new(mock.ProjectJobSpecRepository) - jobSpecRepository.On("GetByDestination", "project.dataset.table1_destination").Return(jobSpec1, projectSpec, nil) - jobSpecRepository.On("GetByName", "static_dep").Return(nil, errors.New("spec not found")) + jobSpecRepository.On("GetByDestination", ctx, "project.dataset.table1_destination").Return(jobSpec1, projectSpec, nil) + jobSpecRepository.On("GetByName", ctx, "static_dep").Return(nil, errors.New("spec not found")) defer jobSpecRepository.AssertExpectations(t) projectJobSpecRepoFactory := new(mock.ProjectJobSpecRepoFactory) projectJobSpecRepoFactory.On("New", projectSpec).Return(jobSpecRepository) defer projectJobSpecRepoFactory.AssertExpectations(t) unitData2 := models.GenerateDependenciesRequest{Config: models.PluginConfigs{}.FromJobSpec(jobSpec2.Task.Config), Assets: models.PluginAssets{}.FromJobSpec(jobSpec2.Assets), Project: projectSpec} - execUnit.On("GenerateDependencies", context.Background(), unitData2).Return(&models.GenerateDependenciesResponse{ + execUnit.On("GenerateDependencies", ctx, unitData2).Return(&models.GenerateDependenciesResponse{ Dependencies: []string{"project.dataset.table1_destination"}, }, nil) @@ -482,7 +482,7 @@ func TestDependencyResolver(t *testing.T) { } jobSpecRepository := new(mock.ProjectJobSpecRepository) - jobSpecRepository.On("GetByDestination", "project.dataset.table1_destination").Return(jobSpec1, projectSpec, nil) + jobSpecRepository.On("GetByDestination", ctx, "project.dataset.table1_destination").Return(jobSpec1, projectSpec, nil) defer jobSpecRepository.AssertExpectations(t) projectJobSpecRepoFactory := new(mock.ProjectJobSpecRepoFactory) projectJobSpecRepoFactory.On("New", projectSpec).Return(jobSpecRepository) @@ -562,8 +562,8 @@ func TestDependencyResolver(t *testing.T) { } jobSpecRepository := new(mock.ProjectJobSpecRepository) - jobSpecRepository.On("GetByDestination", "project.dataset.table2_destination").Return(jobSpec2, projectSpec, nil) - jobSpecRepository.On("GetByName", "test3").Return(jobSpec3, namespaceSpec, nil) + jobSpecRepository.On("GetByDestination", ctx, "project.dataset.table2_destination").Return(jobSpec2, projectSpec, nil) + jobSpecRepository.On("GetByName", ctx, "test3").Return(jobSpec3, namespaceSpec, nil) defer jobSpecRepository.AssertExpectations(t) projectJobSpecRepoFactory := new(mock.ProjectJobSpecRepoFactory) @@ -579,10 +579,10 @@ func TestDependencyResolver(t *testing.T) { Project: projectSpec, } - execUnit.On("GenerateDependencies", context.Background(), unitData).Return(&models.GenerateDependenciesResponse{ + execUnit.On("GenerateDependencies", ctx, unitData).Return(&models.GenerateDependenciesResponse{ Dependencies: []string{"project.dataset.table2_destination"}, }, nil) - execUnit.On("GenerateDependencies", context.Background(), unitData2).Return(&models.GenerateDependenciesResponse{}, nil) + execUnit.On("GenerateDependencies", ctx, unitData2).Return(&models.GenerateDependenciesResponse{}, nil) resolver := job.NewDependencyResolver(projectJobSpecRepoFactory) resolvedJobSpec1, err := resolver.Resolve(ctx, projectSpec, jobSpec1, nil) @@ -693,9 +693,9 @@ func TestDependencyResolver(t *testing.T) { } jobSpecRepository := new(mock.ProjectJobSpecRepository) - jobSpecRepository.On("GetByDestination", "project.dataset.table2_destination").Return(jobSpec2, projectSpec, nil) - jobSpecRepository.On("GetByDestination", "project.dataset.table2_external_destination").Return(jobSpecExternal, externalProjectSpec, nil) - jobSpecRepository.On("GetByNameForProject", externalProjectName, "test3").Return(jobSpec3, externalProjectSpec, nil) + jobSpecRepository.On("GetByDestination", ctx, "project.dataset.table2_destination").Return(jobSpec2, projectSpec, nil) + jobSpecRepository.On("GetByDestination", ctx, "project.dataset.table2_external_destination").Return(jobSpecExternal, externalProjectSpec, nil) + jobSpecRepository.On("GetByNameForProject", ctx, externalProjectName, "test3").Return(jobSpec3, externalProjectSpec, nil) defer jobSpecRepository.AssertExpectations(t) projectJobSpecRepoFactory := new(mock.ProjectJobSpecRepoFactory) diff --git a/job/job.go b/job/job.go index d430a4e689..ce5060b1bc 100644 --- a/job/job.go +++ b/job/job.go @@ -1,11 +1,15 @@ package job -import "github.com/odpf/optimus/models" +import ( + "context" + + "github.com/odpf/optimus/models" +) // SpecRepository represents a storage interface for Job specifications at a namespace level type SpecRepository interface { - Save(models.JobSpec) error - GetByName(string) (models.JobSpec, error) - GetAll() ([]models.JobSpec, error) - Delete(string) error + Save(context.Context, models.JobSpec) error + GetByName(context.Context, string) (models.JobSpec, error) + GetAll(context.Context) ([]models.JobSpec, error) + Delete(context.Context, string) error } diff --git a/job/replay.go b/job/replay.go index 31ca938696..309240666e 100644 --- a/job/replay.go +++ b/job/replay.go @@ -162,7 +162,7 @@ func getRunsBetweenDates(start time.Time, end time.Time, schedule string) ([]tim func (srv *Service) GetReplayStatus(ctx context.Context, replayRequest models.ReplayRequest) (models.ReplayState, error) { // Get replay - replaySpec, err := srv.replayManager.GetReplay(replayRequest.ID) + replaySpec, err := srv.replayManager.GetReplay(ctx, replayRequest.ID) if err != nil { return models.ReplayState{}, err } @@ -225,6 +225,6 @@ func (srv *Service) populateDownstreamRunsWithStatus(ctx context.Context, projec return parentNode, nil } -func (srv *Service) GetReplayList(projectUUID uuid.UUID) ([]models.ReplaySpec, error) { - return srv.replayManager.GetReplayList(projectUUID) +func (srv *Service) GetReplayList(ctx context.Context, projectUUID uuid.UUID) ([]models.ReplaySpec, error) { + return srv.replayManager.GetReplayList(ctx, projectUUID) } diff --git a/job/replay_manager.go b/job/replay_manager.go index 12f99759f9..034b59da7d 100644 --- a/job/replay_manager.go +++ b/job/replay_manager.go @@ -106,7 +106,7 @@ func (m *Manager) Replay(ctx context.Context, reqInput models.ReplayRequest) (st } // could get cancelled later if queue is full - if err = replaySpecRepo.Insert(&replay); err != nil { + if err = replaySpecRepo.Insert(ctx, &replay); err != nil { return "", err } @@ -115,7 +115,7 @@ func (m *Manager) Replay(ctx context.Context, reqInput models.ReplayRequest) (st return reqInput.ID.String(), nil default: // all workers busy, mark the inserted request as cancelled - _ = replaySpecRepo.UpdateStatus(reqInput.ID, models.ReplayStatusCancelled, models.ReplayMessage{ + _ = replaySpecRepo.UpdateStatus(ctx, reqInput.ID, models.ReplayStatusCancelled, models.ReplayMessage{ Type: models.ReplayStatusCancelled, Message: ErrRequestQueueFull.Error(), }) @@ -153,13 +153,13 @@ func (m *Manager) SchedulerSyncer() { } // GetReplay using UUID -func (m *Manager) GetReplay(replayUUID uuid.UUID) (models.ReplaySpec, error) { - return m.replaySpecRepoFac.New().GetByID(replayUUID) +func (m *Manager) GetReplay(ctx context.Context, replayUUID uuid.UUID) (models.ReplaySpec, error) { + return m.replaySpecRepoFac.New().GetByID(ctx, replayUUID) } // GetReplayList using Project ID -func (m *Manager) GetReplayList(projectUUID uuid.UUID) ([]models.ReplaySpec, error) { - replays, err := m.replaySpecRepoFac.New().GetByProjectID(projectUUID) +func (m *Manager) GetReplayList(ctx context.Context, projectUUID uuid.UUID) ([]models.ReplaySpec, error) { + replays, err := m.replaySpecRepoFac.New().GetByProjectID(ctx, projectUUID) if err != nil { if err == store.ErrResourceNotFound { return []models.ReplaySpec{}, nil diff --git a/job/replay_manager_test.go b/job/replay_manager_test.go index 678f2b5072..77b799cc30 100644 --- a/job/replay_manager_test.go +++ b/job/replay_manager_test.go @@ -129,7 +129,7 @@ func TestReplayManager(t *testing.T) { EndDate: endDate, Status: models.ReplayStatusAccepted, } - replayRepository.On("Insert", toInsertReplaySpec).Return(errors.New(errMessage)) + replayRepository.On("Insert", ctx, toInsertReplaySpec).Return(errors.New(errMessage)) worker := mock.NewReplayWorker() replayWorkerFact := new(mock.ReplayWorkerFactory) @@ -195,7 +195,7 @@ func TestReplayManager(t *testing.T) { EndDate: endDate, Status: models.ReplayStatusAccepted, } - replayRepository.On("Insert", toInsertReplaySpec).Return(nil) + replayRepository.On("Insert", ctx, toInsertReplaySpec).Return(nil) worker := mock.NewReplayWorker() replayRequestToProcess := replayRequest @@ -245,7 +245,7 @@ func TestReplayManager(t *testing.T) { // Status: models.ReplayStatusAccepted, // } // - // replayRepository.On("Insert", toInsertReplaySpec).Return(nil).Times(4) + // replayRepository.On("Insert", ctx, toInsertReplaySpec).Return(nil).Times(4) // // // other workers should not be closed before encounter full state. // // replay will be cancelled when workers are full. @@ -315,14 +315,14 @@ func TestReplayManager(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByID", replayUUID).Return(replaySpec, nil) + replayRepository.On("GetByID", ctx, replayUUID).Return(replaySpec, nil) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) replaySpecRepoFac.On("New").Return(replayRepository) replayManager := job.NewManager(log, nil, replaySpecRepoFac, nil, job.ReplayManagerConfig{}, nil, nil, nil) - replayResult, err := replayManager.GetReplay(replayUUID) + replayResult, err := replayManager.GetReplay(ctx, replayUUID) assert.Nil(t, err) assert.Equal(t, replaySpec, replayResult) @@ -335,14 +335,14 @@ func TestReplayManager(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByID", replayUUID).Return(models.ReplaySpec{}, store.ErrResourceNotFound) + replayRepository.On("GetByID", ctx, replayUUID).Return(models.ReplaySpec{}, store.ErrResourceNotFound) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) replaySpecRepoFac.On("New").Return(replayRepository) replayManager := job.NewManager(log, nil, replaySpecRepoFac, nil, job.ReplayManagerConfig{}, nil, nil, nil) - replayResult, err := replayManager.GetReplay(replayUUID) + replayResult, err := replayManager.GetReplay(ctx, replayUUID) assert.Equal(t, err, store.ErrResourceNotFound) assert.Equal(t, models.ReplaySpec{}, replayResult) @@ -372,14 +372,14 @@ func TestReplayManager(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByProjectID", projectUUID).Return(replaySpecs, nil) + replayRepository.On("GetByProjectID", ctx, projectUUID).Return(replaySpecs, nil) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) replaySpecRepoFac.On("New").Return(replayRepository) replayManager := job.NewManager(log, nil, replaySpecRepoFac, nil, job.ReplayManagerConfig{}, nil, nil, nil) - replayListResult, err := replayManager.GetReplayList(projectUUID) + replayListResult, err := replayManager.GetReplayList(ctx, projectUUID) assert.Nil(t, err) assert.Equal(t, replaySpecs, replayListResult) @@ -415,14 +415,14 @@ func TestReplayManager(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByProjectID", projectUUID).Return(replaySpecs, nil) + replayRepository.On("GetByProjectID", ctx, projectUUID).Return(replaySpecs, nil) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) replaySpecRepoFac.On("New").Return(replayRepository) replayManager := job.NewManager(log, nil, replaySpecRepoFac, nil, job.ReplayManagerConfig{}, nil, nil, nil) - replayListResult, err := replayManager.GetReplayList(projectUUID) + replayListResult, err := replayManager.GetReplayList(ctx, projectUUID) expectedReplaySpecs := []models.ReplaySpec{replaySpecs[0]} assert.Nil(t, err) @@ -436,14 +436,14 @@ func TestReplayManager(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByProjectID", projectUUID).Return([]models.ReplaySpec{}, store.ErrResourceNotFound) + replayRepository.On("GetByProjectID", ctx, projectUUID).Return([]models.ReplaySpec{}, store.ErrResourceNotFound) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) replaySpecRepoFac.On("New").Return(replayRepository) replayManager := job.NewManager(log, nil, replaySpecRepoFac, nil, job.ReplayManagerConfig{}, nil, nil, nil) - replayResult, err := replayManager.GetReplayList(projectUUID) + replayResult, err := replayManager.GetReplayList(ctx, projectUUID) assert.Nil(t, err) assert.Equal(t, []models.ReplaySpec{}, replayResult) @@ -457,14 +457,14 @@ func TestReplayManager(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) errorMsg := "unable to get list of replays" - replayRepository.On("GetByProjectID", projectUUID).Return([]models.ReplaySpec{}, errors.New(errorMsg)) + replayRepository.On("GetByProjectID", ctx, projectUUID).Return([]models.ReplaySpec{}, errors.New(errorMsg)) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) replaySpecRepoFac.On("New").Return(replayRepository) replayManager := job.NewManager(log, nil, replaySpecRepoFac, nil, job.ReplayManagerConfig{}, nil, nil, nil) - replayResult, err := replayManager.GetReplayList(projectUUID) + replayResult, err := replayManager.GetReplayList(ctx, projectUUID) assert.Equal(t, errorMsg, err.Error()) assert.Equal(t, []models.ReplaySpec{}, replayResult) diff --git a/job/replay_syncer.go b/job/replay_syncer.go index 6c449a7095..bd775eea9b 100644 --- a/job/replay_syncer.go +++ b/job/replay_syncer.go @@ -36,15 +36,15 @@ func NewReplaySyncer(log log.Logger, replaySpecFactory ReplaySpecRepoFactory, pr } } -func (s Syncer) Sync(context context.Context, runTimeout time.Duration) error { +func (s Syncer) Sync(ctx context.Context, runTimeout time.Duration) error { replaySpecRepo := s.replaySpecFactory.New() - projectSpecs, err := s.projectRepoFactory.New().GetAll() + projectSpecs, err := s.projectRepoFactory.New().GetAll(ctx) if err != nil { return err } for _, projectSpec := range projectSpecs { - replaySpecs, err := replaySpecRepo.GetByProjectIDAndStatus(projectSpec.ID, ReplayStatusToSynced) + replaySpecs, err := replaySpecRepo.GetByProjectIDAndStatus(ctx, projectSpec.ID, ReplayStatusToSynced) if err != nil { if err == store.ErrResourceNotFound { return nil @@ -55,14 +55,14 @@ func (s Syncer) Sync(context context.Context, runTimeout time.Duration) error { for _, replaySpec := range replaySpecs { // sync end state of replayed replays if replaySpec.Status == models.ReplayStatusReplayed { - if err := s.syncRunningReplay(context, projectSpec, replaySpec, replaySpecRepo); err != nil { + if err := s.syncRunningReplay(ctx, projectSpec, replaySpec, replaySpecRepo); err != nil { return err } continue } // sync timed out replays for accepted and in progress replays - if err := s.syncTimedOutReplay(replaySpecRepo, replaySpec, runTimeout); err != nil { + if err := s.syncTimedOutReplay(ctx, replaySpecRepo, replaySpec, runTimeout); err != nil { return err } } @@ -70,10 +70,10 @@ func (s Syncer) Sync(context context.Context, runTimeout time.Duration) error { return nil } -func (s Syncer) syncTimedOutReplay(replaySpecRepo store.ReplaySpecRepository, replaySpec models.ReplaySpec, runTimeout time.Duration) error { +func (s Syncer) syncTimedOutReplay(ctx context.Context, replaySpecRepo store.ReplaySpecRepository, replaySpec models.ReplaySpec, runTimeout time.Duration) error { runningTime := s.Now().Sub(replaySpec.CreatedAt) if runningTime > runTimeout { - if updateStatusErr := replaySpecRepo.UpdateStatus(replaySpec.ID, models.ReplayStatusFailed, models.ReplayMessage{ + if updateStatusErr := replaySpecRepo.UpdateStatus(ctx, replaySpec.ID, models.ReplayStatusFailed, models.ReplayMessage{ Type: ReplayRunTimeout, Message: fmt.Sprintf("replay has been running since %s", replaySpec.CreatedAt.UTC().Format(TimestampLogFormat)), }); updateStatusErr != nil { @@ -84,13 +84,13 @@ func (s Syncer) syncTimedOutReplay(replaySpecRepo store.ReplaySpecRepository, re return nil } -func (s Syncer) syncRunningReplay(context context.Context, projectSpec models.ProjectSpec, replaySpec models.ReplaySpec, replaySpecRepo store.ReplaySpecRepository) error { - stateSummary, err := s.checkInstanceState(context, projectSpec, replaySpec) +func (s Syncer) syncRunningReplay(ctx context.Context, projectSpec models.ProjectSpec, replaySpec models.ReplaySpec, replaySpecRepo store.ReplaySpecRepository) error { + stateSummary, err := s.checkInstanceState(ctx, projectSpec, replaySpec) if err != nil { return err } - return updateCompletedReplays(s.l, stateSummary, replaySpecRepo, replaySpec.ID) + return updateCompletedReplays(ctx, s.l, stateSummary, replaySpecRepo, replaySpec.ID) } func (s Syncer) checkInstanceState(ctx context.Context, projectSpec models.ProjectSpec, replaySpec models.ReplaySpec) (map[models.JobRunState]int, error) { @@ -112,9 +112,9 @@ func (s Syncer) checkInstanceState(ctx context.Context, projectSpec models.Proje return stateSummary, nil } -func updateCompletedReplays(l log.Logger, stateSummary map[models.JobRunState]int, replaySpecRepo store.ReplaySpecRepository, replayID uuid.UUID) error { +func updateCompletedReplays(ctx context.Context, l log.Logger, stateSummary map[models.JobRunState]int, replaySpecRepo store.ReplaySpecRepository, replayID uuid.UUID) error { if stateSummary[models.RunStateRunning] == 0 && stateSummary[models.RunStateFailed] > 0 { - if updateStatusErr := replaySpecRepo.UpdateStatus(replayID, models.ReplayStatusFailed, models.ReplayMessage{ + if updateStatusErr := replaySpecRepo.UpdateStatus(ctx, replayID, models.ReplayStatusFailed, models.ReplayMessage{ Type: models.ReplayStatusFailed, Message: ReplayMessageFailed, }); updateStatusErr != nil { @@ -122,7 +122,7 @@ func updateCompletedReplays(l log.Logger, stateSummary map[models.JobRunState]in return updateStatusErr } } else if stateSummary[models.RunStateRunning] == 0 && stateSummary[models.RunStateFailed] == 0 && stateSummary[models.RunStateSuccess] > 0 { - if updateStatusErr := replaySpecRepo.UpdateStatus(replayID, models.ReplayStatusSuccess, models.ReplayMessage{ + if updateStatusErr := replaySpecRepo.UpdateStatus(ctx, replayID, models.ReplayStatusSuccess, models.ReplayMessage{ Type: models.ReplayStatusSuccess, Message: ReplayMessageSuccess, }); updateStatusErr != nil { diff --git a/job/replay_syncer_test.go b/job/replay_syncer_test.go index eb97f6a83f..c93fdc6ef3 100644 --- a/job/replay_syncer_test.go +++ b/job/replay_syncer_test.go @@ -81,7 +81,7 @@ func TestReplaySyncer(t *testing.T) { t.Run("Sync", func(t *testing.T) { t.Run("should not return error when no replay with sync criteria found", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetAll").Return(projectSpecs, nil) + projectRepository.On("GetAll", ctx).Return(projectSpecs, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -90,7 +90,7 @@ func TestReplaySyncer(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByProjectIDAndStatus", projectSpecs[0].ID, job.ReplayStatusToSynced).Return([]models.ReplaySpec{}, store.ErrResourceNotFound) + replayRepository.On("GetByProjectIDAndStatus", ctx, projectSpecs[0].ID, job.ReplayStatusToSynced).Return([]models.ReplaySpec{}, store.ErrResourceNotFound) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) @@ -103,7 +103,7 @@ func TestReplaySyncer(t *testing.T) { }) t.Run("should return error when fetching replays failed", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetAll").Return(projectSpecs, nil) + projectRepository.On("GetAll", ctx).Return(projectSpecs, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -113,7 +113,7 @@ func TestReplaySyncer(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) errorMsg := "fetching replay error" - replayRepository.On("GetByProjectIDAndStatus", projectSpecs[0].ID, job.ReplayStatusToSynced).Return([]models.ReplaySpec{}, errors.New(errorMsg)) + replayRepository.On("GetByProjectIDAndStatus", ctx, projectSpecs[0].ID, job.ReplayStatusToSynced).Return([]models.ReplaySpec{}, errors.New(errorMsg)) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) @@ -126,7 +126,7 @@ func TestReplaySyncer(t *testing.T) { }) t.Run("should mark state of running replay to success if all instances are success", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetAll").Return(projectSpecs, nil) + projectRepository.On("GetAll", ctx).Return(projectSpecs, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -135,7 +135,7 @@ func TestReplaySyncer(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByProjectIDAndStatus", projectSpecs[0].ID, job.ReplayStatusToSynced).Return(activeReplaySpec, nil) + replayRepository.On("GetByProjectIDAndStatus", ctx, projectSpecs[0].ID, job.ReplayStatusToSynced).Return(activeReplaySpec, nil) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) @@ -160,7 +160,7 @@ func TestReplaySyncer(t *testing.T) { Type: models.ReplayStatusSuccess, Message: job.ReplayMessageSuccess, } - replayRepository.On("UpdateStatus", activeReplayUUID, models.ReplayStatusSuccess, successReplayMessage).Return(nil) + replayRepository.On("UpdateStatus", ctx, activeReplayUUID, models.ReplayStatusSuccess, successReplayMessage).Return(nil) replaySyncer := job.NewReplaySyncer(log, replaySpecRepoFac, projectRepoFactory, scheduler, time.Now) err := replaySyncer.Sync(context.TODO(), runTimeout) @@ -169,7 +169,7 @@ func TestReplaySyncer(t *testing.T) { }) t.Run("should mark state of running replay to failed if no longer running instance and one of instances is failed", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetAll").Return(projectSpecs, nil) + projectRepository.On("GetAll", ctx).Return(projectSpecs, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -178,7 +178,7 @@ func TestReplaySyncer(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByProjectIDAndStatus", projectSpecs[0].ID, job.ReplayStatusToSynced).Return(activeReplaySpec, nil) + replayRepository.On("GetByProjectIDAndStatus", ctx, projectSpecs[0].ID, job.ReplayStatusToSynced).Return(activeReplaySpec, nil) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) @@ -203,7 +203,7 @@ func TestReplaySyncer(t *testing.T) { Type: models.ReplayStatusFailed, Message: job.ReplayMessageFailed, } - replayRepository.On("UpdateStatus", activeReplayUUID, models.ReplayStatusFailed, failedReplayMessage).Return(nil) + replayRepository.On("UpdateStatus", ctx, activeReplayUUID, models.ReplayStatusFailed, failedReplayMessage).Return(nil) replaySyncer := job.NewReplaySyncer(log, replaySpecRepoFac, projectRepoFactory, scheduler, time.Now) err := replaySyncer.Sync(context.TODO(), runTimeout) @@ -212,7 +212,7 @@ func TestReplaySyncer(t *testing.T) { }) t.Run("should not update replay status if instances are still running", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetAll").Return(projectSpecs, nil) + projectRepository.On("GetAll", ctx).Return(projectSpecs, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -221,7 +221,7 @@ func TestReplaySyncer(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByProjectIDAndStatus", projectSpecs[0].ID, job.ReplayStatusToSynced).Return(activeReplaySpec, nil) + replayRepository.On("GetByProjectIDAndStatus", ctx, projectSpecs[0].ID, job.ReplayStatusToSynced).Return(activeReplaySpec, nil) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) @@ -249,7 +249,7 @@ func TestReplaySyncer(t *testing.T) { }) t.Run("should mark timeout replay as failed", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetAll").Return(projectSpecs, nil) + projectRepository.On("GetAll", ctx).Return(projectSpecs, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -270,7 +270,7 @@ func TestReplaySyncer(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByProjectIDAndStatus", projectSpecs[0].ID, job.ReplayStatusToSynced).Return(replaySpec, nil) + replayRepository.On("GetByProjectIDAndStatus", ctx, projectSpecs[0].ID, job.ReplayStatusToSynced).Return(replaySpec, nil) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) @@ -280,7 +280,7 @@ func TestReplaySyncer(t *testing.T) { Type: job.ReplayRunTimeout, Message: fmt.Sprintf("replay has been running since %s", replayCreatedAt.UTC().Format(job.TimestampLogFormat)), } - replayRepository.On("UpdateStatus", activeReplayUUID, models.ReplayStatusFailed, failedReplayMessage).Return(nil) + replayRepository.On("UpdateStatus", ctx, activeReplayUUID, models.ReplayStatusFailed, failedReplayMessage).Return(nil) replaySyncer := job.NewReplaySyncer(log, replaySpecRepoFac, projectRepoFactory, nil, time.Now) err := replaySyncer.Sync(context.TODO(), runTimeout) @@ -289,7 +289,7 @@ func TestReplaySyncer(t *testing.T) { }) t.Run("should return error when unable to get dag run status from batchScheduler", func(t *testing.T) { projectRepository := new(mock.ProjectRepository) - projectRepository.On("GetAll").Return(projectSpecs, nil) + projectRepository.On("GetAll", ctx).Return(projectSpecs, nil) defer projectRepository.AssertExpectations(t) projectRepoFactory := new(mock.ProjectRepoFactory) @@ -298,7 +298,7 @@ func TestReplaySyncer(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByProjectIDAndStatus", projectSpecs[0].ID, job.ReplayStatusToSynced).Return(activeReplaySpec, nil) + replayRepository.On("GetByProjectIDAndStatus", ctx, projectSpecs[0].ID, job.ReplayStatusToSynced).Return(activeReplaySpec, nil) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) diff --git a/job/replay_test.go b/job/replay_test.go index 5cbd7d1d5a..d7deaa83db 100644 --- a/job/replay_test.go +++ b/job/replay_test.go @@ -97,7 +97,7 @@ func TestReplay(t *testing.T) { t.Run("ReplayDryRun", func(t *testing.T) { t.Run("should fail if unable to fetch jobSpecs from project jobSpecRepo", func(t *testing.T) { projectJobSpecRepo := new(mock.ProjectJobSpecRepository) - projectJobSpecRepo.On("GetAll").Return(nil, errors.New("error while getting all dags")) + projectJobSpecRepo.On("GetAll", ctx).Return(nil, errors.New("error while getting all dags")) defer projectJobSpecRepo.AssertExpectations(t) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) @@ -121,7 +121,7 @@ func TestReplay(t *testing.T) { t.Run("should fail if unable to resolve jobs using dependency resolver", func(t *testing.T) { projectJobSpecRepo := new(mock.ProjectJobSpecRepository) - projectJobSpecRepo.On("GetAll").Return(dagSpec, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(dagSpec, nil) defer projectJobSpecRepo.AssertExpectations(t) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) @@ -168,7 +168,7 @@ func TestReplay(t *testing.T) { cyclicDagSpec = append(cyclicDagSpec, cyclicDag1, cyclicDag2) projectJobSpecRepo := new(mock.ProjectJobSpecRepository) - projectJobSpecRepo.On("GetAll").Return(cyclicDagSpec, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(cyclicDagSpec, nil) defer projectJobSpecRepo.AssertExpectations(t) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) @@ -199,7 +199,7 @@ func TestReplay(t *testing.T) { t.Run("resolve create replay tree for a dag with three day task window and mentioned dependencies", func(t *testing.T) { projectJobSpecRepo := new(mock.ProjectJobSpecRepository) - projectJobSpecRepo.On("GetAll").Return(dagSpec, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(dagSpec, nil) defer projectJobSpecRepo.AssertExpectations(t) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) @@ -248,7 +248,7 @@ func TestReplay(t *testing.T) { t.Run("resolve create replay tree for a dag with three day task window and mentioned dependencies", func(t *testing.T) { projectJobSpecRepo := new(mock.ProjectJobSpecRepository) - projectJobSpecRepo.On("GetAll").Return(dagSpec, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(dagSpec, nil) defer projectJobSpecRepo.AssertExpectations(t) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) @@ -301,7 +301,7 @@ func TestReplay(t *testing.T) { t.Run("Replay", func(t *testing.T) { t.Run("should fail if unable to fetch jobSpecs from project jobSpecRepo", func(t *testing.T) { projectJobSpecRepo := new(mock.ProjectJobSpecRepository) - projectJobSpecRepo.On("GetAll").Return(nil, errors.New("error while getting all dags")) + projectJobSpecRepo.On("GetAll", ctx).Return(nil, errors.New("error while getting all dags")) defer projectJobSpecRepo.AssertExpectations(t) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) @@ -325,7 +325,7 @@ func TestReplay(t *testing.T) { t.Run("should fail if replay manager throws an error", func(t *testing.T) { projectJobSpecRepo := new(mock.ProjectJobSpecRepository) - projectJobSpecRepo.On("GetAll").Return(dagSpec, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(dagSpec, nil) defer projectJobSpecRepo.AssertExpectations(t) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) @@ -365,7 +365,7 @@ func TestReplay(t *testing.T) { t.Run("should succeed if replay manager successfully processes request", func(t *testing.T) { projectJobSpecRepo := new(mock.ProjectJobSpecRepository) - projectJobSpecRepo.On("GetAll").Return(dagSpec, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(dagSpec, nil) defer projectJobSpecRepo.AssertExpectations(t) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) @@ -436,7 +436,7 @@ func TestReplay(t *testing.T) { replayManager := new(mock.ReplayManager) defer replayManager.AssertExpectations(t) errorMsg := "unable to fetch replay" - replayManager.On("GetReplay", replayID).Return(models.ReplaySpec{}, errors.New(errorMsg)) + replayManager.On("GetReplay", ctx, replayID).Return(models.ReplaySpec{}, errors.New(errorMsg)) replayRequest := models.ReplayRequest{ ID: replayID, Project: projSpec, @@ -464,7 +464,7 @@ func TestReplay(t *testing.T) { replayManager := new(mock.ReplayManager) defer replayManager.AssertExpectations(t) - replayManager.On("GetReplay", replayID).Return(replaySpec, nil) + replayManager.On("GetReplay", ctx, replayID).Return(replaySpec, nil) errorMsg := "unable to get status of a job run" replayManager.On("GetRunStatus", ctx, projSpec, startDate, endDate, specs[spec1].Name). Return([]models.JobStatus{}, errors.New(errorMsg)) @@ -531,7 +531,7 @@ func TestReplay(t *testing.T) { replayManager := new(mock.ReplayManager) defer replayManager.AssertExpectations(t) - replayManager.On("GetReplay", replayID).Return(replaySpec, nil) + replayManager.On("GetReplay", ctx, replayID).Return(replaySpec, nil) replayManager.On("GetRunStatus", ctx, projSpec, replaySpec.StartDate, replaySpec.EndDate, jobSpec0.Name).Return([]models.JobStatus{jobStatusList[0], jobStatusList[1], jobStatusList[2]}, nil) errorMsg := "unable to get status of a run" replayManager.On("GetRunStatus", ctx, projSpec, replaySpec.StartDate, replaySpec.EndDate, jobSpec1.Name).Return([]models.JobStatus{}, errors.New(errorMsg)) @@ -597,7 +597,7 @@ func TestReplay(t *testing.T) { replayManager := new(mock.ReplayManager) defer replayManager.AssertExpectations(t) - replayManager.On("GetReplay", replayID).Return(replaySpec, nil) + replayManager.On("GetReplay", ctx, replayID).Return(replaySpec, nil) replayManager.On("GetRunStatus", ctx, projSpec, replaySpec.StartDate, replaySpec.EndDate, jobSpec0.Name).Return([]models.JobStatus{jobStatusList[0], jobStatusList[1], jobStatusList[2]}, nil) replayManager.On("GetRunStatus", ctx, projSpec, replaySpec.StartDate, replaySpec.EndDate, jobSpec1.Name).Return([]models.JobStatus{jobStatusList[0], jobStatusList[1], jobStatusList[2]}, nil) errorMsg := "unable to get status of a run" @@ -635,10 +635,10 @@ func TestReplay(t *testing.T) { replayManager := new(mock.ReplayManager) defer replayManager.AssertExpectations(t) - replayManager.On("GetReplayList", projSpec.ID).Return(replaySpecs, nil) + replayManager.On("GetReplayList", ctx, projSpec.ID).Return(replaySpecs, nil) jobSvc := job.NewService(nil, nil, nil, dumpAssets, nil, nil, nil, nil, replayManager) - replayList, err := jobSvc.GetReplayList(projSpec.ID) + replayList, err := jobSvc.GetReplayList(ctx, projSpec.ID) assert.Nil(t, err) assert.Equal(t, replaySpecs, replayList) @@ -647,10 +647,10 @@ func TestReplay(t *testing.T) { replayManager := new(mock.ReplayManager) defer replayManager.AssertExpectations(t) errorMsg := "unable to get replay list" - replayManager.On("GetReplayList", projSpec.ID).Return([]models.ReplaySpec{}, errors.New(errorMsg)) + replayManager.On("GetReplayList", ctx, projSpec.ID).Return([]models.ReplaySpec{}, errors.New(errorMsg)) jobSvc := job.NewService(nil, nil, nil, dumpAssets, nil, nil, nil, nil, replayManager) - replayList, err := jobSvc.GetReplayList(projSpec.ID) + replayList, err := jobSvc.GetReplayList(ctx, projSpec.ID) assert.Equal(t, errorMsg, err.Error()) assert.Equal(t, []models.ReplaySpec{}, replayList) diff --git a/job/replay_validator.go b/job/replay_validator.go index 78a3bde421..a082c5cb3f 100644 --- a/job/replay_validator.go +++ b/job/replay_validator.go @@ -34,7 +34,7 @@ func (v *Validator) Validate(ctx context.Context, replaySpecRepo store.ReplaySpe } //check another replay active for this dag - activeReplaySpecs, err := replaySpecRepo.GetByStatus(ReplayStatusToValidate) + activeReplaySpecs, err := replaySpecRepo.GetByStatus(ctx, ReplayStatusToValidate) if err != nil { if err == store.ErrResourceNotFound { return nil @@ -44,11 +44,11 @@ func (v *Validator) Validate(ctx context.Context, replaySpecRepo store.ReplaySpe return validateReplayJobsConflict(activeReplaySpecs, reqInput, reqReplayNodes) } //check and cancel if found conflicted replays for same job ID - return cancelConflictedReplays(replaySpecRepo, reqInput) + return cancelConflictedReplays(ctx, replaySpecRepo, reqInput) } -func cancelConflictedReplays(replaySpecRepo store.ReplaySpecRepository, reqInput models.ReplayRequest) error { - duplicatedReplaySpecs, err := replaySpecRepo.GetByJobIDAndStatus(reqInput.Job.ID, ReplayStatusToValidate) +func cancelConflictedReplays(ctx context.Context, replaySpecRepo store.ReplaySpecRepository, reqInput models.ReplayRequest) error { + duplicatedReplaySpecs, err := replaySpecRepo.GetByJobIDAndStatus(ctx, reqInput.Job.ID, ReplayStatusToValidate) if err != nil { if err == store.ErrResourceNotFound { return nil @@ -56,7 +56,7 @@ func cancelConflictedReplays(replaySpecRepo store.ReplaySpecRepository, reqInput return err } for _, replaySpec := range duplicatedReplaySpecs { - if err := replaySpecRepo.UpdateStatus(replaySpec.ID, models.ReplayStatusCancelled, models.ReplayMessage{ + if err := replaySpecRepo.UpdateStatus(ctx, replaySpec.ID, models.ReplayStatusCancelled, models.ReplayMessage{ Type: ErrConflictedJobRun.Error(), Message: fmt.Sprintf("force started replay with ID: %s", reqInput.ID), }); err != nil { diff --git a/job/replay_validator_test.go b/job/replay_validator_test.go index e811715676..b86b06442b 100644 --- a/job/replay_validator_test.go +++ b/job/replay_validator_test.go @@ -62,7 +62,7 @@ func TestReplayValidator(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) errMessage := "error checking other replays" - replayRepository.On("GetByStatus", job.ReplayStatusToValidate).Return([]models.ReplaySpec{}, errors.New(errMessage)) + replayRepository.On("GetByStatus", ctx, job.ReplayStatusToValidate).Return([]models.ReplaySpec{}, errors.New(errMessage)) scheduler := new(mock.Scheduler) defer scheduler.AssertExpectations(t) @@ -91,7 +91,7 @@ func TestReplayValidator(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByStatus", job.ReplayStatusToValidate).Return(activeReplaySpec, nil) + replayRepository.On("GetByStatus", ctx, job.ReplayStatusToValidate).Return(activeReplaySpec, nil) scheduler := new(mock.Scheduler) defer scheduler.AssertExpectations(t) @@ -122,7 +122,7 @@ func TestReplayValidator(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByStatus", job.ReplayStatusToValidate).Return(activeReplaySpec, nil) + replayRepository.On("GetByStatus", ctx, job.ReplayStatusToValidate).Return(activeReplaySpec, nil) scheduler := new(mock.Scheduler) defer scheduler.AssertExpectations(t) @@ -148,7 +148,7 @@ func TestReplayValidator(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByStatus", job.ReplayStatusToValidate).Return(activeReplaySpec, nil) + replayRepository.On("GetByStatus", ctx, job.ReplayStatusToValidate).Return(activeReplaySpec, nil) scheduler := new(mock.Scheduler) defer scheduler.AssertExpectations(t) @@ -175,7 +175,7 @@ func TestReplayValidator(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByStatus", job.ReplayStatusToValidate).Return(activeReplaySpec, nil) + replayRepository.On("GetByStatus", ctx, job.ReplayStatusToValidate).Return(activeReplaySpec, nil) scheduler := new(mock.Scheduler) defer scheduler.AssertExpectations(t) @@ -243,7 +243,7 @@ func TestReplayValidator(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByStatus", job.ReplayStatusToValidate).Return(activeReplaySpec, nil) + replayRepository.On("GetByStatus", ctx, job.ReplayStatusToValidate).Return(activeReplaySpec, nil) scheduler := new(mock.Scheduler) defer scheduler.AssertExpectations(t) @@ -274,13 +274,13 @@ func TestReplayValidator(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("GetByJobIDAndStatus", activeReplaySpec[0].Job.ID, job.ReplayStatusToValidate).Return(activeReplaySpec, nil) + replayRepository.On("GetByJobIDAndStatus", ctx, activeReplaySpec[0].Job.ID, job.ReplayStatusToValidate).Return(activeReplaySpec, nil) cancelledReplayMessage := models.ReplayMessage{ Type: job.ErrConflictedJobRun.Error(), Message: fmt.Sprintf("force started replay with ID: %s", replayRequest.ID), } - replayRepository.On("UpdateStatus", activeReplayUUID, models.ReplayStatusCancelled, cancelledReplayMessage).Return(nil) + replayRepository.On("UpdateStatus", ctx, activeReplayUUID, models.ReplayStatusCancelled, cancelledReplayMessage).Return(nil) scheduler := new(mock.Scheduler) defer scheduler.AssertExpectations(t) diff --git a/job/replay_worker.go b/job/replay_worker.go index bf1d6cef20..924a9b6cb2 100644 --- a/job/replay_worker.go +++ b/job/replay_worker.go @@ -26,11 +26,11 @@ type replayWorker struct { func (w *replayWorker) Process(ctx context.Context, input models.ReplayRequest) (err error) { replaySpecRepo := w.replaySpecRepoFac.New() // mark replay request in progress - if inProgressErr := replaySpecRepo.UpdateStatus(input.ID, models.ReplayStatusInProgress, models.ReplayMessage{}); inProgressErr != nil { + if inProgressErr := replaySpecRepo.UpdateStatus(ctx, input.ID, models.ReplayStatusInProgress, models.ReplayMessage{}); inProgressErr != nil { return inProgressErr } - replaySpec, err := replaySpecRepo.GetByID(input.ID) + replaySpec, err := replaySpecRepo.GetByID(ctx, input.ID) if err != nil { return err } @@ -43,7 +43,7 @@ func (w *replayWorker) Process(ctx context.Context, input models.ReplayRequest) if err = w.scheduler.Clear(ctx, input.Project, treeNode.GetName(), startTime, endTime); err != nil { err = errors.Wrapf(err, "error while clearing dag runs for job %s", treeNode.GetName()) w.log.Warn("error while running replay", "replay id", input.ID.String(), "error", err.Error()) - if updateStatusErr := replaySpecRepo.UpdateStatus(input.ID, models.ReplayStatusFailed, models.ReplayMessage{ + if updateStatusErr := replaySpecRepo.UpdateStatus(ctx, input.ID, models.ReplayStatusFailed, models.ReplayMessage{ Type: AirflowClearDagRunFailed, Message: err.Error(), }); updateStatusErr != nil { @@ -53,7 +53,7 @@ func (w *replayWorker) Process(ctx context.Context, input models.ReplayRequest) } } - if err = replaySpecRepo.UpdateStatus(input.ID, models.ReplayStatusReplayed, models.ReplayMessage{}); err != nil { + if err = replaySpecRepo.UpdateStatus(ctx, input.ID, models.ReplayStatusReplayed, models.ReplayMessage{}); err != nil { return err } w.log.Info("successfully cleared instances during replay", "replay id", input.ID.String()) diff --git a/job/replay_worker_test.go b/job/replay_worker_test.go index 8685a8a445..1ccefda778 100644 --- a/job/replay_worker_test.go +++ b/job/replay_worker_test.go @@ -63,7 +63,7 @@ func TestReplayWorker(t *testing.T) { replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) errMessage := "replay repo error" - replayRepository.On("UpdateStatus", currUUID, models.ReplayStatusInProgress, models.ReplayMessage{}).Return(errors.New(errMessage)) + replayRepository.On("UpdateStatus", ctx, currUUID, models.ReplayStatusInProgress, models.ReplayMessage{}).Return(errors.New(errMessage)) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) @@ -78,18 +78,18 @@ func TestReplayWorker(t *testing.T) { ctx := context.Background() replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("UpdateStatus", currUUID, models.ReplayStatusInProgress, models.ReplayMessage{}).Return(nil) + replayRepository.On("UpdateStatus", ctx, currUUID, models.ReplayStatusInProgress, models.ReplayMessage{}).Return(nil) errMessage := "error while clearing dag runs for job job-name: batchScheduler clear error" failedReplayMessage := models.ReplayMessage{ Type: job.AirflowClearDagRunFailed, Message: errMessage, } - replayRepository.On("UpdateStatus", currUUID, models.ReplayStatusFailed, failedReplayMessage).Return(nil) + replayRepository.On("UpdateStatus", ctx, currUUID, models.ReplayStatusFailed, failedReplayMessage).Return(nil) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) replaySpecRepoFac.On("New").Return(replayRepository) - replayRepository.On("GetByID", currUUID).Return(replaySpec, nil) + replayRepository.On("GetByID", ctx, currUUID).Return(replaySpec, nil) scheduler := new(mock.Scheduler) defer scheduler.AssertExpectations(t) @@ -105,19 +105,19 @@ func TestReplayWorker(t *testing.T) { ctx := context.Background() replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("UpdateStatus", currUUID, models.ReplayStatusInProgress, models.ReplayMessage{}).Return(nil) + replayRepository.On("UpdateStatus", ctx, currUUID, models.ReplayStatusInProgress, models.ReplayMessage{}).Return(nil) errMessage := "error while clearing dag runs for job job-name: batchScheduler clear error" failedReplayMessage := models.ReplayMessage{ Type: job.AirflowClearDagRunFailed, Message: errMessage, } updateStatusErr := errors.New("error while updating status to failed") - replayRepository.On("UpdateStatus", currUUID, models.ReplayStatusFailed, failedReplayMessage).Return(updateStatusErr) + replayRepository.On("UpdateStatus", ctx, currUUID, models.ReplayStatusFailed, failedReplayMessage).Return(updateStatusErr) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) replaySpecRepoFac.On("New").Return(replayRepository) - replayRepository.On("GetByID", currUUID).Return(replaySpec, nil) + replayRepository.On("GetByID", ctx, currUUID).Return(replaySpec, nil) scheduler := new(mock.Scheduler) defer scheduler.AssertExpectations(t) @@ -133,14 +133,14 @@ func TestReplayWorker(t *testing.T) { ctx := context.Background() replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("UpdateStatus", currUUID, models.ReplayStatusInProgress, models.ReplayMessage{}).Return(nil) + replayRepository.On("UpdateStatus", ctx, currUUID, models.ReplayStatusInProgress, models.ReplayMessage{}).Return(nil) updateSuccessStatusErr := errors.New("error while updating replay request") - replayRepository.On("UpdateStatus", currUUID, models.ReplayStatusReplayed, models.ReplayMessage{}).Return(updateSuccessStatusErr) + replayRepository.On("UpdateStatus", ctx, currUUID, models.ReplayStatusReplayed, models.ReplayMessage{}).Return(updateSuccessStatusErr) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) replaySpecRepoFac.On("New").Return(replayRepository) - replayRepository.On("GetByID", currUUID).Return(replaySpec, nil) + replayRepository.On("GetByID", ctx, currUUID).Return(replaySpec, nil) scheduler := new(mock.Scheduler) defer scheduler.AssertExpectations(t) @@ -154,13 +154,13 @@ func TestReplayWorker(t *testing.T) { t.Run("should update replay status if successful", func(t *testing.T) { ctx := context.Background() replayRepository := new(mock.ReplayRepository) - replayRepository.On("UpdateStatus", currUUID, models.ReplayStatusInProgress, models.ReplayMessage{}).Return(nil) - replayRepository.On("UpdateStatus", currUUID, models.ReplayStatusReplayed, models.ReplayMessage{}).Return(nil) + replayRepository.On("UpdateStatus", ctx, currUUID, models.ReplayStatusInProgress, models.ReplayMessage{}).Return(nil) + replayRepository.On("UpdateStatus", ctx, currUUID, models.ReplayStatusReplayed, models.ReplayMessage{}).Return(nil) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) replaySpecRepoFac.On("New").Return(replayRepository) - replayRepository.On("GetByID", currUUID).Return(replaySpec, nil) + replayRepository.On("GetByID", ctx, currUUID).Return(replaySpec, nil) scheduler := new(mock.Scheduler) defer scheduler.AssertExpectations(t) @@ -174,13 +174,13 @@ func TestReplayWorker(t *testing.T) { ctx := context.Background() replayRepository := new(mock.ReplayRepository) defer replayRepository.AssertExpectations(t) - replayRepository.On("UpdateStatus", currUUID, models.ReplayStatusInProgress, models.ReplayMessage{}).Return(nil) + replayRepository.On("UpdateStatus", ctx, currUUID, models.ReplayStatusInProgress, models.ReplayMessage{}).Return(nil) replaySpecRepoFac := new(mock.ReplaySpecRepoFactory) defer replaySpecRepoFac.AssertExpectations(t) replaySpecRepoFac.On("New").Return(replayRepository) errMessage := "fetch replay failed" - replayRepository.On("GetByID", currUUID).Return(models.ReplaySpec{}, errors.New(errMessage)) + replayRepository.On("GetByID", ctx, currUUID).Return(models.ReplaySpec{}, errors.New(errMessage)) scheduler := new(mock.Scheduler) defer scheduler.AssertExpectations(t) diff --git a/job/service.go b/job/service.go index 97ba30f369..cf26cbf015 100644 --- a/job/service.go +++ b/job/service.go @@ -61,8 +61,8 @@ type ProjectRepoFactory interface { type ReplayManager interface { Init() Replay(context.Context, models.ReplayRequest) (string, error) - GetReplay(uuid.UUID) (models.ReplaySpec, error) - GetReplayList(projectID uuid.UUID) ([]models.ReplaySpec, error) + GetReplay(context.Context, uuid.UUID) (models.ReplaySpec, error) + GetReplayList(ctx context.Context, projectID uuid.UUID) ([]models.ReplaySpec, error) GetRunStatus(ctx context.Context, projectSpec models.ProjectSpec, startDate time.Time, endDate time.Time, jobName string) ([]models.JobStatus, error) } @@ -89,17 +89,17 @@ type Service struct { } // Create constructs a Job for a namespace and commits it to the store -func (srv *Service) Create(namespace models.NamespaceSpec, spec models.JobSpec) error { +func (srv *Service) Create(ctx context.Context, namespace models.NamespaceSpec, spec models.JobSpec) error { jobRepo := srv.jobSpecRepoFactory.New(namespace) - if err := jobRepo.Save(spec); err != nil { + if err := jobRepo.Save(ctx, spec); err != nil { return errors.Wrapf(err, "failed to save job: %s", spec.Name) } return nil } // GetByName fetches a Job by name for a specific namespace -func (srv *Service) GetByName(name string, namespace models.NamespaceSpec) (models.JobSpec, error) { - jobSpec, err := srv.jobSpecRepoFactory.New(namespace).GetByName(name) +func (srv *Service) GetByName(ctx context.Context, name string, namespace models.NamespaceSpec) (models.JobSpec, error) { + jobSpec, err := srv.jobSpecRepoFactory.New(namespace).GetByName(ctx, name) if err != nil { return models.JobSpec{}, errors.Wrapf(err, "failed to retrieve job") } @@ -107,16 +107,16 @@ func (srv *Service) GetByName(name string, namespace models.NamespaceSpec) (mode } // GetByNameForProject fetches a Job by name for a specific project -func (srv *Service) GetByNameForProject(name string, proj models.ProjectSpec) (models.JobSpec, models.NamespaceSpec, error) { - jobSpec, namespace, err := srv.projectJobSpecRepoFactory.New(proj).GetByName(name) +func (srv *Service) GetByNameForProject(ctx context.Context, name string, proj models.ProjectSpec) (models.JobSpec, models.NamespaceSpec, error) { + jobSpec, namespace, err := srv.projectJobSpecRepoFactory.New(proj).GetByName(ctx, name) if err != nil { return models.JobSpec{}, models.NamespaceSpec{}, errors.Wrapf(err, "failed to retrieve job") } return jobSpec, namespace, nil } -func (srv *Service) GetAll(namespace models.NamespaceSpec) ([]models.JobSpec, error) { - jobSpecs, err := srv.jobSpecRepoFactory.New(namespace).GetAll() +func (srv *Service) GetAll(ctx context.Context, namespace models.NamespaceSpec) ([]models.JobSpec, error) { + jobSpecs, err := srv.jobSpecRepoFactory.New(namespace).GetAll(ctx) if err != nil { return nil, errors.Wrapf(err, "failed to retrieve jobs") } @@ -187,7 +187,7 @@ func (srv *Service) Delete(ctx context.Context, namespace models.NamespaceSpec, jobSpecRepo := srv.jobSpecRepoFactory.New(namespace) // delete from internal store - if err := jobSpecRepo.Delete(jobSpec.Name); err != nil { + if err := jobSpecRepo.Delete(ctx, jobSpec.Name); err != nil { return errors.Wrapf(err, "failed to delete spec: %s", jobSpec.Name) } @@ -214,7 +214,7 @@ func (srv *Service) Sync(ctx context.Context, namespace models.NamespaceSpec, pr } srv.notifyProgress(progressObserver, &EventJobPriorityWeightAssign{}) - jobSpecs, err = srv.filterJobSpecForNamespace(jobSpecs, namespace) + jobSpecs, err = srv.filterJobSpecForNamespace(ctx, jobSpecs, namespace) if err != nil { return err } @@ -253,9 +253,9 @@ func (srv *Service) Sync(ctx context.Context, namespace models.NamespaceSpec, pr } // KeepOnly only keeps the provided jobSpecs in argument and deletes rest from spec repository -func (srv *Service) KeepOnly(namespace models.NamespaceSpec, specsToKeep []models.JobSpec, progressObserver progress.Observer) error { +func (srv *Service) KeepOnly(ctx context.Context, namespace models.NamespaceSpec, specsToKeep []models.JobSpec, progressObserver progress.Observer) error { jobSpecRepo := srv.jobSpecRepoFactory.New(namespace) - jobSpecs, err := jobSpecRepo.GetAll() + jobSpecs, err := jobSpecRepo.GetAll(ctx) if err != nil { return errors.Wrapf(err, "failed to fetch specs for namespace %s", namespace.Name) } @@ -275,7 +275,7 @@ func (srv *Service) KeepOnly(namespace models.NamespaceSpec, specsToKeep []model for _, jobName := range jobsToDelete { // delete raw spec - if err := jobSpecRepo.Delete(jobName); err != nil { + if err := jobSpecRepo.Delete(ctx, jobName); err != nil { return errors.Wrapf(err, "failed to delete spec: %s", jobName) } srv.notifyProgress(progressObserver, &EventSavedJobDelete{jobName}) @@ -284,9 +284,9 @@ func (srv *Service) KeepOnly(namespace models.NamespaceSpec, specsToKeep []model } // filterJobSpecForNamespace returns only job specs of a given namespace -func (srv *Service) filterJobSpecForNamespace(jobSpecs []models.JobSpec, namespace models.NamespaceSpec) ([]models.JobSpec, error) { +func (srv *Service) filterJobSpecForNamespace(ctx context.Context, jobSpecs []models.JobSpec, namespace models.NamespaceSpec) ([]models.JobSpec, error) { jobSpecRepo := srv.jobSpecRepoFactory.New(namespace) - namespaceJobSpecs, err := jobSpecRepo.GetAll() + namespaceJobSpecs, err := jobSpecRepo.GetAll(ctx) if err != nil { return nil, err } @@ -307,7 +307,7 @@ func (srv *Service) filterJobSpecForNamespace(jobSpecs []models.JobSpec, namespa func (srv *Service) GetDependencyResolvedSpecs(ctx context.Context, proj models.ProjectSpec, projectJobSpecRepo store.ProjectJobSpecRepository, progressObserver progress.Observer) (resolvedSpecs []models.JobSpec, resolvedErrors error) { // fetch all jobs since dependency resolution happens for all jobs in a project, not just for a namespace - jobSpecs, err := projectJobSpecRepo.GetAll() + jobSpecs, err := projectJobSpecRepo.GetAll(ctx) if err != nil { return nil, errors.Wrapf(err, "failed to retrieve jobs") } @@ -378,10 +378,10 @@ func (srv *Service) isJobDeletable(ctx context.Context, projectSpec models.Proje return nil } -func (srv *Service) GetByDestination(projectSpec models.ProjectSpec, destination string) (models.JobSpec, error) { +func (srv *Service) GetByDestination(ctx context.Context, projectSpec models.ProjectSpec, destination string) (models.JobSpec, error) { // generate job spec using datastore destination. if a destination can be owned by multiple jobs, need to change to list projectJobSpecRepo := srv.projectJobSpecRepoFactory.New(projectSpec) - jobSpec, _, err := projectJobSpecRepo.GetByDestination(destination) + jobSpec, _, err := projectJobSpecRepo.GetByDestination(ctx, destination) if err != nil { return models.JobSpec{}, err } diff --git a/job/service_test.go b/job/service_test.go index e683992ac8..99db3b61f8 100644 --- a/job/service_test.go +++ b/job/service_test.go @@ -41,7 +41,7 @@ func TestService(t *testing.T) { } repo := new(mock.JobSpecRepository) - repo.On("Save", jobSpec).Return(nil) + repo.On("Save", ctx, jobSpec).Return(nil) defer repo.AssertExpectations(t) repoFac := new(mock.JobSpecRepoFactory) @@ -52,7 +52,7 @@ func TestService(t *testing.T) { defer projJobSpecRepoFac.AssertExpectations(t) svc := job.NewService(repoFac, nil, nil, dumpAssets, nil, nil, nil, projJobSpecRepoFac, nil) - err := svc.Create(namespaceSpec, jobSpec) + err := svc.Create(ctx, namespaceSpec, jobSpec) assert.Nil(t, err) }) @@ -76,7 +76,7 @@ func TestService(t *testing.T) { } repo := new(mock.JobSpecRepository) - repo.On("Save", jobSpec).Return(errors.New("unknown error")) + repo.On("Save", ctx, jobSpec).Return(errors.New("unknown error")) defer repo.AssertExpectations(t) repoFac := new(mock.JobSpecRepoFactory) @@ -84,7 +84,7 @@ func TestService(t *testing.T) { defer repoFac.AssertExpectations(t) svc := job.NewService(repoFac, nil, nil, dumpAssets, nil, nil, nil, nil, nil) - err := svc.Create(namespaceSpec, jobSpec) + err := svc.Create(ctx, namespaceSpec, jobSpec) assert.NotNil(t, err) }) }) @@ -214,7 +214,7 @@ func TestService(t *testing.T) { } jobSpecRepo := new(mock.JobSpecRepository) - jobSpecRepo.On("GetAll").Return(jobSpecsBase, nil) + jobSpecRepo.On("GetAll", ctx).Return(jobSpecsBase, nil) defer jobSpecRepo.AssertExpectations(t) jobSpecRepoFac := new(mock.JobSpecRepoFactory) @@ -222,7 +222,7 @@ func TestService(t *testing.T) { defer jobSpecRepoFac.AssertExpectations(t) projectJobSpecRepo := new(mock.ProjectJobSpecRepository) - projectJobSpecRepo.On("GetAll").Return(jobSpecsBase, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(jobSpecsBase, nil) defer projectJobSpecRepo.AssertExpectations(t) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) @@ -300,7 +300,7 @@ func TestService(t *testing.T) { // used to store raw job specs jobSpecRepo := new(mock.JobSpecRepository) - jobSpecRepo.On("GetAll").Return(jobSpecsBase, nil) + jobSpecRepo.On("GetAll", ctx).Return(jobSpecsBase, nil) defer jobSpecRepo.AssertExpectations(t) jobSpecRepoFac := new(mock.JobSpecRepoFactory) @@ -308,7 +308,7 @@ func TestService(t *testing.T) { defer jobSpecRepoFac.AssertExpectations(t) projectJobSpecRepo := new(mock.ProjectJobSpecRepository) - projectJobSpecRepo.On("GetAll").Return(jobSpecsBase, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(jobSpecsBase, nil) defer projectJobSpecRepo.AssertExpectations(t) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) @@ -326,7 +326,7 @@ func TestService(t *testing.T) { defer priorityResolver.AssertExpectations(t) // fetch currently stored - projectJobSpecRepo.On("GetAll").Return(jobSpecsBase, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(jobSpecsBase, nil) batchScheduler := new(mock.Scheduler) batchScheduler.On("DeployJobs", ctx, namespaceSpec, jobSpecsAfterPriorityResolve, nil).Return(nil) @@ -368,7 +368,7 @@ func TestService(t *testing.T) { defer jobSpecRepoFac.AssertExpectations(t) projectJobSpecRepo := new(mock.ProjectJobSpecRepository) - projectJobSpecRepo.On("GetAll").Return(jobSpecsBase, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(jobSpecsBase, nil) defer projectJobSpecRepo.AssertExpectations(t) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) @@ -437,7 +437,7 @@ func TestService(t *testing.T) { } jobSpecRepo := new(mock.JobSpecRepository) - jobSpecRepo.On("GetAll").Return(jobSpecsBase, nil) + jobSpecRepo.On("GetAll", ctx).Return(jobSpecsBase, nil) defer jobSpecRepo.AssertExpectations(t) jobSpecRepoFac := new(mock.JobSpecRepoFactory) @@ -445,7 +445,7 @@ func TestService(t *testing.T) { defer jobSpecRepoFac.AssertExpectations(t) projectJobSpecRepo := new(mock.ProjectJobSpecRepository) - projectJobSpecRepo.On("GetAll").Return(jobSpecsBase, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(jobSpecsBase, nil) defer projectJobSpecRepo.AssertExpectations(t) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) @@ -531,7 +531,7 @@ func TestService(t *testing.T) { // used to store raw job specs jobSpecRepo := new(mock.JobSpecRepository) - jobSpecRepo.On("GetAll").Return(jobSpecsBase, nil) + jobSpecRepo.On("GetAll", ctx).Return(jobSpecsBase, nil) defer jobSpecRepo.AssertExpectations(t) jobSpecRepoFac := new(mock.JobSpecRepoFactory) @@ -545,12 +545,12 @@ func TestService(t *testing.T) { defer projJobSpecRepoFac.AssertExpectations(t) // fetch currently stored - jobSpecRepo.On("GetAll").Return(jobSpecsBase, nil) + jobSpecRepo.On("GetAll", ctx).Return(jobSpecsBase, nil) // delete unwanted - jobSpecRepo.On("Delete", jobSpecsBase[0].Name).Return(nil) + jobSpecRepo.On("Delete", ctx, jobSpecsBase[0].Name).Return(nil) svc := job.NewService(jobSpecRepoFac, nil, nil, dumpAssets, nil, nil, nil, projJobSpecRepoFac, nil) - err := svc.KeepOnly(namespaceSpec, toKeep, nil) + err := svc.KeepOnly(ctx, namespaceSpec, toKeep, nil) assert.Nil(t, err) }) }) @@ -599,7 +599,7 @@ func TestService(t *testing.T) { } jobSpecRepo := new(mock.JobSpecRepository) - jobSpecRepo.On("Delete", "test").Return(nil) + jobSpecRepo.On("Delete", ctx, "test").Return(nil) defer jobSpecRepo.AssertExpectations(t) jobSpecRepoFac := new(mock.JobSpecRepoFactory) @@ -607,7 +607,7 @@ func TestService(t *testing.T) { defer jobSpecRepoFac.AssertExpectations(t) projectJobSpecRepo := new(mock.ProjectJobSpecRepository) - projectJobSpecRepo.On("GetAll").Return(jobSpecsBase, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(jobSpecsBase, nil) defer projectJobSpecRepo.AssertExpectations(t) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) @@ -685,7 +685,7 @@ func TestService(t *testing.T) { defer jobSpecRepoFac.AssertExpectations(t) projectJobSpecRepo := new(mock.ProjectJobSpecRepository) - projectJobSpecRepo.On("GetAll").Return(jobSpecsBase, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(jobSpecsBase, nil) defer projectJobSpecRepo.AssertExpectations(t) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) @@ -718,14 +718,14 @@ func TestService(t *testing.T) { projectJobSpecRepo := new(mock.ProjectJobSpecRepository) defer projectJobSpecRepo.AssertExpectations(t) - projectJobSpecRepo.On("GetByDestination", destination).Return(jobSpec1, models.ProjectSpec{}, nil) + projectJobSpecRepo.On("GetByDestination", ctx, destination).Return(jobSpec1, models.ProjectSpec{}, nil) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) defer projJobSpecRepoFac.AssertExpectations(t) projJobSpecRepoFac.On("New", projSpec).Return(projectJobSpecRepo) svc := job.NewService(nil, nil, nil, dumpAssets, nil, nil, nil, projJobSpecRepoFac, nil) - jobSpecsResult, err := svc.GetByDestination(projSpec, destination) + jobSpecsResult, err := svc.GetByDestination(ctx, projSpec, destination) assert.Nil(t, err) assert.Equal(t, jobSpec1, jobSpecsResult) }) @@ -738,14 +738,14 @@ func TestService(t *testing.T) { projectJobSpecRepo := new(mock.ProjectJobSpecRepository) defer projectJobSpecRepo.AssertExpectations(t) errorMsg := "unable to fetch jobspec" - projectJobSpecRepo.On("GetByDestination", destination).Return(models.JobSpec{}, models.ProjectSpec{}, errors.New(errorMsg)) + projectJobSpecRepo.On("GetByDestination", ctx, destination).Return(models.JobSpec{}, models.ProjectSpec{}, errors.New(errorMsg)) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) defer projJobSpecRepoFac.AssertExpectations(t) projJobSpecRepoFac.On("New", projSpec).Return(projectJobSpecRepo) svc := job.NewService(nil, nil, nil, dumpAssets, nil, nil, nil, projJobSpecRepoFac, nil) - jobSpecsResult, err := svc.GetByDestination(projSpec, destination) + jobSpecsResult, err := svc.GetByDestination(ctx, projSpec, destination) assert.Contains(t, err.Error(), errorMsg) assert.Equal(t, models.JobSpec{}, jobSpecsResult) }) @@ -763,7 +763,7 @@ func TestService(t *testing.T) { projectJobSpecRepo := new(mock.ProjectJobSpecRepository) defer projectJobSpecRepo.AssertExpectations(t) - projectJobSpecRepo.On("GetAll").Return(jobSpecs, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(jobSpecs, nil) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) defer projJobSpecRepoFac.AssertExpectations(t) @@ -793,7 +793,7 @@ func TestService(t *testing.T) { projectJobSpecRepo := new(mock.ProjectJobSpecRepository) defer projectJobSpecRepo.AssertExpectations(t) errorMsg := "unable to get all job specs of a project" - projectJobSpecRepo.On("GetAll").Return([]models.JobSpec{}, errors.New(errorMsg)) + projectJobSpecRepo.On("GetAll", ctx).Return([]models.JobSpec{}, errors.New(errorMsg)) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) defer projJobSpecRepoFac.AssertExpectations(t) @@ -818,7 +818,7 @@ func TestService(t *testing.T) { projectJobSpecRepo := new(mock.ProjectJobSpecRepository) defer projectJobSpecRepo.AssertExpectations(t) - projectJobSpecRepo.On("GetAll").Return(jobSpecs, nil) + projectJobSpecRepo.On("GetAll", ctx).Return(jobSpecs, nil) projJobSpecRepoFac := new(mock.ProjectJobSpecRepoFactory) defer projJobSpecRepoFac.AssertExpectations(t) diff --git a/mock/backup.go b/mock/backup.go index 1e3a2adee2..7a26629070 100644 --- a/mock/backup.go +++ b/mock/backup.go @@ -1,6 +1,8 @@ package mock import ( + "context" + "github.com/odpf/optimus/models" "github.com/odpf/optimus/store" "github.com/stretchr/testify/mock" @@ -10,12 +12,12 @@ type BackupRepo struct { mock.Mock } -func (repo *BackupRepo) Save(spec models.BackupSpec) error { - return repo.Called(spec).Error(0) +func (repo *BackupRepo) Save(ctx context.Context, spec models.BackupSpec) error { + return repo.Called(ctx, spec).Error(0) } -func (repo *BackupRepo) GetAll() ([]models.BackupSpec, error) { - args := repo.Called() +func (repo *BackupRepo) GetAll(ctx context.Context) ([]models.BackupSpec, error) { + args := repo.Called(ctx) return args.Get(0).([]models.BackupSpec), args.Error(1) } diff --git a/mock/datastore.go b/mock/datastore.go index 268ff332d1..b6ae20e290 100644 --- a/mock/datastore.go +++ b/mock/datastore.go @@ -94,8 +94,8 @@ type DatastoreService struct { mock.Mock } -func (d *DatastoreService) GetAll(spec models.NamespaceSpec, datastoreName string) ([]models.ResourceSpec, error) { - args := d.Called(spec, datastoreName) +func (d *DatastoreService) GetAll(ctx context.Context, spec models.NamespaceSpec, datastoreName string) ([]models.ResourceSpec, error) { + args := d.Called(ctx, spec, datastoreName) return args.Get(0).([]models.ResourceSpec), args.Error(1) } @@ -126,8 +126,8 @@ func (d *DatastoreService) BackupResource(ctx context.Context, req models.Backup return args.Get(0).([]string), args.Error(1) } -func (d *DatastoreService) ListBackupResources(projectSpec models.ProjectSpec, datastoreName string) ([]models.BackupSpec, error) { - args := d.Called(projectSpec, datastoreName) +func (d *DatastoreService) ListBackupResources(ctx context.Context, projectSpec models.ProjectSpec, datastoreName string) ([]models.BackupSpec, error) { + args := d.Called(ctx, projectSpec, datastoreName) return args.Get(0).([]models.BackupSpec), args.Error(1) } @@ -161,27 +161,27 @@ type ResourceSpecRepository struct { mock.Mock } -func (r *ResourceSpecRepository) Save(spec models.ResourceSpec) error { - return r.Called(spec).Error(0) +func (r *ResourceSpecRepository) Save(ctx context.Context, spec models.ResourceSpec) error { + return r.Called(ctx, spec).Error(0) } -func (r *ResourceSpecRepository) GetByName(s string) (models.ResourceSpec, error) { - args := r.Called(s) +func (r *ResourceSpecRepository) GetByName(ctx context.Context, s string) (models.ResourceSpec, error) { + args := r.Called(ctx, s) return args.Get(0).(models.ResourceSpec), args.Error(1) } -func (r *ResourceSpecRepository) GetByURN(s string) (models.ResourceSpec, error) { - args := r.Called(s) +func (r *ResourceSpecRepository) GetByURN(ctx context.Context, s string) (models.ResourceSpec, error) { + args := r.Called(ctx, s) return args.Get(0).(models.ResourceSpec), args.Error(1) } -func (r *ResourceSpecRepository) GetAll() ([]models.ResourceSpec, error) { - args := r.Called() +func (r *ResourceSpecRepository) GetAll(ctx context.Context) ([]models.ResourceSpec, error) { + args := r.Called(ctx) return args.Get(0).([]models.ResourceSpec), args.Error(1) } -func (r *ResourceSpecRepository) Delete(s string) error { - return r.Called(s).Error(0) +func (r *ResourceSpecRepository) Delete(ctx context.Context, s string) error { + return r.Called(ctx, s).Error(0) } type ProjectResourceSpecRepoFactory struct { @@ -196,12 +196,12 @@ type ProjectResourceSpecRepository struct { mock.Mock } -func (r *ProjectResourceSpecRepository) GetByName(s string) (models.ResourceSpec, error) { - args := r.Called(s) +func (r *ProjectResourceSpecRepository) GetByName(ctx context.Context, s string) (models.ResourceSpec, error) { + args := r.Called(ctx, s) return args.Get(0).(models.ResourceSpec), args.Error(1) } -func (r *ProjectResourceSpecRepository) GetAll() ([]models.ResourceSpec, error) { - args := r.Called() +func (r *ProjectResourceSpecRepository) GetAll(ctx context.Context) ([]models.ResourceSpec, error) { + args := r.Called(ctx) return args.Get(0).([]models.ResourceSpec), args.Error(1) } diff --git a/mock/instance.go b/mock/instance.go index 67b268dae4..05e1fe7d3b 100644 --- a/mock/instance.go +++ b/mock/instance.go @@ -1,6 +1,7 @@ package mock import ( + "context" "time" "github.com/odpf/optimus/store" @@ -24,58 +25,58 @@ type JobRunRepository struct { mock.Mock } -func (r *JobRunRepository) Save(spec models.NamespaceSpec, run models.JobRun) error { - args := r.Called(spec, run) +func (r *JobRunRepository) Save(ctx context.Context, spec models.NamespaceSpec, run models.JobRun) error { + args := r.Called(ctx, spec, run) return args.Error(0) } -func (r *JobRunRepository) GetByScheduledAt(jobID uuid.UUID, scheduledAt time.Time) (models.JobRun, models.NamespaceSpec, error) { - args := r.Called(jobID, scheduledAt) +func (r *JobRunRepository) GetByScheduledAt(ctx context.Context, jobID uuid.UUID, scheduledAt time.Time) (models.JobRun, models.NamespaceSpec, error) { + args := r.Called(ctx, jobID, scheduledAt) return args.Get(0).(models.JobRun), args.Get(1).(models.NamespaceSpec), args.Error(2) } -func (r *JobRunRepository) GetByID(u uuid.UUID) (models.JobRun, models.NamespaceSpec, error) { - args := r.Called(u) +func (r *JobRunRepository) GetByID(ctx context.Context, u uuid.UUID) (models.JobRun, models.NamespaceSpec, error) { + args := r.Called(ctx, u) return args.Get(0).(models.JobRun), args.Get(1).(models.NamespaceSpec), args.Error(2) } -func (r *JobRunRepository) UpdateStatus(u uuid.UUID, state models.JobRunState) error { - args := r.Called(u, state) +func (r *JobRunRepository) UpdateStatus(ctx context.Context, u uuid.UUID, state models.JobRunState) error { + args := r.Called(ctx, u, state) return args.Error(0) } -func (r *JobRunRepository) GetByStatus(state ...models.JobRunState) ([]models.JobRun, error) { - args := r.Called(state) +func (r *JobRunRepository) GetByStatus(ctx context.Context, state ...models.JobRunState) ([]models.JobRun, error) { + args := r.Called(ctx, state) return args.Get(0).([]models.JobRun), args.Error(1) } -func (r *JobRunRepository) GetByTrigger(trig models.JobRunTrigger, state ...models.JobRunState) ([]models.JobRun, error) { - args := r.Called(trig, state) +func (r *JobRunRepository) GetByTrigger(ctx context.Context, trig models.JobRunTrigger, state ...models.JobRunState) ([]models.JobRun, error) { + args := r.Called(ctx, trig, state) return args.Get(0).([]models.JobRun), args.Error(1) } -func (r *JobRunRepository) Delete(u uuid.UUID) error { - args := r.Called(u) +func (r *JobRunRepository) Delete(ctx context.Context, u uuid.UUID) error { + args := r.Called(ctx, u) return args.Error(0) } -func (r *JobRunRepository) AddInstance(namespace models.NamespaceSpec, run models.JobRun, spec models.InstanceSpec) error { - args := r.Called(namespace, run, spec) +func (r *JobRunRepository) AddInstance(ctx context.Context, namespace models.NamespaceSpec, run models.JobRun, spec models.InstanceSpec) error { + args := r.Called(ctx, namespace, run, spec) return args.Error(0) } -func (r *JobRunRepository) Clear(runID uuid.UUID) error { - args := r.Called(runID) +func (r *JobRunRepository) Clear(ctx context.Context, runID uuid.UUID) error { + args := r.Called(ctx, runID) return args.Error(0) } -func (r *JobRunRepository) ClearInstance(runID uuid.UUID, instanceType models.InstanceType, instanceName string) error { - args := r.Called(runID, instanceType, instanceName) +func (r *JobRunRepository) ClearInstance(ctx context.Context, runID uuid.UUID, instanceType models.InstanceType, instanceName string) error { + args := r.Called(ctx, runID, instanceType, instanceName) return args.Error(0) } -func (r *JobRunRepository) ClearInstances(jobID uuid.UUID, scheduled time.Time) error { - args := r.Called(jobID, scheduled) +func (r *JobRunRepository) ClearInstances(ctx context.Context, jobID uuid.UUID, scheduled time.Time) error { + args := r.Called(ctx, jobID, scheduled) return args.Error(0) } @@ -93,42 +94,42 @@ type InstanceSpecRepository struct { mock.Mock } -func (repo *InstanceSpecRepository) Save(t models.InstanceSpec) error { - return repo.Called(t).Error(0) +func (repo *InstanceSpecRepository) Save(ctx context.Context, t models.InstanceSpec) error { + return repo.Called(ctx, t).Error(0) } -func (repo *InstanceSpecRepository) GetByScheduledAt(st time.Time) (models.InstanceSpec, error) { - args := repo.Called(st) +func (repo *InstanceSpecRepository) GetByScheduledAt(ctx context.Context, st time.Time) (models.InstanceSpec, error) { + args := repo.Called(ctx, st) if args.Get(0) != nil { return args.Get(0).(models.InstanceSpec), args.Error(1) } return models.InstanceSpec{}, args.Error(1) } -func (repo *InstanceSpecRepository) Clear(st time.Time) error { - return repo.Called(st).Error(0) +func (repo *InstanceSpecRepository) Clear(ctx context.Context, st time.Time) error { + return repo.Called(ctx, st).Error(0) } type RunService struct { mock.Mock } -func (s *RunService) GetScheduledRun(namespaceSpec models.NamespaceSpec, JobID models.JobSpec, scheduledAt time.Time) (models.JobRun, error) { - args := s.Called(namespaceSpec, JobID, scheduledAt) +func (s *RunService) GetScheduledRun(ctx context.Context, namespaceSpec models.NamespaceSpec, JobID models.JobSpec, scheduledAt time.Time) (models.JobRun, error) { + args := s.Called(ctx, namespaceSpec, JobID, scheduledAt) return args.Get(0).(models.JobRun), args.Error(1) } -func (s *RunService) GetByID(JobRunID uuid.UUID) (models.JobRun, models.NamespaceSpec, error) { - args := s.Called(JobRunID) +func (s *RunService) GetByID(ctx context.Context, JobRunID uuid.UUID) (models.JobRun, models.NamespaceSpec, error) { + args := s.Called(ctx, JobRunID) return args.Get(0).(models.JobRun), args.Get(1).(models.NamespaceSpec), args.Error(2) } -func (s *RunService) Register(namespace models.NamespaceSpec, jobRun models.JobRun, instanceType models.InstanceType, instanceName string) (models.InstanceSpec, error) { - args := s.Called(namespace, jobRun, instanceType, instanceName) +func (s *RunService) Register(ctx context.Context, namespace models.NamespaceSpec, jobRun models.JobRun, instanceType models.InstanceType, instanceName string) (models.InstanceSpec, error) { + args := s.Called(ctx, namespace, jobRun, instanceType, instanceName) return args.Get(0).(models.InstanceSpec), args.Error(1) } -func (s *RunService) Compile(namespaceSpec models.NamespaceSpec, jobRun models.JobRun, instanceSpec models.InstanceSpec) (envMap map[string]string, fileMap map[string]string, err error) { - args := s.Called(namespaceSpec, jobRun, instanceSpec) +func (s *RunService) Compile(ctx context.Context, namespaceSpec models.NamespaceSpec, jobRun models.JobRun, instanceSpec models.InstanceSpec) (envMap map[string]string, fileMap map[string]string, err error) { + args := s.Called(ctx, namespaceSpec, jobRun, instanceSpec) return args.Get(0).(map[string]string), args.Get(1).(map[string]string), args.Error(2) } diff --git a/mock/job.go b/mock/job.go index a68202c981..ed07de11bc 100644 --- a/mock/job.go +++ b/mock/job.go @@ -30,32 +30,32 @@ type ProjectJobSpecRepository struct { mock.Mock } -func (repo *ProjectJobSpecRepository) GetByName(name string) (models.JobSpec, models.NamespaceSpec, error) { - args := repo.Called(name) +func (repo *ProjectJobSpecRepository) GetByName(ctx context.Context, name string) (models.JobSpec, models.NamespaceSpec, error) { + args := repo.Called(ctx, name) if args.Get(0) != nil { return args.Get(0).(models.JobSpec), args.Get(1).(models.NamespaceSpec), args.Error(2) } return models.JobSpec{}, models.NamespaceSpec{}, args.Error(1) } -func (repo *ProjectJobSpecRepository) GetByNameForProject(job, project string) (models.JobSpec, models.ProjectSpec, error) { - args := repo.Called(job, project) +func (repo *ProjectJobSpecRepository) GetByNameForProject(ctx context.Context, job, project string) (models.JobSpec, models.ProjectSpec, error) { + args := repo.Called(ctx, job, project) if args.Get(0) != nil { return args.Get(0).(models.JobSpec), args.Get(1).(models.ProjectSpec), args.Error(2) } return models.JobSpec{}, models.ProjectSpec{}, args.Error(1) } -func (repo *ProjectJobSpecRepository) GetAll() ([]models.JobSpec, error) { - args := repo.Called() +func (repo *ProjectJobSpecRepository) GetAll(ctx context.Context) ([]models.JobSpec, error) { + args := repo.Called(ctx) if args.Get(0) != nil { return args.Get(0).([]models.JobSpec), args.Error(1) } return []models.JobSpec{}, args.Error(1) } -func (repo *ProjectJobSpecRepository) GetByDestination(dest string) (models.JobSpec, models.ProjectSpec, error) { - args := repo.Called(dest) +func (repo *ProjectJobSpecRepository) GetByDestination(ctx context.Context, dest string) (models.JobSpec, models.ProjectSpec, error) { + args := repo.Called(ctx, dest) if args.Get(0) != nil { return args.Get(0).(models.JobSpec), args.Get(1).(models.ProjectSpec), args.Error(2) } @@ -76,32 +76,32 @@ type JobSpecRepository struct { mock.Mock } -func (repo *JobSpecRepository) Save(t models.JobSpec) error { - return repo.Called(t).Error(0) +func (repo *JobSpecRepository) Save(ctx context.Context, t models.JobSpec) error { + return repo.Called(ctx, t).Error(0) } -func (repo *JobSpecRepository) GetByName(name string) (models.JobSpec, error) { - args := repo.Called(name) +func (repo *JobSpecRepository) GetByName(ctx context.Context, name string) (models.JobSpec, error) { + args := repo.Called(ctx, name) if args.Get(0) != nil { return args.Get(0).(models.JobSpec), args.Error(1) } return models.JobSpec{}, args.Error(1) } -func (repo *JobSpecRepository) Delete(name string) error { - return repo.Called(name).Error(0) +func (repo *JobSpecRepository) Delete(ctx context.Context, name string) error { + return repo.Called(ctx, name).Error(0) } -func (repo *JobSpecRepository) GetAll() ([]models.JobSpec, error) { - args := repo.Called() +func (repo *JobSpecRepository) GetAll(ctx context.Context) ([]models.JobSpec, error) { + args := repo.Called(ctx) if args.Get(0) != nil { return args.Get(0).([]models.JobSpec), args.Error(1) } return []models.JobSpec{}, args.Error(1) } -func (repo *JobSpecRepository) GetByDestination(dest string) (models.JobSpec, models.ProjectSpec, error) { - args := repo.Called(dest) +func (repo *JobSpecRepository) GetByDestination(ctx context.Context, dest string) (models.JobSpec, models.ProjectSpec, error) { + args := repo.Called(ctx, dest) if args.Get(0) != nil { return args.Get(0).(models.JobSpec), args.Get(1).(models.ProjectSpec), args.Error(2) } @@ -121,28 +121,28 @@ type JobService struct { mock.Mock } -func (srv *JobService) Create(spec2 models.NamespaceSpec, spec models.JobSpec) error { - args := srv.Called(spec, spec2) +func (srv *JobService) Create(ctx context.Context, spec2 models.NamespaceSpec, spec models.JobSpec) error { + args := srv.Called(ctx, spec, spec2) return args.Error(0) } -func (srv *JobService) GetByName(s string, spec models.NamespaceSpec) (models.JobSpec, error) { - args := srv.Called(s, spec) +func (srv *JobService) GetByName(ctx context.Context, s string, spec models.NamespaceSpec) (models.JobSpec, error) { + args := srv.Called(ctx, s, spec) return args.Get(0).(models.JobSpec), args.Error(1) } -func (srv *JobService) KeepOnly(spec models.NamespaceSpec, specs []models.JobSpec, observer progress.Observer) error { - args := srv.Called(spec, specs) +func (srv *JobService) KeepOnly(ctx context.Context, spec models.NamespaceSpec, specs []models.JobSpec, observer progress.Observer) error { + args := srv.Called(ctx, spec, specs) return args.Error(0) } -func (srv *JobService) GetAll(spec models.NamespaceSpec) ([]models.JobSpec, error) { - args := srv.Called(spec) +func (srv *JobService) GetAll(ctx context.Context, spec models.NamespaceSpec) ([]models.JobSpec, error) { + args := srv.Called(ctx, spec) return args.Get(0).([]models.JobSpec), args.Error(1) } -func (srv *JobService) GetByNameForProject(s string, spec models.ProjectSpec) (models.JobSpec, models.NamespaceSpec, error) { - args := srv.Called(s, spec) +func (srv *JobService) GetByNameForProject(ctx context.Context, s string, spec models.ProjectSpec) (models.JobSpec, models.NamespaceSpec, error) { + args := srv.Called(ctx, s, spec) return args.Get(0).(models.JobSpec), args.Get(1).(models.NamespaceSpec), args.Error(2) } @@ -162,7 +162,7 @@ func (j *JobService) Delete(ctx context.Context, c models.NamespaceSpec, job mod } func (j *JobService) ReplayDryRun(ctx context.Context, replayRequest models.ReplayRequest) (*tree.TreeNode, error) { - args := j.Called(replayRequest) + args := j.Called(ctx, replayRequest) return args.Get(0).(*tree.TreeNode), args.Error(1) } @@ -176,8 +176,8 @@ func (j *JobService) GetReplayStatus(ctx context.Context, replayRequest models.R return args.Get(0).(models.ReplayState), args.Error(1) } -func (j *JobService) GetReplayList(projectUUID uuid.UUID) ([]models.ReplaySpec, error) { - args := j.Called(projectUUID) +func (j *JobService) GetReplayList(ctx context.Context, projectUUID uuid.UUID) ([]models.ReplaySpec, error) { + args := j.Called(ctx, projectUUID) return args.Get(0).([]models.ReplaySpec), args.Error(1) } @@ -186,7 +186,7 @@ func (j *JobService) Run(ctx context.Context, ns models.NamespaceSpec, js []mode return args.Error(0) } -func (j *JobService) GetByDestination(projectSpec models.ProjectSpec, destination string) (models.JobSpec, error) { +func (j *JobService) GetByDestination(ctx context.Context, projectSpec models.ProjectSpec, destination string) (models.JobSpec, error) { args := j.Called(projectSpec, destination) return args.Get(0).(models.JobSpec), args.Error(1) } diff --git a/mock/namespace.go b/mock/namespace.go index c5def47490..bd98816b4d 100644 --- a/mock/namespace.go +++ b/mock/namespace.go @@ -1,6 +1,8 @@ package mock import ( + "context" + "github.com/odpf/optimus/models" "github.com/odpf/optimus/store" "github.com/stretchr/testify/mock" @@ -10,17 +12,17 @@ type NamespaceRepository struct { mock.Mock } -func (pr *NamespaceRepository) Save(spec models.NamespaceSpec) error { - return pr.Called(spec).Error(0) +func (pr *NamespaceRepository) Save(ctx context.Context, spec models.NamespaceSpec) error { + return pr.Called(ctx, spec).Error(0) } -func (pr *NamespaceRepository) GetByName(name string) (models.NamespaceSpec, error) { - args := pr.Called(name) +func (pr *NamespaceRepository) GetByName(ctx context.Context, name string) (models.NamespaceSpec, error) { + args := pr.Called(ctx, name) return args.Get(0).(models.NamespaceSpec), args.Error(1) } -func (pr *NamespaceRepository) GetAll() ([]models.NamespaceSpec, error) { - args := pr.Called() +func (pr *NamespaceRepository) GetAll(ctx context.Context) ([]models.NamespaceSpec, error) { + args := pr.Called(ctx) return args.Get(0).([]models.NamespaceSpec), args.Error(1) } diff --git a/mock/project.go b/mock/project.go index 03d48f861e..75266ed169 100644 --- a/mock/project.go +++ b/mock/project.go @@ -1,6 +1,8 @@ package mock import ( + "context" + "github.com/odpf/optimus/core/progress" "github.com/odpf/optimus/models" "github.com/odpf/optimus/store" @@ -11,17 +13,17 @@ type ProjectRepository struct { mock.Mock } -func (pr *ProjectRepository) Save(spec models.ProjectSpec) error { - return pr.Called(spec).Error(0) +func (pr *ProjectRepository) Save(ctx context.Context, spec models.ProjectSpec) error { + return pr.Called(ctx, spec).Error(0) } -func (pr *ProjectRepository) GetByName(name string) (models.ProjectSpec, error) { - args := pr.Called(name) +func (pr *ProjectRepository) GetByName(ctx context.Context, name string) (models.ProjectSpec, error) { + args := pr.Called(ctx, name) return args.Get(0).(models.ProjectSpec), args.Error(1) } -func (pr *ProjectRepository) GetAll() ([]models.ProjectSpec, error) { - args := pr.Called() +func (pr *ProjectRepository) GetAll(ctx context.Context) ([]models.ProjectSpec, error) { + args := pr.Called(ctx) return args.Get(0).([]models.ProjectSpec), args.Error(1) } @@ -47,17 +49,17 @@ type ProjectSecretRepository struct { mock.Mock } -func (pr *ProjectSecretRepository) Save(spec models.ProjectSecretItem) error { - return pr.Called(spec).Error(0) +func (pr *ProjectSecretRepository) Save(ctx context.Context, spec models.ProjectSecretItem) error { + return pr.Called(ctx, spec).Error(0) } -func (pr *ProjectSecretRepository) GetByName(name string) (models.ProjectSecretItem, error) { - args := pr.Called(name) +func (pr *ProjectSecretRepository) GetByName(ctx context.Context, name string) (models.ProjectSecretItem, error) { + args := pr.Called(ctx, name) return args.Get(0).(models.ProjectSecretItem), args.Error(1) } -func (pr *ProjectSecretRepository) GetAll() ([]models.ProjectSecretItem, error) { - args := pr.Called() +func (pr *ProjectSecretRepository) GetAll(ctx context.Context) ([]models.ProjectSecretItem, error) { + args := pr.Called(ctx) return args.Get(0).([]models.ProjectSecretItem), args.Error(1) } diff --git a/mock/replay.go b/mock/replay.go index b6f24eada4..30b2b7326f 100644 --- a/mock/replay.go +++ b/mock/replay.go @@ -18,13 +18,13 @@ type ReplayRepository struct { mock.Mock } -func (repo *ReplayRepository) GetByID(id uuid.UUID) (models.ReplaySpec, error) { - args := repo.Called(id) +func (repo *ReplayRepository) GetByID(ctx context.Context, id uuid.UUID) (models.ReplaySpec, error) { + args := repo.Called(ctx, id) return args.Get(0).(models.ReplaySpec), args.Error(1) } -func (repo *ReplayRepository) Insert(replay *models.ReplaySpec) error { - return repo.Called(&models.ReplaySpec{ +func (repo *ReplayRepository) Insert(ctx context.Context, replay *models.ReplaySpec) error { + return repo.Called(ctx, &models.ReplaySpec{ ID: replay.ID, Job: replay.Job, StartDate: replay.StartDate, @@ -35,27 +35,27 @@ func (repo *ReplayRepository) Insert(replay *models.ReplaySpec) error { }).Error(0) } -func (repo *ReplayRepository) UpdateStatus(replayID uuid.UUID, status string, message models.ReplayMessage) error { - return repo.Called(replayID, status, message).Error(0) +func (repo *ReplayRepository) UpdateStatus(ctx context.Context, replayID uuid.UUID, status string, message models.ReplayMessage) error { + return repo.Called(ctx, replayID, status, message).Error(0) } -func (repo *ReplayRepository) GetByStatus(status []string) ([]models.ReplaySpec, error) { - args := repo.Called(status) +func (repo *ReplayRepository) GetByStatus(ctx context.Context, status []string) ([]models.ReplaySpec, error) { + args := repo.Called(ctx, status) return args.Get(0).([]models.ReplaySpec), args.Error(1) } -func (repo *ReplayRepository) GetByJobIDAndStatus(jobID uuid.UUID, status []string) ([]models.ReplaySpec, error) { - args := repo.Called(jobID, status) +func (repo *ReplayRepository) GetByJobIDAndStatus(ctx context.Context, jobID uuid.UUID, status []string) ([]models.ReplaySpec, error) { + args := repo.Called(ctx, jobID, status) return args.Get(0).([]models.ReplaySpec), args.Error(1) } -func (repo *ReplayRepository) GetByProjectIDAndStatus(projectID uuid.UUID, status []string) ([]models.ReplaySpec, error) { - args := repo.Called(projectID, status) +func (repo *ReplayRepository) GetByProjectIDAndStatus(ctx context.Context, projectID uuid.UUID, status []string) ([]models.ReplaySpec, error) { + args := repo.Called(ctx, projectID, status) return args.Get(0).([]models.ReplaySpec), args.Error(1) } -func (repo *ReplayRepository) GetByProjectID(projectID uuid.UUID) ([]models.ReplaySpec, error) { - args := repo.Called(projectID) +func (repo *ReplayRepository) GetByProjectID(ctx context.Context, projectID uuid.UUID) ([]models.ReplaySpec, error) { + args := repo.Called(ctx, projectID) return args.Get(0).([]models.ReplaySpec), args.Error(1) } @@ -81,13 +81,13 @@ func (rm *ReplayManager) Init() { return } -func (rm *ReplayManager) GetReplay(uuid uuid.UUID) (models.ReplaySpec, error) { - args := rm.Called(uuid) +func (rm *ReplayManager) GetReplay(ctx context.Context, uuid uuid.UUID) (models.ReplaySpec, error) { + args := rm.Called(ctx, uuid) return args.Get(0).(models.ReplaySpec), args.Error(1) } -func (rm *ReplayManager) GetReplayList(projectUUID uuid.UUID) ([]models.ReplaySpec, error) { - args := rm.Called(projectUUID) +func (rm *ReplayManager) GetReplayList(ctx context.Context, projectUUID uuid.UUID) ([]models.ReplaySpec, error) { + args := rm.Called(ctx, projectUUID) return args.Get(0).([]models.ReplaySpec), args.Error(1) } diff --git a/models/datastore.go b/models/datastore.go index 80906b8763..ab2cd691bf 100644 --- a/models/datastore.go +++ b/models/datastore.go @@ -178,7 +178,7 @@ func (s *supportedDatastore) Add(newUnit Datastorer) error { type DatastoreService interface { // does not really fetch resource metadata, just the user provided spec - GetAll(namespace NamespaceSpec, datastoreName string) ([]ResourceSpec, error) + GetAll(ctx context.Context, namespace NamespaceSpec, datastoreName string) ([]ResourceSpec, error) CreateResource(ctx context.Context, namespace NamespaceSpec, resourceSpecs []ResourceSpec, obs progress.Observer) error UpdateResource(ctx context.Context, namespace NamespaceSpec, resourceSpecs []ResourceSpec, obs progress.Observer) error @@ -186,5 +186,5 @@ type DatastoreService interface { DeleteResource(ctx context.Context, namespace NamespaceSpec, datastoreName, name string) error BackupResourceDryRun(ctx context.Context, backupRequest BackupRequest, jobSpecs []JobSpec) ([]string, error) BackupResource(ctx context.Context, backupRequest BackupRequest, jobSpecs []JobSpec) ([]string, error) - ListBackupResources(projectSpec ProjectSpec, datastoreName string) ([]BackupSpec, error) + ListBackupResources(ctx context.Context, projectSpec ProjectSpec, datastoreName string) ([]BackupSpec, error) } diff --git a/models/instance.go b/models/instance.go index 027e676936..67357ec81a 100644 --- a/models/instance.go +++ b/models/instance.go @@ -1,6 +1,7 @@ package models import ( + "context" "encoding/json" "fmt" "strings" @@ -130,16 +131,16 @@ func (j *InstanceSpec) DataToJSON() ([]byte, error) { type RunService interface { // GetScheduledRun find if already present or create a new scheduled run - GetScheduledRun(namespace NamespaceSpec, JobID JobSpec, scheduledAt time.Time) (JobRun, error) + GetScheduledRun(ctx context.Context, namespace NamespaceSpec, JobID JobSpec, scheduledAt time.Time) (JobRun, error) // GetByID returns job run, normally gets requested for manual runs - GetByID(JobRunID uuid.UUID) (JobRun, NamespaceSpec, error) + GetByID(ctx context.Context, JobRunID uuid.UUID) (JobRun, NamespaceSpec, error) // Register creates a new instance in provided job run - Register(namespace NamespaceSpec, jobRun JobRun, instanceType InstanceType, instanceName string) (InstanceSpec, error) + Register(ctx context.Context, namespace NamespaceSpec, jobRun JobRun, instanceType InstanceType, instanceName string) (InstanceSpec, error) // Compile prepares instance execution context environment - Compile(namespaceSpec NamespaceSpec, jobRun JobRun, instanceSpec InstanceSpec) (envMap map[string]string, + Compile(ctx context.Context, namespaceSpec NamespaceSpec, jobRun JobRun, instanceSpec InstanceSpec) (envMap map[string]string, fileMap map[string]string, err error) } diff --git a/models/job.go b/models/job.go index 63ba084895..332dea775f 100644 --- a/models/job.go +++ b/models/job.go @@ -296,13 +296,13 @@ type JobSpecDependency struct { // JobService provides a high-level operations on DAGs type JobService interface { // Create constructs a Job and commits it to a storage - Create(NamespaceSpec, JobSpec) error + Create(context.Context, NamespaceSpec, JobSpec) error // GetByName fetches a Job by name for a specific namespace - GetByName(string, NamespaceSpec) (JobSpec, error) + GetByName(context.Context, string, NamespaceSpec) (JobSpec, error) // KeepOnly deletes all jobs except the ones provided for a namespace - KeepOnly(NamespaceSpec, []JobSpec, progress.Observer) error + KeepOnly(context.Context, NamespaceSpec, []JobSpec, progress.Observer) error // GetAll reads all job specifications of the given namespace - GetAll(NamespaceSpec) ([]JobSpec, error) + GetAll(context.Context, NamespaceSpec) ([]JobSpec, error) // Delete deletes a job spec from all repos Delete(context.Context, NamespaceSpec, JobSpec) error @@ -311,7 +311,7 @@ type JobService interface { Run(context.Context, NamespaceSpec, []JobSpec, progress.Observer) error // GetByNameForProject fetches a Job by name for a specific project - GetByNameForProject(string, ProjectSpec) (JobSpec, NamespaceSpec, error) + GetByNameForProject(context.Context, string, ProjectSpec) (JobSpec, NamespaceSpec, error) Sync(context.Context, NamespaceSpec, progress.Observer) error Check(context.Context, NamespaceSpec, []JobSpec, progress.Observer) error // ReplayDryRun returns the execution tree of jobSpec and its dependencies between start and endDate @@ -321,9 +321,9 @@ type JobService interface { // GetReplayStatus of a replay using its ID GetReplayStatus(context.Context, ReplayRequest) (ReplayState, error) //GetReplayList of a project - GetReplayList(projectID uuid.UUID) ([]ReplaySpec, error) + GetReplayList(ctx context.Context, projectID uuid.UUID) ([]ReplaySpec, error) // GetByDestination fetches a Job by destination for a specific project - GetByDestination(projectSpec ProjectSpec, destination string) (JobSpec, error) + GetByDestination(ctx context.Context, projectSpec ProjectSpec, destination string) (JobSpec, error) // GetDownstream fetches downstream jobspecs GetDownstream(ctx context.Context, projectSpec ProjectSpec, jobName string) ([]JobSpec, error) } diff --git a/run/service.go b/run/service.go index cd32279ebc..c4b0276836 100644 --- a/run/service.go +++ b/run/service.go @@ -36,43 +36,41 @@ type Service struct { templateEngine models.TemplateEngine } -func (s *Service) Compile(namespace models.NamespaceSpec, jobRun models.JobRun, instanceSpec models.InstanceSpec) ( +func (s *Service) Compile(ctx context.Context, namespace models.NamespaceSpec, jobRun models.JobRun, instanceSpec models.InstanceSpec) ( envMap map[string]string, fileMap map[string]string, err error) { return NewContextManager(namespace, jobRun, s.templateEngine).Generate(instanceSpec) } -func (s *Service) GetScheduledRun(namespace models.NamespaceSpec, jobSpec models.JobSpec, +func (s *Service) GetScheduledRun(ctx context.Context, namespace models.NamespaceSpec, jobSpec models.JobSpec, scheduledAt time.Time) (models.JobRun, error) { newJobRun := models.JobRun{ Spec: jobSpec, Trigger: models.TriggerSchedule, Status: models.RunStatePending, ScheduledAt: scheduledAt, - Instances: nil, } repo := s.repoFac.New() - jobRun, _, err := repo.GetByScheduledAt(jobSpec.ID, scheduledAt) + jobRun, _, err := repo.GetByScheduledAt(ctx, jobSpec.ID, scheduledAt) if err == nil || err == store.ErrResourceNotFound { // create a new instance if it does not already exists if err == nil { // if already exists, use the same id for in place update // because job spec might have changed by now, status needs to be reset newJobRun.ID = jobRun.ID - newJobRun.Instances = jobRun.Instances } - if err := repo.Save(namespace, newJobRun); err != nil { + if err := repo.Save(ctx, namespace, newJobRun); err != nil { return models.JobRun{}, err } } else { return models.JobRun{}, err } - jobRun, _, err = repo.GetByScheduledAt(jobSpec.ID, scheduledAt) + jobRun, _, err = repo.GetByScheduledAt(ctx, jobSpec.ID, scheduledAt) return jobRun, err } -func (s *Service) Register(namespace models.NamespaceSpec, jobRun models.JobRun, +func (s *Service) Register(ctx context.Context, namespace models.NamespaceSpec, jobRun models.JobRun, instanceType models.InstanceType, instanceName string) (models.InstanceSpec, error) { executedAt := s.Now() if len(jobRun.Instances) > 0 { @@ -92,10 +90,10 @@ func (s *Service) Register(namespace models.NamespaceSpec, jobRun models.JobRun, switch instanceType { case models.InstanceTypeTask: // clear and save fresh - if err := jobRunRepo.ClearInstance(jobRun.ID, instanceType, instanceName); err != nil && !errors.Is(err, store.ErrResourceNotFound) { + if err := jobRunRepo.ClearInstance(ctx, jobRun.ID, instanceType, instanceName); err != nil && !errors.Is(err, store.ErrResourceNotFound) { return models.InstanceSpec{}, errors.Wrapf(err, "Register: failed to clear instance of job %s", jobRun) } - if err := jobRunRepo.AddInstance(namespace, jobRun, instanceToSave); err != nil { + if err := jobRunRepo.AddInstance(ctx, namespace, jobRun, instanceToSave); err != nil { return models.InstanceSpec{}, err } case models.InstanceTypeHook: @@ -108,7 +106,7 @@ func (s *Service) Register(namespace models.NamespaceSpec, jobRun models.JobRun, } } if !exists { - if err := jobRunRepo.AddInstance(namespace, jobRun, instanceToSave); err != nil { + if err := jobRunRepo.AddInstance(ctx, namespace, jobRun, instanceToSave); err != nil { return models.InstanceSpec{}, err } } @@ -117,7 +115,7 @@ func (s *Service) Register(namespace models.NamespaceSpec, jobRun models.JobRun, } // get whatever is saved, querying again ensures it was saved correctly - if jobRun, _, err = jobRunRepo.GetByID(jobRun.ID); err != nil { + if jobRun, _, err = jobRunRepo.GetByID(ctx, jobRun.ID); err != nil { return models.InstanceSpec{}, errors.Wrapf(err, "failed to save instance for %s of %s:%s", jobRun, instanceName, instanceType) } @@ -169,8 +167,8 @@ func (s *Service) prepInstance(jobRun models.JobRun, instanceType models.Instanc }, nil } -func (s *Service) GetByID(JobRunID uuid.UUID) (models.JobRun, models.NamespaceSpec, error) { - return s.repoFac.New().GetByID(JobRunID) +func (s *Service) GetByID(ctx context.Context, JobRunID uuid.UUID) (models.JobRun, models.NamespaceSpec, error) { + return s.repoFac.New().GetByID(ctx, JobRunID) } func NewService(repoFac SpecRepoFactory, timeFunc func() time.Time, te models.TemplateEngine) *Service { diff --git a/run/service_test.go b/run/service_test.go index 368e075077..2647a4fe4b 100644 --- a/run/service_test.go +++ b/run/service_test.go @@ -18,10 +18,11 @@ import ( ) func TestService(t *testing.T) { + ctx := context.Background() execUnit := new(mock.BasePlugin) execUnit.On("PluginInfo").Return(&models.PluginInfoResponse{Name: "bq"}, nil) depMod := new(mock.DependencyResolverMod) - depMod.On("GenerateDestination", context.TODO(), mock2.AnythingOfType("models.GenerateDestinationRequest")).Return( + depMod.On("GenerateDestination", ctx, mock2.AnythingOfType("models.GenerateDestinationRequest")).Return( &models.GenerateDestinationResponse{Destination: "proj.data.tab"}, nil) jobSpec := models.JobSpec{ Name: "foo", @@ -96,12 +97,12 @@ func TestService(t *testing.T) { } runRepo := new(mock.JobRunRepository) - runRepo.On("ClearInstance", jobRun.ID, instanceSpec.Type, instanceSpec.Name).Return(nil) - runRepo.On("AddInstance", namespaceSpec, jobRun, instanceSpec).Return(nil) + runRepo.On("ClearInstance", ctx, jobRun.ID, instanceSpec.Type, instanceSpec.Name).Return(nil) + runRepo.On("AddInstance", ctx, namespaceSpec, jobRun, instanceSpec).Return(nil) localRun := jobRun localRun.Instances = append(jobRun.Instances, instanceSpec) - runRepo.On("GetByID", jobRun.ID).Return(localRun, namespaceSpec, nil) + runRepo.On("GetByID", ctx, jobRun.ID).Return(localRun, namespaceSpec, nil) defer runRepo.AssertExpectations(t) jobRunSpecRep := new(mock.JobRunRepoFactory) @@ -109,7 +110,7 @@ func TestService(t *testing.T) { defer jobRunSpecRep.AssertExpectations(t) runService := run.NewService(jobRunSpecRep, mockedTimeFunc, nil) - returnedInstanceSpec, err := runService.Register(namespaceSpec, jobRun, models.InstanceTypeTask, "bq") + returnedInstanceSpec, err := runService.Register(ctx, namespaceSpec, jobRun, models.InstanceTypeTask, "bq") assert.Nil(t, err) assert.Equal(t, instanceSpec, returnedInstanceSpec) }) @@ -144,10 +145,10 @@ func TestService(t *testing.T) { } runRepo := new(mock.JobRunRepository) - runRepo.On("AddInstance", namespaceSpec, jobRun, instanceSpec).Return(nil) + runRepo.On("AddInstance", ctx, namespaceSpec, jobRun, instanceSpec).Return(nil) localRun := jobRun localRun.Instances = append(jobRun.Instances, instanceSpec) - runRepo.On("GetByID", jobRun.ID).Return(localRun, namespaceSpec, nil) + runRepo.On("GetByID", ctx, jobRun.ID).Return(localRun, namespaceSpec, nil) defer runRepo.AssertExpectations(t) jobRunSpecRep := new(mock.JobRunRepoFactory) @@ -156,7 +157,7 @@ func TestService(t *testing.T) { runService := run.NewService(jobRunSpecRep, mockedTimeFunc, nil) - returnedInstanceSpec, err := runService.Register(namespaceSpec, jobRun, instanceSpec.Type, instanceSpec.Name) + returnedInstanceSpec, err := runService.Register(ctx, namespaceSpec, jobRun, instanceSpec.Type, instanceSpec.Name) assert.Nil(t, err) assert.Equal(t, returnedInstanceSpec, instanceSpec) }) @@ -193,7 +194,7 @@ func TestService(t *testing.T) { runRepo := new(mock.JobRunRepository) localRun := jobRun localRun.Instances = append(jobRun.Instances, instanceSpec) - runRepo.On("GetByID", jobRun.ID).Return(localRun, namespaceSpec, nil) + runRepo.On("GetByID", ctx, jobRun.ID).Return(localRun, namespaceSpec, nil) defer runRepo.AssertExpectations(t) jobRunSpecRep := new(mock.JobRunRepoFactory) @@ -202,7 +203,7 @@ func TestService(t *testing.T) { runService := run.NewService(jobRunSpecRep, mockedTimeFunc, nil) - returnedInstanceSpec, err := runService.Register(namespaceSpec, localRun, instanceSpec.Type, instanceSpec.Name) + returnedInstanceSpec, err := runService.Register(ctx, namespaceSpec, localRun, instanceSpec.Type, instanceSpec.Name) assert.Nil(t, err) assert.Equal(t, returnedInstanceSpec, instanceSpec) }) @@ -239,7 +240,7 @@ func TestService(t *testing.T) { runRepo := new(mock.JobRunRepository) localRun := jobRun localRun.Instances = append(jobRun.Instances, instanceSpec) - runRepo.On("GetByID", jobRun.ID).Return(localRun, namespaceSpec, nil) + runRepo.On("GetByID", ctx, jobRun.ID).Return(localRun, namespaceSpec, nil) defer runRepo.AssertExpectations(t) jobRunSpecRep := new(mock.JobRunRepoFactory) @@ -247,7 +248,7 @@ func TestService(t *testing.T) { defer jobRunSpecRep.AssertExpectations(t) runService := run.NewService(jobRunSpecRep, time.Now().UTC, nil) - returnedInstanceSpec, err := runService.Register(namespaceSpec, localRun, instanceSpec.Type, instanceSpec.Name) + returnedInstanceSpec, err := runService.Register(ctx, namespaceSpec, localRun, instanceSpec.Type, instanceSpec.Name) assert.Nil(t, err) assert.Equal(t, returnedInstanceSpec, instanceSpec) }) @@ -282,8 +283,8 @@ func TestService(t *testing.T) { } runRepo := new(mock.JobRunRepository) - runRepo.On("ClearInstance", jobRun.ID, instanceSpec.Type, instanceSpec.Name).Return(nil) - runRepo.On("AddInstance", namespaceSpec, jobRun, instanceSpec).Return(errors.New("a random error")) + runRepo.On("ClearInstance", ctx, jobRun.ID, instanceSpec.Type, instanceSpec.Name).Return(nil) + runRepo.On("AddInstance", ctx, namespaceSpec, jobRun, instanceSpec).Return(errors.New("a random error")) defer runRepo.AssertExpectations(t) jobRunSpecRep := new(mock.JobRunRepoFactory) @@ -292,7 +293,7 @@ func TestService(t *testing.T) { runService := run.NewService(jobRunSpecRep, mockedTimeFunc, nil) - returnedInstanceSpec, err := runService.Register(namespaceSpec, jobRun, instanceSpec.Type, instanceSpec.Name) + returnedInstanceSpec, err := runService.Register(ctx, namespaceSpec, jobRun, instanceSpec.Type, instanceSpec.Name) assert.Equal(t, "a random error", err.Error()) assert.Equal(t, models.InstanceSpec{}, returnedInstanceSpec) }) @@ -300,8 +301,8 @@ func TestService(t *testing.T) { t.Run("GetScheduledRun", func(t *testing.T) { t.Run("should update job run even if already exists", func(t *testing.T) { runRepo := new(mock.JobRunRepository) - runRepo.On("GetByScheduledAt", jobSpec.ID, scheduledAt).Return(jobRun, namespaceSpec, nil) - runRepo.On("Save", namespaceSpec, models.JobRun{ + runRepo.On("GetByScheduledAt", ctx, jobSpec.ID, scheduledAt).Return(jobRun, namespaceSpec, nil) + runRepo.On("Save", ctx, namespaceSpec, models.JobRun{ ID: jobRun.ID, Spec: jobSpec, Trigger: models.TriggerSchedule, @@ -315,14 +316,14 @@ func TestService(t *testing.T) { defer jobRunSpecRep.AssertExpectations(t) runService := run.NewService(jobRunSpecRep, mockedTimeFunc, nil) - returnedSpec, err := runService.GetScheduledRun(namespaceSpec, jobSpec, scheduledAt) + returnedSpec, err := runService.GetScheduledRun(ctx, namespaceSpec, jobSpec, scheduledAt) assert.Nil(t, err) assert.Equal(t, jobRun, returnedSpec) }) t.Run("should save a new job run if doesn't exists", func(t *testing.T) { runRepo := new(mock.JobRunRepository) - runRepo.On("GetByScheduledAt", jobSpec.ID, scheduledAt).Return(models.JobRun{}, models.NamespaceSpec{}, store.ErrResourceNotFound) - runRepo.On("Save", namespaceSpec, models.JobRun{ + runRepo.On("GetByScheduledAt", ctx, jobSpec.ID, scheduledAt).Return(models.JobRun{}, models.NamespaceSpec{}, store.ErrResourceNotFound) + runRepo.On("Save", ctx, namespaceSpec, models.JobRun{ Spec: jobSpec, Trigger: models.TriggerSchedule, Status: models.RunStatePending, @@ -335,11 +336,11 @@ func TestService(t *testing.T) { defer jobRunSpecRep.AssertExpectations(t) runService := run.NewService(jobRunSpecRep, mockedTimeFunc, nil) - _, _ = runService.GetScheduledRun(namespaceSpec, jobSpec, scheduledAt) + _, _ = runService.GetScheduledRun(ctx, namespaceSpec, jobSpec, scheduledAt) }) t.Run("should return empty RunSpec if GetByScheduledAt returns an error", func(t *testing.T) { runRepo := new(mock.JobRunRepository) - runRepo.On("GetByScheduledAt", jobSpec.ID, scheduledAt).Return(models.JobRun{}, models.NamespaceSpec{}, errors.New("a random error")) + runRepo.On("GetByScheduledAt", ctx, jobSpec.ID, scheduledAt).Return(models.JobRun{}, models.NamespaceSpec{}, errors.New("a random error")) defer runRepo.AssertExpectations(t) jobRunSpecRep := new(mock.JobRunRepoFactory) @@ -347,7 +348,7 @@ func TestService(t *testing.T) { defer jobRunSpecRep.AssertExpectations(t) runService := run.NewService(jobRunSpecRep, mockedTimeFunc, nil) - returnedSpec, err := runService.GetScheduledRun(namespaceSpec, jobSpec, scheduledAt) + returnedSpec, err := runService.GetScheduledRun(ctx, namespaceSpec, jobSpec, scheduledAt) assert.Equal(t, "a random error", err.Error()) assert.Equal(t, models.JobRun{}, returnedSpec) }) diff --git a/store/local/resource_spec_repository.go b/store/local/resource_spec_repository.go index d9af039b57..1d7d0c7716 100644 --- a/store/local/resource_spec_repository.go +++ b/store/local/resource_spec_repository.go @@ -1,6 +1,7 @@ package local import ( + "context" "fmt" "io/ioutil" "os" @@ -75,7 +76,7 @@ func (repo *resourceRepository) SaveAt(resourceSpec models.ResourceSpec, rootDir return nil } -func (repo *resourceRepository) Save(resourceSpec models.ResourceSpec) error { +func (repo *resourceRepository) Save(ctx context.Context, resourceSpec models.ResourceSpec) error { if resourceSpec.Name == "" { return errors.New("invalid job name") } @@ -96,7 +97,7 @@ func (repo *resourceRepository) Save(resourceSpec models.ResourceSpec) error { } // GetAll finds all the resources recursively in current and sub directory -func (repo *resourceRepository) GetAll() ([]models.ResourceSpec, error) { +func (repo *resourceRepository) GetAll(ctx context.Context) ([]models.ResourceSpec, error) { var resourceSpecs []models.ResourceSpec if repo.cache.dirty { if err := repo.refreshCache(); err != nil { @@ -114,7 +115,7 @@ func (repo *resourceRepository) GetAll() ([]models.ResourceSpec, error) { } // GetByName returns a job requested by the name -func (repo *resourceRepository) GetByName(jobName string) (models.ResourceSpec, error) { +func (repo *resourceRepository) GetByName(ctx context.Context, jobName string) (models.ResourceSpec, error) { if strings.TrimSpace(jobName) == "" { return models.ResourceSpec{}, errors.Errorf("resource name cannot be an empty string") } @@ -135,7 +136,7 @@ func (repo *resourceRepository) GetByName(jobName string) (models.ResourceSpec, } // GetByURN returns a job requested by URN -func (repo *resourceRepository) GetByURN(urn string) (models.ResourceSpec, error) { +func (repo *resourceRepository) GetByURN(ctx context.Context, urn string) (models.ResourceSpec, error) { if strings.TrimSpace(urn) == "" { return models.ResourceSpec{}, errors.Errorf("resource urn cannot be an empty string") } @@ -156,7 +157,7 @@ func (repo *resourceRepository) GetByURN(urn string) (models.ResourceSpec, error } // Delete deletes a requested job by name -func (repo *resourceRepository) Delete(jobName string) error { +func (repo *resourceRepository) Delete(ctx context.Context, jobName string) error { panic("unimplemented") } diff --git a/store/local/resource_spec_repository_test.go b/store/local/resource_spec_repository_test.go index c39e6011af..5349463064 100644 --- a/store/local/resource_spec_repository_test.go +++ b/store/local/resource_spec_repository_test.go @@ -1,6 +1,7 @@ package local_test import ( + "context" "path/filepath" "sort" "testing" @@ -28,6 +29,7 @@ spec: ` func TestResourceSpecRepository(t *testing.T) { + ctx := context.Background() // prepare mocked datastore dsTypeTableAdapter := new(mock.DatastoreTypeAdapter) @@ -71,7 +73,7 @@ func TestResourceSpecRepository(t *testing.T) { appFS := afero.NewMemMapFs() repo := local.NewResourceSpecRepository(appFS, datastorer) - err := repo.Save(specTable) + err := repo.Save(ctx, specTable) assert.Nil(t, err) buf, err := afero.ReadFile(appFS, filepath.Join(specTable.Name, local.ResourceSpecFileName)) @@ -90,7 +92,7 @@ func TestResourceSpecRepository(t *testing.T) { }) t.Run("should return error if name is empty", func(t *testing.T) { repo := local.NewResourceSpecRepository(nil, datastorer) - err := repo.Save(models.ResourceSpec{}) + err := repo.Save(ctx, models.ResourceSpec{}) assert.NotNil(t, err) }) }) @@ -104,7 +106,7 @@ func TestResourceSpecRepository(t *testing.T) { afero.WriteFile(appFS, filepath.Join(specTable.Name, "query.sql"), []byte(specTable.Assets["query.sql"]), 0644) repo := local.NewResourceSpecRepository(appFS, datastorer) - returnedSpec, err := repo.GetByName(specTable.Name) + returnedSpec, err := repo.GetByName(ctx, specTable.Name) assert.Nil(t, err) assert.Equal(t, specTable, returnedSpec) }) @@ -116,20 +118,20 @@ func TestResourceSpecRepository(t *testing.T) { afero.WriteFile(appFS, filepath.Join(specTable.Name, "query.sql"), []byte(specTable.Assets["query.sql"]), 0644) repo := local.NewResourceSpecRepository(appFS, datastorer) - returnedSpec, err := repo.GetByName(specTable.Name) + returnedSpec, err := repo.GetByName(ctx, specTable.Name) assert.Nil(t, err) assert.Equal(t, specTable, returnedSpec) // delete all specs assert.Nil(t, appFS.RemoveAll(specTable.Name)) - returnedSpecAgain, err := repo.GetByName(specTable.Name) + returnedSpecAgain, err := repo.GetByName(ctx, specTable.Name) assert.Nil(t, err) assert.Equal(t, specTable, returnedSpecAgain) }) t.Run("should return ErrNoSuchSpec in case no job folder exist", func(t *testing.T) { repo := local.NewResourceSpecRepository(afero.NewMemMapFs(), datastorer) - _, err := repo.GetByName(specTable.Name) + _, err := repo.GetByName(ctx, specTable.Name) assert.Equal(t, models.ErrNoSuchSpec, err) }) t.Run("should return ErrNoSuchSpec in case the folder exist but no resource file exist", func(t *testing.T) { @@ -137,12 +139,12 @@ func TestResourceSpecRepository(t *testing.T) { appFS.MkdirAll(specTable.Name, 0755) repo := local.NewResourceSpecRepository(appFS, datastorer) - _, err := repo.GetByName(specTable.Name) + _, err := repo.GetByName(ctx, specTable.Name) assert.Equal(t, models.ErrNoSuchSpec, err) }) t.Run("should return an error if name is empty", func(t *testing.T) { repo := local.NewResourceSpecRepository(afero.NewMemMapFs(), nil) - _, err := repo.GetByName("") + _, err := repo.GetByName(ctx, "") assert.NotNil(t, err) }) t.Run("should return error if yaml source is incorrect and failed to validate", func(t *testing.T) { @@ -153,7 +155,7 @@ func TestResourceSpecRepository(t *testing.T) { afero.WriteFile(appFS, filepath.Join(specTable.Name, "query.sql"), []byte(specTable.Assets["query.sql"]), 0644) repo := local.NewResourceSpecRepository(appFS, datastorer) - _, err := repo.GetByName(specTable.Name) + _, err := repo.GetByName(ctx, specTable.Name) assert.NotNil(t, err) }) }) @@ -212,7 +214,7 @@ spec: } repo := local.NewResourceSpecRepository(appFS, datastorer) - result, err := repo.GetAll() + result, err := repo.GetAll(ctx) assert.Nil(t, err) assert.Equal(t, len(resSpecs), len(result)) @@ -222,7 +224,7 @@ spec: }) t.Run("should return ErrNoResources if the root directory does not exist", func(t *testing.T) { repo := local.NewResourceSpecRepository(afero.NewMemMapFs(), datastorer) - _, err := repo.GetAll() + _, err := repo.GetAll(ctx) assert.Equal(t, models.ErrNoResources, err) }) t.Run("should return ErrNoResources if the root directory has no files", func(t *testing.T) { @@ -230,7 +232,7 @@ spec: appFS.MkdirAll("test", 0755) repo := local.NewResourceSpecRepository(appFS, datastorer) - _, err := repo.GetAll() + _, err := repo.GetAll(ctx) assert.Equal(t, models.ErrNoResources, err) }) t.Run("should use cache to return specs if called more than once", func(t *testing.T) { @@ -242,7 +244,7 @@ spec: } repo := local.NewResourceSpecRepository(appFS, datastorer) - result, err := repo.GetAll() + result, err := repo.GetAll(ctx) sort.Slice(result, func(i, j int) bool { return result[i].Name > result[j].Name }) assert.Nil(t, err) assert.Equal(t, resSpecs, result) @@ -250,7 +252,7 @@ spec: // clear fs assert.Nil(t, appFS.RemoveAll(".")) - resultAgain, err := repo.GetAll() + resultAgain, err := repo.GetAll(ctx) assert.Nil(t, err) assert.Equal(t, len(result), len(resultAgain)) }) diff --git a/store/postgres/backup_repository.go b/store/postgres/backup_repository.go index 635a5ae5fe..6f60601171 100644 --- a/store/postgres/backup_repository.go +++ b/store/postgres/backup_repository.go @@ -1,15 +1,16 @@ package postgres import ( + "context" "encoding/json" "time" "gorm.io/datatypes" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/models" "github.com/pkg/errors" + "gorm.io/gorm" ) type BackupDetail struct { @@ -59,7 +60,7 @@ func (b Backup) FromSpec(backupSpec models.BackupSpec) (Backup, error) { }, nil } -func (repo *backupRepository) Save(spec models.BackupSpec) error { +func (repo *backupRepository) Save(ctx context.Context, spec models.BackupSpec) error { if len(spec.Resource.ID) == 0 { return errors.New("resource cannot be empty") } @@ -67,7 +68,7 @@ func (repo *backupRepository) Save(spec models.BackupSpec) error { if err != nil { return err } - return repo.db.Create(&p).Error + return repo.db.WithContext(ctx).Create(&p).Error } func (b Backup) ToSpec(ds models.Datastorer) (models.BackupSpec, error) { @@ -91,10 +92,10 @@ func (b Backup) ToSpec(ds models.Datastorer) (models.BackupSpec, error) { }, nil } -func (repo *backupRepository) GetAll() ([]models.BackupSpec, error) { +func (repo *backupRepository) GetAll(ctx context.Context) ([]models.BackupSpec, error) { var specs []models.BackupSpec var backups []Backup - if err := repo.db.Preload("Resource").Joins("JOIN resource ON backup.resource_id = resource.id"). + if err := repo.db.WithContext(ctx).Preload("Resource").Joins("JOIN resource ON backup.resource_id = resource.id"). Where("resource.project_id = ?", repo.project.ID).Find(&backups).Error; err != nil { return specs, err } diff --git a/store/postgres/backup_repository_test.go b/store/postgres/backup_repository_test.go index b4a0abf51c..48f5d8732e 100644 --- a/store/postgres/backup_repository_test.go +++ b/store/postgres/backup_repository_test.go @@ -3,6 +3,7 @@ package postgres import ( + "context" "fmt" "os" "testing" @@ -11,9 +12,9 @@ import ( testMock "github.com/stretchr/testify/mock" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/models" "github.com/stretchr/testify/assert" + "gorm.io/gorm" ) func TestBackupRepository(t *testing.T) { @@ -25,6 +26,7 @@ func TestBackupRepository(t *testing.T) { }, } hash, _ := models.NewApplicationSecret("32charshtesthashtesthashtesthash") + ctx := context.Background() // prepare mocked datastore dsTypeTableAdapter := new(mock.DatastoreTypeAdapter) @@ -58,7 +60,7 @@ func TestBackupRepository(t *testing.T) { } projRepo := NewProjectRepository(dbConn, hash) - assert.Nil(t, projRepo.Save(projectSpec)) + assert.Nil(t, projRepo.Save(ctx, projectSpec)) return dbConn } @@ -70,7 +72,8 @@ func TestBackupRepository(t *testing.T) { t.Run("Save", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() resourceSpec := models.ResourceSpec{ ID: uuid.Must(uuid.NewRandom()), @@ -112,14 +115,14 @@ func TestBackupRepository(t *testing.T) { projectResourceSpecRepo := NewProjectResourceSpecRepository(db, projectSpec, datastorer) resourceRepo := NewResourceSpecRepository(db, namespaceSpec, datastorer, projectResourceSpecRepo) - err := resourceRepo.Insert(resourceSpec) + err := resourceRepo.Insert(ctx, resourceSpec) assert.Nil(t, err) backupRepo := NewBackupRepository(db, projectSpec, datastorer) - err = backupRepo.Save(backupSpec) + err = backupRepo.Save(ctx, backupSpec) assert.Nil(t, err) - backups, err := backupRepo.GetAll() + backups, err := backupRepo.GetAll(ctx) assert.Nil(t, err) assert.Equal(t, backupSpec.ID, backups[0].ID) diff --git a/store/postgres/instance_repository.go b/store/postgres/instance_repository.go index 0a6ce1cdd6..6514107ba8 100644 --- a/store/postgres/instance_repository.go +++ b/store/postgres/instance_repository.go @@ -1,16 +1,17 @@ package postgres import ( + "context" "encoding/json" "time" "github.com/odpf/optimus/store" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/models" "github.com/pkg/errors" "gorm.io/datatypes" + "gorm.io/gorm" ) type Instance struct { @@ -82,18 +83,18 @@ type InstanceRepository struct { Now func() } -func (repo *InstanceRepository) Insert(run models.JobRun, spec models.InstanceSpec) error { +func (repo *InstanceRepository) Insert(ctx context.Context, run models.JobRun, spec models.InstanceSpec) error { resource, err := Instance{}.FromSpec(spec, run.ID) if err != nil { return err } - return repo.db.Omit("JobRun").Create(&resource).Error + return repo.db.WithContext(ctx).Omit("JobRun").Create(&resource).Error } -func (repo *InstanceRepository) Save(run models.JobRun, spec models.InstanceSpec) error { - existingResource, err := repo.GetByName(run.ID, spec.Name, spec.Type.String()) +func (repo *InstanceRepository) Save(ctx context.Context, run models.JobRun, spec models.InstanceSpec) error { + existingResource, err := repo.GetByName(ctx, run.ID, spec.Name, spec.Type.String()) if errors.Is(err, store.ErrResourceNotFound) { - return repo.Insert(run, spec) + return repo.Insert(ctx, run, spec) } else if err != nil { return errors.Wrap(err, "unable to find instance by schedule") } @@ -103,21 +104,21 @@ func (repo *InstanceRepository) Save(run models.JobRun, spec models.InstanceSpec return err } resource.ID = existingResource.ID - return repo.db.Debug().Omit("JobRun").Model(&resource).Updates(&resource).Error + return repo.db.WithContext(ctx).Debug().Omit("JobRun").Model(&resource).Updates(&resource).Error } -func (repo *InstanceRepository) UpdateStatus(id uuid.UUID, status models.JobRunState) error { +func (repo *InstanceRepository) UpdateStatus(ctx context.Context, id uuid.UUID, status models.JobRunState) error { var r Instance - if err := repo.db.Where("id = ?", id).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("id = ?", id).Find(&r).Error; err != nil { return err } r.Status = status.String() - return repo.db.Omit("JobRun").Save(&r).Error + return repo.db.WithContext(ctx).Omit("JobRun").Save(&r).Error } -func (repo *InstanceRepository) GetByName(runID uuid.UUID, instanceName, instanceType string) (models.InstanceSpec, error) { +func (repo *InstanceRepository) GetByName(ctx context.Context, runID uuid.UUID, instanceName, instanceType string) (models.InstanceSpec, error) { var r Instance - if err := repo.db.Preload("JobRun").Where("job_run_id = ? AND instance_name = ? AND instance_type = ?", runID, instanceName, instanceType).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Preload("JobRun").Where("job_run_id = ? AND instance_name = ? AND instance_type = ?", runID, instanceName, instanceType).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.InstanceSpec{}, store.ErrResourceNotFound } @@ -126,9 +127,9 @@ func (repo *InstanceRepository) GetByName(runID uuid.UUID, instanceName, instanc return r.ToSpec() } -func (repo *InstanceRepository) GetByID(id uuid.UUID) (models.InstanceSpec, error) { +func (repo *InstanceRepository) GetByID(ctx context.Context, id uuid.UUID) (models.InstanceSpec, error) { var r Instance - if err := repo.db.Preload("JobRun").Where("id = ?", id).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Preload("JobRun").Where("id = ?", id).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.InstanceSpec{}, store.ErrResourceNotFound } @@ -137,12 +138,20 @@ func (repo *InstanceRepository) GetByID(id uuid.UUID) (models.InstanceSpec, erro return r.ToSpec() } -func (repo *InstanceRepository) Delete(id uuid.UUID) error { - return repo.db.Where("id = ?", id).Delete(&Instance{}).Error +func (repo *InstanceRepository) Delete(ctx context.Context, id uuid.UUID) error { + return repo.db.WithContext(ctx).Where("id = ?", id).Delete(&Instance{}).Error } -func (repo *InstanceRepository) DeleteByJobRun(runID uuid.UUID) error { - return repo.db.Where("job_run_id = ?", runID).Delete(&Instance{}).Error +func (repo *InstanceRepository) DeleteByJobRun(ctx context.Context, runID uuid.UUID) error { + return repo.db.WithContext(ctx).Where("job_run_id = ?", runID).Delete(&Instance{}).Error +} + +func (repo *InstanceRepository) GetByJobRun(ctx context.Context, runID uuid.UUID) ([]Instance, error) { + var r []Instance + if err := repo.db.WithContext(ctx).Where("job_run_id = ?", runID).Find(&r).Error; err != nil { + return nil, err + } + return r, nil } func NewInstanceRepository(db *gorm.DB, jobAdapter *JobSpecAdapter) *InstanceRepository { diff --git a/store/postgres/instance_repository_test.go b/store/postgres/instance_repository_test.go index be652691d6..56d16ea6ff 100644 --- a/store/postgres/instance_repository_test.go +++ b/store/postgres/instance_repository_test.go @@ -9,10 +9,10 @@ import ( "time" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/mock" "github.com/odpf/optimus/models" "github.com/stretchr/testify/assert" + "gorm.io/gorm" ) func TestInstanceRepository(t *testing.T) { @@ -28,6 +28,7 @@ func TestInstanceRepository(t *testing.T) { Name: "dev-team-1", ProjectSpec: projectSpec, } + ctx := context.Background() gTask := "g-task" tTask := "t-task" @@ -125,15 +126,15 @@ func TestInstanceRepository(t *testing.T) { hash, _ := models.NewApplicationSecret("32charshtesthashtesthashtesthash") prepo := NewProjectRepository(dbConn, hash) - assert.Nil(t, prepo.Save(projectSpec)) + assert.Nil(t, prepo.Save(ctx, projectSpec)) projectJobSpecRepo := NewProjectJobSpecRepository(dbConn, projectSpec, adapter) jrepo := NewJobSpecRepository(dbConn, namespaceSpec, projectJobSpecRepo, adapter) - assert.Nil(t, jrepo.Save(jobConfigs[0])) - assert.Equal(t, "task unit cannot be empty", jrepo.Save(jobConfigs[1]).Error()) + assert.Nil(t, jrepo.Save(ctx, jobConfigs[0])) + assert.Equal(t, "task unit cannot be empty", jrepo.Save(ctx, jobConfigs[1]).Error()) jobRunRepo := NewJobRunRepository(dbConn, adapter) - err = jobRunRepo.Save(namespaceSpec, jobRuns[0]) + err = jobRunRepo.Save(ctx, namespaceSpec, jobRuns[0]) assert.Nil(t, err) return dbConn } @@ -156,69 +157,72 @@ func TestInstanceRepository(t *testing.T) { t.Run("Insert", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() var testModels []models.InstanceSpec testModels = append(testModels, testSpecs...) repo := NewInstanceRepository(db, adapter) - err := repo.Insert(jobRuns[0], testModels[0]) + err := repo.Insert(ctx, jobRuns[0], testModels[0]) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModels[0].ID) + checkModel, err := repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, testModels[0].Name, checkModel.Name) assert.Equal(t, testModels[0].Data, checkModel.Data) }) t.Run("Save", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.InstanceSpec{} testModels = append(testModels, testSpecs...) repo := NewInstanceRepository(db, adapter) - err := repo.Insert(jobRuns[0], testModels[0]) + err := repo.Insert(ctx, jobRuns[0], testModels[0]) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModels[0].ID) + checkModel, err := repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, testModels[0].Name, checkModel.Name) assert.Equal(t, testModels[0].Data, checkModel.Data) - err = repo.Delete(testModels[0].ID) + err = repo.Delete(ctx, testModels[0].ID) assert.Nil(t, err) testModels[0].Name = "updated-name" - err = repo.Save(jobRuns[0], testModels[0]) + err = repo.Save(ctx, jobRuns[0], testModels[0]) assert.Nil(t, err) - checkModel, err = repo.GetByID(testModels[0].ID) + checkModel, err = repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, "updated-name", checkModel.Name) assert.Equal(t, testModels[0].Data, checkModel.Data) }) t.Run("UpdateStatus", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.InstanceSpec{} testModels = append(testModels, testSpecs...) repo := NewInstanceRepository(db, adapter) - err := repo.Save(jobRuns[0], testModels[0]) + err := repo.Save(ctx, jobRuns[0], testModels[0]) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModels[0].ID) + checkModel, err := repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, testModels[0].Name, checkModel.Name) assert.Equal(t, testModels[0].Data, checkModel.Data) - err = repo.UpdateStatus(testModels[0].ID, models.RunStateFailed) + err = repo.UpdateStatus(ctx, testModels[0].ID, models.RunStateFailed) assert.Nil(t, err) - checkModel, err = repo.GetByID(testModels[0].ID) + checkModel, err = repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, models.RunStateFailed, checkModel.Status) }) diff --git a/store/postgres/job_spec_adapter.go b/store/postgres/job_spec_adapter.go index 9b2c70e6bb..18d43c7da6 100644 --- a/store/postgres/job_spec_adapter.go +++ b/store/postgres/job_spec_adapter.go @@ -5,6 +5,8 @@ import ( "encoding/json" "time" + "gorm.io/gorm" + "github.com/google/uuid" "github.com/odpf/optimus/models" "github.com/pkg/errors" @@ -44,7 +46,7 @@ type Job struct { CreatedAt time.Time `gorm:"not null" json:"created_at"` UpdatedAt time.Time `gorm:"not null" json:"updated_at"` - DeletedAt *time.Time + DeletedAt gorm.DeletedAt } type JobBehavior struct { @@ -373,10 +375,6 @@ func (adapt JobSpecAdapter) FromSpecWithNamespace(spec models.JobSpec, namespace type JobRun struct { ID uuid.UUID `gorm:"primary_key;type:uuid;"` - // TODO: I think we can delete this field, its kinda useless - // could be null for manual/adhoc jobs - JobID uuid.UUID `gorm:"type:uuid;"` - // job spec for which this run was created, spec should contain a valid // uuid if it belongs to a saved job and not an adhoc job Spec datatypes.JSON `gorm:"column:specification;"` @@ -388,7 +386,7 @@ type JobRun struct { Status string ScheduledAt time.Time - Instances []Instance `gorm:"polymorphic:JobRun;"` + Instances []Instance CreatedAt time.Time `gorm:"not null" json:"created_at"` UpdatedAt time.Time `gorm:"not null" json:"updated_at"` @@ -420,9 +418,8 @@ func (adapt JobSpecAdapter) FromJobRun(jr models.JobRun, nsSpec models.Namespace } return JobRun{ - ID: jr.ID, - JobID: jr.Spec.ID, - Spec: specBytes, + ID: jr.ID, + Spec: specBytes, NamespaceID: adaptNamespace.ID, Namespace: adaptNamespace, diff --git a/store/postgres/job_spec_repository.go b/store/postgres/job_spec_repository.go index 14e85b504e..f03503017a 100644 --- a/store/postgres/job_spec_repository.go +++ b/store/postgres/job_spec_repository.go @@ -1,13 +1,14 @@ package postgres import ( + "context" "fmt" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/models" "github.com/odpf/optimus/store" "github.com/pkg/errors" + "gorm.io/gorm" ) type ProjectJobSpecRepository struct { @@ -24,9 +25,9 @@ func NewProjectJobSpecRepository(db *gorm.DB, project models.ProjectSpec, adapte } } -func (repo *ProjectJobSpecRepository) GetByName(name string) (models.JobSpec, models.NamespaceSpec, error) { +func (repo *ProjectJobSpecRepository) GetByName(ctx context.Context, name string) (models.JobSpec, models.NamespaceSpec, error) { var r Job - if err := repo.db.Preload("Namespace").Where("project_id = ? AND name = ?", repo.project.ID, name).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Preload("Namespace").Where("project_id = ? AND name = ?", repo.project.ID, name).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.JobSpec{}, models.NamespaceSpec{}, store.ErrResourceNotFound } @@ -46,10 +47,10 @@ func (repo *ProjectJobSpecRepository) GetByName(name string) (models.JobSpec, mo return jobSpec, namespaceSpec, nil } -func (repo *ProjectJobSpecRepository) GetAll() ([]models.JobSpec, error) { +func (repo *ProjectJobSpecRepository) GetAll(ctx context.Context) ([]models.JobSpec, error) { specs := []models.JobSpec{} jobs := []Job{} - if err := repo.db.Where("project_id = ?", repo.project.ID).Find(&jobs).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("project_id = ?", repo.project.ID).Find(&jobs).Error; err != nil { return specs, err } @@ -63,16 +64,16 @@ func (repo *ProjectJobSpecRepository) GetAll() ([]models.JobSpec, error) { return specs, nil } -func (repo *ProjectJobSpecRepository) GetByNameForProject(projName string, jobName string) (models.JobSpec, models.ProjectSpec, error) { +func (repo *ProjectJobSpecRepository) GetByNameForProject(ctx context.Context, projName string, jobName string) (models.JobSpec, models.ProjectSpec, error) { var r Job var p Project - if err := repo.db.Where("name = ?", projName).Find(&p).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("name = ?", projName).First(&p).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.JobSpec{}, models.ProjectSpec{}, errors.Wrap(store.ErrResourceNotFound, "project not found") } return models.JobSpec{}, models.ProjectSpec{}, err } - if err := repo.db.Where("project_id = ? AND name = ?", p.ID, jobName).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("project_id = ? AND name = ?", p.ID, jobName).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.JobSpec{}, models.ProjectSpec{}, errors.Wrap(store.ErrResourceNotFound, "job spec not found") } @@ -92,9 +93,9 @@ func (repo *ProjectJobSpecRepository) GetByNameForProject(projName string, jobNa return jSpec, pSpec, err } -func (repo *ProjectJobSpecRepository) GetByDestination(destination string) (models.JobSpec, models.ProjectSpec, error) { +func (repo *ProjectJobSpecRepository) GetByDestination(ctx context.Context, destination string) (models.JobSpec, models.ProjectSpec, error) { var r Job - if err := repo.db.Preload("Project").Where("destination = ?", destination).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Preload("Project").Where("destination = ?", destination).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.JobSpec{}, models.ProjectSpec{}, store.ErrResourceNotFound } @@ -121,7 +122,7 @@ type JobSpecRepository struct { adapter *JobSpecAdapter } -func (repo *JobSpecRepository) Insert(spec models.JobSpec) error { +func (repo *JobSpecRepository) Insert(ctx context.Context, spec models.JobSpec) error { resource, err := repo.adapter.FromSpecWithNamespace(spec, repo.namespace) if err != nil { return err @@ -130,17 +131,17 @@ func (repo *JobSpecRepository) Insert(spec models.JobSpec) error { return errors.New("name cannot be empty") } // if soft deleted earlier - if err := repo.HardDelete(spec.Name); err != nil { + if err := repo.HardDelete(ctx, spec.Name); err != nil { return err } - return repo.db.Create(&resource).Error + return repo.db.WithContext(ctx).Create(&resource).Error } -func (repo *JobSpecRepository) Save(spec models.JobSpec) error { +func (repo *JobSpecRepository) Save(ctx context.Context, spec models.JobSpec) error { // while saving a JobSpec, we need to ensure that it's name is unique for a project - existingJobSpec, namespaceSpec, err := repo.projectJobSpecRepo.GetByName(spec.Name) + existingJobSpec, namespaceSpec, err := repo.projectJobSpecRepo.GetByName(ctx, spec.Name) if errors.Is(err, store.ErrResourceNotFound) { - return repo.Insert(spec) + return repo.Insert(ctx, spec) } else if err != nil { return errors.Wrap(err, "unable to retrieve spec by name") } @@ -154,12 +155,12 @@ func (repo *JobSpecRepository) Save(spec models.JobSpec) error { return err } resource.ID = existingJobSpec.ID - return repo.db.Model(&resource).Updates(&resource).Error + return repo.db.WithContext(ctx).Model(&resource).Updates(&resource).Error } -func (repo *JobSpecRepository) GetByID(id uuid.UUID) (models.JobSpec, error) { +func (repo *JobSpecRepository) GetByID(ctx context.Context, id uuid.UUID) (models.JobSpec, error) { var r Job - if err := repo.db.Where("namespace_id = ? AND id = ?", repo.namespace.ID, id).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("namespace_id = ? AND id = ?", repo.namespace.ID, id).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.JobSpec{}, store.ErrResourceNotFound } @@ -169,9 +170,9 @@ func (repo *JobSpecRepository) GetByID(id uuid.UUID) (models.JobSpec, error) { return repo.adapter.ToSpec(r) } -func (repo *JobSpecRepository) GetByName(name string) (models.JobSpec, error) { +func (repo *JobSpecRepository) GetByName(ctx context.Context, name string) (models.JobSpec, error) { var r Job - if err := repo.db.Where("namespace_id = ? AND name = ?", repo.namespace.ID, name).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("namespace_id = ? AND name = ?", repo.namespace.ID, name).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.JobSpec{}, store.ErrResourceNotFound } @@ -181,26 +182,26 @@ func (repo *JobSpecRepository) GetByName(name string) (models.JobSpec, error) { return repo.adapter.ToSpec(r) } -func (repo *JobSpecRepository) Delete(name string) error { - return repo.db.Where("namespace_id = ? AND name = ?", repo.namespace.ID, name).Delete(&Job{}).Error +func (repo *JobSpecRepository) Delete(ctx context.Context, name string) error { + return repo.db.WithContext(ctx).Where("namespace_id = ? AND name = ?", repo.namespace.ID, name).Delete(&Job{}).Error } -func (repo *JobSpecRepository) HardDelete(name string) error { - //find the base job +func (repo *JobSpecRepository) HardDelete(ctx context.Context, name string) error { + // find the base job var r Job - if err := repo.db.Unscoped().Where("project_id = ? AND name = ?", repo.namespace.ProjectSpec.ID, name).Find(&r).Error; err == gorm.ErrRecordNotFound { + if err := repo.db.WithContext(ctx).Unscoped().Where("project_id = ? AND name = ?", repo.namespace.ProjectSpec.ID, name).Find(&r).Error; err == gorm.ErrRecordNotFound { // no job exists, inserting for the first time return nil } else if err != nil { return errors.Wrap(err, "failed to fetch soft deleted resource") } - return repo.db.Unscoped().Where("id = ?", r.ID).Delete(&Job{}).Error + return repo.db.WithContext(ctx).Unscoped().Where("id = ?", r.ID).Delete(&Job{}).Error } -func (repo *JobSpecRepository) GetAll() ([]models.JobSpec, error) { - specs := []models.JobSpec{} - jobs := []Job{} - if err := repo.db.Where("namespace_id = ?", repo.namespace.ID).Find(&jobs).Error; err != nil { +func (repo *JobSpecRepository) GetAll(ctx context.Context) ([]models.JobSpec, error) { + var specs []models.JobSpec + var jobs []Job + if err := repo.db.WithContext(ctx).Where("namespace_id = ?", repo.namespace.ID).Find(&jobs).Error; err != nil { return specs, err } diff --git a/store/postgres/job_spec_repository_test.go b/store/postgres/job_spec_repository_test.go index 64d2b0872e..ac4ea83dab 100644 --- a/store/postgres/job_spec_repository_test.go +++ b/store/postgres/job_spec_repository_test.go @@ -9,10 +9,10 @@ import ( "time" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/mock" "github.com/odpf/optimus/models" "github.com/stretchr/testify/assert" + "gorm.io/gorm" ) func TestJobRepository(t *testing.T) { @@ -37,6 +37,7 @@ func TestJobRepository(t *testing.T) { } return dbConn } + ctx := context.Background() projectSpec := models.ProjectSpec{ ID: uuid.Must(uuid.NewRandom()), @@ -159,7 +160,8 @@ func TestJobRepository(t *testing.T) { t.Run("Insert", func(t *testing.T) { t.Run("insert with hooks and assets should return adapted hooks and assets", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() unitData1 := models.GenerateDestinationRequest{Config: models.PluginConfigs{}.FromJobSpec(testConfigs[0].Task.Config), Assets: models.PluginAssets{}.FromJobSpec(testConfigs[0].Assets)} depMod1.On("GenerateDestination", context.TODO(), unitData1).Return(&models.GenerateDestinationResponse{Destination: destination}, nil) @@ -175,13 +177,13 @@ func TestJobRepository(t *testing.T) { repo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - err = repo.Insert(testModels[1]) + err = repo.Insert(ctx, testModels[1]) assert.NotNil(t, err) - checkModel, err := repo.GetByID(testModels[0].ID) + checkModel, err := repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, "g-optimus-id", checkModel.Name) taskSchema := checkModel.Task.Unit.Info() @@ -199,7 +201,8 @@ func TestJobRepository(t *testing.T) { }) t.Run("insert when previously soft deleted should hard delete first along with foreign key cascade", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() unitData1 := models.GenerateDestinationRequest{ Config: models.PluginConfigs{}.FromJobSpec(testConfigs[0].Task.Config), @@ -220,22 +223,22 @@ func TestJobRepository(t *testing.T) { repo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) // first insert - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModels[0].ID) + checkModel, err := repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, "g-optimus-id", checkModel.Name) // soft delete - err = repo.Delete(testModels[0].Name) + err = repo.Delete(ctx, testModels[0].Name) assert.Nil(t, err) // insert back again - err = repo.Insert(testModels[0]) + err = repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - checkModel, err = repo.GetByID(testModels[0].ID) + checkModel, err = repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, "g-optimus-id", checkModel.Name) }) @@ -243,7 +246,8 @@ func TestJobRepository(t *testing.T) { t.Run("Upsert", func(t *testing.T) { t.Run("insert different resource should insert two", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := testConfigs[0] testModelB := testConfigs[2] @@ -264,20 +268,20 @@ func TestJobRepository(t *testing.T) { repo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) //try for create - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModelA.ID) + checkModel, err := repo.GetByID(ctx, testModelA.ID) assert.Nil(t, err) assert.Equal(t, "g-optimus-id", checkModel.Name) taskSchema := checkModel.Task.Unit.Info() assert.Equal(t, gTask, taskSchema.Name) //try for update - err = repo.Save(testModelB) + err = repo.Save(ctx, testModelB) assert.Nil(t, err) - checkModel, err = repo.GetByID(testModelB.ID) + checkModel, err = repo.GetByID(ctx, testModelB.ID) assert.Nil(t, err) assert.Equal(t, "t-optimus-id", checkModel.Name) taskSchema = checkModel.Task.Unit.Info() @@ -285,7 +289,8 @@ func TestJobRepository(t *testing.T) { }) t.Run("insert same resource twice should overwrite existing", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := testConfigs[0] unitData1 := models.GenerateDestinationRequest{Config: models.PluginConfigs{}.FromJobSpec(testConfigs[0].Task.Config), Assets: models.PluginAssets{}.FromJobSpec(testConfigs[0].Assets)} @@ -306,10 +311,10 @@ func TestJobRepository(t *testing.T) { //try for create testModelA.Task.Unit = &models.Plugin{Base: execUnit1, DependencyMod: depMod1} - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModelA.ID) + checkModel, err := repo.GetByID(ctx, testModelA.ID) assert.Nil(t, err) assert.Equal(t, "g-optimus-id", checkModel.Name) taskSchema := checkModel.Task.Unit.Info() @@ -320,10 +325,10 @@ func TestJobRepository(t *testing.T) { //try for update testModelA.Task.Unit = &models.Plugin{Base: execUnit2, DependencyMod: depMod2} - err = repo.Save(testModelA) + err = repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err = repo.GetByID(testModelA.ID) + checkModel, err = repo.GetByID(ctx, testModelA.ID) assert.Nil(t, err) taskSchema = checkModel.Task.Unit.Info() assert.Equal(t, tTask, taskSchema.Name) @@ -332,7 +337,8 @@ func TestJobRepository(t *testing.T) { }) t.Run("upsert without ID should auto generate it", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := testConfigs[0] testModelA.ID = uuid.Nil @@ -340,16 +346,17 @@ func TestJobRepository(t *testing.T) { repo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) //try for create - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByName(testModelA.Name) + checkModel, err := repo.GetByName(ctx, testModelA.Name) assert.Nil(t, err) assert.Equal(t, "g-optimus-id", checkModel.Name) }) t.Run("should update same job with hooks when provided separately", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModel := testConfigs[2] testModel.Task.Unit.DependencyMod = nil execUnit2.On("PluginInfo").Return(&models.PluginInfoResponse{ @@ -360,9 +367,9 @@ func TestJobRepository(t *testing.T) { projectJobSpecRepo := NewProjectJobSpecRepository(db, projectSpec, adapter) repo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) - err := repo.Insert(testModel) + err := repo.Insert(ctx, testModel) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModel.ID) + checkModel, err := repo.GetByID(ctx, testModel.ID) assert.Nil(t, err) assert.Equal(t, "t-optimus-id", checkModel.Name) taskSchema := checkModel.Task.Unit.Info() @@ -382,9 +389,9 @@ func TestJobRepository(t *testing.T) { Unit: &models.Plugin{Base: hookUnit1}, }, } - err = repo.Save(testModel) + err = repo.Save(ctx, testModel) assert.Nil(t, err) - checkModel, err = repo.GetByID(testModel.ID) + checkModel, err = repo.GetByID(ctx, testModel.ID) assert.Nil(t, err) assert.Equal(t, "t-optimus-id", checkModel.Name) taskSchema = checkModel.Task.Unit.Info() @@ -414,9 +421,9 @@ func TestJobRepository(t *testing.T) { }, Unit: &models.Plugin{Base: hookUnit2}, }) - err = repo.Save(testModel) + err = repo.Save(ctx, testModel) assert.Nil(t, err) - checkModel, err = repo.GetByID(testModel.ID) + checkModel, err = repo.GetByID(ctx, testModel.ID) assert.Nil(t, err) assert.Equal(t, "t-optimus-id", checkModel.Name) taskSchema = checkModel.Task.Unit.Info() @@ -443,7 +450,8 @@ func TestJobRepository(t *testing.T) { }) t.Run("should fail if job is already registered for a project with different namespace", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := testConfigs[0] unitData1 := models.GenerateDestinationRequest{Config: models.PluginConfigs{}.FromJobSpec(testConfigs[0].Task.Config), Assets: models.PluginAssets{}.FromJobSpec(testConfigs[0].Assets)} @@ -456,10 +464,10 @@ func TestJobRepository(t *testing.T) { jobRepoNamespace2 := NewJobSpecRepository(db, namespaceSpec2, projectJobSpecRepo, adapter) // try to create with first namespace - err := jobRepoNamespace1.Save(testModelA) + err := jobRepoNamespace1.Save(ctx, testModelA) assert.Nil(t, err) - checkJob, checkNamespace, err := projectJobSpecRepo.GetByName(testModelA.Name) + checkJob, checkNamespace, err := projectJobSpecRepo.GetByName(ctx, testModelA.Name) assert.Nil(t, err) assert.Equal(t, "g-optimus-id", checkJob.Name) schema := checkJob.Task.Unit.Info() @@ -468,13 +476,14 @@ func TestJobRepository(t *testing.T) { assert.Equal(t, namespaceSpec.ProjectSpec.ID, checkNamespace.ProjectSpec.ID) // try to create same job with second namespace and it should fail. - err = jobRepoNamespace2.Save(testModelA) + err = jobRepoNamespace2.Save(ctx, testModelA) assert.NotNil(t, err) assert.Equal(t, "job g-optimus-id already exists for the project t-optimus-id", err.Error()) }) t.Run("should properly insert spec behavior, reading and writing", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := testConfigs[0] unitData1 := models.GenerateDestinationRequest{Config: models.PluginConfigs{}.FromJobSpec(testConfigs[0].Task.Config), Assets: models.PluginAssets{}.FromJobSpec(testConfigs[0].Assets)} @@ -486,10 +495,10 @@ func TestJobRepository(t *testing.T) { repo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) //try for create - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModelA.ID) + checkModel, err := repo.GetByID(ctx, testModelA.ID) assert.Nil(t, err) assert.Equal(t, "g-optimus-id", checkModel.Name) assert.Equal(t, true, checkModel.Behavior.CatchUp) @@ -501,10 +510,10 @@ func TestJobRepository(t *testing.T) { //try for update testModelA.Behavior.CatchUp = false testModelA.Behavior.DependsOnPast = true - err = repo.Save(testModelA) + err = repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err = repo.GetByID(testModelA.ID) + checkModel, err = repo.GetByID(ctx, testModelA.ID) assert.Nil(t, err) assert.Equal(t, false, checkModel.Behavior.CatchUp) assert.Equal(t, true, checkModel.Behavior.DependsOnPast) @@ -513,17 +522,18 @@ func TestJobRepository(t *testing.T) { t.Run("GetByName", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.JobSpec{} testModels = append(testModels, testConfigs...) projectJobSpecRepo := NewProjectJobSpecRepository(db, projectSpec, adapter) repo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - checkModel, err := repo.GetByName(testModels[0].Name) + checkModel, err := repo.GetByName(ctx, testModels[0].Name) assert.Nil(t, err) assert.Equal(t, "g-optimus-id", checkModel.Name) assert.Equal(t, "this", checkModel.Task.Config[0].Value) @@ -531,19 +541,20 @@ func TestJobRepository(t *testing.T) { t.Run("GetAll", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.JobSpec{} testModels = append(testModels, testConfigs...) projectJobSpecRepo := NewProjectJobSpecRepository(db, projectSpec, adapter) repo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - err = repo.Insert(testModels[2]) + err = repo.Insert(ctx, testModels[2]) assert.Nil(t, err) - checkModels, err := repo.GetAll() + checkModels, err := repo.GetAll(ctx) assert.Nil(t, err) assert.Equal(t, 2, len(checkModels)) }) @@ -571,6 +582,7 @@ func TestProjectJobRepository(t *testing.T) { } return dbConn } + ctx := context.Background() projectSpec := models.ProjectSpec{ ID: uuid.Must(uuid.NewRandom()), @@ -679,7 +691,8 @@ func TestProjectJobRepository(t *testing.T) { t.Run("GetByName", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.JobSpec{} testModels = append(testModels, testConfigs...) @@ -693,10 +706,10 @@ func TestProjectJobRepository(t *testing.T) { projectJobSpecRepo := NewProjectJobSpecRepository(db, projectSpec, adapter) repo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - checkJob, checkNamespace, err := projectJobSpecRepo.GetByName(testModels[0].Name) + checkJob, checkNamespace, err := projectJobSpecRepo.GetByName(ctx, testModels[0].Name) assert.Nil(t, err) assert.Equal(t, "g-optimus-id", checkJob.Name) assert.Equal(t, "this", checkJob.Task.Config[0].Value) @@ -705,7 +718,8 @@ func TestProjectJobRepository(t *testing.T) { t.Run("GetAll", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.JobSpec{} testModels = append(testModels, testConfigs...) @@ -726,19 +740,20 @@ func TestProjectJobRepository(t *testing.T) { projectJobSpecRepo := NewProjectJobSpecRepository(db, projectSpec, adapter) repo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - err = repo.Insert(testModels[2]) + err = repo.Insert(ctx, testModels[2]) assert.Nil(t, err) - checkModels, err := projectJobSpecRepo.GetAll() + checkModels, err := projectJobSpecRepo.GetAll(ctx) assert.Nil(t, err) assert.Equal(t, 2, len(checkModels)) }) t.Run("GetByDestination", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() unitData1 := models.GenerateDestinationRequest{ Config: models.PluginConfigs{}.FromJobSpec(testConfigs[0].Task.Config), @@ -755,17 +770,18 @@ func TestProjectJobRepository(t *testing.T) { projectJobSpecRepo := NewProjectJobSpecRepository(db, projectSpec, adapter) jobRepo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) - err := jobRepo.Insert(testModels[0]) + err := jobRepo.Insert(ctx, testModels[0]) assert.Nil(t, err) - j, p, err := projectJobSpecRepo.GetByDestination(destinationUrn) + j, p, err := projectJobSpecRepo.GetByDestination(ctx, destinationUrn) assert.Nil(t, err) assert.Equal(t, testConfigs[0].Name, j.Name) assert.Equal(t, projectSpec.Name, p.Name) }) t.Run("GetByNameForProject", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() unitData1 := models.GenerateDestinationRequest{ Config: models.PluginConfigs{}.FromJobSpec(testConfigs[0].Task.Config), @@ -780,14 +796,14 @@ func TestProjectJobRepository(t *testing.T) { testModels := []models.JobSpec{} testModels = append(testModels, testConfigs...) - assert.Nil(t, NewProjectRepository(db, hash).Save(projectSpec)) + assert.Nil(t, NewProjectRepository(db, hash).Save(ctx, projectSpec)) projectJobSpecRepo := NewProjectJobSpecRepository(db, projectSpec, adapter) jobRepo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) - err := jobRepo.Insert(testModels[0]) + err := jobRepo.Insert(ctx, testModels[0]) assert.Nil(t, err) - j, p, err := projectJobSpecRepo.GetByNameForProject(projectSpec.Name, testModels[0].Name) + j, p, err := projectJobSpecRepo.GetByNameForProject(ctx, projectSpec.Name, testModels[0].Name) assert.Nil(t, err) assert.Equal(t, testConfigs[0].Name, j.Name) assert.Equal(t, projectSpec.Name, p.Name) diff --git a/store/postgres/jobrun_repository.go b/store/postgres/jobrun_repository.go index 164e1ac3c7..d67c7d911b 100644 --- a/store/postgres/jobrun_repository.go +++ b/store/postgres/jobrun_repository.go @@ -1,13 +1,14 @@ package postgres import ( + "context" "time" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/models" "github.com/odpf/optimus/store" "github.com/pkg/errors" + "gorm.io/gorm" ) type JobRunRepository struct { @@ -16,23 +17,23 @@ type JobRunRepository struct { instanceRepo *InstanceRepository } -func (repo *JobRunRepository) Insert(namespace models.NamespaceSpec, spec models.JobRun) error { +func (repo *JobRunRepository) Insert(ctx context.Context, namespace models.NamespaceSpec, spec models.JobRun) error { resource, err := repo.adapter.FromJobRun(spec, namespace) if err != nil { return err } - return repo.db.Omit("Namespace").Create(&resource).Error + return repo.db.WithContext(ctx).Omit("Namespace", "Instances").Create(&resource).Error } -func (repo *JobRunRepository) Save(namespace models.NamespaceSpec, spec models.JobRun) error { +func (repo *JobRunRepository) Save(ctx context.Context, namespace models.NamespaceSpec, spec models.JobRun) error { if spec.Status == "" { // mark default state pending spec.Status = models.RunStatePending } - existingResource, _, err := repo.GetByID(spec.ID) + existingResource, _, err := repo.GetByID(ctx, spec.ID) if errors.Is(err, store.ErrResourceNotFound) { - return repo.Insert(namespace, spec) + return repo.Insert(ctx, namespace, spec) } else if err != nil { return errors.Wrap(err, "unable to find jobrun by id") } @@ -42,73 +43,61 @@ func (repo *JobRunRepository) Save(namespace models.NamespaceSpec, spec models.J return err } resource.ID = existingResource.ID - return repo.db.Omit("Namespace").Model(&resource).Updates(&resource).Error + return repo.db.WithContext(ctx).Omit("Namespace", "Instances").Model(&resource).Updates(&resource).Error } -func (repo *JobRunRepository) GetByID(id uuid.UUID) (models.JobRun, models.NamespaceSpec, error) { +func (repo *JobRunRepository) GetByID(ctx context.Context, id uuid.UUID) (models.JobRun, models.NamespaceSpec, error) { var r JobRun - if err := repo.db.Preload("Namespace").Preload("Instances").Where("id = ?", id).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Preload("Namespace").Where("id = ?", id).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.JobRun{}, models.NamespaceSpec{}, store.ErrResourceNotFound } return models.JobRun{}, models.NamespaceSpec{}, err } + if instances, err := repo.instanceRepo.GetByJobRun(ctx, r.ID); err == nil { + r.Instances = instances + } return repo.adapter.ToJobRun(r) } -func (repo *JobRunRepository) GetByScheduledAt(jobID uuid.UUID, scheduledAt time.Time) (models.JobRun, models.NamespaceSpec, error) { +func (repo *JobRunRepository) GetByScheduledAt(ctx context.Context, jobID uuid.UUID, scheduledAt time.Time) (models.JobRun, models.NamespaceSpec, error) { var r JobRun - if err := repo.db.Preload("Namespace").Preload("Instances").Where("job_id = ? AND scheduled_at = ?", jobID, scheduledAt).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Preload("Namespace").Where("job_id = ? AND scheduled_at = ?", jobID, scheduledAt).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.JobRun{}, models.NamespaceSpec{}, store.ErrResourceNotFound } return models.JobRun{}, models.NamespaceSpec{}, err } + if instances, err := repo.instanceRepo.GetByJobRun(ctx, r.ID); err == nil { + r.Instances = instances + } return repo.adapter.ToJobRun(r) } // AddInstance associate instance details -func (repo *JobRunRepository) AddInstance(namespaceSpec models.NamespaceSpec, run models.JobRun, spec models.InstanceSpec) error { - for idx, instance := range run.Instances { - if instance.Name == spec.Name && instance.Type == spec.Type { - // delete if associated before - if err := repo.instanceRepo.Delete(instance.ID); err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return err - } - - // delete this index - run.Instances[idx] = run.Instances[len(run.Instances)-1] - run.Instances = run.Instances[:len(run.Instances)-1] - break - } +func (repo *JobRunRepository) AddInstance(ctx context.Context, namespaceSpec models.NamespaceSpec, run models.JobRun, spec models.InstanceSpec) error { + instance, err := repo.instanceRepo.GetByName(ctx, run.ID, spec.Name, spec.Type.String()) + if err != nil && !errors.Is(err, store.ErrResourceNotFound) { + return err } - run.Instances = append(run.Instances, spec) - return repo.Save(namespaceSpec, run) -} - -// ClearInstances deletes all associated instance details -func (repo *JobRunRepository) ClearInstances(jobID uuid.UUID, scheduled time.Time) error { - var r JobRun - if err := repo.db.Where("job_id = ? AND scheduled_at = ?", jobID, scheduled).Find(&r).Error; err != nil { - if !errors.Is(err, gorm.ErrRecordNotFound) { + if instance.ID.String() != "" { + // delete if associated before + if err := repo.instanceRepo.Delete(ctx, instance.ID); err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } } - if err := repo.instanceRepo.DeleteByJobRun(r.ID); err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return err - } - return repo.db.Model(&r).Update(map[string]interface{}{"data": nil, "status": models.RunStatePending}).Error + return repo.instanceRepo.Save(ctx, run, spec) } // ClearInstance deletes associated instance details -func (repo *JobRunRepository) ClearInstance(runID uuid.UUID, instanceType models.InstanceType, instanceName string) error { - r, _, err := repo.GetByID(runID) +func (repo *JobRunRepository) ClearInstance(ctx context.Context, runID uuid.UUID, instanceType models.InstanceType, instanceName string) error { + r, _, err := repo.GetByID(ctx, runID) if err != nil { return err } for _, instance := range r.Instances { if instance.Name == instanceName && instance.Type == instanceType { - if err := repo.instanceRepo.Delete(instance.ID); err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + if err := repo.instanceRepo.Delete(ctx, instance.ID); err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } break @@ -118,25 +107,21 @@ func (repo *JobRunRepository) ClearInstance(runID uuid.UUID, instanceType models } // Clear prepares job run for fresh start -func (repo *JobRunRepository) Clear(runID uuid.UUID) error { - r, _, err := repo.GetByID(runID) - if err != nil { +func (repo *JobRunRepository) Clear(ctx context.Context, runID uuid.UUID) error { + if err := repo.instanceRepo.DeleteByJobRun(ctx, runID); err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } - if err := repo.instanceRepo.DeleteByJobRun(runID); err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return err - } - return repo.db.Model(&r).Update(map[string]interface{}{"data": nil, "status": models.RunStatePending}).Error + return repo.db.WithContext(ctx).Model(&JobRun{ID: runID}).Updates(JobRun{Status: models.RunStatePending.String()}).Error } -func (repo *JobRunRepository) Delete(id uuid.UUID) error { - if err := repo.instanceRepo.DeleteByJobRun(id); err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { +func (repo *JobRunRepository) Delete(ctx context.Context, id uuid.UUID) error { + if err := repo.instanceRepo.DeleteByJobRun(ctx, id); err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } - return repo.db.Where("id = ?", id).Delete(&JobRun{}).Error + return repo.db.WithContext(ctx).Where("id = ?", id).Delete(&JobRun{}).Error } -func (repo *JobRunRepository) UpdateStatus(id uuid.UUID, status models.JobRunState) error { +func (repo *JobRunRepository) UpdateStatus(ctx context.Context, id uuid.UUID, status models.JobRunState) error { var jr JobRun if err := repo.db.Where("id = ?", id).Find(&jr).Error; err != nil { return err @@ -145,14 +130,17 @@ func (repo *JobRunRepository) UpdateStatus(id uuid.UUID, status models.JobRunSta return repo.db.Omit("Namespace").Save(jr).Error } -func (repo *JobRunRepository) GetByStatus(statuses ...models.JobRunState) ([]models.JobRun, error) { +func (repo *JobRunRepository) GetByStatus(ctx context.Context, statuses ...models.JobRunState) ([]models.JobRun, error) { var specs []models.JobRun var runs []JobRun - if err := repo.db.Preload("Instances").Where("status IN (?)", statuses).Find(&runs).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("status IN (?)", statuses).Find(&runs).Error; err != nil { return specs, err } for _, run := range runs { + if instances, err := repo.instanceRepo.GetByJobRun(ctx, run.ID); err == nil { + run.Instances = instances + } adapt, _, err := repo.adapter.ToJobRun(run) if err != nil { return specs, err @@ -162,20 +150,23 @@ func (repo *JobRunRepository) GetByStatus(statuses ...models.JobRunState) ([]mod return specs, nil } -func (repo *JobRunRepository) GetByTrigger(trigger models.JobRunTrigger, statuses ...models.JobRunState) ([]models.JobRun, error) { +func (repo *JobRunRepository) GetByTrigger(ctx context.Context, trigger models.JobRunTrigger, statuses ...models.JobRunState) ([]models.JobRun, error) { var specs []models.JobRun var runs []JobRun if len(statuses) > 0 { - if err := repo.db.Preload("Instances").Where("trigger = ? and status IN (?)", trigger, statuses).Find(&runs).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("trigger = ? and status IN (?)", trigger, statuses).Find(&runs).Error; err != nil { return specs, err } } else { - if err := repo.db.Preload("Instances").Where("trigger = ?", trigger).Find(&runs).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("trigger = ?", trigger).Find(&runs).Error; err != nil { return specs, err } } for _, run := range runs { + if instances, err := repo.instanceRepo.GetByJobRun(ctx, run.ID); err == nil { + run.Instances = instances + } adapt, _, err := repo.adapter.ToJobRun(run) if err != nil { return specs, err diff --git a/store/postgres/jobrun_repository_test.go b/store/postgres/jobrun_repository_test.go index 1d4b2ad5cd..a1f7868ce4 100644 --- a/store/postgres/jobrun_repository_test.go +++ b/store/postgres/jobrun_repository_test.go @@ -5,18 +5,18 @@ package postgres import ( "context" "os" - "sort" "testing" "time" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/mock" "github.com/odpf/optimus/models" "github.com/stretchr/testify/assert" + "gorm.io/gorm" ) func TestJobRunRepository(t *testing.T) { + ctx := context.Background() projectSpec := models.ProjectSpec{ ID: uuid.Must(uuid.NewRandom()), Name: "t-optimus-id", @@ -116,12 +116,12 @@ func TestJobRunRepository(t *testing.T) { hash, _ := models.NewApplicationSecret("32charshtesthashtesthashtesthash") prepo := NewProjectRepository(dbConn, hash) - assert.Nil(t, prepo.Save(projectSpec)) + assert.Nil(t, prepo.Save(ctx, projectSpec)) projectJobSpecRepo := NewProjectJobSpecRepository(dbConn, projectSpec, adapter) jrepo := NewJobSpecRepository(dbConn, namespaceSpec, projectJobSpecRepo, adapter) - assert.Nil(t, jrepo.Save(jobConfigs[0])) - assert.Equal(t, "task unit cannot be empty", jrepo.Save(jobConfigs[1]).Error()) + assert.Nil(t, jrepo.Save(ctx, jobConfigs[0])) + assert.Equal(t, "task unit cannot be empty", jrepo.Save(ctx, jobConfigs[1]).Error()) return dbConn } @@ -163,78 +163,74 @@ func TestJobRunRepository(t *testing.T) { t.Run("Insert", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() var testModels []models.JobRun testModels = append(testModels, testSpecs...) repo := NewJobRunRepository(db, adapter) - err := repo.Insert(namespaceSpec, testModels[1]) + err := repo.Insert(ctx, namespaceSpec, testModels[1]) assert.Nil(t, err) - checkModel, ns, err := repo.GetByID(testModels[1].ID) + checkModel, ns, err := repo.GetByID(ctx, testModels[1].ID) assert.Nil(t, err) assert.Equal(t, testModels[1].Spec.Name, checkModel.Spec.Name) assert.Equal(t, testModels[1].ScheduledAt.Unix(), checkModel.ScheduledAt.Unix()) assert.Equal(t, namespaceSpec.ID, ns.ID) - err = repo.Insert(namespaceSpec, testModels[0]) + err = repo.Insert(ctx, namespaceSpec, testModels[0]) assert.Nil(t, err) - checkModel, ns, err = repo.GetByID(testModels[0].ID) + checkModel, ns, err = repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, testModels[0].Spec.Name, checkModel.Spec.Name) assert.Equal(t, testModels[0].ScheduledAt.Unix(), checkModel.ScheduledAt.Unix()) assert.Equal(t, namespaceSpec.ID, ns.ID) - assert.Equal(t, len(testModels[0].Instances), len(checkModel.Instances)) - sort.Slice(testModels[0].Instances, func(i, j int) bool { - return testModels[0].Instances[i].ID.String() < testModels[0].Instances[j].ID.String() - }) - sort.Slice(checkModel.Instances, func(i, j int) bool { - return checkModel.Instances[i].ID.String() < checkModel.Instances[j].ID.String() - }) - assert.EqualValues(t, testModels[0].Instances[0].ID, checkModel.Instances[0].ID) + assert.Equal(t, 0, len(checkModel.Instances)) }) t.Run("Save", func(t *testing.T) { t.Run("should save and delete fresh runs correctly", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.JobRun{} testModels = append(testModels, testSpecs...) repo := NewJobRunRepository(db, adapter) - err := repo.Save(namespaceSpec, testModels[0]) + err := repo.Save(ctx, namespaceSpec, testModels[0]) assert.Nil(t, err) - checkModel, _, err := repo.GetByID(testModels[0].ID) + checkModel, _, err := repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, testModels[0].Spec.Name, checkModel.Spec.Name) assert.Equal(t, testModels[0].ScheduledAt.Unix(), checkModel.ScheduledAt.Unix()) - err = repo.Delete(testModels[0].ID) + err = repo.Delete(ctx, testModels[0].ID) assert.Nil(t, err) - err = repo.Save(namespaceSpec, testModels[0]) + err = repo.Save(ctx, namespaceSpec, testModels[0]) assert.Nil(t, err) - checkModel, _, err = repo.GetByID(testModels[0].ID) + checkModel, _, err = repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, testModels[0].Spec.Name, checkModel.Spec.Name) assert.Equal(t, testModels[0].ScheduledAt.Unix(), checkModel.ScheduledAt.Unix()) }) t.Run("should upsert existing runs correctly", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.JobRun{} testModels = append(testModels, testSpecs...) repo := NewJobRunRepository(db, adapter) - err := repo.Save(namespaceSpec, testModels[0]) + err := repo.Save(ctx, namespaceSpec, testModels[0]) assert.Nil(t, err) - checkModel, _, err := repo.GetByID(testModels[0].ID) + checkModel, _, err := repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, testModels[0].Spec.Name, checkModel.Spec.Name) assert.Equal(t, testModels[0].ScheduledAt.Unix(), checkModel.ScheduledAt.Unix()) @@ -242,99 +238,89 @@ func TestJobRunRepository(t *testing.T) { // update resource testModels[0].ScheduledAt = testModels[0].ScheduledAt.Add(time.Nanosecond) - err = repo.Save(namespaceSpec, testModels[0]) + err = repo.Save(ctx, namespaceSpec, testModels[0]) assert.Nil(t, err) - checkModel, _, err = repo.GetByID(testModels[0].ID) + checkModel, _, err = repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, testModels[0].Spec.Name, checkModel.Spec.Name) assert.Equal(t, testModels[0].ScheduledAt.Add(time.Nanosecond).Unix(), checkModel.ScheduledAt.Unix()) }) }) - t.Run("ClearInstances", func(t *testing.T) { - db := DBSetup() - defer db.Close() - - var testModels []models.JobRun - testModels = append(testModels, testSpecs...) - - repo := NewJobRunRepository(db, adapter) - err := repo.Insert(namespaceSpec, testModels[0]) - assert.Nil(t, err) - - err = repo.ClearInstances(testModels[0].Spec.ID, testModels[0].ScheduledAt) - assert.Nil(t, err) - - checkModel, _, err := repo.GetByID(testModels[0].ID) - assert.Nil(t, err) - assert.Equal(t, 0, len(checkModel.Instances)) - }) t.Run("ClearInstance", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() var testModels []models.JobRun testModels = append(testModels, testSpecs...) repo := NewJobRunRepository(db, adapter) - err := repo.Insert(namespaceSpec, testModels[0]) + err := repo.Insert(ctx, namespaceSpec, testModels[0]) assert.Nil(t, err) + assert.Nil(t, repo.AddInstance(ctx, namespaceSpec, testModels[0], testModels[0].Instances[0])) + assert.Nil(t, repo.AddInstance(ctx, namespaceSpec, testModels[0], testModels[0].Instances[1])) - err = repo.ClearInstance(testModels[0].ID, models.InstanceTypeTask, "do-this") + err = repo.ClearInstance(ctx, testModels[0].ID, models.InstanceTypeTask, "do-this") assert.Nil(t, err) - checkModel, _, err := repo.GetByID(testModels[0].ID) + checkModel, _, err := repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, 1, len(checkModel.Instances)) }) t.Run("GetByStatus", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() var testModels []models.JobRun testModels = append(testModels, testSpecs...) repo := NewJobRunRepository(db, adapter) - err := repo.Insert(namespaceSpec, testModels[0]) + err := repo.Insert(ctx, namespaceSpec, testModels[0]) assert.Nil(t, err) - runs, err := repo.GetByStatus(models.RunStateRunning) + runs, err := repo.GetByStatus(ctx, models.RunStateRunning) assert.Nil(t, err) assert.Equal(t, 1, len(runs)) }) t.Run("AddInstance", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() var testModels []models.JobRun testModels = append(testModels, testSpecs...) repo := NewJobRunRepository(db, adapter) - err := repo.Insert(namespaceSpec, testModels[1]) + err := repo.Insert(ctx, namespaceSpec, testModels[1]) assert.Nil(t, err) - err = repo.AddInstance(namespaceSpec, testModels[1], testInstanceSpecs[0]) + err = repo.AddInstance(ctx, namespaceSpec, testModels[1], testInstanceSpecs[0]) assert.Nil(t, err) - jr, _, err := repo.GetByID(testModels[1].ID) + jr, _, err := repo.GetByID(ctx, testModels[1].ID) assert.Nil(t, err) assert.Equal(t, 1, len(jr.Instances)) }) t.Run("Clear", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() var testModels []models.JobRun testModels = append(testModels, testSpecs...) repo := NewJobRunRepository(db, adapter) - err := repo.Insert(namespaceSpec, testModels[0]) + err := repo.Insert(ctx, namespaceSpec, testModels[0]) assert.Nil(t, err) + assert.Nil(t, repo.AddInstance(ctx, namespaceSpec, testModels[0], testModels[0].Instances[0])) + assert.Nil(t, repo.AddInstance(ctx, namespaceSpec, testModels[0], testModels[0].Instances[1])) - err = repo.Clear(testModels[0].ID) + err = repo.Clear(ctx, testModels[0].ID) assert.Nil(t, err) - jr, _, err := repo.GetByID(testModels[0].ID) + jr, _, err := repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, 0, len(jr.Instances)) assert.Equal(t, models.RunStatePending, jr.Status) diff --git a/store/postgres/namespace_repository.go b/store/postgres/namespace_repository.go index 5c02cf5ab7..dc154b7cd2 100644 --- a/store/postgres/namespace_repository.go +++ b/store/postgres/namespace_repository.go @@ -1,16 +1,17 @@ package postgres import ( + "context" "encoding/json" "time" "github.com/odpf/optimus/store" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/models" "github.com/pkg/errors" "gorm.io/datatypes" + "gorm.io/gorm" ) type Namespace struct { @@ -23,7 +24,7 @@ type Namespace struct { CreatedAt time.Time `gorm:"not null" json:"created_at"` UpdatedAt time.Time `gorm:"not null" json:"updated_at"` - DeletedAt *time.Time + DeletedAt gorm.DeletedAt } func (p Namespace) FromSpec(spec models.NamespaceSpec) (Namespace, error) { @@ -94,7 +95,7 @@ type namespaceRepository struct { hash models.ApplicationKey } -func (repo *namespaceRepository) Insert(resource models.NamespaceSpec) error { +func (repo *namespaceRepository) Insert(ctx context.Context, resource models.NamespaceSpec) error { c, err := Namespace{}.FromSpecWithProject(resource, repo.project) if err != nil { return err @@ -102,13 +103,13 @@ func (repo *namespaceRepository) Insert(resource models.NamespaceSpec) error { if len(c.Name) == 0 { return errors.New("name cannot be empty") } - return repo.db.Create(&c).Error + return repo.db.WithContext(ctx).Create(&c).Error } -func (repo *namespaceRepository) Save(spec models.NamespaceSpec) error { - existingResource, err := repo.GetByName(spec.Name) +func (repo *namespaceRepository) Save(ctx context.Context, spec models.NamespaceSpec) error { + existingResource, err := repo.GetByName(ctx, spec.Name) if errors.Is(err, store.ErrResourceNotFound) { - return repo.Insert(spec) + return repo.Insert(ctx, spec) } else if err != nil { return errors.Wrap(err, "unable to find namespace by name") } @@ -117,12 +118,12 @@ func (repo *namespaceRepository) Save(spec models.NamespaceSpec) error { return err } resource.ID = existingResource.ID - return repo.db.Model(resource).Updates(resource).Error + return repo.db.WithContext(ctx).Model(resource).Updates(resource).Error } -func (repo *namespaceRepository) GetByName(name string) (models.NamespaceSpec, error) { +func (repo *namespaceRepository) GetByName(ctx context.Context, name string) (models.NamespaceSpec, error) { var r Namespace - if err := repo.db.Preload("Project").Preload("Project.Secrets").Where("name = ? AND project_id = ?", name, repo.project.ID).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Preload("Project").Preload("Project.Secrets").Where("name = ? AND project_id = ?", name, repo.project.ID).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.NamespaceSpec{}, store.ErrResourceNotFound } @@ -131,10 +132,10 @@ func (repo *namespaceRepository) GetByName(name string) (models.NamespaceSpec, e return r.ToSpecWithProjectSecrets(repo.hash) } -func (repo *namespaceRepository) GetAll() ([]models.NamespaceSpec, error) { - specs := []models.NamespaceSpec{} - namespaces := []Namespace{} - if err := repo.db.Preload("Project").Preload("Project.Secrets").Where("project_id = ?", repo.project.ID).Find(&namespaces).Error; err != nil { +func (repo *namespaceRepository) GetAll(ctx context.Context) ([]models.NamespaceSpec, error) { + var specs []models.NamespaceSpec + var namespaces []Namespace + if err := repo.db.WithContext(ctx).Preload("Project").Preload("Project.Secrets").Where("project_id = ?", repo.project.ID).Find(&namespaces).Error; err != nil { return specs, err } diff --git a/store/postgres/namespace_repository_test.go b/store/postgres/namespace_repository_test.go index 9739b2ca4f..52f14a9365 100644 --- a/store/postgres/namespace_repository_test.go +++ b/store/postgres/namespace_repository_test.go @@ -3,13 +3,14 @@ package postgres import ( + "context" "os" "testing" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/models" "github.com/stretchr/testify/assert" + "gorm.io/gorm" ) func TestNamespaceRepository(t *testing.T) { @@ -35,6 +36,7 @@ func TestNamespaceRepository(t *testing.T) { return dbConn } + ctx := context.Background() transporterKafkaBrokerKey := "KAFKA_BROKERS" hash, _ := models.NewApplicationSecret("32charshtesthashtesthashtesthash") @@ -87,30 +89,31 @@ func TestNamespaceRepository(t *testing.T) { t.Run("Insert", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.NamespaceSpec{} testModels = append(testModels, namespaceSpecs...) // save project projRepo := NewProjectRepository(db, hash) - err := projRepo.Save(projectSpec) + err := projRepo.Save(ctx, projectSpec) assert.Nil(t, err) secretRepo := NewSecretRepository(db, projectSpec, hash) - err = secretRepo.Insert(secrets[0]) + err = secretRepo.Insert(ctx, secrets[0]) assert.Nil(t, err) - err = secretRepo.Insert(secrets[1]) + err = secretRepo.Insert(ctx, secrets[1]) assert.Nil(t, err) repo := NewNamespaceRepository(db, projectSpec, hash) - err = repo.Insert(testModels[0]) + err = repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - err = repo.Insert(testModels[1]) + err = repo.Insert(ctx, testModels[1]) assert.NotNil(t, err) - checkModel, err := repo.GetByName(testModels[0].Name) + checkModel, err := repo.GetByName(ctx, testModels[0].Name) assert.Nil(t, err) assert.Equal(t, "g-optimus", checkModel.Name) assert.Equal(t, projectSpec.Name, checkModel.ProjectSpec.Name) @@ -120,67 +123,70 @@ func TestNamespaceRepository(t *testing.T) { t.Run("Upsert", func(t *testing.T) { t.Run("insert different resource should insert two", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := namespaceSpecs[0] testModelB := namespaceSpecs[2] repo := NewNamespaceRepository(db, projectSpec, hash) //try for create - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByName(testModelA.Name) + checkModel, err := repo.GetByName(ctx, testModelA.Name) assert.Nil(t, err) assert.Equal(t, "g-optimus", checkModel.Name) //try for update - err = repo.Save(testModelB) + err = repo.Save(ctx, testModelB) assert.Nil(t, err) - checkModel, err = repo.GetByName(testModelB.Name) + checkModel, err = repo.GetByName(ctx, testModelB.Name) assert.Nil(t, err) assert.Equal(t, "t-optimus", checkModel.Name) assert.Equal(t, "10.12.12.12:6668,10.12.12.13:6668", checkModel.Config[transporterKafkaBrokerKey]) }) t.Run("insert same resource twice should overwrite existing", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := namespaceSpecs[2] repo := NewNamespaceRepository(db, projectSpec, hash) //try for create testModelA.Config["bucket"] = "gs://some_folder" - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByName(testModelA.Name) + checkModel, err := repo.GetByName(ctx, testModelA.Name) assert.Nil(t, err) assert.Equal(t, "t-optimus", checkModel.Name) //try for update testModelA.Config["bucket"] = "gs://another_folder" - err = repo.Save(testModelA) + err = repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err = repo.GetByName(testModelA.Name) + checkModel, err = repo.GetByName(ctx, testModelA.Name) assert.Nil(t, err) assert.Equal(t, "gs://another_folder", checkModel.Config["bucket"]) }) t.Run("upsert without ID should auto generate it", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := namespaceSpecs[0] testModelA.ID = uuid.Nil repo := NewNamespaceRepository(db, projectSpec, hash) //try for create - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByName(testModelA.Name) + checkModel, err := repo.GetByName(ctx, testModelA.Name) assert.Nil(t, err) assert.Equal(t, "g-optimus", checkModel.Name) assert.Equal(t, 36, len(checkModel.ID.String())) @@ -189,34 +195,36 @@ func TestNamespaceRepository(t *testing.T) { t.Run("GetByName", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.NamespaceSpec{} testModels = append(testModels, namespaceSpecs...) repo := NewNamespaceRepository(db, projectSpec, hash) - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - checkModel, err := repo.GetByName(testModels[0].Name) + checkModel, err := repo.GetByName(ctx, testModels[0].Name) assert.Nil(t, err) assert.Equal(t, "g-optimus", checkModel.Name) }) t.Run("GetAll", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.NamespaceSpec{} testModels = append(testModels, namespaceSpecs...) repo := NewNamespaceRepository(db, projectSpec, hash) - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - err = repo.Insert(testModels[2]) + err = repo.Insert(ctx, testModels[2]) assert.Nil(t, err) - checkModel, err := repo.GetAll() + checkModel, err := repo.GetAll(ctx) assert.Nil(t, err) assert.Equal(t, 2, len(checkModel)) }) diff --git a/store/postgres/postgres.go b/store/postgres/postgres.go index 1c2c2baf28..e3de2dea22 100644 --- a/store/postgres/postgres.go +++ b/store/postgres/postgres.go @@ -12,14 +12,12 @@ import ( "net/http" "github.com/golang-migrate/migrate/v4" + _ "github.com/golang-migrate/migrate/v4/database/postgres" // required for postgres migrate driver "github.com/golang-migrate/migrate/v4/source/httpfs" - "github.com/jinzhu/gorm" "github.com/pkg/errors" - - _ "embed" - - _ "github.com/golang-migrate/migrate/v4/database/postgres" // required for postgres migrate driver - _ "github.com/lib/pq" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/schema" ) //go:embed migrations @@ -40,16 +38,20 @@ func NewHTTPFSMigrator(DBConnURL string) (*migrate.Migrate, error) { // Connect connect to the DB with custom configuration. func Connect(connURL string, maxIdleConnections, maxOpenConnections int) (*gorm.DB, error) { - var db *gorm.DB - var err error - - if db, err = gorm.Open("postgres", connURL); err != nil { + db, err := gorm.Open(postgres.Open(connURL), &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + SingularTable: true, + }, + }) + if err != nil { + return nil, errors.Wrap(err, "failed to initialize postgres db connection") + } + sqlDB, err := db.DB() + if err != nil { return nil, err } - - db.DB().SetMaxIdleConns(maxIdleConnections) - db.DB().SetMaxOpenConns(maxOpenConnections) - db.SingularTable(true) + sqlDB.SetMaxIdleConns(maxIdleConnections) + sqlDB.SetMaxOpenConns(maxOpenConnections) return db, nil } diff --git a/store/postgres/project_repository.go b/store/postgres/project_repository.go index b92d89ff17..82722a0089 100644 --- a/store/postgres/project_repository.go +++ b/store/postgres/project_repository.go @@ -1,16 +1,17 @@ package postgres import ( + "context" "encoding/json" "time" "github.com/odpf/optimus/store" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/models" "github.com/pkg/errors" "gorm.io/datatypes" + "gorm.io/gorm" ) type Project struct { @@ -23,7 +24,7 @@ type Project struct { CreatedAt time.Time `gorm:"not null" json:"created_at"` UpdatedAt time.Time `gorm:"not null" json:"updated_at"` - DeletedAt *time.Time + DeletedAt gorm.DeletedAt } func (p Project) FromSpec(spec models.ProjectSpec) (Project, error) { @@ -76,7 +77,7 @@ type ProjectRepository struct { hash models.ApplicationKey } -func (repo *ProjectRepository) Insert(resource models.ProjectSpec) error { +func (repo *ProjectRepository) Insert(ctx context.Context, resource models.ProjectSpec) error { p, err := Project{}.FromSpec(resource) if err != nil { return err @@ -84,13 +85,13 @@ func (repo *ProjectRepository) Insert(resource models.ProjectSpec) error { if len(p.Name) == 0 { return errors.New("name cannot be empty") } - return repo.db.Create(&p).Error + return repo.db.WithContext(ctx).Create(&p).Error } -func (repo *ProjectRepository) Save(spec models.ProjectSpec) error { - existingResource, err := repo.GetByName(spec.Name) +func (repo *ProjectRepository) Save(ctx context.Context, spec models.ProjectSpec) error { + existingResource, err := repo.GetByName(ctx, spec.Name) if errors.Is(err, store.ErrResourceNotFound) { - return repo.Insert(spec) + return repo.Insert(ctx, spec) } else if err != nil { return errors.Wrap(err, "unable to find project by name") } @@ -99,12 +100,12 @@ func (repo *ProjectRepository) Save(spec models.ProjectSpec) error { return err } project.ID = existingResource.ID - return repo.db.Model(&project).Updates(&project).Error + return repo.db.WithContext(ctx).Omit("Secrets").Model(&project).Update("Config", project.Config).Error } -func (repo *ProjectRepository) GetByName(name string) (models.ProjectSpec, error) { +func (repo *ProjectRepository) GetByName(ctx context.Context, name string) (models.ProjectSpec, error) { var r Project - if err := repo.db.Preload("Secrets").Where("name = ?", name).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Preload("Secrets").Where("name = ?", name).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.ProjectSpec{}, store.ErrResourceNotFound } @@ -113,9 +114,9 @@ func (repo *ProjectRepository) GetByName(name string) (models.ProjectSpec, error return r.ToSpecWithSecrets(repo.hash) } -func (repo *ProjectRepository) GetByID(id uuid.UUID) (models.ProjectSpec, error) { +func (repo *ProjectRepository) GetByID(ctx context.Context, id uuid.UUID) (models.ProjectSpec, error) { var r Project - if err := repo.db.Preload("Secrets").Where("id = ?", id).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Preload("Secrets").Where("id = ?", id).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.ProjectSpec{}, store.ErrResourceNotFound } @@ -124,10 +125,10 @@ func (repo *ProjectRepository) GetByID(id uuid.UUID) (models.ProjectSpec, error) return r.ToSpecWithSecrets(repo.hash) } -func (repo *ProjectRepository) GetAll() ([]models.ProjectSpec, error) { - specs := []models.ProjectSpec{} - projs := []Project{} - if err := repo.db.Preload("Secrets").Find(&projs).Error; err != nil { +func (repo *ProjectRepository) GetAll(ctx context.Context) ([]models.ProjectSpec, error) { + var specs []models.ProjectSpec + var projs []Project + if err := repo.db.WithContext(ctx).Preload("Secrets").Find(&projs).Error; err != nil { return specs, err } for _, proj := range projs { diff --git a/store/postgres/project_repository_test.go b/store/postgres/project_repository_test.go index c1206677c7..1046462bc2 100644 --- a/store/postgres/project_repository_test.go +++ b/store/postgres/project_repository_test.go @@ -3,14 +3,15 @@ package postgres import ( + "context" "os" "sort" "testing" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/models" "github.com/stretchr/testify/assert" + "gorm.io/gorm" ) func TestProjectRepository(t *testing.T) { @@ -36,6 +37,7 @@ func TestProjectRepository(t *testing.T) { return dbConn } + ctx := context.Background() hash, _ := models.NewApplicationSecret("32charshtesthashtesthashtesthash") @@ -68,108 +70,113 @@ func TestProjectRepository(t *testing.T) { t.Run("Insert", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.ProjectSpec{} testModels = append(testModels, testConfigs...) repo := NewProjectRepository(db, hash) - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - err = repo.Insert(testModels[1]) + err = repo.Insert(ctx, testModels[1]) assert.NotNil(t, err) - checkModel, err := repo.GetByID(testModels[0].ID) + checkModel, err := repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, "g-optimus", checkModel.Name) }) t.Run("Upsert", func(t *testing.T) { t.Run("insert different resource should insert two", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := testConfigs[0] testModelB := testConfigs[2] repo := NewProjectRepository(db, hash) //try for create - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModelA.ID) + checkModel, err := repo.GetByID(ctx, testModelA.ID) assert.Nil(t, err) assert.Equal(t, "g-optimus", checkModel.Name) //try for update - err = repo.Save(testModelB) + err = repo.Save(ctx, testModelB) assert.Nil(t, err) - checkModel, err = repo.GetByID(testModelB.ID) + checkModel, err = repo.GetByID(ctx, testModelB.ID) assert.Nil(t, err) assert.Equal(t, "t-optimus", checkModel.Name) assert.Equal(t, "10.12.12.12:6668,10.12.12.13:6668", checkModel.Config[transporterKafkaBrokerKey]) }) t.Run("insert same resource twice should overwrite existing", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := testConfigs[2] repo := NewProjectRepository(db, hash) //try for create testModelA.Config["bucket"] = "gs://some_folder" - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModelA.ID) + checkModel, err := repo.GetByID(ctx, testModelA.ID) assert.Nil(t, err) assert.Equal(t, "t-optimus", checkModel.Name) //try for update testModelA.Config["bucket"] = "gs://another_folder" - err = repo.Save(testModelA) + err = repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err = repo.GetByID(testModelA.ID) + checkModel, err = repo.GetByID(ctx, testModelA.ID) assert.Nil(t, err) assert.Equal(t, "gs://another_folder", checkModel.Config["bucket"]) }) t.Run("upsert without ID should auto generate it", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := testConfigs[0] testModelA.ID = uuid.Nil repo := NewProjectRepository(db, hash) //try for create - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByName(testModelA.Name) + checkModel, err := repo.GetByName(ctx, testModelA.Name) assert.Nil(t, err) assert.Equal(t, "g-optimus", checkModel.Name) }) }) t.Run("GetByName", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.ProjectSpec{} testModels = append(testModels, testConfigs...) repo := NewProjectRepository(db, hash) - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - err = NewSecretRepository(db, testModels[0], hash).Save(models.ProjectSecretItem{ + err = NewSecretRepository(db, testModels[0], hash).Save(ctx, models.ProjectSecretItem{ Name: "t1", Value: "v1", }) assert.Nil(t, err) - checkModel, err := repo.GetByName(testModels[0].Name) + checkModel, err := repo.GetByName(ctx, testModels[0].Name) assert.Nil(t, err) assert.Equal(t, "g-optimus", checkModel.Name) @@ -178,27 +185,28 @@ func TestProjectRepository(t *testing.T) { }) t.Run("GetAll", func(t *testing.T) { db := DBSetup() - defer db.Close() - testModels := []models.ProjectSpec{} + sqlDB, _ := db.DB() + defer sqlDB.Close() + var testModels []models.ProjectSpec testModels = append(testModels, testConfigs...) repo := NewProjectRepository(db, hash) - assert.Nil(t, repo.Insert(testModels[2])) - assert.Nil(t, repo.Insert(testModels[3])) + assert.Nil(t, repo.Insert(ctx, testModels[2])) + assert.Nil(t, repo.Insert(ctx, testModels[3])) - err := NewSecretRepository(db, testModels[2], hash).Save(models.ProjectSecretItem{ + err := NewSecretRepository(db, testModels[2], hash).Save(ctx, models.ProjectSecretItem{ Name: "t1", Value: "v1", }) assert.Nil(t, err) - err = NewSecretRepository(db, testModels[3], hash).Save(models.ProjectSecretItem{ + err = NewSecretRepository(db, testModels[3], hash).Save(ctx, models.ProjectSecretItem{ Name: "t2", Value: "v2", }) assert.Nil(t, err) - checkModels, err := repo.GetAll() + checkModels, err := repo.GetAll(ctx) assert.Nil(t, err) sort.Slice(checkModels, func(i, j int) bool { return checkModels[i].Name < checkModels[j].Name diff --git a/store/postgres/replay_repository.go b/store/postgres/replay_repository.go index 153192d112..a60a624009 100644 --- a/store/postgres/replay_repository.go +++ b/store/postgres/replay_repository.go @@ -1,6 +1,7 @@ package postgres import ( + "context" "encoding/json" "errors" "time" @@ -10,9 +11,9 @@ import ( "gorm.io/datatypes" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/models" "github.com/odpf/optimus/store" + "gorm.io/gorm" ) type Replay struct { @@ -155,17 +156,17 @@ func NewReplayRepository(db *gorm.DB, jobAdapter *JobSpecAdapter) *replayReposit } } -func (repo *replayRepository) Insert(replay *models.ReplaySpec) error { +func (repo *replayRepository) Insert(ctx context.Context, replay *models.ReplaySpec) error { r, err := Replay{}.FromSpec(replay) if err != nil { return err } - return repo.DB.Create(&r).Error + return repo.DB.WithContext(ctx).Create(&r).Error } -func (repo *replayRepository) GetByID(id uuid.UUID) (models.ReplaySpec, error) { +func (repo *replayRepository) GetByID(ctx context.Context, id uuid.UUID) (models.ReplaySpec, error) { var r Replay - if err := repo.DB.Where("id = ?", id).Preload("Job").Find(&r).Error; err != nil { + if err := repo.DB.WithContext(ctx).Where("id = ?", id).Preload("Job").First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.ReplaySpec{}, store.ErrResourceNotFound } @@ -178,9 +179,9 @@ func (repo *replayRepository) GetByID(id uuid.UUID) (models.ReplaySpec, error) { return r.ToSpec(jobSpec) } -func (repo *replayRepository) UpdateStatus(replayID uuid.UUID, status string, message models.ReplayMessage) error { +func (repo *replayRepository) UpdateStatus(ctx context.Context, replayID uuid.UUID, status string, message models.ReplayMessage) error { var r Replay - if err := repo.DB.Where("id = ?", replayID).Find(&r).Error; err != nil { + if err := repo.DB.WithContext(ctx).Where("id = ?", replayID).Find(&r).Error; err != nil { return errors.New("could not update non-existing replay") } jsonBytes, err := json.Marshal(message) @@ -189,12 +190,12 @@ func (repo *replayRepository) UpdateStatus(replayID uuid.UUID, status string, me } r.Status = status r.Message = jsonBytes - return repo.DB.Save(&r).Error + return repo.DB.WithContext(ctx).Save(&r).Error } -func (repo *replayRepository) GetByStatus(status []string) ([]models.ReplaySpec, error) { +func (repo *replayRepository) GetByStatus(ctx context.Context, status []string) ([]models.ReplaySpec, error) { var replays []Replay - if err := repo.DB.Where("status in (?)", status).Preload("Job").Find(&replays).Error; err != nil { + if err := repo.DB.WithContext(ctx).Where("status in (?)", status).Preload("Job").Find(&replays).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return []models.ReplaySpec{}, store.ErrResourceNotFound } @@ -217,7 +218,7 @@ func (repo *replayRepository) GetByStatus(status []string) ([]models.ReplaySpec, return replaySpecs, nil } -func (repo *replayRepository) GetByJobIDAndStatus(jobID uuid.UUID, status []string) ([]models.ReplaySpec, error) { +func (repo *replayRepository) GetByJobIDAndStatus(ctx context.Context, jobID uuid.UUID, status []string) ([]models.ReplaySpec, error) { var replays []Replay if err := repo.DB.Where("job_id = ? and status in (?)", jobID, status).Preload("Job").Find(&replays).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -241,9 +242,9 @@ func (repo *replayRepository) GetByJobIDAndStatus(jobID uuid.UUID, status []stri return replaySpecs, nil } -func (repo *replayRepository) GetByProjectIDAndStatus(projectID uuid.UUID, status []string) ([]models.ReplaySpec, error) { +func (repo *replayRepository) GetByProjectIDAndStatus(ctx context.Context, projectID uuid.UUID, status []string) ([]models.ReplaySpec, error) { var replays []Replay - if err := repo.DB.Preload("Job").Joins("JOIN job ON replay.job_id = job.id"). + if err := repo.DB.WithContext(ctx).Preload("Job").Joins("JOIN job ON replay.job_id = job.id"). Where("job.project_id = ? and status in (?)", projectID, status).Find(&replays).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return []models.ReplaySpec{}, store.ErrResourceNotFound @@ -266,9 +267,9 @@ func (repo *replayRepository) GetByProjectIDAndStatus(projectID uuid.UUID, statu return replaySpecs, nil } -func (repo *replayRepository) GetByProjectID(projectID uuid.UUID) ([]models.ReplaySpec, error) { +func (repo *replayRepository) GetByProjectID(ctx context.Context, projectID uuid.UUID) ([]models.ReplaySpec, error) { var replays []Replay - if err := repo.DB.Preload("Job").Joins("JOIN job ON replay.job_id = job.id"). + if err := repo.DB.WithContext(ctx).Preload("Job").Joins("JOIN job ON replay.job_id = job.id"). Where("job.project_id = ?", projectID).Order("created_at DESC").Find(&replays).Error; err != nil { return []models.ReplaySpec{}, err } diff --git a/store/postgres/replay_repository_test.go b/store/postgres/replay_repository_test.go index 211061d1ab..21a52e8abb 100644 --- a/store/postgres/replay_repository_test.go +++ b/store/postgres/replay_repository_test.go @@ -13,10 +13,10 @@ import ( "github.com/odpf/optimus/core/tree" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/mock" "github.com/odpf/optimus/models" "github.com/stretchr/testify/assert" + "gorm.io/gorm" ) func treeIsEqual(treeNode *tree.TreeNode, treeNodeComparator *tree.TreeNode) bool { @@ -37,6 +37,7 @@ func treeIsEqual(treeNode *tree.TreeNode, treeNodeComparator *tree.TreeNode) boo } func TestReplayRepository(t *testing.T) { + ctx := context.Background() projectSpec := models.ProjectSpec{ ID: uuid.Must(uuid.NewRandom()), Name: "t-optimus-id", @@ -155,7 +156,8 @@ func TestReplayRepository(t *testing.T) { t.Run("Insert and GetByID", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() execUnit1 := new(mock.BasePlugin) defer execUnit1.AssertExpectations(t) @@ -179,14 +181,14 @@ func TestReplayRepository(t *testing.T) { projectJobSpecRepo := NewProjectJobSpecRepository(db, projectSpec, adapter) jobRepo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) - err := jobRepo.Insert(jobConfigs[0]) + err := jobRepo.Insert(ctx, jobConfigs[0]) assert.Nil(t, err) repo := NewReplayRepository(db, adapter) - err = repo.Insert(testModels[0]) + err = repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModels[0].ID) + checkModel, err := repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, testModels[0].ID, checkModel.ID) assert.True(t, treeIsEqual(testModels[0].ExecutionTree, checkModel.ExecutionTree)) @@ -194,7 +196,8 @@ func TestReplayRepository(t *testing.T) { t.Run("UpdateStatus", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() var testModels []*models.ReplaySpec testModels = append(testModels, testConfigs...) @@ -216,11 +219,11 @@ func TestReplayRepository(t *testing.T) { projectJobSpecRepo := NewProjectJobSpecRepository(db, projectSpec, adapter) jobRepo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) - err := jobRepo.Insert(jobConfigs[0]) + err := jobRepo.Insert(ctx, jobConfigs[0]) assert.Nil(t, err) repo := NewReplayRepository(db, adapter) - err = repo.Insert(testModels[0]) + err = repo.Insert(ctx, testModels[0]) assert.Nil(t, err) errMessage := "failed to execute" @@ -228,10 +231,10 @@ func TestReplayRepository(t *testing.T) { Type: "test failure", Message: errMessage, } - err = repo.UpdateStatus(testModels[0].ID, models.ReplayStatusFailed, replayMessage) + err = repo.UpdateStatus(ctx, testModels[0].ID, models.ReplayStatusFailed, replayMessage) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModels[0].ID) + checkModel, err := repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, models.ReplayStatusFailed, checkModel.Status) assert.Equal(t, errMessage, checkModel.Message.Message) @@ -240,7 +243,8 @@ func TestReplayRepository(t *testing.T) { t.Run("GetByStatus", func(t *testing.T) { t.Run("should return list of job specs given list of status", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() var testModels []*models.ReplaySpec testModels = append(testModels, testConfigs...) @@ -271,23 +275,23 @@ func TestReplayRepository(t *testing.T) { projectJobSpecRepo := NewProjectJobSpecRepository(db, projectSpec, adapter) jobRepo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) - err := jobRepo.Insert(testModels[0].Job) + err := jobRepo.Insert(ctx, testModels[0].Job) assert.Nil(t, err) - err = jobRepo.Insert(testModels[1].Job) + err = jobRepo.Insert(ctx, testModels[1].Job) assert.Nil(t, err) - err = jobRepo.Insert(testModels[2].Job) + err = jobRepo.Insert(ctx, testModels[2].Job) assert.Nil(t, err) repo := NewReplayRepository(db, adapter) - err = repo.Insert(testModels[0]) + err = repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - err = repo.Insert(testModels[1]) + err = repo.Insert(ctx, testModels[1]) assert.Nil(t, err) - err = repo.Insert(testModels[2]) + err = repo.Insert(ctx, testModels[2]) assert.Nil(t, err) statusList := []string{models.ReplayStatusAccepted, models.ReplayStatusInProgress} - replays, err := repo.GetByStatus(statusList) + replays, err := repo.GetByStatus(ctx, statusList) assert.Nil(t, err) assert.Equal(t, jobConfigs[0].ID, replays[0].Job.ID) assert.Equal(t, jobConfigs[2].ID, replays[1].Job.ID) @@ -297,7 +301,8 @@ func TestReplayRepository(t *testing.T) { t.Run("GetByJobIDAndStatus", func(t *testing.T) { t.Run("should return list of replay specs given job_id and list of status", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() var testModels []*models.ReplaySpec testModels = append(testModels, testConfigs...) @@ -326,23 +331,23 @@ func TestReplayRepository(t *testing.T) { projectJobSpecRepo := NewProjectJobSpecRepository(db, projectSpec, adapter) jobRepo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) - err := jobRepo.Insert(testModels[0].Job) + err := jobRepo.Insert(ctx, testModels[0].Job) assert.Nil(t, err) - err = jobRepo.Insert(testModels[1].Job) + err = jobRepo.Insert(ctx, testModels[1].Job) assert.Nil(t, err) - err = jobRepo.Insert(testModels[2].Job) + err = jobRepo.Insert(ctx, testModels[2].Job) assert.Nil(t, err) repo := NewReplayRepository(db, adapter) - err = repo.Insert(testModels[0]) + err = repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - err = repo.Insert(testModels[1]) + err = repo.Insert(ctx, testModels[1]) assert.Nil(t, err) - err = repo.Insert(testModels[2]) + err = repo.Insert(ctx, testModels[2]) assert.Nil(t, err) statusList := []string{models.ReplayStatusAccepted, models.ReplayStatusInProgress} - replays, err := repo.GetByJobIDAndStatus(testModels[2].Job.ID, statusList) + replays, err := repo.GetByJobIDAndStatus(ctx, testModels[2].Job.ID, statusList) assert.Nil(t, err) assert.Equal(t, jobConfigs[2].ID, replays[0].Job.ID) }) @@ -350,7 +355,8 @@ func TestReplayRepository(t *testing.T) { t.Run("GetByProjectIDAndStatus", func(t *testing.T) { t.Run("should return list of replay specs given project_id and list of status", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() var testModels []*models.ReplaySpec testModels = append(testModels, testConfigs...) @@ -381,25 +387,25 @@ func TestReplayRepository(t *testing.T) { jobRepo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) projectRepo := NewProjectRepository(db, hash) - err := projectRepo.Insert(projectSpec) + err := projectRepo.Insert(ctx, projectSpec) assert.Nil(t, err) - err = jobRepo.Insert(testModels[0].Job) + err = jobRepo.Insert(ctx, testModels[0].Job) assert.Nil(t, err) - err = jobRepo.Insert(testModels[1].Job) + err = jobRepo.Insert(ctx, testModels[1].Job) assert.Nil(t, err) - err = jobRepo.Insert(testModels[2].Job) + err = jobRepo.Insert(ctx, testModels[2].Job) assert.Nil(t, err) repo := NewReplayRepository(db, adapter) - err = repo.Insert(testModels[0]) + err = repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - err = repo.Insert(testModels[1]) + err = repo.Insert(ctx, testModels[1]) assert.Nil(t, err) - err = repo.Insert(testModels[2]) + err = repo.Insert(ctx, testModels[2]) assert.Nil(t, err) statusList := []string{models.ReplayStatusAccepted, models.ReplayStatusInProgress} - replays, err := repo.GetByProjectIDAndStatus(projectSpec.ID, statusList) + replays, err := repo.GetByProjectIDAndStatus(ctx, projectSpec.ID, statusList) assert.Nil(t, err) assert.ElementsMatch(t, []uuid.UUID{testModels[0].ID, testModels[2].ID}, []uuid.UUID{replays[0].ID, replays[1].ID}) }) @@ -407,7 +413,8 @@ func TestReplayRepository(t *testing.T) { t.Run("GetByProjectID", func(t *testing.T) { t.Run("should return list of replay specs given project_id", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() var testModels []*models.ReplaySpec testModels = append(testModels, testConfigs...) expectedUUIDs := []uuid.UUID{testModels[0].ID, testModels[1].ID, testModels[2].ID} @@ -439,30 +446,31 @@ func TestReplayRepository(t *testing.T) { jobRepo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) projectRepo := NewProjectRepository(db, hash) - err := projectRepo.Insert(projectSpec) + err := projectRepo.Insert(ctx, projectSpec) assert.Nil(t, err) - err = jobRepo.Insert(testModels[0].Job) + err = jobRepo.Insert(ctx, testModels[0].Job) assert.Nil(t, err) - err = jobRepo.Insert(testModels[1].Job) + err = jobRepo.Insert(ctx, testModels[1].Job) assert.Nil(t, err) - err = jobRepo.Insert(testModels[2].Job) + err = jobRepo.Insert(ctx, testModels[2].Job) assert.Nil(t, err) repo := NewReplayRepository(db, adapter) - err = repo.Insert(testModels[0]) + err = repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - err = repo.Insert(testModels[1]) + err = repo.Insert(ctx, testModels[1]) assert.Nil(t, err) - err = repo.Insert(testModels[2]) + err = repo.Insert(ctx, testModels[2]) assert.Nil(t, err) - replays, err := repo.GetByProjectID(projectSpec.ID) + replays, err := repo.GetByProjectID(ctx, projectSpec.ID) assert.Nil(t, err) assert.ElementsMatch(t, expectedUUIDs, []uuid.UUID{replays[0].ID, replays[1].ID, replays[2].ID}) }) t.Run("should return not found if no recent replay is found", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() var testModels []*models.ReplaySpec testModels = append(testModels, testConfigs...) @@ -492,24 +500,24 @@ func TestReplayRepository(t *testing.T) { jobRepo := NewJobSpecRepository(db, namespaceSpec, projectJobSpecRepo, adapter) projectRepo := NewProjectRepository(db, hash) - err := projectRepo.Insert(projectSpec) + err := projectRepo.Insert(ctx, projectSpec) assert.Nil(t, err) - err = jobRepo.Insert(testModels[0].Job) + err = jobRepo.Insert(ctx, testModels[0].Job) assert.Nil(t, err) - err = jobRepo.Insert(testModels[1].Job) + err = jobRepo.Insert(ctx, testModels[1].Job) assert.Nil(t, err) - err = jobRepo.Insert(testModels[2].Job) + err = jobRepo.Insert(ctx, testModels[2].Job) assert.Nil(t, err) repo := NewReplayRepository(db, adapter) - err = repo.Insert(testModels[0]) + err = repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - err = repo.Insert(testModels[1]) + err = repo.Insert(ctx, testModels[1]) assert.Nil(t, err) - err = repo.Insert(testModels[2]) + err = repo.Insert(ctx, testModels[2]) assert.Nil(t, err) - replays, err := repo.GetByProjectID(uuid.Must(uuid.NewRandom())) + replays, err := repo.GetByProjectID(ctx, uuid.Must(uuid.NewRandom())) assert.Equal(t, store.ErrResourceNotFound, err) assert.Equal(t, []models.ReplaySpec{}, replays) }) diff --git a/store/postgres/resource_spec_repository.go b/store/postgres/resource_spec_repository.go index fc64393864..67c04beeb3 100644 --- a/store/postgres/resource_spec_repository.go +++ b/store/postgres/resource_spec_repository.go @@ -1,6 +1,7 @@ package postgres import ( + "context" "encoding/json" "fmt" "time" @@ -10,13 +11,13 @@ import ( "github.com/odpf/optimus/store" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/models" "github.com/pkg/errors" + "gorm.io/gorm" ) type Resource struct { - ID uuid.UUID `gorm:"primary_key;type:uuid"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4()"` ProjectID uuid.UUID Project Project `gorm:"foreignKey:ProjectID"` @@ -36,7 +37,7 @@ type Resource struct { CreatedAt time.Time `gorm:"not null" json:"created_at"` UpdatedAt time.Time `gorm:"not null" json:"updated_at"` - DeletedAt *time.Time + DeletedAt gorm.DeletedAt } func (r Resource) FromSpec(resourceSpec models.ResourceSpec) (Resource, error) { @@ -147,9 +148,10 @@ type projectResourceSpecRepository struct { datastore models.Datastorer } -func (repo *projectResourceSpecRepository) GetByName(name string) (models.ResourceSpec, models.NamespaceSpec, error) { +func (repo *projectResourceSpecRepository) GetByName(ctx context.Context, name string) (models.ResourceSpec, models.NamespaceSpec, error) { var r Resource - if err := repo.db.Preload("Namespace").Where("project_id = ? AND datastore = ? AND name = ?", repo.project.ID, repo.datastore.Name(), name).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Preload("Namespace").Where("project_id = ? AND datastore = ? AND name = ?", + repo.project.ID, repo.datastore.Name(), name).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.ResourceSpec{}, models.NamespaceSpec{}, store.ErrResourceNotFound } @@ -169,10 +171,10 @@ func (repo *projectResourceSpecRepository) GetByName(name string) (models.Resour return resourceSpec, namespaceSpec, nil } -func (repo *projectResourceSpecRepository) GetAll() ([]models.ResourceSpec, error) { +func (repo *projectResourceSpecRepository) GetAll(ctx context.Context) ([]models.ResourceSpec, error) { specs := []models.ResourceSpec{} resources := []Resource{} - if err := repo.db.Where("project_id = ? AND datastore = ?", repo.project.ID, repo.datastore.Name()).Find(&resources).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("project_id = ? AND datastore = ?", repo.project.ID, repo.datastore.Name()).Find(&resources).Error; err != nil { return specs, err } for _, r := range resources { @@ -200,7 +202,7 @@ type resourceSpecRepository struct { projectResourceSpecRepo store.ProjectResourceSpecRepository } -func (repo *resourceSpecRepository) Insert(resource models.ResourceSpec) error { +func (repo *resourceSpecRepository) Insert(ctx context.Context, resource models.ResourceSpec) error { if len(resource.Name) == 0 { return errors.New("name cannot be empty") } @@ -209,16 +211,16 @@ func (repo *resourceSpecRepository) Insert(resource models.ResourceSpec) error { return err } // if soft deleted earlier - if err := repo.HardDelete(resource.Name); err != nil { + if err := repo.HardDelete(ctx, resource.Name); err != nil { return err } - return repo.db.Create(&p).Error + return repo.db.WithContext(ctx).Create(&p).Error } -func (repo *resourceSpecRepository) Save(spec models.ResourceSpec) error { - existingResource, namespaceSpec, err := repo.projectResourceSpecRepo.GetByName(spec.Name) +func (repo *resourceSpecRepository) Save(ctx context.Context, spec models.ResourceSpec) error { + existingResource, namespaceSpec, err := repo.projectResourceSpecRepo.GetByName(ctx, spec.Name) if errors.Is(err, store.ErrResourceNotFound) { - return repo.Insert(spec) + return repo.Insert(ctx, spec) } else if err != nil { return errors.Wrap(err, "unable to find resource by name") } @@ -233,12 +235,13 @@ func (repo *resourceSpecRepository) Save(spec models.ResourceSpec) error { } resource.ID = existingResource.ID - return repo.db.Model(&resource).Updates(&resource).Error + return repo.db.WithContext(ctx).Model(&resource).Updates(&resource).Error } -func (repo *resourceSpecRepository) GetByName(name string) (models.ResourceSpec, error) { +func (repo *resourceSpecRepository) GetByName(ctx context.Context, name string) (models.ResourceSpec, error) { var r Resource - if err := repo.db.Where("namespace_id = ? AND datastore = ? AND name = ?", repo.namespace.ID, repo.datastore.Name(), name).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("namespace_id = ? AND datastore = ? AND name = ?", + repo.namespace.ID, repo.datastore.Name(), name).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.ResourceSpec{}, store.ErrResourceNotFound } @@ -247,9 +250,10 @@ func (repo *resourceSpecRepository) GetByName(name string) (models.ResourceSpec, return r.ToSpec(repo.datastore) } -func (repo *resourceSpecRepository) GetByID(id uuid.UUID) (models.ResourceSpec, error) { +func (repo *resourceSpecRepository) GetByID(ctx context.Context, id uuid.UUID) (models.ResourceSpec, error) { var r Resource - if err := repo.db.Where("namespace_id = ? AND id = ?", repo.namespace.ID, id).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("namespace_id = ? AND id = ?", + repo.namespace.ID, id).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.ResourceSpec{}, store.ErrResourceNotFound } @@ -258,9 +262,10 @@ func (repo *resourceSpecRepository) GetByID(id uuid.UUID) (models.ResourceSpec, return r.ToSpec(repo.datastore) } -func (repo *resourceSpecRepository) GetByURN(urn string) (models.ResourceSpec, error) { +func (repo *resourceSpecRepository) GetByURN(ctx context.Context, urn string) (models.ResourceSpec, error) { var r Resource - if err := repo.db.Where("namespace_id = ? AND datastore = ? AND urn = ?", repo.namespace.ID, repo.datastore.Name(), urn).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("namespace_id = ? AND datastore = ? AND urn = ?", + repo.namespace.ID, repo.datastore.Name(), urn).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.ResourceSpec{}, store.ErrResourceNotFound } @@ -269,10 +274,10 @@ func (repo *resourceSpecRepository) GetByURN(urn string) (models.ResourceSpec, e return r.ToSpec(repo.datastore) } -func (repo *resourceSpecRepository) GetAll() ([]models.ResourceSpec, error) { +func (repo *resourceSpecRepository) GetAll(ctx context.Context) ([]models.ResourceSpec, error) { specs := []models.ResourceSpec{} resources := []Resource{} - if err := repo.db.Where("namespace_id = ? AND datastore = ?", repo.namespace.ID, repo.datastore.Name()).Find(&resources).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("namespace_id = ? AND datastore = ?", repo.namespace.ID, repo.datastore.Name()).Find(&resources).Error; err != nil { return specs, err } for _, r := range resources { @@ -285,12 +290,12 @@ func (repo *resourceSpecRepository) GetAll() ([]models.ResourceSpec, error) { return specs, nil } -func (repo *resourceSpecRepository) Delete(name string) error { - return repo.db.Where("namespace_id = ? AND datastore = ? AND name = ? ", repo.namespace.ID, repo.datastore.Name(), name).Delete(&Resource{}).Error +func (repo *resourceSpecRepository) Delete(ctx context.Context, name string) error { + return repo.db.WithContext(ctx).Where("namespace_id = ? AND datastore = ? AND name = ? ", repo.namespace.ID, repo.datastore.Name(), name).Delete(&Resource{}).Error } -func (repo *resourceSpecRepository) HardDelete(name string) error { - return repo.db.Unscoped().Where("namespace_id = ? AND datastore = ? AND name = ? ", repo.namespace.ID, repo.datastore.Name(), name).Delete(&Resource{}).Error +func (repo *resourceSpecRepository) HardDelete(ctx context.Context, name string) error { + return repo.db.WithContext(ctx).Unscoped().Where("namespace_id = ? AND datastore = ? AND name = ? ", repo.namespace.ID, repo.datastore.Name(), name).Delete(&Resource{}).Error } func NewResourceSpecRepository(db *gorm.DB, namespace models.NamespaceSpec, ds models.Datastorer, projectResourceSpecRepo store.ProjectResourceSpecRepository) *resourceSpecRepository { diff --git a/store/postgres/resource_spec_repository_test.go b/store/postgres/resource_spec_repository_test.go index 373c6e106e..3a8e045e11 100644 --- a/store/postgres/resource_spec_repository_test.go +++ b/store/postgres/resource_spec_repository_test.go @@ -3,19 +3,21 @@ package postgres import ( + "context" "os" "testing" "github.com/odpf/optimus/mock" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/models" "github.com/stretchr/testify/assert" testMock "github.com/stretchr/testify/mock" + "gorm.io/gorm" ) func TestResourceSpecRepository(t *testing.T) { + ctx := context.Background() projectSpec := models.ProjectSpec{ ID: uuid.Must(uuid.NewRandom()), Name: "t-optimus-project", @@ -59,7 +61,7 @@ func TestResourceSpecRepository(t *testing.T) { } projRepo := NewProjectRepository(dbConn, hash) - assert.Nil(t, projRepo.Save(projectSpec)) + assert.Nil(t, projRepo.Save(ctx, projectSpec)) return dbConn } testConfigs := []models.ResourceSpec{ @@ -119,7 +121,8 @@ func TestResourceSpecRepository(t *testing.T) { t.Run("Insert", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.ResourceSpec{} testModels = append(testModels, testConfigs...) @@ -129,13 +132,13 @@ func TestResourceSpecRepository(t *testing.T) { projectResourceSpecRepo := NewProjectResourceSpecRepository(db, projectSpec, datastorer) repo := NewResourceSpecRepository(db, namespaceSpec, datastorer, projectResourceSpecRepo) - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - err = repo.Insert(testModels[1]) + err = repo.Insert(ctx, testModels[1]) assert.NotNil(t, err) - checkModel, err := repo.GetByID(testModels[0].ID) + checkModel, err := repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, "proj.datas.test", checkModel.Name) }) @@ -143,7 +146,8 @@ func TestResourceSpecRepository(t *testing.T) { t.Run("Upsert", func(t *testing.T) { t.Run("insert different resource should insert two", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := testConfigs[0] testModelB := testConfigs[2] @@ -153,25 +157,26 @@ func TestResourceSpecRepository(t *testing.T) { dsTypeTableController.On("GenerateURN", testMock.Anything).Return(testModelA.URN, nil).Once() //try for create - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModelA.ID) + checkModel, err := repo.GetByID(ctx, testModelA.ID) assert.Nil(t, err) assert.Equal(t, "proj.datas.test", checkModel.Name) //try for create - err = repo.Save(testModelB) + err = repo.Save(ctx, testModelB) assert.Nil(t, err) - checkModel, err = repo.GetByID(testModelB.ID) + checkModel, err = repo.GetByID(ctx, testModelB.ID) assert.Nil(t, err) assert.Equal(t, "proj.ttt.test2", checkModel.Name) assert.Equal(t, "table", checkModel.Type.String()) }) t.Run("insert same resource twice should overwrite existing", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := testConfigs[2] projectResourceSpecRepo := NewProjectResourceSpecRepository(db, projectSpec, datastorer) @@ -180,10 +185,10 @@ func TestResourceSpecRepository(t *testing.T) { dsTypeTableController.On("GenerateURN", testMock.Anything).Return(testModelA.URN, nil).Twice() //try for create - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModelA.ID) + checkModel, err := repo.GetByID(ctx, testModelA.ID) assert.Nil(t, err) assert.Equal(t, "proj.ttt.test2", checkModel.Name) @@ -192,16 +197,17 @@ func TestResourceSpecRepository(t *testing.T) { dsTypeTableAdapter.On("ToYaml", testModelA).Return([]byte("some binary data testModelA"), nil) dsTypeTableAdapter.On("FromYaml", []byte("some binary data testModelA")).Return(testModelA, nil) - err = repo.Save(testModelA) + err = repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err = repo.GetByID(testModelA.ID) + checkModel, err = repo.GetByID(ctx, testModelA.ID) assert.Nil(t, err) assert.Equal(t, 6, checkModel.Version) }) t.Run("upsert without ID should auto generate it", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() resourceSpecWithEmptyUUID := testConfigWithoutAssets[0] resourceSpecWithEmptyUUID.ID = uuid.Nil @@ -226,17 +232,18 @@ func TestResourceSpecRepository(t *testing.T) { dsTypeTableControllerLocal.On("GenerateURN", testMock.Anything).Return(resourceSpecWithEmptyUUID.URN, nil).Once() - //try for create - err := repo.Save(resourceSpecWithEmptyUUID) + // try for create + err := repo.Save(ctx, resourceSpecWithEmptyUUID) assert.Nil(t, err) - checkModel, err := repo.GetByName(resourceSpecWithEmptyUUID.Name) + checkModel, err := repo.GetByName(ctx, resourceSpecWithEmptyUUID.Name) assert.Nil(t, err) assert.Equal(t, "proj.datas.test", checkModel.Name) }) t.Run("should fail if resource is already registered for a project with different namespace", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := testConfigs[2] projectResourceSpecRepo := NewProjectResourceSpecRepository(db, projectSpec, datastorer) @@ -246,17 +253,17 @@ func TestResourceSpecRepository(t *testing.T) { dsTypeTableController.On("GenerateURN", testMock.Anything).Return(testModelA.URN, nil).Twice() //try for create - err := resourceSpecNamespace1.Save(testModelA) + err := resourceSpecNamespace1.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, checkNamespace, err := projectResourceSpecRepo.GetByName(testModelA.Name) + checkModel, checkNamespace, err := projectResourceSpecRepo.GetByName(ctx, testModelA.Name) assert.Nil(t, err) assert.Equal(t, "proj.ttt.test2", checkModel.Name) assert.Equal(t, namespaceSpec.ID, checkNamespace.ID) assert.Equal(t, namespaceSpec.ProjectSpec.ID, checkNamespace.ProjectSpec.ID) // try to create same resource with second client and it should fail. - err = resourceSpecNamespace2.Save(testModelA) + err = resourceSpecNamespace2.Save(ctx, testModelA) assert.NotNil(t, err) assert.Equal(t, "resource proj.ttt.test2 already exists for the project t-optimus-project", err.Error()) }) @@ -264,23 +271,25 @@ func TestResourceSpecRepository(t *testing.T) { t.Run("GetByName", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.ResourceSpec{} testModels = append(testModels, testConfigs...) projectResourceSpecRepo := NewProjectResourceSpecRepository(db, projectSpec, datastorer) repo := NewResourceSpecRepository(db, namespaceSpec, datastorer, projectResourceSpecRepo) - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - checkModel, err := repo.GetByName(testModels[0].Name) + checkModel, err := repo.GetByName(ctx, testModels[0].Name) assert.Nil(t, err) assert.Equal(t, "proj.datas.test", checkModel.Name) }) } func TestProjectResourceSpecRepository(t *testing.T) { + ctx := context.Background() projectSpec := models.ProjectSpec{ ID: uuid.Must(uuid.NewRandom()), Name: "t-optimus-project", @@ -324,7 +333,7 @@ func TestProjectResourceSpecRepository(t *testing.T) { } projRepo := NewProjectRepository(dbConn, hash) - assert.Nil(t, projRepo.Save(projectSpec)) + assert.Nil(t, projRepo.Save(ctx, projectSpec)) return dbConn } testConfigs := []models.ResourceSpec{ @@ -378,7 +387,8 @@ func TestProjectResourceSpecRepository(t *testing.T) { t.Run("GetByName", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.ResourceSpec{} testModels = append(testModels, testConfigs...) @@ -387,24 +397,25 @@ func TestProjectResourceSpecRepository(t *testing.T) { dsTypeTableController.On("GenerateURN", testMock.Anything).Return(testModels[0].URN, nil).Once() - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) // validate at project level - checkModel, checkClient, err := projectResourceSpecRepo.GetByName(testModels[0].Name) + checkModel, checkClient, err := projectResourceSpecRepo.GetByName(ctx, testModels[0].Name) assert.Nil(t, err) assert.Equal(t, "proj.datas.test", checkModel.Name) assert.Equal(t, namespaceSpec.Name, checkClient.Name) // validate at client level - checkModel, err = repo.GetByName(testModels[0].Name) + checkModel, err = repo.GetByName(ctx, testModels[0].Name) assert.Nil(t, err) assert.Equal(t, "proj.datas.test", checkModel.Name) }) t.Run("GetByURN", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() var testModels []models.ResourceSpec testModels = append(testModels, testConfigs...) @@ -413,17 +424,18 @@ func TestProjectResourceSpecRepository(t *testing.T) { dsTypeTableController.On("GenerateURN", testMock.Anything).Return(testModels[0].URN, nil).Once() - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - checkModel, err := repo.GetByURN(testModels[0].URN) + checkModel, err := repo.GetByURN(ctx, testModels[0].URN) assert.Nil(t, err) assert.Equal(t, "proj.datas.test", checkModel.Name) }) t.Run("GetAll", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.ResourceSpec{} testModels = append(testModels, testConfigs...) @@ -432,17 +444,17 @@ func TestProjectResourceSpecRepository(t *testing.T) { dsTypeTableController.On("GenerateURN", testMock.Anything).Return(testModels[0].URN, nil).Once() - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) // validate at project level - checkModels, err := projectResourceSpecRepo.GetAll() + checkModels, err := projectResourceSpecRepo.GetAll(ctx) assert.Nil(t, err) assert.Equal(t, "proj.datas.test", checkModels[0].Name) assert.Equal(t, 1, len(checkModels)) // validate at client level - checkModels, err = repo.GetAll() + checkModels, err = repo.GetAll(ctx) assert.Nil(t, err) assert.Equal(t, "proj.datas.test", checkModels[0].Name) assert.Equal(t, 1, len(checkModels)) diff --git a/store/postgres/secret_repository.go b/store/postgres/secret_repository.go index c57715878d..e04ab34ca1 100644 --- a/store/postgres/secret_repository.go +++ b/store/postgres/secret_repository.go @@ -1,6 +1,7 @@ package postgres import ( + "context" "encoding/base64" "time" @@ -8,13 +9,13 @@ import ( "github.com/google/uuid" "github.com/gtank/cryptopasta" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/models" "github.com/pkg/errors" + "gorm.io/gorm" ) type Secret struct { - ID uuid.UUID `gorm:"primary_key;type:uuid"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4()"` ProjectID uuid.UUID Project Project `gorm:"foreignKey:ProjectID"` @@ -23,7 +24,7 @@ type Secret struct { CreatedAt time.Time `gorm:"not null" json:"created_at"` UpdatedAt time.Time `gorm:"not null" json:"updated_at"` - DeletedAt *time.Time + DeletedAt gorm.DeletedAt } func (p Secret) FromSpec(spec models.ProjectSecretItem, proj models.ProjectSpec, hash models.ApplicationKey) (Secret, error) { @@ -71,7 +72,7 @@ type secretRepository struct { hash models.ApplicationKey } -func (repo *secretRepository) Insert(resource models.ProjectSecretItem) error { +func (repo *secretRepository) Insert(ctx context.Context, resource models.ProjectSecretItem) error { p, err := Secret{}.FromSpec(resource, repo.project, repo.hash) if err != nil { return err @@ -79,13 +80,13 @@ func (repo *secretRepository) Insert(resource models.ProjectSecretItem) error { if len(p.Name) == 0 { return errors.New("name cannot be empty") } - return repo.db.Create(&p).Error + return repo.db.WithContext(ctx).Save(&p).Error } -func (repo *secretRepository) Save(spec models.ProjectSecretItem) error { - existingResource, err := repo.GetByName(spec.Name) +func (repo *secretRepository) Save(ctx context.Context, spec models.ProjectSecretItem) error { + existingResource, err := repo.GetByName(ctx, spec.Name) if errors.Is(err, store.ErrResourceNotFound) { - return repo.Insert(spec) + return repo.Insert(ctx, spec) } else if err != nil { return errors.Wrap(err, "unable to find secret by name") } @@ -96,12 +97,12 @@ func (repo *secretRepository) Save(spec models.ProjectSecretItem) error { if err == nil { resource.ID = existingResource.ID } - return repo.db.Model(&resource).Updates(&resource).Error + return repo.db.WithContext(ctx).Model(&resource).Updates(&resource).Error } -func (repo *secretRepository) GetByName(name string) (models.ProjectSecretItem, error) { +func (repo *secretRepository) GetByName(ctx context.Context, name string) (models.ProjectSecretItem, error) { var r Secret - if err := repo.db.Where("name = ? AND project_id = ?", name, repo.project.ID).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("name = ? AND project_id = ?", name, repo.project.ID).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.ProjectSecretItem{}, store.ErrResourceNotFound } @@ -110,9 +111,9 @@ func (repo *secretRepository) GetByName(name string) (models.ProjectSecretItem, return r.ToSpec(repo.hash) } -func (repo *secretRepository) GetByID(id uuid.UUID) (models.ProjectSecretItem, error) { +func (repo *secretRepository) GetByID(ctx context.Context, id uuid.UUID) (models.ProjectSecretItem, error) { var r Secret - if err := repo.db.Where("id = ?", id).Find(&r).Error; err != nil { + if err := repo.db.WithContext(ctx).Where("id = ?", id).First(&r).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return models.ProjectSecretItem{}, store.ErrResourceNotFound } @@ -121,10 +122,10 @@ func (repo *secretRepository) GetByID(id uuid.UUID) (models.ProjectSecretItem, e return r.ToSpec(repo.hash) } -func (repo *secretRepository) GetAll() ([]models.ProjectSecretItem, error) { +func (repo *secretRepository) GetAll(ctx context.Context) ([]models.ProjectSecretItem, error) { var specs []models.ProjectSecretItem var resources []Secret - if err := repo.db.Find(&resources).Error; err != nil { + if err := repo.db.WithContext(ctx).Find(&resources).Error; err != nil { return specs, err } for _, res := range resources { diff --git a/store/postgres/secret_repository_test.go b/store/postgres/secret_repository_test.go index 8388302ccb..939a17911b 100644 --- a/store/postgres/secret_repository_test.go +++ b/store/postgres/secret_repository_test.go @@ -3,16 +3,18 @@ package postgres import ( + "context" "os" "testing" "github.com/google/uuid" - "github.com/jinzhu/gorm" "github.com/odpf/optimus/models" "github.com/stretchr/testify/assert" + "gorm.io/gorm" ) func TestSecretRepository(t *testing.T) { + ctx := context.Background() projectSpec := models.ProjectSpec{ ID: uuid.Must(uuid.NewRandom()), Name: "t-optimus-project", @@ -43,7 +45,7 @@ func TestSecretRepository(t *testing.T) { } projRepo := NewProjectRepository(dbConn, hash) - assert.Nil(t, projRepo.Save(projectSpec)) + assert.Nil(t, projRepo.Save(ctx, projectSpec)) return dbConn } @@ -65,102 +67,107 @@ func TestSecretRepository(t *testing.T) { t.Run("Insert", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.ProjectSecretItem{} testModels = append(testModels, testConfigs...) repo := NewSecretRepository(db, projectSpec, hash) - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - err = repo.Insert(testModels[1]) + err = repo.Insert(ctx, testModels[1]) assert.NotNil(t, err) - checkModel, err := repo.GetByID(testModels[0].ID) + checkModel, err := repo.GetByID(ctx, testModels[0].ID) assert.Nil(t, err) assert.Equal(t, "g-optimus", checkModel.Name) }) t.Run("Upsert", func(t *testing.T) { t.Run("insert different resource should insert two", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := testConfigs[0] testModelB := testConfigs[2] repo := NewSecretRepository(db, projectSpec, hash) //try for create - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModelA.ID) + checkModel, err := repo.GetByID(ctx, testModelA.ID) assert.Nil(t, err) assert.Equal(t, "g-optimus", checkModel.Name) //try for update - err = repo.Save(testModelB) + err = repo.Save(ctx, testModelB) assert.Nil(t, err) - checkModel, err = repo.GetByID(testModelB.ID) + checkModel, err = repo.GetByID(ctx, testModelB.ID) assert.Nil(t, err) assert.Equal(t, "t-optimus", checkModel.Name) assert.Equal(t, "super-secret", checkModel.Value) }) t.Run("insert same resource twice should overwrite existing", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := testConfigs[2] repo := NewSecretRepository(db, projectSpec, hash) //try for create testModelA.Value = "gs://some_folder" - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByID(testModelA.ID) + checkModel, err := repo.GetByID(ctx, testModelA.ID) assert.Nil(t, err) assert.Equal(t, "t-optimus", checkModel.Name) //try for update testModelA.Value = "gs://another_folder" - err = repo.Save(testModelA) + err = repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err = repo.GetByID(testModelA.ID) + checkModel, err = repo.GetByID(ctx, testModelA.ID) assert.Nil(t, err) assert.Equal(t, "gs://another_folder", checkModel.Value) }) t.Run("upsert without ID should auto generate it", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModelA := testConfigs[0] testModelA.ID = uuid.Nil repo := NewSecretRepository(db, projectSpec, hash) //try for create - err := repo.Save(testModelA) + err := repo.Save(ctx, testModelA) assert.Nil(t, err) - checkModel, err := repo.GetByName(testModelA.Name) + checkModel, err := repo.GetByName(ctx, testModelA.Name) assert.Nil(t, err) assert.Equal(t, "g-optimus", checkModel.Name) }) }) t.Run("GetByName", func(t *testing.T) { db := DBSetup() - defer db.Close() + sqlDB, _ := db.DB() + defer sqlDB.Close() testModels := []models.ProjectSecretItem{} testModels = append(testModels, testConfigs...) repo := NewSecretRepository(db, projectSpec, hash) - err := repo.Insert(testModels[0]) + err := repo.Insert(ctx, testModels[0]) assert.Nil(t, err) - checkModel, err := repo.GetByName(testModels[0].Name) + checkModel, err := repo.GetByName(ctx, testModels[0].Name) assert.Nil(t, err) assert.Equal(t, "g-optimus", checkModel.Name) }) diff --git a/store/store.go b/store/store.go index 603939053c..84a80c0306 100644 --- a/store/store.go +++ b/store/store.go @@ -1,6 +1,7 @@ package store import ( + "context" "errors" "time" @@ -15,91 +16,93 @@ var ( // ProjectJobSpecRepository represents a storage interface for Job specifications at a project level type ProjectJobSpecRepository interface { - GetByName(string) (models.JobSpec, models.NamespaceSpec, error) - GetByNameForProject(projectName, jobName string) (models.JobSpec, models.ProjectSpec, error) - GetAll() ([]models.JobSpec, error) - GetByDestination(string) (models.JobSpec, models.ProjectSpec, error) + GetByName(context.Context, string) (models.JobSpec, models.NamespaceSpec, error) + GetByNameForProject(ctx context.Context, projectName, jobName string) (models.JobSpec, models.ProjectSpec, error) + GetAll(context.Context) ([]models.JobSpec, error) + GetByDestination(context.Context, string) (models.JobSpec, models.ProjectSpec, error) } // ProjectRepository represents a storage interface for registered projects type ProjectRepository interface { - Save(models.ProjectSpec) error - GetByName(string) (models.ProjectSpec, error) - GetAll() ([]models.ProjectSpec, error) + Save(context.Context, models.ProjectSpec) error + GetByName(context.Context, string) (models.ProjectSpec, error) + GetAll(context.Context) ([]models.ProjectSpec, error) } // ProjectSecretRepository stores secrets attached to projects type ProjectSecretRepository interface { - Save(item models.ProjectSecretItem) error - GetByName(string) (models.ProjectSecretItem, error) - GetAll() ([]models.ProjectSecretItem, error) + Save(ctx context.Context, item models.ProjectSecretItem) error + GetByName(context.Context, string) (models.ProjectSecretItem, error) + GetAll(context.Context) ([]models.ProjectSecretItem, error) } // NamespaceRepository represents a storage interface for registered namespaces type NamespaceRepository interface { - Save(models.NamespaceSpec) error - GetByName(string) (models.NamespaceSpec, error) - GetAll() ([]models.NamespaceSpec, error) + Save(context.Context, models.NamespaceSpec) error + GetByName(context.Context, string) (models.NamespaceSpec, error) + GetAll(context.Context) ([]models.NamespaceSpec, error) } // JobRunSpecRepository represents a storage interface for Job runs generated to // represent a job in running state type JobRunRepository interface { - // Save updates the instance in place if it can else insert new - Save(models.NamespaceSpec, models.JobRun) error - GetByScheduledAt(jobID uuid.UUID, scheduledAt time.Time) (models.JobRun, models.NamespaceSpec, error) - GetByID(uuid.UUID) (models.JobRun, models.NamespaceSpec, error) - UpdateStatus(uuid.UUID, models.JobRunState) error - GetByStatus(state ...models.JobRunState) ([]models.JobRun, error) - GetByTrigger(trigger models.JobRunTrigger, state ...models.JobRunState) ([]models.JobRun, error) - Delete(uuid.UUID) error + // Save updates the run in place if it can else insert new + // Note: it doesn't insert the instances attached to job run in db + Save(context.Context, models.NamespaceSpec, models.JobRun) error - AddInstance(namespace models.NamespaceSpec, run models.JobRun, spec models.InstanceSpec) error + GetByScheduledAt(ctx context.Context, jobID uuid.UUID, scheduledAt time.Time) (models.JobRun, models.NamespaceSpec, error) + GetByID(context.Context, uuid.UUID) (models.JobRun, models.NamespaceSpec, error) + UpdateStatus(context.Context, uuid.UUID, models.JobRunState) error + GetByStatus(ctx context.Context, state ...models.JobRunState) ([]models.JobRun, error) + GetByTrigger(ctx context.Context, trigger models.JobRunTrigger, state ...models.JobRunState) ([]models.JobRun, error) + Delete(context.Context, uuid.UUID) error + + AddInstance(ctx context.Context, namespace models.NamespaceSpec, run models.JobRun, spec models.InstanceSpec) error // Clear will not delete the record but will reset all the run details // for fresh start - Clear(runID uuid.UUID) error - ClearInstance(runID uuid.UUID, instanceType models.InstanceType, instanceName string) error - ClearInstances(jobID uuid.UUID, scheduled time.Time) error + Clear(ctx context.Context, runID uuid.UUID) error + ClearInstance(ctx context.Context, runID uuid.UUID, instanceType models.InstanceType, instanceName string) error } // JobRunSpecRepository represents a storage interface for Job run instances created // during execution type InstanceRepository interface { - Save(run models.JobRun, spec models.InstanceSpec) error - UpdateStatus(id uuid.UUID, status models.JobRunState) error - GetByName(runID uuid.UUID, instanceName, instanceType string) (models.InstanceSpec, error) - Delete(id uuid.UUID) error + Save(ctx context.Context, run models.JobRun, spec models.InstanceSpec) error + UpdateStatus(ctx context.Context, id uuid.UUID, status models.JobRunState) error + GetByName(ctx context.Context, runID uuid.UUID, instanceName, instanceType string) (models.InstanceSpec, error) + + DeleteByJobRun(ctx context.Context, id uuid.UUID) error } // ProjectResourceSpecRepository represents a storage interface for Resource specifications at project level type ProjectResourceSpecRepository interface { - GetByName(string) (models.ResourceSpec, models.NamespaceSpec, error) - GetAll() ([]models.ResourceSpec, error) + GetByName(context.Context, string) (models.ResourceSpec, models.NamespaceSpec, error) + GetAll(context.Context) ([]models.ResourceSpec, error) } // ResourceSpecRepository represents a storage interface for Resource specifications at namespace level type ResourceSpecRepository interface { - Save(models.ResourceSpec) error - GetByName(string) (models.ResourceSpec, error) - GetByURN(string) (models.ResourceSpec, error) - GetAll() ([]models.ResourceSpec, error) - Delete(string) error + Save(context.Context, models.ResourceSpec) error + GetByName(context.Context, string) (models.ResourceSpec, error) + GetByURN(context.Context, string) (models.ResourceSpec, error) + GetAll(context.Context) ([]models.ResourceSpec, error) + Delete(context.Context, string) error } // ReplaySpecRepository represents a storage interface for replay objects type ReplaySpecRepository interface { - Insert(replay *models.ReplaySpec) error - GetByID(id uuid.UUID) (models.ReplaySpec, error) - UpdateStatus(replayID uuid.UUID, status string, message models.ReplayMessage) error - GetByStatus(status []string) ([]models.ReplaySpec, error) - GetByJobIDAndStatus(jobID uuid.UUID, status []string) ([]models.ReplaySpec, error) - GetByProjectIDAndStatus(projectID uuid.UUID, status []string) ([]models.ReplaySpec, error) - GetByProjectID(projectID uuid.UUID) ([]models.ReplaySpec, error) + Insert(ctx context.Context, replay *models.ReplaySpec) error + GetByID(ctx context.Context, id uuid.UUID) (models.ReplaySpec, error) + UpdateStatus(ctx context.Context, replayID uuid.UUID, status string, message models.ReplayMessage) error + GetByStatus(ctx context.Context, status []string) ([]models.ReplaySpec, error) + GetByJobIDAndStatus(ctx context.Context, jobID uuid.UUID, status []string) ([]models.ReplaySpec, error) + GetByProjectIDAndStatus(ctx context.Context, projectID uuid.UUID, status []string) ([]models.ReplaySpec, error) + GetByProjectID(ctx context.Context, projectID uuid.UUID) ([]models.ReplaySpec, error) } // BackupRepository represents a storage interface for backup objects type BackupRepository interface { - Save(spec models.BackupSpec) error - GetAll() ([]models.BackupSpec, error) + Save(ctx context.Context, spec models.BackupSpec) error + GetAll(context.Context) ([]models.BackupSpec, error) }