diff --git a/go.mod b/go.mod index 61f75d37..53bc6ee0 100644 --- a/go.mod +++ b/go.mod @@ -121,7 +121,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.26.0 // indirect golang.org/x/sys v0.24.0 // indirect - golang.org/x/text v0.17.0 // indirect + golang.org/x/text v0.17.0 golang.org/x/time v0.6.0 // indirect google.golang.org/genproto v0.0.0-20240808171019-573a1156607a // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240808171019-573a1156607a diff --git a/pkg/repository/transpiler.go b/pkg/repository/transpiler.go index 263f9ae4..133630bc 100644 --- a/pkg/repository/transpiler.go +++ b/pkg/repository/transpiler.go @@ -2,6 +2,7 @@ package repository import ( "fmt" + "strings" "time" "github.com/iancoleman/strcase" @@ -196,7 +197,7 @@ func (t *Transpiler) transpileComparisonCallExpr(e *expr.Expr, op any) (*clause. vars = append(vars, con.Vars[0], fmt.Sprintf("%%%s%%", con.Vars[0])) case "tag": sql = "model_tag.tag_name = ?" - vars = append(vars, con.Vars...) + vars = append(vars, strings.ToLower(con.Vars[0].(string))) default: sql = fmt.Sprintf("%s = ?", ident.SQL) vars = append(vars, con.Vars...) diff --git a/pkg/service/convert.go b/pkg/service/convert.go index 185d731c..566c78a5 100644 --- a/pkg/service/convert.go +++ b/pkg/service/convert.go @@ -10,12 +10,15 @@ import ( "image" "image/jpeg" "image/png" + "slices" "strings" "time" "github.com/gabriel-vasile/mimetype" "github.com/gofrs/uuid" "golang.org/x/image/draw" + "golang.org/x/text/cases" + "golang.org/x/text/language" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" @@ -189,8 +192,18 @@ func (s *service) DBToPBModel(ctx context.Context, modelDef *datamodel.ModelDefi License: &dbModel.License.String, Readme: &dbModel.Readme.String, ProfileImage: &profileImage, - Tags: dbModel.TagNames(), - Versions: dbModel.VersionNames(), + Tags: func() []string { + tagNames := make([]string, len(dbModel.TagNames())) + for i, tag := range dbModel.TagNames() { + if slices.Contains(preserveTags, tag) { + tagNames[i] = cases.Title(language.English).String(tag) + } else { + tagNames[i] = tag + } + } + return tagNames + }(), + Versions: dbModel.VersionNames(), Stats: &modelpb.Model_Stats{ NumberOfRuns: int32(dbModel.NumberOfRuns), LastRunTime: timestamppb.New(dbModel.LastRunTime), diff --git a/pkg/service/service.go b/pkg/service/service.go index 24167131..149ec16e 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -393,6 +393,21 @@ func (s *service) CreateNamespaceModel(ctx context.Context, ns resource.Namespac if err := s.aclClient.SetOwner(ctx, "model_", dbCreatedModel.UID, ownerType, ownerUID); err != nil { return err } + toCreatedTags := model.GetTags() + toBeCreatedTagNames := make([]string, 0, len(toCreatedTags)) + for _, tag := range toCreatedTags { + tag = strings.ToLower(tag) + if !slices.Contains(preserveTags, tag) { + toBeCreatedTagNames = append(toBeCreatedTagNames, tag) + } + } + + if len(toBeCreatedTagNames) > 0 { + err = s.repository.CreateModelTags(ctx, dbCreatedModel.UID, toBeCreatedTagNames) + if err != nil { + return err + } + } if dbCreatedModel.Visibility == datamodel.ModelVisibility(modelpb.Model_VISIBILITY_PUBLIC) { if err := s.aclClient.SetPublicModelPermission(ctx, dbCreatedModel.UID); err != nil { @@ -1172,8 +1187,14 @@ func (s *service) UpdateNamespaceModelByID(ctx context.Context, ns resource.Name } toUpdTags := toUpdateModel.GetTags() - + for i := range toUpdTags { + toUpdTags[i] = strings.ToLower(toUpdTags[i]) + } currentTags := dbModel.TagNames() + for i := range currentTags { + currentTags[i] = strings.ToLower(currentTags[i]) + } + toBeCreatedTagNames := make([]string, 0, len(toUpdTags)) for _, tag := range toUpdTags { if !slices.Contains(currentTags, tag) && !slices.Contains(preserveTags, tag) {