-
Notifications
You must be signed in to change notification settings - Fork 0
/
GCN0614.py
131 lines (101 loc) · 3.81 KB
/
GCN0614.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
from torch.nn import functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from sklearn.model_selection import train_test_split
# Assuming X is your feature matrix and y is your adjacency matrix
X = torch.rand(320, 48) # replace with your actual data
y = torch.randint(2, (320, 27458)) # replace with your actual data
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GCNConv(48, 128)
self.conv2 = GCNConv(128, 27458)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return torch.sigmoid(x)
import torch
from torch_geometric.data import Data
from sklearn.model_selection import train_test_split
# Assuming X is your feature tensor and y is your adjacency matrix
X = torch.rand(320, 23, 3) # replace with your actual data
y = torch.randint(23, (320, 27458)) # replace with your actual data
# Split the data into training and testing
train_index, test_index = train_test_split(range(320), test_size=0.25, random_state=42)
def to_edge_index(y):
# Get all A node indices
A_indices = torch.arange(y.size(1)).unsqueeze(0).repeat(y.size(0), 1)
# Flatten the tensors and then stack
edge_index = torch.stack([A_indices[y != -1].flatten(), y[y != -1].flatten()], dim=0)
return edge_index
# Create PyG Data objects for each time step
train_data = [Data(x=X[i], edge_index=to_edge_index(y[i])) for i in train_index]
test_data = [Data(x=X[i], edge_index=to_edge_index(y[i])) for i in test_index]
# Create PyG Data objects for each time step
train_data = [Data(x=X[i], edge_index=to_edge_index(y[i])) for i in train_index]
test_data = [Data(x=X[i], edge_index=to_edge_index(y[i])) for i in test_index]
# Initialize the model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Training function
def train():
model.train()
optimizer.zero_grad()
out = model(train_data)
loss = F.binary_cross_entropy(out[train_data.edge_index[0]], train_data.edge_index[1])
loss.backward()
optimizer.step()
return loss
# Evaluation function
def evaluate(loader):
model.eval()
for data in loader:
data = data.to(device)
with torch.no_grad():
pred = model(data)
# You will want to replace this with your own evaluation metric
return pred
# Training loop
for epoch in range(100):
loss = train()
print(f'Epoch: {epoch+1}, Loss: {loss:.4f}')
# Evaluation
pred = evaluate(test_data)
from torch_geometric.data import DataLoader
# Create DataLoader for training and testing sets
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32)
# Training function
def train():
model.train()
total_loss = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
out = model(data)
loss = F.binary_cross_entropy(out[data.edge_index[0]], data.edge_index[1]) # Update your loss function accordingly
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader) # Return average loss
# Evaluation function
def evaluate(loader):
model.eval()
all_preds = []
for data in loader:
data = data.to(device)
with torch.no_grad():
pred = model(data)
all_preds.append(pred)
return all_preds # Return predictions for all graphs
# Training loop
for epoch in range(100):
loss = train()
print(f'Epoch: {epoch+1}, Loss: {loss:.4f}')
# Evaluation
preds = evaluate(test_loader)