Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Processing the forecasting task in a step by step manner #2

Open
AbbyWood98 opened this issue Oct 17, 2023 · 1 comment
Open

Processing the forecasting task in a step by step manner #2

AbbyWood98 opened this issue Oct 17, 2023 · 1 comment

Comments

@AbbyWood98
Copy link

你好,

我想请教一下,在做seq2seq forecasting的时候,除了像代码里那样,每两个step中间做一次gradient计算,是否可以将初始值h0和想要预测的time span t放入ode中?

pred_zs = odeint(self.func, z0, torch.mean(t, dim=0).squeeze()).permute(1, 0, 2)

如果想要使用CDE的话,是否可以理解为先用CDE解算的结果作为初始值,再用ODE进行decode?

z_T = cdeint(X=X, z0=z0, func=self.func, t=X.grid_points, rtol=self.args.rtol, atol=self.args.atol)
h0 = F.avg_pool1d(z_T.transpose(1, 2), kernel_size=z_T.size(1)).squeeze(2)
pred_zs = odeint(self.func_ode, h0, t_fore).permute(1, 0, 2)

谢谢!

@Saltsmart
Copy link
Owner

Saltsmart commented Oct 18, 2023

是否可以将初始值h0和想要预测的time span t放入ode中?

这个没有问题,不知道是不是指另外实现一个decoder。

像代码里那样,每两个step中间做一次gradient计算

如果step指预测步的话,seq2seq不是这样实现的。
它的预测按照 y_time 输出全部预测步,一起输出计算梯度,就按以下的代码:

def forward(self, y_time, x_data, x_time, x_mask=None):
        # 完成序列预测
        if x_mask is not None:
            x = x_data * x_mask
        else:
            x = x_data

        if len(y_time.shape) < 1:
            y_time = y_time.unsqueeze(0)

        # encoder
        hs = self.encoder.run_to_last_point(x, x_time, return_latents=True)

        decoder_begin_hi = hs[:, -1, :]  # 将decoder初始点设为x_time最后一点
        y_time = torch.cat((x_time[-1:], y_time))

        # decoder
        y_pred = self.decoder(decoder_begin_hi, y_time)[:, 1:, :]  # 继续向下解ode,利用decoder解出预测值

        return y_pred

如果想要使用CDE的话,是否可以理解为先用CDE解算的结果作为初始值,再用ODE进行decode?

对,以下的三行代码就实现了一个简单的seq2seq,encoder是CDE,decoder是ODE。
注意三行代码中,z_T 是cdeint解出的所有时刻的解,参考链接,看一下是否在求 h0 的时候,要取其最后一个维度的分量,代表最后一个时刻的解。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants