Skip to content

Commit

Permalink
updated notations
Browse files Browse the repository at this point in the history
  • Loading branch information
juannat7 committed Jun 14, 2023
1 parent eba0a8a commit 6a9a329
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 13 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ by learning key features in data-abundant regions and adapt them to fluxes in da
![Meta inference](https://github.com/juannat7/metaflux/blob/main/docs/gpp_infer.jpeg)

3. (experimental) `01c_with_encoder_pipeline`: adding context encoder to current classic metalearning model

![Encoder workflow](https://github.com/juannat7/metaflux/blob/main/docs/encoder_workflow.png)
![Meta inference with context encoder](https://github.com/juannat7/metaflux/blob/main/docs/gpp_encoder_infer.jpeg)
24 changes: 13 additions & 11 deletions metaflux/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,20 +224,20 @@ def _train_base(

with torch.backends.cudnn.flags(enabled=False):
# Baseline learning + evaluation
for task, (s_train, s_test) in enumerate(zip(base_train_k, base_test_k)):
s_train_x, s_train_y = s_train
s_test_x, s_test_y = s_test
s_train_x, s_train_y = s_train_x.to(device), s_train_y.to(device)
s_test_x, s_test_y = s_test_x.to(device), s_test_y.to(device)

pred = self.base_model(s_train_x[:,:,:self.input_size])
error = self.loss(pred, s_train_y)
for task, (b_train, b_test) in enumerate(zip(base_train_k, base_test_k)):
b_train_x, b_train_y = b_train
b_test_x, b_test_y = b_test
b_train_x, b_train_y = b_train_x.to(device), b_train_y.to(device)
b_test_x, b_test_y = b_test_x.to(device), b_test_y.to(device)

pred = self.base_model(b_train_x[:,:,:self.input_size])
error = self.loss(pred, b_train_y)
error.backward()
opt.step()

## Note: no gradient step
train_error += self.loss(self.base_model(s_test_x[:,:,:self.input_size]),
s_test_y).item()
train_error += self.loss(self.base_model(b_test_x[:,:,:self.input_size]),
b_test_y).item()

train_epoch.append(train_error/(task + 1))

Expand Down Expand Up @@ -293,7 +293,8 @@ def _grad_step(
encoder_schedule
) -> None:
"Subroutine to perform gradient step"


# normalize loss with the number of task
for _, p in self.maml.named_parameters():
p.grad.data.mul_(1.0/(task + 1))

Expand All @@ -311,5 +312,6 @@ def _get_k_shot(
dataloader,
k
):
"Get k-shot samples from each batch"
selected_dataloader = itertools.islice(iter(dataloader), 0, min(len(dataloader), k))
return selected_dataloader
2 changes: 1 addition & 1 deletion notebooks/01b_temporal_pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/01c_with_encoder_pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down

0 comments on commit 6a9a329

Please sign in to comment.