You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Edit
The issue was I was using a python list when I should have been using a ModuleList from pytorch. The python list was causing some other issues with pytorch itself when trying to use the gpu, so as far as I'm concerned this isn't much of a concern for torchview.
Describe the bug
I was attempting to implement an U-net and visualize it with torchview. The network itself seemed to work, as in I could pass a tensor to it and I would get back some output, but when trying to use torchview, I received a KeyError. The following is the shortest code I could come up with that shows the error:
import torch
import torch.nn as nn
from torchview import draw_graph
class BuggyModule(nn.Module):
def __init__(self):
super(BuggyModule, self).__init__()
self.modules = [nn.Conv2d(3, 4, 3)]
def forward(self, x):
return self.modules[0](x)
net = BuggyModule()
draw_graph(net, input_size=(1, 3, 100, 100), device = "meta")
It seems to have something to do with the list of submodules, however I was able to get my code working while using such a list. I have uploaded the faulty code to google colab.
I have also uploaded my original code that was displaying the bug here and the corrected code here. Note the only difference is in the forward method of ContractionModule. Also note these contain a bunch of probably unrelated code, unlike the first one which is only what is needed to reproduce the issue.
To Reproduce
Steps to reproduce the behavior:
Run the above code
There will be an error on the draw_graph() line
Expected behavior
A graph of the network should be displayed on my screen.
Screenshots / Text
Provided above
Additional context
The text was updated successfully, but these errors were encountered:
Edit
The issue was I was using a python list when I should have been using a ModuleList from pytorch. The python list was causing some other issues with pytorch itself when trying to use the gpu, so as far as I'm concerned this isn't much of a concern for torchview.
Describe the bug
I was attempting to implement an U-net and visualize it with torchview. The network itself seemed to work, as in I could pass a tensor to it and I would get back some output, but when trying to use torchview, I received a KeyError. The following is the shortest code I could come up with that shows the error:
The error that I received is:
It seems to have something to do with the list of submodules, however I was able to get my code working while using such a list. I have uploaded the faulty code to google colab.
I have also uploaded my original code that was displaying the bug here and the corrected code here. Note the only difference is in the forward method of ContractionModule. Also note these contain a bunch of probably unrelated code, unlike the first one which is only what is needed to reproduce the issue.
To Reproduce
Steps to reproduce the behavior:
Expected behavior
A graph of the network should be displayed on my screen.
Screenshots / Text
Provided above
Additional context
The text was updated successfully, but these errors were encountered: