Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/ispras/GNN-AID into add_…
Browse files Browse the repository at this point in the history
…attack_defense_metrics
  • Loading branch information
LukyanovKirillML committed Dec 3, 2024
2 parents 0df1ec1 + f4b2298 commit 5e1520e
Show file tree
Hide file tree
Showing 12 changed files with 742 additions and 227 deletions.
1 change: 1 addition & 0 deletions data/multiple-graphs/custom/example/raw/.info
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"name": "example",
"count": 3,
"nodes": [3, 4, 5],
"directed": false,
Expand Down
12 changes: 4 additions & 8 deletions data/multiple-graphs/custom/small/raw/small.node_attributes/a
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
"1": 0,
"2": 1,
"3": 1,
"4": 1,
"5": 0
"4": 1
},
{
"0": 1,
Expand All @@ -17,8 +16,7 @@
"0": 0,
"1": 1,
"2": 1,
"3": 1,
"4": 0
"3": 1
},
{
"0": 0,
Expand All @@ -45,8 +43,7 @@
"3": 1,
"4": 1,
"5": 1,
"6": 1,
"7": 0
"6": 1
},
{
"0": 1,
Expand All @@ -55,8 +52,7 @@
"3": 1,
"4": 1,
"5": 1,
"6": 1,
"7": 0
"6": 1
},
{
"0": 0,
Expand Down
12 changes: 4 additions & 8 deletions data/multiple-graphs/custom/small/raw/small.node_attributes/b
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
"1": 1,
"2": 1,
"3": 0,
"4": 0,
"5": 1
"4": 0
},
{
"0": 0,
Expand All @@ -17,8 +16,7 @@
"0": 0,
"1": 0,
"2": 0,
"3": 0,
"4": 1
"3": 0
},
{
"0": 1,
Expand All @@ -45,8 +43,7 @@
"3": 1,
"4": 0,
"5": 1,
"6": 1,
"7": 1
"6": 1
},
{
"0": 0,
Expand All @@ -55,8 +52,7 @@
"3": 0,
"4": 0,
"5": 1,
"6": 0,
"7": 1
"6": 0
},
{
"0": 1,
Expand Down
1 change: 1 addition & 0 deletions data/single-graph/custom/example/raw/.info
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"name": "example",
"count": 1,
"directed": false,
"nodes": [8],
Expand Down
12 changes: 1 addition & 11 deletions metainfo/torch_geom_index.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
"pytorch-geometric-other": [
"Actor",
"BAShapes",
"Flickr",
"KarateClub",
"Reddit2"
"KarateClub"
],
"Planetoid": [
"CiteSeer",
Expand All @@ -32,7 +30,6 @@
"AIDS",
"BZR",
"BZR_MD",
"COIL-DEL",
"COX2",
"COX2_MD",
"Cuneiform",
Expand Down Expand Up @@ -128,13 +125,6 @@
"salicylic_acid",
"toluene",
"uracil"
],
"MoleculeNet": [
"MUV",
"PCBA",
"SIDER",
"Tox21",
"ToxCast"
]
}
}
Expand Down
31 changes: 31 additions & 0 deletions src/aux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,34 @@ def all_subclasses(
) -> set:
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c)])


class tmp_dir():
"""
Temporary create a directory near the given path. Remove it on exit.
"""
def __init__(
self,
path: Path
):
self.path = path
from time import time
self.tmp_dir = self.path.parent / (self.path.name + str(time()))

def __enter__(
self
) -> Path:
self.tmp_dir.mkdir(parents=True)
return self.tmp_dir

def __exit__(
self,
exception_type,
exception_value,
exception_traceback
) -> None:
import shutil
try:
shutil.rmtree(self.tmp_dir)
except FileNotFoundError:
pass
82 changes: 81 additions & 1 deletion src/base/custom_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
"""
super().__init__(dataset_config)

assert self.labels_dir.exists()
# assert self.labels_dir.exists()
self.info = DatasetInfo.read(self.info_path)
self.node_map = None # Optional nodes mapping: node_map[i] = original id of node i
self.edge_index = None
Expand Down Expand Up @@ -68,6 +68,86 @@ def edge_index_path(
""" Path to dir with labels. """
return self.root_dir / 'raw' / (self.name + '.edge_index')

def check_validity(
self
):
""" Check that dataset files (graph and attributes) are valid and consistent with .info.
"""
# Assuming info is OK
count = self.info.count
# Check edges
if self.is_multi():
with open(self.edges_path, 'r') as f:
num_edges = sum(1 for _ in f)
with open(self.edge_index_path, 'r') as f:
edge_index = json.load(f)
assert all(i <= num_edges for i in edge_index)
assert num_edges == edge_index[-1]
assert count == len(edge_index)

# Check nodes
all_nodes = [set() for _ in range(count)] # sets of nodes
if self.is_multi():
with open(self.edges_path, 'r') as f:
start = 0
for ix, end in enumerate(edge_index):
for _ in range(end-start):
all_nodes[ix].update(map(int, f.readline().split()))
if self.info.remap:
assert len(all_nodes[ix]) == self.info.nodes[ix]
else:
assert all_nodes[ix] == set(range(self.info.nodes[ix]))
start = end
else:
with open(self.edges_path, 'r') as f:
for line in f.readlines():
all_nodes[0].update(map(int, line.split()))
if self.info.remap:
assert len(all_nodes[0]) == self.info.nodes[0]
else:
assert all_nodes[0] == set(range(self.info.nodes[0]))

# Check node attributes
for ix, attr in enumerate(self.info.node_attributes["names"]):
with open(self.node_attributes_dir / attr, 'r') as f:
node_attributes = json.load(f)
if not self.is_multi():
node_attributes = [node_attributes]
for i, attributes in enumerate(node_attributes):
assert all_nodes[i] == set(map(int, attributes.keys()))
if self.info.node_attributes["types"][ix] == "continuous":
v_min, v_max = self.info.node_attributes["values"][ix]
assert all(isinstance(v, (int, float, complex)) for v in attributes.values())
assert min(attributes.values()) >= v_min
assert max(attributes.values()) <= v_max
elif self.info.node_attributes["types"][ix] == "categorical":
assert set(attributes.values()).issubset(set(self.info.node_attributes["values"][ix]))

# Check edge attributes
for ix, attr in enumerate(self.info.edge_attributes["names"]):
with open(self.edge_attributes_dir / attr, 'r') as f:
edge_attributes = json.load(f)
if not self.is_multi():
edge_attributes = [edge_attributes]
for i, attributes in enumerate(edge_attributes):
# TODO check edges
if self.info.edge_attributes["types"][ix] == "continuous":
v_min, v_max = self.info.edge_attributes["values"][ix]
assert all(isinstance(v, (int, float, complex)) for v in attributes.values())
assert min(attributes.values()) >= v_min
assert max(attributes.values()) <= v_max
elif self.info.edge_attributes["types"][ix] == "categorical":
assert set(attributes.values()).issubset(set(self.info.edge_attributes["values"][ix]))

# Check labels
for labelling, n_classes in self.info.labelings.items():
with open(self.labels_dir / labelling, 'r') as f:
labels = json.load(f)
if self.is_multi(): # graph labels
assert set(range(count)) == set(map(int, labels.keys()))
else: # nodes labels
assert all_nodes[0] == set(map(int, labels.keys()))

def build(
self,
dataset_var_config: Union[ConfigPattern, DatasetVarConfig]
Expand Down
Loading

0 comments on commit 5e1520e

Please sign in to comment.