Skip to content

Commit

Permalink
[FIX] Critical fixes for correct evaluation of stagger derivatives //…
Browse files Browse the repository at this point in the history
… Tested against Bifrost.
  • Loading branch information
M1kol4j committed Jan 26, 2023
1 parent c6b46c6 commit c3dae9e
Showing 1 changed file with 36 additions and 42 deletions.
78 changes: 36 additions & 42 deletions helita/sim/stagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,85 +74,79 @@ def do(var, operation='xup', diff=None, pad_mode=None):
return func(out, out_diff, up=up, derivative=derivative)


@njit(parallel=False)
@jit(parallel=True,nopython=True)
def _xshift(var, diff, up=True, derivative=False):
if up:
sign = 1
else:
sign = -1
grdshf = 1 if up else 0
if derivative:
pm = -1
c = (-1 + (3**5 - 3) / (3**3 - 3)) / (5**5 - 5 - 5 * (3**5 - 3))
b = (-1 - 120*c) / 24
a = (1 - 3*b - 5*c)
c = (-1. + (3.**5 - 3.) / (3.**3 - 3.)) / (5.**5 - 5. - 5. * (3.**5 - 3))
b = (-1. - 120.*c) / 24.
a = (1. - 3.*b - 5.*c)
else:
pm = 1
c = 3.0 / 256.0
b = -25.0 / 256.0
a = 0.5 - b - c
start = int(2.5 - sign*0.5)
end = - int(2.5 + sign*0.5)
start = int(3. - grdshf)
end = - int(2. + grdshf)
nx, ny, nz = var.shape
out = np.zeros((nx, ny, nz))
for k in prange(nz):
for j in prange(ny):
for i in prange(start, nx + end):
var[i, j, k] = diff[i] * (a * (var[i, j, k] + pm * var[i + sign, j, k]) +
b * (var[i - sign*1, j, k] + pm * var[i + sign*2, j, k]) +
c * (var[i - sign*2, j, k] + pm * var[i + sign*3, j, k]))
return var[start:end]
out[i, j, k] = diff[i] * (a * (var[i + grdshf, j, k] + pm * var[i - 1 + grdshf, j, k]) +
b * (var[i + 1 + grdshf, j, k] + pm * var[i - 2 + grdshf, j, k]) +
c * (var[i + 2 + grdshf, j, k] + pm * var[i - 3 + grdshf, j, k]))
return out[start:end,:,:]


@njit(parallel=False)
@jit(parallel=True,nopython=True)
def _yshift(var, diff, up=True, derivative=False):
if up:
sign = 1
else:
sign = -1
grdshf = 1 if up else 0
if derivative:
pm = -1
c = (-1 + (3**5 - 3) / (3**3 - 3)) / (5**5 - 5 - 5 * (3**5 - 3))
b = (-1 - 120*c) / 24
a = (1 - 3*b - 5*c)
c = (-1. + (3.**5 - 3.) / (3.**3 - 3.)) / (5.**5 - 5. - 5. * (3.**5 - 3))
b = (-1. - 120.*c) / 24.
a = (1. - 3.*b - 5.*c)
else:
pm = 1
c = 3.0 / 256.0
b = -25.0 / 256.0
a = 0.5 - b - c
start = int(2.5 - sign*0.5)
end = - int(2.5 + sign*0.5)
start = int(3. - grdshf)
end = - int(2. + grdshf)
nx, ny, nz = var.shape
out = np.zeros((nx, ny, nz))
for k in prange(nz):
for j in prange(start, ny + end):
for i in prange(nx):
var[i, j, k] = diff[j] * (a * (var[i, j, k] + pm * var[i, j + sign, k]) +
b * (var[i, j - sign*1, k] + pm * var[i, j + sign*2, k]) +
c * (var[i, j - sign*2, k] + pm * var[i, j + sign*3, k]))
return var[:, start:end]
out[i, j, k] = diff[j] * (a * (var[i, j + grdshf, k] + pm * var[i, j - 1 + grdshf, k]) +
b * (var[i, j + 1 + grdshf, k] + pm * var[i, j - 2 + grdshf, k]) +
c * (var[i, j + 2 + grdshf, k] + pm * var[i, j - 3 + grdshf, k]))
return out[:,start:end,:]


@njit(parallel=False)
@jit(parallel=True,nopython=True)
def _zshift(var, diff, up=True, derivative=False):
if up:
sign = 1
else:
sign = -1
grdshf = 1 if up else 0
start = int(3. - grdshf)
end = - int(2. + grdshf)
if derivative:
pm = -1
c = (-1 + (3**5 - 3) / (3**3 - 3)) / (5**5 - 5 - 5 * (3**5 - 3))
b = (-1 - 120*c) / 24
a = (1 - 3*b - 5*c)
c = (-1. + (3.**5 - 3.) / (3.**3 - 3.)) / (5.**5 - 5. - 5. * (3.**5 - 3))
b = (-1. - 120.*c) / 24.
a = (1. - 3.*b - 5.*c)
else:
pm = 1
c = 3.0 / 256.0
b = -25.0 / 256.0
a = 0.5 - b - c
start = int(2.5 - sign*0.5)
end = - int(2.5 + sign*0.5)
nx, ny, nz = var.shape
out = np.zeros((nx, ny, nz))
for k in prange(start, nz + end):
for j in prange(ny):
for i in prange(nx):
var[i, j, k] = diff[k] * (a * (var[i, j, k] + pm * var[i, j, k + sign]) +
b * (var[i, j, k - sign*1] + pm * var[i, j, k + sign*2]) +
c * (var[i, j, k - sign*2] + pm * var[i, j, k + sign*3]))
return var[..., start:end]
out[i, j, k] = diff[k] * (a * (var[i, j, k + grdshf] + pm * var[i, j, k - 1 + grdshf]) +
b * (var[i, j, k + 1 + grdshf] + pm * var[i, j, k - 2 + grdshf]) +
c * (var[i, j, k + 2 + grdshf] + pm * var[i, j, k - 3 + grdshf]))
return out[:,:,start:end]

0 comments on commit c3dae9e

Please sign in to comment.