Skip to content

Commit

Permalink
Fix segfault on s3 filesystem (#1829)
Browse files Browse the repository at this point in the history
* fix error

* apply clang-format
  • Loading branch information
jeongukjae authored Aug 19, 2023
1 parent 8b87c3b commit 744fd06
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions tensorflow_io/core/filesystems/s3/s3_filesystem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,15 @@ static Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
}
// if these timeouts are low, you may see an error when
// uploading/downloading large files: Unable to connect to endpoint
const char *connect_timeout = getenv("S3_CONNECT_TIMEOUT_MSEC"),
*request_timeout = getenv("S3_REQUEST_TIMEOUT_MSEC");
int64_t timeout;
cfg.connectTimeoutMs =
absl::SimpleAtoi(getenv("S3_CONNECT_TIMEOUT_MSEC"), &timeout)
connect_timeout && absl::SimpleAtoi(connect_timeout, &timeout)
? timeout
: kS3TimeoutMsec;
cfg.requestTimeoutMs =
absl::SimpleAtoi(getenv("S3_REQUEST_TIMEOUT_MSEC"), &timeout)
request_timeout && absl::SimpleAtoi(request_timeout, &timeout)
? timeout
: kS3TimeoutMsec;
const char* ca_file = getenv("S3_CA_FILE");
Expand Down Expand Up @@ -189,8 +191,11 @@ static void GetS3Client(tf_s3_filesystem::S3File* s3_file) {
}
});

const char* disable_multi_part_download =
getenv("S3_DISABLE_MULTI_PART_DOWNLOAD");
int temp_value;
if (absl::SimpleAtoi(getenv("S3_DISABLE_MULTI_PART_DOWNLOAD"), &temp_value))
if (disable_multi_part_download &&
absl::SimpleAtoi(disable_multi_part_download, &temp_value))
s3_file->use_multi_part_download = (temp_value != 1);

const char* endpoint = getenv("S3_ENDPOINT");
Expand Down Expand Up @@ -220,12 +225,15 @@ static void GetTransferManager(
if (s3_file->transfer_managers.count(direction) == 0) {
uint64_t temp_value;
if (direction == Aws::Transfer::TransferDirection::UPLOAD) {
if (!absl::SimpleAtoi(getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE"),
&temp_value))
const char* upload_chunk_size = getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE");
if (upload_chunk_size == nullptr ||
!absl::SimpleAtoi(upload_chunk_size, &temp_value))
temp_value = kS3MultiPartUploadChunkSize;
} else if (direction == Aws::Transfer::TransferDirection::DOWNLOAD) {
if (!absl::SimpleAtoi(getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE"),
&temp_value))
const char* download_chunk_size =
getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE");
if (download_chunk_size == nullptr ||
!absl::SimpleAtoi(download_chunk_size, &temp_value))
temp_value = kS3MultiPartDownloadChunkSize;
}
s3_file->multi_part_chunk_sizes.emplace(direction, temp_value);
Expand Down Expand Up @@ -1015,7 +1023,7 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,

PathExists(filesystem, dir_path.c_str(), status);
if (TF_GetCode(status) == TF_OK) {
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile * file)> file(
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile* file)> file(
new TF_WritableFile, [](TF_WritableFile* file) {
if (file != nullptr) {
if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file);
Expand Down

0 comments on commit 744fd06

Please sign in to comment.