diff --git a/utils.py b/utils.py index d1d1707..4cf48bf 100644 --- a/utils.py +++ b/utils.py @@ -329,6 +329,7 @@ def train(dataloader, pde_loss.backward() if use_dist: average_gradients(model) + torch.nn.utils.clip_grad_norm_(model.parameters(), 10) optimizer.step() total_examples_finished += len(data) print(f"{device}: loss: {pde_loss.item():>7f} [{total_examples_finished:>5d}/{size:>5d}]")