Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Computational graph with gradients? #88

Open
oseledets opened this issue Mar 26, 2023 · 0 comments
Open

Computational graph with gradients? #88

oseledets opened this issue Mar 26, 2023 · 0 comments

Comments

@oseledets
Copy link

Thanks for the library!
Now the code takes an nn.Module and visualizes the forward pass;
Inside, it explicitly uses torch.no_grad.
However, if the forward pass of a module has autograd.grad inside, the library will fail.
Is it possible to modify the the library to allow for such use cases?

Simple example is attached.

import torch
import torch.nn as nn
# Define the model architecture
class visGrad(nn.Module):
    
    def __init__(self, model):
        super().__init__() # Add this line
        self.backbone =  model
        
    def forward(self, x, y):
        pred = self.backbone(x)
        er = pred - y
        loss = torch.sum(er**2)
        grd = torch.autograd.grad(loss, self.backbone.parameters())
        return loss

    
input_size = 1
hidden_size = 10
output_size = 1
model = nn.Sequential(
    nn.Linear(input_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, output_size),)
model.to(device)

    
x = torch.linspace(-2, 2, 256)
x = x[:, None]
y = x**2


model1 = visGrad(model)
model_graph = draw_graph(model1, input_data = (x, y), device='cpu')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant