From 2bec5a22a2b9ec7cf79c08960b89daf50b672473 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Wed, 30 Oct 2024 15:39:37 -0600 Subject: [PATCH] improve ensembling robustness and documentation --- hippynn/graphs/ensemble.py | 164 ++++++++++++++++++++++++++++++++----- 1 file changed, 142 insertions(+), 22 deletions(-) diff --git a/hippynn/graphs/ensemble.py b/hippynn/graphs/ensemble.py index c9fc68c3..16e6a778 100644 --- a/hippynn/graphs/ensemble.py +++ b/hippynn/graphs/ensemble.py @@ -16,11 +16,42 @@ from typing import List, Dict, Union, Tuple -def make_ensemble(models, *, targets: List[str] = "auto", inputs: List[str] = "auto", - prefix: str = "ensemble_", quiet=False, - ) -> Tuple[GraphModule, Tuple[Dict[str, int], Dict[str, int]]]: +def make_ensemble( + models, + *, + targets: List[str] = "auto", + inputs: List[str] = "auto", + prefix: str = "ensemble_", + quiet=False, +) -> Tuple[GraphModule, Tuple[Dict[str, int], Dict[str, int]]]: """ + + Make an ensemble out of a set of models. The ensemble graph can then be used with a predictor, + ase graph, or etc. + + The selected nodes to ensemble are classed by the db_name associated with the nodes. + + When using "auto" mode for inputs and outputs: + + - The input to the ensemble will be the combined inputs for all models in the ensemble. + - The output of the ensemble will be the combined outputs for all models. + + Otherwise, the set of ensemble inputs is explicitly specified, and errors + may occur if the requested set of inputs and outputs is not available. + + The result ensemble graph has several outputs, each of which has .mean, .std, and .all + attributes which reflect the statistics of the models in the ensemble. + + Note that it is not required that all models have the same sets of inputs and outputs. + + If a desired node is automatically ensembled, it probably does not have a db_name. + A remedy for this is to load the graphs with hippynn.graphs.ensemble.get_graphs, then + find the requested nodes in the graphs and assign them the db_name. Then pass these + graphs to make_ensemble. + + For more information on the `models` parameter, see the :func:`~hippynn.graphs.ensemble.get_graphs` function. + :param models: list containing str, node, or graphmodule, or str to glob for model directories. :param targets: list of db_name strings or the string 'auto', which will attempt to infer. :param inputs: list of db_name strings of the string 'auto', which will attempt to infer. @@ -34,12 +65,12 @@ def make_ensemble(models, *, targets: List[str] = "auto", inputs: List[str] = "a # Phase 1: Figure out what the ensemble will look like. if inputs == "auto": - inputs = identify_inputs(graphs) + inputs: set[str] = identify_inputs(graphs) if not quiet: print("Identified input quantities:", inputs) if targets == "auto": - targets = identify_targets(graphs) + targets: set[str] = identify_targets(graphs) if not quiet: print("Identified output quantities:", targets) @@ -49,7 +80,7 @@ def make_ensemble(models, *, targets: List[str] = "auto", inputs: List[str] = "a ensemble_info = make_ensemble_info(input_classes, target_classes, quiet=quiet) # Phase 2 build ensemble graph and GraphModule. - ensemble_outputs: List[EnsembleTarget] = construct_outputs(target_classes, prefix=prefix) + ensemble_outputs: Dict[str, EnsembleTarget] = construct_outputs(target_classes, prefix=prefix) ensemble_inputs: List[_BaseNode] = replace_inputs(input_classes) merged_inputs: List[_BaseNode] = merge_children_recursive(ensemble_inputs) @@ -67,6 +98,17 @@ def make_ensemble(models, *, targets: List[str] = "auto", inputs: List[str] = "a # TODO ; It seems possible that someone might want to load several models without ensembling them. def get_graphs(models: Union[List[Union[str, GraphModule, _BaseNode]], str]) -> List[GraphModule]: """ + Take a simple spec for modeled variables (glob for model directories, list of graphs, list of output nodes) + and convert this into a list of graph modules. + + + Models can be a list with entries that are one of the following: + + - str: directory to use with :func:`~hippynn.experiment.serialization.load_model_from_cwd` + - GraphModule: already built model + - node: an output target, which will be converted to a GraphModule with automatically defined inputs. + + or a string, which is used with glob to specify the list of strings. :param models: :return: @@ -89,6 +131,7 @@ def get_graphs(models: Union[List[Union[str, GraphModule, _BaseNode]], str]) -> model = load_model_from_cwd(map_location=device) except FileNotFoundError: import warnings + warnings.warn(f"Model not found in directory: {model}") else: graphs.append(model) @@ -106,6 +149,14 @@ def get_graphs(models: Union[List[Union[str, GraphModule, _BaseNode]], str]) -> def identify_targets(models: List[GraphModule]) -> set[str]: + """ + Internal function for ensembling. + + Identify targets types to ensemble. + + :param models: + :return: + """ targets: set[str] = set() @@ -118,6 +169,14 @@ def identify_targets(models: List[GraphModule]) -> set[str]: def identify_inputs(models: list[GraphModule]) -> set[str]: + """ + Internal function for ensembling. + + Find all required inputs for ensemble. + + :param models: + :return: + """ inputs: set[str] = set() @@ -128,8 +187,11 @@ def identify_inputs(models: list[GraphModule]) -> set[str]: return inputs -def collate_inputs(models: list[GraphModule], inputs: List[str]) -> Dict[str, List[GraphModule]]: +def collate_inputs(models: list[GraphModule], inputs: List[str]) -> Dict[str, List[_BaseNode]]: """ + Internal function for ensembling. + + Identify input classes for ensemble. :param models: :param inputs: @@ -148,6 +210,15 @@ def collate_inputs(models: list[GraphModule], inputs: List[str]) -> Dict[str, Li def collate_targets(models: List[GraphModule], targets: List[str]) -> Dict[str, List[_BaseNode]]: + """ + Internal function for ensembling. + + Identify targets to ensemble. + + :param models: + :param targets: + :return: + """ target_classes = collections.defaultdict(list) for m in models: @@ -164,7 +235,17 @@ def collate_targets(models: List[GraphModule], targets: List[str]) -> Dict[str, return target_classes -def make_ensemble_info(input_classes: Dict[str, List[GraphModule]], output_classes: Dict[str, List[GraphModule]], quiet=False): +def make_ensemble_info(input_classes: Dict[str, List[_BaseNode]], output_classes: Dict[str, List[_BaseNode]], quiet=False): + """ + Internal function for ensembling. + + Count up and print the ensemble variables identified. + + :param input_classes: + :param output_classes: + :param quiet: + :return: + """ input_info = {k: len(v) for k, v in input_classes.items()} output_info = {k: len(v) for k, v in output_classes.items()} @@ -182,23 +263,45 @@ def make_ensemble_info(input_classes: Dict[str, List[GraphModule]], output_class return ensemble_info -def construct_outputs(output_classes: Dict[str, List[GraphModule]], prefix: str) -> List[EnsembleTarget]: - ensemble_outputs = {} +def construct_outputs(output_classes: Dict[str, List[_BaseNode]], prefix: str) -> Dict[str, EnsembleTarget]: + """ + Internal function for ensembling. + + Build the EnsembleNodes for classes of outputs. Note that + this function attempts to produce consistent index state versions + for both a database-type index state and a reduced index state. + + :param output_classes: Dictionary giving list of nodes to ensemble. + :param prefix: name to prepend to each ensembled variable. + :return: dictionary of ensembled node names to ndoes. + """ + + # To facilitate conversion of index states of ensembled nodes, we will build + # an ensemble target for both the db_form and the reduced form for each node. + # The ensemble will return the db_form when they differ, + # but the index cache will still register the reduced form (when it is different) + # Note!: We want to run these before linking the separate models together, + # because the automation algorithms of hippynn currently handle cases + # where there is a unique type for some nodes in the graph, e.g. one pair indexer + # or one padding indexer. + # Therefore, in the first loop, we generate all index states before any + # of the ensemble nodes have been constructed. for db_name, parents in sorted(output_classes.items(), key=lambda x: x[0]): + reduced_index_state = get_reduced_index_state(*parents) + db_index_state = db_state_of(reduced_index_state) + db_state_parents = [index_type_coercion(p, db_index_state) for p in parents] + reduced_parents = [index_type_coercion(p, reduced_index_state) for p in parents] - # To facilitate conversion of index states of ensembled nodes, we will build - # an ensemble target for both the db_form and the reduced form for each node. - # The ensemble will return the db_form when they differ, - # but the index cache will still register the reduced form (when it is different) + # Now that the index states for each node have been generated and cached, + # we can loop through again and build the ensembled target nodes. + + ensemble_outputs = {} + for db_name, parents in sorted(output_classes.items(), key=lambda x: x[0]): reduced_index_state = get_reduced_index_state(*parents) db_index_state = db_state_of(reduced_index_state) - # Note: We want to run these before linking the separate models together, - # because the automation algorithms of hippynn currently handle cases - # where there is a unique type for some nodes in the graph, e.g. one pair indexer - # or one padding indexer. db_state_parents = [index_type_coercion(p, db_index_state) for p in parents] reduced_parents = [index_type_coercion(p, reduced_index_state) for p in parents] @@ -217,7 +320,16 @@ def construct_outputs(output_classes: Dict[str, List[GraphModule]], prefix: str) return ensemble_outputs -def replace_inputs(input_classes: Dict[str, List[GraphModule]]) -> List[InputNode]: +def replace_inputs(input_classes: Dict[str, List[_BaseNode]]) -> List[InputNode]: + """ + Internal function for ensembling. + + Replace all input nodes with the first representative from the class, thereby + merging inputs. + + :param input_classes: + :return: list of new input nodes. + """ ensemble_inputs = [] @@ -231,10 +343,18 @@ def replace_inputs(input_classes: Dict[str, List[GraphModule]]) -> List[InputNod return ensemble_inputs -def make_ensemble_graph(ensemble_inputs: List[InputNode], ensemble_outputs: List[EnsembleTarget]) -> GraphModule: +def make_ensemble_graph(ensemble_inputs: List[InputNode], ensemble_outputs: Dict[str, EnsembleTarget]) -> GraphModule: + """ + Internal function for ensembling. + + Put together the ensemble graph. - ensemble_output_list = [c for k,out in ensemble_outputs.items() for c in out.children] + :param ensemble_inputs: + :param ensemble_outputs: + :return: + """ + + ensemble_output_list = [c for k, out in ensemble_outputs.items() for c in out.children] ensemble_graph = GraphModule(ensemble_inputs, ensemble_output_list) return ensemble_graph -