Skip to content

Commit

Permalink
add simplified plain gcn api
Browse files Browse the repository at this point in the history
  • Loading branch information
guochengqian committed Feb 16, 2022
1 parent fc83d08 commit 751382a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 18 deletions.
4 changes: 2 additions & 2 deletions examples/modelnet_cls/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(self, inputs):
# dilated knn
parser.add_argument('--use_dilation', default=True, type=bool, help='use dilated knn or not')
parser.add_argument('--epsilon', default=0.2, type=float, help='stochastic epsilon for gcn')
parser.add_argument('--stochastic', default=True, type=bool, help='stochastic for gcn, True or False')
parser.add_argument('--use_stochastic', default=True, type=bool, help='stochastic for gcn, True or False')

args = parser.parse_args()
args.device = torch.device('cuda')
Expand All @@ -112,5 +112,5 @@ def forward(self, inputs):
print('Input size {}'.format(feats.size()))
net = DeepGCN(args).to(args.device)
out = net(feats)

print(net)
print('Output size {}'.format(out.size()))
17 changes: 9 additions & 8 deletions gcn_lib/dense/torch_vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='e
else:
self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, stochastic, epsilon)

def forward(self, x):
edge_index = self.dilated_knn_graph(x)
def forward(self, x, edge_index=None):
if edge_index is None:
edge_index = self.dilated_knn_graph(x)
return super(DynConv2d, self).forward(x, edge_index)


Expand All @@ -81,8 +82,8 @@ def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='rel
self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv,
act, norm, bias, stochastic, epsilon, knn)

def forward(self, x):
return self.body(x)
def forward(self, x, edge_index=None):
return self.body(x, edge_index)


class ResDynBlock2d(nn.Module):
Expand All @@ -96,8 +97,8 @@ def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='rel
act, norm, bias, stochastic, epsilon, knn)
self.res_scale = res_scale

def forward(self, x):
return self.body(x) + x*self.res_scale
def forward(self, x, edge_index=None):
return self.body(x, edge_index) + x*self.res_scale


class DenseDynBlock2d(nn.Module):
Expand All @@ -110,6 +111,6 @@ def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, con
self.body = DynConv2d(in_channels, out_channels, kernel_size, dilation, conv,
act, norm, bias, stochastic, epsilon, knn)

def forward(self, x):
dense = self.body(x)
def forward(self, x, edge_index=None):
dense = self.body(x, edge_index)
return torch.cat((x, dense), 1)
17 changes: 9 additions & 8 deletions gcn_lib/sparse/torch_vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,9 @@ def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='e
self.d = dilation
self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, **kwargs)

def forward(self, x, batch=None):
edge_index = self.dilated_knn_graph(x, batch)
def forward(self, x, batch=None, edge_index=None):
if edge_index is None:
edge_index = self.dilated_knn_graph(x, batch)
return super(DynConv, self).forward(x, edge_index)


Expand All @@ -291,8 +292,8 @@ def __init__(self, channels, kernel_size=9, dilation=1, conv='edge', act='relu'
act, norm, bias, **kwargs)
self.res_scale = res_scale

def forward(self, x, batch=None):
return self.body(x, batch), batch
def forward(self, x, batch=None, edge_index=None):
return self.body(x, batch, edge_index), batch


class ResDynBlock(nn.Module):
Expand All @@ -306,8 +307,8 @@ def __init__(self, channels, kernel_size=9, dilation=1, conv='edge', act='relu'
act, norm, bias, **kwargs)
self.res_scale = res_scale

def forward(self, x, batch=None):
return self.body(x, batch) + x*self.res_scale, batch
def forward(self, x, batch=None, edge_index=None):
return self.body(x, batch, edge_index) + x*self.res_scale, batch


class DenseDynBlock(nn.Module):
Expand All @@ -319,8 +320,8 @@ def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv
self.body = DynConv(in_channels, out_channels, kernel_size, dilation, conv,
act, norm, bias, **kwargs)

def forward(self, x, batch=None):
dense = self.body(x, batch)
def forward(self, x, batch=None, edge_index=None):
dense = self.body(x, batch, edge_index)
return torch.cat((x, dense), 1), batch


Expand Down

0 comments on commit 751382a

Please sign in to comment.