Skip to content

Commit

Permalink
fix: Fix gradients plot
Browse files Browse the repository at this point in the history
  • Loading branch information
andrea-pasquale committed Feb 14, 2024
1 parent 79a38c4 commit c51272b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions plotscripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def plot_gradients(
"""
grads = dict(np.load(path / f"{GRADS_FILE + '.npz'}"))
config = json.loads((path / OPTIMIZATION_FILE).read_text())

ave_grads = []

print(len(grads["0"]))
# print([i for i in ])
for epoch in grads:
for grads_list in grads[epoch]:
ave_grads.append(np.mean(np.abs(grads_list)))
Expand All @@ -162,7 +162,7 @@ def plot_gradients(
label=r"$\langle |\partial_{\theta_i}\text{L}| \rangle_i$",
)
for b in range(config["nboost"] - 1):
boost_x = config["boost_frequency"] * (b + 1)
boost_x = len(grads[str(b)]) * (b + 1)
if b == 0:
plt.plot(
(boost_x, boost_x + 1),
Expand Down

0 comments on commit c51272b

Please sign in to comment.