Skip to content

Commit

Permalink
Merge pull request #233 from allenai/shanea/fix-r2-profile
Browse files Browse the repository at this point in the history
Fix R2 profile and errors
  • Loading branch information
2015aroras authored Jun 10, 2024
2 parents 7008167 + 45ead9f commit a45c796
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Updated dependencies

- Fix authentication with AWS profile for R2
- Make R2 throw FileNotFoundError instead of botocore.client.ClientError when object does not exist.

## [v1.6.0](https://github.com/allenai/cached_path/releases/tag/v1.6.0) - 2024-02-22

Expand Down
2 changes: 1 addition & 1 deletion cached_path/schemes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from .gs import GsClient
from .hf import hf_get_from_cache
from .http import HttpClient
from .s3 import S3Client
from .r2 import R2Client
from .s3 import S3Client
from .scheme_client import SchemeClient

__all__ = ["GsClient", "HttpClient", "S3Client", "R2Client", "SchemeClient", "hf_get_from_cache"]
Expand Down
23 changes: 16 additions & 7 deletions cached_path/schemes/r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import os
from typing import Optional

import boto3.dynamodb
import boto3.session
from botocore.config import Config
import botocore.exceptions
from botocore.config import Config

from .scheme_client import SchemeClient
from ..common import _split_cloud_path
from .scheme_client import SchemeClient


class R2Client(SchemeClient):
Expand All @@ -36,30 +37,36 @@ def __init__(self, resource: str) -> None:
access_key_id = os.environ.get("R2_ACCESS_KEY_ID")
secret_access_key = os.environ.get("R2_SECRET_ACCESS_KEY")
if access_key_id is not None and secret_access_key is not None:
client_kwargs = {
session_kwargs = {
"aws_access_key_id": access_key_id,
"aws_secret_access_key": secret_access_key,
}
elif profile_name is not None:
client_kwargs = {"profile_name": profile_name}
session_kwargs = {"profile_name": profile_name}
else:
raise ValueError(
"To authenticate for R2, you either have to set the 'R2_PROFILE' env var and set up this profile, "
"or set R2_ACCESS_KEY_ID and R2_SECRET_ACCESS_KEY."
)

self.s3 = boto3.client(
s3_session = boto3.session.Session(**session_kwargs)

self.s3 = s3_session.client(
service_name="s3",
endpoint_url=endpoint_url,
region_name="auto",
config=Config(retries={"max_attempts": 10, "mode": "standard"}),
**client_kwargs,
)
self.object_info = None

def _ensure_object_info(self):
if self.object_info is None:
self.object_info = self.s3.head_object(Bucket=self.bucket_name, Key=self.path)
try:
self.object_info = self.s3.head_object(Bucket=self.bucket_name, Key=self.path)
except botocore.exceptions.ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
raise FileNotFoundError(f"File {self.resource} not found") from e
raise

def get_etag(self) -> Optional[str]:
self._ensure_object_info()
Expand All @@ -72,9 +79,11 @@ def get_size(self) -> Optional[int]:
return self.object_info.get("ContentLength")

def get_resource(self, temp_file: io.BufferedWriter) -> None:
self._ensure_object_info()
self.s3.download_fileobj(Fileobj=temp_file, Bucket=self.bucket_name, Key=self.path)

def get_bytes_range(self, index: int, length: int) -> bytes:
self._ensure_object_info()
response = self.s3.get_object(
Bucket=self.bucket_name, Key=self.path, Range=f"bytes={index}-{index+length-1}"
)
Expand Down

0 comments on commit a45c796

Please sign in to comment.