diff --git a/experiments/interpretation_metrics_test.py b/experiments/interpretation_metrics_test.py index 8b01d9e..4886299 100644 --- a/experiments/interpretation_metrics_test.py +++ b/experiments/interpretation_metrics_test.py @@ -1,29 +1,162 @@ +import json +import os import random import warnings import torch from aux.custom_decorators import timing_decorator -from aux.utils import EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, EXPLAINERS_INIT_PARAMETERS_PATH +from aux.utils import EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, EXPLAINERS_INIT_PARAMETERS_PATH, root_dir, \ + EVASION_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, POISON_ATTACK_PARAMETERS_PATH from explainers.explainers_manager import FrameworkExplainersManager +from models_builder.gnn_constructor import FrameworkGNNConstructor from models_builder.gnn_models import FrameworkGNNModelManager, Metric -from src.aux.configs import ModelModificationConfig, ConfigPattern +from src.aux.configs import ModelModificationConfig, ConfigPattern, ModelConfig +from src.aux.utils import POISON_DEFENSE_PARAMETERS_PATH from src.base.datasets_processing import DatasetManager from src.models_builder.models_zoo import model_configs_zoo +from defense.JaccardDefense import jaccard_def +from attacks.metattack import meta_gradient_attack +from defense.GNNGuard import gnnguard +def load_result_dict(path): + if os.path.exists(path): + with open(path, "r") as file: + try: + data = json.load(file) + except json.JSONDecodeError: + data = {} + else: + data = {} + return data + + +def save_result_dict(path, data): + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w") as file: + json.dump(data, file) + + +def get_model_by_name(model_name, dataset): + return model_configs_zoo(dataset=dataset, model_name=model_name) + + +def explainer_run_config_for_node(explainer_name, node_ind, explainer_kwargs=None): + if explainer_kwargs is None: + explainer_kwargs = {} + explainer_kwargs["element_idx"] = node_ind + return ConfigPattern( + _config_class="ExplainerRunConfig", + _config_kwargs={ + "mode": "local", + "kwargs": { + "_class_name": explainer_name, + "_import_path": EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, + "_config_class": "Config", + "_config_kwargs": explainer_kwargs + } + } + ) + @timing_decorator -def run_interpretation_test(): - full_name = ("single-graph", "Planetoid", 'Cora') +def run_interpretation_test(explainer_name, dataset_full_name, model_name): steps_epochs = 10 - save_model_flag = False - my_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + num_explaining_nodes = 1 + explaining_metrics_params = { + "stability_graph_perturbations_nums": 1, + "stability_feature_change_percent": 0.05, + "stability_node_removal_percent": 0.05, + "consistency_num_explanation_runs": 1, + } + # steps_epochs = 200 + # num_explaining_nodes = 30 + # explaining_metrics_params = { + # "stability_graph_perturbations_nums": 5, + # "stability_feature_change_percent": 0.05, + # "stability_node_removal_percent": 0.05, + # "consistency_num_explanation_runs": 5 + # } + explainer_kwargs_by_explainer_name = { + 'GNNExplainer(torch-geom)': {}, + 'SubgraphX': {"max_nodes": 5}, + 'Zorro': {}, + } + dataset_key_name = "_".join(dataset_full_name) + metrics_path = root_dir / "experiments" / "explainers_metrics" + dataset_metrics_path = metrics_path / f"{model_name}_{dataset_key_name}_{explainer_name}_metrics.json" dataset, data, results_dataset_path = DatasetManager.get_by_full_name( - full_name=full_name, + full_name=dataset_full_name, dataset_ver_ind=0 ) - gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn') + explainer_kwargs = explainer_kwargs_by_explainer_name[explainer_name] + + restart_experiment = False + if restart_experiment: + + node_indices = random.sample(range(dataset.data.x.shape[0]), num_explaining_nodes) + result_dict = { + "num_nodes": num_explaining_nodes, + "nodes": list(node_indices), + "metrics_params": explaining_metrics_params, + "explainer_kwargs": explainer_kwargs + } + # save_result_dict(dataset_metrics_path, result_dict) + else: + result_dict = load_result_dict(dataset_metrics_path) + if "nodes" not in result_dict: + node_indices = random.sample(range(dataset.data.x.shape[0]), num_explaining_nodes) + result_dict["nodes"] = list(node_indices) + result_dict["metrics_params"] = explaining_metrics_params + result_dict["num_nodes"] = num_explaining_nodes + result_dict["explainer_kwargs"] = explainer_kwargs + save_result_dict(dataset_metrics_path, result_dict) + node_indices = result_dict["nodes"] + explaining_metrics_params = result_dict["metrics_params"] + + + node_id_to_explainer_run_config = \ + {node_id: explainer_run_config_for_node(explainer_name, node_id, explainer_kwargs) for node_id in node_indices} + + experiment_name_to_experiment = [ + ("Unprotected", calculate_unprotected_metrics), + ("Jaccard_defence", calculate_jaccard_defence_metrics), + ("AdvTraining_defence", calculate_adversial_defence_metrics), + ("GNNGuard_defence", calculate_gnnguard_defence_metrics), + ] + + for experiment_name, calculate_fn in experiment_name_to_experiment: + if experiment_name not in result_dict: + print(f"Calculation of explanation metrics with defence: {experiment_name} started.") + explaining_metrics_params["experiment_name"] = experiment_name + metrics = calculate_fn( + explainer_name, + steps_epochs, + explaining_metrics_params, + dataset, + node_id_to_explainer_run_config, + model_name + ) + result_dict[experiment_name] = metrics + print(f"Calculation of explanation metrics with defence: {experiment_name} completed. Metrics:\n{metrics}") + save_result_dict(dataset_metrics_path, result_dict) + + +@timing_decorator +def calculate_unprotected_metrics( + explainer_name, + steps_epochs, + explaining_metrics_params, + dataset, + node_id_to_explainer_run_config, + model_name +): + save_model_flag = True + device = torch.device('cpu') + + data, results_dataset_path = dataset.data, dataset.results_dir + manager_config = ConfigPattern( _config_class="ModelManagerConfig", _config_kwargs={ @@ -37,19 +170,25 @@ def run_interpretation_test(): } } ) + + gnn = get_model_by_name(model_name, dataset) + gnn_model_manager = FrameworkGNNModelManager( gnn=gnn, dataset_path=results_dataset_path, manager_config=manager_config, modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs) ) - gnn_model_manager.gnn.to(my_device) + gnn_model_manager.gnn.to(device) data.x = data.x.float() - data = data.to(my_device) + # data.y = data.y.float() + data = data.to(device) warnings.warn("Start training") try: - raise FileNotFoundError() + print("Loading model executor") + gnn_model_manager.load_model_executor() + print("Loaded model") except FileNotFoundError: gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, @@ -69,51 +208,323 @@ def run_interpretation_test(): print(metric_loc) explainer_init_config = ConfigPattern( - _class_name="GNNExplainer(torch-geom)", + _class_name=explainer_name, _import_path=EXPLAINERS_INIT_PARAMETERS_PATH, _config_class="ExplainerInitConfig", + _config_kwargs={} + ) + + explainer = FrameworkExplainersManager( + init_config=explainer_init_config, + dataset=dataset, gnn_manager=gnn_model_manager, + explainer_name=explainer_name, + ) + + explanation_metrics = explainer.evaluate_metrics(node_id_to_explainer_run_config, explaining_metrics_params) + return explanation_metrics + + +@timing_decorator +def calculate_jaccard_defence_metrics( + explainer_name, + steps_epochs, + explaining_metrics_params, + dataset, + node_id_to_explainer_run_config, + model_name +): + save_model_flag = True + device = torch.device('cpu') + + data, results_dataset_path = dataset.data, dataset.results_dir + + gnn = get_model_by_name(model_name, dataset) + manager_config = ConfigPattern( + _config_class="ModelManagerConfig", _config_kwargs={ - "epochs": 10 + "mask_features": [], + "optimizer": { + # "_config_class": "Config", + "_class_name": "Adam", + # "_import_path": OPTIMIZERS_PARAMETERS_PATH, + # "_class_import_info": ["torch.optim"], + "_config_kwargs": {}, + } } ) - explainer_metrics_run_config = ConfigPattern( - _config_class="ExplainerRunConfig", + gnn_model_manager = FrameworkGNNModelManager( + gnn=gnn, + dataset_path=results_dataset_path, + manager_config=manager_config, + modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs) + ) + gnn_model_manager.gnn.to(device) + data.x = data.x.float() + data = data.to(device) + + poison_defense_config = ConfigPattern( + _class_name="JaccardDefender", + _import_path=POISON_DEFENSE_PARAMETERS_PATH, + _config_class="PoisonDefenseConfig", _config_kwargs={ - "mode": "local", - "kwargs": { - "_class_name": "GNNExplainer(torch-geom)", - "_import_path": EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, - "_config_class": "Config", - "_config_kwargs": { - "stability_graph_perturbations_nums": 10, - "stability_feature_change_percent": 0.05, - "stability_node_removal_percent": 0.05, - "consistency_num_explanation_runs": 10 - }, + } + ) + + gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config) + warnings.warn("Start training") + try: + print("Loading model executor") + gnn_model_manager.load_model_executor() + print("Loaded model") + except FileNotFoundError: + print("Training started.") + gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 + train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, + save_model_flag=save_model_flag, + metrics=[Metric("F1", mask='train', average=None)]) + + if train_test_split_path is not None: + dataset.save_train_test_mask(train_test_split_path) + train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[ + :] + dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask + data.percent_train_class, data.percent_test_class = train_test_sizes + warnings.warn("Training was successful") + + metric_loc = gnn_model_manager.evaluate_model( + gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro')]) + print(metric_loc) + + explainer_init_config = ConfigPattern( + _class_name=explainer_name, + _import_path=EXPLAINERS_INIT_PARAMETERS_PATH, + _config_class="ExplainerInitConfig", + _config_kwargs={} + ) + + explainer = FrameworkExplainersManager( + init_config=explainer_init_config, + dataset=dataset, gnn_manager=gnn_model_manager, + explainer_name=explainer_name, + ) + + explanation_metrics = explainer.evaluate_metrics(node_id_to_explainer_run_config, explaining_metrics_params) + return explanation_metrics + + +@timing_decorator +def calculate_adversial_defence_metrics( + explainer_name, + steps_epochs, + explaining_metrics_params, + dataset, + node_id_to_explainer_run_config, + model_name +): + save_model_flag = True + device = torch.device('cpu') + + data, results_dataset_path = dataset.data, dataset.results_dir + + gnn = get_model_by_name(model_name, dataset) + manager_config = ConfigPattern( + _config_class="ModelManagerConfig", + _config_kwargs={ + "mask_features": [], + "optimizer": { + # "_config_class": "Config", + "_class_name": "Adam", + # "_import_path": OPTIMIZERS_PARAMETERS_PATH, + # "_class_import_info": ["torch.optim"], + "_config_kwargs": {}, } } ) + gnn_model_manager = FrameworkGNNModelManager( + gnn=gnn, + dataset_path=results_dataset_path, + manager_config=manager_config, + modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs) + ) + gnn_model_manager.gnn.to(device) + data.x = data.x.float() + data = data.to(device) - explainer_GNNExpl = FrameworkExplainersManager( + fgsm_evasion_attack_config0 = ConfigPattern( + _class_name="FGSM", + _import_path=EVASION_ATTACK_PARAMETERS_PATH, + _config_class="EvasionAttackConfig", + _config_kwargs={ + "epsilon": 0.1 * 1, + } + ) + at_evasion_defense_config = ConfigPattern( + _class_name="AdvTraining", + _import_path=EVASION_DEFENSE_PARAMETERS_PATH, + _config_class="EvasionDefenseConfig", + _config_kwargs={ + "attack_name": None, + "attack_config": fgsm_evasion_attack_config0 # evasion_attack_config + } + ) + + from defense.evasion_defense import EvasionDefender + from src.aux.utils import all_subclasses + print([e.name for e in all_subclasses(EvasionDefender)]) + gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config) + + warnings.warn("Start training") + try: + print("Loading model executor") + gnn_model_manager.load_model_executor() + print("Loaded model") + except FileNotFoundError: + gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 + train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, + save_model_flag=save_model_flag, + metrics=[Metric("F1", mask='train', average=None)]) + + if train_test_split_path is not None: + dataset.save_train_test_mask(train_test_split_path) + train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[ + :] + dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask + data.percent_train_class, data.percent_test_class = train_test_sizes + warnings.warn("Training was successful") + + metric_loc = gnn_model_manager.evaluate_model( + gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro')]) + print(metric_loc) + + explainer_init_config = ConfigPattern( + _class_name=explainer_name, + _import_path=EXPLAINERS_INIT_PARAMETERS_PATH, + _config_class="ExplainerInitConfig", + _config_kwargs={} + ) + + explainer = FrameworkExplainersManager( init_config=explainer_init_config, dataset=dataset, gnn_manager=gnn_model_manager, - explainer_name='GNNExplainer(torch-geom)', + explainer_name=explainer_name, + ) + + explanation_metrics = explainer.evaluate_metrics(node_id_to_explainer_run_config, explaining_metrics_params) + return explanation_metrics + + +@timing_decorator +def calculate_gnnguard_defence_metrics( + explainer_name, + steps_epochs, + explaining_metrics_params, + dataset, + node_id_to_explainer_run_config, + model_name +): + save_model_flag = True + device = torch.device('cpu') + + data, results_dataset_path = dataset.data, dataset.results_dir + + gnn = get_model_by_name(model_name, dataset) + manager_config = ConfigPattern( + _config_class="ModelManagerConfig", + _config_kwargs={ + "mask_features": [], + "optimizer": { + # "_config_class": "Config", + "_class_name": "Adam", + # "_import_path": OPTIMIZERS_PARAMETERS_PATH, + # "_class_import_info": ["torch.optim"], + "_config_kwargs": {}, + } + } ) + gnn_model_manager = FrameworkGNNModelManager( + gnn=gnn, + dataset_path=results_dataset_path, + manager_config=manager_config, + modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs) + ) + gnn_model_manager.gnn.to(device) + data.x = data.x.float() + data = data.to(device) - num_explaining_nodes = 10 - node_indices = random.sample(range(dataset.data.x.shape[0]), num_explaining_nodes) + gnnguard_poison_defense_config = ConfigPattern( + _class_name="GNNGuard", + _import_path=POISON_DEFENSE_PARAMETERS_PATH, + _config_class="PoisonDefenseConfig", + _config_kwargs={ + "lr": 0.01, + "train_iters": 100, + # "model": gnn_model_manager.gnn + } + ) - # explainer_GNNExpl.explainer.pbar = ProgressBar(socket, "er", desc=f'{explainer_GNNExpl.explainer.name} explaining') - # explanation_metric = NodesExplainerMetric( - # model=explainer_GNNExpl.gnn, - # graph=explainer_GNNExpl.gen_dataset.data, - # explainer=explainer_GNNExpl.explainer - # ) - # res = explanation_metric.evaluate(node_indices) - explanation_metrics = explainer_GNNExpl.evaluate_metrics(node_indices, explainer_metrics_run_config) - print(explanation_metrics) + gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config) + + warnings.warn("Start training") + try: + print("Loading model executor") + gnn_model_manager.load_model_executor() + print("Loaded model") + except FileNotFoundError: + gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 + train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, + save_model_flag=save_model_flag, + metrics=[Metric("F1", mask='train', average=None)]) + + if train_test_split_path is not None: + dataset.save_train_test_mask(train_test_split_path) + train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[ + :] + dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask + data.percent_train_class, data.percent_test_class = train_test_sizes + warnings.warn("Training was successful") + + metric_loc = gnn_model_manager.evaluate_model( + gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro')]) + print(metric_loc) + + explainer_init_config = ConfigPattern( + _class_name=explainer_name, + _import_path=EXPLAINERS_INIT_PARAMETERS_PATH, + _config_class="ExplainerInitConfig", + _config_kwargs={} + ) + + explainer = FrameworkExplainersManager( + init_config=explainer_init_config, + dataset=dataset, gnn_manager=gnn_model_manager, + explainer_name=explainer_name, + ) + + explanation_metrics = explainer.evaluate_metrics(node_id_to_explainer_run_config, explaining_metrics_params) + return explanation_metrics if __name__ == '__main__': - random.seed(11) - run_interpretation_test() + # random.seed(777) + + explainers = [ + # 'GNNExplainer(torch-geom)', + # 'SubgraphX', + "Zorro", + ] + + models = [ + # 'gcn_gcn', + 'gat_gat', + # 'sage_sage', + ] + datasets = [ + ("single-graph", "Planetoid", 'Cora'), + # ("single-graph", "Amazon", 'Photo'), + ] + for explainer in explainers: + for dataset_full_name in datasets: + for model_name in models: + run_interpretation_test(explainer, dataset_full_name, model_name) + # dataset_full_name = ("single-graph", "Amazon", 'Photo') + # run_interpretation_test(dataset_full_name) diff --git a/src/explainers/GNNExplainer/torch_geom_our/out.py b/src/explainers/GNNExplainer/torch_geom_our/out.py index 5595d5a..3af5e76 100644 --- a/src/explainers/GNNExplainer/torch_geom_our/out.py +++ b/src/explainers/GNNExplainer/torch_geom_our/out.py @@ -102,17 +102,6 @@ def run(self, mode, kwargs, finalize=True): self.raw_explanation = self.explainer(self.x, self.edge_index, index=self.node_idx) self.pbar.close() - @finalize_decorator - def evaluate_tensor_graph(self, x, edge_index, node_idx, **kwargs): - self._run_mode = "local" - self.node_idx = node_idx - self.x = x - self.edge_index = edge_index - self.pbar.reset(total=self.epochs, mode=self._run_mode) - self.explainer.algorithm.pbar = self.pbar - self.raw_explanation = self.explainer(self.x, self.edge_index, index=self.node_idx, **kwargs) - self.pbar.close() - def _finalize(self): mode = self._run_mode assert mode == "local" diff --git a/src/explainers/explainer_metrics.py b/src/explainers/explainer_metrics.py index 40d127b..36f4118 100644 --- a/src/explainers/explainer_metrics.py +++ b/src/explainers/explainer_metrics.py @@ -1,21 +1,29 @@ -from typing import Type +import copy +import json +import os +from pathlib import Path +from typing import Union, Type import numpy as np import torch -from torch_geometric.utils import subgraph +from torch_geometric.utils import subgraph, k_hop_subgraph + +from aux.configs import ConfigPattern, ExplainerRunConfig +from aux.custom_decorators import timing_decorator +from base.datasets_processing import GeneralDataset class NodesExplainerMetric: - def __init__( - self, - model: Type, - graph, - explainer, - kwargs_dict: dict - ): - self.model = model - self.explainer = explainer - self.graph = graph + def __init__(self, explainers_manager: Type, explaining_metrics_params=None): + self.node_id_to_explainer_run_config = None + self.explanation_metrics_path = None + if explaining_metrics_params is None: + explaining_metrics_params = {} + self.explainers_manager = explainers_manager + self.model = explainers_manager.gnn + self.gen_dataset = explainers_manager.gen_dataset + self.explainer = explainers_manager.explainer + self.graph = explainers_manager.gen_dataset.data self.x = self.graph.x self.edge_index = self.graph.edge_index self.kwargs_dict = { @@ -24,133 +32,189 @@ def __init__( "stability_node_removal_percent": 0.05, "consistency_num_explanation_runs": 10 } - self.kwargs_dict.update(kwargs_dict) - self.nodes_explanations = {} # explanations cache. node_ind -> explanation + self.kwargs_dict.update(explaining_metrics_params) + self.dictionary = { + "explaining_metrics_params": self.kwargs_dict, + "perturbed_explanations": {} } + print(f"NodesExplainerMetric initialized with kwargs:\n{self.kwargs_dict}") - def evaluate( - self, - target_nodes_indices: list - ) -> dict: - num_targets = len(target_nodes_indices) - sparsity = 0 - stability = 0 - consistency = 0 + def get_explanation_path(self, run_config: Union[ConfigPattern, ExplainerRunConfig]) -> Path: + self.explainers_manager.explanation_result_path(run_config) + explainer_result_file_path, files_paths = \ + self.explainers_manager.explainer_result_file_path, self.explainers_manager.files_paths + self.explanation_metrics_path = files_paths[0].parent / Path('explanation_metrics.json') + return explainer_result_file_path + + def save_dictionary(self) -> None: + with open(self.explanation_metrics_path, "w") as f: + json.dump(self.dictionary, f, indent=2) + + def evaluate(self, node_id_to_explainer_run_config: dict) -> dict: + self.node_id_to_explainer_run_config = node_id_to_explainer_run_config + target_nodes_indices = sorted(node_id_to_explainer_run_config.keys()) + + self.get_explanations(target_nodes_indices[0]) + if os.path.exists(self.explanation_metrics_path): + with open(self.explanation_metrics_path, "r") as f: + self.dictionary = json.load(f) + + sparsity = [] + stability = [] + consistency = [] for node_ind in target_nodes_indices: - self.get_explanation(node_ind) - sparsity += self.calculate_sparsity(node_ind) - stability += self.calculate_stability( + print(f"Processing explanation metrics calculation for node id {node_ind}.") + self.get_explanations(node_ind) + sparsity += [self.calculate_sparsity(node_ind)] + stability += [self.calculate_stability( node_ind, graph_perturbations_nums=self.kwargs_dict["stability_graph_perturbations_nums"], feature_change_percent=self.kwargs_dict["stability_feature_change_percent"], node_removal_percent=self.kwargs_dict["stability_node_removal_percent"] - ) - consistency += self.calculate_consistency( + )] + consistency += [self.calculate_consistency( node_ind, num_explanation_runs=self.kwargs_dict["consistency_num_explanation_runs"] - ) + )] fidelity = self.calculate_fidelity(target_nodes_indices) - self.dictionary["sparsity"] = sparsity / num_targets - self.dictionary["stability"] = stability / num_targets - self.dictionary["consistency"] = consistency / num_targets - self.dictionary["fidelity"] = fidelity + self.dictionary["sparsity"] = process_metric(sparsity) + self.dictionary["stability"] = process_metric(stability) + self.dictionary["consistency"] = process_metric(consistency) + self.dictionary["fidelity"] = process_metric(fidelity) + self.save_dictionary() return self.dictionary + @timing_decorator def calculate_fidelity( self, target_nodes_indices: list - ) -> float: + ) -> list[int]: original_answer = self.model.get_answer(self.x, self.edge_index) - same_answers_count = 0 + same_answers_count = [] for node_ind in target_nodes_indices: - node_explanation = self.get_explanation(node_ind) + node_explanation = self.get_explanations(node_ind)[0] new_x, new_edge_index, new_target_node = self.filter_graph_by_explanation( self.x, self.edge_index, node_explanation, node_ind ) filtered_answer = self.model.get_answer(new_x, new_edge_index) matched = filtered_answer[new_target_node] == original_answer[node_ind] print(f"Processed fidelity calculation for node id {node_ind}. Matched: {matched}") - if matched: - same_answers_count += 1 - fidelity = same_answers_count / len(target_nodes_indices) - return fidelity + same_answers_count.append(int(matched)) + return same_answers_count + + @timing_decorator def calculate_sparsity( self, node_ind: int ) -> float: - explanation = self.get_explanation(node_ind) - sparsity = 1 - (len(explanation["data"]["nodes"]) + len(explanation["data"]["edges"])) / ( - len(self.x) + len(self.edge_index)) + explanation = self.get_explanations(node_ind)[0] + if "data" not in explanation: + raise Exception(f"Invalid explanation. No data. Explanation: {explanation}") + explanation_data = explanation["data"] + num_hops = self.model.get_num_hops() + local_subset, local_edge_index, _, _ = k_hop_subgraph(node_ind, num_hops, self.edge_index, relabel_nodes=False) + num = 0 + den = 0 + if explanation_data["nodes"] and len(explanation_data["nodes"]) != 0: + num += len(explanation["data"]["nodes"]) + den += local_subset.shape[0] + if explanation_data["edges"] and len(explanation_data["nodes"]) != 0: + num += len(explanation["data"]["edges"]) + den += local_edge_index.shape[1] + if explanation_data["features"] and len(explanation_data["features"]) != 0: + num += len(explanation_data["features"]) + den += self.x.shape[1] + if den == 0: + raise Exception(f"Invalid explanation. No data. Explanation: {explanation}") + sparsity = 1 - num / den + print(f"Sparsity calculation for node id {node_ind} completed.") return sparsity + @timing_decorator def calculate_stability( self, node_ind: int, graph_perturbations_nums: int = 10, feature_change_percent: float = 0.05, node_removal_percent: float = 0.05 - ) -> float: - base_explanation = self.get_explanation(node_ind) - stability = 0 - for _ in range(graph_perturbations_nums): - new_x, new_edge_index = self.perturb_graph( - self.x, self.edge_index, node_ind, feature_change_percent, node_removal_percent - ) - perturbed_explanation = self.calculate_explanation(new_x, new_edge_index, node_ind) + ) -> list[float]: + print(f"Stability calculation for node id {node_ind} started.") + base_explanation = self.get_explanations(node_ind)[0] + run_config = self.node_id_to_explainer_run_config[node_ind] + stability = [] + if node_ind not in self.dictionary["perturbed_explanations"]: + self.dictionary["perturbed_explanations"][node_ind] = [] + + for i in range(graph_perturbations_nums): + if i < len(self.dictionary["perturbed_explanations"][node_ind]): + perturbed_explanation = self.dictionary["perturbed_explanations"][node_ind][i] + else: + new_dataset = self.perturb_graph( + self.gen_dataset, node_ind, feature_change_percent, node_removal_percent + ) + perturbed_explanation = self.calculate_explanation(run_config, new_dataset) + self.dictionary["perturbed_explanations"][node_ind] += [perturbed_explanation] + self.save_dictionary() + base_explanation_vector, perturbed_explanation_vector = \ NodesExplainerMetric.calculate_explanation_vectors(base_explanation, perturbed_explanation) - stability += euclidean_distance(base_explanation_vector, perturbed_explanation_vector) + stability += [euclidean_distance(base_explanation_vector, perturbed_explanation_vector)] - stability = stability / graph_perturbations_nums + # stability = stability / graph_perturbations_nums + print(f"Stability calculation for node id {node_ind} completed.") return stability + @timing_decorator def calculate_consistency( self, node_ind: int, num_explanation_runs: int = 10 - ) -> float: - explanation = self.get_explanation(node_ind) - consistency = 0 - for _ in range(num_explanation_runs): - perturbed_explanation = self.calculate_explanation(self.x, self.edge_index, node_ind) + ) -> list[float]: + print(f"Consistency calculation for node id {node_ind} started.") + explanations = self.get_explanations(node_ind, num_explanations=num_explanation_runs + 1) + explanation = explanations[0] + consistency = [] + for ind in range(num_explanation_runs): + perturbed_explanation = explanations[ind + 1] base_explanation_vector, perturbed_explanation_vector = \ NodesExplainerMetric.calculate_explanation_vectors(explanation, perturbed_explanation) - consistency += cosine_similarity(base_explanation_vector, perturbed_explanation_vector) + consistency += [cosine_similarity(base_explanation_vector, perturbed_explanation_vector)] explanation = perturbed_explanation - consistency = consistency / num_explanation_runs + # consistency = consistency / num_explanation_runs + print(f"Consistency calculation for node id {node_ind} completed.") return consistency - def calculate_explanation( - self, - x: torch.Tensor, - edge_index: torch.Tensor, - node_idx: int, - **kwargs - ): - print(f"Processing explanation calculation for node id {node_idx}.") - self.explainer.evaluate_tensor_graph(x, edge_index, node_idx, **kwargs) - print(f"Explanation calculation for node id {node_idx} completed.") - return self.explainer.explanation.dictionary - - def get_explanation( - self, - node_ind: int - ): - if node_ind in self.nodes_explanations: - node_explanation = self.nodes_explanations[node_ind] - else: - node_explanation = self.calculate_explanation(self.x, self.edge_index, node_ind) - self.nodes_explanations[node_ind] = node_explanation - return node_explanation + @timing_decorator + def calculate_explanation(self, run_config, gen_dataset, save_explanation_flag=False) -> dict: + return self.explainers_manager.conduct_experiment_by_dataset( + run_config, + gen_dataset, + save_explanation_flag=save_explanation_flag + ) + + def get_explanations(self, node_ind, num_explanations=1) -> list[dict]: + node_explanations = [] + for explanation_index in range(num_explanations): + run_config = self.node_id_to_explainer_run_config[node_ind] + self.explainers_manager.modification_config = ConfigPattern( + _config_class="ExplainerModificationConfig", + _config_kwargs={"explainer_ver_ind": explanation_index} + ) + explainer_result_file_path = self.get_explanation_path(run_config) + if os.path.exists(explainer_result_file_path): + with open(explainer_result_file_path, "r") as f: + node_explanation = json.load(f) + else: + node_explanation = self.calculate_explanation(run_config, self.gen_dataset, save_explanation_flag=True) + node_explanations += [node_explanation] + return node_explanations @staticmethod - def parse_explanation( - explanation: dict - ) -> [dict, dict]: + def parse_explanation(explanation: dict) -> [dict, dict]: important_nodes = { int(node): float(weight) for node, weight in explanation["data"]["nodes"].items() } @@ -185,9 +249,9 @@ def filter_graph_by_explanation( @staticmethod def calculate_explanation_vectors( - base_explanation, - perturbed_explanation - ): + base_explanation: dict, + perturbed_explanation: dict + ) -> np.array: base_important_nodes, base_important_edges = NodesExplainerMetric.parse_explanation( base_explanation ) @@ -212,12 +276,14 @@ def calculate_explanation_vectors( @staticmethod def perturb_graph( - x: torch.Tensor, - edge_index: torch.Tensor, + gen_dataset: GeneralDataset, node_ind: int, feature_change_percent: float, node_removal_percent: float - ) -> [torch.Tensor, torch.Tensor]: + ) -> GeneralDataset: + new_dataset = copy.deepcopy(gen_dataset) + x = new_dataset.data.x + edge_index = new_dataset.data.edge_index new_x = x.clone() num_nodes = x.shape[0] num_features = x.shape[1] @@ -232,18 +298,27 @@ def perturb_graph( nodes_to_remove = neighbors[ torch.randperm(neighbors.size(0), device=edge_index.device)[:num_nodes_to_remove] ] - mask = ~((edge_index[0] == node_ind).unsqueeze(1) & (edge_index[1].unsqueeze(0) == nodes_to_remove).any( - dim=0)) + mask = ~((edge_index[0] == node_ind) & torch.isin(edge_index[1], nodes_to_remove)) new_edge_index = edge_index[:, mask] else: new_edge_index = edge_index + new_dataset.data.x = new_x + new_dataset.data.edge_index = new_edge_index + return new_dataset + - return new_x, new_edge_index +def process_metric(data: list[float]) -> dict: + np_data = np.array(data) + return { + "mean": np.mean(np_data), + "var": np.var(np_data), + "data": data + } -def cosine_similarity(a, b): +def cosine_similarity(a: np.array, b: np.array) -> float: return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) -def euclidean_distance(a, b): +def euclidean_distance(a: np.array, b: np.array) -> float: return np.linalg.norm(a - b) diff --git a/src/explainers/explainers_manager.py b/src/explainers/explainers_manager.py index cfe1034..96a6e24 100644 --- a/src/explainers/explainers_manager.py +++ b/src/explainers/explainers_manager.py @@ -196,20 +196,42 @@ def conduct_experiment( return result + def conduct_experiment_by_dataset( + self, + run_config: Union[ConfigPattern, ExplainerRunConfig], + dataset: GeneralDataset, + socket: SocketIO = None, + save_explanation_flag=False + ) -> dict: + init_kwargs = getattr(self.init_config, CONFIG_OBJ).to_dict() + if self.explainer_name not in FrameworkExplainersManager.supported_explainers: + raise ValueError( + f"Explainer {self.explainer_name} is not supported. Choose one of " + f"{FrameworkExplainersManager.supported_explainers}") + print("Creating explainer") + name_klass = {e.name: e for e in Explainer.__subclasses__()} + klass = name_klass[self.explainer_name] + self.explainer = klass( + dataset, model=self.gnn, + device=self.device, + # device=device("cpu"), + **init_kwargs + ) + old_save_explanation_flag = self.save_explanation_flag + self.save_explanation_flag = save_explanation_flag + result = self.conduct_experiment(run_config, socket) + self.save_explanation_flag = old_save_explanation_flag + return result + def evaluate_metrics( self, - target_nodes_indices: list, - run_config: Union[ConfigPattern, ExplainerRunConfig, None] = None, + node_id_to_explainer_run_config: dict[int, ConfigPattern], + explaining_metrics_params: Union[dict, None] = None, socket: SocketIO = None ) -> dict: """ Evaluates explanation metrics between given node indices """ - # TODO: Refactor this method for framework design - if run_config: - params = getattr(getattr(run_config, CONFIG_OBJ).kwargs, CONFIG_OBJ).to_dict() - else: - params = {} self.explainer.pbar = ProgressBar( socket, "er", desc=f'{self.explainer.name} explaining metrics calculation' ) # progress bar @@ -218,13 +240,12 @@ def evaluate_metrics( if self.gen_dataset.is_multi(): raise NotImplementedError("Explanation metrics for graph classification") else: + explanation_metrics_calculator = NodesExplainerMetric( - model=self.gnn, - graph=self.gen_dataset.data, - explainer=self.explainer, - kwargs_dict=params + self, + explaining_metrics_params ) - result = explanation_metrics_calculator.evaluate(target_nodes_indices) + result = explanation_metrics_calculator.evaluate(node_id_to_explainer_run_config) print("Explanation metrics are ready") if socket: diff --git a/src/models_builder/models_zoo.py b/src/models_builder/models_zoo.py index 3cf351a..6597e74 100644 --- a/src/models_builder/models_zoo.py +++ b/src/models_builder/models_zoo.py @@ -114,6 +114,53 @@ def model_configs_zoo( ) ) + sage_sage = FrameworkGNNConstructor( + model_config=ModelConfig( + structure=ModelStructureConfig( + [ + { + 'label': 'n', + 'layer': { + 'layer_name': 'SAGEConv', + 'layer_kwargs': { + 'in_channels': dataset.num_node_features, + 'out_channels': 16, + 'heads': 3, + }, + }, + 'batchNorm': { + 'batchNorm_name': 'BatchNorm1d', + 'batchNorm_kwargs': { + 'num_features': 16, + 'eps': 1e-05, + } + }, + 'activation': { + 'activation_name': 'ReLU', + 'activation_kwargs': None, + }, + }, + + { + 'label': 'n', + 'layer': { + 'layer_name': 'SAGEConv', + 'layer_kwargs': { + 'in_channels': 16, + 'out_channels': dataset.num_classes, + 'heads': 2, + }, + }, + 'activation': { + 'activation_name': 'LogSoftmax', + 'activation_kwargs': None, + }, + }, + ] + ) + ) + ) + gat_gat = FrameworkGNNConstructor( model_config=ModelConfig( structure=ModelStructureConfig(