Skip to content

Commit

Permalink
improve ensembling robustness and documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
lubbersnick committed Oct 30, 2024
1 parent a0b5ca7 commit 2bec5a2
Showing 1 changed file with 142 additions and 22 deletions.
164 changes: 142 additions & 22 deletions hippynn/graphs/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,42 @@
from typing import List, Dict, Union, Tuple


def make_ensemble(models, *, targets: List[str] = "auto", inputs: List[str] = "auto",
prefix: str = "ensemble_", quiet=False,
) -> Tuple[GraphModule, Tuple[Dict[str, int], Dict[str, int]]]:
def make_ensemble(
models,
*,
targets: List[str] = "auto",
inputs: List[str] = "auto",
prefix: str = "ensemble_",
quiet=False,
) -> Tuple[GraphModule, Tuple[Dict[str, int], Dict[str, int]]]:

"""
Make an ensemble out of a set of models. The ensemble graph can then be used with a predictor,
ase graph, or etc.
The selected nodes to ensemble are classed by the db_name associated with the nodes.
When using "auto" mode for inputs and outputs:
- The input to the ensemble will be the combined inputs for all models in the ensemble.
- The output of the ensemble will be the combined outputs for all models.
Otherwise, the set of ensemble inputs is explicitly specified, and errors
may occur if the requested set of inputs and outputs is not available.
The result ensemble graph has several outputs, each of which has .mean, .std, and .all
attributes which reflect the statistics of the models in the ensemble.
Note that it is not required that all models have the same sets of inputs and outputs.
If a desired node is automatically ensembled, it probably does not have a db_name.
A remedy for this is to load the graphs with hippynn.graphs.ensemble.get_graphs, then
find the requested nodes in the graphs and assign them the db_name. Then pass these
graphs to make_ensemble.
For more information on the `models` parameter, see the :func:`~hippynn.graphs.ensemble.get_graphs` function.
:param models: list containing str, node, or graphmodule, or str to glob for model directories.
:param targets: list of db_name strings or the string 'auto', which will attempt to infer.
:param inputs: list of db_name strings of the string 'auto', which will attempt to infer.
Expand All @@ -34,12 +65,12 @@ def make_ensemble(models, *, targets: List[str] = "auto", inputs: List[str] = "a

# Phase 1: Figure out what the ensemble will look like.
if inputs == "auto":
inputs = identify_inputs(graphs)
inputs: set[str] = identify_inputs(graphs)
if not quiet:
print("Identified input quantities:", inputs)

if targets == "auto":
targets = identify_targets(graphs)
targets: set[str] = identify_targets(graphs)
if not quiet:
print("Identified output quantities:", targets)

Expand All @@ -49,7 +80,7 @@ def make_ensemble(models, *, targets: List[str] = "auto", inputs: List[str] = "a
ensemble_info = make_ensemble_info(input_classes, target_classes, quiet=quiet)

# Phase 2 build ensemble graph and GraphModule.
ensemble_outputs: List[EnsembleTarget] = construct_outputs(target_classes, prefix=prefix)
ensemble_outputs: Dict[str, EnsembleTarget] = construct_outputs(target_classes, prefix=prefix)
ensemble_inputs: List[_BaseNode] = replace_inputs(input_classes)
merged_inputs: List[_BaseNode] = merge_children_recursive(ensemble_inputs)

Expand All @@ -67,6 +98,17 @@ def make_ensemble(models, *, targets: List[str] = "auto", inputs: List[str] = "a
# TODO ; It seems possible that someone might want to load several models without ensembling them.
def get_graphs(models: Union[List[Union[str, GraphModule, _BaseNode]], str]) -> List[GraphModule]:
"""
Take a simple spec for modeled variables (glob for model directories, list of graphs, list of output nodes)
and convert this into a list of graph modules.
Models can be a list with entries that are one of the following:
- str: directory to use with :func:`~hippynn.experiment.serialization.load_model_from_cwd`
- GraphModule: already built model
- node: an output target, which will be converted to a GraphModule with automatically defined inputs.
or a string, which is used with glob to specify the list of strings.
:param models:
:return:
Expand All @@ -89,6 +131,7 @@ def get_graphs(models: Union[List[Union[str, GraphModule, _BaseNode]], str]) ->
model = load_model_from_cwd(map_location=device)
except FileNotFoundError:
import warnings

warnings.warn(f"Model not found in directory: {model}")
else:
graphs.append(model)
Expand All @@ -106,6 +149,14 @@ def get_graphs(models: Union[List[Union[str, GraphModule, _BaseNode]], str]) ->


def identify_targets(models: List[GraphModule]) -> set[str]:
"""
Internal function for ensembling.
Identify targets types to ensemble.
:param models:
:return:
"""

targets: set[str] = set()

Expand All @@ -118,6 +169,14 @@ def identify_targets(models: List[GraphModule]) -> set[str]:


def identify_inputs(models: list[GraphModule]) -> set[str]:
"""
Internal function for ensembling.
Find all required inputs for ensemble.
:param models:
:return:
"""

inputs: set[str] = set()

Expand All @@ -128,8 +187,11 @@ def identify_inputs(models: list[GraphModule]) -> set[str]:
return inputs


def collate_inputs(models: list[GraphModule], inputs: List[str]) -> Dict[str, List[GraphModule]]:
def collate_inputs(models: list[GraphModule], inputs: List[str]) -> Dict[str, List[_BaseNode]]:
"""
Internal function for ensembling.
Identify input classes for ensemble.
:param models:
:param inputs:
Expand All @@ -148,6 +210,15 @@ def collate_inputs(models: list[GraphModule], inputs: List[str]) -> Dict[str, Li


def collate_targets(models: List[GraphModule], targets: List[str]) -> Dict[str, List[_BaseNode]]:
"""
Internal function for ensembling.
Identify targets to ensemble.
:param models:
:param targets:
:return:
"""
target_classes = collections.defaultdict(list)

for m in models:
Expand All @@ -164,7 +235,17 @@ def collate_targets(models: List[GraphModule], targets: List[str]) -> Dict[str,
return target_classes


def make_ensemble_info(input_classes: Dict[str, List[GraphModule]], output_classes: Dict[str, List[GraphModule]], quiet=False):
def make_ensemble_info(input_classes: Dict[str, List[_BaseNode]], output_classes: Dict[str, List[_BaseNode]], quiet=False):
"""
Internal function for ensembling.
Count up and print the ensemble variables identified.
:param input_classes:
:param output_classes:
:param quiet:
:return:
"""

input_info = {k: len(v) for k, v in input_classes.items()}
output_info = {k: len(v) for k, v in output_classes.items()}
Expand All @@ -182,23 +263,45 @@ def make_ensemble_info(input_classes: Dict[str, List[GraphModule]], output_class
return ensemble_info


def construct_outputs(output_classes: Dict[str, List[GraphModule]], prefix: str) -> List[EnsembleTarget]:
ensemble_outputs = {}
def construct_outputs(output_classes: Dict[str, List[_BaseNode]], prefix: str) -> Dict[str, EnsembleTarget]:
"""
Internal function for ensembling.
Build the EnsembleNodes for classes of outputs. Note that
this function attempts to produce consistent index state versions
for both a database-type index state and a reduced index state.
:param output_classes: Dictionary giving list of nodes to ensemble.
:param prefix: name to prepend to each ensembled variable.
:return: dictionary of ensembled node names to ndoes.
"""

# To facilitate conversion of index states of ensembled nodes, we will build
# an ensemble target for both the db_form and the reduced form for each node.
# The ensemble will return the db_form when they differ,
# but the index cache will still register the reduced form (when it is different)
# Note!: We want to run these before linking the separate models together,
# because the automation algorithms of hippynn currently handle cases
# where there is a unique type for some nodes in the graph, e.g. one pair indexer
# or one padding indexer.
# Therefore, in the first loop, we generate all index states before any
# of the ensemble nodes have been constructed.

for db_name, parents in sorted(output_classes.items(), key=lambda x: x[0]):
reduced_index_state = get_reduced_index_state(*parents)
db_index_state = db_state_of(reduced_index_state)
db_state_parents = [index_type_coercion(p, db_index_state) for p in parents]
reduced_parents = [index_type_coercion(p, reduced_index_state) for p in parents]

# To facilitate conversion of index states of ensembled nodes, we will build
# an ensemble target for both the db_form and the reduced form for each node.
# The ensemble will return the db_form when they differ,
# but the index cache will still register the reduced form (when it is different)
# Now that the index states for each node have been generated and cached,
# we can loop through again and build the ensembled target nodes.

ensemble_outputs = {}
for db_name, parents in sorted(output_classes.items(), key=lambda x: x[0]):

reduced_index_state = get_reduced_index_state(*parents)
db_index_state = db_state_of(reduced_index_state)

# Note: We want to run these before linking the separate models together,
# because the automation algorithms of hippynn currently handle cases
# where there is a unique type for some nodes in the graph, e.g. one pair indexer
# or one padding indexer.
db_state_parents = [index_type_coercion(p, db_index_state) for p in parents]
reduced_parents = [index_type_coercion(p, reduced_index_state) for p in parents]

Expand All @@ -217,7 +320,16 @@ def construct_outputs(output_classes: Dict[str, List[GraphModule]], prefix: str)
return ensemble_outputs


def replace_inputs(input_classes: Dict[str, List[GraphModule]]) -> List[InputNode]:
def replace_inputs(input_classes: Dict[str, List[_BaseNode]]) -> List[InputNode]:
"""
Internal function for ensembling.
Replace all input nodes with the first representative from the class, thereby
merging inputs.
:param input_classes:
:return: list of new input nodes.
"""

ensemble_inputs = []

Expand All @@ -231,10 +343,18 @@ def replace_inputs(input_classes: Dict[str, List[GraphModule]]) -> List[InputNod
return ensemble_inputs


def make_ensemble_graph(ensemble_inputs: List[InputNode], ensemble_outputs: List[EnsembleTarget]) -> GraphModule:
def make_ensemble_graph(ensemble_inputs: List[InputNode], ensemble_outputs: Dict[str, EnsembleTarget]) -> GraphModule:
"""
Internal function for ensembling.
Put together the ensemble graph.
ensemble_output_list = [c for k,out in ensemble_outputs.items() for c in out.children]
:param ensemble_inputs:
:param ensemble_outputs:
:return:
"""

ensemble_output_list = [c for k, out in ensemble_outputs.items() for c in out.children]
ensemble_graph = GraphModule(ensemble_inputs, ensemble_output_list)

return ensemble_graph

0 comments on commit 2bec5a2

Please sign in to comment.