Skip to content

Commit

Permalink
fix: update oss client caching logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Ayushi Sharma committed Dec 12, 2024
1 parent b9b3857 commit 8e443b9
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 22 deletions.
2 changes: 2 additions & 0 deletions core/provider/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ var (
ErrAppealValidationInvalidDurationValue = errors.New("invalid duration value")
ErrAppealValidationMissingRequiredParameter = errors.New("missing required parameter")
ErrAppealValidationMissingRequiredQuestion = errors.New("missing required question")

ErrGrantAlreadyExists = errors.New("grant already exists")
)
77 changes: 55 additions & 22 deletions plugins/providers/oss/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,17 @@ type Policy struct {
Statement []PolicyStatement `json:"Statement"`
}

type OSSClient struct {
client *oss.Client
stsClientExist bool
}
type provider struct {
pv.UnimplementedClient
pv.PermissionManager
typeName string
encryptor encryptor

ossClients map[string]*oss.Client
ossClients map[string]OSSClient
sts *sts.Sts

mu sync.Mutex
Expand All @@ -51,7 +55,7 @@ func NewProvider(typeName string, encryptor encryptor) *provider {
return &provider{
typeName: typeName,
encryptor: encryptor,
ossClients: make(map[string]*oss.Client),
ossClients: make(map[string]OSSClient),
sts: sts.NewSTS(),
}
}
Expand Down Expand Up @@ -151,6 +155,9 @@ func (p *provider) GrantAccess(ctx context.Context, pc *domain.ProviderConfig, g

bucketPolicy, err := updatePolicyToGrantPermissions(existingPolicy, g)
if err != nil {
if errors.Is(err, pv.ErrGrantAlreadyExists) {
return nil
}
return err
}

Expand Down Expand Up @@ -324,7 +331,7 @@ func updatePolicyToGrantPermissions(policy string, g domain.Grant) (string, erro
foundStatementToUpdate := false
for _, statement := range matchingStatements {
if slices.Contains(statement.Principal, principalAccountID) {
return "", fmt.Errorf("access already granted for role: %s", g.Role)
return "", pv.ErrGrantAlreadyExists
}

if !foundStatementToUpdate {
Expand Down Expand Up @@ -375,12 +382,6 @@ func (p *provider) getCreds(pc *domain.ProviderConfig) (*Credentials, error) {
}

func (p *provider) getOSSClient(pc *domain.ProviderConfig, ramRole string) (*oss.Client, error) {
if existingClient, ok := p.ossClients[ramRole]; ok {
if p.sts.IsSTSTokenValid(ramRole) {
return existingClient, nil
}
}

creds, err := p.getCreds(pc)
if err != nil {
return nil, err
Expand All @@ -390,29 +391,61 @@ func (p *provider) getOSSClient(pc *domain.ProviderConfig, ramRole string) (*oss
ramRole = creds.RAMRole
}

stsClient, err := p.sts.GetSTSClient(ramRole, creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID)
if err != nil {
return nil, err
stsClientID := "oss-" + ramRole
if ossClient, ok := p.getCachedOSSClient(ramRole, stsClientID, pc.URN); ok {
return ossClient, nil
}

clientConfig, err := sts.AssumeRole(stsClient, creds.AccessKeyID, ramRole, pc.URN)
if err != nil {
return nil, err
}

clientOpts := oss.SecurityToken(*clientConfig.SecurityToken)
endpoint := fmt.Sprintf("https://oss-%s.aliyuncs.com", creds.RegionID)
client, err := oss.New(endpoint, *clientConfig.AccessKeyId, *clientConfig.AccessKeySecret, clientOpts)
if err != nil {
return nil, fmt.Errorf("failed to initialize oss client: %w", err)
var client *oss.Client
if ramRole != "" {
stsClient, err := p.sts.GetSTSClient(stsClientID, creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID)
if err != nil {
return nil, err
}

clientConfig, err := sts.AssumeRole(stsClient, creds.RAMRole, pc.URN, creds.RegionID)
if err != nil {
return nil, err
}

clientOpts := oss.SecurityToken(*clientConfig.SecurityToken)
client, err = oss.New(endpoint, *clientConfig.AccessKeyId, *clientConfig.AccessKeySecret, clientOpts)
if err != nil {
return nil, fmt.Errorf("failed to initialize oss client: %w", err)
}
} else {
client, err = oss.New(endpoint, creds.AccessKeyID, creds.AccessKeySecret)
if err != nil {
return nil, fmt.Errorf("failed to initialize oss client: %w", err)
}
}

p.mu.Lock()
p.ossClients[ramRole] = client
if ramRole != "" {
p.ossClients[ramRole] = OSSClient{client: client, stsClientExist: true}
} else {
p.ossClients[pc.URN] = OSSClient{client: client}
}
p.mu.Unlock()
return client, nil
}

func (p *provider) getCachedOSSClient(ramRole, stsClientID, urn string) (*oss.Client, bool) {
if c, ok := p.ossClients[ramRole]; ok {
if c.stsClientExist && p.sts.IsSTSTokenValid(stsClientID) {
return c.client, true
}
return c.client, true
}

if c, ok := p.ossClients[urn]; ok {
return c.client, true
}

return nil, false
}

func getRAMRole(g domain.Grant) (string, error) {
resourceAccountID, err := getAccountIDFromResource(g.Resource)
if err != nil {
Expand Down

0 comments on commit 8e443b9

Please sign in to comment.