diff --git a/.gitignore b/.gitignore index 9ed5eb27f..4483dbd16 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,7 @@ Thumbs.db *MANIFEST *.egg-info venv*/ +*/__pycache__/* /nav /tags diff --git a/codalab/common.py b/codalab/common.py index 558f3bd9c..5a0f2dd9b 100644 --- a/codalab/common.py +++ b/codalab/common.py @@ -286,7 +286,7 @@ def _get_azure_sas_url(self, path, **kwargs): account_name=AZURE_BLOB_ACCOUNT_NAME, container_name=AZURE_BLOB_CONTAINER_NAME, account_key=AZURE_BLOB_ACCOUNT_KEY, - expiry=datetime.datetime.now() + datetime.timedelta(hours=1), + expiry=datetime.datetime.now() + datetime.timedelta(hours=10), blob_name=blob_name, ) return f"{AZURE_BLOB_HTTP_ENDPOINT}/{AZURE_BLOB_CONTAINER_NAME}/{blob_name}?{sas_token}" @@ -306,7 +306,7 @@ def _get_gcs_signed_url(self, path, **kwargs): blob = bucket.blob(blob_name) signed_url = blob.generate_signed_url( version="v4", - expiration=datetime.timedelta(hours=1), + expiration=datetime.timedelta(hours=10), method=kwargs.get("method", "GET"), # HTTP method. eg, GET, PUT content_type=kwargs.get("request_content_type", None), response_disposition=kwargs.get("content_disposition", None), diff --git a/codalab/lib/beam/MultiReaderFileStream.py b/codalab/lib/beam/MultiReaderFileStream.py index bab64cc78..0d96a3572 100644 --- a/codalab/lib/beam/MultiReaderFileStream.py +++ b/codalab/lib/beam/MultiReaderFileStream.py @@ -2,7 +2,7 @@ from threading import Lock from codalab.worker.un_gzip_stream import BytesBuffer - +import threading class MultiReaderFileStream(BytesIO): """ @@ -10,11 +10,16 @@ class MultiReaderFileStream(BytesIO): """ NUM_READERS = 2 + # MAX memory usage <= MAX_BUF_SIZE + max(num_bytes called in read) + MAX_BUF_SIZE = 1024 * 1024 * 1024 # 10 MiB for test + def __init__(self, fileobj): self._bufs = [BytesBuffer() for _ in range(0, self.NUM_READERS)] self._pos = [0 for _ in range(0, self.NUM_READERS)] self._fileobj = fileobj self._lock = Lock() # lock to ensure one does not concurrently read self._fileobj / write to the buffers. + self._current_max_buf_length = 0 + self._buffer_condition = threading.Condition() class FileStreamReader(BytesIO): def __init__(s, index): @@ -36,15 +41,46 @@ def _fill_buf_bytes(self, index: int, num_bytes=None): break for i in range(0, self.NUM_READERS): self._bufs[i].write(s) + self.find_largest_buffer() + def find_largest_buffer(self): + self._current_max_buf_length = len(self._bufs[0]) + for i in range(1, self.NUM_READERS): + self._current_max_buf_length = max(self._current_max_buf_length, len(self._bufs[i])) + # Notify the condition variable when the buffer length condition is met + if self._current_max_buf_length <= self.MAX_BUF_SIZE: + with self._buffer_condition: + self._buffer_condition.notifyAll() + def read(self, index: int, num_bytes=None): # type: ignore """Read the specified number of bytes from the associated file. index: index that specifies which reader is reading. """ + + # print(f"calling read() in thread {threading.current_thread().name}, num_bytes={num_bytes}") + # busy waiting until t + # while(self._current_max_buf_length > self.MAX_BUF_SIZE and len(self._bufs[index]) < self._current_max_buf_length): + # # only the slowest reader could read + # # print(f"Busy waiting in thread: {threading.current_thread().name}, current max_len = {self._current_max_buf_length}, current_buf_size = {len(self._bufs[index])}") + # pass + + # Wait until the buffer length condition is satisfied + with self._buffer_condition: + while self._current_max_buf_length > self.MAX_BUF_SIZE and len(self._bufs[index]) < self._current_max_buf_length: + # Wait for the condition variable to be notified + self._buffer_condition.wait() + + # If current thread is the slowest reader, continue read. + # If current thread is the slowest reader, and num_bytes > len(self._buf[index]) / num_bytes = None, will continue grow the buffer. + # max memory usage <= MAX_BUF_SIZE + max(num_bytes called in read) self._fill_buf_bytes(index, num_bytes) + assert self._current_max_buf_length <= 2 * self.MAX_BUF_SIZE if num_bytes is None: num_bytes = len(self._bufs[index]) s = self._bufs[index].read(num_bytes) + self.find_largest_buffer() + # print("Current thread name: ", threading.current_thread().name) + self._pos[index] += len(s) return s diff --git a/codalab/lib/beam/SQLiteIndexedTar.py b/codalab/lib/beam/SQLiteIndexedTar.py index 4b6110d5d..42c231cad 100644 --- a/codalab/lib/beam/SQLiteIndexedTar.py +++ b/codalab/lib/beam/SQLiteIndexedTar.py @@ -736,7 +736,7 @@ def _createIndex( # In that case add that itself to the file index. This won't work when called recursively, # so check stream offset. fileCount = self.sqlConnection.execute('SELECT COUNT(*) FROM "files";').fetchone()[0] - if fileCount == 0: # Jiani: For Codalab, the bundle contains only + if fileCount == 0: # Jiani: For Codalab, the bundle contains only single files # This branch is not used. if self.printDebug >= 3: print(f"Did not find any file in the given TAR: {self.tarFileName}. Assuming a compressed file.") diff --git a/tests/unit/server/upload_download_test.py b/tests/unit/server/upload_download_test.py index 7365a2483..e79d5f324 100644 --- a/tests/unit/server/upload_download_test.py +++ b/tests/unit/server/upload_download_test.py @@ -71,6 +71,7 @@ def test_not_found(self): def check_file_target_contents(self, target): """Checks to make sure that the specified file has the contents 'hello world'.""" + # This can not be checked, Since with self.download_manager.stream_file(target, gzipped=False) as f: self.assertEqual(f.read(), b"hello world")