Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using the same architecture and training parameters in pytorch, the model fails to converge. #201

Open
bioczsun opened this issue Oct 31, 2024 · 5 comments

Comments

@bioczsun
Copy link

bioczsun commented Oct 31, 2024

Hello, I have rewritten akita using pytorch in order to facilitate debugging of akita. i am using the same training parameters as you have in tensorflow and found that the model is not converging. However, on the same dataset, I reused your tensorflow version of akita and the model was able to converge.

`----------------------------------------------------------------
Layer (type) Output Shape Param #

        Conv1d-1          [-1, 96, 1048576]           4,224
   BatchNorm1d-2          [-1, 96, 1048576]             192
          ReLU-3          [-1, 96, 1048576]               0
     MaxPool1d-4           [-1, 96, 524288]               0
        Conv1d-5           [-1, 96, 524288]          46,080
   BatchNorm1d-6           [-1, 96, 524288]             192
          ReLU-7           [-1, 96, 524288]               0
     MaxPool1d-8           [-1, 96, 262144]               0
        Conv1d-9           [-1, 96, 262144]          46,080
  BatchNorm1d-10           [-1, 96, 262144]             192
         ReLU-11           [-1, 96, 262144]               0
    MaxPool1d-12           [-1, 96, 131072]               0
       Conv1d-13           [-1, 96, 131072]          46,080
  BatchNorm1d-14           [-1, 96, 131072]             192
         ReLU-15           [-1, 96, 131072]               0
    MaxPool1d-16            [-1, 96, 65536]               0
       Conv1d-17            [-1, 96, 65536]          46,080
  BatchNorm1d-18            [-1, 96, 65536]             192
         ReLU-19            [-1, 96, 65536]               0
    MaxPool1d-20            [-1, 96, 32768]               0
       Conv1d-21            [-1, 96, 32768]          46,080
  BatchNorm1d-22            [-1, 96, 32768]             192
         ReLU-23            [-1, 96, 32768]               0
    MaxPool1d-24            [-1, 96, 16384]               0
       Conv1d-25            [-1, 96, 16384]          46,080
  BatchNorm1d-26            [-1, 96, 16384]             192
         ReLU-27            [-1, 96, 16384]               0
    MaxPool1d-28             [-1, 96, 8192]               0
       Conv1d-29             [-1, 96, 8192]          46,080
  BatchNorm1d-30             [-1, 96, 8192]             192
         ReLU-31             [-1, 96, 8192]               0
    MaxPool1d-32             [-1, 96, 4096]               0
       Conv1d-33             [-1, 96, 4096]          46,080
  BatchNorm1d-34             [-1, 96, 4096]             192
         ReLU-35             [-1, 96, 4096]               0
    MaxPool1d-36             [-1, 96, 2048]               0
       Conv1d-37             [-1, 96, 2048]          46,080
  BatchNorm1d-38             [-1, 96, 2048]             192
         ReLU-39             [-1, 96, 2048]               0
    MaxPool1d-40             [-1, 96, 1024]               0
       Conv1d-41             [-1, 96, 1024]          46,080
  BatchNorm1d-42             [-1, 96, 1024]             192
         ReLU-43             [-1, 96, 1024]               0
    MaxPool1d-44              [-1, 96, 512]               0
       Conv1d-45              [-1, 48, 512]          13,824
  BatchNorm1d-46              [-1, 48, 512]              96
         ReLU-47              [-1, 48, 512]               0
       Conv1d-48              [-1, 96, 512]           4,608
  BatchNorm1d-49              [-1, 96, 512]             192
         ReLU-50              [-1, 96, 512]               0
      Dropout-51              [-1, 96, 512]               0
     Residual-52              [-1, 96, 512]               0
       Conv1d-53              [-1, 48, 512]          13,824
  BatchNorm1d-54              [-1, 48, 512]              96
         ReLU-55              [-1, 48, 512]               0
       Conv1d-56              [-1, 96, 512]           4,608
  BatchNorm1d-57              [-1, 96, 512]             192
         ReLU-58              [-1, 96, 512]               0
      Dropout-59              [-1, 96, 512]               0
     Residual-60              [-1, 96, 512]               0
       Conv1d-61              [-1, 48, 512]          13,824
  BatchNorm1d-62              [-1, 48, 512]              96
         ReLU-63              [-1, 48, 512]               0
       Conv1d-64              [-1, 96, 512]           4,608
  BatchNorm1d-65              [-1, 96, 512]             192
         ReLU-66              [-1, 96, 512]               0
      Dropout-67              [-1, 96, 512]               0
     Residual-68              [-1, 96, 512]               0
       Conv1d-69              [-1, 48, 512]          13,824
  BatchNorm1d-70              [-1, 48, 512]              96
         ReLU-71              [-1, 48, 512]               0
       Conv1d-72              [-1, 96, 512]           4,608
  BatchNorm1d-73              [-1, 96, 512]             192
         ReLU-74              [-1, 96, 512]               0
      Dropout-75              [-1, 96, 512]               0
     Residual-76              [-1, 96, 512]               0
       Conv1d-77              [-1, 48, 512]          13,824
  BatchNorm1d-78              [-1, 48, 512]              96
         ReLU-79              [-1, 48, 512]               0
       Conv1d-80              [-1, 96, 512]           4,608
  BatchNorm1d-81              [-1, 96, 512]             192
         ReLU-82              [-1, 96, 512]               0
      Dropout-83              [-1, 96, 512]               0
     Residual-84              [-1, 96, 512]               0
       Conv1d-85              [-1, 48, 512]          13,824
  BatchNorm1d-86              [-1, 48, 512]              96
         ReLU-87              [-1, 48, 512]               0
       Conv1d-88              [-1, 96, 512]           4,608
  BatchNorm1d-89              [-1, 96, 512]             192
         ReLU-90              [-1, 96, 512]               0
      Dropout-91              [-1, 96, 512]               0
     Residual-92              [-1, 96, 512]               0
       Conv1d-93              [-1, 48, 512]          13,824
  BatchNorm1d-94              [-1, 48, 512]              96
         ReLU-95              [-1, 48, 512]               0
       Conv1d-96              [-1, 96, 512]           4,608
  BatchNorm1d-97              [-1, 96, 512]             192
         ReLU-98              [-1, 96, 512]               0
      Dropout-99              [-1, 96, 512]               0
    Residual-100              [-1, 96, 512]               0
      Conv1d-101              [-1, 48, 512]          13,824
 BatchNorm1d-102              [-1, 48, 512]              96
        ReLU-103              [-1, 48, 512]               0
      Conv1d-104              [-1, 96, 512]           4,608
 BatchNorm1d-105              [-1, 96, 512]             192
        ReLU-106              [-1, 96, 512]               0
     Dropout-107              [-1, 96, 512]               0
    Residual-108              [-1, 96, 512]               0

DilatedResidual1D-109 [-1, 96, 512] 0
Conv1d-110 [-1, 64, 512] 30,720
BatchNorm1d-111 [-1, 64, 512] 128
ReLU-112 [-1, 64, 512] 0
OneToTwo-113 [-1, 64, 512, 512] 0
ConcatDist2D-114 [-1, 65, 512, 512] 0
Conv2d-115 [-1, 48, 512, 512] 28,080
BatchNorm2d-116 [-1, 48, 512, 512] 96
ReLU-117 [-1, 48, 512, 512] 0
Symmetrize2D-118 [-1, 48, 512, 512] 0
Conv2d-119 [-1, 24, 512, 512] 10,368
BatchNorm2d-120 [-1, 24, 512, 512] 48
ReLU-121 [-1, 24, 512, 512] 0
Conv2d-122 [-1, 48, 512, 512] 1,152
BatchNorm2d-123 [-1, 48, 512, 512] 96
ReLU-124 [-1, 48, 512, 512] 0
Dropout-125 [-1, 48, 512, 512] 0
Residual-126 [-1, 48, 512, 512] 0
Symmetrize2D-127 [-1, 48, 512, 512] 0
Conv2d-128 [-1, 24, 512, 512] 10,368
BatchNorm2d-129 [-1, 24, 512, 512] 48
ReLU-130 [-1, 24, 512, 512] 0
Conv2d-131 [-1, 48, 512, 512] 1,152
BatchNorm2d-132 [-1, 48, 512, 512] 96
ReLU-133 [-1, 48, 512, 512] 0
Dropout-134 [-1, 48, 512, 512] 0
Residual-135 [-1, 48, 512, 512] 0
Symmetrize2D-136 [-1, 48, 512, 512] 0
Conv2d-137 [-1, 24, 512, 512] 10,368
BatchNorm2d-138 [-1, 24, 512, 512] 48
ReLU-139 [-1, 24, 512, 512] 0
Conv2d-140 [-1, 48, 512, 512] 1,152
BatchNorm2d-141 [-1, 48, 512, 512] 96
ReLU-142 [-1, 48, 512, 512] 0
Dropout-143 [-1, 48, 512, 512] 0
Residual-144 [-1, 48, 512, 512] 0
Symmetrize2D-145 [-1, 48, 512, 512] 0
Conv2d-146 [-1, 24, 512, 512] 10,368
BatchNorm2d-147 [-1, 24, 512, 512] 48
ReLU-148 [-1, 24, 512, 512] 0
Conv2d-149 [-1, 48, 512, 512] 1,152
BatchNorm2d-150 [-1, 48, 512, 512] 96
ReLU-151 [-1, 48, 512, 512] 0
Dropout-152 [-1, 48, 512, 512] 0
Residual-153 [-1, 48, 512, 512] 0
Symmetrize2D-154 [-1, 48, 512, 512] 0
Conv2d-155 [-1, 24, 512, 512] 10,368
BatchNorm2d-156 [-1, 24, 512, 512] 48
ReLU-157 [-1, 24, 512, 512] 0
Conv2d-158 [-1, 48, 512, 512] 1,152
BatchNorm2d-159 [-1, 48, 512, 512] 96
ReLU-160 [-1, 48, 512, 512] 0
Dropout-161 [-1, 48, 512, 512] 0
Residual-162 [-1, 48, 512, 512] 0
Symmetrize2D-163 [-1, 48, 512, 512] 0
Conv2d-164 [-1, 24, 512, 512] 10,368
BatchNorm2d-165 [-1, 24, 512, 512] 48
ReLU-166 [-1, 24, 512, 512] 0
Conv2d-167 [-1, 48, 512, 512] 1,152
BatchNorm2d-168 [-1, 48, 512, 512] 96
ReLU-169 [-1, 48, 512, 512] 0
Dropout-170 [-1, 48, 512, 512] 0
Residual-171 [-1, 48, 512, 512] 0
DilatedResidual2D-172 [-1, 48, 512, 512] 0
Cropping2D-173 [-1, 48, 448, 448] 0
UpperTri-174 [-1, 48, 99681] 0
Linear-175 [-1, 99681, 5] 245
Final-176 [-1, 99681, 5] 0

Total params: 746,149
Trainable params: 746,149
Non-trainable params: 0

Input size (MB): 16.00
Forward/backward pass size (MB): 10473.61
Params size (MB): 2.85
Estimated Total Size (MB): 10492.46

SeqNN(
(feature_extractor_1d): Sequential(
(0): Sequential(
(0): Conv1d(4, 96, kernel_size=(11,), stride=(1,), padding=(5,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(1): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(2): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(3): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(4): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(5): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(6): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(7): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(8): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(9): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(10): Sequential(
(0): Conv1d(96, 96, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(11): DilatedResidual1D(
(layers): Sequential(
(0): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
(1): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
(2): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
(3): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
(4): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(9,), dilation=(9,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
(5): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(16,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
(6): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(29,), dilation=(29,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
(7): Residual(
(fn): Sequential(
(0): Conv1d(96, 48, kernel_size=(3,), stride=(1,), padding=(50,), dilation=(50,), bias=False)
(1): BatchNorm1d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv1d(48, 96, kernel_size=(1,), stride=(1,), bias=False)
(4): BatchNorm1d(96, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.4, inplace=False)
)
)
)
)
(12): Sequential(
(0): Conv1d(96, 64, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
(1): BatchNorm1d(64, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
)
)
(feature_extractor_2d): Sequential(
(0): Conv2d(65, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): DilatedResidual2D(
(layers): Sequential(
(0): Symmetrize2D()
(1): Residual(
(fn): Sequential(
(0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.1, inplace=False)
)
)
(2): Symmetrize2D()
(3): Residual(
(fn): Sequential(
(0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.1, inplace=False)
)
)
(4): Symmetrize2D()
(5): Residual(
(fn): Sequential(
(0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.1, inplace=False)
)
)
(6): Symmetrize2D()
(7): Residual(
(fn): Sequential(
(0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.1, inplace=False)
)
)
(8): Symmetrize2D()
(9): Residual(
(fn): Sequential(
(0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(9, 9), dilation=(9, 9), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.1, inplace=False)
)
)
(10): Symmetrize2D()
(11): Residual(
(fn): Sequential(
(0): Conv2d(48, 24, kernel_size=(3, 3), stride=(1, 1), padding=(16, 16), dilation=(16, 16), bias=False)
(1): BatchNorm2d(24, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(48, eps=0.001, momentum=0.0735, affine=True, track_running_stats=True)
(5): ReLU()
(6): Dropout(p=0.1, inplace=False)
)
)
)
)
)
(oneto_two): OneToTwo()
(concat_dist_2d): ConcatDist2D()
(crop_2d): Cropping2D()
(uppertri): UpperTri()
(final): Final(
(dense): Linear(in_features=48, out_features=5, bias=True)
)
)`

I checked the difference between the model's predicted and true values. It was found that the model had roughly the same mean values for predicted and true values in the beginning phase, but the var gap was large. r2 was consistently elevated during the training process.

image

I would be grateful for any advice you can give!

@davek44
Copy link
Contributor

davek44 commented Nov 1, 2024

Moving across frameworks can be very challenging. You'll basically have to double check that every detail matches. E.g. make sure the initialization is handled the same way. You might want to train simplified Tensorflow models with our code alongside your PyTorch version to focus on specific components.

@bioczsun
Copy link
Author

bioczsun commented Nov 1, 2024

Moving across frameworks can be very challenging. You'll basically have to double check that every detail matches. E.g. make sure the initialization is handled the same way. You might want to train simplified Tensorflow models with our code alongside your PyTorch version to focus on specific components.

Thank you very much for your reply. I have ensured that the parameters are consistent at each layer and have used the same initialization method as in TensorFlow. The dataset is also the same as the original Akita dataset. Initially, I suspected that the issue might be with the Adam optimizer, but after switching to SGD, the problem still persists.

@gfudenberg
Copy link
Contributor

hi @bioczsun -- did you try loading the trained model into your pytorch re-implementation rather than re-training? If so, I wonder if it gave similar predictions on test set sequences.

@bioczsun
Copy link
Author

bioczsun commented Nov 2, 2024

hi @bioczsun -- did you try loading the trained model into your pytorch re-implementation rather than re-training? If so, I wonder if it gave similar predictions on test set sequences.

Hi @gfudenberg , do you mean loading the pre-trained parameters from tensorflow into the pytorch model?

@gfudenberg
Copy link
Contributor

hi @bioczsun yes, exactly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants