Skip to content

Commit

Permalink
Merge pull request #462 from DagsHub/data-engine/to-hf-dataset
Browse files Browse the repository at this point in the history
Data engine: Add an `as_hf_dataset` to QueryResult
  • Loading branch information
kbolashev authored Apr 9, 2024
2 parents cfd3118 + d729d3c commit c764525
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 8 deletions.
13 changes: 11 additions & 2 deletions dagshub/data_engine/model/datapoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from os import PathLike
from pathlib import Path
from typing import Optional, Union, List, Dict, Any, Callable, TYPE_CHECKING
from typing import Optional, Union, List, Dict, Any, Callable, TYPE_CHECKING, Literal

from dagshub.common.download import download_files
from dagshub.common.helpers import http_request
Expand Down Expand Up @@ -229,7 +229,12 @@ def blob_url(self, sha):


def _get_blob(
url: Optional[str], cache_path: Optional[Path], auth, cache_on_disk: bool, return_blob: bool
url: Optional[str],
cache_path: Optional[Path],
auth,
cache_on_disk: bool,
return_blob: bool,
path_format: Literal["str", "path"] = "path",
) -> Optional[Union[Path, str, bytes]]:
"""
Args:
Expand All @@ -248,6 +253,8 @@ def _get_blob(
with cache_path.open("rb") as f:
return f.read()
else:
if path_format == "str":
cache_path = str(cache_path)
return cache_path

try:
Expand All @@ -266,4 +273,6 @@ def _get_blob(
if return_blob:
return content
else:
if path_format == "str":
cache_path = str(cache_path)
return cache_path
71 changes: 65 additions & 6 deletions dagshub/data_engine/model/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
add_ls_annotations,
)
from dagshub.data_engine.client.models import DatasourceType
from dagshub.data_engine.model.datapoint import Datapoint, _get_blob
from dagshub.data_engine.model.datapoint import Datapoint, _get_blob, _generated_fields
from dagshub.data_engine.client.loaders.base import DagsHubDataset
from dagshub.data_engine.voxel_plugin_server.utils import set_voxel_envvars
from dagshub.data_engine.dtypes import MetadataFieldType
Expand All @@ -28,10 +28,12 @@
from dagshub.data_engine.model.datasource import Datasource
import fiftyone as fo
import dagshub.data_engine.voxel_plugin_server.server as plugin_server_module
import datasets as hf_ds
else:
plugin_server_module = lazy_load("dagshub.data_engine.voxel_plugin_server.server")
fo = lazy_load("fiftyone")
tf = lazy_load("tensorflow")
hf_ds = lazy_load("datasets")

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -227,6 +229,49 @@ def keypairs(keys):
else:
raise ValueError("supported flavors are torch|tensorflow|<torch.utils.data.Dataset>|<tf.data.Dataset>")

def as_hf_dataset(
self, target_dir: Optional[Union[str, PathLike]] = None, download_datapoints=True, download_blobs=True
):
"""
Loads this QueryResult as a HuggingFace dataset.
The paths of the downloads are set to the local paths in the filesystem, so they can be used with
a ``cast_column()`` function later.
Args:
target_dir: Where to download the datapoints. The metadata is still downloaded into the global cache.
download_datapoints: If set to ``True`` (default), downloads the datapoint files and sets the path column\
to the path of the datapoint in the filesystem
download_blobs: If set to ``True`` (default), downloads all blob fields and sets the respective column\
to the path of the file in the filesystem.
"""
if download_blobs:
# Download blobs as paths, so later a user can apply ds.cast_column on the blobs
self.get_blob_fields(load_into_memory=False, path_format="str")

df = self.dataframe

if download_datapoints:
# Do the same for the actual datapoint files, changing the path

if target_dir is None:
target_dir = self.datasource.default_dataset_location
elif isinstance(target_dir, str):
target_dir = Path(target_dir).absolute()
new_paths = []
self.download_files(target_dir=target_dir)
for dp in df["path"]:
new_paths.append(str(target_dir / self.datasource.source.source_prefix / dp))
df["path"] = new_paths

# Drop the generated fields
for f in _generated_fields.keys():
if f == "path":
continue
df.drop(f, axis=1, inplace=True)

return hf_ds.Dataset.from_pandas(df)

def __getitem__(self, item: Union[str, int, slice]):
"""
Gets datapoint by its path (string) or by its index in the result (or slice)
Expand All @@ -244,7 +289,12 @@ def __getitem__(self, item: Union[str, int, slice]):
)

def get_blob_fields(
self, *fields: str, load_into_memory=False, cache_on_disk=True, num_proc: int = config.download_threads
self,
*fields: str,
load_into_memory=False,
cache_on_disk=True,
num_proc: int = config.download_threads,
path_format: Literal["str", "path"] = "path",
) -> "QueryResult":
"""
Downloads data from blob fields
Expand All @@ -260,6 +310,8 @@ def get_blob_fields(
cache_on_disk: Whether to cache the blobs on disk or not (valid only if load_into_memory is set to True)
Cache location is ``~/dagshub/datasets/<repo>/<datasource_id>/.metadata_blobs/``
num_proc: number of download threads
path_format: What way the paths to the file should be represented.
``path`` returns a Path object, and ``str`` returns a string of this path.
"""
send_analytics_event("Client_DataEngine_downloadBlobs", repo=self.datasource.source.repoApi)
if not load_into_memory:
Expand Down Expand Up @@ -301,8 +353,8 @@ def get_blob_fields(
auth = self.datasource.source.repoApi.auth

def _get_blob_fn(dp: Datapoint, field: str, url: str, blob_path: Path):
blob_or_path = _get_blob(url, blob_path, auth, cache_on_disk, load_into_memory)
if isinstance(blob_or_path, str):
blob_or_path = _get_blob(url, blob_path, auth, cache_on_disk, load_into_memory, path_format)
if isinstance(blob_or_path, str) and path_format != "str":
logger.warning(f"Error while downloading blob for field {field} in datapoint {dp.path}:{blob_or_path}")
dp.metadata[field] = blob_or_path

Expand All @@ -318,15 +370,22 @@ def _get_blob_fn(dp: Datapoint, field: str, url: str, blob_path: Path):
return self

def download_binary_columns(
self, *columns: str, load_into_memory=True, cache_on_disk=True, num_proc: int = 32
self,
*columns: str,
load_into_memory=True,
cache_on_disk=True,
num_proc: int = 32,
) -> "QueryResult":
"""
deprecated: Use get_blob_fields instead.
:meta private:
"""
return self.get_blob_fields(
*columns, load_into_memory=load_into_memory, cache_on_disk=cache_on_disk, num_proc=num_proc
*columns,
load_into_memory=load_into_memory,
cache_on_disk=cache_on_disk,
num_proc=num_proc,
)

def download_files(
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pytest-git==1.7.0
pytest-env==1.1.3
pytest-mock==3.14.0
fiftyone==0.23.7
datasets==2.18.0

0 comments on commit c764525

Please sign in to comment.