From 72210eadf3dda3887ac25ec77b4970e5fe44f460 Mon Sep 17 00:00:00 2001 From: Sunny Sun <38218185+sunnyosun@users.noreply.github.com> Date: Thu, 28 Nov 2024 11:37:29 +0100 Subject: [PATCH 1/6] =?UTF-8?q?=F0=9F=8E=A8=20Separate=20labels=20fetching?= =?UTF-8?q?=20from=20printing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lamindb/core/_label_manager.py | 67 +++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 25 deletions(-) diff --git a/lamindb/core/_label_manager.py b/lamindb/core/_label_manager.py index 00e651054..19edc2e26 100644 --- a/lamindb/core/_label_manager.py +++ b/lamindb/core/_label_manager.py @@ -31,7 +31,7 @@ def get_labels_as_dict( self: Artifact | Collection, links: bool = False, instance: str | None = None -) -> dict: +) -> dict[str, tuple[str, QuerySet]]: labels = {} # type: ignore if self.id is None: return labels @@ -46,21 +46,27 @@ def get_labels_as_dict( return labels +def _get_labels_postgres( + self: Artifact | Collection, m2m_data: dict | None = None +) -> dict: + if m2m_data is None: + artifact_meta = get_artifact_with_related(self, include_m2m=True) + m2m_data = artifact_meta.get("related_data", {}).get("m2m", {}) + return m2m_data + + def _print_labels_postgres( - self: Artifact | Collection, m2m_data: dict | None = None, print_types: bool = False + self: Artifact | Collection, m2m_data: dict, print_types: bool = False ) -> str: + m2m_data = _get_labels_postgres(self, m2m_data) labels_msg = "" - if not m2m_data: - artifact_meta = get_artifact_with_related(self, include_m2m=True) - m2m_data = artifact_meta.get("related_data", {}).get("m2m", {}) - if m2m_data: - for related_name, labels in m2m_data.items(): - if not labels or related_name == "feature_sets": - continue - related_model = get_related_model(self, related_name) - print_values = _print_values(labels.values(), n=10) - type_str = f": {related_model}" if print_types else "" - labels_msg += f" .{related_name}{type_str} = {print_values}\n" + for related_name, labels in m2m_data.items(): + if not labels or related_name == "feature_sets": + continue + related_model = get_related_model(self, related_name) + print_values = _print_values(labels.values(), n=10) + type_str = f": {related_model}" if print_types else "" + labels_msg += f" .{related_name}{type_str} = {print_values}\n" return labels_msg @@ -68,7 +74,16 @@ def print_labels( self: Artifact | Collection, m2m_data: dict | None = None, print_types: bool = False, -): +) -> str: + """Print labels associated with an artifact or collection. + + Args: + m2m_data: A dictionary of m2m data. If not provided, it will be fetched. + print_types: Whether to print the types of the related models. + + Returns: + A string representation of the labels associated with the artifact or collection. + """ if not self._state.adding and connections[self._state.db].vendor == "postgresql": labels_msg = _print_labels_postgres(self, m2m_data, print_types) else: @@ -90,9 +105,10 @@ def print_labels( return msg -# Alex: is this a label transfer function? -def validate_labels(labels: QuerySet | list | dict): - def validate_labels_registry( +def save_validated_records(labels: QuerySet | list | dict) -> tuple[list, list]: + """Save validated labels from public based on ontology_id_fields.""" + + def save_records_from_ontology_ids( labels: QuerySet | list | dict, ) -> tuple[list[str], list[str]]: if len(labels) == 0: @@ -131,9 +147,10 @@ def validate_labels_registry( if isinstance(labels, dict): result = {} for registry, labels_registry in labels.items(): - result[registry] = validate_labels_registry(labels_registry) + result[registry] = save_records_from_ontology_ids(labels_registry) + return result # type: ignore else: - return validate_labels_registry(labels) + return save_records_from_ontology_ids(labels) class LabelManager: @@ -144,7 +161,7 @@ class LabelManager: with features. """ - def __init__(self, host: Artifact | Collection): + def __init__(self, host: Artifact | Collection) -> None: self._host = host def __repr__(self) -> str: @@ -211,7 +228,7 @@ def add_from(self, data: Artifact | Collection, transfer_logs: dict = None) -> N data_name_lower = data.__class__.__name__.lower() labels_by_features = defaultdict(list) features = set() - _, new_labels = validate_labels(labels) + _, new_labels = save_validated_records(labels) if len(new_labels) > 0: transfer_fk_to_default_db_bulk( new_labels, using_key, transfer_logs=transfer_logs @@ -241,7 +258,7 @@ def add_from(self, data: Artifact | Collection, transfer_logs: dict = None) -> N label = label_returned labels_by_features[key].append(label) # treat features - _, new_features = validate_labels(list(features)) + _, new_features = save_validated_records(list(features)) if len(new_features) > 0: transfer_fk_to_default_db_bulk( new_features, using_key, transfer_logs=transfer_logs @@ -255,16 +272,16 @@ def add_from(self, data: Artifact | Collection, transfer_logs: dict = None) -> N ) save(new_features) if hasattr(self._host, related_name): - for feature_name, labels in labels_by_features.items(): + for feature_name, feature_labels in labels_by_features.items(): if feature_name is not None: feature_id = Feature.get(name=feature_name).id else: feature_id = None getattr(self._host, related_name).add( - *labels, through_defaults={"feature_id": feature_id} + *feature_labels, through_defaults={"feature_id": feature_id} ) - def make_external(self, label: Record): + def make_external(self, label: Record) -> None: """Make a label external, aka dissociate label from internal features. Args: From 5fab1de6dcfe27dbd721263103ec90eacaa0a996 Mon Sep 17 00:00:00 2001 From: Sunny Sun <38218185+sunnyosun@users.noreply.github.com> Date: Thu, 28 Nov 2024 13:48:11 +0100 Subject: [PATCH 2/6] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Simplify?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lamindb/core/_django.py | 4 +- lamindb/core/_feature_manager.py | 6 +- lamindb/core/_label_manager.py | 97 ++++++++++++++++++++------------ 3 files changed, 66 insertions(+), 41 deletions(-) diff --git a/lamindb/core/_django.py b/lamindb/core/_django.py index 29c6f8b2e..8a5d79d9a 100644 --- a/lamindb/core/_django.py +++ b/lamindb/core/_django.py @@ -35,7 +35,7 @@ def get_artifact_with_related( """Fetch an artifact with its related data.""" from lamindb._can_curate import get_name_field - from ._label_manager import LABELS_EXCLUDE_SET + from ._label_manager import EXCLUDE_LABELS model = artifact.__class__ schema_modules = get_schemas_modules(artifact._state.db) @@ -54,7 +54,7 @@ def get_artifact_with_related( for v in dict_related_model_to_related_name( model, instance=artifact._state.db ).values() - if not v.startswith("_") and v not in LABELS_EXCLUDE_SET + if not v.startswith("_") and v not in EXCLUDE_LABELS ] ) link_tables = ( diff --git a/lamindb/core/_feature_manager.py b/lamindb/core/_feature_manager.py index d3fcff0b0..96c4ba7c0 100644 --- a/lamindb/core/_feature_manager.py +++ b/lamindb/core/_feature_manager.py @@ -46,7 +46,7 @@ from lamindb.core.storage import LocalPathClasses from ._django import get_artifact_with_related -from ._label_manager import get_labels_as_dict +from ._label_manager import _get_labels from ._settings import settings from .schema import ( dict_related_model_to_related_name, @@ -199,9 +199,7 @@ def _get_categoricals( return {} result = defaultdict(set) - for _, (_, links) in get_labels_as_dict( - self, links=True, instance=self._state.db - ).items(): + for _, (_, links) in _get_labels(self, links=True, instance=self._state.db).items(): for link in links: if hasattr(link, "feature_id") and link.feature_id is not None: feature = Feature.objects.using(self._state.db).get(id=link.feature_id) diff --git a/lamindb/core/_label_manager.py b/lamindb/core/_label_manager.py index 19edc2e26..a4d2a57ea 100644 --- a/lamindb/core/_label_manager.py +++ b/lamindb/core/_label_manager.py @@ -26,39 +26,44 @@ from lamindb._query_set import QuerySet -LABELS_EXCLUDE_SET = {"feature_sets"} - - -def get_labels_as_dict( - self: Artifact | Collection, links: bool = False, instance: str | None = None -) -> dict[str, tuple[str, QuerySet]]: - labels = {} # type: ignore - if self.id is None: - return labels - for related_model_name, related_name in dict_related_model_to_related_name( - self.__class__, links=links, instance=instance - ).items(): - if related_name not in LABELS_EXCLUDE_SET and not related_name.startswith("_"): - labels[related_name] = ( - related_model_name, - getattr(self, related_name).all(), - ) +EXCLUDE_LABELS = {"feature_sets"} + + +def _get_labels( + obj, links: bool = False, instance: str | None = None +) -> dict[str, QuerySet]: + """Get all labels associated with an object as a dictionary. + + This is a generic approach that uses django orm. + """ + if obj.id is None: + return {} + + labels = {} + related_models = dict_related_model_to_related_name( + obj.__class__, links=links, instance=instance + ) + + for _, related_name in related_models.items(): + if related_name not in EXCLUDE_LABELS and not related_name.startswith("_"): + labels[related_name] = getattr(obj, related_name).all() return labels def _get_labels_postgres( self: Artifact | Collection, m2m_data: dict | None = None -) -> dict: +) -> dict[str, list]: + """Get all labels associated with an artifact or collection as a dictionary. + + This is a postgres-specific approach that uses django Subquery. + """ if m2m_data is None: artifact_meta = get_artifact_with_related(self, include_m2m=True) m2m_data = artifact_meta.get("related_data", {}).get("m2m", {}) return m2m_data -def _print_labels_postgres( - self: Artifact | Collection, m2m_data: dict, print_types: bool = False -) -> str: - m2m_data = _get_labels_postgres(self, m2m_data) +def _print_labels(self, m2m_data: dict, print_types: bool = False): labels_msg = "" for related_name, labels in m2m_data.items(): if not labels or related_name == "feature_sets": @@ -70,6 +75,21 @@ def _print_labels_postgres( return labels_msg +# def _print_labels_postgres( +# self: Artifact | Collection, m2m_data: dict, print_types: bool = False +# ) -> str: +# m2m_data = _get_labels_postgres(self, m2m_data) +# labels_msg = "" +# for related_name, labels in m2m_data.items(): +# if not labels or related_name == "feature_sets": +# continue +# related_model = get_related_model(self, related_name) +# print_values = _print_values(labels.values(), n=10) +# type_str = f": {related_model}" if print_types else "" +# labels_msg += f" .{related_name}{type_str} = {print_values}\n" +# return labels_msg + + def print_labels( self: Artifact | Collection, m2m_data: dict | None = None, @@ -84,19 +104,26 @@ def print_labels( Returns: A string representation of the labels associated with the artifact or collection. """ + # if not self._state.adding and connections[self._state.db].vendor == "postgresql": + # labels_msg = _print_labels_postgres(self, m2m_data, print_types) + # else: + # labels_msg = "" + # for related_name, (related_model, labels) in _get_labels( + # self, instance=self._state.db + # ).items(): + # field = get_name_field(labels) + # labels_list = list(labels.values_list(field, flat=True)) + # if len(labels_list) > 0: + # print_values = _print_values(labels_list, n=10) + # type_str = f": {related_model}" if print_types else "" + # labels_msg += f" .{related_name}{type_str} = {print_values}\n" + if not self._state.adding and connections[self._state.db].vendor == "postgresql": - labels_msg = _print_labels_postgres(self, m2m_data, print_types) - else: - labels_msg = "" - for related_name, (related_model, labels) in get_labels_as_dict( - self, instance=self._state.db - ).items(): - field = get_name_field(labels) - labels_list = list(labels.values_list(field, flat=True)) - if len(labels_list) > 0: - print_values = _print_values(labels_list, n=10) - type_str = f": {related_model}" if print_types else "" - labels_msg += f" .{related_name}{type_str} = {print_values}\n" + m2m_data = _get_labels_postgres(self, m2m_data) + if m2m_data is None: + m2m_data = _get_labels(self, instance=self._state.db) + + labels_msg = _print_labels(self, m2m_data, print_types) msg = "" if labels_msg: @@ -218,7 +245,7 @@ def add_from(self, data: Artifact | Collection, transfer_logs: dict = None) -> N if transfer_logs is None: transfer_logs = {"mapped": [], "transferred": [], "run": None} using_key = settings._using_key - for related_name, (_, labels) in get_labels_as_dict( + for related_name, (_, labels) in _get_labels( data, instance=data._state.db ).items(): labels = labels.all() From 8161d2a4d7f14ae45d9a4db2b9a925bb69942574 Mon Sep 17 00:00:00 2001 From: Sunny Sun <38218185+sunnyosun@users.noreply.github.com> Date: Thu, 28 Nov 2024 14:18:17 +0100 Subject: [PATCH 3/6] =?UTF-8?q?=F0=9F=8E=A8=20Fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lamindb/core/_label_manager.py | 125 +++++++++++++++++------------ tests/core/test_feature_manager.py | 2 +- 2 files changed, 76 insertions(+), 51 deletions(-) diff --git a/lamindb/core/_label_manager.py b/lamindb/core/_label_manager.py index a4d2a57ea..74d5fd6bb 100644 --- a/lamindb/core/_label_manager.py +++ b/lamindb/core/_label_manager.py @@ -3,7 +3,6 @@ from collections import defaultdict from typing import TYPE_CHECKING -import numpy as np from django.db import connections from lamin_utils import colors, logger from lnschema_core.models import CanCurate, Feature @@ -69,9 +68,16 @@ def _print_labels(self, m2m_data: dict, print_types: bool = False): if not labels or related_name == "feature_sets": continue related_model = get_related_model(self, related_name) - print_values = _print_values(labels.values(), n=10) - type_str = f": {related_model}" if print_types else "" - labels_msg += f" .{related_name}{type_str} = {print_values}\n" + if isinstance(labels, dict): + print_values = _print_values(labels.values(), n=10) + else: # labels are a QuerySet + field = get_name_field(labels) + labels_list = list(labels.values_list(field, flat=True)) + if len(labels_list) > 0: + print_values = _print_values(labels_list, n=10) + if print_values: + type_str = f": {related_model}" if print_types else "" + labels_msg += f" .{related_name}{type_str} = {print_values}\n" return labels_msg @@ -132,52 +138,73 @@ def print_labels( return msg -def save_validated_records(labels: QuerySet | list | dict) -> tuple[list, list]: - """Save validated labels from public based on ontology_id_fields.""" - - def save_records_from_ontology_ids( - labels: QuerySet | list | dict, - ) -> tuple[list[str], list[str]]: - if len(labels) == 0: - return [], [] - registry = labels[0].__class__ - field = REGISTRY_UNIQUE_FIELD.get(registry.__name__.lower(), "uid") - if hasattr(registry, "_ontology_id_field"): - field = registry._ontology_id_field - # if the field value is None, use uid field - label_uids = np.array( - [getattr(label, field) for label in labels if label is not None] - ) - # save labels from ontology_ids - if hasattr(registry, "_ontology_id_field") and len(label_uids) > 0: - try: - labels_records = registry.from_values(label_uids, field=field) - save([r for r in labels_records if r._state.adding]) - except Exception: # noqa S110 - pass - field = "uid" - label_uids = np.array( - [getattr(label, field) for label in labels if label is not None] - ) - if issubclass(registry, CanCurate): - validated = registry.validate(label_uids, field=field, mute=True) - validated_uids = label_uids[validated] - validated_labels = registry.filter( - **{f"{field}__in": validated_uids} - ).list() - new_labels = [labels[int(i)] for i in np.argwhere(~validated).flatten()] - else: - validated_labels = [] - new_labels = list(labels) +def _save_validated_records( + labels: QuerySet | list | dict, +) -> tuple[list[str], list[str]]: + if not labels: + return [], [] + registry = labels[0].__class__ + field = ( + REGISTRY_UNIQUE_FIELD.get(registry.__name__.lower(), "uid") + if not hasattr(registry, "_ontology_id_field") + else registry._ontology_id_field + ) + # if the field value is None, use uid field + # label_uids = np.array( + # [getattr(label, field) for label in labels if label is not None] + # ) + label_uids = [getattr(label, field) for label in labels if label is not None] + # save labels from ontology_ids + # if hasattr(registry, "_ontology_id_field") and len(label_uids) > 0: + # try: + # labels_records = registry.from_values(label_uids, field=field) + # save([r for r in labels_records if r._state.adding]) + # except Exception: # S110 + # pass + # field = "uid" + # label_uids = np.array( + # [getattr(label, field) for label in labels if label is not None] + # ) + # if issubclass(registry, CanCurate): + # validated = registry.validate(label_uids, field=field, mute=True) + # validated_uids = label_uids[validated] + # validated_labels = registry.filter(**{f"{field}__in": validated_uids}).list() + # new_labels = [labels[int(i)] for i in np.argwhere(~validated).flatten()] + # else: + # validated_labels = [] + # new_labels = list(labels) + if hasattr(registry, "_ontology_id_field") and label_uids: + try: + records = registry.from_values(label_uids, field=field) + save([r for r in records if r._state.adding]) + except Exception: # noqa: S110 + pass + field = "uid" + label_uids = [label.uid for label in labels if label is not None] + + if issubclass(registry, CanCurate): + validated = registry.validate(label_uids, field=field, mute=True) + validated_uids = [ + uid for uid, is_valid in zip(label_uids, validated) if is_valid + ] + validated_labels = registry.filter(**{f"{field}__in": validated_uids}).list() + new_labels = [ + label for label, is_valid in zip(labels, validated) if not is_valid + ] return validated_labels, new_labels + return [], list(labels) + +def save_validated_records( + labels: QuerySet | list | dict, +) -> tuple[list[str], list[str]] | dict[str, tuple[list[str], list[str]]]: + """Save validated labels from public based on ontology_id_fields.""" if isinstance(labels, dict): - result = {} - for registry, labels_registry in labels.items(): - result[registry] = save_records_from_ontology_ids(labels_registry) - return result # type: ignore - else: - return save_records_from_ontology_ids(labels) + return { + registry: _save_validated_records(registry_labels) + for registry, registry_labels in labels.items() + } + return _save_validated_records(labels) class LabelManager: @@ -245,9 +272,7 @@ def add_from(self, data: Artifact | Collection, transfer_logs: dict = None) -> N if transfer_logs is None: transfer_logs = {"mapped": [], "transferred": [], "run": None} using_key = settings._using_key - for related_name, (_, labels) in _get_labels( - data, instance=data._state.db - ).items(): + for related_name, labels in _get_labels(data, instance=data._state.db).items(): labels = labels.all() if not labels.exists(): continue diff --git a/tests/core/test_feature_manager.py b/tests/core/test_feature_manager.py index a68045719..d00069775 100644 --- a/tests/core/test_feature_manager.py +++ b/tests/core/test_feature_manager.py @@ -379,7 +379,7 @@ def test_labels_add(adata): adata2 = adata.copy() adata2.uns["mutated"] = True artifact2 = ln.Artifact(adata2, description="My new artifact").save() - from lamindb.core._label_manager import get_labels_as_dict + from lamindb.core._label_manager import _get_labels artifact2.labels.add_from(artifact) experiments = artifact2.labels.get(experiment) From 657b59b990741ef2b45319d7766dfee935a02aaf Mon Sep 17 00:00:00 2001 From: Sunny Sun <38218185+sunnyosun@users.noreply.github.com> Date: Thu, 28 Nov 2024 15:39:31 +0100 Subject: [PATCH 4/6] =?UTF-8?q?=F0=9F=8E=A8=20Simplify?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lamindb/core/_feature_manager.py | 2 +- lamindb/core/_label_manager.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lamindb/core/_feature_manager.py b/lamindb/core/_feature_manager.py index 96c4ba7c0..a26643335 100644 --- a/lamindb/core/_feature_manager.py +++ b/lamindb/core/_feature_manager.py @@ -199,7 +199,7 @@ def _get_categoricals( return {} result = defaultdict(set) - for _, (_, links) in _get_labels(self, links=True, instance=self._state.db).items(): + for _, links in _get_labels(self, links=True, instance=self._state.db).items(): for link in links: if hasattr(link, "feature_id") and link.feature_id is not None: feature = Feature.objects.using(self._state.db).get(id=link.feature_id) diff --git a/lamindb/core/_label_manager.py b/lamindb/core/_label_manager.py index 74d5fd6bb..4ef1cfe41 100644 --- a/lamindb/core/_label_manager.py +++ b/lamindb/core/_label_manager.py @@ -76,7 +76,7 @@ def _print_labels(self, m2m_data: dict, print_types: bool = False): if len(labels_list) > 0: print_values = _print_values(labels_list, n=10) if print_values: - type_str = f": {related_model}" if print_types else "" + type_str = f": {related_model.__name__}" if print_types else "" labels_msg += f" .{related_name}{type_str} = {print_values}\n" return labels_msg @@ -126,7 +126,7 @@ def print_labels( if not self._state.adding and connections[self._state.db].vendor == "postgresql": m2m_data = _get_labels_postgres(self, m2m_data) - if m2m_data is None: + if not m2m_data: m2m_data = _get_labels(self, instance=self._state.db) labels_msg = _print_labels(self, m2m_data, print_types) From 8d88f86ecf1a30373b53f571f19357e4fbced28a Mon Sep 17 00:00:00 2001 From: Sunny Sun <38218185+sunnyosun@users.noreply.github.com> Date: Thu, 28 Nov 2024 16:11:25 +0100 Subject: [PATCH 5/6] =?UTF-8?q?=F0=9F=94=A5=20Clean=20up?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lamindb/core/_label_manager.py | 85 ++++++---------------------------- 1 file changed, 15 insertions(+), 70 deletions(-) diff --git a/lamindb/core/_label_manager.py b/lamindb/core/_label_manager.py index 4ef1cfe41..b30d60bc0 100644 --- a/lamindb/core/_label_manager.py +++ b/lamindb/core/_label_manager.py @@ -62,40 +62,6 @@ def _get_labels_postgres( return m2m_data -def _print_labels(self, m2m_data: dict, print_types: bool = False): - labels_msg = "" - for related_name, labels in m2m_data.items(): - if not labels or related_name == "feature_sets": - continue - related_model = get_related_model(self, related_name) - if isinstance(labels, dict): - print_values = _print_values(labels.values(), n=10) - else: # labels are a QuerySet - field = get_name_field(labels) - labels_list = list(labels.values_list(field, flat=True)) - if len(labels_list) > 0: - print_values = _print_values(labels_list, n=10) - if print_values: - type_str = f": {related_model.__name__}" if print_types else "" - labels_msg += f" .{related_name}{type_str} = {print_values}\n" - return labels_msg - - -# def _print_labels_postgres( -# self: Artifact | Collection, m2m_data: dict, print_types: bool = False -# ) -> str: -# m2m_data = _get_labels_postgres(self, m2m_data) -# labels_msg = "" -# for related_name, labels in m2m_data.items(): -# if not labels or related_name == "feature_sets": -# continue -# related_model = get_related_model(self, related_name) -# print_values = _print_values(labels.values(), n=10) -# type_str = f": {related_model}" if print_types else "" -# labels_msg += f" .{related_name}{type_str} = {print_values}\n" -# return labels_msg - - def print_labels( self: Artifact | Collection, m2m_data: dict | None = None, @@ -110,26 +76,26 @@ def print_labels( Returns: A string representation of the labels associated with the artifact or collection. """ - # if not self._state.adding and connections[self._state.db].vendor == "postgresql": - # labels_msg = _print_labels_postgres(self, m2m_data, print_types) - # else: - # labels_msg = "" - # for related_name, (related_model, labels) in _get_labels( - # self, instance=self._state.db - # ).items(): - # field = get_name_field(labels) - # labels_list = list(labels.values_list(field, flat=True)) - # if len(labels_list) > 0: - # print_values = _print_values(labels_list, n=10) - # type_str = f": {related_model}" if print_types else "" - # labels_msg += f" .{related_name}{type_str} = {print_values}\n" - if not self._state.adding and connections[self._state.db].vendor == "postgresql": m2m_data = _get_labels_postgres(self, m2m_data) if not m2m_data: m2m_data = _get_labels(self, instance=self._state.db) - labels_msg = _print_labels(self, m2m_data, print_types) + labels_msg = "" + for related_name, labels in m2m_data.items(): + if not labels or related_name == "feature_sets": + continue + if isinstance(labels, dict): + print_values = _print_values(labels.values(), n=10) + else: # labels are a QuerySet + field = get_name_field(labels) + labels_list = list(labels.values_list(field, flat=True)) + if len(labels_list) > 0: + print_values = _print_values(labels_list, n=10) + if print_values: + related_model = get_related_model(self, related_name) + type_str = f": {related_model.__name__}" if print_types else "" + labels_msg += f" .{related_name}{type_str} = {print_values}\n" msg = "" if labels_msg: @@ -150,29 +116,8 @@ def _save_validated_records( else registry._ontology_id_field ) # if the field value is None, use uid field - # label_uids = np.array( - # [getattr(label, field) for label in labels if label is not None] - # ) label_uids = [getattr(label, field) for label in labels if label is not None] # save labels from ontology_ids - # if hasattr(registry, "_ontology_id_field") and len(label_uids) > 0: - # try: - # labels_records = registry.from_values(label_uids, field=field) - # save([r for r in labels_records if r._state.adding]) - # except Exception: # S110 - # pass - # field = "uid" - # label_uids = np.array( - # [getattr(label, field) for label in labels if label is not None] - # ) - # if issubclass(registry, CanCurate): - # validated = registry.validate(label_uids, field=field, mute=True) - # validated_uids = label_uids[validated] - # validated_labels = registry.filter(**{f"{field}__in": validated_uids}).list() - # new_labels = [labels[int(i)] for i in np.argwhere(~validated).flatten()] - # else: - # validated_labels = [] - # new_labels = list(labels) if hasattr(registry, "_ontology_id_field") and label_uids: try: records = registry.from_values(label_uids, field=field) From 2cf108e8eb63cc8df9d679b9e6c310406d684e48 Mon Sep 17 00:00:00 2001 From: Sunny Sun <38218185+sunnyosun@users.noreply.github.com> Date: Thu, 28 Nov 2024 16:32:45 +0100 Subject: [PATCH 6/6] =?UTF-8?q?=F0=9F=94=A5=20Further=20cleanup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lamindb/core/_label_manager.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/lamindb/core/_label_manager.py b/lamindb/core/_label_manager.py index b30d60bc0..ba689bcec 100644 --- a/lamindb/core/_label_manager.py +++ b/lamindb/core/_label_manager.py @@ -51,7 +51,7 @@ def _get_labels( def _get_labels_postgres( self: Artifact | Collection, m2m_data: dict | None = None -) -> dict[str, list]: +) -> dict[str, dict[int, str]]: """Get all labels associated with an artifact or collection as a dictionary. This is a postgres-specific approach that uses django Subquery. @@ -85,13 +85,11 @@ def print_labels( for related_name, labels in m2m_data.items(): if not labels or related_name == "feature_sets": continue - if isinstance(labels, dict): + if isinstance(labels, dict): # postgres, labels are a dict[id, name] print_values = _print_values(labels.values(), n=10) else: # labels are a QuerySet field = get_name_field(labels) - labels_list = list(labels.values_list(field, flat=True)) - if len(labels_list) > 0: - print_values = _print_values(labels_list, n=10) + print_values = _print_values(labels.values_list(field, flat=True), n=10) if print_values: related_model = get_related_model(self, related_name) type_str = f": {related_model.__name__}" if print_types else "" @@ -106,9 +104,9 @@ def print_labels( def _save_validated_records( labels: QuerySet | list | dict, -) -> tuple[list[str], list[str]]: +) -> list[str]: if not labels: - return [], [] + return [] registry = labels[0].__class__ field = ( REGISTRY_UNIQUE_FIELD.get(registry.__name__.lower(), "uid") @@ -129,20 +127,16 @@ def _save_validated_records( if issubclass(registry, CanCurate): validated = registry.validate(label_uids, field=field, mute=True) - validated_uids = [ - uid for uid, is_valid in zip(label_uids, validated) if is_valid - ] - validated_labels = registry.filter(**{f"{field}__in": validated_uids}).list() new_labels = [ label for label, is_valid in zip(labels, validated) if not is_valid ] - return validated_labels, new_labels - return [], list(labels) + return new_labels + return list(labels) def save_validated_records( labels: QuerySet | list | dict, -) -> tuple[list[str], list[str]] | dict[str, tuple[list[str], list[str]]]: +) -> list[str] | dict[str, list[str]]: """Save validated labels from public based on ontology_id_fields.""" if isinstance(labels, dict): return { @@ -225,7 +219,7 @@ def add_from(self, data: Artifact | Collection, transfer_logs: dict = None) -> N data_name_lower = data.__class__.__name__.lower() labels_by_features = defaultdict(list) features = set() - _, new_labels = save_validated_records(labels) + new_labels = save_validated_records(labels) if len(new_labels) > 0: transfer_fk_to_default_db_bulk( new_labels, using_key, transfer_logs=transfer_logs @@ -255,7 +249,7 @@ def add_from(self, data: Artifact | Collection, transfer_logs: dict = None) -> N label = label_returned labels_by_features[key].append(label) # treat features - _, new_features = save_validated_records(list(features)) + new_features = save_validated_records(list(features)) if len(new_features) > 0: transfer_fk_to_default_db_bulk( new_features, using_key, transfer_logs=transfer_logs