From 53d4c8741dc9ea3ea98f2e096aaf9e0bcf5bef08 Mon Sep 17 00:00:00 2001 From: Chris Wetherill <89031899+cwetherill-ps@users.noreply.github.com> Date: Fri, 8 Apr 2022 14:26:40 -0400 Subject: [PATCH] feat(datasets): allow multipart uploads for large datasets (#384) This attempts to fall back to a multipart upload strategy with presigned URLs in the event that a dataset is larger than 500MB --- gradient/commands/datasets.py | 118 +++++++++++++++++++++++++++++++--- 1 file changed, 108 insertions(+), 10 deletions(-) diff --git a/gradient/commands/datasets.py b/gradient/commands/datasets.py index afeda34a..245e819d 100644 --- a/gradient/commands/datasets.py +++ b/gradient/commands/datasets.py @@ -12,6 +12,9 @@ import Queue as queue from xml.etree import ElementTree from urllib.parse import urlparse +from ..api_sdk.clients import http_client +from ..api_sdk.config import config +from ..cli_constants import CLI_PS_CLIENT_NAME import halo import requests @@ -557,24 +560,114 @@ def update_status(): class PutDatasetFilesCommand(BaseDatasetFilesCommand): - @classmethod - def _put(cls, path, url, content_type): + # @classmethod + def _put(self, path, url, content_type, dataset_version_id=None, key=None): size = os.path.getsize(path) with requests.Session() as session: headers = {'Content-Type': content_type} try: - if size > 0: + if size <= 0: + headers.update({'Content-Size': '0'}) + r = session.put(url, data='', headers=headers, timeout=5) + # for files under half a GB + elif size <= (10e8) / 2: with open(path, 'rb') as f: r = session.put( url, data=f, headers=headers, timeout=5) + # # for chonky files, use a multipart upload else: - headers.update({'Content-Size': '0'}) - r = session.put(url, data='', headers=headers, timeout=5) - - cls.validate_s3_response(r) + # Chunks need to be at least 5MB or AWS throws an + # EntityTooSmall error; we'll arbitrarily choose a + # 15MB chunksize + # + # Note also that AWS limits the max number of chunkc + # in a multipart upload to 10000, so this setting + # currently enforces a hard limit on 150GB per file. + # + # We can dynamically assign a larger part size if needed, + # but for the majority of use cases we should be fine + # as-is + part_minsize = int(15e6) + dataset_id, _, version = dataset_version_id.partition(":") + mpu_url = f'/datasets/{dataset_id}/versions/{version}/s3/preSignedUrls' + + api_client = http_client.API( + api_url=config.CONFIG_HOST, + api_key=self.api_key, + ps_client_name=CLI_PS_CLIENT_NAME + ) + + mpu_create_res = api_client.post( + url=mpu_url, + json={ + 'datasetId': dataset_id, + 'version': version, + 'calls': [{ + 'method': 'createMultipartUpload', + 'params': {'Key': key} + }] + } + ) + mpu_data = json.loads(mpu_create_res.text)[0]['url'] + + parts = [] + with open(path, 'rb') as f: + # we +2 the number of parts since we're doing floor + # division, which will cut off any trailing part + # less than the part_minsize, AND we want to 1-index + # our range to match what AWS expects for part + # numbers + for part in range(1, (size // part_minsize) + 2): + presigned_url_res = api_client.post( + url=mpu_url, + json={ + 'datasetId': dataset_id, + 'version': version, + 'calls': [{ + 'method': 'uploadPart', + 'params': { + 'Key': key, + 'UploadId': mpu_data['UploadId'], + 'PartNumber': part + } + }] + } + ) + + presigned_url = json.loads( + presigned_url_res.text + )[0]['url'] + + chunk = f.read(part_minsize) + part_res = session.put( + presigned_url, + data=chunk, + timeout=5) + etag = part_res.headers['ETag'].replace('"', '') + parts.append({'ETag': etag, 'PartNumber': part}) + + r = api_client.post( + url=mpu_url, + json={ + 'datasetId': dataset_id, + 'version': version, + 'calls': [{ + 'method': 'completeMultipartUpload', + 'params': { + 'Key': key, + 'UploadId': mpu_data['UploadId'], + 'MultipartUpload': {'Parts': parts} + } + }] + } + ) + + self.validate_s3_response(r) except requests.exceptions.ConnectionError as e: - return cls.report_connection_error(e) + return self.report_connection_error(e) + except Exception as e: + return e @staticmethod def _list_files(source_path): @@ -599,8 +692,13 @@ def _sign_and_put(self, dataset_version_id, pool, results, update_status): for pre_signed, result in zip(pre_signeds, results): update_status() - pool.put(self._put, url=pre_signed.url, - path=result['path'], content_type=result['mimetype']) + pool.put( + self._put, + url=pre_signed.url, + path=result['path'], + content_type=result['mimetype'], + dataset_version_id=dataset_version_id, + key=result['key']) def execute(self, dataset_version_id, source_paths, target_path): self.assert_supported(dataset_version_id)