Skip to content

Commit

Permalink
Merge pull request #296 from DRosen766/fix_tilde_in_dataset_path
Browse files Browse the repository at this point in the history
Fix tilde in dataset path
  • Loading branch information
biphasic authored Dec 15, 2024
2 parents c62db86 + 30d993b commit 82ff84c
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 17 deletions.
20 changes: 13 additions & 7 deletions test/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
import unittest
from typing import Any, Dict, Union
from unittest.mock import patch

import pytest
import numpy as np
import os

# Location of the files to be saved and extracted by the datasets during testing
TEST_LOCATION_ON_SYSTEM = "~/../../tmp"
TEST_LOCATION_ON_SYSTEM = os.path.expanduser(TEST_LOCATION_ON_SYSTEM)


class DatasetTestCase(unittest.TestCase):
DATASET_CLASS = None
FEATURE_TYPES = None

_CHECK_FUNCTIONS = {"check_md5", "check_integrity", "check_exists"}
_DOWNLOAD_EXTRACT_FUNCTIONS = {
"download_url",
Expand Down Expand Up @@ -41,8 +45,7 @@ def inject_fake_data(
)

def create_dataset(self, inject_fake_data: bool = True, **kwargs: Any):
tmpdir = "/tmp/"
info = self._inject_fake_data(tmpdir)
info = self._inject_fake_data(TEST_LOCATION_ON_SYSTEM)

if inject_fake_data:
with patch.object(self.DATASET_CLASS, "_check_exists", return_value=True):
Expand Down Expand Up @@ -81,7 +84,7 @@ def test_feature_types(self):
assert len(data) == len(self.FEATURE_TYPES)
assert len(target) == len(self.TARGET_TYPES)

for (data_piece, feature_type) in zip(data, self.FEATURE_TYPES):
for data_piece, feature_type in zip(data, self.FEATURE_TYPES):
if type(data_piece) == np.ndarray:
assert data_piece.dtype == feature_type
else:
Expand All @@ -93,6 +96,9 @@ def test_num_examples(self):

@classmethod
def setUpClass(cls):
cls.KWARGS.update({"save_to": "/tmp"})
shutil.rmtree("/tmp/" + cls.DATASET_CLASS.__name__, ignore_errors=True)
cls.KWARGS.update({"save_to": TEST_LOCATION_ON_SYSTEM})
shutil.rmtree(
f"{TEST_LOCATION_ON_SYSTEM}/" + cls.DATASET_CLASS.__name__,
ignore_errors=True,
)
super().setUpClass()
8 changes: 4 additions & 4 deletions test/torch_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
--index-url https://download.pytorch.org/whl/cpu
torch==2.1.0
torchaudio==2.1.0
torchvision==0.16.0
torchdata
torch==2.3.0
torchaudio==2.3.0
torchvision==0.18.0
torchdata<=0.8.0
2 changes: 1 addition & 1 deletion tonic/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
self.location_on_system = os.path.join(save_to, self.__class__.__name__)
self.location_on_system = os.path.join(os.path.expanduser(save_to), self.__class__.__name__)
self.transform = transform
self.target_transform = target_transform
self.transforms = transforms
Expand Down
8 changes: 3 additions & 5 deletions tonic/download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No
with zipfile.ZipFile(
from_path,
"r",
compression=_ZIP_COMPRESSION_MAP[compression]
if compression
else zipfile.ZIP_STORED,
compression=(
_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
),
) as zip:
zip.extractall(to_path)

Expand Down Expand Up @@ -313,12 +313,10 @@ def download_and_extract_archive(
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)

download_url(url, download_root, filename, md5)

archive = os.path.join(download_root, filename)
Expand Down

0 comments on commit 82ff84c

Please sign in to comment.