-
Notifications
You must be signed in to change notification settings - Fork 9
/
global_optimization_fwdgrad.py
44 lines (33 loc) · 1.08 KB
/
global_optimization_fwdgrad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import time
import torch.func as fc
import hydra
import torch
@hydra.main(config_path="./configs/", config_name="global_optim_config.yaml")
def main(cfg):
torch.manual_seed(cfg.seed)
params = torch.rand(2)
t_total = 0
for iteration in range(cfg.iterations):
t0 = time.perf_counter()
# Sample perturbation vector
v_params = torch.randn_like(params)
# Forward AD
func_value, jvp = fc.jvp(
hydra.utils.call(
cfg.function,
),
(params,),
(v_params,),
)
# Forward gradient + parmeter update (SGD)
params = params.sub_(cfg.learning_rate * jvp * v_params)
t1 = time.perf_counter()
t_total += t1 - t0
if iteration % 199 == 0:
print(
f"Iteration [{iteration + 1}/{cfg.iterations}], Loss: {func_value.item():.4f}, Time (s): {t1 - t0:.4f}"
)
print(f"Total time: {t_total:.4f}")
print(f"Parameters value:\n" f"\tx: {params[0]}\n" f"\ty: {params[1]}")
if __name__ == "__main__":
main()