Skip to content

Commit

Permalink
rewrite relabel method
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Dec 22, 2023
1 parent 4074030 commit 9ba3fc2
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 37 deletions.
74 changes: 39 additions & 35 deletions historydag/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,10 @@ def _check_valid(self) -> bool:

for node in po:
if not node.is_ua_node():
# *** Clades are pairwise disjoint:
if not node.is_leaf():
if len(node.clade_union()) != sum(len(clade) for clade in node.clades):
raise ValueError("Found a node whose child clades are not pairwise disjoint")
for clade, eset in node.clades.items():
for child in eset.targets:
# ***Parent clade equals child clade union for all edges:
Expand Down Expand Up @@ -810,49 +814,49 @@ def relabel(
appropriate for that node. The relabel_func should return a consistent
NamedTuple type with name Label. That is, all returned labels
should have matching `_fields` attribute.
No two leaf nodes may be mapped to the same new label.
Method is only guaranteed to work when no two leaf nodes are
mapped to the same new label. If this is not the case, this method may
raise a warning or error, or may fail silently, returning an invalid
HistoryDag.
relax_type: Whether to require the returned HistoryDag to be of the same subclass as self.
If True, the returned HistoryDag will be of the abstract type `HistoryDag`
"""
# TODO: Sometimes this fails because multiple children are identical
# after relabeling. Rewrite using
# mutation_annotated_dag.load_MAD_protobuf as template.

leaf_label_dict = {leaf.label: relabel_func(leaf) for leaf in self.get_leaves()}
if len(leaf_label_dict) != len(set(leaf_label_dict.values())):
raise RuntimeError(
"relabeling function maps multiple leaf nodes to the same new label"
)

def relabel_clade(old_clade):
return frozenset(leaf_label_dict[old_label] for old_label in old_clade)

def relabel_node(old_node):
old_node_to_node_d = dict()
node_to_node_d = dict()

for old_node in self.postorder():
new_children = [old_node_to_node_d[old_child] for old_child in old_node.children()]
child_clades = frozenset({child.clade_union() for child in new_children})
if len(child_clades) != len(old_node.clades):
warnings.warn(f"relabel_func {relabel_func.__name__} maps multiple"
" leaf nodes to the same label. This is not supported"
" and may fail with an error or silently. If you ignore"
" this warning, at least run _check_valid() on the result")
if old_node.is_ua_node():
return UANode(
EdgeSet(
[relabel_node(old_child) for old_child in old_node.children()]
)
)
new_node = UANode(EdgeSet())
else:
clades = {
relabel_clade(old_clade): EdgeSet(
[relabel_node(old_child) for old_child in old_eset.targets],
weights=old_eset.weights,
probs=old_eset.probs,
)
for old_clade, old_eset in old_node.clades.items()
}
return HistoryDagNode(relabel_func(old_node), clades, old_node.attr)
new_node = HistoryDagNode(
relabel_func(old_node),
{clade: EdgeSet() for clade in child_clades},
old_node.attr,
)
new_node = node_to_node_d.get(new_node, new_node)
node_to_node_d[new_node] = new_node

old_node_to_node_d[old_node] = new_node

for child in new_children:
new_node.add_edge(child, weight=1, prob=1, prob_norm=False)

# Last node in postorder should be UA node
assert new_node.is_ua_node()
dag = HistoryDag(new_node)
if relax_type:
newdag = HistoryDag(relabel_node(self.dagroot))
return dag
else:
newdag = self.__class__(relabel_node(self.dagroot))
# do any necessary collapsing
newdag = newdag.sample() | newdag
return newdag
return type(self).from_history_dag(dag)


def add_label_fields(self, new_field_names=[], new_field_values=lambda n: []):
"""Returns a copy of the DAG in which each node's label is extended to
include the new fields listed in `new_field_names`.
Expand All @@ -874,7 +878,7 @@ def add_fields(node):
return self.relabel(add_fields)

def remove_label_fields(self, fields_to_remove=[]):
"""Returns a oopy of the DAG with the list of `fields_to_remove`
"""Returns a copy of the DAG with the list of `fields_to_remove`
dropped from each node's label.
Args:
Expand Down
16 changes: 14 additions & 2 deletions historydag/mutation_annotated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,24 @@ def to_protobuf(self, leaf_data_func=None):
leaf_data_func: a function taking a DAG node and returning a string to store
in the protobuf node_name field `condensed_leaves` of leaf nodes. On leaf
nodes, this data is appended after the unique leaf ID.
Note that internal node IDs will be reassigned, even if internal nodes have node IDs
in their label data.
"""

#TODO fix when node ids aren't available for internal nodes
refseq = next(self.preorder(skip_ua_node=True)).label.compact_genome.reference
empty_cg = CompactGenome(dict(), refseq)

# Create unique leaf IDs if the node_id field isn't available
if "node_id" in self.get_label_type()._fields:
def get_leaf_id(node):
return node.label.node_id
else:
leaf_id_map = {n: f"s{idx}" for idx, n in enumerate(self.get_leaves())}

def get_leaf_id(node):
return leaf_id_map[node]

def mut_func(pnode, cnode):
if pnode.is_ua_node():
parent_seq = empty_cg
Expand All @@ -142,7 +154,7 @@ def key_func(cladeitem):
node_name = data.node_names.add()
node_name.node_id = idx
if node.is_leaf():
node_name.condensed_leaves.append(node.label.node_id)
node_name.condensed_leaves.append(get_leaf_id(node))
if leaf_data_func is not None:
node_name.condensed_leaves.append(leaf_data_func(node))

Expand Down

0 comments on commit 9ba3fc2

Please sign in to comment.