Skip to content

Commit

Permalink
add the direction loss, from a paper out of Wuhan China for accelerat…
Browse files Browse the repository at this point in the history
…ing DiT training
  • Loading branch information
lucidrains committed Nov 5, 2024
1 parent 291b3a4 commit 1b3af2e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,12 @@ sampled_actions = model(vision, commands, joint_state, trajectory_length = 32) #
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```

```bibtex
@inproceedings{Yao2024FasterDiTTF,
title = {FasterDiT: Towards Faster Diffusion Transformers Training without Architecture Modification},
author = {Jingfeng Yao and Wang Cheng and Wenyu Liu and Xinggang Wang},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273346237}
}
```
27 changes: 25 additions & 2 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def softclamp(t, value):

return (t / value).tanh() * value

# losses

def direction_loss(pred, target, dim = -1):
return 0.5 * (1. - F.cosine_similarity(pred, target, dim = dim))

# attention

class Attention(Module):
Expand Down Expand Up @@ -265,6 +270,7 @@ def __init__(
ff_kwargs: dict = dict(),
lm_loss_weight = 1.,
flow_loss_weight = 1.,
direction_loss_weight = 0.,
odeint_kwargs: dict = dict(
atol = 1e-5,
rtol = 1e-5,
Expand Down Expand Up @@ -340,10 +346,15 @@ def __init__(
self.lm_loss_weight = lm_loss_weight
self.flow_loss_weight = flow_loss_weight

self.has_direction_loss = direction_loss_weight > 0.
self.direction_loss_weight = direction_loss_weight

# sampling related

self.odeint_fn = partial(odeint, **odeint_kwargs)

self.register_buffer('zero', torch.tensor(0.), persistent = False)

@property
def device(self):
return next(self.parameters()).device
Expand Down Expand Up @@ -541,6 +552,13 @@ def forward(

flow_loss = F.mse_loss(flow, pred_actions_flow)

# maybe direction loss

dir_loss = self.zero

if self.has_direction_loss:
dir_loss = direction_loss(flow, pred_actions_flow)

# language cross entropy loss

language_logits = self.state_to_logits(tokens)
Expand All @@ -550,14 +568,19 @@ def forward(
labels
)

# loss breakdonw

loss_breakdown = (language_loss, flow_loss, dir_loss)

# total loss and return breakdown

total_loss = (
language_loss * self.lm_loss_weight +
flow_loss * self.flow_loss_weight
flow_loss * self.flow_loss_weight +
dir_loss * self.direction_loss_weight
)

return total_loss, (language_loss, flow_loss)
return total_loss, loss_breakdown

# fun

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.1"
version = "0.0.2"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 1b3af2e

Please sign in to comment.