Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ALS-7539] - Get rid of super complex self refreshing client #44

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package edu.harvard.dbmi.avillach.dataupload.aws;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Service;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
import software.amazon.awssdk.services.sts.model.AssumeRoleResponse;
import software.amazon.awssdk.services.sts.model.Credentials;

import java.util.Map;
import java.util.Optional;

@Profile("!dev")
@Service
public class AWSClientBuilder {

private static final Logger log = LoggerFactory.getLogger(AWSClientBuilder.class);

private final Map<String, SiteAWSInfo> sites;
private final StsClientProvider stsClientProvider;
private final S3ClientBuilder s3ClientBuilder;
private final SdkHttpClient sdkHttpClient;

@Autowired
public AWSClientBuilder(
Map<String, SiteAWSInfo> sites,
StsClientProvider stsClientProvider,
S3ClientBuilder s3ClientBuilder,
@Autowired(required = false) SdkHttpClient sdkHttpClient
) {
this.sites = sites;
this.stsClientProvider = stsClientProvider;
this.s3ClientBuilder = s3ClientBuilder;
this.sdkHttpClient = sdkHttpClient;
}

public Optional<S3Client> buildClientForSite(String siteName) {
log.info("Building client for site {}", siteName);
if (!sites.containsKey(siteName)) {
log.warn("Could not find site {}", siteName);
return Optional.empty();
}

log.info("Found site, making assume role request");
SiteAWSInfo site = sites.get(siteName);
AssumeRoleRequest roleRequest = AssumeRoleRequest.builder()
.roleArn(site.roleARN())
.roleSessionName("test_session" + System.nanoTime())
.externalId(site.externalId())
.durationSeconds(60*60) // 1 hour
.build();
Optional<Credentials> assumeRoleResponse = stsClientProvider.createClient()
.map(c -> c.assumeRole(roleRequest))
.map(AssumeRoleResponse::credentials);
if (assumeRoleResponse.isEmpty() ) {
log.error("Error assuming role {} , no credentials returned", site.roleARN());
return Optional.empty();
}
log.info("Successfully assumed role {} for site {}", site.roleARN(), site.siteName());

log.info("Building S3 client for site {}", site.siteName());
// Use the credentials from the role to create the S3 client
Credentials credentials = assumeRoleResponse.get();
AwsSessionCredentials sessionCredentials = AwsSessionCredentials.builder()
.accessKeyId(credentials.accessKeyId())
.secretAccessKey(credentials.secretAccessKey())
.sessionToken(credentials.sessionToken())
.expirationTime(credentials.expiration())
.build();
StaticCredentialsProvider provider = StaticCredentialsProvider.create(sessionCredentials);
return Optional.of(buildFromProvider(provider));
}

private S3Client buildFromProvider(StaticCredentialsProvider provider) {
if (sdkHttpClient == null) {
return s3ClientBuilder.credentialsProvider(provider).build();
}
log.info("Http proxy detected and added to S3 client");
return s3ClientBuilder
.credentialsProvider(provider)
.httpClient(sdkHttpClient)
.build();

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.util.StringUtils;
import org.springframework.web.context.annotation.RequestScope;
import software.amazon.awssdk.auth.credentials.*;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.StsClientBuilder;
import software.amazon.encryption.s3.S3EncryptionClient;
Expand Down Expand Up @@ -82,4 +85,15 @@ StsClientBuilder stsClientBuilder() {
// This is a bean for mocking purposes
return StsClient.builder();
}

@Bean
S3ClientBuilder s3ClientBuilder() {
return S3Client.builder();
}

@Bean
@RequestScope
StsClient getStsClient() {
return StsClient.builder().region(Region.US_EAST_1).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class S3StateVerifier {
private Map<String, SiteAWSInfo> sites;

@Autowired
private SelfRefreshingS3Client client;
private AWSClientBuilder clientBuilder;

@PostConstruct
private void verifyS3Status() {
Expand All @@ -39,7 +39,7 @@ private void verifyS3Status() {
private void asyncVerify(SiteAWSInfo institution) {
LOG.info("Checking S3 connection to {} ...", institution.siteName());
createTempFileWithText(institution)
.map(p -> uploadFileFromPath(p, institution))
.flatMap(p -> uploadFileFromPath(p, institution))
.map(this::waitABit)
.flatMap(s1 -> deleteFileFromBucket(s1, institution))
.orElseThrow();
Expand All @@ -49,8 +49,10 @@ private void asyncVerify(SiteAWSInfo institution) {
private Optional<String> deleteFileFromBucket(String s, SiteAWSInfo info) {
LOG.info("Verifying delete capabilities");
DeleteObjectRequest request = DeleteObjectRequest.builder().bucket(info.bucket()).key(s).build();
DeleteObjectResponse deleteObjectResponse = client.getS3Client(info.siteName()).deleteObject(request);
return deleteObjectResponse.deleteMarker() ? Optional.of(s) : Optional.empty();
return clientBuilder.buildClientForSite(info.siteName())
.map(c -> c.deleteObject(request))
.map(DeleteObjectResponse::deleteMarker)
.map((ignored) -> s);
}

private String waitABit(String s) {
Expand All @@ -62,7 +64,7 @@ private String waitABit(String s) {
return s;
}

private String uploadFileFromPath(Path p, SiteAWSInfo info) {
private Optional<String> uploadFileFromPath(Path p, SiteAWSInfo info) {
LOG.info("Verifying upload capabilities");
RequestBody body = RequestBody.fromFile(p.toFile());
PutObjectRequest request = PutObjectRequest.builder()
Expand All @@ -71,8 +73,9 @@ private String uploadFileFromPath(Path p, SiteAWSInfo info) {
.ssekmsKeyId(info.kmsKeyID())
.key(p.getFileName().toString())
.build();
client.getS3Client(info.siteName()).putObject(request, body);
return p.getFileName().toString();
return clientBuilder.buildClientForSite(info.siteName())
.map(client -> client.putObject(request, body))
.map(resp -> p.getFileName().toString());
}

private Optional<Path> createTempFileWithText(SiteAWSInfo info) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package edu.harvard.dbmi.avillach.dataupload.aws;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sts.StsClient;

import java.util.Optional;

@Service
public class StsClientProvider {

private static final Logger log = LoggerFactory.getLogger(StsClientProvider.class);

public Optional<StsClient> createClient() {
StsClient client = StsClient.builder().region(Region.US_EAST_1).build();
return Optional.of(client);
}
}
Loading