From 9ba3fc2d08bd1905f8b1c38f46f5222936574158 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Fri, 22 Dec 2023 14:07:14 -0700 Subject: [PATCH] rewrite relabel method --- historydag/dag.py | 74 +++++++++++++++------------- historydag/mutation_annotated_dag.py | 16 +++++- 2 files changed, 53 insertions(+), 37 deletions(-) diff --git a/historydag/dag.py b/historydag/dag.py index d283fdd..5a2270c 100644 --- a/historydag/dag.py +++ b/historydag/dag.py @@ -551,6 +551,10 @@ def _check_valid(self) -> bool: for node in po: if not node.is_ua_node(): + # *** Clades are pairwise disjoint: + if not node.is_leaf(): + if len(node.clade_union()) != sum(len(clade) for clade in node.clades): + raise ValueError("Found a node whose child clades are not pairwise disjoint") for clade, eset in node.clades.items(): for child in eset.targets: # ***Parent clade equals child clade union for all edges: @@ -810,49 +814,49 @@ def relabel( appropriate for that node. The relabel_func should return a consistent NamedTuple type with name Label. That is, all returned labels should have matching `_fields` attribute. - No two leaf nodes may be mapped to the same new label. + Method is only guaranteed to work when no two leaf nodes are + mapped to the same new label. If this is not the case, this method may + raise a warning or error, or may fail silently, returning an invalid + HistoryDag. relax_type: Whether to require the returned HistoryDag to be of the same subclass as self. If True, the returned HistoryDag will be of the abstract type `HistoryDag` """ - # TODO: Sometimes this fails because multiple children are identical - # after relabeling. Rewrite using - # mutation_annotated_dag.load_MAD_protobuf as template. - - leaf_label_dict = {leaf.label: relabel_func(leaf) for leaf in self.get_leaves()} - if len(leaf_label_dict) != len(set(leaf_label_dict.values())): - raise RuntimeError( - "relabeling function maps multiple leaf nodes to the same new label" - ) - - def relabel_clade(old_clade): - return frozenset(leaf_label_dict[old_label] for old_label in old_clade) - - def relabel_node(old_node): + old_node_to_node_d = dict() + node_to_node_d = dict() + + for old_node in self.postorder(): + new_children = [old_node_to_node_d[old_child] for old_child in old_node.children()] + child_clades = frozenset({child.clade_union() for child in new_children}) + if len(child_clades) != len(old_node.clades): + warnings.warn(f"relabel_func {relabel_func.__name__} maps multiple" + " leaf nodes to the same label. This is not supported" + " and may fail with an error or silently. If you ignore" + " this warning, at least run _check_valid() on the result") if old_node.is_ua_node(): - return UANode( - EdgeSet( - [relabel_node(old_child) for old_child in old_node.children()] - ) - ) + new_node = UANode(EdgeSet()) else: - clades = { - relabel_clade(old_clade): EdgeSet( - [relabel_node(old_child) for old_child in old_eset.targets], - weights=old_eset.weights, - probs=old_eset.probs, - ) - for old_clade, old_eset in old_node.clades.items() - } - return HistoryDagNode(relabel_func(old_node), clades, old_node.attr) + new_node = HistoryDagNode( + relabel_func(old_node), + {clade: EdgeSet() for clade in child_clades}, + old_node.attr, + ) + new_node = node_to_node_d.get(new_node, new_node) + node_to_node_d[new_node] = new_node + + old_node_to_node_d[old_node] = new_node + for child in new_children: + new_node.add_edge(child, weight=1, prob=1, prob_norm=False) + + # Last node in postorder should be UA node + assert new_node.is_ua_node() + dag = HistoryDag(new_node) if relax_type: - newdag = HistoryDag(relabel_node(self.dagroot)) + return dag else: - newdag = self.__class__(relabel_node(self.dagroot)) - # do any necessary collapsing - newdag = newdag.sample() | newdag - return newdag + return type(self).from_history_dag(dag) + def add_label_fields(self, new_field_names=[], new_field_values=lambda n: []): """Returns a copy of the DAG in which each node's label is extended to include the new fields listed in `new_field_names`. @@ -874,7 +878,7 @@ def add_fields(node): return self.relabel(add_fields) def remove_label_fields(self, fields_to_remove=[]): - """Returns a oopy of the DAG with the list of `fields_to_remove` + """Returns a copy of the DAG with the list of `fields_to_remove` dropped from each node's label. Args: diff --git a/historydag/mutation_annotated_dag.py b/historydag/mutation_annotated_dag.py index df2817e..541da99 100644 --- a/historydag/mutation_annotated_dag.py +++ b/historydag/mutation_annotated_dag.py @@ -116,12 +116,24 @@ def to_protobuf(self, leaf_data_func=None): leaf_data_func: a function taking a DAG node and returning a string to store in the protobuf node_name field `condensed_leaves` of leaf nodes. On leaf nodes, this data is appended after the unique leaf ID. + + Note that internal node IDs will be reassigned, even if internal nodes have node IDs + in their label data. """ - #TODO fix when node ids aren't available for internal nodes refseq = next(self.preorder(skip_ua_node=True)).label.compact_genome.reference empty_cg = CompactGenome(dict(), refseq) + # Create unique leaf IDs if the node_id field isn't available + if "node_id" in self.get_label_type()._fields: + def get_leaf_id(node): + return node.label.node_id + else: + leaf_id_map = {n: f"s{idx}" for idx, n in enumerate(self.get_leaves())} + + def get_leaf_id(node): + return leaf_id_map[node] + def mut_func(pnode, cnode): if pnode.is_ua_node(): parent_seq = empty_cg @@ -142,7 +154,7 @@ def key_func(cladeitem): node_name = data.node_names.add() node_name.node_id = idx if node.is_leaf(): - node_name.condensed_leaves.append(node.label.node_id) + node_name.condensed_leaves.append(get_leaf_id(node)) if leaf_data_func is not None: node_name.condensed_leaves.append(leaf_data_func(node))