Skip to content

Commit

Permalink
Merge pull request #837 from CitrineInformatics/feature/file-ingest
Browse files Browse the repository at this point in the history
Add files ingest method
  • Loading branch information
kroenlein authored Mar 28, 2023
2 parents e6fdb36 + ce02140 commit 19b7c03
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.9.0'
__version__ = '2.10.0'
4 changes: 3 additions & 1 deletion src/citrine/_rest/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def _put_resource_ref(self, subpath: str, uid: Union[UUID, str]):
ref = ResourceRef(uid)
return self.session.put_resource(url, ref.dump(), version=self._api_version)

def _get_path(self, uid: Optional[Union[UUID, str]] = None,
def _get_path(self,
uid: Optional[Union[UUID, str]] = None,
*,
ignore_dataset: Optional[bool] = False) -> str:
"""Construct a url from __base_path__ and, optionally, id."""
subpath = format_escaped_url('/{}', uid) if uid else ''
Expand Down
82 changes: 61 additions & 21 deletions src/citrine/resources/file_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path
from enum import Enum
from logging import getLogger
from typing import Optional, Tuple, Union, List, Dict
from typing import Optional, Tuple, Union, List, Dict, Iterable
from urllib.parse import urlparse, quote
from uuid import UUID

Expand Down Expand Up @@ -182,8 +182,8 @@ def __init__(self, project_id: UUID, dataset_id: UUID, session: Session):

def _get_path(self,
uid: Optional[Union[UUID, str]] = None,
ignore_dataset: Optional[bool] = False,
*,
ignore_dataset: Optional[bool] = False,
version: Union[str, UUID] = None,
action: str = None) -> str:
"""Build the path for taking an action with a particular file version."""
Expand Down Expand Up @@ -248,7 +248,7 @@ def get(self,
*,
version: Optional[Union[UUID, str, int]] = None) -> FileLink:
"""
Get an element of the collection by its id.
Retrieve an on-platform FileLink from its filename or file uuid.
Parameters
----------
Expand Down Expand Up @@ -359,7 +359,7 @@ def _make_upload_request(self, file_path: Path, dest_name: str):
aws_session_token, bucket, object_key, & upload_id.
"""
path = self._get_path() + "/uploads"
path = self._get_path(action="uploads")
mime_type = self._mime_type(file_path)
file_size = file_path.stat().st_size
assert isinstance(file_size, int)
Expand Down Expand Up @@ -423,7 +423,7 @@ def _search_by_file_name(self,
All the data needed for a file.
"""
path = self._get_path() + "/search"
path = self._get_path(action="search")

search_json = {
'fileSearchFilter':
Expand Down Expand Up @@ -456,7 +456,7 @@ def _search_by_file_version_id(self,
All the data needed for a file.
"""
path = self._get_path() + "/search"
path = self._get_path(action="search")

search_json = {
'fileSearchFilter': {
Expand Down Expand Up @@ -495,7 +495,7 @@ def _search_by_dataset_file_id(self,
All the data needed for a file.
"""
path = self._get_path() + "/search"
path = self._get_path(action="search")

search_json = {
'fileSearchFilter': {
Expand Down Expand Up @@ -644,14 +644,12 @@ def read(self, *, file_link: Union[str, UUID, FileLink]):

if self._is_external_url(file_link.url): # Pull it from where ever it lives
final_url = file_link.url
elif self._validate_local_url(file_link.url):
else:
# The "/content-link" route returns a pre-signed url to download the file.
content_link = self._get_path_from_file_link(file_link, action='content-link')
content_link_response = self.session.get_resource(content_link)
pre_signed_url = content_link_response['pre_signed_read_link']
final_url = rewrite_s3_links_locally(pre_signed_url, self.session.s3_endpoint_url)
else: # Unrecognized
raise ValueError(f"URL was malformed for a local file resource ({file_link.url}).")

download_response = requests.get(final_url)
return download_response.content
Expand Down Expand Up @@ -690,10 +688,10 @@ def process(self, *, file_link: Union[FileLink, str, UUID],
A JobSubmissionResponse which can be used to poll for the result.
"""
file_link = self._resolve_file_link(file_link)
if not self._validate_local_url(file_link.url):
if self._is_external_url(file_link.url):
raise ValueError(f"Only on-platform resources can be processed. "
f"Passed URL {file_link.url}.")
file_link = self._resolve_file_link(file_link)

params = {"processing_type": processing_type.value}
response = self.session.put_resource(
Expand Down Expand Up @@ -797,6 +795,38 @@ def file_processing_result(self, *,

return results

def ingest(self, files: Iterable[FileLink]):
"""
[ALPHA] Ingest a set of CSVs and/or Excel Workbooks formatted per the gemd-ingest protocol.
Parameters
----------
files: List[FileLink]
A list of files, already on platform, from which GEMD objects should be built
"""
targets = [self._resolve_file_link(f) for f in files]
if any(self._is_external_url(f.url) for f in targets):
externals = [f.url for f in targets if self._is_external_url(f.url)]
raise ValueError(f"All files must be on-platform to load them. "
f"The following are not: {externals}")

file_infos = [
{"dataset_file_id": str(f.uid),
"file_version_uuid": str(f.version)
}
for f in targets]
req = {
"project_id": str(self.project_id),
"dataset_id": str(self.dataset_id),
"files": file_infos
}
base_url = format_escaped_url("/projects/{}/ingestions", self.project_id)
create_ingestion_resp = self.session.post_resource(path=base_url, json=req)
ingestion_id = create_ingestion_resp["ingestion_id"]
job_url = base_url + format_escaped_url("/{}/gemd-objects", ingestion_id)
return self.session.post_resource(path=job_url, json={})

def delete(self, file_link: FileLink):
"""
Delete the file associated with a given FileLink from the database.
Expand All @@ -817,8 +847,25 @@ def delete(self, file_link: FileLink):

def _resolve_file_link(self, identifier: Union[str, UUID, FileLink]) -> FileLink:
"""Generate the FileLink object referenced by the passed argument."""
if isinstance(identifier, FileLink): # Passthrough for convenience
return identifier
if isinstance(identifier, GEMDFileLink):
if isinstance(identifier, FileLink) and identifier.uid is not None:
# Passthrough since it's as full as it can get
return identifier
if self._is_external_url(identifier.url):
# Up-convert type with existing info
return FileLink(filename=identifier.filename, url=identifier.url)
# Resolve on-platform uid and possibly up-convert
file_id, version_id = self._get_ids_from_url(identifier.url)
if file_id is None:
raise ValueError(f"URL was malformed for local resources; "
f"passed URL {identifier.url}")
platform_link = self.get(uid=file_id, version=version_id)
if platform_link.filename != identifier.filename:
raise ValueError(
f"Name mismatch between link ({identifier.filename}) "
f"and platform ({platform_link.filename})"
)
return platform_link
elif isinstance(identifier, str) and self._is_external_url(identifier):
# Assume it's an absolute URL
filename = urlparse(identifier).path.split('/')[-1]
Expand Down Expand Up @@ -848,10 +895,3 @@ def _is_external_url(self, url: str):
return False

return urlparse(self._get_path()).netloc != parsed.netloc

def _validate_local_url(self, url):
"""Verify link is well formed."""
if self._is_external_url(url):
return False

return self._get_ids_from_url(url)[1] is not None # Implies file_id is None, too
83 changes: 46 additions & 37 deletions tests/resources/test_file_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,10 @@ def test_file_download(collection: FileCollection, session, tmpdir):
"""
# Given
filename = 'diagram.pdf'
url = f"projects/{collection.project_id}/datasets/{collection.dataset_id}/files/{uuid4()}/versions/{uuid4()}"
file = FileLink.build(FileLinkDataFactory(url=url, filename=filename))
file_uid = str(uuid4())
version_uid = str(uuid4())
url = f"projects/{collection.project_id}/datasets/{collection.dataset_id}/files/{file_uid}/versions/{version_uid}"
file = FileLink.build(FileLinkDataFactory(url=url, filename=filename, id=file_uid, version=version_uid))
pre_signed_url = "http://files.citrine.io/secret-codes/jiifema987pjfsda" # arbitrary
session.set_response({
'pre_signed_read_link': pre_signed_url,
Expand Down Expand Up @@ -403,8 +405,10 @@ def test_read(collection: FileCollection, session):
"""
# Given
filename = 'diagram.pdf'
url = f"projects/{collection.project_id}/datasets/{collection.dataset_id}/files/{uuid4()}/versions/{uuid4()}"
file = FileLink.build(FileLinkDataFactory(url=url, filename=filename))
file_uid = str(uuid4())
version_uid = str(uuid4())
url = f"projects/{collection.project_id}/datasets/{collection.dataset_id}/files/{file_uid}/versions/{version_uid}"
file = FileLink.build(FileLinkDataFactory(url=url, filename=filename, id=file_uid, version=version_uid))
pre_signed_url = "http://files.citrine.io/secret-codes/jiifema987pjfsda" # arbitrary
session.set_response({
'pre_signed_read_link': pre_signed_url,
Expand Down Expand Up @@ -499,8 +503,8 @@ def test_process_file(collection: FileCollection, session):
"""Test processing an existing file."""

file_id, version_id = str(uuid4()), str(uuid4())
full_url = 'www.citrine.io/develop/files/{}/versions/{}'.format(file_id, version_id)
file_link = collection.build(FileLinkDataFactory(url=full_url))
full_url = collection._get_path(uid=file_id, version=version_id)
file_link = collection.build(FileLinkDataFactory(url=full_url, id=file_id, version=version_id))

job_id_resp = {
'job_id': str(uuid4())
Expand Down Expand Up @@ -548,8 +552,8 @@ def test_process_file_no_waiting(collection: FileCollection, session):
"""Test processing an existing file without waiting on the result."""

file_id, version_id = str(uuid4()), str(uuid4())
full_url = 'www.citrine.io/develop/files/{}/versions/{}'.format(file_id, version_id)
file_link = collection.build(FileLinkDataFactory(url=full_url))
full_url = collection._get_path(uid=file_id, version=version_id)
file_link = collection.build(FileLinkDataFactory(url=full_url, id=file_id, version=version_id))

job_id_resp = {
'job_id': str(uuid4())
Expand All @@ -566,11 +570,9 @@ def test_process_file_no_waiting(collection: FileCollection, session):

def test_process_file_exceptions(collection: FileCollection, session):
"""Test processing an existing file without waiting on the result."""

file_id, version_id = str(uuid4()), str(uuid4())
full_url = 'https://www.citrine.io/develop/files/{}/versions/{}'.format(file_id, version_id)
full_url = f'http://www.files.com/file.path'
file_link = collection.build(FileLinkDataFactory(url=full_url))

collection._get_path()
# First does a PUT on the /processed endpoint
# then does a GET on the job executions endpoint
with pytest.raises(ValueError, match="on-platform resources"):
Expand All @@ -580,6 +582,22 @@ def test_process_file_exceptions(collection: FileCollection, session):
wait_for_response=False)


def test_ingest(collection: FileCollection, session):
"""Test the on-platform ingest route."""
good_file1 = collection.build({"filename": "good.csv", "id": str(uuid4()), "version": str(uuid4())})
good_file2 = collection.build({"filename": "also.csv", "id": str(uuid4()), "version": str(uuid4())})
bad_file = FileLink(filename="bad.csv", url="http://files.com/input.csv")

job_id_resp = {
'ingestion_id': str(uuid4())
}
session.set_responses(job_id_resp, job_id_resp)
collection.ingest([good_file1, good_file2])

with pytest.raises(ValueError, match=bad_file.url):
collection.ingest([good_file1, bad_file])


def test_resolve_file_link(collection: FileCollection, session):
# The actual response contains more fields, but these are the only ones we use.
raw_files = [
Expand Down Expand Up @@ -626,17 +644,27 @@ def test_resolve_file_link(collection: FileCollection, session):
session.set_response({
'files': [raw_files[1]]
})
assert collection._resolve_file_link(UUID(raw_files[1]['id'])) == file1, "UUID didn't resolve"

unresolved = FileLink(filename=file1.filename, url=file1.url)
assert collection._resolve_file_link(unresolved) == file1, "FileLink didn't resolve"
assert session.num_calls == 1

unresolved.filename = "Wrong.file"
with pytest.raises(ValueError):
collection._resolve_file_link(unresolved)
assert session.num_calls == 2

assert collection._resolve_file_link(UUID(raw_files[1]['id'])) == file1, "UUID didn't resolve"
assert session.num_calls == 3

session.set_response({
'files': [raw_files[1]]
})
assert collection._resolve_file_link(raw_files[1]['id']) == file1, "String UUID didn't resolve"
assert session.num_calls == 2
assert session.num_calls == 4

assert collection._resolve_file_link(raw_files[1]['version']) == file1, "Version UUID didn't resolve"
assert session.num_calls == 3
assert session.num_calls == 5

abs_link = "https://wwww.website.web/web.pdf"
assert collection._resolve_file_link(abs_link).filename == "web.pdf"
Expand All @@ -646,36 +674,17 @@ def test_resolve_file_link(collection: FileCollection, session):
'files': [raw_files[1]]
})
assert collection._resolve_file_link(file1.url) == file1, "Relative path didn't resolve"
assert session.num_calls == 4
assert session.num_calls == 6

session.set_response({
'files': [raw_files[1]]
})
assert collection._resolve_file_link(file1.filename) == file1, "Filename didn't resolve"
assert session.num_calls == 5
assert session.num_calls == 7

with pytest.raises(TypeError):
collection._resolve_file_link(12345)
assert session.num_calls == 5


def test_validate_filelink_url(collection: FileCollection):
good = [
f"projects/{uuid4()}/datasets/{uuid4()}/files/{uuid4()}/versions/{uuid4()}",
f"/files/{uuid4()}/versions/{uuid4()}"
]
bad = [
f"/projects/{uuid4()}/datasets/{uuid4()}/files/{uuid4()}/versions/{uuid4()}/action",
f"/projects/{uuid4()}/datasets/{uuid4()}/{uuid4()}/versions/{uuid4()}",
f"projects/{uuid4()}/datasets/{uuid4()}/files/{uuid4()}/versions/{uuid4()}?query=param",
f"projects/{uuid4()}/datasets/{uuid4()}/files/{uuid4()}/versions/{uuid4()}?#fragment",
"http://customer.com/data-lake/files/123/versions/456",
"/files/uuid4/versions/uuid4",
]
for x in good:
assert collection._validate_local_url(x)
for x in bad:
assert not collection._validate_local_url(x)
assert session.num_calls == 7


def test_get_ids_from_url(collection: FileCollection):
Expand Down

0 comments on commit 19b7c03

Please sign in to comment.