Skip to content

Commit

Permalink
clearer notation and update README for preprint
Browse files Browse the repository at this point in the history
  • Loading branch information
juannat7 committed Apr 2, 2023
1 parent 0910915 commit 7892710
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 61 deletions.
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ git clone https://github.com/juannat7/metaflux.git
pip install -r requirements.txt
```

![Meta inference](https://github.com/juannat7/metaflux/blob/main/docs/gpp_infer.jpeg)

![Meta inference with context encoder](https://github.com/juannat7/metaflux/blob/main/docs/gpp_encoder_infer.jpeg)
## Sample notebooks
These sample notebooks attempt to demonstrate the applications of meta-learning for spatiotemporal domain adaptation. In particular, we tried to infer gross primary production (GPP) from key meteorological and remote sensing data points
by learning key features in data-abundant regions and adapt them to fluxes in data-sparse areas. We demonstrate the use of meta-learning in non-temporal, temporal, and with spatial context situations. Feel free to apply the algorithm presented in the notebook for your specific use cases:

## Sample Notebooks
1. `01a_non_temporal_pipeline`: for non-temporal dataset and model (eg. MLP)
2. `01b_temporal_pipeline`: for temporal dataset and model (eg. LSTM, BiLSTM)
3. (experimental) `01c_with_encoder_pipeline`: adding context encoder to current classic metalearning model

![Meta inference](https://github.com/juannat7/metaflux/blob/main/docs/gpp_infer.jpeg)

3. (experimental) `01c_with_encoder_pipeline`: adding context encoder to current classic metalearning model

![Meta inference with context encoder](https://github.com/juannat7/metaflux/blob/main/docs/gpp_encoder_infer.jpeg)
Binary file modified docs/gpp_infer.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 0 additions & 2 deletions metaflux/configs/hyperparams_1a.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ batch_size: 256
input_size: 4
meta_lr: 1.0e-03
update_lr: 1.0e-4
meta_time_lr: 1.0e-02
update_time_lr: 1.0e-3
num_lstm_layers: 1
max_meta_step: 2
finetune_size: 0.2
Expand Down
2 changes: 0 additions & 2 deletions metaflux/configs/hyperparams_1b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ batch_size: 256
input_size: 4
meta_lr: 1.0e-03
update_lr: 1.0e-4
meta_time_lr: 1.0e-02
update_time_lr: 1.0e-3
num_lstm_layers: 1
max_meta_step: 2
finetune_size: 0.2
Expand Down
2 changes: 0 additions & 2 deletions metaflux/configs/hyperparams_1c.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ batch_size: 256
input_size: 4
meta_lr: 1.0e-03
update_lr: 1.0e-4
meta_time_lr: 1.0e-02
update_time_lr: 1.0e-3
num_lstm_layers: 1
max_meta_step: 2
finetune_size: 0.2
Expand Down
44 changes: 19 additions & 25 deletions metaflux/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def __init__(
self.input_size = input_size
self.hidden_size = hidden_size
self.model_type = model_type
self.fluxnet_train = fluxnet_train
self.fluxnet_test = fluxnet_test
self.fluxnet_support = fluxnet_train
self.fluxnet_query = fluxnet_test
self.update_lr = update_lr
self.meta_lr = meta_lr
self.batch_size = batch_size
Expand Down Expand Up @@ -83,12 +83,12 @@ def train_meta(
schedule = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)

# processing the data
support_train_sz = int(len(self.fluxnet_train) * (1 - self.finetune_size))
query_train_sz = int(len(self.fluxnet_test) * (1 - self.finetune_size))
support_train_sz = int(len(self.fluxnet_support) * (1 - self.finetune_size))
query_train_sz = int(len(self.fluxnet_query) * (1 - self.finetune_size))

# split both datasets to train and test
support_train, support_test = torch.utils.data.random_split(self.fluxnet_train, [support_train_sz, len(self.fluxnet_train) - support_train_sz])
query_train, query_test = torch.utils.data.random_split(self.fluxnet_test, [query_train_sz, len(self.fluxnet_test) - query_train_sz])
# split both datasets to meta-train and meta-test
support_train, support_test = torch.utils.data.random_split(self.fluxnet_support, [support_train_sz, len(self.fluxnet_support) - support_train_sz])
query_train, query_test = torch.utils.data.random_split(self.fluxnet_query, [query_train_sz, len(self.fluxnet_query) - query_train_sz])

support_train_dl = DataLoader(support_train, batch_size=self.batch_size, shuffle=True)
support_test_dl = DataLoader(support_test, batch_size=self.batch_size, shuffle=True)
Expand Down Expand Up @@ -125,40 +125,38 @@ def train_meta(
with torch.backends.cudnn.flags(enabled=False):
learner = self.maml.clone().double()

# Propose phi using support sets
for task, (s_train, s_test) in enumerate(zip(support_train_k, support_test_k)):
s_train_x, s_train_y = s_train
s_test_x, s_test_y = s_test

# transfer to GPU device
s_train_x, s_train_y = s_train_x.to(device), s_train_y.to(device)
s_test_x, s_test_y = s_test_x.to(device), s_test_y.to(device)

# Meta-training
## Inner-loop to propose phi using meta-training dataset
pred = self._get_pred(s_train_x, learner)
error = self.loss(pred, s_train_y)
learner.adapt(error)

# Meta-learning evaluation (no gradient step)
# Inner-loop evaluation (no gradient step) using meta-testing dataset
pred = self._get_pred(s_test_x, learner)
error = self.loss(pred, s_test_y)
train_error += error.item()

train_epoch.append(train_error/(task + 1))

# Outer-loop using query sets
for task, (q_train, q_test) in enumerate(zip(query_train_k, query_test_k)):
q_train_x, q_train_y = q_train
q_test_x, q_test_y = q_test

# transfer to GPU device
q_train_x, q_train_y = q_train_x.to(device), q_train_y.to(device)
q_test_x, q_test_y = q_test_x.to(device), q_test_y.to(device)

# Update (accumulate inner-loop gradients)
## accumulate inner-loop gradients given proposed phi
pred = self._get_pred(q_train_x, learner)
error = self.loss(pred, q_train_y)
outer_error += error

# Adaptation evaluation (no gradient step)
# Adaptation evaluation (no gradient step) using meta-testing dataset
pred = self._get_pred(q_test_x, learner)
error = self.loss(pred, q_test_y)
val_error += error.item()
Expand Down Expand Up @@ -200,12 +198,12 @@ def _train_base(
schedule = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)

# processing the data
support_train_sz = int(len(self.fluxnet_train) * (1 - self.finetune_size))
query_train_sz = int(len(self.fluxnet_test) * (1 - self.finetune_size))
support_train_sz = int(len(self.fluxnet_support) * (1 - self.finetune_size))
query_train_sz = int(len(self.fluxnet_query) * (1 - self.finetune_size))

# split both datasets to train and test
support_train, support_test = torch.utils.data.random_split(self.fluxnet_train, [support_train_sz, len(self.fluxnet_train) - support_train_sz])
query_train, query_test = torch.utils.data.random_split(self.fluxnet_test, [query_train_sz, len(self.fluxnet_test) - query_train_sz])
support_train, support_test = torch.utils.data.random_split(self.fluxnet_support, [support_train_sz, len(self.fluxnet_support) - support_train_sz])
query_train, query_test = torch.utils.data.random_split(self.fluxnet_query, [query_train_sz, len(self.fluxnet_query) - query_train_sz])

support_train_dl = DataLoader(support_train, batch_size=self.batch_size, shuffle=True)
support_test_dl = DataLoader(support_test, batch_size=self.batch_size, shuffle=True)
Expand All @@ -225,15 +223,13 @@ def _train_base(
query_test_k = self._get_k_shot(query_test_dl, self.max_meta_step)

with torch.backends.cudnn.flags(enabled=False):
# Baseline learning + evaluation
for task, (s_train, s_test) in enumerate(zip(support_train_k, support_test_k)):
s_train_x, s_train_y = s_train
s_test_x, s_test_y = s_test

# transfer to GPU device
s_train_x, s_train_y = s_train_x.to(device), s_train_y.to(device)
s_test_x, s_test_y = s_test_x.to(device), s_test_y.to(device)

# Baseline learning + evaluation
pred = self.base_model(s_train_x[:,:,:self.input_size])
error = self.loss(pred, s_train_y)
error.backward()
Expand All @@ -245,15 +241,13 @@ def _train_base(

train_epoch.append(train_error/(task + 1))

# Baseline-equivalent to meta-adaptation + validation
for task, (q_train, q_test) in enumerate(zip(query_train_k, query_test_k)):
q_train_x, q_train_y = q_train
q_test_x, q_test_y = q_test

# transfer to GPU device
q_train_x, q_train_y = q_train_x.to(device), q_train_y.to(device)
q_test_x, q_test_y = q_test_x.to(device), q_test_y.to(device)

# Baseline adaptation + validation
pred = self.base_model(q_train_x[:,:,:self.input_size])
error = self.loss(pred, q_train_y)
error.backward()
Expand Down
10 changes: 4 additions & 6 deletions notebooks/01a_non_temporal_pipeline.ipynb

Large diffs are not rendered by default.

23 changes: 7 additions & 16 deletions notebooks/01b_temporal_pipeline.ipynb

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions notebooks/01c_with_encoder_pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@
" 'input_size': 5,\n",
" 'meta_lr': 0.001,\n",
" 'update_lr': 0.0001,\n",
" 'meta_time_lr': 0.01,\n",
" 'update_time_lr': 0.001,\n",
" 'num_lstm_layers': 1,\n",
" 'max_meta_step': 2,\n",
" 'finetune_size': 0.2,\n",
Expand Down Expand Up @@ -175,7 +173,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down

0 comments on commit 7892710

Please sign in to comment.