diff --git a/core/provider/errors.go b/core/provider/errors.go index 3ed442206..9371dee00 100644 --- a/core/provider/errors.go +++ b/core/provider/errors.go @@ -27,4 +27,6 @@ var ( ErrAppealValidationInvalidDurationValue = errors.New("invalid duration value") ErrAppealValidationMissingRequiredParameter = errors.New("missing required parameter") ErrAppealValidationMissingRequiredQuestion = errors.New("missing required question") + + ErrGrantAlreadyExists = errors.New("grant already exists") ) diff --git a/plugins/providers/oss/provider.go b/plugins/providers/oss/provider.go index 3ec37c5ad..80f49c0e0 100644 --- a/plugins/providers/oss/provider.go +++ b/plugins/providers/oss/provider.go @@ -35,13 +35,17 @@ type Policy struct { Statement []PolicyStatement `json:"Statement"` } +type OSSClient struct { + client *oss.Client + stsClientExist bool +} type provider struct { pv.UnimplementedClient pv.PermissionManager typeName string encryptor encryptor - ossClients map[string]*oss.Client + ossClients map[string]OSSClient sts *sts.Sts mu sync.Mutex @@ -51,7 +55,7 @@ func NewProvider(typeName string, encryptor encryptor) *provider { return &provider{ typeName: typeName, encryptor: encryptor, - ossClients: make(map[string]*oss.Client), + ossClients: make(map[string]OSSClient), sts: sts.NewSTS(), } } @@ -151,6 +155,9 @@ func (p *provider) GrantAccess(ctx context.Context, pc *domain.ProviderConfig, g bucketPolicy, err := updatePolicyToGrantPermissions(existingPolicy, g) if err != nil { + if errors.Is(err, pv.ErrGrantAlreadyExists) { + return nil + } return err } @@ -324,7 +331,7 @@ func updatePolicyToGrantPermissions(policy string, g domain.Grant) (string, erro foundStatementToUpdate := false for _, statement := range matchingStatements { if slices.Contains(statement.Principal, principalAccountID) { - return "", fmt.Errorf("access already granted for role: %s", g.Role) + return "", pv.ErrGrantAlreadyExists } if !foundStatementToUpdate { @@ -375,12 +382,6 @@ func (p *provider) getCreds(pc *domain.ProviderConfig) (*Credentials, error) { } func (p *provider) getOSSClient(pc *domain.ProviderConfig, ramRole string) (*oss.Client, error) { - if existingClient, ok := p.ossClients[ramRole]; ok { - if p.sts.IsSTSTokenValid(ramRole) { - return existingClient, nil - } - } - creds, err := p.getCreds(pc) if err != nil { return nil, err @@ -390,29 +391,61 @@ func (p *provider) getOSSClient(pc *domain.ProviderConfig, ramRole string) (*oss ramRole = creds.RAMRole } - stsClient, err := p.sts.GetSTSClient(ramRole, creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID) - if err != nil { - return nil, err + stsClientID := "oss-" + ramRole + if ossClient, ok := p.getCachedOSSClient(ramRole, stsClientID, pc.URN); ok { + return ossClient, nil } - clientConfig, err := sts.AssumeRole(stsClient, creds.AccessKeyID, ramRole, pc.URN) - if err != nil { - return nil, err - } - - clientOpts := oss.SecurityToken(*clientConfig.SecurityToken) endpoint := fmt.Sprintf("https://oss-%s.aliyuncs.com", creds.RegionID) - client, err := oss.New(endpoint, *clientConfig.AccessKeyId, *clientConfig.AccessKeySecret, clientOpts) - if err != nil { - return nil, fmt.Errorf("failed to initialize oss client: %w", err) + var client *oss.Client + if ramRole != "" { + stsClient, err := p.sts.GetSTSClient(stsClientID, creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID) + if err != nil { + return nil, err + } + + clientConfig, err := sts.AssumeRole(stsClient, creds.RAMRole, pc.URN, creds.RegionID) + if err != nil { + return nil, err + } + + clientOpts := oss.SecurityToken(*clientConfig.SecurityToken) + client, err = oss.New(endpoint, *clientConfig.AccessKeyId, *clientConfig.AccessKeySecret, clientOpts) + if err != nil { + return nil, fmt.Errorf("failed to initialize oss client: %w", err) + } + } else { + client, err = oss.New(endpoint, creds.AccessKeyID, creds.AccessKeySecret) + if err != nil { + return nil, fmt.Errorf("failed to initialize oss client: %w", err) + } } p.mu.Lock() - p.ossClients[ramRole] = client + if ramRole != "" { + p.ossClients[ramRole] = OSSClient{client: client, stsClientExist: true} + } else { + p.ossClients[pc.URN] = OSSClient{client: client} + } p.mu.Unlock() return client, nil } +func (p *provider) getCachedOSSClient(ramRole, stsClientID, urn string) (*oss.Client, bool) { + if c, ok := p.ossClients[ramRole]; ok { + if c.stsClientExist && p.sts.IsSTSTokenValid(stsClientID) { + return c.client, true + } + return c.client, true + } + + if c, ok := p.ossClients[urn]; ok { + return c.client, true + } + + return nil, false +} + func getRAMRole(g domain.Grant) (string, error) { resourceAccountID, err := getAccountIDFromResource(g.Resource) if err != nil {