Skip to content

Commit

Permalink
Add get_bytes_range() function (#196)
Browse files Browse the repository at this point in the history
* Add `get_bytes_range()` function

* CHANGELOG.md

* Add tests, resort to default behavior for HTTP
  • Loading branch information
epwalsh authored Oct 6, 2023
1 parent 40b0216 commit f435b48
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 13 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `get_bytes_range()` function.

## [v1.4.0](https://github.com/allenai/cached_path/releases/tag/v1.4.0) - 2023-08-02

### Added
Expand Down
2 changes: 2 additions & 0 deletions cached_path/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""

from ._cached_path import cached_path
from .bytes_range import get_bytes_range
from .common import get_cache_dir, set_cache_dir
from .progress import get_download_progress
from .schemes import SchemeClient, add_scheme_client
Expand All @@ -24,6 +25,7 @@

__all__ = [
"cached_path",
"get_bytes_range",
"get_cache_dir",
"set_cache_dir",
"get_download_progress",
Expand Down
14 changes: 11 additions & 3 deletions cached_path/_cached_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from .common import PathOrStr, get_cache_dir
from .file_lock import FileLock
from .meta import Meta
from .schemes import get_scheme_client, get_supported_schemes, hf_get_from_cache
from .schemes import (
SchemeClient,
get_scheme_client,
get_supported_schemes,
hf_get_from_cache,
)
from .util import (
_lock_file_path,
_meta_file_path,
Expand Down Expand Up @@ -269,6 +274,8 @@ def get_from_cache(
cache_dir: Optional[PathOrStr] = None,
quiet: bool = False,
progress: Optional["Progress"] = None,
no_downloads: bool = False,
_client: Optional[SchemeClient] = None,
) -> Tuple[Path, Optional[str]]:
"""
Given a URL, look for the corresponding dataset in the local cache.
Expand All @@ -279,8 +286,7 @@ def get_from_cache(

cache_dir = Path(cache_dir if cache_dir else get_cache_dir()).expanduser()
cache_dir.mkdir(parents=True, exist_ok=True)

client = get_scheme_client(url)
client = _client or get_scheme_client(url)

# Get eTag to add to filename, if it exists.
try:
Expand Down Expand Up @@ -327,6 +333,8 @@ def get_from_cache(
with FileLock(_lock_file_path(cache_path), read_only_ok=True):
if os.path.exists(cache_path):
logger.info("cache of %s is up-to-date", url)
elif no_downloads:
raise FileNotFoundError(cache_path)
else:
size = client.get_size()
with CacheFile(cache_path) as cache_file:
Expand Down
140 changes: 140 additions & 0 deletions cached_path/bytes_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import TYPE_CHECKING, Optional
from urllib.parse import urlparse

from ._cached_path import cached_path, get_from_cache
from .common import PathOrStr
from .schemes import get_scheme_client, get_supported_schemes

if TYPE_CHECKING:
from rich.progress import Progress


def get_bytes_range(
url_or_filename: PathOrStr,
index: int,
length: int,
cache_dir: Optional[PathOrStr] = None,
extract_archive: bool = False,
force_extract: bool = False,
quiet: bool = False,
progress: Optional["Progress"] = None,
) -> bytes:
"""
Get a range of up to ``length`` bytes starting at ``index``.
In some cases the entire file may need to be downloaded, such as when the server does not support
a range download or when you're trying to get a bytes range from a file within an archive.
.. caution::
You may get less than ``length`` bytes sometimes, such as when fetching a range from an HTTP
resource starting at 0 since headers will be omitted in the bytes returned.
Parameters
----------
url_or_filename :
A URL or path to parse and possibly download.
index :
The index of the byte to start at.
length :
The number of bytes to read.
cache_dir :
The directory to cache downloads. If not specified, the global default cache directory
will be used (``~/.cache/cached_path``). This can be set to something else with
:func:`set_cache_dir()`.
This is only relevant when the bytes range cannot be obtained directly from the resource.
extract_archive :
Set this to ``True`` when you want to get a bytes range from a file within an archive.
In this case the ``url_or_filename`` must contain an "!" followed by the relative path of the file
within the archive, e.g. "s3://my-archive.tar.gz!my-file.txt".
Note that the entire archive has to be downloaded in this case.
force_extract :
If ``True`` and the resource is a file within an archive (when the path contains an "!" and
``extract_archive=True``), it will be extracted regardless of whether or not the extracted
directory already exists.
.. caution::
Use this flag with caution! This can lead to race conditions if used
from multiple processes on the same file.
quiet :
If ``True``, progress displays won't be printed.
This is only relevant when the bytes range cannot be obtained directly from the resource.
progress :
A custom progress display to use. If not set and ``quiet=False``, a default display
from :func:`~cached_path.get_download_progress()` will be used.
This is only relevant when the bytes range cannot be obtained directly from the resource.
"""
if not isinstance(url_or_filename, str):
url_or_filename = str(url_or_filename)

# If we're using the /a/b/foo.zip!c/d/file.txt syntax, handle it here.
exclamation_index = url_or_filename.find("!")
if extract_archive and exclamation_index >= 0:
archive_path = url_or_filename[:exclamation_index]
file_name = url_or_filename[exclamation_index + 1 :]

# Call 'cached_path' now to get the local path to the archive itself.
cached_archive_path = cached_path(
archive_path,
cache_dir=cache_dir,
extract_archive=True,
force_extract=force_extract,
quiet=quiet,
progress=progress,
)
if not cached_archive_path.is_dir():
raise ValueError(
f"{url_or_filename} uses the ! syntax, but does not specify an archive file."
)

# Now load bytes from the desired file within the extracted archive, provided it exists.
file_path = cached_archive_path / file_name
if not file_path.exists():
raise FileNotFoundError(f"'{file_name}' not found within '{archive_path}'")

return _bytes_range_from_file(file_path, index, length)

if urlparse(url_or_filename).scheme in get_supported_schemes():
# URL, so use the scheme client.
client = get_scheme_client(url_or_filename)

# Check if file is already downloaded.
try:
cache_path, _ = get_from_cache(
url_or_filename,
cache_dir=cache_dir,
quiet=quiet,
progress=progress,
no_downloads=True,
_client=client,
)
return _bytes_range_from_file(cache_path, index, length)
except FileNotFoundError:
pass

# Otherwise try streaming bytes directly.
try:
return client.get_bytes_range(index, length)
except NotImplementedError:
# fall back to downloading the whole file.
pass

file_path = cached_path(url_or_filename, cache_dir=cache_dir, quiet=quiet, progress=progress)
return _bytes_range_from_file(file_path, index, length)


def _bytes_range_from_file(path: PathOrStr, index: int, length: int) -> bytes:
with open(path, "rb") as f:
f.seek(index)
return f.read(length)
7 changes: 6 additions & 1 deletion cached_path/schemes/gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from google.api_core.exceptions import NotFound
from google.auth.exceptions import DefaultCredentialsError
from google.cloud import storage
from google.cloud.storage.blob import Blob
from google.cloud.storage.retry import DEFAULT_RETRY

from ..common import _split_cloud_path
Expand Down Expand Up @@ -42,12 +43,16 @@ def get_resource(self, temp_file: io.BufferedWriter) -> None:
self.load()
self.blob.download_to_file(temp_file, checksum="md5", retry=DEFAULT_RETRY)

def get_bytes_range(self, index: int, length: int) -> bytes:
self.load()
return self.blob.download_as_bytes(start=index, end=index + length - 1)

@staticmethod
def split_gcs_path(resource: str) -> Tuple[str, str]:
return _split_cloud_path(resource, "gs")

@staticmethod
def get_gcs_blob(resource: str) -> storage.blob.Blob:
def get_gcs_blob(resource: str) -> Blob:
try:
gcs_resource = storage.Client()
except DefaultCredentialsError:
Expand Down
18 changes: 18 additions & 0 deletions cached_path/schemes/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,24 @@ def get_resource(self, temp_file: io.BufferedWriter) -> None:
if chunk: # filter out keep-alive new chunks
temp_file.write(chunk)

# TODO (epwalsh): There may be a better way to do this, but...
# HTTP range requests don't necessarily match our expectation in this context. For example, the range might
# implicitly include header data, but we usually don't care about that. The server might also
# interpret the range relative to an encoding of the data, not the underlying data itself.
# So to avoid unexpected behavior we resort to the default behavior of downloading the whole file
# and returning the desired bytes range from the cached content.
# def get_bytes_range(self, index: int, length: int) -> bytes:
# with session_with_backoff() as session:
# try:
# response = session.get(
# self.resource, headers={"Range": f"bytes={index}-{index+length-1}"}
# )
# except MaxRetryError as e:
# raise RecoverableServerError(e.reason)
# self.validate_response(response)
# # 'content' might contain the full file if the server doesn't support the "Range" header.
# return response.content[:length]

def validate_response(self, response):
if response.status_code == 404:
raise FileNotFoundError(self.resource)
Expand Down
12 changes: 9 additions & 3 deletions cached_path/schemes/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
import io
from typing import Optional, Tuple

import boto3
import botocore
import boto3.session
import botocore.client
import botocore.exceptions

from ..common import _split_cloud_path
from .scheme_client import SchemeClient


class S3Client(SchemeClient):
recoverable_errors = SchemeClient.recoverable_errors + (
botocore.exceptions.EndpointConnectionError,
botocore.exceptions.HTTPClientError,
botocore.exceptions.ConnectionError,
)
scheme = "s3"

Expand Down Expand Up @@ -55,6 +57,10 @@ def get_resource(self, temp_file: io.BufferedWriter) -> None:
self.load()
self.s3_object.download_fileobj(temp_file)

def get_bytes_range(self, index: int, length: int) -> bytes:
self.load()
return self.s3_object.get(Range=f"bytes={index}-{index + length - 1}")["Body"].read()

@staticmethod
def split_s3_path(url: str) -> Tuple[str, str]:
return _split_cloud_path(url, "s3")
16 changes: 13 additions & 3 deletions cached_path/schemes/scheme_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import io
from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import ClassVar, Optional, Tuple, Type, Union

import requests


class SchemeClient:
class SchemeClient(ABC):
"""
A client used for caching remote resources corresponding to URLs with a particular scheme.
Expand Down Expand Up @@ -109,6 +109,16 @@ def get_resource(self, temp_file: io.BufferedWriter) -> None:
``Other errors``
Any other error type can be raised. These errors will be treated non-recoverable
and will be propogated immediately by ``cached_path()``.
and will be propagated immediately by ``cached_path()``.
"""
raise NotImplementedError

def get_bytes_range(self, index: int, length: int) -> bytes:
"""
Get a sequence of ``length`` bytes from the resource, starting at ``index`` bytes.
If a scheme provides a direct way of downloading a bytes range, the scheme client
should implement that. Otherwise the entire file has to be downloaded.
"""
del index, length
raise NotImplementedError
4 changes: 4 additions & 0 deletions docs/source/api/get_bytes_range.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
``get_bytes_range()``
=================

.. autofunction:: cached_path.get_bytes_range
1 change: 1 addition & 0 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ faq
:caption: API Docs
api/cached_path
api/get_bytes_range
api/util
```

Expand Down
9 changes: 8 additions & 1 deletion docs/source/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Overview
The main functionality of **cached-path** is provided by the function {func}`~cached_path.cached_path()`.

```{testsetup}
>>> from cached_path import cached_path, set_cache_dir
>>> from cached_path import cached_path, set_cache_dir, get_bytes_range
>>> set_cache_dir(cache_dir)
>>>
```
Expand Down Expand Up @@ -40,6 +40,13 @@ and the local path returned from `cached_path()` would point to the newly downlo
There are multiple ways to [change the cache directory](#overriding-the-default-cache-directory).
```

You can also get a range of bytes directly using {func}`~cached_path.get_bytes_range()`. For example:

```python
>>> get_bytes_range("https://raw.githubusercontent.com/allenai/cached_path/main/README.md", 0, 100)
b'# [cached-path](https://cached-path.readt'
```

## Supported URL schemes

In addition to `http` and `https`, {func}`~cached_path.cached_path()` supports several other schemes such as `s3` (AWS S3), `gs` (Google Cloud Storage),
Expand Down
Loading

0 comments on commit f435b48

Please sign in to comment.