Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🎨 Separate labels fetching from printing #2225

Merged
merged 6 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lamindb/core/_django.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_artifact_with_related(
"""Fetch an artifact with its related data."""
from lamindb._can_curate import get_name_field

from ._label_manager import LABELS_EXCLUDE_SET
from ._label_manager import EXCLUDE_LABELS

model = artifact.__class__
schema_modules = get_schemas_modules(artifact._state.db)
Expand All @@ -54,7 +54,7 @@ def get_artifact_with_related(
for v in dict_related_model_to_related_name(
model, instance=artifact._state.db
).values()
if not v.startswith("_") and v not in LABELS_EXCLUDE_SET
if not v.startswith("_") and v not in EXCLUDE_LABELS
]
)
link_tables = (
Expand Down
6 changes: 2 additions & 4 deletions lamindb/core/_feature_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from lamindb.core.storage import LocalPathClasses

from ._django import get_artifact_with_related
from ._label_manager import get_labels_as_dict
from ._label_manager import _get_labels
from ._settings import settings
from .schema import (
dict_related_model_to_related_name,
Expand Down Expand Up @@ -199,9 +199,7 @@ def _get_categoricals(
return {}

result = defaultdict(set)
for _, (_, links) in get_labels_as_dict(
self, links=True, instance=self._state.db
).items():
for _, links in _get_labels(self, links=True, instance=self._state.db).items():
for link in links:
if hasattr(link, "feature_id") and link.feature_id is not None:
feature = Feature.objects.using(self._state.db).get(id=link.feature_id)
Expand Down
200 changes: 104 additions & 96 deletions lamindb/core/_label_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections import defaultdict
from typing import TYPE_CHECKING

import numpy as np
from django.db import connections
from lamin_utils import colors, logger
from lnschema_core.models import CanCurate, Feature
Expand All @@ -26,62 +25,75 @@

from lamindb._query_set import QuerySet

LABELS_EXCLUDE_SET = {"feature_sets"}


def get_labels_as_dict(
self: Artifact | Collection, links: bool = False, instance: str | None = None
) -> dict:
labels = {} # type: ignore
if self.id is None:
return labels
for related_model_name, related_name in dict_related_model_to_related_name(
self.__class__, links=links, instance=instance
).items():
if related_name not in LABELS_EXCLUDE_SET and not related_name.startswith("_"):
labels[related_name] = (
related_model_name,
getattr(self, related_name).all(),
)
EXCLUDE_LABELS = {"feature_sets"}


def _get_labels(
obj, links: bool = False, instance: str | None = None
) -> dict[str, QuerySet]:
"""Get all labels associated with an object as a dictionary.

This is a generic approach that uses django orm.
"""
if obj.id is None:
return {}

labels = {}
related_models = dict_related_model_to_related_name(
obj.__class__, links=links, instance=instance
)

for _, related_name in related_models.items():
if related_name not in EXCLUDE_LABELS and not related_name.startswith("_"):
labels[related_name] = getattr(obj, related_name).all()
return labels


def _print_labels_postgres(
self: Artifact | Collection, m2m_data: dict | None = None, print_types: bool = False
) -> str:
labels_msg = ""
if not m2m_data:
def _get_labels_postgres(
self: Artifact | Collection, m2m_data: dict | None = None
) -> dict[str, dict[int, str]]:
"""Get all labels associated with an artifact or collection as a dictionary.

This is a postgres-specific approach that uses django Subquery.
"""
if m2m_data is None:
artifact_meta = get_artifact_with_related(self, include_m2m=True)
m2m_data = artifact_meta.get("related_data", {}).get("m2m", {})
if m2m_data:
for related_name, labels in m2m_data.items():
if not labels or related_name == "feature_sets":
continue
related_model = get_related_model(self, related_name)
print_values = _print_values(labels.values(), n=10)
type_str = f": {related_model}" if print_types else ""
labels_msg += f" .{related_name}{type_str} = {print_values}\n"
return labels_msg
return m2m_data


def print_labels(
self: Artifact | Collection,
m2m_data: dict | None = None,
print_types: bool = False,
):
) -> str:
"""Print labels associated with an artifact or collection.

Args:
m2m_data: A dictionary of m2m data. If not provided, it will be fetched.
print_types: Whether to print the types of the related models.

Returns:
A string representation of the labels associated with the artifact or collection.
"""
if not self._state.adding and connections[self._state.db].vendor == "postgresql":
labels_msg = _print_labels_postgres(self, m2m_data, print_types)
else:
labels_msg = ""
for related_name, (related_model, labels) in get_labels_as_dict(
self, instance=self._state.db
).items():
m2m_data = _get_labels_postgres(self, m2m_data)
if not m2m_data:
m2m_data = _get_labels(self, instance=self._state.db)

labels_msg = ""
for related_name, labels in m2m_data.items():
if not labels or related_name == "feature_sets":
continue
if isinstance(labels, dict): # postgres, labels are a dict[id, name]
print_values = _print_values(labels.values(), n=10)
else: # labels are a QuerySet
field = get_name_field(labels)
labels_list = list(labels.values_list(field, flat=True))
if len(labels_list) > 0:
print_values = _print_values(labels_list, n=10)
type_str = f": {related_model}" if print_types else ""
labels_msg += f" .{related_name}{type_str} = {print_values}\n"
print_values = _print_values(labels.values_list(field, flat=True), n=10)
if print_values:
related_model = get_related_model(self, related_name)
type_str = f": {related_model.__name__}" if print_types else ""
labels_msg += f" .{related_name}{type_str} = {print_values}\n"

msg = ""
if labels_msg:
Expand All @@ -90,50 +102,48 @@
return msg


# Alex: is this a label transfer function?
def validate_labels(labels: QuerySet | list | dict):
def validate_labels_registry(
labels: QuerySet | list | dict,
) -> tuple[list[str], list[str]]:
if len(labels) == 0:
return [], []
registry = labels[0].__class__
field = REGISTRY_UNIQUE_FIELD.get(registry.__name__.lower(), "uid")
if hasattr(registry, "_ontology_id_field"):
field = registry._ontology_id_field
# if the field value is None, use uid field
label_uids = np.array(
[getattr(label, field) for label in labels if label is not None]
)
# save labels from ontology_ids
if hasattr(registry, "_ontology_id_field") and len(label_uids) > 0:
try:
labels_records = registry.from_values(label_uids, field=field)
save([r for r in labels_records if r._state.adding])
except Exception: # noqa S110
pass
field = "uid"
label_uids = np.array(
[getattr(label, field) for label in labels if label is not None]
)
if issubclass(registry, CanCurate):
validated = registry.validate(label_uids, field=field, mute=True)
validated_uids = label_uids[validated]
validated_labels = registry.filter(
**{f"{field}__in": validated_uids}
).list()
new_labels = [labels[int(i)] for i in np.argwhere(~validated).flatten()]
else:
validated_labels = []
new_labels = list(labels)
return validated_labels, new_labels

def _save_validated_records(
labels: QuerySet | list | dict,
) -> list[str]:
if not labels:
return []
registry = labels[0].__class__
field = (
REGISTRY_UNIQUE_FIELD.get(registry.__name__.lower(), "uid")
if not hasattr(registry, "_ontology_id_field")
else registry._ontology_id_field
)
# if the field value is None, use uid field
label_uids = [getattr(label, field) for label in labels if label is not None]
# save labels from ontology_ids
if hasattr(registry, "_ontology_id_field") and label_uids:
try:
records = registry.from_values(label_uids, field=field)
save([r for r in records if r._state.adding])
except Exception: # noqa: S110
pass

Check warning on line 124 in lamindb/core/_label_manager.py

View check run for this annotation

Codecov / codecov/patch

lamindb/core/_label_manager.py#L123-L124

Added lines #L123 - L124 were not covered by tests
field = "uid"
label_uids = [label.uid for label in labels if label is not None]

if issubclass(registry, CanCurate):
validated = registry.validate(label_uids, field=field, mute=True)
new_labels = [
label for label, is_valid in zip(labels, validated) if not is_valid
]
return new_labels
return list(labels)

Check warning on line 134 in lamindb/core/_label_manager.py

View check run for this annotation

Codecov / codecov/patch

lamindb/core/_label_manager.py#L134

Added line #L134 was not covered by tests


def save_validated_records(
labels: QuerySet | list | dict,
) -> list[str] | dict[str, list[str]]:
"""Save validated labels from public based on ontology_id_fields."""
if isinstance(labels, dict):
result = {}
for registry, labels_registry in labels.items():
result[registry] = validate_labels_registry(labels_registry)
else:
return validate_labels_registry(labels)
return {

Check warning on line 142 in lamindb/core/_label_manager.py

View check run for this annotation

Codecov / codecov/patch

lamindb/core/_label_manager.py#L142

Added line #L142 was not covered by tests
registry: _save_validated_records(registry_labels)
for registry, registry_labels in labels.items()
}
return _save_validated_records(labels)


class LabelManager:
Expand All @@ -144,7 +154,7 @@
with features.
"""

def __init__(self, host: Artifact | Collection):
def __init__(self, host: Artifact | Collection) -> None:
self._host = host

def __repr__(self) -> str:
Expand Down Expand Up @@ -201,17 +211,15 @@
if transfer_logs is None:
transfer_logs = {"mapped": [], "transferred": [], "run": None}
using_key = settings._using_key
for related_name, (_, labels) in get_labels_as_dict(
data, instance=data._state.db
).items():
for related_name, labels in _get_labels(data, instance=data._state.db).items():
labels = labels.all()
if not labels.exists():
continue
# look for features
data_name_lower = data.__class__.__name__.lower()
labels_by_features = defaultdict(list)
features = set()
_, new_labels = validate_labels(labels)
new_labels = save_validated_records(labels)
if len(new_labels) > 0:
transfer_fk_to_default_db_bulk(
new_labels, using_key, transfer_logs=transfer_logs
Expand Down Expand Up @@ -241,7 +249,7 @@
label = label_returned
labels_by_features[key].append(label)
# treat features
_, new_features = validate_labels(list(features))
new_features = save_validated_records(list(features))
if len(new_features) > 0:
transfer_fk_to_default_db_bulk(
new_features, using_key, transfer_logs=transfer_logs
Expand All @@ -255,16 +263,16 @@
)
save(new_features)
if hasattr(self._host, related_name):
for feature_name, labels in labels_by_features.items():
for feature_name, feature_labels in labels_by_features.items():
if feature_name is not None:
feature_id = Feature.get(name=feature_name).id
else:
feature_id = None
getattr(self._host, related_name).add(
*labels, through_defaults={"feature_id": feature_id}
*feature_labels, through_defaults={"feature_id": feature_id}
)

def make_external(self, label: Record):
def make_external(self, label: Record) -> None:
"""Make a label external, aka dissociate label from internal features.

Args:
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_feature_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def test_labels_add(adata):
adata2 = adata.copy()
adata2.uns["mutated"] = True
artifact2 = ln.Artifact(adata2, description="My new artifact").save()
from lamindb.core._label_manager import get_labels_as_dict
from lamindb.core._label_manager import _get_labels

artifact2.labels.add_from(artifact)
experiments = artifact2.labels.get(experiment)
Expand Down