Skip to content

Commit

Permalink
Updated notations, aligned with the latest revision
Browse files Browse the repository at this point in the history
  • Loading branch information
juannat7 committed Jun 14, 2023
1 parent 7892710 commit eba0a8a
Show file tree
Hide file tree
Showing 227 changed files with 109 additions and 87 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ instance/
docs/_build/

# PyBuilder
target/
# target/

# Jupyter Notebook
.ipynb_checkpoints
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ by learning key features in data-abundant regions and adapt them to fluxes in da
![Meta inference](https://github.com/juannat7/metaflux/blob/main/docs/gpp_infer.jpeg)

3. (experimental) `01c_with_encoder_pipeline`: adding context encoder to current classic metalearning model

![Encoder workflow](https://github.com/juannat7/metaflux/blob/main/docs/encoder_workflow.png)
![Meta inference with context encoder](https://github.com/juannat7/metaflux/blob/main/docs/gpp_encoder_infer.jpeg)
Binary file modified docs/GPP_CV.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/GPP_CV_fluxcom.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/GPP_ensemble_mean.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/GPP_ensemble_std.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/GPP_trend.pdf
Binary file not shown.
Binary file modified docs/RECO_CV.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/RECO_CV_fluxcom.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/RECO_ensemble_mean.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/RECO_ensemble_std.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/RECO_trend.pdf
Binary file not shown.
Binary file added docs/encoder_workflow.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/meta_robust_GPP_NT_VUT_REF.pdf
Binary file not shown.
Binary file modified docs/meta_robust_RECO.pdf
Binary file not shown.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions metaflux/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ def __init__(self, root, mode, x_columns, y_column, context_columns=None, time_c
Parameters:
-----------
root <str>: root path of dataset (before <mode>) in CSV, in the following structure: <class>/<mode>/<filenames>.csv
mode <str>: train or test (ie. metatrain or metatest)
mode <str>: base or target tasks
x_columns <list>: list containing the column names of input features
y_column <str>: the name of column for target variable
context_columns <list>: list containing the column names of contextual features
time_column <str>: the name of column indicating time (must be of DateTime object)
time_agg <str>: how to aggregate time across observations, defaults to 1H (must be compatible with DateTime object)
time_window <int>: the size of a time window
"""
modes = ["train", "test"]
modes = ["base", "target"]
assert mode in modes

self.mode = mode
Expand Down
120 changes: 60 additions & 60 deletions metaflux/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __init__(
input_size: int,
hidden_size: int,
model_type: str,
fluxnet_train,
fluxnet_test,
fluxnet_base,
fluxnet_target,
update_lr,
meta_lr,
batch_size,
Expand All @@ -43,8 +43,8 @@ def __init__(
self.input_size = input_size
self.hidden_size = hidden_size
self.model_type = model_type
self.fluxnet_support = fluxnet_train
self.fluxnet_query = fluxnet_test
self.fluxnet_base = fluxnet_base
self.fluxnet_target = fluxnet_target
self.update_lr = update_lr
self.meta_lr = meta_lr
self.batch_size = batch_size
Expand All @@ -53,7 +53,7 @@ def __init__(
self.encoder_hidden_size = encoder_hidden_size
self.with_context = with_context
self.with_baseline = with_baseline
self.encoder_input_size = next(iter(fluxnet_train))[0].shape[-1] - self.input_size
self.encoder_input_size = next(iter(fluxnet_base))[0].shape[-1] - self.input_size
self.loss = nn.MSELoss(reduction="mean")

self.meta_loss_metric = dict()
Expand Down Expand Up @@ -83,17 +83,17 @@ def train_meta(
schedule = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)

# processing the data
support_train_sz = int(len(self.fluxnet_support) * (1 - self.finetune_size))
query_train_sz = int(len(self.fluxnet_query) * (1 - self.finetune_size))
base_train_sz = int(len(self.fluxnet_base) * (1 - self.finetune_size))
target_train_sz = int(len(self.fluxnet_target) * (1 - self.finetune_size))

# split both datasets to meta-train and meta-test
support_train, support_test = torch.utils.data.random_split(self.fluxnet_support, [support_train_sz, len(self.fluxnet_support) - support_train_sz])
query_train, query_test = torch.utils.data.random_split(self.fluxnet_query, [query_train_sz, len(self.fluxnet_query) - query_train_sz])
base_train, base_test = torch.utils.data.random_split(self.fluxnet_base, [base_train_sz, len(self.fluxnet_base) - base_train_sz])
target_train, target_test = torch.utils.data.random_split(self.fluxnet_target, [target_train_sz, len(self.fluxnet_target) - target_train_sz])

support_train_dl = DataLoader(support_train, batch_size=self.batch_size, shuffle=True)
support_test_dl = DataLoader(support_test, batch_size=self.batch_size, shuffle=True)
query_train_dl = DataLoader(query_train, batch_size=self.batch_size, shuffle=True)
query_test_dl = DataLoader(query_test, batch_size=self.batch_size, shuffle=True)
base_train_dl = DataLoader(base_train, batch_size=self.batch_size, shuffle=True)
base_test_dl = DataLoader(base_test, batch_size=self.batch_size, shuffle=True)
target_train_dl = DataLoader(target_train, batch_size=self.batch_size, shuffle=True)
target_test_dl = DataLoader(target_test, batch_size=self.batch_size, shuffle=True)

if self.with_context:
self.encoder = Encoder(
Expand All @@ -116,49 +116,49 @@ def train_meta(
train_error, val_error, outer_error = 0.0, 0.0, 0.0

# Get k-shot batches
support_train_k = self._get_k_shot(support_train_dl, self.max_meta_step)
support_test_k = self._get_k_shot(support_test_dl, self.max_meta_step)
query_train_k = self._get_k_shot(query_train_dl, self.max_meta_step)
query_test_k = self._get_k_shot(query_test_dl, self.max_meta_step)
base_train_k = self._get_k_shot(base_train_dl, self.max_meta_step)
base_test_k = self._get_k_shot(base_test_dl, self.max_meta_step)
target_train_k = self._get_k_shot(target_train_dl, self.max_meta_step)
target_test_k = self._get_k_shot(target_test_dl, self.max_meta_step)

# Main loop
with torch.backends.cudnn.flags(enabled=False):
learner = self.maml.clone().double()

# Propose phi using support sets
for task, (s_train, s_test) in enumerate(zip(support_train_k, support_test_k)):
s_train_x, s_train_y = s_train
s_test_x, s_test_y = s_test
s_train_x, s_train_y = s_train_x.to(device), s_train_y.to(device)
s_test_x, s_test_y = s_test_x.to(device), s_test_y.to(device)
# Propose phi using base sets
for task, (b_train, b_test) in enumerate(zip(base_train_k, base_test_k)):
b_train_x, b_train_y = b_train
b_test_x, b_test_y = b_test
b_train_x, b_train_y = b_train_x.to(device), b_train_y.to(device)
b_test_x, b_test_y = b_test_x.to(device), b_test_y.to(device)

## Inner-loop to propose phi using meta-training dataset
pred = self._get_pred(s_train_x, learner)
error = self.loss(pred, s_train_y)
pred = self._get_pred(b_train_x, learner)
error = self.loss(pred, b_train_y)
learner.adapt(error)

# Inner-loop evaluation (no gradient step) using meta-testing dataset
pred = self._get_pred(s_test_x, learner)
error = self.loss(pred, s_test_y)
pred = self._get_pred(b_test_x, learner)
error = self.loss(pred, b_test_y)
train_error += error.item()

train_epoch.append(train_error/(task + 1))

# Outer-loop using query sets
for task, (q_train, q_test) in enumerate(zip(query_train_k, query_test_k)):
q_train_x, q_train_y = q_train
q_test_x, q_test_y = q_test
q_train_x, q_train_y = q_train_x.to(device), q_train_y.to(device)
q_test_x, q_test_y = q_test_x.to(device), q_test_y.to(device)
# Outer-loop using target sets
for task, (t_train, t_test) in enumerate(zip(target_train_k, target_test_k)):
t_train_x, t_train_y = t_train
t_test_x, t_test_y = t_test
t_train_x, t_train_y = t_train_x.to(device), t_train_y.to(device)
t_test_x, t_test_y = t_test_x.to(device), t_test_y.to(device)

## accumulate inner-loop gradients given proposed phi
pred = self._get_pred(q_train_x, learner)
error = self.loss(pred, q_train_y)
pred = self._get_pred(t_train_x, learner)
error = self.loss(pred, t_train_y)
outer_error += error

# Adaptation evaluation (no gradient step) using meta-testing dataset
pred = self._get_pred(q_test_x, learner)
error = self.loss(pred, q_test_y)
pred = self._get_pred(t_test_x, learner)
error = self.loss(pred, t_test_y)
val_error += error.item()

val_epoch.append(val_error/(task + 1))
Expand Down Expand Up @@ -198,17 +198,17 @@ def _train_base(
schedule = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)

# processing the data
support_train_sz = int(len(self.fluxnet_support) * (1 - self.finetune_size))
query_train_sz = int(len(self.fluxnet_query) * (1 - self.finetune_size))
base_train_sz = int(len(self.fluxnet_base) * (1 - self.finetune_size))
target_train_sz = int(len(self.fluxnet_target) * (1 - self.finetune_size))

# split both datasets to train and test
support_train, support_test = torch.utils.data.random_split(self.fluxnet_support, [support_train_sz, len(self.fluxnet_support) - support_train_sz])
query_train, query_test = torch.utils.data.random_split(self.fluxnet_query, [query_train_sz, len(self.fluxnet_query) - query_train_sz])
base_train, base_test = torch.utils.data.random_split(self.fluxnet_base, [base_train_sz, len(self.fluxnet_base) - base_train_sz])
target_train, target_test = torch.utils.data.random_split(self.fluxnet_target, [target_train_sz, len(self.fluxnet_target) - target_train_sz])

support_train_dl = DataLoader(support_train, batch_size=self.batch_size, shuffle=True)
support_test_dl = DataLoader(support_test, batch_size=self.batch_size, shuffle=True)
query_train_dl = DataLoader(query_train, batch_size=self.batch_size, shuffle=True)
query_test_dl = DataLoader(query_test, batch_size=self.batch_size, shuffle=True)
base_train_dl = DataLoader(base_train, batch_size=self.batch_size, shuffle=True)
base_test_dl = DataLoader(base_test, batch_size=self.batch_size, shuffle=True)
target_train_dl = DataLoader(target_train, batch_size=self.batch_size, shuffle=True)
target_test_dl = DataLoader(target_test, batch_size=self.batch_size, shuffle=True)

train_epoch, val_epoch = list(), list()

Expand All @@ -217,14 +217,14 @@ def _train_base(
train_error, val_error = 0.0, 0.0

# Get k-shot batches
support_train_k = self._get_k_shot(support_train_dl, self.max_meta_step)
support_test_k = self._get_k_shot(support_test_dl, self.max_meta_step)
query_train_k = self._get_k_shot(query_train_dl, self.max_meta_step)
query_test_k = self._get_k_shot(query_test_dl, self.max_meta_step)
base_train_k = self._get_k_shot(base_train_dl, self.max_meta_step)
base_test_k = self._get_k_shot(base_test_dl, self.max_meta_step)
target_train_k = self._get_k_shot(target_train_dl, self.max_meta_step)
target_test_k = self._get_k_shot(target_test_dl, self.max_meta_step)

with torch.backends.cudnn.flags(enabled=False):
# Baseline learning + evaluation
for task, (s_train, s_test) in enumerate(zip(support_train_k, support_test_k)):
for task, (s_train, s_test) in enumerate(zip(base_train_k, base_test_k)):
s_train_x, s_train_y = s_train
s_test_x, s_test_y = s_test
s_train_x, s_train_y = s_train_x.to(device), s_train_y.to(device)
Expand All @@ -242,20 +242,20 @@ def _train_base(
train_epoch.append(train_error/(task + 1))

# Baseline-equivalent to meta-adaptation + validation
for task, (q_train, q_test) in enumerate(zip(query_train_k, query_test_k)):
q_train_x, q_train_y = q_train
q_test_x, q_test_y = q_test
q_train_x, q_train_y = q_train_x.to(device), q_train_y.to(device)
q_test_x, q_test_y = q_test_x.to(device), q_test_y.to(device)

pred = self.base_model(q_train_x[:,:,:self.input_size])
error = self.loss(pred, q_train_y)
for task, (t_train, t_test) in enumerate(zip(target_train_k, target_test_k)):
t_train_x, t_train_y = t_train
t_test_x, t_test_y = t_test
t_train_x, t_train_y = t_train_x.to(device), t_train_y.to(device)
t_test_x, t_test_y = t_test_x.to(device), t_test_y.to(device)

pred = self.base_model(t_train_x[:,:,:self.input_size])
error = self.loss(pred, t_train_y)
error.backward()
opt.step()

## Note: no gradient step
val_error += self.loss(self.base_model(q_test_x[:,:,:self.input_size]),
q_test_y).item()
val_error += self.loss(self.base_model(t_test_x[:,:,:self.input_size]),
t_test_y).item()

schedule.step()

Expand Down
26 changes: 19 additions & 7 deletions metaflux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _get_robustness_data(x, y, factor):

return meta_c_loss_dict, meta_loss_dict, base_loss_dict

def extreme_analysis(maml, base, hyper_args, data_dir, factor=1., is_plot=True):
def extreme_analysis(maml, base, hyper_args, data_dir, model_type="mlp", factor=1., is_plot=True):
def _get_pred(learner, encoder, x):
with_context = False if encoder == None else True
"Subroutine to perform prediction on input x"
Expand All @@ -150,8 +150,17 @@ def _get_pred(learner, encoder, x):

def _get_extreme(x, y, y_norm, factor=factor):
xtr_mask = y_norm > factor
xtr_x = x[xtr_mask.squeeze()]
xtr_y = y[xtr_mask.squeeze()]
xtr_index = [index for index, value in enumerate(xtr_mask.squeeze()) if value]
xtr_x = list()
if model_type == "mlp":
xtr_x = x[xtr_mask.squeeze()]
xtr_y = y[xtr_mask.squeeze()]
else:
for xtr_idx in xtr_index[30:]:
xtr_x.append(x[xtr_idx - 30 : xtr_idx, :])
xtr_y = y[xtr_mask.squeeze()]
xtr_y = xtr_y[30:]

return torch.tensor(xtr_x).to(device), torch.tensor(xtr_y).to(device)

learner = maml.clone().double()
Expand All @@ -166,10 +175,13 @@ def _get_extreme(x, y, y_norm, factor=factor):
for i, station in all_df.groupby("Site"):
x, y, y_norm = station[hyper_args['xcolumns']].to_numpy(), station[hyper_args['ycolumn']].to_numpy(), station[[f"{hyper_args['ycolumn'][0]}_norm"]].to_numpy()
x, y = _get_extreme(x, y, y_norm)
maml_pred = _get_pred(learner, encoder=None, x=x)
base_pred = base(x)

try:
maml_pred = _get_pred(learner, encoder=None, x=x)
base_pred = base(x)
if model_type != "mlp":
maml_pred = maml_pred[:,-1:]
base_pred = base_pred[:,-1:]

extreme_under_d = {
"climate": station["Climate"].iloc[0],
"lon": station["Lon"].iloc[0],
Expand All @@ -191,7 +203,7 @@ def _get_extreme(x, y, y_norm, factor=factor):
"maml_over_mean": abs(maml_pred[maml_pred > y] - y[maml_pred > y]).mean().item(),
}

except:
except Exception as e:
continue

extreme_under_df.append(extreme_under_d)
Expand Down
12 changes: 8 additions & 4 deletions notebooks/01a_non_temporal_pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -95,6 +96,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -110,11 +112,12 @@
"root_dir = '../metaflux/data/sample/'\n",
"\n",
"# Note that the inputs are normalized here. For non-temporal data, we specify time_window = 1\n",
"fluxnet_train = metaflux.dataloader.Fluxmetanet(root=root_dir, mode=\"train\", x_columns=hyper_args[\"xcolumns\"], y_column=ycolumn, context_columns=hyper_args[\"contextcolumns\"] , time_column=None, time_window=1)\n",
"fluxnet_test = metaflux.dataloader.Fluxmetanet(root=root_dir, mode=\"test\", x_columns=hyper_args[\"xcolumns\"], y_column=ycolumn, context_columns=hyper_args[\"contextcolumns\"], time_column=None, time_window=1)"
"fluxnet_base = metaflux.dataloader.Fluxmetanet(root=root_dir, mode=\"base\", x_columns=hyper_args[\"xcolumns\"], y_column=ycolumn, context_columns=hyper_args[\"contextcolumns\"] , time_column=None, time_window=1)\n",
"fluxnet_target = metaflux.dataloader.Fluxmetanet(root=root_dir, mode=\"target\", x_columns=hyper_args[\"xcolumns\"], y_column=ycolumn, context_columns=hyper_args[\"contextcolumns\"], time_column=None, time_window=1)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -135,8 +138,8 @@
" input_size=hyper_args[\"input_size\"], \n",
" hidden_size=hyper_args[\"hidden_size\"], \n",
" model_type=model_type, \n",
" fluxnet_train=fluxnet_train,\n",
" fluxnet_test=fluxnet_test,\n",
" fluxnet_base=fluxnet_base,\n",
" fluxnet_target=fluxnet_target,\n",
" update_lr=hyper_args[\"update_lr\"],\n",
" meta_lr=hyper_args[\"meta_lr\"],\n",
" batch_size=hyper_args[\"batch_size\"],\n",
Expand All @@ -149,6 +152,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down
Loading

0 comments on commit eba0a8a

Please sign in to comment.