Skip to content

Commit

Permalink
Update download method to use tempfile
Browse files Browse the repository at this point in the history
  • Loading branch information
Paulooh007 committed Oct 9, 2023
1 parent 7716c11 commit 816ddb0
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions laser_encoders/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import os
import sys
import tempfile
from pathlib import Path

import requests
Expand Down Expand Up @@ -48,26 +49,24 @@ def download(self, filename: str):
url = os.path.join(self.base_url, filename)

local_file_path = os.path.join(self.model_dir, filename)
temp_file_path = os.path.join("/tmp", filename)

if os.path.exists(local_file_path):
logger.info(f" - {filename} already downloaded")
else:
logger.info(f" - Downloading {filename}")

if os.path.exists(temp_file_path):
os.remove(temp_file_path)
tf = tempfile.NamedTemporaryFile(delete=False)
temp_file_path = tf.name

response = requests.get(url, stream=True)
total_size = int(response.headers.get("Content-Length", 0))
progress_bar = tqdm(total=total_size, unit_scale=True, unit="B")
with tf:
response = requests.get(url, stream=True)
total_size = int(response.headers.get("Content-Length", 0))
progress_bar = tqdm(total=total_size, unit_scale=True, unit="B")

# Download to /tmp first
with open(temp_file_path, "wb") as f:
for chunk in response.iter_content(chunk_size=1024):
f.write(chunk)
tf.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
progress_bar.close()

os.rename(temp_file_path, local_file_path)

Expand Down

0 comments on commit 816ddb0

Please sign in to comment.