From c3dae9ee0de9d9dd441728047f64ca5d757c32ce Mon Sep 17 00:00:00 2001 From: M1kol4j Date: Thu, 26 Jan 2023 12:03:14 +0100 Subject: [PATCH] [FIX] Critical fixes for correct evaluation of stagger derivatives // Tested against Bifrost. --- helita/sim/stagger.py | 78 ++++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 42 deletions(-) diff --git a/helita/sim/stagger.py b/helita/sim/stagger.py index 1dff6785..9718d8c8 100644 --- a/helita/sim/stagger.py +++ b/helita/sim/stagger.py @@ -74,85 +74,79 @@ def do(var, operation='xup', diff=None, pad_mode=None): return func(out, out_diff, up=up, derivative=derivative) -@njit(parallel=False) +@jit(parallel=True,nopython=True) def _xshift(var, diff, up=True, derivative=False): - if up: - sign = 1 - else: - sign = -1 + grdshf = 1 if up else 0 if derivative: pm = -1 - c = (-1 + (3**5 - 3) / (3**3 - 3)) / (5**5 - 5 - 5 * (3**5 - 3)) - b = (-1 - 120*c) / 24 - a = (1 - 3*b - 5*c) + c = (-1. + (3.**5 - 3.) / (3.**3 - 3.)) / (5.**5 - 5. - 5. * (3.**5 - 3)) + b = (-1. - 120.*c) / 24. + a = (1. - 3.*b - 5.*c) else: pm = 1 c = 3.0 / 256.0 b = -25.0 / 256.0 a = 0.5 - b - c - start = int(2.5 - sign*0.5) - end = - int(2.5 + sign*0.5) + start = int(3. - grdshf) + end = - int(2. + grdshf) nx, ny, nz = var.shape + out = np.zeros((nx, ny, nz)) for k in prange(nz): for j in prange(ny): for i in prange(start, nx + end): - var[i, j, k] = diff[i] * (a * (var[i, j, k] + pm * var[i + sign, j, k]) + - b * (var[i - sign*1, j, k] + pm * var[i + sign*2, j, k]) + - c * (var[i - sign*2, j, k] + pm * var[i + sign*3, j, k])) - return var[start:end] + out[i, j, k] = diff[i] * (a * (var[i + grdshf, j, k] + pm * var[i - 1 + grdshf, j, k]) + + b * (var[i + 1 + grdshf, j, k] + pm * var[i - 2 + grdshf, j, k]) + + c * (var[i + 2 + grdshf, j, k] + pm * var[i - 3 + grdshf, j, k])) + return out[start:end,:,:] -@njit(parallel=False) +@jit(parallel=True,nopython=True) def _yshift(var, diff, up=True, derivative=False): - if up: - sign = 1 - else: - sign = -1 + grdshf = 1 if up else 0 if derivative: pm = -1 - c = (-1 + (3**5 - 3) / (3**3 - 3)) / (5**5 - 5 - 5 * (3**5 - 3)) - b = (-1 - 120*c) / 24 - a = (1 - 3*b - 5*c) + c = (-1. + (3.**5 - 3.) / (3.**3 - 3.)) / (5.**5 - 5. - 5. * (3.**5 - 3)) + b = (-1. - 120.*c) / 24. + a = (1. - 3.*b - 5.*c) else: pm = 1 c = 3.0 / 256.0 b = -25.0 / 256.0 a = 0.5 - b - c - start = int(2.5 - sign*0.5) - end = - int(2.5 + sign*0.5) + start = int(3. - grdshf) + end = - int(2. + grdshf) nx, ny, nz = var.shape + out = np.zeros((nx, ny, nz)) for k in prange(nz): for j in prange(start, ny + end): for i in prange(nx): - var[i, j, k] = diff[j] * (a * (var[i, j, k] + pm * var[i, j + sign, k]) + - b * (var[i, j - sign*1, k] + pm * var[i, j + sign*2, k]) + - c * (var[i, j - sign*2, k] + pm * var[i, j + sign*3, k])) - return var[:, start:end] + out[i, j, k] = diff[j] * (a * (var[i, j + grdshf, k] + pm * var[i, j - 1 + grdshf, k]) + + b * (var[i, j + 1 + grdshf, k] + pm * var[i, j - 2 + grdshf, k]) + + c * (var[i, j + 2 + grdshf, k] + pm * var[i, j - 3 + grdshf, k])) + return out[:,start:end,:] -@njit(parallel=False) +@jit(parallel=True,nopython=True) def _zshift(var, diff, up=True, derivative=False): - if up: - sign = 1 - else: - sign = -1 + grdshf = 1 if up else 0 + start = int(3. - grdshf) + end = - int(2. + grdshf) if derivative: pm = -1 - c = (-1 + (3**5 - 3) / (3**3 - 3)) / (5**5 - 5 - 5 * (3**5 - 3)) - b = (-1 - 120*c) / 24 - a = (1 - 3*b - 5*c) + c = (-1. + (3.**5 - 3.) / (3.**3 - 3.)) / (5.**5 - 5. - 5. * (3.**5 - 3)) + b = (-1. - 120.*c) / 24. + a = (1. - 3.*b - 5.*c) else: pm = 1 c = 3.0 / 256.0 b = -25.0 / 256.0 a = 0.5 - b - c - start = int(2.5 - sign*0.5) - end = - int(2.5 + sign*0.5) nx, ny, nz = var.shape + out = np.zeros((nx, ny, nz)) for k in prange(start, nz + end): for j in prange(ny): for i in prange(nx): - var[i, j, k] = diff[k] * (a * (var[i, j, k] + pm * var[i, j, k + sign]) + - b * (var[i, j, k - sign*1] + pm * var[i, j, k + sign*2]) + - c * (var[i, j, k - sign*2] + pm * var[i, j, k + sign*3])) - return var[..., start:end] \ No newline at end of file + out[i, j, k] = diff[k] * (a * (var[i, j, k + grdshf] + pm * var[i, j, k - 1 + grdshf]) + + b * (var[i, j, k + 1 + grdshf] + pm * var[i, j, k - 2 + grdshf]) + + c * (var[i, j, k + 2 + grdshf] + pm * var[i, j, k - 3 + grdshf])) + return out[:,:,start:end] \ No newline at end of file