Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fetch attributes using Run UUIDs instead of sys/id #95

Merged
merged 2 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions src/neptune_fetcher/read_only_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@
LeaderboardEntry,
NextPage,
)
from neptune.api.searching_entries import find_attribute
from neptune.envs import PROJECT_ENV_NAME
from neptune.exceptions import NeptuneException
from neptune.internal.backends.api_model import Project
from neptune.internal.backends.hosted_neptune_backend import HostedNeptuneBackend
from neptune.internal.backends.nql import (
Expand Down Expand Up @@ -494,17 +492,17 @@ def _fetch_df(
# We use this list to maintain the sorting order as returned by the backend during the
# initial filtering of runs.
# This is because the request for field values always sorts the result by (run_id, path).
all_run_ids = []
all_run_uuids = []

# Workers fetching attributes in parallel
futures = []

value_count = 0
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
for run_ids in _batch_run_ids(runs_generator, batch_size=FETCH_RUNS_BATCH_SIZE):
all_run_ids.extend(run_ids)
for run_uuids in _batch_run_ids(runs_generator, batch_size=FETCH_RUNS_BATCH_SIZE):
all_run_uuids.extend(run_uuids)

if len(all_run_ids) > limit:
if len(all_run_uuids) > limit:
raise ValueError(
f"The number of runs returned exceeds the limit of {limit}. "
"Please narrow down your query or provide a smaller 'limit' "
Expand All @@ -513,7 +511,7 @@ def _fetch_df(

# Scatter
futures.append(
executor.submit(self._fetch_columns_batch, run_ids, columns=columns, columns_regex=columns_regex)
executor.submit(self._fetch_columns_batch, run_uuids, columns=columns, columns_regex=columns_regex)
)

# Gather
Expand All @@ -527,39 +525,39 @@ def _fetch_df(
for run_id, attrs in values.items():
acc[run_id].update(attrs)

df = _to_pandas_df(all_run_ids, acc, columns)
df = _to_pandas_df(all_run_uuids, acc, columns)

return df

def _fetch_columns_batch(
self, run_ids: List[str], columns: Optional[Iterable[str]] = None, columns_regex: Optional[str] = None
self, run_uuids: List[str], columns: Optional[Iterable[str]] = None, columns_regex: Optional[str] = None
) -> Tuple[int, Dict[str, Dict[str, Any]]]:
"""
Called as a worker function concurrently.

Fetch a batch of columns for the given runs. Return a tuple of the number of
values fetched, and a dictionary mapping run_id -> (attr_path -> value).
values fetched, and a dictionary mapping run UUID -> (attr_path -> value).
"""

acc = collections.defaultdict(dict)
count = 0

for run_id, attr in self._stream_attributes(
run_ids, columns=columns, columns_regex=columns_regex, batch_size=FETCH_COLUMNS_BATCH_SIZE
run_uuids, columns=columns, columns_regex=columns_regex, batch_size=FETCH_COLUMNS_BATCH_SIZE
):
acc[run_id][attr.path] = _extract_value(attr)
count += 1

return count, acc

def _stream_attributes(
self, run_ids, *, columns=None, columns_regex=None, limit: Optional[int] = None, batch_size=10_000
self, run_uuids, *, columns=None, columns_regex=None, limit: Optional[int] = None, batch_size=10_000
) -> Generator[Tuple[str, Field], None, None]:
"""
Download attributes that match the given criteria, for the given runs. Attributes are downloaded
in batches of `batch_size` values per HTTP request.

The returned generator yields tuples of (run_id, attribute) for each attribute returned, until
The returned generator yields tuples of (run UUID, attribute) for each attribute returned, until
there is no more data, or the provided `limit` is reached. Limit is calculated as the number of
non-null data cells returned.
"""
Expand All @@ -574,7 +572,7 @@ def _stream_attributes(

response = self._backend.query_fields_within_project(
project_id=self._project,
experiment_ids_filter=run_ids,
experiment_ids_filter=run_uuids,
field_names_filter=columns,
field_name_regex=columns_regex,
next_page=next_page,
Expand All @@ -584,7 +582,7 @@ def _stream_attributes(
# We're assuming that the backend does not return more entries than requested
for entry in response.entries:
for attr in entry.fields:
yield entry.object_key, attr
yield entry.object_id, attr

remaining -= len(entry.fields)
if remaining <= 0:
Expand Down Expand Up @@ -624,9 +622,9 @@ def _extract_value(attr: Field):
return attr.value


def _to_pandas_df(order: List, items: Dict[str, Any], ensure_columns=None) -> DataFrame:
def _to_pandas_df(run_uuids: List[str], items: Dict[str, Any], ensure_columns=None) -> DataFrame:
"""
Convert the provided items into a pandas DataFrame, ensuring the order of columns as specified.
Convert the provided items into a pandas DataFrame, ensuring the order rows is the same as run_uuids.
Any columns passed in `ensure_columns` will be present in the result as NA, even if not returned by the backend.

System and monitoring columns will be sorted to the front.
Expand All @@ -640,7 +638,7 @@ def sort_key(field: str) -> Tuple[int, str]:
return 2, field
return 1, field

df = DataFrame(items[x] for x in order)
df = DataFrame(items[x] for x in run_uuids)

if ensure_columns:
for col in ensure_columns:
Expand Down Expand Up @@ -730,18 +728,16 @@ def _stream_runs(
)


def _batch_run_ids(runs: Generator, *, batch_size: int) -> Generator[List[str], None, None]:
def _batch_run_ids(
runs: Generator[LeaderboardEntry, None, None], *, batch_size: int
) -> Generator[List[str], None, None]:
"""
Consumes the `runs` generator, yields lists of short ids, of the given max size.
Consumes the `runs` generator, yielding lists of Run UUIDs. The length of a single list is limited by `batch_size`.
"""

batch = []
for run in runs:
run_id = find_attribute(entry=run, path="sys/id")
if run_id is None:
raise NeptuneException("Experiment id missing in server response")

batch.append(run_id.value)
batch.append(run.object_id)
if len(batch) == batch_size:
yield batch
batch = []
Expand Down
26 changes: 14 additions & 12 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def api_token() -> str:
return base64.b64encode(json.dumps({"api_address": ""}).encode()).decode()


def create_leaderboard_entry(sys_id, custom_run_id, name: Optional[str] = None, columns=None):
def create_leaderboard_entry(sys_id, run_uuid, custom_run_id, name: Optional[str] = None, columns=None):
name = name if name is not None else ""

return LeaderboardEntry(
object_id=sys_id,
object_id=run_uuid,
fields=list(
filter(
lambda field: columns is None or field.path in columns,
Expand Down Expand Up @@ -104,11 +104,13 @@ def search_leaderboard_entries(columns, query, *args, **kwargs):
complex_query_exp = '((`sys/trashed`:bool = false) AND (`sys/name`:string != "") AND (`fields/int`:int > 5))'
query_all_exps = '((`sys/trashed`:bool = false) AND (`sys/name`:string != ""))'

run1 = create_leaderboard_entry("RUN-1", "alternative_tesla", columns=columns)
run2 = create_leaderboard_entry("RUN-2", "nostalgic_stallman", columns=columns)
run1 = create_leaderboard_entry("RUN-1", "RUN-UUID-1", "alternative_tesla", columns=columns)
run2 = create_leaderboard_entry("RUN-2", "RUN-UUID-2", "nostalgic_stallman", columns=columns)

exp1 = create_leaderboard_entry("EXP-1", "custom_experiment_id", name="powerful-sun-2", columns=columns)
exp2 = create_leaderboard_entry("EXP-2", "nostalgic_stallman", name="lazy-moon-2", columns=columns)
exp1 = create_leaderboard_entry(
"EXP-1", "EXP-UUID-1", "custom_experiment_id", name="powerful-sun-2", columns=columns
)
exp2 = create_leaderboard_entry("EXP-2", "EXP-UUID-2", "nostalgic_stallman", name="lazy-moon-2", columns=columns)

if str(query) == query_run1 or str(query) == complex_query_run:
output = [run1]
Expand Down Expand Up @@ -198,10 +200,10 @@ def get_float_series_values(*args, **kwargs):
# )


def make_query_fields_entry(sys_id, custom_run_id):
def make_query_fields_entry(sys_id, run_uuid, custom_run_id):
return QueryFieldsExperimentResult(
object_key=sys_id,
object_id=custom_run_id,
object_id=run_uuid,
fields=[
StringField(path="sys/id", value=sys_id),
StringField(path="sys/custom_run_id", value=custom_run_id),
Expand All @@ -212,10 +214,10 @@ def make_query_fields_entry(sys_id, custom_run_id):
def query_fields_within_project(*args, **kwargs) -> QueryFieldsResult:
return QueryFieldsResult(
entries=[
make_query_fields_entry("RUN-1", "alternative_tesla"),
make_query_fields_entry("RUN-2", "nostalgic_stallman"),
make_query_fields_entry("EXP-1", "custom_experiment_id"),
make_query_fields_entry("EXP-2", "nostalgic_stallman"),
make_query_fields_entry("RUN-1", "RUN-UUID-1", "alternative_tesla"),
make_query_fields_entry("RUN-2", "RUN-UUID-2", "nostalgic_stallman"),
make_query_fields_entry("EXP-1", "EXP-UUID-1", "custom_experiment_id"),
make_query_fields_entry("EXP-2", "EXP-UUID-2", "nostalgic_stallman"),
],
next_page=NextPage(next_page_token=None, limit=None),
)
Expand Down
Loading