Skip to content

Commit

Permalink
fix bug not checking upstream update time and move buffer clean up logic
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsantana11 committed Sep 29, 2024
1 parent 74f40f5 commit e09508a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 50 deletions.
88 changes: 41 additions & 47 deletions clouddrift/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Callable, Sequence

import requests
from requests import Response
from tenacity import (
RetryCallState,
WrappedFn,
Expand Down Expand Up @@ -66,11 +65,10 @@ def download_with_progress(
else:
retry_protocol = custom_retry_protocol # type: ignore

buffer: BufferedIOBase | BufferedWriter
executor = concurrent.futures.ThreadPoolExecutor()
futures: dict[
concurrent.futures.Future,
tuple[str, BufferedIOBase | str, BufferedIOBase | BufferedWriter],
concurrent.futures.Future[None],
tuple[str, BufferedIOBase | str],
] = dict()
bar = None

Expand All @@ -81,20 +79,15 @@ def download_with_progress(
src, dst = request
exp_size = None

if isinstance(dst, (str,)):
buffer = open(dst, "wb")
else:
buffer = dst

futures[
executor.submit(
retry_protocol(_download_with_progress),
src,
buffer,
dst,
exp_size,
not show_list_progress,
)
] = (src, dst, buffer)
] = (src, dst)

try:
if show_list_progress:
Expand All @@ -106,11 +99,7 @@ def download_with_progress(
)

for fut in concurrent.futures.as_completed(futures):
src, dst, buffer = futures[fut]

if isinstance(dst, (str,)):
buffer.close()

src, dst = futures[fut]
ex = fut.exception(0)
if ex is None:
_logger.debug(f"Finished download job: ({src}, {dst})")
Expand All @@ -124,7 +113,7 @@ def download_with_progress(
any created resources."
)
for x in futures.keys():
src, dst, buffer = futures[x]
src, dst = futures[x]

if not x.done():
x.cancel()
Expand All @@ -140,7 +129,7 @@ def download_with_progress(

def _download_with_progress(
url: str,
output: BufferedIOBase | BufferedWriter,
output: str | BufferedIOBase,
expected_size: float | None,
show_progress: bool,
):
Expand All @@ -165,39 +154,44 @@ def _download_with_progress(
"Cannot determine if the file has been updated on the remote source. "
+ "'Last-Modified' header not present in server response."
)
_logger.debug(f"Downloading from {url} to {output}...")

response: Response | None = None
_logger.debug(f"Downloading from {url} to {output}...")
bar = None

try:
response = requests.get(url, timeout=5, stream=True)

if (content_length := response.headers.get("Content-Length")) is not None:
expected_size = float(content_length)

if show_progress:
bar = tqdm(
desc=url,
total=expected_size,
unit="B",
unit_scale=True,
unit_divisor=_CHUNK_SIZE,
nrows=2,
disable=_DISABLE_SHOW_PROGRESS,
)

for chunk in response.iter_content(_CHUNK_SIZE):
if not chunk:
break
output.write(chunk)
with requests.get(url, timeout=5, stream=True) as response:
buffer: BufferedWriter | BufferedIOBase | None = None
try:
if isinstance(output, (str,)):
buffer = open(output, "wb")
else:
buffer = output

if (content_length := response.headers.get("Content-Length")) is not None:
expected_size = float(content_length)

if show_progress:
bar = tqdm(
desc=url,
total=expected_size,
unit="B",
unit_scale=True,
unit_divisor=_CHUNK_SIZE,
nrows=2,
disable=_DISABLE_SHOW_PROGRESS,
)
for chunk in response.iter_content(_CHUNK_SIZE):
if not chunk:
break

Check warning on line 184 in clouddrift/adapters/utils.py

View check run for this annotation

Codecov / codecov/patch

clouddrift/adapters/utils.py#L184

Added line #L184 was not covered by tests
buffer.write(chunk)
if bar is not None:
bar.update(len(chunk))
finally:
if response is not None:
response.close()
if bar is not None:
bar.update(len(chunk))
finally:
if response is not None:
response.close()
if bar is not None:
bar.close()
bar.close()
if buffer is not None and isinstance(output, (str,)):
buffer.close()


__all__ = ["download_with_progress"]
8 changes: 5 additions & 3 deletions tests/adapters/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def setUp(self) -> None:
self.get_response_mock = Mock()
self.get_response_mock.headers = dict()
self.get_response_mock.iter_content = Mock(return_value=["a", "b", "c"])
self.get_response_mock.__enter__ = Mock(return_value=self.get_response_mock)
self.get_response_mock.__exit__ = Mock()

self.requests_mock = Mock()
self.requests_mock.head = Mock(return_value=self.head_response_mock)
Expand Down Expand Up @@ -92,9 +94,9 @@ def test_progress_mechanism_disabled_files(self):
"""
mocked_futures = [self.gen_future_mock() for _ in range(0, 3)]
download_requests = [
("src0", "dst", None),
("src1", "dst", None),
("src2", "dst", None),
("src0", "dst"),
("src1", "dst"),
("src2", "dst"),
]

tpe_mock = Mock()
Expand Down
1 change: 1 addition & 0 deletions tests/sphere_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def test_bearing(self):
self.assertTrue(np.isclose(bearing(0, 0, 0, -0.1), -np.pi / 2))
self.assertTrue(np.isclose(bearing(0, 0, 0.1, -0.1), -np.pi / 4))


class position_from_distance_and_bearing_tests(unittest.TestCase):
def test_position_from_distance_and_bearing_one_degree(self):
self.assertTrue(
Expand Down

0 comments on commit e09508a

Please sign in to comment.