Skip to content

Commit

Permalink
Merge pull request #48 from boutproject/xarray
Browse files Browse the repository at this point in the history
Add basic xarray support
  • Loading branch information
ZedThree authored Apr 26, 2023
2 parents 618616b + 7ba5147 commit 0554ff6
Show file tree
Hide file tree
Showing 17 changed files with 10 additions and 31 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
# pre-commit autoupdate
repos:
- repo: https://github.com/psf/black
rev: 22.3.0
rev: 23.1.0
hooks:
- id: black

- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
Expand All @@ -24,6 +24,6 @@ repos:
args: [--prose-wrap=always, --print-width=88]

- repo: https://github.com/PyCQA/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
3 changes: 0 additions & 3 deletions boututils/View3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def View3D(g, path=None, gb=None):
# q=i surfaces

for i in range(np.shape(x)[0]):

s = mlab.pipeline.streamline(field)
s.streamline_type = "line"
# s.seed.widget = s.seed.widget_list[0]
Expand Down Expand Up @@ -328,7 +327,6 @@ def View3D(g, path=None, gb=None):


def magnetic_field(g, X, Y, Z, rmin, rmax, zmin, zmax, Br, Bz, Btrz):

rho = np.sqrt(X**2 + Y**2)
phi = np.arctan2(Y, X)

Expand Down Expand Up @@ -371,7 +369,6 @@ def magnetic_field(g, X, Y, Z, rmin, rmax, zmin, zmax, Br, Bz, Btrz):


def psi_field(g, X, Y, Z, rmin, rmax, zmin, zmax):

rho = np.sqrt(X**2 + Y**2)

psi = np.zeros(np.shape(X))
Expand Down
2 changes: 0 additions & 2 deletions boututils/analyse_equil_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def analyse_equil(F, R, Z):
print("")

if n_xpoint > 0:

# Find the primary separatrix

# First remove non-monotonic separatrices
Expand Down Expand Up @@ -265,7 +264,6 @@ def analyse_equil(F, R, Z):
inner_sep = 0

else:

# No x-points. Pick mid-point in f

xpt_f = 0.5 * (numpy.max(F) + numpy.min(F))
Expand Down
1 change: 0 additions & 1 deletion boututils/anim.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def anim(s, d, *args, **kwargs):


if __name__ == "__main__":

path = "../../../examples/elm-pb/data"

data = collect("P", path=path)
Expand Down
1 change: 0 additions & 1 deletion boututils/boutgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def aligned_points(grid, nz=1, period=1.0, maxshift=0.4):

start = 0
for y in range(ny):

end = start + nx * nz

phi = zshift[:, y] + phi0[:, None]
Expand Down
1 change: 0 additions & 1 deletion boututils/closest_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# Find the closest contour line to a given point
def closest_line(n, x, y, ri, zi, mind=None):

mind = numpy.min((x[0] - ri) ** 2 + (y[0] - zi) ** 2)
ind = 0

Expand Down
1 change: 0 additions & 1 deletion boututils/crosslines.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def meshgrid_as_strided(x, y, mask=None):
nans_[:, :] = np.isnan(ua)

if not np.ma.any(nans):

# remove duplicate cases where intersection happens on an endpoint
# ignore[np.ma.where((ua[:, :-1] == 1) & (ua[:, 1:] == 0))] = True
# ignore[np.ma.where((ub[:-1, :] == 1) & (ub[1:, :] == 0))] = True
Expand Down
8 changes: 5 additions & 3 deletions boututils/datafile.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,8 @@ def _bout_dimensions_from_var(self, data):
try:
bout_type = data.attributes["bout_type"]
except AttributeError:
if hasattr(data, "dims"):
return data.dims
defdims_list = [
(),
("t",),
Expand All @@ -533,7 +535,6 @@ def _bout_dimensions_from_var(self, data):
return BoutArray.dims_from_type(bout_type)

def write(self, name, data, info=False):

if not self.writeable:
raise Exception("File not writeable. Open with write=True keyword")

Expand All @@ -542,6 +543,9 @@ def write(self, name, data, info=False):
# Get the variable type
t = type(data).__name__

if t == "DataArray":
t = data.dtype.str

if t == "NoneType":
print("DataFile: None passed as data to write. Ignoring")
return
Expand Down Expand Up @@ -911,7 +915,6 @@ def size(self, varname):
return var.shape

def write(self, name, data, info=False):

if not self.writeable:
raise Exception("File not writeable. Open with write=True keyword")

Expand Down Expand Up @@ -973,7 +976,6 @@ def list_file_attributes(self):
return self.handle.attrs.keys()

def attributes(self, varname):

try:
return self._attributes_cache[varname]
except KeyError:
Expand Down
4 changes: 0 additions & 4 deletions boututils/efit_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@


def View2D(g, option=0):

# plot and check the field
fig = figure(num=2, figsize=(16, 6))
# fig.suptitle('Efit Analysis', fontsize=20)
Expand Down Expand Up @@ -350,7 +349,6 @@ def View2D(g, option=0):


def surface(cs, i, f, opt_ri, opt_zi, style, iplot=0):

# contour_lines( F, np.arange(nx).astype(float), np.arange(ny).astype(float),
# levels=[start_f])
# cs=contour( g.r, g.z, g.psi, levels=[f])
Expand Down Expand Up @@ -400,7 +398,6 @@ def surface(cs, i, f, opt_ri, opt_zi, style, iplot=0):
# y=yy
#
if iplot == 0:

# plot the start_f line
zc = cs.collections[i]
setp(zc, linewidth=4, linestyle=style[i])
Expand All @@ -417,7 +414,6 @@ def surface(cs, i, f, opt_ri, opt_zi, style, iplot=0):


if __name__ == "__main__":

path = "../../tokamak_grids/pyGridGen/"

g = read_geqdsk(path + "g118898.03400")
Expand Down
1 change: 0 additions & 1 deletion boututils/fft_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def fft_integrate(y, loop=None):


def test_integrate():

n = 10
dx = 2.0 * np.pi / np.float(n)
x = dx * np.arange(n)
Expand Down
1 change: 0 additions & 1 deletion boututils/int_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def int_func(xin, fin=None, simple=None):
g[i] = g[i - 1] + 0.5 * (x[i] - x[i - 1]) * (f[i] + f[i - 1])

else:

n2 = numpy.int(old_div(n, 2))

g[0] = 0.0
Expand Down
3 changes: 0 additions & 3 deletions boututils/mode_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

# interpolates a 1D periodic function
def zinterp(v, zind):

v = numpy.ravel(v)

nz = numpy.size(v)
Expand Down Expand Up @@ -70,7 +69,6 @@ def mode_structure(
pmodes=None,
_extra=None,
):

# ON_ERROR, 2
#
# period = 1 ; default = full torus
Expand Down Expand Up @@ -362,7 +360,6 @@ def mode_structure(
#
#
if subset is not None:

# get number of modes larger than 5% of the maximum
count = numpy.size(numpy.where(fmax > 0.10 * numpy.max(fmax)))

Expand Down
1 change: 0 additions & 1 deletion boututils/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class BOUTOptions(object):
"""

def __init__(self, inp_path=None):

self._sections = ["root"]

for section in self._sections:
Expand Down
1 change: 0 additions & 1 deletion boututils/plotpolslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def plotpolslice(var3d, gridfile, period=1, zangle=0.0, rz=1, fig=0):
z[x, ypos] = zxy[x, ny - 1]

if fig == 1:

f = mlab.figure(size=(600, 600))
# Tell visual to use this as the viewer.
visual.set_viewer(f)
Expand Down
1 change: 0 additions & 1 deletion boututils/radial_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
def radial_grid(
n, pin, pout, include_in, include_out, seps, sep_factor, in_dp=None, out_dp=None
):

if n == 1:
return [0.5 * (pin + pout)]

Expand Down
1 change: 0 additions & 1 deletion boututils/read_geqdsk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@


def read_geqdsk(file):

data = Geqdsk()

data.openFile(file)
Expand Down
1 change: 0 additions & 1 deletion boututils/showdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,6 @@ def showdata(
clevels = []

for i in range(0, Nvar):

dummymax.append([])
dummymin.append([])
for j in range(0, Nlines[i]):
Expand Down

0 comments on commit 0554ff6

Please sign in to comment.