Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Sep 26, 2024
2 parents 4c1b127 + e5bc6f8 commit 4a473d5
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 21 deletions.
72 changes: 51 additions & 21 deletions applications/hmc/dwf/ensembleK.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,11 @@ def mk_slv_s(solution_space, N, grid = None):
css = [
([], 8, F_grid_eo),
([], 16, F_grid_eo),
([], 8, F_grid)
([], 4, F_grid)
]

def store_css():
global css
ckp_css = g.checkpointer(f"{dst}/checkpoint3")
ckp_css.grid = U[0].grid

Expand All @@ -235,23 +236,44 @@ def store_css():
g.message("with grid", str(cc[i].grid))
ckp_css.save(cc[i])

# g.save(f"{dst}/css", css)

if os.path.exists(f"{dst}/checkpoint3"):
ckp_css = g.checkpointer(f"{dst}/checkpoint3")
def load_css():
global css

if os.path.exists(f"{dst}/checkpoint3"):
ckp_css = g.checkpointer(f"{dst}/checkpoint3")
ckp_css.grid = U[0].grid

for cc, ncc, cgrid in css:
params = [0.0]
if not ckp_css.load(params):
g.message("No more fields to load")
break
g.message(f"Loading {int(params[0])} solution fields")
for i in range(int(params[0])):
nn = g.vspincolor(cgrid)
if ckp_css.load(nn):
cc.append(nn)

def store_cfields(tag, flds):
ckp_css = g.checkpointer(f"{dst}/checkpoint.{tag}")
ckp_css.grid = U[0].grid

for cc, ncc, cgrid in css:
params = [0.0]
if not ckp_css.load(params):
g.message("No more fields to load")
break
g.message(f"Loading {int(params[0])} solution fields")
for i in range(int(params[0])):
nn = g.vspincolor(cgrid)
if ckp_css.load(nn):
g.message("Field norm2", g.norm2(nn))
cc.append(nn)
for i in range(len(flds)):
flds.save(flds[i])

def load_cfields(tag, flds):
if not os.path.exists(f"{dst}/checkpoint.{tag}"):
ckp_css = g.checkpointer(f"{dst}/checkpoint.{tag}")
ckp_css.grid = U[0].grid

for i in range(len(flds)):
if not ckp_css.load(flds[i]):
return False
g.message(f"Successfully restored {tag}")
return True


load_css()

hasenbusch_ratios = [ # Nf=2+1
(0.65, 1.0, None, two_flavor_ratio, mk_chron(cg_e, *css[0]), mk_chron(cg_s, *css[0]), light),
Expand Down Expand Up @@ -380,12 +402,17 @@ def fermion_force():
for i in range(len(hasenbusch_ratios)):
g.message(f"Hasenbusch ratio {hasenbusch_ratios[i][0]}/{hasenbusch_ratios[i][1]}")
forces[i] = action_fermions_s[i].gradient(fields[i], fields[i][0 : len(U)])
g.message("Ratio complete")

g.message("Log Time")
log.time()

g.message("Add and log forces")
for i in range(len(hasenbusch_ratios)):
log.gradient(forces[i], f"{hasenbusch_ratios[i][0]}/{hasenbusch_ratios[i][1]} {i}")
for j in range(len(x)):
x[j] += forces[i][j]
g.message("Done")

g.message(f"Fermion force done")
return x
Expand Down Expand Up @@ -430,7 +457,7 @@ def log_det_force_sp():
ip_log_det_fg = sympl.update_p_force_gradient(U, iq, U_mom, ip_log_det, ip_log_det_sp)

mdint = sympl.OMF2_force_gradient(
4, ip_fermion,
1, ip_fermion,
sympl.OMF2_force_gradient(2, ip_log_det,
sympl.OMF2_force_gradient(2, ip_gauge, iq, ip_gauge_fg),
ip_log_det_fg),
Expand All @@ -457,13 +484,16 @@ def hmc(tau):
else:
h0, s0 = params[-2:]
g.message("After H(true)",h0,s0)
for its in range(10):
g.message(f"tau-iteration: {its} -> {tau/10*its}")
for its in range(40):
g.message(f"tau-iteration: {its} -> {tau/40*its}")
if not ckp.load(params + U):
mdint(tau / 10)
mdint(tau / 40)
ckp.save(params + U)
if it % 2 == 1:
if it % 2 == 0:
store_css()

store_cfields(f"{its}", params + U)
# sys.exit(0)
g.message("After mdint(tau)")
h1, s1 = hamiltonian(False)
g.message("After H(false)")
Expand Down
27 changes: 27 additions & 0 deletions lib/gpt/qcd/pseudofermion/action/exact_one_flavor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,21 +118,48 @@ def gradient(self, fields, dfields):

frc = self._allocate_force(U)

g.barrier()
g.message("checkmark 0")
g.barrier()

Check failure on line 124 in lib/gpt/qcd/pseudofermion/action/exact_one_flavor.py

View workflow job for this annotation

GitHub Actions / lint

W293:blank line contains whitespace
inv_M12_adj = self.inverter(M12_adj)

g.barrier()
g.message("checkmark 1")
g.barrier()

inv_M11_adj = self.inverter(M11_adj)

g.barrier()
g.message("checkmark 2")
g.barrier()


m1 = self.m1
m2 = self.m2

w_plus = g(inv_M12_adj * M12.R * g.gamma[5] * M12.ImportUnphysicalFermion * Pplus * phi)

g.barrier()
g.message("checkmark 3")
g.barrier()

w_minus = g(inv_M11_adj * M11.R * g.gamma[5] * M11.ImportUnphysicalFermion * Pminus * phi)

g.barrier()
g.message("checkmark 4")
g.barrier()

w2_plus = g(g.gamma[5] * M12.R * M12.Dminus.adj() * w_plus)
w3_plus = g(Pplus * phi)

w2_minus = g(g.gamma[5] * M11.R * M11.Dminus.adj() * w_minus)
w3_minus = g(Pminus * phi)

g.barrier()
g.message("checkmark 5")
g.barrier()

self._accumulate(frc, M12.M_projected_gradient(w_plus, w2_plus), m1 - m2)
self._accumulate(
frc,
Expand Down

0 comments on commit 4a473d5

Please sign in to comment.