diff --git a/cmd/main/main.go b/cmd/main/main.go index a9f79c6e..b0056b2e 100644 --- a/cmd/main/main.go +++ b/cmd/main/main.go @@ -232,13 +232,13 @@ func main() { defer timeseries.Close() // Initialize Minio client - minioClient, err := miniox.NewMinioClientAndInitBucket(ctx, &config.Config.Minio, logger, config.MetadataExpiryRules...) + minioClient, err := miniox.NewMinioClientAndInitBucket(ctx, &config.Config.Minio, logger, service.MetadataExpiryRules...) if err != nil { logger.Fatal("failed to create minio client", zap.Error(err)) } serv := service.NewService(repo, timeseries.WriteAPI(), mgmtPublicServiceClient, mgmtPrivateServiceClient, - artifactPrivateServiceClient, redisClient, temporalClient, rayService, &aclClient, minioClient, + artifactPrivateServiceClient, redisClient, temporalClient, rayService, &aclClient, minioClient, nil, config.Config.Server.InstillCoreHost) modelpb.RegisterModelPublicServiceServer( diff --git a/config/config.go b/config/config.go index bd80626d..b77d1cbe 100644 --- a/config/config.go +++ b/config/config.go @@ -232,14 +232,3 @@ func ParseConfigFlag() string { return *configPath } - -const ( - DefaultExpiryTag = "default-expiry" -) - -var MetadataExpiryRules = []miniox.ExpiryRule{ - { - Tag: DefaultExpiryTag, - ExpirationDays: 3, - }, -} diff --git a/pkg/service/metadataretention.go b/pkg/service/metadataretention.go new file mode 100644 index 00000000..984ed23e --- /dev/null +++ b/pkg/service/metadataretention.go @@ -0,0 +1,34 @@ +package service + +import ( + "context" + + "github.com/gofrs/uuid" + + miniox "github.com/instill-ai/x/minio" +) + +type MetadataRetentionHandler interface { + GetExpiryTagBySubscriptionPlan(ctx context.Context, requesterUID uuid.UUID) (string, error) +} + +type metadataRetentionHandler struct{} + +func NewRetentionHandler() MetadataRetentionHandler { + return &metadataRetentionHandler{} +} + +func (h metadataRetentionHandler) GetExpiryTagBySubscriptionPlan(ctx context.Context, requesterUID uuid.UUID) (string, error) { + return defaultExpiryTag, nil +} + +const ( + defaultExpiryTag = "default-expiry" +) + +var MetadataExpiryRules = []miniox.ExpiryRule{ + { + Tag: defaultExpiryTag, + ExpirationDays: 3, + }, +} diff --git a/pkg/service/service.go b/pkg/service/service.go index 500ede5e..59092f4d 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -104,8 +104,6 @@ type Service interface { UpdateModelRunWithError(ctx context.Context, runLog *datamodel.ModelRun, err error) *datamodel.ModelRun ListModelRuns(ctx context.Context, req *modelpb.ListModelRunsRequest, filter filtering.Filter) (*modelpb.ListModelRunsResponse, error) ListModelRunsByRequester(ctx context.Context, req *modelpb.ListModelRunsByRequesterRequest) (*modelpb.ListModelRunsByRequesterResponse, error) - - GetExpiryTagBySubscriptionPlan(context.Context, uuid.UUID) (string, error) } type service struct { @@ -119,6 +117,7 @@ type service struct { ray ray.Ray aclClient acl.ACLClientInterface minioClient miniox.MinioI + retentionHandler MetadataRetentionHandler instillCoreHost string } @@ -134,7 +133,11 @@ func NewService( ra ray.Ray, a acl.ACLClientInterface, minioClient miniox.MinioI, + retentionHandler MetadataRetentionHandler, h string) Service { + if retentionHandler == nil { + retentionHandler = NewRetentionHandler() + } return &service{ repository: r, influxDBWriteClient: i, @@ -146,6 +149,7 @@ func NewService( temporalClient: tc, aclClient: a, minioClient: minioClient, + retentionHandler: retentionHandler, instillCoreHost: h, } } @@ -190,7 +194,7 @@ func (s *service) CreateModelRun(ctx context.Context, triggerUID uuid.UUID, mode requesterUID, userUID := resourcex.GetRequesterUIDAndUserUID(ctx) requesterUUID := uuid.FromStringOrNil(requesterUID) - expiryRuleTag, err := s.GetExpiryTagBySubscriptionPlan(ctx, requesterUUID) + expiryRuleTag, err := s.retentionHandler.GetExpiryTagBySubscriptionPlan(ctx, requesterUUID) if err != nil { return nil, err } @@ -553,7 +557,7 @@ func (s *service) TriggerNamespaceModelByID(ctx context.Context, ns resource.Nam }, } - expiryRuleTag, err := s.GetExpiryTagBySubscriptionPlan(ctx, runLog.RequesterUID) + expiryRuleTag, err := s.retentionHandler.GetExpiryTagBySubscriptionPlan(ctx, runLog.RequesterUID) if err != nil { return nil, err } @@ -666,7 +670,7 @@ func (s *service) TriggerAsyncNamespaceModelByID(ctx context.Context, ns resourc userUID := uuid.FromStringOrNil(resource.GetRequestSingleHeader(ctx, constant.HeaderUserUIDKey)) - expiryRuleTag, err := s.GetExpiryTagBySubscriptionPlan(ctx, runLog.RequesterUID) + expiryRuleTag, err := s.retentionHandler.GetExpiryTagBySubscriptionPlan(ctx, runLog.RequesterUID) if err != nil { return nil, err } @@ -1370,7 +1374,3 @@ func (s *service) GetModelVersionAdmin(ctx context.Context, modelUID uuid.UUID, func (s *service) CreateModelVersionAdmin(ctx context.Context, version *datamodel.ModelVersion) error { return s.repository.CreateModelVersion(ctx, "", version) } - -func (s *service) GetExpiryTagBySubscriptionPlan(context.Context, uuid.UUID) (string, error) { - return config.DefaultExpiryTag, nil -} diff --git a/pkg/service/service_test.go b/pkg/service/service_test.go index fc26e991..ceac1df6 100644 --- a/pkg/service/service_test.go +++ b/pkg/service/service_test.go @@ -446,7 +446,7 @@ func TestGetModelDefinition(t *testing.T) { t.Run("TestGetModelDefinition", func(t *testing.T) { mockRepository := mock.NewRepositoryMock(mc) mockRepository.GetModelDefinitionMock.Times(1).Expect("github").Return(&datamodel.ModelDefinition{}, nil) - s := service.NewService(mockRepository, nil, nil, nil, nil, nil, nil, nil, nil, nil, "") + s := service.NewService(mockRepository, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, "") _, err := s.GetModelDefinition(context.Background(), "github") assert.NoError(t, err) @@ -455,7 +455,7 @@ func TestGetModelDefinition(t *testing.T) { t.Run("GetModelDefinitionByUID", func(t *testing.T) { mockRepository := mock.NewRepositoryMock(mc) mockRepository.GetModelDefinitionByUIDMock.Times(1).Expect(ModelDefinition).Return(&datamodel.ModelDefinition{}, nil) - s := service.NewService(mockRepository, nil, nil, nil, nil, nil, nil, nil, nil, nil, "") + s := service.NewService(mockRepository, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, "") _, err := s.GetModelDefinitionByUID(context.Background(), ModelDefinition) assert.NoError(t, err) @@ -469,7 +469,7 @@ func TestListModelDefinitions(t *testing.T) { mockRepository := mock.NewRepositoryMock(mc) mockRepository.ListModelDefinitionsMock.Times(1).Expect(modelPB.View_VIEW_FULL, 100, ""). Return([]*datamodel.ModelDefinition{}, "", 100, nil) - s := service.NewService(mockRepository, nil, nil, nil, nil, nil, nil, nil, nil, nil, "") + s := service.NewService(mockRepository, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, "") _, _, _, err := s.ListModelDefinitions(context.Background(), modelPB.View_VIEW_FULL, int32(100), "") assert.NoError(t, err)