Skip to content

Commit

Permalink
[ALS-7539] - Get rid of super complex self refreshing client
Browse files Browse the repository at this point in the history
  • Loading branch information
Luke Sikina authored and Luke-Sikina committed Nov 28, 2024
1 parent 0b7e349 commit daf9d4a
Show file tree
Hide file tree
Showing 10 changed files with 277 additions and 177 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package edu.harvard.dbmi.avillach.dataupload.aws;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Service;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
import software.amazon.awssdk.services.sts.model.AssumeRoleResponse;
import software.amazon.awssdk.services.sts.model.Credentials;

import java.util.Map;
import java.util.Optional;

@Profile("!dev")
@Service
public class AWSClientBuilder {

private static final Logger log = LoggerFactory.getLogger(AWSClientBuilder.class);

private final Map<String, SiteAWSInfo> sites;
private final StsClientProvider stsClientProvider;
private final S3ClientBuilder s3ClientBuilder;
private final SdkHttpClient sdkHttpClient;

@Autowired
public AWSClientBuilder(
Map<String, SiteAWSInfo> sites,
StsClientProvider stsClientProvider,
S3ClientBuilder s3ClientBuilder,
@Autowired(required = false) SdkHttpClient sdkHttpClient
) {
this.sites = sites;
this.stsClientProvider = stsClientProvider;
this.s3ClientBuilder = s3ClientBuilder;
this.sdkHttpClient = sdkHttpClient;
}

public Optional<S3Client> buildClientForSite(String siteName) {
log.info("Building client for site {}", siteName);
if (!sites.containsKey(siteName)) {
log.warn("Could not find site {}", siteName);
return Optional.empty();
}

log.info("Found site, making assume role request");
SiteAWSInfo site = sites.get(siteName);
AssumeRoleRequest roleRequest = AssumeRoleRequest.builder()
.roleArn(site.roleARN())
.roleSessionName("test_session" + System.nanoTime())
.externalId(site.externalId())
.durationSeconds(60*60) // 1 hour
.build();
Optional<Credentials> assumeRoleResponse = stsClientProvider.createClient()
.map(c -> c.assumeRole(roleRequest))
.map(AssumeRoleResponse::credentials);
if (assumeRoleResponse.isEmpty() ) {
log.error("Error assuming role {} , no credentials returned", site.roleARN());
return Optional.empty();
}
log.info("Successfully assumed role {} for site {}", site.roleARN(), site.siteName());

log.info("Building S3 client for site {}", site.siteName());
// Use the credentials from the role to create the S3 client
Credentials credentials = assumeRoleResponse.get();
AwsSessionCredentials sessionCredentials = AwsSessionCredentials.builder()
.accessKeyId(credentials.accessKeyId())
.secretAccessKey(credentials.secretAccessKey())
.sessionToken(credentials.sessionToken())
.expirationTime(credentials.expiration())
.build();
StaticCredentialsProvider provider = StaticCredentialsProvider.create(sessionCredentials);
return Optional.of(buildFromProvider(provider));
}

private S3Client buildFromProvider(StaticCredentialsProvider provider) {
if (sdkHttpClient == null) {
return s3ClientBuilder.credentialsProvider(provider).build();
}
log.info("Http proxy detected and added to S3 client");
return s3ClientBuilder
.credentialsProvider(provider)
.httpClient(sdkHttpClient)
.build();

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.util.StringUtils;
import org.springframework.web.context.annotation.RequestScope;
import software.amazon.awssdk.auth.credentials.*;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.StsClientBuilder;
import software.amazon.encryption.s3.S3EncryptionClient;
Expand Down Expand Up @@ -82,4 +85,15 @@ StsClientBuilder stsClientBuilder() {
// This is a bean for mocking purposes
return StsClient.builder();
}

@Bean
S3ClientBuilder s3ClientBuilder() {
return S3Client.builder();
}

@Bean
@RequestScope
StsClient getStsClient() {
return StsClient.builder().region(Region.US_EAST_1).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class S3StateVerifier {
private Map<String, SiteAWSInfo> sites;

@Autowired
private SelfRefreshingS3Client client;
private AWSClientBuilder clientBuilder;

@PostConstruct
private void verifyS3Status() {
Expand All @@ -39,7 +39,7 @@ private void verifyS3Status() {
private void asyncVerify(SiteAWSInfo institution) {
LOG.info("Checking S3 connection to {} ...", institution.siteName());
createTempFileWithText(institution)
.map(p -> uploadFileFromPath(p, institution))
.flatMap(p -> uploadFileFromPath(p, institution))
.map(this::waitABit)
.flatMap(s1 -> deleteFileFromBucket(s1, institution))
.orElseThrow();
Expand All @@ -49,8 +49,10 @@ private void asyncVerify(SiteAWSInfo institution) {
private Optional<String> deleteFileFromBucket(String s, SiteAWSInfo info) {
LOG.info("Verifying delete capabilities");
DeleteObjectRequest request = DeleteObjectRequest.builder().bucket(info.bucket()).key(s).build();
DeleteObjectResponse deleteObjectResponse = client.getS3Client(info.siteName()).deleteObject(request);
return deleteObjectResponse.deleteMarker() ? Optional.of(s) : Optional.empty();
return clientBuilder.buildClientForSite(info.siteName())
.map(c -> c.deleteObject(request))
.map(DeleteObjectResponse::deleteMarker)
.map((ignored) -> s);
}

private String waitABit(String s) {
Expand All @@ -62,7 +64,7 @@ private String waitABit(String s) {
return s;
}

private String uploadFileFromPath(Path p, SiteAWSInfo info) {
private Optional<String> uploadFileFromPath(Path p, SiteAWSInfo info) {
LOG.info("Verifying upload capabilities");
RequestBody body = RequestBody.fromFile(p.toFile());
PutObjectRequest request = PutObjectRequest.builder()
Expand All @@ -71,8 +73,9 @@ private String uploadFileFromPath(Path p, SiteAWSInfo info) {
.ssekmsKeyId(info.kmsKeyID())
.key(p.getFileName().toString())
.build();
client.getS3Client(info.siteName()).putObject(request, body);
return p.getFileName().toString();
return clientBuilder.buildClientForSite(info.siteName())
.map(client -> client.putObject(request, body))
.map(resp -> p.getFileName().toString());
}

private Optional<Path> createTempFileWithText(SiteAWSInfo info) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package edu.harvard.dbmi.avillach.dataupload.aws;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;

import java.util.Optional;

@Service
public class StsClientProvider {

private static final Logger log = LoggerFactory.getLogger(StsClientProvider.class);

public Optional<StsClient> createClient() {
StsClient client = StsClient.builder().region(Region.US_EAST_1).build();
return Optional.of(client);
}
}
Loading

0 comments on commit daf9d4a

Please sign in to comment.