Skip to content

Commit

Permalink
🎨 Separate labels fetching from printing (#2225)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnyosun authored Nov 28, 2024
1 parent 7d2ff6c commit ec37354
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 103 deletions.
4 changes: 2 additions & 2 deletions lamindb/core/_django.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_artifact_with_related(
"""Fetch an artifact with its related data."""
from lamindb._can_curate import get_name_field

from ._label_manager import LABELS_EXCLUDE_SET
from ._label_manager import EXCLUDE_LABELS

model = artifact.__class__
schema_modules = get_schemas_modules(artifact._state.db)
Expand All @@ -54,7 +54,7 @@ def get_artifact_with_related(
for v in dict_related_model_to_related_name(
model, instance=artifact._state.db
).values()
if not v.startswith("_") and v not in LABELS_EXCLUDE_SET
if not v.startswith("_") and v not in EXCLUDE_LABELS
]
)
link_tables = (
Expand Down
6 changes: 2 additions & 4 deletions lamindb/core/_feature_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from lamindb.core.storage import LocalPathClasses

from ._django import get_artifact_with_related
from ._label_manager import get_labels_as_dict
from ._label_manager import _get_labels
from ._settings import settings
from .schema import (
dict_related_model_to_related_name,
Expand Down Expand Up @@ -199,9 +199,7 @@ def _get_categoricals(
return {}

result = defaultdict(set)
for _, (_, links) in get_labels_as_dict(
self, links=True, instance=self._state.db
).items():
for _, links in _get_labels(self, links=True, instance=self._state.db).items():
for link in links:
if hasattr(link, "feature_id") and link.feature_id is not None:
feature = Feature.objects.using(self._state.db).get(id=link.feature_id)
Expand Down
200 changes: 104 additions & 96 deletions lamindb/core/_label_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections import defaultdict
from typing import TYPE_CHECKING

import numpy as np
from django.db import connections
from lamin_utils import colors, logger
from lnschema_core.models import CanCurate, Feature
Expand All @@ -26,62 +25,75 @@

from lamindb._query_set import QuerySet

LABELS_EXCLUDE_SET = {"feature_sets"}


def get_labels_as_dict(
self: Artifact | Collection, links: bool = False, instance: str | None = None
) -> dict:
labels = {} # type: ignore
if self.id is None:
return labels
for related_model_name, related_name in dict_related_model_to_related_name(
self.__class__, links=links, instance=instance
).items():
if related_name not in LABELS_EXCLUDE_SET and not related_name.startswith("_"):
labels[related_name] = (
related_model_name,
getattr(self, related_name).all(),
)
EXCLUDE_LABELS = {"feature_sets"}


def _get_labels(
obj, links: bool = False, instance: str | None = None
) -> dict[str, QuerySet]:
"""Get all labels associated with an object as a dictionary.
This is a generic approach that uses django orm.
"""
if obj.id is None:
return {}

labels = {}
related_models = dict_related_model_to_related_name(
obj.__class__, links=links, instance=instance
)

for _, related_name in related_models.items():
if related_name not in EXCLUDE_LABELS and not related_name.startswith("_"):
labels[related_name] = getattr(obj, related_name).all()
return labels


def _print_labels_postgres(
self: Artifact | Collection, m2m_data: dict | None = None, print_types: bool = False
) -> str:
labels_msg = ""
if not m2m_data:
def _get_labels_postgres(
self: Artifact | Collection, m2m_data: dict | None = None
) -> dict[str, dict[int, str]]:
"""Get all labels associated with an artifact or collection as a dictionary.
This is a postgres-specific approach that uses django Subquery.
"""
if m2m_data is None:
artifact_meta = get_artifact_with_related(self, include_m2m=True)
m2m_data = artifact_meta.get("related_data", {}).get("m2m", {})
if m2m_data:
for related_name, labels in m2m_data.items():
if not labels or related_name == "feature_sets":
continue
related_model = get_related_model(self, related_name)
print_values = _print_values(labels.values(), n=10)
type_str = f": {related_model}" if print_types else ""
labels_msg += f" .{related_name}{type_str} = {print_values}\n"
return labels_msg
return m2m_data


def print_labels(
self: Artifact | Collection,
m2m_data: dict | None = None,
print_types: bool = False,
):
) -> str:
"""Print labels associated with an artifact or collection.
Args:
m2m_data: A dictionary of m2m data. If not provided, it will be fetched.
print_types: Whether to print the types of the related models.
Returns:
A string representation of the labels associated with the artifact or collection.
"""
if not self._state.adding and connections[self._state.db].vendor == "postgresql":
labels_msg = _print_labels_postgres(self, m2m_data, print_types)
else:
labels_msg = ""
for related_name, (related_model, labels) in get_labels_as_dict(
self, instance=self._state.db
).items():
m2m_data = _get_labels_postgres(self, m2m_data)
if not m2m_data:
m2m_data = _get_labels(self, instance=self._state.db)

labels_msg = ""
for related_name, labels in m2m_data.items():
if not labels or related_name == "feature_sets":
continue
if isinstance(labels, dict): # postgres, labels are a dict[id, name]
print_values = _print_values(labels.values(), n=10)
else: # labels are a QuerySet
field = get_name_field(labels)
labels_list = list(labels.values_list(field, flat=True))
if len(labels_list) > 0:
print_values = _print_values(labels_list, n=10)
type_str = f": {related_model}" if print_types else ""
labels_msg += f" .{related_name}{type_str} = {print_values}\n"
print_values = _print_values(labels.values_list(field, flat=True), n=10)
if print_values:
related_model = get_related_model(self, related_name)
type_str = f": {related_model.__name__}" if print_types else ""
labels_msg += f" .{related_name}{type_str} = {print_values}\n"

msg = ""
if labels_msg:
Expand All @@ -90,50 +102,48 @@ def print_labels(
return msg


# Alex: is this a label transfer function?
def validate_labels(labels: QuerySet | list | dict):
def validate_labels_registry(
labels: QuerySet | list | dict,
) -> tuple[list[str], list[str]]:
if len(labels) == 0:
return [], []
registry = labels[0].__class__
field = REGISTRY_UNIQUE_FIELD.get(registry.__name__.lower(), "uid")
if hasattr(registry, "_ontology_id_field"):
field = registry._ontology_id_field
# if the field value is None, use uid field
label_uids = np.array(
[getattr(label, field) for label in labels if label is not None]
)
# save labels from ontology_ids
if hasattr(registry, "_ontology_id_field") and len(label_uids) > 0:
try:
labels_records = registry.from_values(label_uids, field=field)
save([r for r in labels_records if r._state.adding])
except Exception: # noqa S110
pass
field = "uid"
label_uids = np.array(
[getattr(label, field) for label in labels if label is not None]
)
if issubclass(registry, CanCurate):
validated = registry.validate(label_uids, field=field, mute=True)
validated_uids = label_uids[validated]
validated_labels = registry.filter(
**{f"{field}__in": validated_uids}
).list()
new_labels = [labels[int(i)] for i in np.argwhere(~validated).flatten()]
else:
validated_labels = []
new_labels = list(labels)
return validated_labels, new_labels

def _save_validated_records(
labels: QuerySet | list | dict,
) -> list[str]:
if not labels:
return []
registry = labels[0].__class__
field = (
REGISTRY_UNIQUE_FIELD.get(registry.__name__.lower(), "uid")
if not hasattr(registry, "_ontology_id_field")
else registry._ontology_id_field
)
# if the field value is None, use uid field
label_uids = [getattr(label, field) for label in labels if label is not None]
# save labels from ontology_ids
if hasattr(registry, "_ontology_id_field") and label_uids:
try:
records = registry.from_values(label_uids, field=field)
save([r for r in records if r._state.adding])
except Exception: # noqa: S110
pass
field = "uid"
label_uids = [label.uid for label in labels if label is not None]

if issubclass(registry, CanCurate):
validated = registry.validate(label_uids, field=field, mute=True)
new_labels = [
label for label, is_valid in zip(labels, validated) if not is_valid
]
return new_labels
return list(labels)


def save_validated_records(
labels: QuerySet | list | dict,
) -> list[str] | dict[str, list[str]]:
"""Save validated labels from public based on ontology_id_fields."""
if isinstance(labels, dict):
result = {}
for registry, labels_registry in labels.items():
result[registry] = validate_labels_registry(labels_registry)
else:
return validate_labels_registry(labels)
return {
registry: _save_validated_records(registry_labels)
for registry, registry_labels in labels.items()
}
return _save_validated_records(labels)


class LabelManager:
Expand All @@ -144,7 +154,7 @@ class LabelManager:
with features.
"""

def __init__(self, host: Artifact | Collection):
def __init__(self, host: Artifact | Collection) -> None:
self._host = host

def __repr__(self) -> str:
Expand Down Expand Up @@ -201,17 +211,15 @@ def add_from(self, data: Artifact | Collection, transfer_logs: dict = None) -> N
if transfer_logs is None:
transfer_logs = {"mapped": [], "transferred": [], "run": None}
using_key = settings._using_key
for related_name, (_, labels) in get_labels_as_dict(
data, instance=data._state.db
).items():
for related_name, labels in _get_labels(data, instance=data._state.db).items():
labels = labels.all()
if not labels.exists():
continue
# look for features
data_name_lower = data.__class__.__name__.lower()
labels_by_features = defaultdict(list)
features = set()
_, new_labels = validate_labels(labels)
new_labels = save_validated_records(labels)
if len(new_labels) > 0:
transfer_fk_to_default_db_bulk(
new_labels, using_key, transfer_logs=transfer_logs
Expand Down Expand Up @@ -241,7 +249,7 @@ def add_from(self, data: Artifact | Collection, transfer_logs: dict = None) -> N
label = label_returned
labels_by_features[key].append(label)
# treat features
_, new_features = validate_labels(list(features))
new_features = save_validated_records(list(features))
if len(new_features) > 0:
transfer_fk_to_default_db_bulk(
new_features, using_key, transfer_logs=transfer_logs
Expand All @@ -255,16 +263,16 @@ def add_from(self, data: Artifact | Collection, transfer_logs: dict = None) -> N
)
save(new_features)
if hasattr(self._host, related_name):
for feature_name, labels in labels_by_features.items():
for feature_name, feature_labels in labels_by_features.items():
if feature_name is not None:
feature_id = Feature.get(name=feature_name).id
else:
feature_id = None
getattr(self._host, related_name).add(
*labels, through_defaults={"feature_id": feature_id}
*feature_labels, through_defaults={"feature_id": feature_id}
)

def make_external(self, label: Record):
def make_external(self, label: Record) -> None:
"""Make a label external, aka dissociate label from internal features.
Args:
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_feature_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def test_labels_add(adata):
adata2 = adata.copy()
adata2.uns["mutated"] = True
artifact2 = ln.Artifact(adata2, description="My new artifact").save()
from lamindb.core._label_manager import get_labels_as_dict
from lamindb.core._label_manager import _get_labels

artifact2.labels.add_from(artifact)
experiments = artifact2.labels.get(experiment)
Expand Down

0 comments on commit ec37354

Please sign in to comment.