-
Notifications
You must be signed in to change notification settings - Fork 127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using the same architecture and training parameters in pytorch, the model fails to converge. #201
Comments
Moving across frameworks can be very challenging. You'll basically have to double check that every detail matches. E.g. make sure the initialization is handled the same way. You might want to train simplified Tensorflow models with our code alongside your PyTorch version to focus on specific components. |
Thank you very much for your reply. I have ensured that the parameters are consistent at each layer and have used the same initialization method as in TensorFlow. The dataset is also the same as the original Akita dataset. Initially, I suspected that the issue might be with the Adam optimizer, but after switching to SGD, the problem still persists. |
hi @bioczsun -- did you try loading the trained model into your pytorch re-implementation rather than re-training? If so, I wonder if it gave similar predictions on test set sequences. |
Hi @gfudenberg , do you mean loading the pre-trained parameters from tensorflow into the pytorch model? |
hi @bioczsun yes, exactly |
Hello, I have rewritten akita using pytorch in order to facilitate debugging of akita. i am using the same training parameters as you have in tensorflow and found that the model is not converging. However, on the same dataset, I reused your tensorflow version of akita and the model was able to converge.
`----------------------------------------------------------------
Layer (type) Output Shape Param #
DilatedResidual1D-109 [-1, 96, 512] 0
Conv1d-110 [-1, 64, 512] 30,720
BatchNorm1d-111 [-1, 64, 512] 128
ReLU-112 [-1, 64, 512] 0
OneToTwo-113 [-1, 64, 512, 512] 0
ConcatDist2D-114 [-1, 65, 512, 512] 0
Conv2d-115 [-1, 48, 512, 512] 28,080
BatchNorm2d-116 [-1, 48, 512, 512] 96
ReLU-117 [-1, 48, 512, 512] 0
Symmetrize2D-118 [-1, 48, 512, 512] 0
Conv2d-119 [-1, 24, 512, 512] 10,368
BatchNorm2d-120 [-1, 24, 512, 512] 48
ReLU-121 [-1, 24, 512, 512] 0
Conv2d-122 [-1, 48, 512, 512] 1,152
BatchNorm2d-123 [-1, 48, 512, 512] 96
ReLU-124 [-1, 48, 512, 512] 0
Dropout-125 [-1, 48, 512, 512] 0
Residual-126 [-1, 48, 512, 512] 0
Symmetrize2D-127 [-1, 48, 512, 512] 0
Conv2d-128 [-1, 24, 512, 512] 10,368
BatchNorm2d-129 [-1, 24, 512, 512] 48
ReLU-130 [-1, 24, 512, 512] 0
Conv2d-131 [-1, 48, 512, 512] 1,152
BatchNorm2d-132 [-1, 48, 512, 512] 96
ReLU-133 [-1, 48, 512, 512] 0
Dropout-134 [-1, 48, 512, 512] 0
Residual-135 [-1, 48, 512, 512] 0
Symmetrize2D-136 [-1, 48, 512, 512] 0
Conv2d-137 [-1, 24, 512, 512] 10,368
BatchNorm2d-138 [-1, 24, 512, 512] 48
ReLU-139 [-1, 24, 512, 512] 0
Conv2d-140 [-1, 48, 512, 512] 1,152
BatchNorm2d-141 [-1, 48, 512, 512] 96
ReLU-142 [-1, 48, 512, 512] 0
Dropout-143 [-1, 48, 512, 512] 0
Residual-144 [-1, 48, 512, 512] 0
Symmetrize2D-145 [-1, 48, 512, 512] 0
Conv2d-146 [-1, 24, 512, 512] 10,368
BatchNorm2d-147 [-1, 24, 512, 512] 48
ReLU-148 [-1, 24, 512, 512] 0
Conv2d-149 [-1, 48, 512, 512] 1,152
BatchNorm2d-150 [-1, 48, 512, 512] 96
ReLU-151 [-1, 48, 512, 512] 0
Dropout-152 [-1, 48, 512, 512] 0
Residual-153 [-1, 48, 512, 512] 0
Symmetrize2D-154 [-1, 48, 512, 512] 0
Conv2d-155 [-1, 24, 512, 512] 10,368
BatchNorm2d-156 [-1, 24, 512, 512] 48
ReLU-157 [-1, 24, 512, 512] 0
Conv2d-158 [-1, 48, 512, 512] 1,152
BatchNorm2d-159 [-1, 48, 512, 512] 96
ReLU-160 [-1, 48, 512, 512] 0
Dropout-161 [-1, 48, 512, 512] 0
Residual-162 [-1, 48, 512, 512] 0
Symmetrize2D-163 [-1, 48, 512, 512] 0
Conv2d-164 [-1, 24, 512, 512] 10,368
BatchNorm2d-165 [-1, 24, 512, 512] 48
ReLU-166 [-1, 24, 512, 512] 0
Conv2d-167 [-1, 48, 512, 512] 1,152
BatchNorm2d-168 [-1, 48, 512, 512] 96
ReLU-169 [-1, 48, 512, 512] 0
Dropout-170 [-1, 48, 512, 512] 0
Residual-171 [-1, 48, 512, 512] 0
DilatedResidual2D-172 [-1, 48, 512, 512] 0
Cropping2D-173 [-1, 48, 448, 448] 0
UpperTri-174 [-1, 48, 99681] 0
Linear-175 [-1, 99681, 5] 245
Final-176 [-1, 99681, 5] 0
Total params: 746,149
Trainable params: 746,149
Non-trainable params: 0
Input size (MB): 16.00
Forward/backward pass size (MB): 10473.61
Params size (MB): 2.85
Estimated Total Size (MB): 10492.46
SeqNN(
(feature_extractor_1d): Sequential(
(0): Sequential(
(0): Conv1d(4, 96, kernel_size=(11,), stride=(1,), padding=(5,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(1): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(2): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(3): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(4): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(5): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(6): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(7): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(8): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(9): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(10): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(11): DilatedResidual1D(
(layers): Sequential(
(0): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
(1): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
(2): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
(3): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
(4): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(9,), dilation=(9,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
(5): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(16,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
(6): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(29,), dilation=(29,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
(7): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(50,), dilation=(50,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
)
)
(12): Sequential(
(0): Conv1d(96, 64, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(64, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(feature_extractor_2d): Sequential(
(0): Conv2d(65, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): DilatedResidual2D(
(layers): Sequential(
(0): Symmetrize2D()
(1): Residual(
(fn): Sequential(
(0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.1, inplace=False)
)
)
(2): Symmetrize2D()
(3): Residual(
(fn): Sequential(
(0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.1, inplace=False)
)
)
(4): Symmetrize2D()
(5): Residual(
(fn): Sequential(
(0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.1, inplace=False)
)
)
(6): Symmetrize2D()
(7): Residual(
(fn): Sequential(
(0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.1, inplace=False)
)
)
(8): Symmetrize2D()
(9): Residual(
(fn): Sequential(
(0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(9, 9), dilation=(9, 9), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.1, inplace=False)
)
)
(10): Symmetrize2D()
(11): Residual(
(fn): Sequential(
(0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(16, 16), dilation=(16, 16), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.1, inplace=False)
)
)
)
)
)
(oneto_two): OneToTwo()
(concat_dist_2d): ConcatDist2D()
(crop_2d): Cropping2D()
(uppertri): UpperTri()
(final): Final(
(dense): Linear(in_features=48, out_features=5, bias=True)
)
)`
I checked the difference between the model's predicted and true values. It was found that the model had roughly the same mean values for predicted and true values in the beginning phase, but the var gap was large. r2 was consistently elevated during the training process.
I would be grateful for any advice you can give!
The text was updated successfully, but these errors were encountered: