diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..ae6eb1d2 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,34 @@ +name: CI + +on: + push: + branches: + - 'master' + tags: + - 'v*' + - '!*dev*' + - '!*pre*' + - '!*post*' + pull_request: + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + core: + uses: OpenAstronomy/github-actions-workflows/.github/workflows/tox.yml@main + with: + submodules: false + envs: | + - linux: py310 + + test: + needs: [core] + uses: OpenAstronomy/github-actions-workflows/.github/workflows/tox.yml@main + with: + submodules: false + envs: | + - windows: py39 + - macos: py38 diff --git a/.gitignore b/.gitignore index 55e81011..e6f88f97 100644 --- a/.gitignore +++ b/.gitignore @@ -88,3 +88,15 @@ ENV/ # Rope project settings .ropeproject + +# VSCode +.vscode/ +.history/ + +# Cython Creations +helita/io/anapyio.c +helita/sim/cstagger.c +helita/utils/radtrans.c +helita/utils/utilsfast.c + +*~ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..f53686c6 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,58 @@ +repos: +# The warnings/errors we check for here are: + # E901 - SyntaxError or IndentationError + # E902 - IOError + # F822 - undefined name in __all__ + # F823 - local variable name referenced before assignment + # Others are taken care of by autopep8 + - repo: https://github.com/PyCQA/flake8 + rev: 5.0.4 + hooks: + - id: flake8 + args: + [ + "--count", + "--select", + "E901,E902,F822,F823", + ] + exclude: ".*(.fits|.fts|.fit|.header|.txt|tca.*|extern.*|.rst|.md)$" + - repo: https://github.com/PyCQA/autoflake + rev: v1.7.7 + hooks: + - id: autoflake + args: + [ + "--in-place", + "--remove-all-unused-imports", + "--remove-unused-variable", + ] + exclude: ".*(.fits|.fts|.fit|.header|.txt|tca.*|extern.*|.rst|.md)$" + - repo: https://github.com/PyCQA/isort + rev: 5.10.1 + hooks: + - id: isort + args: ["--sp", "setup.cfg"] + exclude: ".*(.fits|.fts|.fit|.header|.txt|tca.*|extern.*|.rst|.md)$" + - repo: https://github.com/pre-commit/mirrors-autopep8 + rev: v2.0.0 + hooks: + - id: autopep8 + args: ["--in-place","--max-line-length", "200"] + exclude: ".*(.fits|.fts|.fit|.header|.txt|tca.*|extern.*|.rst|.md)$" + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: check-ast + - id: check-case-conflict + - id: trailing-whitespace + exclude: ".*(.fits|.fts|.fit|.header|.txt)$" + - id: check-yaml + - id: debug-statements + - id: check-added-large-files + args: ['--enforce-all','--maxkb=1054'] + - id: end-of-file-fixer + exclude: ".*(.fits|.fts|.fit|.header|.txt|tca.*|.json)$|^CITATION.rst$" + - id: mixed-line-ending + exclude: ".*(.fits|.fts|.fit|.header|.txt|tca.*)$" +ci: + autofix_prs: false diff --git a/LICENSE b/LICENSE index 41de70be..65cabf56 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ BSD 3-Clause License -Copyright (c) 2017, The Helita developers and Institute of Theoretical Astrophysics. +Copyright (c) 2017-2022, The Helita developers and Institute of Theoretical Astrophysics. All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/README.md b/README.md index a700e3be..fa935314 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,10 @@ # Helita -Helita is a Python library for solar physics focused on interfacing with code and projects from the [Institute of Theoretical Astrophysics](http://astro.uio.no) (ITA) and the [Rosseland Centre for Solar Physics](https://www.mn.uio.no/rocs/) (RoCS) at the [University of Oslo](https://www.uio.no). The name comes from Helios + ITA. +Helita is a Python library for solar physics focused on interfacing with code and projects from the [Institute of Theoretical Astrophysics](http://astro.uio.no) (ITA) and the [Rosseland Centre for Solar Physics](https://www.mn.uio.no/rocs/) (RoCS) at the [University of Oslo](https://www.uio.no). +The name comes from Helios + ITA. The library is a loose collection of different scripts and classes with varying degrees of portability and usefulness. ## Documentation -For more details including installation instructions, please see the documentation at http://ita-solar.github.io/helita. +For more details including installation instructions, [please see the documentation.](http://ita-solar.github.io/helita) diff --git a/helita/__init__.py b/helita/__init__.py index ded8e713..bac3098f 100644 --- a/helita/__init__.py +++ b/helita/__init__.py @@ -4,7 +4,7 @@ __all__ = ["io", "obs", "sim", "utils"] -from . import io -from . import obs -from . import sim -from . import utils +#from . import io +#from . import obs +#from . import sim +#from . import utils diff --git a/helita/io/anapyio.pyx b/helita/io/anapyio.pyx index d53d5bf8..c164caa6 100644 --- a/helita/io/anapyio.pyx +++ b/helita/io/anapyio.pyx @@ -1,7 +1,11 @@ import os + import numpy as np + cimport numpy as np from stdlib cimport free, malloc + + cdef extern from "stdlib.h": void *memcpy(void *dst, void *src, long n) diff --git a/helita/io/crispex.py b/helita/io/crispex.py index d29a2455..5f17b26d 100644 --- a/helita/io/crispex.py +++ b/helita/io/crispex.py @@ -1,9 +1,9 @@ """ set of tools to deal with crispex data """ -import xarray import numpy as np import scipy.interpolate as interp +import xarray def write_buf(intensity, outfile, wave=None, stokes=False): diff --git a/helita/io/lp.py b/helita/io/lp.py index 2873cabb..395c4627 100644 --- a/helita/io/lp.py +++ b/helita/io/lp.py @@ -1,9 +1,10 @@ """ Set of tools to read and write 'La Palma' cubes """ -import numpy as np import os +import numpy as np + def make_header(image): ''' Creates header for La Palma images. ''' diff --git a/helita/io/sdf.py b/helita/io/sdf.py index 4c3064ad..54938829 100644 --- a/helita/io/sdf.py +++ b/helita/io/sdf.py @@ -11,7 +11,6 @@ def __init__(self, filename, verbose=False): self.verbose = verbose self.query(filename) - def query(self, filename, verbose=False): ''' Queries the file, returning datasets and shapes.''' f = open(filename, 'r') @@ -32,7 +31,6 @@ def query(self, filename, verbose=False): self.header_data(header) return - def header_data(self, header): ''' Breaks header string into variable informationp. ''' self.variables = {} diff --git a/helita/io/src/anacompress.c b/helita/io/src/anacompress.c index d4dfa5b6..7174b6b8 100644 --- a/helita/io/src/anacompress.c +++ b/helita/io/src/anacompress.c @@ -41,7 +41,7 @@ int anacrunchrun8(uint8_t *x,uint8_t *array,int slice,int nx,int ny,int limit,in for (iy=0;iy1) { x[i+1]=y.b[1]; if (nb>2) x[i+2]=y.b[2]; } } - + r1=r1+slice; /* bump r1 pass the fixed part */ i=r1>>3; j=r1 & 7; /* note that r3 is the # of bits required minus 1 */ @@ -174,7 +174,7 @@ int anacrunchrun8(uint8_t *x,uint8_t *array,int slice,int nx,int ny,int limit,in /* we have to put these in a form readable by the Vax (these may be used by fcwrite) */ if(t_endian){ // big endian - bswapi32(&(ch->tsize),1); bswapi32(&(ch->bsize),1); bswapi32(&(ch->nblocks),1); + bswapi32(&(ch->tsize),1); bswapi32(&(ch->bsize),1); bswapi32(&(ch->nblocks),1); } free(dif); return i; /*return # of bytes used */ @@ -454,7 +454,7 @@ int anacrunchrun(uint8_t *x,int16_t *array,int slice,int nx,int ny,int limit,int /* we have to put these in a form readable by the Vax (these may be used by fcwrite) */ if(t_endian){ // big endian - bswapi32(&(ch->tsize),1); bswapi32(&(ch->bsize),1); bswapi32(&(ch->nblocks),1); + bswapi32(&(ch->tsize),1); bswapi32(&(ch->bsize),1); bswapi32(&(ch->nblocks),1); } free(dif); return i; /*return # of bytes used */ @@ -464,7 +464,7 @@ int anacrunchrun(uint8_t *x,int16_t *array,int slice,int nx,int ny,int limit,int int anacrunch(uint8_t *x,int16_t *array,int slice,int nx,int ny,int limit,int t_endian) // compress 16 bit array into x (a byte array) using ny blocks each of size -// nx, bit slice size slice, returns # of bytes in x +// nx, bit slice size slice, returns # of bytes in x { uint8_t bits[8]={1,2,4,8,16,32,64,128}; unsigned register i,j,r1,in; @@ -487,9 +487,9 @@ int anacrunch(uint8_t *x,int16_t *array,int slice,int nx,int ny,int limit,int t_ mask-=1; // no inline expon. in C unsigned nb; // determine the # of bytes to transfer to 32 bit int for fixed portion if(slice==0){ - nb=0; + nb=0; }else{ - if(slice<2){ + if(slice<2){ nb=1; }else{ if(slice<10) nb=2; else nb=3; @@ -602,7 +602,7 @@ int anacrunch(uint8_t *x,int16_t *array,int slice,int nx,int ny,int limit,int t_ if(t_endian){ // we have to put these in a form readable by the Vax (these may be used by fcwrite) bswapi32(&(ch->tsize),1); bswapi32(&(ch->bsize),1); - bswapi32(&(ch->nblocks),1); + bswapi32(&(ch->nblocks),1); } return i; // return # of bytes used } @@ -725,9 +725,8 @@ int anacrunch32(uint8_t *x,int32_t *array,int slice,int nx,int ny,int limit,int /* we have to put these in a form readable by the Vax (these may be used by fcwrite) */ if(t_endian){ // big endian - bswapi32(&(ch->tsize),1); bswapi32(&(ch->bsize),1); bswapi32(&(ch->nblocks),1); + bswapi32(&(ch->tsize),1); bswapi32(&(ch->bsize),1); bswapi32(&(ch->nblocks),1); } /* printf("number of big ones for this I*4 = %d\n", big); */ return i; /*return # of bytes used */ } /* end of routine */ - diff --git a/helita/io/src/anadecompress.c b/helita/io/src/anadecompress.c index 603f26dd..a347613f 100644 --- a/helita/io/src/anadecompress.c +++ b/helita/io/src/anadecompress.c @@ -84,7 +84,7 @@ int anadecrunch32(unsigned char *x,int32_t *array,int r9,int nx,int ny,int littl if ((xq&16) != 0) r0+=5; else { if ((xq&32) != 0) r0+=6; else { if ((xq&64) != 0) r0+=7; else { - if ((xq&128) != 0) r0+=8; }}}}}}} break; } else { r0=r0+8; + if ((xq&128) != 0) r0+=8; }}}}}}} break; } else { r0=r0+8; /* add 8 bits for each all zero byte */ if (r0 > 32) { fprintf(stderr,"DECRUNCH -- bad bit sequence, cannot continue\n"); fprintf(stderr,"i = %d, r1 = %d, ix= %d, iy = %d\n",i,r1,ix,iy); @@ -92,13 +92,13 @@ int anadecrunch32(unsigned char *x,int32_t *array,int r9,int nx,int ny,int littl r1=r1+r0; /* update pointer */ /* r0 even or odd determines sign of difference */ /*printf("r0 = %d\n", r0);*/ - if ((r0&1) != 0) { + if ((r0&1) != 0) { /* positive case */ /*printf("plus case, r0, r2, iq = %d %d %d\n", r0, r2, iq);*/ r0=(r0/2)< 32) { fprintf(stderr,"DECRUNCH -- bad bit sequence, cannot continue\n"); fprintf(stderr,"i = %d, r1 = %d, ix= %d, iy = %d\n",i,r1,ix,iy); return -1; } } } } r1=r1+r0; /* update pointer */ /* r0 even or odd determines sign of difference */ - if ((r0&1) != 0) { + if ((r0&1) != 0) { /* positive case */ r0=(r0/2)< 32) { fprintf(stderr,"DECRUNCH -- bad bit sequence, cannot continue"); return -1; } } } } r1=r1+r0; /* update pointer */ /* r0 even or odd determines sign of difference */ - if ((r0&1) != 0) { + if ((r0&1) != 0) { /* positive case */ r0=(r0/2)< 32) { fprintf(stderr,"DECRUNCH -- bad bit sequence, cannot continue\n"); fprintf(stderr,"i = %d, r1 = %d, iy = %d\n",i,r1,iy); return -1; } } } } r1=r1+r0; /* update pointer */ /* r0 even or odd determines sign of difference */ - if ((r0&1) != 0) { + if ((r0&1) != 0) { /* positive case */ r0=(r0/2)< 32) { fprintf(stderr,"DECRUNCH -- bad bit sequence, cannot continue\n"); fprintf(stderr,"i = %d, r1 = %d, iy = %d\n",i,r1,iy); return -1; } } } } r1=r1+r0; /* update pointer */ /* r0 even or odd determines sign of difference */ - if ((r0&1) != 0) { + if ((r0&1) != 0) { /* positive case */ r0=(r0/2)< -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include #include -#include "types.h" +#include "types.h" #define ANA_VAR_SZ {1,2,4,4,8,8} @@ -19,11 +19,11 @@ #define FLOAT64 4 #define INT64 5 -#define M_TM_INPRO 0 +#define M_TM_INPRO 0 #define M_TM_INFUN -1 -#include "anadecompress.h" -#include "anacompress.h" +#include "anadecompress.h" +#include "anacompress.h" static __inline int min(int a,int b) { @@ -86,7 +86,7 @@ int ck_synch_hd(FILE *fin,struct fzhead *fh,int t_endian) } if(syncpat==t_endian){ fprintf(stderr,"ck_synch_hd: warning: reversed F0 synch pattern\n"); - wwflag=1; + wwflag=1; } if(fh->nhb>1){ // if the header is long, read in the rest now if(fh->nhb>15){ @@ -100,19 +100,19 @@ int ck_synch_hd(FILE *fin,struct fzhead *fh,int t_endian) free(buf); // not very useful? } if(t_endian) bswapi32(fh->dim,fh->ndim); // for big endian machines - return wwflag; + return wwflag; } -char *ana_fzhead(char *file_name) // fzhead subroutine +char *ana_fzhead(char *file_name) // fzhead subroutine { struct stat stat_buf; if(stat(file_name,&stat_buf)<0){ - fprintf(stderr,"ana_fzhead: error: file \"%s\" not found.\n",file_name); + fprintf(stderr,"ana_fzhead: error: file \"%s\" not found.\n",file_name); return 0; } fprintf(stdout,"reading ANA file \"%s\" header\n",file_name); int one=1; - int t_endian=(*(char*)&one==0); // an endian detector, taken from SL's tiff library + int t_endian=(*(char*)&one==0); // an endian detector, taken from SL's tiff library // FILE *fin=fopen(file_name,"r"); if(!fin){ @@ -125,20 +125,20 @@ char *ana_fzhead(char *file_name) // fzhead subroutine char *header=strcpy(malloc(strlen(fh.txt)+1),fh.txt); fclose(fin); - return header; + return header; } -uint8_t *ana_fzread(char *file_name,int **ds,int *nd,char **header,int *type,int *osz) // fzread subroutine +uint8_t *ana_fzread(char *file_name,int **ds,int *nd,char **header,int *type,int *osz) // fzread subroutine { struct stat stat_buf; if(stat(file_name,&stat_buf)<0){ - fprintf(stderr,"ana_fzread: error: file \"%s\" not found.\n",file_name); + fprintf(stderr,"ana_fzread: error: file \"%s\" not found.\n",file_name); return 0; } fprintf(stdout,"reading ANA file \"%s\"\n",file_name); int type_sizes[]=ANA_VAR_SZ; int one=1; - int t_endian=(*(char*)&one==0); // an endian detector, taken from SL's tiff library + int t_endian=(*(char*)&one==0); // an endian detector, taken from SL's tiff library // FILE *fin=fopen(file_name,"r"); if(!fin){ @@ -172,7 +172,7 @@ uint8_t *ana_fzread(char *file_name,int **ds,int *nd,char **header,int *type,int } // read data int size=ch.tsize-14; - // Allocate 4 bytes extra to solve the possible illegal read beyond the + // Allocate 4 bytes extra to solve the possible illegal read beyond the // malloc'ed memory in anadecrunch() functions. uint8_t *buf=malloc(size+4); if(fread(buf,1,size,fin) uint8_t // int08-> int8 // others: stdtypes, include in configure diff --git a/helita/obs/__init__.py b/helita/obs/__init__.py index 77eed036..bff4307b 100644 --- a/helita/obs/__init__.py +++ b/helita/obs/__init__.py @@ -3,4 +3,3 @@ """ __all__ = ["hinode", "iris"] - diff --git a/helita/obs/hinode.py b/helita/obs/hinode.py index c01494ce..00a7d1bf 100644 --- a/helita/obs/hinode.py +++ b/helita/obs/hinode.py @@ -1,7 +1,6 @@ """ set of tools to deal with Hinode observations """ -import os import numpy as np from pkg_resources import resource_filename @@ -32,7 +31,7 @@ def bfi_filter(wave, band='CAH', norm=True): 'BLUE': '4504', 'GREEN': '5550', 'RED': '6684'} if band not in list(filt_names.keys()): msg = "Band name must be one of %s" % ', '.join(filt_names.keys()) - raise(ValueError, "Invalid band. " + msg + ".") + raise (ValueError, "Invalid band. " + msg + ".") cfile = resource_filename('helita', 'data/BFI_filter_%s.txt' % filt_names[band]) wave_filt, filt = np.loadtxt(cfile, unpack=True) diff --git a/helita/obs/iris.py b/helita/obs/iris.py index 119dba0a..984f2a23 100644 --- a/helita/obs/iris.py +++ b/helita/obs/iris.py @@ -102,6 +102,7 @@ def sj_filter(wave, band='IRIS_MGII_CORE', norm=True): """ from scipy import interpolate as interp from scipy.io.idl import readsav + # File with IRIS effective area CFILE = resource_filename('helita', 'data/iris_sra_20130211.geny') ea = readsav(CFILE).p0 @@ -164,8 +165,9 @@ def make_fits_level3_skel(filename, dtype, naxis, times, waves, wsizes, Extra header information. This should be used to write important information such as CDELTx, CRVALx, CPIXx, XCEN, YCEN, DATE_OBS. """ - from astropy.io import fits as pyfits from datetime import datetime + + from astropy.io import fits as pyfits VERSION = '001' FITSBLOCK = 2880 # FITS blocksize in bytes # Consistency checks @@ -344,10 +346,11 @@ def rh_to_fits_level3(filelist, outfile, windows, window_desc, times=None, array. Must be exact match. Useful to combine output files that have common wavelengths. """ - from ..sim import rh15d - from specutils.utils.wcs_utils import air_to_vac - from astropy.io import fits as pyfits from astropy import units as u + from astropy.io import fits as pyfits + from specutils.utils.wcs_utils import air_to_vac + + from ..sim import rh15d nt = len(filelist) robj = rh15d.Rh15dout() robj.read_ray(filelist[0]) @@ -397,7 +400,7 @@ def rh_to_fits_level3(filelist, outfile, windows, window_desc, times=None, "CRVAL3": waves[0], "CRVAL4": times[0], "CDELT1": xres, "CDELT2": xres, "CDELT3": np.median(np.diff(waves)), "CDELT4": tres} - desc = "Calculated from %s" % (robj.ray.params['atmosID']) + "Calculated from %s" % (robj.ray.params['atmosID']) make_fits_level3_skel(outfile, robj.ray.intensity.dtype, (ny, nx), times, waves, nwaves, descw=window_desc, cwaves=cwaves, header_extra=header_extra) diff --git a/helita/obs/iris_util.py b/helita/obs/iris_util.py index 1c6c1737..e7043b4a 100644 --- a/helita/obs/iris_util.py +++ b/helita/obs/iris_util.py @@ -1,18 +1,18 @@ """ Set of utility programs for IRIS. """ +import io import os import re -import io -import numpy as np -import pandas as pd -from datetime import datetime, timedelta from glob import glob - +from datetime import datetime, timedelta +from urllib.error import URLError, HTTPError +from urllib.parse import urljoin, urlparse # pylint: disable=F0401,E0611,E1103 from urllib.request import urlopen -from urllib.parse import urljoin, urlparse -from urllib.error import HTTPError, URLError + +import numpy as np +import pandas as pd def iris_timeline_parse(timeline_file): diff --git a/helita/sim/__init__.py b/helita/sim/__init__.py index a8e9f22e..c4167b99 100644 --- a/helita/sim/__init__.py +++ b/helita/sim/__init__.py @@ -4,10 +4,28 @@ with synthetic spectra. """ -__all__ = ["bifrost", "multi", "multi3d", "muram", "rh", "rh15d", "simtools", - "synobs"] +try: + found = True +except ImportError: + found = False -from . import bifrost -from . import multi -from . import muram -from . import rh + +try: + PYCUDA_INSTALLED = True +except ImportError: + PYCUDA_INSTALLED = False + + +if found: + __all__ = ["bifrost", "multi", "multi3d", "muram", "rh", "rh15d", + "simtools", "synobs", "ebysus", "cipmocct", "laresav", + "pypluto", "matsumotosav"] +else: + __all__ = ["bifrost", "multi", "multi3d", "muram", "rh", "rh15d", + "simtools", "synobs"] + + +from . import bifrost, multi, rh + +if found: + from . import muram diff --git a/helita/sim/aux_compare.py b/helita/sim/aux_compare.py new file mode 100644 index 00000000..9d549ac9 --- /dev/null +++ b/helita/sim/aux_compare.py @@ -0,0 +1,633 @@ +""" +Created by Sam Evans on Apr 24 2021 + +purpose: easily compare values between helita and aux vars from a simulation. + +Highest-level use-case: compare all the aux vars with their helita counterparts! + #<< input: + from helita.sim import aux_compare as axc + from helita.sim import ebysus as eb + dd = eb.EbysusData(...) # you must fill in the ... as appropriate. + c = axc.compare_all(dd) + + #>> output: + >->->->->->->-> initiate comparison for auxvar = etg <-<-<-<-<-<-<-< + + auxvar etg min= 4.000e+03, mean= 4.000e+03, max= 4.000e+03 + helvar tg -1 min= 4.000e+03, mean= 4.000e+03, max= 4.000e+03; mean ratio (aux / helita): 1.000e+00 + ---------------------------------------------------------------------------------------------------------------------- + + comparison_result(N_differ=0, N_total=1, runtime=0.0020618438720703125) + + + >->->->->->->-> initiate comparison for auxvar = mm_cnu <-<-<-<-<-<-<-< + + auxvar mm_cnu ( 1, 1) ( 1, 2) min= 8.280e+05, mean= 8.280e+05, max= 8.280e+05 + helvar nu_ij ( 1, 1) ( 1, 2) min= 8.280e+05, mean= 8.280e+05, max= 8.280e+05; mean ratio (aux / helita): 1.000e+00 + --------------------------------------------------------------------------------------------------------------------------------- + + ... << (more lines of output, which we are not showing you in this file, to save space.) + + #<< more input: + print(c) + + #>> more output: + {'N_compare': 30, 'N_var': 8, 'N_differ': 4, 'N_diffvar': 1, 'N_error': 1, + 'errors': [FileNotFoundError(2, 'No such file or directory')], 'runtime': 1.581925868988037} + +High-level use-case: compare a single aux var with its helita counterpart! + #<< input: + from helita.sim import aux_compare as axc + from helita.sim import ebysus as eb + dd = eb.EbysusData(...) # you must fill in the ... as appropriate. + axc.compare(dd, 'mfr_nu_es') + + #>> output: + auxvar mfr_nu_es ( 1, 1) min= 3.393e+04, mean= 3.393e+04, max= 3.393e+04 + helvar nu_ij -1 ( 1, 1) min= 1.715e+04, mean= 1.715e+04, max= 1.715e+04; mean ratio (aux / helita): 1.978e+00 + WARNING: RATIO DIFFERS FROM 1.000 + ------------------------------------------------------------------------------------------------------------------------------------ + auxvar mfr_nu_es ( 1, 2) min= 1.621e+05, mean= 1.621e+05, max= 1.621e+05 + helvar nu_ij -1 ( 1, 2) min= 1.622e+05, mean= 1.622e+05, max= 1.622e+05; mean ratio (aux / helita): 9.993e-01 + ------------------------------------------------------------------------------------------------------------------------------------ + + #<< more input: + axc.compare(dd, 'mm_cnu') + + #>> more output: + auxvar mm_cnu ( 1, 1) ( 1, 2) min= 8.280e+05, mean= 8.280e+05, max= 8.280e+05 + helvar nu_ij ( 1, 1) ( 1, 2) min= 8.280e+05, mean= 8.280e+05, max= 8.280e+05; mean ratio (aux / helita): 1.000e+00 + --------------------------------------------------------------------------------------------------------------------------------- + auxvar mm_cnu ( 1, 2) ( 1, 1) min= 8.280e+06, mean= 8.280e+06, max= 8.280e+06 + helvar nu_ij ( 1, 2) ( 1, 1) min= 8.280e+06, mean= 8.280e+06, max= 8.280e+06; mean ratio (aux / helita): 1.000e+00 + --------------------------------------------------------------------------------------------------------------------------------- + +# output format notes: +# vartype varname (ispecie, ilevel) (jspecie, jlevel) min mean max +# when ispecies < 0 or jspecie < 0 (i.e. for electrons), they may be shown as "specie" instead of "(ispecie, ilevel)". + + +TODO (maybe): + - allow to put kwargs in auxvar lookup. + - for example, ebysus defines mm_cross = 0 when ispecies is ion, to save space. + meanwhile get_var('cross') in helita will tell same values even if fluids are swapped. + e.g. get_var('mm_cross', ifluid=(1,2), jfluid=(1,1)) == 0 + get_var('cross', ifluid=(1,2), jfluid=(1,1)) == get_var('cross', ifluid=(1,1), jfluid=(1,2)) +""" + +import time +# import built-in +from collections import namedtuple + +# import external public modules +import numpy as np + +# import internal modules +from . import fluid_tools, tools + +# import external private modules +try: + from atom_py.at_tools import fluids as fl +except ImportError: + fl = tools.ImportFailed('at_tools.fluids') + +# set defaults +DEFAULT_TOLERANCE = 0.05 # the max for (1-abs(X/Y)) before we think X != Y + + +''' ----------------------------- lookup helita counterpart to aux var ----------------------------- ''' + +# dict of defaults for converting from auxvar to helita var (aka "helvar"). +AUXVARS = { + # aux var : helita var. if tuple, v[1] tells required fluid. + # v[1] tells jfluid for 2-fluid vars (such as 'nu_ij'); + # ifluid for 1-fluid vars (such as 'tg'). + 'etg': ('tg', -1), # electron temperature + 'mfe_tg': 'tg', # fluid temperature + 'mfr_nu_es': ('nu_ij', -1), # electron-fluid collision frequency + 'mm_cnu': 'nu_ij', # fluid - fluid collision frequency + 'mm_cross': 'cross', # cross section + 'mfr_cross': ('cross', -1), # cross section + 'mfr_tgei': ('tgij', -1), # tg+etg weighted. + 'mfr_p': 'p', # pressure + 'mfe_qcolue': ('qcol_uj', -1), # energy component of the ohmic term from velocity drift + 'mfe_qcolte': ('qcol_tgj', -1), # energy component of the ohmic term from temperature diff + 'mm_qcolt': 'qcol_tgj', # energy component of the coll. term from temperature diff + 'mm_qcolu': 'qcol_uj', # energy component of the coll. term from velocity drift +} +# add each of these, formatted by x=axis, to AUXVARS. +# e.g. {'e{x}': 'ef{x}'} --> {'ex': 'efx', 'ey': 'efy', 'ez': 'efz'}. +AUX_AXIAL_VARS = { + 'e{x}': 'ef{x}', # electric field + 'eu{x}': 'ue{x}', # electron velocity + 'i{x}': 'j{x}', # current density (charge per time per area) + 'bb_bat{x}': 'bat{x}', # "battery" term (contribution to electric field: grad(P_e)/(n_e q_e)) + 'mfp_bb_ddp{x}': 'mombat{x}', # momentum component of the battery term ni*qi*grad(P_e)/(n_e q_e) + 'mfp_ddp{x}': 'gradp{x}', # momentum component of the gradient of pressure + 'mm_cdp{x}dt': 'rij{x}', # momentum transfer rate to ifluid due to collisions with jfluid + 'mfp_cdp{x}dt': 'rijsum{x}', # momentum transfer rate to ifluid due to collisions with all other fluids + 'mfp_ecdp{x}dt': ('rij{x}', -1), # momentum transfer rate to electrons due to collisions with ifluid + 'mfp_ecdp{x}dt_ef': 'momohme{x}', # momentum component of the ohmic term + 'mm_driftu{x}': 'uid{x}', # velocity drifts +} +# add the axial vars to auxvars. +AXES = ['x', 'y', 'z'] + + +def _format(val, *args, **kw): + if isinstance(val, str): + return val.format(*args, **kw) + else: # handle tuples + return (_format(val[0], *args, **kw), *val[1:]) + + +for (aux, hel) in AUX_AXIAL_VARS.items(): + AUXVARS.update({_format(aux, x=x): _format(hel, x=x) for x in AXES}) + + +def get_helita_var(auxvar): + return AUXVARS[auxvar] + + +''' ----------------------------- get_var for helita & aux ----------------------------- ''' + + +def _callsig(helvar): + '''returns dict with keys for getvar for helvar''' + if isinstance(helvar, str): + return dict(var=helvar) + # else: helvar has len 2 or longer + result = dict(var=helvar[0]) + try: + next(iter(helvar[1])) + except TypeError: # helvar[1] is not a list + result.update(dict(mf_ispecies=helvar[1])) + else: # helvar[1] is a list + result.update(dict(ifluid=helvar[1])) + if len(helvar) > 2: # we have info for jfluid as well. + try: + next(iter(helvar[2])) + except TypeError: + result.update(dict(mf_jspecies=helvar[2])) + else: + result.update(dict(jfluid=helvar[2])) + return result + + +def _loop_fluids(obj, callsig): + '''return the fluid kws which need to be looped through. + obj should be EbysusData object. + callsig should be _callsig(helvar). + returns a tuple telling whether to loop through (ifluid, jfluid) for helvar. + ''' + var = callsig['var'] + search = obj.search_vardict(var) + nfluid = search.result['nfluid'] + if nfluid is None: # we do not need to loop through any fluids. + return (False, False) + elif nfluid == 0: # we do not need to loop through any fluids. + assert list(callsig.keys()) == ['var'], "invalid var tuple in AUXVARS for nfluid=0 var '{}'".format(var) + return (False, False) + elif nfluid == 1: # we might need to loop through ifluid. + result = [True, False] + for kw in ['mf_ispecies', 'ifluid']: + if kw in callsig.keys(): + result[0] = False # we do not need to loop through ifluid. + break + return tuple(result) + elif nfluid == 2: # we might need to loop through ifluid and/or jfluid. + result = [True, True] + for kw in ['mf_jspecies', 'jfluid']: + if kw in callsig.keys(): + result[1] = False # we do not need to loop through jfluid. + break + for kw in ['mf_ispecies', 'ifluid']: + if kw in callsig.keys(): + result[0] = False # we do not need to loop through ifluid. + break + return tuple(result) + else: + raise NotImplementedError # we don't know what to do when nfluid is not 0, 1, 2, or None. + + +def _iter_fluids(fluids, loopfluids, **kw__fluid_pairs): + '''returns an iterator which yields pairs of dicts: (daux, dhel) + daux are the fluid kws to call with aux var + dhel are the fluid kws to call with helita var. + + loopfluids == + (False, False) -> yields (dict(), dict()) then stops iteration. + (True, False) -> yields (dict(ifluid=fluid), dict(ifluid=fluid)) for fluid in fluids. + (False, True ) -> yields (dict(ifluid=fluid), dict(jfluid=fluid)) for fluid in fluids. + (True, True) -> yields (x, x) where x is a dict with keys ifluid, jfluid, + and we iterate over pairs of ifluid, jfluid. + **kw__fluid_pairs + only matters if loopfluids == (True, True); + these kwargs go to fluid_tools.fluid_pairs. + ''' + loopi, loopj = loopfluids + if not loopi and not loopj: + x = dict() + yield (x, x) + elif loopi and not loopj: + for fluid in fluids: + x = dict(ifluid=fluid) + yield (x, x) + elif not loopi and loopj: + for fluid in fluids: + yield (dict(ifluid=fluid), dict(jfluid=fluid)) + elif loopi and loopj: + for ifluid, jfluid in fluid_tools.fluid_pairs(fluids, **kw__fluid_pairs): + x = dict(ifluid=ifluid, jfluid=jfluid) + yield (x, x) + + +def _SL_fluids(fluids_dict, f=lambda fluid: fluid): + '''update values in fluids_dict by applying f''' + return {key: f(val) for key, val in fluids_dict.items()} + + +def _setup_fluid_kw(auxvar, callsig, auxfluids, helfluids, f=lambda fluid: fluid): + '''returns ((args, kwargs) to use with auxvar, (args, kwargs) to use with helitavar) + args with be the list [var] + kwargs will be the dict of auxfluids (or helfluids). (species, levels) only. + + f is applied to all values in auxfluids and helfluids. + use f = (lambda fluid: fluid.SL) when fluids are at_tools.fluids.Fluids, + to convert them to (species, level) tuples. + ''' + # convert fluids to SLs via f + auxfluids = _SL_fluids(auxfluids, f=f) + helfluids = _SL_fluids(helfluids, f=f) + # pop var from callsig (we pass var as arg rather than kwarg). + callsigcopy = callsig.copy() # copy to ensure callsig is not altered + helvar = callsigcopy.pop('var') + helfluids.update(callsigcopy) + # make & return output + callaux = ([auxvar], auxfluids) + callhel = ([helvar], helfluids) + return (callaux, callhel) + + +def _get_fluids_and_f(obj, fluids=None, f=lambda fluid: fluid): + '''returns fluids, f. + if fluids is None: + fluids = fl.Fluids(dd=obj) + f = lambda fluid: fluid.SL + if we failed to import at_tools.fluids, try fluids=obj.fluids, before giving up. + ''' + if fluids is None: + def f(fluid): return fluid.SL + if fl is None: + if not obj.hasattr('fluids'): + errmsg = ("{} has no attribute 'fluids', we failed to import at_tools.fluids " + "and you didn't input fluids, so we don't know which fluids to use!") + errmsg = errmsg.format(obj) + raise NameError(errmsg) # choosing NameError type because "fluids" is "not defined". + else: + fluids = obj.fluids + else: + fluids = fl.Fluids(dd=obj) + return (fluids, f) + + +def iter_get_var(obj, auxvar, helvar=None, fluids=None, f=lambda fluid: fluid, + ordered=False, allow_same=False, quick_ratio=False, **kw__get_var): + '''gets values for auxvar and helita var. + + yields dict(vars = dict(aux=auxvar, hel=helita var name), + vals = dict(aux=get_var(auxvar), hel=get_var(helvar)), + fluids = dict(aux=auxfluids_dict, hel=helfluids_dict)), + SLs = dict(aux=auxfluidsSL, hel=helfluidsSL)) , + ) + + obj: EbysusData object + we will do obj.get_var(...) to get the values. + auxvar: str + name of var in aux. e.g. 'mfe_tg' for temperature. + helvar: None (default), or str, or tuple + None -> lookup helvar using helita.sim.aux_compare.AUXVARS. + str -> use this as helvar. Impose no required fluids on helvar. + tuple -> use helvar[0] as helvar. Impose required fluids: + helvar[1] imposes ifluid or mf_ispecies. + helvar[2] imposes jfluid or mf_jspecies (if helvar[2] exists). + fluids: None (default) or list of fluids + None -> use fluids = fl.Fluids(dd=obj). + f: function which converts fluid to (species, level) tuple + if fluids is None, f is ignored, we will instead use f = lambda fluid: fluid.SL + otherwise, we apply f to each fluid in fluids, before putting it into get_var. + Note: auxfluids_dict and helfluids_dict contain fluids before f is applied. + if iterating over fluid pairs, the following kwargs also matter: + ordered: False (default) or True + whether to only yield ordered combinations of fluid pairs (AB but not BA) + allow_same: False (default) or True + whether to also yield pairs of fluids which are the same (AA, BB, etc.) + quick_ratio: False (default) or True + whether to calculate (aux/hel) using means (if True) or full arrays (if False) + + **kw__get_var goes to obj.get_var(). + ''' + if helvar is None: + helvar = get_helita_var(auxvar) + callsig = _callsig(helvar) + loopfluids = _loop_fluids(obj, callsig) + # set fluids if necessary + if loopfluids[0] or loopfluids[1]: + fluids, f = _get_fluids_and_f(obj, fluids, f) + iterfluids = _iter_fluids(fluids, loopfluids, ordered=ordered, allow_same=allow_same) + for auxfluids_dict, helfluids_dict in iterfluids: + auxcall, helcall = _setup_fluid_kw(auxvar, callsig, auxfluids_dict, helfluids_dict, f=f) + auxfluidsSL = auxcall[1].copy() + helfluidsSL = helcall[1].copy() + auxcall[1].update(**kw__get_var) + helcall[1].update(**kw__get_var) + # actually get values by reading data and/or doing calculations + auxval = obj.get_var(*auxcall[0], **auxcall[1]) + helval = obj.get_var(*helcall[0], **helcall[1]) + # format output & yield it + vardict = dict(aux=auxvar, hel=callsig['var']) + valdict = dict(aux=auxval, hel=helval) + fludict = dict(aux=auxfluids_dict, hel=helfluids_dict) + SLsdict = dict(aux=auxfluidsSL, hel=helfluidsSL) + result = dict(vars=vardict, vals=valdict, fluids=fludict, SLs=SLsdict) + if not quick_ratio: + vals_equal = (helval == auxval) # handle "both equal to 0" case. + if np.count_nonzero(vals_equal) > 0: + helval_ = np.copy(helval) + auxval_ = np.copy(auxval) + helval_[vals_equal] = 1 + auxval_[vals_equal] = 1 + else: + helval_ = helval + auxval_ = auxval + result['ratio'] = tools.finite_mean(auxval_ / helval_) + yield result + + +''' ----------------------------- prettyprint comparison ----------------------------- ''' + + +def _stats(arr): + '''return stats for arr. dict with min, mean, max.''' + return dict(min=arr.min(), mean=arr.mean(), max=arr.max()) + + +def _strstats(arr_or_stats, fmt='{: 0.3e}', fmtkey='{:>4s}'): + '''return pretty string for stats. min=__, mean=__, max=__.''' + keys = ['min', 'mean', 'max'] + if isinstance(arr_or_stats, dict): # arr_or_stats is stats + x = arr_or_stats + else: # arr_or_stats is arr + x = _stats(arr_or_stats) + return ', '.join([fmtkey.format(key) + '='+fmt.format(x[key]) for key in keys]) + + +def _strvals(valdict): + '''return dict of pretty str for vals from valdict. keys 'hel', 'aux', 'stats'. + 'stats' contains dict of stats for hel & aux. + ''' + result = dict(stats=dict()) + for aux in valdict.keys(): # for aux in ['aux', 'hel']: + stats = _stats(valdict[aux]) + strstats = _strstats(stats) + result[aux] = strstats + result['stats'][aux] = stats + return result + + +def _strSL(SL, fmtSL='({:2d},{:2d})', fmtS=' {:2d} ', fmtNone=' '*(1+2+1+2+1)): + '''pretty string for (specie, level) SL. (or just specie SL, or None SL)''' + if SL is None: + return fmtNone + try: + next(iter(SL)) # error if SL is not a list. + except TypeError: + return fmtS.format(SL) # SL is just S + else: + return fmtSL.format(*SL) # SL is (S, L) + + +def _strfluids(fludict): + '''return dict of pretty str for fluids from fludict. keys 'hel', 'aux'.''' + N = max(len(fludict['aux']), len(fludict['hel'])) + result = dict() + for aux in fludict.keys(): # for aux in ['aux', 'hel']: + s = '' + if N > 0: + iSL = fludict[aux].get('ifluid', fludict[aux].get('mf_ispecies', None)) + s += _strSL(iSL) + ' ' + if N > 1: + jSL = fludict[aux].get('jfluid', fludict[aux].get('mf_jspecies', None)) + s += _strSL(jSL) + ' ' + result[aux] = s + return result + + +def _strvars(vardict, prefix=True): + '''return dict of pretty str for vars from vardict. keys 'hel', 'aux'. + prefix==True -> include prefix 'helita' or 'auxvar'. + ''' + L = max(len(vardict['aux']), len(vardict['hel'])) + fmt = '{:>'+str(L)+'s}' + result = dict() + for aux in vardict.keys(): # for aux in ['aux', 'hel']: + s = '' + if prefix: + s += dict(aux='auxvar', hel='helvar')[aux] + ' ' + s += fmt.format(vardict[aux]) + ' ' + result[aux] = s + return result + + +def prettyprint_comparison(x, printout=True, prefix=True, underline=True, + rattol=DEFAULT_TOLERANCE, return_warned=False, **kw__None): + '''pretty printing of info in x. x is one output of iter_get_var. + e.g.: for x in iter_get_var(...): prettyprint_comparison(x) + + printout: if False, return string instead of printing. + prefix: whether to include prefix of 'helita' or 'auxvar' at start of each line. + underline: whether to include a line of '------'... at the end. + rattol: if abs(1 - (mean aux / mean helita)) > rattol, print extra warning line. + return_warned: whether to also return whether we made a warning. + **kw__None goes to nowhere. + ''' + # get strings / values: + svars = _strvars(x['vars']) + sfluids = _strfluids(x['SLs']) + svals = _strvals(x['vals']) + meanaux = svals['stats']['aux']['mean'] + meanhel = svals['stats']['hel']['mean'] + if 'ratio' in x: + ratio = x['ratio'] + elif meanaux == 0.0 and meanhel == 0.0: + ratio = 1.0 + else: + ratio = meanaux / meanhel + ratstr = 'mean ratio (aux / helita): {: 0.3e}'.format(ratio) + # combine strings + key = 'aux' + s = ' '.join([svars[key], sfluids[key], svals[key]]) + '\n' + lline = len(s) + key = 'hel' + s += ' '.join([svars[key], sfluids[key], svals[key]]) + '; ' + s += ratstr + if (not np.isfinite(ratio)) or (abs(1 - ratio) > rattol): # then, add warning! + s += '\n' + ' '*(lline) + '>>> WARNING: RATIO DIFFERS FROM 1.000 <<<<' + warned = True + else: + warned = False + if underline: + s += '\n' + '-' * (lline + len(ratstr) + 10) + # print (or return) + result = None + if printout: + print(s) + else: + result = s + if return_warned: + result = (result, warned) + return result + + +''' ----------------------------- high-level comparison interface ----------------------------- ''' + +comparison_result = namedtuple('comparison_result', ('N_differ', 'N_total', 'runtime')) + + +@fluid_tools.maintain_fluids # restore dd.ifluid and dd.jfluid after finishing compare(). +def compare(obj, auxvar, helvar=None, fluids=None, **kwargs): + '''compare values of auxvar with appropriate helita var, for obj. + **kwargs propagate to iter_get_var, obj.get_var, and prettyprint_comparison. + + involves looping through fluids: + none (nfluid=0, e.g. 'ex'), + one (nfluid=1, e.g. 'tg'), or + two (nfluid=2, e.g. 'nu_ij'). + + Parameters + ---------- + helvar: None (default)or str or tuple + helita var corresponding to auxvar. E.g. 'efx' for auxvar='ex'. + it is assumed that helvar and auxvar use the same number of fluids. + For example, 'mfe_tg' and 'tg' each use one fluid. + For some vars, there is a slight hiccup. Example: 'etg' and 'tg'. + 'etg' is equivalent to 'tg' only when mf_ispecies=-1. + To accomodate such cases, we allow a tuple such as ('tg', -1) for helvar. + type of helvar, and explanations below: + None -> use default. (but not all auxvars have an existing default.) + all defaults which exist are hard-coded in helita.sim.aux_compare. + when default fails, use non-None helvar, + or edit helita.sim.aux_compare.AUXVARS to add default. + str -> use var = helvar. This one should do exactly what you expect. + Note that for this case, auxvar and helvar must depend on the + exact same fluids (e.g. both depend on just ifluid). + tuple-> use var = helvar[0]. The remaining items in the tuple + will force fluid kwargs for helvar. ints for mf_species; + tuples force fluid. Example: ('nu_ij', -1) forces mf_ispecies=-1, + and because nu_ij depends on 2 fluids (according to obj.vardict), + we still need to enter one fluid. So we will loop through + fluids, passing each fluid to helvar as jfluid, and auxvar as ifluid. + + fluids: None (default) or iterable (e.g. list) + None -> get fluids using obj. fluids = at_tools.fluids.Fluids(dd=obj) + iterable -> use these fluids. Should be tuples of (specie, level), + example: fluids = [(1,2),(2,2),(2,3)] + See aux_compare.iter_get_var for more documentation. + + Returns + ------- + returns namedtuple (N_differ, N_total, runtime), where: + N_differ = number of times helita and auxvar gave + different mean results (differing by more than rattol). + 0 is good, it means helita & auxvar agreed on everything! :) + N_total = total number of values compared. example: + if we compared 'mfe_tg' and 'tg' for ifluid in + [(1,1),(1,2),(2,3)], we will have N_total==3. + runtime = time it took to run, in seconds. + ''' + now = time.time() + N_warnings = 0 + N_total = 0 + for x in iter_get_var(obj, auxvar, helvar=helvar, fluids=fluids, **kwargs): + N_total += 1 + _, warned = prettyprint_comparison(x, return_warned=True, **kwargs) + if warned: + N_warnings += 1 + runtime = round(time.time() - now, 3) # round because sub-ms times are untrustworthy and ugly. + return comparison_result(N_warnings, N_total, runtime) + + +def _get_aux_vars(obj): + '''returns list of vars in aux based on obj.''' + return obj.params['aux'][obj.snapInd].split() + + +def compare_all(obj, aux=None, verbose=2, **kwargs): + '''compare all aux vars with their corresponding values in helita. + + each comparison involves looping through fluids: + none (nfluid=0, e.g. 'ex'), + one (nfluid=1, e.g. 'tg'), or + two (nfluid=2, e.g. 'nu_ij'). + + Parameters + ---------- + obj: an EbysusData object. + (or any object with get_var method.) + aux: None (default) or a list of strs + the list of aux vars to compare. + None -> get the list from obj (via obj.params['aux']). + verbose: 2 (default), 1, or 0 + 2 -> full print info + 1 -> print some details but set printout=False (unless printout is in kwargs) + 0 -> don't print anything. + **kwargs: + extra kwargs are passed to compare(). (i.e.: helita.sim.aux_compare.compare) + + Returns + ------- + returns dict with contents: + N_compare = number of actual values we compared. + (one for each set of fluids used for each (auxvar, helvar) pair) + N_var = number of vars we tried to run comparisons for. + (one for each (auxvar, helvar) pair) + N_differ = number of times helita and auxvar gave + different mean results (by more than rattol) + N_diffvar = number of vars for which helita and auxvar gave + different mean results at least once. + N_error = number of times compare() crashed due to error. + errors = list of errors raised. + runtime = time it took to run, in seconds. + + A 100% passing test looks like N_differ == N_error == 0. + ''' + now = time.time() + printout = kwargs.pop('printout', (verbose >= 2)) # default = (verbose>=2) + x = dict(N_compare=0, N_var=0, N_differ=0, N_diffvar=0, N_error=0, errors=[]) + auxvars = _get_aux_vars(obj) if aux is None else aux + for aux in auxvars: + if verbose: + banner = ' >->->->->->->-> {} <-<-<-<-<-<-<-<' + if printout: + banner += '\n' + print(banner.format('initiate comparison for auxvar = {}'.format(aux))) + x['N_var'] += 1 + try: + comp = compare(obj, aux, printout=printout, **kwargs) + except Exception as exc: + x['N_error'] += 1 + x['errors'] += [exc] + if verbose >= 1: + print('>>>', repr(exc), '\n') + else: + x['N_compare'] += comp[1] + x['N_differ'] += comp[0] + x['N_diffvar'] += (comp[0] > 0) + if printout: + print() # print a single new line + if verbose >= 1: + print(comp, '\n') + if printout: + print() # print a single new line + x['runtime'] = round(time.time() - now, 3) # round because sub-ms times are untrustworthy and ugly. + return x diff --git a/helita/sim/bifrost.py b/helita/sim/bifrost.py index bad7db21..ea157c6e 100644 --- a/helita/sim/bifrost.py +++ b/helita/sim/bifrost.py @@ -2,17 +2,35 @@ Set of programs to read and interact with output from Bifrost """ +# import builtin modules import os +import ast +import time +import weakref import warnings +import functools +import collections from glob import glob + +# import external public modules import numpy as np -from . import stagger from scipy import interpolate from scipy.ndimage import map_coordinates -from multiprocessing.dummy import Pool as ThreadPool + +from . import document_vars, file_memory, load_fromfile_quantities, stagger, tools, units +from .load_arithmetic_quantities import * +# import internal modules +from .load_quantities import * +from .tools import * + +# defaults +whsp = ' ' +AXES = ('x', 'y', 'z') + +# BifrostData class -class BifrostData(object): +class BifrostData(): """ Reads data from Bifrost simulations in native format. @@ -23,11 +41,9 @@ class BifrostData(object): will be added afterwards, and directory will be added before. snap - integer, optional Snapshot number. If None, will read first snapshot in sequence. - meshfile - string + meshfile - string, optional File name (including full path) for file with mesh. If set - to None (default), will try to read file listed in Bifrost files. - If set to "uniform", will create a uniform mesh. Do not use "uniform" - unless you really know what you're doing! + to None (default), a uniform mesh will be created. fdir - string, optional Directory where simulation files are. Must be a real path. verbose - bool, optional @@ -40,47 +56,97 @@ class BifrostData(object): ghost_analyse - bool, optional If True, will read data from ghost zones when this is saved to files. Default is never to read ghost zones. + do_stagger - bool, optional + whether to correctly account for the stagger mesh when doing operations. + if enabled, operations will take more time but produce more accurate results. + stagger_kind - string, optional + which method to use for performing stagger operations, if do_stagger. + options are 'cstagger', 'numba' (default), 'numpy'. See stagger.py for details. + More options may be defined later. Set stagger_kind='' to see all options. lowbus - bool, optional - Use True only if data is too big to load. It will do stagger + Use True only if data is too big to load. It will do cstagger operations layer by layer using threads (slower). numThreads - integer, optional number of threads for certain operations that use parallelism. + fast - whether to read data "fast", by only reading the requested data. + implemented as a flag, with False as default, for backwards + compatibility; some previous codes may have assumed non-requested + data was read. To avoid issues, just ensure you use get_var() + every time you want to have data, and don't assume things exist + (e.g. self.bx) unless you do get_var for that thing + (e.g. get_var('bx')). + units_output - string, optional + unit system for output. default 'simu' for simulation output. + options are 'simu', 'si', 'cgs'. + Only affects final values from (external calls to) get_var. + if not 'simu', self.got_units_name will store units string from latest get_var. + Do not use at the same time as non-default sel_units. + squeeze_output - bool, optional. default False + whether to apply np.squeeze() before returning the result of get_var. + print_freq - value, default 2. + number of seconds between print statements during get_varTime. + == 0 --> print update at every snapshot during get_varTime. + < 0 --> never print updates during get_varTime. + printing_stats - bool or dict, optional. default False + whether to print stats about values of var upon completing a(n external) call to get_var. + False --> don't print stats. + True --> do print stats. + dict --> do print stats, passing this dictionary as kwargs. Examples -------- This reads snapshot 383 from simulation "cb24bih", whose file root is "cb24bih", and is found at directory /data/cb24bih: - >>> a = BifrostData("cb24bih", snap=383, fdir="/data/cb24bih") + a = BifrostData("cb24bih", snap=383, fdir="/data/cb24bih") Scalar variables do not need de-staggering and are available as memory map (only loaded to memory when needed), e.g.: - >>> a.r.shape - (504, 504, 496) + a.r.shape + (504, 504, 496) Composite variables need to be obtained by get_var(): - >>> vx = a.get_var("ux") + vx = a.get_var("ux") """ - snap = None + ## CREATION ## def __init__(self, file_root, snap=None, meshfile=None, fdir='.', - verbose=True, dtype='f4', big_endian=False, - ghost_analyse=False, lowbus=False, numThreads=1): + fast=False, verbose=True, dtype='f4', big_endian=False, + cstagop=None, do_stagger=True, ghost_analyse=False, lowbus=False, + numThreads=1, params_only=False, sel_units=None, + use_relpath=False, stagger_kind=stagger.DEFAULT_STAGGER_KIND, + units_output='simu', squeeze_output=False, + print_freq=2, printing_stats=False, + iix=None, iiy=None, iiz=None): """ Loads metadata and initialises variables. """ - self.fdir = fdir + # bookkeeping + self.fdir = fdir if use_relpath else os.path.abspath(fdir) self.verbose = verbose + self.do_stagger = do_stagger if (cstagop is None) else cstagop self.lowbus = lowbus self.numThreads = numThreads self.file_root = os.path.join(self.fdir, file_root) self.root_name = file_root self.meshfile = meshfile self.ghost_analyse = ghost_analyse - self.lowbus = lowbus + self.stagger_kind = stagger_kind self.numThreads = numThreads + self.fast = fast + self._fast_skip_flag = False if fast else None # None-> never skip + self.squeeze_output = squeeze_output + self.print_freq = print_freq + self.printing_stats = printing_stats + + # units. Two options for management. Should only use one at a time; leave the other at default value. + self.units_output = units_output # < units.py system of managing units. + self.sel_units = sel_units # < other system of managing units. + + setattr(self, document_vars.LOADING_LEVEL, -1) # tells how deep we are into loading a quantity now. + # endianness and data type if big_endian: self.dtype = '>' + dtype @@ -88,50 +154,161 @@ def __init__(self, file_root, snap=None, meshfile=None, fdir='.', self.dtype = '<' + dtype self.hion = False self.heion = False - self.set_snap(snap) + try: - tmp = find_first_match("%s*%d*.idl" % (file_root, snap), fdir) - if tmp == None: - raise IndexError + tmp = find_first_match("%s*idl" % file_root, fdir) except IndexError: try: - tmp = find_first_match("%s*idl" % file_root, fdir) - if tmp == None: - raise IndexError + tmp = find_first_match("%s*idl.scr" % file_root, fdir) except IndexError: try: - tmp = find_first_match("%s*idl.scr" % file_root, fdir) - if tmp == None: - raise IndexError + tmp = find_first_match("mhd.in", fdir) except IndexError: - try: - tmp = find_first_match("mhd.in", fdir) - if tmp == None: - raise IndexError - except IndexError: - raise ValueError(("(EEE) init: no .idl or mhd.in files " - "found")) - self.uni = Bifrost_units(filename=tmp, fdir=fdir) + raise ValueError(("(EEE) init: no .idl or mhd.in files " + "found")) + self.uni = Bifrost_units(filename=tmp, fdir=fdir, parent=self) + + self.set_snap(snap, True, params_only=params_only) + + self.set_domain_iiaxes(iix=iix, iiy=iiy, iiz=iiz, internal=False) + + self.genvar() + self.transunits = False + self.cross_sect = cross_sect_for_obj(self) + if 'tabinputfile' in self.params.keys(): + tabfile = os.path.join(self.fdir, self.get_param('tabinputfile').strip()) + if os.access(tabfile, os.R_OK): + self.rhoee = Rhoeetab(tabfile=tabfile, fdir=fdir, radtab=True, verbose=self.verbose) + + self.stagger = stagger.StaggerInterface(self) + + document_vars.create_vardict(self) + document_vars.set_vardocs(self) + + ## PROPERTIES ## + help = property(lambda self: self.vardoc) + + shape = property(lambda self: (self.xLength, self.yLength, self.zLength)) + size = property(lambda self: (self.xLength * self.yLength * self.zLength)) + ndim = property(lambda self: 3) + + units_output = units.UNITS_OUTPUT_PROPERTY(internal_name='_units_output') + + @property + def internal_means(self): + '''whether to take means of get_var internally, immediately (for simple vars). + DISABLED by default. + + E.g. if enabled, self.get_var('r') will be single-valued, not an array. + Note this will have many consequences. E.g. derivatives will all be 0. + Original intent: analyzing simulations with just a small perturbation around the mean. + ''' + return getattr(self, '_internal_means', False) + + @internal_means.setter + def internal_means(self, value): + self._internal_means = value + + @property + def printing_stats(self): + '''whether to print stats about values of var upon completing a(n external) call to get_var. + + Options: + False (default) --> don't print stats. + True --> do print stats. + dict --> call print stats with these kwargs + e.g. printing_stats=dict(fmt='{:.3e}') --> self.print_stats(fmt='{:.3e}') + + This is useful especially while investigating just the approximate values for each quantity. + ''' + return getattr(self, '_printing_stats', False) - def _set_snapvars(self): + @printing_stats.setter + def printing_stats(self, value): + self._printing_stats = value + + stagger_kind = stagger.STAGGER_KIND_PROPERTY(internal_name='_stagger_kind') + + @property + def cstagop(self): # cstagop is an alias to do_stagger. Maintained for backwards compatibility. + return self.do_stagger + + @cstagop.setter + def cstagop(self, value): + self.do_stagger = value + + @property + def snap(self): + '''snapshot number, or list of snapshot numbers.''' + return getattr(self, '_snap', None) + + @snap.setter + def snap(self, value): + self.set_snap(value) + + @property + def snaps(self): + '''equivalent to self.snap when it is a list (or other iterable). Otherwise, raise TypeError.''' + snaps = self.snap + try: + iter(snaps) + except TypeError: + raise TypeError(f'self.snap (={self.snap}) is not a list!') from None + return snaps + + @property + def snapname(self): + '''alias for self.root_name. Set by 'snapname' in mhd.in / .idl files.''' + return self.root_name + + kx = property(lambda self: 2*np.pi*np.fft.fftshift(np.fft.fftfreq(self.xLength, self.dx)), + doc='kx coordinates [simulation units] (fftshifted such that 0 is in the middle).') + ky = property(lambda self: 2*np.pi*np.fft.fftshift(np.fft.fftfreq(self.yLength, self.dy)), + doc='ky coordinates [simulation units] (fftshifted such that 0 is in the middle).') + kz = property(lambda self: 2*np.pi*np.fft.fftshift(np.fft.fftfreq(self.zLength, self.dz)), + doc='kz coordinates [simulation units] (fftshifted such that 0 is in the middle).') + # ^ convert k to physical units by dividing by self.uni.usi_l (or u_l for cgs) + + ## SET SNAPSHOT ## + def __getitem__(self, i): + '''sets snap to i then returns self. + + i: string, or anything which can index a list + string --> set snap to int(i) + else --> set snap to self.get_snaps()[i] + + Example usage: + bb = BifrostData(...) + bb['3']('r') + # is equivalent to: bb.set_snap(3); bb.get_var('r') + bb[3]('r') + # is equivalent to: bb.set_snap(bb.get_snaps()[3]); bb.get_var('r') + # if the existing snaps are [0,1,2,3,...], this is equivalent to bb['3']('r') + # if the existing snaps are [4,5,6,7,...], this is equivalent to bb['7']('r') + ''' + if isinstance(i, str): + self.set_snap(int(i)) + else: + self.set_snap(self.get_snaps()[i]) + return self + + def _set_snapvars(self, firstime=False): """ Sets list of avaible variables """ self.snapvars = ['r', 'px', 'py', 'pz', 'e'] - self.auxvars = self.params['aux'][self.snapInd].split() + self.auxvars = self.get_param('aux', error_prop=True).split() if self.do_mhd: self.snapvars += ['bx', 'by', 'bz'] self.hionvars = [] self.heliumvars = [] - if 'do_hion' in self.params: - if self.params['do_hion'][self.snapInd] > 0: - self.hionvars = ['hionne', 'hiontg', 'n1', - 'n2', 'n3', 'n4', 'n5', 'n6', 'nh2'] - self.hion = True - if 'do_helium' in self.params: - if self.params['do_helium'][self.snapInd] > 0: - self.heliumvars = ['nhe1', 'nhe2', 'nhe3'] - self.heion = True + if self.get_param('do_hion', default=0) > 0: + self.hionvars = ['hionne', 'hiontg', 'n1', + 'n2', 'n3', 'n4', 'n5', 'n6', 'nh2'] + self.hion = True + if self.get_param('do_helium', default=0) > 0: + self.heliumvars = ['nhe1', 'nhe2', 'nhe3'] + self.heion = True self.compvars = ['ux', 'uy', 'uz', 's', 'ee'] self.simple_vars = self.snapvars + self.auxvars + self.hionvars + \ self.heliumvars @@ -147,7 +324,7 @@ def _set_snapvars(self): self.auxvars.remove(var) self.vars2d.append(var) - def set_snap(self, snap): + def set_snap(self, snap, firstime=False, params_only=False): """ Reads metadata and sets variable memmap links for a given snapshot number. @@ -166,7 +343,7 @@ def set_snap(self, snap): else: tmp = glob("%s.idl" % self.file_root) snap = 0 - except: + except Exception: try: tmp = sorted(glob("%s*idl.scr" % self.file_root))[0] snap = -1 @@ -177,26 +354,26 @@ def set_snap(self, snap): except IndexError: raise ValueError(("(EEE) set_snap: snapshot not defined " "and no .idl files found")) - self.snap = snap - if np.size(snap) > 1: + + self._snap = snap + if np.shape(self.snap) != (): self.snap_str = [] for num in snap: - self.snap_str.append('_%03i' % int(num)) + self.snap_str.append(_N_to_snapstr(num)) else: - if snap == 0: - self.snap_str = '' - else: - self.snap_str = '_%03i' % snap + self.snap_str = _N_to_snapstr(snap) self.snapInd = 0 - self._read_params() + self._read_params(firstime=firstime) # Read mesh for all snaps because meshfiles could differ - self.__read_mesh(self.meshfile) + self.__read_mesh(self.meshfile, firstime=firstime) # variables: lists and initialisation - self._set_snapvars() - self._init_vars() + self._set_snapvars(firstime=firstime) + # Do not call if params_only requested + if (not params_only): + self._init_vars(firstime=firstime) - def _read_params(self): + def _read_params(self, firstime=False): """ Reads parameter file (.idl) """ @@ -219,7 +396,7 @@ def _read_params(self): filename.append(self.file_root + snap_str[i] + '.idl') for file in filename: - self.paramList.append(read_idl_ascii(file)) + self.paramList.append(read_idl_ascii(file, firstime=firstime, obj=self)) # assign some parameters as attributes for params in self.paramList: @@ -236,34 +413,88 @@ def _read_params(self): raise KeyError(('read_params: could not find ' '%s in idl file!' % p)) try: - if params['boundarychk'] == 1: + if ((params['boundarychk'] == 1) and (params['isnap'] != 0)): self.nzb = self.nz + 2 * self.nb else: self.nzb = self.nz + if ((params['boundarychky'] == 1) and (params['isnap'] != 0)): + self.nyb = self.ny + 2 * self.nb + else: + self.nyb = self.ny + if ((params['boundarychkx'] == 1) and (params['isnap'] != 0)): + self.nxb = self.nx + 2 * self.nb + else: + self.nxb = self.nx except KeyError: self.nzb = self.nz + self.nyb = self.ny + self.nxb = self.nx # check if units are there, if not use defaults and print warning unit_def = {'u_l': 1.e8, 'u_t': 1.e2, 'u_r': 1.e-7, 'u_b': 1.121e3, 'u_ee': 1.e12} for unit in unit_def: if unit not in params: - print(("(WWW) read_params:"" %s not found, using " - "default of %.3e" % (unit, unit_def[unit]))) - params[unit] = unit_def[unit] + default = unit_def[unit] + if hasattr(self, 'uni'): + default = getattr(self.uni, unit, default) + if getattr(self, 'verbose', True): + print("(WWW) read_params:"" %s not found, using " + "default of %.3e" % (unit, default), 2*whsp, + end="\r", flush=True) + params[unit] = default self.params = {} for key in self.paramList[0]: self.params[key] = np.array( - [self.paramList[i][key] for i in range( - 0, len(self.paramList))]) + [self.paramList[i][key] for i in range(0, len(self.paramList)) + if key in self.paramList[i].keys()]) + # the if statement is required in case extra params in + # self.ParmList[0] + self.time = self.params['t'] + if self.sel_units == 'cgs': + self.time *= self.uni.uni['t'] + + def get_param(self, param, default=None, warning=None, error_prop=None): + ''' get param via self.params[param][self.snapInd]. + + if param not in self.params.keys(), then the following kwargs may play a role: + default: None (default) or any value. + return this value (eventually) instead. (check warning and error_prop first.) + warning: None (default) or any Warning or string. + if not None, do warnings.warn(warning). + error_prop: None (default), True, or any Exception object. + None --> ignore this kwarg. + True --> raise the original KeyError caused by trying to get self.params[param]. + else --> raise error_prop from None. + ''' + try: + p = self.params[param] + except KeyError as err_triggered: + if (warning is not None) and (self.verbose): + warnings.warn(warning) + if error_prop is not None: + if isinstance(error_prop, BaseException): + raise error_prop from None # "from None" --> show just this error, not also err_triggered + elif error_prop: + raise err_triggered + return default + else: + p = p[self.snapInd] + return p + + def get_params(self, *params, **kw): + '''return a dict of the values of params in self. + Equivalent to {p: self.get_param(p, **kw) for p in params}. + ''' + return {p: self.get_param(p, **kw) for p in params} - def __read_mesh(self, meshfile): + def __read_mesh(self, meshfile, firstime=False): """ Reads mesh file """ if meshfile is None: meshfile = os.path.join( - self.fdir, self.params['meshfile'][self.snapInd].strip()) + self.fdir, self.get_param('meshfile', error_prop=True).strip()) if os.path.isfile(meshfile): f = open(meshfile, 'r') for p in ['x', 'y', 'z']: @@ -300,16 +531,20 @@ def __read_mesh(self, meshfile): np.repeat(self.dzidzdn[0], self.nb), self.dzidzdn, np.repeat(self.dzidzdn[-1], self.nb))) + self.nx = self.nxb + self.ny = self.nyb self.nz = self.nzb - elif meshfile.lower() == "uniform": + else: # no mesh file if self.dx == 0.0: self.dx = 1.0 if self.dy == 0.0: self.dy = 1.0 if self.dz == 0.0: self.dz = 1.0 - print(('(WWW) Creating uniform grid with [dx,dy,dz] = ' - '[%f,%f,%f]') % (self.dx, self.dy, self.dz)) + if self.verbose and firstime: + warnings.warn(('Mesh file {mf} does not exist. Creating uniform grid ' + 'with (dx,dy,dz)=({dx:.2e},{dy:.2e},{dz:.2e})').format( + mf=repr(meshfile), dx=self.dx, dy=self.dy, dz=self.dz)) # x self.x = np.arange(self.nx) * self.dx self.xdn = self.x - 0.5 * self.dx @@ -322,150 +557,239 @@ def __read_mesh(self, meshfile): self.dyidydn = np.zeros(self.ny) + 1. / self.dy # z if self.ghost_analyse: + self.nx = self.nxb + self.ny = self.nyb self.nz = self.nzb self.z = np.arange(self.nz) * self.dz self.zdn = self.z - 0.5 * self.dz self.dzidzup = np.zeros(self.nz) + 1. / self.dz self.dzidzdn = np.zeros(self.nz) + 1. / self.dz - else: - raise ValueError("No meshfile available. Either file was not found" - " or meshfile was not set to 'uniform'.") - if self.nz > 1: - self.dz1d = np.gradient(self.z) - else: - self.dz1d = np.zeros(self.nz) + for x in ('x', 'y', 'z'): + setattr(self, x, getattr(self, x)[getattr(self, 'ii'+x, slice(None))]) - def _init_vars(self, *args, **kwargs): + for x in ('x', 'y', 'z'): + xcoords = getattr(self, x) + if len(xcoords) > 1: + dx1d = np.gradient(xcoords) + else: + dx1d = np.zeros(len(xcoords)) + setattr(self, 'd'+x+'1d', dx1d) + + if self.sel_units == 'cgs': + self.x *= self.uni.uni['l'] + self.y *= self.uni.uni['l'] + self.z *= self.uni.uni['l'] + self.zdn *= self.uni.uni['l'] + self.dx *= self.uni.uni['l'] + self.dy *= self.uni.uni['l'] + self.dz *= self.uni.uni['l'] + self.dx1d *= self.uni.uni['l'] + self.dy1d *= self.uni.uni['l'] + self.dz1d *= self.uni.uni['l'] + + self.dxidxup /= self.uni.uni['l'] + self.dxidxdn /= self.uni.uni['l'] + self.dyidyup /= self.uni.uni['l'] + self.dyidydn /= self.uni.uni['l'] + self.dzidzup /= self.uni.uni['l'] + self.dzidzdn /= self.uni.uni['l'] + + self.transunits = False + + def _init_vars(self, firstime=False, fast=None, *args, **kwargs): """ Memmaps "simple" variables, and maps them to methods. Also, sets file name[s] from which to read a data + + fast: None, True, or False. + whether to only read density (and not all the other variables). + if None, use self.fast instead. """ + fast = fast if fast is not None else self.fast + if self._fast_skip_flag is True: + return + elif self._fast_skip_flag is False: + self._fast_skip_flag = True # swaps flag to True, then runs the rest of the code (this time around). + # else, fast_skip_flag is None, so the code should never be skipped. + # as long as fast is False, fast_skip_flag should be None. + self.variables = {} for var in self.simple_vars: try: self.variables[var] = self._get_simple_var( var, *args, **kwargs) setattr(self, var, self.variables[var]) - except Exception: + except Exception as err: if self.verbose: - print(('(WWW) init_vars: could not read ' - 'variable %s' % var)) + if firstime: + print('(WWW) init_vars: could not read ' + 'variable {} due to {}'.format(var, err)) for var in self.auxxyvars: try: self.variables[var] = self._get_simple_var_xy(var, *args, **kwargs) setattr(self, var, self.variables[var]) - except Exception: + except Exception as err: if self.verbose: - print(('(WWW) init_vars: could not read ' - 'variable %s' % var)) + if firstime: + print('(WWW) init_vars: could not read ' + 'variable {} due to {}'.format(var, err)) rdt = self.r.dtype - - def get_varTime(self, var, snap=None, iix=None, iiy=None, iiz=None, - *args, **kwargs): - """ - Reads a given variable as a function of time. - - Parameters - ---------- - var - string - Name of the variable to read. Must be a valid Bifrost variable name, - see Bifrost.get_var(). - snap - array of integers - Snapshot numbers to read. - iix -- integer or array of integers, optional - reads yz slices. - iiy -- integer or array of integers, optional - reads xz slices. - iiz -- integer or array of integers, optional - reads xy slices. - """ - self.iix = iix - self.iiy = iiy - self.iiz = iiz - - try: - if snap is not None: - if np.size(snap) == np.size(self.snap): - if any(snap != self.snap): - self.set_snap(snap) - else: - self.set_snap(snap) - except ValueError: - print('WWW: snap has to be a numpy.arrange parameter') - - # lengths for dimensions of return array - self.xLength = 0 - self.yLength = 0 - self.zLength = 0 - - for dim in ('iix', 'iiy', 'iiz'): - if getattr(self, dim) is None: - if dim[2] == 'z': - setattr(self, dim[2] + 'Length', - getattr(self, 'n' + dim[2] + 'b')) - else: - setattr(self, dim[2] + 'Length', - getattr(self, 'n' + dim[2])) - setattr(self, dim, slice(None)) + if self.stagger_kind == 'cstagger': + if (self.nz > 1): + cstagger.init_stagger(self.nz, self.dx, self.dy, self.z.astype(rdt), + self.zdn.astype(rdt), self.dzidzup.astype(rdt), + self.dzidzdn.astype(rdt)) + self.cstagger_exists = True # we can use cstagger methods! else: - indSize = np.size(getattr(self, dim)) - setattr(self, dim[2] + 'Length', indSize) - - snapLen = np.size(self.snap) - value = np.empty([self.xLength, self.yLength, self.zLength, snapLen]) + cstagger.init_stagger_mz1d(self.nz, self.dx, self.dy, self.z.astype(rdt)) + self.cstagger_exists = True # we must avoid using cstagger methods. + else: + self.cstagger_exists = True - for i in range(0, snapLen): - self.snapInd = i - self._set_snapvars() - self._init_vars() + ## GET VARIABLE ## + def __call__(self, var, *args, **kwargs): + '''equivalent to self.get_var(var, *args, **kwargs)''' + __tracebackhide__ = True # hide this func from error traceback stack + return self.get_var(var, *args, **kwargs) - value[..., i] = self.get_var(var, self.snap[i], iix=self.iix, - iiy=self.iiy, iiz=self.iiz) - return value - - def set_domain_iiaxis(self, iinum=slice(None), iiaxis='x'): + def set_domain_iiaxis(self, iinum=None, iiaxis='x'): """ - Sets length of each dimension for get_var based on iix/iiy/iiz + Sets iix=iinum and xLength=len(iinum). (x=iiaxis) + if iinum is a slice, use self.nx (or self.nzb, for x='z') to determine xLength. + + Also, if we end up using a non-None slice, disable stagger. + TODO: maybe we can leave do_stagger=True if stagger_kind != 'cstagger' ? Parameters ---------- - iinum - int, list, or array + iinum - slice, int, list, array, or None (default) Slice to be taken from get_var quantity in that axis (iiaxis) + int --> convert to slice(iinum, iinum+1) (to maintain dimensions of output) + None --> don't change existing self.iix (or iiy or iiz). + if it doesn't exist, set it to slice(None). + To set existing self.iix to slice(None), use iinum=slice(None). iiaxis - string Axis from which the slice will be taken ('x', 'y', or 'z') + + Returns True if any changes were made, else None. """ + iix = 'ii' + iiaxis + if hasattr(self, iix): + # if iinum is None or self.iix == iinum, do nothing and return nothing. + if (iinum is None): + return None + elif np.all(iinum == getattr(self, iix)): + return None + if iinum is None: iinum = slice(None) - dim = 'ii' + iiaxis - setattr(self, dim, iinum) - setattr(self, iiaxis + 'Length', np.size(iinum)) + if not np.array_equal(iinum, slice(None)): + # smash self.variables. Necessary, since we will change the domain size. + self.variables = {} - if np.size(getattr(self, dim)) == 1: - if getattr(self, dim) == slice(None): - if dim[2] == 'z': - setattr(self, dim[2] + 'Length', - getattr(self, 'n' + dim[2] + 'b')) - else: - setattr(self, dim[2] + 'Length', - getattr(self, 'n' + dim[2])) + if isinstance(iinum, (int, np.integer)): # we convert to slice, to maintain dimensions of output. + iinum = slice(iinum, iinum+1) # E.g. [0,1,2][slice(1,2)] --> [1]; [0,1,2][1] --> 1 + + # set self.iix + setattr(self, iix, iinum) + if self.verbose: + # convert iinum to string that wont be super long (in case iinum is a long list) + try: + assert len(iinum) > 20 + except (TypeError, AssertionError): + iinumprint = iinum + else: + iinumprint = 'list with length={:4d}, min={:4d}, max={:4d}, x[1]={:2d}' + iinumprint = iinumprint.format(len(iinum), min(iinum), max(iinum), iinum[1]) + # print info. + print('(set_domain) {}: {}'.format(iix, iinumprint), + whsp*4, end="\r", flush=True) + + # set self.xLength + if isinstance(iinum, slice): + nx = getattr(self, 'n'+iiaxis+'b') + indSize = len(range(*iinum.indices(nx))) + else: + iinum = np.asarray(iinum) + if iinum.dtype == 'bool': + indSize = np.sum(iinum) else: - indSize = np.size(getattr(self, dim)) - setattr(self, dim[2] + 'Length', indSize) - if indSize == 1: - temp = np.asarray(getattr(self, dim)) - setattr(self, dim, temp.item()) + indSize = np.size(iinum) + setattr(self, iiaxis + 'Length', indSize) + + return True + + def set_domain_iiaxes(self, iix=None, iiy=None, iiz=None, internal=False): + '''sets iix, iiy, iiz, xLength, yLength, zLength. + iix: slice, int, list, array, or None (default) + Slice to be taken from get_var quantity in x axis + None --> don't change existing self.iix. + if self.iix doesn't exist, set it to slice(None). + To set existing self.iix to slice(None), use iix=slice(None). + iiy, iiz: similar to iix. + internal: bool (default: False) + if internal and self.do_stagger, don't change slices. + internal=True inside get_var. + + updates x, y, z, dx1d, dy1d, dz1d afterwards, if any domains were changed. + ''' + if internal and self.do_stagger: + # we slice at the end, only. For now, set all to slice(None) + slices = (slice(None), slice(None), slice(None)) else: - indSize = np.size(getattr(self, dim)) - setattr(self, dim[2] + 'Length', indSize) - if indSize == 1: - temp = np.asarray(getattr(self, dim)) - setattr(self, dim, temp.item()) - - def get_var(self, var, snap=None, *args, iix=slice(None), iiy=slice(None), - iiz=slice(None), **kwargs): + slices = (iix, iiy, iiz) + + any_domain_changes = False + for x, iix in zip(AXES, slices): + domain_changed = self.set_domain_iiaxis(iix, x) + any_domain_changes = any_domain_changes or domain_changed + + # update x, y, z, dx1d, dy1d, dz1d appropriately. + if any_domain_changes: + self.__read_mesh(self.meshfile, firstime=False) + + def genvar(self): + ''' + Dictionary of original variables which will allow to convert to cgs. + ''' + self.varn = {} + self.varn['rho'] = 'r' + self.varn['totr'] = 'r' + self.varn['tg'] = 'tg' + self.varn['pg'] = 'p' + self.varn['ux'] = 'ux' + self.varn['uy'] = 'uy' + self.varn['uz'] = 'uz' + self.varn['e'] = 'e' + self.varn['bx'] = 'bx' + self.varn['by'] = 'by' + self.varn['bz'] = 'bz' + + @document_vars.quant_tracking_top_level + def _load_quantity(self, var, cgsunits=1.0, **kwargs): + '''helper function for get_var; actually calls load_quantities for var.''' + __tracebackhide__ = True # hide this func from error traceback stack + # look for var in self.variables + if cgsunits == 1.0: + if var in self.variables: # if var is still in memory, + return self.variables[var] # load from memory instead of re-reading. + # Try to load simple quantities. + val = load_fromfile_quantities.load_fromfile_quantities(self, var, + save_if_composite=True, cgsunits=cgsunits, **kwargs) + + # Try to load "regular" quantities + if val is None: + val = load_quantities(self, var, **kwargs) + # Try to load "arithmetic" quantities. + if val is None: + val = load_arithmetic_quantities(self, var, **kwargs) + + return val + + def get_var(self, var, snap=None, *args, iix=None, iiy=None, iiz=None, printing_stats=None, **kwargs): """ Reads a variable from the relevant files. @@ -477,70 +801,220 @@ def get_var(self, var, snap=None, *args, iix=slice(None), iiy=slice(None), Snapshot number to read. By default reads the loaded snapshot; if a different number is requested, will load that snapshot by running self.set_snap(snap). + + **kwargs go to load_..._quantities functions. """ if self.verbose: - print('(get_var): reading ', var) - - if not hasattr(self, 'iix'): - self.set_domain_iiaxis(iinum=iix, iiaxis='x') - self.set_domain_iiaxis(iinum=iiy, iiaxis='y') - self.set_domain_iiaxis(iinum=iiz, iiaxis='z') - else: - if (iix != slice(None)) and np.any(iix != self.iix): - if self.verbose: - print('(get_var): iix ', iix, self.iix) - self.set_domain_iiaxis(iinum=iix, iiaxis='x') - if (iiy != slice(None)) and np.any(iiy != self.iiy): - if self.verbose: - print('(get_var): iiy ', iiy, self.iiy) - self.set_domain_iiaxis(iinum=iiy, iiaxis='y') - if (iiz != slice(None)) and np.any(iiz != self.iiz): - if self.verbose: - print('(get_var): iiz ', iiz, self.iiz) - self.set_domain_iiaxis(iinum=iiz, iiaxis='z') + print('(get_var): reading ', var, whsp*6, end="\r", flush=True) if var in ['x', 'y', 'z']: return getattr(self, var) if (snap is not None) and np.any(snap != self.snap): if self.verbose: - print('(get_var): setsnap ', snap, self.snap) + print('(get_var): setsnap ', snap, self.snap, whsp*6, + end="\r", flush=True) self.set_snap(snap) + self.variables = {} + + # set iix, iiy, iiz appropriately + slices_names_and_vals = (('iix', iix), ('iiy', iiy), ('iiz', iiz)) + original_slice = [iix if iix is not None else getattr(self, slicename, slice(None)) + for slicename, iix in slices_names_and_vals] + self.set_domain_iiaxes(iix=iix, iiy=iiy, iiz=iiz, internal=True) + + if var in self.varn.keys(): + var = self.varn[var] + + if (self.sel_units == 'cgs'): + varu = var.replace('x', '') + varu = varu.replace('y', '') + varu = varu.replace('z', '') + if varu == 'r': + varu = 'rho' + if (varu in self.uni.uni.keys()): + cgsunits = self.uni.uni[varu] + else: + cgsunits = 1.0 - if var in self.simple_vars: # is variable already loaded? - val = self._get_simple_var(var, *args, **kwargs) - if self.verbose: - print('(get_var): reading simple ', np.shape(val)) - elif var in self.auxxyvars: - val = self._get_simple_var_xy(var, *args, **kwargs) - elif var in self.compvars: # add to variable list - self.variables[var] = self._get_composite_var(var, *args, **kwargs) - setattr(self, var, self.variables[var]) - val = self.variables[var] else: - val = self.get_quantity(var, *args, **kwargs) - - if np.shape(val) != (self.xLength, self.yLength, self.zLength): - # at least one slice has more than one value - if np.size(self.iix) + np.size(self.iiy) + np.size(self.iiz) > 3: - # x axis may be squeezed out, axes for take() - axes = [0, -2, -1] - - for counter, dim in enumerate(['iix', 'iiy', 'iiz']): - if (np.size(getattr(self, dim)) > 1 or - getattr(self, dim) != slice(None)): - # slicing each dimension in turn - val = val.take(getattr(self, dim), axis=axes[counter]) - else: - # all of the slices are only one int or slice(None) - val = val[self.iix, self.iiy, self.iiz] + cgsunits = 1.0 - # ensuring that dimensions of size 1 are retained - val = np.reshape(val, (self.xLength, self.yLength, self.zLength)) + # get value of variable. + val = self._load_quantity(var, cgsunits=cgsunits, **kwargs) + # do post-processing + val = self._get_var_postprocess(val, var=var, original_slice=original_slice, printing_stats=printing_stats) return val - def _get_simple_var(self, var, order='F', mode='r', *args, **kwargs): + def _get_var_postprocess(self, val, var='', printing_stats=None, original_slice=[slice(None) for x in ('x', 'y', 'z')]): + '''does post-processing for get_var. + This includes: + - handle "creating documentation" or "var==''" case + - handle "don't know how to get this var" case + - reshape result as appropriate (based on iix,iiy,iiz) + - take mean if self.internal_means (disabled by default). + - squeeze if self.squeeze_output (disabled by default). + - convert units as appropriate (based on self.units_output.) + - default is to keep result in simulation units, doing no conversions. + - if converting, note that any caching would happen in _load_quantity, + outside this function. The cache will always be in simulation units. + - print stats if printing_stats or ((printing_stats is None) and self.printing_stats). + returns val after the processing is complete. + ''' + # handle documentation case + if document_vars.creating_vardict(self): + return None + elif var == '': + print('Variables from snap or aux files:') + print(self.simple_vars) + print('Variables from xy aux files:') + print(self.auxxyvars) + if hasattr(self, 'vardict'): + self.vardocs() + return None + + # handle "don't know how to get this var" case + if val is None: + errmsg = ('get_var: do not know (yet) how to calculate quantity {}. ' + '(Got None while trying to calculate it.) ' + 'Note that simple_var available variables are: {}. ' + '\nIn addition, get_quantity can read others computed variables; ' + "see e.g. help(self.get_var) or get_var('')) for guidance.") + raise ValueError(errmsg.format(repr(var), repr(self.simple_vars))) + + # set original_slice if do_stagger and we are at the outermost layer. + if self.do_stagger and not self._getting_internal_var(): + self.set_domain_iiaxes(*original_slice, internal=False) + + # reshape if necessary... E.g. if var is a simple var, and iix tells to slice array. + if (np.ndim(val) >= self.ndim) and (np.shape(val) != self.shape): + if all(isinstance(s, slice) for s in (self.iix, self.iiy, self.iiz)): + val = val[self.iix, self.iiy, self.iiz] # we can index all together + else: # we need to index separately due to numpy multidimensional index array rules. + val = val[self.iix, :, :] + val = val[:, self.iiy, :] + val = val[:, :, self.iiz] + + # take mean if self.internal_means (disabled by default) + if self.internal_means: + val = val.mean() + + # handle post-processing steps which we only do for top-level calls to get_var: + if not self._getting_internal_var(): + + # squeeze if self.squeeze_output (disabled by default) + if self.squeeze_output and (np.ndim(val) > 0): + val = val.squeeze() + + # convert units if we are using units_output != 'simu'. + if self.units_output != 'simu': + units_f, units_name = self.get_units(mode=self.units_output, _force_from_simu=True) + self.got_units_name = units_name # << this line is just for reference. Not used internally. + val = val * units_f # can't do *= in case val is a read-only memmap. + + # print stats if self.printing_stats + self.print_stats(val, printing_stats=printing_stats) + + return val + + def _getting_internal_var(self): + '''returns whether we are currently inside of an internal call to _load_quantity. + (_load_quantity is called inside of get_var.) + + Here is an example, with the comments telling self._getting_internal_var() at that line: + # False + get_var('ux') --> + # False + px = get_var('px') --> + # True + returns the value of px + # False + rxdn = get_var('rxdn') --> + # True + r = get_var('r') --> + # True + returns the value of r + # True + returns apply_xdn_to(r) + # False + return px / rxdn + # False + (Of course, this example assumes get_var('ux') was called externally.) + ''' + return getattr(self, document_vars.LOADING_LEVEL) >= 0 + + def trans2comm(self, varname, snap=None, *args, **kwargs): + ''' + Transform the domain into a "common" format. All arrays will be 3D. The 3rd axis + is: + + - for 3D atmospheres: the vertical axis + - for loop type atmospheres: along the loop + - for 1D atmosphere: the unique dimension is the 3rd axis. + At least one extra dimension needs to be created artifically. + + All of them should obey the right hand rule + + In all of them, the vectors (velocity, magnetic field etc) away from the Sun. + + If applies, z=0 near the photosphere. + + Units: everything is in cgs. + + If an array is reverse, do ndarray.copy(), otherwise pytorch will complain. + + ''' + + self.trans2commaxes() + + self.sel_units = 'cgs' + + sign = 1.0 + if varname[-1] in ['x', 'y', 'z']: + varname = varname+'c' + if varname[-2] in ['y', 'z']: + sign = -1.0 + + var = self.get_var(varname, snap=snap, *args, **kwargs) + var = sign * var + + var = var[..., ::-1].copy() + + return var + + def trans2commaxes(self): + if self.transunits == False: + self.transunits = True + if self.sel_units == 'cgs': + cte = 1.0 + else: + cte = self.uni.u_l # not sure if this works, u_l seems to be 1.e8 + self.x = self.x*cte + self.dx = self.dx*cte + self.y = self.y*cte + self.dy = self.dy*cte + self.z = - self.z[::-1].copy()*cte + self.dz = - self.dz1d[::-1].copy()*cte + + def trans2noncommaxes(self): + + if self.transunits == True: + self.transunits = False + if self.sel_units == 'cgs': + cte = 1.0 + else: + cte = self.uni.u_l + self.x = self.x/cte + self.dx = self.dx/cte + self.y = self.y/cte + self.dy = self.dy/cte + self.z = - self.z[::-1].copy()/cte + self.dz = - self.dz1d[::-1].copy()/cte + + @document_vars.quant_tracking_simple('SIMPLE_VARS') + def _get_simple_var(self, var, order='F', mode='r', + panic=False, *args, **kwargs): """ Gets "simple" variable (ie, only memmap, not load into memory). Parameters @@ -557,7 +1031,21 @@ def _get_simple_var(self, var, order='F', mode='r', *args, **kwargs): result - numpy.memmap array Requested variable. """ - if np.size(self.snap) > 1: + if var == '': + _simple_vars_msg = ('Quantities which are stored by the simulation. These are ' + 'loaded as numpy memmaps by reading data files directly.') + document_vars.vars_documenter(self, 'SIMPLE_VARS', None, _simple_vars_msg) + # TODO: << add documentation for bifrost simple vars, here. + return None + + if var not in self.simple_vars: + return None + + if self.verbose: + print('(get_var): reading simple ', var, whsp*5, # TODO: show np.shape(val) info somehow? + end="\r", flush=True) + + if np.shape(self.snap) != (): currSnap = self.snap[self.snapInd] currStr = self.snap_str[self.snapInd] else: @@ -565,7 +1053,10 @@ def _get_simple_var(self, var, order='F', mode='r', *args, **kwargs): currStr = self.snap_str if currSnap < 0: filename = self.file_root - fsuffix_b = '.scr' + if panic: + fsuffix_b = '' + else: + fsuffix_b = '.scr' elif currSnap == 0: filename = self.file_root fsuffix_b = '' @@ -574,7 +1065,10 @@ def _get_simple_var(self, var, order='F', mode='r', *args, **kwargs): fsuffix_b = '' if var in self.snapvars: - fsuffix_a = '.snap' + if panic: + fsuffix_a = '.panic' + else: + fsuffix_a = '.snap' idx = (self.snapvars).index(var) filename += fsuffix_a + fsuffix_b elif var in self.auxvars: @@ -583,34 +1077,41 @@ def _get_simple_var(self, var, order='F', mode='r', *args, **kwargs): filename += fsuffix_a + fsuffix_b elif var in self.hionvars: idx = self.hionvars.index(var) - isnap = self.params['isnap'][self.snapInd] - if isnap <= -1: - filename = filename + '.hion.snap.scr' - elif isnap == 0: - filename = filename + '.hion.snap' - elif isnap > 0: - filename = '%s.hion_%03d.snap' % (self.file_root, isnap) - if not os.path.isfile(filename): - filename = '%s_.hion%s.snap' % (self.file_root, isnap) + isnap = self.get_param('isnap', error_prop=True) + if panic: + filename = filename + '.hion.panic' + else: + if isnap <= -1: + filename = filename + '.hion.snap.scr' + elif isnap == 0: + filename = filename + '.hion.snap' + elif isnap > 0: + filename = '%s.hion_%03d.snap' % (self.file_root, isnap) + if not os.path.isfile(filename): + filename = '%s_.hion%s.snap' % (self.file_root, isnap) elif var in self.heliumvars: idx = self.heliumvars.index(var) - isnap = self.params['isnap'][self.snapInd] - if isnap <= -1: - filename = filename + '.helium.snap.scr' - elif isnap == 0: - filename = filename + '.helium.snap' - elif isnap > 0: - filename = '%s.helium_%s.snap' % (self.file_root, isnap) + isnap = self.get_param('isnap', error_prop=True) + if panic: + filename = filename + '.helium.panic' + else: + if isnap <= -1: + filename = filename + '.helium.snap.scr' + elif isnap == 0: + filename = filename + '.helium.snap' + elif isnap > 0: + filename = '%s.helium_%s.snap' % (self.file_root, isnap) else: raise ValueError(('_get_simple_var: could not find variable ' '%s. Available variables:' % (var) + '\n' + repr(self.simple_vars))) dsize = np.dtype(self.dtype).itemsize if self.ghost_analyse: - offset = self.nx * self.ny * self.nzb * idx * dsize - ss = (self.nx, self.ny, self.nzb) + offset = self.nxb * self.nyb * self.nzb * idx * dsize + ss = (self.nxb, self.nyb, self.nzb) else: - offset = (self.nx * self.ny * + offset = ((self.nxb + (self.nxb - self.nx)) * + (self.nyb + (self.nyb - self.ny)) * (self.nzb + (self.nzb - self.nz) // 2) * idx * dsize) ss = (self.nx, self.ny, self.nz) @@ -621,950 +1122,26 @@ def _get_simple_var(self, var, order='F', mode='r', *args, **kwargs): return np.memmap(filename, dtype=self.dtype, order=order, mode=mode, offset=offset, shape=ss) - def _get_simple_var_xy(self, var, order='F', mode='r'): - """ - Reads a given 2D variable from the _XY.aux file - """ - if var in self.auxxyvars: - fsuffix = '_XY.aux' - idx = self.auxxyvars.index(var) - filename = self.file_root + fsuffix - else: + def _get_simple_var_xy(self, *args, **kwargs): + '''returns load_fromfile_quantities._get_simple_var_xy(self, *args, **kwargs). + raises ValueError if result is None (to match historical behavior of this function). + + included for backwards compatibility purposes, only. + new code should instead use the function from load_fromfile_quantitites. + ''' + val = load_fromfile_quantities._get_simple_var_xy(self, *args, **kwargs) + if val is None: raise ValueError(('_get_simple_var_xy: variable' ' %s not available. Available vars:' % (var) + '\n' + repr(self.auxxyvars))) - # Now memmap the variable - if not os.path.isfile(filename): - raise IOError(('_get_simple_var_xy: variable' - ' %s should be in %s file, not found!' % - (var, filename))) - # size of the data type - dsize = np.dtype(self.dtype).itemsize - offset = self.nx * self.ny * idx * dsize - return np.memmap(filename, dtype=self.dtype, order=order, mode=mode, - offset=offset, shape=(self.nx, self.ny)) - - def _get_composite_var(self, var, *args, **kwargs): - """ - Gets composite variables (will load into memory). - """ - if var in ['ux', 'uy', 'uz']: # velocities - p = self.get_var('p' + var[1], order='F') - if getattr(self, 'n' + var[1]) < 5: - return p / self.get_var('r') # do not recentre for 2D cases - else: # will call xdn, ydn, or zdn to get r at cell faces - return p / stagger.do(self.get_var('r'), var[1] + 'dn') - elif var == 'ee': # internal energy - return self.get_var('e') / self.get_var('r') - elif var == 's': # entropy? - return np.log(self.get_var('p', *args, **kwargs)) - \ - self.params['gamma'][self.snapInd] * np.log( - self.get_var('r', *args, **kwargs)) - - def get_quantity(self, quant, *args, **kwargs): - """ - Calculates a quantity from the simulation quantiables. - - Parameters - ---------- - quant - string - Name of the quantity to calculate (see below for some categories). - - Returns - ------- - array - ndarray - Array with the dimensions of the simulation. - - Notes - ----- - All possible rules for quantity names are described in the dictionary - "self.description", e.g.: - - >>> dd.get_quantity('') - >>> dd.description.keys() - >>> dd.description['DERIV'] - """ - quant = quant.lower() - self.description = {} - DERIV_QUANT = ['dxup', 'dyup', 'dzup', 'dxdn', 'dydn', 'dzdn'] - self.description['DERIV'] = ('Spatial derivative (Bifrost units). ' - 'It must start with d and end with: ' - ', '.join(DERIV_QUANT)) - - CENTRE_QUANT = ['xc', 'yc', 'zc'] - self.description['CENTRE'] = ('Allows to center any vector(Bifrost' - ' units). It must end with ' + - ', '.join(CENTRE_QUANT)) - - MODULE_QUANT = ['mod', 'h'] # This one must be called the last - self.description['MODULE'] = ('Module (starting with mod) or horizontal ' - '(ending with h) component of vectors (Bifrost units)') - - HORVAR_QUANT = ['horvar'] - self.description['HORVAR'] = ('Horizontal average (Bifrost units).' - ' Starting with: ' + ', '.join(HORVAR_QUANT)) - - GRADVECT_QUANT = ['div', 'rot', 'she', 'chkdiv', 'chbdiv', 'chhdiv'] - self.description['GRADVECT'] = ('vectorial derivative opeartions ' - '(Bifrost units). ' - 'The following show divergence, rotational, shear, ratio of the ' - 'divergence with the maximum of the abs of each spatial derivative, ' - 'with the sum of the absolute of each spatial derivative, with ' - 'horizontal averages of the absolute of each spatial derivative ' - 'respectively when starting with: ' + ', '.join(GRADVECT_QUANT)) - - GRADSCAL_QUANT = ['gra'] - self.description['GRADSCAL'] = ('Gradient of a scalar (Bifrost units)' - ' starts with: ' + ', '.join(GRADSCAL_QUANT)) - - SQUARE_QUANT = ['2'] # This one must be called the towards the last - self.description['SQUARE'] = ('Square of a variable (Bifrost units)' - ' ends with: ' + ', '.join(SQUARE_QUANT)) - - RATIO_QUANT = 'rat' - self.description['RATIO'] = ('Ratio of two variables (Bifrost units)' - 'have in between: ' + ', '.join(RATIO_QUANT)) - - EOSTAB_QUANT = ['ne', 'tg', 'pg', 'kr', 'eps', 'opa', 'temt', 'ent'] - self.description['EOSTAB'] = ('Variables from EOS table. All of them ' - 'are in cgs except ne which is in SI. The electron density ' - '[m^-3], temperature [K], pressure [dyn/cm^2], Rosseland opacity ' - '[cm^2/g], scattering probability, opacity, thermal emission and ' - 'entropy are as follows: ' + ', '.join(EOSTAB_QUANT)) - - TAU_QUANT = 'tau' - self.description['TAU'] = ('tau at 500 is: ' + ', '.join(TAU_QUANT)) - - PROJ_QUANT = ['par', 'per'] - self.description['PROJ'] = ('Projected vectors (Bifrost units).' - ' Parallel and perpendicular have in the middle the following: ' + - ', '.join(PROJ_QUANT)) - - CURRENT_QUANT = ['ix', 'iy', 'iz', 'wx', 'wy', 'wz'] - self.description['CURRENT'] = ('Calculates currents (bifrost units) or' - 'rotational components of the velocity as follows ' + - ', '.join(CURRENT_QUANT)) - - FLUX_QUANT = ['pfx', 'pfy', 'pfz', 'pfex', 'pfey', 'pfez', 'pfwx', - 'pfwy', 'pfwz'] - self.description['FLUX'] = ('Poynting flux, Flux emergence, and' - 'Poynting flux from "horizontal" motions: ' + - ', '.join(FLUX_QUANT)) - - PLASMA_QUANT = ['beta', 'va', 'cs', 's', 'ke', 'mn', 'man', 'hp', - 'vax', 'vay', 'vaz', 'hx', 'hy', 'hz', 'kx', 'ky', - 'kz'] - self.description['PLASMA'] = ('Plasma beta, alfven velocity (and its' - 'components), sound speed, entropy, kinetic energy flux' - '(and its components), magnetic and sonic Mach number' - 'pressure scale height, and each component of the total energy' - 'flux (if applicable, Bifrost units): ' + - ', '.join(PLASMA_QUANT)) - - WAVE_QUANT = ['alf', 'fast', 'long'] - self.description['WAVE'] = ('Alfven, fast and longitudinal wave' - 'components (Bifrost units): ' + ', '.join(WAVE_QUANT)) - - CYCL_RES = ['n6nhe2', 'n6nhe3', 'nhe2nhe3'] - self.description['CYCL_RES'] = ('Resonant cyclotron frequencies' - '(only for do_helium) are (SI units): ' + ', '.join(CYCL_RES)) - - elemlist = ['h', 'he', 'c', 'o', 'ne', 'na', 'mg', 'al', 'si', 's', - 'k', 'ca', 'cr', 'fe', 'ni'] - GYROF_QUANT = ['gfe'] + ['gf' + clist for clist in elemlist] - self.description['GYROF'] = ('gyro freqency are (Hz): ' + - ', '.join(GYROF_QUANT)) - - DEBYE_LN_QUANT = ['debye_ln'] - self.description['DEBYE'] = ('Debye length in ... units:', - ', '.join(DEBYE_LN_QUANT)) - - COULOMB_COL_QUANT = ['coucol' + clist for clist in elemlist] - self.description['COULOMB_COL'] = ('Coulomb collision frequency in Hz' - 'units: ' + ', '.join(COULOMB_COL_QUANT)) - - CROSTAB_QUANT = ['h_' + clist for clist in elemlist] - for iel in elemlist: - CROSTAB_QUANT = CROSTAB_QUANT + [ - iel + '_' + clist for clist in elemlist] - self.description['CROSTAB'] = ('Cross section between species' - '(in cgs): ' + ', '.join(CROSTAB_QUANT)) - - COLFRE_QUANT = ['nu' + clist for clist in CROSTAB_QUANT] - self.description['COLFRE'] = ('Collision frequency (elastic and charge' - 'exchange) between different species in (cgs): ' + - ', '.join(COLFRE_QUANT)) - - COLFRI_QUANT = ['nu_ni', 'nu_en', 'nu_ei'] - COLFRI_QUANT = COLFRI_QUANT + \ - ['nu' + clist + '_i' for clist in elemlist] - COLFRI_QUANT = COLFRI_QUANT + \ - ['nu' + clist + '_n' for clist in elemlist] - self.description['COLFRI'] = ('Collision frequency (elastic and charge' - 'exchange) between fluids in (cgs): ' + ', '.join(COLFRI_QUANT)) - - IONP_QUANT = ['n' + clist + '-' for clist in elemlist] - IONP_QUANT = IONP_QUANT + ['r' + clist + '-' for clist in elemlist] - self.description['IONP'] = ('densities for specific ionized species as' - 'follow (in SI): ' + ', '.join(IONP_QUANT)) - - if quant == '': - help(self.get_quantity) - return -1 - - if np.size(self.snap) > 1: - currSnap = self.snap[self.snapInd] - else: - currSnap = self.snap - - if RATIO_QUANT in quant: - # Calculate module of vector quantity - q = quant[:quant.find(RATIO_QUANT)] - if q[0] == 'b': - if not self.do_mhd: - raise ValueError("No magnetic field available.") - result = self.get_var(q) - q = quant[quant.find(RATIO_QUANT) + 3:] - if q[0] == 'b': - if not self.do_mhd: - raise ValueError("No magnetic field available.") - return result / (self.get_var(q) + 1e-19) - - elif quant[0] == 'd' and quant[-4:] in DERIV_QUANT: - # Calculate derivative of quantity - axis = quant[-3] - q = quant[1:-4] # base variable - var = self.get_var(q) - - def deriv_loop(var, quant): - return stagger.do(var, 'd' + quant[0]) - - if getattr(self, 'n' + axis) < 5: # 2D or close - print('(WWW) get_quantity: DERIV_QUANT: ' - 'n%s < 5, derivative set to 0.0' % axis) - return np.zeros_like(var) - else: - if self.numThreads > 1: - if self.verbose: - print('Threading') - quantlist = [quant[-4:] for numb in range(self.numThreads)] - if axis != 'z': - return threadQuantity_z( - deriv_loop, self.numThreads, var, quantlist) - else: - return threadQuantity_y( - deriv_loop, self.numThreads, var, quantlist) - else: - if self.lowbus: - output = np.zeros_like(var) - if axis != 'z': - for iiz in range(self.nz): - output[:, :, iiz] = np.reshape( - stagger.do(var[:, :, iiz].reshape((self.nx, self.ny, 1)), - 'd' + quant[-4:]), - (self.nx, self.ny)) - else: - for iiy in range(self.ny): - output[:, iiy, :] = np.reshape( - stagger.do(var[:, iiy, :].reshape((self.nx, 1, self.nz)), - 'd' + quant[-4:]), - (self.nx, self.nz)) - - return output - else: - return stagger.do(var, 'd' + quant[-4:], getattr(self,'d' + axis +'i' + quant[-4:])) - - elif quant[-2:] in CENTRE_QUANT: - # This brings a given vector quantity to cell centres - axis = quant[-2] - q = quant[:-1] # base variable - if q[:-1] == 'i' or q == 'e': - AXIS_TRANSFORM = {'x': ['yup', 'zup'], - 'y': ['xup', 'zup'], - 'z': ['xup', 'yup']} - else: - AXIS_TRANSFORM = {'x': ['xup'], - 'y': ['yup'], - 'z': ['zup']} - transf = AXIS_TRANSFORM[axis] - - var = self.get_var(q, **kwargs) - - # 2D - if getattr(self, 'n' + axis) < 5: - return var - else: - if len(transf) == 2: - if self.lowbus: - output = np.zeros_like(var) - if transf[0][0] != 'z': - for iiz in range(self.nz): - output[:, :, iiz] = np.reshape(stagger.do( - var[:, :, iiz].reshape( - (self.nx, self.ny, 1)), - transf[0]), (self.nx, self.ny)) - else: - for iiy in range(self.ny): - output[:, iiy, :] = np.reshape(stagger.do( - var[:, iiy, :].reshape( - (self.nx, 1, self.nz)), - transf[0]), (self.nx, self.nz)) - - if transf[1][0] != 'z': - for iiz in range(self.nz): - output[:, :, iiz] = np.reshape(stagger.do( - output[:, :, iiz].reshape( - (self.nx, self.ny, 1)), - transf[1]), (self.nx, self.ny)) - else: - for iiy in range(self.ny): - output[:, iiy, :] = np.reshape(stagger.do( - output[:, iiy, :].reshape( - (self.nx, 1, self.nz)), - transf[1]), (self.nx, self.nz)) - return output - else: - tmp = stagger.do(var, transf[0]) - return stagger.do(tmp, transf[1]) - else: - if self.lowbus: - output = np.zeros_like(var) - if axis != 'z': - for iiz in range(self.nz): - output[:, :, iiz] = np.reshape(stagger.do( - var[:, :, iiz].reshape( - (self.nx, self.ny, 1)), - transf[0]), (self.nx, self.ny)) - else: - for iiy in range(self.ny): - output[:, iiy, :] = np.reshape(stagger.do( - var[:, iiy, :].reshape( - (self.nx, 1, self.nz)), - transf[0]), (self.nx, self.nz)) - return output - else: - return stagger.do(var, transf[0]) - - elif quant[:6] in GRADVECT_QUANT or quant[:3] in GRADVECT_QUANT: - if quant[:3] == 'chk': - q = quant[6:] # base variable - if getattr(self, 'nx') < 5: # 2D or close - varx = np.zeros_like(self.r) - else: - varx = self.get_var('d' + q + 'xdxup') - - if getattr(self, 'ny') > 5: - vary = self.get_var('d' + q + 'ydyup') - else: - vary = np.zeros_like(varx) - - if getattr(self, 'nz') > 5: - varz = self.get_var('d' + q + 'zdzup') - else: - varz = np.zeros_like(varx) - return np.abs(varx + vary + varx) / (np.maximum( - np.abs(varx), np.abs(vary), np.abs(varz)) + 1.0e-20) - - elif quant[:3] == 'chb': - q = quant[6:] # base variable - varx = self.get_var(q + 'x') - vary = self.get_var(q + 'y') - varz = self.get_var(q + 'z') - if getattr(self, 'nx') < 5: # 2D or close - result = np.zeros_like(varx) - else: - result = self.get_var('d' + q + 'xdxup') - - if getattr(self, 'ny') > 5: - result += self.get_var('d' + q + 'ydyup') - - if getattr(self, 'nz') > 5: - result += self.get_var('d' + q + 'zdzup') - - return np.abs(result / (np.sqrt( - varx * varx + vary * vary + varz * varz) + 1.0e-20)) - - elif quant[:3] == 'chh': - q = quant[6:] # base variable - varx = self.get_var(q + 'x') - vary = self.get_var(q + 'y') - varz = self.get_var(q + 'z') - if getattr(self, 'nx') < 5: # 2D or close - result = np.zeros_like(varx) - else: - result = self.get_var('d' + q + 'xdxup') - - if getattr(self, 'ny') > 5: - result += self.get_var('d' + q + 'ydyup') - - if getattr(self, 'nz') > 5: - result += self.get_var('d' + q + 'zdzup') - - for iiz in range(0, self.nz): - result[:, :, iiz] = np.abs(result[:, :, iiz]) / np.mean(( - np.sqrt(varx[:, :, iiz]**2 + vary[:, :, iiz]**2 + - varz[:, :, iiz]**2))) - return result - - elif quant[:3] == 'div': # divergence of vector quantity - q = quant[3:] # base variable - if getattr(self, 'nx') < 5: # 2D or close - result = np.zeros_like(self.r) - else: - result = self.get_var('d' + q + 'xdxup') - if getattr(self, 'ny') > 5: - result += self.get_var('d' + q + 'ydyup') - if getattr(self, 'nz') > 5: - result += self.get_var('d' + q + 'zdzup') - - elif quant[:3] == 'rot' or quant[:3] == 'she': - q = quant[3:-1] # base variable - qaxis = quant[-1] - if qaxis == 'x': - if getattr(self, 'ny') < 5: # 2D or close - result = np.zeros_like(self.r) - else: - result = self.get_var('d' + q + 'zdyup') - if getattr(self, 'nz') > 5: - if quant[:3] == 'rot': - result -= self.get_var('d' + q + 'ydzup') - else: # shear - result += self.get_var('d' + q + 'ydzup') - elif qaxis == 'y': - if getattr(self, 'nz') < 5: # 2D or close - result = np.zeros_like(self.r) - else: - result = self.get_var('d' + q + 'xdzup') - if getattr(self, 'nx') > 5: - if quant[:3] == 'rot': - result -= self.get_var('d' + q + 'zdxup') - else: # shear - result += self.get_var('d' + q + 'zdxup') - elif qaxis == 'z': - if getattr(self, 'nx') < 5: # 2D or close - result = np.zeros_like(self.r) - else: - result = self.get_var('d' + q + 'ydxup') - if getattr(self, 'ny') > 5: - if quant[:3] == 'rot': - result -= self.get_var('d' + q + 'xdyup') - else: # shear - result += self.get_var('d' + q + 'xdyup') - return result - - elif quant[:3] in GRADSCAL_QUANT: - if quant[:3] == 'gra': - q = quant[3:] # base variable - if getattr(self, 'nx') < 5: # 2D or close - result = np.zeros_like(self.r) - else: - result = self.get_var('d' + q + 'dxup') - if getattr(self, 'ny') > 5: - result += self.get_var('d' + q + 'dyup') - if getattr(self, 'nz') > 5: - result += self.get_var('d' + q + 'dzup') - return result - - elif quant[:6] in HORVAR_QUANT: - # Compares the variable with the horizontal mean - if quant[:6] == 'horvar': - result = np.zeros_like(self.r) - result += self.get_var(quant[6:]) # base variable - horv = np.mean(np.mean(result, 0), 0) - for iix in range(0, getattr(self, 'nx')): - for iiy in range(0, getattr(self, 'ny')): - result[iix, iiy, :] = result[iix, iiy, :] / horv[:] - return result - - elif quant in EOSTAB_QUANT: - # unit conversion to SI - # to g/cm^3 - ur = self.params['u_r'][self.snapInd] - ue = self.params['u_ee'][self.snapInd] # to erg/g - if self.hion and quant == 'ne': - return self.get_var('hionne') - rho = self.get_var('r') - rho = rho * ur - ee = self.get_var('ee') - ee = ee * ue - if self.verbose: - print(quant + ' interpolation...') - - fac = 1.0 - # JMS Why SI?? SI seems to work with bifrost_uvotrt. - if quant == 'ne': - fac = 1.e6 # cm^-3 to m^-3 - if quant in ['eps', 'opa', 'temt']: - radtab = True - else: - radtab = False - eostab = Rhoeetab(fdir=self.fdir, radtab=radtab) - return eostab.tab_interp( - rho, ee, order=1, out=quant) * fac - - elif quant[1:4] in PROJ_QUANT: - # projects v1 onto v2 - v1 = quant[0] - v2 = quant[4] - x_a = self.get_var(v1 + 'xc', self.snap) - y_a = self.get_var(v1 + 'yc', self.snap) - z_a = self.get_var(v1 + 'zc', self.snap) - x_b = self.get_var(v2 + 'xc', self.snap) - y_b = self.get_var(v2 + 'yc', self.snap) - z_b = self.get_var(v2 + 'zc', self.snap) - # can be used for threadQuantity() or as is - - def proj_task(x1, y1, z1, x2, y2, z2): - v2Mag = np.sqrt(x2 ** 2 + y2 ** 2 + z2 ** 2) - v2x, v2y, v2z = x2 / v2Mag, y2 / v2Mag, z2 / v2Mag - parScal = x1 * v2x + y1 * v2y + z1 * v2z - parX, parY, parZ = parScal * v2x, parScal * v2y, parScal * v2z - result = np.abs(parScal) - if quant[1:4] == 'per': - perX = x1 - parX - perY = y1 - parY - perZ = z1 - parZ - v1Mag = np.sqrt(perX**2 + perY**2 + perZ**2) - result = v1Mag - return result - - if self.numThreads > 1: - if self.verbose: - print('Threading') - - return threadQuantity(proj_task, self.numThreads, - x_a, y_a, z_a, x_b, y_b, z_b) - else: - return proj_task(x_a, y_a, z_a, x_b, y_b, z_b) - - elif quant in CURRENT_QUANT: - # Calculate derivative of quantity - axis = quant[-1] - if quant[0] == 'i': - q = 'b' - else: - q = 'u' - try: - var = getattr(self, quant) - except AttributeError: - if axis == 'x': - varsn = ['z', 'y'] - derv = ['dydn', 'dzdn'] - elif axis == 'y': - varsn = ['x', 'z'] - derv = ['dzdn', 'dxdn'] - elif axis == 'z': - varsn = ['y', 'x'] - derv = ['dxdn', 'dydn'] - - # 2D or close - if (getattr(self, 'n' + varsn[0]) < 5) or (getattr(self, 'n' + varsn[1]) < 5): - return np.zeros_like(self.r) - else: - return (self.get_var('d' + q + varsn[0] + derv[0]) - - self.get_var('d' + q + varsn[1] + derv[1])) - - elif quant in FLUX_QUANT: - axis = quant[-1] - if axis == 'x': - varsn = ['z', 'y'] - elif axis == 'y': - varsn = ['x', 'z'] - elif axis == 'z': - varsn = ['y', 'x'] - if 'pfw' in quant or len(quant) == 3: - var = - self.get_var('b' + axis + 'c') * ( - self.get_var('u' + varsn[0] + 'c') * - self.get_var('b' + varsn[0] + 'c') + - self.get_var('u' + varsn[1] + 'c') * - self.get_var('b' + varsn[1] + 'c')) - else: - var = np.zeros_like(self.r) - if 'pfe' in quant or len(quant) == 3: - var += self.get_var('u' + axis + 'c') * ( - self.get_var('b' + varsn[0] + 'c')**2 + - self.get_var('b' + varsn[1] + 'c')**2) - return var - - elif quant in PLASMA_QUANT: - if quant in ['hp', 's', 'cs', 'beta']: - var = self.get_var('p') - if quant == 'hp': - if getattr(self, 'nx') < 5: - return np.zeros_like(var) - else: - return 1. / (stagger.do(var, 'ddzup', self.dzidzup) + 1e-12) - elif quant == 'cs': - return np.sqrt(self.params['gamma'][self.snapInd] * - var / self.get_var('r')) - elif quant == 's': - return (np.log(var) - self.params['gamma'][self.snapInd] * - np.log(self.get_var('r'))) - elif quant == 'beta': - return 2 * var / self.get_var('b2') - - if quant in ['mn', 'man']: - var = self.get_var('modu') - if quant == 'mn': - return var / (self.get_var('cs') + 1e-12) - else: - return var / (self.get_var('va') + 1e-12) - - if quant in ['va', 'vax', 'vay', 'vaz']: - var = self.get_var('r') - if len(quant) == 2: - return self.get_var('modb') / np.sqrt(var) - else: - axis = quant[-1] - return np.sqrt(self.get_var('b' + axis + 'c') ** 2 / var) - - if quant in ['hx', 'hy', 'hz', 'kx', 'ky', 'kz']: - axis = quant[-1] - var = self.get_var('p' + axis + 'c') - if quant[0] == 'h': - return ((self.get_var('e') + self.get_var('p')) / - self.get_var('r') * var) - else: - return self.get_var('u2') * var * 0.5 - - if quant in ['ke']: - var = self.get_var('r') - return self.get_var('u2') * var * 0.5 - - elif quant == 'tau': - - return self.calc_tau() - - elif quant in WAVE_QUANT: - bx = self.get_var('bxc') - by = self.get_var('byc') - bz = self.get_var('bzc') - bMag = np.sqrt(bx**2 + by**2 + bz**2) - bx, by, bz = bx / bMag, by / bMag, bz / bMag # b is already centered - # unit vector of b - unitB = np.stack((bx, by, bz)) - - if quant == 'alf': - uperb = self.get_var('uperb') - uperbVect = uperb * unitB - # cross product (uses stagger bc no variable gets uperbVect) - curlX = ( stagger.do(stagger.do(uperbVect[2], 'ddydn', self.dyidydn), 'yup') - -stagger.do(stagger.do(uperbVect[1], 'ddzdn', self.dzidzdn), 'zup')) - curlY = (-stagger.do(stagger.do(uperbVect[2], 'ddxdn', self.dxidxdn), 'xup') - +stagger.do(stagger.do(uperbVect[0], 'ddzdn', self.dzidzdn), 'zup')) - curlZ = ( stagger.do(stagger.do(uperbVect[1], 'ddxdn', self.dxidxdn), 'xup') - -stagger.do(stagger.do(uperbVect[0], 'ddydn', self.dyidydn), 'yup')) - curl = np.stack((curlX, curlY, curlZ)) - # dot product - result = np.abs((unitB * curl).sum(0)) - elif quant == 'fast': - uperb = self.get_var('uperb') - uperbVect = uperb * unitB - - result = np.abs(stagger.do(stagger.do( - uperbVect[0], 'ddxdn', self.dxidxdn), 'xup') + stagger.do(stagger.do( - uperbVect[1], 'ddydn', self.dyidydn), 'yup') + stagger.do(stagger.do( - uperbVect[2], 'ddzdn', self.dzidzdn), 'zup')) - else: - dot1 = self.get_var('uparb') - grad = np.stack((stagger.do( - stagger.do(dot1, 'ddxdn', self.dxidxdn),'xup'), - stagger.do( - stagger.do(dot1, 'ddydn', self.dyidydn),'yup'), - stagger.do( - stagger.do(dot1, 'ddzdn', self.dzidzdn),'zup'))) - result = np.abs((unitB * grad).sum(0)) - return result - - elif quant in CYCL_RES: - if self.hion and self.heion: - posn = ([pos for pos, char in enumerate(quant) if char == 'n']) - q2 = quant[posn[-1]:] - q1 = quant[:posn[-1]] - if self.hion: - nel = self.get_var('hionne') - else: - nel = self.get_var('nel') - var2 = self.get_var(q2) - var1 = self.get_var(q1) - z1 = 1.0 - z2 = float(quant[-1]) - if q1[:3] == 'n6': - omega1 = self.get_var('gfh2') - else: - omega1 = self.get_var('gf'+q1[1:]) - omega2 = self.get_var('gf'+q2[1:]) - return (z1 * var1 * omega2 + z2 * var2 * omega1) / nel - else: - raise ValueError(('get_quantity: This variable is only ' - 'avaiable if do_hion and do_helium is true')) - - elif quant in DEBYE_LN_QUANT: - tg = self.get_var('tg') - part = np.copy(self.get_var('ne')) - # We are assuming a single charge state: - for iele in elemlist: - part += self.get_var('n' + iele + '-2') - if self.heion: - part += 4.0 * self.get_var('nhe3') - # check units of n - return np.sqrt(self.uni.permsi / self.uni.qsi_electron**2 / - (self.uni.ksi_b * tg.astype('float64') * - part.astype('float64') + 1.0e-20)) - - elif ''.join([i for i in quant if not i.isdigit()]) in GYROF_QUANT: - if quant == 'gfe': - return self.get_var('modb') * self.uni.usi_b * \ - self.uni.qsi_electron / (self.uni.msi_e) - else: - ion = float(''.join([i for i in quant if i.isdigit()])) - return self.get_var('modb') * self.uni.usi_b * \ - self.uni.qsi_electron * \ - (ion - 1.0) / \ - (self.uni.weightdic[quant[2:-1]] * self.uni.amusi) - - elif quant in COULOMB_COL_QUANT: - iele = np.where(COULOMB_COL_QUANT == quant) - tg = self.get_var('tg') - nel = np.copy(self.get_var('ne')) - elem = quant.replace('coucol', '') - - const = (self.uni.pi * self.uni.qsi_electron ** 4 / - ((4.0 * self.uni.pi * self.uni.permsi)**2 * - np.sqrt(self.uni.weightdic[elem] * self.uni.amusi * - (2.0 * self.uni.ksi_b) ** 3) + 1.0e-20)) - - return (const * nel.astype('Float64') * - np.log(12.0 * self.uni.pi * nel.astype('Float64') * - self.get_var('debye_ln').astype('Float64') + 1e-50) / - (np.sqrt(tg.astype('Float64')**3) + 1.0e-20)) - - elif quant in CROSTAB_QUANT: - tg = self.get_var('tg') - elem = quant.split('_') - spic1 = ''.join([i for i in elem[0] if not i.isdigit()]) - spic2 = ''.join([i for i in elem[1] if not i.isdigit()]) - cross_tab = '' - crossunits = 2.8e-17 - if spic1 == 'h': - if spic2 == 'h': - cross_tab = 'p-h-elast.txt' - elif spic2 == 'he': - cross_tab = 'p-he.txt' - elif spic2 == 'e': - cross_tab = 'e-h.txt' - crossunits = 1e-16 - else: - cross = self.uni.weightdic[spic2] / self.uni.weightdic['h'] * \ - self.uni.cross_p * np.ones(np.shape(tg)) - elif spic1 == 'he': - if spic2 == 'h': - cross_tab = 'p-h-elast.txt' - elif spic2 == 'he': - cross_tab = 'he-he.txt' - crossunits = 1e-16 - elif spic2 == 'e': - cross_tab = 'e-he.txt' - else: - cross = self.uni.weightdic[spic2] / self.uni.weightdic['he'] * \ - self.uni.cross_he * np.ones(np.shape(tg)) - elif spic1 == 'e': - if spic2 == 'h': - cross_tab = 'e-h.txt' - elif spic2 == 'he': - cross_tab = 'e-he.txt' - if cross_tab != '': - crossobj = Cross_sect(cross_tab=[cross_tab]) - cross = crossunits * crossobj.tab_interp(tg) - else: - cross = self.uni.weightdic[spic2] / self.uni.weightdic['h'] * \ - self.uni.cross_p * np.ones(np.shape(tg)) - try: - return cross - except Exception: - print('(WWW) cross-section: wrong combination of species') - - elif ''.join([i for i in quant if not i.isdigit()]) in COLFRE_QUANT: - - elem = quant.split('_') - spic1 = ''.join([i for i in elem[0] if not i.isdigit()]) - ion1 = ''.join([i for i in elem[0] if i.isdigit()]) - spic2 = ''.join([i for i in elem[1] if not i.isdigit()]) - ion2 = ''.join([i for i in elem[1] if i.isdigit()]) - - spic1 = spic1[2:] - crossarr = self.get_var('%s_%s' % (spic1, spic2)) - nspic2 = self.get_var('n%s-%s' % (spic2, ion2)) - - tg = self.get_var('tg') - awg1 = self.uni.weightdic[spic1] * self.uni.amu - awg2 = self.uni.weightdic[spic2] * self.uni.amu - scr1 = np.sqrt(8.0 * self.uni.kboltzmann * tg / self.uni.pi) - - return crossarr * np.sqrt((awg1 + awg2) / (awg1 * awg2)) *\ - scr1 * nspic2 * (awg1 / (awg1 + awg1)) - - elif ''.join([i for i in quant if not i.isdigit()]) in COLFRI_QUANT: - if quant == 'nu_ni': - result = self.uni.m_h * self.get_var('nh-1') * \ - self.get_var('nuh1_i') + \ - self.uni.m_he * \ - self.get_var('nhe-1') * self.get_var('nuhe1_i') - elif quant == 'nu_ei': - if self.hion: - nel = self.get_var('hionne') - else: - nel = self.get_var('nel') - culblog = 23. + 1.5 * np.log(self.get_var('tg') / 1.e6) - \ - 0.5 * np.log(nel / 1e6) - - result = 3.759 * nel / (self.get_var('tg')**(1.5)) * culblog - - elif quant == 'nu_en': - if self.hion: - nel = self.get_var('hionne') - else: - nel = self.get_var('nel') - culblog = 23. + 1.5 * np.log(self.get_var('tg') / 1.e6) - \ - 0.5 * np.log(nel / 1e6) - scr1 = 3.759 * nel / (self.get_var('tg')**(1.5)) * culblog - scr2 = 0.0 * nel - for ielem in elemlist: - scr2 += self.get_var('n%s-%s' % (ielem, 2)) - if self.heion and quant[-2:] == '_i': - scr2 += self.get_var('%s_%s' % (elem[0], 'he3')) - result = 5.2e-11 * scr2 / nel * self.get_var('tg')**2 / \ - culblog * scr1 - - else: - if quant[-2:] == '_i': - lvl = '2' - else: - lvl = '1' - elem = quant.split('_') - result = np.zeros(np.shape(self.r)) - for ielem in elemlist: - if elem[0][2:] != '%s%s' % (ielem, lvl): - result += self.get_var('%s_%s%s' % - (elem[0], ielem, lvl)) - if self.heion and quant[-2:] == '_i': - result += self.get_var('%s_%s' % (elem[0], 'he3')) - return result - - elif ''.join([i for i in quant if not i.isdigit()]) in IONP_QUANT: - elem = quant.split('_') - spic = ''.join([i for i in elem[0] if not i.isdigit()]) - lvl = ''.join([i for i in elem[0] if i.isdigit()]) - if self.hion and spic[1:-1] == 'h': - if quant[0] == 'n': - mass = 1.0 - else: - mass = self.uni.m_h - if lvl == '1': - - return mass * (self.get_var('n1') + - self.get_var('n2') + self.get_var('n3') + - self.get_var('n4') + self.get_var('n5')) - else: - return mass * self.get_var('n6') - elif self.heion and spic[1:-1] == 'he': - if quant[0] == 'n': - mass = 1.0 - else: - mass = self.uni.m_he - if self.verbose: - print('get_var: reading nhe%s' % lvl) - return mass * self.get_var('nhe%s' % lvl) - - else: - tg = self.get_var('tg') - r = self.get_var('r') - nel = self.get_var('ne') / 1e6 # 1e6 conversion from SI to cgs - - if quant[0] == 'n': - dens = False - else: - dens = True - - return ionpopulation(r, nel, tg, elem=spic[1:-1], lvl=lvl, - dens=dens) - - elif ((quant[:3] in MODULE_QUANT) or ( - quant[-1] in MODULE_QUANT) or ( - quant[-1] in SQUARE_QUANT and not(quant in CYCL_RES))): - # Calculate module of vector quantity - if (quant[:3] in MODULE_QUANT): - q = quant[3:] - else: - q = quant[:-1] - if q == 'b': - if not self.do_mhd: - raise ValueError("No magnetic field available.") - result = self.get_var(q + 'xc') ** 2 - result += self.get_var(q + 'yc') ** 2 - if not(quant[-1] in MODULE_QUANT): - result += self.get_var(q + 'zc') ** 2 - - if (quant[:3] in MODULE_QUANT) or (quant[-1] in MODULE_QUANT): - return np.sqrt(result) - elif quant[-1] in SQUARE_QUANT: - return result - else: - raise ValueError(('get_quantity: do not know (yet) how to ' - 'calculate quantity %s. Note that simple_var ' - 'available variables are: %s.\nIn addition, ' - 'get_quantity can read others computed variables ' - 'see e.g. help(self.get_quantity) for guidance' - '.' % (quant, repr(self.simple_vars)))) - - def calc_tau(self): - """ - Calculates optical depth. - - DEPRECATED, DO NOT USE. - """ - warnings.warn("Use of calc_tau is discouraged. It is model-dependent, " - "inefficient and slow, and will give wrong results in " - "many scenarios. DO NOT USE.") - - if not hasattr(self, 'z'): - print('(WWW) get_tau needs the height (z) in Mm (units code)') - - # grph = 2.38049d-24 uni.GRPH - # bk = 1.38e-16 uni.KBOLTZMANN - # EV_TO_ERG=1.60217733E-12 uni.EV_TO_ERG - if not hasattr(self, 'ne'): - nel = self.get_var('ne') - else: - nel = self.ne - - if not hasattr(self, 'tg'): - tg = self.get_var('tg') - else: - tg = self.tg + def _get_composite_var(self, *args, **kwargs): + '''returns load_fromfile_quantities._get_composite_var(self, *args, **kwargs). - if not hasattr(self, 'r'): - rho = self.get_var('r') * self.uni.u_r - else: - rho = self.r * self.uni.u_r - - tau = np.zeros((self.nx, self.ny, self.nz)) + 1.e-16 - xhmbf = np.zeros((self.nz)) - const = (1.03526e-16 / self.uni.grph) * 2.9256e-17 / 1e6 - for iix in range(self.nx): - for iiy in range(self.ny): - for iiz in range(self.nz): - xhmbf[iiz] = const * nel[iix, iiy, iiz] / \ - tg[iix, iiy, iiz]**1.5 * np.exp(0.754e0 * - self.uni.ev_to_erg / self.uni.kboltzmann / - tg[iix, iiy, iiz]) * rho[iix, iiy, iiz] - - for iiz in range(1, self.nz): - tau[iix, iiy, iiz] = tau[iix, iiy, iiz - 1] + 0.5 *\ - (xhmbf[iiz] + xhmbf[iiz - 1]) *\ - np.abs(self.dz1d[iiz]) * 1.0e8 - return tau + included for backwards compatibility purposes, only. + new code should instead use the function from load_fromfile_quantitites. + ''' + return load_fromfile_quantities._get_composite_var(self, *args, **kwargs) def get_electron_density(self, sx=slice(None), sy=slice(None), sz=slice(None)): """ @@ -1609,21 +1186,20 @@ def get_hydrogen_pops(self, sx=slice(None), sy=slice(None), sz=slice(None)): if slice_size == 0: slice_size = n shape.append(slice_size) - nh = np.empty(shape, dtype='float32') + nh = np.empty(shape, dtype='Float32') for k in range(6): nv = self.get_var('n%i' % (k + 1)) nh[k] = nv[sx, sy, sz] else: rho = self.r[sx, sy, sz] * self.uni.u_r subsfile = os.path.join(self.fdir, 'subs.dat') - tabfile = os.path.join( - self.fdir, self.params['tabinputfile'][self.snapInd].strip()) + tabfile = os.path.join(self.fdir, self.get_param('tabinputfile', error_prop=True).strip()) tabparams = [] if os.access(tabfile, os.R_OK): - tabparams = read_idl_ascii(tabfile) + tabparams = read_idl_ascii(tabfile, obj=self) if 'abund' in tabparams and 'aweight' in tabparams: - abund = tabparams['abund'].astype('f') - aweight = tabparams['aweight'].astype('f') + abund = np.array(tabparams['abund'].split()).astype('f') + aweight = np.array(tabparams['aweight'].split()).astype('f') grph = calc_grph(abund, aweight) elif os.access(subsfile, os.R_OK): grph = subs2grph(subsfile) @@ -1669,7 +1245,7 @@ def write_rh15d(self, outfile, desc=None, append=True, sx=slice(None), # strongly recoment to use Bifrost_units. ul = self.params['u_l'][self.snapInd] / 1.e2 # to metres # to g/cm^3 (for ne_rt_table) - ur = self.params['u_r'][self.snapInd] + self.params['u_r'][self.snapInd] ut = self.params['u_t'][self.snapInd] # to seconds uv = ul / ut ub = self.params['u_b'][self.snapInd] * 1e-4 # to Tesla @@ -1680,9 +1256,12 @@ def write_rh15d(self, outfile, desc=None, append=True, sx=slice(None), rho = self.r[sx, sy, sz] if self.do_mhd: - Bx = stagger.xup(self.bx)[sx, sy, sz] - By = stagger.yup(self.by)[sx, sy, sz] - Bz = stagger.zup(self.bz)[sx, sy, sz] + Bx = do_stagger(self.bx, 'xup', obj=self)[sx, sy, sz] + By = do_stagger(self.by, 'yup', obj=self)[sx, sy, sz] + Bz = do_stagger(self.bz, 'zup', obj=self)[sx, sy, sz] + # Bx = cstagger.xup(self.bx)[sx, sy, sz] + # By = cstagger.yup(self.by)[sx, sy, sz] + # Bz = cstagger.zup(self.bz)[sx, sy, sz] # Change sign of Bz (because of height scale) and By # (to make right-handed system) Bx = Bx * ub @@ -1691,12 +1270,13 @@ def write_rh15d(self, outfile, desc=None, append=True, sx=slice(None), else: Bx = By = Bz = None - vz = stagger.zup(self.pz)[sx, sy, sz] / rho + vz = do_stagger(self.pz, 'zup', obj=self)[sx, sy, sz] / rho + # vz = cstagger.zup(self.pz)[sx, sy, sz] / rho vz *= -uv if write_all_v: - vx = stagger.xup(self.px)[sx, sy, sz] / rho + vx = cstagger.xup(self.px)[sx, sy, sz] / rho vx *= uv - vy = stagger.yup(self.py)[sx, sy, sz] / rho + vy = cstagger.yup(self.py)[sx, sy, sz] / rho vy *= -uv else: vx = None @@ -1728,8 +1308,7 @@ def write_rh15d(self, outfile, desc=None, append=True, sx=slice(None), pbar.update() def write_multi3d(self, outfile, mesh='mesh.dat', desc=None, - sx=slice(None), sy=slice(None), sz=slice(None), - write_magnetic=False): + sx=slice(None), sy=slice(None), sz=slice(None)): """ Writes snapshot in Multi3D format. Parameters @@ -1744,41 +1323,41 @@ def write_multi3d(self, outfile, mesh='mesh.dat', desc=None, Slice objects for x, y, and z dimensions, when not all points are needed. E.g. use slice(None) for all points, slice(0, 100, 2) for every second point up to 100. - write_magnetic - bool, optional - Whether to write a magnetic field file. Default is False. Returns ------- None. """ from .multi3d import Multi3dAtmos - from .multi3d import Multi3dMagnetic + # unit conversion to cgs and km/s ul = self.params['u_l'][self.snapInd] # to cm ur = self.params['u_r'][self.snapInd] # to g/cm^3 (for ne_rt_table) ut = self.params['u_t'][self.snapInd] # to seconds uv = ul / ut / 1e5 # to km/s - ub = self.params['u_b'][self.snapInd] # to G ue = self.params['u_ee'][self.snapInd] # to erg/g nh = None if self.verbose: - print('Slicing and unit conversion...') + print('Slicing and unit conversion...', whsp*4, end="\r", + flush=True) temp = self.tg[sx, sy, sz] rho = self.r[sx, sy, sz] # Change sign of vz (because of height scale) and vy (to make # right-handed system) - vx = stagger.xup(self.px)[sx, sy, sz] / rho + # vx = cstagger.xup(self.px)[sx, sy, sz] / rho + vx = do_stagger(self.px, 'xup', obj=self)[sx, sy, sz] / rho vx *= uv - vy = stagger.yup(self.py)[sx, sy, sz] / rho + vy = do_stagger(self.py, 'yup', obj=self)[sx, sy, sz] / rho vy *= -uv - vz = stagger.zup(self.pz)[sx, sy, sz] / rho + vz = do_stagger(self.pz, 'zup', obj=self)[sx, sy, sz] / rho vz *= -uv rho = rho * ur # to cgs x = self.x[sx] * ul - y = self.y[sy] * (-ul) + y = self.y[sy] * ul z = self.z[sz] * (-ul) ne = self.get_electron_density(sx, sy, sz).to_value('1/cm3') # write to file - print('Write to file...') + if self.verbose: + print('Write to file...', whsp*8, end="\r", flush=True) nx, ny, nz = temp.shape fout = Multi3dAtmos(outfile, nx, ny, nz, mode="w+", read_nh=self.hion) fout.ne[:] = ne @@ -1800,19 +1379,745 @@ def write_multi3d(self, outfile, mesh='mesh.dat', desc=None, fout2.write("\n%i\n" % nz) z.tofile(fout2, sep=" ", format="%11.5e") fout2.close() - if write_magnetic: - Bx = stagger.xup(self.bx)[sx, sy, sz] - By = stagger.yup(self.by)[sx, sy, sz] - Bz = stagger.zup(self.bz)[sx, sy, sz] - # Change sign of Bz (because of height scale) and By - # (to make right-handed system) - Bx = Bx * ub - By = -By * ub - Bz = -Bz * ub - fout3 = Multi3dMagnetic('magnetic.dat', nx, ny, nz, mode='w+') - fout3.Bx[:] = Bx - fout3.By[:] = By - fout3.Bz[:] = Bz + + ## VALUES OVER TIME, and TIME DERIVATIVES ## + + def get_varTime(self, var, snap=None, iix=None, iiy=None, iiz=None, + print_freq=None, printing_stats=None, + *args__get_var, **kw__get_var): + """ + Reads a given variable as a function of time. + + Parameters + ---------- + var - string + Name of the variable to read. Must be a valid Bifrost variable name, + see Bifrost.get_var(). + snap - array of integers + Snapshot numbers to read. + iix -- integer or array of integers, optional + reads yz slices. + iiy -- integer or array of integers, optional + reads xz slices. + iiz -- integer or array of integers, optional + reads xy slices. + print_freq - number, default 2 + print progress update every print_freq seconds. + Use print_freq < 0 to never print update. + Use print_freq ==0 to print all updates. + printing_stats - None, bool, or dict + whether to print stats of result (via self.print_stats). + None --> use value of self.printing_stats. + False --> don't print stats. (This is the default value for self.printing_stats.) + True --> do print stats. + dict --> do print stats, passing this dictionary as kwargs. + + additional *args and **kwargs are passed to get_var. + """ + # set print_freq + if print_freq is None: + print_freq = getattr(self, 'print_freq', 2) # default 2 + else: + setattr(self, 'print_freq', print_freq) + + # set snap + if snap is None: + snap = kw__get_var.pop('snaps', None) # look for 'snaps' kwarg + if snap is None: + snap = self.snap + snap = np.array(snap, copy=False) + if len(snap.shape) == 0: + raise ValueError('Expected snap to be list (in get_varTime) but got snap={}'.format(snap)) + if not np.array_equal(snap, self.snap): + self.set_snap(snap) + self.variables = {} + + # set iix,iiy,iiz. + self.set_domain_iiaxes(iix=iix, iiy=iiy, iiz=iiz, internal=False) + snapLen = np.size(self.snap) + + # bookkeeping - maintain self.snap; handle self.recoverData; don't print stats in the middle; track timing. + remembersnaps = self.snap # remember self.snap (restore later if crash) + if hasattr(self, 'recoverData'): + delattr(self, 'recoverData') # smash any existing saved data + kw__get_var.update(printing_stats=False) # never print_stats in the middle of get_varTime. + timestart = now = time.time() # track timing, so we can make updates. + printed_update = False + + def _print_clearline(N=100): # clear N chars, and move cursor to start of line. + print('\r' + ' '*N + '\r', end='') # troubleshooting: ensure end='' in other prints. + + try: + firstit = True + for it in range(0, snapLen): + self.snapInd = it + # print update if it is time to print update + if (print_freq >= 0) and (time.time() - now > print_freq): + _print_clearline() + print('Getting {:^10s}; at snap={:2d} (snap_it={:2d} out of {:2d}).'.format( + var, snap[it], it, snapLen), end='') + now = time.time() + print(' Total time elapsed = {:.1f} s'.format(now - timestart), end='') + printed_update = True + + # actually get the values here: + if firstit: + # get value at first snap + val0 = self.get_var(var, snap=snap[it], *args__get_var, **kw__get_var) + # figure out dimensions and initialize the output array. + value = np.empty_like(val0, shape=[*np.shape(val0), snapLen]) + value[..., 0] = val0 + firstit = False + else: + value[..., it] = self.get_var(var, snap=snap[it], + *args__get_var, **kw__get_var) + except: # here it is ok to except all errors, because we always raise. + if it > 0: + self.recoverData = value[..., :it] # save data + if self.verbose: + print(('Crashed during get_varTime, but managed to get data from {} ' + 'snaps before crashing. Data was saved and can be recovered ' + 'via self.recoverData.'.format(it))) + raise + finally: + self.set_snap(remembersnaps) # restore snaps + if printed_update: + _print_clearline() + print('Completed in {:.1f} s'.format(time.time() - timestart), end='\r') + + self.print_stats(value, printing_stats=printing_stats) + return value + + @tools.maintain_attrs('snap') + def ddt(self, var, snap=None, *args__get_var, method='centered', printing_stats=None, **kw__get_var): + '''time derivative of var, at current snapshot. + Units are determined by self.units_output (default: [simulation units]). + + snap: None or value + if provided (not None), first self.set_snap(snap). + method: ('forward', 'backward', 'centered') + tells how to take the time derivative. + forward --> (var[snap+1] - var[snap]) / (t[snap+1] - t[snap]) + backward --> (var[snap] - var[snap-1]) / (t[snap] - t[snap-1]) + centered --> (var[snap+1] - var[snap-1]) / (t[snap+1] - t[snap-1]) + ''' + if snap is not None: + self.set_snap(snap) + method = method.lower() + if method == 'forward': + snaps = [self.get_snap_here(), self.get_snap_next()] + elif method == 'backward': + snaps = [self.get_snap_prev(), self.get_snap_here()] + elif method == 'centered': + snaps = [self.get_snap_prev(), self.get_snap_next()] + else: + raise ValueError(f'Unrecognized method in ddt: {repr(method)}') + kw__get_var.update(printing_stats=False) # never print_stats in the middle of ddt. + self.set_snap(snaps[0]) + value0 = self(var, *args__get_var, **kw__get_var) + time0 = self.get_coord('t')[0] + self.set_snap(snaps[1]) + value1 = self(var, *args__get_var, **kw__get_var) + time1 = self.get_coord('t')[0] + result = (value1 - value0) / (time1 - time0) + self.print_stats(result, printing_stats=printing_stats) # print stats iff self.printing_stats. + return result + + def get_dvarTime(self, var, method='numpy', kw__gradient=dict(), printing_stats=None, **kw__get_varTime): + '''time derivative of var, across time. + Units are determined by self.units_output (default: [simulation units]). + + method: ('numpy', 'simple', 'centered') + tells how to take the time derivative: + numpy --> np.gradient(v, axis=-1) / np.gradient(tt, axis=-1) + result will be shape (..., M), + corresponding to times (tt). + simple --> (v[..., 1:] - v[..., :-1]) / (tt[..., 1:] - tt[..., :-1]) + result will be shape (..., M-1), + corresponding to times (tt[..., 1:] + tt[..., :-1]) / 2. + centered --> (v[..., 2:] - v[..., :-2]) / (tt[..., 2:] - tt[..., :-2]) + result will be shape (..., M-2), + corresponding to times (tt[..., 1:-1]) + where, above, v = self.get_varTime(var); + tt=self.get_coord('t'), with dims expanded (np.expand_dims) appropriately. + kw__gradient: dict + if method=='numpy', kw__gradient are passed to np.gradient. + (do not include 'axis' in kw__gradient.) + additional **kwargs are passed to self.get_varTime. + + returns: array of shape (..., M), + where M=len(self.snap) if method=='numpy', or len(self.snap)-1 if method=='simple'. + ''' + KNOWN_METHODS = ('numpy', 'simple', 'centered') + method = method.lower() + assert method in KNOWN_METHODS, f"Unrecognized method for get_dvarTime: {repr(method)}" + v = self.get_varTime(var, printing_stats=False, **kw__get_varTime) + tt = self.get_coord('t') + tt = np.expand_dims(tt, axis=tuple(range(0, v.ndim - tt.ndim))) # e.g. shape (1,1,1,len(self.snaps)) + method = method.lower() + if method == 'numpy': + result = np.gradient(v, **kw__gradient, axis=-1) / np.gradient(tt, axis=-1) + elif method == 'simple': + result = (v[..., 1:] - v[..., :-1]) / (tt[..., 1:] - tt[..., :-1]) + else: # method == 'centered' + result = (v[..., 2:] - v[..., :-2]) / (tt[..., 2:] - tt[..., :-2]) + self.print_stats(result, printing_stats=printing_stats) + return result + + def get_atime(self): + '''get average time, corresponding to times of derivative from get_dvarTime(..., method='simple').''' + tt = self.get_coord('t') + return (tt[..., 1:] + tt[..., :-1]) / 2 + + ## MISC. CONVENIENCE METHODS ## + def print_stats(self, value, *args, printing_stats=True, **kwargs): + '''print stats of value, via tools.print_stats. + printing_stats: None, bool, or dict. + None --> use value of self.printing_stats. + False --> don't print stats. + True --> do print stats. + dict --> do print stats, passing this dictionary as kwargs. + ''' + if printing_stats is None: + printing_stats = self.printing_stats + if printing_stats: + kw__print_stats = printing_stats if isinstance(printing_stats, dict) else dict() + tools.print_stats(value, **kw__print_stats) + + def get_varm(self, *args__get_var, **kwargs__get_var): + '''get_var but returns np.mean() of result. + provided for convenience for quicker debugging. + ''' + return np.mean(self.get_var(*args__get_var, **kwargs__get_var)) + + def get_varu(self, *args__get_var, mode='si', **kwargs__get_var): + '''get_var() then get_units() and return (result * units factor, units name). + e.g. r = self.get_var('r'); units = self.get_units('si'); return (r*units.factor, units.name). + e.g. self.get_varu('r') --> (r * units.factor, 'kg / m^{3}') + ''' + x = self.get_var(*args__get_var, **kwargs__get_var) + u = self.get_units(mode=mode) + return (x * u.factor, u.name) + + def get_varU(self, *args__get_var, mode='si', **kwargs__get_var): + '''get_varm() then get_units and return (result * units factor, units name). + equivalent to: x=self.get_varu(...); return (np.mean(x[0]), x[1]). + ''' + x = self.get_varm(*args__get_var, **kwargs__get_var) + u = self.get_units(mode=mode) + return (x * u.factor, u.name) + + get_varmu = get_varum = get_varU # aliases for get_varU + + def get_varV(self, var, *args__get_var, mode='si', vmode='modhat', **kwargs__get_var): + '''returns get_varU info but for a vector. + Output format depends on vmode: + 'modhat' ---> ((|var|,units), get_unit_vector(var, mean=True)) + 'modangle' -> ((|var|,units), (angle between +x and var, units of angle)) + 'xyz' ------> ([varx, vary, varz], units of var) + ''' + VALIDMODES = ('modhat', 'modangle', 'xyz') + vmode = vmode.lower() + assert vmode in VALIDMODES, 'vmode {} invalid! Expected vmode in {}.'.format(repr(vmode), VALIDMODES) + if vmode in ('modhat', 'modangle'): + mod = self.get_varU('mod'+var, *args__get_var, mode=mode, **kwargs__get_var) + if vmode == 'modhat': + hat = self.get_unit_vector(var, mean=True, **kwargs__get_var) + return (mod, hat) + elif vmode == 'modangle': + angle = self.get_varU(var+'_anglexxy', *args__get_var, mode=mode, **kwargs__get_var) + return (mod, angle) + elif vmode == 'xyz': + varxyz = [self.get_varm(var + x, *args__get_var, **kwargs__get_var) for x in ('x', 'y', 'z')] + units = self.get_units(mode=mode) + return (np.array(varxyz) * units.factor, units.name) + assert False # if we made it to this line it means something is wrong with the code here. + + def _varV_formatter(self, vmode, fmt_values='{: .2e}', fmt_units='{:^7s}'): + '''returns a format function for pretty formatting of the result of get_varV.''' + VALIDMODES = ('modhat', 'modangle', 'xyz') + vmode = vmode.lower() + assert vmode in VALIDMODES, 'vmode {} invalid! Expected vmode in {}.'.format(repr(vmode), VALIDMODES) + if vmode == 'modhat': + def fmt(x): + mag = fmt_values.format(x[0][0]) + units = fmt_units.format(x[0][1]) + hat = ('[ '+fmt_values+', '+fmt_values+', '+fmt_values+' ]').format(*x[1]) + return 'magnitude = {} [{}]; unit vector = {}'.format(mag, units, hat) + elif vmode == 'modangle': + def fmt(x): + mag = fmt_values.format(x[0][0]) + units = fmt_units.format(x[0][1]) + angle = fmt_values.format(x[1][0]) + angle_units = fmt_units.format(x[1][1]) + return 'magnitude = {} [{}]; angle (from +x) = {} [{}]'.format(mag, units, angle, angle_units) + elif vmode == 'xyz': + def fmt(x): + vec = ('[ '+fmt_values+', '+fmt_values+', '+fmt_values+' ]').format(*x[0]) + units = fmt_units.format(x[1]) + return '{} [{}]'.format(vec, units) + fmt.__doc__ = 'formats result of get_varV. I was made by helita.sim.bifrost._varV_formatter.' + return fmt + + def zero(self, **kw__np_zeros): + '''return np.zeros() with shape equal to shape of result of get_var()''' + return np.zeros(self.shape, **kw__np_zeros) + + def get_snap_here(self): + '''return self.snap, or self.snap[0] if self.snap is a list. + This is the snap which get_var() will work at, for the given self.snap value. + ''' + try: + iter(self.snap) + except TypeError: + return self.snap + else: + return self.snap[0] + + def get_snap_at_time(self, t, units='simu'): + '''get snap number which is closest to time t. + + units: 's', 'si', 'cgs', or 'simu' (default). + 's', 'si', 'cgs' --> enter t in seconds; return time at snap in seconds. + 'simu' (default) --> enter t in simulation units; return time at snap in simulation units. + + Return (snap number, time at this snap). + ''' + snaps = self.snap + try: + snaps[0] + except TypeError: + raise TypeError('expected self.snap (={}) to be a list. You can set it via self.set_snap()'.format(snaps)) + units = units.lower() + VALIDUNITS = ('s', 'si', 'cgs', 'simu') + assert units in VALIDUNITS, 'expected units (={}) to be one of {}'.format(repr(units), VALIDUNITS) + if units in ('s', 'si', 'cgs'): + u_t = self.uni.u_t # == self.uni.usi_t. time [simu units] * u_t = time [seconds]. + else: + u_t = 1 + t_get = t / u_t # time [simu units] + idxmin = np.argmin(np.abs(self.time - t_get)) + return snaps[idxmin], self.time[idxmin] * u_t + + def set_snap_time(self, t, units='simu', snaps=None, snap=None): + '''set self.snap to the snap which is closest to time t. + + units: 's', 'si', 'cgs', or 'simu' (default). + 's', 'si', 'cgs' --> enter t in seconds; return time at snap in seconds. + 'simu' (default) --> enter t in simulation units; return time at snap in simulation units. + snaps: None (default) or list of snaps. + None --> use self.snap for list of snaps to choose from. + list --> use snaps for list of snaps to choose from. + self.set_snap_time(t, ..., snaps=SNAPLIST) is equivalent to: + self.set_snap(SNAPLIST); self.set_snap_time(t, ...) + snap: alias for snaps kwarg. (Ignore snap if snaps is also entered, though.) + + Return (snap number, time at this snap). + ''' + + snaps = snaps if (snaps is not None) else snap + if snaps is not None: + self.set_snap(snaps) + try: + result_snap, result_time = self.get_snap_at_time(t, units=units) + except TypeError: + raise TypeError('expected self.snap to be a list, or snaps=list_of_snaps input to function.') + self.set_snap(result_snap) + return (result_snap, result_time) + + def get_lmin(self): + '''return smallest length resolvable for each direction ['x', 'y', 'z']. + result is in [simu. length units]. Multiply by self.uni.usi_l to convert to SI. + + return 1 (instead of 0) for any direction with number of points < 2. + ''' + def _dxmin(x): + dx1d = getattr(self, 'd'+x+'1d') + if len(dx1d) == 1: + return 1 + else: + return dx1d.min() + return np.array([_dxmin(x) for x in AXES]) + + def get_kmax(self): + '''return largest value of each component of wavevector resolvable by self. + I.e. returns [max kx, max ky, max kz]. + result is in [1/ simu. length units]. Divide by self.uni.usi_l to convert to SI. + ''' + return 2 * np.pi / self.get_lmin() + + def get_unit_vector(self, var, mean=False, **kw__get_var): + '''return unit vector of var. [varx, vary, varz]/|var|.''' + varx = self.get_var(var+'x', **kw__get_var) + vary = self.get_var(var+'y', **kw__get_var) + varz = self.get_var(var+'z', **kw__get_var) + varmag = self.get_var('mod'+var, **kw__get_var) + if mean: + varx, vary, varz, varmag = varx.mean(), vary.mean(), varz.mean(), varmag.mean() + return np.array([varx, vary, varz]) / varmag + + def write_mesh_file(self, meshfile='untitled_mesh.mesh', u_l=None): + '''writes mesh to meshfilename. + mesh will be the mesh implied by self, + using values for x, y, z, dx1d, dy1d, dz1d, indexed by iix, iiy, iiz. + + u_l: None, or a number + cgs length units (length [simulation units] * u_l = length [cm]), + for whoever will be reading the meshfile. + None -> use length units of self. + + Returns abspath to generated meshfile. + ''' + if not meshfile.endswith('.mesh'): + meshfile += '.mesh' + if u_l is None: + scaling = 1.0 + else: + scaling = self.uni.u_l / u_l + kw_x = {x: getattr(self, x) * scaling for x in AXES} + kw_dx = {'d'+x: getattr(self, 'd'+x+'1d') / scaling for x in AXES} + kw_nx = {'n'+x: getattr(self, x+'Length') for x in AXES} + kw_mesh = {**kw_x, **kw_nx, **kw_dx} + Create_new_br_files().write_mesh(**kw_mesh, meshfile=meshfile) + return os.path.abspath(meshfile) + + write_meshfile = write_mesh_file # alias + + def get_coords(self, units='si', axes=None, mode=None): + '''returns dict of coords, with keys ['x', 'y', 'z', 't']. + units: + 'si' (default) -> [meters] for x,y,z; [seconds] for t. + 'cgs' -> [cm] for x,y,z; [seconds] for t. + 'simu' -> [simulation units] for all coords. + if axes is not None: + instead of returning a dict, return coords for the axes provided, in the order listed. + axes can be provided in either of these formats: + strings: 'x', 'y', 'z', 't'. + ints: 0 , 1 , 2 , 3 . + For example: + c = self.get_coords() + c['y'], c['t'] == self.get_coords(axes=('y', 'z')) + c['z'], c['x'], c['y'] == self.get_coords(axes='zxy') + mode: alias for units. (for backwards compatibility) + if entered, ignore units kwarg; use mode instead. + ''' + if mode is None: + mode = units + mode = mode.lower() + VALIDMODES = ('si', 'cgs', 'simu') + assert mode in VALIDMODES, "Invalid mode ({})! Expected one of {}".format(repr(mode), VALIDMODES) + if mode == 'si': + u_l = self.uni.usi_l + u_t = self.uni.usi_t + elif mode == 'cgs': + u_l = self.uni.u_l + u_t = self.uni.u_t + else: # mode == 'simu' + u_l = 1 + u_t = 1 + x, y, z = (self_x * u_l for self_x in (self.x, self.y, self.z)) + t = self.time * u_t + result = dict(x=x, y=y, z=z, t=t) + if axes is not None: + AXES_LOOKUP = {'x': 'x', 0: 'x', 'y': 'y', 1: 'y', 'z': 'z', 2: 'z', 't': 't', 3: 't'} + result = tuple(result[AXES_LOOKUP[axis]] for axis in axes) + return result + + def get_coord(self, axis, units=None): + '''gets coord for the given axis, in the given unit system. + axis: string ('x', 'y', 'z', 't') or int (0, 1, 2, 3) + units: None (default) or string ('si', 'cgs', 'simu') ('simu' for 'simulation units') + None --> use self.units_output. + + The result will be an array (possibly with only 1 element). + ''' + if units is None: + units = self.units_output + return self.get_coords(units=units, axes=[axis])[0] + + def coord_grid(self, axes='xyz', units='si', sparse=True, **kw__meshgrid): + '''returns grid of coords for self along the given axes. + + axes: list of strings ('x', 'y', 'z', 't'), or ints (0, 1, 2, 3) + units: string ('si', 'cgs', 'simu') ('simu' for 'simulation units') + sparse: bool. Example: + coord_grid('xyz', sparse=True)[0].shape == (Nx, 1, 1) + coord_grid('xyz', sparse=False)[0].shape == (Nx, Ny, Nz) + + This function basically just calls np.meshgrid, using coords from self.get_coords. + + Example: + xx, yy, zz = self.coord_grid('xyz', sparse=True) + # yy.shape == (1, self.yLength, 1) + # yy[0, i, 0] == self.get_coord('x')[i] + + xx, tt = self.coord_grid('xt', sparse=False) + # xx.shape == (self.xLength, len(self.time)) + # tt.shape == (self.XLength, len(self.time)) + ''' + coords = self.get_coords(axes=axes, units=units) + indexing = kw__meshgrid.pop('indexing', 'ij') # default 'ij' indexing + return np.meshgrid(*coords, sparse=sparse, indexing=indexing, **kw__meshgrid) + + def get_kcoords(self, units='si', axes=None): + '''returns dict of k-space coords, with keys ['kx', 'ky', 'kz'] + coords units are based on mode. + 'si' (default) -> [ 1 / m] + 'cgs' -> [ 1 / cm] + 'simu' -> [ 1 / simulation unit length] + if axes is not None: + instead of returning a dict, return coords for the axes provided, in the order listed. + axes can be provided in either of these formats: + strings: 'x', 'y', 'z' + ints: 0 , 1 , 2 + ''' + # units + units = units.lower() + assert units in ('si', 'cgs', 'simu') + u_l = {'si': self.uni.usi_l, 'cgs': self.uni.u_l, 'simu': 1}[units] + # axes bookkeeping + if axes is None: + axes = AXES + return_dict = True + else: + AXES_LOOKUP = {'x': 'x', 0: 'x', 'y': 'y', 1: 'y', 'z': 'z', 2: 'z'} + axes = [AXES_LOOKUP[x] for x in axes] + return_dict = False + result = {f'k{x}': getattr(self, f'k{x}') for x in axes} # get k + result = {key: val / u_l for key, val in result.items()} # convert units + # return + if return_dict: + return result + else: + return [result[f'k{x}'] for x in axes] + + def get_extent(self, axes, units='si'): + '''use plt.imshow(extent=get_extent()) to make a 2D plot in x,y,z,t coords. + (Be careful if coords are unevenly spaced; imshow assumes even spacing.) + units: 'si' (default), 'cgs', or 'simu' + unit system for result + axes: None, strings (e.g. ('x', 'z') or 'xz'), or list of indices (e.g. (0, 2)) + which axes to get the extent for. + first axis will be the plot's x axis; second will be the plot's y axis. + E.g. axes='yz' means 'y' as the horizontal axis, 'z' as the vertical axis. + ''' + assert len(axes) == 2, f"require exactly 2 axes for get_extent, but got {len(axes)}" + x, y = self.get_coords(units=units, axes=axes) + return tools.extent(x, y) + + def get_kextent(self, axes=None, units='si'): + '''use plt.imshow(extent=get_kextent()) to make a plot in k-space. + units: 'si' (default), 'cgs', or 'simu' + unit system for result + axes: None, strings (e.g. ('x', 'z') or 'xz'), or list of indices (e.g. (0, 2)) + which axes to get the extent for. + if None, use obj._latest_fft_axes (see helita.sim.load_arithmetic_quantities.get_fft_quant) + first axis will be the plot's x axis; second will be the plot's y axis. + E.g. axes='yz' means 'y' as the horizontal axis, 'z' as the vertical axis. + ''' + if axes is None: + try: + axes = self._latest_fft_axes + except AttributeError: + errmsg = "self._latest_fft_axes not set; maybe you meant to get a quant from " +\ + "FFT_QUANT first? Use self.vardoc('FFT_QUANT') to see list of options." + raise AttributeError(errmsg) from None + assert len(axes) == 2, f"require exactly 2 axes for get_kextent, but got {len(axes)}" + kx, ky = self.get_kcoords(units=units, axes=axes) + return tools.extent(kx, ky) + + if file_memory.DEBUG_MEMORY_LEAK: + def __del__(self): + print('deleted {}'.format(self), flush=True) + + +#################### +# LOCATING SNAPS # +#################### + +SnapStuff = collections.namedtuple('SnapStuff', ('snapname', 'snaps')) + + +def get_snapstuff(dd=None): + '''return (get_snapname(), available_snaps()). + dd: None or BifrostData object. + None -> do operations locally. + else -> cd to dd.fdir, first. + ''' + snapname = get_snapname(dd=dd) + snaps = get_snaps(snapname=snapname, dd=dd) + return SnapStuff(snapname=snapname, snaps=snaps) + + +snapstuff = get_snapstuff # alias + + +def get_snapname(dd=None): + '''gets snapname by reading it from local mhd.in, or dd.snapname if dd is provided.''' + if dd is None: + mhdin_ascii = read_idl_ascii('mhd.in') + return mhdin_ascii['snapname'] + else: + return dd.snapname + + +snapname = get_snapname # alias + + +def get_snaps(dd=None, snapname=None): + '''list available snap numbers. + Does look for: snapname_*.idl, snapname.idl (i.e. snap 0) + Doesn't look for: .pan, .scr, .aux files. + snapname: None (default) or str + snapname parameter from mhd.in. If None, get snapname. + if dd is not None, look in dd.fdir. + ''' + with tools.EnterDirectory(_get_dd_fdir(dd)): + snapname = snapname if snapname is not None else get_snapname() + snaps = [_snap_to_N(f, snapname) for f in os.listdir()] + snaps = [s for s in snaps if s is not None] + snaps = sorted(snaps) + return snaps + + +snaps = get_snaps # alias +available_snaps = get_snaps # alias +list_snaps = get_snaps # alias + + +def snaps_info(dd=None, snapname=None, snaps=None): + '''returns string with length of snaps, as well as min and max. + if snaps is None, lookup all available snaps. + ''' + if snaps is None: + snaps = get_snaps(dd=dd, snapname=snapname) + return 'There are {} snaps, from {} (min) to {} (max)'.format(len(snaps), min(snaps), max(snaps)) + + +def get_snap_shifted(dd=None, shift=0, snapname=None, snap=None): + '''returns snap's number for snap at index (current_snap_index + shift). + Must provide dd or snap, so we can figure out current_snap_index. + ''' + snaps = list(get_snaps(dd=dd, snapname=snapname)) + snap_here = snap if snap is not None else dd.get_snap_here() + i_here = snaps.index(snap_here) + i_result = i_here + shift + if i_result < 0: + if shift == -1: + raise ValueError(f'No snap found prior to snap={snap_here}') + else: + raise ValueError(f'No snap found {abs(shift)} prior to snap={snap_here}') + elif i_result >= len(snaps): + if shift == 1: + raise ValueError(f'No snap found after snap={snap_here}') + else: + raise ValueError(f'No snap found {abs(shift)} after snap={snap_here}') + else: + return snaps[i_result] + + +def get_snap_prev(dd=None, snapname=None, snap=None): + '''returns previous available snap's number. TODO: implement more efficiently. + Must provide dd or snap, so we can figure out the snap here, first. + ''' + return get_snap_shifted(dd=dd, shift=-1, snapname=snapname, snap=snap) + + +def get_snap_next(dd=None, snapname=None, snap=None): + '''returns next available snap's number. TODO: implement more efficiently. + Must provide dd or snap, so we can figure out the snap here, first. + ''' + return get_snap_shifted(dd=dd, shift=+1, snapname=snapname, snap=snap) + + +def _get_dd_fdir(dd=None): + '''return dd.fdir if dd is not None, else os.curdir.''' + if dd is not None: + fdir = dd.fdir + else: + fdir = os.curdir + return fdir + + +def _snap_to_N(name, base, sep='_', ext='.idl'): + '''returns N as number given snapname (and basename) if possible, else None. + for all strings in exclude, if name contains string, return None. + E.g. _snap_to_N('s_075.idl', 's') == 75 + E.g. _snap_to_N('s.idl', 's') == 0 + E.g. _snap_to_N('notasnap', 's') == None + ''' + if not name.startswith(base): + return None + namext = os.path.splitext(name) + if namext[1] != ext: + return None + elif namext[0] == base: + return 0 + else: + try: + snapN = int(namext[0][len(base+sep):]) + except ValueError: + return None + else: + return snapN + + +def _N_to_snapstr(N): + '''return string representing snap number N.''' + if N == 0: + return '' + else: + assert tools.is_integer(N), f"snap values must be integers! (snap={N})" + return '_%03i' % N + + +# include methods (and some aliases) for getting snaps in BifrostData object +BifrostData.get_snapstuff = get_snapstuff +BifrostData.get_snapname = get_snapname +BifrostData.available_snaps = available_snaps +BifrostData.get_snaps = get_snaps +BifrostData.get_snap_prev = get_snap_prev +BifrostData.get_snap_next = get_snap_next +BifrostData.snaps_info = snaps_info + + +#################### +# WRITING SNAPS # +#################### + +def write_br_snap(rootname, r, px, py, pz, e, bx, by, bz): + nx, ny, nz = r.shape + data = np.memmap(rootname, dtype='float32', mode='w+', order='f', shape=(nx, ny, nz, 8)) + data[..., 0] = r + data[..., 1] = px + data[..., 2] = py + data[..., 3] = pz + data[..., 4] = e + data[..., 5] = bx + data[..., 6] = by + data[..., 7] = bz + data.flush() + + +def paramfile_br_update(infile, outfile, new_values): + ''' Updates a given number of fields with values on a bifrost.idl file. + These are given in a dictionary: fvalues = {field: value}. + Reads from infile and writes into outfile.''' + out = open(outfile, 'w') + with open(infile) as fin: + for line in fin: + if line[0] == ';': + out.write(line) + elif line.find('=') < 0: + out.write(line) + else: + ss = line.split('=')[0] + ssv = ss.strip().upper() + if ssv in list(new_values.keys()): + out.write('%s= %s\n' % (ss, str(new_values[ssv]))) + else: + out.write(line) + return class Create_new_br_files: @@ -1820,6 +2125,8 @@ def write_mesh(self, x=None, y=None, z=None, nx=None, ny=None, nz=None, dx=None, dy=None, dz=None, meshfile="newmesh.mesh"): """ Writes mesh to ascii file. + + The meshfile units are simulation units for length (or 1/length, for derivatives). """ def __xxdn(f): ''' @@ -1908,203 +2215,116 @@ def __ddxxdn(f, dx=None): f.write(str(getattr(self, 'n' + p)) + "\n") f.write(" ".join(map("{:.5f}".format, getattr(self, p))) + "\n") f.write(" ".join(map("{:.5f}".format, xmdn)) + "\n") - f.write(" ".join(map("{:.5f}".format, 1. / dxidxup)) + "\n") - f.write(" ".join(map("{:.5f}".format, 1. / dxidxdn)) + "\n") + f.write(" ".join(map("{:.5f}".format, 1.0/dxidxup)) + "\n") + f.write(" ".join(map("{:.5f}".format, 1.0/dxidxdn)) + "\n") f.close() -def polar2cartesian(r, t, grid, x, y, order=3): - ''' - Converts polar grid to cartesian grid - ''' - - X, Y = np.meshgrid(x, y) - - new_r = np.sqrt(X * X + Y * Y) - new_t = np.arctan2(X, Y) - - ir = interpolate.interp1d(r, np.arange(len(r)), bounds_error=False) - it = interpolate.interp1d(t, np.arange(len(t))) - - new_ir = ir(new_r.ravel()) - new_it = it(new_t.ravel()) +############ +# UNITS # +############ - new_ir[new_r.ravel() > r.max()] = len(r) - 1 - new_ir[new_r.ravel() < r.min()] = 0 +class BifrostUnits(units.HelitaUnits): + '''stores units as attributes. - return map_coordinates(grid, np.array([new_ir, new_it]), - order=order).reshape(new_r.shape) + units starting with 'u_' are in cgs. starting with 'usi_' are in SI. + Convert to these units by multiplying data by this factor. + Example: + r = obj.get_var('r') # r = mass density / (simulation units) + rcgs = r * obj.uni.u_r # rcgs = mass density / (cgs units, i.e. (g * cm^-3)) + rsi = r * obj.uni.usi_r # rsi = mass density / (si units, i.e. (kg * m^-3)) + all units are uniquely determined by the following minimal set of units: + (length, time, mass density, gamma) -def cartesian2polar(x, y, grid, r, t, order=3): + you can access documentation on the units themselves via: + self.help(). (for BifrostData object obj, do obj.uni.help()) + this documentation is not very detailed, but at least tells you + which physical quantity the units are for. ''' - Converts cartesian grid to polar grid - ''' - - R, T = np.meshgrid(r, t) - - new_x = R * np.cos(T) - new_y = R * np.sin(T) - - ix = interpolate.interp1d(x, np.arange(len(x)), bounds_error=False) - iy = interpolate.interp1d(y, np.arange(len(y)), bounds_error=False) - new_ix = ix(new_x.ravel()) - new_iy = iy(new_y.ravel()) - - new_ix[new_x.ravel() > x.max()] = len(x) - 1 - new_ix[new_x.ravel() < x.min()] = 0 - - new_iy[new_y.ravel() > y.max()] = len(y) - 1 - new_iy[new_y.ravel() < y.min()] = 0 + def __init__(self, filename='mhd.in', fdir='./', verbose=True, base_units=None, **kw__super_init): + '''get units from file (by reading values of u_l, u_t, u_r, gamma). + + filename: str; name of file. Default 'mhd.in' + fdir: str; directory of file. Default './' + verbose: True (default) or False + True -> if we use default value for a base unit because + we can't find its value otherwise, print warning. + base_units: None (default), dict, or list + None -> ignore this keyword. + dict -> if contains any of the keys: u_l, u_t, u_r, gamma, + initialize the corresponding unit to the value found. + if base_units contains ALL of those keys, IGNORE file. + list -> provides value for u_l, u_t, u_r, gamma; in that order. + ''' + DEFAULT_UNITS = dict(u_l=1.0e8, u_t=1.0e2, u_r=1.0e-7, gamma=1.667) + base_to_use = dict() # << here we will put the u_l, u_t, u_r, gamma to actually use. + _n_base_set = 0 # number of base units set (i.e. assigned in base_to_use) - return map_coordinates(grid, np.array([new_ix, new_iy]), - order=order).reshape(new_x.shape) + # setup units from base_units, if applicable + if base_units is not None: + try: + base_units.items() + except AttributeError: # base_units is a list + for i, val in enumerate(base_units): + base_to_use[self.BASE_UNITS[i]] = val + _n_base_set += 1 + else: + for key, val in base_units.items(): + if key in DEFAULT_UNITS.keys(): + base_to_use[key] = val + _n_base_set += 1 + elif verbose: + print(('(WWW) the key {} is not a base unit', + ' so it was ignored').format(key)) + + # setup units from file (or defaults), if still necessary. + if _n_base_set != len(DEFAULT_UNITS): + if filename is None: + file_exists = False + else: + file = os.path.join(fdir, filename) + file_exists = os.path.isfile(file) + if file_exists: + # file exists -> set units using file. + self.params = read_idl_ascii(file, firstime=True) + + def setup_unit(key): + if base_to_use.get(key, None) is not None: + return + # else: + try: + value = self.params[key] + except Exception: + value = DEFAULT_UNITS[key] + if verbose: + printstr = ("(WWW) the file '{file}' does not contain '{unit}'. " + "Default Solar Bifrost {unit}={value} has been selected.") + print(printstr.format(file=file, unit=key, value=value)) + base_to_use[key] = value + + for unit in DEFAULT_UNITS.keys(): + setup_unit(unit) + else: + # file does not exist -> setup default units. + units_to_set = {unit: DEFAULT_UNITS[unit] for unit in DEFAULT_UNITS.keys() + if getattr(self, unit, None) is None} + if verbose: + print("(WWW) selected file '{file}' is not available.".format(file=filename), + "Setting the following Default Solar Bifrost units: ", units_to_set) + for key, value in units_to_set.items(): + base_to_use[key] = value + # initialize using instructions from HelitaUnits (see helita.sim.units.py) + super().__init__(**base_to_use, verbose=verbose, **kw__super_init) -class Bifrost_units(object): - def __init__(self, filename='mhd.in', fdir='./'): - import scipy.constants as const - from astropy import constants as aconst - from astropy import units +Bifrost_units = BifrostUnits # alias (required for historical compatibility) - if os.path.isfile(os.path.join(fdir, filename)): - self.params = read_idl_ascii(os.path.join(fdir, filename)) - try: - self.u_l = self.params['u_l'] - self.u_t = self.params['u_t'] - self.u_r = self.params['u_r'] - # --- ideal gas - self.gamma = self.params['gamma'] - - except: - print('(WWW) the filename does not have u_l, u_t and u_r.' - ' Default Solar Bifrost units has been selected') - self.u_l = 1.0e8 - self.u_t = 1.0e2 - self.u_r = 1.0e-7 - # --- ideal gas - self.gamma = self.params['gamma'] - else: - print('(WWW) selected filename is not available.' - ' Default Solar Bifrost units has been selected') - self.u_l = 1.0e8 - self.u_t = 1.0e2 - self.u_r = 1.0e-7 - # --- ideal gas - self.gamma = 1.667 - - self.u_u = self.u_l / self.u_t - self.u_p = self.u_r * (self.u_l / self.u_t)**2 # Pressure [dyne/cm2] - # Rosseland opacity [cm2/g] - self.u_kr = 1 / (self.u_r * self.u_l) - self.u_ee = self.u_u**2 - self.u_e = self.u_r * self.u_ee - # Box therm. em. [erg/(s ster cm2)] - self.u_te = self.u_e / self.u_t * self.u_l - self.mu = 0.8 - self.u_n = 3.00e+10 # Density number n_0 * 1/cm^3 - # 1.380658E-16 Boltzman's cst. [erg/K] - self.k_b = aconst.k_B.to_value('erg/K') - self.m_h = const.m_n / const.gram # 1.674927471e-24 - self.m_he = 6.65e-24 - self.m_p = self.mu * self.m_h # Mass per particle - self.m_e = aconst.m_e.to_value('g') - self.u_tg = (self.m_h / self.k_b) * self.u_ee - self.u_tge = (self.m_e / self.k_b) * self.u_ee - self.pi = const.pi - self.u_b = self.u_u * np.sqrt(4. * self.pi * self.u_r) - - self.usi_l = self.u_l * const.centi # 1e6 - self.usi_r = self.u_r * const.gram # 1e-4 - self.usi_u = self.usi_l / self.u_t - self.usi_p = self.usi_r * (self.usi_l / self.u_t)**2 # Pressure [N/m2] - # Rosseland opacity [m2/kg] - self.usi_kr = 1 / (self.usi_r * self.usi_l) - self.usi_ee = self.usi_u**2 - self.usi_e = self.usi_r * self.usi_ee - self.usi_te = self.usi_e / self.u_t * \ - self.usi_l # Box therm. em. [J/(s ster m2)] - self.ksi_b = aconst.k_B.to_value( - 'J/K') # Boltzman's cst. [J/K] - self.msi_h = const.m_n # 1.674927471e-27 - self.msi_he = 6.65e-27 - self.msi_p = self.mu * self.msi_h # Mass per particle - self.usi_tg = (self.msi_h / self.ksi_b) * self.usi_ee - self.msi_e = const.m_e # 9.1093897e-31 - self.usi_b = self.u_b * 1e-4 - - # Solar gravity - self.gsun = (aconst.GM_sun / aconst.R_sun ** - 2).cgs.value # solar surface gravity - - # --- physical constants and other useful quantities - self.clight = aconst.c.to_value('cm/s') # Speed of light [cm/s] - self.hplanck = aconst.h.to_value('erg s') # Planck's constant [erg s] - self.kboltzmann = aconst.k_B.to_value( - 'erg/K') # Boltzman's cst. [erg/K] - self.amu = aconst.u.to_value('g') # Atomic mass unit [g] - self.amusi = aconst.u.to_value('kg') # Atomic mass unit [kg] - self.m_electron = aconst.m_e.to_value('g') # Electron mass [g] - self.q_electron = aconst.e.esu.value # Electron charge [esu] - self.qsi_electron = aconst.e.value # Electron charge [C] - self.rbohr = aconst.a0.to_value('cm') # bohr radius [cm] - self.e_rydberg = aconst.Ryd.to_value( - 'erg', equivalencies=units.spectral()) - self.eh2diss = 4.478007 # H2 dissociation energy [eV] - self.pie2_mec = (np.pi * aconst.e.esu ** 2 / - (aconst.m_e * aconst.c)).cgs.value - # 5.670400e-5 Stefan-Boltzmann constant [erg/(cm^2 s K^4)] - self.stefboltz = aconst.sigma_sb.cgs.value - self.mion = self.m_h # Ion mass [g] - self.r_ei = 1.44E-7 # e^2 / kT = 1.44x10^-7 T^-1 cm - - # --- Unit conversions - self.ev_to_erg = units.eV.to('erg') - self.ev_to_j = units.eV.to('J') - self.nm_to_m = const.nano # 1.0e-09 - self.cm_to_m = const.centi # 1.0e-02 - self.km_to_m = const.kilo # 1.0e+03 - self.erg_to_joule = const.erg # 1.0e-07 - self.g_to_kg = const.gram # 1.0e-03 - self.micron_to_nm = units.um.to('nm') - self.megabarn_to_m2 = units.Mbarn.to('m2') - self.atm_to_pa = const.atm # 1.0135e+05 atm to pascal (n/m^2) - self.dyne_cm2_to_pascal = (units.dyne / units.cm**2).to('Pa') - self.k_to_ev = units.K.to( - 'eV', equivalencies=units.temperature_energy()) - self.ev_to_k = 1. / self.k_to_ev - self.ergd2wd = 0.1 - self.grph = 2.27e-24 - self.permsi = aconst.eps0.value # Permitivitty in vacuum (F/m) - self.cross_p = 1.59880e-14 - self.cross_he = 9.10010e-17 - - # Dissociation energy of H2 [eV] from Barklem & Collet (2016) - self.di = self.eh2diss - - self.atomdic = {'h': 1, 'he': 2, 'c': 3, 'n': 4, 'o': 5, 'ne': 6, 'na': 7, - 'mg': 8, 'al': 9, 'si': 10, 's': 11, 'k': 12, 'ca': 13, - 'cr': 14, 'fe': 15, 'ni': 16} - self.abnddic = {'h': 12.0, 'he': 11.0, 'c': 8.55, 'n': 7.93, 'o': 8.77, - 'ne': 8.51, 'na': 6.18, 'mg': 7.48, 'al': 6.4, 'si': 7.55, - 's': 5.21, 'k': 5.05, 'ca': 6.33, 'cr': 5.47, 'fe': 7.5, - 'ni': 5.08} - self.weightdic = {'h': 1.008, 'he': 4.003, 'c': 12.01, 'n': 14.01, - 'o': 16.00, 'ne': 20.18, 'na': 23.00, 'mg': 24.32, - 'al': 26.97, 'si': 28.06, 's': 32.06, 'k': 39.10, - 'ca': 40.08, 'cr': 52.01, 'fe': 55.85, 'ni': 58.69} - self.xidic = {'h': 13.595, 'he': 24.580, 'c': 11.256, 'n': 14.529, - 'o': 13.614, 'ne': 21.559, 'na': 5.138, 'mg': 7.644, - 'al': 5.984, 'si': 8.149, 's': 10.357, 'k': 4.339, - 'ca': 6.111, 'cr': 6.763, 'fe': 7.896, 'ni': 7.633} - self.u0dic = {'h': 2., 'he': 1., 'c': 9.3, 'n': 4., 'o': 8.7, - 'ne': 1., 'na': 2., 'mg': 1., 'al': 5.9, 'si': 9.5, 's': 8.1, - 'k': 2.1, 'ca': 1.2, 'cr': 10.5, 'fe': 26.9, 'ni': 29.5} - self.u1dic = {'h': 1., 'he': 2., 'c': 6., 'n': 9., 'o': 4., 'ne': 5., - 'na': 1., 'mg': 2., 'al': 1., 'si': 5.7, 's': 4.1, 'k': 1., - 'ca': 2.2, 'cr': 7.2, 'fe': 42.7, 'ni': 10.5} +##################### +# CROSS SECTIONS # +##################### class Rhoeetab: @@ -2130,7 +2350,7 @@ def __init__(self, tabfile=None, fdir='.', big_endian=False, dtype='f4', try: tmp = find_first_match("mhd.in", fdir) except IndexError: - tmp = '' + tmp = None print("(WWW) init: no .idl or mhd.in files found." + "Units set to 'standard' Bifrost units.") self.uni = Bifrost_units(filename=tmp, fdir=fdir) @@ -2141,9 +2361,10 @@ def __init__(self, tabfile=None, fdir='.', big_endian=False, dtype='f4', def read_tab_file(self, tabfile): ''' Reads tabparam.in file, populates parameters. ''' - self.params = read_idl_ascii(tabfile) + self.params = read_idl_ascii(tabfile, obj=self) if self.verbose: - print(('*** Read parameters from ' + tabfile)) + print(('*** Read parameters from ' + tabfile), whsp*4, end="\r", + flush=True) p = self.params # construct lnrho array self.lnrho = np.linspace( @@ -2169,7 +2390,8 @@ def load_eos_table(self, eostabfile=None): self.lnkr = table[:, :, 3] self.eosload = True if self.verbose: - print(('*** Read EOS table from ' + eostabfile)) + print('*** Read EOS table from ' + eostabfile, whsp*4, end="\r", + flush=True) def load_ent_table(self, eostabfile=None): ''' @@ -2210,9 +2432,10 @@ def load_rad_table(self, radtabfile=None): self.opatab = table[:, :, :, 2] self.radload = True if self.verbose: - print(('*** Read rad table from ' + radtabfile)) + print('*** Read rad table from ' + radtabfile, whsp*4, end="\r", + flush=True) - def get_table(self, out='ne', bine=None, order=1): + def get_table(self, out='ne', bin=None, order=1): qdict = {'ne': 'lnne', 'tg': 'tgt', 'pg': 'lnpg', 'kr': 'lnkr', 'eps': 'epstab', 'opa': 'opatab', 'temp': 'temtab', @@ -2230,8 +2453,8 @@ def get_table(self, out='ne', bine=None, order=1): quant = getattr(self, qdict[out]) if out in ['opa eps temp'.split()]: if bin is None: - print(("(WWW) tab_interp: radiation bin not set," - " using first bin.")) + print("(WWW) tab_interp: radiation bin not set," + " using first bin.") bin = 0 quant = quant[..., bin] return quant @@ -2327,21 +2550,21 @@ class Opatab: """ def __init__(self, tabname=None, fdir='.', dtype='f4', - verbose=True, lambd=100.0): + verbose=True, lambd=100.0, big_endian=False): self.fdir = fdir self.dtype = dtype self.verbose = verbose self.big_endian = big_endian self.lambd = lambd self.radload = False - self.teinit = 4.0 - self.dte = 0.1 + self.teinit = 3.0 + self.dte = 0.05 + self.nte = 100 + self.ch_tabname = "chianti" # alternatives are e.g. 'mazzotta' and others found in Chianti # read table file and calculate parameters if tabname is None: tabname = os.path.join(fdir, 'ionization.dat') self.tabname = tabname - # load table(s) - self.load_opa_table() def hopac(self): ghi = 0.99 @@ -2385,17 +2608,19 @@ def load_opa_table(self, tabname=None): self.ionhei = table[:, :, 2] self.opaload = True if self.verbose: - print('*** Read EOS table from ' + tabname) + print('*** Read EOS table from ' + tabname, whsp*4, end="\r", + flush=True) def tg_tab_interp(self, order=1): ''' Interpolates the opa table to same format as tg table. ''' self.load_opa1d_table() - rhoeetab = Rhoeetab(fdir=self.fdir) - tgTable = rhoeetab.get_table('tg') + #rhoeetab = Rhoeetab(fdir=self.fdir) + #tgTable = rhoeetab.get_table('tg') + tgTable = np.linspace(self.teinit, self.teinit + self.dte*self.nte, self.nte) # translate to table coordinates - x = (np.log10(tgTable) - self.teinit) / self.dte + x = ((tgTable) - self.teinit) / self.dte # interpolate quantity self.ionh = map_coordinates(self.ionh1d, [x], order=order) self.ionhe = map_coordinates(self.ionhe1d, [x], order=order) @@ -2410,30 +2635,53 @@ def h_he_absorb(self, lambd=None): if lambd is not None: self.lambd = lambd self.tg_tab_interp() - ion_h = self.ionh - ion_he = self.ionhe - ion_hei = self.ionhei - ohi = self.hopac() - ohei = self.heiopac() - oheii = self.heiiopac() - arr = (1 - ion_h) * ohi + rhe * ((1 - ion_he - ion_hei) * - ohei + ion_he * oheii) + arr = (self.ionh) * self.hopac() + rhe * ((1 - self.ionhei - (1-self.ionhei-self.ionhe)) * + self.heiopac() + (self.ionhei) * self.heiiopac()) + #ion_h = self.ionh + #ion_he = self.ionhe + #ion_hei = self.ionhei + #ohi = self.hopac() + #ohei = self.heiopac() + #oheii = self.heiiopac() + # arr = (1 - ion_h) * ohi + rhe * ((1 - ion_he - ion_hei) * + # ohei + ion_he * oheii) arr[arr < 0] = 0 return arr - def load_opa1d_table(self, tabname=None): + def load_opa1d_table(self, tabname='chianti', tgmin=3.0, tgmax=9.0, ntg=121): ''' Loads ionizationstate table. ''' + import ChiantiPy.core as ch if tabname is None: tabname = '%s/%s' % (self.fdir, 'ionization1d.dat') - dtype = ('>' if self.big_endian else '<') + self.dtype - table = np.memmap(tabname, mode='r', shape=(41, 3), dtype=dtype, - order='F') - self.ionh1d = table[:, 0] - self.ionhe1d = table[:, 1] - self.ionhei1d = table[:, 2] - self.opaload = True + if tabname == '%s/%s' % (self.fdir, 'ionization1d.dat'): + dtype = ('>' if self.big_endian else '<') + self.dtype + table = np.memmap(tabname, mode='r', shape=(41, 3), dtype=dtype, + order='F') + self.ionh1d = table[:, 0] + self.ionhe1d = table[:, 1] + self.ionhei1d = table[:, 2] + self.opaload = True + else: # Chianti table + import ChiantiPy.core as ch + if self.verbose: + print('*** Reading Chianti table', whsp*4, end="\r", + flush=True) + h = ch.Ioneq.ioneq(1) + h.load(tabname) + temp = np.linspace(tgmin, tgmax, ntg) + h.calculate(10**temp) + logte = np.log10(h.Temperature) + self.dte = logte[1]-logte[0] + self.teinit = logte[0] + self.nte = np.size(logte) + self.ionh1d = h.Ioneq[0, :] + he = ch.Ioneq.ioneq(2) + he.load(tabname) + self.ionhe1d = he.Ioneq[0, :] + self.ionhei1d = he.Ioneq[1, :] if self.verbose: - print('*** Read OPA table from ' + tabname) + print('*** Read OPA table from ' + tabname, whsp*4, end="\r", + flush=True) class Cross_sect: @@ -2450,52 +2698,69 @@ class Cross_sect: If True, will print out more diagnostic messages dtype - string, optional Data type for reading variables. Default is 32 bit float. + kelvin - bool (default True) + Whether to load data in Kelvin. (uses eV otherwise) Examples -------- - >>> a = cross_sect(['h-h-data2.txt','h-h2-data.txt'], fdir="/data/cb24bih") + a = cross_sect(['h-h-data2.txt','h-h2-data.txt'], fdir="/data/cb24bih") """ - def __init__(self, cross_tab=None, fdir='.', dtype='f4', verbose=True): + def __init__(self, cross_tab=None, fdir=os.curdir, dtype='f4', verbose=None, kelvin=True, obj=None): ''' Loads cross section tables and calculates collision frequencies and ambipolar diffusion. - ''' + parameters: + cross_tab: None or list of strings + None -> use default cross tab list of strings. + else -> treat each string as the name of a cross tab file. + fdir: str (default '.') + directory of files (prepend to each filename in cross_tab). + dtype: default 'f4' + sets self.dtype. aside from that, internally does NOTHING. + verbose: None (default) or bool. + controls verbosity. presently, internally does NOTHING. + if None, use obj.verbose if possible, else use False (default) + kelvin - bool (default True) + Whether to load data in Kelvin. (uses eV otherwise) + obj: None (default) or an object + None -> does nothing; ignore this parameter. + else -> improve time-efficiency by saving data from cross_tab files + into memory of obj (save in obj._memory_read_cross_txt). + ''' self.fdir = fdir self.dtype = dtype + if verbose is None: + verbose = False if obj is None else getattr(obj, 'verbose', False) self.verbose = verbose + self.kelvin = kelvin + self.units = {True: 'K', False: 'eV'}[self.kelvin] + # save pointer to obj. Use weakref to help ensure we don't create a circular reference. + self.obj = (lambda: None) if (obj is None) else weakref.ref(obj) # self.obj() returns obj. # read table file and calculate parameters - cross_txt_list = ['h-h-data2.txt', 'h-h2-data.txt', 'he-he.txt', - 'e-h.txt', 'e-he.txt', 'h2_molecule_bc.txt', - 'h2_molecule_pj.txt', 'p-h-elast.txt', 'p-he.txt', - 'proton-h2-data.txt'] - self.cross_tab_list = {} - counter = 0 if cross_tab is None: - for icross_txt in cross_txt_list: - os.path.isfile('%s/%s' % (fdir, icross_txt)) - self.cross_tab_list[counter] = '%s/%s' % (fdir, icross_txt) - counter += 1 - else: - for icross_txt in cross_tab: - os.path.isfile('%s/%s' % (fdir, icross_txt)) - self.cross_tab_list[counter] = '%s/%s' % (fdir, icross_txt) - counter += 1 + cross_tab = ['h-h-data2.txt', 'h-h2-data.txt', 'he-he.txt', + 'e-h.txt', 'e-he.txt', 'h2_molecule_bc.txt', + 'h2_molecule_pj.txt', 'p-h-elast.txt', 'p-he.txt', + 'proton-h2-data.txt'] + self._cross_tab_strs = cross_tab + self.cross_tab_list = {} + for i, cross_txt in enumerate(cross_tab): + self.cross_tab_list[i] = os.path.join(fdir, cross_txt) + # load table(s) - self.load_cross_tables() + self.load_cross_tables(firstime=True) - def load_cross_tables(self): + def load_cross_tables(self, firstime=False): ''' Collects the information in the cross table files. ''' - uni = Bifrost_units() - self.cross_tab = {} - + self.cross_tab = dict() for itab in range(len(self.cross_tab_list)): - self.cross_tab[itab] = read_cross_txt(self.cross_tab_list[itab]) - self.cross_tab[itab]['tg'] *= uni.ev_to_k + self.cross_tab[itab] = read_cross_txt(self.cross_tab_list[itab], firstime=firstime, + obj=self.obj(), kelvin=self.kelvin) def tab_interp(self, tg, itab=0, out='el', order=1): ''' Interpolates the cross section tables in the simulated domain. @@ -2512,21 +2777,111 @@ def tab_interp(self, tg, itab=0, out='el', order=1): if out in ['se el vi mt'.split()] and not self.load_cross_tables: raise ValueError("(EEE) tab_interp: EOS table not loaded!") - finterp = interpolate.interp1d(self.cross_tab[itab]['tg'], + finterp = interpolate.interp1d(np.log(self.cross_tab[itab]['tg']), self.cross_tab[itab][out]) - tgreg = tg * 1.0 + tgreg = np.array(tg, copy=True) max_temp = np.max(self.cross_tab[itab]['tg']) - tgreg[np.where(tg > max_temp)] = max_temp + tgreg[tg > max_temp] = max_temp min_temp = np.min(self.cross_tab[itab]['tg']) - tgreg[np.where(tg < min_temp)] = min_temp + tgreg[tg < min_temp] = min_temp + + return finterp(np.log(tgreg)) + + def __call__(self, tg, *args, **kwargs): + '''alias for self.tab_interp.''' + return self.tab_interp(tg, *args, **kwargs) + + def __repr__(self): + return '{} == {}'.format(object.__repr__(self), str(self)) + + def __str__(self): + return "Cross_sect(cross_tab={}, fdir='{}')".format(self._cross_tab_strs, self.fdir) + + +def cross_sect_for_obj(obj=None): + '''return function which returns Cross_sect with self.obj=obj. + obj: None (default) or an object + None -> does nothing; ignore this parameter. + else -> improve time-efficiency by saving data from cross_tab files + into memory of obj (save in obj._memory_read_cross_txt). + Also, use fdir=obj.fdir, unless fdir is entered explicitly. + ''' + @functools.wraps(Cross_sect) + def _init_cross_sect(cross_tab=None, fdir=None, *args__Cross_sect, **kw__Cross_sect): + if fdir is None: + fdir = getattr(obj, 'fdir', '.') + return Cross_sect(cross_tab, fdir, *args__Cross_sect, **kw__Cross_sect, obj=obj) + return _init_cross_sect + +## Tools for making cross section table such that colfreq is independent of temperature ## + + +def constant_colfreq_cross(tg0, Q0, tg=range(1000, 400000, 100), T_to_eV=lambda T: T / 11604): + '''makes values for constant collision frequency vs temperature cross section table. + tg0, Q0: + enforce Q(tg0) = Q0. + tg: array of values for temperature. + (recommend: 1000 to 400000, with intervals of 100.) + T_to_eV: function + T_to_eV(T) --> value in eV. + + colfreq = consts * Q(tg) * sqrt(tg). + For constant colfreq: + Q(tg1) sqrt(tg1) = Q(tg0) sqrt(tg0) + + returns dict of arrays. keys: 'E' (for energy in eV), 'T' (for temperature), 'Q' (for cross) + ''' + tg = np.asarray(tg) + E = T_to_eV(tg) + Q = Q0 * np.sqrt(tg0) / np.sqrt(tg) + return dict(E=E, T=tg, Q=Q) + + +def cross_table_str(E, T, Q, comment=''): + '''make a string for the table for cross sections. + put comment at top of file if provided. + ''' + header = '' + if len(comment) > 0: + if not comment.startswith(';'): + comment = ';' + comment + header += comment + '\n' + header += '\n'.join(["", + "; 1 atomic unit of square distance = 2.80e-17 cm^2", + "; 1eV = 11604K", + "", + "2.80e-17", + "", + "", + "; E T Q11 ", + "; (eV) (K) (a.u.)", + "", + "", + ]) + lines = [] + for e, t, q in zip(E, T, Q): + lines.append('{:.6f} {:d} {:.3f}'.format(e, t, q)) + return header + '\n'.join(lines) + + +def constant_colfreq_cross_table_str(tg0, Q0, **kw): + '''make a string for a cross section table which will give constant collision frequency (vs tg).''' + if 'comment' in kw: + comment = kw.pop('comment') + else: + comment = '\n'.join(['; This table provides cross sections such that', + '; the collision frequency will be independent of temperature,', + '; assuming the functional form colfreq proportional to sqrt(T).', + ]) + ccc = constant_colfreq_cross(tg0, Q0, **kw) + result = cross_table_str(**ccc, comment=comment) + return result - return finterp(tgreg) ########### # TOOLS # ########### - def bifrost2d_to_rh15d(snaps, outfile, file_root, meshfile, fdir, writeB=False, sx=slice(None), sz=slice(None), desc=None): """ @@ -2591,154 +2946,79 @@ def bifrost2d_to_rh15d(snaps, outfile, file_root, meshfile, fdir, writeB=False, y = snaps z = data.z[sz] * (-ul) - rdt = data.r.dtype + data.r.dtype + # cstagger.init_stagger(data.nz, data.dx, data.dy, data.z.astype(rdt), + # data.zdn.astype(rdt), data.dzidzup.astype(rdt), + # data.dzidzdn.astype(rdt)) for i, s in enumerate(snaps): data.set_snap(s) tgas[:, i] = np.squeeze(data.tg)[sx, sz] - rho = data.r[sx, sz] - vz[:, i] = np.squeeze(stagger.zup(data.pz)[sx, sz] / rho) * (-uv) + rho = np.squeeze(data.r)[sx, sz] + vz[:, i] = np.squeeze(do_stagger(data.pz, 'zup', obj=data))[sx, sz] / rho * (-uv) if writeB: Bx[:, i] = np.squeeze(data.bx)[sx, sz] * ub By[:, i] = np.squeeze(-data.by)[sx, sz] * ub Bz[:, i] = np.squeeze(-data.bz)[sx, sz] * ub - ne[:, i] = np.squeeze(data.get_electron_density( - sx=sx, sz=sz)).to_value('1/m3') - nH[:, :, i] = np.squeeze( - data.get_hydrogen_pops(sx=sx, sz=sz)).to_value('1/m3') + ne[:, i] = np.squeeze(data.get_electron_density(sx=sx, sz=sz)).to_value('1/m3') + nH[:, :, i] = np.squeeze(data.get_hydrogen_pops(sx=sx, sz=sz)).to_value('1/m3') rh15d.make_xarray_atmos(outfile, tgas, vz, z, nH=nH, ne=ne, x=x, y=y, append=False, Bx=Bx, By=By, Bz=Bz, desc=desc, snap=snaps[0]) -def read_idl_ascii(filename): - ''' Reads IDL-formatted (command style) ascii file into dictionary ''' - li = 0 +@file_memory.remember_and_recall('_memory_read_idl_ascii') +def read_idl_ascii(filename, firstime=False): + ''' Reads IDL-formatted (command style) ascii file into dictionary. + if obj is not None, remember the result and restore it if ever reading the same exact file again. + ''' + li = -1 params = {} # go through the file, add stuff to dictionary with open(filename) as fp: for line in fp: + li += 1 # ignore empty lines and comments - line = line.strip() - if not line: - li += 1 + line, _, comment = line.partition(';') + key, _, value = line.partition('=') + key = key.strip().lower() + value = value.strip() + if len(key) == 0: + continue # this was a blank line. + elif len(value) == 0: + if firstime: + print('(WWW) read_params: line %i is invalid, skipping' % li) continue - if line[0] == ';': - li += 1 - continue - line = line.split(';')[0].split('=') - if len(line) != 2: - print(('(WWW) read_params: line %i is invalid, skipping' % li)) - li += 1 - continue - # force lowercase because IDL is case-insensitive - key = line[0].strip().lower() - value = line[1].strip() - # instead of the insecure 'exec', find out the datatypes - if value.find('"') >= 0: - # string type - value = value.strip('"') - try: - if (value.find(' ') >= 0): - value2 = np.array(value.split()) - if ((value2[0].upper().find('E') >= 0) or ( - value2[0].find('.') >= 0)): - value = value2.astype(np.float) - - except: - value = value - elif (value.find("'") >= 0): - value = value.strip("'") - try: - if (value.find(' ') >= 0): - value2 = np.array(value.split()) - if ((value2[0].upper().find('E') >= 0) or ( - value2[0].find('.') >= 0)): - value = value2.astype(np.float) - except: - value = value - elif (value.lower() in ['.false.', '.true.']): - # bool type + # --- evaluate value --- # + # allow '.false.' or '.true.' for bools + if (value.lower() in ['.false.', '.true.']): value = False if value.lower() == '.false.' else True - elif (value.find('[') >= 0) and (value.find(']') >= 0): - # list type - value = eval(value) - elif (value.upper().find('E') >= 0) or (value.find('.') >= 0): - # float type - value = float(value) else: - # int type + # safely evaluate any other type of value try: - value = int(value) + value = ast.literal_eval(value) except Exception: - print('(WWW) read_idl_ascii: could not find datatype in ' - 'line %i, skipping' % li) - li += 1 - continue + # failed to evaluate. Might be string, or might be int with leading 0's. + try: + value = int(value) + except ValueError: + # failed to convert to int; interpret value as string. + pass # leave value as string without evaluating it. params[key] = value return params -def ionpopulation(rho, nel, tg, elem='h', lvl='1', dens=True): - - print('ionpopulation: reading species %s and level %s' % (elem, lvl)) - fdir = '.' - try: - tmp = find_first_match("*.idl", fdir) - except IndexError: - try: - tmp = find_first_match("*idl.scr", fdir) - except IndexError: - try: - tmp = find_first_match("mhd.in", fdir) - except IndexError: - tmp = '' - print("(WWW) init: no .idl or mhd.in files found." + - "Units set to 'standard' Bifrost units.") - uni = Bifrost_units(filename=tmp) - - totconst = 2.0 * uni.pi * uni.m_electron * uni.k_b / \ - uni.hplanck / uni.hplanck - abnd = np.zeros(len(uni.abnddic)) - count = 0 - - for ibnd in uni.abnddic.keys(): - abnddic = 10**(uni.abnddic[ibnd] - 12.0) - abnd[count] = abnddic * uni.weightdic[ibnd] * uni.amu - count += 1 - - abnd = abnd / np.sum(abnd) - phit = (totconst * tg)**(1.5) * 2.0 / nel - kbtg = uni.ev_to_erg / uni.k_b / tg - n1_n0 = phit * uni.u1dic[elem] / uni.u0dic[elem] * np.exp( - - uni.xidic[elem] * kbtg) - c2 = abnd[uni.atomdic[elem] - 1] * rho - ifracpos = n1_n0 / (1.0 + n1_n0) - - if dens: - if lvl == '1': - return (1.0 - ifracpos) * c2 - else: - return ifracpos * c2 - - else: - if lvl == '1': - return (1.0 - ifracpos) * c2 * (uni.u_r / (uni.weightdic[elem] * - uni.amu)) - else: - return ifracpos * c2 * (uni.u_r / (uni.weightdic[elem] * - uni.amu)) - - -def read_cross_txt(filename): - ''' Reads IDL-formatted (command style) ascii file into dictionary ''' +@file_memory.remember_and_recall('_memory_read_cross_txt', kw_mem=['kelvin']) +def read_cross_txt(filename, firstime=False, kelvin=True): + ''' Reads IDL-formatted (command style) ascii file into dictionary. + tg will be converted to Kelvin, unless kelvin==False. + ''' li = 0 params = {} - count = 0 # go through the file, add stuff to dictionary with open(filename) as fp: for line in fp: @@ -2751,13 +3031,21 @@ def read_cross_txt(filename): li += 1 continue line = line.split(';')[0].split() + if (len(line) == 1): + params['crossunits'] = float(line[0].strip()) + li += 1 + continue + elif not ('crossunits' in params.keys()): + print('(WWW) read_cross: line %i is invalid, missing crossunits, file %s' % (li, filename)) + if (len(line) < 2): - print(('(WWW) read_params: line %i is invalid, skipping' % li)) + if (firstime): + print('(WWW) read_cross: line %i is invalid, skipping, file %s' % (li, filename)) li += 1 continue # force lowercase because IDL is case-insensitive temp = line[0].strip() - cross = line[1].strip() + cross = line[2].strip() # instead of the insecure 'exec', find out the datatypes if ((temp.upper().find('E') >= 0) or (temp.find('.') >= 0)): @@ -2768,11 +3056,12 @@ def read_cross_txt(filename): try: temp = int(temp) except Exception: - print('(WWW) read_idl_ascii: could not find datatype in ' - 'line %i, skipping' % li) + if (firstime): + print('(WWW) read_cross: could not find datatype in ' + 'line %i, skipping' % li) li += 1 continue - if not('tg' in params.keys()): + if not ('tg' in params.keys()): params['tg'] = temp else: params['tg'] = np.append(params['tg'], temp) @@ -2785,11 +3074,12 @@ def read_cross_txt(filename): try: cross = int(cross) except Exception: - print('(WWW) read_idl_ascii: could not find datatype in ' - 'line %i, skipping' % li) + if (firstime): + print('(WWW) read_cross: could not find datatype in ' + 'line %i, skipping' % li) li += 1 continue - if not('el' in params.keys()): + if not ('el' in params.keys()): params['el'] = cross else: params['el'] = np.append(params['el'], cross) @@ -2805,11 +3095,12 @@ def read_cross_txt(filename): try: cross = int(cross) except Exception: - print('(WWW) read_idl_ascii: could not find datatype' - 'in line %i, skipping' % li) + if (firstime): + print('(WWW) read_cross: could not find datatype' + 'in line %i, skipping' % li) li += 1 continue - if not('mt' in params.keys()): + if not ('mt' in params.keys()): params['mt'] = cross else: params['mt'] = np.append(params['mt'], cross) @@ -2825,8 +3116,9 @@ def read_cross_txt(filename): try: cross = int(cross) except Exception: - print('(WWW) read_idl_ascii: could not find datatype' - 'in line %i, skipping' % li) + if (firstime): + print('(WWW) read_cross: could not find datatype' + 'in line %i, skipping' % li) li += 1 continue if not hasattr(params, 'vi'): @@ -2845,8 +3137,9 @@ def read_cross_txt(filename): try: cross = int(cross) except Exception: - print('(WWW) read_idl_ascii: could not find datatype' - 'in line %i, skipping' % li) + if (firstime): + print('(WWW) read_cross: could not find datatype' + 'in line %i, skipping' % li) li += 1 continue if not hasattr(params, 'se'): @@ -2855,6 +3148,10 @@ def read_cross_txt(filename): params['se'] = np.append(params['se'], cross) li += 1 + # convert to kelvin + if kelvin: + params['tg'] *= Bifrost_units(verbose=False).ev_to_k + return params @@ -2906,50 +3203,6 @@ def subs2grph(subsfile): return calc_grph(ab, am) -def threadQuantity(task, numThreads, *args): - # split arg arrays - args = list(args) - - for index in range(np.shape(args)[0]): - args[index] = np.array_split(args[index], numThreads) - - # make threadpool, task = task, with zipped args - pool = ThreadPool(processes=numThreads) - result = np.concatenate(pool.starmap(task, zip(*args))) - return result - - -def threadQuantity_y(task, numThreads, *args): - # split arg arrays - args = list(args) - - for index in range(np.shape(args)[0]): - if len(np.shape(args[index])) == 3: - args[index] = np.array_split(args[index], numThreads, axis=1) - else: - args[index] = np.array_split(args[index], numThreads) - # make threadpool, task = task, with zipped args - pool = ThreadPool(processes=numThreads) - result = np.concatenate(pool.starmap(task, zip(*args)), axis=1) - return result - - -def threadQuantity_z(task, numThreads, *args): - # split arg arrays - args = list(args) - - for index in range(np.shape(args)[0]): - if len(np.shape(args[index])) == 3: - args[index] = np.array_split(args[index], numThreads, axis=2) - else: - args[index] = np.array_split(args[index], numThreads) - - # make threadpool, task = task, with zipped args - pool = ThreadPool(processes=numThreads) - result = np.concatenate(pool.starmap(task, zip(*args)), axis=2) - return result - - def find_first_match(name, path, incl_path=False): ''' This will find the first match, diff --git a/helita/sim/cipmocct.py b/helita/sim/cipmocct.py new file mode 100644 index 00000000..aeefd8ec --- /dev/null +++ b/helita/sim/cipmocct.py @@ -0,0 +1,391 @@ +import os + +import numpy as np +from scipy import interpolate +from scipy.io import readsav as rsav +from scipy.ndimage import rotate + +from . import document_vars +from .load_arithmetic_quantities import * +from .load_noeos_quantities import * +from .load_quantities import * +from .tools import * + + +class Cipmocct: + """ + Class to read cipmocct atmosphere + + Parameters + ---------- + fdir : str, optional + Directory with snapshots. + rootname : str + rootname of the file (wihtout params or vars). + verbose : bool, optional + If True, will print more information. + snap : integer + snapshot number + """ + + def __init__(self, rootname, snap, fdir='./', sel_units='cgs', verbose=True): + + self.rootname = rootname + self.fdir = fdir + self.snap = snap + self.sel_units = sel_units + self.verbose = verbose + self.uni = Cipmocct_units() + + params = rsav(os.path.join(self.fdir, 'params_'+rootname+'.sav')) + if snap == None: + self.x = params['x1'].copy() + self.y = params['x3'].copy() + self.z = params['time'].copy() + self.nx = len(params['x1']) + self.ny = len(params['x3']) + self.nz = len(params['time']) + + if self.sel_units == 'cgs': + self.x *= self.uni.uni['l'] + self.y *= self.uni.uni['l'] + + self.time = params['time'] # No uniform (array) + self.varfile = rsav(os.path.join(self.fdir, 'variables_'+self.rootname+'.sav')) + else: + + self.x = params['x1'].copy() + self.y = params['x3'].copy() + self.z = params['x2'].copy() + + self.nx = len(params['x1']) + self.ny = len(params['x3']) + self.nz = len(params['x2']) + + if self.sel_units == 'cgs': + self.x *= self.uni.uni['l'] + self.y *= self.uni.uni['l'] + self.z *= self.uni.uni['l'] + + self.time = params['time'] # No uniform (array) + + if self.nx > 1: + self.dx1d = np.gradient(self.x) + self.dx = self.dx1d + else: + self.dx1d = np.zeros(self.nx) + self.dx = self.dx1d + + if self.ny > 1: + self.dy1d = np.gradient(self.y) + self.dy = self.dy1d + else: + self.dy1d = np.zeros(self.ny) + self.dy = self.dy1d + + if self.nz > 1: + self.dz1d = np.gradient(self.z) + self.dz = self.dz1d + else: + self.dz1d = np.zeros(self.nz) + self.dz = self.dz1d + + self.transunits = False + + self.cstagop = False # This will not allow to use cstagger from Bifrost in load + self.hion = False # This will not allow to use HION from Bifrost in load + self.genvar() + + document_vars.create_vardict(self) + document_vars.set_vardocs(self) + + def get_var(self, var, *args, snap=None, iix=None, iiy=None, iiz=None, layout=None, **kargs): + ''' + Reads the variables from a snapshot (it). + + Parameters + ---------- + var - string + Name of the variable to read. Must be Bifrost internal names. + snap - integer, optional + Snapshot number to read. By default reads the loaded snapshot; + if a different number is requested, will load that snapshot. + Axes: + ----- + z-axis is along the loop + x and y axes are perperdicular to the loop + + Variable list: + -------------- + ro_cube -- Density (multipy by self.uni['rho'] to get in g/cm^3) + te_cube -- Temperature (multipy by self.uni['tg'] to get in K) + vx_cube -- component x of the velocity (multipy by self.uni['u'] to get in cm/s) + vy_cube -- component y of the velocity (multipy by self.uni['u'] to get in cm/s) + vz_cube -- component z of the velocity (multipy by self.uni['u'] to get in cm/s) + bx_cube -- component x of the magnetic field (multipy by self.uni['b'] to get in G) + by_cube -- component y of the magnetic field (multipy by self.uni['b'] to get in G) + bz_cube -- component z of the magnetic field (multipy by self.uni['b'] to get in G) + ''' + + if snap != None: + self.snap = snap + + if var in self.varn.keys(): + varname = self.varn[var] + else: + varname = var + + try: + + if self.sel_units == 'cgs': + varu = var.replace('x', '') + varu = varu.replace('y', '') + varu = varu.replace('z', '') + if (var in self.varn.keys()) and (varu in self.uni.uni.keys()): + cgsunits = self.uni.uni[varu] + else: + cgsunits = 1.0 + else: + cgsunits = 1.0 + + if self.snap == None: + varfile = self.varfile + self.data = np.transpose(varfile[varname]) * cgsunits + + else: + itname = '{:04d}'.format(self.snap) + varfile = rsav(self.fdir+'vars_'+self.rootname+'_'+itname+'.sav') + self.data = np.transpose(varfile[varname]) * cgsunits + + +# varfile = rsav(os.path.join(self.fdir,self.rootname+'_'+itname+'.sav')) + + except: + # Loading quantities + if self.verbose: + print('Loading composite variable', end="\r", flush=True) + self.data = load_noeos_quantities(self, var, **kargs) + + if np.shape(self.data) == (): + self.data = load_quantities(self, var, PLASMA_QUANT='', CYCL_RES='', + COLFRE_QUANT='', COLFRI_QUANT='', IONP_QUANT='', + EOSTAB_QUANT='', TAU_QUANT='', DEBYE_LN_QUANT='', + CROSTAB_QUANT='', COULOMB_COL_QUANT='', AMB_QUANT='', + HALL_QUANT='', BATTERY_QUANT='', SPITZER_QUANT='', + KAPPA_QUANT='', GYROF_QUANT='', WAVE_QUANT='', + FLUX_QUANT='', CURRENT_QUANT='', COLCOU_QUANT='', + COLCOUMS_QUANT='', COLFREMX_QUANT='', **kargs) + + # Loading arithmetic quantities + if np.shape(self.data) == (): + if self.verbose: + print('Loading arithmetic variable', end="\r", flush=True) + self.data = load_arithmetic_quantities(self, var, **kargs) + + if document_vars.creating_vardict(self): + return None + elif var == '': + print(help(self.get_var)) + print('VARIABLES USING CGS OR GENERIC NOMENCLATURE') + for ii in self.varn: + print('use ', ii, ' for ', self.varn[ii]) + if hasattr(self, 'vardict'): + self.vardocs() + + return None + + return self.data + + def genvar(self): + ''' + Dictionary of original variables which will allow to convert to cgs. + ''' + self.varn = {} + self.varn['rho'] = 'ro_cube' + self.varn['tg'] = 'te_cube' + self.varn['ux'] = 'vx_cube' + self.varn['uy'] = 'vz_cube' + self.varn['uz'] = 'vy_cube' + self.varn['bx'] = 'bx_cube' + self.varn['by'] = 'bz_cube' + self.varn['bz'] = 'by_cube' + + def trans2comm(self, varname, snap=None, angle=0, loop=3): + ''' + Transform the domain into a "common" format. All arrays will be 3D. The 3rd axis + is: + + - for 3D atmospheres: the vertical axis + - for loop type atmospheres: along the loop + - for 1D atmosphere: the unique dimension is the 3rd axis. + At least one extra dimension needs to be created artifically. + + All of them should obey the right hand rule + + In all of them, the vectors (velocity, magnetic field etc) away from the Sun. + + If applies, z=0 near the photosphere. + + Units: everything is in cgs. + + If an array is reverse, do ndarray.copy(), otherwise pytorch will complain. + + INPUT: + varname - string + snap - integer + angle - real (degrees). Any number -90 to 90, default = 45 + ''' + + self.sel_units = 'cgs' + + self.trans2commaxes(loop) + + if angle != 0: + if varname[-1] in ['x']: + varx = self.get_var(varname, snap=snap) + vary = self.get_var(varname[0]+'y', snap=snap) + var = varx * np.cos(angle/90.0*np.pi/2.0) - vary * np.sin(angle/90.0*np.pi/2.0) + elif varname[-1] in ['y']: + vary = self.get_var(varname, snap=snap) + varx = self.get_var(varname[0]+'x', snap=snap) + var = vary * np.cos(angle/90.0*np.pi/2.0) + varx * np.sin(angle/90.0*np.pi/2.0) + else: + var = self.get_var(varname, snap=snap) + var = rotate(var, angle=angle, reshape=False, mode='nearest', axes=(0, 1)) + + else: + var = self.get_var(varname, snap=snap) + + if loop != None: + if varname[-1] in ['x']: + var = self.make_loop(var, loop) + varz = self.get_var(varname[0]+'z', snap=snap) + varz = self.make_loop(varz, loop) + xx, zz = np.meshgrid(self.x, self.z) + aa = np.angle(xx+1j*zz) + for iiy, iy in enumerate(self.y): + var[:, iiy, :] = var[:, iiy, :] * np.cos(aa.T) - varz[:, iiy, :] * np.sin(aa.T) + elif varname[-1] in ['z']: + var = self.make_loop(var, loop) + varx = self.get_var(varname[0]+'x', snap=snap) + varx = self.make_loop(varx, loop) + xx, zz = np.meshgrid(self.x, self.z) + aa = np.angle(xx+1j*zz) + for iiy, iy in enumerate(self.y): + var[:, iiy, :] = var[:, iiy, :] * np.cos(aa.T) + varx[:, iiy, :] * np.sin(aa.T) + else: + var = self.make_loop(var, loop) + + return var + + def make_loop(self, var, loop): + R = np.max(self.z*2)/np.pi/2. + rad = self.x_orig+np.max(self.x_loop)-np.max(self.x_orig)/2 + angl = self.z_orig / R + var_new = np.zeros((self.nx, self.ny, self.nz)) + iiy0 = np.argmin(np.abs(self.y_orig)) + + for iiy, iy in enumerate(self.y): + temp = var[:, iiy*2+iiy0, :] + data = polar2cartesian(rad, angl, temp, self.z, self.x) + var_new[:, iiy, :] = data + return var_new + + def trans2commaxes(self, loop=3): + + if self.transunits == False: + self.x_orig = self.x + self.y_orig = self.y + self.z_orig = self.z + if loop != None: + R = np.max(self.z*2)/np.pi/2. + self.x_loop = np.linspace(R*np.cos([np.pi/loop]), R, + int((R-R*np.cos([np.pi/loop]))/2/np.min(self.dx1d))) + self.z_loop = np.linspace(0, R*np.sin([np.pi/loop]), + int(R*np.sin([np.pi/loop])/2/np.min(self.dx1d))) + + self.x = self.x_loop.squeeze() + self.z = self.z_loop.squeeze() + self.y = self.y[np.argmin(np.abs(self.y))+1::2] + + self.dx1d = np.gradient(self.x) + self.dy1d = np.gradient(self.y) + self.dz1d = np.gradient(self.z) + self.nx = np.size(self.x) + self.ny = np.size(self.y) + self.nz = np.size(self.z) + + self.transunits = True + + def trans2noncommaxes(self): + + if self.transunits == True: + self.x = self.x_orig + self.y = self.y_orig + self.z = self.z_orig + self.dx1d = np.gradient(self.x) + self.dy1d = np.gradient(self.y) + self.dz1d = np.gradient(self.z) + self.nx = np.size(self.x) + self.ny = np.size(self.y) + self.nz = np.size(self.z) + # opposite to the previous function + self.transunits = False + + +class Cipmocct_units(object): + + def __init__(self, verbose=False): + import scipy.constants as const + from astropy import constants as aconst + + ''' + Units and constants in cgs + ''' + self.uni = {} + self.verbose = verbose + self.uni['gamma'] = 5./3. + self.uni['proton'] = 1.67262158e-24 # g + self.uni['tg'] = 1.0e6 # K + self.uni['fact'] = 2 + self.uni['l'] = 1000.*self.uni['fact']*1.0e5 # for having a 2000 km wide loop + self.uni['n'] = 1.0e9 # cm^-3 + + # Units and constants in SI + globalvars(self) + + self.uni['rho'] = self.uni['n'] * self.uni['proton'] / 2. # gr cm^-3 + self.uni['u'] = np.sqrt(2*self.uni['gamma']*self.k_b/self.m_p*self.uni['tg']) # cm/s + + self.uni['b'] = self.uni['u']*np.sqrt(self.uni['rho']) # Gauss + self.uni['j'] = self.uni['b']/self.uni['l'] * aconst.c.to_value('cm/s') # current density + self.uni['t'] = self.uni['l']/self.uni['u'] # seconds + + convertcsgsi(self) + self.unisi['ee'] = self.unisi['u']**2 + self.unisi['e'] = self.unisi['rho'] * self.unisi['ee'] + self.unisi['pg'] = self.unisi['rho'] * (self.unisi['l'] / self.unisi['t'])**2 + self.unisi['u'] = self.uni['u'] * const.centi # m/s + + +def polar2cartesian(r, t, grid, x, y, order=3): + ''' + Converts polar grid to cartesian grid + ''' + from scipy import ndimage + + X, Y = np.meshgrid(x, y) + + new_r = np.sqrt(X * X + Y * Y) + new_t = np.arctan2(X, Y) + + ir = interpolate.interp1d(r, np.arange(len(r)), bounds_error=False, fill_value=0.0) + it = interpolate.interp1d(t, np.arange(len(t)), bounds_error=False, fill_value=0.0) + new_ir = ir(new_r.ravel()) + new_it = it(new_t.ravel()) + + new_ir[new_r.ravel() > r.max()] = len(r) - 1 + new_ir[new_r.ravel() < r.min()] = 0 + + return ndimage.map_coordinates(grid, np.array([new_ir, new_it]), + order=order).reshape(new_r.shape) diff --git a/helita/sim/cstagger.pyx b/helita/sim/cstagger.pyx new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/helita/sim/cstagger.pyx @@ -0,0 +1 @@ + diff --git a/helita/sim/document_vars.py b/helita/sim/document_vars.py new file mode 100644 index 00000000..b61b1dd6 --- /dev/null +++ b/helita/sim/document_vars.py @@ -0,0 +1,919 @@ +""" +Created by Sam Evans on Apr 3 2021 + +Purpose: helper functions for documentation of variables. + +create vardict which looks like: +vardict = { + meta_quant_1 : # example: "mf_quantities" + { + QUANTDOC : 'meta_quant_1 description', + TYPE_QUANT_1 : # example: "GLOBAL_QUANT" + { + QUANTDOC : 'TYPE_QUANT_1 description', + # '_DOC_QUANT' : 'global variables; calculated by looping through species', # example + mq1tq1_var_1 : 'mq1tq1_var_1 description', + # 'nel' : 'electron number density [cm^-3]', # example + mq1tq1_var_2 : 'mq1tq1_var_2 description', + ... + }, + TYPE_QUANT_2 : # example: "PLASMA_QUANT" + { + QUANTDOC : 'TYPE_QUANT_2 description', + mq1tq2_var_1 : 'mq1tq2_var_1 description', + ... + }, + ... + }, + meta_quant_2 : # example "arithmetic_quantities" + { + QUANTDOC : 'meta_quant_2 description', + TYPE_QUANT_1 : + { + QUANTDOC : 'TYPE_QUANT_2 description', + mq2tq1_var_1 : 'mq2tq1_var_1 description', + ... + }, + ... + }, + ... +} + +""" + +import copy # for deepcopy for QuantTree +# import built-ins +import math # for pretty strings +import functools +import collections + +# import internal modules +from . import units # not used heavily; just here for setting defaults, and setting obj.get_units +from . import tools + +VARDICT = 'vardict' # name of attribute (of obj) which should store documentation about vars. +NONEDOC = '(not yet documented)' # default documentation if none is provided. +QUANTDOC = '_DOC_QUANT' # key for dd.vardict[TYPE_QUANT] containing doc for what TYPE_QUANT means. +NFLUID = 'nfluid' # key which stores number of fluids. (e.g. 0 for none; 1 for "uses ifluid but not jfluid". +CREATING_VARDICT = '_creating_vardict' # attribute of obj which tells if we are running get_var('') to create vardict. + +# defaults for quant tracking +# attributes of obj +VARNAME_INPUT = '_varname_input' # stores name of most recent variable which was input to get_var. +QUANT_SELECTED = '_quant_selected' # stores vardict lookup info for the latest quant selected. +QUANTS_SELECTED = '_quants_selected' # stores quant_selected for QUANT_NTRACKING recent quants. +QUANTS_BY_LEVEL = '_quants_by_level' # stores quant_selected by level. +QUANTS_TREE = '_quants_tree' # stores quant_selected as a tree. +QUANT_SELECTION = '_quant_selection' # stores info for latest quant selected; use for hesitant setting. +QUANT_NTRACKING = '_quant_ntracking' # if it exists, sets maxlen for _quants_selected deque. + +# misc +QUANT_TRACKING_N = 1000 # default for number of quant selections to remember. +QUANT_BY_LEVEL_N = 50 # default for number of quant selections to remember at each level. +QUANT_NOT_FOUND = '???' # default for typequant and metaquant when quant is not found. Do not use None. + +# defaults for loading level tracking +# attribute of obj which tells how deep we are into loading a quantity right now. 0 = top level. +LOADING_LEVEL = '_loading_level' + + +HIDE_DECORATOR_TRACEBACKS = True # whether to hide decorators from this file when showing error traceback. + +# global variable which tells which quantity you are setting now. +METAQUANT = None + +''' ----------------------------- create vardict ----------------------------- ''' + + +def set_meta_quant(obj, name, QUANT_DOC=NONEDOC): + '''sets the current "meta_quant". You must use this before starting documentation. + see load_mf_quantities.load_mf_quantities for an example. + + QUANT_DOC is the documentation to put about this metaquant. + for example, in load_mf_quantities.load_mf_quantities, + set_meta_quant('MULTIFLUID_QUANTITIES', 'These are the multiple-fluid quantities.') + + The idea is that the meta_quant will be the same throughout a given load_*_quantities.py file. + ''' + if not hasattr(obj, VARDICT): + setattr(obj, VARDICT, dict()) + vardict = getattr(obj, VARDICT) + + global METAQUANT # allows to edit the value of document_vars.METAQUANT + METAQUANT = name + + if METAQUANT not in vardict.keys(): + vardict[METAQUANT] = dict() + vardict[METAQUANT][QUANTDOC] = QUANT_DOC + + +def vars_documenter(obj, TYPE_QUANT, QUANT_VARS=None, QUANT_DOC=NONEDOC, nfluid=None, rewrite=False, **kw__defaults): + '''function factory; returns function(varname, vardoc, nfluid=None) which writes documentation of var. + The documentation goes to vd['doc'] where vd = obj.vardict[METAQUANT][TYPE_QUANT][varname]. + + Also store vd['nfluid'] = nfluid. + vars_documenter(...,nfluid) -> store as default + f = vars_documenter(); f(var, doc, nfluid) -> store for this var, only. + nfluid = + None -> does not even understand what a "fluid" is. (Use this value outside of load_mf_quantities.py) + Or, if in mf_quantities, None indicates nfluid has not been documented for this var. + 2 -> uses obj.ifluid and obj.jfluid to calculate results. (e.g. 'nu_ij') + 1 -> uses obj.ifluid (and not jfluid) to calculate results. (e.g. 'ux', 'tg') + 0 -> does not use ifluid nor jfluid to calculate results. (e.g. 'bx', 'nel', 'tot_e') + + METAQUANT (i.e. document_vars.METAQUANT) must be set before using vars_documenter; + use document_vars.set_meta_quant() to accomplish this. + Raises ValueError if METAQUANT has not been set. + + if QUANT_VARS is not None: + initialize documentation of all the vars in varnames with vardoc=NONEDOC. + enforce that only vars in QUANT_VARS can be documented (ignore documentation for all vars not in QUANT_DOC). + + if not rewrite, and TYPE_QUANT already in obj.vardict[METAQUANT].keys() (when vars_documenter is called), + instead do nothing and return a function which does nothing. + + kw__defaults become default values for all quants in QUANT_VARS. + Example: + docvar = vars_documenter(obj, typequant, ['var1', 'var2', 'var3'], foo_kwarg='bar') + docvar('var1', 'info about var 1') + docvar('var2', 'info about var 2', foo_kwarg='overwritten') + docvar('var3', 'info about var 3') + Leads to obj.vardict[METAQUANT][typequant] like: + {'var1': {'doc':'info about var 1', 'foo_kwarg':'bar'}, + 'var2': {'doc':'info about var 2', 'foo_kwarg':'overwritten'}, + 'var3': {'doc':'info about var 3', 'foo_kwarg':'bar'}} + + also sets obj.vardict[METAQUANT][TYPE_QUANT][document_vars.QUANTDOC] = QUANT_DOC. + ''' + if METAQUANT is None: + raise ValueError('METAQUANT cannot be None when calling vars_documenter. ' + + 'Use document_vars.set_meta_quant() to set METAQUANT.') + vardict = getattr(obj, VARDICT)[METAQUANT] # sets vardict = obj.vardict[METAQUANT] + write = rewrite + if not TYPE_QUANT in vardict.keys(): + vardict[TYPE_QUANT] = dict() + vardict[TYPE_QUANT][QUANTDOC] = QUANT_DOC + write = True + if write: + # define function (which will be returned) + def document_var(varname, vardoc, nfluid=nfluid, copy=False, **kw__more_info_about_var): + '''puts documentation about var named varname into obj.vardict[TYPE_QUANT]. + copy: bool, default False + False --> ignore this kwarg. + True --> instead of usual behavior, look up vardoc in vardict and copy the info to varname. + Example: + document_var('myvar', 'the documentation for myvar', nfluid=N, **kw_original) + # now, these two options are equivalent: + 1) document_var('myvar_alias', 'the documentation for myvar', nfluid=N, **kw_original) + 2) document_var('myvar_alias', 'myvar', copy=True) + ''' + if (QUANT_VARS is not None) and (varname not in QUANT_VARS): + return + + tqd = vardict[TYPE_QUANT] + if not copy: # case "basic usage" (copy=False) + var_info_dict = {'doc': vardoc, 'nfluid': nfluid, **kw__defaults, **kw__more_info_about_var} + else: # case "aliasing" (copy=True) + var_info_dict = tqd[vardoc] # obj(varname) == obj(args[1]). (but for "basic" case, args[1]==vardoc.) + try: + vd = tqd[varname] # vd = vardict[TYPE_QUANT][varname] (if it already exists) + except KeyError: # else, initialize tqd[varname]: + tqd[varname] = var_info_dict + else: # if vd assignment was successful, set info. + vd.update(var_info_dict) + + # initialize documentation to NONEDOC for var in QUANT_VARS + if QUANT_VARS is not None: + for varname in QUANT_VARS: + document_var(varname, vardoc=NONEDOC, nfluid=nfluid) + + # return document_var function which we defined. + return document_var + else: + # do nothing and return a function which does nothing. + def dont_document_var(varname, vardoc, nfluid=None, **kw): + '''does nothing. + (because obj.vardict[TYPE_QUANT] already existed when vars_documenter was called). + ''' + return + return dont_document_var + + +def create_vardict(obj): + '''call obj.get_var('') but with prints turned off. + Afterwards, obj.vardict will be full of documentation. + + Also, set a few more things (conceptually these belong elsewhere, + but it is convenient to do them here because create_vardict is called in __init__ for all the DataClass objects) : + set obj.gotten_vars() to a function which returns obj._quants_selected. + set obj.got_vars_tree() to a function which returns obj._quants_tree. + set obj.quant_lookup() to a function which returns dict of info about quant, as found in obj.vardict. + ''' + # creat vardict + setattr(obj, CREATING_VARDICT, True) + obj.get_var('') + setattr(obj, CREATING_VARDICT, False) + # set some other useful functions in obj. + + def _make_weak_bound_method(f): + @functools.wraps(f) + def _weak_bound_method(*args, **kwargs): + __tracebackhide__ = HIDE_DECORATOR_TRACEBACKS + return f(obj, *args, **kwargs) # << obj which was passed to create_vardict + return _weak_bound_method + obj.gotten_vars = _make_weak_bound_method(gotten_vars) + obj.got_vars_tree = _make_weak_bound_method(got_vars_tree) + obj.get_quant_info = _make_weak_bound_method(get_quant_info) + obj.get_var_info = obj.get_quant_info # alias + obj.quant_lookup = _make_weak_bound_method(quant_lookup) + obj.get_units = _make_weak_bound_method(units.get_units) + + +def creating_vardict(obj, default=False): + '''return whether obj is currently creating vardict. If unsure, return .''' + return getattr(obj, CREATING_VARDICT, default) + + +''' ----------------------------- search vardict ----------------------------- ''' + + +def _apply_keys(d, keys): + '''result result of successive application of (key for key in in keys) to dict of dicts, d.''' + for key in keys: + d = d[key] + return d + + +search_result = collections.namedtuple('vardict_search_result', ('result', 'type', 'keys')) + + +def search_vardict(vardict, x): + '''search vardict for x. x is the key we are looking for. + + return search_result named tuple. its attributes give (in order): + result: the dict which x is pointing to. + type: None or a string: + None (vardict itself) + 'metaquant' (top-level) # a dict of typequants + 'typequant' (middle-level) # a dict of vars + 'var' (bottom-level) # a dict with keys 'doc' (documentation) and 'nfluid' + keys: the list of keys to apply to vardict to get to result. + when type is None, keys is []; + when type is metaquant, keys is [x] + when type is typequant, keys is [metaquantkey, x] + when type is 'var', keys is [metaquantkey, typequantkey, x] + + return False if failed to find x in vardict. + ''' + v = vardict + if x is None: + return search_result(result=v, type=None, keys=[]) + if x in v.keys(): + return search_result(result=v[x], type='metaquant', keys=[x]) + for metaquant in vardict.keys(): + v = vardict[metaquant] + if not isinstance(v, dict): + continue # skip QUANTDOC + if x in v.keys(): + return search_result(result=v[x], type='typequant', keys=[metaquant, x]) + for metaquant in vardict.keys(): + for typequant in vardict[metaquant].keys(): + v = vardict[metaquant][typequant] + if not isinstance(v, dict): + continue # skip QUANTDOC + if x in v.keys(): + return search_result(result=v[x], type='var', keys=[metaquant, typequant, x]) + return False + + +''' ----------------------------- prettyprint vardict ----------------------------- ''' + +TW = 3 # tabwidth +WS = ' ' # whitespace + + +def _underline(s, underline='-', minlength=0): + '''return underlined s''' + if len(underline.strip()) == 0: + return s + line = underline * math.ceil(max(len(s), minlength)/len(underline)) + return s + '\n' + line + + +def _intro_line(text, length=80): + '''return fancy formatting of text as "intro line".''' + left, right = '(<< ', ' >>)' + length = max(0, (length - len(left) - len(right))) + fmtline = '{:^' + str(length) + '}' # {:^N} makes for line which is N long, and centered. + return (left + fmtline + right).format(text) + + +def _vardocs_var(varname, vd, q=WS*TW*2): + '''docs for vd (var_dict). returns list containing one string, or None (if undocumented)''' + vardoc = vd['doc'] + if vardoc is NONEDOC: + return None + else: + nfluid = vd['nfluid'] + rstr = q + '{:10s}'.format(varname) + ' : ' + if nfluid is not None: + rstr += '(nfluid = {}) '.format(nfluid) + rstr += str(vardoc) + return [rstr] + + +def _vardocs_typequant(typequant_dict, tqd=WS*TW, q=WS*TW*2, ud=WS*TW*3): + '''docs for typequant_dict. returns list of strings, each string is one line.''' + result = [] + if QUANTDOC in typequant_dict.keys(): + s = str(typequant_dict[QUANTDOC]).lstrip().replace('\n', tqd+'\n') + s = s.rstrip() + '\n' # make end have exactly 1 newline. + result += [tqd + s] + undocumented = [] + for varname in (key for key in sorted(typequant_dict.keys()) if key != QUANTDOC): + vd = typequant_dict[varname] + vdv = _vardocs_var(varname, vd, q=q) + if vdv is None: + undocumented += [varname] + else: + result += vdv + if undocumented != []: + result += ['\n' + q + 'existing but undocumented vars:\n' + ud + ', '.join(undocumented)] + return result + + +def _vardocs_metaquant(metaquant_dict, underline='-', + mqd=''*TW, tq=WS*TW, tqd=WS*TW, q=WS*TW*2, ud=WS*TW*3): + '''docs for metaquant_dict. returns list of strings, each string is one line.''' + result = [] + if QUANTDOC in metaquant_dict.keys(): + result += [mqd + str(metaquant_dict[QUANTDOC]).lstrip().replace('\n', mqd+'\n')] + for typequant in (key for key in sorted(metaquant_dict.keys()) if key != QUANTDOC): + result += ['', _underline(tq + typequant, underline)] + typequant_dict = metaquant_dict[typequant] + result += _vardocs_typequant(typequant_dict, tqd=tqd, q=q, ud=ud) + return result + + +def _vardocs_print(result, printout=True): + '''x = '\n'.join(result). if printout, print x. Else, return x.''' + stresult = '\n'.join(result) + if printout: + print(stresult) + else: + return stresult + + +def set_vardocs(obj, printout=True, underline='-', min_mq_underline=80, + mqd=''*TW, tq=WS*TW, tqd=WS*TW, q=WS*TW*2, ud=WS*TW*3): + '''make obj.vardocs be a function which prints vardict in pretty format. + (return string instead if printout is False.) + mqd, tq, tqd are indents for metaquant_doc, typequant, typequant_doc, + q, ud are indents for varname, undocumented vars + + also make obj.vardoc(x) print doc for x, only. + x can be a var, typequant, metaquant, or None (equivalent to vardocs if None). + ''' + def vardocs(printout=True): + '''prettyprint docs. If printout is False, return string instead of printing.''' + result = [ + 'Following is documentation for vars compatible with self.get_var(var).', + _intro_line('Documentation contents available in dictionary form via self.{}'.format(VARDICT)), + _intro_line('Documentation string available via self.vardocs(printout=False)'), + ] + vardict = getattr(obj, VARDICT) + for metaquant in sorted(vardict.keys()): + result += ['', '', _underline(metaquant, underline, minlength=min_mq_underline)] + metaquant_dict = vardict[metaquant] + result += _vardocs_metaquant(metaquant_dict, underline=underline, + mqd=mqd, tq=tq, tqd=tqd, q=q, ud=ud) + return _vardocs_print(result, printout=printout) + + obj.vardocs = vardocs + + def vardoc(x=None, printout=True): + '''prettyprint docs for x. x can be a var, typequant, metaquant, or None. + + default x is None; when x is None, this function is equivalent to vardocs(). + + If printout is False, return string instead of printing. + ''' + search = search_vardict(obj.vardict, x) + if search == False: + result = ["key '{}' does not exist in obj.vardict!".format(x)] + return _vardocs_print(result, printout) + # else: search was successful. + if search.type is None: + return vardocs(printout=printout) + # else: search actually did something nontrivial. + keystr = ''.join(["['{}']".format(key) for key in search.keys]) + result = ['vardoc for {}, accessible via obj.vardict{}'.format(x, keystr)] + if search.type == 'metaquant': + result += _vardocs_metaquant(search.result, underline=underline, + mqd=mqd, tq=tq, tqd=tqd, q=q, ud=ud) + elif search.type == 'typequant': + result += _vardocs_typequant(search.result, tqd=tqd, q=q, ud=ud) + elif search.type == 'var': + vdv = _vardocs_var(x, search.result, q=q) + result += [WS*TW*2 + NONEDOC] if vdv is None else vdv + return _vardocs_print(result, printout) + + obj.vardoc = vardoc + + def _search_vardict(x): + '''searches self.vardict for x. x can be a var, typequant, metaquant, or None.''' + vardict = getattr(obj, VARDICT) + return search_vardict(vardict, x) + + obj.search_vardict = _search_vardict + + +''' ----------------------------- quant tracking ----------------------------- ''' + +QuantInfo = collections.namedtuple('QuantInfo', ('varname', 'quant', 'typequant', 'metaquant', 'level'), + defaults=[None, None, None, None, None]) + + +def setattr_quant_selected(obj, quant, typequant, metaquant=None, varname=None, level=None, delay=False): + '''sets QUANT_SELECTED to QuantInfo(varname, quant, typequant, metaquant, level). + + varname = name of var which was input. + default (if None): getattr(obj, VARNAME_INPUT, None) + quant = name of quant which matches var. + (e.g. '2' maches 'b2'; see get_square from load_arithmetic_quantities.) + typequant = type associated with quant + e.g. 'SQUARE_QUANT' + metaquant = metatype associated with quant + e.g. 'arquantities' + default (if None): METAQUANT (global variable in document_vars module, set in load_..._quantities files). + level = loading_level + i.e. number of layers deep right now in the chain of get_var call(s). + default (if None): getattr(obj, LOADING_LEVEL, 0) + + if metaquant is None, use helita.sim.document_vars.METAQUANT as default. + returns the value in obj._quant_selected. + + if delay, set QUANT_SELECTION instead of QUANT_SELECTED. + (if delay, it is recommended to later call quant_select_selection to update.) + QUANT_SELECTION is maintained by document_vars.quant_tracking_top_level() wrapper + ''' + if varname is None: + varname = getattr(obj, VARNAME_INPUT, None) + if metaquant is None: + metaquant = METAQUANT + if level is None: + level = getattr(obj, LOADING_LEVEL, 0) + info = QuantInfo(varname=varname, quant=quant, typequant=typequant, metaquant=metaquant, level=level) + if delay: + setattr(obj, QUANT_SELECTION, info) + else: + setattr(obj, QUANT_SELECTED, info) + _track_quants_selected(obj, info) + return info + + +def _track_quants_selected(obj, info, maxlen=QUANT_TRACKING_N): + '''updates obj._quants_selected with info. + if _quants_selected attr doesn't exist, make a deque. + + maxlen for deque will be obj._quant_ntracking if it exists; else value of maxlen kwarg. + + Also, updates obj._quants_by_level with info. (same info; different format.) + ''' + # put info into QUANTS_SELECTED + if hasattr(obj, QUANTS_SELECTED): + getattr(obj, QUANTS_SELECTED).appendleft(info) + else: + maxlen = getattr(obj, QUANT_NTRACKING, maxlen) # maxlen kwarg is default value. + setattr(obj, QUANTS_SELECTED, collections.deque([info], maxlen)) + + # put info into QUANTS_BY_LEVEL + loading_level = getattr(obj, LOADING_LEVEL, 0) + if not hasattr(obj, QUANTS_BY_LEVEL): + setattr(obj, QUANTS_BY_LEVEL, + {loading_level: collections.deque([info], maxlen=QUANT_BY_LEVEL_N)}) + else: + qbl_dict = getattr(obj, QUANTS_BY_LEVEL) + try: + qbl_dict[loading_level].appendleft(info) + except KeyError: + qbl_dict[loading_level] = collections.deque([info], maxlen=QUANT_BY_LEVEL_N) + + # put info in QUANTS_TREE + if hasattr(obj, QUANTS_TREE): + getattr(obj, QUANTS_TREE).set_data(info) + + # return QUANTS_SELECTED + return getattr(obj, QUANTS_SELECTED) + + +def select_quant_selection(obj, info_default=None): + '''puts data from QUANT_SELECTION into QUANT_SELECTED. + Also, updates QUANTS_SELECTED with the info. + + Recommended to only use after doing setattr_quant_selected(..., delay=True). + ''' + info = getattr(obj, QUANT_SELECTION, info_default) + setattr(obj, QUANT_SELECTED, info) + _track_quants_selected(obj, info) + return info + + +def quant_tracking_simple(typequant, metaquant=None): + '''returns a function dectorator which turns f(obj, quant, *args, **kwargs) into: + result = f(...) + if result is not None: + obj._quant_selected = QuantInfo(quant, typequant, metaquant) + return result + if metaquant is None, use helita.sim.document_vars.METAQUANT as default. + + Suggested use: + use this wrapper for any get_..._quant functions whose results correspond to + the entire quant. For examples see helita.sim.load_mf_quantities.py (or load_quantities.py). + + do NOT use this wrapper for get_..._quant functions whose results correspond to + only a PART of the quant entered. For example, don't use this for get_square() in + load_arithmetic_quantities, since that function pulls a '2' off the end of quant, + E.g. get_var('b2') --> bx**2 + by**2 + bz**2. + For this case, use settattr_quant_selected. (See load_arithmetic_quantities.py for examples) + ''' + def decorator(f): + @functools.wraps(f) + def f_but_quant_tracking(obj, quant, *args, **kwargs): + __tracebackhide__ = HIDE_DECORATOR_TRACEBACKS + # we need to save original METAQUANT now, because doing f might change METAQUANT. + # (quant_tracking_simple is meant to wrap a function inside a load_..._quantities file, + # so when that function (f) is called, METAQUANT will be the correct value.) + if metaquant is None: + remembered_metaquant = METAQUANT + else: + remembered_metaquant = metaquant + # call f + result = f(obj, quant, *args, **kwargs) + # set quant_selected + if result is not None: + setattr_quant_selected(obj, quant, typequant, remembered_metaquant) + # return result of f + return result + return f_but_quant_tracking + return decorator + + +class QuantTree: + '''use for tree representation of quants. + + Notes: + - level should always be larger for children than for their parents. + ''' + + def __init__(self, data, level=-1): + self.data = data + self.children = [] + self._level = level + self.hide_level = None + + def add_child(self, child, adjusted_level=False): + '''add child to self. + + If child is QuantTree and adjusted_level=True, + instead append a copy of child, with its _level adjusted to self._level + 1 + ''' + if isinstance(child, QuantTree): + if adjusted_level: + child = child.with_adjusted_base_level(self._level + 1) + self.children.append(child) # newest child goes at end of list. + else: + child = QuantTree(child, level=self._level + 1) + self.add_child(child) + return child + + def set_data(self, data): + self.data = data + + def __str__(self): + lvlstr = ' '*self._level + '(L{level}) '.format(level=self._level) + # check hide level. if level >= hide_level, hide. + if self.hide_level is not None: + if self._level >= self.hide_level: + return (lvlstr + '{}').format(repr(self)) + # << if I reach this line it means I am not hiding myself. + # if no children, return string with level and data. + if len(self.children) == 0: + return (lvlstr + '{data}').format(data=self.data) + # else, we have children, so return a string with level, data, and children + + def _child_to_str(child): + return '\n' + child.str(self.hide_level, count_from_here=False) + children_strs = ','.join([_child_to_str(child) for child in self.children]) + return (lvlstr + '{data} : {children}').format(data=self.data, children=children_strs) + + def __repr__(self): + if isinstance(self.data, QuantInfo): + qi_str = "(varname='{}', quant='{}')".format(self.data.varname, self.data.quant) + else: + qi_str = '' + fmtdict = dict(qi_str=qi_str, hexid=hex(id(self)), Nnode=1 + self.count_descendants()) + return '< with {Nnode} nodes.>'.format(**fmtdict) + + def str(self, hide_level=None, count_from_here=True): + '''sets self.hide_level, returns str(self). + restores self.hide_level to its original value before returning result. + + if count_from_here, only hides after getting hide_level more levels deep than we are now. + (e.g. if count_from_here, and self is level 7, and hide_level is 3: hides level 10 and greater.) + ''' + orig_hide = self.hide_level + if count_from_here and (hide_level is not None): + hide_level = hide_level + self._level + self.hide_level = hide_level + try: + return self.__str__() + finally: + self.hide_level = orig_hide + + def get_child(self, i_child=0, oldest_first=True): + '''returns the child determined by index i_child. + + self.get_child() (with no args/kwargs entered) returns the oldest child (the child added first). + + Parameters + ---------- + i_child: int (default 0) + index of child to get. See oldest_first kwarg for ordering convention. + oldest_first: True (default) or False. + Determines the ordering convention for children: + True --> sort from oldest to youngest (0 is the oldest child (added first)). + Equivalent to self.children[i_child] + False -> sort from youngest to oldest (0 is the youngest child (added most-recently)). + Equivalent to self.children[::-1][i_child] + ''' + if oldest_first: + return self.children[i_child] + else: + return self.children[::-1][i_child] + + def set_base_level(self, level): + '''sets self._level to level; also adjusts level of all children appropriately. + + Example: self._level==2, self.children[0]._level==3; + self.set_base_level(0) --> self._level==0, self.children[0]._level==1 + ''' + lsubtract = self._level - level + self._adjust_base_level(lsubtract) + + def _adjust_base_level(self, l_subtract): + '''sets self._level to self._level - l_subtract. + Also decreases level of all children by ldiff, and all childrens' children, etc. + ''' + self._level -= l_subtract + for child in self.children: + child._adjust_base_level(l_subtract) + + def with_adjusted_base_level(self, level): + '''set_base_level(self) except it is nondestructive, i.e. sets level for a deepcopy of self. + returns the copy with the base level set to level. + ''' + result = copy.deepcopy(self) + result.set_base_level(level) + return result + + def count_descendants(self): + '''returns total number of descendants (children, childrens' children, etc) of self.''' + result = len(self.children) + for child in self.children: + result += child.count_descendants() + return result + + +def _get_orig_tree(obj): + '''gets QUANTS_TREE from obj (when LOADING_LEVEL is not -1; else, returns a new QuantTree).''' + loading_level = getattr(obj, LOADING_LEVEL, -1) # get loading_level. Outside of f, the default is -1. + if (loading_level == -1) or (not hasattr(obj, QUANTS_TREE)): + orig_tree = QuantTree(None, level=-1) + else: + orig_tree = getattr(obj, QUANTS_TREE) + return orig_tree + + +def quant_tree_tracking(f): + '''wrapper for f which makes it track quant tree. + + QUANTS_TREE (attr of obj) will be a tree like: + (L(N-1)) None : + (L(N)) QuantInfo(var_current_layer) : + (L(N+1)) QuantInfo(var_1, i.e. 1st var gotten while calculating var_current_layer) : + (L(N+2)) ... (tree for var_1) + (L(N+1)) QuantInfo(var_2, i.e. 2nd var gotten while calcualting var_current_layer) : + (L(N+2)) ... (tree for var_2) + + Another way to write it; it will be a tree like: + QuantTree(data=None, level=N-1, children= \ + [ + QuantTree(data=QuantInfo(var_current_layer), level=N, children= \ + [ + QuantTree(data=QuantInfo(var_1, level=N+1, children=...)), + QuantTree(data=QuantInfo(var_2, level=N+1, children=...)), + ... + ]) + ]) + ''' + @functools.wraps(f) + def f_but_quant_tree_tracking(obj, varname, *args, **kwargs): + __tracebackhide__ = HIDE_DECORATOR_TRACEBACKS + orig_tree = _get_orig_tree(obj) + tree_child = orig_tree.add_child(None) + setattr(obj, QUANTS_TREE, tree_child) + # call f. + result = f(obj, varname, *args, **kwargs) + # retore original tree. (The data is set by f, via _track_quants_selected and quant_tracking_top_level) + setattr(obj, QUANTS_TREE, orig_tree) + # return result of f. + return result + return f_but_quant_tree_tracking + + +def quant_tracking_top_level(f): + '''decorator which improves quant tracking. (decorate _load_quantities using this.)''' + @quant_tree_tracking + @tools.maintain_attrs(LOADING_LEVEL, VARNAME_INPUT, QUANT_SELECTION, QUANT_SELECTED) + @functools.wraps(f) + def f_but_quant_tracking_level(obj, varname, *args, **kwargs): + __tracebackhide__ = HIDE_DECORATOR_TRACEBACKS + setattr(obj, LOADING_LEVEL, getattr(obj, LOADING_LEVEL, -2) + 1) # increment LOADING_LEVEL. + # << if obj didn't have LOADING_LEVEL, its LOADING_LEVEL will now be -1. + setattr(obj, VARNAME_INPUT, varname) # save name of the variable which was input. + setattr(obj, QUANT_SELECTED, QuantInfo(None)) # smash QUANT_SELECTED before doing f. + result = f(obj, varname, *args, **kwargs) + # even if we don't recognize this quant (because we didn't put quant tracking document_vars code for it (yet)), + # we should still set QUANT_SELECTED to the info we do know (i.e. the varname), with blanks for what we dont know. + quant_info = getattr(obj, QUANT_SELECTED, QuantInfo(None)) + if quant_info.varname is None: # f did not set varname for quant_info, so we'll do it now. + setattr_quant_selected(obj, quant=QUANT_NOT_FOUND, typequant=QUANT_NOT_FOUND, metaquant=QUANT_NOT_FOUND, + varname=varname) + return result + return f_but_quant_tracking_level + + +def get_quant_tracking_state(obj, from_internal=False): + '''returns quant tracking state of obj. + The state includes only the quant_tree and the quant_selected. + + from_internal: False (default) or True + True <-> use when we are caching due to a "with Caching(...)" block. + False <-> use when we are caching due to the "@with_caching" wrapper. + ''' + if not hasattr(obj, QUANTS_TREE): # not sure if this ever happens. + quants_tree = QuantTree(None) # if it does, return an empty QuantTree so we don't crash. + elif from_internal: # we are saving state while INSIDE quant_tree_tracking (inside _load_quantity). + # QUANTS_TREE looks like QuantTree(None) but it is the child of the tree which will have + # the data that we are getting from this call of get_var (which we are inside now). + # Thus, when the call to get_var is completed, the data for this tree will be filled + # with the appropriate QuantInfo about the quant we are getting now. + quants_tree = getattr(obj, QUANTS_TREE) + else: # we are saving state while OUTSIDE quant_tree_tracking (outside _load_quantity). + # QUANTS_TREE looks like [None : [..., QuantTree(QuantInfo( v ))]] where + # v is the var we just got with the latest call to _load_quantity. + quants_tree = getattr(obj, QUANTS_TREE).get_child(-1) # get the newest child. + state = dict(quants_tree=quants_tree, + quant_selected=getattr(obj, QUANT_SELECTED, QuantInfo(None)), + _from_internal=from_internal, # not used, but maybe helpful for debugging. + _ever_restored=False # whether we have ever restored this state. + ) + return state + + +def restore_quant_tracking_state(obj, state): + '''restores the quant tracking state of obj.''' + state_tree = state['quants_tree'] + obj_tree = _get_orig_tree(obj) + child_to_add = state_tree # add state tree as child of obj_tree. + if not state['_ever_restored']: + state['_ever_restored'] = True + if isinstance(child_to_add.data, QuantInfo): + # adjust level of top QuantInfo in tree, to indicate it is from cache. + q = child_to_add.data._asdict() + q['level'] = str(q['level']) + ' (FROM CACHE)' + child_to_add.data = QuantInfo(**q) + # add child to obj_tree. + obj_tree.add_child(child_to_add, adjusted_level=True) + setattr(obj, QUANTS_TREE, obj_tree) + # set QUANT_SELECTED. + selected = state.get('quant_selected', QuantInfo(None)) + setattr(obj, QUANT_SELECTED, selected) + + +''' ----------------------------- quant tracking - lookup ----------------------------- ''' + + +def gotten_vars(obj, hide_level=3, hide_interp=True, hide=[], hidef=lambda info: False, + hide_quants=[], hide_typequants=[], hide_metaquants=[]): + '''returns obj._quants_selected, which shows the most recent quants which get_var got. + + It is possible to hide quants from the list using the kwargs of this function. + + hide_level: integer (default 3) + hide all QuantInfo tuples with typequant or metaquant like 'level_n' with n >= hide_level. + case insensitive. Also, 'top_level' is treated like 'level_0'. + (This can only hide quants with otherwise unnamed typequant/metaquant info.) + hide_interp: True (default) or False + whether hide all quants which are one of the following types: + 'INTERP_QUANT', 'CENTER_QUANT' + hide: list. (default []) + hide all QuantInfo tuples with quant, typequant, or metaquant in this list. + hidef: function(info) --> bool. Default: (lambda info: False) + if evaluates to True for a QuantInfo tuple, hide this info. + Such objects are namedtuples, with contents (quant, typequant, metaquant). + hide_quants: list. (default []) + hide all QuantInfo tuples with quant in this list. + hide_typequants: list. (default []) + hide all QuantInfo tuples with typequant in this list. + hide_metaquants: list. (default []) + hide all QuantInfo tuples with metaquant in this list. + ''' + quants_selected = getattr(obj, QUANTS_SELECTED, collections.deque([])) + result = collections.deque(maxlen=quants_selected.maxlen) + for info in quants_selected: + quant, typequant, metaquant = info.quant, info.typequant, info.metaquant + varname, level = info.varname, info.level + # if we should hide this one, continue. + if level is not None: + if level >= hide_level: + continue + if hide_interp: + if typequant in ['INTERP_QUANT', 'CENTER_QUANT']: + continue + if len(hide) > 0: + if (quant in hide) or (typequant in hide) or (metaquant in hide): + continue + if (quant in hide_quants) or (typequant in hide_typequants) or (metaquant in hide_metaquants): + continue + if hidef(info): + continue + # else, add this one to result. + result.append(info) + return result + + +def got_vars_tree(obj, as_data=False, hide_level=None, i_child=0, oldest_first=True): + '''prints QUANTS_TREE for obj. + This tree shows the vars which were gotten during the most recent "level 0" call to get_var. + (Calls to get_var from outside of helita have level == 0.) + + Use as_data=True to return the QuantTree object instead of printing it. + + Use hide_level=N to hide all layers of the tree with L >= N. + + Use i_child to get a child other than children[-1] (default). + Note that if you are calling got_vars_tree externally (not inside any calls to get_var), + then there will be only one child, and it will be the QuantTree for the var passed to get_var. + + Use oldest_first to tell the children ordering convention: + True --> 0 is the oldest child (added first); -1 is the newest child (added most-recently). + False -> the order is reversed, e.g. -1 is the oldest child instead. + ''' + # Get QUANTS_TREE attr. Since this function (got_vars_tree) is optional, and for end-user, + # crash elegantly if obj doesn't have QUANTS_TREE, instead of trying to handle the crash. + quants_tree = getattr(obj, QUANTS_TREE) + # By design, data in top level of QUANTS_TREE is always None + # (except inside the wrapper quant_tree_tracking, which is the code that manages the tree). + # Thus the top level of quants_tree is not useful data, so we go to a child instead. + quants_tree = quants_tree.get_child(i_child, oldest_first) + # if as_data, return. Else, print. + if as_data: + return quants_tree + else: + print(quants_tree.str(hide_level=hide_level)) + + +def quant_lookup(obj, quant_info): + '''returns entry in obj.vardict related to quant_info (a QuantInfo object). + returns vardict[quant_info.metaquant][quant_info.typequant][quant_info.quant] + + if that cannot be found: + if metaquant in obj.VDSEARCH_IF_META return search_vardict(vardict, quant).result if it exists + + default (if we haven't returned anything else): return an empty dict. + ''' + quant_dict = dict() # default value + vardict = getattr(obj, VARDICT, dict()) + try: + metaquant_dict = vardict[quant_info.metaquant] + typequant_dict = metaquant_dict[quant_info.typequant] + quant_dict = typequant_dict[quant_info.quant] + except KeyError: + if quant_info.metaquant in getattr(obj, 'VDSEARCH_IF_META', []): + search = search_vardict(vardict, quant_info.quant) + if search: + quant_dict = search.result + return quant_dict + + +def get_quant_info(obj, lookup_in_vardict=False): + '''returns QuantInfo object for the top-level quant in got_vars_tree. + If lookup_in_vardict, also use obj.quant_lookup to look up that info in obj.vardict. + ''' + quant_info = got_vars_tree(obj, as_data=True, i_child=0).data + if lookup_in_vardict: + return quant_lookup(obj, quant_info) + else: + return quant_info diff --git a/helita/sim/ebysus.py b/helita/sim/ebysus.py index 2d49ddb8..757b96af 100644 --- a/helita/sim/ebysus.py +++ b/helita/sim/ebysus.py @@ -1,27 +1,303 @@ """ Set of programs to read and interact with output from Multifluid/multispecies + + +TODO: + Fix the memory leak... + The following code: + dd = eb.EbysusData(...) + del dd + does not actually free the dd object. It does not run dd.__del__(). + This can be proven by defining EbysusData.__del__() to print something + (which is what happens if you edit file_memory.py to set DEBUG_MEMORY_LEAK=True). + You can also turn off all the file_memory.py caches and memory by + setting a flag when initializing dd: dd = eb.EbysusData(..., _force_disable_memory=True). + + This leak could be caused by an attribute of dd pointing to dd without using weakref. + + It is also possible that there isn't a leak, because Python can collect objects in circular + reference chains as long as none of the objects in the chain have defined a __del__ method. + So it is possible that there is a circular reference which gets collected when __del__ is + not defined (when DEBUG_MEMORY_LEAK=False), but then can't get collected when __del__ is defined... + + A short-term solution is to hope python's default garbage collection routines + will collect the garbage often enough, or to do import gc; and gc.collect() sometimes. + + In the long-term, we should find which attribute of dd points to dd, and fix it. + """ -import numpy as np +# import built-in modules import os -from .bifrost import BifrostData, Rhoeetab, Bifrost_units -from .bifrost import read_idl_ascii, subs2grph -from . import cstagger -from at_tools import atom_tools as at +import time +import shutil +import warnings +import collections + +from . import document_vars, file_memory, fluid_tools, stagger, tools +# import local modules +from .bifrost import ( # for historical reasons / convenience, also import directly: + Bifrost_units, + BifrostData, + Cross_sect, + EnterDir, + EnterDirectory, + Rhoeetab, + _N_to_snapstr, + available_snaps, + get_snapname, + get_snaps, + get_snapstuff, + list_snaps, + read_idl_ascii, + snapname, + snaps, + snaps_info, + snapstuff, + subs2grph, +) +from .load_arithmetic_quantities import load_arithmetic_quantities +from .load_fromfile_quantities import load_fromfile_quantities +from .load_mf_quantities import load_mf_quantities +from .load_quantities import load_quantities +from .units import U_TUPLE, UNI, UNI_rho, UNI_speed, Usym, UsymD + +try: + from . import cstagger +except ImportError: + cstagger = tools.ImportFailed('cstagger', "This module is required to use stagger_kind='cstagger'.") + +# import external public modules +import numpy as np + +try: + import zarr +except ImportError: + zarr = tools.ImportFailed('zarr') + +# import external private modules +try: + from atom_py.at_tools import atom_tools as at + at_tools_exists = True +except: + at_tools_exists = False + at = tools.ImportFailed('atom_py.at_tools.atom_tools') +try: + from atom_py.at_tools import fluids as fl +except ImportError: + fl = tools.ImportFailed('at_tools.fluids') + +# set defaults: +from .load_mf_quantities import MATCH_AUX, MATCH_PHYSICS + +MATCH_TYPE_DEFAULT = MATCH_PHYSICS # can change this one. Tells whether to match physics or aux. +# match physics -> try to return physical value. +# match aux -> try to return value matching aux. + +AXES = ('x', 'y', 'z') + -class EbysusData(BifrostData): +class EbysusData(BifrostData, fluid_tools.Multifluid): """ Class to hold data from Multifluid/multispecies simulations in native format. """ - def __init__(self, *args, **kwargs): - super(EbysusData, self).__init__(*args, **kwargs) + def __init__(self, *args, fast=True, match_type=MATCH_TYPE_DEFAULT, + mesh_location_tracking=stagger.DEFAULT_MESH_LOCATION_TRACKING, + read_mode='io', + N_memmap=200, mm_persnap=True, + do_caching=True, cache_max_MB=10, cache_max_Narr=20, + _force_disable_memory=False, + **kwargs): + ''' initialize EbysusData object. + + mesh_location_tracking: False (default) or True + False --> disable conversion to ArrayOnMesh. (You can safely ignore this feature.) + True --> arrays from get_var will be returned as stagger.ArrayOnMesh objects, + which track the location on the mesh. Note that when doing arithmetic + with multiple ArrayOnMesh objects, it is required for locations to match. + The default may be changed to True in the future, after sufficient testing. + + read_mode: 'io' (default) or 'zc' + Where to read the data from, and how the data is stored. + 'io': 'input/output', the direct output from the ebysus simulation. + 'zc': 'zarr-compressed', the output from the EbysusData.compress() function. + 'zc' mode is generally faster to read and requires less storage space. + But it requires the compress function to have been run separately. + + N_memmap: int (default 0) + keep the N_memmap most-recently-created memmaps stored in self._memory_numpy_memmap. + -1 --> try to never forget any memmaps. + May increase (for this python session) the default maximum number of files + allowed to be open simultaneously. Tries to be conservative about doing so. + See file_memory.py for more details. + 0 --> never remember any memmaps. + Turns off remembering memmaps. + Not recommended; causes major slowdown. + >=1 --> remember up to this many memmaps. + + mm_persnap: True (default) or False + whether to delete all memmaps in self._memory_memmap when we set_snap to a new snap. + + fast: True (default) or False + whether to be fast. + True -> don't create memmaps for all simple variables when snapshot changes. + False -> do create memmaps for all simple variables when snapshot changes. + Not recommended; causes major slowdown. + This option is included in case legacy code assumes values + via self.var, or self.variables[var], instead of self.get_var(var). + As long as you use get_var to get var values, you can safely use fast=True. + + match_type: 0 (default) or 1 + whether to try to match physical answer (0) or aux data (1). + Applicable to terms which can be turned on or off. e.g.: + if do_hall='false': + match_type=0 --> return result as if do_hall is turned on. (matches actual physics) + match_type=1 --> return result as if do_hall is off. (matches aux file data) + Only applies when explicitly implemented in load quantities files, e.g. load_mf_quantities. + + do_caching: True (default) or False + whether to allow any type of caching (maintaining a short list of recent results of get_var). + if False, the with_caching() function will skip caching and self.cache will be ignored. + can be enabled or disabled at any point; does not erase the current cache. + cache_max_MB: 10 (default) or number + maximum number of MB of data which cache is allowed to store at once. + cache_max_Narr: 20 (default) or number + maximum number of arrays which cache is allowed to store at once. + + _force_disable_memory: False (default) or True + if True, disable ALL code from file_memory.py. + Very inefficient; however, it is useful for debugging file_memory.py. + + *args and **kwargs go to helita.sim.bifrost.BifrostData.__init__ + ''' + # set values of some attrs (e.g. from args & kwargs passed to __init__) + self.mesh_location_tracking = mesh_location_tracking + self.match_type = match_type + self.read_mode = read_mode + + setattr(self, file_memory.NMLIM_ATTR, N_memmap) + setattr(self, file_memory.MM_PERSNAP, mm_persnap) + + self.do_caching = do_caching and not _force_disable_memory + self._force_disable_memory = _force_disable_memory + if not _force_disable_memory: + self.cache = file_memory.Cache(obj=self, max_MB=cache_max_MB, max_Narr=cache_max_Narr) + self.caching = lambda: self.do_caching and not self.cache.is_NoneCache() # (used by load_mf_quantities) + setattr(self, document_vars.LOADING_LEVEL, -1) # tells how deep we are into loading a quantity now. + + self.panic = False + + # figure out snapname. If it doesn't agree with snapname (optionally) entered in args, make warning. + with EnterDirectory(kwargs.get('fdir', os.curdir)): + snapname = get_snapname() + if len(args) >= 1: + if args[0] != snapname: + snapname_errmsg = "snapname from args ('{}') disagrees with snapname from mhd.in ('{}')!" + # it will read from arg and won't raise error if mhd.in does not match args. + warnings.warn(snapname_errmsg.format(args[0], snapname)) + snapname = args[0] + + # call BifrostData.__init__ + BifrostData.__init__(self, snapname, *args[1:], fast=fast, **kwargs) + + # call Multifluid.__init__ + fluid_tools.Multifluid.__init__(self, ifluid=kwargs.pop('ifluid', (1, 1)), # default (1,1) + jfluid=kwargs.pop('jfluid', (1, 1))) # default (1,1) + + # set up self.att + self.att = {} + tab_species = self.mf_tabparam['SPECIES'] + self.mf_nspecies = len(tab_species) + self.mf_total_nlevel = 0 + for row in tab_species: + # example row looks like: ['01', 'H', 'H_2.atom'] + mf_ispecies = int(row[0]) + self.att[mf_ispecies] = at.Atom_tools(atom_file=row[2], fdir=self.fdir) + self.mf_total_nlevel += self.att[mf_ispecies].params.nlevel + + # read minimal amounts of data, to finish initializing. + self._init_vars_get(firstime=True) + self._init_coll_keys() + + ## PROPERTIES ## + mesh_location_tracking = stagger.MESH_LOCATION_TRACKING_PROPERTY(internal_name='_mesh_location_tracking') + + @property + def mf_arr_size(self): + '''now deprecated. Previously: number of fluids per variable in the most recently queried memmap file. + + To reinstate this value (if your code actually used it..) take these steps: + - add this line to the end (before return) of _get_simple_var_file_meta: + self.mf_arr_size = mf_arr_size + - delete the mf_arr_size @property (where this documentation appears in the code). + ''' + raise AttributeError('mf_arr_size has been deprecated.') + + @property + def read_mode(self): + '''which data to read. options are 'io', 'zc'. default is 'io' + 'io': 'input/output', the direct output from the ebysus simulation. + 'zc': 'zarr-compressed', the output from the EbysusData.compress() function. + 'zc' mode is generally faster to read and requires less storage space. + But it requires the compress function to have been run separately. + ''' + return getattr(self, '_read_mode', 'io') + + @read_mode.setter + def read_mode(self, value): + value = value.lower() + assert value in ('io', 'zc'), f"Expected read_mode in ('io', 'zc') but got {value}" + self._read_mode = value + + @property + def file_root_with_io_ext(self): + '''returns self.file_root with appropriate extension. E.g. 'snapname.io'. + extension is based on read_mode. '.io' by default. Can be '.zc' if using 'zc' read_mode. + ''' + return f'{self.file_root}.{self.read_mode}' + + ## INITIALIZING ## + def _init_coll_keys(self): + '''initialize self.coll_keys as a dict for better efficiency when looking up collision types. + self.coll_keys will be a dict with keys (ispecies, jspecies) values (collision type). + collision types are: + 'CL' ("coulomb"; whether coulomb collisions are allowed between these species) + 'EL' ("elastic"; previous default in ebysus) + 'MX' ("maxwell"; this one is usable even if we don't have cross section file) + Note that MX and EL are (presently) mutually exclusive. + ''' + _enforce_symmetry_in_collisions = False + # ^^ whether to manually put (B,A):value if (A,B):value is in coll_keys. + # disabled now because presently, ebysus simulation does not enforce + # that symmetry; e.g. it is possible to have (1,2):'EL' and (2,1):'MX', + # though I don't know what that combination would mean... - SE May 26 2021 + + # begin processing: + result = dict() + if 'COLL_KEYS' in self.mf_tabparam: + x = self.mf_tabparam['COLL_KEYS'] + for tokenline in x: # example tokenline: ['01', '02', 'EL'] + ispec, jspec, collkey = tokenline + ispec, jspec = int(ispec), int(jspec) + key = (ispec, jspec) + try: + result[key] += [collkey] + except KeyError: + result[key] = [collkey] + if _enforce_symmetry_in_collisions: + for key in list(result.keys()): # list() because changing size of result + rkey = (key[1], key[0]) # reversed + if rkey not in result.keys(): + result[rkey] = result[key] + + self.coll_keys = result - def _set_snapvars(self): + def _set_snapvars(self, firstime=False): - if os.path.exists('%s.io' % self.file_root): + if os.path.exists(self.file_root_with_io_ext): self.snaprvars = ['r'] self.snappvars = ['px', 'py', 'pz'] else: @@ -38,41 +314,47 @@ def _set_snapvars(self): self.varsmfc = [v for v in self.auxvars if v.startswith('mfc_')] self.varsmf = [v for v in self.auxvars if v.startswith('mf_')] self.varsmm = [v for v in self.auxvars if v.startswith('mm_')] + self.varsmfr = [v for v in self.auxvars if v.startswith('mfr_')] + self.varsmfp = [v for v in self.auxvars if v.startswith('mfp_')] self.varsmfe = [v for v in self.auxvars if v.startswith('mfe_')] if (self.mf_epf): # add internal energy to basic snaps - #self.snapvars.append('e') + # self.snapvars.append('e') # make distiction between different aux variable - self.mf_e_file = self.file_root + '_mf_e' + self.mf_e_file = self.root_name + '_mf_e' else: # one energy for all fluid self.mhdvars.insert(0, 'e') self.snapevars = [] if hasattr(self, 'with_electrons'): if self.with_electrons: - self.mf_e_file = self.file_root + '_mf_e' + self.mf_e_file = self.root_name + '_mf_e' # JMS This must be implemented - self.snapelvars=['r', 'px', 'py', 'pz', 'e'] + self.snapelvars = ['r', 'px', 'py', 'pz', 'e'] for var in ( + self.varsmfr + + self.varsmfp + self.varsmfe + self.varsmfc + self.varsmf + self.varsmm): self.auxvars.remove(var) - #if hasattr(self, 'mf_total_nlevel'): + # if hasattr(self, 'mf_total_nlevel'): # if self.mf_total_nlevel == 1: # self.snapvars.append('e') - if os.path.exists('%s.io' % self.file_root): + if os.path.exists(self.file_root_with_io_ext): self.simple_vars = self.snaprvars + self.snappvars + \ self.snapevars + self.mhdvars + self.auxvars + \ - self.varsmf + self.varsmfe + self.varsmfc + self.varsmm + self.varsmf + self.varsmfr + self.varsmfp + self.varsmfe + \ + self.varsmfc + self.varsmm else: self.simple_vars = self.snapvars + self.snapevars + \ - self.mhdvars + self.auxvars + self.varsmf + self.varsmfe + \ + self.mhdvars + self.auxvars + self.varsmf + \ + self.varsmfr + self.varsmfp + self.varsmfe + \ self.varsmfc + self.varsmm self.auxxyvars = [] @@ -91,256 +373,532 @@ def _set_snapvars(self): if (self.do_mhd): self.compvars = self.compvars + ['bxc', 'byc', 'bzc', 'modb']''' - # def set_snap(self,snap): - # super(EbysusData, self).set_snap(snap) + def set_snap(self, snap, *args__set_snap, **kwargs__set_snap): + '''call set_snap from BifrostData, + but also if mm_persnap, then delete all the memmaps in memory.. + ''' + if getattr(self, file_memory.MM_PERSNAP, False) and np.shape(self.snap) == (): + if hasattr(self, file_memory.MEMORY_MEMMAP): + delattr(self, file_memory.MEMORY_MEMMAP) + super(EbysusData, self).set_snap(snap, *args__set_snap, **kwargs__set_snap) - def _read_params(self): + def _read_params(self, firstime=False): ''' Reads parameter file specific for Multi Fluid Bifrost ''' - super(EbysusData, self)._read_params() + super(EbysusData, self)._read_params(firstime=firstime) self.nspecies_max = 28 self.nlevels_max = 28 + + # get misc. params (these have no default values. Make error if we can't get them). + errmsg = 'read_params: could not find {} in idl file!' + self.mf_epf = self.get_param('mf_epf', error_prop=KeyError(errmsg.format('mf_epf'))) + self.mf_nspecies = self.get_param('mf_nspecies', error_prop=KeyError(errmsg.format('mf_nspecies'))) + self.with_electrons = self.get_param('mf_electrons', error_prop=KeyError(errmsg.format('mf_electrons'))) + self.mf_total_nlevel = self.get_param('mf_total_nlevel', error_prop=KeyError(errmsg.format('mf_total_nlevel'))) + + # get param_file params (these have default values). + # mf_param_file + param_file = self.get_param('mf_param_file', default='mf_params.in', + warning='mf_param_file not found in this idl file; trying to use mf_params.in') + file = os.path.join(self.fdir, param_file.strip()) + self.mf_tabparam = read_mftab_ascii(file, obj=self) + # mf_eparam_file + do_ohm_ecol = self.get_param('do_ohm_ecol', 0) + warning = 'mf_eparam_file parameter not found; trying to use mf_eparams.in' if do_ohm_ecol else None + eparam_file = self.get_param('mf_eparam_file', default='mf_eparams.in', warning=warning) + file = os.path.join(self.fdir, eparam_file.strip()) try: - self.mf_epf = self.params['mf_epf'][self.snapInd] - except KeyError: - raise KeyError('read_params: could not find mf_epf in idl file!') - try: - self.mf_nspecies = self.params['mf_nspecies'][self.snapInd] - except KeyError: - raise KeyError('read_params: could not find mf_nspecies in idl file!') - try: - self.with_electrons = self.params['mf_electrons'][self.snapInd] - except KeyError: - raise KeyError( - 'read_params: could not find with_electrons in idl file!') - try: - self.mf_total_nlevel = self.params['mf_total_nlevel'][self.snapInd] - except KeyError: - print('warning, this idl file does not include mf_total_nlevel') - try: - filename = os.path.join( - self.fdir, self.params['mf_param_file'][self.snapInd].strip()) - self.mf_tabparam = read_mftab_ascii(filename) - except KeyError: - print('warning, this idl file does not include mf_param_file') + self.mf_etabparam = read_mftab_ascii(file, obj=self) + except FileNotFoundError: + # if do_ohm_ecol, crash; otherwise quietly ignore error. + if do_ohm_ecol: + raise - def _init_vars(self, *args, **kwargs): + def _init_vars(self, firstime=False, fast=None, *args__get_simple_var, **kw__get_simple_var): """ - Initialises variable (common for all fluid) + Initialises variables (common for all fluid) + + fast: None, True, or False. + whether to only read density (and not all the other variables). + if None, use self.fast instead. + + *args and **kwargs go to _get_simple_var """ + fast = fast if fast is not None else self.fast + if self._fast_skip_flag is True: + return + elif self._fast_skip_flag is False: + self._fast_skip_flag = True # swaps flag to True, then runs the rest of the code (this time around). + # else, fast_skip_flag is None, so the code should never be skipped. + # as long as fast is False, fast_skip_flag should be None. + self.mf_common_file = (self.root_name + '_mf_common') - if os.path.exists('%s.io' % self.file_root): - self.mfr_file = (self.root_name + '_mfr_%02i_%02i') - self.mfp_file = (self.root_name + '_mfp_%02i_%02i') + if os.path.exists(self.file_root_with_io_ext): + self.mfr_file = (self.root_name + '_mfr_{iS:}_{iL:}') + self.mfp_file = (self.root_name + '_mfp_{iS:}_{iL:}') else: - self.mf_file = (self.root_name + '_mf_%02i_%02i') - self.mfe_file = (self.root_name + '_mfe_%02i_%02i') - self.mfc_file = (self.root_name + '_mfc_%02i_%02i') - self.mm_file = (self.root_name + '_mm_%02i_%02i') + self.mf_file = (self.root_name + '_mf_{iS:}_{iL:}') + self.mfe_file = (self.root_name + '_mfe_{iS:}_{iL:}') + self.mfc_file = (self.root_name + '_mfc_{iS:}_{iL:}') + self.mm_file = (self.root_name + '_mm_{iS:}_{iL:}') self.mf_e_file = (self.root_name + '_mf_e') + self.aux_file = (self.root_name) self.variables = {} self.set_mfi(None, None) self.set_mfj(None, None) - for var in self.simple_vars: + if not firstime: + self._init_vars_get(firstime=False, *args__get_simple_var, **kw__get_simple_var) + + def _init_vars_get(self, firstime=False, *args__get_simple_var, **kw__get_simple_var): + '''get vars for _init_vars.''' + varlist = ['r'] if self.fast else self.simple_vars + for var in varlist: try: - self.variables[var] = self._get_simple_var( - var, self.mf_ispecies, self.mf_ilevel, *args, **kwargs) + # try to get var via _get_simple_var. + self.variables[var] = self._get_simple_var(var, + *args__get_simple_var, **kw__get_simple_var) + except Exception as error: + # if an error occurs, then... + if var == 'r' and firstime: + # RAISE THE ERROR + # Many methods depend on self.r being set. So if we can't get it, the code needs to crash. + raise + elif isinstance(error, ValueError) and (self.mf_ispecies < 0 or self.mf_ilevel < 0): + # SILENTLY HIDE THE ERROR. + # We assume it came from doing something like get_var('r', mf_ispecies=-1), + # which is is _supposed_ to fail. We hope it came from that, at least.... + # To be cautious / help debugging, we will store any such errors in self._hidden_errors. + if not hasattr(self, '_hidden_errors'): + self._hidden_errors = [] + if not hasattr(self, '_hidden_errors_max_len'): + self._hidden_errors_max_len = 100 # don't keep track of more than this many errors. + errmsg = "during _init_vars_get, with var='{}', {}".format(var, self.quick_look()) + errmsg.format(var, self.snap, self.ifluid, self.jfluid) + self._hidden_errors += [(errmsg, error)] + if len(self._hidden_errors) > self._hidden_errors_max_len: + del self._hidden_errors[0] + else: + # MAKE A WARNING but don't crash the code. + # Note: warnings with the same exact contents will only appear once per session, by default. + # You can change this behavior via, e.g.: import warnings; warnings.simplefilter('always') + errmsg = error if (self.verbose or firstime) else type(error).__name__ + warnings.warn("init_vars failed to read variable '{}' due to: {}".format(var, errmsg)) + else: + # if there was no error, then set self.var to the result. + # also set self.variables['metadata'] to self._metadata. + # this ensures we only pull data from self.variables when + # it is the correct snapshot, ifluid, and jfluid. setattr(self, var, self.variables[var]) - except BaseException: - if self.verbose: - if not (self.mf_ilevel == 1 and var in self.varsmfc): - print(('(WWW) init_vars: could not read ' - 'variable %s' % var)) + self.variables['metadata'] = self._metadata() rdt = self.r.dtype - cstagger.init_stagger(self.nz, self.dx, self.dy, self.z.astype(rdt), - self.zdn.astype(rdt), self.dzidzup.astype(rdt), - self.dzidzdn.astype(rdt)) + if self.stagger_kind == 'cstagger': + if (self.nz > 1): + cstagger.init_stagger(self.nz, self.dx, self.dy, self.z.astype(rdt), + self.zdn.astype(rdt), self.dzidzup.astype(rdt), + self.dzidzdn.astype(rdt)) + self.cstagger_exists = True # we can use cstagger methods! + else: + self.cstagger_exists = False + #cstagger.init_stagger_mz1(self.nz, self.dx, self.dy, self.z.astype(rdt)) + # self.cstagger_exists = True # we must avoid using cstagger methods. + else: + self.cstagger_exists = True + + ## INTROSPECTION ## + def _metadata(self, none=None, with_nfluid=2): + '''returns dict of metadata for self. Including snap, ifluid, jfluid, and more. + if self.snap is an array, set result['snaps']=snap and result['snap']=snaps[self.snapInd]. + + none: any value (default None) + metadata attrs which are not yet set will be set to this value. + with_nfluid: 2 (default), 1, or 0. + tells which fluids to include in the result. + 2 -> ifluid and jfluid. 1 -> just ifluid. 0 -> no fluids. + ''' + # METADATA_ATTRS is the list of all the attrs of self which may affect the output of get_var. + # we only read from cache if ALL these attrs agree with those associated to the cached value. + METADATA_ATTRS = ['ifluid', 'jfluid', 'snap', 'iix', 'iiy', 'iiz', 'match_type', 'panic', + 'do_stagger', 'stagger_kind', '_mesh_location_tracking', 'read_mode'] + if with_nfluid < 2: + del METADATA_ATTRS[1] # jfluid + if with_nfluid < 1: + del METADATA_ATTRS[0] # ifluid + # get attrs + result = {attr: getattr(self, attr, none) for attr in METADATA_ATTRS} + # if snap is array, set snaps=snap, and snap=snaps[self.snapInd] + if result['snap'] is not none: + if len(np.shape(result['snap'])) > 0: + result['snaps'] = result['snap'] # snaps is the array of snaps + result['snap'] = result['snap'][self.snapInd] # snap is the single snap + return result + + def quick_look(self): + '''returns string with snap, ifluid, and jfluid.''' + x = self._metadata(none='(not set)') + result = 'ifluid={}, jfluid={}, snap={}'.format(x['ifluid'], x['jfluid'], x['snap']) + snaps = x.get('snaps', None) + if snaps is not None: + result += ', snaps={}'.format(''.format( + np.size(snaps), np.min(snaps), np.max(snaps))) + return result + + def __repr__(self): + '''makes prettier repr of self''' + return '<{} with {}>'.format(object.__repr__(self), self.quick_look()) + + def _metadata_is_consistent(self, alt_metadata, none=None): + '''return whether alt_metadata is consistent with self._metadata(). + They "are consistent" if alt_metadata is a subset of self._metadata(). + i.e. if for all keys in alt_metadata, alt_metadata[key]==self._metadata[key]. + (Even works if contents are numpy arrays. See _dict_is_subset function for details.) + + For developing helita code, it is preferred to use _metadata_matches() instead. + ''' + return file_memory._dict_is_subset(alt_metadata, self._metadata(none=none)) + + def _metadata_matches(self, alt_metadata, none=None): + '''return whether alt_metadata matches self._metadata(). + They "match" if: + for fluid (either ifluid or jfluid) which exists in alt_metadata, + self._metadata()[fluid] must have the same value. + all other keys in each dict are the same and have the same value. + ''' + self_metadata = self._metadata(none=none) + return self._metadata_equals(self_metadata, alt_metadata) + + def _metadata_equals(self, m1, m2): + '''return whether metadata1 == metadata2. + They are "equal" if: + - for fluid ('ifluid' or 'jfluid') appearing in both m1 and m2: m1[fluid]==m2[fluid] + - all other keys in each dict are the same and have the same value. + ''' + for fluid in ['ifluid', 'jfluid']: + SL1 = m1.get(fluid, None) + SL2 = m2.get(fluid, None) + if None not in (SL1, SL2): + if not self.fluids_equal(SL1, SL2): + return False + # else: fluid is missing from m1 and/or m2; at least one of the metadata doesn't care about fluid. + # << if we reached this line, then we know ifluid and jfluid "match" between alt and self. + return file_memory._dict_equals(m1, m2, ignore_keys=['ifluid', 'jfluid']) + + # MATCH TYPE ## # (MATCH_AUX --> match simulation values; MATCH_PHYSICS --> match physical values) + @property + def match_type(self): + '''whether to match aux or physics. see self.match_aux and self.match_physics.''' + return getattr(self, '_match_type', MATCH_TYPE_DEFAULT) + + @match_type.setter + def match_type(self, value): + VALID_MATCH_TYPES = (MATCH_PHYSICS, MATCH_AUX) + assert value in VALID_MATCH_TYPES, 'Invalid match_type {}; not in {}.'.format(m, VALID_MATCH_TYPES) + self._match_type = value + + def match_physics(self): + '''return whether self.match_type == MATCH_PHYSICS''' + return self.match_type == MATCH_PHYSICS + + def match_aux(self): + '''return whether self.match_type == MATCH_AUX''' + return self.match_type == MATCH_AUX + + ## READING DATA / DOING CALCULATIONS ## + def _raw_load_quantity(self, var, panic=False): + '''load_quantity without any of the wrapper functions. + Makes it easier to subclass EbysusData: + _load_quantity in subclasses can be wrapped and call _raw_load_quantity. + ''' + __tracebackhide__ = True # hide this func from error traceback stack + # look for var in self.variables, if metadata is appropriate. + if var in self.variables and self._metadata_matches(self.variables.get('metadata', dict())): + return self.variables[var] + # load quantities. + val = load_fromfile_quantities(self, var, panic=panic, save_if_composite=False) + if val is None: + val = load_quantities(self, var, PLASMA_QUANT='', + CYCL_RES='', COLFRE_QUANT='', COLFRI_QUANT='', + IONP_QUANT='', EOSTAB_QUANT='', TAU_QUANT='', + DEBYE_LN_QUANT='', CROSTAB_QUANT='', + COULOMB_COL_QUANT='', AMB_QUANT='') + if val is None: + val = load_mf_quantities(self, var) + if val is None: + val = load_arithmetic_quantities(self, var) + return val - def set_mfi(self, mf_ispecies=None, mf_ilevel=None): - """ - adds mf_ispecies and mf_ilevel attributes if they don't exist and - changes mf_ispecies and mf_ilevel if needed. It will set defaults to 1 + @tools.maintain_attrs('match_type', 'ifluid', 'jfluid') + @file_memory.with_caching(cache=False, check_cache=True, cache_with_nfluid=None) + @document_vars.quant_tracking_top_level + def _load_quantity(self, var, panic=False): + '''helper function for get_var; actually calls load_quantities for var. + Also, restores self.ifluid and self.jfluid afterwards. + Also, restores self.match_type afterwards. + ''' + __tracebackhide__ = True # hide this func from error traceback stack + return self._raw_load_quantity(var, panic=panic) + + def get_var(self, var, snap=None, iix=None, iiy=None, iiz=None, + mf_ispecies=None, mf_ilevel=None, mf_jspecies=None, mf_jlevel=None, + ifluid=None, jfluid=None, panic=False, + match_type=None, check_cache=True, cache=False, cache_with_nfluid=None, + read_mode=None, printing_stats=None, + *args, **kwargs): """ + Reads a given variable from the relevant files. - if (mf_ispecies is not None): - if (mf_ispecies != self.mf_ispecies): - self.mf_ispecies = mf_ispecies - elif not hasattr(self, 'mf_ispecies'): - self.mf_ispecies = 1 - elif not hasattr(self, 'mf_ispecies'): - self.mf_ispecies = 1 - - if (mf_ilevel is not None): - if (mf_ilevel != self.mf_ilevel): - self.mf_ilevel = mf_ilevel - elif not hasattr(self, 'mf_ilevel'): - self.mf_ilevel = 1 - elif not hasattr(self, 'mf_ilevel'): - self.mf_ilevel = 1 - - def set_mfj(self, mf_jspecies=None, mf_jlevel=None): - """ - adds mf_ispecies and mf_ilevel attributes if they don't exist and - changes mf_ispecies and mf_ilevel if needed. It will set defaults to 1 - """ + Use self.get_var('') for help. + Use self.vardocs() to prettyprint the available variables and what they mean. - if (mf_jspecies is not None): - if (mf_jspecies != self.mf_jspecies): - self.mf_jspecies = mf_jspecies - elif not hasattr(self, 'mf_jspecies'): - self.mf_ispecies = 1 - elif not hasattr(self, 'mf_jspecies'): - self.mf_jspecies = 1 - - if (mf_jlevel is not None): - if (mf_jlevel != self.mf_jlevel): - self.mf_jlevel = mf_jlevel - elif not hasattr(self, 'mf_jlevel'): - self.mf_jlevel = 1 - elif not hasattr(self, 'mf_jlevel'): - self.mf_jlevel = 1 - - def get_var(self, var, snap=None, iix=slice(None), iiy=slice(None), - iiz=slice(None), mf_ispecies=None, mf_ilevel=None, - mf_jspecies=None, mf_jlevel=None, *args, **kwargs): - """ - Reads a given variable from the relevant files. + sets fluid-related attributes (e.g. self.ifluid) based on fluid-related kwargs. + + returns the data for the variable (as a 3D array with axes 0,1,2 <-> x,y,z). Parameters ---------- var - string - Name of the variable to read. Must be Bifrost internal names. - mf_ispecies - integer [1, 28] - Species ID - mf_ilevel - integer - Ionization level + Name of the variable to read. snap - integer, optional Snapshot number to read. By default reads the loaded snapshot; if a different number is requested, will load that snapshot by running self.set_snap(snap). + mf_ispecies - integer, or None (default) + Species ID + if None, set using other fluid kwargs (see ifluid, iSL, iS). + if still None, use self.mf_ispecies + mf_ilevel - integer, or None (default) + Ionization level + if None, set using other fluid kwargs (see ifluid, iSL, iL). + if still None, use self.mf_ilevel + ifluid - tuple of integers, or None (default) + if not None: (mf_ispecies, mf_ilevel) = ifluid + match_type - None (default), 0, or 1. + whether to try to match physics (0) or aux (1) where applicable. + see self.__init__.doc for more help. + cache - False (default) or True + whether to cache (store in memory) the result. + (if result already in memory, bring to "front" of list.) + check_cache - True (default) or False + whether to check cache to see if the result already exists in memory. + When possible, return existing result instead of repeating calculation. + cache_with_nfluid - None (default), 0, 1, or 2 + if not None, cache result and associate it with this many fluids. + 0 -> neither; 1 -> just ifluid; 2 -> both ifluid and jfluid. + read_mode - None (default), 'io', or 'zc' + if not None, first set self.read_mode to the value provided. + **kwargs may contain the following: + iSL - alias for ifluid + jSL - alias for jfluid + iS, iL - alias for ifluid[0], ifluid[1] + jS, jL - alias for jfluid[0], jfluid[1] + extra **kwargs are passed to NOWHERE. + extra *args are passed to NOWHERE. """ + kw__preprocess = dict(snap=snap, iix=iix, iiy=iiy, iiz=iiz, + mf_ispecies=mf_ispecies, mf_ilevel=mf_ilevel, + mf_jspecies=mf_jspecies, mf_jlevel=mf_jlevel, + ifluid=ifluid, jfluid=jfluid, + panic=panic, match_type=match_type, + check_cache=check_cache, cache=cache, + cache_with_nfluid=cache_with_nfluid, + read_mode=read_mode, + internal=True, # we are inside get_var. + **kwargs, + ) + # do pre-processing + kw__load_quantity, kw__postprocess = self._get_var_preprocess(var, **kw__preprocess) + + # actually get the value of var <<<<< + val = self._load_quantity(var, **kw__load_quantity) + + # do post-processing (function is defined in bifrost.py) + val = self._get_var_postprocess(val, var=var, printing_stats=printing_stats, **kw__postprocess) + return val - if var in ['x', 'y', 'z']: + def _get_var_preprocess(self, var, snap=None, iix=None, iiy=None, iiz=None, + mf_ispecies=None, mf_ilevel=None, mf_jspecies=None, mf_jlevel=None, + ifluid=None, jfluid=None, panic=False, internal=False, + match_type=None, check_cache=True, cache=False, cache_with_nfluid=None, + read_mode=None, **kw__fluids): + '''preprocessing for get_var. + returns ((dict of kwargs to pass to _load_quantity), + (dict of kwargs to pass to _get_var_postprocess)) + ''' + if var == '' and not document_vars.creating_vardict(self): + help(self.get_var) + + if var in AXES: return getattr(self, var) - if var in self.varsmfc: - if mf_ilevel is None and self.mf_ilevel == 1: - mf_ilevel = 2 - print("Warning: mfc is only for ionized species," - "Level changed to 2") - if mf_ilevel == 1: - mf_ilevel = 2 - print("Warning: mfc is only for ionized species." - " Level changed to 2") - - if var not in self.snapevars: - if (mf_ispecies is None): - if self.mf_ispecies < 1: - mf_ispecies = 1 - print("Warning: variable is only for electrons, " - "iSpecie changed to 1") - elif (mf_ispecies < 1): - mf_ispecies = 1 - print("Warning: variable is only for electrons, " - "iSpecie changed to 1") - - if not hasattr(self, 'iix'): - self.set_domain_iiaxis(iinum=iix, iiaxis='x') - self.set_domain_iiaxis(iinum=iiy, iiaxis='y') - self.set_domain_iiaxis(iinum=iiz, iiaxis='z') - else: - if (iix != slice(None)) and np.any(iix != self.iix): - if self.verbose: - print('(get_var): iix ', iix, self.iix) - self.set_domain_iiaxis(iinum=iix, iiaxis='x') - if (iiy != slice(None)) and np.any(iiy != self.iiy): - if self.verbose: - print('(get_var): iiy ', iiy, self.iiy) - self.set_domain_iiaxis(iinum=iiy, iiaxis='y') - if (iiz != slice(None)) and np.any(iiz != self.iiz): - if self.verbose: - print('(get_var): iiz ', iiz, self.iiz) - self.set_domain_iiaxis(iinum=iiz, iiaxis='z') - - if self.cstagop and ((self.iix != slice(None)) or - (self.iiy != slice(None)) or - (self.iiz != slice(None))): - self.cstagop = False - print( - 'WARNING: cstagger use has been turned off,', - 'turn it back on with "dd.cstagop = True"') - - if ((snap is not None) and np.any(snap != self.snap)): - self.set_snap(snap) - - if ((mf_ispecies is not None) and (mf_ispecies != self.mf_ispecies)): - self.set_mfi(mf_ispecies, mf_ilevel) - elif (( mf_ilevel is not None) and (mf_ilevel != self.mf_ilevel)): - self.set_mfi(mf_ispecies, mf_ilevel) - - if var in self.varsmm: - if ((mf_jspecies is not None) and (mf_jspecies != self.mf_jspecies)): - self.set_mfj(mf_jspecies, mf_jlevel) - elif (( mf_ilevel is not None) and (mf_jlevel != self.mf_jlevel)): - self.set_mfj(mf_jspecies, mf_jlevel) - - # This should not be here because mf_ispecies < 0 is for electrons. - #assert (self.mf_ispecies > 0 and self.mf_ispecies <= 28) - - # # check if already in memmory - # if var in self.variables: - # return self.variables[var] - if var in self.simple_vars: # is variable already loaded? - val = self._get_simple_var(var, self.mf_ispecies, self.mf_ilevel, - self.mf_jspecies, self.mf_jlevel) - elif var in self.auxxyvars: - val = super(EbysusData, self)._get_simple_var_xy(var) - else: - val = self._get_composite_mf_var(var) + if var in self.varn.keys(): + var = self.varn[var] - if np.shape(val) != (self.xLength, self.yLength, self.zLength): + if match_type is not None: + self.match_type = match_type + if read_mode is not None: + self.read_mode = read_mode - if np.size(self.iix)+np.size(self.iiy)+np.size(self.iiz) > 3: - # at least one slice has more than one value + # set fluids as appropriate to kwargs + kw__fluids = dict(mf_ispecies=mf_ispecies, mf_ilevel=mf_ilevel, ifluid=ifluid, + mf_jspecies=mf_jspecies, mf_jlevel=mf_jlevel, jfluid=jfluid, + **kw__fluids) + self.set_fluids(**kw__fluids) - # x axis may be squeezed out, axes for take() - axes = [0, -2, -1] + # set snapshot as needed + if snap is not None: + if not np.array_equal(snap, self.snap): + self.set_snap(snap) + self.panic = panic - for counter, dim in enumerate(['iix', 'iiy', 'iiz']): - if (np.size(getattr(self, dim)) > 1 or - getattr(self, dim) != slice(None)): - # slicing each dimension in turn - val = val.take(getattr(self, dim), axis=axes[counter]) - else: - # all of the slices are only one int or slice(None) - val = val[self.iix, self.iiy, self.iiz] + # set iix, iiy, iiz appropriately + slices_names_and_vals = (('iix', iix), ('iiy', iiy), ('iiz', iiz)) + original_slice = [iix if iix is not None else getattr(self, slicename, slice(None)) + for slicename, iix in slices_names_and_vals] + self.set_domain_iiaxes(iix=iix, iiy=iiy, iiz=iiz, internal=internal) - # ensuring that dimensions of size 1 are retained - val = np.reshape(val, (self.xLength, self.yLength, self.zLength)) + # set caching kwargs appropriately (see file_memory.with_caching() for details.) + kw__caching = dict(check_cache=check_cache, cache=cache, cache_with_nfluid=cache_with_nfluid) - return val + # setup and return result. + kw__load_quantity = dict(panic=panic, **kw__caching) + kw__postprocess = dict(original_slice=original_slice) + return (kw__load_quantity, kw__postprocess) + + def genvar(self): + ''' + Dictionary of original variables which will allow to convert to cgs. + ''' + self.varn = {} + self.varn['rho'] = 'r' + self.varn['ne'] = 'nel' + self.varn['tg'] = 'tg' + self.varn['pg'] = 'p' + self.varn['ux'] = 'ux' + self.varn['uy'] = 'uy' + self.varn['uz'] = 'uz' + self.varn['e'] = 'e' + self.varn['bx'] = 'bx' + self.varn['by'] = 'by' + self.varn['bz'] = 'bz' + + def simple_trans2comm(self, varname, snap=None, mf_ispecies=None, mf_ilevel=None, *args, **kwargs): + ''' Simple form of trans2com, can select species and ionized level''' + + self.trans2commaxes() + + self.sel_units = 'cgs' + + # Trying cgs + + sign = 1.0 + if varname[-1] in ['x', 'y', 'z']: + + varname = varname+'c' + if varname[-2] in ['y', 'z']: + sign = -1.0 - def _get_simple_var( - self, - var, - mf_ispecies=None, - mf_ilevel=None, - mf_jspecies=None, - mf_jlevel=None, - order='F', - mode='r', - *args, - **kwargs): + var = self.get_var(varname, snap=snap, mf_ispecies=mf_ispecies, mf_ilevel=mf_ilevel, *args, **kwargs) + var = sign * var + + var = var[..., ::-1].copy() + + return var + + def total_trans2comm(self, varname, snap=None, *args, **kwargs): + ''' Trans2comm that sums the selected variable over all species and levels. + For variables that do not change through species simple_trans2comm is used + with the default specie. ''' + + if varname in self.varn.keys(): + varname = self.varn[varname] + + # # # # # Helping dictionaries # # # # # + + # Electron variables + e_variables = {'r': 're', 'ux': 'uex', 'uy': 'uey', 'uz': 'uez', 'tg': 'etg', 'px': 'pex', 'py': 'pey', 'pz': 'pez'} + + # Instead of using ux or similar, uix with the specific fluid is used + i_variables = {'ux': 'uix', 'uy': 'uiy', 'uz': 'uiz', 'px': 'pix', 'py': 'piy', 'pz': 'piz'} + + # Different variables add in different ways + + # # # # # Density # # # # # + # Since it is the same volume, density just adds + if varname == 'r': + var = self.simple_trans2comm(e_variables[varname], snap, *args, **kwargs) + + for fluid in self.fluids.SL: + var += self.simple_trans2comm(varname, snap, mf_ispecies=fluid[0], mf_ilevel=fluid[1], *args, **kwargs) + + return var + + # # # # # Momentum # # # # # + # Momentum just adds. + if varname in ['px', 'py', 'pz', 'pix', 'piy', 'piz']: + # e variables are giving problems for some reason, the next line could be removed if necesary + var = self.simple_trans2comm(e_variables[varname], snap, *args, **kwargs) + + for fluid in self.fluids.SL: + var += self.simple_trans2comm(i_variables[varname], snap, mf_ispecies=fluid[0], mf_ilevel=fluid[1], *args, **kwargs) + + return var + + # # # # # Velocity # # # # # + # Velocity depends on the density and the momentum of each fluid + # Ux = Px/rho + # trying recursivity for rho + if varname in ['ux', 'uy', 'uz', 'uix', 'uiy', 'uiz']: + axis = varname[-1] + + # px = sum_j rho_j*ux_j + + # e contribution to velocity, could be removed + var1 = self.simple_trans2comm(e_variables['r'], snap, *args, **kwargs) * self.simple_trans2comm(e_variables['p'+axis], snap, *args, **kwargs) + + for fluid in self.fluids.SL: + specie_rho = self.simple_trans2comm('r', snap, mf_ispecies=fluid[0], mf_ilevel=fluid[1], *args, **kwargs) + specie_pi = self.simple_trans2comm(i_variables['p'+axis], snap, mf_ispecies=fluid[0], mf_ilevel=fluid[1], *args, **kwargs) + var1 += specie_rho*specie_pi + + # rho, recursive + var2 = self.total_trans2comm('r', snap=None, *args, **kwargs) + + return var1/var2 + + # # # # # Temperature # # # # # + # Temperature depends on density, mass and temperature of each fluid + # T_total = [ sum_j (rho_j/m_j)*tg_j ]/[ sum_j (rho_j/m_j) ] + # = alpha/beta + + if varname in ['tg', 'temperature']: + + n_e = self.simple_trans2comm('er', snap, *args, **kwargs)*self.uni.u_r/self.get_mass(-1, units='cgs') + tgi_e = self.simple_trans2comm('etg', snap, *args, **kwargs) + + alpha = n_e * tgi_e + beta = n_e + + for fluid in self.fluids.SL: + n = self.simple_trans2comm('r', snap, mf_ispecies=fluid[0], mf_ilevel=fluid[1], *args, **kwargs)*self.uni.u_r/self.get_mass((fluid[0], fluid[1]), units='cgs') + tgi = self.simple_trans2comm('tg', snap, mf_ispecies=fluid[0], mf_ilevel=fluid[1], *args, **kwargs) + + alpha += n*tgi + beta += n + + return alpha/beta + + # # # # # All other variables # # # # # + # For variables that do not deppend on the specie + return self.simple_trans2comm(varname, snap, *args, **kwargs) + + @document_vars.quant_tracking_simple('SIMPLE_VARS') + def _get_simple_var(self, var, order='F', mode='r', panic=False, *args, **kwargs): """ Gets "simple" variable (ie, only memmap, not load into memory). - Overloads super class to make a distinction between different - filenames for different variables - Parameters: ----------- var - string @@ -350,484 +908,1550 @@ def _get_simple_var( mode - string, optional numpy.memmap read mode. By default is read only ('r'), but you can use 'r+' to read and write. DO NOT USE 'w+'. + panic - False (default) or True. + whether we are trying to read a '.panic' file. + + *args and **kwargs go to NOWHERE. + + Minor Deprecation Notice: + ------------------------- + Support for entering fluids args/kwargs (mf_ispecies, mf_ilevel, mf_jspecies, mf_jlevel) + directly into _get_simple_var has been deprecated as of July 6, 2021. + As an alternative, use self.set_fluids() (or self.set_mfi() and self.set_mfj()), + before calling self._get_simple_var(). Returns ------- result - numpy.memmap array Requested variable. """ - if (np.size(self.snap) > 1): + # handle documentation for simple_vars + # set documentation for vardict, if var == ''. + if var == '': + _simple_vars_msg = ('Quantities which are stored by the simulation. These are ' + 'loaded as numpy memmaps by reading data files directly.') + docvar = document_vars.vars_documenter(self, 'SIMPLE_VARS', None, _simple_vars_msg) + # TODO (maybe): ^^^ use self.simple_vars, instead of None, for QUANT_VARS (args[2]) + # However, that might not be viable, depending on when self.simple_vars is assigned + for x in AXES: + docvar('b'+x, x+'-component of magnetic field [simu. units]', + nfluid=0, uni=U_TUPLE(UNI.b, UsymD(usi='T', ucgs='G'))) + docvar('r', 'mass density of ifluid [simu. units]', nfluid=1, uni=UNI_rho) + for x in AXES: + docvar('p'+x, x+'-component of momentum density of ifluid [simu. units]', + nfluid=1, uni=UNI_speed * UNI_rho) + units_e = dict(uni_f=UNI.e, usi_name=Usym('J') / Usym('m')**3) # ucgs_name= ??? + docvar('e', 'energy density of ifluid [simu. units]. Use -1 for electrons.', + nfluid=1, **units_e) + return None + + if var not in self.simple_vars: + return None + + # here is where we decide which file and what part of the file to load as a memmap <<<<< + result = self._load_simple_var_from_file(var, order=order, mode=mode, panic=panic, **kwargs) + result = self._assign_simple_var_mesh_location(result, var) # convert to ArrayOnMesh, if mesh_location_tracking is enabled + return result + + def _assign_simple_var_mesh_location(self, arr, var): + '''assigns the mesh location associated with var to arr, returning a stagger.ArrayOnMesh object. + (The ArrayOnMesh behaves just like a numpy array, but also tracks mesh location.) + + if self.mesh_location_tracking is disabled, instead just returns arr. + ''' + if not self.mesh_location_tracking: + return arr + # else + location = self._get_simple_var_mesh_location(var) + result = stagger.ArrayOnMesh(arr, location) + return result + + def _get_simple_var_mesh_location(self, var): + '''returns mesh location of simple var (bx,by,bz,r,px,py,pz,e)''' + if var in ((face_centered_quant + x) for x in AXES for face_centered_quant in ('b', 'p')): + x = var[-1] + return stagger.mesh_location_face(x) + elif var in ('r', 'e'): + return stagger.mesh_location_center() + raise ValueError(f"Mesh location for var={var} unknown. Locations are only known for: (bx,by,bz,r,px,py,pz,e)") + + def _load_simple_var_from_file(self, var, order='F', mode='r', panic=False, **kwargs): + '''loads the var data directly from the appropriate file. returns an array.''' + if self.read_mode == 'io': + assert mode in ('r', 'r+'), f"invalid mode: {mode}. Halted before deleting data. Valid modes are: 'r', 'r+'." + filename, kw__get_mmap = self._get_simple_var_file_info(var, order=order, panic=panic, **kwargs) + result = get_numpy_memmap(filename, mode=mode, **kw__get_mmap) + elif self.read_mode == 'zc': + # << note that 'zc' read_mode ignores order, mode, and **kwargs + filename, array_n = self._get_simple_var_file_meta(var, panic=panic, _meta_as_index=True) + result = load_zarr(filename, array_n) + else: + raise NotImplementedError(f'EbysusData.read_mode = {read_mode}') + return result + + def _get_simple_var_file_info(self, var, panic=False, order='F', **kw__get_memmap): + '''gets file info but does not read memmap; helper function for _get_simple_var. + + returns (filename, kwargs for get_numpy_memmap) corresponding to var. + ''' + filename, meta_info = self._get_simple_var_file_meta(var, panic=panic) + _kw__memmap = self._file_meta_to_memmap_kwargs(*meta_info) + _kw__memmap.update(**kw__get_memmap, order=order) + return filename, _kw__memmap + + def _get_simple_var_file_meta(self, var, panic=False, _meta_as_index=False): + '''returns "meta" info about reading var from file. + + primarily intended as a helper function for _get_simple_var_file_info. + + Each file contains N vars, M fluids per var. For a total of N*M single-fluid arrays. + The meta details tell: + idx - index of var in the file. (scan along N) + jdx - index of fluid for this array. (scan along M) (equals 0 for vars with M==1) + mf_arr_size - number of fluids per var in the file. (== M) + + returns (filename, (idx, mf_arr_size, jdx)) + + if _meta_as_index, instead returns (filename, index of array in file), + where index of array in file = idx * mf_arr_size + jdx. + ''' + # set currSnap, currStr = (current single snap, string for this snap) + if np.shape(self.snap) != (): # self.snap is list; pick snapInd value from list. currSnap = self.snap[self.snapInd] currStr = self.snap_str[self.snapInd] - else: + else: # self.snap is single snap. currSnap = self.snap currStr = self.snap_str - if currSnap < 0: - filename = self.file_root - fsuffix_b = '.scr' + + # check if we are reading .scr (snap < 0), snap0 (snap == 0), or "normal" snap (snap > 1) + if currSnap > 0: # reading "normal" snap + _reading_scr = False + #currStr = currStr + elif currSnap == 0: # reading snap0 + _reading_scr = False currStr = '' - elif currSnap == 0: - filename = self.file_root - fsuffix_b = '' + else: # currSnap < 0 # reading .scr + _reading_scr = True currStr = '' - else: - filename = self.file_root - fsuffix_b = '' - self.mf_arr_size = 1 - if os.path.exists('%s.io' % self.file_root): + mf_arr_size = 1 + iS = str(self.mf_ispecies).zfill(2) # ispecies as str. min 2 digits. (E.g. 3 --> '03') + iL = str(self.mf_ilevel).zfill(2) # ilevel as str. min 2 digits. (E.g. 14 --> '14') + iSL = dict(iS=iS, iL=iL) + + jdx = 0 # counts number of fluids with iSL < jSL. ( (iS < jS) OR ((iS == jS) AND (iL < jL)) ) + + # -------- figure out file name and idx (used to find offset in file). --------- # + if os.path.exists(self.file_root_with_io_ext): + # in this case, we are reading an ebysus-like snapshot. + _reading_ebysuslike_snap = True + + # check if var is a simple var from snaps. + _reading_snap_not_aux = True # whether we are reading '.snap' (not '.aux') if (var in self.mhdvars and self.mf_ispecies > 0) or ( - var in ['bx', 'by', 'bz']): + var in ['bx', 'by', 'bz']): # magnetic field, or a fluid-specific mhd simple variable) idx = self.mhdvars.index(var) - fsuffix_a = '.snap' - dirvars = '%s.io/mf_common/' % self.file_root - filename = self.mf_common_file - elif var in self.snaprvars and self.mf_ispecies > 0: + filename = os.path.join('mf_common', self.mf_common_file) + elif var in self.snaprvars and self.mf_ispecies > 0: # mass density (for non-electron fluid) idx = self.snaprvars.index(var) - fsuffix_a = '.snap' - dirvars = '%s.io/mf_%02i_%02i/mfr/' % (self.file_root, - self.mf_ispecies, self.mf_ilevel) - filename = self.mfr_file % (self.mf_ispecies, self.mf_ilevel) - elif var in self.snappvars and self.mf_ispecies > 0: + filename = os.path.join('mf_{iS:}_{iL:}', 'mfr', self.mfr_file).format(**iSL) + elif var in self.snappvars and self.mf_ispecies > 0: # momentum density (for non-electron fluid) idx = self.snappvars.index(var) - fsuffix_a = '.snap' - dirvars = '%s.io/mf_%02i_%02i/mfp/' % (self.file_root, - self.mf_ispecies, self.mf_ilevel) - filename = self.mfp_file % (self.mf_ispecies, self.mf_ilevel) - elif var in self.snapevars and self.mf_ispecies > 0: + filename = os.path.join('mf_{iS:}_{iL:}', 'mfp', self.mfp_file).format(**iSL) + elif var in self.snapevars and self.mf_ispecies > 0: # energy density (for non-electron fluid) idx = self.snapevars.index(var) - fsuffix_a = '.snap' - dirvars = '%s.io/mf_%02i_%02i/mfe/' % (self.file_root, - self.mf_ispecies, self.mf_ilevel) - filename = self.mfe_file % (self.mf_ispecies, self.mf_ilevel) - elif var in self.snapevars and self.mf_ispecies < 0: + filename = os.path.join('mf_{iS:}_{iL:}', 'mfe', self.mfe_file).format(**iSL) + elif var in self.snapevars and self.mf_ispecies < 0: # energy density (for electrons) idx = self.snapevars.index(var) - filename = self.mf_e_file - dirvars = '%s.io/mf_e/'% self.file_root - fsuffix_a = '.snap' - elif var in self.auxvars: - idx = self.auxvars.index(var) - fsuffix_a = '.aux' - dirvars = '%s.io/mf_common/' % (self.file_root, - self.mf_ispecies, self.mf_ilevel) - filename = self.file_root - elif var in self.varsmf: - idx = self.varsmf.index(var) - fsuffix_a = '.aux' - dirvars = '%s.io/mf_%02i_%02i/mfa/' % (self.file_root, - self.mf_ispecies, self.mf_ilevel) - filename = self.mf_file % (self.mf_ispecies, self.mf_ilevel) - elif var in self.varsmm: - idx = self.varsmm.index(var) - fsuffix_a = '.aux' - dirvars = '%s.io/mf_%02i_%02i/mm/' % (self.file_root, - self.mf_ispecies, self.mf_ilevel) - filename = self.mm_file % (self.mf_ispecies, self.mf_ilevel) - self.mf_arr_size = self.mf_total_nlevel - jdx=0 - for ispecies in range(1,self.mf_nspecies+1): - if (self.mf_nspecies == 1): - aa=at.atom_tools(atom_file=self.mf_tabparam['SPECIES'][2]) - else: - aa=at.atom_tools(atom_file=self.mf_tabparam['SPECIES'][ispecies-1][2]) - nlevels=len(aa.params['lvl']) - for ilevel in range(1,nlevels+1): - if (ispecies < self.mf_jspecies): - jdx += 1 - elif ((ispecies == self.mf_jspecies) and (ilevel < self.mf_jlevel)): - jdx += 1 - elif var in self.varsmfe: - idx = self.varsmfe.index(var) - fsuffix_a = '.aux' - dirvars = '%s.io/mf_%02i_%02i/mfe/' % (self.file_root, - self.mf_ispecies, self.mf_ilevel) - filename = self.mfe_file % (self.mf_ispecies, self.mf_ilevel) - elif var in self.varsmfc: - idx = self.varsmfc.index(var) - fsuffix_a = '.aux' - dirvars = '%s.io/mf_%02i_%02i/mfc/' % (self.file_root, - self.mf_ispecies, self.mf_ilevel) - filename = self.mfc_file % (self.mf_ispecies, self.mf_ilevel) + filename = os.path.join('mf_e', self.mf_e_file) + else: # var is not a simple var from snaps. + # check if var is from aux. + _reading_snap_not_aux = False # we are reading '.aux' (not '.snap') + if var in self.auxvars: # global auxvars + idx = self.auxvars.index(var) + filename = os.path.join('mf_common', self.aux_file) + elif var in self.varsmf: # ?? + idx = self.varsmf.index(var) + filename = os.path.join('mf_{iS:}_{iL:}', 'mfa', self.mf_file).format(**iSL) + elif var in self.varsmfr: # ?? + idx = self.varsmfr.index(var) + filename = os.path.join('mf_{iS:}_{iL:}', 'mfr', self.mfr_file).format(**iSL) + elif var in self.varsmfp: # ?? + idx = self.varsmfp.index(var) + filename = os.path.join('mf_{iS:}_{iL:}', 'mfp', self.mfp_file).format(**iSL) + elif var in self.varsmfe: # ?? + idx = self.varsmfe.index(var) + filename = os.path.join('mf_{iS:}_{iL:}', 'mfe', self.mfe_file).format(**iSL) + elif var in self.varsmfc: # ?? + idx = self.varsmfc.index(var) + filename = os.path.join('mf_{iS:}_{iL:}', 'mfc', self.mfc_file).format(**iSL) + elif var in self.varsmm: # two-fluid auxvars, e.g. mm_cross. + idx = self.varsmm.index(var) + filename = os.path.join('mf_{iS:}_{iL:}', 'mm', self.mm_file).format(**iSL) + # calculate important details for data's offset in file. + mf_arr_size = self.mf_total_nlevel + for ispecies in range(1, self.mf_nspecies+1): + nlevels = self.att[ispecies].params.nlevel + for ilevel in range(1, nlevels+1): + if (ispecies < self.mf_jspecies): + jdx += 1 + elif ((ispecies == self.mf_jspecies) and (ilevel < self.mf_jlevel)): + jdx += 1 + else: + errmsg = "Failed to find '{}' in simple vars for {}. (at point 1 in ebysus.py)" + errmsg = errmsg.format(var, self) + raise ValueError(errmsg) else: - dirvars = '' + # in this case, we are reading a bifrost-like snapshot. (There is NO snapname.io folder.) + _reading_ebysuslike_snap = True + # check if var is a simple var from snaps. + _reading_snap_not_aux = True # whether we are reading '.snap' (not '.aux') if (var in self.mhdvars and self.mf_ispecies > 0) or ( - var in ['bx', 'by', 'bz']): + var in ['bx', 'by', 'bz']): # magnetic field, or a fluid-specific mhd simple variable) idx = self.mhdvars.index(var) - fsuffix_a = '.snap' filename = self.mf_common_file - elif var in self.snapvars and self.mf_ispecies > 0: + elif var in self.snapvars and self.mf_ispecies > 0: # snapvars idx = self.snapvars.index(var) - fsuffix_a = '.snap' - filename = self.mf_file % (self.mf_ispecies, self.mf_ilevel) - elif var in self.snapevars and self.mf_ispecies > 0: + filename = self.mf_file.format(**iSL) + elif var in self.snapevars and self.mf_ispecies > 0: # snapevars (non-electrons) (??) idx = self.snapevars.index(var) - fsuffix_a = '.snap' - filename = self.mfe_file % (self.mf_ispecies, self.mf_ilevel) - elif var in self.snapevars and self.mf_ispecies < 0: + filename = self.mfe_file.format(**iSL) + elif var in self.snapevars and self.mf_ispecies < 0: # snapevars (electrons) (??) idx = self.snapevars.index(var) filename = self.mf_e_file - fsuffix_a = '.snap' - elif var in self.auxvars: - idx = self.auxvars.index(var) - fsuffix_a = '.aux' - filename = self.file_root - elif var in self.varsmf: - idx = self.varsmf.index(var) - fsuffix_a = '.aux' - filename = self.mf_file % (self.mf_ispecies, self.mf_ilevel) - elif var in self.varsmm: - idx = self.varsmm.index(var) - fsuffix_a = '.aux' - filename = self.mm_file % (self.mf_ispecies, self.mf_ilevel) - self.mf_arr_size = self.mf_total_nlevel - jdx=0 - for ispecies in range(1,self.mf_nspecies+1): - if (self.mf_nspecies == 1): - aa=at.atom_tools(atom_file=self.mf_tabparam['SPECIES'][2]) - else: - aa=at.atom_tools(atom_file=self.mf_tabparam['SPECIES'][ispecies-1][2]) - nlevels=len(aa.params['lvl']) - for ilevel in range(1,nlevels+1): - if (ispecies < self.mf_jspecies): - jdx += 1 - elif ((ispecies == self.mf_jspecies) and (ilevel < self.mf_jlevel)): - jdx += 1 - - elif var in self.varsmfe: - idx = self.varsmfe.index(var) - fsuffix_a = '.aux' - filename = self.mfe_file % (self.mf_ispecies, self.mf_ilevel) - elif var in self.varsmfc: - idx = self.varsmfc.index(var) - fsuffix_a = '.aux' - filename = self.mfc_file % (self.mf_ispecies, self.mf_ilevel) - - filename = dirvars + filename + currStr + fsuffix_a + fsuffix_b - - '''if var not in self.mhdvars and not (var in self.snapevars and - self.mf_ispecies < 0) and var not in self.auxvars : - filename = filename % (self.mf_ispecies, self.mf_ilevel)''' + else: # var is not a simple var from snaps. + # check if var is from aux. + _reading_snap_not_aux = False # we are reading '.aux' (not '.snap') + if var in self.auxvars: # global auxvars + idx = self.auxvars.index(var) + filename = self.aux_file + elif var in self.varsmf: # ?? + idx = self.varsmf.index(var) + filename = self.mf_file.format(**iSL) + elif var in self.varsmfr: # ?? + idx = self.varsmfr.index(var) + filename = self.mfr_file.format(**iSL) + elif var in self.varsmfp: # ?? + idx = self.varsmfp.index(var) + filename = self.mfp_file.format(**iSL) + elif var in self.varsmfe: # ?? + idx = self.varsmfe.index(var) + filename = self.mfe_file.format(**iSL) + elif var in self.varsmfc: # ?? + idx = self.varsmfc.index(var) + filename = self.mfc_file.format(**iSL) + elif var in self.varsmm: # two-fluid auxvars, e.g. mm_cross. (??) + idx = self.varsmm.index(var) + filename = self.mm_file.format(**iSL) + # calculate important details for data's offset in file. + mf_arr_size = self.mf_total_nlevel + for ispecies in range(1, self.mf_nspecies+1): + nlevels = self.att[ispecies].params.nlevel + for ilevel in range(1, nlevels+1): + if (ispecies < self.mf_jspecies): + jdx += 1 + elif ((ispecies == self.mf_jspecies) and (ilevel < self.mf_jlevel)): + jdx += 1 + else: + errmsg = "Failed to find '{}' in simple vars for {}. (at point 2 in ebysus.py)" + errmsg = errmsg.format(var, self) + raise ValueError(errmsg) - dsize = np.dtype(self.dtype).itemsize - offset = self.nx * self.ny * self.nzb * idx * dsize * self.mf_arr_size - if (self.mf_arr_size == 1): - return np.memmap( - filename, - dtype=self.dtype, - order=order, - offset=offset, - mode=mode, - shape=(self.nx, self.ny, self.nzb)) + _snapdir = (self.file_root_with_io_ext) if _reading_ebysuslike_snap else '' + filename = os.path.join(_snapdir, filename) # TODO: remove formats above; put .format(**iSL) here. + + if panic: + _suffix_panic = '.panic' if _reading_snap_not_aux else '.aux.panic' + filename = filename + _suffix_panic else: - if var in self.varsmm: - offset += self.nx * self.ny * self.nzb * jdx * dsize - return np.memmap( - filename, - dtype=self.dtype, - order=order, - offset=offset, - mode=mode, - shape=(self.nx, self.ny, self.nzb)) - else: - return np.memmap( - filename, - dtype=self.dtype, - order=order, - offset=offset, - mode=mode, - shape=(self.nx, self.ny, self.nzb, self.mf_arr_size)) - - def _get_composite_mf_var(self, var, order='F', mode='r', *args, **kwargs): - """ - Gets composite variables for multi species fluid. - """ - if var == 'totr': # velocities - for mf_ispecies in range(28): - for mf_ispecies in range(28): - r = self._get_simple_var( - 'e', - mf_ispecies=self.mf_ispecies, - mf_ilevel=self.mf_ilevel, - order=order, - mode=mode) - return r - elif var in self.compvars: - return super(EbysusData, self)._get_composite_var(var) + _suffix_dotsnap = '.snap' if _reading_snap_not_aux else '.aux' + _suffix_dotscr = '.scr' if _reading_scr else '' + filename = filename + currStr + _suffix_dotsnap + _suffix_dotscr + + if _meta_as_index: + return filename, (idx * mf_arr_size + jdx) else: - return super(EbysusData, self).get_quantity(var) + return filename, (idx, mf_arr_size, jdx) - def get_varTime(self, var, snap=None, iix=None, iiy=None, iiz=None, - mf_ispecies=None, mf_ilevel=None, mf_jspecies=None, - mf_jlevel=None,order='F', - mode='r', *args, **kwargs): + def _file_meta_to_memmap_kwargs(self, idx, mf_arr_size=1, jdx=0): + '''convert details about where the array is located in the file, to kwargs for numpy memmap. - self.iix = iix - self.iiy = iiy - self.iiz = iiz + primarily intended as a helper function for _get_simple_var_file_info. - try: - if (snap is not None): - if (np.size(snap) == np.size(self.snap)): - if (any(snap != self.snap)): - self.set_snap(snap) - else: - self.set_snap(snap) - except ValueError: - print('WWW: snap has to be a numpy.arrange parameter') - - if var in self.varsmfc: - if mf_ilevel is None and self.mf_ilevel == 1: - mf_ilevel = 2 - print("Warning: mfc is only for ionized species," - "Level changed to 2") - if mf_ilevel == 1: - mf_ilevel = 2 - print("Warning: mfc is only for ionized species." - "Level changed to 2") - - if var not in self.snapevars: - if (mf_ispecies is None): - if self.mf_ispecies < 1: - mf_ispecies = 1 - print("Warning: variable is only for electrons," - "iSpecie changed to 1") - elif (mf_ispecies < 1): - mf_ispecies = 1 - print("Warning: variable is only for electrons," - "iSpecie changed to 1") - - if (((mf_ispecies is not None) and ( - mf_ispecies != self.mf_ispecies)) or (( - mf_ilevel is not None) and (mf_ilevel != self.mf_ilevel))): - self.set_mfi(mf_ispecies, mf_ilevel) - - # lengths for dimensions of return array - self.xLength = 0 - self.yLength = 0 - self.zLength = 0 - - for dim in ('iix', 'iiy', 'iiz'): - if getattr(self, dim) is None: - if dim[2] == 'z': - setattr(self, dim[2] + 'Length', getattr(self, 'n' + dim[2]+'b')) - else: - setattr(self, dim[2] + 'Length', getattr(self, 'n' + dim[2])) - setattr(self, dim, slice(None)) - else: - indSize = np.size(getattr(self, dim)) - setattr(self, dim[2] + 'Length', indSize) + Each file contains N vars, M fluids per var. For a total of N*M single-fluid arrays. + The meta details tell: + idx - index of var in the file. (scan along N) + jdx - index of fluid for this array. (scan along M. equals 0 when M==1.) + mf_arr_size - number of fluids per var in the file. (== M) - snapLen = np.size(self.snap) - value = np.empty([self.xLength, self.yLength, self.zLength, snapLen]) + returns a dict of the kwargs to pass to get_numpy_memmap. + ''' + # -------- use filename and offset details to pick appropriate kwargs for numpy memmap --------- # - for i in range(0, snapLen): - self.snapInd = 0 - self._set_snapvars() - self._init_vars() - value[:, :, :, i] = self.get_var( - var, snap=snap[i], iix=self.iix, iiy=self.iiy, iiz=self.iiz, - mf_ispecies = self.mf_ispecies, mf_ilevel=self.mf_ilevel) + # calculate info which numpy needs to read file as memmap. + dsize = np.dtype(self.dtype).itemsize + offset = self.nxb * self.nyb * self.nzb * dsize * (idx * mf_arr_size + jdx) + shape = (self.nxb, self.nyb, self.nzb) + obj = self if (self.N_memmap != 0) else None # for memmap memory management; popped before np.memmap(**kw). - try: - if ((snap is not None) and (snap != self.snap)): - self.set_snap(snap) + # kwargs which will be passed to get_numpy_memmap. + kw__get_mmap = dict(dtype=self.dtype, offset=offset, shape=shape, obj=obj) + return kw__get_mmap - except ValueError: - if ((snap is not None) and any(snap != self.snap)): - self.set_snap(snap) + def get_var_if_in_aux(self, var, *args__get_var, **kw__get_var): + """ get_var but only if it appears in aux (i.e. self.params['aux'][self.snapInd]) - return value + if var not in aux, return None. + *args and **kwargs go to get_var. + """ + if var in self.params['aux'][self.snapInd].split(): + return self.get_var(var, *args__get_var, **kw__get_var) + else: + return None + + ## COMPRESSION ALGORITHMS ## + def compress(self, mode='zc', smash_mode=None, warn=True, kw_smash=dict(), skip_existing=False, **kwargs): + '''compress the direct output of an ebysus simulation (the .io folder). + + mode tells what type of compression to use. + 'zc': zarr compression. Use zarr package to read then store data in compressed format. + Currently, there are no other modes available. + + The resulting, compressed data will be stored in a folder with .{mode} at the end of it. + e.g. self.compress(mode='zc') if data is stored in snapname.io will create snapname.zc. + **kwargs go to the compression algorithm for the given mode. + e.g. for mode='zc', kwargs go to zarr.array(**kwargs). + + smash_mode: None (default), 'trash' or one of ('destroy', 'delete', 'rm') + mode for smashing the original folder (containing the non-compressed data). + (Will only be applied after the compression is completed successfully) + None --> do not smash the original folder. + 'trash' --> move original folder to trash (determined by os.environ['TRASH']) + 'destroy', 'delete', or 'rm' --> destroy the original folder permanently. + warn: bool, default True + whether to ask for user confirmation before smashing the old folder. + CAUTION: be very careful about using warn=False! + kw_smash: dict + additional kwargs to pass to smash_folder(). + skip_existing: bool, default False + if True, skip compressing each file for which a compressed version exists + (only checking destination filepath to determine existence.) + + returns name of created folder. + ''' + assert (smash_mode is None) or (not skip_existing), "smash_mode and skip_existing are incompatible." # for safety reasons. + mode = mode.lower() + try: # put the compression algorithms in a try..except block to make a warning if error is encountered. + if mode == 'zc': + result = self._zc_compress(skip_existing=skip_existing, **kwargs) + else: + raise NotImplementedError(f"EbysusData.compress(mode={repr(mode)})") + except: # we specifically want to catch ALL errors here, even BaseException like KeyboardInterrupt. + print('\n', '-'*94, sep='') + print('WARNING: ERROR ENCOUNTERED DURING COMPRESSION. COMPRESSION OUTPUT MAY BE CORRUPT OR INCOMPLETE') + print('-'*94) + raise + else: + if smash_mode is not None: + ORIGINAL = f"{self.get_param('snapname')}.io" + smash_folder(ORIGINAL, mode=smash_mode, warn=warn, **kw_smash) + return result + + def decompress(self, mode='zc', smash_mode=None, warn=True, kw_smash=dict(), **kwargs): + '''convert compressed data back into 'original .io format'. + + mode tells which type of compressed data to use. + 'zc': zarr compressed data. + Currently, there are no other modes available. + + smash_mode: None (default), 'trash' or one of ('destroy', 'delete', 'rm') + mode for smashing the original folder (containing the compressed data). + (Will only be applied after the compression is completed successfully) + None --> do not smash the original folder. + 'trash' --> move original folder to trash (determined by os.environ['TRASH']) + 'destroy', 'delete', or 'rm' --> destroy the original folder permanently. + warn: bool, default True + whether to ask for user confirmation before smashing the old folder. + CAUTION: be very careful about using warn=False! + kw_smash: dict + additional kwargs to pass to smash_folder(). + + The resulting data will be stored in a folder with .io at the end of it, like the original .io folder. + returns name of created folder. + ''' + mode = mode.lower() + # put the compression algorithms in a try..except block to make a warning if error is encountered. + try: + if mode == 'zc': + ORIGINAL = f"{self.get_param('snapname')}.zc" + result = self._zc_decompress(**kwargs) + else: + raise NotImplementedError(f"EbysusData.decompress(mode={repr(mode)})") + except: # we specifically want to catch ALL errors here, even BaseException like KeyboardInterrupt. + print('\n', '-'*98, sep='') + print('WARNING: ERROR ENCOUNTERED DURING DECOMPRESSION. DECOMPRESSION OUTPUT MAY BE CORRUPT OR INCOMPLETE') + print('-'*98) + raise + else: + if smash_mode is not None: + smash_folder(ORIGINAL, mode=smash_mode, warn=warn, **kw_smash) + return result + + def _zc_compress(self, verbose=1, skip_existing=False, **kw__zarr): + '''compress the .io folder into a .zc folder. + Converts data to format readable by zarr. + Testing indicates the .zc data usually takes less space AND is faster to read. + + skip_existing: bool, default False + if True, skip compressing each file for which a compressed version exists + (only checking destination filepath to determine existence.) + + returns (the name of the new folder, the number of bytes originally, the number of bytes after compression) + ''' + # bookkeeping - parameters + SNAPNAME = self.get_param('snapname') + SHAPE = self.shape # (nx, ny, nz). reshape the whole file to shape (nx, ny, nz, -1). + CHUNKS = (*(None for dim in self.shape), 1) # (None, None, None, 1). "1 chunk per array for each var". + ORDER = 'F' # data order. 'F' for 'fortran'. Results are nonsense if the wrong order is used. + DTYPE = ' 0: + new_dir = root.replace(f'{SNAPNAME}.io', f'{SNAPNAME}.zc') + if makedirs: + os.makedirs(new_dir, exist_ok=True) + for base in files: + src = os.path.join(root, base) + dst = os.path.join(new_dir, base) + if skip_existing and os.path.exists(dst): + continue + else: + yield (src, dst) + + # bookkeeping - printing updates + if verbose >= 1: # calculate total number of files and print progress as fraction. + nfiles = sum(1 for src, dst in snapfiles_iter(makedirs=False)) + nfstr = len(str(nfiles)) + start_time = time.time() + + def print_if_verbose(*args, vreq=1, print_time=True, file_n=None, clearline=0, **kw): + if verbose < vreq: + return # without printing anything. + if file_n is not None: + args = (f'({file_n:>{nfstr}d} / {nfiles})', *args) + if print_time: + args = (f'Time elapsed: {time.time() - start_time:.2f} s.', *args) + print(' '*clearline, end='\r') # cover the first characters with empty space. + print(*args, **kw) + + # the actual compression happens in this loop. + file_str_len = 0 + original_bytes_total = 0 + compressed_bytes_total = 0 + for file_n, (src, dst) in enumerate(snapfiles_iter(makedirs=True)): + z = save_filebinary_to_filezarr(src, dst, shape=SHAPE, dtype=DTYPE, order=ORDER, chunks=CHUNKS, **kw__zarr) + original_bytes_total += z.nbytes + compressed_bytes_total += z.nbytes_stored + # printing updates + if verbose: + file_str_len = max(file_str_len, len(dst)) + print_if_verbose(f'{dst}', end='\r', vreq=1, file_n=file_n, clearline=40+file_str_len) + + print_if_verbose('_zc_compress complete!' + + f' Compressed {tools.pretty_nbytes(original_bytes_total)}' + + f' into {tools.pretty_nbytes(compressed_bytes_total)}' + + f' (net compression ratio = {original_bytes_total/(compressed_bytes_total+1e-10):.2f}).', + print_time=True, vreq=1, clearline=40+file_str_len) + return (f'{SNAPNAME}.zc', original_bytes_total, compressed_bytes_total) + + def _zc_decompress(self, verbose=1): + '''use the data from the .zc folder to recreate the original .io folder. + returns the name of the new (.io) folder. + ''' + # notes: + # each zarray (in SNAPNAME.zc) is stored as a directory, + # containing a file named '.zarray', and possibly some data files. + # We assume that all folders containing a '.zarray' file are zarrays. + # (This structure is utilized in the implementation below.) + + # bookkeeping - parameters + SNAPNAME = self.get_param('snapname') + ORDER = 'F' # data order. 'F' for 'fortran'. Results are nonsense if the wrong order is used. + + # bookeeping - printing updates + if verbose >= 1: # calculate total number of files and print progress as fraction. + nfiles = sum(1 for _, _, files in os.walk(f'{SNAPNAME}.zc') if '.zarray' in files) + nfstr = len(str(nfiles)) + start_time = time.time() + + def print_if_verbose(*args, vreq=1, print_time=True, file_n=None, clearline=0, **kw): + if verbose < vreq: + return # without printing anything. + if file_n is not None: + args = (f'({file_n:>{nfstr}d} / {nfiles})', *args) + if print_time: + args = (f'Time elapsed: {time.time() - start_time:.2f} s.', *args) + print(' '*clearline, end='\r') # cover the first characters with empty space. + print(*args, **kw) + + # the actual decompression happens in this loop. + file_n = 0 + file_str_len = 0 + for root, dirs, files in os.walk(f'{SNAPNAME}.zc'): + if '.zarray' in files: # then this root is actually a zarray folder. + file_n += 1 # bookkeeping + src = root + dst = src.replace(f'{SNAPNAME}.zc', f'{SNAPNAME}.io') + os.makedirs(os.path.dirname(dst), exist_ok=True) # make dst dir if necessary. + save_filezarr_to_filebinary(src, dst, order=ORDER) + # printing updates + if verbose: + file_str_len = max(file_str_len, len(dst)) + print_if_verbose(f'{dst}', end='\r', vreq=1, file_n=file_n, clearline=40+file_str_len) + + print_if_verbose('_zc_decompress complete!', print_time=True, vreq=1, clearline=40+file_str_len) + return f'{SNAPNAME}.io' + + ## SNAPSHOT FILES - SELECTING / MOVING ## + def get_snap_files(self, snap=None, include_aux=True): + '''returns the minimal list of filenames for all files specific to this snap. + Directories containing solely files for this snap will be reported as the directory, not all contents. + (e.g. for zarray "files", which are directories, the zarray directory will be reported, not all its contents.) + + include_aux: whether to include aux files in the result. + ''' + snap = snap if snap is not None else self.get_snap_here() + with tools.EnterDirectory(self.fdir): + return get_snap_files(snap, dd=self, include_aux=include_aux) + + def get_snaps_files(self, snaps=None, include_aux=True): + '''returns a dict of {snap number: minimal list of filenames for all files specific to this snap}, + for the snaps provided (or self.snaps if snaps is not provided). + ''' + snaps = snaps if snaps is not None else self.snaps + result = {snap: self.get_snap_files(snap=snap, include_aux=include_aux) for snap in snaps} + return result + + ## CONVENIENCE METHODS ## + def print_values(self, fmtval='{: .1e}', fmtname='{:5s}', fmtgvar='{:s}', _show_units=True, + GLOBAL_VARS=['bx', 'by', 'bz'], FLUID_VARS=['nr', 'uix', 'uiy', 'uiz', 'tg'], + SKIP_FLUIDS=[], display_names={}, as_string=False, skip_errors=True): + '''prints fundamental values for self. + bx, by, bz, AND for each fluid: nr, uix, uiy, uiz, tg + Default behavior is to just print the mean of each value. + TODO: implement "advanced stats" mode where min and max are shown as well. + + fmtval: str + format string for values. + fmtname: str + format string for fluid names. + fmtgvar: str + format string for GLOBAL_VARS (var names, not values). + _show_units: bool, default True + if True, show a string at the top indicating the units system (e.g. 'si', 'simu', 'cgs'). + GLOBAL_VARS: list of strings + global vars to show. "global" --> "no fluids". + insert '' string(s) to put new line(s). Otherwise all will display on one line. + FLUID_VARS: list of strings + fluid vars to show. possibly a different value for each fluid. + SKIP_FLUIDS: list of (species,level) tuples. + skip any SL found in SKIP_FLUIDS. + rename: dict of {var: display_name} + display var name as display_name in the table. + as_string: bool, default False + if True, return result as a string instead of printing it. + skip_errors: bool, default True + if True, hide any values that have errors during get_var, instead of crashing. + ''' + # setup + def fmtv(val): return fmtval.format(val) # format value + def fmtn(name): return fmtname.format(name) # format fluid name + def fmtg(gvar): return fmtgvar.format(gvar) # format global var name + + def getv(var, **kw): + return self.get_var_gracefully(var, **kw) if skip_errors else self.get_var(var, **kw) + for var in (*GLOBAL_VARS, *FLUID_VARS): + if var not in display_names: + display_names[var] = var + lines = [] + # units + if _show_units: + lines.append(f"units = '{self.units_output}'") + # globals + if len(GLOBAL_VARS) > 0: + # solution without allowing for newlines: + #lines.append(' | '.join([f"{display_names[var]} = {fmtv(np.mean(getv(var)))}" for var in GLOBAL_VARS])) + # solution which converts any '' to newline: + i = 0 + gvarline = [] + while i < len(GLOBAL_VARS): + var = GLOBAL_VARS[i] + if var == '': + lines.append(' | '.join(gvarline)) + gvarline = [] + else: + gvarline.append(f"{fmtg(display_names[var])} = {fmtv(np.mean(getv(var)))}") + i += 1 + lines.append(' | '.join(gvarline)) + # put a new line before FLUID_VARS if there are any FLUID_VARS. + if len(FLUID_VARS) > 0: + lines.append('') + # table with fluids + if len(FLUID_VARS) > 0: + # get all the values # + SLs = [SL for SL in self.fluid_SLs() if not SL in SKIP_FLUIDS] + values = {SL: {} for SL in SLs} + for var in FLUID_VARS: + for SL in SLs: + values[SL][var] = np.mean(getv(var, iSL=SL)) + # convert values to strings # + vstrs = {SL: {var: fmtv(values[SL][var]) for var in FLUID_VARS} for SL in SLs} + for SL in SLs: + vstrs[SL]['name'] = fmtn(self.get_fluid_name(SL)) + # calculate string lengths # + vlens = {var: max(*(len(vstrs[SL][var]) for SL in SLs), len(display_names[var])) for var in FLUID_VARS} + vlens['name'] = max(len(vstrs[SL]['name']) for SL in SLs) + # convert strings to appropriate lengths for consistency # + vstrs_pretty = {SL: {var: vstrs[SL][var].rjust(vlens[var]) for var in FLUID_VARS} for SL in SLs} + for SL in SLs: + vstrs_pretty[SL]['name'] = vstrs[SL]['name'].ljust(vlens['name']) + # add header # + header = ' | '.join([' '*vlens['name'], *[display_names[var].center(vlens[var]) for var in FLUID_VARS]]) + lines.append(header) + lines.append('-'*len(header)) + # add rows # + for SL in SLs: + lines.append(' | '.join([vstrs_pretty[SL]['name'], *[vstrs_pretty[SL][var] for var in FLUID_VARS]])) + + # print or return + result = '\n'.join(lines) + if as_string: + return result + else: + print(result) def get_nspecies(self): return len(self.mf_tabparam['SPECIES']) -########### -# TOOLS # -########### - - -def write_mfr(rootname,inputdata,mf_ispecies,mf_ilevel): + def get_var_gracefully(self, var, *args, **kw): + '''returns self.get_var(*args, **kw) or return array of np.nan if get_var crashes.''' + try: + return self.get_var(var, *args, **kw) + except Exception: + return self.zero() + np.nan + + def get_var_nfluid(self, var): + '''returns number of fluids which affect self.get_var(var). + 0 - depends on NEITHER self.ifluid nor self.jfluid. + 1 - depends on self.ifluid but NOT self.jfluid. + 2 - depends on BOTH self.ifluid and self.jfluid. + None - unknown (var is in vardict, but has undocumented nfluid). + + Only works for var in self.vardict; fails for "constructed" vars, e.g. "b_mod". + ''' + search = self.search_vardict(var) + try: + return search.result['nfluid'] + except AttributeError: # var not found. (search is False) + raise ValueError(f"var not documented: '{var}'") from None + + def zero_at_meshloc(self, meshloc=[0, 0, 0], **kw__np_zeros): + '''return array of zeros, associated with the provided mesh location. + if not self.mesh_location_tracking, return self.zero() instead. + ''' + zero = self.zero(**kw__np_zeros) + if self.mesh_location_tracking: + return stagger.ArrayOnMesh(zero, meshloc) + else: + return zero + + def zero_at_mesh_center(self, **kw__np_zeros): + '''return array of zeros, associated with 'center of cell' mesh location. + if not self.mesh_location_tracking, return self.zero() instead. + ''' + return self.zero_at_meshloc(stagger.mesh_location_center(), **kw__np_zeros) + + def zero_at_mesh_face(self, x, **kw__np_zeros): + '''return array of zeros, associated with 'face of cell' mesh location. + Uses x to determine face. Use x='x', 'y', or 'z'. E.g. 'y' --> (0, -0.5, 0). + if not self.mesh_location_tracking, return self.zero() instead. + ''' + return self.zero_at_meshloc(stagger.mesh_location_face(x), **kw__np_zeros) + + def zero_at_mesh_edge(self, x, **kw__np_zeros): + '''return array of zeros, associated with 'edge of cell' mesh location. + Uses x to determine edge. Use x='x', 'y', or 'z'. E.g. 'y' --> (-0.5, 0, -0.5). + if not self.mesh_location_tracking, return self.zero() instead. + ''' + return self.zero_at_meshloc(stagger.mesh_location_edge(x), **kw__np_zeros) + + +############################# +# MAKING INITIAL SNAPSHOT # +############################# + +def write_mf_data(rootname, inputs, mfstr, **kw_ifluid): + '''write density, momentum, or energy for fluid indicated by kw_ifluid. + rootname = (should be set equal to the value of parameter 'snapname' in mhd.in) + inputs = list of arrays, each having shape (nx, ny, nz). This is the data to write. + mfstr = string indicating type of data. 'mfr', 'mfp', or 'mfe'. + **kw_ifluid: kwargs indicating fluid. + ''' + # interpret fluid kwargs + mf_ispecies, mf_ilevel = fluid_tools._interpret_kw_ifluid(**kw_ifluid, None_ok=False) if mf_ispecies < 1: - print('(WWW) species should start with 1') + print('(WWW) species should be 1 or larger when writing fluid data. For electrons use mf_e') if mf_ilevel < 1: - print('(WWW) levels should start with 1') - directory = '%s.io/mf_%02i_%02i/mfr' % (rootname,mf_ispecies,mf_ilevel) - nx, ny, nz = inputdata.shape + print('(WWW) levels should be 1 or larger when writing fluid data. For electrons use mf_e') + # check that all arrays are finite; warn if one is not. + for arr in inputs: + if not np.isfinite(arr).all(): + nonfinite_errmsg = 'at least one non-finite value detected in write_mfr! for iSL={}' + warnings.warn(nonfinite_errmsg.format((mf_ispecies, mf_ilevel))) + # calculate names of directory and saveloc. + directory = os.path.join( + '{}.io'.format(rootname), + 'mf_%02i_%02i' % (mf_ispecies, mf_ilevel), + mfstr + ) + saveloc = os.path.join( + directory, + '%s_%s_%02i_%02i.snap' % (rootname, mfstr, mf_ispecies, mf_ilevel) + ) + # calculate shape for memmap + shape = (*(inputs[0].shape), len(inputs)) # (nx, ny, nz, (1 or 3)) + # save memmap if not os.path.exists(directory): os.makedirs(directory) - data = np.memmap(directory+'/%s_mfr_%02i_%02i.snap' % (rootname,mf_ispecies,mf_ilevel), dtype='float32', mode='w+', order='f',shape=(nx,ny,nz,1)) - data[...,0] = inputdata + data = np.memmap(saveloc, dtype='float32', mode='w+', order='f', shape=shape) + for i, arr in enumerate(inputs): + data[..., i] = arr data.flush() -def write_mfp(rootname,inputdatax,inputdatay,inputdataz,mf_ispecies,mf_ilevel): - if mf_ispecies < 1: - print('(WWW) species should start with 1') - if mf_ilevel < 1: - print('(WWW) levels should start with 1') - directory = '%s.io/mf_%02i_%02i/mfp' % (rootname,mf_ispecies,mf_ilevel) - nx, ny, nz = inputdatax.shape - if not os.path.exists(directory): - os.makedirs(directory) - data = np.memmap(directory+'/%s_mfp_%02i_%02i.snap' % (rootname,mf_ispecies,mf_ilevel), dtype='float32', mode='w+', order='f',shape=(nx,ny,nz,3)) - data[...,0] = inputdatax - data[...,1] = inputdatay - data[...,2] = inputdataz - data.flush() -def write_mfe(rootname,inputdata,mf_ispecies,mf_ilevel): +def write_mfr(rootname, inputdata, mf_ispecies=None, mf_ilevel=None, **kw_ifluid): + '''write density. (Useful when using python to make initial snapshot; e.g. in make_mf_snap.py) + rootname = snapname (should be set equal to the value of parameter 'snapname' in mhd.in) + inputdata = array of shape (nx, ny, nz) + mass density [in ebysus units] of ifluid + ifluid must be entered. If not entered, raise TypeError. ifluid can be entered via one of: + - (mf_ispecies and mf_ilevel) + - **kw_ifluid, via the kwargs (ifluid), (iSL), or (iS and iL) + ''' + return write_mf_data(rootname, [inputdata], 'mfr', + mf_ispecies=mf_ispecies, mf_ilevel=mf_ilevel, **kw_ifluid) + + +def write_mfp(rootname, inputdatax, inputdatay, inputdataz, mf_ispecies=None, mf_ilevel=None, **kw_ifluid): + '''write momentum. (Useful when using python to make initial snapshot; e.g. in make_mf_snap.py) + rootname = snapname (should be set equal to the value of parameter 'snapname' in mhd.in) + inputdata = arrays of shape (nx, ny, nz) + momentum [in ebysus units] of ifluid + inputdatax is x-momentum, px; (px, py, pz) = (inputdatax, inputdatay, inputdataz) + ifluid must be entered. If not entered, raise TypeError. ifluid can be entered via one of: + - (mf_ispecies and mf_ilevel) + - **kw_ifluid, via the kwargs (ifluid), (iSL), or (iS and iL) + ''' + return write_mf_data(rootname, [inputdatax, inputdatay, inputdataz], 'mfp', + mf_ispecies=mf_ispecies, mf_ilevel=mf_ilevel, **kw_ifluid) + + +def write_mfpxyz(rootname, inputdataxyz, mf_ispecies, mf_ilevel, xyz): + '''write component of momentum. (Useful when using python to make initial snapshot; e.g. in make_mf_snap.py) + rootname = snapname (should be set equal to the value of parameter 'snapname' in mhd.in) + inputdataxyz = array of shape (nx, ny, nz) + momentum [in ebysus units] of ifluid, in x, y, OR z direction + (direction determined by parameter xyz) + mf_ispecies, mf_ilevel = int, int + species number and level number for ifluid. + xyz = 0 (for x), 1 (for y), 2 (for z) + determines which axis to write momentum along; e.g. xyz = 0 -> inputdataxyz is written to px. + ''' if mf_ispecies < 1: print('(WWW) species should start with 1') if mf_ilevel < 1: print('(WWW) levels should start with 1') - directory = '%s.io/mf_%02i_%02i/mfe' % (rootname,mf_ispecies,mf_ilevel) - nx, ny, nz = inputdata.shape + directory = '%s.io/mf_%02i_%02i/mfp' % (rootname, mf_ispecies, mf_ilevel) + nx, ny, nz = inputdataxyz.shape if not os.path.exists(directory): os.makedirs(directory) - data = np.memmap(directory+'/%s_mfe_%02i_%02i.snap' % (rootname,mf_ispecies,mf_ilevel), dtype='float32', mode='w+', order='f',shape=(nx,ny,nz,1)) - data[...,0] = inputdata + data = np.memmap(directory+'/%s_mfp_%02i_%02i.snap' % (rootname, mf_ispecies, mf_ilevel), dtype='float32', mode='w+', order='f', shape=(nx, ny, nz, 3)) + data[..., xyz] = inputdataxyz data.flush() -def write_mf_common(rootname,inputdatax,inputdatay,inputdataz,inputdatae=None): + +def write_mfe(rootname, inputdata, mf_ispecies=None, mf_ilevel=None, **kw_ifluid): + '''write energy. (Useful when using python to make initial snapshot; e.g. in make_mf_snap.py) + rootname = snapname (should be set equal to the value of parameter 'snapname' in mhd.in) + inputdata = array of shape (nx, ny, nz) + energy [in ebysus units] of ifluid + ifluid must be entered. If not entered, raise TypeError. ifluid can be entered via one of: + - mf_ispecies and mf_ilevel + - **kw_ifluid, via the kwargs (ifluid), (iSL), or (iS and iL) + ''' + return write_mf_data(rootname, [inputdata], 'mfe', + mf_ispecies=mf_ispecies, mf_ilevel=mf_ilevel, **kw_ifluid) + + +def write_mf_common(rootname, inputdatax, inputdatay, inputdataz, inputdatae=None): + '''write common (?? what is this ??). (Useful when using python to make initial snapshot; e.g. in make_mf_snap.py) + rootname = snapname (should be set equal to the value of parameter 'snapname' in mhd.in) + inputdata = arrays of shape (nx, ny, nz) + data for common. + inputdatax is x-common; (commonx, commony, commonz) = (inputdatax, inputdatay, inputdataz) + inputdatae = array of shape (nx, ny, nz), or None (default) + if non-None, written to common[...,3]. + ''' directory = '%s.io/mf_common' % (rootname) nx, ny, nz = inputdatax.shape if not os.path.exists(directory): os.makedirs(directory) if np.any(inputdatae) == None: - data = np.memmap(directory+'/%s_mf_common.snap' % (rootname), dtype='float32', mode='w+', order='f',shape=(nx,ny,nz,3)) - data[...,0] = inputdatax - data[...,1] = inputdatay - data[...,2] = inputdataz + data = np.memmap(directory+'/%s_mf_common.snap' % (rootname), dtype='float32', mode='w+', order='f', shape=(nx, ny, nz, 3)) + data[..., 0] = inputdatax + data[..., 1] = inputdatay + data[..., 2] = inputdataz else: - data = np.memmap(directory+'/%s_mf_common.snap' % (rootname), dtype='float32', mode='w+', order='f',shape=(nx,ny,nz,4)) - data[...,0] = inputdatae - data[...,1] = inputdatax - data[...,2] = inputdatay - data[...,3] = inputdataz + data = np.memmap(directory+'/%s_mf_common.snap' % (rootname), dtype='float32', mode='w+', order='f', shape=(nx, ny, nz, 4)) + data[..., 0] = inputdatae + data[..., 1] = inputdatax + data[..., 2] = inputdatay + data[..., 3] = inputdataz + data.flush() + + +def write_mf_commonxyz(rootname, inputdataxyz, xyz): + '''write common (?? what is this ??). (Useful when using python to make initial snapshot; e.g. in make_mf_snap.py) + rootname = snapname (should be set equal to the value of parameter 'snapname' in mhd.in) + inputdataxyz = array of shape (nx, ny, nz) + data for common. + (direction determined by parameter xyz) + xyz = 0 (for x), 1 (for y), 2 (for z) + determines which axis to write common along; e.g. xyz = 0 -> inputdataxyz is written to commonx. + ''' + directory = '%s.io/mf_common' % (rootname) + nx, ny, nz = inputdataxyz.shape + if not os.path.exists(directory): + os.makedirs(directory) + data = np.memmap(directory+'/%s_mf_common.snap' % (rootname), dtype='float32', mode='w+', order='f', shape=(nx, ny, nz, 4)) + data[..., xyz] = inputdataxyz data.flush() -def write_mf_e(rootname,inputdata): + +def write_mf_e(rootname, inputdata): + ''' write electron energy. (Useful when using python to make initial snapshot; e.g. in make_mf_snap.py) + rootname = snapname (should be set equal to the value of parameter 'snapname' in mhd.in) + inputdata = array of shape (nx, ny, nz) + energy [in ebysus units] of electrons. + ''' directory = '%s.io/mf_e/' % (rootname) nx, ny, nz = inputdata.shape if not os.path.exists(directory): os.makedirs(directory) - data = np.memmap(directory+'/%s_mf_e.snap' % (rootname), dtype='float32', mode='w+', order='f',shape=(nx,ny,nz,1)) - data[...,0] = inputdata + data = np.memmap(directory+'/%s_mf_e.snap' % (rootname), dtype='float32', mode='w+', order='f', shape=(nx, ny, nz, 1)) + data[..., 0] = inputdata data.flush() -def printi(fdir='./',rootname='',it=1): - dd=EbysusData(rootname,fdir=fdir,verbose=False) - nspecies=len(dd.mf_tabparam['SPECIES']) - for ispecies in range(0,nspecies): - aa=at.atom_tools(atom_file=dd.mf_tabparam['SPECIES'][ispecies][2]) - nlevels=len(aa.params['lvl']) - print('reading %s'%dd.mf_tabparam['SPECIES'][ispecies][2]) - for ilevel in range(1,nlevels+1): - print('ilv = %i'%ilevel) - r=dd.get_var('r',it,mf_ilevel=ilevel,mf_ispecies=ispecies+1) * dd.params['u_r'] - print('dens=%6.2E,%6.2E g/cm3'%(np.min(r),np.max(r))) - ux=dd.get_var('ux',it,mf_ilevel=ilevel,mf_ispecies=ispecies+1) * dd.params['u_u'] / 1e5 - print('ux=%6.2E,%6.2E km/s'%(np.min(ux),np.max(ux))) - uy=dd.get_var('uy',it,mf_ilevel=ilevel,mf_ispecies=ispecies+1) * dd.params['u_u'] / 1e5 - print('uy=%6.2E,%6.2E km/s'%(np.min(uy),np.max(uy))) - uz=dd.get_var('uz',it,mf_ilevel=ilevel,mf_ispecies=ispecies+1) * dd.params['u_u'] / 1e5 - print('uz=%6.2E,%6.2E km/s'%(np.min(uz),np.max(uz))) - tg=dd.get_var('mfe_tg',it,mf_ilevel=ilevel,mf_ispecies=ispecies+1) - print('tg=%6.2E,%6.2E K'%(np.min(tg),np.max(tg))) - ener=dd.get_var('e',it,mf_ilevel=ilevel,mf_ispecies=ispecies+1) * dd.params['u_e'] - print('e=%6.2E,%6.2E erg'%(np.min(ener),np.max(ener))) - - bx=dd.get_var('bx',it) * dd.params['u_b'] - print('bx=%5.2E G'%np.max(bx)) - by=dd.get_var('by',it) * dd.params['u_b'] - print('by=%5.2E G'%np.max(by)) - bz=dd.get_var('bz',it) * dd.params['u_b'] - print('bz=%5.2E G'%np.max(bz)) - + +def calculate_fundamental_writeables(fluids, B, nr, v, tg, tge, uni): + '''calculates the fundamental variables, in ebysus units, ready to be written to snapshot. + + Fluid-dependent results are saved to fluids; others are returned as dict. + Electrons are not included in fluids; they are treated separately. + + Inputs + ------ + fluids: an at_tools.fluids.Fluids object + fluid-dependent results will be saved to attributes of this object. + Also, the information in it is necessary to do the calculations. + B : magnetic field [Gauss]. + a list of [Bx, By, Bz]; Bx, By, Bz can be constants, or arrays. + result['B'] = B + nr: number densities [per meter^3] of fluids + a list of values ---> fluids[i].nr = nr[i] for i in range(len(fluids)) + a single value ---> fluid.nr = nr for fluid in fluids + v: velocities [meter per second] of fluids + a list of vectors --> fluids[i].v = v[i] for i in range(len(fluids)) + a single vector --> fluid.v = v for fluid in fluids + tg: temperature [Kelvin] of fluids + a list of values ---> fluids[i].tg = tg[i] for i in range(len(fluids)) + a single value ---> fluid.tg = tg for fluid in fluids + tge: temperature [Kelvin] of electrons + uni: bifrost.Bifrost_units object + this object is used to convert all results to ebysus units, before saving. + (e.g., for v, really it will be fluids[i].v = v[i] / uni.usi_u) + + Outputs + ------- + Edits fluids attributes, and returns result (a dict). + All outputs (in result, and in fluid attributes) are in [ebysus units]. + Keys of result are: + result['B'] = magnetic field. B[0] = Bx, B[1] = By, B[2] = Bz. + result['ee'] = electron energy density + Attributes of fluids containing fundamental calculated values are: + fluids.rho = mass densities of fluids. + fluids.p = momentum densities of fluids. fluids.p[i][x] is for fluid i, axis x. + fluids.energy = energy densities of fluids. + + Side Effects + ------------ + Additional attributes of fluids which are affected by this function are: + fluids.nr = number densities [cm^-3] of fluids. + fluids.tg = temperatures [K] of fluids. + fluids.v = velocities of fluids. fluids.v[i][x] is for fluid i, axis x. + fluids.px = fluids.p[:, 0, ...]. x-component of momentum densities of fluids. + fluids.py = fluids.p[:, 1, ...]. y-component of momentum densities of fluids. + fluids.pz = fluids.p[:, 2, ...]. z-component of momentum densities of fluids. + + Units for Outputs and Side Effects are [ebysus units] unless otherwise specified. + ''' + orig_stack, orig_stack_axis = getattr(fluids, 'stack', None), getattr(fluids, 'stack_axis', None) + fluids.stack = True + fluids.stack_axis = -1 + # global quantities + B = np.asarray(B)/uni.u_b # [ebysus units] magnetic field + # fluid (and global) quantities + fluids.assign_scalars('nr', (np.asarray(nr) / 1e6)) # [cm^-3] number density of fluids + nre = np.sum(fluids.nr * fluids.ionization, axis=-1) # [cm^-3] number density of electrons + fluids.assign_scalars('tg', tg) # [K] temperature of fluids + tge = tge # [K] temperature of electrons + + def _energy(ndens, tg): # returns energy density [ebysus units] + return (ndens * tg * uni.k_b / (uni.gamma-1)) / uni.u_e + fluids.energy = _energy(fluids.nr, fluids.tg) # [ebysus units] energy density of fluids + energy_electrons = _energy(nre, tge) # [ebysus units] energy density of electrons + # fluid quantities + fluids.rho = (fluids.nr * fluids.atomic_weight * uni.amu) / uni.u_r # [ebysus units] mass density of fluids + fluids.assign_vectors('v', (np.asarray(v) / uni.usi_u)) # [ebysus units] velocity + for fluid in fluids: + fluid_v = fluid.v # want to get to shape (3, Nx, Ny, Nz), or (3, 1, 1, 1) for broadcasting with rho. + fluid_v = np.expand_dims(fluid_v, axis=tuple(range(1, 1+4-np.ndim(fluid_v)))) # (if already 4D, does nothing.) + fluid.p = fluid_v * fluid.rho + # fluids.p = fluids.v * fluids.rho # [ebysus units] momentum density + for x in AXES: + setattr(fluids, 'p'+x, fluids.p[dict(x=0, y=1, z=2)[x]]) # sets px, py, pz + # restore original stack, stack_axis of fluids object. + if orig_stack is not None: + fluids.stack = orig_stack + if orig_stack_axis is not None: + fluids.stack_axis = orig_stack_axis + return dict(B=B, ee=energy_electrons) + + +def write_fundamentals(rootname, fluids, B, ee, zero=0): + '''writes fundamental quantities using write funcs (write_mfr, write_mfp, etc). + Fundamental quantities are: + magnetic field, electron energy, + fluids energy densities, fluids mass densities, fluids momentum densities. + + Inputs + ------ + rootname: string + rootname = snapname (should be set equal to the value of parameter 'snapname' in mhd.in) + fluids: an at_tools.fluids.Fluids object + The following attributes of fluids will be written. They should be in [ebysus units]: + fluids.rho = mass densities of fluids. + fluids.p = momentum densities of fluids. fluids[i].p[x] is for fluid i, axis x. + fluids.energy = energy densities of fluids. + B : magnetic field + ee : electron energy density + zero: a number or array + zero will be added to all data before it is written. + Suggestion: use zero = np.zeros((nx, ny, nz)). + This ensure all data will be the correct shape, and will be reshaped if it is a constant. + + Example Usage: + -------------- + # This is an example which performs the same task as a simple make_mf_snap.py file. + from atom_py.at_tools import fluids as fl + import helita.sim.ebysus as eb + uni = eb.Bifrost_units('mhd.in') # get units + # put code here which sets the values for: + # nx, ny, nz, mf_param_file, snapname # << these are all from 'mhd.in'; suggestion: read via RunTools.loadfiles. + # B, nr, velocities, tg, tge # << these are physical values; you can choose here what they should be. + # once those values are set, we can run the following: + fluids = fl.Fluids(mf_param_file=mf_param_file) # get fluids + # calculate the values of the fundamental quantities, in [ebysus units]: + global_quants = eb.calculate_fundamental_writeables(fluids, B, nr, velocities, tg, tge, uni) + zero = np.zeros((nx,ny,nz)) + # write the values (thus, completing the process of making the initial snapshot): + eb.write_fundamentals(rootname, fluids, **global_quants, zero=zero) + ''' + ## Fluid Densities ## + for fluid in fluids: + write_mfr(rootname, zero+fluid.rho, ifluid=fluid.SL) + ## Fluid Momenta ## + for fluid in fluids: + write_mfp(rootname, zero+fluid.p[0], zero+fluid.p[1], zero+fluid.p[2], ifluid=fluid.SL) + ## Fluid Energies ## + if len(fluids) > 1: + for fluid in fluids: + write_mfe(rootname, zero+fluid.energy, ifluid=fluid.SL) + ## Electron Energy ## + write_mf_e(rootname, zero+ee) + + ## Magnetic Field ## + write_mf_common(rootname, zero+B[0], zero+B[1], zero+B[2]) + else: + write_mf_common(rootname, zero+B[0], zero+B[1], zero+B[2], fluid.energy) + + +def printi(fdir='./', rootname='', it=1): + '''?? print data about snapshot i ?? (seems to not work though; SE checked on Mar 2, 2021).''' + dd = EbysusData(rootname, fdir=fdir, verbose=False) + nspecies = len(dd.mf_tabparam['SPECIES']) + for ispecies in range(0, nspecies): + aa = at.Atom_tools(atom_file=dd.mf_tabparam['SPECIES'][ispecies][2], fdir=fdir) + nlevels = aa.params.nlevel + print('reading %s' % dd.mf_tabparam['SPECIES'][ispecies][2]) + for ilevel in range(1, nlevels+1): + print('ilv = %i' % ilevel) + r = dd.get_var('r', it, mf_ilevel=ilevel, mf_ispecies=ispecies+1) * dd.params['u_r'] + print('dens=%6.2E,%6.2E g/cm3' % (np.min(r), np.max(r))) + r = dd.get_var('nr', it, mf_ilevel=ilevel, mf_ispecies=ispecies+1) + print('ndens=%6.2E,%6.2E 1/cm3' % (np.min(r), np.max(r))) + ux = dd.get_var('ux', it, mf_ilevel=ilevel, mf_ispecies=ispecies+1) * dd.params['u_u'] / 1e5 + print('ux=%6.2E,%6.2E km/s' % (np.min(ux), np.max(ux))) + uy = dd.get_var('uy', it, mf_ilevel=ilevel, mf_ispecies=ispecies+1) * dd.params['u_u'] / 1e5 + print('uy=%6.2E,%6.2E km/s' % (np.min(uy), np.max(uy))) + uz = dd.get_var('uz', it, mf_ilevel=ilevel, mf_ispecies=ispecies+1) * dd.params['u_u'] / 1e5 + print('uz=%6.2E,%6.2E km/s' % (np.min(uz), np.max(uz))) + tg = dd.get_var('mfe_tg', it, mf_ilevel=ilevel, mf_ispecies=ispecies+1) + print('tg=%6.2E,%6.2E K' % (np.min(tg), np.max(tg))) + ener = dd.get_var('e', it, mf_ilevel=ilevel, mf_ispecies=ispecies+1) * dd.params['u_e'] + print('e=%6.2E,%6.2E erg' % (np.min(ener), np.max(ener))) + + bx = dd.get_var('bx', it) * dd.params['u_b'] + print('bx=%6.2E,%6.2E G' % (np.min(bx), np.max(bx))) + by = dd.get_var('by', it) * dd.params['u_b'] + print('by=%6.2E,%6.2E G' % (np.min(by), np.max(by))) + bz = dd.get_var('bz', it) * dd.params['u_b'] + print('bz=%6.2E,%6.2E G' % (np.min(bz), np.max(bz))) + va = dd.get_var('va', it) * dd.params['u_u'] / 1e5 + print('va=%6.2E,%6.2E km/s' % (np.min(va), np.max(va))) + +################### +# READING FILES # +################### + + +@file_memory.manage_memmaps(file_memory.MEMORY_MEMMAP) +@file_memory.remember_and_recall(file_memory.MEMORY_MEMMAP, ORDERED=True) +def get_numpy_memmap(filename, **kw__np_memmap): + '''makes numpy memmap; also remember and recall (i.e. don't re-make memmap for the same file multiple times.)''' + return np.memmap(filename, **kw__np_memmap) + + +def load_zarr(filename, array_n=None): + '''reads zarr from file. if array_n is provided, index by [..., array_n].''' + if not os.path.exists(filename): + raise FileNotFoundError(filename) + # zarr error for non-existing file is confusing and doesn't include filename (as of 02/28/22) + # so we instead do our own check if file exists, and raise a nice error if it doesn't exist. + z = zarr.open(filename, mode='r') # we use 'open' instead of 'load' to ensure we only read the required chunks. + if array_n is None: + result = z[...] + else: + result = z[..., array_n] + return result + + +def save_filebinary_to_filezarr(src, dst, shape, dtype=' 2): - # print(('(WWW) read_params: line %i is invalid, skipping' % li)) - # li += 1 - # continue - if (np.size(line) == 1): - key = line - ii = 0 - # force lowercase because IDL is case-insensitive - if (np.size(line) == 2): - value = line[0].strip() - text = line[1].strip().lower() - try: - value = int(value) - except BaseException: - print('(WWW) read_mftab_ascii: could not find datatype in' - 'line %i, skipping' % li) - li += 1 - continue - if not (key[0] in params): - params[key[0]] = [value, text] - else: - params[key[0]] = np.vstack((params[key[0]], [value, text])) - if (np.size(line) == 3): - value = line[0].strip() - value2 = line[1].strip() - text = line[2].strip() - if key != 'species': - try: - value = int(value) - except BaseException: - print( - '(WWW) read_mftab_ascii: could not find datatype' - 'in line %i, skipping' % li) - else: - try: - value = int(value) - value2 = int(value2) - except BaseException: - print( - '(WWW) read_mftab_ascii: could not find datatype' - 'in line %i, skipping' % li) - li += 1 - continue - if not (key[0] in params): - params[key[0]] = [value, value2, text] - else: - params[key[0]] = np.vstack( - (params[key[0]], [value, value2, text])) - if (np.size(line) > 3): - # int type - try: - arr = [int(numeric_string) for numeric_string in line] - except BaseException: - print('(WWW) read_mftab_ascii: could not find datatype in' - 'line %i, skipping' % li) - li += 1 - continue - if not (key[0] in params): - params[key[0]] = [arr] - else: - params[key[0]] = np.vstack((params[key[0]], [arr])) - li += 1 + elif len(tokens) == 1: + key = tokens[0] + params[key] = [] + for colstart in colstartkeys: + if key.startswith(colstart): + convert_to_ints = True + else: + if convert_to_ints: + tokens = [int(token) for token in tokens] + params[key] += [tokens] + + for key in params.keys(): + params[key] = np.array(params[key]) + return params +read_mf_param_file = read_mftab_ascii # alias + +#################### +# LOCATING SNAPS # +#################### + + +class SnapfileNotFoundError(FileNotFoundError): + '''custom error class for reporting snap not found; subclass of FileNotFoundError.''' + + +def get_snap_files(snap, snapname=None, read_mode=None, dd=None, include_aux=True): + '''returns the minimal list of filenames for all files specific to this snap (a number). + if no data for snap is found, raises SnapfileNotFoundError (subclass of FileNotFoundError). + Directories containing solely files for this snap will be reported as the directory, not all contents. + (e.g. for zarray "files", which are directories, the zarray directory will be reported, not all its contents.) + + read_mode : string (e.g. 'io' or 'zc') or None; see EbysusData.read_mode. Defaults to 'io'. + snapname : snapshot name, or None. + dd : EbysusData object or None. If provided, used to guess snapname & read_mode as necessary. + include_aux: whether to include aux files in the result. + + This method expects to be called with the working directory set to the "main folder" for the run, + i.e. the directory containing the .idl files with parameters and the .io folder with snapshot data. + ''' + result = [] + # snapname, read_mode. set default values if not provided; use dd to help if dd is provided. + if snapname is None: + snapname = get_snapname(dd=dd) + if read_mode is None: + read_mode = 'io' if dd is None else dd.read_mode + # snapdir (e.g. 'snapname.io') + snapdir = f'{snapname}.{read_mode}' + if not os.path.isdir(snapdir): + raise SnapfileNotFoundError(repr(snapdir)) + # snapidl (e.g. 'snapname_072.idl') + Nstr = _N_to_snapstr(snap) + snapidl = f'{snapname}{Nstr}.idl' + if not os.path.isfile(snapidl): + raise SnapfileNotFoundError(repr(snapidl)) + else: + result.append(snapidl) + # snapshot data - checking + + def _is_snapN_data(name): + return _is_Nstr_snapfile(Nstr, name, '.snap') or (include_aux and _is_Nstr_snapfile(Nstr, name, '.aux')) + # looping through files & directories + for dirpath, dirnames, filenames in os.walk(snapdir, topdown=True): # topdown=True required for "skip popped dirnames" behavior. + # check if files are snapshot data files. + for fname in filenames: + if _is_snapN_data(fname): + result.append(os.path.join(dirpath, fname)) + # check if directories are snapshot data files (e.g. this will be the case for "zarray" storage system, read_mode='zc'). + i = 0 + while i < len(dirnames): + dname = dirnames[i] + if dname.endswith(f'.snap') or dname.endswith(f'.aux'): + del dirnames[i] # we don't need to os.walk any further down from a directory ending with '.snap' or '.aux' + if _is_snapN_data(dname): + result.append(os.path.join(dirpath, dname)) + else: + i += 1 + return result + + +def _is_Nstr_snapfile(Nstr, filename, ext='.snap'): + '''returns whether filename is a '.snap' file associated with snap indicated by Nstr. + This is only difficult when Nstr==''; else we just check if filename looks like '{stuff}{Nstr}.snap'. + For '.aux' checking, use ext='.aux' instead of the default '.snap'. + ''' + # pop extension + if filename.endswith(ext): + basename = filename[: -len(ext)] + else: + return False # << proper extension is required, otherwise not a snap file. + # handle "easy" case (Nstr != '') + if Nstr != '': + return basename.endswith(f'{Nstr}') + # else: Nstr == '', so we need to do a lot more work. + # in particular, we recognize snap 0 only in these cases: + # case A) 'stuff{C}.ext' with C non-numeric + # case B) 'stuff{C}_{SS}_{LL}.ext' with C non-numeric, SS numeric, LL numeric + # all other cases are considered to be a snap other than 0, so we ignore them. + if not basename[-1].isdigit(): + return True # << we are in case A. "no-fluid" case. + # consider case B: + stuffC_SS, underscore, LL = basename.rpartition('_') + if not (underscore == '_' and LL.isdigit()): + return False + stuffC, underscore, SS = stuffC_SS.rpartition('_') + if not (underscore == '_' and SS.isdigit()): + return False + if stuffC[-1].isdigit(): + return False + return True # << we are in case B. + + +############################# +# WRITING PARAMETER FILES # +############################# + +def coll_keys_generate(mf_param_file='mf_params.in', as_str=True): + '''generates COLL_KEYS such that all collisions will be turned on. + + COLL_KEYS look like: + II JJ TT + where II is ispecies, JJ is jspecies, TT is ('MX', 'EL', or 'CL'), and this line means: + turn on TT collisions between II ions and JJ (any level). + 'EL' --> "elastic". This should only be used when we have the collisions tables. + 'MX' --> "maxwell". Assume "maxwell molecules" (velocity-independent collision frequency). + 'CL' --> "coulomb". For ion-ion collisions. (Only applies to ion-ion collisions). + + if as_str, return a string which can be copy-pasted into an mf_param_file. + Otherwise, return an 2D array with result[i] = [AAi, BBi, TTi]. + ''' + x = read_mftab_ascii(mf_param_file) + + def levels_ions_neutrals(atomfile): + '''returns (levels of ions in atomfile, levels of neutrals in atomfile)''' + fluids = fl.Fluids([atomfile]) + return (fluids.ions().level_no, fluids.neutrals().level_no) + + species = {iS: levels_ions_neutrals(file) for (iS, elem, file) in x['SPECIES']} + tables = collections.defaultdict(list) + for (neuS, ionS, ionL, file) in x['CROSS_SECTIONS_TABLES']: + tables[(neuS, ionS)].append(ionL) # tables keys (neuS, ionS); vals lists of ionL. + + def table_exists(neuS, ionS, ion_levels): + '''tells whether a table exists between neutralSpecie and ionSpecie, + at at least one of the levels in ion_levels. + ''' + for ionL in tables.get((neuS, ionS), []): + if int(ionL) in ion_levels: # (note that ion_levels are ints). + return True + return False + coll_keys = [] + for (iS, (ilevels_ion, ilevels_neu)) in species.items(): + if len(ilevels_ion) == 0: # if there are no i ions, + continue # continue, because no coll_keys start with iS in this case. + for (jS, (jlevels_ion, jlevels_neu)) in species.items(): + # ion-neutral collisions: + if len(jlevels_neu) >= 1: + if table_exists(jS, iS, ilevels_ion): + coll_keys.append((iS, jS, 'EL')) + else: + coll_keys.append((iS, jS, 'MX')) + # ion-ion collisions: + make_CL = False + if iS == jS: + if len(ilevels_ion) >= 2: # ilevels_ion == jlevels_ion + make_CL = True + else: + if len(jlevels_ion) >= 1: + make_CL = True + if make_CL: + coll_keys.append((iS, jS, 'CL')) + if not as_str: + return np.array(coll_keys) + else: + fmtstr = ' {} {} {}' + result = 'COLL_KEYS\n' + result += '\n'.join([fmtstr.format(*collkey_row) for collkey_row in coll_keys]) + return result + + +def write_idlparamsfile(snapname, mx=1, my=1, mz=1): + '''Write default .idl file''' + default_idl = [ + '; ************************* From params ************************* \n', + ' mx = {} \n'.format(mx), + ' my = {} \n'.format(my), + ' mz = {} \n'.format(mz), + ' mb = 5 \n', + ' nstep = 10 \n', + ' nstepstart = 0 \n', + ' debug = 0 \n', + ' time_lim = -1.000E+00 \n', + ' tstop = -1.00000000E+00 \n', + 'mf_total_nlevel = 5 \n', + ' mf_electrons = 0 \n', + ' mf_epf = 1 \n', + ' mf_nspecies = 2 \n', + ' mf_param_file = "mf_params.in" \n', + '; ************************* From parallel ************************* \n', + ' periodic_x = 1 \n', + ' periodic_y = 1 \n', + ' periodic_z = 0 \n', + ' ndim = 3 \n', + ' reorder = 1 \n', + '; ************************* From units ************************* \n', + ' u_l = 1.000E+08 \n', + ' u_t = 1.000E+02 \n', + ' u_r = 1.000E-07 \n', + ' u_p = 1.000E+05 \n', + ' u_u = 1.000E+06 \n', + ' u_kr = 1.000E-01 \n', + ' u_ee = 1.000E+12 \n', + ' u_e = 1.000E+05 \n', + ' u_te = 1.000E+11 \n', + ' u_tg = 1.212E+04 \n', + ' u_B = 1.121E+03 \n', + '; ************************* From stagger ************************* \n,' + ' meshfile = "{}.mesh" \n'.format(snapname), + ' dx = 1.000E+00 \n', + ' dy = 1.000E+00 \n', + ' dz = 2.993E-02 \n', + '; ************************* From timestep ************************* \n', + ' Cdt = 0.030 \n', + ' dt = 1.e-11 \n', + ' t = 0.0 \n', + ' timestepdebug = 0 \n', + '; ************************* From mhd ************************* \n', + ' nu1 = 0.100 \n', + ' nu2 = 0.300 \n', + ' nu3 = 0.800 \n', + ' nu_r = 0.100 \n', + ' nu_r_z = 9.990E+02 \n', + ' nu_r_mz = 0.100 \n', + ' nu_ee = 0.100 \n', + ' nu_ee_z = 9.990E+02 \n', + ' nu_ee_mz = 0.100 \n', + ' nu_e_ee = 0.000 \n', + ' nu_e_ee_z = 9.990E+02 \n', + ' nu_e_ee_mz = 0.000 \n', + ' symmetric_e = 0 \n', + ' symmetric_b = 0 \n', + ' grav = -2.740 \n', + ' eta3 = 3.000E-01 \n', + ' ca_max = 0.000E+00 \n', + ' mhddebug = 0 \n', + ' do_mhd = 1 \n', + ' mhdclean = -1 \n', + ' mhdclean_ub = 0 \n', + ' mhdclean_lb = 0 \n', + ' mhdclean_ubx = 0 \n', + ' mhdclean_lbx = 0 \n', + ' mhdclean_uby = 0 \n', + ' mhdclean_lby = 0 \n', + ' do_e_joule = 1 \n', + ' do_ion_joule = 1 \n', + ' nue1 = 0.050 \n', + ' nue2 = 0.100 \n', + ' nue3 = 0.050 \n', + ' nue4 = 0.000 \n', + '; ************************* From io ************************* \n', + ' one_file = 0 \n', + ' snapname = "{}" \n'.format(snapname), + ' isnap = 0 \n', + ' large_memory = 1 \n', + ' nsnap = 100000000 \n', + ' nscr = 250 \n', + ' aux = " nel mfe_tg etg " \n', + ' dtsnap = 5.000E-09 \n', + ' newaux = 0 \n', + ' rereadpars = 1000000 \n', + ' dtscr = 1.000E+04 \n', + ' tsnap = 0.0 \n', + ' tscr = 0.00000000E+00 \n', + ' boundarychk = 0 \n', + ' boundarychky = 0 \n', + ' boundarychkx = 0 \n', + ' print_stats = 0 \n', + '; ************************* From math ************************* \n', + ' max_r = 5 \n', + ' smooth_r = 3 \n', + ' divhc_niter = 1000 \n', + ' divhc_cfl = 0.400 \n', + ' divhc_r = 0.180 \n', + ' divhc_vxr = 0.000 \n', + ' divhc_vyr = 0.000 \n', + ' divhc_vzr = 0.950 \n', + ' divhc_tol = 1.000E-05 \n', + '; ************************* From quench ************************* \n', + ' qmax = 8.000 \n', + '; ************************* From eos ************************* \n', + ' gamma = 1.667 \n', + ' eosdebug = 0 \n', + '; ************************* From collisions utils ************* \n', + ' do_col = 0 \n', + ' col_debug = 0 \n', + ' do_qcol = 1 \n', + ' do_ecol = 0 \n', + 'col_calc_nu_in = 1 \n', + 'col_const_nu_in = -1.000E+03 \n', + ' col_cnu_max = 1.000E+03 \n', + ' col_utiny = -1.000E-05 \n', + 'col_trans_tim0 = 0.000E+00 \n', + ' col_trans_dt = 1.000E+00 \n', + 'col_trans_ampl = 1.000E-10 \n', + ' col_tabin = "mf_coltab.in" \n', + '; ************************* From collisions ************* \n', + ' qcol_method = "expl" \n', + 'col_matrix_norm = 0 \n', + '; ************************* From ionrec ************* \n', + ' qri_method = "impl" \n', + '; ************************* From mf_recion (utils) ************* \n', + ' do_recion = 0 \n', + ' recion_debug = 0 \n', + ' calc_freq = 1 \n', + ' three_bdy = 1 \n', + ' const_fion = -1.000E+00 \n', + ' const_frec = -1.000E+00 \n', + ' recion_tabin = "mf_reciontab.in" \n', + 'recion_modname = "atomic" \n', + '; ************************* From hall ************************* \n', + ' do_hall = "false" \n', + ' tstep_hall = "ntsv" \n', + ' eta_hallo = 1.000E+00 \n', + ' eta4_hall = [ 0.100, 0.100, 0.100 ] \n', + 'mts_max_n_hall = 10 \n', + '; ************************* From Bierman ************************* \n', + ' do_battery = 0 \n', + ' bb_bato = 1.000E+00 \n', + 'bb_extdyn_time = -1.000E+00 \n', + ' bb_ext_bb = 0.000E+00 \n', + 'bb_debug_battery = 0 \n', + ' do_qbat = 0 \n', + '; ************************* From ohm_ecol ************* \n', + ' do_ohm_ecol = 0 \n', + ' do_qohm = 1 \n', + 'ec_ohm_ecoll_debug = 0 \n', + ' ec_calc_nu_en = 1 \n', + ' ec_calc_nu_ei = 1 \n', + 'ec_const_nu_en = -1.000E+00 \n', + 'ec_const_nu_ei = -1.000E+00 \n', + ' ec_tabin = "mf_ecoltab.in" \n', + 'mf_eparam_file = "mf_eparams.in" \n', + '; ************************* From spitzer ************************* \n', + ' spitzer = "impl" \n', + ' debug_spitzer = 0 \n', + ' info_spitzer = 0 \n', + ' spitzer_amp = 0.000 \n', + ' theta_mg = 0.900 \n', + ' dtgerr = 1.000E-05 \n', + ' ntest_mg = 1 \n', + ' tgb0 = 0.000E+00 \n', + ' tgb1 = 0.000E+00 \n', + ' tau_tg = 1.000E+00 \n', + ' fix_grad_tg = 1 \n', + ' niter_mg = [ 2, 5, 5, 5, 30 ] \n', + ' bmin = 1.000E-04 \n', + ' kappaq0 = 0.000E+00 \n', + '; ************************* From genrad ************************* \n', + ' do_genrad = 1 \n', + ' genradfile = "qthresh.dat" \n', + ' debug_genrad = 0 \n', + ' incrad_detail = 0 \n', + ' incrad_quad = 3 \n', + ' dtincrad = 1.000E-03 \n', + ' dtincrad_lya = 1.000E-04 \n', + ' debug_incrad = 0 \n', + '; ************************* From ue_electric ************* \n', + 'do_ue_electric = 1 \n', + 'ue_electric_debug = 0 \n', + 'ue_fudge_mass = 1.000E+00 \n', + ' ue_incr = 0.000 \n', + ' ue_dt_inc = -1.000E+00 \n', + ' ue_nu = [ 0.000, 0.000, 0.000, 0.000, 0.000 ] \n', + ' eionsfrz = 1 \n', + '; ************************* From bc_lowerx_magnetic ************* \n', + ' bctypelowerx = "mcccc" \n', + ' bcldebugx = 0 \n', + ' nextrap_bclx = 1 \n', + ' nsmooth_bclx = 0 \n', + 'nsmoothbyz_bcl = 0 \n', + '; ************************* From bc_upperx_magnetic ************* \n', + ' bctypeupperx = "mcccc" \n', + ' bcudebugx = 0 \n', + ' nextrap_bcux = 1 \n', + ' nsmooth_bcux = 0 \n', + 'nsmoothbyz_bcu = 0 \n', + '; ************************* From bc_lowery_magnetic ************* \n', + ' bctypelowery = "mcccc" \n', + ' bcldebugy = 0 \n', + ' nextrap_bcly = 1 \n', + ' nsmooth_bcly = 0 \n', + 'nsmoothbxz_bcl = 0 \n', + '; ************************* From bc_uppery_magnetic ************* \n', + ' bctypeuppery = "mcccc" \n', + ' bcudebugy = 0 \n', + ' nextrap_bcuy = 1 \n', + ' nsmooth_bcuy = 0 \n', + 'nsmoothbxz_bcu = 0 \n', + '; ************************* From bc_lowerz_magnetic ************* \n', + ' bctypelowerz = "mesec" \n', + ' bcldebugz = 0 \n', + ' nextrap_bclz = 1 \n', + ' nsmooth_bclz = 0 \n', + 'nsmoothbxy_bcl = 0 \n', + '; ************************* From bc_upperz_magnetic ************* \n', + ' bctypeupperz = "mesec" \n', + ' bcudebugz = 0 \n', + ' nextrap_bcuz = 1 \n', + ' nsmooth_bcuz = 0 \n', + 'nsmoothbxy_bcu = 0 \n' + ] + out = open('{}.idl'.format(snapname), 'w') + out.writelines(default_idl) + return + + +def keyword_update(inoutfile, new_values): + ''' Updates a given number of fields with values on a snapname.idl file. + These are given in a dictionary: fvalues = {field: value}. + Reads from snapname.idl and writes back into the same file.''' + lines = list() + with open(inoutfile) as f: + for line in f.readlines(): + if line[0] == '#' or line[0] == ';': + continue + elif line.find('=') < 0: + continue + else: + ss = line.split('=')[0] + ssv = ss.strip().lower() + if ssv in list(new_values.keys()): + line = '{} = {} \n'.format(ss, str(new_values[ssv])) + lines.append(line) + + with open(inoutfile, "w") as f: + f.writelines(lines) + + def write_mftab_ascii(filename, NSPECIES_MAX=28, SPECIES=None, EOS_TABLES=None, REC_TABLES=None, ION_TABLES=None, CROSS_SECTIONS_TABLES=None, @@ -863,21 +2487,21 @@ def write_mftab_ascii(filename, NSPECIES_MAX=28, ''' if SPECIES is None: - SPECIES=['H_2.atom', 'He_2.atom'] + SPECIES = ['H_2.atom', 'He_2.atom'] if EOS_TABLES is None: - EOS_TABLES=['H_EOS.dat', 'He_EOS.dat'] + EOS_TABLES = ['H_EOS.dat', 'He_EOS.dat'] if REC_TABLES is None: - REC_TABLES=['h_rec.dat', 'he_rec.dat'] + REC_TABLES = ['h_rec.dat', 'he_rec.dat'] if ION_TABLES is None: - ION_TABLES=['h_ion.dat', 'he_ion.dat'] + ION_TABLES = ['h_ion.dat', 'he_ion.dat'] if CROSS_SECTIONS_TABLES is None: - CROSS_SECTIONS_TABLES=[[1, 1, 'p-H-elast.txt'], - [1, 2, 'p-He.txt'], - [2, 2, 'He-He.txt']] + CROSS_SECTIONS_TABLES = [[1, 1, 'p-H-elast.txt'], + [1, 2, 'p-He.txt'], + [2, 2, 'He-He.txt']] if CROSS_SECTIONS_TABLES_I is None: - CROSS_SECTIONS_TABLES_I=[] + CROSS_SECTIONS_TABLES_I = [] if CROSS_SECTIONS_TABLES_N is None: - CROSS_SECTIONS_TABLES_N=[] + CROSS_SECTIONS_TABLES_N = [] params = [ 'NSPECIES_MAX', @@ -1008,7 +2632,7 @@ def write_mftab_ascii(filename, NSPECIES_MAX=28, for symb in SPECIES: symb = symb.split('_')[0] - if not(symb.lower() in coll_vars_list): + if not (symb.lower() in coll_vars_list): print('write_mftab_ascii: WARNING there may be a mismatch between' 'the atom files and selected species.\n' 'Check for species', symb.lower()) @@ -1113,7 +2737,7 @@ def write_mftab_ascii(filename, NSPECIES_MAX=28, f.write("\t" + "\t".join( [str(int( COLISIONS_MAP[crs][v])).zfill(2) for v in range( - 0, NSPECIES_MAX)]) + "\n") + 0, NSPECIES_MAX)]) + "\n") f.write("\n") if head == 'COLISIONS_MAP_I': f.write("#\t" + "\t".join( @@ -1122,7 +2746,7 @@ def write_mftab_ascii(filename, NSPECIES_MAX=28, for crs in range(0, NSPECIES_MAX): f.write("\t" + "\t".join([str(int( COLISIONS_MAP_I[crs][v])).zfill(2) for v in range( - 0, NSPECIES_MAX)]) + "\n") + 0, NSPECIES_MAX)]) + "\n") f.write("\n") if head == 'COLISIONS_MAP_N': f.write("#\t" + "\t".join( @@ -1131,7 +2755,7 @@ def write_mftab_ascii(filename, NSPECIES_MAX=28, for crs in range(0, NSPECIES_MAX): f.write("\t" + "\t".join([str(int( COLISIONS_MAP_N[crs][v])).zfill(2) for v in range( - 0, NSPECIES_MAX)]) + "\n") + 0, NSPECIES_MAX)]) + "\n") f.write("\n") if head == 'EMASK': f.write("#\t" + "\t".join( @@ -1139,6 +2763,140 @@ def write_mftab_ascii(filename, NSPECIES_MAX=28, 0, NSPECIES_MAX)]) + "\n") f.write("\t" + "\t".join([str( int(EMASK_MAP[v])).zfill(2) for v in range( - 0, NSPECIES_MAX)]) + "\n") + 0, NSPECIES_MAX)]) + "\n") f.write("\n") f.close() + +###################### +# DESTROYING FILES # +###################### + + +def smash_folder(folder, mode='trash', warn=True, _force_no_warn=False): + '''smashes (destroys or moves to trash) folder. + mode: 'trash' (default) or one of ('destroy', 'delete', 'rm') + 'trash' --> move folder to trash (determined by os.environ['TRASH']) + 'destroy', 'delete', or 'rm' --> destroy the folder permanently as per shutil.rmtree. + warn: bool, default True + whether to ask for user confirmation before smashing the folder. + CAUTION: be very careful about using warn=False! + _force_no_warn: bool, default False + if folder has fewer than N components, warn is set to True unless _force_no_warn. + N = 3 for 'trash' mode; N = 4 for 'destroy' mode. + E.g. /Users/You/ has 2 components. /Users/You/Folder/Subfolder has 4 components. + + returns one of: + None (if smashing was aborted, due to user input) + new path to folder, in trash (for mode='trash') + old path to now-destroyed folder (for mode='destroy') + ''' + mode = mode.lower() + VALID_MODES = 'trash', 'destroy', 'delete', 'rm' + assert mode in VALID_MODES, f"mode={repr(mode)} invalid; expected one of: {VALID_MODES}" + if mode == 'trash': + result = _trash_folder(folder, warn=warn, _force_no_warn=_force_no_warn) + if result is not None: + print(f"trashed folder; now located at {repr(result)}") + elif mode in ('destroy', 'delete', 'rm'): + result = _destroy_folder(folder, warn=warn, _force_no_warn=_force_no_warn) + if result is not None: + print(f"permanently destroyed folder {repr(result)}") + else: + raise NotImplementedError(f"mode={repr(mode)}") # << we should never reach this line. + return result + + +def _trash_folder(folder, warn=True, _force_no_warn=False): + '''moves the indicated folder to Trash (which the User can empty at a later time). + Uses os.environ['TRASH'] to determine the location of trash. + (You must set os.environ['TRASH'] before using this function) + E.g. on macOS, the standard Trash is at os.environ['TRASH']='~/.Trash' + + if warn (default True), first ask for user confirmation. + CAUTION: be very careful about using warn=False! + + if folder has fewer than 3 components, warn is set to True unless _force_no_warn. + E.g. /Users/You/ has 2 components. /Users/You/Folder/Subfolder has 4 components. + + returns the new abspath to the folder (in the trash), or None if aborted. + ''' + # preprocessing - folder exists? + folder = os.path.abspath(folder) + if not os.path.exists(folder): + raise FileNotFoundError(folder) + # preprocessing - trash exists? + try: + trash_ = os.environ['TRASH'] + except KeyError: + errmsg = ("_trash_folder() requires os.environ['TRASH'] to be set.\n" + "Set it via os.environ['TRASH']='~/.Trash' (or some other value, as appropriate)") + raise AttributeError(errmsg) from None + trash = os.path.abspath(os.path.expanduser(trash_)) # expanduser handles '~'. + # preprocessing - are we warning the user? + if _force_no_warn: + warn = False + else: + MIN_N_COMPONENTS = 3 + if _count_components(folder) < MIN_N_COMPONENTS: + warn = True # force warn to True for "small" paths. + # possible warning + if warn: + confirm_msg = f'About to move to trash (at {repr(trash)}) the folder:\n {repr(folder)}\n' + \ + "Proceed? ('y' or empty string for yes; 'n' or anything else for no.)\n" + input_ = input(confirm_msg) + input_ = input_.lower() + if input_ not in ('y', ''): + print('Aborted. No files were moved to trash.') + return + # check if folder already in trash; edit name if necessary + folder_basename = os.path.basename(folder) + dst = os.path.join(trash, folder_basename) + if os.path.exists(dst): # append time if necessary to make name unique in trash + dst = dst + ' ' + time.strftime('%I.%M.%S %p') # e.g. '12.07.59 PM' + if os.path.exists(dst): # append date if necessary to make name unique in trash + dst = dst + ' ' + time.strftime('%m_%d_%Y') # e.g. '03_01_2022' + # actually trash the folder + result = shutil.move(folder, dst) + # return the old path of the now-deleted folder. + return result + + +def _destroy_folder(folder, warn=True, _force_no_warn=False): + '''destroys the indicated folder. + if warn (default True), first ask for user confirmation. + CAUTION: be very careful about using warn=False! + + if folder has fewer than 4 components, warn is set to True unless _force_no_warn. + E.g. /Users/You/ has 2 components. /Users/You/Folder/Subfolder has 4 components. + + returns the abspath to the now-deleted folder, or None if aborted. + ''' + # preprocessing + folder = os.path.abspath(folder) + if not os.path.exists(folder): + raise FileNotFoundError(folder) + if _force_no_warn: + warn = False + else: + MIN_N_COMPONENTS = 4 + if _count_components(folder) < MIN_N_COMPONENTS: + warn = True # force warn to True for "small" paths. + # possible warning + if warn: + confirm_msg = f'About to destroy the folder:\n {repr(folder)}\nProceed? ' + \ + "('y' or empty string for yes; 'n' or anything else for no.)\n" + input_ = input(confirm_msg) + input_ = input_.lower() + if input_ not in ('y', ''): + print('Aborted. No files were destroyed.') + return + # actually remove the folder + shutil.rmtree(folder) + # return the old path of the now-deleted folder. + return folder + + +def _count_components(path): + '''counts components in the provided path. + E.g. /Users/You/ has 2 components. /Users/You/Folder/Subfolder has 4 components.''' + return len(os.path.normpath(path).split(path)) diff --git a/helita/sim/fake_ebysus_data.py b/helita/sim/fake_ebysus_data.py new file mode 100644 index 00000000..43674450 --- /dev/null +++ b/helita/sim/fake_ebysus_data.py @@ -0,0 +1,517 @@ +""" +File purpose: + Access the load_..._quantities calculations without needing to write a full snapshot. + + Examples where this module is particularly useful: + - quickly testing values of quantities from helita postprocessing, for small arrays of data. + - check what would happen if a small number of changes are made to existing data. + +[TODO] + - allow FakeEbysusData.set_var to use units other than 'simu' + (probably by looking up var in vardict to determine units). + Note - this is already allowed in set_var_fundamanetal + (which, by default, gets called from set_var for var in FUNDAMENTAL_SETTABLES). +""" + +import os +# import built-ins +import shutil +import warnings +import collections + +# import external public modules +import numpy as np + +# import internal modules +from . import document_vars, ebysus, file_memory, tools, units + +AXES = ('x', 'y', 'z') + + +class FakeEbysusData(ebysus.EbysusData): + '''Behaves like EbysusData but allows to choose artificial values for any quant. + + If a quant is requested (via self(quant) or self.get_var(quant)), + first check if that quant has been set artificially, using the artificial value if found. + If the quant has not been set artificially, try to read it from snapshot as normal. + + No snapshot data are required to utilize FakeEbysusData. + All "supporting materials" are required as normal, though: + - snapname.idl ^(see footnote) + - mf_params.in (filename determined by mf_param_file in snapname.idl) + - any relevant .atom files + - collision tables if collision module(s) are enabled: + - mf_coltab.in, mf_ecoltab.in + - any collision files referenced by those files ^ + ^(footnote) if mhd.in is provided: + if snapname is not entered at __init__, use snapname from mhd.in. + if snapname.idl file does exist, copy mhd.in to a new file named snapname.idl. + + units_input: None, 'simu', 'si', or 'cgs' + units of value input using set_var. + only 'simu' system is implemented right now. + None --> use same mode as units_output. + units_input_fundamental: None, 'simu', 'si', or 'cgs' + units of value input using set_fundamental_var + None --> use same mode as units_input. + ''' + + def __init__(self, *args, verbose=False, units_input=None, units_input_fundamental=None, **kw): + '''initialize self using method from parent. But first: + - if there is no .idl file with appropriate snapname (from mhd.in or args[0]), + make one (by copying mhd.in) + + TODO: non-default options for units_input + ''' + # setup memory for fake data + self.setvars = collections.defaultdict(tools.GenericDict_with_equals(self._metadata_equals)) + self.nset = 0 # nset tracks the number of times set_var has been called. + + # units + self.units_input = units_input + self.units_input_fundamental = units_input_fundamental + + # make snapname.idl if necessary. + snapname = args[0] if len(args) > 0 else ebysus.get_snapname() + idlfilename = f'{snapname}.idl' + if not os.path.isfile(idlfilename): + shutil.copyfile('mhd.in', idlfilename) + if verbose: + print(f"copied 'mhd.in' to '{idlfilename}'") + # initialize self using method from parent. + super(FakeEbysusData, self).__init__(*args, verbose=verbose, **kw) + + @property + def units_input_fundamental(self): + '''units of value input using set_fundamental_var + None --> use same mode as units_input. + ''' + result = getattr(self, '_units_input_fundamental', None) + if result is None: + result = getattr(self, 'units_input', 'simu') + return result + + @units_input_fundamental.setter + def units_input_fundamental(self, value): + if value is not None: + value = value.lower() + units.ASSERT_UNIT_SYSTEM(value) + self._units_input_fundamental = value + + @property + def units_input(self): + '''units of value input using set_var. + only 'simu' system is implemented right now. + None --> use same mode as units_output. + ''' + result = getattr(self, '_units_input', 'simu') + if result is None: + result = getattr(self, 'units_output', 'simu') + return result + + @units_input.setter + def units_input(self, value): + if value is not None: + value = value.lower() + units.ASSERT_UNIT_SYSTEM(value) + if value != 'simu': + raise NotImplementedError(f'units_input = {repr(value)}') + self._units_input = value + + def _init_vars_get(self, *args__None, **kw__None): + '''do nothing and return None. + (overriding the initial data grabbing from EbysusData.) + ''' + + ## SET_VAR ## + def set_var(self, var, value, *args, nfluid=None, units=None, fundamental=None, + _skip_preprocess=False, fundamental_only=False, **kwargs): + '''set var in memory of self. + Use this to set the value for some fake data. + Any time we get var, we will check memory first; + if the value is in memory (with the appropriate metadata, e.g. ifluid,) + use the value from memory. Otherwise try to get it a different way. + + NOTE: set_var also clears self.cache (which otherwise has no way to know the data has changed). + + *args, and **kwargs go to self._get_var_preprocess. + E.g. using set_var(..., ifluid=(1,2)) will first set self.ifluid=(1,2). + + nfluid: None (default), 0, 1, or 2 + number of fluids which this var depends on. + None - try to guess, using documentation about the vars in self.vardict. + This option is good enough for most cases. + But fails for constructed vars which don't appear in vardict directly, e.g. 'b_mod'. + 0; 1; or 2 - depends on neither; just ifluid; or both ifluid and jfluid. + units: None, 'simu', 'si', or 'cgs' + units associated with value input to this function. + (Note that all values will be internally stored in the same units as helita would output, + given units_output='simu'. This usually means values are stored in simulation units.) + None --> use self.units_input. + else --> use the value of this kwarg. + fundamental: None (default), True, or False + None --> check first if var is in self.FUNDAMENTAL_SETTABLES. + if it is, use set_fundamental_var instead. + True --> immediately call set_fundamental_var instead. + False --> do not even consider using set_fundamental_var. + fundamental_only: True (default), or False + if calling set_fundamental_var... + True --> only set value of fundamental quantity corresponding to var. + False --> also set value of var. + Example of why it matters...: + If running the following lines: + (1) obj.set_var('tg', 4321) + (2) obj.set_var('e', obj('e') * 100) + (3) obj.get_var('tg') + with fundamental_only==True: + (1) sets 'e' to the appropriate value such that 'tg' will be 4321 + (2) adjusts the value of 'e' (only) + (3) gives the answer 432100 + with fundamental_only==False: + (1) sets 'tg' to 4321. (AND 'e' appropriately, if fundamental is True or None.) + (2) adjusts the value of 'e' (only) + (3) gives the answer 4321, because it reads the value of 'tg' directly, instead of checking 'e'. + + ''' + if fundamental is None: + if var in self.FUNDAMENTAL_SETTABLES: + fundamental = True + if fundamental: + return self.set_fundamental_var(var, value, *args, units=units, + fundamental_only=fundamental_only, **kwargs) + + self._warning_during_setvar_if_slicing_and_stagger() + + if not _skip_preprocess: + self._get_var_preprocess(var, *args, **kwargs) + + # bookkeeping - nfluid + if nfluid is None: + nfluid = self.get_var_nfluid(var) + if nfluid is None: # if still None, raise instructive error (instead of confusing crash later). + raise ValueError(f"nfluid=None for var='{var}'. Workaround: manually enter nfluid in set_var.") + # bookkeeping - units + units_input = units if units is not None else self.units_input + if units_input != 'simu': + raise NotImplementedError(f'set_var(..., units={repr(units_input)})') + + # save to memory. + meta = self._metadata(with_nfluid=nfluid) + self.setvars[var][meta] = value + + # do any updates that we should do whenever a var is set. + self._signal_set_var(var=var) + + def _signal_set_var(self, var=None): + '''update anything that needs to be updated whenever a var is set. + This code should probably be run any time self.setvars is altered in any way. + ''' + # bookkeeping - increment nset + self.nset += 1 + # clear the cache. + if hasattr(self, 'cache'): + self.cache.clear() + + # tell quant lookup to search vardict for var if metaquant == 'setvars' + VDSEARCH_IF_META = getattr(ebysus.EbysusData, 'VDSEARCH_IF_META', []) + ['setvars'] + + @tools.maintain_attrs('match_type', 'ifluid', 'jfluid') + @file_memory.with_caching(cache=False, check_cache=True, cache_with_nfluid=None) + @document_vars.quant_tracking_top_level + def _load_quantity(self, var, *args, **kwargs): + '''load quantity, but first check if the value is in memory with the appropriate metadata.''' + if var in self.setvars: + meta = self._metadata() + try: + result = self.setvars[var][meta] + document_vars.setattr_quant_selected(self, var, 'SET_VAR', metaquant='setvars', + varname=var, level='(FROM SETVARS)', delay=False) + return result + except KeyError: # var is in memory, but not with appropriate metadata. + pass # e.g. we know some 'nr', but not for the currently-set ifluid. + # else + return self._raw_load_quantity(var, *args, **kwargs) + + FUNDAMENTAL_SETTABLES = ('r', 'nr', 'e', 'tg', *(f'{v}{x}' for x in AXES for v in ('p', 'u', 'ui', 'b'))) + + def set_fundamental_var(self, var, value, *args, fundamental_only=True, units=None, **kwargs): + '''sets fundamental quantity corresponding to var; also sets var (unless fundamental_only). + fundamental quantities, and alternate options for vars will allow to set them, are: + r - nr + e - tg, p + p{x} - u{x}, ui{x} (for {x} in 'x', 'y', 'z') + b{x} – (no alternates.) + + fundamental_only: True (default) or False + True --> only set value of fundamental quantity corresponding to var. + False --> also set value of var. + units: None, 'simu', 'si', or 'cgs' + units associated with value input to this function. + (Note that all values will be internally stored in the same units as helita would output, + given units_output='simu'. This usually means values are stored in simulation units.) + None --> use self.units_input_fundamental. + else --> use the value of this kwarg. + + returns (name of fundamental var which was set, value to which it was set [in self.units_output units]) + ''' + assert var in self.FUNDAMENTAL_SETTABLES, f"I don't know how this var relates to a fundamental var: '{var}'" + self._warning_during_setvar_if_slicing_and_stagger() + + # preprocess + self._get_var_preprocess(var, *args, **kwargs) # sets snap, ifluid, etc. + also_set_var = (not fundamental_only) + # units + units_input = units if units is not None else self.units_input_fundamental + self.units_output + + def u_in2simu(key): + return self.uni(key, 'simu', units_input) + + def get_in_simu(vname): + with tools.MaintainingAttrs(self, 'units_output'): + self.units_output = 'simu' + return self(vname) + # # below, we calculate the value in 'simu' units, then set vars in 'simu' units as appropriate. + # set fundamental var. + # # 'r' - mass density + if var in ['r', 'nr']: + fundvar = 'r' # fundvar = 'the corresponding fundamental var' + if var == 'r': + ukey = 'r' + also_set_var = False + value_simu = value * u_in2simu(ukey) # value_simu = ' (input) in simu units' + fundval_simu = value_simu # fundval_simu = 'value of fundvar in simu units' + elif var == 'nr': # r = nr * m + ukey = 'nr' + value_simu = value * u_in2simu(ukey) + fundval_simu = value_simu * self.get_mass(units='simu') + # # 'e' - energy density + elif var in ['e', 'p', 'tg']: + fundvar = 'e' + if var == 'e': + ukey = 'e' + also_set_var = False + value_simu = value * u_in2simu(ukey) + fundval_simu = value_simu + elif var == 'p': # e = p / (gamma - 1) + ukey = 'e' # 'p' units are same as 'e' units. + value_simu = value * u_in2simu(ukey) + fundval_simu = value_simu / (self.uni.gamma - 1) + elif var == 'tg': # e = T / e_to_tg + ukey = None # T always in [K]. + value_simu = value + e_to_tg_simu = get_in_simu('e_to_tg') + fundval_simu = value_simu / e_to_tg_simu + # # 'p{x}' - momentum density ({x}-component) + elif var in tuple(f'{v}{x}' for x in AXES for v in ('p', 'u', 'ui')): + base, x = var[:-1], var[-1] + fundvar = f'p{x}' + if base == 'p': + ukey = 'pm' + also_set_var = False + value_simu = value * u_in2simu(ukey) + fundval_simu = value_simu + elif base in ['u', 'ui']: # px = ux * rxdn + ukey = 'u' + value_simu = value * u_in2simu(ukey) + r_simu = get_in_simu('r'+f'{x}dn') + fundval_simu = value_simu * r_simu + # # 'b{x}' - magnetic field ({x}-component) + elif var in tuple(f'b{x}' for x in AXES): + base, x = var[:-1], var[-1] + fundvar = f'b{x}' + if base == 'b': + ukey = 'b' + also_set_var = False + value_simu = value * u_in2simu(ukey) + fundval_simu = value_simu + else: + raise NotImplementedError(f'{var} in set_fundamental_var') + + # set fundamental var + self.set_var(fundvar, fundval_simu, *args, **kwargs, + units='simu', # we already handled the units; set_var shouldn't mess with them. + fundamental=False, # we already handled the 'fundamental' possibility. + _skip_preprocess=True, # we already handled preprocessing. + ) + # set var (the one that was entered to this function) + if also_set_var: + self.set_var(var, value_simu, *args, **kwargs, + units='simu', # we already handled the units; set_var shouldn't mess with them. + fundamental=False, # we already handled the 'fundamental' possibility. + _skip_preprocess=True, # we already handled preprocessing. + ) + u_simu2out = (1 if ukey is None else self.uni(ukey, self.units_output, 'simu')) + return (fundvar, fundval_simu * u_simu2out) + + def _warn_if_slicing_and_stagger(self, message): + '''if any slice is not slice(None), and do_stagger=True, warnings.warn(message)''' + if self.do_stagger and any(iiax != slice(None) for iiax in (self.iix, self.iiy, self.iiz)): + warnings.warn(message) + + def _warning_during_setvar_if_slicing_and_stagger(self): + self._warn_if_slicing_and_stagger(( + 'setting var with iix, iiy, or iiz != slice(None) and do_stagger=True' + ' may lead to unexpectedly not using values from setvars. \n\n(Internally,' + ' when do_stagger=True, slices are set to slice(None) while getting vars, and' + ' the original slices are only applied after completing all other calculations.)' + f'\n\nGot slices: iix={self.iix}, iiy={self.iiy}, iiz={self.iiz}' + )) + + ## UNSET_VAR ## + def unset_var(self, var): + '''unset the value of var if it has been set in setvars. + if var hasn't been set, do nothing. + returns whether var was previously set. + ''' + try: + del self.setvars[var] + except KeyError: + return False # var wasn't in self.setvars. + else: + self._signal_set_var(var=var) + return True + + FUNDAMENTAL_VARS = (*(f'b{x}' for x in AXES), 'r', *(f'p{x}' for x in AXES), 'e') + + def unset_extras(self): + '''unsets the values of all non-fundamental vars. + returns list of vars which were previously set but are no longer set. + ''' + desetted = [] + for var in iter(set(self.setvars.keys()) - set(self.FUNDAMENTAL_VARS)): + u = self.unset_var(var) + if u: + desetted.append(var) + return desetted + + unset_nonfundamentals = unset_non_fundamentals = keep_only_fundamentals = unset_extras # aliases + + ## ITER FUNDAMENTALS ## + def iter_fundamentals(self, b=True, r=True, p=True, e=True, AXES=AXES): + '''iterate through fundamental vars: + b (per axis) + r (per non-electron fluid) + p (per axis, per non-electron fluid) + e (per fluid (including electrons)) + during iteration through fluids, set self.ifluid appropriately. + + b, r, p, e: bool + whether to iterate through this fundamental var. + AXES: string or list of strings from ('x', 'y', 'z') + axes to use when iterating through fundamental vars. + + yields var name. (sets ifluid immediately before yielding each value.) + ''' + if b: + for x in AXES: + yield f'b{x}' + if r: + for fluid in self.fluid_SLs(with_electrons=False): + self.ifluid = fluid + yield 'r' + if p: + for fluid in self.fluid_SLs(with_electrons=False): + for x in AXES: + self.ifluid = fluid # in inner loop, to set immediately before yielding + yield f'p{x}' + if e: + for fluid in self.fluid_SLs(with_electrons=True): + self.ifluid = fluid + yield 'e' + + @tools.maintain_attrs('ifluid') + def set_fundamental_means(self): + '''sets all fundamental vars to their mean values. (Method provided for convenience.)''' + for var in self.iter_fundamentals(): + self.set_fundamental_var(var, np.mean(self(var))) + + @tools.maintain_attrs('ifluid') + def set_fundamental_full(self): + '''sets all fundamental vars to their fully-shaped values. + (If their shape is not self.shape, add self.zero()) + ''' + for var in self.iter_fundamentals(): + self.set_fundamental_var(var, self.reshape_if_necessary(self(var))) + + ## WRITE SNAPSHOT ## + def write_snap0(self, warning=True): + '''write data from self to snapshot 0. + if warning, first warn user that snapshot 0 will be overwritten, and request confirmation. + ''' + if not self._confirm_write('Snapshot 0', warning): + return # skip writing unless confirmed. + self.write_mfr(warning=False) + self.write_mfp(warning=False) + self.write_mfe(warning=False) + self.write_mf_common(warning=False) + + def _confirm_write(self, name, warning=True): + '''returns whether user truly wants to write name at self.file_root. + if warning==False, return True (i.e. "yes, overwrite") without asking user. + ''' + if warning: + confirm = input(f'Write {name} at {self.file_root}? (y/n)') + if confirm.lower() not in ('y', 'yes'): + print('Aborted. Nothing was written.') + return False + return True + + @tools.with_attrs(units_output='simu') + def write_mfr(self, warning=True): + '''write mass densities from self to snapshot 0.''' + if not self._confirm_write('Snapshot 0 mass densities', warning): + return # skip writing unless confirmed. + for ifluid in self.iter_fluid_SLs(with_electrons=False): + r_i = self.reshape_if_necessary(self('r', ifluid=ifluid)) + ebysus.write_mfr(self.root_name, r_i, ifluid=ifluid) + + @tools.with_attrs(units_output='simu') + def write_mfp(self, warning=True): + '''write momentum densitites from self to snapshot 0.''' + if not self._confirm_write('Snapshot 0 momentum densities', warning): + return # skip writing unless confirmed. + for ifluid in self.iter_fluid_SLs(with_electrons=False): + self.ifluid = ifluid + p_xyz_i = [self.reshape_if_necessary(self(f'p{x}')) for x in AXES] + ebysus.write_mfp(self.root_name, *p_xyz_i, ifluid=ifluid) + + @tools.with_attrs(units_output='simu') + def write_mfe(self, warning=True): + '''write energy densitites from self to snapshot 0. + Note: if there is only 1 non-electron fluid, this function does nothing and returns None + (because ebysus treats e as 'common' in single fluid case. See also: write_common()). + ''' + non_e_fluids = self.fluid_SLs(with_electrons=False) + if len(non_e_fluids) == 1: + return + if not self._confirm_write('Snapshot 0 energy densities', warning): + return # skip writing unless confirmed. + for ifluid in non_e_fluids: + e_i = self.reshape_if_necessary(self('e', ifluid=ifluid)) + ebysus.write_mfe(self.root_name, e_i, ifluid=ifluid) + e_e = self.reshape_if_necessary(self('e', ifluid=(-1, 0))) + ebysus.write_mf_e(self.root_name, e_e) + + @tools.with_attrs(units_output='simu') + def write_mf_common(self, warning=True): + '''write magnetic field from self to snapshot 0. (Also writes energy density if singlefluid.)''' + b_xyz = [self.reshape_if_necessary(self(f'b{x}')) for x in AXES] + non_e_fluids = self.fluid_SLs(with_electrons=False) + if len(non_e_fluids) == 1: + if not self._confirm_write('Snapshot 0 magnetic field and single fluid energy density', warning): + return # skip writing unless confirmed. + self.ifluid = non_e_fluids[0] + e_singlefluid = self.reshape_if_necessary(self('e')) + ebysus.write_mf_common(self.root_name, *b_xyz, e_singlefluid) + else: + if not self._confirm_write('Snapshot 0 magnetic field', warning): + return # skip writing unless confirmed. + ebysus.write_mf_common(self.root_name, *b_xyz) + + ## CONVENIENCE ## + def reshape_if_necessary(self, val): + '''returns val + self.zero() if shape(val) != self.shape, else val (unchanged)''' + if np.shape(val) != self.shape: + val = val + self.zero() + return val diff --git a/helita/sim/file_memory.py b/helita/sim/file_memory.py new file mode 100644 index 00000000..87047ad8 --- /dev/null +++ b/helita/sim/file_memory.py @@ -0,0 +1,724 @@ +""" +created by Sam Evans on Apr 12 2021 + +purpose: + + - don't re-read files multiple times. (see remember_and_recall()) + - limit number of open memmaps; avoid crash via "too many files open". (see manage_memmaps()) + +TODO: + try to manage_memmaps a bit more intelligently.. + current implementation will delete the oldest-created memmap first. + This leads to non-useful looping behavior esp. if using get_varTime. + whereas we could instead do something intelligent. Options include: + - dedicate the first N memmaps to the first N that we read. + - maintain separate list of just the memmap file names + count how many times we read each file; + keep in memory the memmaps for the files we are reading more often. + - some combination of the above ideas. + + allow for check_cache to propagate downwards throughout all calls to get_var. + E.g. right now get_var(x, check_cache=False) will not check cache for x, + however if it requires to get_var(y) it will still check cache for y. + +""" + +import os +import sys # for debugging 'too many files' crash; will be removed in the future +import time # for time profiling for caching +import weakref # for refering to parent in cache without making circular reference. +# import builtins +import resource +import warnings +import functools +from collections import OrderedDict, namedtuple + +# import local modules +from . import document_vars + +# import external public modules +try: + import numpy as np +except ImportError: + warnings.warn('failed to import numpy; some functions in helita.sim.file_memory may crash') + +# import internal modules +# from .fluid_tools import fluid_equals # can't import this here, due to dependency loop: + # bifrost imports file_memory + # fluid_tools imports at_tools + # at_tools import Bifrost_units from bifrost + +# set defaults +# apparently there is good efficiency improvement even if we only remember the last few memmaps. +# NMLIM_ATTR is the name of the attr which will tell us max number of memmaps to remember. +NMLIM_ATTR = 'N_memmap' +MEMORY_MEMMAP = '_memory_memmap' +MM_PERSNAP = 'mm_persnap' +# hard limit on number of open files = limit set by system; cannot be changed. +_, HARD = resource.getrlimit(resource.RLIMIT_NOFILE) +# soft limit on number of open files = limit observed by programs; can be changed, must always be less than HARD. +SOFT_INCREASE = 1.2 # amount to increase soft limit, when increasing. int -> add; float -> multiply. +MAX_SOFT = int(min(1e6, 0.1 * HARD)) # we will never set the soft limit to a value larger than this. +SOFT_WARNING = 8192 # if soft limit exceeds this value we will warn user every time we increase it. +SOFT_PER_OBJ = 0.1 # limit number of open memmaps in one object to SOFT_PER_OBJ * soft limit. + +HIDE_DECORATOR_TRACEBACKS = True # whether to hide decorators from this file when showing error traceback. + + +DEBUG_MEMORY_LEAK = False # whether to turn on debug messages to tell when Cache and/or EbysusData are deleted. +# There is currently a memory leak which seems unrelated to file_memory.py, +# because even with _force_disable_memory=True, the EbysusData objects +# are not actually being deleted when del is called. - SE June 10, 2021 +# E.g. dd = eb.EbysusData(...); del dd --> dd.__del__() is not being called. +# This could be caused by some attribute of dd containing a pointer to dd. +# Those pointers should be replaced by weakrefs; see e.g. Cache class in this file. + + +''' --------------------- remember_and_recall() --------------------- ''' + + +def remember_and_recall(MEMORYATTR, ORDERED=False, kw_mem=[]): + '''wrapper which returns function but with optional args obj, MEMORYATTR. + default obj=None, MEMORYATTR=MEMORYATTR. + if obj is None, behavior is unchanged; + else, remembers the values from reading files (by saving to the dict obj.MEMORYATTR), + and returns those values instead of rereading files. (This improves efficiency.) + + track modification timestamp for file in memory lookup dict. + this ensures if file is modified we will read the new file data. + ''' + def decorator(f): + @functools.wraps(f) + def f_but_remember_and_recall(filename, *args, obj=None, MEMORYATTR=MEMORYATTR, kw_mem=kw_mem, **kwargs): + '''if obj is None, simply does f(filename, *args, **kwargs). + Else, recall or remember result, as appropriate. + memory location is obj.MEMORYATTR[filename.lower()]. + kw_mem: list of strs (default []) + for key in kw_mem, associate key kwargs[key] with uniqueness of result. + ''' + __tracebackhide__ = HIDE_DECORATOR_TRACEBACKS + if getattr(obj, '_force_disable_memory', False): + obj = None + if obj is not None: + if not hasattr(obj, '_recalled'): + obj._recalled = dict() + if MEMORYATTR not in obj._recalled: + obj._recalled[MEMORYATTR] = 0 + if not hasattr(obj, MEMORYATTR): + setattr(obj, MEMORYATTR, dict()) + memory = getattr(obj, MEMORYATTR) + memory['data'] = dict() + memory['len'] = 0 + memory['recalled'] = 0 + if ORDERED: + memory['order'] = OrderedDict() + else: + memory = getattr(obj, MEMORYATTR) + memdata = memory['data'] + if os.path.exists(filename): + timestamp = os.stat(filename).st_mtime # timestamp of when file was last modified + else: + timestamp = '???' + # set filekey (key string for memory dict with filename keys) + filekey = filename.lower() # this would be enough, but we remove common prefix for readability. + if hasattr(obj, '_memory_filekey_fdir'): + _memory_filekey_fdir = obj._memory_filekey_fdir + elif hasattr(obj, 'fdir'): + _memory_filekey_fdir = os.path.abspath(obj.fdir).lower() + obj._memory_filekey_fdir = _memory_filekey_fdir + else: + _memory_filekey_fdir = os.path.abspath(os.sep) # 'root directory' (plays nicely with relpath, below) + filekey = os.path.relpath(filekey, _memory_filekey_fdir) + # determine whether we have this (filename, timestamp, kwargs) in memory already. + need_to_read = True + existing_mid = None + if filekey in memdata.keys(): + memfile = memdata[filekey] + for mid, memdict in memfile['memdicts'].items(): + # determine if the values of kwargs (which appear in kw_mem) match those in memdict. + kws_match = True + for key in kw_mem: + # if key (in kw_mem) appears in kwargs, it must appear in memdict and have matching value. + if key in kwargs.keys(): + if key not in memdict.keys(): + kws_match = False + break + elif kwargs[key] != memdict[key]: + kws_match = False + break + # if kwargs and timestamp match, we don't need to read; instead use value from memdict. + if kws_match and memdict['file_timestamp'] == timestamp: + need_to_read = False + break + # if we found a memdict matching (filename, timestamp, kwargs), + # we need to read if and only if memdict['value'] has been smashed. + if not need_to_read: + if 'value' not in memdict.keys(): + need_to_read = True + existing_mid = mid # mid is the unique index for this (timestamp, kwargs) + # combination, for this filekey. This allows to have + # a unique dict key which is (filekey, mid); and use + # (filekey, mid) to uniquely specify (file, timestamp, kwargs). + else: + memdata[filekey] = dict(memdicts=dict(), mid_next=1) # mid_next is the next available mid. + memfile = memdata[filekey] + # read file if necessary (and store result to memory) + if need_to_read: + result = f(filename, *args, **kwargs) # here is where we call f, if obj is not None. + if not existing_mid: + mid = memfile['mid_next'] + memfile['mid_next'] = mid + 1 + memdict = dict(value=result, # value is in memdict + file_timestamp=timestamp, mid=mid, recalled=0) # << metadata about file, kwargs, etc + memdict.update(kwargs) # << metadata about file, kwargs, etc + memfile['memdicts'][mid] = memdict # store memdict in memory. + else: + mid = existing_mid + memdict = memfile['memdicts'][mid] + memdict['value'] = result + memory['len'] += 1 # total number of 'value's stored in memory + if ORDERED: + memory['order'][(filekey, mid)] = None + # this is faster than a list due to the re-ordering of memory['order'] + # which occurs if we ever access the elements again. + # Really, a dict is unnecessary, we just need a "hashed" list, + # but we can abuse OrderedDict to get that. + else: + memory['recalled'] += 1 + memdict['recalled'] += 1 + obj._recalled[MEMORYATTR] += 1 + if ORDERED: + # move this memdict to end of order list; order is order of access. + memory['order'].move_to_end((filekey, memdict['mid'])) + # return value from memory + return memdict['value'] + else: # obj is None, so there is no memory, so we just call f and return the result. + return f(filename, *args, **kwargs) # here is where we call f, if obj is None. + return f_but_remember_and_recall + return decorator + + +''' --------------------- manage_memmaps() --------------------- ''' + + +def get_nfiles_soft_limit(): + soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE) + return soft + + +def increase_soft_limit(soft=None, increase=SOFT_INCREASE): + '''increase soft by increase (int -> add; float -> multiply).''' + if soft is None: + soft = get_nfiles_soft_limit() + soft0 = soft + if isinstance(increase, int): + soft += increase + elif isinstance(increase, float): + soft = int(soft * increase) + else: + raise TypeError('invalid increase type! expected int or float but got {}'.format(type(increase))) + increase_str = ' limit on number of simultaneous open files from {} to {}'.format(soft0, soft) + if soft > MAX_SOFT: + raise ValueError('refusing to increase'+increase_str+' because this exceeds MAX_SOFT={}'.format(MAX_SOFT)) + if soft > SOFT_WARNING: + warnings.warn('increasing'+increase_str+'.') + resource.setrlimit(resource.RLIMIT_NOFILE, (soft, HARD)) + + +def manage_memmaps(MEMORYATTR, kw_mem=['dtype', 'order', 'offset', 'shape']): + '''decorator which manages number of memmaps. f should at most add one memmap to memory.''' + def decorator(f): + @functools.wraps(f) + def f_but_forget_memmaps_if_needed(*args, **kwargs): + '''forget one memmap if there are too many in MEMORYATTR. + determine what is "too many" via NMLIM_ATTR. + + Then return f(*args, **kwargs). + ''' + # check if we need to forget a memmap; forget one if needd. + __tracebackhide__ = HIDE_DECORATOR_TRACEBACKS + try: + obj = kwargs['obj'] + except KeyError: + obj = None + if getattr(obj, '_force_disable_memory', False): + obj = None + if obj is not None: + memory = getattr(obj, MEMORYATTR, None) + if memory is not None: + # memory is a dict of {filekey: memdictlist}; each memdict in memdictlist stores one memmap. + soft = get_nfiles_soft_limit() + forget_one = False + val = getattr(obj, NMLIM_ATTR, -1) + if val == -1: # we limit number of open memmaps based on limit for simultaneous open files. + if memory['len'] >= SOFT_PER_OBJ * soft: + try: + increase_soft_limit(soft) + except ValueError: # we are not allowed to increase soft limit any more. + warnings.warn('refusing to increase soft Nfile limit further than {}!'.format(soft)) + forget_one = True + elif val < 0: + raise ValueError('obj.'+NMLIM_ATTR+'must be -1 or 0 or >0 but got {}'.format(val)) + else: # we limit number of open memmaps based on NMLIM_ATTR. + if memory['len'] >= val: + forget_one = True + if forget_one: + # forget oldest memmap. + # ... TODO: possibly add a warning? It may be okay to be silent though. + filekey, mid = next(iter(memory['order'].keys())) + # commented lines for debugging 'too many files' crash; will be removed in the future: + # x = memory[filekey]['memdicts'][mid]['value'] # this is the memmap + #print('there are {} references to the map.'.format(sys.getrefcount(x))) + memdata = memory['data'] + memdict = memdata[filekey]['memdicts'][mid] + del memdict['value'] # this is the memmap + #print('there are {} references to the map (after deleting dict)'.format(sys.getrefcount(x))) + #print('referrers are: ', referrers(x)) + del memory['order'][(filekey, mid)] + memory['len'] -= 1 + # return f(*args, **kwargs) + return f(*args, kw_mem=kw_mem, **kwargs) + + return f_but_forget_memmaps_if_needed + return decorator + +# for debugging 'too many files' crash; will be removed in the future: + + +def namestr(obj, namespace): + return [name for name in namespace if namespace[name] is obj] + +# for debugging 'too many files' crash; will be removed in the future: + + +def referrers(obj): + return [namestr(refe, globals()) for refe in gc.get_referrers(obj)] + + +''' --------------------- cache --------------------- ''' + +CacheEntry = namedtuple('CacheEntry', ['value', 'metadata', 'id', 'nbytes', 'calctime', 'qtracking_state'], + defaults=[None, None, None, None, None, dict()]) +# value: value. +# metadata: additional params which are associated with this value of var. +# id: unique id associated to this var and cache_params for this cache. +# nbytes: number of bytes in value +# calctime: amount of time taken to calculate value. + + +def _fmt_SL(SL, sizing=2): + '''pretty formatting for species,level''' + if SL is None: + totlen = len('(') + sizing + len(', ') + sizing + len(')') + totlen = str(totlen) + fmtstr = '{:^'+totlen+'s}' # e.g. '{:8s}' + return fmtstr.format(str(None)) + else: + sizing = str(sizing) + fmtnum = '{:'+sizing+'d}' # e.g. '{:2d}' + fmtstr = '('+fmtnum+', '+fmtnum+')' + return fmtstr.format(SL[0], SL[1]) + + +def _new_cache_entry_str_(x): + '''new __str__ method for CacheEntry, which shows a much more readable format. + To get the original (namedtuple-style) representation of CacheEntry object x, use repr(x). + ''' + FMT_SNAP = '{:3d}' + FMT_DATA = '{: .3e}' + FMT_META = '{: .2e}' + snap = FMT_SNAP.format(x.metadata.get('snap', None)) + ifluid = _fmt_SL(x.metadata.get('ifluid', None)) + jfluid = _fmt_SL(x.metadata.get('jfluid', None)) + value = x.value + valmin = None if value is None else FMT_DATA.format(np.min(value)) + valmean = None if value is None else FMT_DATA.format(np.mean(value)) + valmax = None if value is None else FMT_DATA.format(np.max(value)) + nbytes = FMT_META.format(x.nbytes) + calctime = FMT_META.format(x.calctime) + result = ('CacheEntryView(snap={snap:}, ifluid={ifluid:}, jfluid={jfluid:}, ' + 'valmin={valmin:}, valmean={valmean:}, valmax={valmax:}, ' + 'nbytes={nbytes:}, calctime={calctime:})' + ) + result = result.format( + snap=snap, ifluid=ifluid, jfluid=jfluid, + valmin=valmin, valmean=valmean, valmax=valmax, + nbytes=nbytes, calctime=calctime) + return result + + +# actually overwrite the __str__ method for CacheEntry: +CacheEntry.__str__ = _new_cache_entry_str_ + + +class Cache(): + '''cache results of get_var. + can contain up to self.max_MB MB of data, and up to self.max_Narr entries. + Deletes oldest entries first when needing to free up space. + + self.performance tells total number of times arrays have been recalled, + and total amount of time saved (estimate based on time it took to read the first time.) + (Note the time saved is usually an overestimate unless you have N_memmap=0.) + + self.contents() shows a human-readable view of cache contents. + ''' + + def __init__(self, obj=None, max_MB=10, max_Narr=20): + '''initialize Cache. + + obj: None or object with _metadata() and _metadata_matches() methods. + Cache remembers this obj and uses these methods, if possible. + obj._metadata() must accept kwarg with_nfluid, and must return a dict. + obj._metadata_matches() must take a single dict as input, and must return a bool. + max_MB: 10 (default) or number + maximum number of MB of data which cache is allowed to store at once. + max_Narr: 20 (default) or number + maximum number of arrays which cache is allowed to store at once. + ''' + # set attrs which dictate max size of cache + self.max_MB = max_MB + self.max_Narr = max_Narr + # set parent, using weakref, to ensure we don't keep parent alive just because Cache points to it. + self.parent = (lambda: None) if (obj is None) else weakref.ref(obj) + # initialize self.performance, which will track the performance of Cache. + self.performance = dict(time_saved_estimate=0, N_recalled=0, N_recalled_unknown_time_savings=0) + # initialize attrs for internal use. + self._content = dict() + self._next_cacheid = 0 # unique id associated to each cache entry (increases by 1 each time) + self._order = [] # list of (var, id) + self._nbytes = 0 # number of bytes of data stored in self. + self.debugging = False # if true, print some helpful debugging statements. + + def get_parent_attr(self, attr, default=None): + '''return getattr(self.parent(), attr, default) + Caution: to ensure weakref is useful and no circular reference is created, + make sure to not save the result of get_parent_attr as an attribute of self. + ''' + return getattr(self.parent(), attr, None) + + def _metadata(self, *args__parent_metadata, **kw__parent_metadata): + '''returns self.parent()._metadata() if it exists; else None.''' + get_metadata_func = self.get_parent_attr('_metadata') + if get_metadata_func is not None: + return get_metadata_func(*args__parent_metadata, **kw__parent_metadata) + else: + return None + + def get_metadata(self, metadata=None, obj=None, with_nfluid=2): + '''returns metadata, given args. + + metadata: None or dict + if not None, return this value immediately. + obj: None or object with _metadata() method which returns dict + if not None, return obj._metadata(with_nfluid=with_nfluid) + with_nfluid: 2, 1, or 0 + if obj is not None, with_nfluid is passed to obj._metadata. + else, with_nfluid is passed to self.parent_get_metadata. + + This method's default behavior (i.e. behavior when called with no arguments) + is to return self.parent_get_metadata(with_nfluid=2). + ''' + if metadata is not None: + return metadata + if obj is not None: + return obj._metadata(with_nfluid=with_nfluid) + parent_metadata = self._metadata(with_nfluid=with_nfluid) + if parent_metadata is not None: + return parent_metadata + raise ValueError('Expected non-None metadata, obj, or self.parent_get_metadata, but all were None.') + + def _metadata_matches(self, cached_metadata, metadata=None, obj=None): + '''return whether metadata matches cached_metadata. + + if self has parent and self.parent() has _metadata_matches method: + return self.parent()._metadata_matches(cached_metadata) + else: + return _dict_equals(cached_metadata, self.get_metadata(metadata, obj)) + ''' + metadata_matches_func = self.get_parent_attr('_metadata_matches') + if metadata_matches_func is not None: + return metadata_matches_func(cached_metadata) + else: + return _dict_equals(cached_metadata, self.get_metadata(metadata=metadata, obj=obj)) + + def get(self, var, metadata=None, obj=None): + '''return entry associated with var and metadata in self, + if such an entry exists. Else, return empty CacheEntry. + + if Cache was initialized with obj, use obj._metadata() to + var: string + metadata: None (default) or dict + check that this agrees with cached metadata before returning result. + obj: None (default) or EbysusData object + if not None, use obj to determine params and fluids. + + if metadata and obj are None, tries to use metadata from self.parent(). + ''' + try: + var_cache_entries = self._content[var] + except KeyError: + if self.debugging >= 2: + print(' > Getting {:15s}; var not found in cache.'.format(var)) + return CacheEntry(None) # var is not in self. + # else (var is in self): + for entry in var_cache_entries: + if self._metadata_matches(entry.metadata, metadata=metadata, obj=obj): + # we found a match! So, return this entry (after doing some bookkeeping). + if self.debugging >= 1: + print(' -> Loaded {:^15s} -> {}'.format(var, entry)) + # update performance tracker. + self._update_performance_tracker(entry) + # update QUANT_SELECTED in self.parent() + parent = self.parent() + if parent is not None: + document_vars.restore_quant_tracking_state(parent, entry.qtracking_state) + return entry + # else (var is in self but not associated with this metadata): + if self.debugging >= 2: + print(' > Getting {:15s}, var in cache but not with this metadata.'.format(var)) + return CacheEntry(None) + + def cache(self, var, val, metadata=None, obj=None, with_nfluid=2, calctime=None, from_internal=False): + '''add var with value val (and associated with cache_params) to self.''' + if self.debugging >= 2: + print(' < Caching {:15s}; with_nfluid={}'.format(var, with_nfluid)) + val = np.array(val, copy=True, subok=True) # copy ensures value in cache isn't altered even if val array changes. + nbytes = val.nbytes + self._nbytes += nbytes + metadata = self.get_metadata(metadata=metadata, obj=obj, with_nfluid=with_nfluid) + quant_tracking_state = document_vars.get_quant_tracking_state(self.parent(), from_internal=from_internal) + entry = CacheEntry(value=val, metadata=metadata, + id=self._take_next_cacheid(), nbytes=nbytes, calctime=calctime, + qtracking_state=quant_tracking_state) + if self.debugging >= 1: + print(' <- Caching {:^15s} <- {}'.format(var, entry)) + if var in self._content.keys(): + self._content[var] += [entry] + else: + self._content[var] = [entry] + self._order += [(var, entry.id)] + self._shrink_cache_as_needed() + + def remove_one_entry(self, id=None): + '''removes the oldest entry in self. returns id of entry removed. + if id is not None, instead removes the entry with id==id. + ''' + if id is None: + oidx = 0 + var, eid = self._order[oidx] + else: + try: + oidx, (var, eid) = next(((i, x) for i, x in enumerate(self._order) if x[1] == id)) + except StopIteration: + raise KeyError('id={} not found in cache {}'.format(id, self)) + var_entries = self._content[var] + i = next((i for i, entry in enumerate(var_entries) if entry.id == eid)) + self._nbytes -= var_entries[i].nbytes + del var_entries[i] + del self._order[oidx] + return eid + + def clear(self): + '''remove all entries from self. + Returns (Original number of entries, Original number of bytes). + ''' + result = (len(self._order), self._nbytes) + while len(self._order) > 0: + self.remove_one_entry() + return result + + def __repr__(self): + '''pretty print of self''' + s = '<{self:} totaling {MB:0.3f} MB, containing {N:} cached values from {k:} vars: {vars:}>' + vars = list(self._content.keys()) + if len(vars) > 20: # then we will show only the first 20. + svars = '[' + ', '.join(vars[:20]) + ', ...]' + else: + svars = '[' + ', '.join(vars) + ']' + return s.format(self=object.__repr__(self), MB=self._nMB(), N=len(self._order), k=len(vars), vars=svars) + + def contents(self): + '''pretty display of contents (as CacheEntryView tuples). + To access the content data directly, use self._content. + ''' + result = dict() + for var, content in self._content.items(): + result[var] = [] + for entry in content: + result[var] += [str(entry)] + return result + + def _update_performance_tracker(self, entry): + '''update self.performance as if we just got entry from cache once.''' + self.performance['N_recalled'] += 1 + savedtime = entry.calctime + if savedtime is None: + self.performance['N_recalled_unknown_time_savings'] += 1 + else: + self.performance['time_saved_estimate'] += savedtime + + def _take_next_cacheid(self): + result = self._next_cacheid + self._next_cacheid += 1 + return result + + def _max_nbytes(self): + return self.max_MB * 1024 * 1024 + + def _nMB(self): + return self._nbytes / (1024 * 1024) + + def _shrink_cache_as_needed(self): + '''shrink cache to stay within limits of number of entries and amount of data.''' + while len(self._order) > self.max_Narr: + self.remove_one_entry() + max_nbytes = self._max_nbytes() + while self._nbytes > max_nbytes: + self.remove_one_entry() + + def is_NoneCache(self): + '''return if self.max_MB <= 0 or self.max_Narr <= 0''' + return (self.max_MB <= 0 or self.max_Narr <= 0) + + if DEBUG_MEMORY_LEAK: + def __del__(self): + print('deleted {}'.format(self)) + + +def with_caching(check_cache=True, cache=False, cache_with_nfluid=None): + '''decorate function so that it does caching things. + + cache, check_cache, and nfluid values passed to with_caching() + will become the _default_ values of these kwargs for the function + which is being decorated (but other values for these kwargs + can still be passed to that function, later). + + cache: whether to store result in obj.cache + check_cache: whether to try to get result from obj.cache if it exists there. + cache_with_nfluid - None (default), 0, 1, or 2 + if not None, cache result and associate it with this many fluids. + 0 -> neither; 1 -> just ifluid; 2 -> both ifluid and jfluid. + ''' + def decorator(f): + @functools.wraps(f) + def f_but_caching(obj, var, *args_f, + check_cache=check_cache, cache=cache, cache_with_nfluid=cache_with_nfluid, + **kwargs_f): + '''do f(obj, *args_f, **kwargs_f) but do caching things as appropriate, + i.e. check cache first (if check_cache) and store result (if cache). + ''' + __tracebackhide__ = HIDE_DECORATOR_TRACEBACKS + if getattr(obj, '_force_disable_memory', False): + cache = check_cache = False + + val = None + if (not getattr(obj, 'do_caching', True)) or (not hasattr(obj, 'cache')) or (obj.cache.is_NoneCache()): + cache = check_cache = False + elif cache_with_nfluid is not None: + cache = True + if (cache or check_cache): + track_timing = True + # check cache for result (if check_cache==True) + if check_cache: + entry = obj.cache.get(var) + val = entry.value + if cache and (val is not None): + # remove entry from cache to prevent duplicates (because we will re-add entry soon) + obj.cache.remove_one_entry(id=entry.id) + # use timing from original entry + track_timing = False + calctime = entry.calctime + if cache and track_timing: + now = time.time() # track timing, so we can estimate how much time cache is saving. + # calculate result (if necessary) + if val is None: + val = f(obj, var, *args_f, **kwargs_f) + # save result to obj.cache (if cache==True) + if cache: + if track_timing: + calctime = time.time() - now + obj.cache.cache(var, val, with_nfluid=cache_with_nfluid, calctime=calctime) + # return result + return val + return f_but_caching + return decorator + + +class Caching(): + '''context manager which lets you do caching by setting self.result to a value.''' + + def __init__(self, obj, nfluid=None): + self.obj = obj + self.nfluid = nfluid + + def __enter__(self): + self.caching = (getattr(self.obj, 'do_caching', True)) \ + and (hasattr(self.obj, 'cache')) \ + and (not self.obj.cache.is_NoneCache()) + if self.caching: + self.start = time.time() + self.metadata = self.obj._metadata(with_nfluid=self.nfluid) + + def _cache_it(var, value, restart_timer=True): + '''save this result to cache.''' + if not self.caching: + return + # else + calctime = time.time() - self.start + self.obj.cache.cache(var, value, metadata=self.metadata, calctime=calctime, from_internal=True) + if restart_timer: + self.start = time.time() # restart the clock. + + return _cache_it + + def __exit__(self, exc_type, exc_value, traceback): + pass + + +def _dict_matches(A, B, subset_ok=True, ignore_keys=[]): + '''returns whether A matches B for dicts A, B. + + A "matches" B if for all keys in A, A[key] == B[key]. + + subset_ok: True (default) or False + if False, additionally "A matches B" requires A.keys() == B.keys() + ignore_keys: list (default []) + these keys are never checked; A[key] need not equal B[key] for key in ignore_keys. + + This function is especially useful when checking dicts which may contain numpy arrays, + because numpy arrays override __equals__ to return an array instead of True or False. + ''' + ignore_keys = set(ignore_keys) + keysA = set(A.keys()) - ignore_keys + keysB = set(B.keys()) - ignore_keys + if not subset_ok: + if not keysA == keysB: + return False + for key in keysA: + eq = (A[key] == B[key]) + if isinstance(eq, np.ndarray): + if not np.all(eq): + return False + elif eq == False: + return False + elif eq == True: + pass # continue on to next key. + else: + raise ValueError("Object equality was not boolean nor np.ndarray. Don't know what to do. " + + "Objects = {:}, {:}; (x == y) = {:}; type((x==y)) = {:}".format( + A[key], B[key], eq, type(eq))) + return True + + +def _dict_equals(A, B, ignore_keys=[]): + '''returns whether A==B for dicts A, B. + Even works if some contents are numpy arrays. + ''' + return _dict_matches(A, B, subset_ok=False, ignore_keys=ignore_keys) + + +def _dict_is_subset(A, B, ignore_keys=[]): + '''returns whether A is a subset of B, i.e. whether for all keys in A, A[key]==B[key]. + Even works if some contents are numpy arrays. + ''' + return _dict_matches(A, B, subset_ok=True, ignore_keys=ignore_keys) diff --git a/helita/sim/fluid_tools.py b/helita/sim/fluid_tools.py new file mode 100644 index 00000000..b8786673 --- /dev/null +++ b/helita/sim/fluid_tools.py @@ -0,0 +1,736 @@ +""" +created by Sam Evans on Apr 19 2021 + +purpose: + - tools for fluids in ebysus.py +""" + +import warnings +# import built-ins +import functools +import itertools + +# import internal modules +from . import tools +from .load_mf_quantities import MATCH_AUX, MATCH_PHYSICS + +# import external private modules +try: + from atom_py.at_tools import fluids as fl +except ImportError: + fl = tools.ImportFailed('at_tools.fluids') + +# set defaults +HIDE_DECORATOR_TRACEBACKS = True # whether to hide decorators from this file when showing error traceback. + +# list of functions from fluid_tools which will be set as methods of the Multifluid class. +# for example, EbysusData inherits from Multifluid, so if Multifluid gets get_mass, then: +# for dd=EbysusData(...), dd.get_mass(*args, **kw) == fluid_tools.get_mass(dd, *args, **kw). +MULTIFLUID_FUNCS = \ + ['set_mf_fluid', 'set_mfi', 'set_mfj', 'set_fluids', + 'get_species_name', 'get_fluid_name', 'get_mass', 'get_charge', + 'get_cross_tab', 'get_cross_sect', 'get_coll_type', + 'i_j_same_fluid', 'fluid_SLs', 'fluid_SLs_and_names', + 'iter_fluid_SLs', 'iter_fluid_SLs_and_names'] + +''' --------------------- setting fluids --------------------- ''' + +# NOTE: these functions are largely obsolete, now. +# Thanks to the "magic" of property(), doing something like obj.ifluid=(1,2) +# will effectively set mf_ispecies and mf_ilevel appropriately. +# And, reading something like obj.ifluid will give values (obj.mf_ispecies, obj.mf_ilevel) +# However, we cannot delete these functions, for historical reasons. +# And, maybe they are still useful thanks to the kwarg interpretation in set_fluids. + + +def set_mf_fluid(obj, species=None, level=None, i='i'): + '''sets obj.mf_{i}species and obj.mf_{i}level. + species, level: None or int + None -> don't change obj.mf_{i}species, mf_{i}level. + ints -> set mf_{i}species=species, mf_{i}level=level. + ''' + setattr(obj, 'mf_'+i+'species', species) + setattr(obj, 'mf_'+i+'level', level) + + +def set_mfi(obj, mf_ispecies=None, mf_ilevel=None): + return obj.set_mf_fluid(mf_ispecies, mf_ilevel, 'i') + + +set_mfi.__doc__ = set_mf_fluid.__doc__.format(i='i') + + +def set_mfj(obj, mf_jspecies=None, mf_jlevel=None): + return obj.set_mf_fluid(mf_jspecies, mf_jlevel, 'j') + + +set_mfj.__doc__ = set_mf_fluid.__doc__.format(i='j') + + +def set_fluids(obj, **kw__fluids): + '''interprets kw__fluids then sets them using set_mfi and set_mfj. + returns (ifluid, jfluid). + ''' + (si, li, sj, lj) = _interpret_kw_fluids(**kw__fluids) + obj.set_mfi(si, li) + obj.set_mfj(sj, lj) + return (obj.ifluid, obj.jfluid) + + +''' --------------------- fluid kwargs --------------------- ''' + + +def _interpret_kw_fluids(mf_ispecies=None, mf_ilevel=None, mf_jspecies=None, mf_jlevel=None, + ifluid=None, jfluid=None, iSL=None, jSL=None, + iS=None, iL=None, jS=None, jL=None, + **kw__None): + '''interpret kwargs entered for fluids. Returns (mf_ispecies, mf_ilevel, mf_jspecies, mf_jlevel). + kwargs are meant to be shorthand notation. If conflicting kwargs are entered, raise ValueError. + **kw__None are ignored; it is part of the function def'n so that it will not break if extra kwargs are entered. + Meanings for non-None kwargs (similar for j, only writing for i here): + mf_ispecies, mf_ilevel = ifluid + mf_ispecies, mf_ilevel = iSL + mf_ispecies, mf_ilevel = iS, iL + Examples: + These all return (1,2,3,4) (they are equivalent): + _interpret_kw_fluids(mf_ispecies=1, mf_ilevel=2, mf_jspecies=3, mf_jlevel=4) + _interpret_kw_fluids(ifluid=(1,2), jfluid=(3,4)) + _interpret_kw_fluids(iSL=(1,2), jSL=(3,4)) + _interpret_kw_fluids(iS=1, iL=2, jS=3, jL=4) + Un-entered fluids will be returned as None: + _interpret_kw_fluids(ifluid=(1,2)) + >> (1,2,None,None) + Conflicting non-None kwargs will cause ValueError: + _interpret_kw_fluids(mf_ispecies=3, ifluid=(1,2)) + >> ValueError('mf_ispecies (==3) was incompatible with ifluid[0] (==1)') + _interpret_kw_fluids(mf_ispecies=1, ifluid=(1,2)) + >> (1,2,None,None) + ''' + si, li = _interpret_kw_fluid(mf_ispecies, mf_ilevel, ifluid, iSL, iS, iL, i='i') + sj, lj = _interpret_kw_fluid(mf_jspecies, mf_jlevel, jfluid, jSL, jS, jL, i='j') + return (si, li, sj, lj) + + +def _interpret_kw_ifluid(mf_ispecies=None, mf_ilevel=None, ifluid=None, iSL=None, iS=None, iL=None, None_ok=True): + '''interpret kwargs entered for ifluid. See _interpret_kw_fluids for more documentation.''' + return _interpret_kw_fluid(mf_ispecies, mf_ilevel, ifluid, iSL, iS, iL, None_ok=None_ok, i='i') + + +def _interpret_kw_jfluid(mf_jspecies=None, mf_jlevel=None, jfluid=None, jSL=None, jS=None, jL=None, None_ok=True): + '''interpret kwargs entered for jfluid. See _interpret_kw_fluids for more documentation.''' + return _interpret_kw_fluid(mf_jspecies, mf_jlevel, jfluid, jSL, jS, jL, None_ok=None_ok, i='j') + + +def _interpret_kw_fluid(mf_species=None, mf_level=None, fluid=None, SL=None, S=None, L=None, i='', None_ok=True): + '''interpret kwargs entered for fluid. Returns (mf_ispecies, mf_ilevel). + See _interpret_kw_fluids for more documentation. + i : 'i', or 'j'; Used to make clearer error messages, if entered. + None_ok: True (default) or False; + whether to allow answer of None or species and/or level. + if False and species and/or level is None, raise TypeError. + ''' + s, l = None, None + kws, kwl = '', '' + errmsg = 'Two incompatible fluid kwargs entered! {oldkw:} and {newkw:} must be equal ' + \ + '(unless one is None), but got {oldkw:}={oldval:} and {newkw:}={newval:}' + + def set_sl(news, newl, newkws, newkwl, olds, oldl, oldkws, oldkwl, i): + newkws, newkwl = newkws.format(i), newkwl.format(i) + if (olds is not None): + if (news is not None): + if (news != olds): + raise ValueError(errmsg.format(newkw=newkws, newval=news, oldkw=oldkws, oldval=olds)) + else: + news = olds + if (oldl is not None): + if (newl is not None): + if (newl != oldl): + raise ValueError(errmsg.format(newkw=newkwl, newval=newl, oldkw=oldkwl, oldval=oldl)) + else: + newl = oldl + return news, newl, newkws, newkwl + + if fluid is None: + fluid = (None, None) + if SL is None: + SL = (None, None) + s, l, kws, kwl = set_sl(mf_species, mf_level, 'mf_{:}species', 'mf_{:}level', s, l, kws, kwl, i) + s, l, kws, kwl = set_sl(fluid[0], fluid[1], '{:}fluid[0]', '{:}fluid[1]', s, l, kws, kwl, i) + s, l, kws, kwl = set_sl(SL[0], SL[1], '{:}SL[0]', '{:}SL[1]', s, l, kws, kwl, i) + s, l, kws, kwl = set_sl(S, L, '{:}S', '{:}L', s, l, kws, kwl, i) + if not None_ok: + if s is None or l is None: + raise TypeError('{0:}species and {0:}level cannot be None, but got: '.format(i) + + 'mf_{0:}species={1:}; mf_{0:}level={2:}.'.format(i, s, l)) + return s, l + + +''' --------------------- fluid SL context managers --------------------- ''' + + +class _MaintainingFluids(): + '''context manager which restores ifluid and jfluid to original values, upon exit. + + Example: + dd = EbysusData(...) + dd.set_mfi(2,3) + print(dd.ifluid) #>> (2,3) + with _MaintainingFluids(dd): + print(dd.ifluid) #>> (2,3) + dd.set_mfi(4,5) + print(dd.ifluid) #>> (4,5) + print(dd.ifluid) #>> (2,3) + ''' + + def __init__(self, obj): + self.obj = obj + self.orig_ifluid = obj.ifluid + self.orig_jfluid = obj.jfluid + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + self.obj.set_mfi(*self.orig_ifluid) + self.obj.set_mfj(*self.orig_jfluid) + + +_MaintainFluids = _MaintainingFluids # alias + + +class _UsingFluids(_MaintainingFluids): + '''context manager for using fluids, but ending up with the same ifluid & jfluid at the end. + upon enter, set fluids, based on kw__fluids. + upon exit, restore original fluids. + + Example: + dd = EbysusData(...) + dd.set_mfi(1,1) + print(dd.ifluid) #>> (1,1) + with _UsingFluids(dd, ifluid=(2,3)): + print(dd.ifluid) #>> (2,3) + dd.set_mfi(4,5) + print(dd.ifluid) #>> (4,5) + print(dd.ifluid) #>> (1,1) + ''' + + def __init__(self, obj, **kw__fluids): + _MaintainingFluids.__init__(self, obj) + (si, li, sj, lj) = _interpret_kw_fluids(**kw__fluids) + self.ifluid = (si, li) + self.jfluid = (sj, lj) + + def __enter__(self): + self.obj.set_mfi(*self.ifluid) + self.obj.set_mfj(*self.jfluid) + + # __exit__ is inheritted from MaintainingFluids. + + +_UseFluids = _UsingFluids # alias + + +def maintain_fluids(f): + '''decorator version of _MaintainFluids. first arg of f must be an EbysusData object.''' + return tools.maintain_attrs('ifluid', 'jfluid')(f) + + +def use_fluids(**kw__fluids): + '''returns decorator version of _UseFluids. first arg of f must be an EbysusData object.''' + def decorator(f): + @functools.wraps(f) + def f_but_use_fluids(obj, *args, **kwargs): + __tracebackhide__ = HIDE_DECORATOR_TRACEBACKS + with _UsingFluids(obj, **kw__fluids): + return f(obj, *args, **kwargs) + return f_but_use_fluids + return decorator + + +''' --------------------- iterators over fluids --------------------- ''' + + +def fluid_pairs(fluids, ordered=False, allow_same=False): + '''returns an iterator over fluids of obj. + + ordered: False (default) or True + False -> (A,B) and (B,A) will be yielded separately. + True -> (A,B) will be yielded, but (B,A) will not. + allow_same: False (default) or True + False -> (A,A) will never be yielded. + True -> (A,A) will be yielded. + + This function just returns a combinatoric iterators from itertools. + defaults lead to calling itertools.permutations(fluids) + + Example: + for (ifluid, jfluid) in fluid_pairs([(1,2),(3,4),(5,6)], ordered=True, allow_same=False): + print(ifluid, jfluid, end=' | ') + # >> (1, 2) (3, 4) | (1, 2) (5, 6) | (3, 4) (5, 6) | + ''' + if ordered and allow_same: + return itertools.combinations_with_replacement(fluids, 2) + elif ordered and not allow_same: + return itertools.combinations(fluids, 2) + elif not ordered and not allow_same: + return itertools.permutations(fluids, 2) + elif not ordered and allow_same: + return itertools.product(fluids, repeat=2) + assert False # we should never reach this line... + + +def iter_fluid_SLs(dd, with_electrons=True, iset=False): + '''returns an iterator over the fluids of dd, and electrons. + yields SL pairs; NOT at_tools.fluids.Fluid objects! + example: list(iter_fluids(dd)) = [(-1,0), (1,1), (1,2)]. + + if iset (default False, for now), + also sets dd.ifluid to the SL as we iterate. + + with_electrons: bool + True --> electrons are included, first: (-1,0) + False --> electrons are not included. + ''' + if iset: + for SL in iter_fluid_SLs(dd, with_electrons, iset=False): + dd.ifluid = SL + yield SL + else: + if with_electrons: + yield (-1, 0) + for fluid in dd.fluids: + yield fluid.SL + + +def fluid_SLs(dd, with_electrons=True): + '''returns list of (species, level) pairs for fluids in dd. + See also: iter_fluid_SLs + + with_electrons: bool + True --> electrons are included, first: (-1,0) + False --> electrons are not included. + ''' + return list(iter_fluid_SLs(dd, with_electrons=with_electrons)) + + +def iter_fluid_SLs_and_names(dd, with_electrons=True, iset=False): + '''returns and iterator over the fluids of dd, and electrons. + yields ((species, level), name) + + if iset (default False, for now), + also sets dd.ifluid to the SL as we iterate. + + with_electrons: bool + True --> electrons are included, first: ((-1,0),'e-'). + False --> electrons are not included. + ''' + for SL in dd.iter_fluid_SLs(with_electrons=with_electrons, iset=iset): + yield (SL, dd.get_fluid_name(SL)) + + +def fluid_SLs_and_names(dd, with_electrons=True): + '''returns list of ((species, level), name) for fluids in dd. + See also: iter_fluid_SLs_and_names + + with_electrons: bool + True --> electrons are included, first: ((-1,0),'e-'). + False --> electrons are not included. + ''' + return list(iter_fluid_SLs_and_names(dd, with_electrons=with_electrons)) + + +''' --------------------- compare fluids --------------------- ''' + + +def i_j_same_fluid(obj): + '''returns whether obj.ifluid and obj.jfluid represent the same fluid.''' + return fluid_equals(obj.ifluid, obj.jfluid) + + +def fluid_equals(iSL, jSL): + '''returns whether iSL and jSL represent the same fluid.''' + if iSL[0] < 0 and jSL[0] < 0: + return True + else: + return (iSL == jSL) + + +''' --------------------- small helper functions --------------------- ''' +# for each of these functions, obj should be an EbysusData object. + + +def get_species_name(obj, specie=None): + '''return specie's name: 'e' for electrons; element (atomic symbol) for other fluids. + if specie is None, use obj.mf_ispecies. + ''' + if specie is None: + obj.iS + if specie < 0: + return 'e' + else: + return obj.att[specie].params.element + + +def get_fluid_name(obj, fluid=None): + '''return fluid's name: 'e-' for electrons; element & ionization for other fluids (e.g. 'H II'). + fluid can be at_tools.fluids.Fluid object, (species, level) pair, or -1 (for electrons). + if fluid is None, use obj.ifluid. + ''' + if fluid is None: + fluid = obj.ifluid + try: + return fluid.name + except AttributeError: + try: + specie = fluid[0] + except TypeError: + specie = fluid + if not (specie < 0): + errmsg_badfluid = ('Expected at_tools.fluids.Fluid object or (species, level) for fluid, ' + 'but got fluid = {}'.format(fluid)) + raise TypeError(errmsg_badfluid) + if specie < 0: + return 'e-' + else: + return obj.fluids[fluid].name + + +def get_mass(obj, specie=None, units='amu'): + '''return specie's mass [units]. default units is amu. + units: one of: ['amu', 'g', 'kg', 'cgs', 'si', 'simu']. Default 'amu' + 'amu' -> mass in amu. For these units, mH ~= 1 + 'g' or 'cgs' -> mass in grams. For these units, mH ~= 1.66E-24 + 'kg' or 'si' -> mass in kg. For these units, mH ~= 1.66E-27 + 'simu' -> mass in simulation units. + if specie is None, use specie = obj.mf_ispecies + ''' + if specie is None: + specie = obj.iS + # if specie is actually (spec, level) return get_mass(obj, spec) instead. + try: + specie = next(iter(specie)) + except TypeError: + pass + else: + return get_mass(obj, specie, units=units) + units = units.lower() + VALID_UNITS = ['amu', 'g', 'kg', 'cgs', 'si', 'simu'] + assert units in VALID_UNITS, "Units invalid; got units={}".format(units) + if specie < 0: + # electron + if units == 'amu': + return obj.uni.m_electron / obj.uni.amu + elif units in ['g', 'cgs']: + return obj.uni.m_electron + elif units in ['kg', 'si']: + return obj.uni.msi_e + else: # units == 'simu' + return obj.uni.simu_m_e + else: + # not electron + m_amu = obj.att[specie].params.atomic_weight + if units == 'amu': + return m_amu + elif units in ['g', 'cgs']: + return m_amu * obj.uni.amu + elif units in ['kg', 'si']: + return m_amu * obj.uni.amusi + else: # units == 'simu' + return m_amu * obj.uni.simu_amu + + +def get_charge(obj, SL=None, units='e'): + '''return the charge fluid SL in [units]. default is elementary charge units. + units: one of ['e', 'elementary', 'esu', 'c', 'cgs', 'si', 'simu']. Default 'elementary'. + 'e' or 'elementary' -> charge in elementary charge units. For these units, qH+ ~= 1. + 'c' or 'si' -> charge in SI units (Coulombs). For these units, qH+ ~= 1.6E-19 + 'esu' or 'cgs' -> charge in cgs units (esu). For these units, qH+ ~= 4.8E-10 + 'simu' -> charge in simulation units. + if SL is None, use SL = obj.ifluid. + ''' + if SL is None: + SL = obj.iSL + units = units.lower() + VALID_UNITS = ['e', 'elementary', 'esu', 'c', 'cgs', 'si', 'simu'] + assert units in VALID_UNITS, "Units invalid; got units={}".format(units) + # get charge, in 'elementary charge' units: + if (SL == -1) or (SL[0] < 0): + # electron + charge = -1. + else: + # not electron + charge = obj.fluids[SL].ionization + # convert to proper units and return: + if units in ['e', 'elementary']: + return charge + elif units in ['esu', 'cgs']: + return charge * obj.uni.q_electron + elif units in ['c', 'si']: + return charge * obj.uni.qsi_electron + else: # units=='simu' + return charge * obj.uni.simu_qsi_e + + +def get_cross_tab(obj, iSL=None, jSL=None, **kw__fluids): + '''return (filename of) cross section table for obj.ifluid, obj.jfluid. + use S=-1 for electrons. (e.g. iSL=(-1,1) represents electrons.) + either ifluid or jfluid must be neutral. (charged -> Coulomb collisions.) + iSL, jSL, kw__fluids behavior is the same as in get_var. + ''' + iSL, jSL = obj.set_fluids(iSL=iSL, jSL=jSL, **kw__fluids) + if iSL == jSL: + warnings.warn('Tried to get cross_tab when ifluid==jfluid. (Both equal {})'.format(iSL)) + icharge, jcharge = (get_charge(obj, SL) for SL in (iSL, jSL)) + assert icharge == 0 or jcharge == 0, "cannot get cross_tab for charge-charge interaction." + # force ispecies to be neutral (swap i & j if necessary; cross tab is symmetric). + if icharge != 0: + return get_cross_tab(obj, jSL, iSL) + # now, ispecies is the neutral one. + # now we will actually get the filename. + CTK = 'CROSS_SECTIONS_TABLES' + if (jSL == -1) or (jSL[0] < 0): + # electrons + cross_tab_table = obj.mf_etabparam[CTK] + for row in cross_tab_table: + # example row looks like: ['01', 'e-h-bruno-fits.txt'] + # contents are: [mf_species, filename] + if int(row[0]) == iSL[0]: + return row[1] + else: + # not electrons + cross_tab_table = obj.mf_tabparam[CTK] + for row in cross_tab_table: + # example row looks like: ['01', '02', '01', 'he-h-bruno-fits.txt'] + # contents are: [mf_ispecies, mf_jspecies, mf_jlevel, filename] + if int(row[0]) == iSL[0]: + if int(row[1]) == jSL[0]: + if int(row[2]) == jSL[1]: + return row[3] + # if we reach this line, we couldn't find cross section file, so make the code crash. + errmsg = "Couldn't find cross section file for ifluid={}, jfluid={}. ".format(iSL, jSL) + \ + "(We looked in obj.mf_{}tabparam['{}'].)".format(('e' if jSL[0] < 0 else ''), CTK) + raise ValueError(errmsg) + + +def get_cross_sect(obj, **kw__fluids): + '''returns Cross_sect object containing cross section data for obj.ifluid & obj.jfluid. + equivalent to obj.cross_sect(cross_tab=[get_cross_tab(obj, **kw__fluids)]) + + common use-case: + obj.get_cross_sect().tab_interp(tg_array) + ''' + return obj.cross_sect([obj.get_cross_tab(**kw__fluids)]) + + +def get_coll_type(obj, iSL=None, jSL=None, **kw__fluids): + '''return type of collisions between obj.ifluid, obj.jfluid. + use S=-1 for electrons. (e.g. iSL=(-1,1) represents electrons.) + iSL, jSL, kw__fluids behavior is the same as in get_var. + + if ifluid or jfluid is electrons: + if both are charged: return ('EE', 'CL') + if one is neutral: return ('EE', 'EL') + + if obj.match_type != MATCH_AUX: + enforce that coll_keys=obj.coll_keys[(iSL, jSL)] MUST exist. + an appropriate key ('CL' if both charged; otherwise 'EL' or 'MX') MUST appear in coll_keys. + return the appropriate key if possible, else raise KeyError. + otherwise: (obj.match_type == MATCH_AUX) + return the appropriate key if possible, else return None + + note that 'EL' and 'MX' are mutually exclusive; raise ValueError if both are found in coll_keys. + + keys mean: + 'EL' <--> elastic collisions, + 'MX' <--> maxwell collisions, + 'CL' <--> coulomb collisions. + + returns one of: None, 'EL', 'MX', 'CL', ('EE', 'CL'), or ('EE', 'EL'). + ''' + iSL, jSL = obj.set_fluids(iSL=iSL, jSL=jSL, **kw__fluids) + icharge = obj.get_charge(iSL) + jcharge = obj.get_charge(jSL) + if icharge < 0 or jcharge < 0: + implied_coll_key = 'CL' if (icharge != 0 and jcharge != 0) else 'EL' + return ('EE', implied_coll_key) + matching_aux = (getattr(obj, 'match_type', MATCH_PHYSICS) == MATCH_AUX) + key_errmsg_start = '' # no errmsg. + try: + coll_keys = obj.coll_keys[(iSL[0], jSL[0])] # note: obj.coll_keys only knows about species. + except KeyError: + key_errmsg_start = 'no coll_keys found for (iS, jS) = ({}, {})!'.format(iSL[0], jSL[0]) + if (len(key_errmsg_start)==0): + if (icharge != 0) and (jcharge != 0): + if 'CL' in coll_keys: + return 'CL' + elif matching_aux: # and 'CL' not in coll_keys + return None + else: + key_errmsg_start = "did not find coll key 'CL' for collisions between charged fluids!" + else: + is_EL = 'EL' in coll_keys + is_MX = 'MX' in coll_keys + if is_EL and is_MX: + errmsg = 'got EL and MX in coll_keys for ifluid={}, jfluid={}.' +\ + 'But EL and MX are mutually exclusive. Crashing...' + raise ValueError(errmsg.format(ifluid=iSL, jfluid=jSL)) + elif is_EL: + return 'EL' + elif is_MX: + return 'MX' + elif matching_aux: # and 'EL' not in coll_keys and 'MX' not in coll_keys + return None + else: + key_errmsg_start = "did not find either coll key 'EL' or 'MX' for collisions involving >=1 neutral fluid." + assert (len(key_errmsg_start) != 0), "code above should handle all cases where there is no error..." + key_errmsg = key_errmsg_start + '\n' + \ + 'Most common cause: mistakes or missing keys in COLL KEYS in mf_param_file. ' +\ + 'Alternative option: set obj.match_type = helita.sim.ebysus.MATCH_AUX to skip collisions with missing coll keys.' + + +''' --------------------- MultiFluid class --------------------- ''' + + +def simple_property(internal_name, doc=None, name=None, **kw): + '''return a property with a setter and getter method for internal_name. + if 'default' in kw: + - getter will have a default of kw['default'], if attr has not been set. + - setter will do nothing if value is kw['default']. + ''' + if 'default' in kw: + default = kw['default'] + # define getter method + + def getter(self): + return getattr(self, internal_name, default) + # define setter method + + def setter(self, value): + if value is not default: + setattr(self, internal_name, value) + else: + # define getter method + def getter(self): + return getattr(self, internal_name) + # define setter method + + def setter(self, value): + setattr(self, internal_name, value) + # define deleter method + + def deleter(self): + delattr(self, internal_name) + # bookkeeping + if name is not None: + getter.__name__ = 'get_'+name + setter.__name__ = 'set_'+name + deleter.__name__ = 'del_'+name + # collect and return result. + return property(getter, setter, deleter, doc=doc) + + +def simple_tuple_property(*internal_names, doc=None, name=None, **kw): + '''return a property which refers to a tuple of internal names. + if 'default' in kw: + - getter will have a default of kw['default'], if attr has not been set. + - setter will do nothing if value is kw['default']. + This applies to each name in internal_names, individually. + ''' + if 'default' in kw: + default = kw['default'] + # define getter method + + def getter(self): + return tuple(getattr(self, name, default) for name in internal_names) + # define setter method + + def setter(self, value): + for name, val in zip(internal_names, value): + if val is not default: + setattr(self, name, val) + else: + # define getter method + def getter(self): + return tuple(getattr(self, name) for name in internal_names) + # define setter method + + def setter(self, value): + for name, val in zip(internal_names, value): + setattr(self, name, val) + # define deleter method + + def deleter(self): + for name in internal_names: + delattr(self, name) + # bookkeeping + if name is not None: + getter.__name__ = 'get_'+name + setter.__name__ = 'set_'+name + deleter.__name__ = 'del_'+name + # collect and return result. + return property(getter, setter, deleter, doc=doc) + + +# internal names for properties: +_IS = '_mf_ispecies' +_JS = '_mf_jspecies' +_IL = '_mf_ilevel' +_JL = '_mf_jlevel' + + +class Multifluid(): + '''class which tracks fluids, and contains methods related to fluids.''' + + def __init__(self, **kw): + self.set_fluids(**kw) + + ## PROPERTIES (FLUIDS) ## + ### "ORIGINAL PROPERTIES" ### + mf_ispecies = simple_property(_IS, default=None, name='mf_ispecies') + mf_jspecies = simple_property(_JS, default=None, name='mf_jspecies') + mf_ilevel = simple_property(_IL, default=None, name='mf_ilevel') + mf_jlevel = simple_property(_JL, default=None, name='mf_jlevel') + ### ALIASES - single ### + iS = simple_property(_IS, default=None, name='iS') + jS = simple_property(_JS, default=None, name='jS') + iL = simple_property(_IL, default=None, name='iL') + jL = simple_property(_JL, default=None, name='jL') + ### ALIASES - multiple ### + ifluid = simple_tuple_property(_IS, _IL, default=None, name='ifluid') + iSL = simple_tuple_property(_IS, _IL, default=None, name='iSL') + jfluid = simple_tuple_property(_JS, _JL, default=None, name='jfluid') + jSL = simple_tuple_property(_JS, _JL, default=None, name='jSL') + + ### FLUIDS OBJECT (from at_tools.fluids) ### + @property + def fluids(self): + '''at_tools.fluids.Fluids object describing the fluids in self.''' + if hasattr(self, '_fluids'): + return self._fluids + else: + return fl.Fluids(dd=self) + + ## METHODS ## + def fluids_equal(self, ifluid, jfluid): + '''returns whether ifluid and jfluid represent the same fluid.''' + return fluid_equals(ifluid, jfluid) + + def MaintainingFluids(self): + return _MaintainingFluids(self) + MaintainingFluids.__doc__ = _MaintainingFluids.__doc__.replace( + '_MaintainingFluids(dd', 'dd.MaintainingFluids(') # set docstring + MaintainFluids = MaintainingFluids # alias + + def UsingFluids(self, **kw__fluids): + return _UsingFluids(self, **kw__fluids) + + UsingFluids.__doc__ = _UsingFluids.__doc__.replace( + '_UsingFluids(dd, ', 'dd.UsingFluids(') # set docstring + UseFluids = UsingFluids # alias + + +# include bound versions of methods from this module into the Multifluid class. +for func in MULTIFLUID_FUNCS: + setattr(Multifluid, func, globals().get(func, NotImplementedError)) + +del func # (we don't want func to remain in the fluid_tools.py namespace beyond this point.) diff --git a/helita/sim/laresav.py b/helita/sim/laresav.py new file mode 100644 index 00000000..a88218e3 --- /dev/null +++ b/helita/sim/laresav.py @@ -0,0 +1,300 @@ +import os + +import numpy as np +from scipy.io import readsav as rsav + +from . import document_vars +from .load_arithmetic_quantities import * +from .load_noeos_quantities import * +from .load_quantities import * +from .tools import * + + +class Laresav: + """ + Class to read Lare3D sav file atmosphere + + Parameters + ---------- + fdir : str, optional + Directory with snapshots. + rootname : str, optional + Template for snapshot number. + verbose : bool, optional + If True, will print more information. + """ + + def __init__(self, snap, fdir='.', sel_units='cgs', verbose=True): + + self.fdir = fdir + try: + self.savefile = rsav(os.path.join(self.fdir, '{:03d}'.format(snap)+'.sav')) + self.root_name = '{:03d}'.format(snap) + except: + self.savefile = rsav(os.path.join(self.fdir, '{:04d}'.format(snap)+'.sav')) + self.root_name = '{:04d}'.format(snap) + + self.filename = self.savefile['d']['filename'][0] + + self.snap = snap + self.sel_units = sel_units + self.verbose = verbose + self.uni = Laresav_units() + + self.set_time() + + self.visc_heating = self.savefile['d']['visc_heating'][0].copy() + self.visc3_heating = self.savefile['d']['visc3_heating'][0].copy() + + self.x = self.savefile['d']['x'][0].copy().byteswap('=').newbyteorder('=') + self.y = self.savefile['d']['y'][0].copy().byteswap('=').newbyteorder('=') + self.z = self.savefile['d']['z'][0].copy().byteswap('=').newbyteorder('=') + self.z -= np.min(self.z) + + if self.sel_units == 'cgs': + self.x *= self.uni.uni['l'] + self.y *= self.uni.uni['l'] + self.z *= self.uni.uni['l'] + + # GRID STRUCT -> Array[1] + + self.nx = len(self.x) + self.ny = len(self.y) + self.nz = len(self.z) + + if self.nx > 1: + self.dx1d = np.gradient(self.x) + else: + self.dx1d = np.zeros(self.nx) + + if self.ny > 1: + self.dy1d = np.gradient(self.y) + else: + self.dy1d = np.zeros(self.ny) + + if self.nz > 1: + self.dz1d = np.gradient(self.z) + else: + self.dz1d = np.zeros(self.nz) + + self.transunits = False + + self.cstagop = False # This will not allow to use cstagger from Bifrost in load + self.hion = False # This will not allow to use HION from Bifrost in load + + self.genvar() + + document_vars.create_vardict(self) + document_vars.set_vardocs(self) + + def set_time(self): + + self.time = self.savefile['d']['time'][0].copy() + self.time_prev = self.savefile['d']['time_prev'][0].copy() + self.timestep = self.savefile['d']['timestep'][0].copy() + self.dt = self.savefile['d']['dt'][0].copy() + + def get_var(self, var, *args, snap=None, iix=None, iiy=None, iiz=None, layout=None, **kargs): + ''' + Reads the variables from a snapshot (snap). + + Parameters + ---------- + var - string + Name of the variable to read. Must be Bifrost internal names. + snap - integer, optional + Snapshot number to read. By default reads the loaded snapshot; + if a different number is requested, will load that snapshot. + + Axes: + ----- + x and y axes horizontal plane + z-axis is vertical axis, top corona is last index and positive. + + Variable list: + -------------- + rho -- Density (multipy by self.uni['r'] to get in g/cm^3) [nx, ny, nz] + energy -- Energy (multipy by self.uni['e'] to get in erg) [nx, ny, nz] + temperature -- Temperature (multipy by self.uni['tg'] to get in K) [nx, ny, nz] + vx -- component x of the velocity (multipy by self.uni['u'] to get in cm/s) [nx+1, ny+1, nz+1] + vy -- component y of the velocity (multipy by self.uni['u'] to get in cm/s) [nx+1, ny+1, nz+1] + vz -- component z of the velocity (multipy by self.uni['u'] to get in cm/s) [nx+1, ny+1, nz+1] + bx -- component x of the magnetic field (multipy by self.uni['b'] to get in G) [nx+1, ny, nz] + by -- component y of the magnetic field (multipy by self.uni['b'] to get in G) [nx, ny+1, nz] + bz -- component z of the magnetic field (multipy by self.uni['b'] to get in G) [nx, ny, nz+1] + jx -- component x of the current [nx+1, ny+1, nz+1] + jy -- component x of the current [nx+1, ny+1, nz+1] + jz -- component x of the current [nx+1, ny+1, nz+1] + pressure -- Pressure (multipy by self.uni['pg']) [nx, ny, nz] + eta -- eta (?) [nx, ny, nz] + + ''' + + if var in self.varn.keys(): + varname = self.varn[var] + else: + varname = var + + if snap != None: + self.snap = snap + try: + self.savefile = rsav(os.path.join(self.fdir, '{:03d}'.format(snap)+'.sav')) + except: + self.savefile = rsav(os.path.join(self.fdir, '{:04d}'.format(snap)+'.sav')) + self.set_time() + + try: + + if self.sel_units == 'cgs': + varu = var.replace('x', '') + varu = varu.replace('y', '') + varu = varu.replace('z', '') + if (var in self.varn.keys()) and (varu in self.uni.uni.keys()): + cgsunits = self.uni.uni[varu] + else: + cgsunits = 1.0 + else: + cgsunits = 1.0 + + self.data = (self.savefile['d'][varname][0].T).copy().byteswap('=').newbyteorder('=') * cgsunits + + if (np.shape(self.data)[0] > self.nx): + self.data = ((self.data[1:, :, :] + self.data[:-1, :, :]) / 2).copy() + + if (np.shape(self.data)[1] > self.ny): + self.data = ((self.data[:, 1:, :] + self.data[:, :-1, :]) / 2).copy() + + if (np.shape(self.data)[2] > self.nz): + self.data = ((self.data[:, :, 1:] + self.data[:, :, :-1]) / 2).copy() + + except: + # Loading quantities + if self.verbose: + print('Loading composite variable', end="\r", flush=True) + self.data = load_noeos_quantities(self, var, **kargs) + + if np.shape(self.data) == (): + self.data = load_quantities(self, var, PLASMA_QUANT='', CYCL_RES='', + COLFRE_QUANT='', COLFRI_QUANT='', IONP_QUANT='', + EOSTAB_QUANT='', TAU_QUANT='', DEBYE_LN_QUANT='', + CROSTAB_QUANT='', COULOMB_COL_QUANT='', AMB_QUANT='', + HALL_QUANT='', BATTERY_QUANT='', SPITZER_QUANT='', + KAPPA_QUANT='', GYROF_QUANT='', WAVE_QUANT='', + FLUX_QUANT='', CURRENT_QUANT='', COLCOU_QUANT='', + COLCOUMS_QUANT='', COLFREMX_QUANT='', **kargs) + + # Loading arithmetic quantities + if np.shape(self.data) == (): + if self.verbose: + print('Loading arithmetic variable', end="\r", flush=True) + self.data = load_arithmetic_quantities(self, var, **kargs) + + if document_vars.creating_vardict(self): + return None + elif var == '': + print(help(self.get_var)) + print('VARIABLES USING CGS OR GENERIC NOMENCLATURE') + for ii in self.varn: + print('use ', ii, ' for ', self.varn[ii]) + if hasattr(self, 'vardict'): + self.vardocs() + + return None + + return self.data + + def genvar(self): + ''' + Dictionary of original variables which will allow to convert to cgs. + ''' + self.varn = {} + self.varn['rho'] = 'rho' + self.varn['tg'] = 'temperature' + self.varn['e'] = 'energy' + self.varn['pg'] = 'pressure' + self.varn['ux'] = 'vx' + self.varn['uy'] = 'vy' + self.varn['uz'] = 'vz' + self.varn['bx'] = 'bx' + self.varn['by'] = 'by' + self.varn['bz'] = 'bz' + self.varn['jx'] = 'jx' + self.varn['jy'] = 'jy' + self.varn['jz'] = 'jz' + + def trans2comm(self, varname, snap=None): + ''' + Transform the domain into a "common" format. All arrays will be 3D. The 3rd axis + is: + + - for 3D atmospheres: the vertical axis + - for loop type atmospheres: along the loop + - for 1D atmosphere: the unique dimension is the 3rd axis. + At least one extra dimension needs to be created artifically. + + All of them should obey the right hand rule + + In all of them, the vectors (velocity, magnetic field etc) away from the Sun. + + If applies, z=0 near the photosphere. + + Units: everything is in cgs. + + If an array is reverse, do ndarray.copy(), otherwise pytorch will complain. + + ''' + + self.sel_units = 'cgs' + + self.trans2commaxes + + var = self.get_var(varname, snap=snap).copy() + + #var = transpose(var,(X,X,X)) + # also velocities. + + return var + + def trans2commaxes(self): + + if self.transunits == False: + # self.x = # including units conversion + # self.y = + # self.z = + # self.dx = + # self.dy = + # self.dz = + self.transunits = True + + def trans2noncommaxes(self): + + if self.transunits == True: + # opposite to the previous function + self.transunits = False + + +class Laresav_units(object): + + def __init__(self, verbose=False): + ''' + Units and constants in cgs + ''' + self.uni = {} + self.verbose = verbose + self.uni['b'] = 2.0 # Gauss + self.uni['l'] = 1.0e8 # Mm -> cm + self.uni['gamma'] = 5./3. + self.uni['rho'] = 1.67e-15 # gr cm^-3 + + globalvars(self) + + mu0 = 4.e-7*np.pi + + self.uni['u'] = self.uni['b']*1e-3 / np.sqrt(mu0 * self.uni['rho']*1e3) * 1e2 # cm/s + self.uni['tg'] = (self.uni['u']*1e-2)**2 * self.msi_h / self.ksi_b # K + self.uni['t'] = self.uni['l'] / self.uni['u'] # seconds + + # Units and constants in SI + convertcsgsi(self) + + self.uni['n'] = self.uni['rho'] / self.m_p / 2. # cm^-3 diff --git a/helita/sim/load_arithmetic_quantities.py b/helita/sim/load_arithmetic_quantities.py new file mode 100644 index 00000000..cdd4ed85 --- /dev/null +++ b/helita/sim/load_arithmetic_quantities.py @@ -0,0 +1,1324 @@ +""" +These quantities relate to doing manipulations. +Frequently, they are "added" to regular variable names. +Examples: + - get_var('u2') is roughly equal to get_var('ux')**2 + get_var('uy')**2 + get_var('uz')**2 + - get_var('drdxdn') takes derivative of 'r' and pushes down in x. + - get_var('rxup') pushes 'r' up in x, by interpolating. +In general, these are not hard coded for every variable, but rather you will add to names. +For example, you can do get_var('d'+var+'dxdn') for any var which get_var knows how to get. + +Interpolation guide: + 'up' moves up by 0.5 (i.e. half a grid cell) + 'dn' moves down by 0.5 (i.e. half a grid cell) + scalars are in center of cell, at (0,0,0). + e.g.: density, energy + B, p are on the faces of the cell. Example: + Bx at ( -0.5, 0 , 0 ) + By at ( 0 , -0.5, 0 ) + Bz at ( 0 , 0 , -0.5 ) + B = magnetic field; p = momentum density. + E, i are on the edges of the cell. Example: + Ex at ( 0 , -0.5, -0.5 ) + Ey at ( -0.5, 0 , -0.5 ) + Ez at ( -0.5, -0.5, 0 ) + E = electric field; i = current per unit area. +""" + + +import warnings +# import built-ins +from multiprocessing.dummy import Pool as ThreadPool + +# import internal modules +from . import document_vars, tools + +try: + from . import cstagger +except ImportError: + cstagger = tools.ImportFailed('cstagger', "This module is required to use stagger_kind='cstagger'.") +try: + from . import stagger +except ImportError: + stagger = tools.ImportFailed('stagger') + +# import external public modules +import numpy as np + +# import the relevant things from the internal module "units" +from .units import DIMENSIONLESS, UNI, UNITS_FACTOR_1, UNI_length, Usym + +# set constants +AXES = ('x', 'y', 'z') +YZ_FROM_X = dict(x=('y', 'z'), y=('z', 'x'), z=('x', 'y')) # right-handed coord system x,y,z given x. +EPSILON = 1.0e-20 # small number which is added in denominators of some operations. + + +# we need to convert to float32 before doing cstagger.do. +# not sure why this conversion isnt done in the cstagger method, but it is a bit +# painful to change the method itself (required re-installing helita for me) so we will +# instead just change our calls to the method here. -SE Apr 22 2021 +CSTAGGER_TYPES = ['float32'] # these are the allowed types + + +def do_stagger(arr, operation, default_type=CSTAGGER_TYPES[0], obj=None): + '''does stagger of arr. + For stagger_kind='cstagger', first does some preprocessing: + - ensure arr is the correct type, converting if necessary. + if type conversion is necessary, convert to default_type. + - TODO: check _can_interp here, instead of separately. + For other stagger kinds, + - assert that obj has been provided + - call stagger via obj: obj.stagger.do(arr, operation) + ''' + kind = getattr(obj, 'stagger_kind', stagger.DEFAULT_STAGGER_KIND) + if kind == 'cstagger': # use cstagger routine. + arr = np.array(arr, copy=False) # make numpy array, if necessary. + if arr.dtype not in CSTAGGER_TYPES: # if necessary, + arr = arr.astype(default_type) # convert type + return cstagger.do(arr, operation) # call the original cstagger function + else: # use stagger routine. + assert obj is not None, f'obj must be provided to use stagger, with stagger_kind = {stagger_kind}.' + return obj.stagger.do(arr, operation) + + +do_cstagger = do_stagger # << alias, for backwards compatibility. + + +def _can_interp(obj, axis, warn=True): + '''return whether we can interpolate. Make warning if we can't. + must check before doing any cstagger operation. + pythonic stagger methods (e.g. 'numba', 'numpy') make this check on their own. + ''' + if not obj.do_stagger: # this is True by default; if it is False we assume that someone + return False # intentionally turned off interpolation. So we don't make warning. + kind = getattr(obj, 'stagger_kind', stagger.DEFAULT_STAGGER_KIND) + if kind != 'cstagger': + return True # we can let the pythonic methods check _can_interp on their own. + if not getattr(obj, 'cstagger_exists', False): + if obj.verbose: + warnmsg = 'interpolation requested, but cstagger not initialized, for obj={}! '.format(object.__repr__(obj)) +\ + 'We will skip the interpolation, and instead return the original value.' + warnings.warn(warnmsg) # warn user we will not be interpolating! (cstagger doesn't exist) + return False + if not getattr(obj, 'n'+axis, 0) >= 5: + if obj.verbose: + warnmsg = 'requested interpolation in {x:} but obj.n{x:} < 5. '.format(x=axis) +\ + 'We will skip this interpolation, and instead return the original value.' + warnings.warn(warnmsg) # warn user we will not be interpolating! (dimension is too small) + return False + return True + + +''' --------------------- functions to load quantities --------------------- ''' + + +def load_arithmetic_quantities(obj, quant, *args__None, **kwargs__None): + '''load arithmetic quantities. + *args__None and **kwargs__None go to nowhere. + ''' + __tracebackhide__ = True # hide this func from error traceback stack. + quant = quant.lower() + + document_vars.set_meta_quant(obj, 'arquantities', 'Computes arithmetic quantities') + + # tell which funcs to use for getting things. (funcs will be called in the order listed here) + _getter_funcs = ( + get_center, get_deriv, get_interp, + get_module, get_horizontal_average, + get_gradients_vect, get_gradients_scalar, + get_dot_product, + get_square, get_lg, get_numop, get_ratios, get_parens, + get_projections, get_angle, + get_stat_quant, get_fft_quant, + get_multi_quant, + get_vector_product, # this is intentionally later in the order, so that e.g. "(eftimesb)2" will work. + ) + + val = None + # loop through the function and QUANT pairs, running the functions as appropriate. + for getter in _getter_funcs: + val = getter(obj, quant) + if val is not None: + break + else: # didn't break; val is still None + return None + # << did break; found a non-None val. + document_vars.select_quant_selection(obj) # (bookkeeping for obj.got_vars_tree(), obj.get_units(), etc.) + return val + + +# default +_DERIV_QUANT = ('DERIV_QUANT', ['d'+x+up+one for up in ('up', 'dn') for x in AXES for one in ('', '1')]) +# get value + + +def get_deriv(obj, quant): + ''' + Computes derivative of quantity. + Example: 'drdxup' does dxup for var 'r'. + ''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_DERIV_QUANT, get_deriv.__doc__, + uni=UNI.quant_child(0) / UNI_length) + for x in AXES: + for up in ('up', 'dn'): + docvar(f'd{x}{up}', f'dvard{x}{up} --> d(var)/d{x}, with half grid {up}, using method implied by stagger_kind.') + docvar(f'd{x}{up}1', f'dvard{x}{up}1 --> d(var)/d{x}, with half grid {up}, using first order gradients method.') + return None + + if quant[-1] == '1': + getq = quant[-5:] # e.g. 'dxup1' + dxup = getq[:-1] # e.g. 'dxup' + order = 1 # 1 --> "use first order. Ignore stagger_kind" + else: + getq = quant[-4:] # e.g. 'dxup' + dxup = getq + order = None # None --> "use whatever order is implied by obj.stagger_kind" + + if not (quant[0] == 'd' and getq[0] == 'd' and getq in _DERIV_QUANT[1]): + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, getq, _DERIV_QUANT[0], delay=True) + + # interpret quant string + axis = getq[1] # 'x', 'y', or 'z' + q = quant[1: -len(getq)] # base variable + # get value of var (before derivative) + var = obj.get_var(q) + + # handle "cant interpolate" case (e.g. if nx < 5 and trying to interpolate in x.) + if not _can_interp(obj, axis): + if obj.verbose: + warnings.warn("Can't interpolate; using np.gradient to take derivative, instead.") + xidx = dict(x=0, y=1, z=2)[axis] # axis; 0, 1, or 2. + if np.ndim(var) < 3: + if obj.verbose: + warnings.warn(f"Returning 0 for derivative of quant with ndim<3: {repr(quant)}.") + return np.zeros_like(var) + elif var.shape[xidx] <= 1: + return np.zeros_like(var) + dvar = np.gradient(var, axis=xidx) # 3D + dx = getattr(obj, 'd'+axis+'1d') # 1D; needs dims to be added. add dims below. + dx = np.expand_dims(dx, axis=tuple(set((0, 1, 2)) - set([xidx]))) + dvardx = dvar / dx + return dvardx + + # calculate derivative with interpolations + # -- bookkeeping: + threading = (obj.numThreads > 1) + lowbusing = obj.lowbus + # -- default case -- + if not (threading or lowbusing): + if order is None: # default order; order implied by obj.stagger_kind + return do_stagger(var, 'd'+dxup, obj=obj) + elif order == 1: # force first order + with tools.MaintainingAttrs(obj, 'stagger_kind'): # reset obj.stagger_kind after this block + obj.stagger_kind = 'o1_numpy' # stagger kind option which causes to use first order method. + return do_stagger(var, 'd'+dxup, obj=obj) + # -- "using numThreads" case (False by default) -- + if threading: + if obj.verbose: + print('Threading', whsp*8, end="\r", flush=True) + quantlist = [quant[-4:] for numb in range(obj.numThreads)] + + def deriv_loop(var, quant): + return do_stagger(var, 'd' + quant[0], obj=obj) + if axis != 'z': + return threadQuantity_z(deriv_loop, obj.numThreads, var, quantlist) + else: + return threadQuantity_y(deriv_loop, obj.numThreads, var, quantlist) + # -- "using lowbus" case (False by default) -- + else: + if lowbusing: + output = np.zeros_like(var) + if axis != 'z': + for iiz in range(obj.nz): + slicer = np.s_[:, :, iiz:iiz+1] + staggered = do_stagger(var[slicer], 'd' + quant[-4:], obj=obj) + output[slicer] = staggered + else: + for iiy in range(obj.ny): + slicer = np.s_[:, iiy:iiy+1, :] + staggered = do_stagger(var[slicer], 'd' + quant[-4:], obj=obj) + output[slicer] = staggered + + return output + + # if we reach this line, quant is a deriv quant but we did not handle it. + raise NotImplementedError(f'{repr(getq)} in get_deriv') + + +# default +_CENTER_QUANT = ('CENTER_QUANT', [x+'c' for x in AXES] + ['_center']) +# get value + + +def get_center(obj, quant, *args, **kwargs): + ''' + Center the variable in the midle of the grid cells + ''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_CENTER_QUANT, get_center.__doc__, uni=UNI.quant_child(0)) + docvar('_center', 'quant_center brings quant to center via interpolation. Requires mesh_location_tracking to be enabled') + return None + + getq = quant[-2:] # the quant we are "getting" by this function. E.g. 'xc'. + + if getq in _CENTER_QUANT[1]: + q = quant[:-1] # base variable, including axis. E.g. 'efx'. + elif quant.endswith('_center'): + assert getattr(obj, mesh_location_tracking, False), "mesh location tracking is required for this to be enabled" + q = quant[:-len('_center')] + else: + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, getq, _CENTER_QUANT[0], delay=True) + + # get the variable (pre-centering). + var = obj.get_var(q, **kwargs) + + # determine which interpolations are necessary. + if stagger.has_mesh_location(var): # << using mesh_location_tracking >> + transf = var.meshloc.steps_to((0, 0, 0)) # e.g. the list: ['xup', 'ydn', 'zdn'] + if len(transf) == 0: + warnings.warn(f'called get_center on an already-centered variable: {q}') + else: # << not using mesh_location_tracking >> + axis = quant[-2] + qvec = q[:-1] # base variable, without axis. E.g. 'ef'. + if qvec in ['i', 'e', 'j', 'ef']: # edge-centered variable. efx is at (0, -0.5, -0.5) + AXIS_TRANSFORM = {'x': ['yup', 'zup'], + 'y': ['xup', 'zup'], + 'z': ['xup', 'yup']} + else: + AXIS_TRANSFORM = {'x': ['xup'], + 'y': ['yup'], + 'z': ['zup']} + transf = AXIS_TRANSFORM[axis] + + # do interpolation + if obj.lowbus: + # do "lowbus" version of interpolation # not sure what is this? -SE Apr21 2021 + output = np.zeros_like(var) + for interp in transf: + axis = interp[0] + if _can_interp(obj, axis): + if axis != 'z': + for iiz in range(obj.nz): + slicer = np.s_[:, :, iiz:iiz+1] + staggered = do_stagger(var[slicer], interp, obj=obj) + output[slicer] = staggered + else: + for iiy in range(obj.ny): + slicer = np.s_[:, iiy:iiy+1, :] + staggered = do_stagger(var[slicer], interp, obj=obj) + output[slicer] = staggered + else: + # do "regular" version of interpolation + for interp in transf: + if _can_interp(obj, interp[0]): + var = do_stagger(var, interp, obj=obj) + return var + + +# default +_INTERP_QUANT = ('INTERP_QUANT', [x+up for up in ('up', 'dn') for x in AXES]) +# get value + + +def get_interp(obj, quant): + '''simple interpolation. var must end in interpolation instructions. + e.g. get_var('rxup') --> do_stagger(get_var('r'), 'xup') + ''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_INTERP_QUANT, get_interp.__doc__, uni=UNI.quant_child(0)) + for xup in _INTERP_QUANT[1]: + docvar(xup, 'move half grid {up:} in the {x:} axis'.format(up=xup[1:], x=xup[0])) + return None + + # interpret quant string + varname, interp = quant[:-3], quant[-3:] + + if not interp in _INTERP_QUANT[1]: + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, interp, _INTERP_QUANT[0], delay=True) + + val = obj.get_var(varname) # un-interpolated value + if _can_interp(obj, interp[0]): + val = do_stagger(val, interp, obj=obj) # interpolated value + return val + + +# default +_MODULE_QUANT = ('MODULE_QUANT', ['mod', 'h', '_mod']) +# get value + + +def get_module(obj, quant): + ''' + Module or horizontal component of vectors + ''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_MODULE_QUANT, get_module.__doc__, uni=UNI.quant_child(0)) + docvar('mod', 'starting with mod computes the module of the vector [simu units]. sqrt(vx^2 + vy^2 + vz^2).') + docvar('_mod', 'ending with mod computes the module of the vector [simu units]. sqrt(vx^2 + vy^2 + vz^2). ' + + "This is an alias for starting with mod. E.g. 'modb' and 'b_mod' mean the same thing.") + docvar('h', 'ending with h computes the horizontal component of the vector [simu units]. sqrt(vx^2 + vy^2).') + return None + + # interpret quant string + if quant.startswith('mod'): + getq = 'mod' + q = quant[len('mod'):] + elif quant.endswith('_mod'): + getq = 'mod' + q = quant[: -len('_mod')] + elif quant.endswith('h'): + getq = 'h' + q = quant[: -len('h')] + else: + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, getq, _MODULE_QUANT[0], delay=True) + + # actually get the quant: + result = obj.get_var(q + 'xc') ** 2 + result += obj.get_var(q + 'yc') ** 2 + if getq == 'mod': + result += obj.get_var(q + 'zc') ** 2 + + return np.sqrt(result) + + +# default +_HORVAR_QUANT = ('HORVAR_QUANT', ['horvar']) +# get value + + +def get_horizontal_average(obj, quant): + ''' + Computes horizontal average + ''' + + if quant == '': + docvar = document_vars.vars_documenter(obj, *_HORVAR_QUANT, get_horizontal_average.__doc__) + docvar('horvar', 'starting with horvar computes the horizontal average of a variable [simu units]') + + # interpret quant string + getq = quant[:6] + if not getq in _HORVAR_QUANT[1]: + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, getq, _HORVAR_QUANT[0], delay=True) + + # Compares the variable with the horizontal mean + if getq == 'horvar': + result = np.zeros_like(obj.r) + result += obj.get_var(quant[6:]) # base variable + horv = np.mean(np.mean(result, 0), 0) + for iix in range(0, getattr(obj, 'nx')): + for iiy in range(0, getattr(obj, 'ny')): + result[iix, iiy, :] = result[iix, iiy, :] / horv[:] + return result + else: + # quant is a horizontal_average quant but we did not handle it. + raise NotImplementedError(f'{repr(getq)} in get_horizontal_average') + + +# default +_GRADVECT_QUANT = ('GRADVECT_QUANT', + ['divup', 'divdn', 'div', # note: div must come after divup and divdn, + # since to check which quant to get we are checking .startswith, + # and 'divup' and 'divdn' both start with 'div'. + 'rot', 'she', 'curlcc', 'curvec', + 'chkdiv', 'chbdiv', 'chhdiv'] + ) +# get value + + +def get_gradients_vect(obj, quant): + ''' + Vectorial derivative opeartions + + for rot, she, curlcc, curvec, ensure that quant ends with axis. + e.g. curvecbx gets the x component of curl of b. + ''' + + if quant == '': + docvar = document_vars.vars_documenter(obj, *_GRADVECT_QUANT, get_gradients_vect.__doc__) + for div in ['div', 'divup']: + docvar(div, 'starting with, divergence [simu units], shifting up (e.g. dVARdxup) for derivatives', uni=UNI.quant_child(0)) + docvar('divdn', 'starting with, divergence [simu units], shifting down (e.g. dVARdxdn) for derivatives', uni=UNI.quant_child(0)) + docvar('rot', 'starting with, rotational (a.k.a. curl) [simu units]', uni=UNI.quant_child(0)) + docvar('she', 'starting with, shear [simu units]', uni=UNI.quant_child(0)) + docvar('curlcc', 'starting with, curl but shifted (via interpolation) back to original location on cell [simu units]', uni=UNI.quant_child(0)) + docvar('curvec', 'starting with, curl of face-centered vector (e.g. B, p) [simu units]', uni=UNI.quant_child(0)) + docvar('chkdiv', 'starting with, ratio of the divergence with the maximum of the abs of each spatial derivative [simu units]') + docvar('chbdiv', 'starting with, ratio of the divergence with the sum of the absolute of each spatial derivative [simu units]') + docvar('chhdiv', 'starting with, ratio of the divergence with horizontal averages of the absolute of each spatial derivative [simu units]') + return None + + # interpret quant string + for GVQ in _GRADVECT_QUANT[1]: + if quant.startswith(GVQ): + getq = GVQ # the quant we are "getting" via this function. (e.g. 'rot' or 'div') + q = quant[len(GVQ):] # the "base" quant, i.e. whatever is left after pulling getq. + break + else: # if we did not break, we did not match any GVQ to quant, so we return None. + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, getq, _GRADVECT_QUANT[0], delay=True) + + # do calculations and return result + if getq == 'chkdiv': + if getattr(obj, 'nx') < 5: # 2D or close + varx = np.zeros_like(obj.r) + else: + varx = obj.get_var('d' + q + 'xdxup') + + if getattr(obj, 'ny') > 5: + vary = obj.get_var('d' + q + 'ydyup') + else: + vary = np.zeros_like(varx) + + if getattr(obj, 'nz') > 5: + varz = obj.get_var('d' + q + 'zdzup') + else: + varz = np.zeros_like(varx) + return np.abs(varx + vary + varx) / (np.maximum( + np.abs(varx), np.abs(vary), np.abs(varz)) + EPSILON) + + elif getq == 'chbdiv': + varx = obj.get_var(q + 'x') + vary = obj.get_var(q + 'y') + varz = obj.get_var(q + 'z') + if getattr(obj, 'nx') < 5: # 2D or close + result = np.zeros_like(varx) + else: + result = obj.get_var('d' + q + 'xdxup') + + if getattr(obj, 'ny') > 5: + result += obj.get_var('d' + q + 'ydyup') + + if getattr(obj, 'nz') > 5: + result += obj.get_var('d' + q + 'zdzup') + + return np.abs(result / (np.sqrt( + varx * varx + vary * vary + varz * varz) + EPSILON)) + + elif getq == 'chhdiv': + varx = obj.get_var(q + 'x') + vary = obj.get_var(q + 'y') + varz = obj.get_var(q + 'z') + if getattr(obj, 'nx') < 5: # 2D or close + result = np.zeros_like(varx) + else: + result = obj.get_var('d' + q + 'xdxup') + + if getattr(obj, 'ny') > 5: + result += obj.get_var('d' + q + 'ydyup') + + if getattr(obj, 'nz') > 5: + result += obj.get_var('d' + q + 'zdzup') + + for iiz in range(0, obj.nz): + result[:, :, iiz] = np.abs(result[:, :, iiz]) / np.mean(( + np.sqrt(varx[:, :, iiz]**2 + vary[:, :, iiz]**2 + + varz[:, :, iiz]**2))) + return result + + elif getq in ['div', 'divup', 'divdn']: # divergence of vector quantity + up = 'dn' if (getq == 'divdn') else 'up' + result = 0 + for xdx in ['xdx', 'ydy', 'zdz']: + result += obj.get_var('d' + q + xdx + up) + return result + + elif getq == 'curlcc': # re-aligned curl + x = q[-1] # axis, 'x', 'y', 'z' + q = q[:-1] # q without axis + y, z = YZ_FROM_X[x] + dqz_dy = obj.get_var('d' + q + z + 'd' + y + 'dn' + y + 'up') + dqy_dz = obj.get_var('d' + q + y + 'd' + z + 'dn' + z + 'up') + return dqz_dy - dqy_dz + + elif getq == 'curvec': # curl of vector which is originally on face of cell + x = q[-1] # axis, 'x', 'y', 'z' + q = q[:-1] # q without axis + y, z = YZ_FROM_X[x] + # interpolation notes: + ## qz is at (0, 0, -0.5); dqzdydn is at (0, -0.5, -0.5) + ## qy is at (0, -0.5, 0); dqydzdn is at (0, -0.5, -0.5) + dqz_dydn = obj.get_var('d' + q + z + 'd' + y + 'dn') + dqy_dzdn = obj.get_var('d' + q + y + 'd' + z + 'dn') + return dqz_dydn - dqy_dzdn + + elif getq in ['rot', 'she']: + q = q[:-1] # base variable + qaxis = quant[-1] + if qaxis == 'x': + if getattr(obj, 'ny') < 5: # 2D or close + result = np.zeros_like(obj.r) + else: + result = obj.get_var('d' + q + 'zdyup') + if getattr(obj, 'nz') > 5: + if quant[:3] == 'rot': + result -= obj.get_var('d' + q + 'ydzup') + else: # shear + result += obj.get_var('d' + q + 'ydzup') + elif qaxis == 'y': + if getattr(obj, 'nz') < 5: # 2D or close + result = np.zeros_like(obj.r) + else: + result = obj.get_var('d' + q + 'xdzup') + if getattr(obj, 'nx') > 5: + if quant[:3] == 'rot': + result -= obj.get_var('d' + q + 'zdxup') + else: # shear + result += obj.get_var('d' + q + 'zdxup') + elif qaxis == 'z': + if getattr(obj, 'nx') < 5: # 2D or close + result = np.zeros_like(obj.r) + else: + result = obj.get_var('d' + q + 'ydxup') + if getattr(obj, 'ny') > 5: + if quant[:3] == 'rot': + result -= obj.get_var('d' + q + 'xdyup') + else: # shear + result += obj.get_var('d' + q + 'xdyup') + return result + + else: + # if we reach this line, quant is a gradients_vect quant but we did not handle it. + raise NotImplementedError(f'{repr(getq)} in get_gradients_vect') + + +# default +_GRADSCAL_QUANT = ('GRADSCAL_QUANT', ['gra']) +# get value + + +def get_gradients_scalar(obj, quant): + ''' + Gradient of a scalar + ''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_GRADSCAL_QUANT, get_gradients_scalar.__doc__) + docvar('gra', 'starting with, Gradient of a scalar [simu units]. dqdx + dqdy + dqdz.' + + ' Shifting up for derivatives.', uni=UNI.quant_child(0)) + return None + + getq = quant[:3] + + if not getq in _GRADSCAL_QUANT[1]: + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, getq, _GRADSCAL_QUANT[0], delay=True) + + # do calculations and return result + if getq == 'gra': + q = quant[3:] # base variable + result = obj.get_var('d' + q + 'dxup') + result += obj.get_var('d' + q + 'dyup') + result += obj.get_var('d' + q + 'dzup') + return result + else: + # if we reach this line, quant is a gradients_scalar quant but we did not handle it. + raise NotImplementedError(f'{repr(getq)} in get_gradients_scalar') + + +# default +_SQUARE_QUANT = ('SQUARE_QUANT', ['2']) +# get value + + +def get_square(obj, quant): + '''|vector| squared. Equals got product of vector with itself''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_SQUARE_QUANT, get_square.__doc__, + uni=UNI.quant_child(0)**2) + docvar('2', 'ending with, Square of a vector [simu units].' + + ' (Dot product of vector with itself.) Example: b2 --> bx^2 + by^2 + bz^2.') + return None + + getq = quant[-1] + if not getq in _SQUARE_QUANT[1]: + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, getq, _SQUARE_QUANT[0], delay=True) + + # interpret quant string + q = quant[:-1] # vector name + + # do calculations and return result + if getq == '2': + result = obj.get_var(q + 'xc') ** 2 + result += obj.get_var(q + 'yc') ** 2 + result += obj.get_var(q + 'zc') ** 2 + return result + else: + # if we reach this line, quant is a square quant but we did not handle it. + raise NotImplementedError(f'{repr(getq)} in get_square') + + +# default +_LOG_QUANT = ('LOG_QUANT', ['lg', 'log_', 'log10_', 'ln_']) +# get value + + +def get_lg(obj, quant): + '''Logarithm of a variable. E.g. log_r --> log10(r)''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_LOG_QUANT, get_lg.__doc__, uni=DIMENSIONLESS) + for lg in ['lg', 'log_', 'log10_']: + docvar(lg, 'starting with, logarithm base 10 of a variable expressed in [simu. units].') + docvar('ln_', 'starting with, logarithm base e of a variable expressed in [simu. units].') + return None + + # interpret quant string + for LGQ in _LOG_QUANT[1]: + if quant.startswith(LGQ): + getq = LGQ # the quant we are "getting" via this function. (e.g. 'lg' or 'ln_') + q = quant[len(LGQ):] # the "base" quant, i.e. whatever is left after pulling getq. + break + else: # if we did not break, we did not match any LGQ to quant, so we return None. + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, getq, _LOG_QUANT[0], delay=True) + + # do calculations and return result + if getq in ['lg', 'log_', 'log10_']: + return np.log10(obj.get_var(q)) + elif getq == 'ln_': + return np.log(obj.get_var(q)) + else: + # if we reach this line, quant is a lg quant but we did not handle it. + raise NotImplementedError(f'{repr(getq)} in get_lg') + + +# default +_NUMOP_QUANT = ('NUMOP_QUANT', ['delta_', 'deltafrac_', 'abs_']) +# get value + + +def get_numop(obj, quant): + '''Some numerical operation on a variable. E.g. delta_var computes (var - var.mean()).''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_NUMOP_QUANT, get_numop.__doc__) + docvar('delta_', 'starting with, deviation from mean. delta_v --> v - mean(v)', uni=UNI.qc(0)) + docvar('deltafrac_', 'starting with, fractional deviation from mean. deltafrac_v --> v / mean(v) - 1', uni=DIMENSIONLESS) + docvar('abs_', 'starting with, absolute value of a scalar. abs_v --> |v|', uni=UNI.qc(0)) + return None + + # interpret quant string + for start in _NUMOP_QUANT[1]: + if quant.startswith(start): + getq = start # the quant we are "getting" via this function. (e.g. 'lg' or 'ln_') + base = quant[len(getq):] # the "base" quant, i.e. whatever is left after pulling getq. + break + else: # if we did not break, we did not match any start to quant, so we return None. + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, getq, _NUMOP_QUANT[0], delay=True) + + # do calculations and return result + v = obj.get_var(base) + if getq == 'delta_': + return (v - np.mean(v)) + elif getq == 'deltafrac_': + return (v / np.mean(v)) - 1 + elif getq == 'abs_': + return np.abs(v) + else: + # if we reach this line, quant is a numop quant but we did not handle it. + raise NotImplementedError(f'{repr(getq)} in get_numop') + + +# default +_RATIO_QUANT = ('RATIO_QUANT', ['rat']) +# get value + + +def get_ratios(obj, quant): + '''Ratio of two variables''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_RATIO_QUANT, get_ratios.__doc__, uni=UNI.qc(0)/UNI.qc(1)) + docvar('rat', 'in between with, ratio of two variables [simu units]. aratb gives a/b.') + return None + + # interpret quant string + for RAT in _RATIO_QUANT[1]: + qA, rat, qB = quant.partition(RAT) + if qB != '': + break + else: # if we did not break, we did not match any RAT to quant, so we return None. + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, rat, _RATIO_QUANT[0], delay=True) + + # do calculations and return result + qA_val = obj.get_var(qA) + qB_val = obj.get_var(qB) + return qA_val / (qB_val + EPSILON) + + +# default +_PARENS_QUANT = ('PARENS_QUANT', ['()', '()x', '()y', '()z']) +# get value + + +def get_parens(obj, quant): + '''parentheses (in the sense of "order of operations"). + E.g. mean_(b_mod) --> mean |B| + E.g. (mean_b)_mod --> |(mean Bx, mean By, mean Bz)| + E.g. curvec(u_facecrosstoface_b)x --> x component of (curl of (u cross b)). + E.g. (curvecu)_facecrosstoface_bx --> x component of ((curl of u) cross b). + ''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_PARENS_QUANT, get_parens.__doc__, uni=UNI.qc(0)) + docvar('()', "(s) --> get_var(s).") + for x in AXES: + docvar('()'+x, f"(s){x} --> get_var(sx).") + return None + + # interpret quant string + if quant[0] != '(': + return None + if quant[-1] == ')': + getting = '()' + var = quant[1: -1] + axis = None + elif quant[-2] == ')': + axis = quant[-1] + if axis not in AXES: + return None + getting = '()'+axis + var = quant[1: -2] + axis + else: + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, getting, _PARENS_QUANT[0], delay=True) + + # do calculations and return result + val = obj(var) + return val + + +# default +_PROJ_QUANT = ('PROJ_QUANT', ['par', 'per']) +# get value + + +def get_projections(obj, quant): + '''Projected vectors''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_PROJ_QUANT, get_projections.__doc__, uni=UNI.quant_child(0)) + docvar('par', 'in between with, parallel component of the first vector respect to the second vector [simu units]') + docvar('per', 'in between with, perpendicular component of the first vector respect to the second vector [simu units]') + return None + + # interpret quant string + for PAR in _PROJ_QUANT[1]: + v1, par, v2 = quant.partition(PAR) + if v2 != '': + break + else: # if we did not break, we did not match any PAR to quant, so we return None. + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, par, _PROJ_QUANT[0], delay=True) + + # do calculations and return result v1 onto v2 + x_a = obj.get_var(v1 + 'xc') + y_a = obj.get_var(v1 + 'yc') + z_a = obj.get_var(v1 + 'zc') + x_b = obj.get_var(v2 + 'xc') + y_b = obj.get_var(v2 + 'yc') + z_b = obj.get_var(v2 + 'zc') + + def proj_task(x1, y1, z1, x2, y2, z2): + '''do projecting; can be used in threadQuantity() or as is''' + v2Mag = np.sqrt(x2 ** 2 + y2 ** 2 + z2 ** 2) + v2x, v2y, v2z = x2 / v2Mag, y2 / v2Mag, z2 / v2Mag + parScal = np.sqrt((x1 * v2x)**2 + (y1 * v2y)**2 + (z1 * v2z)**2) + parX, parY, parZ = x1 * v2x, y1 * v2y, z1 * v2z + if par == 'par': + return np.abs(parScal) + elif par == 'per': + perX = x1 - parX + perY = y1 - parY + perZ = z1 - parZ + v1Mag = np.sqrt(perX**2 + perY**2 + perZ**2) + return v1Mag + + if obj.numThreads > 1: + if obj.verbose: + print('Threading', whsp*8, end="\r", flush=True) + + return threadQuantity(proj_task, obj.numThreads, + x_a, y_a, z_a, x_b, y_b, z_b) + else: + return proj_task(x_a, y_a, z_a, x_b, y_b, z_b) + + +# default +_VECTOR_PRODUCT_QUANT = \ + ('VECTOR_PRODUCT_QUANT', + ['times', '_facecross_', '_edgecross_', '_edgefacecross_', + '_facecrosstocenter_', '_facecrosstoface_' + ] + ) +# get value + + +def get_vector_product(obj, quant): + '''cross product between two vectors. + call via . + Example, for the x component of b cross u, you should call get_var('b_facecross_ux'). + ''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_VECTOR_PRODUCT_QUANT, get_vector_product.__doc__, + uni=UNI.quant_child(0) * UNI.quant_child(1)) + docvar('times', '"naive" cross product between two vectors. (We do not do any interpolation.) [simu units]') + docvar('_facecross_', ('cross product [simu units]. For two face-centered vectors, such as B, u. ' + 'result is edge-centered. E.g. result_x is at ( 0 , -0.5, -0.5).')) + docvar('_edgecross_', ('cross product [simu units]. For two edge-centered vectors, such as E, I. ' + 'result is face-centered. E.g. result_x is at (-0.5, 0 , 0 ).')) + docvar('_edgefacecross_', ('cross product [simu units]. A_edgefacecross_Bx gives x-component of A x B.' + 'A must be edge-centered (such as E, I); B must be face-centered, such as B, u.' + 'result is face-centered. E.g. result_x is at (-0.5, 0 , 0 ).')) + docvar('_facecrosstocenter_', ('cross product for two face-centered vectors such as B, u. ' + 'result is fully centered. E.g. result_x is at ( 0 , 0 , 0 ).' + ' For most cases, it is better to use _facecrosstoface_')) + docvar('_facecrosstoface_', ('cross product for two face-centered vectors such as B, u. ' + 'result is face-centered E.g. result_x is at (-0.5, 0 , 0 ).'), + uni=UNI.quant_child(0)) # quant_child(0) will be _facecrosstocenter_ for this one. + return None + + # interpret quant string + for TIMES in _VECTOR_PRODUCT_QUANT[1]: + A, cross, q = quant.partition(TIMES) + if q != '': + B, x = q[:-1], q[-1] + y, z = YZ_FROM_X[x] + break + else: # if we did not break, we did not match any TIMES to quant, so we return None. + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, cross, _VECTOR_PRODUCT_QUANT[0], delay=True) + + # at this point, we know quant looked like + + if cross == 'times': + return (obj.get_var(A + y) * obj.get_var(B + z) - + obj.get_var(A + z) * obj.get_var(B + y)) + + elif cross == '_facecross_': + # interpolation notes, for x='x', y='y', z='z': + # resultx will be at (0, -0.5, -0.5) + # Ay, By are at (0, -0.5, 0 ). we must shift by zdn to align with result. + # Az, Bz are at (0, 0 , -0.5). we must shift by ydn to align with result. + ydn, zdn = y+'dn', z+'dn' + Ay = obj.get_var(A+y + zdn) + By = obj.get_var(B+y + zdn) + Az = obj.get_var(A+z + ydn) + Bz = obj.get_var(B+z + ydn) + AxB__x = Ay * Bz - By * Az # x component of A x B. (x='x', 'y', or 'z') + return AxB__x + + elif cross == '_edgecross_': + # interpolation notes, for x='x', y='y', z='z': + # resultx will be at (-0.5, 0, 0) + # Ay, By are at (-0.5, 0 , -0.5). we must shift by zup to align with result. + # Az, Bz are at (-0.5, -0.5, 0 ). we must shift by yup to align with result. + yup, zup = y+'up', z+'up' + Ay = obj.get_var(A+y + zup) + By = obj.get_var(B+y + zup) + Az = obj.get_var(A+z + yup) + Bz = obj.get_var(B+z + yup) + AxB__x = Ay * Bz - By * Az # x component of A x B. (x='x', 'y', or 'z') + return AxB__x + + elif cross == '_edgefacecross_': + # interpolation notes, for x='x', y='y', z='z': + # resultx will be at (-0.5, 0, 0) + # Ay is at (-0.5, 0 , -0.5). we must shift by zup to align with result. + # Az is at (-0.5, -0.5, 0 ). we must shift by yup to align with result. + # By is at ( 0 , -0.5, 0 ). we must shift by xdn yup to align with result. + # Bz is at ( 0 , 0 , -0.5). we must shift by xdn zup to align with result. + xdn, yup, zup = x+'dn', y+'up', z+'up' + Ay = obj.get_var(A+y + zup) + Az = obj.get_var(A+z + yup) + By = obj.get_var(B+y + xdn+yup) + Bz = obj.get_var(B+z + xdn+zup) + AxB__x = Ay * Bz - By * Az # x component of A x B. (x='x', 'y', or 'z') + return AxB__x + + elif cross == '_facecrosstocenter_': + # interpolation notes, for x='x', y='y', z='z': + # resultx will be at (0, 0, 0) + # Ay, By are at (0, -0.5, 0 ). we must shift by yup to align with result. + # Az, Bz are at (0, 0 , -0.5). we must shift by zup to align with result. + yup, zup = y+'up', z+'up' + Ay = obj.get_var(A+y + yup) + By = obj.get_var(B+y + yup) + Az = obj.get_var(A+z + zup) + Bz = obj.get_var(B+z + zup) + AxB__x = Ay * Bz - By * Az # x component of A x B. (x='x', 'y', or 'z') + return AxB__x + + elif cross == '_facecrosstoface_': + # resultx will be at (-0.5, 0, 0). + # '_facecrosstocenter_' gives result at (0, 0, 0) so we shift by xdn to align. + return obj.get_var(A+'_facecrosstocenter_'+B+x + x+'dn') + + else: + # if we reach this line, quant is a vector_product but we did not handle it. + raise NotImplementedError(f'{repr(cross)} in get_vector_product') + + +# default +_DOT_PRODUCT_QUANT = \ + ('DOT_PRODUCT_QUANT', + ['_dot_', '_facedot_', '_edgedot_'] + ) +# get value + + +def get_dot_product(obj, quant): + '''dot product between two vectors. + call via + Example, u dot ue, you should call get_var('u_facedot_ue'). + Result will always be centered on the meshgrid. + ''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_DOT_PRODUCT_QUANT, get_dot_product.__doc__, + uni=UNI.quant_child(0) * UNI.quant_child(1)) + docvar('_dot_', '"smart" dot product between two vectors. centers all values before dotting.') + docvar('_facedot_', 'dot product between two face-centered vectors, such as B, u.') + docvar('_edgedot_', 'dot product between two edge-centered vectors, such as E, I.') + return None + + # interpret quant string + for TIMES in _DOT_PRODUCT_QUANT[1]: + A, dot, B = quant.partition(TIMES) + if B != '': + break + else: # if we did not break, we did not match any TIMES to quant, so we return None. + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, dot, _DOT_PRODUCT_QUANT[0], delay=True) + + # at this point, we know quant looked like + + if dot == '_dot_': + return obj(A+'xc') * obj(B+'xc') + \ + obj(A+'yc') * obj(B+'yc') + \ + obj(A+'zc') * obj(B+'zc') + + elif dot == '_facedot_': + return obj(A+'xxup') * obj(B+'xxup') + \ + obj(A+'yyup') * obj(B+'yyup') + \ + obj(A+'zzup') * obj(B+'zzup') + + elif dot == '_edgedot_': + return obj(A+'xyupzup') * obj(B+'xyupzup') + \ + obj(A+'yzupxup') * obj(B+'yzupxup') + \ + obj(A+'zxupyup') * obj(B+'zxupyup') + + else: + # if we reach this line, quant is a dot_product quant but we did not handle it. + raise NotImplementedError(f'{repr(dot)} in get_dot_product') + + +# default +_HATS = ['_hat'+x for x in AXES] +_ANGLES_XXY = ['_angle' + xxy for xxy in ['xxy', 'yyz', 'zzx']] +_ANGLE_QUANT = _HATS + _ANGLES_XXY +_ANGLE_QUANT = ('ANGLE_QUANT', _ANGLE_QUANT) +# get value + + +def get_angle(obj, quant): + '''angles. includes unit vector, and angle off an axis in a plane (xy, yz, or zx). + + Presently not very efficient, due to only being able to return one unit vector component at a time. + + call via . + Example: b_angleyyz --> angle off of the positive y axis in the yz plane, for b (magnetic field). + + TODO: interpolation + ''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_ANGLE_QUANT, get_angle.__doc__) + for x in AXES: + docvar('_hat'+x, x+'-component of unit vector. Example: b_hat'+x+' is '+x+'-component of unit vector for b.', + uni=DIMENSIONLESS) + for _angle_xxy in _ANGLES_XXY: + x, y = _angle_xxy[-2], _angle_xxy[-1] # _angle_xxy[-3] == _angle_xxy[-1] + docvar(_angle_xxy, 'angle off of the positive '+x+'-axis in the '+x+y+'plane. Result in range [-pi, pi].', + uni_f=UNITS_FACTOR_1, uni_name=Usym('radians')) + return None + + # interpret quant string + var, _, command = quant.rpartition('_') + command = '_' + command + + if command not in _ANGLE_QUANT[1]: + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, command, _ANGLE_QUANT[0], delay=True) + + # do calculations and return result + if command in _HATS: + x = command[-1] # axis; 'x', 'y', or 'z' + varhatx = obj.get_var(var+x) / obj.get_var('mod'+var) + return varhatx + + elif command in _ANGLES_XXY: + x, y = command[-2], command[-1] # _angle_xxy[-3] == _angle_xxy[-1] + varx = obj.get_var(var + x) + vary = obj.get_var(var + y) + return np.arctan2(vary, varx) + + else: + # if we reach this line, quant is an angle quant but we did not handle it. + raise NotImplementedError(f'{repr(command)} in get_angle') + + +# default +_STAT_QUANT = ('STAT_QUANT', ['mean_', 'variance_', 'std_', 'max_', 'min_', 'abs_']) +# get value + + +def get_stat_quant(obj, quant): + '''statistics such as mean, std. + + The result will be a single value (not a 3D array). + ''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_STAT_QUANT, get_stat_quant.__doc__) + docvar('mean_', 'mean_v --> np.mean(v)', uni=UNI.qc(0)) + docvar('variance_', 'variance_v --> np.var(v).', uni=UNI.qc(0)**2) + docvar('std_', 'std_v --> np.std(v)', uni=UNI.qc(0)) + docvar('max_', 'max_v --> np.max(v)', uni=UNI.qc(0)) + docvar('min_', 'min_v --> np.min(v)', uni=UNI.qc(0)) + docvar('abs_', 'abs_v --> np.abs(v)', uni=UNI.qc(0)) + return None + + # interpret quant string + command, _, var = quant.partition('_') + command = command + '_' + + if command not in _STAT_QUANT[1]: + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, command, _STAT_QUANT[0], delay=True) + + # do calculations and return result + val = obj.get_var(var) + if command == 'mean_': + return np.mean(val) + elif command == 'variance_': + return np.var(val) + elif command == 'std_': + return np.std(val) + elif command == 'max_': + return np.max(val) + elif command == 'min_': + return np.min(val) + elif command == 'abs_': + return np.abs(val) + else: + raise NotImplementedError(f'command={repr(command)} in get_stat_quant') + + +# default +_FFT_QUANT = ('FFT_QUANT', ['fft2_', 'fftxy_', 'fftyz_', 'fftxz_']) +# get value + + +def get_fft_quant(obj, quant): + '''Fourier transform, using np.fft.fft2, and shifting using np.fft.fftshift. + + result will be complex-valued. (consider get_var('abs_fft2_quant') to convert to magnitude.) + + See obj.kx, ky, kz for the corresponding coordinates in k-space. + See obj.get_kextent for the extent to use if plotting k-space via imshow. + + Also sets obj._latest_fft_axes = ('x', 'y'), ('x', 'z') or ('y', 'z') as appropriate. + + Note that for plotting with imshow, you will likely want to transpose and use origin='lower'. + Example, making a correctly labeled and aligned plot of FFT(r[:, 0, :]): + dd = BifrostData(...) + val = dd('abs_fftxz_r')[:, 0, :] # == |FFT(dd('r')[:, 0, :])| + extent = dd.get_kextent('xz', units='si') + plt.imshow(val.T, origin='lower', extent=extent) + plt.xlabel('kx [1/m]'); plt.ylabel('kz [1/m]') + plt.xlim([0, None]) # <-- not necessary, however numpy's FFT of real-valued input + # will be symmetric under rotation by 180 degrees, so half the spectrum is redundant. + ''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_FFT_QUANT, get_fft_quant.__doc__, uni=UNI.qc(0)) + shifted = ' result will be shifted so that the zero-frequency component is in the middle (via np.fft.fftshift).' + docvar('fft2_', '2D fft. requires 2D data (i.e. x, y, or z with length 1). result will be 2D.' + shifted) + docvar('fftxy_', '2D fft in (x, y) plane, at each z. result will be 3D.' + shifted) + docvar('fftyz_', '2D fft in (y, z) plane, at each x. result will be 3D.' + shifted) + docvar('fftxz_', '2D fft in (x, z) plane, at each y. result will be 3D.' + shifted) + return None + + # interpret quant string + command, _, var = quant.partition('_') + command = command + '_' + + if command not in _FFT_QUANT[1]: + return None + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, command, _FFT_QUANT[0], delay=True) + + # do calculations and return result + val = obj(var) + if command == 'fft2_': + if np.shape(val) != obj.shape: + raise NotImplementedError(f'fft2_ for {repr(var)} with shape {np.shape(val)} not equal to obj.shape {obj.shape}') + if obj.xLength == 1: + return obj(f'fftyz_{var}')[0, :, :] + elif obj.yLength == 1: + return obj(f'fftxz_{var}')[:, 0, :] + elif obj.zLength == 1: + return obj(f'fftxy_{var}')[:, :, 0] + else: + errmsg = f'fft2_ requires x, y, or z to have length 1, but obj.shape = {obj.shape}.' +\ + 'maybe you meant to specify axes, via fftxy_, fftyz, or fftxz_?' + raise ValueError(errmsg) + elif command in ('fftxy_', 'fftyz_', 'fftxz_'): + x, y = command[3:5] + obj._latest_fft_axes = (x, y) # <-- bookkeeping + AX_STR_TO_I = {'x': 0, 'y': 1, 'z': 2} + xi = AX_STR_TO_I[x] + yi = AX_STR_TO_I[y] + fft_unshifted = np.fft.fft2(val, axes=(xi, yi)) + return np.fft.fftshift(fft_unshifted) + else: + raise NotImplementedError(f'command={repr(command)} in get_fft_quant') + + +# default +_MULTI_QUANT = ('MULTI_QUANT', + [fullcommand + for command in ('vec', 'vecxyz', 'vecxy', 'vecyz', 'vecxz') + for fullcommand in ('_'+command, command+'_')] + ) +# get value + + +def get_multi_quant(obj, quant): + '''multiple quantities. (last axis will be the multi.) + E.g. 'b_vec' --> result.shape=(Nx, Ny, Nz, 3), result[...,0]=bx, result[...,1]=by, result[...,2]=bz. + ''' + if quant == '': + docvar = document_vars.vars_documenter(obj, *_MULTI_QUANT, get_multi_quant.__doc__, uni=UNI.qc(0)) + for fmt in '{var}_{command}', '{command}_{var}': + for command in ('vec', 'vecxyz'): # 'vec' and 'vecxyz' are aliases for each other. + docvar(fmt.format(var='', command=command), + "'" + fmt.format(var='var', command=command) + "'" + + " --> (varx, vary, varz) stacked along last axis (shape == (Nx, Ny, Nz, 3).") + for (x, y) in ('xy', 'yz', 'xz'): + docvar(fmt.format(var='', command=f'vec{x}{y}'), + "'" + fmt.format(var='var', command=f'vec{x}{y}') + "'" + + " --> (var{x}, var{y}) stacked along last axis (shape == (Nx, Ny, Nz, 2).") + return None + + # interpret quant string + var, _, command = quant.rpartition('_') + fullcommand = '_' + command + + if fullcommand not in _MULTI_QUANT[1]: # quant doesn't look like 'var_command' + command, _, var = quant.partition('_') + fullcommand = command + '_' + if fullcommand not in _MULTI_QUANT[1]: + return None + # now we have assigned: + # command = command without underscore. e.g. 'vec' + # fullcommand = command with underscore. e.g. '_vec' or 'vec_' + # var = quant without command. + + # tell obj the quant we are getting by this function. + document_vars.setattr_quant_selected(obj, fullcommand, _MULTI_QUANT[0], delay=True) + + # do calculations and return result + if command.startswith('vec'): + if command in ('vec', 'vecxyz'): + axes = 'xyz' + else: # command is 'vecxy', 'vecyz', or 'vecxz' + axes = command[-2:] + components = [obj(var+x) for x in axes] + return np.stack(components, axis=-1) + + else: + raise NotImplementedError(f'command={repr(fullcommand)} in get_multi_quant') + + +''' ------------- End get_quant() functions; Begin helper functions ------------- ''' + + +def threadQuantity(task, numThreads, *args): + # split arg arrays + args = list(args) + + for index in range(np.shape(args)[0]): + args[index] = np.array_split(args[index], numThreads) + + # make threadpool, task = task, with zipped args + pool = ThreadPool(processes=numThreads) + result = np.concatenate(pool.starmap(task, zip(*args))) + return result + + +def threadQuantity_y(task, numThreads, *args): + # split arg arrays + args = list(args) + + for index in range(np.shape(args)[0]): + if len(np.shape(args[index])) == 3: + args[index] = np.array_split(args[index], numThreads, axis=1) + else: + args[index] = np.array_split(args[index], numThreads) + # make threadpool, task = task, with zipped args + pool = ThreadPool(processes=numThreads) + result = np.concatenate(pool.starmap(task, zip(*args)), axis=1) + return result + + +def threadQuantity_z(task, numThreads, *args): + # split arg arrays + args = list(args) + + for index in range(np.shape(args)[0]): + if len(np.shape(args[index])) == 3: + args[index] = np.array_split(args[index], numThreads, axis=2) + else: + args[index] = np.array_split(args[index], numThreads) + + # make threadpool, task = task, with zipped args + pool = ThreadPool(processes=numThreads) + result = np.concatenate(pool.starmap(task, zip(*args)), axis=2) + return result diff --git a/helita/sim/load_fromfile_quantities.py b/helita/sim/load_fromfile_quantities.py new file mode 100644 index 00000000..47edf0c1 --- /dev/null +++ b/helita/sim/load_fromfile_quantities.py @@ -0,0 +1,114 @@ +# import builtins + +# import internal modules +# import external public modules +import numpy as np + +from . import document_vars +# import the relevant things from the internal module "units" +from .units import DIMENSIONLESS, UNI, UNI_speed, Usym + + +def load_fromfile_quantities(obj, quant, order='F', mode='r', panic=False, save_if_composite=False, cgsunits=None, **kwargs): + '''loads quantities which are stored directly inside files. + + save_if_composite: False (default) or True. + use True for bifrost; False for ebysus. + See _get_composite_var() for more details. + + cgsunits: None or value + None --> ignore + value --> multiply val by this value if val was a simple var. + ''' + __tracebackhide__ = True # hide this func from error traceback stack. + + quant = quant.lower() + + document_vars.set_meta_quant(obj, 'fromfile', + ('These are the quantities which are stored directly inside the snapshot files.\n' + 'Their values are "calculated" just by reading the appropriate part of the appropriate file.\n' + '(Except for composite_var, which is included here only because it used to be in bifrost.py.)') + ) + + val = obj._get_simple_var(quant, order=order, mode=mode, panic=panic, **kwargs) # method of obj. + if ((cgsunits is not None) and (val is not None)): + val = val*cgsunits + if val is None: + val = _get_simple_var_xy(obj, quant, order=order, mode=mode) # method defined in this file. + if val is None: + val = _get_composite_var(obj, quant, save_if_composite=save_if_composite, **kwargs) # method defined in this file. + return val + + +@document_vars.quant_tracking_simple('SIMPLE_XY_VAR') +def _get_simple_var_xy(obj, var, order='F', mode='r'): + ''' Reads a given 2D variable from the _XY.aux file ''' + if var == '': + document_vars.vars_documenter(obj, 'SIMPLE_XY_VAR', getattr(obj, 'auxxyvars', []), _get_composite_var.__doc__) + # TODO << fill in the documentation for simple_xy_var quantities here. + return None + + if var not in obj.auxxyvars: + return None + + # determine the file + fsuffix = '_XY.aux' + idx = obj.auxxyvars.index(var) + filename = obj.file_root + fsuffix + + # memmap the variable + if not os.path.isfile(filename): + raise FileNotFoundError('_get_simple_var_xy: variable {} should be in {} file, not found!'.format(var, filename)) + dsize = np.dtype(obj.dtype).itemsize # size of the data type + offset = obj.nx * obj.ny * idx * dsize # offset in the file + return np.memmap(filename, dtype=obj.dtype, order=order, mode=mode, + offset=offset, shape=(obj.nx, obj.ny)) + + +# default +_COMPOSITE_QUANT = ('COMPOSITE_QUANT', ['ux', 'uy', 'uz', 'ee', 's']) +# get value + + +@document_vars.quant_tracking_simple(_COMPOSITE_QUANT[0]) +def _get_composite_var(obj, var, *args, save_if_composite=False, **kwargs): + ''' gets velocities, internal energy ('e' / 'r'), entropy. + + save_if_composite: False (default) or True. + if True, also set obj.variables[var] = result. + (Provided for backwards compatibility with bifrost, which + used to call _get_composite_var then do obj.variables[var] = result.) + (True is NOT safe for ebysus, unless proper care is taken to save _metadata to obj.variables.) + + *args and **kwargs go to get_var. + ''' + if var == '': + docvar = document_vars.vars_documenter(obj, *_COMPOSITE_QUANT, _get_composite_var.__doc__, nfluid=1) + for ux in ['ux', 'uy', 'uz']: + docvar(ux, '{x:}-component of velocity [simu. velocity units]'.format(x=ux[-1]), uni=UNI_speed) + docvar('ee', "internal energy. get_var('e')/get_var('r').", uni_f=UNI.e/UNI.r, usi_name=Usym('J')) + docvar('s', 'entropy (??)', uni=DIMENSIONLESS) + return None + + if var not in _COMPOSITE_QUANT[1]: + return None + + if var in ['ux', 'uy', 'uz']: # velocities + # u = p / r. + # r is on center of grid cell, but p is on face, + # so we need to interpolate. + # r is at (0,0,0), ux and px are at (-0.5, 0, 0) + # --> to align r with px, shift by xdn + x = var[-1] # axis; 'x', 'y', or 'z' + interp = x+'dn' + p = obj.get_var('p' + x) + r = obj.get_var('r' + interp) + return p / r + + elif var == 'ee': # internal energy + return obj.get_var('e') / obj.get_var('r') + + elif var == 's': # entropy? + return np.log(obj.get_var('p', *args, **kwargs)) - \ + obj.params['gamma'][obj.snapInd] * np.log( + obj.get_var('r', *args, **kwargs)) diff --git a/helita/sim/load_mf_quantities.py b/helita/sim/load_mf_quantities.py new file mode 100644 index 00000000..31d5257f --- /dev/null +++ b/helita/sim/load_mf_quantities.py @@ -0,0 +1,2528 @@ +# import builtins +import warnings + +# import external public modules +import numpy as np + +# import internal modules +from . import document_vars +from .file_memory import Caching # never alters results, but caches them for better efficiency. +# use sparingly on "short" calculations; apply liberally to "long" calculations. +# see also cache_with_nfluid and cache kwargs of get_var. +from .load_arithmetic_quantities import do_stagger +# import the relevant things from the internal module "units" +from .units import ( + DIMENSIONLESS, + U_TUPLE, + UCONST, + UNI, + UNITS_FACTOR_1, + UNI_hz, + UNI_length, + UNI_nr, + UNI_rho, + UNI_speed, + UNI_time, + Usym, + UsymD, +) + +# set constants +MATCH_PHYSICS = 0 # don't change this value. # this one is the default (see ebysus.py) +MATCH_AUX = 1 # don't change this value. +AXES = ('x', 'y', 'z') # the axes names. +YZ_FROM_X = dict(x=('y', 'z'), y=('z', 'x'), z=('x', 'y')) # right-handed coord system x,y,z given x. + +# TODO: +# adapt maxwell collisions from load_quantities.py file, to improve maxwell collisions in this file. + +# construct some frequently-used units +units_e = dict(uni_f=UNI.e, usi_name=Usym('J') / Usym('m')**3) # ucgs_name= ??? + + +def load_mf_quantities(obj, quant, *args__None, GLOBAL_QUANT=None, EFIELD_QUANT=None, + ONEFLUID_QUANT=None, ELECTRON_QUANT=None, + CONTINUITY_QUANT=None, MOMENTUM_QUANT=None, HEATING_QUANT=None, + SPITZERTERM_QUANT=None, + COLFRE_QUANT=None, LOGCUL_QUANT=None, CROSTAB_QUANT=None, + DRIFT_QUANT=None, MEAN_QUANT=None, CFL_QUANT=None, PLASMA_QUANT=None, + HYPERDIFFUSIVE_QUANT=None, + WAVE_QUANT=None, FB_INSTAB_QUANT=None, THERMAL_INSTAB_QUANT=None, + **kw__None): + '''load multifluid quantity indicated by quant. + *args__None and **kw__None go nowhere. + ''' + __tracebackhide__ = True # hide this func from error traceback stack. + + quant = quant.lower() + + document_vars.set_meta_quant(obj, 'mf_quantities', + ("These are the multi-fluid quantities; only used by ebysus.\n" + "nfluid means 'number of fluids used to read the quantity'.\n" + " 2 -> uses obj.ifluid and obj.jfluid. (e.g. 'nu_ij')\n" + " 1 -> uses obj.ifluid (but not jfluid). (e.g. 'ux', 'tg')\n" + " 0 -> does not use ifluid nor jfluid. (e.g. 'bx', 'nel', 'tot_e'))\n") + ) + + # tell which getter function is associated with each QUANT. + # (would put this list outside this function if the getter functions were defined there, but they are not.) + _getter_QUANT_pairs = ( + (get_global_var, 'GLOBAL_QUANT'), + (get_efield_var, 'EFIELD_QUANT'), + (get_onefluid_var, 'ONEFLUID_QUANT'), + (get_electron_var, 'ELECTRON_QUANT'), + (get_continuity_quant, 'CONTINUITY_QUANT'), + (get_momentum_quant, 'MOMENTUM_QUANT'), + (get_heating_quant, 'HEATING_QUANT'), + (get_spitzerterm, 'SPITZERTERM_QUANT'), + (get_mf_colf, 'COLFRE_QUANT'), + (get_mf_logcul, 'LOGCUL_QUANT'), + (get_mf_cross, 'CROSTAB_QUANT'), + (get_mf_driftvar, 'DRIFT_QUANT'), + (get_mean_quant, 'MEAN_QUANT'), + (get_cfl_quant, 'CFL_QUANT'), + (get_mf_plasmaparam, 'PLASMA_QUANT'), + (get_hyperdiffusive_quant, 'HYPERDIFFUSIVE_QUANT'), + (get_mf_wavequant, 'WAVE_QUANT'), + (get_fb_instab_quant, 'FB_INSTAB_QUANT'), + (get_thermal_instab_quant, 'THERMAL_INSTAB_QUANT'), + ) + + val = None + # loop through the function and QUANT pairs, running the functions as appropriate. + for getter, QUANT_STR in _getter_QUANT_pairs: + QUANT = locals()[QUANT_STR] # QUANT = value of input parameter named QUANT_STR. + # if QUANT != '': + val = getter(obj, quant, **{QUANT_STR: QUANT}) + if val is not None: + break + return val + + +# default +_GLOBAL_QUANT = ('GLOBAL_QUANT', + ['totr', 'rc', 'rions', 'rneu', + 'tot_e', 'tot_ke', 'e_ef', 'e_b', 'total_energy', + 'tot_px', 'tot_py', 'tot_pz', + 'grph', 'tot_part', 'mu', + 'jx', 'jy', 'jz', 'resistivity' + ] + ) +# get value + + +@document_vars.quant_tracking_simple(_GLOBAL_QUANT[0]) +def get_global_var(obj, var, GLOBAL_QUANT=None): + '''Variables which are calculated by looping through species or levels.''' + if GLOBAL_QUANT is None: + GLOBAL_QUANT = _GLOBAL_QUANT[1] + + if var == '': + docvar = document_vars.vars_documenter(obj, _GLOBAL_QUANT[0], GLOBAL_QUANT, get_global_var.__doc__, nfluid=0) + docvar('totr', 'sum of mass densities of all fluids [simu. mass density units]', uni=UNI_rho) + for rc in ['rc', 'rions']: + docvar(rc, 'sum of mass densities of all ionized fluids [simu. mass density units]', uni=UNI_rho) + docvar('rneu', 'sum of mass densities of all neutral species [simu. mass density units]', uni=UNI_rho) + docvar('tot_e', 'sum of internal energy densities of all fluids [simu. energy density units]', **units_e) + docvar('tot_ke', 'sum of kinetic energy densities of all fluids [simu. energy density units]', **units_e) + docvar('e_ef', 'energy density in electric field [simu. energy density units]', **units_e) + docvar('e_b', 'energy density in magnetic field [simu. energy density units]', **units_e) + docvar('total_energy', 'total energy density. tot_e + tot_ke + e_ef + e_b [simu units].', **units_e) + docvar('resistivity', 'total resistivity of the plasma. sum of partial resistivity.' + + '[(simu. E-field units)/(simu. current per area units)]', + uni_f=UNI.ef / UNI.i, usi_name=(Usym('V'))/(Usym('A')*Usym('m'))) + for axis in AXES: + docvar('tot_p'+axis, 'sum of '+axis+'-momentum densities of all fluids [simu. mom. dens. units] ' + + 'NOTE: does not include "electron momentum" which is assumed to be ~= 0.', + uni=UNI_speed * UNI_rho) + docvar('grph', 'grams per hydrogen atom') + docvar('tot_part', 'total number of particles, including free electrons [cm^-3]') + docvar('mu', 'ratio of total number of particles without free electrong / tot_part') + for axis in AXES: + docvar('j'+axis, 'sum of '+axis+'-component of current per unit area [simu. current per area units]', + uni_f=UNI.i, usi_name=Usym('A')/Usym('m')**2) # ucgs_name= ??? + return None + + if var not in GLOBAL_QUANT: + return None + + output = obj.zero_at_mesh_center() + if var == 'totr': # total density + for ispecies in obj.att: + nlevels = obj.att[ispecies].params.nlevel + for ilevel in range(1, nlevels+1): + output += obj.get_var('r', mf_ispecies=ispecies, mf_ilevel=ilevel) + + return output + elif var in ['rc', 'rions']: # total ionized density + for fluid in obj.fluids.ions(): + output += obj.get_var('r', ifluid=fluid) + return output + elif var == 'rneu': # total neutral density + for ispecies in obj.att: + nlevels = obj.att[ispecies].params.nlevel + for ilevel in range(1, nlevels+1): + if (obj.att[ispecies].params.levels['stage'][ilevel-1] == 1): + output += obj.get_var('r', mf_ispecies=ispecies, mf_ilevel=ilevel) + return output + elif var == 'tot_e': + output += obj.get_var('e', mf_ispecies=-1) # internal energy density of electrons + for fluid in obj.fluids: + output += obj.get_var('e', ifluid=fluid.SL) # internal energy density of fluid + return output + elif var == 'tot_ke': + output = obj.get_var('eke') # kinetic energy density of electrons + for fluid in obj.fluids: + output += obj.get_var('ke', ifluid=fluid.SL) # kinetic energy density of fluid + return output + elif var == 'e_ef': + ef2 = obj.get_var('ef2') # |E|^2 [simu E-field units, squared] + eps0 = obj.uni.permsi # epsilon_0 [SI units] + units = obj.uni.usi_ef**2 / obj.uni.usi_e # convert ef2 * eps0 to [simu energy density units] + return (0.5 * eps0 * units) * ef2 + elif var == 'resistivity': + + ne = obj.get_var('nr', mf_ispecies=-1) # [simu. number density units] + neqe = ne * obj.uni.simu_qsi_e + rhoe = obj.get_var('re') + nu_sum = 0.0 + for fluid in obj.fluids: + nu_sum += obj.get_var('nu_ij', mf_ispecies=-1, jfluid=fluid) + output = nu_sum * rhoe / (neqe)**2 + return output + + elif var == 'e_b': + b2 = obj.get_var('b2') # |B|^2 [simu B-field units, squared] + mu0 = obj.uni.mu0si # mu_0 [SI units] + units = obj.uni.usi_b**2 / obj.uni.usi_e # convert b2 * mu0 to [simu energy density units] + return (0.5 * mu0 * units) * b2 + + elif var == 'total_energy': + with Caching(obj, nfluid=0) as cache: + output = obj.get_var('tot_e') + output += obj.get_var('tot_ke') + output += obj.get_var('e_ef') + output += obj.get_var('e_b') + cache(var, output) + return output + + elif var.startswith('tot_p'): # note: must be tot_px, tot_py, or tot_pz. + axis = var[-1] + for fluid in obj.fluids: + output += obj.get_var('p'+axis, ifluid=fluid.SL) # momentum density of fluid + + return output + elif var == 'grph': + for ispecies in obj.att: + nlevels = obj.att[ispecies].params.nlevel + weight = obj.att[ispecies].params.atomic_weight * \ + obj.uni.amu / obj.uni.u_r + + for ilevel in range(1, nlevels+1): + total_hpart += obj.get_var('r', mf_ispecies=ispecies, + mf_ilevel=ilevel) / weight + + for ispecies in obj.att: + nlevels = obj.att[ispecies].params.nlevel + weight = obj.att[ispecies].params.atomic_weight * \ + obj.uni.amu / obj.uni.u_r + + for ilevel in range(1, nlevels+1): + output += obj.get_var('r', mf_ispecies=ispecies, + mf_ilevel=ilevel) / mf_total_hpart * u_r + return output + elif var == 'tot_part': + for ispecies in obj.att: + nlevels = obj.att[ispecies].params.nlevel + weight = obj.att[ispecies].params.atomic_weight * \ + obj.uni.amu / obj.uni.u_r + for ilevel in range(1, nlevels+1): + output += obj.get_var('r', mf_ispecies=ispecies, + mf_ilevel=ilevel) / weight * (obj.att[ispecies].params.levels[ilevel-1]+1) + return output + elif var == 'mu': + for ispecies in obj.att: + nlevels = obj.att[ispecies].params.nlevel + weight = obj.att[ispecies].params.atomic_weight * \ + obj.uni.amu / obj.uni.u_r + for mf_ilevel in range(1, nlevels+1): + output += obj.get_var('r', mf_ispecies=ispecies, + mf_ilevel=mf_ilevel) / weight + output = output / obj.get_var('tot_part') + return output + elif var in ['jx', 'jy', 'jz']: + # J = curl (B) / mu_0 + x = var[-1] + # imposed current (imposed "additional" current, added artificially to system) + if obj.get_param('do_imposed_current', 0) > 0: + ic_units = obj.get_param('ic_units', 'ebysus') + ic_ix = obj.get_param('ic_i'+x, 0) # ic_ix [ic_units] + if ic_units.strip().lower() == 'si': + ic_ix /= obj.uni.usi_i # ic_ix [simu. units] + elif ic_units.strip().lower() == 'cgs': + ic_ix /= obj.uni.u_i # ic_ix [simu. units] + else: + ic_ix = 0 + # calculated current + curlb_x = obj.get_var('curvec'+'b'+x) * obj.uni.usi_b / obj.uni.usi_l # (curl b)_x [si units] + jx = curlb_x / obj.uni.mu0si # j [si units] + jx = jx / obj.uni.usi_i # j [simu. units] + return ic_ix + jx # j [simu. units] + else: + # if we reach this line, var is a global_var quant but we did not handle it. + raise NotImplementedError(f'{repr(var)} in get_global_var') + + +# default +_EFIELD_QUANT = ('EFIELD_QUANT', + ['efx', 'efy', 'efz', + 'uexbx', 'uexby', 'uexbz', + 'uepxbx', 'uepxby', 'uepxbz', + 'batx', 'baty', 'batz', + 'emomx', 'emomy', 'emomz', + 'efneqex', 'efneqey', 'efneqez'] + ) +# get value + + +@document_vars.quant_tracking_simple(_EFIELD_QUANT[0]) +def get_efield_var(obj, var, EFIELD_QUANT=None): + '''variables related to electric field.''' + if EFIELD_QUANT is None: + EFIELD_QUANT = _EFIELD_QUANT[1] + + if var == '': + docvar = document_vars.vars_documenter(obj, _EFIELD_QUANT[0], EFIELD_QUANT, get_efield_var.__doc__, nfluid=0) + EF_UNITS = dict(uni_f=UNI.ef, usi_name=Usym('V')/Usym('m')) # ucgs_name= ??? + for x in AXES: + docvar('ef'+x, x+'-component of electric field [simu. E-field units] ', **EF_UNITS) + for x in AXES: + docvar('uexb'+x, x+'-component of u_e cross B [simu. E-field units]. Note efx = - uexbx + ...', **EF_UNITS) + for x in AXES: + docvar('uepxb'+x, x+'-component of uep cross B [simu. E-field units]. Note efx = - uexbx + ... . ' + + ' uep is the electron velocity assuming current = 0.', **EF_UNITS) + for x in AXES: + docvar('bat'+x, x+'-component of "battery term" (contribution to electric field) [simu. E-field units]. ' + + '== grad(P_e) / (n_e q_e), where q_e < 0. ', **EF_UNITS) + for x in AXES: + docvar('emom'+x, x+'-component of collisions contribution to electric field [simu. E-field units]. ' + + '== sum_j R_e^(ej) / (n_e q_e)', **EF_UNITS) + for x in AXES: + docvar('efneqe'+x, 'value of n_e * q_e, interpolated to align with the {}-component of E '.format(x) + + '[simu. charge density units]. Note q_e < 0, so efneqe{} < 0.'.format(x), + uni_f=UNI.nq, usi_name=Usym('C')/Usym('m')**3) + return None + + if var not in EFIELD_QUANT: + return None + + x = var[-1] # axis; 'x', 'y', or 'z' + y, z = YZ_FROM_X[x] + base = var[:-1] # var without axis. E.g. 'ef', 'uexb', 'emom'. + + if base == 'ef': # electric field # efx + with Caching(obj, nfluid=0) as cache: + # E = - ue x B + (ne qe)^-1 * ( grad(pressure_e) - (ion & rec terms) - sum_j(R_e^(ej)) ) + # (where the convention used is qe < 0.) + # ----- -ue x B contribution ----- # + # There is a flag, "do_hall", when "false", we don't let the contribution + # from current to ue to enter in to the B x ue for electric field. + if obj.match_aux() and obj.get_param('do_hall', default="false") == "false": + ue = 'uep' # include only the momentum contribution in ue, in our ef calculation. + if obj.verbose: + warnings.warn('do_hall=="false", so we are dropping the j (current) contribution to ef (E-field)') + else: + ue = 'ue' # include the full ue term, in our ef calculation. + B_cross_ue__x = -1 * obj.get_var(ue+'xb'+x) + + # ----- grad Pe contribution ----- # + battery_x = obj.get_var('bat'+x) + # ----- calculate ionization & recombination effects ----- # + if obj.get_param('do_recion', default=False): + if obj.verbose: + warnings.warn('E-field contribution from ionization & recombination have not yet been added.') + # ----- calculate collisional effects ----- # + emom_x = obj.get_var('emom'+x) + # ----- calculate efx ----- # + result = B_cross_ue__x + battery_x + emom_x # [simu. E-field units] + cache(var, result) + + elif base in ('uexb', 'uepxb'): # ue x B # (aligned with efx) + ue = 'ue' if (base == 'uexb') else 'uep' + # interpolation: + # B and ue are face-centered vectors. + # Thus we use _facecross_ from load_arithmetic_quantities. + result = obj.get_var(ue+'_facecross_b'+x) + + elif base == 'bat': # grad(P_e) / (ne qe) + if obj.match_aux() and (not obj.get_param('do_battery', default=False)): + return obj.zero_at_mesh_edge(x) + # interpolation: + # efx is at (0, -1/2, -1/2). + # P is at (0,0,0). + # dpdxup is at (1/2, 0, 0). + # dpdxup xdn ydn zdn is at (0, -1/2, -1/2) --> aligned with efx. + interp = 'xdnydnzdn' + gradPe_x = obj.get_var('dpd'+x+'up'+interp, iS=-1) # [simu. energy density units] + neqe = obj.get_var('efneqe'+x) # ne qe, aligned with efx + result = gradPe_x / neqe + + elif base == 'emom': # -1 * sum_j R_e^(ej) / (ne qe) (aligned with efx) + if obj.match_aux() and (not obj.get_param('do_ohm_ecol', default=False)): + return obj.zero_at_mesh_edge(x) + # interpolation: + ## efx is at (0, -1/2, -1/2) + # rijx is at (-1/2, 0, 0) (same as ux) + # --> to align with efx, we shift rijx by xup ydn zdn + interp = x+'up'+y+'dn'+z+'dn' + sum_rejx = obj.get_var('rijsum'+x + interp, iS=-1) # [simu. momentum density units / simu. time units] + neqe = obj.get_var('efneqe'+x) # ne qe, aligned with efx + result = -1 * sum_rejx / neqe + + elif base == 'efneqe': # ne qe (aligned with efx) + # interpolation: + ## efx is at (0, -1/2, -1/2) + ## ne is at (0, 0, 0) + # to align with efx, we shift ne by ydn zdn + interp = y+'dn'+z+'dn' + result = obj.get_var('nq'+interp, iS=-1) # [simu. charge density units] (Note: 'nq' < 0 for electrons) + + else: + raise NotImplementedError(f'{repr(base)} in get_efield_var') + + return result + + +# default +_ONEFLUID_QUANT = ('ONEFLUID_QUANT', + ['nr', 'nq', 'p', 'pressure', 'tg', 'temperature', 'tgjoule', 'ke', 'vtherm', 'vtherm_simple', + 'ri', 'uix', 'uiy', 'uiz', 'pix', 'piy', 'piz'] + ) +# get value + + +@document_vars.quant_tracking_simple(_ONEFLUID_QUANT[0]) +def get_onefluid_var(obj, var, ONEFLUID_QUANT=None): + '''variables related to information about a single fluid. + Use mf_ispecies= -1 to refer to electrons. + Intended to contain only "simple" physical quantities. + For more complicated "plasma" quantities such as gryofrequncy, see PLASMA_QUANT. + + Quantities with 'i' are "generic" version of that quantity, + meaning it works with electrons or another fluid for ifluid. + For example, obj.get_var('uix') is equivalent to: + obj.get_var('uex') when obj.mf_ispecies < 0 + obj.get_var('ux') otherwise. + ''' + if ONEFLUID_QUANT is None: + ONEFLUID_QUANT = _ONEFLUID_QUANT[1] + + if var == '': + docvar = document_vars.vars_documenter(obj, _ONEFLUID_QUANT[0], ONEFLUID_QUANT, get_onefluid_var.__doc__, nfluid=1) + docvar('nr', 'number density of ifluid [simu. number density units]', uni=UNI_nr) + docvar('nq', 'charge density of ifluid [simu. charge density units]', + uni_f=UNI.q * UNI_nr.f, usi_name=Usym('C') / Usym('m')**3) + for tg in ['tg', 'temperature']: + docvar(tg, 'temperature of ifluid [K]', uni=U_TUPLE(UNITS_FACTOR_1, Usym('K'))) + docvar('tgjoule', 'temperature of ifluid [ebysus energy units]. == tg [K] * k_boltzmann [J/K]', uni=U_TUPLE(UNITS_FACTOR_1, Usym('J'))) + for p in ['p', 'pressure']: + docvar(p, 'pressure of ifluid [simu. energy density units]', uni_f=UNI.e) + docvar('ke', 'kinetic energy density of ifluid [simu. units]', **units_e) + _equivstr = " Equivalent to obj.get_var('{ve:}') when obj.mf_ispecies < 0; obj.get_var('{vf:}'), otherwise." + def equivstr(v): return _equivstr.format(ve=v.replace('i', 'e'), vf=v.replace('i', '')) + docvar('vtherm', 'thermal speed of ifluid [simu. velocity units]. = sqrt (8 * k_b * T_i / (pi * m_i) )', uni=UNI_speed) + docvar('vtherm_simple', '"simple" thermal speed of ifluid [simu. velocity units]. ' + + '= sqrt (k_b * T_i / m_i)', uni=UNI_speed) + docvar('ri', 'mass density of ifluid [simu. mass density units]. '+equivstr('ri'), uni=UNI_rho) + for uix in ['uix', 'uiy', 'uiz']: + docvar(uix, 'velocity of ifluid [simu. velocity units]. '+equivstr(uix), uni=UNI_speed) + for pix in ['pix', 'piy', 'piz']: + docvar(pix, 'momentum density of ifluid [simu. momentum density units]. '+equivstr(pix), uni=UNI_rho * UNI_speed) + return None + + if var not in ONEFLUID_QUANT: + return None + + if var == 'nr': + if obj.mf_ispecies < 0: # electrons + return obj.get_var('nre') + else: # not electrons + mass = obj.get_mass(obj.mf_ispecies, units='simu') # [simu. mass units] + return obj.get_var('r') / mass # [simu number density units] + + elif var == 'nq': + charge = obj.get_charge(obj.ifluid, units='simu') # [simu. charge units] + if charge == 0: + return obj.zero_at_mesh_center() + else: + return charge * obj.get_var('nr') + + elif var in ['p', 'pressure']: + gamma = obj.uni.gamma + return (gamma - 1) * obj.get_var('e') # p = (gamma - 1) * internal energy + + elif var in ['tg', 'temperature']: + p = obj.get_var('p') # [simu units] + nr = obj.get_var('nr') # [simu units] + if getattr(obj, 'debug', False): + raise Exception('boom') + return p / (nr * obj.uni.simu_kB) # [K] # p = n k T + + elif var == 'tgjoule': + return obj.uni.ksi_b * obj('tg') + + elif var == 'ke': + return 0.5 * obj.get_var('ri') * obj.get_var('ui2') + + elif var == 'vtherm': + Ti = obj.get_var('tg') # [K] + mi = obj.get_mass(obj.mf_ispecies, units='si') # [kg] + vtherm = np.sqrt(obj.uni.ksi_b * Ti / mi) # [m / s] + consts = np.sqrt(8 / np.pi) + return consts * vtherm / obj.uni.usi_u # [simu. velocity units] + + elif var == 'vtherm_simple': + Ti = obj('tg') # [K] + mi = obj.get_mass(obj.mf_ispecies, units='si') # [kg] + vtherm = np.sqrt(obj.uni.ksi_b * Ti / mi) # [m / s] + return vtherm / obj.uni.usi_u # [simu. velocity units] + + else: + if var in ['ri', 'uix', 'uiy', 'uiz', 'pix', 'piy', 'piz']: + if obj.mf_ispecies < 0: # electrons + e_var = var.replace('i', 'e') + return obj.get_var(e_var) + else: # not electrons + f_var = var.replace('i', '') + return obj.get_var(f_var) + + else: + raise NotImplementedError(f'{repr(var)} in get_onefluid_var') + + +# default +_ELECTRON_QUANT = ['nel', 'nre', 're', 'eke', 'pe'] +_ELECTRON_QUANT += [ue + x for ue in ['ue', 'pe', 'uej', 'uep'] for x in AXES] +_ELECTRON_QUANT = ('ELECTRON_QUANT', _ELECTRON_QUANT) +# get value + + +@document_vars.quant_tracking_simple(_ELECTRON_QUANT[0]) +def get_electron_var(obj, var, ELECTRON_QUANT=None): + '''variables related to electrons (requires looping over ions to calculate).''' + + if ELECTRON_QUANT is None: + ELECTRON_QUANT = _ELECTRON_QUANT[1] + + if var == '': + docvar = document_vars.vars_documenter(obj, _ELECTRON_QUANT[0], ELECTRON_QUANT, get_electron_var.__doc__, nfluid=0) + docvar('nel', 'electron number density [cm^-3]') + docvar('nre', 'electron number density [simu. number density units]', uni=UNI_nr) + docvar('re', 'mass density of electrons [simu. mass density units]', uni=UNI_rho) + docvar('eke', 'electron kinetic energy density [simu. energy density units]', **units_e) + docvar('pe', 'electron pressure [simu. pressure units]', uni_f=UNI.e) + for x in AXES: + docvar('ue'+x, '{}-component of electron velocity [simu. velocity units]'.format(x), + uni=UNI_speed) + for x in AXES: + docvar('pe'+x, '{}-component of electron momentum density [simu. momentum density units]'.format(x), + uni=UNI_speed * UNI_rho) + for x in AXES: + docvar('uej'+x, '{}-component of current contribution to electron velocity [simu. velocity units]'.format(x), + uni=UNI_speed) + for x in AXES: + docvar('uep'+x, '{}-component of species velocities contribution to electron velocity [simu. velocity units]'.format(x), + uni=UNI_speed) + return None + + if (var not in ELECTRON_QUANT): + return None + + if var == 'nel': # number density of electrons [cm^-3] + return obj.get_var('nre') * obj.uni.u_nr # [cm^-3] + + elif var == 'nre': # number density of electrons [simu. units] + with Caching(obj, nfluid=0) as cache: + ions = obj.fluids.ions() + if len(ions) == 0: + return obj.zero_at_mesh_center() + else: + return sum(obj('nr', ifluid=ion.SL) * ion.ionization for ion in ions) # [simu. number density units] + return output + + elif var == 're': # mass density of electrons [simu. mass density units] + return obj.get_var('nr', mf_ispecies=-1) * obj.uni.simu_m_e + + elif var == 'eke': # electron kinetic energy density [simu. energy density units] + return obj.get_var('ke', mf_ispecies=-1) + + elif var == 'pe': + return (obj.uni.gamma-1) * obj.get_var('e', mf_ispecies=-1) + + elif var in ['uepx', 'uepy', 'uepz']: # electron velocity (contribution from momenta) + # i = sum_j (nj uj qj) + ne qe ue + # --> ue = (i - sum_j(nj uj qj)) / (ne qe) + x = var[-1] # axis; 'x', 'y', or 'z'. + # get component due to velocities: + # r is in center of cells, while u is on faces, so we need to interpolate. + ## r is at (0, 0, 0); ux is at (-0.5, 0, 0) + # ---> to align with ux, we shift r by xdn + interp = x+'dn' + output = obj.zero_at_mesh_face(x) + nqe = obj.zero_at_mesh_face(x) # charge density of electrons. + for fluid in obj.fluids.ions(): + nq = obj.get_var('nq' + interp, ifluid=fluid.SL) # [simu. charge density units] + ux = obj.get_var('u'+x, ifluid=fluid.SL) # [simu. velocity units] + output -= nq * ux # [simu. current per area units] + nqe -= nq # [simu. charge density units] + return output / nqe # [simu velocity units] + + elif var in ['uejx', 'uejy', 'uejz']: # electron velocity (contribution from current) + # i = sum_j (nj uj qj) + ne qe ue + # --> ue = (i - sum_j(nj uj qj)) / (ne qe) + x = var[-1] # axis; 'x', 'y', or 'z'. + # get component due to current: + # i is on edges of cells, while u is on faces, so we need to interpolate. + ## ix is at (0, -0.5, -0.5); ux is at (-0.5, 0, 0) + # ---> to align with ux, we shift ix by xdn yup zup + y, z = tuple(set(AXES) - set((x))) + interpj = x+'dn' + y+'up' + z+'up' + jx = obj.get_var('j'+x + interpj) # [simu current per area units] + # r (nq) is in center of cells, while u is on faces, so we need to interpolate. + ## r is at (0, 0, 0); ux is at (-0.5, 0, 0) + # ---> to align with ux, we shift r by xdn + interpn = x+'dn' + nqe = obj.get_var('nq' + interpn, iS=-1) # [simu charge density units] + return jx / nqe # [simu velocity units] + + elif var in ['uex', 'uey', 'uez']: # electron velocity [simu. velocity units] + with Caching(obj, nfluid=0) as cache: + # i = sum_j (nj uj qj) + ne qe ue + # --> ue = (i - sum_j(nj uj qj)) / (ne qe) + x = var[-1] # axis; 'x', 'y', or 'z'. + # get component due to current: + # i is on edges of cells, while u is on faces, so we need to interpolate. + ## ix is at (0, -0.5, -0.5); ux is at (-0.5, 0, 0) + # ---> to align with ux, we shift ix by xdn yup zup + y, z = tuple(set(AXES) - set((x))) + interp = x+'dn' + y+'up' + z+'up' + output = obj.get_var('j'+x + interp) # [simu current per area units] + # get component due to velocities: + # r is in center of cells, while u is on faces, so we need to interpolate. + ## r is at (0, 0, 0); ux is at (-0.5, 0, 0) + # ---> to align with ux, we shift r by xdn + interp = x+'dn' + nqe = obj.zero_at_mesh_face(x) # charge density of electrons. + for fluid in obj.fluids.ions(): + nq = obj.get_var('nq' + interp, ifluid=fluid.SL) # [simu. charge density units] + ux = obj.get_var('u'+x, ifluid=fluid.SL) # [simu. velocity units] + output -= nq * ux # [simu. current per area units] + nqe -= nq # [simu. charge density units] + output = output / nqe + cache(var, output) + return output # [simu velocity units] + + elif var in ['pex', 'pey', 'pez']: # electron momentum density [simu. momentum density units] + # p = r * u. + # u is on faces of cells, while r is in center, so we need to interpolate. + # px and ux are at (-0.5, 0, 0); r is at (0, 0, 0) + # ---> to align with ux, we shift r by xdn + x = var[-1] # axis; 'x', 'y', or 'z'. + interp = x+'dn' + re = obj.get_var('re'+interp) # [simu. mass density units] + uex = obj.get_var('ue'+x) # [simu. velocity units] + return re * uex # [simu. momentum density units] + + else: + raise NotImplementedError(f'{repr(var)} in get_electron_var') + + +# default +_CONTINUITY_QUANT = ('CONTINUITY_QUANT', + ['ndivu', 'udivn', 'udotgradn', 'flux_nu', 'flux_un', + 'gradnx', 'gradny', 'gradnz'] + ) +# get value + + +@document_vars.quant_tracking_simple(_CONTINUITY_QUANT[0]) +def get_continuity_quant(obj, var, CONTINUITY_QUANT=None): + '''terms related to the continuity equation. + In the simple case (e.g. no ionization), expect dn/dt + flux_un = 0. + ''' + if CONTINUITY_QUANT is None: + CONTINUITY_QUANT = _CONTINUITY_QUANT[1] + + if var == '': + docvar = document_vars.vars_documenter(obj, _CONTINUITY_QUANT[0], CONTINUITY_QUANT, + get_continuity_quant.__doc__, nfluid=1, uni=UNI_nr * UNI_hz) + docvar('ndivu', 'number density times divergence of velocity') + for udivn in ('udivn', 'udotgradn'): + docvar(udivn, 'velocity dotted with gradient of number density') + for x in AXES: + docvar('gradn'+x, x+'-component of grad(nr), face-centered.', nfluid=1, uni=UNI.qc(0)) # qc0 will be e.g. dnrdxdn. + for flux_un in ('flux_un', 'flux_nu'): + docvar(flux_un, 'divergence of (velocity times number density). Calculated via ndivu + udotgradn.') + docvar('flux_p', 'divergence of momentum density') + return None + + if var not in CONTINUITY_QUANT: + return None + + # --- continuity equation terms --- # + if var == 'ndivu': + n = obj('nr') + divu = obj('divupui') # divup(ui). up to align with n. + return n * divu + + elif var in ('gradnx', 'gradny', 'gradnz'): + return obj(f'dnrd{var[-1]}dn') + + elif var in ('udivn', 'udotgradn'): + return obj('ui_facedot_gradn') + + elif var in ('flux_nu', 'flux_un'): + return obj('ndivu') + obj('udivn') + + else: + raise NotImplementedError(f'{repr(var)} in get_momentum_quant') + + +# default +_MQUVECS = [ # momentum quantities that will have 'x', 'y', or 'z' after, and can have 'u' before. + 'rijsum', 'rij', + 'momflorentz', 'momfef', 'momfb', + 'mompg', + 'momdtime', +] +_MQVECS = [ # momentum quantities that will have 'x', 'y', or 'z' after. + 'momohme', 'mombat', 'gradp', + 'momrate', + 'ueq', 'ueqsimple', '_ueq_scr', + *_MQUVECS, + *('u'+v for v in _MQUVECS), +] +_MOMENTUM_QUANT = ('MOMENTUM_QUANT', [v + x for v in _MQVECS for x in AXES]) +# get value + + +@document_vars.quant_tracking_simple(_MOMENTUM_QUANT[0]) +def get_momentum_quant(obj, var, MOMENTUM_QUANT=None): + '''terms related to momentum equations of fluids. + The units are: + - for "momentum density rate of change" quantities -- [simu. momentum density units / simu. time units]. + - for "velocity rate of change" quantities -- [simu. velocity units / simu. time units]. + ''' + if MOMENTUM_QUANT is None: + MOMENTUM_QUANT = _MOMENTUM_QUANT[1] + + if var == '': + docvar = document_vars.vars_documenter(obj, _MOMENTUM_QUANT[0], MOMENTUM_QUANT, get_momentum_quant.__doc__) + # "helper" units + units_dpdt = dict(uni_f=UNI.phz, uni_name=UNI_rho.name * UNI_speed.name / UNI_time.name) + units_dudt = dict(uni=UNI_speed / UNI_time) + # begin documenting + for x in AXES: + # "helper" strings + momratex_i = f'{x}-component of momentum density rate of change of ifluid' + velratex_i = f'{x}-component of velocity rate of change of ifluid' + # documentation of variables + ## Collisions ## + docvar(f'rij{x}', f'{momratex_i} due to collisions with jfluid. = mi ni nu_ij * (u{x}_j - u{x}_i)', nfluid=2, **units_dpdt) + docvar(f'urij{x}', f'{velratex_i} due to collisions with jfluid. = nu_ij * (u{x}_j - u{x}_i)', nfluid=2, **units_dudt) + docvar(f'rijsum{x}', f'{momratex_i} due to collisions with all other fluids. = sum_j rij{x}', nfluid=1, **units_dpdt) + docvar(f'urijsum{x}', f'{velratex_i} due to collisions with all other fluids. = sum_j urij{x}', nfluid=1, **units_dudt) + ## Lorentz force ## + docvar(f'momfef{x}', f'{momratex_i} due to electric field. = ni qi E{x}', nfluid=1, **units_dpdt) + docvar(f'umomfef{x}', f'{velratex_i} due to electric field. = (qi/mi) E{x}', nfluid=1, **units_dudt) + docvar(f'momfb{x}', f'{momratex_i} due to magnetic field. = ni qi (ui x B)_{x}', nfluid=1, **units_dpdt) + docvar(f'umomfb{x}', f'{velratex_i} due to magnetic field. = (qi/mi) (ui x B)_{x}', nfluid=1, **units_dudt) + docvar(f'momflorentz{x}', f'{momratex_i} due to Lorentz force. = ni qi (E + ui x B)_{x}.', nfluid=1, **units_dpdt) + docvar(f'umomflorentz{x}', f'{velratex_i} due to Lorentz force. = (qi/mi) (E + ui x B)_{x}.', nfluid=1, **units_dudt) + ### electric field sub-terms ### + docvar(f'momohme{x}', f'{momratex_i} due the ohmic term in the electric field. = ni qi nu_es (ui-epUx) .', nfluid=1, **units_dpdt) + docvar(f'mombat{x}', f'{momratex_i} due to battery term. = ni qi grad(P_e) / (ne qe).', nfluid=1, **units_dpdt) + ## Pressure ## + docvar(f'gradp{x}', f'{x}-component of grad(Pi), face-centered (mesh location aligns with momentum).', nfluid=1, uni=UNI.qc(0)) + docvar(f'mompg{x}', f'{momratex_i} due to pressure. = -grad(P_i) dot {x}', nfluid=1, **units_dpdt) + docvar(f'umompg{x}', f'{velratex_i} due to pressure. = -grad(P_i) dot {x} / (ni mi)', nfluid=1, **units_dudt) + ## TOTAL d/dt ## + docvar(f'momdtime{x}', f'{momratex_i}, total. = (-gradp + momflorentz + rijsum)_{x}', nfluid=1, **units_dpdt) + docvar(f'umomdtime{x}', f'{velratex_i}, total. = (-gradp + momflorentz + rijsum)_{x} / (mi ni)', nfluid=1, **units_dudt) + docvar(f'momrate{x}', f'momdtime{x}', copy=True) # alias momratex <--> momdtimex, for historical reasons. (prefer: momdtime) + ## "equilibrium" velocities ## + docvar('ueq'+x, x+'-component of equilibrium velocity of ifluid. Ignores derivatives in momentum equation. ' + + '= '+x+'-component of [qs (_ueq_scr x B) + (ms) (sum_{j!=s} nu_sj) (_ueq_scr)] /' + + ' [(qs^2/ms) B^2 + (ms) (sum_{j!=s} nu_sj)^2]. [simu velocity units].', nfluid=1, uni=UNI_speed) + docvar('ueqsimple'+x, x+'-component of "simple" equilibrium velocity of ifluid. ' + + 'Treats these as 0: derivatives in momentum equation, velocity of jfluid, nu_sb for b not jfluid.' + + '= '+x+'-component of [(qs/(ms nu_sj))^2 (E x B) + qs/(ms nu_sj) E] /' + + ' [( (qs/ms) (|B|/nu_sj) )^2 + 1]. [simu. velocity units].', nfluid=2, uni=UNI_speed) + docvar('_ueq_scr'+x, x+'-component of helper term which appears twice in formula for ueq. '+x+'-component of ' + + ' [(qs/ms) E + (sum_{j!=s} nu_sj uj)]. face-centered. [simu velocity units].', nfluid=1, uni=UNI_speed) + return None + + if var not in MOMENTUM_QUANT: + return None + + # --- momentum equation terms --- # + x = var[-1] # axis; x= 'x', 'y', or 'z'. + if x in AXES: + y, z = YZ_FROM_X[x] + base = var[:-1] + else: + base = var + + umom = (base[0] == 'u') # whether we are doign 'umom' version of quant. + ubase = (base[1:] if umom else base) # e.g., 'rijx' given base='urijx' or base='rijx'. + + ## COLLISIONS ## + if ubase == 'rij': + if obj.i_j_same_fluid(): # when ifluid==jfluid, u_j = u_i, so rij = 0. + return obj.zero_at_mesh_face(x) # save time by returning 0 without reading any data. + # rij = mi ni nu_ij * (u_j - u_i) = ri nu_ij * (u_j - u_i) + # (Note: this does NOT equal to nu_ij * (rj u_j - ri u_i)) + # Scalars are at (0,0,0) so we must shift by xdn to align with face-centered u at (-0.5,0,0) + nu_ij = obj.get_var('nu_ij' + x+'dn') + uix = obj.get_var('ui'+x) + with obj.MaintainFluids(): + ujx = obj.get_var('ui'+x, ifluid=obj.jfluid) + if umom: + return nu_ij * (ujx - uix) + else: + ri = obj.get_var('ri' + x+'dn') + return ri * nu_ij * (ujx - uix) + + elif ubase == 'rijsum': + u = 'u' if umom else '' + return sum(obj(f'{u}rij{x}', jfluid=jSL) for jSL in obj.fluid_SLs(with_electrons=True)) + + ## LORENTZ FORCE ## + elif ubase in ('momfef', 'momfb', 'momflorentz'): + # momflorentz = (qi*ni) (E + ui x B) + # umomflorentz = (qi/mi) (E + ui x B) + # all of these quants are proportional to qi; get that first (if 0, return 0 to save time) + qi = obj.get_charge(obj.ifluid, units='simu') + if qi == 0: + return obj.zero_at_mesh_face(x) # no lorentz force for neutrals - save time by just returning 0 here :) + # factor in front. (qi*ni) for 'mom'; (qi/mi) for 'umom' + if umom: + mi = obj.get_mass(obj.ifluid, units='simu') + front = qi / mi + else: + ni = obj('nr'+x+'dn') # n, aligned with velocity + front = qi * ni + ## specific quantities (E, u x B, or E + u x B) ## + if ubase in ('momfef', 'momflorentz'): + # E interpolation notes: + # Ex is at (0, -0.5, -0.5); we shift to align with ux at (-0.5, 0, 0) + Ex = obj('ef'+x + x+'dn' + y+'up' + z+'up', cache_with_nfluid=0) # caching improves speed if calculation is repeated. + if ubase == 'momfef': + return front * Ex # (qi ni) E or (qi/mi) E + if ubase in ('momfb', 'momflorentz'): + # B, ui interpolation notes: + # B and ui are face-centered vectors, and we want a face-centered result to align with u. + # Thus we use ui_facecrosstoface_b (which gives a face-centered result). + uxB__x = obj.get_var('ui_facecrosstoface_b'+x) + if ubase == 'momfb': + return front * uxB__x # (qi ni) u x B or (qi/mi) u x B + if ubase == 'momflorentz': + return front * (Ex + uxB__x) # (qi ni) (E + u x B) or (qi/mi) (E + u x B) + else: + raise NotImplementedError(f"all ubase cases should have been handled, but got ubase={repr(ubase)}") + + ### ELECTRIC FIELD SUB-TERMS ### + elif base == 'momohme': + # momflorentz = ni qi (E + ui x B) + qi = obj.get_charge(obj.ifluid, units='simu') + if qi == 0: + return obj.zero_at_mesh_face(x) # no lorentz force for neutrals - save time by just returning 0 here :) + ni = obj.get_var('nr') + # make sure we get the interpolation correct: + # B and ui are face-centered vectors, and we want a face-centered result to align with p. + # Thus we use ui_facecrosstoface_b (which gives a face-centered result). + # Meanwhile, E is edge-centered, so we must shift all three coords. + # Ex is at (0, -0.5, -0.5), so we shift by xdn, yup, zup + Ex = obj.get_var('emom'+x + x+'dn' + y+'up' + z+'up', cache_with_nfluid=0) + + return ni * qi * Ex + + elif base == 'mombat': + # px is at (-0.5, 0, 0); nq is at (0, 0, 0), so we shift by xdn + interp = x+'dn' + niqi = obj('nq'+interp) + with obj.MaintainFluids(): + obj.iS = -1 + neqe = obj('nq'+interp) + gradPe_x = obj('gradp'+x) # gradp handles the interpolation already. + return (niqi / neqe) * gradPe_x + + ## PRESSURE ## + elif base == 'gradp': + # px is at (-0.5, 0, 0); pressure is at (0, 0, 0), so we do dpdxdn + return obj.get_var('dpd'+x+'dn') + + elif ubase == 'mompg': + gradpx = obj('dpd'+x+'dn') + mompgx = - gradpx + if umom: + ri = obj('ri'+x+'dn') # rho_i, shifted to align with u + return mompgx / ri + else: + return mompgx + + ## TOTAL ## + elif (ubase == 'momdtime') or (base == 'momrate'): + if obj.get_param('do_recion', default=False): + if obj.verbose: + warnings.warn('momentum contribution from ionization & recombination have not yet been added.') + u = 'u' if umom else '' + mompgx = obj(f'{u}mompg{x}') + florentzx = obj(f'{u}momflorentz{x}') + rijsumx = obj(f'{u}rijsum{x}') + return mompgx + florentzx + rijsumx + + # --- "equilibrium velocity" terms --- # + elif base == '_ueq_scr': + qi = obj.get_charge(obj.ifluid, units='simu') + mi = obj.get_mass(obj.ifluid, units='simu') + ifluid_orig = obj.ifluid + # make sure we get the interpolation correct: + # We want a face-centered result to align with u. + # E is edge-centered, so we must shift all three coords. + # Ex is at (0, -0.5, -0.5), so we shift by xdn, yup, zup + # Meanwhile, scalars are at (0,0,0), so we shift those by xdn to align with u. + Ex = obj.get_var('ef'+x + x+'dn' + y+'up' + z+'up', cache_with_nfluid=0) + sum_nu_u = 0 + for jSL in obj.iter_fluid_SLs(): + if jSL != ifluid_orig: + nu_sj = obj.get_var('nu_ij' + x+'dn', ifluid=ifluid_orig, jfluid=jSL) + uj = obj.get_var('ui'+x, ifluid=jSL) + sum_nu_u += nu_sj * uj + return (qi / mi) * Ex + sum_nu_u + + elif base == 'ueq': + qi = obj.get_charge(obj.ifluid, units='simu') + mi = obj.get_mass(obj.ifluid, units='simu') + # make sure we get the interpolation correct: + # We want a face-centered result to align with u. + # B and _ueq_scr are face-centered, so we use _facecrosstoface_ to get a face-centered cross product. + # Meanwhile, E is edge-centered, so we must shift all three coords. + # Finally, scalars are at (0,0,0), so we shift those by xdn to align with u. + B2 = obj.get_var('b2' + x+'dn') + # begin calculations + ueq_scr_x_B__x = obj.get_var('_ueq_scr_facecrosstoface_b'+x) + ueq_scr__x = obj.get_var('_ueq_scr'+x) + sumnu = 0 + for jSL in obj.iter_fluid_SLs(): + if jSL != obj.ifluid: + sumnu += obj.get_var('nu_ij' + x+'dn', jfluid=jSL) + numer = qi * ueq_scr_x_B__x + mi * sumnu * ueq_scr__x + denom = (qi**2/mi) * B2 + mi * sumnu**2 + return numer / denom + + elif base == 'ueqsimple': + qi = obj.get_charge(obj.ifluid, units='simu') + mi = obj.get_mass(obj.ifluid, units='simu') + # make sure we get the interpolation correct: + # B and ui are face-centered vectors, and we want a face-centered result to align with u. + # Thus we use ui_facecrosstoface_b (which gives a face-centered result). + # Meanwhile, E is edge-centered, so we must shift all three coords. + # Ex is at (0, -0.5, -0.5), so we shift by xdn, yup, zup + # Finally, scalars are at (0,0,0), so we shift those by xdn to align with u. + Ex = obj.get_var('ef'+x + x+'dn' + y+'up' + z+'up', cache_with_nfluid=0) + B2 = obj.get_var('b2' + x+'dn') + ExB__x = obj.get_var('ef_edgefacecross_b'+x) + nu_ij = obj.get_var('nu_ij' + x+'dn') + # begin calculations + q_over_m_nu = (qi/mi) / nu_ij + q_over_m_nu__squared = q_over_m_nu**2 + numer = q_over_m_nu__squared * ExB__x + q_over_m_nu * Ex + denom = q_over_m_nu__squared * B2 + 1 + return numer / denom + + else: + raise NotImplementedError(f'{repr(base)} in get_momentum_quant') + + +# default +_HEATING_QUANT = ['qcol_uj', 'qcol_tgj', 'qcol_coeffj', 'qcolj', 'qcol_j', + 'qcol_u', 'qcol_tg', 'qcol', + 'edu', 'edspaceu', + 'edtime', + 'e_to_tg', + 'tg_qcol', # TODO: add tg_qcol_... for as many of the qcol terms as applicable. + 'tg_qcol_uj', 'tg_qcol_u', 'tg_qcol_tgj', 'tg_qcol_tg', 'tg_qcol_j', 'tg_qcolj', + 'qjoulei', + 'tgdu', 'tgdspaceu', + 'tg_rate', 'tgdtime', # use tgdtime instead of tg_rate to avoid ambiguity with "rat" quant. + 'qcol_u_noe', 'qcol_tg_noe', + ] +_TGQCOL_EQUIL = ['tgqcol_equil' + x for x in ('_uj', '_tgj', '_j', '_u', '_tg', '')] +_HEATING_QUANT += _TGQCOL_EQUIL +_HEATING_QUANT = ('HEATING_QUANT', _HEATING_QUANT) +# get value + + +@document_vars.quant_tracking_simple(_HEATING_QUANT[0]) +def get_heating_quant(obj, var, HEATING_QUANT=None): + '''terms related to heating of fluids. + + Note that the code in this section is written for maximum readability, not maximum efficiency. + For most vars in this section the code would run a bit faster if you write use-case specific code. + + For example, qcolj gets qcol_uj + qcol_tgj, however each of those will separately calculate + number density (nr) and collision frequency (nu_ij); so the code will calculate the same value + of nr and nu_ij two separate times. It would be more efficient to calculate these only once. + + As another example, qjoulei will re-calculate the electric field efx, efy, and efz + each time it is called; if you are doing a sum of multiple qjoulei terms it would be more + efficient to calculate each of these only once. + + Thus, if you feel that code in this section is taking too long, you can speed it up by writing + your own code which reduces the number of times calculations are repeated. + (Note, if you had N_memmap = 0 or fast=False, first try using N_memmap >= 200, and fast=True.) + ''' + if HEATING_QUANT is None: + HEATING_QUANT = _HEATING_QUANT[1] + + if var == '': + docvar = document_vars.vars_documenter(obj, _HEATING_QUANT[0], HEATING_QUANT, get_heating_quant.__doc__) + units_qcol = dict(uni_f=UNI.e / UNI.t, usi_name=Usym('J')/(Usym('m')**3 * Usym('s'))) + units_e_to_tg = dict(uni_f=UNITS_FACTOR_1 / UNI.e, usi_name=Usym('K') / (Usym('J') / Usym('m')**3)) + units_tg = dict(uni_f=UNITS_FACTOR_1, uni_name=Usym('K')) + units_dtgdt = dict(uni_f=UNI.hz, uni_name=Usym('K')/Usym('s')) + + # docs for qcol and tg_qcol terms. + qcol_docdict = { + 'qcol_uj': ('{heati} due to collisions with jfluid, and velocity drifts.', dict(nfluid=2)), + 'qcol_u': ('{heati} due to collisions and velocity drifts.', dict(nfluid=1)), + 'qcol_u_noe': ('{heati} due to collisions and velocity drifts without electrons.', dict(nfluid=1)), + 'qcol_tgj': ('{heati} due to collisions with jfluid, and temperature differences.', dict(nfluid=2)), + 'qcol_tg': ('{heati} due to collisions and temperature differences.', dict(nfluid=1)), + 'qcol_tg_noe': ('{heati} due to collisions and temperature differences without electrons.', dict(nfluid=1)), + 'qcolj': ('total {heati} due to collisions with jfluid.', dict(nfluid=2)), + 'qcol': ('total {heati} due to collisions.', dict(nfluid=1)), + } + qcol_docdict['qcol_j'] = qcol_docdict['qcolj'] # alias + + # qcol: heating due to collisions in addition to velocity and/or temperature differences + # qcol tells the energy density change per unit time. + # tg_qcol tells the temperature change per unit time. + for vname, (vdoc, kw_nfluid) in qcol_docdict.items(): + docvar(vname, vdoc.format(heati='heating of ifluid [simu. energy density per time]'), **kw_nfluid, **units_qcol) + docvar('tg_'+vname, vdoc.format(heati='heating of ifluid [Kelvin per simu. time]'), **kw_nfluid, **units_dtgdt) + docvar('qcol_coeffj', 'coefficient common to qcol_uj and qcol_tj terms.' + + ' == (mi / (gamma - 1) (mi + mj)) * ni * nu_ij. [simu units: length^-3 time^-1]', + nfluid=2, **units_qcol) + docvar('e_to_tg', 'conversion factor from energy density to temperature for ifluid. ' + + 'e_ifluid * e_to_tg = tg_ifluid', nfluid=1, **units_e_to_tg) + # the other heating in the heating equation. + # partial e / partial t = - div(u e) - P div(u) + qcol + # partial T / partial t = - div(u T) + (1/3) T div(u) + tgqcol + # we've written it this way because the divergence theorem tells us + # that the mean (or the "integral over the box") of div(u T) must be 0 for a periodic box. + docvar('edu', 'rate of change of ei due to -P * div(u)', **units_qcol) + docvar('tgdu', 'rate of change of Ti due to +1/3 * T * div(u)', **units_dtgdt) + docvar('edspaceu', 'rate of change of ei due to -div(u e)', **units_qcol) + docvar('tgdspaceu', 'rate of change of Ti due to - div(u T)', **units_dtgdt) + docvar('edtime', 'predicted total rate of change of ei, including all contributions.', **units_qcol) + for tg_rate in ('tg_rate', 'tgdtime'): + docvar(tg_rate, 'predicted total rate of change of Ti, including all contributions. ' + + 'use "tgdtime" to avoid ambiguity with "rat" quant.', **units_dtgdt) + + # "simple equilibrium" vars + equili = '"simple equilibrium" temperature [K] of ifluid (setting sum_j Qcol_ij=0 and solving for Ti)' + # note: these all involve setting sum_j Qcol_ij = 0 and solving for Ti. + # Let Cij = qcol_coeffj; Uij = qcol_uj, Tj = temperature of j. Then: + # Ti == ( sum_{s!=i}(Cis Uis + Cis * 2 kB Ts) ) / ( 2 kB sum_{s!=i}(Cis) ) + # so for the "components" terms, we pick out only one term in this sum (in the numerator), e.g.: + # tgqcol_equil_uj == Cij Uij / ( 2 kB sum_{s!=i}(Cis) ) + docvar('tgqcol_equil_uj', equili + ', due only to contribution from velocity drift with jfluid.', nfluid=2, **units_tg) + docvar('tgqcol_equil_tgj', equili + ', due only to contribution from temperature of jfluid.', nfluid=2, **units_tg) + docvar('tgqcol_equil_j', equili + ', due only to contribution from jfluid.', nfluid=2, **units_tg) + docvar('tgqcol_equil_u', equili + ', due only to contribution from velocity drifts with fluids.', nfluid=1, **units_tg) + docvar('tgqcol_equil_tg', equili + ', due only to contribution from temperature of fluids.', nfluid=1, **units_tg) + docvar('tgqcol_equil', equili + '.', nfluid=1, **units_tg) + # "ohmic heating" (obsolete (?) - nonphysical to include this qjoule and the qcol_u term as it appears here.) + docvar('qjoulei', 'qi ni ui dot E. (obsolete, nonphysical to include this term and the qcol_u term)', nfluid=1, **units_qcol) + return None + + if var not in HEATING_QUANT: + return None + + def heating_is_off(): + '''returns whether we should treat heating as if it is turned off.''' + if obj.match_physics(): + return False + if obj.mf_ispecies < 0 or obj.mf_jspecies < 0: # electrons + return not (obj.get_param('do_ohm_ecol', True) and obj.get_param('do_qohm', True)) + else: # not electrons + return not (obj.get_param('do_col', True) and obj.get_param('do_qcol', True)) + + # full rate of change of e or T: + if var == 'edtime': + qcol = obj('qcol') + edu = obj('edu') + edsp = obj('edspaceu') + return qcol + edu + edsp + + if var in ['tg_rate', 'tgdtime']: + tgqcol = obj('tg_qcol') + tgdu = obj('tgdu') + tgd_udivtg = obj('tgdspaceu') + return tgqcol + tgdu + tgd_udivtg + + # qcol terms + elif var == 'qcol_coeffj': + if heating_is_off() or obj.i_j_same_fluid(): + return obj.zero_at_mesh_center() + ni = obj.get_var('nr') # [simu. units] + mi = obj.get_mass(obj.mf_ispecies) # [amu] + mj = obj.get_mass(obj.mf_jspecies) # [amu] + nu_ij = obj.get_var('nu_ij') # [simu. units] + coeff = (mi / (mi + mj)) * ni * nu_ij # [simu units: length^-3 time^-1] + return coeff + + elif var in ['qcol_uj', 'qcol_tgj']: + if heating_is_off() or obj.i_j_same_fluid(): + return obj.zero_at_mesh_center() + coeff = obj.get_var('qcol_coeffj') + if var == 'qcol_uj': + mj_simu = obj.get_mass(obj.mf_jspecies, units='simu') # [simu mass] + energy = mj_simu * obj.get_var('uid2') # [simu energy] + elif var == 'qcol_tgj': + simu_kB = obj.uni.ksi_b * (obj.uni.usi_nr / obj.uni.usi_e) # kB [simu energy / K] + tgi = obj.get_var('tg') # [K] + tgj = obj.get_var('tg', ifluid=obj.jfluid) # [K] + energy = 3. * simu_kB * (tgj - tgi) + return coeff * energy # [simu energy density / time] + + elif var in ['qcolj', 'qcol_j']: + if heating_is_off(): + return obj.zero_at_mesh_center() + return obj.get_var('qcol_uj') + obj.get_var('qcol_tgj') + + elif var in ['qcol_u', 'qcol_tg']: + if heating_is_off(): + return obj.zero_at_mesh_center() + varj = var + 'j' # qcol_uj or qcol_tgj + output = obj.get_var(varj, jS=-1) # get varj for j = electrons + for fluid in obj.fluids: + if fluid.SL != obj.ifluid: # exclude varj for j = i # not necessary but doesn't hurt. + output += obj.get_var(varj, jfluid=fluid) + return output + + elif var in ['qcol_u_noe', 'qcol_tg_noe']: + output = obj.zero_at_mesh_center() + if heating_is_off(): + return obj.zero_at_mesh_center() + varj = var[:-4] + 'j' # qcol_uj or qcol_tgj + for fluid in obj.fluids: + if fluid.SL != obj.ifluid: # exclude varj for j = i # not necessary but doesn't hurt. + output += obj.get_var(varj, jfluid=fluid) + return output + + elif var == 'qcol': + if heating_is_off(): + return obj.zero_at_mesh_center() + return obj.get_var('qcol_u') + obj.get_var('qcol_tg') + + # derivative terms + elif var == 'edu': + p = obj('p') # pressure + divu = obj('divup'+'ui') + return -1 * p * divu + + elif var == 'tgdu': + tg = obj('tg') + divu = obj('divup'+'ui') + return +1/3 * tg * divu + + elif var == 'edspaceu': + return sum(obj.stagger.do(obj('e'+f'{x}dn') * obj(f'ui{x}'), f'dd{x}up') for x in AXES) + + elif var == 'tgdspaceu': + return sum(obj.stagger.do(obj('tg'+f'{x}dn') * obj(f'ui{x}'), f'dd{x}up') for x in AXES) + + # converting to temperature (from energy density) terms + elif var == 'e_to_tg': + simu_kB = obj.uni.ksi_b * (obj.uni.usi_nr / obj.uni.usi_e) # kB [simu energy / K] + return (obj.uni.gamma - 1) / (obj.get_var('nr') * simu_kB) + + elif var.startswith('tg_'): + qcol = var[len('tg_'):] # var looks like tg_qcol + assert qcol in HEATING_QUANT, "qcol must be in heating quant to get tg_qcol. qcol={}".format(repr(qcol)) + qcol_value = obj.get_var(qcol) # [simu energy density / time] + e_to_tg = obj.get_var('e_to_tg') # [K / simu energy density (of ifluid)] + return qcol_value * e_to_tg # [K] + + # "simple equilibrium temperature" terms + elif var in _TGQCOL_EQUIL: + suffix = var.split('_')[-1] # uj, tgj, j, u, tg, or equil + # Let Cij = qcol_coeffj; Uij = qcol_uj, Tj = temperature of j. Then: + # Ti == ( sum_{s!=i}(Cis Uis + Cis * 2 kB Ts) ) / ( 2 kB sum_{s!=i}(Cis) ) + # so for the "components" terms, we pick out only one term in this sum (in the numerator), e.g.: + # tgqcol_equil_uj == Cij Uij / ( 2 kB sum_{s!=i}(Cis) ) + if suffix == 'j': # total contribution (u + tg) from j + return obj.get_var('tgqcol_equil_uj') + obj.get_var('tgqcol_equil_tgj') + elif suffix in ['u', 'tg']: # total contribution (sum over j); total from u or total from tg + result = obj.get_var('tgqcol_equil_'+suffix+'j', jfluid=(-1, 0)) + for fluid in obj.fluids: + result += obj.get_var('tgqcol_equil_'+suffix+'j', jfluid=fluid) + return result + elif suffix == 'equil': # total contribution u + tg, summed over all j. + return obj.get_var('tgqcol_equil_u') + obj.get_var('tgqcol_equil_tg') + else: + # suffix == 'uj' or 'tgj' + with obj.MaintainFluids(): + # denom = sum_{s!=i}(Cis). [(simu length)^-3 (simu time)^-1] + denom = obj.get_var('qcol_coeffj', jS=-1, cache_with_nfluid=2) # coeff for j = electrons + for fluid in obj.fluids: + denom += obj.get_var('qcol_coeffj', jfluid=fluid, cache_with_nfluid=2) + # Based on suffix, return appropriate term. + if suffix == 'uj': + simu_kB = obj.uni.ksi_b * (obj.uni.usi_nr / obj.uni.usi_e) # kB [simu energy / K] + qcol_uj = obj.get_var('qcol_uj') + temperature_contribution = qcol_uj / (2 * simu_kB) # [K (simu length)^-3 (simu time)^-1] + elif suffix == 'tgj': + coeffj = obj.get_var('qcol_coeffj') + tgj = obj.get_var('tg', ifluid=obj.jfluid) + temperature_contribution = coeffj * tgj # [K (simu length)^-3 (simu time)^-1] + return temperature_contribution / denom # [K] + + elif var == 'qjoulei': + # qjoulei = qi * ni * \vec{ui} dot \vec{E} + # ui is on grid cell faces while E is on grid cell edges. + # We must interpolate to align with energy density e, which is at center of grid cells. + # uix is at (-0.5, 0, 0) while Ex is at (0, -0.5, -0.5) + # --> we shift uix by xup, and Ex by yup zup + result = obj.zero_at_mesh_center() + qi = obj.get_charge(obj.ifluid, units='simu') # [simu charge] + if qi == 0: + return result # there is no contribution if qi is 0. + # else + ni = obj.get_var('nr') # [simu number density] + for x, y, z in [('x', 'y', 'z'), ('y', 'z', 'x'), ('z', 'x', 'y')]: + uix = obj.get_var('ui' + x + x+'up') # [simu velocity] + efx = obj.get_var('ef' + x + y+'up' + z+'up') # [simu electric field] + result += uix * efx + # << at this point, result = ui dot ef + return qi * ni * result + + else: + raise NotImplementedError(f'{repr(var)} in get_heating_quant') + + +# default +_SPITZTERM_QUANT = ('SPITZTERM_QUANT', ['kappaq', 'dxTe', 'dyTe', 'dzTe', 'rhs']) +# get value + + +@document_vars.quant_tracking_simple(_SPITZTERM_QUANT[0]) +def get_spitzerterm(obj, var, SPITZERTERM_QUANT=None): + '''spitzer conductivies''' + if SPITZERTERM_QUANT is None: + SPITZERTERM_QUANT = _SPITZTERM_QUANT[1] + + if var == '': + docvar = document_vars.vars_documenter(obj, _SPITZTERM_QUANT[0], SPITZERTERM_QUANT, get_spitzerterm.__doc__, nfluid='???') + docvar('kappaq', 'Electron thermal diffusivity coefficient [Ebysus units], in SI: W.m-1.K-1') + docvar('dxTe', 'Gradient of electron temperature in the x direction [simu.u_te/simu.u_l] in SI: K.m-1', uni=UNI.quant_child(0)) + docvar('dyTe', 'Gradient of electron temperature in the y direction [simu.u_te/simu.u_l] in SI: K.m-1', uni=UNI.quant_child(0)) + docvar('dzTe', 'Gradient of electron temperature in the z direction [simu.u_te/simu.u_l] in SI: K.m-1', uni=UNI.quant_child(0)) + docvar('rhs', 'Anisotropic gradient of electron temperature following magnetic field, i.e., bb.grad(Te), [simu.u_te/simu.u_l] in SI: K.m-1') + return None + + if var not in SPITZERTERM_QUANT: + return None + + if (var == 'kappaq'): + spitzer_amp = 1.0 + kappa_e = 1.1E-25 + kappaq0 = kappa_e * spitzer_amp + te = obj.get_var('tg', mf_ispecies=-1) # obj.get_var('etg') + result = kappaq0*(te)**(5.0/2.0) + + elif (var == 'dxTe'): + gradx_Te = obj.get_var('dtgdxup', iS=-1) + result = gradx_Te + + elif (var == 'dyTe'): + grady_Te = obj.get_var('dtgdyup', iS=-1) + result = grady_Te + + elif (var == 'dzTe'): + gradz_Te = obj.get_var('dtgdzup', iS=-1) + result = gradz_Te + + elif (var == 'rhs'): + bx = obj.get_var('bx') + by = obj.get_var('by') + bz = obj.get_var('bz') + gradx_Te = obj.get_var('dtgdxup', iS=-1) + grady_Te = obj.get_var('dtgdyup', iS=-1) + gradz_Te = obj.get_var('dtgdzup', iS=-1) + + bmin = 1E-5 + + normb = np.sqrt(bx**2+by**2+bz**2) + norm2bmin = bx**2+by**2+bz**2+bmin**2 + + bbx = bx/normb + bby = by/normb + bbz = bz/normb + + (bmin**2)/norm2bmin + + rhs = bbx*gradx_Te + bby*grady_Te + bbz*gradz_Te + result = rhs + + else: + raise NotImplementedError(f'{repr(var)} in get_spitzterm') + + return result + + +# default +_COLFRE_QUANT = ('COLFRE_QUANT', + ['nu_ij', 'nu_sj', # basics: frequencies + 'nu_si', 'nu_sn', 'nu_ei', 'nu_en', 'nu_ssum', # sum of frequencies + 'nu_ij_el', 'nu_ij_mx', 'nu_ij_cl', # colfreq by type + 'nu_ij_res', 'nu_se_spitzcoul', 'nu_ij_capcoul', # alternative colfreq formulae + 'nu_ij_to_ji', 'nu_sj_to_js', # conversion factor nu_ij --> nu_ji + 'c_tot_per_vol', '1dcolslope', # misc. + ] + ) +# get value + + +@document_vars.quant_tracking_simple(_COLFRE_QUANT[0]) +def get_mf_colf(obj, var, COLFRE_QUANT=None): + '''quantities related to collision frequency. + + Note the collision frequencies here are the momentum transer collision frequencies. + These obey the identity m_a n_a nu_ab = m_b n_b nu_ba. + This identity ensures total momentum (sum over all species) does not change due to collisions. + ''' + + if COLFRE_QUANT is None: + COLFRE_QUANT = _COLFRE_QUANT[1] + + if var == '': + docvar = document_vars.vars_documenter(obj, _COLFRE_QUANT[0], COLFRE_QUANT, get_mf_colf.__doc__, uni=UNI_hz) + mtra = 'momentum transfer collision frequency [simu. frequency units] between ifluid & jfluid. ' + for nu_ij in ['nu_ij', 'nu_sj']: + docvar(nu_ij, mtra + 'Use species<0 for electrons.', nfluid=2) + + sstr = 'sum of momentum transfer collision frequencies [simu. frequency units] between {} & {}.' + docvar('nu_si', sstr.format('ifluid', 'ion fluids (excluding ifluid)'), nfluid=1) + docvar('nu_sn', sstr.format('ifluid', 'neutral fluids (excluding ifluid)'), nfluid=1) + docvar('nu_ei', sstr.format('electrons', 'ion fluids'), nfluid=0) + docvar('nu_en', sstr.format('electrons', 'neutral fluids'), nfluid=0) + docvar('nu_ssum', sstr.format('ifluid', 'all other fluids'), nfluid=1) + docvar('nu_ij_el', 'Elastic ' + mtra, nfluid=2) + docvar('nu_ij_mx', 'Maxwell ' + mtra + 'NOTE: assumes maxwell molecules; result independent of temperatures. ' + + 'presently, only properly implemented when ifluid=H or jfluid=H.', nfluid=2) + docvar('nu_ij_cl', 'Coulomb ' + mtra, nfluid=2) + docvar('nu_ij_res', 'resonant collisions between ifluid & jfluid. ' + + 'presently, only properly implemented for ifluid=H+, jfluid=H.', nfluid=2) + docvar('nu_se_spitzcoul', 'coulomb collisions between s & e-, including spitzer correction. ' + + 'Formula in Oppenheim et al 2020 appendix A eq 4. [simu freq]', nfluid=1) + docvar('nu_ij_capcoul', 'coulomb collisions using Capitelli 2013 formulae. [simu freq]', nfluid=2) + docvar('nu_ij_to_ji', 'nu_ij_to_ji * nu_ij = nu_ji. nu_ij_to_ji = m_i * n_i / (m_j * n_j) = r_i / r_j', + nfluid=2, uni=DIMENSIONLESS) + docvar('nu_sj_to_js', 'nu_sj_to_js * nu_sj = nu_js. nu_sj_to_js = m_s * n_s / (m_j * n_j) = r_s / r_j', + nfluid=2, uni=DIMENSIONLESS) + docvar('1dcolslope', '-(nu_ij + nu_ji)', nfluid=2) + docvar('c_tot_per_vol', 'number density of collisions per volume per time ' + '[simu. number density * simu. frequency] between ifluid and jfluid.', nfluid=2, + uni=UNI_nr * UNI_hz) + return None + + if var not in COLFRE_QUANT: + return None + + # collision frequency between ifluid and jfluid + if var in ['nu_ij', 'nu_sj']: + # TODO: also check mf_param_file tables to see if the collision is turned off. + if obj.match_aux(): + # return constant if constant collision frequency is turned on. + i_elec, j_elec = (obj.mf_ispecies < 0, obj.mf_jspecies < 0) + if i_elec or j_elec: + const_nu_en = obj.get_param('ec_const_nu_en', default=-1.0) + const_nu_ei = obj.get_param('ec_const_nu_ei', default=-1.0) + if const_nu_en >= 0 or const_nu_ei >= 0: # at least one constant collision frequency is turned on. + non_elec_fluid = getattr(obj, '{}fluid'.format('j' if i_elec else 'i')) + non_elec_neutral = obj.get_charge(non_elec_fluid) == 0 # whether the non-electrons are neutral. + + def nu_ij(const_nu): + result = obj.zero_at_mesh_center() + const_nu + if i_elec: + return result + else: + return result * obj.get_var('nu_ij_to_ji', ifluid=obj.jfluid, jfluid=obj.ifluid) + if non_elec_neutral and const_nu_en >= 0: + return nu_ij(const_nu_en) + elif (not non_elec_neutral) and const_nu_ei >= 0: + return nu_ij(const_nu_ei) + # << if we reach this line, we don't have to worry about constant electron colfreq. + coll_type = obj.get_coll_type() # gets 'EL', 'MX', 'CL', or None + if coll_type is not None: + if coll_type[0] == 'EE': # electrons --> use "implied" coll type. + coll_type = coll_type[1] # TODO: add coll_keys to mf_eparams.in?? + nu_ij_varname = 'nu_ij_{}'.format(coll_type.lower()) # nu_ij_el, nu_ij_mx, or nu_ij_cl + return obj.get_var(nu_ij_varname) + elif obj.match_aux() and (obj.get_charge(obj.ifluid) > 0) and (obj.get_charge(obj.jfluid) > 0): + # here, we want to match aux, i and j are ions, and coulomb collisions are turned off. + return obj.zero_at_mesh_center() # so we return zero (instead of making a crash) + else: + errmsg = ("Found no valid coll_keys for ifluid={}, jfluid={}. " + "looked for 'CL' for coulomb collisions, or 'EL' or 'MX' for other collisions. " + "You can enter coll_keys in the COLL_KEYS section in mf_param_file='{}'.") + mf_param_file = obj.get_param('mf_param_file', default='mf_params.in') + raise ValueError(errmsg.format(obj.ifluid, obj.jfluid, mf_param_file)) + + # collision frequency - elastic or coulomb + if var in ['nu_ij_el', 'nu_ij_cl']: + with Caching(obj, nfluid=2) as cache: + iSL = obj.ifluid + jSL = obj.jfluid + # get ifluid info + tgi = obj.get_var('tg', ifluid=iSL) # [K] + m_i = obj.get_mass(iSL[0]) # [amu] + # get jfluid info, then restore original iSL & jSL + with obj.MaintainFluids(): + n_j = obj.get_var('nr', ifluid=jSL) * obj.uni.u_nr # [cm^-3] + tgj = obj.get_var('tg', ifluid=jSL) # [K] + m_j = obj.get_mass(jSL[0]) # [amu] + + # compute some values: + m_jfrac = m_j / (m_i + m_j) # [(dimensionless)] + m_ij = m_i * m_jfrac # [amu] + tgij = (m_i * tgj + m_j * tgi) / (m_i + m_j) # [K] + + # coulomb collisions: + if var.endswith('cl'): + icharge = obj.get_charge(iSL) # [elementary charge == 1] + jcharge = obj.get_charge(jSL) # [elementary charge == 1] + m_h = obj.uni.m_h / obj.uni.amu # [amu] + logcul = obj.get_var('logcul') + scalars = 1.7 * 1/20.0 * (m_h/m_i) * (m_ij/m_h)**0.5 * icharge**2 * jcharge**2 / obj.uni.u_hz + result = scalars * logcul * n_j / tgij**1.5 # [ simu frequency units] + + # elastic collisions: + elif var.endswith('el'): + cross = obj.get_var('cross_physical') # [cm^2] + tg_speed = np.sqrt(8 * (obj.uni.kboltzmann/obj.uni.amu) * tgij / (np.pi * m_ij)) # [cm s^-1] + result = 4./3. * n_j * m_jfrac * cross * tg_speed / obj.uni.u_hz # [simu frequency units] + + # cache result, then return: + cache(var, result) # / 1.0233) + return result # / 1.0233 + + # collision frequency - maxwell + elif var == 'nu_ij_mx': + # set constants. for more details, see eq2 in Appendix A of Oppenheim 2020 paper. + CONST_MULT = 1.96 # factor in front. + CONST_ALPHA_N = 6.67e-31 # [m^3] #polarizability for Hydrogen #(should be different for different species) + e_charge = obj.uni.qsi_electron # [C] #elementary charge + eps0 = 8.854187e-12 # [kg^-1 m^-3 s^4 (C^2 s^-2)] #epsilon0, standard definition + CONST_RATIO = (e_charge / obj.uni.amusi) * (e_charge / eps0) * CONST_ALPHA_N # [C^2 kg^-1 [eps0]^-1 m^3] + # units of CONST_RATIO: [C^2 kg^-1 (kg^1 m^3 s^-2 C^-2) m^-3] = [s^-2] + # get variables. + with obj.MaintainFluids(): + n_j = obj.get_var('nr', ifluid=obj.jfluid) * obj.uni.usi_nr # number density [m^-3] + m_i = obj.get_mass(obj.mf_ispecies) # mass [amu] + m_j = obj.get_mass(obj.mf_jspecies) # mass [amu] + # calculate & return nu_ij_test: + return CONST_MULT * n_j * np.sqrt(CONST_RATIO * m_j / (m_i * (m_i + m_j))) / obj.uni.usi_hz + + # sum of collision frequencies: sum_{i in ions} (nu_{ifluid, i}) + elif var == 'nu_si': + ifluid = obj.ifluid + result = obj.zero_at_mesh_center() + for fluid in obj.fluids.ions(): + if fluid.SL != ifluid: + result += obj.get_var('nu_ij', jfluid=fluid.SL) + return result + + # sum of collision frequencies: sum_{n in neutrals} (nu_{ifluid, n}) + elif var == 'nu_sn': + ifluid = obj.ifluid + result = obj.zero_at_mesh_center() + for fluid in obj.fluids.neutrals(): + if fluid.SL != ifluid: + result += obj.get_var('nu_ij', jfluid=fluid.SL) + return result + + elif var == 'nu_ei': + return obj.get_var('nu_si', mf_ispecies=-1) + + elif var == 'nu_en': + return obj.get_var('nu_sn', mf_ispecies=-1) + + # sum of collision frequencies: sum_{s != ifluid} (nu_{ifluid, s}) + elif var == 'nu_ssum': + return sum(obj('nu_ij', jSL=SL) + for SL in obj.fluid_SLs(with_electrons=True) + if not obj.fluids_equal(obj.ifluid, SL)) + + # collision frequency - resonant charge exchange for H, H+ + elif var == 'nu_ij_res': + # formula assumes we are doing nu_{H+, H} collisions. + # it also happens to be valid for nu_{H, H+}, + # because nu_ij_to_ji for H, H+ is the ratio nH / nH+. + with obj.MaintainFluids(): + nH = obj.get_var('nr', ifluid=obj.jfluid) * obj.uni.usi_nr # [m^-3] + tg = 0.5 * (obj.get_var('tg') + obj.get_var('tg', ifluid=obj.jfluid)) # [K] + return 2.65e-16 * nH * np.sqrt(tg) * (1 - 0.083 * np.log10(tg))**2 / obj.uni.usi_hz + + # collision frequency - spitzer coulomb formula + elif var == 'nu_se_spitzcoul': + icharge = obj.get_charge(obj.ifluid) + assert icharge > 0, "ifluid must be ion, but got charge={} (ifluid={})".format(icharge, obj.ifluid) + # nuje = me pi ne e^4 ln(12 pi ne ldebye^3) / ( ms (4 pi eps0)^2 sqrt(ms (2 kb T)^3) ) + ldebye = obj.get_var('ldebye') * obj.uni.usi_l + me = obj.uni.msi_e + tg = obj.get_var('tg') + ms = obj.get_mass(obj.mf_ispecies, units='si') + eps0 = obj.uni.permsi + kb = obj.uni.ksi_b + qe = obj.uni.qsi_electron + ne = obj.get_var('nr', mf_ispecies=-1) * obj.uni.usi_nr # [m^-3] + # combine numbers in a way that will prevent extremely large or small values: + const = (1 / (16 * np.pi)) * (qe / eps0)**2 * (qe / kb) * (qe / np.sqrt(kb)) + mass_ = me / ms * 1 / np.sqrt(ms) + ln_ = np.log(12 * np.pi * ne) + 3 * np.log(ldebye) + nuje0 = (const * ne) * mass_ * ln_ / (2 * tg)**(3/2) + + # try again but with logs. Run this code to confirm that the above code is correct. + run_confirmation_routine = False # change to True to run this code. + if run_confirmation_routine: + ln = np.log + tmp1 = ln(me) + ln(np.pi) + ln(ne) + 4*ln(qe) + ln(ln(12) + ln(np.pi) + ln(ne) + 3*ln(ldebye)) # numerator + tmp2 = ln(ms) + 2*(ln(4) + ln(np.pi) + ln(eps0)) + 0.5*(ln(ms) + 3*(ln(2) + ln(kb) + ln(tg))) # denominator + tmp = tmp1 - tmp2 + nuje1 = np.exp(tmp) + print('we expect these to be approximately equal:', nuje0.mean(), nuje1.mean()) + return nuje0 / obj.uni.usi_hz + + # collision frequency - capitelli coulomb formula + elif var == 'nu_ij_capcoul': + iSL = obj.ifluid + jSL = obj.jfluid + icharge = obj.get_charge(iSL, units='si') # [C] + jcharge = obj.get_charge(jSL, units='si') # [C] + assert icharge != 0 and jcharge != 0, 'we require i & j both charged' +\ + ' but got icharge={}, jcharge={}'.format(icharge, jcharge) + + # get ifluid info + tgi = obj.get_var('tg', ifluid=iSL) # [K] + m_i = obj.get_mass(iSL[0]) # [amu] + # get jfluid info, then restore original iSL & jSL + with obj.MaintainFluids(): + n_j = obj.get_var('nr', ifluid=jSL) * obj.uni.usi_nr # [m^-3] + tgj = obj.get_var('tg', ifluid=jSL) # [K] + m_j = obj.get_mass(jSL[0]) # [amu] + + # compute some values: + m_jfrac = m_j / (m_i + m_j) # [(dimensionless)] + m_ij = m_i * m_jfrac # [amu] # e.g. for H, H+, m_ij = 0.5. + tgij = (m_i * tgj + m_j * tgi) / (m_i + m_j) # [K] + + tg_speed = np.sqrt(8 * (obj.uni.ksi_b/obj.uni.amusi) * tgij / (np.pi * m_ij)) # [m s^-1] + E_alpha = 0.5 * (m_ij * obj.uni.amusi) * tg_speed**2 + + euler_constant = 0.577215 + b_0 = abs(icharge*jcharge)/(4 * np.pi * obj.uni.permsi * E_alpha) # [m] # permsi == epsilon_0 + # b_0 = abs(icharge*jcharge)/(2 * obj.uni.ksi_b*obj.uni.permsi * tgij) # [m] # permsi == epsilon_0 + cross = np.pi*2.0*(b_0**2)*(np.log(2.0*obj.get_var('ldebye')*obj.uni.usi_l/b_0)-2.0*euler_constant) # [m2] + # Before we had np.log(2.0*obj.get_var('ldebye')*obj.uni.usi_l/b_0)-0.5-2.0*euler_constant + # Not sure about the coefficient 0.5 from Capitelli et al. (2000). Should be only euler constant according to Liboff (1959, eq. 4.28) + + # calculate & return nu_ij: + nu_ij = 4./3. * n_j * m_jfrac * cross * tg_speed / obj.uni.u_hz # [simu frequency units] + return nu_ij + + # collision frequency conversion factor: nu_ij to nu_ji + elif var in ['nu_ij_to_ji', 'nu_sj_to_js']: + mi_ni = obj.get_var('ri', ifluid=obj.ifluid) # mi * ni = ri + mj_nj = obj.get_var('ri', ifluid=obj.jfluid) # mj * nj = rj + return mi_ni / mj_nj + + elif var == "c_tot_per_vol": + m_i = obj.get_mass(obj.mf_ispecies) # [amu] + m_j = obj.get_mass(obj.mf_jspecies) # [amu] + return obj.get_var("nr", ifluid=obj.jfluid) * obj.get_var("nu_ij") / (m_j / (m_i + m_j)) + + elif var == "1dcolslope": + if obj.verbose: + warnings.warn(DeprecationWarning('1dcolslope will be removed at some point in the future.')) + return -1 * obj.get_var("nu_ij") * (1 + obj.get_var('nu_ij_to_ji')) + + else: + raise NotImplementedError(f'{repr(var)} in get_mf_colf') + + +# default +_LOGCUL_QUANT = ('LOGCUL_QUANT', ['logcul']) +# get value + + +@document_vars.quant_tracking_simple(_LOGCUL_QUANT[0]) +def get_mf_logcul(obj, var, LOGCUL_QUANT=None): + '''coulomb logarithm''' + if LOGCUL_QUANT is None: + LOGCUL_QUANT = _LOGCUL_QUANT[1] + + if var == '': + docvar = document_vars.vars_documenter(obj, _LOGCUL_QUANT[0], LOGCUL_QUANT, get_mf_logcul.__doc__) + docvar('logcul', 'Coulomb Logarithmic used for Coulomb collisions.', nfluid=0, uni=DIMENSIONLESS) + return None + + if var not in LOGCUL_QUANT: + return None + + if var == "logcul": + etg = obj.get_var('tg', mf_ispecies=-1) + nel = obj.get_var('nel') + return 23. + 1.5 * np.log(etg / 1.e6) - \ + 0.5 * np.log(nel / 1e6) + + else: + raise NotImplementedError(f'{repr(var)} in get_logcul') + + +# default +_CROSTAB_QUANT = ('CROSTAB_QUANT', ['cross', 'cross_physical', 'tgij']) +# get value + + +@document_vars.quant_tracking_simple(_CROSTAB_QUANT[0]) +def get_mf_cross(obj, var, CROSTAB_QUANT=None): + '''cross section between species.''' + if CROSTAB_QUANT is None: + CROSTAB_QUANT = _CROSTAB_QUANT[1] + + if var == '': + docvar = document_vars.vars_documenter(obj, _CROSTAB_QUANT[0], CROSTAB_QUANT, get_mf_cross.__doc__, nfluid=2) + docvar('cross', 'cross section between ifluid and jfluid [cm^2]. Use species < 0 for electrons.', + uni_name=UNI_length.name**2, ucgs_f=UNITS_FACTOR_1, usi_f=UCONST.cm_to_m**2) + docvar('cross_physical', "cross section between ifluid and jfluid [cm^2]. " + + "Always returns physical value, regardless of match_type. (As opposed to 'cross' " + + "which gives 0 for ifluid > jfluid, in order to match aux.)", + uni_name=UNI_length.name**2, ucgs_f=UNITS_FACTOR_1, usi_f=UCONST.cm_to_m**2) + docvar('tgij', 'mass-weighted temperature: (Ti mj + Tj mi) / (mi + mj)', + uni=U_TUPLE(UNITS_FACTOR_1, Usym('K'))) + return None + + if var not in CROSTAB_QUANT: + return None + + if (var == 'cross') and obj.match_aux(): + # return 0 if ifluid > jfluid. (comparing species, then level if species are equal) + # we do this because mm_cross gives 0 if ifluid > jfluid (and jfluid is not electrons)) + if (obj.ifluid > obj.jfluid) and obj.mf_jspecies > 0: + return obj.zero_at_mesh_center() + + # get masses & temperatures, then restore original obj.ifluid and obj.jfluid values. + with obj.MaintainFluids(): + m_i = obj.get_mass(obj.mf_ispecies) + m_j = obj.get_mass(obj.mf_jspecies) + tgi = obj.get_var('tg', ifluid=obj.ifluid) + tgj = obj.get_var('tg', ifluid=obj.jfluid) + + # temperature, weighted by mass of species + tg = (tgi*m_j + tgj*m_i)/(m_i + m_j) + if var == 'tgij': + return tg + else: + # look up cross table and get cross section + #crossunits = 2.8e-17 + try: + crossobj = obj.get_cross_sect(ifluid=obj.ifluid, jfluid=obj.jfluid) + except ValueError: # we failed to get the cross section. + if obj.match_aux(): + if (obj.get_charge(obj.iSL) < 0) or (obj.get_charge(obj.jSL) < 0): # one of them is electrons: + cross = obj.zero_at_mesh_center() # use 0 for cross section if match_aux and there was no defined cross section. + else: + errmsg = "expected this case was handled during get_var('nu_ij'), if getting a collision frequency." + raise NotImplementedError(errmsg) + else: + raise # raise the original error + else: + crossunits = crossobj.cross_tab[0]['crossunits'] + cross = crossunits * crossobj.tab_interp(tg) + + return cross + + +# default +_DRIFT_QUANT = ['ed', 'rd', 'tgd'] +_DRIFT_QUANT += [dq + x for dq in ('ud', 'pd', 'uid') for x in AXES] +_DRIFT_QUANT = ('DRIFT_QUANT', _DRIFT_QUANT) +# get value + + +@document_vars.quant_tracking_simple(_DRIFT_QUANT[0]) +def get_mf_driftvar(obj, var, DRIFT_QUANT=None): + '''var drift between fluids. I.e. var_ifluid - var_jfluid.''' + if DRIFT_QUANT is None: + DRIFT_QUANT = _DRIFT_QUANT[1] + + if var == '': + docvar = document_vars.vars_documenter(obj, _DRIFT_QUANT[0], DRIFT_QUANT, get_mf_driftvar.__doc__, + nfluid=2, uni=UNI.quant_child(0)) + + def doc_start(var): + return '"drift" for quantity "{var}". I.e. ({var} for ifluid) - ({var} for jfluid). '.format(var=var) + for x in AXES: + docvar('ud'+x, doc_start(var='u'+x) + 'u = velocity [simu. units].') + for x in AXES: + docvar('uid'+x, doc_start(var='ui'+x) + 'ui = velocity [simu. units].') + for x in AXES: + docvar('pd'+x, doc_start(var='p'+x) + 'p = momentum density [simu. units].') + docvar('ed', doc_start(var='ed') + 'e = energy (density??) [simu. units].') + docvar('rd', doc_start(var='rd') + 'r = mass density [simu. units].') + docvar('tgd', doc_start(var='tgd') + 'tg = temperature [K].') + return None + + if var not in DRIFT_QUANT: + return None + + else: + if var[-1] == 'd': # scalar drift quant e.g. tgd + quant = var[:-1] # "base quant"; without d. e.g. tg + elif var[-2] == 'd': # vector drift quant e.g. uidx + quant = var[:-2] + var[-1] # "base quant"; without d e.g. uix + + q_i = obj.get_var(quant, ifluid=obj.ifluid) + q_j = obj.get_var(quant, ifluid=obj.jfluid) + return q_i - q_j + + +# default +_MEAN_QUANT = ('MEAN_QUANT', + ['neu_meannr_mass', 'ion_meannr_mass', + ] + ) +# get value + + +@document_vars.quant_tracking_simple(_MEAN_QUANT[0]) +def get_mean_quant(obj, var, MEAN_QUANT=None): + '''weighted means of quantities.''' + if MEAN_QUANT is None: + MEAN_QUANT = _MEAN_QUANT[1] + + if var == '': + docvar = document_vars.vars_documenter(obj, _MEAN_QUANT[0], MEAN_QUANT, get_mean_quant.__doc__) + docvar('neu_meannr_mass', 'number density weighted mean mass of neutrals.' + ' == sum_n(mass_n * nr_n) / sum_n(nr_n). [simu mass units]', + nfluid=0, uni_name=UsymD(usi='kg', ucgs='g'), uni_f=UNI.m) + docvar('ion_meannr_mass', 'number density weighted mean mass of ions.' + ' == sum_i(mass_i * nr_i) / sum_i(nr_i). [simu mass units]', + nfluid=0, uni_name=UsymD(usi='kg', ucgs='g'), uni_f=UNI.m) + return None + + if var not in MEAN_QUANT: + return None + + if var.endswith('_meannr_mass'): + neu = var[:-len('_meannr_mass')] + fluids = obj.fluids + if neu == 'neu': + fluids = fluids.neutrals() + elif neu == 'ion': + fluids = fluids.ions() + else: + raise NotImplementedError('only know _meannr_mass for neu or ion but got {}'.format(neu)) + numer = obj.zero_at_mesh_center() + denom = obj.zero_at_mesh_center() + for fluid in fluids: + r = obj.get_var('r', ifluid=fluid) + m = obj.get_mass(fluid, units='simu') + numer += r + denom += r / m + return numer / denom + + else: + raise NotImplementedError(f'{repr(var)} in get_mean_quant') + + +# default +_CFL_QUANTS = ['ohm'] +_CFL_QUANT = ['cfl_' + q for q in _CFL_QUANTS] +_CFL_QUANT = ('CFL_QUANT', _CFL_QUANT) +# get value + + +@document_vars.quant_tracking_simple(_CFL_QUANT[0]) +def get_cfl_quant(obj, quant, CFL_QUANT=None): + '''CFL quantities. All are in simu. frequency units.''' + if CFL_QUANT is None: + CFL_QUANT = _CFL_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _CFL_QUANT[0], CFL_QUANT, get_cfl_quant.__doc__) + docvar('cfl_ohm', 'cfl condition for ohmic module. (me / ms) ((qs / qe) + (ne / ns)) nu_es', nfluid=1) + return None + + _, cfl_, quant = quant.partition('cfl_') + if quant == '': + return None + + elif quant == 'ohm': + fluid = obj.ifluid + nrat = obj.get_var('nr', iS=-1) / obj.get_var('nr', ifluid=fluid) # ne / ns + mrat = obj.uni.msi_electron / obj.get_mass(fluid, units='si') # me / ms + qrat = obj.get_charge(fluid) / -1 # qs / qe + nu_es = obj.get_var('nu_ij', iS=-1, jfluid=fluid) # nu_es + return mrat * (qrat + nrat) * nu_es + + else: + raise NotImplementedError(f'{repr(quant)} in get_cfl_quant') + + +# default +_PLASMA_QUANT = ('PLASMA_QUANT', + ['beta', 'beta_ions', 'va', 'va_ions', 'vai', 'cs', 's', 'ke', 'mn', 'man', 'hp', + 'vax', 'vay', 'vaz', 'hx', 'hy', 'hz', 'kx', 'ky', 'kz', + 'sgyrof', 'gyrof', 'skappa', 'kappa', 'ldebye', 'ldebyei', + 'meanfreepath', 'gyroradius', + ] + ) +# get value + + +@document_vars.quant_tracking_simple(_PLASMA_QUANT[0]) +def get_mf_plasmaparam(obj, quant, PLASMA_QUANT=None): + '''plasma parameters, e.g. plasma beta, sound speed, pressure scale height''' + if PLASMA_QUANT is None: + PLASMA_QUANT = _PLASMA_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _PLASMA_QUANT[0], PLASMA_QUANT, get_mf_plasmaparam.__doc__) + docvar('beta', "plasma beta", nfluid='???', uni=DIMENSIONLESS) # nfluid= 1 if mfe_p = p_ifluid; 0 if mfe_p = sum of pressures. + docvar('beta_ions', "plasma beta using sum of ion pressures. P / (B^2 / (2 mu0)).", nfluid=0, uni=DIMENSIONLESS) + docvar('va', "alfven speed [simu. units]", nfluid=0, uni=UNI_speed) + docvar('va_ions', "alfven speed [simu. units], using density := density of ions.", nfluid=0, uni=UNI_speed) + docvar('vai', "alfven speed [simu. units] of ifluid. Vai = sqrt(B^2 / (mu0 * rho_i))", nfluid=1, uni=UNI_speed) + docvar('cs', "sound speed [simu. units]", nfluid='???', uni=UNI_speed) + docvar('csi', "sound speed [simu. units] of ifluid. Csi = sqrt(gamma * pressure_i / rho_i)", nfluid=1, uni=UNI_speed) + docvar('cfast', "Cfast for ifluid. == (Csi**2 + Vai**2 + Cse**2)?? NEEDS UPDATING.", nfluid=1, uni=UNI_speed) + docvar('s', "entropy [log of quantities in simu. units]", nfluid='???', uni=DIMENSIONLESS) + docvar('mn', "mach number (using sound speed)", nfluid=1, uni=DIMENSIONLESS) + docvar('man', "mach number (using alfven speed)", nfluid=1, uni=DIMENSIONLESS) + docvar('hp', "Pressure scale height", nfluid='???') + for x in AXES: + docvar('va'+x, x+"-component of alfven velocity [simu. units]", nfluid=0, uni=UNI_speed) + for x in AXES: + docvar('k'+x, ("{axis} component of kinetic energy density of ifluid [simu. units]." + + "(0.5 * rho * (get_var(u{axis})**2)").format(axis=x), nfluid=1, **units_e) + docvar('sgyrof', "signed gryofrequency for ifluid. I.e. qi * |B| / mi. [1 / (simu. time units)]", nfluid=1, uni=UNI_hz) + docvar('gyrof', "gryofrequency for ifluid. I.e. abs(qi * |B| / mi). [1 / (simu. time units)]", nfluid=1, uni=UNI_hz) + kappanote = ' "Highly magnetized" when kappa^2 >> 1.' + docvar('skappa', "signed magnetization for ifluid. I.e. sgryof/nu_sn." + kappanote, nfluid=1, uni=DIMENSIONLESS) + docvar('kappa', "magnetization for ifluid. I.e. gyrof/nu_sn." + kappanote, nfluid=1, uni=DIMENSIONLESS) + docvar('ldebyei', "debye length of ifluid [simu. length units]. sqrt(kB eps0 q^-2 Ti / ni)", nfluid=1, uni=UNI_length) + docvar('ldebye', "debye length of plasma [simu. length units]. " + + "sqrt(kB eps0 e^-2 / (ne/Te + sum_j(Zj^2 * nj / Tj)) ); Zj = qj/e" + + "1/sum_j( (1/ldebye_j) for j in fluids and electrons)", nfluid=0, uni=UNI_length) + docvar('meanfreepath', "mean free path of particles of ifluid. = |ui| / sum_j(nu_ij).", nfluid=1, uni=UNI_length) + docvar('gyroradius', "gyroradius for ifluid. I.e. |ui| / abs(qi * |B| / mi)", nfluid=1, uni=UNI_length) + return None + + if quant not in PLASMA_QUANT: + return None + + if quant in ['hp', 's', 'cs', 'beta']: + var = obj.get_var('mfe_p') # is mfe_p pressure for ifluid, or sum of all fluid pressures? - SE Apr 19 2021 + if quant == 'hp': + if getattr(obj, 'nx') < 5: + return obj.zero() + else: + return 1. / (do_stagger(var, 'ddzup', obj=obj) + 1e-12) + elif quant == 'cs': + return np.sqrt(obj.params['gamma'][obj.snapInd] * + var / obj.get_var('totr')) + elif quant == 's': + return (np.log(var) - obj.params['gamma'][obj.snapInd] * + np.log(obj.get_var('totr'))) + else: # quant == 'beta': + return 2 * var / obj.get_var('b2') + + elif quant == 'csi': + p = obj('p') + r = obj('r') + return np.sqrt(obj.uni.gamma * p / r) + + elif quant == 'cfast': + warnings.warn('cfast implementation may be using the wrong formula.') + speeds = [obj('csi')] # sound speed + i_charged = obj.get_charge(obj.ifluid) != 0 + if i_charged: + speeds.append(obj('vai')) # alfven speed + if not obj.fluids_equal(obj.ifluid, (-1, 0)): # if ifluid is not electrons + speeds.append(obj('csi', ifluid=(-1, 0))) # sound speed of electrons + result = sum(speed**2 for speed in speeds) + return result + + elif quant == 'beta_ions': + p = obj.zero() + for fluid in obj.fluids.ions(): + p += obj.get_var('p', ifluid=fluid) + bp = obj.get_var('b2') / 2 # (dd.uni.usi_b**2 / dd.uni.mu0si) == 1 by def'n of b in ebysus. + return p / bp + + elif quant in ['mn', 'man']: + var = obj.get_var('modu') + if quant == 'mn': + return var / (obj.get_var('cs') + 1e-12) + else: + return var / (obj.get_var('va') + 1e-12) + + elif quant in ['va', 'vax', 'vay', 'vaz']: + var = obj.get_var('totr') + if len(quant) == 2: + return obj.get_var('modb') / np.sqrt(var) + else: + axis = quant[-1] + return np.sqrt(obj.get_var('b' + axis + 'c') ** 2 / var) + + elif quant in ['va_ions']: + r = obj.get_var('rions') + return obj.get_var('modb') / np.sqrt(r) + + elif quant == 'vai': + r = obj('r') + b = obj('modb') + return b / np.sqrt(r) # [simu speed units]. note: mu0 = 1 in simu units. + + elif quant in ['hx', 'hy', 'hz', 'kx', 'ky', 'kz']: + axis = quant[-1] + var = obj.get_var('p' + axis + 'c') + if quant[0] == 'h': + # anyone can delete this warning once you have confirmed that get_var('hx') does what you think it should: + warnmsg = ('get_var(hx) (or hy or hz) uses get_var(p), and used it since before get_var(p) was implemented. ' + 'Maybe should be using get_var(mfe_p) instead? ' + 'You should not trust results until you check this. - SE Apr 19 2021.') + if obj.verbose: + warnings.warn(warnmsg) + return ((obj.get_var('e') + obj.get_var('p')) / + obj.get_var('r') * var) + else: + return obj.get_var('u2') * var * 0.5 + + elif quant == 'sgyrof': + B = obj.get_var('modb') # magnitude of B [simu. B-field units] + q = obj.get_charge(obj.ifluid, units='simu') # [simu. charge units] + m = obj.get_mass(obj.mf_ispecies, units='simu') # [simu. mass units] + return q * B / m + + elif quant == 'gyrof': + return np.abs(obj.get_var('sgyrof')) + + elif quant == 'skappa': + gyrof = obj.get_var('sgyrof') # [simu. freq.] + nu_sn = obj.get_var('nu_sn') # [simu. freq.] + return gyrof / nu_sn + + elif quant == 'kappa': + return np.abs(obj.get_var('skappa')) + + elif quant == 'ldebyei': + Zi2 = obj.get_charge(obj.ifluid)**2 + if Zi2 == 0: + return obj.zero_at_mesh_center() + const = obj.uni.permsi * obj.uni.ksi_b / obj.uni.qsi_electron**2 + tg = obj.get_var('tg') # [K] + nr = obj.get_var('nr') * obj.uni.usi_nr # [m^-3] + ldebsi = np.sqrt(const * tg / (nr * Zi2)) # [m] + return ldebsi / obj.uni.usi_l # [simu. length units] + + elif quant == 'ldebye': + # ldebye = 1/sum_j( (1/ldebye_j) for j in fluids and electrons) + ldeb_inv_sum = 1/obj.get_var('ldebyei', mf_ispecies=-1) + for fluid in obj.fluids.ions(): + ldeb_inv_sum += 1/obj.get_var('ldebyei', ifluid=fluid.SL) + return 1/ldeb_inv_sum + + elif quant == 'meanfreepath': + ui = obj('ui_mod') + nu = obj('nu_ssum') + return ui / nu + + elif quant == 'gyroradius': + ui = obj('ui_mod') + omega = obj('gyrof') + return ui / omega + + else: + raise NotImplementedError(f'{repr(quant)} in get_mf_plasmaparam') + + +# default +_FUNDAMENTALS = ('r', 'px', 'py', 'pz', 'e', 'bx', 'by', 'bz') +_HD_Fs = ('part', # part --> only get the internal part. e.g. nu1 * Cfast. + *_FUNDAMENTALS) +_HD_QUANTS = ['hd1_part', 'hd2_part'] # << without the factor of nu1, nu2 +_HD_QUANTS += ['hd1_partnu', 'hd2_partnu'] # << include the factor of nu1, nu2 +_HD_QUANTS += [f'hd3{x}_part' for x in AXES] + [f'hd3{x}_bpart' for x in AXES] # << without the factor of nu3 +_HD_QUANTS += [f'hd3{x}_partnu' for x in AXES] + [f'hd3{x}_bpartnu' for x in AXES] # << include the factor of nu3 +_HD_QUANTS += [f'hd{x}quench_{f}' for x in AXES for f in _FUNDAMENTALS] # Q(∂f/∂x) +_HD_QUANTS += [f'hd{x}coeff_{f}' for x in AXES for f in _FUNDAMENTALS] # nu dx (∂f/∂x) * Q(∂f/∂x) +_HD_QUANTS += [f'{d}hd{n}{x}_{f}' for d in ('', 'd') # E.g. hd1x_r == hd1_part * nu dx (∂r/∂x) * Q(∂r/∂x) + for n in (1, 2, 3) # and dhd1x_r == ∂[hd1_part * nu dx (∂r/∂x) * Q(∂r/∂x)]/∂x + for x in AXES + for f in _FUNDAMENTALS] +_HYPERDIFFUSIVE_QUANT = ('HYPERDIFFUSIVE_QUANT', _HD_QUANTS) +# get value + + +@document_vars.quant_tracking_simple(_HYPERDIFFUSIVE_QUANT[0]) +def get_hyperdiffusive_quant(obj, quant, HYPERDIFFUSIVE_QUANT=None): + '''hyperdiffusive terms. All in simu units.''' + if HYPERDIFFUSIVE_QUANT is None: + HYPERDIFFUSIVE_QUANT = _HYPERDIFFUSIVE_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _HYPERDIFFUSIVE_QUANT[0], HYPERDIFFUSIVE_QUANT, + get_hyperdiffusive_quant.__doc__, nfluid=1) + docvar('hd1_part', 'Cfast_i', uni=UNI_speed) + docvar('hd1_partnu', 'nu1 * Cfast_i', uni=UNI_speed) + docvar('hd2_part', '|ui|', uni=UNI_speed) + docvar('hd2_partnu', 'nu2 * |ui|', uni=UNI_speed) + for x in AXES: + docvar(f'hd3{x}_part', f'd{x} * grad1{x}(ui{x})', uni=UNI_speed) + docvar(f'hd3{x}_partnu', f'nu3 * d{x} * grad1{x}(ui{x})', uni=UNI_speed) + docvar(f'hd3{x}_bpart', f'd{x} * |grad1_perp_to_b(ui{x})|', uni=UNI_speed) + docvar(f'hd3{x}_bpartnu', f'nu3 * d{x} * |grad1_perp_to_b(ui{x})|', uni=UNI_speed) + for x in AXES: + for f in _FUNDAMENTALS: + docvar(f'hd{x}quench_{f}', f'Q(∂{f}/∂{x})', uni=DIMENSIONLESS) + docvar(f'hd{x}coeff_{f}', f'nu d{x} (∂{f}/∂{x}) * Q(∂{f}/∂{x})', uni=UNI.qc(0)) + for x in AXES: + for n in (1, 2, 3): + for f in _FUNDAMENTALS: + if n == 3 and f.startswith('b'): + docvar(f'hd{n}{x}_{f}', f'nu{n} * hd{n}_bpart * hd{x}coeff_{f}', ) # uni=[TODO] + else: + docvar(f'hd{n}{x}_{f}', f'nu{n} * hd{n}_part * hd{x}coeff_{f}', ) # uni=[TODO] + docvar(f'dhd{n}{x}_{f}', f'∂[hd{n}{x}_{f}]/∂{x}', ) + return None + + if quant not in HYPERDIFFUSIVE_QUANT: + return None + + # nu1 term + if quant == 'hd1_part': + return obj('cfast') + elif quant == 'hd1_partnu': + return obj('hd1_part') * obj.get_param('nu1') + + # nu2 term + elif quant == 'hd2_part': + return obj('ui_mod') + elif quant == 'hd2_partnu': + return obj('hd2_part') * obj.get_param('nu2') + + # nu3 term + elif quant.startswith('hd3') and quant in (f'hd3{x}_part' for x in AXES): + x = quant[len('hd3')+0] # 'x', 'y', or 'z' + # dx * grad1x (uix) + raise NotImplementedError(f'hd3{x}_part') + + elif quant.startswith('hd3') and quant in (f'hd3{x}_bpart' for x in AXES): + x = quant[len('hd3')+0] # 'x', 'y', or 'z' + # dx * |grad1_perp_to_b(ui{x})| + raise NotImplementedError(f'hd3{x}_bpart') + + elif quant.starstwith('hd3') and quant in (f'hd3{x}_{b}partnu' for x in AXES for b in ('', 'b')): + part_without_nu = quant[:-len('nu')] + return obj(part_without_nu) * obj.get_param('nu3') + + # quench term + elif quant.startswith('hd') and quant.partition('_')[0] in (f'hd{x}quench' for x in AXES): + base, _, f = quant.partition('_') + x = base[len('hd')] # 'x', 'y', or 'z' + fval = obj(f) # value of f, e.g. value of 'r' or 'bx' + # Q(∂f/∂x) + raise NotImplementedError(f'hd{x}quench_{f}') + + # coeff term + elif quant.startswith('hd') and quant.partition('_')[0] in (f'hd{x}coeff' for x in AXES): + base, _, f = quant.partition('_') + x = base[len('hd')] # 'x', 'y', or 'z' + nu = NotImplemented # << TODO + dx = obj.dx # << TODO (allow to vary in space) + quench = obj(f'hd{x}quench_{f}') + dfdx = obj(f'd{f}dxdn') + return nu * dx * dfdx * quench + + # full hd term + elif quant.startswith('hd') and quant.partition('_')[0] in (f'hd{n}{x}' for x in AXES for n in (1, 2, 3)): + base, _, f = quant.partition('_') + n, x = base[2:4] + nu = obj.get_param(f'nu{n}') + if n == 3 and f.startswith('b'): + hd_part = obj('hd3_bpart') + else: + hd_part = obj(f'hd{n}_part') + coeff = obj(f'hd{x}coeff_{f}') + return nu * hd_part * coeff + + # full hd term, with derivative + elif quant.startswith('dhd') and quant.partition('_')[0] in (f'dhd{n}{x}' for x in AXES for n in (1, 2, 3)): + quant_str = quant[1:] + return obj('d'+quant_str+'dxdn') + + +# default +_WAVE_QUANT = ('WAVE_QUANT', + ['ci', 'fplasma', 'kmaxx', 'kmaxy', 'kmaxz'] + ) +# get value + + +@document_vars.quant_tracking_simple(_WAVE_QUANT[0]) +def get_mf_wavequant(obj, quant, WAVE_QUANT=None): + '''quantities related most directly to waves in plasmas.''' + if WAVE_QUANT is None: + WAVE_QUANT = _WAVE_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _WAVE_QUANT[0], WAVE_QUANT, get_mf_wavequant.__doc__) + docvar('ci', "ion acoustic speed for ifluid (must be ionized) [simu. velocity units]", + nfluid=1, uni=UNI_speed) + docvar('fplasma', "('angular') plasma frequency for ifluid (must be charged) [simu. frequency units]. " + + "== sqrt(n_i q_i**2 / (epsilon_0 m_i))", nfluid=1, uni=UNI_hz) + for x in AXES: + docvar('kmax'+x, "maximum resolvable wavevector in "+x+" direction. Determined via 2*pi/obj.d"+x+"1d", + nfluid=0, uni=UNI_length) + return None + + if quant not in _WAVE_QUANT[1]: + return None + + if quant == 'ci': + assert obj.mf_ispecies != -1, "ifluid {} must be ion to get ci, but got electron.".format(obj.ifluid) + fluids = obj.fluids + ion = fluids[obj.ifluid] + assert ion.ionization >= 1, "ifluid {} is not ionized; cannot get ci (==ion acoustic speed).".format(obj.ifluid) + # (we only want to get ion acoustic speed for ions; it doesn't make sense to get it for neutrals.) + tg_i = obj.get_var('tg') # [K] temperature of fluid + tg_e = obj.get_var('tg', mf_ispecies=-1) # [K] temperature of electrons + igamma = obj.uni.gamma # gamma (ratio of specific heats) of fluid + egamma = obj.uni.gamma # gamma (ratio of specific heats) of electrons + m_i = obj.get_mass(ion, units='si') # [kg] mass of ions + ci = np.sqrt(obj.uni.ksi_b * (ion.ionization * igamma * tg_i + egamma * tg_e) / m_i) + ci_sim = ci / obj.uni.usi_u + return ci_sim + + elif quant == 'fplasma': + q = obj.get_charge(obj.ifluid, units='si') + assert q != 0, "ifluid {} must be charged to get fplasma.".format(obj.ifluid) + m = obj.get_mass(obj.ifluid, units='si') + eps0 = obj.uni.permsi + n = obj('nr') + unit = 1 / obj.uni.usi_hz # convert from si frequency to ebysus frequency. + consts = np.sqrt(q**2 / (eps0 * m)) * unit + return consts * np.sqrt(n) # [ebysus frequency units] + + elif quant in ['kmaxx', 'kmaxy', 'kmaxz']: + x = quant[-1] # axis; 'x', 'y', or 'z'. + xidx = dict(x=0, y=1, z=2)[x] # axis; 0, 1, or 2. + dx1d = getattr(obj, 'd'+x+'1d') # 1D; needs dims to be added. add dims below. + dx1d = np.expand_dims(dx1d, axis=tuple(set((0, 1, 2)) - set([xidx]))) + return (2 * np.pi / dx1d) + obj.zero() + + else: + raise NotImplementedError(f'{repr(quant)} in get_mf_wavequant') + + +# default +_FB_INSTAB_QUANT = ['psi0', 'psii', 'vde', 'fb_ssi_vdtrigger', 'fb_ssi_possible', + 'fb_ssi_freq', 'fb_ssi_growth_rate'] +_FB_INSTAB_VECS = ['fb_ssi_freq_max', 'fb_ssi_growth_rate_max', 'fb_ssi_growth_time_min'] +_FB_INSTAB_QUANT += [v+x for v in _FB_INSTAB_VECS for x in AXES] +_FB_INSTAB_QUANT = ('FB_INSTAB_QUANT', _FB_INSTAB_QUANT) +# get value + + +@document_vars.quant_tracking_simple(_FB_INSTAB_QUANT[0]) +def get_fb_instab_quant(obj, quant, FB_INSTAB_QUANT=None): + '''very specific quantities which are related to the Farley-Buneman instability.''' + if FB_INSTAB_QUANT is None: + FB_INSTAB_QUANT = _FB_INSTAB_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _FB_INSTAB_QUANT[0], FB_INSTAB_QUANT, get_fb_instab_quant.__doc__) + for psi in ['psi0', 'psii']: + docvar(psi, 'psi_i when k_parallel==0. equals to: (kappa_e * kappa_i)^-1.', nfluid=1, uni=DIMENSIONLESS) + docvar('vde', 'electron drift velocity. equals to: |E|/|B|. [simu. velocity units]', nfluid=0, uni=UNI_speed) + docvar('fb_ssi_vdtrigger', 'minimum vde [in simu. velocity units] above which the FB instability can grow, ' + + 'in the case of SSI (single-species-ion). We assume ifluid is the single ion species.', nfluid=1, uni=UNI_speed) + docvar('fb_ssi_possible', 'whether SSI Farley Buneman instability can occur (vde > fb_ssi_vdtrigger). ' + + 'returns an array of booleans, with "True" meaning "can occur at this point".', nfluid=1, uni=DIMENSIONLESS) + docvar('fb_ssi_freq', 'SSI FB instability wave frequency (real part) divided by wavenumber (k). ' + + 'assumes wavevector in E x B direction. == Vd / (1 + psi0). ' + + 'result is in units of [simu. frequency * simu. length].', nfluid=2, uni=UNI_speed) + docvar('fb_ssi_growth_rate', 'SSI FB instability growth rate divided by wavenumber (k) squared. ' + + 'assumes wavevector in E x B direction. == (Vd^2/(1+psi0)^2 - Ci^2)/(nu_in*(1+1/psi0)). ' + + 'result is in units of [simu. frequency * simu. length^2].', nfluid=2, uni=UNI_hz * UNI_length**2) + for x in AXES: + docvar('fb_ssi_freq_max'+x, 'SSI FB instability max frequency in '+x+' direction ' + + '[simu. frequency units]. calculated using fb_ssi_freq * kmax'+x, nfluid=2, uni=UNI_hz) + for x in AXES: + docvar('fb_ssi_growth_rate_max'+x, 'SSI FB instability max growth rate in '+x+' direction ' + + '[simu. frequency units]. calculated using fb_ssi_growth_rate * kmax'+x, nfluid=2, uni=UNI_hz) + for x in AXES: + docvar('fb_ssi_growth_time_min'+x, 'SSI FB instability min growth time in '+x+' direction ' + + '[simu. time units]. This is the amount of time it takes for the wave amplitude for the wave ' + + 'with the largest wave vector to grow by a factor of e. == 1/fb_ssi_growth_rate_max'+x, nfluid=2, uni=UNI_time) + + return None + + if quant not in FB_INSTAB_QUANT: + return None + + elif quant in ['psi0', 'psii']: + kappa_i = obj.get_var('kappa') + kappa_e = obj.get_var('kappa', mf_ispecies=-1) + return 1./(kappa_i * kappa_e) + + elif quant == 'vde': + modE = obj.get_var('modef') # [simu. E-field units] + modB = obj.get_var('modb') # [simu. B-field units] + return modE / modB # [simu. velocity units] + + elif quant == 'fb_ssi_vdtrigger': + icharge = obj.get_charge(obj.ifluid) + assert icharge > 0, "expected ifluid to be an ion but got ifluid charge == {}".format(icharge) + ci = obj.get_var('ci') # [simu. velocity units] + psi0 = obj.get_var('psi0') + return ci * (1 + psi0) # [simu. velocity units] + + elif quant == 'fb_ssi_possible': + return obj.get_var('vde') > obj.get_var('fb_ssi_vdtrigger') + + elif quant == 'fb_ssi_freq': + icharge = obj.get_charge(obj.ifluid) + assert icharge > 0, "expected ifluid to be an ion but got ifluid charge == {}".format(icharge) + jcharge = obj.get_charge(obj.jfluid) + assert jcharge == 0, "expected jfluid to be neutral but got jfluid charge == {}".format(jcharge) + # freq (=real part of omega) = Vd * k_x / (1 + psi0) + Vd = obj.get_var('vde') + psi0 = obj.get_var('psi0') + return Vd / (1 + psi0) + + elif quant == 'fb_ssi_growth_rate': + # growth rate = ((omega_r/k_x)^2 - Ci^2) * (k_x)^2/(nu_in*(1+1/psi0)) + w_r_k = obj.get_var('fb_ssi_freq') # omega_r / k_x + psi0 = obj.get_var('psi0') + Ci = obj.get_var('ci') + nu_in = obj.get_var('nu_ij') + return (w_r_k**2 - Ci**2) / (nu_in * (1 + 1/psi0)) + + elif quant in ['fb_ssi_freq_max'+x for x in AXES]: + x = quant[-1] + freq = obj.get_var('fb_ssi_freq') + kmaxx = obj.get_var('kmax'+x) + return kmaxx**2 * freq + + elif quant in ['fb_ssi_growth_rate_max'+x for x in AXES]: + x = quant[-1] + growth_rate = obj.get_var('fb_ssi_growth_rate') + kmaxx = obj.get_var('kmax'+x) + return kmaxx**2 * growth_rate + + elif quant in ['fb_ssi_growth_time_min'+x for x in AXES]: + x = quant[-1] + return 1/obj.get_var('fb_ssi_growth_rate_max'+x) + + else: + raise NotImplementedError(f'{repr(quant)} in get_fb_instab_quant') + + +# default +_THERMAL_INSTAB_QUANT = ['thermal_growth_rate', + 'thermal_freq', 'thermal_tan2xopt', + 'thermal_xopt', 'thermal_xopt_rad', 'thermal_xopt_deg', + 'ethermal_s0', 'ethermal_tan2xopt', + 'ethermal_xopt', 'ethermal_xopt_rad', 'ethermal_xopt_deg', + ] +_THERMAL_INSTAB_VECS = ['thermal_u0', 'thermal_v0'] +_THERMAL_INSTAB_QUANT += [v+x for v in _THERMAL_INSTAB_VECS for x in AXES] +# add thermal_growth_rate with combinations of terms. +# LEGACY: assumes we are using optimal angle for ion thermal plus FB effects; +# the code implements a formula where the optimal angle is plugged in already. +# NON-LEGACY: allows to plug in the optimal angle. +_LEGACY_THERMAL_GROWRATE_QUANTS = ['legacy_thermal_growth_rate' + x for x in ['', '_fb', '_thermal', '_damping']] +_LEGACY_THERMAL_GROWRATE_QUANTS += [quant+'_max' for quant in _LEGACY_THERMAL_GROWRATE_QUANTS] +_THERMAL_INSTAB_QUANT += _LEGACY_THERMAL_GROWRATE_QUANTS + +_THERMAL_GROWRATE_QUANTS = ['thermal_growth_rate' + x for x in ['', '_fb', '_thermal', '_damping']] +_THERMAL_GROWRATE_QUANTS += [quant+'_max' for quant in _THERMAL_GROWRATE_QUANTS] +_THERMAL_INSTAB_QUANT += _THERMAL_GROWRATE_QUANTS +_ETHERMAL_GROWRATE_QUANTS = ['ethermal_growth_rate' + x for x in ['', '_fb', '_it', '_et', '_damping']] +_ETHERMAL_GROWRATE_QUANTS += [quant+'_max' for quant in _ETHERMAL_GROWRATE_QUANTS] +_THERMAL_INSTAB_QUANT += _ETHERMAL_GROWRATE_QUANTS + +_THERMAL_INSTAB_QUANT = ('THERMAL_INSTAB_QUANT', _THERMAL_INSTAB_QUANT) +# get_value + + +@document_vars.quant_tracking_simple(_THERMAL_INSTAB_QUANT[0]) +def get_thermal_instab_quant(obj, quant, THERMAL_INSTAB_QUANT=None): + '''very specific quantities which are related to the ion thermal and/or electron thermal instabilities. + For source of formulae, see paper by Dimant & Oppenheim, 2004. + + In general, ion ifluid --> calculate for ion thermal instability; electron fluid --> for electron thermal. + Electron thermal is not yet implemented. + + Quantities which depend on two fluids expect ifluid to be ion or electron, and jfluid to be neutral. + ''' + if THERMAL_INSTAB_QUANT is None: + THERMAL_INSTAB_QUANT = _THERMAL_INSTAB_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _THERMAL_INSTAB_QUANT[0], THERMAL_INSTAB_QUANT, + get_thermal_instab_quant.__doc__, nfluid=1) + # document ion thermal stuff. (Intrinsically coupled to FB and diffusion effects.) + for thermal_xopt_rad in ['thermal_xopt', 'thermal_xopt_rad']: + docvar(thermal_xopt_rad, 'thermal instability optimal angle between k and (Ve - Vi) to maximize growth.' + + 'result will be in radians. Result will be between -pi/4 and pi/4.', nfluid=1, + uni_f=UNITS_FACTOR_1, uni_name='radians') + docvar('thermal_xopt_deg', 'thermal instability optimal angle between k and (Ve - Vi) to maximize growth.' + + 'result will be in degrees. Result will be between -45 and 45.', nfluid=1, + uni_f=UNITS_FACTOR_1, uni_name='degrees') + docvar('thermal_tan2xopt', 'tangent of 2 times thermal_xopt', nfluid=1, uni=DIMENSIONLESS) + for x in AXES: + docvar('thermal_u0'+x, x+'-component of (Ve - Vi). Warning: proper interpolation not yet implemented.', + nfluid=1, uni=UNI_speed) + for x in AXES: + docvar('thermal_v0'+x, x+'-component of E x B / B^2. Warning: proper interpolation not yet implemented.', + nfluid=0, uni=UNI_speed) + # document electron thermal stuff. (Intrinsically coupled to ion thermal, FB, and diffusion effects.) + docvar('ethermal_s0', 'S0 = S / sin(2 * ethermal_xopt). (Used in calculated ethermal effects.)' + + 'ifluid must be ion; jfluid must be neutral.', nfluid=2, uni=DIMENSIONLESS) + docvar('ethermal_tan2xopt', 'tangent of 2 times ethermal_xopt', nfluid=2, uni=DIMENSIONLESS) + docvar('ethermal_xopt', 'ethermal instability optimal angle between k and (Ve - Vi) to maximize growth, ' + + 'when accounting for ion thermal, electron thermal, and Farley-Buneman effects. ' + + 'result will be in radians, and between -pi/4 and pi/4.', nfluid=2, + uni_f=UNITS_FACTOR_1, uni_name='radians') + # document ion thermal growrate terms + for growquant in _LEGACY_THERMAL_GROWRATE_QUANTS + _THERMAL_GROWRATE_QUANTS + _ETHERMAL_GROWRATE_QUANTS: + # build docstring depending on growquant. final will be sGR + sMAX + sINC + sLEG. + # determine if MAX. + q, ismax, m = growquant.partition('_max') + if m != '': + continue # looks like '..._max_moretext'. Unrecognized format. Don't document. + if ismax: + sMAX = 'using wavenumber = 2*pi/(pixel width). result units are [simu. frequency].' + units = UNI_hz + else: + sMAX = 'divided by wavenumber squared. result units are [(simu. frequency) * (simu. length)^2].' + units = UNI_hz * UNI_length**2 + # split (take away the 'thermal_growth_rate' part). + e, _, terms = q.partition('thermal_growth_rate') + # determine if electron thermal is included. + if e == 'e': + sGR = 'Optimal growth rate for [electron thermal plus ion thermal plus Farley-Buneman] instability ' + nfluid = 2 + else: + sGR = 'Optimal growth rate for [ion thermal plus Farley-Buneman] instability ' + nfluid = 1 + # determine which term or terms we are INCluding + if terms == '': + sINC = '' + elif terms == '_fb': + sINC = ' Includes only the FB term. (No thermal, no damping.)' + elif terms == '_damping': + sINC = ' Includes only the damping term. (No FB, no thermal.)' + elif terms == '_thermal': # (only available if electron thermal is NOT included) + sINC = ' Includes only the ion thermal term. (No FB, no damping.)' + elif terms == '_it': # (only available if electron thermal IS included) + sINC = ' Includes only the ion thermal term. (No FB, no damping, no electron thermal.)' + elif terms == '_et': # (only available if electron thermal IS included) + sINC = ' Includes only the electron thermal term. (No FB, no damping, no ion thermal.)' + # determine if LEGacy. + if q.startswith('legacy'): + sLEG = ' Calculated using formula where optimal angle has already been entered.' + else: + sLEG = '' + # actually document. + docvar(growquant, sGR + sMAX + sINC + sLEG, nfluid=nfluid, uni=units) + return None + + # if quant not in THERMAL_INSTAB_QUANT: + # return None + + def check_fluids_ok(nfluid=1): + '''checks that ifluid is ion and jfluid is neutral. Only checks up to nfluid. raise error if bad.''' + if nfluid >= 1: + icharge = obj.get_charge(obj.ifluid) + if icharge <= 0: + raise ValueError('Expected ion ifluid for Thermal Instability quants, but got charge(ifluid)={}.'.format(icharge)) + if nfluid >= 2: + jcharge = obj.get_charge(obj.jfluid) + if jcharge != 0: + raise ValueError('Expected neutral jfluid but got charge(jfluid)={}.'.format(jcharge)) + return True + + if '_max' in quant: + if 'thermal_growth_rate' in quant: + q_without_max = quant.replace('_max', '') + k2 = max(obj.get_kmax())**2 # units [simu. length^-2]. (cancels with the [simu. length^2] from non-"_max" quant.) + result = obj.get_var(q_without_max) + result *= k2 + return result + + elif quant.startswith('thermal_growth_rate') or quant.startswith('ethermal_growth_rate'): + # determine included terms. + if quant == 'ethermal_growth_rate': + include_terms = ['fb', 'it', 'et', 'damping'] + elif quant == 'thermal_growth_rate': + include_terms = ['fb', 'thermal', 'damping'] + else: + include_terms = quant.split('_')[3:] + # do prep work; calculate coefficient which is in front of all terms. + psi = obj.get_var('psi0') + U02 = obj.get_var('thermal_u02') # U_0^2 + nu_in = obj.get_var('nu_sn') # sum_{j for j in neutrals and j!=i} (nu_ij) + front_coeff = psi * U02 / ((1 + psi) * nu_in) # leading coefficient (applies to all terms) + # do prep work; calculate components which appear in multiple terms. + + def any_included(*check): + return any((x in include_terms for x in check)) + if any_included('fb', 'thermal', 'it'): + kappai = obj.get_var('kappa') + if any_included('et', 'damping'): + Cs = obj.get_var('ci') + Cs2_over_U02 = Cs**2 / U02 + if any_included('fb', 'thermal', 'it', 'et'): + s_getX = 'ethermal_xopt' if quant.startswith('e') else 'thermal_xopt' + X = obj.get_var(s_getX) + cosX = np.cos(X) + sinX = np.sin(X) + if any_included('et'): + sin2X = np.sin(2*X) + S0 = obj.get_var('ethermal_s0') + # calculating terms and returning result. + result = obj.zero() + if any_included('fb'): + term_fb = (1 - kappai**2) * cosX**2 / (1 + psi)**2 + result += term_fb + if any_included('thermal', 'it'): + term_it = (2 / 3) * (kappai**2 * cosX**2 - kappai * cosX * sinX) / (1 + psi) + result += term_it + if any_included('et'): + term_et = (2 / 3) * Cs2_over_U02 * (S0**2 * sin2X**2 - (17/5) * S0 * sin2X) + result += term_et + if any_included('damping'): + term_dp = -1 * Cs2_over_U02 + result += term_dp + result *= front_coeff + return result + + elif quant.startswith('legacy_thermal_growth_rate'): + if quant == 'legacy_thermal_growth_rate': + include_terms = ['fb', 'thermal', 'damping'] + else: + include_terms = quant.split('_')[3:] + # prep work + psi = obj.get_var('psi0') + U02 = obj.get_var('thermal_u02') # U_0^2 + nu_in = obj.get_var('nu_sn') # sum_{j for j in neutrals and j!=i} (nu_ij) + front_coeff = psi * U02 / ((1 + psi) * nu_in) # leading coefficient (applies to all terms) + if 'fb' in include_terms or 'thermal' in include_terms: + # if calculating fb or thermal terms, need to know these values: + ki2 = obj.get_var('kappa')**2 # kappa_i^2 + A = (8 + (1 - ki2)**2 + 4 * psi * ki2)**(-1/2) + # calculating terms + result = obj.zero() + if 'fb' in include_terms: + fbterm = (1 - ki2) * (1 + (3 - ki2) * A) / (2 * (1 + psi)**2) + result += fbterm + if 'thermal' in include_terms: + thermterm = ki2 * (1 + (4 - ki2 + psi) * A) / (3 * (1 + psi)) + result += thermterm + if 'damping' in include_terms: + Cs = obj.get_var('ci') + dampterm = -1 * Cs**2 / U02 + result += dampterm + # multiply by leading coefficient + result *= front_coeff + return result + + elif quant in ['thermal_u0'+x for x in AXES]: + check_fluids_ok(nfluid=1) + # TODO: handle interpolation properly. + x = quant[-1] + qi = obj.get_charge(obj.ifluid, units='simu') + efx = obj.get_var('ef'+x) + mi = obj.get_mass(obj.ifluid, units='simu') + nu_in = obj.get_var('nu_sn') + Vix = qi * efx / (mi * nu_in) + V0x = obj.get_var('thermal_v0'+x) + ki2 = obj.get_var('kappa')**2 + return (V0x - Vix) / (1 + ki2) + + elif quant in ['thermal_v0'+x for x in AXES]: + # TODO: handle interpolation properly. + x = quant[-1] + ExB__x = obj.get_var('eftimesb'+x) + B2 = obj.get_var('b2') + return ExB__x/B2 + + elif quant == 'thermal_tan2xopt': + check_fluids_ok(nfluid=1) + ki = obj.get_var('kappa') + psi = obj.get_var('psi0') + return 2 * ki * (1 + psi) / (ki**2 - 3) + + elif quant in ['thermal_xopt', 'thermal_xopt_rad']: + # TODO: think about which results are being dropped because np.arctan is not multi-valued. + return 0.5 * np.arctan(obj.get_var('thermal_tan2xopt')) + + elif quant == 'thermal_xopt_deg': + return np.rad2deg(obj.get_var('thermal_xopt_rad')) + + # begin electron thermal quants + elif quant == 'ethermal_s0': + # S = S0 sin(2 * xopt). + # S0 = (psi / (1+psi)) (gyro_i / (delta_en nu_en)) (V0^2 / C_s^2) + check_fluids_ok(nfluid=2) + psi = obj.get_var('psi0') + gyroi = obj.get_var('gyrof') + with obj.MaintainFluids(): + nu_en = obj.get_var('nu_sn', iS=-1) + m_n = obj.get_mass(obj.mf_jspecies, units='amu') + m_e = obj.get_mass(-1, units='amu') + delta_en = 2 * m_e / (m_e + m_n) + V02 = obj.get_var('thermal_v02') # V0^2 + Cs2 = obj.get_var('ci')**2 # Cs^2 + factor1 = psi / (1 + psi) + factor2 = gyroi / (delta_en * nu_en) + factor3 = V02 / Cs2 + return factor1 * factor2 * factor3 + + elif quant == 'ethermal_tan2xopt': + # optimal tan(2\chi)... see Sam's derivation (available upon request). + # (Assumes |4 * thermal_s0 * sin(2 xopt)| << 34 / 5) + check_fluids_ok(nfluid=1) + S0 = obj.get_var('ethermal_s0') + kappai = obj.get_var('kappa') + psi = obj.get_var('psi0') + U02 = obj.get_var('thermal_u02') # V0^2 + Cs2 = obj.get_var('ci')**2 # Cs^2 + + c1 = - (1 - kappai**2)/(1 + psi)**2 - 2 * kappai**2 / (3 * (1 + psi)) + c2 = - 2 * kappai / (3 * (1 + psi)) - (68 / 15) * (Cs2 / U02) * S0 + return - c2 / c1 + + elif quant in ['ethermal_xopt', 'ethermal_xopt_rad']: + # TODO: think about which results are being dropped because np.arctan is not multi-valued. + return 0.5 * np.arctan(obj.get_var('ethermal_tan2xopt')) diff --git a/helita/sim/load_noeos_quantities.py b/helita/sim/load_noeos_quantities.py new file mode 100644 index 00000000..b95e03f3 --- /dev/null +++ b/helita/sim/load_noeos_quantities.py @@ -0,0 +1,28 @@ +from . import document_vars + + +def load_noeos_quantities(obj, quant, *args, EOSTAB_QUANT=None, **kwargs): + + quant = quant.lower() + + document_vars.set_meta_quant(obj, 'noeosquantities', 'Computes some variables without EOS tables') + + val = get_eosparam(obj, quant, EOSTAB_QUANT=EOSTAB_QUANT) + + return val + + +def get_eosparam(obj, quant, EOSTAB_QUANT=None): + '''Computes some variables without EOS tables ''' + if (EOSTAB_QUANT == None): + EOSTAB_QUANT = ['ne'] + + docvar = document_vars.vars_documenter(obj, 'EOSTAB_QUANT', EOSTAB_QUANT, get_eosparam.__doc__) + docvar('ne', "electron density [cm^-3]") + + if (quant == '') or not quant in EOSTAB_QUANT: + return None + + nh = obj.get_var('rho') / obj.uni.grph + + return nh + 2.*nh*(obj.uni.grph/obj.uni.m_h-1.) # this may need a better adjustment. diff --git a/helita/sim/load_quantities.py b/helita/sim/load_quantities.py new file mode 100644 index 00000000..c6a90286 --- /dev/null +++ b/helita/sim/load_quantities.py @@ -0,0 +1,1967 @@ +# import builtins +import warnings + +# import external public modules +import numpy as np + +# import internal modules +from . import document_vars, tools +from .load_arithmetic_quantities import do_stagger + +# from glob import glob # this is only used for find_first_match which is never called... + + +try: + from numba import jit, njit, prange +except ImportError: + numba = prange = tools.ImportFailed('numba', "This module is required to use stagger_kind='numba'.") + jit = njit = tools.boring_decorator + +# import the potentially-relevant things from the internal module "units" +from .units import UNI_nr + +DEFAULT_ELEMLIST = ['h', 'he', 'c', 'o', 'ne', 'na', 'mg', 'al', 'si', 's', 'k', 'ca', 'cr', 'fe', 'ni'] + +# setup DEFAULT_CROSS_DICT +cross_dict = dict() +cross_dict['h1', 'h2'] = cross_dict['h2', 'h1'] = 'p-h-elast.txt' +cross_dict['h2', 'h22'] = cross_dict['h22', 'h2'] = 'h-h2-data.txt' +cross_dict['h2', 'he1'] = cross_dict['he1', 'h2'] = 'p-he.txt' +cross_dict['e', 'he1'] = cross_dict['he1', 'e'] = 'e-he.txt' +cross_dict['e', 'h1'] = cross_dict['h1', 'e'] = 'e-h.txt' +DEFAULT_CROSS_DICT = cross_dict +del cross_dict + +# set constants + +POLARIZABILITY_DICT = { # polarizability (used in maxwell collisions) + 'h': 6.68E-31, + 'he': 2.05E-31, + 'li': 2.43E-29, + 'be': 5.59E-30, + 'b': 3.04E-30, + 'c': 1.67E-30, + 'n': 1.10E-30, + 'o': 7.85E-31, + 'f': 5.54E-31, + 'ne': 3.94E-31, + 'na': 2.41E-29, + 'mg': 1.06E-29, + 'al': 8.57E-30, + 'si': 5.53E-30, + 'p': 3.70E-30, + 's': 2.87E-30, + 'cl': 2.16E-30, + 'ar': 1.64E-30, + 'k': 4.29E-29, + 'ca': 2.38E-29, + 'sc': 1.44E-29, + 'ti': 1.48E-29, + 'v': 1.29E-29, + 'cr': 1.23E-29, + 'mn': 1.01E-29, + 'fe': 9.19E-30, + 'co': 8.15E-30, + 'ni': 7.26E-30, + 'cu': 6.89E-30, + 'zn': 5.73E-30, + 'ga': 7.41E-30, + 'ge': 5.93E-30, + 'as': 4.45E-30, + 'se': 4.28E-30, + 'br': 3.11E-30, + 'kr': 2.49E-30, + 'rb': 4.74E-29, + 'sr': 2.92E-29, + 'y': 2.40E-29, + 'zr': 1.66E-29, + 'nb': 1.45E-29, + 'mo': 1.29E-29, + 'tc': 1.17E-29, + 'ru': 1.07E-29, + 'rh': 9.78E-30, + 'pd': 3.87E-30, +} + + +whsp = ' ' + + +def set_elemlist_as_needed(obj, elemlist=None, ELEMLIST=None, **kwargs): + ''' set_elemlist if appropriate. Accepts 'elemlist' or 'ELEMLIST' kwargs. ''' + # -- get elemlist. Could be entered as 'elemlist' or 'ELEMLIST' -- # + if elemlist is None: + elemlist = ELEMLIST # ELEMLIST is alias for elemlist. + # -- if obj.ELEMLIST doesn't exist (first time setting ELEMLIST) -- # + if not hasattr(obj, 'ELEMLIST'): + if elemlist is None: + # << if we reach this line it means elemlist wasn't entered as a kwarg. + elemlist = DEFAULT_ELEMLIST # so, use the default. + + if elemlist is None: + # << if we reach this line, elemlist wasn't entered, + # AND obj.ELEMLIST exists (so elemlist has been set previously). + # So, do nothing and return None. + return None + else: + return set_elemlist(obj, elemlist) + + +def set_elemlist(obj, elemlist=DEFAULT_ELEMLIST): + ''' sets all things which depend on elemlist, as attrs of obj. + Also sets obj.set_elemlist to partial(set_elemlist(obj)). + ''' + obj.ELEMLIST = elemlist + obj.CROSTAB_LIST = ['e_'+elem for elem in obj.ELEMLIST] \ + + [elem+'_e' for elem in obj.ELEMLIST] \ + + [e1 + '_' + e2 for e1 in obj.ELEMLIST for e2 in obj.ELEMLIST] + obj.COLFRE_QUANT = ['nu' + clist for clist in obj.CROSTAB_LIST] \ + + ['nu%s_mag' % clist for clist in obj.CROSTAB_LIST] + + obj.COLFREMX_QUANT = ['numx' + clist for clist in obj.CROSTAB_LIST] \ + + ['numx%s_mag' % clist for clist in obj.CROSTAB_LIST] + obj.COLCOU_QUANT = ['nucou' + clist for clist in obj.CROSTAB_LIST] + obj.COLCOUMS_QUANT = ['nucou_ei', 'nucou_ii'] + obj.COLCOUMS_QUANT += ['nucou' + elem + '_i' for elem in obj.ELEMLIST] + obj.COLFRI_QUANT = ['nu_ni', 'numx_ni', 'nu_en', 'nu_ei', 'nu_in', 'nu_ni_mag', 'nu_in_mag'] + obj.COLFRI_QUANT += [nu + elem + '_' + i + for i in ('i', 'i_mag', 'n', 'n_mag') + for nu in ('nu', 'numx') + for elem in obj.ELEMLIST] + + obj.COULOMB_COL_QUANT = ['coucol' + elem for elem in obj.ELEMLIST] + obj.GYROF_QUANT = ['gfe'] + ['gf' + elem for elem in obj.ELEMLIST] + obj.KAPPA_QUANT = ['kappa' + elem for elem in obj.ELEMLIST] + obj.KAPPA_QUANT += ['kappanorm_', 'kappae'] + obj.IONP_QUANT = ['n' + elem + '-' for elem in obj.ELEMLIST] \ + + ['r' + elem + '-' for elem in obj.ELEMLIST] \ + + ['rneu', 'rion', 'nion', 'nneu', 'nelc'] \ + + ['rneu_nomag', 'rion_nomag', 'nion_nomag', 'nneu_nomag'] + + def _set_elemlist(elemlist=DEFAULT_ELEMLIST): + '''sets all things which depend on elemlist, as attrs of self.''' + set_elemlist(obj, elemlist) + obj.set_elemlist = _set_elemlist + + +def set_crossdict_as_needed(obj, **kwargs): + '''sets all things related to cross_dict. + Use None to restore default values. + + e.g. get_var(..., maxwell=None) retores to using default value for maxwell (False). + Defaults: + maxwell: False + cross_tab: None + cross_dict: + cross_dict['h1','h2'] = cross_dict['h2','h1'] = 'p-h-elast.txt' + cross_dict['h2','h22'] = cross_dict['h22','h2'] = 'h-h2-data.txt' + cross_dict['h2','he1'] = cross_dict['he1','h2'] = 'p-he.txt' + cross_dict['e','he1'] = cross_dict['he1','e'] = 'e-he.txt' + cross_dict['e','h1'] = cross_dict['h1','e'] = 'e-h.txt' + ''' + if not hasattr(obj, 'CROSS_SECTION_INFO'): + obj.CROSS_SECTION_INFO = dict() + + CSI = obj.CROSS_SECTION_INFO # alias + + DEFAULTS = dict(cross_tab=None, cross_dict=DEFAULT_CROSS_DICT, maxwell=False) + + for key in ('cross_tab', 'cross_dict', 'maxwell'): + if key in kwargs: + if kwargs[key] is None: + CSI[key] = DEFAULTS[key] + else: + CSI[key] = kwargs[key] + elif key not in CSI: + CSI[key] = DEFAULTS[key] + + +''' ----------------------------- get values of quantities ----------------------------- ''' + + +def load_quantities(obj, quant, *args__None, PLASMA_QUANT=None, CYCL_RES=None, + COLFRE_QUANT=None, COLFRI_QUANT=None, IONP_QUANT=None, + EOSTAB_QUANT=None, TAU_QUANT=None, DEBYE_LN_QUANT=None, + CROSTAB_QUANT=None, COULOMB_COL_QUANT=None, AMB_QUANT=None, + HALL_QUANT=None, BATTERY_QUANT=None, SPITZER_QUANT=None, + KAPPA_QUANT=None, GYROF_QUANT=None, WAVE_QUANT=None, + FLUX_QUANT=None, CURRENT_QUANT=None, COLCOU_QUANT=None, + COLCOUMS_QUANT=None, COLFREMX_QUANT=None, EM_QUANT=None, + POND_QUANT=None, **kwargs): + '''loads or calculates the value of single-fluid quantity quant. + + obj: HelitaData object. (e.g. BifrostData, EbysusData) + use this object for loading / calculating. + if quant depends on another quantity, call obj (or obj's get_var method). + e.g. getting 'beta' --> will call obj('b2') and obj('p'). + quant: string + the name of the quantity to get. + For help on available quantity names, use obj.vardocs(), or obj.get_var(''). + *args__None: + additional non-named arguments are passed to NOWHERE. + kwargs in function call signature (such as PLASMA_QUANT, CYCL_RES): None, list of strings, or ''. + Controls which quants are gettable by the corresponding getter function. + These are mostly intended for internal use, to allow certain HelitaData classes to + utilize load_quantities, but possibly provide different behavior for some variables. + OPTIONS: + None --> "all available quants" + can get any quant the getter function knows about. + list of strings --> "these quants" + can get any quant in this list, if the getter function knows about it. + '' --> "no quants" + the associated getter function will not get any quantities. + EXAMPLES: + - load_quantities(obj, var) + This is the default behavior of load_quantities; + all getter functions will get a chance to do their default behavior. + - load_quantities(obj, var, PLASMA_QUANT='') + In this case, obj will use all the getter funcs in load_quantities except for get_plasmaparam. + This is useful, e.g., if you want to implement a different routine for getting 'beta', + rather than letting get_plasmaparam handle getting 'beta'. + - load_quantities(obj, var, PLASMA_QUANT=['va', 'cs', 'nr']) + In this case, obj will use all the getter funcs in load_quantities, + using default behavior for all getter funcs except get_plasmaparam. + For get_plasmaparam, ONLY 'va', 'cs', and 'nr' will be handled. + This is useful, e.g., if you want to implement a different routine for getting 'beta', + but still want to utilitze get_plasmaparam's routines in case of 'va', 'cs', and 'nr'. + **kwargs: + additional kwargs (not in function call signature) are passed to the getter funcs. + ''' + + # HALL_QUANT=None, SPITZER_QUANT=None, **kwargs): + __tracebackhide__ = True # hide this func from error traceback stack. + + set_elemlist_as_needed(obj, **kwargs) + set_crossdict_as_needed(obj, **kwargs) + + quant = quant.lower() + + document_vars.set_meta_quant(obj, 'quantities', 'These are the single-fluid quantities') + + # tell which getter function is associated with each QUANT. + # (would put this list outside this function if the getter functions were defined there, but they are not.) + _getter_QUANT_pairs = ( + (get_em, 'EM_QUANT'), + (get_coulomb, 'COULOMB_COL_QUANT'), + (get_collision, 'COLFRE_QUANT'), + (get_crossections, 'CROSTAB_QUANT'), + (get_collision_ms, 'COLFRI_QUANT'), + (get_current, 'CURRENT_QUANT'), + (get_flux, 'FLUX_QUANT'), + (get_plasmaparam, 'PLASMA_QUANT'), + (get_wavemode, 'WAVE_QUANT'), + (get_cyclo_res, 'CYCL_RES'), + (get_gyrof, 'GYROF_QUANT'), + (get_kappa, 'KAPPA_QUANT'), + (get_debye_ln, 'DEBYE_LN_QUANT'), + (get_ionpopulations, 'IONP_QUANT'), + (get_ambparam, 'AMB_QUANT'), + (get_hallparam, 'HALL_QUANT'), + (get_batteryparam, 'BATTERY_QUANT'), + (get_spitzerparam, 'SPITZER_QUANT'), + (get_eosparam, 'EOSTAB_QUANT'), + (get_collcoul, 'COLCOU_QUANT'), + (get_collcoul_ms, 'COLCOUMS_QUANT'), + (get_collision_maxw, 'COLFREMX_QUANT'), + (get_ponderomotive, 'POND_QUANT'), + ) + + val = None + # loop through the function and QUANT pairs, running the functions as appropriate. + for getter, QUANT_STR in _getter_QUANT_pairs: + QUANT = locals()[QUANT_STR] # QUANT = value of input parameter named QUANT_STR. + # if QUANT == '', we are skipping this getter function (see docstring of load_quantities for more detail). + if QUANT != '': + val = getter(obj, quant, **{QUANT_STR: QUANT}, **kwargs) + if val is not None: + break + return val + + +# default +_EM_QUANT = ('EM_QUANT', ['emiss']) +# get value + + +@document_vars.quant_tracking_simple(_EM_QUANT[0]) +def get_em(obj, quant, EM_QUANT=None, *args, **kwargs): + """ + Calculates emission messure (EM). + + Parameters + ---------- + Returns + ------- + array - ndarray + Array with the dimensions of the 3D spatial from the simulation + of the emission measure c.g.s units. + """ + if EM_QUANT == '': # by entering EM_QUANT='', we are saying "skip get_em; get nothing." + return None + + if EM_QUANT == None: + EM_QUANT = _EM_QUANT[1] + + unitsnorm = 1e27 + for key, value in kwargs.items(): + if key == 'unitsnorm': + unitsnorm = value + + if quant == '': + docvar = document_vars.vars_documenter(obj, _EM_QUANT[0], EM_QUANT, get_em.__doc__) + docvar('emiss', 'emission messure [cgs]') + + if (quant == '') or not quant in EM_QUANT: + return None + + sel_units = obj.sel_units + obj.sel_units = 'cgs' + + rho = obj.get_var('totr') + en = obj.get_var('ne') + nh = rho / obj.uni.grph + + obj.sel_units = sel_units + + return en * (nh / unitsnorm) + + +# default +_CROSTAB_QUANT0 = ('CROSTAB_QUANT') +# get value + + +@document_vars.quant_tracking_simple(_CROSTAB_QUANT0) +def get_crossections(obj, quant, CROSTAB_QUANT=None, **kwargs): + ''' + Computes cross section between species in cgs + + optional kwarg: cross_dict + (can pass it to get_var. E.g. get_var(..., cross_dict=mycrossdict)) + tells which cross sections to use. + If not entered, use: + cross_dict['h1','h2'] = cross_dict['h2','h1'] = 'p-h-elast.txt' + cross_dict['h2','h22'] = cross_dict['h22','h2'] = 'h-h2-data.txt' + cross_dict['h2','he1'] = cross_dict['he1','h2'] = 'p-he.txt' + cross_dict['e','he1'] = cross_dict['he1','e'] = 'e-he.txt' + cross_dict['e','h1'] = cross_dict['h1','e'] = 'e-h.txt' + ''' + if CROSTAB_QUANT is None: + CROSTAB_QUANT = obj.CROSTAB_LIST + + if quant == '': + document_vars.vars_documenter(obj, _CROSTAB_QUANT0, CROSTAB_QUANT, get_crossections.__doc__) + + quant_elem = ''.join([i for i in quant if not i.isdigit()]) + + if (quant == '') or not quant_elem in CROSTAB_QUANT: + return None + + cross_tab = obj.CROSS_SECTION_INFO['cross_tab'] + cross_dict = obj.CROSS_SECTION_INFO['cross_dict'] + maxwell = obj.CROSS_SECTION_INFO['maxwell'] + + elem = quant.split('_') + spic1 = elem[0] + spic2 = elem[1] + spic1_ele = ''.join([i for i in spic1 if not i.isdigit()]) + spic2_ele = ''.join([i for i in spic2 if not i.isdigit()]) + + # -- try to read cross tab (unless it was entered in kwargs) -- # + if cross_tab is None: + try: + cross_tab = cross_dict[spic1, spic2] + except Exception: + if not (maxwell): + # use a guess. (Might be a bad guess...) + ww = obj.uni.weightdic + if (spic1_ele == 'h'): + cross = ww[spic2_ele] / ww['h'] * obj.uni.cross_p + elif (spic2_ele == 'h'): + cross = ww[spic1_ele] / ww['h'] * obj.uni.cross_p + elif (spic1_ele == 'he'): + cross = ww[spic2_ele] / ww['he'] * obj.uni.cross_he + elif (spic2_ele == 'he'): + cross = ww[spic1_ele] / ww['he'] * obj.uni.cross_he + else: + cross = ww[spic2_ele] / ww['h'] * obj.uni.cross_p / (np.pi*ww[spic2_ele])**2 + # make sure the guess has the right shape. + cross = obj.zero() + cross + + # -- use cross_tab to read cross at tg -- # + if cross_tab is not None: + tg = obj.get_var('tg') + crossobj = obj.cross_sect(cross_tab=[cross_tab]) + cross = crossobj.cross_tab[0]['crossunits'] * crossobj.tab_interp(tg) + + # -- return result -- # + try: + return cross + except Exception: + print('(WWW) cross-section: wrong combination of species', end="\r", + flush=True) + return None + + +# default +_EOSTAB_QUANT = ('EOSTAB_QUANT', ['ne', 'tg', 'pg', 'kr', 'eps', 'opa', 'temt', 'ent']) +# get value + + +@document_vars.quant_tracking_simple(_EOSTAB_QUANT[0]) +def get_eosparam(obj, quant, EOSTAB_QUANT=None, **kwargs): + ''' + Variables from EOS table. All of them + are in cgs except ne which is in SI. + ''' + + if (EOSTAB_QUANT == None): + EOSTAB_QUANT = _EOSTAB_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _EOSTAB_QUANT[0], EOSTAB_QUANT, get_eosparam.__doc__) + docvar('ne', 'electron density [m^-3]') + if (obj.sel_units == 'cgs'): + docvar('ne', 'electron density [cm^-3]') + docvar('tg', 'Temperature [K]') + docvar('pg', 'gas pressure [dyn/cm^2]') + docvar('kr', 'Rosseland opacity [cm^2/g]') + docvar('eps', 'scattering probability') + docvar('opa', 'opacity') + docvar('temt', 'thermal emission') + docvar('ent', 'entropy') + + if (quant == '') or not quant in EOSTAB_QUANT: + return None + + if quant == 'tau': + return calc_tau(obj) + + else: + # bifrost_uvotrt uses SI! + fac = 1.0 + if (quant == 'ne') and (obj.sel_units != 'cgs'): + fac = 1.e6 # cm^-3 to m^-3 + + if obj.hion and quant == 'ne': + return obj.get_var('hionne') * fac + + sel_units = obj.sel_units + obj.sel_units = 'cgs' + rho = obj.get_var('rho') + ee = obj.get_var('e') / rho + + obj.sel_units = sel_units + + if obj.verbose: + print(quant + ' interpolation...', whsp*7, end="\r", flush=True) + + return obj.rhoee.tab_interp( + rho, ee, order=1, out=quant) * fac + + +# default +_COLFRE_QUANT0 = ('COLFRE_QUANT') +# get value + + +@document_vars.quant_tracking_simple(_COLFRE_QUANT0) +def get_collision(obj, quant, COLFRE_QUANT=None, **kwargs): + ''' + Collision frequency between different species in (cgs) + It will assume Maxwell molecular collisions if crossection + tables does not exist. + ''' + + if COLFRE_QUANT is None: + COLFRE_QUANT = obj.COLFRE_QUANT # _COLFRE_QUANT[1] + + if quant == '': + document_vars.vars_documenter(obj, _COLFRE_QUANT0, COLFRE_QUANT, get_collision.__doc__) + + if (quant == '') or not ''.join([i for i in quant if not i.isdigit()]) in COLFRE_QUANT: + return None + + elem = quant.split('_') + spic1 = ''.join([i for i in elem[0] if not i.isdigit()]) + ion1 = ''.join([i for i in elem[0] if i.isdigit()]) + spic2 = ''.join([i for i in elem[1] if not i.isdigit()]) + ion2 = ''.join([i for i in elem[1] if i.isdigit()]) + spic1 = spic1[2:] + + crossarr = get_crossections(obj, '%s%s_%s%s' % (spic1, ion1, spic2, ion2), **kwargs) + + if np.any(crossarr) == None: + return get_collision_maxw(obj, 'numx'+quant[2:], **kwargs) + else: + + nspic2 = obj.get_var('n%s-%s' % (spic2, ion2)) + if np.size(elem) > 2: + nspic2 *= (1.0-obj.get_var('kappanorm_%s' % spic2)) + + tg = obj.get_var('tg') + if spic1 == 'e': + awg1 = obj.uni.m_electron + else: + awg1 = obj.uni.weightdic[spic1] * obj.uni.amu + if spic1 == 'e': + awg2 = obj.uni.m_electron + else: + awg2 = obj.uni.weightdic[spic2] * obj.uni.amu + scr1 = np.sqrt(8.0 * obj.uni.kboltzmann * tg / obj.uni.pi) + + return crossarr * np.sqrt((awg1 + awg2) / (awg1 * awg2)) *\ + scr1 * nspic2 # * (awg1 / (awg1 + awg1)) + + +# default +_COLFREMX_QUANT0 = ('COLFREMX_QUANT') +# get value + + +@document_vars.quant_tracking_simple(_COLFREMX_QUANT0) +def get_collision_maxw(obj, quant, COLFREMX_QUANT=None, **kwargs): + ''' + Maxwell molecular collision frequency + ''' + if COLFREMX_QUANT is None: + COLFREMX_QUANT = obj.COLFREMX_QUANT + + if quant == '': + document_vars.vars_documenter(obj, _COLFREMX_QUANT0, COLFREMX_QUANT, get_collision_maxw.__doc__) + + if (quant == '') or not ''.join([i for i in quant if not i.isdigit()]) in COLFREMX_QUANT: + return None + + #### ASSUMES ifluid is charged AND jfluid is neutral. #### + # set constants. for more details, see eq2 in Appendix A of Oppenheim 2020 paper. + CONST_MULT = 1.96 # factor in front. + CONST_ALPHA_N = 6.67e-31 # [m^3] #polarizability for Hydrogen + e_charge = 1.602176e-19 # [C] #elementary charge + eps0 = 8.854187e-12 # [F m^-1] #epsilon0, standard + + elem = quant.split('_') + spic1 = ''.join([i for i in elem[0] if not i.isdigit()]) + ion1 = ''.join([i for i in elem[0] if i.isdigit()]) + spic2 = ''.join([i for i in elem[1] if not i.isdigit()]) + ion2 = ''.join([i for i in elem[1] if i.isdigit()]) + spic1 = spic1[4:] + + obj.get_var('tg') + if spic1 == 'e': + awg1 = obj.uni.msi_e + else: + awg1 = obj.uni.weightdic[spic1] * obj.uni.amusi + if spic1 == 'e': + awg2 = obj.uni.msi_e + else: + awg2 = obj.uni.weightdic[spic2] * obj.uni.amusi + + if (ion1 == 0 and ion2 != 0): + CONST_ALPHA_N = POLARIZABILITY_DICT[spic1] + nspic2 = obj.get_var('n%s-%s' % (spic2, ion2)) / (obj.uni.cm_to_m**3) # convert to SI. + if np.size(elem) > 2: + nspic2 *= (1.0-obj.get_var('kappanorm_%s' % spic2)) + return CONST_MULT * nspic2 * np.sqrt(CONST_ALPHA_N * e_charge**2 * awg2 / (eps0 * awg1 * (awg1 + awg2))) + elif (ion2 == 0 and ion1 != 0): + CONST_ALPHA_N = POLARIZABILITY_DICT[spic2] + nspic1 = obj.get_var('n%s-%s' % (spic1, ion1)) / (obj.uni.cm_to_m**3) # convert to SI. + if np.size(elem) > 2: + nspic1 *= (1.0-obj.get_var('kappanorm_%s' % spic2)) + return CONST_MULT * nspic1 * np.sqrt(CONST_ALPHA_N * e_charge**2 * awg1 / (eps0 * awg2 * (awg1 + awg2))) + else: + nspic2 = obj.get_var('n%s-%s' % (spic2, ion2)) / (obj.uni.cm_to_m**3) # convert to SI. + if np.size(elem) > 2: + nspic2 *= (1.0-obj.get_var('kappanorm_%s' % spic2)) + return CONST_MULT * nspic2 * np.sqrt(CONST_ALPHA_N * e_charge**2 * awg2 / (eps0 * awg1 * (awg1 + awg2))) + + +# default +_COLCOU_QUANT0 = ('COLCOU_QUANT') +# get value + + +@document_vars.quant_tracking_simple(_COLCOU_QUANT0) +def get_collcoul(obj, quant, COLCOU_QUANT=None, **kwargs): + ''' + Coulomb Collision frequency between different ionized species (cgs) + (Hansteen et al. 1997) + ''' + if COLCOU_QUANT is None: + COLCOU_QUANT = obj.COLCOU_QUANT + + if quant == '': + document_vars.vars_documenter(obj, _COLCOU_QUANT0, COLCOU_QUANT, get_collcoul.__doc__) + + if (quant == '') or not ''.join([i for i in quant if not i.isdigit()]) in COLCOU_QUANT: + return None + + elem = quant.split('_') + spic1 = ''.join([i for i in elem[0] if not i.isdigit()]) + ''.join([i for i in elem[0] if i.isdigit()]) + spic2 = ''.join([i for i in elem[1] if not i.isdigit()]) + ion2 = ''.join([i for i in elem[1] if i.isdigit()]) + spic1 = spic1[5:] + nspic2 = obj.get_var('n%s-%s' % (spic2, ion2)) # scr2 + + tg = obj.get_var('tg') # scr1 + nel = obj.get_var('ne') / 1e6 # it takes into account NEQ and converts to cgs + + coulog = 23. + 1.5 * np.log(tg/1.e6) - 0.5 * np.log(nel/1e6) # Coulomb logarithm scr4 + + mst = obj.uni.weightdic[spic1] * obj.uni.weightdic[spic2] * obj.uni.amu / \ + (obj.uni.weightdic[spic1] + obj.uni.weightdic[spic2]) + + return 1.7 * coulog/20.0 * (obj.uni.m_h/(obj.uni.weightdic[spic1] * + obj.uni.amu)) * (mst/obj.uni.m_h)**0.5 * \ + nspic2 / tg**1.5 * (int(ion2)-1)**2 + + +# default +_COLCOUMS_QUANT0 = ('COLCOUMS_QUANT') +# get value + + +@document_vars.quant_tracking_simple(_COLCOUMS_QUANT0) +def get_collcoul_ms(obj, quant, COLCOUMS_QUANT=None, **kwargs): + ''' + Coulomb collision between for a specific ionized species (or electron) with + all ionized elements (cgs) + ''' + if (COLCOUMS_QUANT == None): + COLCOUMS_QUANT = obj.COLCOUMS_QUANT + + if quant == '': + document_vars.vars_documenter(obj, _COLCOUMS_QUANT0, COLCOUMS_QUANT, get_collcoul_ms.__doc__) + + if (quant == '') or not ''.join([i for i in quant if not i.isdigit()]) in COLCOUMS_QUANT: + return None + + if (quant == 'nucou_ii'): + result = obj.zero() + for ielem in obj.ELEMLIST: + + result += obj.uni.amu * obj.uni.weightdic[ielem] * \ + obj.get_var('n%s-1' % ielem) * \ + obj.get_var('nucou%s1_i' % (ielem)) + + if obj.heion: + result += obj.uni.amu * obj.uni.weightdic['he'] * obj.get_var('nhe-3') * \ + obj.get_var('nucouhe3_i') + + elif quant[-2:] == '_i': + lvl = '2' + + elem = quant.split('_') + result = obj.zero() + for ielem in obj.ELEMLIST: + if elem[0][5:] != '%s%s' % (ielem, lvl): + result += obj.get_var('%s_%s%s' % + (elem[0], ielem, lvl)) + + return result + + +# default +_COLFRI_QUANT0 = ('COLFRI_QUANT') +# get value + + +@document_vars.quant_tracking_simple(_COLFRI_QUANT0) +def get_collision_ms(obj, quant, COLFRI_QUANT=None, **kwargs): + ''' + Sum of collision frequencies (cgs). + + Formats (with , replaced by elements, e.g. '' --> 'he'): + - nu_n : sum of collision frequencies between A2 and neutrals + nuA2_h1 + nuA2_he1 + ... + - nu_i : sum of collision frequencies between A1 and once-ionized ions + nuA1_h2 + nuA1_he2 + ... + + For more precise control over which collision frequencies are summed, + refer to obj.ELEMLIST, and/or obj.set_elemlist(). + ''' + + if (COLFRI_QUANT == None): + COLFRI_QUANT = obj.COLFRI_QUANT + + if quant == '': + document_vars.vars_documenter(obj, _COLFRI_QUANT0, COLFRI_QUANT, get_collision_ms.__doc__) + return None + + if (quant[0:2] != 'nu') or (not ''.join([i for i in quant if not i.isdigit()]) in COLFRI_QUANT): + return None + + elif quant in ('nu_ni_mag', 'nu_ni', 'numx_ni_mag', 'numx_ni'): + result = obj.zero() + s_nu, _, ni_mag = quant.partition('_') # s_numx = nu or numx + for ielem in obj.ELEMLIST: + if ielem in obj.ELEMLIST[2:] and '_mag' in quant: + const = (1 - obj.get_var('kappanorm_%s' % ielem)) + mag = '_mag' + else: + const = 1.0 + mag = '' + + # S + nelem_1 = 'n{elem}-1'.format(elem=ielem) + nuelem1_imag = '{nu}{elem}_i{mag}'.format(nu=s_nu, elem=ielem, mag=mag) + result += obj.uni.amu * obj.uni.weightdic[ielem] * \ + obj.get_var(nelem_1) * const * \ + obj.get_var(nuelem1_imag, **kwargs) + + if ((ielem in obj.ELEMLIST[2:]) and ('_mag' in quant)): + nelem_2 = 'n{elem}-2'.format(elem=ielem) + nuelem2_imag = '{nu}{elem}_i{mag}'.format(nu=s_nu, elem=ielem, mag=mag) + result += obj.uni.amu * obj.uni.weightdic[ielem] * \ + obj.get_var(nelem_2) * const * \ + obj.get_var(nuelem2_imag, **kwargs) + + elif ((quant == 'nu_in_mag') or (quant == 'nu_in')): + result = obj.zero() + for ielem in obj.ELEMLIST: + if (ielem in obj.ELEMLIST[2:] and '_mag' in quant): + const = (1 - obj.get_var('kappanorm_%s' % ielem)) + mag = '_mag' + else: + const = 1.0 + mag = '' + + result += obj.uni.amu * obj.uni.weightdic[ielem] * const * \ + obj.get_var('n%s-2' % ielem) * obj.get_var('nu%s2_n%s' % (ielem, mag), **kwargs) + if obj.heion: + result += obj.uni.amu * obj.uni.weightdic['he'] * obj.get_var('nhe-3') * \ + obj.get_var('nuhe3_n%s' % mag, **kwargs) + + elif quant == 'nu_ei': + nel = obj.get_var('ne') / 1e6 # NEQ is taken into account and its converted to cgs + culblog = 23. + 1.5 * np.log(obj.get_var('tg') / 1.e6) - \ + 0.5 * np.log(nel / 1e6) + + result = 3.759 * nel / (obj.get_var('tg')**(1.5)) * culblog + + elif quant == 'nu_en': + elem = quant.split('_') + result = obj.zero() + lvl = 1 + for ielem in obj.ELEMLIST: + if ielem in ['h', 'he']: + result += obj.get_var('%s_%s%s' % + ('nue', ielem, lvl), **kwargs) + + elif (quant[0:2] == 'nu' and (quant[-2:] == '_i' or quant[-2:] == '_n' or quant[-6:] == '_i_mag' or quant[-6:] == '_n_mag')): + nu = 'numx' if quant.startswith('numx') else 'nu' + qrem = quant[len(nu):] # string remaining in quant, after taking away nu (either 'nu' or 'numx'). + elem, _, qrem = qrem.partition('_') # e.g. 'h2', '_', 'n_mag' # or, e.g. 'he', '_', 'i' + n, _, mag = qrem.partition('_') # e.g. 'n', '_', 'mag' # or, e.g. 'i', '', '' + if mag != '': + mag = '_' + mag + + if not elem[-1].isdigit(): # Didn't provide level for elem; we infer it to be 1 or 2 based on '_i' or '_n'. + elemlvl = {'n': 2, 'i': 1}[n] # elemlvl is 2 for nu_n; 1 for nu_i. + elem = '{elem}{lvl}'.format(elem=elem, lvl=elemlvl) + jlvl = {'n': 1, 'i': 2}[n] # level of second species will be 1 for nu_n, 2 for nu_i. + + result = obj.zero() + for ielem in obj.ELEMLIST: # e.g. ielem == 'he' + ielem = '{elem}{lvl}'.format(elem=ielem, lvl=jlvl) + if ielem != elem: + getting = '{nu}{elem}_{ielem}{mag}'.format(nu=nu, elem=elem, ielem=ielem, mag=mag) + result += obj.get_var(getting, **kwargs) + + return result + + +# default +_COULOMB_COL_QUANT0 = ('COULOMB_COL_QUANT') +# get value + + +@document_vars.quant_tracking_simple(_COULOMB_COL_QUANT0) +def get_coulomb(obj, quant, COULOMB_COL_QUANT=None, **kwargs): + ''' + Coulomb collision frequency in Hz + ''' + + if COULOMB_COL_QUANT is None: + COULOMB_COL_QUANT = obj.COULOMB_COL_QUANT + + if quant == '': + document_vars.vars_documenter(obj, _COULOMB_COL_QUANT0, COULOMB_COL_QUANT, get_coulomb.__doc__) + + if (quant == '') or not quant in COULOMB_COL_QUANT: + return None + + iele = np.where(COULOMB_COL_QUANT == quant) + tg = obj.get_var('tg') + nel = np.copy(obj.get_var('ne')) # already takes into account NEQ (SI) + elem = quant.replace('coucol', '') + + const = (obj.uni.pi * obj.uni.qsi_electron ** 4 / + ((4.0 * obj.uni.pi * obj.uni.permsi)**2 * + np.sqrt(obj.uni.weightdic[elem] * obj.uni.amusi * + (2.0 * obj.uni.ksi_b) ** 3) + 1.0e-20)) + + return (const * nel.astype('Float64') * + np.log(12.0 * obj.uni.pi * nel.astype('Float64') * + obj.get_var('debye_ln').astype('Float64') + 1e-50) / + (np.sqrt(tg.astype('Float64')**3) + 1.0e-20)) + + +# default +_CURRENT_QUANT = ('CURRENT_QUANT', ['ix', 'iy', 'iz', 'wx', 'wy', 'wz']) +# get value + + +@document_vars.quant_tracking_simple(_CURRENT_QUANT[0]) +def get_current(obj, quant, CURRENT_QUANT=None, **kwargs): + ''' + Calculates currents (bifrost units) or + rotational components of the velocity + ''' + if CURRENT_QUANT is None: + CURRENT_QUANT = _CURRENT_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _CURRENT_QUANT[0], CURRENT_QUANT, get_current.__doc__) + docvar('ix', 'component x of the current') + docvar('iy', 'component y of the current') + docvar('iz', 'component z of the current') + docvar('wx', 'component x of the rotational of the velocity') + docvar('wy', 'component y of the rotational of the velocity') + docvar('wz', 'component z of the rotational of the velocity') + + if (quant == '') or not quant in CURRENT_QUANT: + return None + + # Calculate derivative of quantity + axis = quant[-1] + if quant[0] == 'i': + q = 'b' + else: + q = 'u' + try: + getattr(obj, quant) + except AttributeError: + if axis == 'x': + varsn = ['z', 'y'] + derv = ['dydn', 'dzdn'] + elif axis == 'y': + varsn = ['x', 'z'] + derv = ['dzdn', 'dxdn'] + elif axis == 'z': + varsn = ['y', 'x'] + derv = ['dxdn', 'dydn'] + + # 2D or close + if (getattr(obj, 'n' + varsn[0]) < 5) or (getattr(obj, 'n' + varsn[1]) < 5): + return obj.zero() + else: + return (obj.get_var('d' + q + varsn[0] + derv[0]) - + obj.get_var('d' + q + varsn[1] + derv[1])) + + +# default +_FLUX_QUANT = ('FLUX_QUANT', + ['pfx', 'pfy', 'pfz', + 'pfex', 'pfey', 'pfez', + 'pfwx', 'pfwy', 'pfwz'] + ) +# get value + + +@document_vars.quant_tracking_simple(_FLUX_QUANT[0]) +def get_flux(obj, quant, FLUX_QUANT=None, **kwargs): + ''' + Computes flux + ''' + if FLUX_QUANT is None: + FLUX_QUANT = _FLUX_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _FLUX_QUANT[0], FLUX_QUANT, get_flux.__doc__) + docvar('pfx', 'component x of the Poynting flux') + docvar('pfy', 'component y of the Poynting flux') + docvar('pfz', 'component z of the Poynting flux') + docvar('pfex', 'component x of the Flux emergence') + docvar('pfey', 'component y of the Flux emergence') + docvar('pfez', 'component z of the Flux emergence') + docvar('pfwx', 'component x of the Poynting flux from "horizontal" motions') + docvar('pfwy', 'component y of the Poynting flux from "horizontal" motions') + docvar('pfwz', 'component z of the Poynting flux from "horizontal" motions') + + if (quant == '') or not quant in FLUX_QUANT: + return None + + axis = quant[-1] + if axis == 'x': + varsn = ['z', 'y'] + elif axis == 'y': + varsn = ['x', 'z'] + elif axis == 'z': + varsn = ['y', 'x'] + if 'pfw' in quant or len(quant) == 3: + var = - obj.get_var('b' + axis + 'c') * ( + obj.get_var('u' + varsn[0] + 'c') * + obj.get_var('b' + varsn[0] + 'c') + + obj.get_var('u' + varsn[1] + 'c') * + obj.get_var('b' + varsn[1] + 'c')) + else: + var = obj.zero() + if 'pfe' in quant or len(quant) == 3: + var += obj.get_var('u' + axis + 'c') * ( + obj.get_var('b' + varsn[0] + 'c')**2 + + obj.get_var('b' + varsn[1] + 'c')**2) + return var + + +# default +_POND_QUANT = ('POND_QUANT', + ['pond'] + ) +# get value + + +@document_vars.quant_tracking_simple(_POND_QUANT[0]) +def get_ponderomotive(obj, quant, POND_QUANT=None, **kwargs): + ''' + Computes flux + ''' + if POND_QUANT is None: + POND_QUANT = _POND_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _POND_QUANT[0], POND_QUANT, get_flux.__doc__) + docvar('pond', 'Ponderomotive aceleration along the field lines') + + if (quant == '') or not quant in POND_QUANT: + return None + + bxc = obj.get_var('bxc') + byc = obj.get_var('byc') + bzc = obj.get_var('bzc') + + nx, ny, nz = bxc.shape + + b2 = bxc**2 + byc**2 + bzc**2 + + ubx = obj.get_var('uyc')*bzc - obj.get_var('uzc')*byc + uby = obj.get_var('uxc')*bzc - obj.get_var('uzc')*bxc + ubz = obj.get_var('uxc')*byc - obj.get_var('uyc')*bxc + + xl, yl, zl = calc_field_lines(obj.x[::2], obj.y, obj.z[::2], bxc[::2, :, ::2], byc[::2, :, ::2], bzc[::2, :, ::2], niter=501) + #S = calc_lenghth_lines(xl, yl, zl) + ixc = obj.get_var('ixc') + iyc = obj.get_var('iyc') + izc = obj.get_var('izc') + + dex = - ubx + ixc + dey = - uby + iyc + dez = - ubz + izc + + dpond = (dex**2 + dey**2 + dez**2) / b2 + + ibxc = bxc / (np.sqrt(b2)+1e-30) + ibyc = byc / (np.sqrt(b2)+1e-30) + ibzc = bzc / (np.sqrt(b2)+1e-30) + + return do_stagger(dpond, 'ddxdn', obj=obj)*ibxc +\ + do_stagger(dpond, 'ddydn', obj=obj)*ibyc +\ + do_stagger(dpond, 'ddzdn', obj=obj)*ibzc + + +# default +_PLASMA_QUANT = ('PLASMA_QUANT', + ['beta', 'beta_ion', 'va', 'cs', 's', 'ke', 'mn', 'man', 'hp', 'nr', + 'vax', 'vay', 'vaz', 'hx', 'hy', 'hz', 'kx', 'ky', 'kz', + ] + ) +# get value + + +@document_vars.quant_tracking_simple(_PLASMA_QUANT[0]) +def get_plasmaparam(obj, quant, PLASMA_QUANT=None, **kwargs): + ''' + Adimensional parameters for single fluid + ''' + if PLASMA_QUANT is None: + PLASMA_QUANT = _PLASMA_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _PLASMA_QUANT[0], PLASMA_QUANT, get_plasmaparam.__doc__) + docvar('beta', "plasma beta: P / (B / (2 mu0)), where P is the single-fluid pressure.") + docvar('beta_ion', "plasma beta: Pi / (B / (2 mu0)), where Pi is the pressure from ions, only.") + docvar('va', "alfven speed [simu. units]") + docvar('cs', "sound speed [simu. units]") + docvar('s', "entropy [log of quantities in simu. units]") + docvar('ke', "kinetic energy density of ifluid [simu. units]") + docvar('mn', "mach number (using sound speed)") + docvar('man', "mach number (using alfven speed)") + docvar('hp', "Pressure scale height") + docvar('nr', "total number density (including neutrals) [simu. units].", uni=UNI_nr) + for var in ['vax', 'vay', 'vaz']: + docvar(var, "{axis} component of alfven velocity [simu. units]".format(axis=var[-1])) + for var in ['kx', 'ky', 'kz']: + docvar(var, ("{axis} component of kinetic energy density of ifluid [simu. units]." + + "(0.5 * rho * (get_var(u{axis})**2)").format(axis=var[-1])) + + if (quant == '') or not quant in PLASMA_QUANT: + return None + + if quant in ['hp', 's', 'cs', 'beta']: + var = obj.get_var('p') + if quant == 'hp': + if getattr(obj, 'nx') < 5: + return obj.zero() + else: + return 1. / (do_stagger(var, 'ddzup', obj=obj) + 1e-12) + elif quant == 'cs': + return np.sqrt(obj.params['gamma'][obj.snapInd] * + var / obj.get_var('r')) + elif quant == 's': + return (np.log(var) - obj.params['gamma'][obj.snapInd] * + np.log(obj.get_var('r'))) + elif quant == 'beta': + return 2 * var / obj.get_var('b2') + + elif quant == 'beta_ion': + ni = obj.get_var('nion') / (obj.uni.cm_to_m**3) # [1/m^3] + kB = obj.uni.ksi_b + Ti = obj.get_var('tg') # [K] + Pi = ni * kB * Ti # [SI pressure units] + B2 = obj.get_var('b2') * obj.uni.usi_b ** 2 # [T^2] + mu0 = obj.uni.mu0si + return Pi / (B2 / (2 * mu0)) # [dimensionless] + + if quant in ['mn', 'man']: + var = obj.get_var('modu') + if quant == 'mn': + return var / (obj.get_var('cs') + 1e-12) + else: + return var / (obj.get_var('va') + 1e-12) + + if quant in ['va', 'vax', 'vay', 'vaz']: + var = obj.get_var('r') + if len(quant) == 2: + return obj.get_var('modb') / np.sqrt(var) + else: + axis = quant[-1] + return np.sqrt(obj.get_var('b' + axis + 'c') ** 2 / var) + + if quant in ['hx', 'hy', 'hz', 'kx', 'ky', 'kz']: + axis = quant[-1] + var = obj.get_var('p' + axis + 'c') + if quant[0] == 'h': + return ((obj.get_var('e') + obj.get_var('p')) / + obj.get_var('r') * var) + else: + return obj.get_var('u2') * var * 0.5 + + if quant in ['ke']: + var = obj.get_var('r') + return obj.get_var('u2') * var * 0.5 + + if quant == 'nr': + r = obj.get_var('r') + r = r.astype('float64') # if r close to 1, nr will be huge in simu units. use float64 to avoid infs. + nr_H = r / obj.uni.simu_amu # nr [simu. units] if only species is H. + return nr_H * obj.uni.mu # mu is correction factor since plasma isn't just H. + + +# default +_WAVE_QUANT = ('WAVE_QUANT', ['alf', 'fast', 'long']) +# get value + + +@document_vars.quant_tracking_simple(_WAVE_QUANT[0]) +def get_wavemode(obj, quant, WAVE_QUANT=None, **kwargs): + ''' + computes waves modes + ''' + if WAVE_QUANT is None: + WAVE_QUANT = _WAVE_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _WAVE_QUANT[0], WAVE_QUANT, get_wavemode.__doc__) + docvar('alf', "Alfven wave component [simu units]") + docvar('fast', "fast wave component [simu units]") + docvar('long', "longitudinal wave component [simu units]") + + if (quant == '') or not quant in WAVE_QUANT: + return None + + bx = obj.get_var('bxc') + by = obj.get_var('byc') + bz = obj.get_var('bzc') + bMag = np.sqrt(bx**2 + by**2 + bz**2) + bx, by, bz = bx / bMag, by / bMag, bz / bMag # b is already centered + # unit vector of b + unitB = np.stack((bx, by, bz)) + + if quant == 'alf': + uperb = obj.get_var('uperb') + uperbVect = uperb * unitB + # cross product (uses cstagger bc no variable gets uperbVect) + curlX = (do_stagger(do_stagger(uperbVect[2], 'ddydn', obj=obj), 'yup', obj=obj) - + do_stagger(do_stagger(uperbVect[1], 'ddzdn', obj=obj), 'zup', obj=obj)) + curlY = (-do_stagger(do_stagger(uperbVect[2], 'ddxdn', obj=obj), 'xup', obj=obj) + + do_stagger(do_stagger(uperbVect[0], 'ddzdn', obj=obj), 'zup', obj=obj)) + curlZ = (do_stagger(do_stagger(uperbVect[1], 'ddxdn', obj=obj), 'xup', obj=obj) - + do_stagger(do_stagger(uperbVect[0], 'ddydn', obj=obj), 'yup', obj=obj)) + curl = np.stack((curlX, curlY, curlZ)) + # dot product + result = np.abs((unitB * curl).sum(0)) + elif quant == 'fast': + uperb = obj.get_var('uperb') + uperbVect = uperb * unitB + + result = np.abs(do_stagger(do_stagger( + uperbVect[0], 'ddxdn', obj=obj), 'xup', obj=obj) + do_stagger(do_stagger( + uperbVect[1], 'ddydn', obj=obj), 'yup', obj=obj) + do_stagger( + do_stagger(uperbVect[2], 'ddzdn', obj=obj), 'zup', obj=obj)) + else: + dot1 = obj.get_var('uparb') + grad = np.stack((do_stagger(do_stagger(dot1, 'ddxdn', obj=obj), + 'xup', obj=obj), do_stagger(do_stagger(dot1, 'ddydn', obj=obj), 'yup', obj=obj), + do_stagger(do_stagger(dot1, 'ddzdn', obj=obj), 'zup', obj=obj))) + result = np.abs((unitB * grad).sum(0)) + return result + + +# default +_CYCL_RES = ('CYCL_RES', ['n6nhe2', 'n6nhe3', 'nhe2nhe3']) +# get value + + +@document_vars.quant_tracking_simple(_CYCL_RES[0]) +def get_cyclo_res(obj, quant, CYCL_RES=None, **kwargs): + ''' + esonant cyclotron frequencies + (only for do_helium) are (SI units) + ''' + if (CYCL_RES is None): + CYCL_RES = _CYCL_RES[1] + + if quant == '': + document_vars.vars_documenter(obj, _CYCL_RES[0], CYCL_RES, get_cyclo_res.__doc__) + + if (quant == '') or not quant in CYCL_RES: + return None + + if obj.hion and obj.heion: + posn = ([pos for pos, char in enumerate(quant) if char == 'n']) + q2 = quant[posn[-1]:] + q1 = quant[:posn[-1]] + nel = obj.get_var('ne')/1e6 # already takes into account NEQ converted to cgs + var2 = obj.get_var(q2) + var1 = obj.get_var(q1) + z1 = 1.0 + z2 = float(quant[-1]) + if q1[:3] == 'n6': + omega1 = obj.get_var('gfh2') + else: + omega1 = obj.get_var('gf'+q1[1:]) + omega2 = obj.get_var('gf'+q2[1:]) + return (z1 * var1 * omega2 + z2 * var2 * omega1) / nel + else: + raise ValueError(('get_quantity: This variable is only ' + 'avaiable if do_hion and do_helium is true')) + + +# default +_GYROF_QUANT0 = ('GYROF_QUANT') +# get value + + +@document_vars.quant_tracking_simple(_GYROF_QUANT0) +def get_gyrof(obj, quant, GYROF_QUANT=None, **kwargs): + ''' + gyro freqency are (Hz) + gf+ ionization state + ''' + + if (GYROF_QUANT is None): + GYROF_QUANT = obj.GYROF_QUANT + + if quant == '': + document_vars.vars_documenter(obj, _GYROF_QUANT0, GYROF_QUANT, get_gyrof.__doc__) + + if (quant == '') or not ''.join([i for i in quant if not i.isdigit()]) in GYROF_QUANT: + return None + + if quant == 'gfe': + return obj.get_var('modb') * obj.uni.usi_b * \ + obj.uni.qsi_electron / (obj.uni.msi_e) + else: + ion_level = ''.join([i for i in quant if i.isdigit()]) # 1-indexed ionization level (e.g. H+ --> ion_level=2) + assert ion_level != '', "Expected 'gf' with A an element, N a number (ionization level), but got '{quant}'".format(quant) + ion_Z = float(ion_level) - 1.0 # 0-indexed ionization level. (e.g. H+ --> ion_Z = 1. He++ --> ion_Z=2.) + return obj.get_var('modb') * obj.uni.usi_b * \ + obj.uni.qsi_electron * ion_Z / \ + (obj.uni.weightdic[quant[2:-1]] * obj.uni.amusi) + + +# default +#_KAPPA_QUANT = ['kappa' + elem for elem in ELEMLIST] +#_KAPPA_QUANT = ['kappanorm_', 'kappae'] + _KAPPA_QUANT +# I suspect that ^^^ should be kappanorm_ + elem for elem in ELEMLIST, +# but I don't know what kappanorm is supposed to mean, so I'm not going to change it now. -SE June 28 2021 +_KAPPA_QUANT0 = ('KAPPA_QUANT') +# set value + + +@document_vars.quant_tracking_simple(_KAPPA_QUANT0) +def get_kappa(obj, quant, KAPPA_QUANT=None, **kwargs): + ''' + kappa, i.e., magnetization (adimensional) + at the end it must have the ionization + ''' + + if (KAPPA_QUANT is None): + KAPPA_QUANT = obj.KAPPA_QUANT + + if quant == '': + document_vars.vars_documenter(obj, _KAPPA_QUANT0, KAPPA_QUANT, get_kappa.__doc__) + + if (quant == ''): + return None + + if ''.join([i for i in quant if not i.isdigit()]) in KAPPA_QUANT: + if quant == 'kappae': + return obj.get_var('gfe') / (obj.get_var('nu_en') + 1e-28) + else: + elem = quant.replace('kappa', '') + return obj.get_var('gf'+elem) / (obj.get_var('nu'+elem+'_n') + 1e-28) + + elif quant[:-1] in KAPPA_QUANT or quant[:-2] in KAPPA_QUANT: + elem = quant.split('_') + return obj.get_var('kappah2')**2/(obj.get_var('kappah2')**2 + 1) - \ + obj.get_var('kappa%s2' % elem[1])**2 / \ + (obj.get_var('kappa%s2' % elem[1])**2 + 1) + else: + return None + + +# default +_DEBYE_LN_QUANT = ('DEBYE_LN_QUANT', ['debye_ln']) +# set value + + +@document_vars.quant_tracking_simple(_DEBYE_LN_QUANT[0]) +def get_debye_ln(obj, quant, DEBYE_LN_QUANT=None, **kwargs): + ''' + Computes Debye length in ... units + ''' + + if (DEBYE_LN_QUANT is None): + DEBYE_LN_QUANT = _DEBYE_LN_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _DEBYE_LN_QUANT[0], DEBYE_LN_QUANT, get_debye_ln.__doc__) + docvar('debye_ln', "Debye length [u.u_l]") + + if (quant == '') or not quant in DEBYE_LN_QUANT: + return None + + tg = obj.get_var('tg') + part = np.copy(obj.get_var('ne')) + # We are assuming a single charge state: + for iele in obj.ELEMLIST: + part += obj.get_var('n' + iele + '-2') + if obj.heion: + part += 4.0 * obj.get_var('nhe3') + # check units of n + return np.sqrt(obj.uni.permsi / obj.uni.qsi_electron**2 / + (obj.uni.ksi_b * tg.astype('float64') * + part.astype('float64') + 1.0e-20)) + + +# default +_IONP_QUANT0 = ('IONP_QUANT') +# set value + + +@document_vars.quant_tracking_simple(_IONP_QUANT0) +def get_ionpopulations(obj, quant, IONP_QUANT=None, **kwargs): + ''' + densities for specific ionized species. + For example, nc-1 gives number density of neutral carbon, in cm^-3. nc-2 is for once-ionized carbon. + ''' + if (IONP_QUANT is None): + IONP_QUANT = obj.IONP_QUANT + + if quant == '': + document_vars.vars_documenter(obj, _IONP_QUANT0, IONP_QUANT, get_ionpopulations.__doc__) + + if (quant == ''): + return None + + if ((quant in IONP_QUANT) and (quant[-3:] in ['ion', 'neu'])): + if 'ion' in quant: + lvl = '2' + else: + lvl = '1' + result = obj.zero() + for ielem in obj.ELEMLIST: + result += obj.get_var(quant[0]+ielem+'-'+lvl) + return result + + elif ((quant in IONP_QUANT) and (quant[-9:] in ['ion_nomag', 'neu_nomag'])): + # I dont think it makes sence to have neu_nomag + if 'ion' in quant: + lvl = '2' + else: + lvl = '1' + result = obj.zero() + if quant[-7:] == 'ion_nomag': + for ielem in obj.ELEMLIST[2:]: + result += obj.get_var(quant[0]+ielem+'-'+lvl) * \ + (1-obj.get_var('kappanorm_%s' % ielem)) + else: + for ielem in obj.ELEMLIST[2:]: + result += obj.get_var(quant[0]+ielem+'-'+lvl) * \ + (1-obj.get_var('kappanorm_%s' % ielem)) + return result + + elif (quant == 'nelc'): + + result = obj.zero() + for ielem in obj.ELEMLIST: + result += obj.get_var('n'+ielem+'-2') + + result += obj.get_var('nhe-3')*2 + + return result + + elif ''.join([i for i in quant if not i.isdigit()]) in IONP_QUANT: + elem = quant.replace('-', '') + spic = ''.join([i for i in elem if not i.isdigit()]) + lvl = ''.join([i for i in elem if i.isdigit()]) + + if obj.hion and spic[1:] == 'h': + if quant[0] == 'n': + mass = 1.0 + else: + mass = obj.uni.m_h + if lvl == '1': + return mass * (obj.get_var('n1') + obj.get_var('n2') + obj.get_var('n3') + + obj.get_var('n4') + obj.get_var('n5')) + else: + return mass * obj.get_var('n6') + + elif obj.heion and spic[1:] == 'he': + if quant[0] == 'n': + mass = 1.0 + else: + mass = obj.uni.m_he + if obj.verbose: + print('get_var: reading nhe%s' % lvl, whsp*5, end="\r", + flush=True) + return mass * obj.get_var('nhe%s' % lvl) + + else: + sel_units = obj.sel_units + obj.sel_units = 'cgs' + rho = obj.get_var('rho') + nel = np.copy(obj.get_var('ne')) # cgs + tg = obj.get_var('tg') + obj.sel_units = sel_units + + if quant[0] == 'n': + dens = False + else: + dens = True + + return ionpopulation(obj, rho, nel, tg, elem=spic[1:], lvl=lvl, dens=dens) # cm^3 + else: + return None + + +# default +_AMB_QUANT = ('AMB_QUANT', + ['uambx', 'uamby', 'uambz', 'ambx', 'amby', 'ambz', + 'eta_amb1', 'eta_amb2', 'eta_amb3', 'eta_amb4', 'eta_amb5', + 'nchi', 'npsi', 'nchi_red', 'npsi_red', + 'rchi', 'rpsi', 'rchi_red', 'rpsi_red', 'alphai', 'betai'] + ) +# set value + + +@document_vars.quant_tracking_simple(_AMB_QUANT[0]) +def get_ambparam(obj, quant, AMB_QUANT=None, **kwargs): + ''' + ambipolar velocity or related terms + ''' + if (AMB_QUANT is None): + AMB_QUANT = _AMB_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _AMB_QUANT[0], AMB_QUANT, get_ambparam.__doc__) + docvar('uambx', "component x of the ambipolar velocity") + docvar('uamby', "component y of the ambipolar velocity") + docvar('uambz', "component z of the ambipolar velocity") + docvar('ambx', "component x of the ambipolar term") + docvar('amby', "component y of the ambipolar term") + docvar('ambz', "component z of the ambipolar term") + docvar('eta_amb1', "ambipolar diffusion using nu_ni") + docvar('eta_amb2', "ambipolar diffusion using nu_in") + docvar('eta_amb3', "ambipolar diffusion using nu_ni_max and rion_nomag") + docvar('eta_amb4', "ambipolar diffusion using Yakov for low ionization regime, Eq (20) (ref{Faraday_corr})") + docvar('eta_amb4a', "ambipolar diffusion using Yakov for low ionization regime, Eq (20) (ref{Faraday_corr}), only the numerator") + docvar('eta_amb4b', "ambipolar diffusion using Yakov for low ionization regime, Eq (20) (ref{Faraday_corr}), only the denumerator") + docvar('eta_amb5', "ambipolar diffusion using Yakov for any ionization regime, 7c") + docvar('nchi', "from Yakov notes to derive the ambipolar diff") + docvar('npsi', "from Yakov notes to derive the ambipolar diff") + docvar('nchi_red', "from Yakov notes to derive the ambipolar diff") + docvar('npsi_red', "from Yakov notes to derive the ambipolar diff") + docvar('rchi', "from Yakov notes to derive the ambipolar diff") + docvar('rpsi', "from Yakov notes to derive the ambipolar diff") + docvar('rchi_red', "from Yakov notes to derive the ambipolar diff") + docvar('rpsi_red', "from Yakov notes to derive the ambipolar diff") + docvar('alphai', "from Yakov notes to derive the ambipolar diff") + docvar('betai', "from Yakov notes to derive the ambipolar diff") + + if (quant == '') or not (quant in AMB_QUANT): + return None + + if obj.sel_units == 'cgs': + u_b = 1.0 + else: + u_b = obj.uni.u_b + + axis = quant[-1] + if quant == 'eta_amb1': # version from other + result = (obj.get_var('rneu') / obj.get_var('rho') * u_b)**2 + result /= (4.0 * obj.uni.pi * obj.get_var('nu_ni', **kwargs) + 1e-20) + result *= obj.get_var('b2') # / 1e7 + + # This should be the same and eta_amb2 except that eta_amb2 has many more species involved. + elif quant == 'eta_amb2': + result = (obj.get_var('rneu') / obj.get_var('rho') * u_b)**2 / ( + 4.0 * obj.uni.pi * obj.get_var('nu_in', **kwargs) + 1e-20) + result *= obj.get_var('b2') # / 1e7 + + elif quant == 'eta_amb3': # This version takes into account the magnetization + result = ((obj.get_var('rneu') + obj.get_var('rion_nomag')) / obj.r * obj.uni.u_b)**2 / ( + 4.0 * obj.uni.pi * obj.get_var('nu_ni_mag') + 1e-20) + result *= obj.get_var('b2') # / 1e7 + + # Yakov for low ionization regime, Eq (20) (ref{Faraday_corr}) + elif quant == 'eta_amb4': + psi = obj.get_var('npsi') + chi = obj.get_var('nchi') + + result = obj.get_var('modb') * obj.uni.u_b * (psi / (1e2 * (psi**2 + chi**2)) - 1.0 / ( + obj.get_var('nelc') * obj.get_var('kappae') * 1e2 + 1e-20)) + + # Yakov for any ionization regime, 7c + elif quant == 'eta_amb5': + psi = obj.get_var('npsi') + chi = obj.get_var('nchi') + + chi = obj.r*0.0 + chif = obj.r*0.0 + psi = obj.r*0.0 + psif = obj.r*0.0 + eta = obj.r*0.0 + kappae = obj.get_var('kappae') + + for iele in obj.ELEMLIST: + kappaiele = obj.get_var('kappa'+iele+'2') + chi += (kappae + kappaiele) * ( + kappae - kappaiele) / ( + 1.0 + kappaiele**2) / ( + 1.0 + kappae**2) * obj.get_var('n'+iele+'-2') + chif += obj.get_var('r'+iele+'-2') * kappaiele / ( + 1.0 + kappaiele**2) + psif += obj.get_var('r'+iele+'-2') / ( + 1.0 + kappaiele**2) + psi += (kappae + kappaiele) * ( + 1.0 + kappae * kappaiele) / ( + 1.0 + kappaiele**2) / ( + 1.0 + kappae**2) * obj.get_var('n'+iele+'-2') + eta += (kappae + kappaiele) * obj.get_var('n'+iele+'-2') + + result = obj.get_var('modb') * obj.uni.u_b * (1.0 / ((psi**2 + chi**2) * obj.r) * (chi * chif - psi * ( + obj.get_var('rneu')+psif)) - 1.0 / (eta+1e-28)) + + elif quant == 'eta_amb4a': + psi = obj.get_var('npsi') + chi = obj.get_var('nchi') + + result = obj.get_var('modb') * obj.uni.u_b * (psi / (psi**2 + chi**2) + 1e-20) + + elif quant == 'eta_amb4b': + + result = obj.get_var('modb') * obj.uni.u_b * (1.0 / ( + obj.get_var('hionne') / 1e6 * obj.get_var('kappae') + 1e-20)) + + elif quant in ['nchi', 'rchi']: + result = obj.r*0.0 + kappae = obj.get_var('kappae') + + for iele in obj.ELEMLIST: + result += (kappae + obj.get_var('kappa'+iele+'2')) * ( + kappae - obj.get_var('kappa'+iele+'2')) / ( + 1.0 + obj.get_var('kappa'+iele+'2')**2) / ( + 1.0 + kappae**2) * obj.get_var(quant[0]+iele+'-2') + + elif quant in ['npsi', 'rpsi']: # Yakov, Eq () + result = obj.r*0.0 + kappae = obj.get_var('kappae') + + for iele in obj.ELEMLIST: + result += (kappae + obj.get_var('kappa'+iele+'2')) * ( + 1.0 + kappae * obj.get_var('kappa'+iele+'2')) / ( + 1.0 + obj.get_var('kappa'+iele+'2')**2) / ( + 1.0 + kappae**2) * obj.get_var(quant[0]+iele+'-2') + + elif quant == 'alphai': + result = obj.r*0.0 + kappae = obj.get_var('kappae') + + for iele in obj.ELEMLIST: + result += (kappae + obj.get_var('kappa'+iele+'2')) * ( + kappae - obj.get_var('kappa'+iele+'2')) / ( + 1.0 + obj.get_var('kappa'+iele+'2')**2) / ( + 1.0 + kappae**2) + + elif quant == 'betai': # Yakov, Eq () + result = obj.r*0.0 + + for iele in obj.ELEMLIST: + result += (obj.get_var('kappae') + obj.get_var('kappa'+iele+'2')) * ( + 1.0 + obj.get_var('kappae') * obj.get_var('kappa'+iele+'2')) / ( + 1.0 + obj.get_var('kappa'+iele+'2')**2) / ( + 1.0 + obj.get_var('kappae')**2) + + elif quant in ['nchi_red', 'rchi_red']: # alpha + result = obj.r*0.0 + + for iele in obj.ELEMLIST: + result += 1.0 / (1.0 + obj.get_var('kappa'+iele+'2')**2) *\ + obj.get_var(quant[0]+iele+'-2') + + elif quant in ['npsi_red', 'rpsi_red']: # beta + result = obj.r*0.0 + + for iele in obj.ELEMLIST: + result += obj.get_var('kappa'+iele+'2') / ( + 1.0 + obj.get_var('kappa'+iele+'2')**2) * \ + obj.get_var(quant[0]+iele+'-2') + + elif quant[0] == 'u': + result = obj.get_var('itimesb' + quant[-1]) / \ + obj.get_var('b2') * obj.get_var('eta_amb') + + elif (quant[-4:-1] == 'amb' and quant[-1] in ['x', 'y', 'z'] and + quant[1:3] != 'chi' and quant[1:3] != 'psi'): + + axis = quant[-1] + if axis == 'x': + varsn = ['y', 'z'] + elif axis == 'y': + varsn = ['z', 'y'] + elif axis == 'z': + varsn = ['x', 'y'] + result = (obj.get_var('itimesb' + varsn[0]) * + obj.get_var('b' + varsn[1] + 'c') - + obj.get_var('itimesb' + varsn[1]) * + obj.get_var('b' + varsn[0] + 'c')) / obj.get_var('b2') * obj.get_var('eta_amb') + + return result + + +# default +_HALL_QUANT = ('HALL_QUANT', + ['uhallx', 'uhally', 'uhallz', 'hallx', 'hally', 'hallz', + 'eta_hall', 'eta_hallb'] + ) +# set value + + +@document_vars.quant_tracking_simple(_HALL_QUANT[0]) +def get_hallparam(obj, quant, HALL_QUANT=None, **kwargs): + ''' + Hall velocity or related terms + ''' + if (HALL_QUANT is None): + HALL_QUANT = _HALL_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _HALL_QUANT[0], HALL_QUANT, get_hallparam.__doc__) + docvar('uhallx', "component x of the Hall velocity") + docvar('uhally', "component y of the Hall velocity") + docvar('uhallz', "component z of the Hall velocity") + docvar('hallx', "component x of the Hall term") + docvar('hally', "component y of the Hall term") + docvar('hallz', "component z of the Hall term") + docvar('eta_hall', "Hall term ") + docvar('eta_hallb', "Hall term / B") + + if (quant == '') or not (quant in HALL_QUANT): + return None + + if quant[0] == 'u': + try: + result = obj.get_var('i' + quant[-1]) + except Exception: + result = obj.get_var('rotb' + quant[-1]) + elif quant == 'eta_hall': + nel = obj.get_var('nel') + result = (obj.uni.clight)*(obj.uni.u_b) / (4.0 * obj.uni.pi * obj.uni.q_electron * nel) + result = obj.get_var('modb')*result / obj.uni.u_l/obj.uni.u_l*obj.uni.u_t + + elif quant == 'eta_hallb': + nel = obj.get_var('nel') + result = (obj.uni.clight)*(obj.uni.u_b) / (4.0 * obj.uni.pi * obj.uni.q_electron * nel) + result = result / obj.uni.u_l/obj.uni.u_l*obj.uni.u_t + + else: + result = obj.get_var('itimesb_' + quant[-1]) / obj.get_var('modb') + + return result # obj.get_var('eta_hall') * result + + +# default +_BATTERY_QUANT = ('BATTERY_QUANT', + ['bb_constqe', 'dxpe', 'dype', 'dzpe', + 'bb_batx', 'bb_baty', 'bb_batz'] + ) +# set value + + +@document_vars.quant_tracking_simple(_BATTERY_QUANT[0]) +def get_batteryparam(obj, quant, BATTERY_QUANT=None, **kwargs): + ''' + Related battery terms + ''' + if (BATTERY_QUANT is None): + BATTERY_QUANT = _BATTERY_QUANT[1] + + if quant == '': + docvar = document_vars.vars_documenter(obj, _BATTERY_QUANT[0], BATTERY_QUANT, get_batteryparam.__doc__) + docvar('bb_constqe', "constant coefficient involved in the battery term") + docvar('dxpe', "Gradient of electron pressure in the x direction [simu.u_p/simu.u_l]") + docvar('dype', "Gradient of electron pressure in the y direction [simu.u_p/simu.u_l]") + docvar('dzpe', "Gradient of electron pressure in the z direction [simu.u_p/simu.u_l]") + docvar('bb_batx', "Component of the battery term in the x direction, (1/ne qe)*dx(pe)") + docvar('bb_baty', "Component of the battery term in the y direction, (1/ne qe)*dy(pe)") + docvar('bb_batz', "Component of the battery term in the z direction, (1/ne qe)*dz(pe)") + + if (quant == '') or not (quant in BATTERY_QUANT): + return None + + if quant == 'bb_constqe': + const = (obj.uni.usi_p / obj.uni.qsi_electron / (1.0/((obj.uni.cm_to_m)**3)) / obj.uni.usi_l / (obj.uni.usi_b * obj.uni.usi_l/obj.uni.u_t)) # /obj.uni.u_p + result = const + + if quant == 'bb_batx': + gradx_pe = obj.get_var('dpedxup') # obj.get_var('d' + pe + 'dxdn') + nel = obj.get_var('nel') + bb_constqe = obj.uni.usi_p / obj.uni.qsi_electron / (1.0/((obj.uni.cm_to_m)**3)) / obj.uni.usi_l / (obj.uni.usi_b * obj.uni.usi_l/obj.uni.u_t) # /obj.uni.u_p + bb_batx = gradx_pe / (nel * bb_constqe) + result = bb_batx + + if quant == 'bb_baty': + grady_pe = obj.get_var('dpedyup') # obj.get_var('d' + pe + 'dxdn') + nel = obj.get_var('nel') + bb_constqe = obj.uni.usi_p / obj.uni.qsi_electron / (1.0/((obj.uni.cm_to_m)**3)) / obj.uni.usi_l / (obj.uni.usi_b * obj.uni.usi_l/obj.uni.u_t) # /obj.uni.u_p + bb_baty = grady_pe / (nel * bb_constqe) + result = bb_baty + + if quant == 'bb_batz': + gradz_pe = obj.get_var('dpedzup') # obj.get_var('d' + pe + 'dxdn') + nel = obj.get_var('nel') + bb_constqe = obj.uni.usi_p / obj.uni.qsi_electron / (1.0/((obj.uni.cm_to_m)**3)) / obj.uni.usi_l / (obj.uni.usi_b * obj.uni.usi_l/obj.uni.u_t) # /obj.uni.u_p + bb_batz = gradz_pe / (nel * bb_constqe) + result = bb_batz + return result + + +# default +_SPITZER_QUANT = ('SPITZER_QUANT', ['fcx', 'fcy', 'fcz', 'qspitz']) +# set value + + +@document_vars.quant_tracking_simple(_BATTERY_QUANT[0]) +def get_spitzerparam(obj, quant, SPITZER_QUANT=None, **kwargs): + ''' + Spitzer related term + ''' + + if (SPITZER_QUANT is None): + SPITZER_QUANT = ['fcx', 'fcy', 'fcz', 'qspitz'] + + if quant == '': + docvar = document_vars.vars_documenter(obj, 'SPITZER_QUANT', SPITZER_QUANT, get_spitzerparam.__doc__) + docvar('fcx', "X component of the anisotropic electron heat flux, i.e., (kappae(B)*grad(Te))_x") + docvar('fcy', "Y component of the anisotropic electron heat flux, i.e., (kappae(B)*grad(Te))_y") + docvar('fcz', "Z component of the anisotropic electron heat flux, i.e., (kappae(B)*grad(Te))_z") + docvar('qspitz', "Electron heat flux, i.e., Qspitz [simu.u_e/simu.u_t] erg.s-1") + + if (quant == '') or not (quant in SPITZER_QUANT): + return None + + if (quant == 'fcx'): + kappaq = obj.get_var('kappaq') + gradx_Te = obj.get_var('detgdxup') + bx = obj.get_var('bx') + by = obj.get_var('by') + bz = obj.get_var('bz') + rhs = obj.get_var('rhs') + bmin = 1E-5 + + normb = np.sqrt(bx**2+by**2+bz**2) + norm2bmin = bx**2+by**2+bz**2+bmin**2 + + bbx = bx/normb + + bm = (bmin**2)/norm2bmin + + fcx = kappaq * (bbx*rhs+bm*gradx_Te) + + result = fcx + + if (quant == 'fcy'): + kappaq = obj.get_var('kappaq') + grady_Te = obj.get_var('detgdyup') + bx = obj.get_var('bx') + by = obj.get_var('by') + bz = obj.get_var('bz') + rhs = obj.get_var('rhs') + bmin = 1E-5 + + normb = np.sqrt(bx**2+by**2+bz**2) + norm2bmin = bx**2+by**2+bz**2+bmin**2 + + bby = by/normb + + bm = (bmin**2)/norm2bmin + + fcy = kappaq * (bby*rhs+bm*grady_Te) + + result = fcy + + if (quant == 'fcz'): + kappaq = obj.get_var('kappaq') + gradz_Te = obj.get_var('detgdzup') + bx = obj.get_var('bx') + by = obj.get_var('by') + bz = obj.get_var('bz') + rhs = obj.get_var('rhs') + bmin = 1E-5 + + normb = np.sqrt(bx**2+by**2+bz**2) + norm2bmin = bx**2+by**2+bz**2+bmin**2 + + bbz = bz/normb + + bm = (bmin**2)/norm2bmin + + fcz = kappaq * (bbz*rhs+bm*gradz_Te) + + result = fcz + + if (quant == 'qspitz'): + dxfcx = obj.get_var('dfcxdxup') + dyfcy = obj.get_var('dfcydyup') + dzfcz = obj.get_var('dfczdzup') + result = dxfcx + dyfcy + dzfcz + + return result + + +''' ------------- End get_quant() functions; Begin helper functions ------------- ''' + + +@njit(parallel=True) +def calc_field_lines(x, y, z, bxc, byc, bzc, niter=501): + + modb = np.sqrt(bxc**2+byc**2+bzc**2) + + ibxc = bxc / (modb+1e-30) + ibyc = byc / (modb+1e-30) + ibzc = bzc / (modb+1e-30) + + nx, ny, nz = bxc.shape + niter2 = int(np.floor(niter/2)) + dx = x[1]-x[0] + zl = np.zeros((nx, ny, nz, niter)) + yl = np.zeros((nx, ny, nz, niter)) + xl = np.zeros((nx, ny, nz, niter)) + for iix in prange(nx): + for iiy in prange(ny): + for iiz in prange(nz): + + xl[iix, iiy, iiz, niter2] = x[iix] + yl[iix, iiy, iiz, niter2] = y[iiy] + zl[iix, iiy, iiz, niter2] = z[iiz] + + for iil in prange(1, niter2+1): + iixp = np.argmin(x-xl[iix, iiy, iiz, niter2 + iil - 1]) + iiyp = np.argmin(y-yl[iix, iiy, iiz, niter2 + iil - 1]) + iizp = np.argmin(z-zl[iix, iiy, iiz, niter2 + iil - 1]) + + xl[iix, iiy, iiz, niter2 + iil] = xl[iix, iiy, iiz, niter2 + iil - 1] + ibxc[iixp, iiyp, iizp]*dx + yl[iix, iiy, iiz, niter2 + iil] = yl[iix, iiy, iiz, niter2 + iil - 1] + ibyc[iixp, iiyp, iizp]*dx + zl[iix, iiy, iiz, niter2 + iil] = zl[iix, iiy, iiz, niter2 + iil - 1] + ibzc[iixp, iiyp, iizp]*dx + + iixm = np.argmin(x-xl[iix, iiy, iiz, niter2 - iil + 1]) + iiym = np.argmin(y-yl[iix, iiy, iiz, niter2 - iil + 1]) + iizm = np.argmin(z-zl[iix, iiy, iiz, niter2 - iil + 1]) + + xl[iix, iiy, iiz, niter2 - iil] = xl[iix, iiy, iiz, niter2 - iil + 1] - ibxc[iixm, iiym, iizm]*dx + yl[iix, iiy, iiz, niter2 - iil] = yl[iix, iiy, iiz, niter2 - iil + 1] - ibyc[iixm, iiym, iizm]*dx + zl[iix, iiy, iiz, niter2 - iil] = zl[iix, iiy, iiz, niter2 - iil + 1] - ibzc[iixm, iiym, iizm]*dx + + return xl, yl, zl + + +@njit(parallel=True) +def calc_lenghth_lines(xl, yl, zl): + + nx, ny, nz, nl = np.shape(xl) + + S = np.zeros((nx, ny, nz)) + + for iix in prange(nx): + for iiy in prange(ny): + for iiz in prange(nz): + iilmin = np.argmin(zl[iix, iiy, iiz, :]) # Corona + iilmax = np.argmin(np.abs(zl[iix, iiy, iiz, :])) # Photosphere + for iil in prange(iilmax+1, iilmin): + S[iix, iiy, iiz] += np.sqrt((xl[iix, iiy, iiz, iil]-xl[iix, iiy, iiz, iil-1])**2 + + (yl[iix, iiy, iiz, iil]-yl[iix, iiy, iiz, iil-1])**2 + + (zl[iix, iiy, iiz, iil]-zl[iix, iiy, iiz, iil-1])**2) + + return S + + +def calc_tau(obj): + """ + Calculates optical depth. + + """ + if obj.verbose: + warnings.warn("Use of calc_tau is discouraged. It is model-dependent, " + "inefficient and slow.") + + # grph = 2.38049d-24 uni.GRPH + # bk = 1.38e-16 uni.KBOLTZMANN + # EV_TO_ERG=1.60217733E-12 uni.EV_TO_ERG + + units_temp = obj.transunits + + nel = obj.trans2comm('ne') + tg = obj.trans2comm('tg') + rho = obj.trans2comm('rho') + + tau = obj.zero() + 1.e-16 + xhmbf = np.zeros((obj.zLength)) + const = (1.03526e-16 / obj.uni.grph) * 2.9256e-17 + for iix in range(obj.nx): + for iiy in range(obj.ny): + for iiz in range(obj.nz): + xhmbf[iiz] = const * nel[iix, iiy, iiz] / \ + tg[iix, iiy, iiz]**1.5 * np.exp(0.754e0 * + obj.uni.ev_to_erg / obj.uni.kboltzmann / + tg[iix, iiy, iiz]) * rho[iix, iiy, iiz] + + for iiz in range(obj.nz-1, 0, -1): + tau[iix, iiy, iiz] = tau[iix, iiy, iiz - 1] + 0.5 *\ + (xhmbf[iiz] + xhmbf[iiz - 1]) *\ + np.abs(obj.dz1d[iiz]) + + if not units_temp: + obj.trans2noncommaxes + + return tau + + +def ionpopulation(obj, rho, nel, tg, elem='h', lvl='1', dens=True, **kwargs): + ''' + rho is cgs. + tg in [K] + nel in cgs. + The output, is in cgs + ''' + + if getattr(obj, 'verbose', True): + print('ionpopulation: reading species %s and level %s' % (elem, lvl), whsp, + end="\r", flush=True) + ''' + fdir = '.' + try: + tmp = find_first_match("*.idl", fdir) + except IndexError: + try: + tmp = find_first_match("*idl.scr", fdir) + except IndexError: + try: + tmp = find_first_match("mhd.in", fdir) + except IndexError: + tmp = '' + print("(WWW) init: no .idl or mhd.in files found." + + "Units set to 'standard' Bifrost units.") + ''' + uni = obj.uni + + totconst = 2.0 * uni.pi * uni.m_electron * uni.k_b / \ + uni.hplanck / uni.hplanck + abnd = np.zeros(len(uni.abnddic)) + count = 0 + + for ibnd in uni.abnddic.keys(): + abnddic = 10**(uni.abnddic[ibnd] - 12.0) + abnd[count] = abnddic * uni.weightdic[ibnd] * uni.amu + count += 1 + + abnd = abnd / np.sum(abnd) + phit = (totconst * tg)**(1.5) * 2.0 / nel + kbtg = uni.ev_to_erg / uni.k_b / tg + n1_n0 = phit * uni.u1dic[elem] / uni.u0dic[elem] * np.exp( + - uni.xidic[elem] * kbtg) + c2 = abnd[uni.atomdic[elem] - 1] * rho + ifracpos = n1_n0 / (1.0 + n1_n0) + + if dens: + if lvl == '1': + return (1.0 - ifracpos) * c2 + else: + return ifracpos * c2 + + else: + if lvl == '1': + return (1.0 - ifracpos) * c2 / uni.weightdic[elem] / uni.amu + else: + return ifracpos * c2 / uni.weightdic[elem] / uni.amu + + +def find_first_match(name, path, incl_path=False, **kwargs): + ''' + This will find the first match, + name : string, e.g., 'patern*' + incl_root: boolean, if true will add full path, otherwise, the name. + path : sring, e.g., '.' + ''' + errmsg = ('find_first_match() from load_quantities has been deprecated. ' + 'If you believe it should not be deprecated, you can easily restore it by going to ' + 'helita.sim.load_quantities and doing the following: ' + '(1) uncomment the "from glob import glob" at top of the file; ' + '(2) edit the find_first_match function: remove this error and uncomment the code. ' + '(3) please put a comment to explain where load_quantities.find_first_match() is used, ' + ' since it is not being used anywhere in the load_quantities file directly.') + raise Exception(errmsg) + """ + originalpath=os.getcwd() + os.chdir(path) + for file in glob(name): + if incl_path: + os.chdir(originalpath) + return os.path.join(path, file) + else: + os.chdir(originalpath) + return file + os.chdir(originalpath) + """ diff --git a/helita/sim/mah.py b/helita/sim/mah.py new file mode 100644 index 00000000..d70e668e --- /dev/null +++ b/helita/sim/mah.py @@ -0,0 +1,686 @@ +import numpy as np + +from . import document_vars +from .load_arithmetic_quantities import * +from .load_noeos_quantities import * +from .load_quantities import * +from .tools import * + + +class Mah: + """ + Class to read Lare3D sav file atmosphere + + Parameters + ---------- + fdir : str, optional + Directory with snapshots. + rootname : str, optional + Template for snapshot number. + verbose : bool, optional + If True, will print more information. + """ + + def __init__(self, run_name, snap, fdir='.', sel_units='cgs', verbose=True, + num_pts=300, ngridc=256, nzc=615, nzc5sav=7, nt5sav=1846): + + self.fdir = fdir + self.rootname = run_name + self.snap = snap + self.sel_units = sel_units + self.verbose = verbose + self.uni = Mah_units() + + self.read_ini() + self.read_dat1() + self.read_dat2() + self.read_dat3() + self.read_dat4() + self.read_dat5() + self.read_dat6() + + #self.x = dd.input_ini['xpos'] + #self.z = dd.input_ini['zpos'] + self.z = self.input_ini['spos'] + if self.sel_units == 'cgs': + # self.x *= self.uni.uni['l'] + # self.y *= self.uni.uni['l'] + self.z *= self.uni.uni['l'] + + self.num_pts = num_pts + self.nx = ngridc + self.ny = ngridc + self.nz = ngridc + + ''' + if self.nx > 1: + self.dx1d = np.gradient(self.x) + else: + self.dx1d = np.zeros(self.nx) + + if self.ny > 1: + self.dy1d = np.gradient(self.y) + else: + self.dy1d = np.zeros(self.ny) + + if self.nz > 1: + self.dz1d = np.gradient(self.z) + else: + self.dz1d = np.zeros(self.nz) + ''' + + self.transunits = False + + self.cstagop = False # This will not allow to use cstagger from Bifrost in load + self.hion = False # This will not allow to use HION from Bifrost in load + + self.genvar() + document_vars.create_vardict(self) + document_vars.set_vardocs(self) + + def read_ini(self): + f = open('%s.ini' % self.rootname, 'rb') + varnamenew = ['unk1', 'opt', + 'unk2', 'unk3', 'kmaxc', 'nsizec', 'ngridc', 'nzc', 'ntube', + 'unk4', 'unk5', 'kmaxt', 'nsizet', 'ngridt', 'nza', 'nzb', + 'unk6', 'unk7', 'nlev_max', 'max_section', 'max_jump', 'max_step', + 'unk8', 'unk9', 'ntmax', 'ndrv', 'nzc4sav', 'nq4sav', 'nm3sav', + 'unk10', 'unk11', 'nza2sav', 'nzb2sav', 'nzc2sav', 'nzc5sav', 'num_pts', + 'unk12', 'unk13', 'nt1sav', 'nt1del', 'nt2sav', 'nt2del', + 'unk14', 'unk15', 'nt5sav', 'nt5del', 'nt6sav', 'nt6del', + 'unk16', 'unk17', 'nmodec', 'kmax1dc'] + input = np.fromfile(f, dtype='int32', count=np.size(varnamenew)) + input_dic = {} + for idx, iv in enumerate(varnamenew): + input_dic[iv] = input[idx] + input_dic['amaxc'] = np.fromfile(f, dtype='float32', count=1) + input_dic['nnxc'] = np.fromfile(f, dtype='int32', count=input_dic['kmaxc']) + input_dic['nnyc'] = np.fromfile(f, dtype='int32', count=input_dic['kmaxc']) + input_dic['aaxc'] = np.fromfile(f, dtype='float32', count=input_dic['kmaxc']) + input_dic['aayc'] = np.fromfile(f, dtype='int32', count=input_dic['kmaxc']) + input_dic['aac'] = np.fromfile(f, dtype='int32', count=input_dic['kmaxc']) + varnamenew2 = ['unk18', 'unk19', 'nmodet', 'kmax1dt'] + input2 = np.fromfile(f, dtype='int32', count=np.size(varnamenew2)) + for idx, iv in enumerate(varnamenew2): + input_dic[iv] = input2[idx] + input_dic['amaxt'] = np.fromfile(f, dtype='float32', count=1) + input_dic['nnxt'] = np.fromfile(f, dtype='int32', count=input_dic['kmaxt']) + input_dic['nnyt'] = np.fromfile(f, dtype='int32', count=input_dic['kmaxt']) + input_dic['aaxt'] = np.fromfile(f, dtype='float32', count=input_dic['kmaxt']) + input_dic['aayt'] = np.fromfile(f, dtype='int32', count=input_dic['kmaxt']) + input_dic['aat'] = np.fromfile(f, dtype='int32', count=input_dic['kmaxt']) + unk2021 = np.fromfile(f, dtype='int32', count=2) + input_dic['nza2arr'] = np.fromfile(f, dtype='int32', count=input_dic['nza2sav']) + input_dic['nzb2arr'] = np.fromfile(f, dtype='int32', count=input_dic['nzb2sav']) + input_dic['nzc2arr'] = np.fromfile(f, dtype='int32', count=input_dic['nzc2sav']) + unk2223 = np.fromfile(f, dtype='int32', count=2) + input_dic['nzc4arr'] = np.fromfile(f, dtype='int32', count=input_dic['nzc4sav']) + input_dic['aa_casc'] = np.fromfile(f, dtype='float32', count=input_dic['nq4sav']) + input_dic['km3arr'] = np.fromfile(f, dtype='int32', count=input_dic['nm3sav']) + unk2425 = np.fromfile(f, dtype='int32', count=2) + input_dic['nzc5arr'] = np.fromfile(f, dtype='int32', count=input_dic['nzc5sav']) + input_dic['xpts_ini'] = (np.fromfile(f, dtype='float32', count=input_dic['num_pts']*input_dic['nzc5sav'])).reshape((input_dic['num_pts'], input_dic['nzc5sav'])) + input_dic['ypts_ini'] = (np.fromfile(f, dtype='float32', count=input_dic['num_pts']*input_dic['nzc5sav'])).reshape((input_dic['num_pts'], input_dic['nzc5sav'])) + + unk2627 = np.fromfile(f, dtype='int32', count=2) + varnamenew3 = ['dt0', 'tmin0', 'tmax0'] + for idx, iv in enumerate(varnamenew3): + input_dic[iv] = np.fromfile(f, dtype='float32', count=1) + varnamenew4 = ['imer1', 'imer2', 'itr1', 'itr2'] + for idx, iv in enumerate(varnamenew4): + input_dic[iv] = np.fromfile(f, dtype='int32', count=1) + varnamenew5 = ['str1', 'str2'] + for idx, iv in enumerate(varnamenew5): + input_dic[iv] = np.fromfile(f, dtype='float64', count=1) + input_dic['ztr0'] = np.fromfile(f, dtype='float32', count=1) + varnamenew6 = ['xoff1', 'yoff1', 'xoff2', 'yoff2'] + for idx, iv in enumerate(varnamenew6): + input_dic[iv] = np.fromfile(f, dtype='int32', count=1) + varnamenew7 = ['xm0', 'zm0', 'hm0', 'bm0', 'len_tot', 'tau0', 'vbase0', 'aat1_drv', + 'aat2_drv', 'bphot', 'width_base', 'temp_chrom', 'fact_chrom', + 'fipfactor', 'nexpo', 'len_cor', 'pres_cor', 'temp_max', 'bm_cor', + 'nu', 'nutim', 'nuconpc', 'nuconmc', 'aac1_mid', 'aac2_mid', 'nuconpt', + 'nuconmt', 'at1_mid', 'aat2_mid'] + for idx, iv in enumerate(varnamenew7): + input_dic[iv] = np.fromfile(f, dtype='float32', count=1) + + input_dic['nz'] = input_dic['nzc']+input_dic['nza']+input_dic['nzb']-2 + unk2627 = np.fromfile(f, dtype='int32', count=2) + varnamenew8 = ['trav', 'spos'] # trav is the total travel time + for idx, iv in enumerate(varnamenew8): + input_dic[iv] = np.fromfile(f, dtype='float64', count=input_dic['nz']) + varnamenew9 = ['xpos', 'zpos', 'gamm', 'grav', 'temp', 'mu', 'pres', 'rho', + 'bmag', 'va'] + for idx, iv in enumerate(varnamenew9): + input_dic[iv] = np.fromfile(f, dtype='float32', count=input_dic['nz']) + input_dic['nlev'] = np.fromfile(f, dtype='int32', count=input_dic['nz']-1) + input_dic['dva'] = np.fromfile(f, dtype='float32', count=input_dic['nz']-1) + + unk2627 = np.fromfile(f, dtype='int32', count=2) + varnamenew9 = ['vel', 'qwav', 'qrad', 'fent', 'fcon', 'zzp_scale', 'zzm_scale'] + for idx, iv in enumerate(varnamenew9): + input_dic[iv] = np.fromfile(f, dtype='float32', count=input_dic['nz']) + + unk2627 = np.fromfile(f, dtype='int32', count=2) + input_dic['num_section'] = np.fromfile(f, dtype='int32', count=input_dic['nlev_max']+1) + input_dic['section'] = np.fromfile(f, dtype='int32', + count=2*input_dic['max_section']*(input_dic['nlev_max']+1)).reshape((2, input_dic['max_section'], input_dic['nlev_max']+1)) + input_dic['num_jump'] = np.fromfile(f, dtype='int32', count=input_dic['nlev_max']) + input_dic['jump'] = np.fromfile(f, dtype='int32', + count=2*input_dic['max_jump']*(input_dic['nlev_max'])).reshape((2, input_dic['max_jump'], input_dic['nlev_max'])) + input_dic['list'] = np.fromfile(f, dtype='int32', count=3*input_dic['max_step']).reshape((input_dic['max_step'], 3)) + + unk2627 = np.fromfile(f, dtype='int32', count=2) + input_dic['rat0c'] = np.fromfile(f, dtype='float32', count=input_dic['kmaxc']) + input_dic['rat0t'] = np.fromfile(f, dtype='float32', count=input_dic['kmaxt']) + + unk2627 = np.fromfile(f, dtype='int32', count=2) + varnamenew10 = ['width_a', 'rat1pa_ini', 'rat1ma_ini', 'zzpa_scale', 'zzma_scale'] + for idx, iv in enumerate(varnamenew10): + input_dic[iv] = np.fromfile(f, dtype='float32', count=input_dic['nza']) + + unk2627 = np.fromfile(f, dtype='int32', count=2) + varnamenew10 = ['width_b', 'rat1pb_ini', 'rat1mb_ini', 'zzpb_scale', 'zzmb_scale'] + for idx, iv in enumerate(varnamenew10): + input_dic[iv] = np.fromfile(f, dtype='float32', count=input_dic['nza']) + + unk2627 = np.fromfile(f, dtype='int32', count=2) + varnamenew10 = ['width_c', 'rat1pc_ini', 'rat1mc_ini', 'zzpc_scale', 'zzmc_scale'] + for idx, iv in enumerate(varnamenew10): + input_dic[iv] = np.fromfile(f, dtype='float32', count=input_dic['nzc']) + + unk2627 = np.fromfile(f, dtype='int32', count=2) + input_dic['kind_drv'] = np.fromfile(f, dtype='int32', count=input_dic['ndrv']) + input_dic['omega_rms_drv'] = np.fromfile(f, dtype='float32', count=input_dic['ndrv']) + + unk2627 = np.fromfile(f, dtype='int32', count=2) + input_dic['time'] = np.fromfile(f, dtype='float32', count=input_dic['ntmax']+1) + input_dic['omega_driver'] = np.fromfile(f, dtype='float32', + count=input_dic['ndrv']*input_dic['ntube']*2*(input_dic['ntmax']+1)).reshape(( + input_dic['ndrv'], input_dic['ntube'], 2, (input_dic['ntmax']+1))) + self.input_ini = input_dic + + def read_dat1(self): + f = open('%s.dat1' % self.rootname, 'rb') + varlist = ['zzp', 'zzm', 'vrms', 'orms', 'brms', 'arms', 'eep', + 'eem', 'eer', 'ekin', 'emag', 'etot', 'rat1p', 'rat1m', 'qperp_p', + 'qperp_m', 'qtot'] + ''' + zzp -- amplitude of the outward waves (km/s) [nz,nt] + zzm -- amplitude of the inward waves (km/s) [nz,nt] + brms -- rms of the |B| (cgs) [nz,nt] + vrms -- rms of the |vel| (cgs) [nz,nt] + orms -- rms vorticity (cgs) [nz,nt] + etot -- total energy (cgs) [nz,nt] + ee -- energy denstiy (cgs) [nz,nt] + emag -- magnetic energy (cgs) [nz,nt] + ekin -- kinetic energy (cgs) [nz,nt] + qperp_p -- perpendicular heating rate (cgs) [nz,nt] + ''' + self.input_dat1 = {} + input2 = np.fromfile(f, dtype='int32', count=1) + for idx, iv in enumerate(varlist): + self.input_dat1[iv] = np.zeros((self.input_ini['nz'], self.input_ini['nt1sav'])) + for it in range(self.input_ini['nt1sav']): + try: + for idx, iv in enumerate(varlist): + self.input_dat1[iv][:, it] = np.fromfile(f, dtype='float32', count=self.input_ini['nz']) + except: + self.input_ini['nt1sav'] = it + varnamenew2 = ['unk18', 'unk19'] + input2 = np.fromfile(f, dtype='int32', count=np.size(varnamenew2)) + + def read_dat4(self): + f = open('%s.ini' % self.rootname, 'rb') + varlist = ['qcasc_p', 'qcasc_m'] + self.input_dat4 = {} + input2 = np.fromfile(f, dtype='int32', count=1) + for idx, iv in enumerate(varlist): + self.input_dat4[iv] = np.zeros((self.input_ini['nq4sav'], + self.input_ini['nzc4sav'], + self.input_ini['nt1sav'])) + for it in range(self.input_ini['nt1sav']): + try: + for idx, iv in enumerate(varlist): + self.input_dat4[iv][:, :, it] = (np.fromfile(f, dtype='float32', count=np.size(self.input_dat4[iv][..., 0]))).reshape(( + self.input_ini['nq4sav'], self.input_ini['nzc4sav'])) + varnamenew2 = ['unk18', 'unk19'] + input2 = np.fromfile(f, dtype='int32', count=np.size(varnamenew2)) + except: + self.input_ini['nt1sav'] = it + varnamenew2 = ['unk18', 'unk19'] + input2 = np.fromfile(f, dtype='int32', count=np.size(varnamenew2)) + + def read_dat2(self): + f = open('%s.dat2' % self.rootname, 'rb') + input2 = np.fromfile(f, dtype='int32', count=1) + varlista = ['omegpa', 'omegma'] # voriticty w+, w- (x,y) + self.input_dat2 = {} + for idx, iv in enumerate(varlista): + self.input_dat2[iv] = np.zeros((self.input_ini['kmaxt'], + self.input_ini['nza2sav'], + self.input_ini['ntube'], + self.input_ini['nt2sav'])) + varlistb = ['omegpb', 'omegmb'] + for idx, iv in enumerate(varlistb): + self.input_dat2[iv] = np.zeros((self.input_ini['kmaxt'], + self.input_ini['nzb2sav'], + self.input_ini['ntube'], + self.input_ini['nt2sav'])) + + varlistc = ['omegpc', 'omegmc'] + for idx, iv in enumerate(varlistc): + self.input_dat2[iv] = np.zeros((self.input_ini['kmaxt'], + self.input_ini['nzc2sav'], + self.input_ini['ntube'])) + + for it in range(self.input_ini['nt2sav']): + try: + for idx, iv in enumerate(varlista): + self.input_dat2[iv][:, :, it] = (np.fromfile(f, dtype='float32', count=np.size(self.input_dat2[iv][..., 0])))*reshape(( + self.input_ini['ntube'], self.input_ini['nza2sav'], self.input_ini['kmaxt'])).T + varnamenew2 = ['unk18', 'unk19'] + input2 = np.fromfile(f, dtype='int32', count=np.size(varnamenew2)) + for idx, iv in enumerate(varlistb): + self.input_dat2[iv][:, :, it] = (np.fromfile(f, dtype='float32', count=np.size(self.input_dat2[iv][..., 0])))*reshape(( + self.input_ini['ntube'], self.input_ini['nzb2sav'], self.input_ini['kmaxt'])).T + varnamenew2 = ['unk18', 'unk19'] + input2 = np.fromfile(f, dtype='int32', count=np.size(varnamenew2)) + for idx, iv in enumerate(varlistc): + self.input_dat2[iv][:, :, it] = (np.fromfile(f, dtype='float32', count=np.size(self.input_dat2[iv][..., 0])))*reshape(( + self.input_ini['nzc2sav'], self.input_ini['kmaxt'])).T + varnamenew2 = ['unk18', 'unk19'] + input2 = np.fromfile(f, dtype='int32', count=np.size(varnamenew2)) + except: + self.input_ini['nt2sav'] = it + + def read_dat3(self): + f = open('%s.dat3' % self.rootname, 'rb') + input2 = np.fromfile(f, dtype='int32', count=1) + varlist = ['omegpc3', 'omegmc3'] + self.input_dat3 = {} + for idx, iv in enumerate(varlist): + self.input_dat3[iv] = np.zeros((self.input_ini['nm3sav'], + self.input_ini['nzc'], + self.input_ini['nt2sav'])) + for it in range(self.input_ini['nt2sav']): + try: + for idx, iv in enumerate(varlist): + self.input_dat3[iv][:, :, it] = (np.fromfile(f, dtype='float32', count=np.size(self.input_dat3[iv][..., 0]))).reshape(( + self.input_ini['nzc'], self.input_ini['nm3sav'])).T + varnamenew2 = ['unk18', 'unk19'] + input2 = np.fromfile(f, dtype='int32', count=np.size(varnamenew2)) + except: + self.input_ini['nt2sav'] = it + + def read_dat5(self): + f = open('%s.dat5' % self.rootname, 'rb') + input2 = np.fromfile(f, dtype='int32', count=1) + varlista = ['vel1', 'vel2'] # LOS vel, Non thermal velocity (both perp to the loop). + self.input_dat5 = {} + for idx, iv in enumerate(varlista): + self.input_dat5[iv] = np.zeros((self.input_ini['ngridc'], + self.input_ini['nzc'], + self.input_ini['nt5sav'])) + varlistb = ['qqq0'] # heating rate. + for idx, iv in enumerate(varlistb): + self.input_dat5[iv] = np.zeros((self.input_ini['ngridc'], + self.input_ini['ngridc'], + self.input_ini['nzc5sav'], + self.input_ini['nt5sav'])) + + varlistc = ['xpts', 'ypts', 'qpts', 'vxpts', 'vypts'] + for idx, iv in enumerate(varlistc): + self.input_dat5[iv] = np.zeros((self.input_ini['num_pts'], + self.input_ini['nzc5sav'], + self.input_ini['nt5sav'])) + + for it in range(self.input_ini['nt5sav']): + # try: + for idx, iv in enumerate(varlista): + self.input_dat5[iv][..., it] = (np.fromfile(f, dtype='float32', count=np.size(self.input_dat5[iv][..., 0]))).reshape(( + self.input_ini['nzc'], self.input_ini['ngridc'])).T + varnamenew2 = ['unk18', 'unk19'] + input2 = np.fromfile(f, dtype='int32', count=np.size(varnamenew2)) + for idx, iv in enumerate(varlistb): + self.input_dat5[iv][..., it] = (np.fromfile(f, dtype='float32', count=np.size(self.input_dat5[iv][..., 0]))).reshape(( + self.input_ini['nzc5sav'], self.input_ini['ngridc'], self.input_ini['ngridc'])).T + varnamenew2 = ['unk18', 'unk19'] + input2 = np.fromfile(f, dtype='int32', count=np.size(varnamenew2)) + for idx, iv in enumerate(varlistc): + self.input_dat5[iv][..., it] = (np.fromfile(f, dtype='float32', count=np.size(self.input_dat5[iv][..., 0]))).reshape(( + self.input_ini['nzc5sav'], self.input_ini['num_pts'])).T + varnamenew2 = ['unk18', 'unk19'] + input2 = np.fromfile(f, dtype='int32', count=np.size(varnamenew2)) + # except: + # self.input_ini['nt5sav'] = it + + def read_dat6(self): + f = open('%s.dat6' % self.rootname, 'rb') + input2 = np.fromfile(f, dtype='int32', count=1) + varlista = ['bbxa', 'bbya'] + self.input_dat6 = {} + for idx, iv in enumerate(varlista): + self.input_dat6[iv] = np.zeros((self.input_ini['ngridt']+1, + self.input_ini['ngridt']+1, + self.input_ini['nza'], + self.input_ini['ntube'], + self.input_ini['nt6sav'])) + varlistb = ['bbxb', 'bbyb'] + for idx, iv in enumerate(varlistb): + self.input_dat6[iv] = np.zeros((self.input_ini['ngridt']+1, + self.input_ini['ngridt']+1, + self.input_ini['nzb'], + self.input_ini['ntube'], + self.input_ini['nt6sav'])) + + varlistc = ['bbxc', 'bbyc'] + for idx, iv in enumerate(varlistc): + self.input_dat6[iv] = np.zeros((self.input_ini['ngridt']+1, + self.input_ini['ngridt']+1, + self.input_ini['nzc'], + self.input_ini['nt6sav'])) + + for it in range(self.input_ini['nt6sav']): + try: + for idx, iv in enumerate(varlista): + self.input_dat6[iv][:, :, it] = (np.fromfile(f, dtype='float32', count=np.size(self.input_dat6[iv][..., 0])))*reshape(( + self.input_ini['ntube'], self.input_ini['nza'], + self.input_ini['ngridt']+1, self.input_ini['ngridt']+1)).T + varnamenew2 = ['unk18', 'unk19'] + input2 = np.fromfile(f, dtype='int32', count=np.size(varnamenew2)) + for idx, iv in enumerate(varlistb): + self.input_dat6[iv][:, :, it] = (np.fromfile(f, dtype='float32', count=np.size(self.input_dat6[iv][..., 0])))*reshape(( + self.input_ini['ntube'], self.input_ini['nzb'], + self.input_ini['ngridt']+1, self.input_ini['ngridt']+1)).T + varnamenew2 = ['unk18', 'unk19'] + input2 = np.fromfile(f, dtype='int32', count=np.size(varnamenew2)) + for idx, iv in enumerate(varlistc): + self.input_dat6[iv][:, :, it] = (np.fromfile(f, dtype='float32', count=np.size(self.input_dat6[iv][..., 0])))*reshape(( + self.input_ini['nzb'], self.input_ini['ngridt']+1, + self.input_ini['ngridt']+1)).T + varnamenew2 = ['unk18', 'unk19'] + input2 = np.fromfile(f, dtype='int32', count=np.size(varnamenew2)) + except: + self.input_ini['nt6sav'] = it + + def get_var(self, var, *args, snap=None, iix=None, iiy=None, iiz=None, layout=None, **kargs): + ''' + Reads the variables from a snapshot (snap). + + Parameters + ---------- + var - string + Name of the variable to read. Must be Bifrost internal names. + snap - integer, optional + Snapshot number to read. By default reads the loaded snapshot; + if a different number is requested, will load that snapshot. + + Axes: + ----- + x and y axes horizontal plane + z-axis is vertical axis, top corona is last index and positive. + + Variable list: + -------------- + rho -- Density (g/cm^3) [nz] + energy -- Energy (erg) [nz] + tg -- Temperature (K) [nz] + modb -- |B| (G) [nz] + va -- Alfven speed (km/s) [nz] + pg -- Pressure (cgs) [nz] + vx -- component x of the velocity (multipy by self.uni['u'] to get in cm/s) [nx+1, ny+1, nz+1] + vy -- component y of the velocity (multipy by self.uni['u'] to get in cm/s) [nx+1, ny+1, nz+1] + vz -- component z of the velocity (multipy by self.uni['u'] to get in cm/s) [nx+1, ny+1, nz+1] + bx -- component x of the magnetic field (multipy by self.uni['b'] to get in G) [nx+1, ny, nz] + by -- component y of the magnetic field (multipy by self.uni['b'] to get in G) [nx, ny+1, nz] + bz -- component z of the magnetic field (multipy by self.uni['b'] to get in G) [nx, ny, nz+1] + jx -- component x of the current [nx+1, ny+1, nz+1] + jy -- component x of the current [nx+1, ny+1, nz+1] + jz -- component x of the current [nx+1, ny+1, nz+1] + eta -- eta (?) [nx, ny, nz] + + ''' + + if var in self.varn.keys(): + varname = self.varn[var] + elif var in self.varn1.keys(): + varname = self.varn1[var] + elif var in self.varn2.keys(): + varname = self.varn2[var] + elif var in self.varn3.keys(): + varname = self.varn3[var] + elif var in self.varn4.keys(): + varname = self.varn4[var] + elif var in self.varn5.keys(): + varname = self.varn5[var] + elif var in self.varn6.keys(): + varname = self.varn6[var] + else: + varname = var + + if snap != None: + self.snap = snap + + # try: + + varu = var.replace('x', '') + varu = varu.replace('y', '') + varu = varu.replace('z', '') + + if self.sel_units == 'cgs': + if (var in self.varn.keys()) and (varu in self.uni.uni.keys()): + self.uni.uni[varu] + else: + pass + else: + pass + + if var in self.varn.keys(): + self.data = (self.input_ini[varname]) + elif var in self.varn1.keys(): + self.data = (self.input_dat1[varname]) + if snap != None: + self.data = self.data[..., self.snap] + elif var in self.varn2.keys(): + self.data = (self.input_dat2[varname]) + if snap != None: + self.data = self.data[..., self.snap] + elif var in self.varn3.keys(): + self.data = (self.input_dat3[varname]) + if snap != None: + self.data = self.data[..., self.snap] + elif var in self.varn4.keys(): + self.data = (self.input_dat4[varname]) + if snap != None: + self.data = self.data[..., self.snap] + elif var in self.varn5.keys(): + self.data = (self.input_dat5[varname]) + if snap != None: + self.data = self.data[..., self.snap] + elif var in self.varn6.keys(): + self.data = (self.input_dat6[varname]) + if snap != None: + self.data = self.data[..., self.snap] + ''' + except: + # Loading quantities + if self.verbose: + print('Loading composite variable',end="\r",flush=True) + self.data = load_noeos_quantities(self,var, **kargs) + + if np.shape(self.data) == (): + self.data = load_quantities(self,var,PLASMA_QUANT='', CYCL_RES='', + COLFRE_QUANT='', COLFRI_QUANT='', IONP_QUANT='', + EOSTAB_QUANT='', TAU_QUANT='', DEBYE_LN_QUANT='', + CROSTAB_QUANT='', COULOMB_COL_QUANT='', AMB_QUANT='', + HALL_QUANT='', BATTERY_QUANT='', SPITZER_QUANT='', + KAPPA_QUANT='', GYROF_QUANT='', WAVE_QUANT='', + FLUX_QUANT='', CURRENT_QUANT='', COLCOU_QUANT='', + COLCOUMS_QUANT='', COLFREMX_QUANT='', **kargs) + + # Loading arithmetic quantities + if np.shape(self.data) == (): + if self.verbose: + print('Loading arithmetic variable',end="\r",flush=True) + self.data = load_arithmetic_quantities(self,var, **kargs) + ''' + elif document_vars.creating_vardict(self): + return None + elif var == '': + + print(help(self.get_var)) + print('VARIABLES USING CGS OR GENERIC NOMENCLATURE') + for ii in self.varn: + print('use ', ii, ' for ', self.varn[ii]) + if hasattr(self, 'vardict'): + self.vardocs() + return None + + return self.data + + def readvar(self, inputfilename, nx, ny, nz, snap, nvar): + f = open(inputfilename, 'rb') + f.seek(8*nvar*nx*ny*nz + 64*snap*nx*ny*nz) + print(8*nvar*nx*ny*nz, 64*snap*nx*ny*nz) + var = np.fromfile(f, dtype='float32', count=nx*ny*nz) + var = np.reshape(var, (self.nx, self.ny, self.nz)) + f.close() + return var + + def genvar(self): + ''' + Dictionary of original variables which will allow to convert to cgs. + ''' + self.varn = {} + self.varn['rho'] = 'rho' + self.varn['tg'] = 'temp' + self.varn['e'] = 'e' + self.varn['pg'] = 'pres' + self.varn['ux'] = 'ux' + self.varn['uy'] = 'uy' + self.varn['uz'] = 'uz' + self.varn['modb'] = 'bmag' + self.varn['va'] = 'va' + self.varn['by'] = 'by' + self.varn['bz'] = 'bz' + self.varn['jx'] = 'jx' + self.varn['jy'] = 'jy' + self.varn['jz'] = 'jz' + + varlist = ['zzp', 'zzm', 'vrms', 'orms', 'brms', 'arms', 'eep', + 'eem', 'eer', 'ekin', 'emag', 'etot', 'rat1p', 'rat1m', 'qperp_p', + 'qperp_m', 'qtot'] + + self.varn1 = {} + for var in varlist: + self.varn1[var] = var + # self.varn1['zzp'] = 'zzp' # amplitude of the outward waves (km/s) + # self.varn1['zzm'] = 'zzm' # amplitude of the inward waves (km/s) + # self.varn1['vrms'] = 'vrms' # rms of the |vel| (cgs) + # self.varn1['orms'] = 'orms' # rms vorticity (cgs) + # self.varn1['brms'] = 'brms' # rms of the |B| (cgs) + # self.varn1['e'] = 'etot' # total energy (cgs) + # self.varn1['emag'] = 'emag' # magnetic energy (cgs) + # self.varn1['ekin'] = 'ekin' # kinetic energy (cgs) + # self.varn1['qperp_p']= 'qperp_p' # perpendicular heating rate (cgs) + + varlist = ['omegpa', 'omegma', 'omegpb', 'omegmb', 'omegpc', 'omegmc'] + self.varn2 = {} + for var in varlist: + self.varn2[var] = var + + varlist = ['omegpc3', 'omegmc3'] + self.varn3 = {} + for var in varlist: + self.varn3[var] = var + + varlist = ['qcasc_p', 'qcasc_m'] + self.varn4 = {} + for var in varlist: + self.varn4[var] = var + + varlist = ['vel1', 'vel2', 'qqq0', 'xpts', 'ypts', 'qpts', 'vxpts', 'vypts'] + self.varn5 = {} + for var in varlist: + self.varn5[var] = var + + varlist = ['bbxa', 'bbya', 'bbxb', 'bbyb', 'bbxc', 'bbyc'] + self.varn6 = {} + for var in varlist: + self.varn6[var] = var + + def trans2comm(self, varname, snap=None): + ''' + Transform the domain into a "common" format. All arrays will be 3D. The 3rd axis + is: + + - for 3D atmospheres: the vertical axis + - for loop type atmospheres: along the loop + - for 1D atmosphere: the unique dimension is the 3rd axis. + At least one extra dimension needs to be created artifically. + + All of them should obey the right hand rule + + In all of them, the vectors (velocity, magnetic field etc) away from the Sun. + + If applies, z=0 near the photosphere. + + Units: everything is in cgs. + + If an array is reverse, do ndarray.copy(), otherwise pytorch will complain. + + ''' + + self.sel_units = 'cgs' + + self.trans2commaxes + + var = self.get_var(varname, snap=snap).copy() + + #var = transpose(var,(X,X,X)) + # also velocities. + + return var + + def trans2commaxes(self): + + if self.transunits == False: + # self.x = # including units conversion + # self.y = + # self.z = + # self.dx = + # self.dy = + # self.dz = + self.transunits = True + + def trans2noncommaxes(self): + + if self.transunits == True: + # opposite to the previous function + self.transunits = False + + +class Mah_units(object): + + def __init__(self, verbose=False): + ''' + Units and constants in cgs + ''' + self.uni = {} + self.verbose = verbose + self.uni['gamma'] = 5./3. + self.uni['tg'] = 1.0 # K + self.uni['l'] = 1.0e5 # km -> cm + self.uni['rho'] = 1.0 # gr cm^-3 + self.uni['u'] = 1.0 # cm/s + self.uni['b'] = 1.0 # Gauss + self.uni['t'] = 1.0 # seconds + + # Units and constants in SI + + convertcsgsi(self) + + globalvars(self) + + self.uni['n'] = self.uni['rho'] / self.m_p / 2. # cm^-3 diff --git a/helita/sim/matsumotosav.py b/helita/sim/matsumotosav.py new file mode 100644 index 00000000..b0172025 --- /dev/null +++ b/helita/sim/matsumotosav.py @@ -0,0 +1,299 @@ +import os + +import numpy as np +from scipy.io import readsav as rsav + +from . import document_vars +from .load_arithmetic_quantities import * +from .load_noeos_quantities import * +from .load_quantities import * +from .tools import * + + +class Matsumotosav: + """ + Class to read Matsumoto's sav file atmosphere. + Snapshots from a MHD simulation ( Matsumoto 2018 ) + https://ui.adsabs.harvard.edu/abs/2018MNRAS.476.3328M/abstract + + Parameters + ---------- + fdir : str, optional + Directory with snapshots. + rootname : str + Template for snapshot number. + it : integer + Snapshot number to read. By default reads the loaded snapshot; + if a different number is requested, will load that snapshot. + verbose : bool, optional + If True, will print more information. + """ + + def __init__(self, rootname, snap, fdir='.', sel_units='cgs', verbose=True): + + self.fdir = fdir + self.rootname = rootname + self.savefile = rsav(os.path.join(fdir, rootname+'{:06d}'.format(snap)+'.sav')) + self.snap = snap + self.sel_units = sel_units + self.verbose = verbose + self.uni = Matsumotosav_units() + + self.time = self.savefile['v']['time'][0].copy() + self.grav = self.savefile['v']['gx'][0].copy() + self.gamma = self.savefile['v']['gm'][0].copy() + + if self.sel_units == 'cgs': + self.x = self.savefile['v']['x'][0].copy() # cm + self.y = self.savefile['v']['y'][0].copy() + self.z = self.savefile['v']['z'][0].copy() + + self.dx = self.savefile['v']['dx'][0].copy() + self.dy = self.savefile['v']['dy'][0].copy() + self.dz = self.savefile['v']['dz'][0].copy() + else: + self.x = self.savefile['v']['x'][0].copy()/1e8 # Mm + self.y = self.savefile['v']['y'][0].copy()/1e8 + self.z = self.savefile['v']['z'][0].copy()/1e8 + + self.dx = self.savefile['v']['dx'][0].copy()/1e8 + self.dy = self.savefile['v']['dy'][0].copy()/1e8 + self.dz = self.savefile['v']['dz'][0].copy()/1e8 + + self.nx = len(self.x) + self.ny = len(self.y) + self.nz = len(self.z) + + if self.nx > 1: + self.dx1d = np.gradient(self.x) + else: + self.dx1d = np.zeros(self.nx) + + if self.ny > 1: + self.dy1d = np.gradient(self.y) + else: + self.dy1d = np.zeros(self.ny) + + if self.nz > 1: + self.dz1d = np.gradient(self.z) + else: + self.dz1d = np.zeros(self.nz) + + self.transunits = False + + self.cstagop = False # This will not allow to use cstagger from Bifrost in load + self.hion = False # This will not allow to use HION from Bifrost in load + + self.genvar() + document_vars.create_vardict(self) + document_vars.set_vardocs(self) + + def get_var(self, var, *args, snap=None, iix=None, iiy=None, iiz=None, layout=None, **kargs): + ''' + Reads the variables from a snapshot (snap). + + Parameters + ---------- + var - string + Name of the variable to read. Must be Bifrost internal names. + snap - integer, optional + Snapshot number to read. By default reads the loaded snapshot; + if a different number is requested, will load that snapshot. + + Axes: + ----- + y and z axes horizontal plane + x-axis is vertical axis, top corona is first index and negative. + + Variable list: + -------------- + ro -- Density (g/cm^3) [nx, ny, nz] + temperature -- Temperature (K) [nx, ny, nz] + vx -- component x of the velocity (cm/s) [nx, ny, nz] + vy -- component y of the velocity (cm/s) [nx, ny, nz] + vz -- component z of the velocity (cm/s) [nx, ny, nz] + bx -- component x of the magnetic field (G) [nx, ny, nz] + by -- component y of the magnetic field (G) [nx, ny, nz] + bz -- component z of the magnetic field (G) [nx, ny, nz] + pressure -- Pressure (dyn/cm^2) [nx, ny, nz] + + ''' + + if snap != None: + self.snap = snap + self.savefile = rsav(os.path.join(self.fdir, self.rootname+'{:06d}'.format(self.snap)+'.sav')) + + if var in self.varn.keys(): + varname = self.varn[var] + else: + varname = var + + try: + + if self.sel_units == 'cgs': + varu = var.replace('x', '') + varu = varu.replace('y', '') + varu = varu.replace('z', '') + if (var in self.varn.keys()) and (varu in self.uni.uni.keys()): + cgsunits = self.uni.uni[varu] + else: + cgsunits = 1.0 + else: + cgsunits = 1.0 + + self.data = self.savefile['v'][varname][0].T * cgsunits + ''' + if (np.shape(self.data)[0]>self.nx): + self.data = (self.data[1:,:,:] + self.data[:-1,:,:]) / 2 + + if (np.shape(self.data)[1]>self.ny): + self.data = (self.data[:,1:,:] + self.data[:,:-1,:]) / 2 + + if (np.shape(self.data)[2]>self.nz): + self.data = (self.data[:,:,1:] + self.data[:,:,:-1]) / 2 + ''' + except: + # Loading quantities + if self.verbose: + print('Loading composite variable', end="\r", flush=True) + self.data = load_noeos_quantities(self, var, **kargs) + + if np.shape(self.data) == (): + self.data = load_quantities(self, var, PLASMA_QUANT='', CYCL_RES='', + COLFRE_QUANT='', COLFRI_QUANT='', IONP_QUANT='', + EOSTAB_QUANT='', TAU_QUANT='', DEBYE_LN_QUANT='', + CROSTAB_QUANT='', COULOMB_COL_QUANT='', AMB_QUANT='', + HALL_QUANT='', BATTERY_QUANT='', SPITZER_QUANT='', + KAPPA_QUANT='', GYROF_QUANT='', WAVE_QUANT='', + FLUX_QUANT='', CURRENT_QUANT='', COLCOU_QUANT='', + COLCOUMS_QUANT='', COLFREMX_QUANT='', **kargs) + + # Loading arithmetic quantities + if np.shape(self.data) == (): + if self.verbose: + print('Loading arithmetic variable', end="\r", flush=True) + self.data = load_arithmetic_quantities(self, var, **kargs) + + if document_vars.creating_vardict(self): + return None + elif var == '': + + print(help(self.get_var)) + print('VARIABLES USING CGS OR GENERIC NOMENCLATURE') + for ii in self.varn: + print('use ', ii, ' for ', self.varn[ii]) + if hasattr(self, 'vardict'): + self.vardocs() + + return None + + return self.data + + def genvar(self): + ''' + Dictionary of original variables which will allow to convert to cgs. + ''' + self.varn = {} + self.varn['rho'] = 'ro' + self.varn['tg'] = 'te' + self.varn['pg'] = 'pr' + self.varn['ux'] = 'vx' + self.varn['uy'] = 'vy' + self.varn['uz'] = 'vz' + self.varn['bx'] = 'bx' + self.varn['by'] = 'by' + self.varn['bz'] = 'bz' + + def trans2comm(self, varname, snap=None): + ''' + Transform the domain into a "common" format. All arrays will be 3D. The 3rd axis + is: + + - for 3D atmospheres: the vertical axis + - for loop type atmospheres: along the loop + - for 1D atmosphere: the unique dimension is the 3rd axis. + At least one extra dimension needs to be created artifically. + + All of them should obey the right hand rule + + In all of them, the vectors (velocity, magnetic field etc) away from the Sun. + + If applies, z=0 near the photosphere. + + Units: everything is in cgs. + + If an array is reverse, do ndarray.copy(), otherwise pytorch will complain. + + ''' + + self.sel_units = 'cgs' + + if varname[-1] in ['x', 'y', 'z']: + if varname[-1] == 'x': + varname = varname.replace(varname[len(varname)-1], 'y') + elif varname[-1] == 'y': + varname = varname.replace(varname[len(varname)-1], 'z') + else: + varname = varname.replace(varname[len(varname)-1], 'x') + + self.order = np.array((1, 2, 0)) + + self.trans2commaxes() + + return np.transpose(self.get_var(varname, snap=snap), + self.order).copy() + + def trans2commaxes(self): + + if self.transunits == False: + # including units conversion + axisarrs = np.array(((self.x), (self.y), (self.z))) + daxisarrs = np.array(((self.dx), (self.dy), (self.dz))) + self.x = axisarrs[self.order[0]].copy() + self.y = axisarrs[self.order[1]].copy() + self.z = axisarrs[self.order[2]].copy() + np.max(np.abs(axisarrs[self.order[2]])) + self.dx = daxisarrs[self.order[0]].copy() + self.dy = daxisarrs[self.order[1]].copy() + self.dz = -axisarrs[self.order[2]].copy() + self.dx1d, self.dy1d, self.dz1d = np.gradient(self.x).copy(), np.gradient(self.y).copy(), np.gradient(self.z).copy() + self.nx, self.ny, self.nz = np.size(self.x), np.size(self.dy), np.size(self.dz) + self.transunits = True + + def trans2noncommaxes(self): + + if self.transunits == True: + # opposite to the previous function + axisarrs = np.array(((self.x), (self.y), (self.z))) + self.x = axisarrs[self.order[0]].copy() + self.y = axisarrs[self.order[1]].copy() + self.z = (- axisarrs[self.order[2]]).copy() - np.max(np.abs(axisarrs[self.order[2]])) + self.dx = (daxisarrs[self.order[0]]).copy() + self.dy = daxisarrs[self.order[1]].copy() + self.dz = (- axisarrs[self.order[2]]).copy() + self.dx1d, self.dy1d, self.dz1d = np.gradient(self.x).copy(), np.gradient(self.y).copy(), np.gradient(self.z).copy() + self.nx, self.ny, self.nz = np.size(self.x), np.size(self.dy), np.size(self.dz) + self.transunits = False + + +class Matsumotosav_units(object): + + def __init__(self, verbose=False): + ''' + Units and constants in cgs + ''' + self.uni = {} + self.verbose = verbose + self.uni['tg'] = 1.0 # K + self.uni['l'] = 1.0e8 # Mm -> cm + self.uni['rho'] = 1.0 # gr cm^-3 + self.uni['n'] = 1.0 # cm^-3 + self.uni['u'] = 1.0 # cm/s + self.uni['b'] = 1.0 # Gauss + self.uni['t'] = 1.0 # seconds + + # Units and constants in SI + convertcsgsi(self) + + globalvars(self) + + self.uni['gamma'] = 5./3. diff --git a/helita/sim/multi.py b/helita/sim/multi.py index e4e9296e..85cb8719 100644 --- a/helita/sim/multi.py +++ b/helita/sim/multi.py @@ -1,14 +1,16 @@ """ Set of routines to interface with MULTI (1D or _3D) """ -import numpy as np import os +import numpy as np + class Multi_3dOut: """ Class that reads and deals with output from multi_3d """ + def __init__(self, outfile=None, basedir='.', atmosid='', length=4, verbose=False, readall=False): self.verbose = verbose @@ -90,20 +92,20 @@ def read_out3d(self, outfile, length=4): length=length) setattr(self, cname, np.transpose(aa.reshape(self.ndep, self.ny, - self.nx))) + self.nx))) elif cname == 'Iv': self.check_basic() aa = fort_read(file, isize, 'f', big_endian=be, length=length) self.Iv = np.transpose(aa.reshape(self.ny, self.nx, - self.nqtot)) + self.nqtot)) elif cname == 'n3d': # might be brokenp... self.check_basic() self.nk = isize // (self.nx * self.ny * self.ndep) aa = fort_read(file, isize, 'f', big_endian=be, length=length) self.n3d = np.transpose(aa.reshape(self.nk, self.ndep, - self.ny, self.nx)) + self.ny, self.nx)) elif cname == 'nk': self.nk = fort_read(file, 1, 'i', big_endian=be, length=length)[0] @@ -194,7 +196,6 @@ def read(self, infile, big_endian, length=4): file.close() return - def write_rh15d(self, outfile, sx=None, sy=None, sz=None, desc=None): ''' Writes atmos into rh15d NetCDF format. ''' from . import rh15d @@ -384,6 +385,7 @@ def write_atmos3d(outfile, x, y, z, ne, temp, vz, vx=None, vy=None, rho=None, None. Writes file to disk. """ import os + from ..io.fio import fort_write if os.path.isfile(outfile): diff --git a/helita/sim/multi3d.py b/helita/sim/multi3d.py index 914d61cf..577c95de 100644 --- a/helita/sim/multi3d.py +++ b/helita/sim/multi3d.py @@ -3,15 +3,17 @@ """ import os + +import astropy.units as u import numpy as np import scipy.io -import astropy.units as u class Geometry: """ class def for geometry """ + def __init__(self): self.nx = -1 self.ny = -1 @@ -30,6 +32,7 @@ class Atom: """ class def for atom """ + def __init__(self): self.nrad = -1 self.nrfix = -1 @@ -51,10 +54,12 @@ def __init__(self): self.totn = None self.dopfac = None + class Atmos: """ class def for atmos """ + def __init__(self): self.ne = None self.tg = None @@ -71,6 +76,7 @@ class Spectrum: """ class def for spectrum """ + def __init__(self): self.nnu = -1 self.maxal = -1 @@ -87,6 +93,7 @@ class Cont: """ class def for continuum """ + def __init__(self): self.f_type = None self.j = -1 @@ -107,6 +114,7 @@ class Line: """ class def for spectral line """ + def __init__(self): self.profile_type = None self.ga = -1.0 @@ -137,6 +145,7 @@ class Transition: """ class to hold transition info for IO """ + def __init__(self): self.i = -1 self.j = -1 @@ -168,19 +177,19 @@ class Multi3dOut: Examples -------- - >>> data = Multi3dOut(directory='./output') - >>> data.readall() + data = Multi3dOut(directory='./output') + data.readall() Now select transition (by upper / lower level): - >>> data.set_transition(3, 2) - >>> emergent_intensity = data.readvar('ie') - >>> source_function = data.readvar('snu') - >>> tau1_height = data.readvar('zt1') + data.set_transition(3, 2) + emergent_intensity = data.readvar('ie') + source_function = data.readvar('snu') + tau1_height = data.readvar('zt1') Wavelength for the selected transition is saved in data.d.l, e.g.: - >>> plt.plot(data.d.l, emergent_intensity[0, 0]) + plt.plot(data.d.l, emergent_intensity[0, 0]) """ def __init__(self, inputfile="multi3d.input", directory='./', printinfo=True): @@ -263,7 +272,6 @@ def readinput(self): self.geometry.ny = self.theinput["ny"] self.geometry.nz = self.theinput["nz"] - def readpar(self): """ reads the out_par file @@ -331,7 +339,7 @@ def readpar(self): c.nu = f.read_reals(dtype=self.floattype) c.wnu = f.read_reals(dtype=self.floattype) - #line info + # line info self.line = [Line() for i in range(self.atom.nline)] for l in self.line: l.profile_type = f.read_record(dtype='S72')[0].strip() @@ -404,7 +412,7 @@ def readatmos(self): shape=s, offset=gs*5, order='F') self.atmos.nh = np.memmap(fname, dtype='float32', mode='r', order='F', shape=(nx, ny, nz, nhl), offset=gs * 6) - #self.atmos.vturb = np.memmap(fname, dtype='float32', mode='r', + # self.atmos.vturb = np.memmap(fname, dtype='float32', mode='r', # shape=s ,offset=gs*12, order='F' ) if self.printinfo: print("reading " + fname) @@ -559,6 +567,7 @@ class Multi3dAtmos: read_vturb : bool, optional If True, will read/write turbulent velocity. Default is False. """ + def __init__(self, infile, nx, ny, nz, mode='r', **kwargs): if os.path.isfile(infile) or (mode == "w+"): self.open_atmos(infile, nx, ny, nz, mode=mode, **kwargs) diff --git a/helita/sim/muram.py b/helita/sim/muram.py index df0655ed..6aba49a9 100644 --- a/helita/sim/muram.py +++ b/helita/sim/muram.py @@ -1,6 +1,13 @@ import os + import numpy as np +from . import document_vars +from .bifrost import Rhoeetab +from .load_arithmetic_quantities import * +from .load_quantities import * +from .tools import * + class MuramAtmos: """ @@ -21,111 +28,508 @@ class MuramAtmos: prim : bool, optional Set to True if moments are written instead of velocities. """ + def __init__(self, fdir='.', template=".020000", verbose=True, dtype='f4', - big_endian=False, prim=False): + sel_units='cgs', big_endian=False, prim=False, iz0=None, inttostring=(lambda x: '{0:07d}'.format(x))): + self.prim = prim self.fdir = fdir + self.verbose = verbose + self.sel_units = sel_units + self.iz0 = iz0 # endianness and data type if big_endian: self.dtype = '>' + dtype else: self.dtype = '<' + dtype + self.uni = Muram_units() self.read_header("%s/Header%s" % (fdir, template)) - self.read_atmos(fdir, template) - self.snap = 0 + #self.read_atmos(fdir, template) + # Snapshot number + self.snap = int(template[1:]) + self.filename = '' + self.inttostring = inttostring + self.siter = template + self.file_root = template + + self.transunits = False + self.lowbus = False + + self.cstagop = False # This will not allow to use cstagger from Bifrost in load + self.do_stagger = False + self.hion = False # This will not allow to use HION from Bifrost in load + tabfile = os.path.join(self.fdir, 'tabparam.in') + + if os.access(tabfile, os.R_OK): + self.rhoee = Rhoeetab(tabfile=tabfile, fdir=fdir, radtab=False) + + self.genvar(order=self.order) + + document_vars.create_vardict(self) + document_vars.set_vardocs(self) def read_header(self, headerfile): tmp = np.loadtxt(headerfile) - self.nx, self.nz, self.ny = tmp[:3].astype("i") - self.dx, self.dz, self.dy, self.time, self.dt, self.vamax = tmp[3:9] - self.x = np.arange(self.nx) * self.dx - self.y = np.arange(self.ny) * self.dy - self.z = np.arange(self.nz) * self.dz + #self.dims_orig = tmp[:3].astype("i") + dims = tmp[:3].astype("i") + deltas = tmp[3:6] + # if len(tmp) == 10: # Old version of MURaM, deltas stored in km + # self.uni.uni['l'] = 1e5 # JMS What is this for? + + self.time = tmp[6] + + layout = np.loadtxt('layout.order') + self.order = layout[0:3].astype(int) + # if len(self.order) == 0: + # self.order = np.array([0,2,1]).astype(int) + #self.order = tmp[-3:].astype(int) + # dims = [1,2,0] 0=z, + #dims = np.array((self.dims_orig[self.order[2]],self.dims_orig[self.order[0]],self.dims_orig[self.order[1]])) + #deltas = np.array((deltas[self.order[2]],deltas[self.order[0]],deltas[self.order[1]])).astype('float32') + deltas = deltas[self.order] + dims = dims[self.order] + + if self.sel_units == 'cgs': + deltas *= self.uni.uni['l'] + + self.x = np.arange(dims[0])*deltas[0] + self.y = np.arange(dims[1])*deltas[1] + self.z = np.arange(dims[2])*deltas[2] + if self.iz0 != None: + self.z = self.z - self.z[self.iz0] + self.dx, self.dy, self.dz = deltas[0], deltas[1], deltas[2] + self.nx, self.ny, self.nz = dims[0], dims[1], dims[2] + + if self.nx > 1: + self.dx1d = np.gradient(self.x) + else: + self.dx1d = np.zeros(self.nx) + + if self.ny > 1: + self.dy1d = np.gradient(self.y) + else: + self.dy1d = np.zeros(self.ny) + + if self.nz > 1: + self.dz1d = np.gradient(self.z) + else: + self.dz1d = np.zeros(self.nz) def read_atmos(self, fdir, template): ashape = (self.nx, self.nz, self.ny) file_T = "%s/eosT%s" % (fdir, template) + # When 0-th dimension is vertical, 1st is x, 2nd is y + # when 1st dimension is vertical, 0th is x. + # remember to swap names + + bfact = np.sqrt(4 * np.pi) if os.path.isfile(file_T): self.tg = np.memmap(file_T, mode="r", shape=ashape, dtype=self.dtype, - order="F").transpose((0, 2, 1)) + order="F") file_press = "%s/eosP%s" % (fdir, template) if os.path.isfile(file_press): self.pressure = np.memmap(file_press, mode="r", shape=ashape, dtype=self.dtype, - order="F").transpose((0, 2, 1)) + order="F") file_rho = "%s/result_prim_0%s" % (fdir, template) if os.path.isfile(file_rho): self.rho = np.memmap(file_rho, mode="r", shape=ashape, dtype=self.dtype, - order="F").transpose((0, 2, 1)) + order="F") file_vx = "%s/result_prim_1%s" % (fdir, template) if os.path.isfile(file_vx): self.vx = np.memmap(file_vx, mode="r", shape=ashape, dtype=self.dtype, - order="F").transpose((0, 2, 1)) + order="F") file_vz = "%s/result_prim_2%s" % (fdir, template) if os.path.isfile(file_vz): self.vz = np.memmap(file_vz, mode="r", shape=ashape, dtype=self.dtype, - order="F").transpose((0, 2, 1)) + order="F") file_vy = "%s/result_prim_3%s" % (fdir, template) if os.path.isfile(file_vy): self.vy = np.memmap(file_vy, mode="r", shape=ashape, dtype=self.dtype, - order="F").transpose((0, 2, 1)) + order="F") file_ei = "%s/result_prim_4%s" % (fdir, template) if os.path.isfile(file_ei): self.ei = np.memmap(file_ei, mode="r", shape=ashape, dtype=self.dtype, - order="F").transpose((0, 2, 1)) + order="F") file_Bx = "%s/result_prim_5%s" % (fdir, template) if os.path.isfile(file_Bx): self.bx = np.memmap(file_Bx, mode="r", shape=ashape, dtype=self.dtype, - order="F").transpose((0, 2, 1)) + order="F") + self.bx = self.bx * bfact file_Bz = "%s/result_prim_6%s" % (fdir, template) if os.path.isfile(file_Bz): self.bz = np.memmap(file_Bz, mode="r", shape=ashape, dtype=self.dtype, - order="F").transpose((0, 2, 1)) + order="F") + self.bz = self.bz * bfact file_By = "%s/result_prim_7%s" % (fdir, template) if os.path.isfile(file_By): self.by = np.memmap(file_By, mode="r", shape=ashape, dtype=self.dtype, - order="F").transpose((0, 2, 1)) + order="F") + self.by = self.by * bfact file_tau = "%s/tau%s" % (fdir, template) if os.path.isfile(file_tau): self.tau = np.memmap(file_tau, mode="r", shape=ashape, dtype=self.dtype, - order="F").transpose((0, 2, 1)) + order="F") file_Qtot = "%s/Qtot%s" % (fdir, template) if os.path.isfile(file_Qtot): self.qtot = np.memmap(file_Qtot, mode="r", shape=ashape, dtype=self.dtype, - order="F").transpose((0, 2, 1)) - bfact = np.sqrt(4 * np.pi) - self.bx = self.bx * bfact - self.by = self.by * bfact - self.bz = self.bz * bfact + order="F") + # from moments to velocities - if self.prim: - self.vx /= self.rho - self.vy /= self.rho - self.vz /= self.rho + # if self.prim: + # if hasattr(self,'rho'): + # if hasattr(self,'vx'): + # self.vx /= self.rho + # if hasattr(self,'vy'): + # self.vy /= self.rho + # if hasattr(self,'vz'): + # self.vz /= self.rho + + def read_Iout(self): + + tmp = np.fromfile(self.fdir+'I_out.'+self.siter) + + size = tmp[1:3].astype(int) + time = tmp[3] + + return tmp[4:].reshape([size[1], size[0]]).swapaxes(0, 1), size, time + + def read_slice(self, var, depth): + + tmp = np.fromfile(self.fdir+var+'_slice_'+depth+'.'+self.siter) + + nslices = tmp[0].astype(int) + size = tmp[1:3].astype(int) + time = tmp[3] + + return tmp[4:].reshape([nslices, size[1], size[0]]).swapaxes(1, 2), nslices, size, time + + def read_dem(self, path, max_bins=None): + + tmp = np.fromfile(path+'corona_emission_adj_dem_'+self.fdir+'.'+self.siter) + + bins = tmp[0].astype(int) + size = tmp[1:3].astype(int) + time = tmp[3] + lgTmin = tmp[4] + dellgT = tmp[5] + + dem = tmp[6:].reshape([bins, size[1], size[0]]).transpose(2, 1, 0) + + taxis = lgTmin+dellgT*np.arange(0, bins+1) + + X_H = 0.7 + dem = dem*X_H*0.5*(1+X_H)*3.6e19 + + if max_bins != None: + if bins > max_bins: + dem = dem[:, :, 0:max_bins] + else: + tmp = dem + dem = np.zeros([size[0], size[1], max_bins]) + dem[:, :, 0:bins] = tmp + + taxis = lgTmin+dellgT*np.arange(0, max_bins+1) + + return dem, taxis, time + + def _load_quantity(self, var, cgsunits=1.0, **kwargs): + '''helper function for get_var; actually calls load_quantities for var.''' + __tracebackhide__ = True # hide this func from error traceback stack + # look for var in self.variables + if (var == 'ne'): + print('WWW: Reading ne from Bifrost EOS', end="\r", flush=True) + + # Try to load simple quantities. + # val = load_fromfile_quantities.load_fromfile_quantities(self, var, + # save_if_composite=True, **kwargs) + # if val is not None: + # val = val * cgsunits # (vars from load_fromfile need to get hit by cgsunits.) + # Try to load "regular" quantities + # if val is None: + val = load_quantities(self, var, PLASMA_QUANT='', CYCL_RES='', + COLFRE_QUANT='', COLFRI_QUANT='', IONP_QUANT='', + EOSTAB_QUANT=['ne', 'tau'], TAU_QUANT='', DEBYE_LN_QUANT='', + CROSTAB_QUANT='', COULOMB_COL_QUANT='', AMB_QUANT='', + HALL_QUANT='', BATTERY_QUANT='', SPITZER_QUANT='', + KAPPA_QUANT='', GYROF_QUANT='', WAVE_QUANT='', + FLUX_QUANT='', CURRENT_QUANT='', COLCOU_QUANT='', + COLCOUMS_QUANT='', COLFREMX_QUANT='', **kwargs) + + # Try to load "arithmetic" quantities. + if val is None: + val = load_arithmetic_quantities(self, var, **kwargs) + + return val + + def get_var(self, var, snap=None, iix=None, iiy=None, iiz=None, layout=None, **kwargs): + ''' + Reads the variables from a snapshot (snap). + + Parameters + ---------- + var - string + Name of the variable to read. Must be Bifrost internal names. + snap - integer, optional + Snapshot number to read. By default reads the loaded snapshot; + if a different number is requested, will load that snapshot. + + Axes: + ----- + For the hgcr model: + y-axis is the vertical x and z-axes are horizontal + Newer runs could have x-axis the vertical. + + Variable list: + -------------- + result_prim_0 -- Density (g/cm^3) + result_prim_1 -- component x of the velocity (cm/s) + result_prim_2 -- component y of the velocity (cm/s), vertical in the hgcr + result_prim_3 -- component z of the velocity (cm/s) + result_prim_4 -- internal energy (erg) + result_prim_5 -- component x of the magnetic field (G/sqrt(4*pi)) + result_prim_6 -- component y of the magnetic field (G/sqrt(4*pi)) + result_prim_7 -- component z of the magnetic field (G/sqrt(4*pi)) + eosP -- Pressure (cgs) + eosT -- Temperature (K) + ''' + + if (not snap == None): + self.snap = snap + self.siter = '.'+self.inttostring(snap) + self.read_header("%s/Header%s" % (self.fdir, self.siter)) + + if var in self.varn.keys(): + varname = self.varn[var] + else: + varname = var + + if ((var in self.varn.keys()) and os.path.isfile(self.fdir+'/'+varname + self.siter)): + ashape = np.array([self.nx, self.ny, self.nz]) + + transpose_order = self.order + + if self.sel_units == 'cgs': + varu = var.replace('x', '') + varu = varu.replace('y', '') + varu = varu.replace('z', '') + if (var in self.varn.keys()) and (varu in self.uni.uni.keys()): + cgsunits = self.uni.uni[varu] + else: + cgsunits = 1.0 + else: + cgsunits = 1.0 + #orderfiles = [self.order[2],self.order[0],self.order[1]] + + # self.order = [2,0,1] + data = np.memmap(self.fdir+'/'+varname + self.siter, mode="r", + shape=tuple(ashape[self.order[self.order]]), + dtype=self.dtype, order="F") + data = data.transpose(transpose_order) + + if iix != None: + data = data[iix, :, :] + if iiy != None: + data = data[:, iiy, :] + if iiz != None: + data = data[:, :, iiz] + + self.data = data * cgsunits + + else: + # Loading quantities + cgsunits = 1.0 + # get value of variable. + self.data = self._load_quantity(var, cgsunits, **kwargs) + + # do post-processing + # self.data = self._get_var_postprocess(self.data, var=var, original_slice=original_slice) + + return self.data + + def _get_var_postprocess(self, val, var='', original_slice=[slice(None) for x in ('x', 'y', 'z')]): + '''does post-processing for get_var. + This includes: + - handle "creating documentation" or "var==''" case + - handle "don't know how to get this var" case + - reshape result as appropriate (based on iix,iiy,iiz) + returns val after the processing is complete. + ''' + # handle documentation case + if document_vars.creating_vardict(self): + return None + elif var == '': + print('Variables from snap or aux files:') + print(self.simple_vars) + if hasattr(self, 'vardict'): + self.vardocs() + return None + + # handle "don't know how to get this var" case + if val is None: + errmsg = ('get_var: do not know (yet) how to calculate quantity {}. ' + '(Got None while trying to calculate it.) ' + 'Note that simple_var available variables are: {}. ' + '\nIn addition, get_quantity can read others computed variables; ' + "see e.g. help(self.get_var) or get_var('')) for guidance.") + raise ValueError(errmsg.format(repr(var), repr(self.simple_vars))) + + # set original_slice if cstagop is enabled and we are at the outermost layer. + if self.cstagop and not self._getting_internal_var(): + self.set_domain_iiaxes(*original_slice, internal=False) + + # reshape if necessary... E.g. if var is a simple var, and iix tells to slice array. + if np.shape(val) != (self.xLength, self.yLength, self.zLength): + def isslice(x): return isinstance(x, slice) + if isslice(self.iix) and isslice(self.iiy) and isslice(self.iiz): + val = val[self.iix, self.iiy, self.iiz] # we can index all together + else: # we need to index separately due to numpy multidimensional index array rules. + val = val[self.iix, :, :] + val = val[:, self.iiy, :] + val = val[:, :, self.iiz] + + return val + + def read_var_3d(self, var, iter=None, layout=None): + + if (not iter == None): + self.siter = '.'+self.inttostring(iter) + self.read_header("%s/Header%s" % (self.fdir, self.siter)) + + tmp = np.fromfile(self.fdir+'/'+var + self.siter) + self.data = tmp.reshape([self.nx, self.ny, self.nz]) + + if layout != None: + self.data = tmp.transpose(layout) + + return self.data + + def read_vlos(self, path, max_bins=None): + + tmp = np.fromfile(path+'corona_emission_adj_vlos_'+self.fdir+'.'+self.siter) + + bins = tmp[0].astype(int) + size = tmp[1:3].astype(int) + time = tmp[3] + lgTmin = tmp[4] + dellgT = tmp[5] + + vlos = tmp[6:].reshape([bins, self.ny, self.nz]).transpose(2, 1, 0) + + taxis = lgTmin+dellgT*np.arange(0, bins+1) + + if max_bins != None: + if bins > max_bins: + vlos = vlos[:, :, 0:max_bins] + else: + tmp = vlos + vlos = np.zeros([self.nz, self.ny, max_bins]) + vlos[:, :, 0:bins] = tmp + + taxis = lgTmin+dellgT*np.arange(0, max_bins+1) + + return vlos, taxis, time + + def read_vrms(self, path, max_bins=None): + + tmp = np.fromfile(path+'corona_emission_adj_vrms_'+self.fdir+'.'+self.template) + + bins = tmp[0].astype(int) + size = tmp[1:3].astype(int) + time = tmp[3] + lgTmin = tmp[4] + dellgT = tmp[5] + + vlos = tmp[6:].reshape([bins, self.ny, self.nz]).transpose(2, 1, 0) + + taxis = lgTmin+dellgT*np.arange(0, bins+1) + + if max_bins != None: + if bins > max_bins: + vlos = vlos[:, :, 0:max_bins] + else: + tmp = vlos + vlos = np.zeros([self.nz, self.ny, max_bins]) + vlos[:, :, 0:bins] = tmp + + taxis = lgTmin+dellgT*np.arange(0, max_bins+1) + + return vlos, taxis, time + + def read_fil(self, path, max_bins=None): + + tmp = np.fromfile(path+'corona_emission_adj_fil_'+self.fdir+'.'+self.template) + bins = tmp[0].astype(int) + size = tmp[1:3].astype(int) + time = tmp[3] + lgTmin = tmp[4] + dellgT = tmp[5] + + vlos = tmp[6:].reshape([bins, size[1], size[0]]).transpose(2, 1, 0) + + taxis = lgTmin+dellgT*np.arange(0, bins+1) + + if max_bins != None: + if bins > max_bins: + vlos = vlos[:, :, 0:max_bins] + else: + tmp = vlos + vlos = np.zeros([size[0], size[1], max_bins]) + vlos[:, :, 0:bins] = tmp + + taxis = lgTmin+dellgT*np.arange(0, max_bins+1) + + return vlos, taxis, time + + def genvar(self, order=[0, 1, 2]): + ''' + Dictionary of original variables which will allow to convert to cgs. + ''' + self.varn = {} + self.varn['rho'] = 'result_prim_0' + self.varn['totr'] = 'result_prim_0' + self.varn['tg'] = 'eosT' + self.varn['pg'] = 'eosP' + if os.path.isfile(self.fdir+'/eosne' + self.siter): + print('Has ne files') + self.varn['ne'] = 'eosne' + + unames = np.array(['result_prim_1', 'result_prim_2', 'result_prim_3']) + unames = unames[order] + self.varn['ux'] = unames[0] + self.varn['uy'] = unames[1] + self.varn['uz'] = unames[2] + self.varn['e'] = 'result_prim_4' + unames = np.array(['result_prim_5', 'result_prim_6', 'result_prim_7']) + unames = unames[order] + self.varn['bx'] = unames[0] + self.varn['by'] = unames[1] + self.varn['bz'] = unames[2] def write_rh15d(self, outfile, desc=None, append=True, writeB=False, sx=slice(None), sy=slice(None), sz=slice(None), wght_per_h=1.4271): ''' Writes RH 1.5D NetCDF snapshot ''' - from . import rh15d import scipy.constants as ct - from .bifrost import Rhoeetab + + from . import rh15d + # unit conversion to SI ul = 1.e-2 # to metres ur = 1.e3 # from g/cm^3 to kg/m^3 ut = 1. # to seconds - uv = ul / ut + ul / ut ub = 1.e-4 # to Tesla ue = 1. # to erg/g # slicing and unit conversion (default slice of None selects all) @@ -165,3 +569,124 @@ def write_rh15d(self, outfile, desc=None, append=True, writeB=False, rh15d.make_xarray_atmos(outfile, temp, vz, z, nH=nh, x=x, y=y, vx=vx, vy=vy, rho=rho, append=append, Bx=Bx, By=By, Bz=Bz, desc=desc, snap=self.snap) + + def trans2comm(self, varname, snap=None): + ''' + Transform the domain into a "common" format. All arrays will be 3D. The 3rd axis + is: + + - for 3D atmospheres: the vertical axis + - for loop type atmospheres: along the loop + - for 1D atmosphere: the unique dimension is the 3rd axis. + At least one extra dimension needs to be created artifically. + + All of them should obey the right hand rule + + In all of them, the vectors (velocity, magnetic field etc) away from the Sun. + + If applies, z=0 near the photosphere. + + Units: everything is in cgs. + + If an array is reverse, do ndarray.copy(), otherwise pytorch will complain. + + ''' + + self.sel_units = 'cgs' + + self.trans2commaxes + + return self.get_var(varname, snap=snap) + + def trans2commaxes(self): + + if self.transunits == False: + self.transunits = True + + def trans2noncommaxes(self): + + if self.transunits == True: + self.transunits = False + + def trasn2fits(self, varname, snap=None, instrument='MURaM', + name='ar098192', origin='HGCR ', z_tau51m=None, iz0=None): + ''' + converts the original data into fits files following Bifrost publicly available + format, i.e., SI, vertical axis, z and top corona is positive and last index. + ''' + + if varname[-1] == 'x': + varname = varname.replace('x', 'z') + elif varname[-1] == 'z': + varname = varname.replace('z', 'x') + + self.datafits = self.trans2comm(varname, snap=snap) + + varu = varname.replace('x', '') + varu = varu.replace('y', '') + varu = varu.replace('z', '') + varu = varu.replace('lg', '') + if (varname in self.varn.keys()) and (varu in self.uni.uni.keys()): + siunits = self.uni.unisi[varu]/self.uni.uni[varu] + else: + siunits = 1.0 + + units_title(self) + + if varu == 'ne': + self.fitsunits = 'm^(-3)' + siunits = 1e6 + else: + self.fitsunits = self.unisi_title[varu] + + if varname[:2] == 'lg': + self.datafits = self.datafits + np.log10(siunits) # cgs -> SI + else: + self.datafits = self.datafits * siunits + + self.xfits = self.x / 1e8 + self.yfits = self.y / 1e8 + self.zfits = self.z / 1e8 + + if iz0 != None: + self.zfits -= self.z[iz0]/1e8 + + if z_tau51m == None: + tau51 = self.trans2comm('tau', snap=snap) + z_tau51 = np.zeros((self.nx, self.ny)) + for ix in range(0, self.nx): + for iy in range(0, self.ny): + z_tau51[ix, iy] = self.zfits[np.argmin(np.abs(tau51[ix, iy, :]-1.0))] + + z_tau51m = np.mean(z_tau51) + + print(z_tau51m) + + self.dxfits = self.dx / 1e8 + self.dyfits = self.dy / 1e8 + self.dzfits = self.dz / 1e8 + + writefits(self, varname, instrument=instrument, name=name, + origin=origin, z_tau51m=z_tau51m) + + +class Muram_units(object): + + def __init__(self, verbose=False): + ''' + Units and constants in cgs + ''' + self.uni = {} + self.verbose = verbose + self.uni['tg'] = 1.0 # K + self.uni['l'] = 1.0 # to cm + self.uni['rho'] = 1.0 # g cm^-3 + self.uni['u'] = 1.0 # cm/s + self.uni['b'] = np.sqrt(4.0*np.pi) # convert to Gauss + self.uni['t'] = 1.0 # seconds + self.uni['j'] = 1.0 # current density + + # Units and constants in SI + convertcsgsi(self) + + globalvars(self) diff --git a/helita/sim/preft.py b/helita/sim/preft.py new file mode 100644 index 00000000..66c527c4 --- /dev/null +++ b/helita/sim/preft.py @@ -0,0 +1,465 @@ +from math import ceil + +import numpy as np +from scipy.io import readsav +from scipy.sparse import coo_matrix + +from .load_arithmetic_quantities import * +from .load_noeos_quantities import * +from .load_quantities import * +from .tools import * + + +class preft(object): + """ + Class to read preft atmosphere + + Parameters + ---------- + fdir : str, optional + Directory with snapshots. + rootname : str + rootname of the file (wihtout params or vars). + verbose : bool, optional + If True, will print more information. + it : integer + snapshot number + """ + + def __init__(self, filename, snap, *args, fdir='.', + sel_units='cgs', verbose=True, **kwargs): + + self.filename = filename + self.fdir = fdir + + a = readsav(filename, python_dict=True) + #self.la = a['la'] + self.snap = snap + la = a['la'] + l = a['la'][snap] + + #self.zmax = (a['la'][0][3][:,2]).max() + + self.extent = (0, (a['la'][0][3][:, 0]).max()-(a['la'][0][3][:, 0]).min(), + 0, (a['la'][0][3][:, 1]).max()-(a['la'][0][3][:, 1]).min(), + 0, (a['la'][0][3][:, 2]).max()-(a['la'][0][3][:, 2]).min()) + + self.obj = {'time': np.array([lal[0] for lal in la]), + 's': np.stack([lal[1] for lal in la]), + 'ux': np.stack([lal[2][:, 0] for lal in la]), + 'uy': np.stack([lal[2][:, 1] for lal in la]), + 'uz': np.stack([lal[2][:, 2] for lal in la]), + 'x': np.stack([lal[3][:, 0] for lal in la]), + 'y': np.stack([lal[3][:, 1] for lal in la]), + 'z': np.stack([lal[3][:, 2] for lal in la]), + 'rho': np.stack([lal[4] for lal in la]), + 'p': np.stack([lal[5] for lal in la]), + 'tg': np.stack([lal[6] for lal in la]), + 'ne': np.stack([lal[7] for lal in la]), + 'b': np.stack([lal[8] for lal in la]), + 'units': l[9]} + + # = {'time':l[0],'s':l[1],'ux':l[2][:,0],'uy':l[2][:,1],'uz':l[2][:,2], + # 'x':l[3][:,0], 'y':l[3][:,1], 'z':l[3][:,2], + # 'rho':l[4],'p':l[5],'tg':l[6],'ne':l[7],'b':l[8], 'units':l[9]} + + self.x = l[3][:, 0].copy() + self.x -= self.x.min() # np.array([0.0]) + self.y = l[3][:, 1].copy() + self.y -= self.y.min() # np.array([0.0]) + self.z = l[3][:, 2].copy() # np.array([0.0]) + #self.z = np.flip(self.rdobj.__getattr__('zm')) + self.sel_units = sel_units + self.verbose = verbose + #self.snap = None + self.uni = PREFT_units() + + #self.dx = np.array([1.0]) + #self.dy = np.array([1.0]) + #self.dz = np.copy(self.z) + self.nt = [1] # np.shape(self.z)[0] + self.nz = np.shape(self.z)[0] + # for it in range(0,self.nt): + #self.dz[it,:] = np.gradient(self.z[it,:]) + self.dx = np.gradient(self.x) + self.dy = np.gradient(self.y) + self.dz = np.gradient(self.z) + + self.dz1d = self.dz + self.dx1d = np.array([1.0]) + self.dy1d = np.array([1.0]) + + self.nx = np.shape(self.x) + self.ny = np.shape(self.y) + + #self.time = self.rdobj.__getattr__('time') + + self.transunits = False + + self.cstagop = False # This will not allow to use cstagger from Bifrost in load + self.hion = False # This will not allow to use HION from Bifrost in load + + self.genvar() + + def get_var(self, var, snap=None, iix=None, iiy=None, iiz=None, layout=None): + ''' + Reads the variables from a snapshot (it). + + Parameters + ---------- + var - string + Name of the variable to read. + + cgs- logic + converts into cgs units. + Axes: + ----- + z-axis is along the loop + x and y axes have only one grid. + + Information about radynpy library: + -------------- + ''' + + if snap == None: + snap = self.snap + if var in self.varn.keys(): + varname = self.varn[var] + else: + varname = var + + # print(var,varname,'try') + # print(self.obj.keys()) + try: + if self.sel_units == 'cgs': + varu = var.replace('x', '') + varu = varu.replace('y', '') + varu = varu.replace('z', '') + + if (var in self.varn.keys()) and (varu in self.uni.uni.keys()): + cgsunits = self.uni.uni[varu] + else: + cgsunits = 1.0 + else: + cgsunits = 1.0 + + # print(varname) + self.data = self.obj[varname][snap]*cgsunits + + #self.rdobj.__getattr__(varname) * cgsunits + except: + + self.data = load_quantities(self, var, PLASMA_QUANT='', CYCL_RES='', + COLFRE_QUANT='', COLFRI_QUANT='', IONP_QUANT='', + EOSTAB_QUANT='', TAU_QUANT='', DEBYE_LN_QUANT='', + CROSTAB_QUANT='', COULOMB_COL_QUANT='', AMB_QUANT='', + HALL_QUANT='', BATTERY_QUANT='', SPITZER_QUANT='', + KAPPA_QUANT='', GYROF_QUANT='', WAVE_QUANT='', + FLUX_QUANT='', CURRENT_QUANT='', COLCOU_QUANT='', + COLCOUMS_QUANT='', COLFREMX_QUANT='') + + if np.shape(self.data) == (): + if self.verbose: + print('Loading arithmetic variable', end="\r", flush=True) + self.data = load_arithmetic_quantities(self, var) + + if var == '': + + print(help(self.get_var)) + print('VARIABLES USING CGS OR GENERIC NOMENCLATURE') + for ii in self.varn: + print('use ', ii, ' for ', self.varn[ii]) + print(self.description['ALL']) + #print('\n radyn obj is self.rdobj, self.rdobj.var_info is as follows') + # print(self.rdobj.var_info) + + return None + + # self.trans2noncommaxes() + + return self.data + + def genvar(self): + ''' + Dictionary of original variables which will allow to convert to cgs. + ''' + self.varn = {} + self.varn['rho'] = 'rho' + self.varn['tg'] = 'tg' + self.varn['ux'] = 'ux' + self.varn['uy'] = 'uy' + self.varn['uz'] = 'uz' + self.varn['bx'] = 'bx' + self.varn['by'] = 'by' + self.varn['bz'] = 'bz' + self.varn['ne'] = 'ne' + + def trans2comm(self, varname, snap=0, **kwargs): + ''' + Transform the domain into a "common" format. All arrays will be 3D. The 3rd axis + is: + + - for 3D atmospheres: the vertical axis + - for loop type atmospheres: along the loop + - for 1D atmosphere: the unique dimension is the 3rd axis. + At least one extra dimension needs to be created artifically. + + All of them should obey the right hand rule + + In all of them, the vectors (velocity, magnetic field etc) away from the Sun. + + If applies, z=0 near the photosphere. + + Units: everything is in cgs. + + If an array is reverse, do ndarray.copy(), otherwise pytorch will complain. + + ''' + + self.sel_units = 'cgs' + + if not hasattr(self, 'trans_dx'): + self.trans_dx = 3e7 + if not hasattr(self, 'trans_dy'): + self.trans_dy = 3e7 + if not hasattr(self, 'trans_dz'): + self.trans_dz = 3e7 + + for key, value in kwargs.items(): + if key == 'dx': + if hasattr(self, 'trans_dx'): + if value != self.trans_dx: + self.transunits = False + if key == 'dz': + if hasattr(self, 'trans_dz'): + if value != self.trans_dz: + self.transunits = False + + if self.snap != snap: + self.snap = snap + self.transunits = False + + #var = self.get_var(varname) + + # What does this do? + #self.trans2commaxes(var, **kwargs) + + if not hasattr(self, 'trans_loop_width'): + self.trans_loop_width = 1.0 + if not hasattr(self, 'trans_sparse'): + self.trans_sparse = True + + for key, value in kwargs.items(): + if key == 'loop_width': + self.trans_loop_width = value + if key == 'unscale': + pass + if key == 'sparse': + self.trans_sparse = value + + # GSK -- smax was changed 12th March 2021. See comment in trans2commaxes + ##smax = s.max() + + shape = (ceil(self.extent[1]/self.trans_dx), + ceil(self.extent[3]/self.trans_dy), + ceil(self.extent[5]/self.trans_dz)+1) + + # In the PREFT model in the corona, successive grid points may be several pixels away from each other. + # In this case, need to refine loop. + #do_expscale = False + # for key, value in kwargs.items(): + # if key == 'unscale': + # do_expscale = value + + # if self.gridfactor > 1: + # if do_expscale: + # ss, var= refine(s, np.log(var),factor=self.gridfactor, unscale=np.exp) + # else: + # ss, var= refine(s, var,factor=self.gridfactor) + # else: + # ss = s + #var_copy = var.copy() + + x_loop = self.obj['x'][self.snap] + y_loop = self.obj['y'][self.snap] + z_loop = self.obj['z'][self.snap] + s_loop = self.obj['s'][self.snap] + x_loop -= x_loop.min() + y_loop -= y_loop.min() + z_loop -= z_loop.min() + + var = self.get_var(varname, snap=self.snap) + + print(s_loop, var) + x_loop, y_loop, z_loop, var = self.trans2commaxes(x_loop, y_loop, z_loop, var, s_loop, **kwargs) + + # Arc lengths (in radians) + dA = np.ones(var.shape)*self.trans_dx*self.trans_dy*self.trans_dz + xind = np.floor(x_loop/self.trans_dx).astype(np.int64) + yind = np.floor(y_loop/self.trans_dy).astype(np.int64) + zind = np.clip(np.floor(z_loop/self.trans_dz).astype(np.int64), 0, shape[2]-1) + + # Make copies of loops with an x-offset + for xoffset in range(-shape[0], shape[0], 10): + this_snap = self.snap + xoffset + shape[0] + x_loop = self.obj['x'][this_snap] + y_loop = self.obj['y'][this_snap] + z_loop = self.obj['z'][this_snap] + s_loop = self.obj['s'][this_snap] + x_loop -= x_loop.min() + y_loop -= y_loop.min() + z_loop -= z_loop.min() + this_var = self.get_var(varname, snap=this_snap) + print(this_snap, s_loop.shape) + + x_loop, y_loop, z_loop, this_var = self.trans2commaxes(x_loop, y_loop, z_loop, s_loop, this_var, **kwargs) + + xind = np.concatenate((xind, np.floor((x_loop+xoffset*self.trans_dx)/self.trans_dx).astype(np.int64))) + yind = np.concatenate((yind, np.floor(y_loop/self.trans_dy).astype(np.int64))) + zind = np.concatenate((zind, np.clip(np.floor(z_loop/self.trans_dz).astype(np.int64), 0, shape[2]-1))) + + dA = np.concatenate((dA, np.ones(var.shape)*self.trans_dx*self.trans_dy*self.trans_dz)) + var = np.concatenate((var, this_var)) + + good = (xind >= 0)*(xind < shape[0]) + good *= (yind >= 0)*(yind < shape[1]) + good *= (zind >= 0)*(zind < shape[2]) + xind = xind[good] + yind = yind[good] + zind = zind[good] + dA = dA[good] + var = var[good] + + # Define matrix with column coordinate corresponding to point along loop + # and row coordinate corresponding to position in Cartesian grid + col = np.arange(len(xind), dtype=np.int64) + row = (xind*shape[1]+yind)*shape[2]+zind + + if self.trans_sparse: + M = coo_matrix((dA/(self.trans_dx*self.trans_dy*self.trans_dz), + (row, col)), shape=(shape[0]*shape[1]*shape[2], len(xind)), dtype=np.float) + M = M.tocsr() + else: + M = np.zeros(shape=(shape[0]*shape[1]*shape[2], len(ss)), dtype=np.float) + M[row, col] = dA/(self.dx1d*self.dz1d.max()) # weighting by area of arc segment + + # The final quantity at each Cartesian grid cell is an volume-weighted + # average of values from loop passing through this grid cell + # This arrays are not actually used for VDEM extraction + var = (M@var).reshape(shape) + + self.x = np.linspace(self.x_loop.min(), self.x_loop.max(), np.shape(var)[0]) + self.y = np.linspace(self.y_loop.min(), self.y_loop.max(), np.shape(var)[1]) + self.z = np.linspace(self.z_loop.min(), self.z_loop.max(), np.shape(var)[2]) + + self.dx1d = np.gradient(self.x) + self.dy1d = np.gradient(self.y) + self.dz1d = np.gradient(self.z) + + return var + + def trans2commaxes(self, x, y, z, s, var, **kwargs): + + if self.transunits == False: + + if not hasattr(self, 'trans_dx'): + self.trans_dx = 3e7 + if not hasattr(self, 'trans_dy'): + self.trans_dy = 3e7 + if not hasattr(self, 'trans_dz'): + self.trans_dz = 3e7 + + for key, value in kwargs.items(): + if key == 'dx': + self.trans_dx = value + if key == 'dy': + self.trans_dy = value + if key == 'dz': + self.trans_dz = value + + # Semicircular loop + # s = self.obj['s'] #np.copy(self.zorig) + #good = (s >=0.0) + #s = s[good] + + #x = self.x + #y = self.y + #z = self.z + + #shape = (ceil(x.max()/self.trans_dx),ceil(y.max()/self.trans_dy), ceil(self.zmax/self.trans_dz)) + + # In the RADYN model in the corona, successive grid points may be several pixels away from each other. + # In this case, need to refine loop. + maxdl = np.sqrt((z[1:]-z[0:-1])**2 + (x[1:]-x[0:-1])**2 + (y[1:]-y[0:-1])**2).max() + gridfactor = ceil(2*maxdl/np.min([self.trans_dx, self.trans_dy, self.trans_dz])) + + do_expscale = False + for key, value in kwargs.items(): + if key == 'unscale': + do_expscale = value + + if gridfactor > 1: + print(s, x) + ss, x_loop = refine(s, x, factor=gridfactor) + ss, y_loop = refine(s, y, factor=gridfactor) + ss, z_loop = refine(s, z, factor=gridfactor) + if do_expscale: + ss, var_loop = refine(s, np.log(var), factor=gridfactor, unscale=np.exp) + else: + ss, var_loop = refine(s, var, factor=gridfactor) + else: + x_loop = x.copy() + y_loop = y.copy() + z_loop = z.copy() + var_loop = var.copy() + + self.dx1d_loop = np.gradient(x_loop) + self.dy1d_loop = np.gradient(y_loop) + self.dz1d_loop = np.gradient(z_loop) + + self.transunits = True + return x_loop, y_loop, z_loop, var_loop + + def trans2noncommaxes(self): + + if self.transunits == True: + self.x = np.array([0.0]) + self.y = np.array([0.0]) + self.z = self.rdobj.__getattr__('zm') + + self.dx = np.array([1.0]) + self.dy = np.array([1.0]) + self.dz = np.copy(self.z) + self.nz = np.shape(self.z)[1] + for it in range(0, self.nt): + self.dz[it, :] = np.gradient(self.z[it, :]) + self.dz1d = self.dz + self.dx1d = np.array([1.0]) + self.dy1d = np.array([1.0]) + + self.nx = np.shape(self.x) + self.ny = np.shape(self.y) + self.transunits = False + + +class PREFT_units(object): + + def __init__(self, verbose=False): + ''' + Units and constants in cgs + ''' + self.uni = {} + self.verbose = verbose + self.uni['tg'] = 1.0 + self.uni['l'] = 1.0 + self.uni['n'] = 1.0 + self.uni['rho'] = 1.0 + self.uni['u'] = 1.0 + self.uni['b'] = 1.0 + self.uni['t'] = 1.0 # seconds + self.uni['j'] = 1.0 + + # Units and constants in SI + convertcsgsi(self) + + globalvars(self) diff --git a/helita/sim/pypluto.py b/helita/sim/pypluto.py new file mode 100644 index 00000000..23dc5fbc --- /dev/null +++ b/helita/sim/pypluto.py @@ -0,0 +1,2206 @@ +# -*- coding: utf-8 -*- +import os +import sys +import array + +import numpy as np +import scipy.constants as const +import scipy.interpolate +import scipy.ndimage +from scipy.interpolate import UnivariateSpline +from scipy.ndimage import rotate + +from . import document_vars +from .load_arithmetic_quantities import * +from .load_noeos_quantities import * +from .load_quantities import * +from .tools import * + +####### Check for h5py to Read AMR data ###### +try: + import h5py as h5 + hasH5 = True +except ImportError: + hasH5 = False + + +def curdir(): + """ Get the current working directory. + """ + curdir = os.getcwd()+'/' + return curdir + + +def get_nstepstr(ns): + """ Convert the float input *ns* into a string that would match the data file name. + **Inputs**: + ns -- Integer number that represents the time step number. E.g., The ns for data.0001.dbl is 1.\n + **Outputs**: + Returns the string that would be used to complete the data file name. E.g., + for data.0001.dbl, ns = 1 and pyPLUTO.get_nstepstr(1) returns '0001' + """ + nstepstr = str(ns) + while len(nstepstr) < 4: + nstepstr = '0'+nstepstr + return nstepstr + + +def nlast_info(w_dir=None, datatype=None): + """ Prints the information of the last step of the simulation as obtained from out files + + **Inputs**: + + w_dir -- path to the directory which has the dbl.out(or flt.out) and the data\n + datatype -- If the data is of 'float' type then datatype = 'float' else by default the datatype is set to 'double'. + + **Outputs**: + + This function returns a dictionary with following keywords - \n + nlast -- The ns for the last file saved.\n + time -- The simulation time for the last file saved.\n + dt -- The time step dt for the last file. \n + Nstep -- The Nstep value for the last file saved. + + + **Usage**: + + In case the data is 'float'. + + ``wdir = /path/to/data/directory``\n + ``import pyPLUTO as pp``\n + ``A = pp.nlast_info(w_dir=wdir,datatype='float')`` + """ + if w_dir is None: + w_dir = curdir() + if datatype == 'float': + fname_v = w_dir+"flt.out" + elif datatype == 'vtk': + fname_v = w_dir+"vtk.out" + else: + fname_v = w_dir+"dbl.out" + last_line = file(fname_v, "r").readlines()[-1].split() + nlast = int(last_line[0]) + SimTime = float(last_line[1]) + Dt = float(last_line[2]) + Nstep = int(last_line[3]) + + print("------------TIME INFORMATION--------------") + print('nlast =', nlast) + print('time =', SimTime) + print('dt =', Dt) + print('Nstep =', Nstep) + print("-------------------------------------------") + + return {'nlast': nlast, 'time': SimTime, 'dt': Dt, 'Nstep': Nstep} + + +class pload(object): + def __init__(self, ns, w_dir=None, datatype=None, level=0, x1range=None, x2range=None, x3range=None): + """Loads the data. + + **Inputs**: + + ns -- Step Number of the data file\n + w_dir -- path to the directory which has the data files\n + datatype -- Datatype (default = 'double') + + **Outputs**: + + pyPLUTO pload object whose keys are arrays of data values. + + """ + self.NStep = ns + self.Dt = 0.0 + + self.n1 = 0 + self.n2 = 0 + self.n3 = 0 + + self.x1 = [] + self.x2 = [] + self.x3 = [] + self.dx1 = [] + self.dx2 = [] + self.dx3 = [] + + self.x1range = x1range + self.x2range = x2range + self.x3range = x3range + + self.NStepStr = str(self.NStep) + while len(self.NStepStr) < 4: + self.NStepStr = '0'+self.NStepStr + + if datatype is None: + datatype = "double" + self.datatype = datatype + + if ((not hasH5) and (datatype == 'hdf5')): + print('To read AMR hdf5 files with python') + print('Please install h5py (Python HDF5 Reader)') + return + + self.level = level + + if w_dir is None: + w_dir = os.getcwd() + '/' + self.wdir = w_dir + + Data_dictionary = self.ReadDataFile(self.NStepStr) + for keys in Data_dictionary: + object.__setattr__(self, keys, Data_dictionary.get(keys)) + + def ReadTimeInfo(self, timefile): + """ Read time info from the outfiles. + + **Inputs**: + + timefile -- name of the out file which has timing information. + + """ + + if (self.datatype == 'hdf5'): + fh5 = h5.File(timefile, 'r') + self.SimTime = fh5.attrs.get('time') + # self.Dt = 1.e-2 # Should be erased later given the level in AMR + fh5.close() + else: + ns = self.NStep + f_var = open(timefile, "r") + tlist = [] + for line in f_var.readlines(): + tlist.append(line.split()) + self.SimTime = float(tlist[ns][1]) + self.Dt = float(tlist[ns][2]) + + def ReadVarFile(self, varfile): + """ Read variable names from the outfiles. + + **Inputs**: + + varfile -- name of the out file which has variable information. + + """ + if (self.datatype == 'hdf5'): + fh5 = h5.File(varfile, 'r') + self.filetype = 'single_file' + self.endianess = '>' # not used with AMR, kept for consistency + self.vars = [] + for iv in range(fh5.attrs.get('num_components')): + self.vars.append(fh5.attrs.get('component_'+str(iv))) + fh5.close() + else: + vfp = open(varfile, "r") + varinfo = vfp.readline().split() + self.filetype = varinfo[4] + self.endianess = varinfo[5] + self.vars = varinfo[6:] + vfp.close() + + def ReadGridFile(self, gridfile): + """ Read grid values from the grid.out file. + + **Inputs**: + + gridfile -- name of the grid.out file which has information about the grid. + + """ + xL = [] + xR = [] + nmax = [] + gfp = open(gridfile, "r") + for i in gfp.readlines(): + if len(i.split()) == 1: + try: + int(i.split()[0]) + nmax.append(int(i.split()[0])) + except: + pass + + if len(i.split()) == 3: + try: + int(i.split()[0]) + xL.append(float(i.split()[1])) + xR.append(float(i.split()[2])) + except: + if (i.split()[1] == 'GEOMETRY:'): + self.geometry = i.split()[2] + + self.n1, self.n2, self.n3 = nmax + n1 = self.n1 + n1p2 = self.n1 + self.n2 + n1p2p3 = self.n1 + self.n2 + self.n3 + self.x1 = np.asarray([0.5*(xL[i]+xR[i]) for i in range(n1)]) + self.dx1 = np.asarray([(xR[i]-xL[i]) for i in range(n1)]) + self.x2 = np.asarray([0.5*(xL[i]+xR[i]) for i in range(n1, n1p2)]) + self.dx2 = np.asarray([(xR[i]-xL[i]) for i in range(n1, n1p2)]) + self.x3 = np.asarray([0.5*(xL[i]+xR[i]) for i in range(n1p2, n1p2p3)]) + self.dx3 = np.asarray([(xR[i]-xL[i]) for i in range(n1p2, n1p2p3)]) + + # Stores the total number of points in '_tot' variable in case only + # a portion of the domain is loaded. Redefine the x and dx arrays + # to match the requested ranges + self.n1_tot = self.n1 + self.n2_tot = self.n2 + self.n3_tot = self.n3 + if (self.x1range != None): + self.n1_tot = self.n1 + self.irange = range(abs(self.x1-self.x1range[0]).argmin(), abs(self.x1-self.x1range[1]).argmin()+1) + self.n1 = len(self.irange) + self.x1 = self.x1[self.irange] + self.dx1 = self.dx1[self.irange] + else: + self.irange = range(self.n1) + if (self.x2range != None): + self.n2_tot = self.n2 + self.jrange = range(abs(self.x2-self.x2range[0]).argmin(), abs(self.x2-self.x2range[1]).argmin()+1) + self.n2 = len(self.jrange) + self.x2 = self.x2[self.jrange] + self.dx2 = self.dx2[self.jrange] + else: + self.jrange = range(self.n2) + if (self.x3range != None): + self.n3_tot = self.n3 + self.krange = range(abs(self.x3-self.x3range[0]).argmin(), abs(self.x3-self.x3range[1]).argmin()+1) + self.n3 = len(self.krange) + self.x3 = self.x3[self.krange] + self.dx3 = self.dx3[self.krange] + else: + self.krange = range(self.n3) + self.Slice = (self.x1range != None) or (self.x2range != None) or (self.x3range != None) + + # Create the xr arrays containing the edges positions + # Useful for pcolormesh which should use those + self.x1r = np.zeros(len(self.x1)+1) + self.x1r[1:] = self.x1 + self.dx1/2.0 + self.x1r[0] = self.x1r[1]-self.dx1[0] + self.x2r = np.zeros(len(self.x2)+1) + self.x2r[1:] = self.x2 + self.dx2/2.0 + self.x2r[0] = self.x2r[1]-self.dx2[0] + self.x3r = np.zeros(len(self.x3)+1) + self.x3r[1:] = self.x3 + self.dx3/2.0 + self.x3r[0] = self.x3r[1]-self.dx3[0] + + prodn = self.n1*self.n2*self.n3 + if prodn == self.n1: + self.nshp = (self.n1) + elif prodn == self.n1*self.n2: + self.nshp = (self.n2, self.n1) + else: + self.nshp = (self.n3, self.n2, self.n1) + + def DataScanVTK(self, fp, n1, n2, n3, endian, dtype): + """ Scans the VTK data files. + + **Inputs**: + + fp -- Data file pointer\n + n1 -- No. of points in X1 direction\n + n2 -- No. of points in X2 direction\n + n3 -- No. of points in X3 direction\n + endian -- Endianess of the data\n + dtype -- datatype + + **Output**: + + Dictionary consisting of variable names as keys and its values. + + """ + ks = [] + vtkvar = [] + while True: + l = fp.readline() + try: + l.split()[0] + except IndexError: + pass + else: + if l.split()[0] == 'SCALARS': + ks.append(l.split()[1]) + elif l.split()[0] == 'LOOKUP_TABLE': + A = array.array(dtype) + fmt = endian+str(n1*n2*n3)+dtype + nb = np.dtype(fmt).itemsize + A.fromstring(fp.read(nb)) + if (self.Slice): + darr = np.zeros((n1*n2*n3)) + indxx = np.sort([n3_tot*n2_tot*k + j*n2_tot + i for i in self.irange for j in self.jrange for k in self.krange]) + if (sys.byteorder != self.endianess): + A.byteswap() + for ii, iii in enumerate(indxx): + darr[ii] = A[iii] + vtkvar_buf = [darr] + else: + vtkvar_buf = np.frombuffer(A, dtype=np.dtype(fmt)) + vtkvar.append(np.reshape(vtkvar_buf, self.nshp).transpose()) + else: + pass + if l == '': + break + + vtkvardict = dict(zip(ks, vtkvar)) + return vtkvardict + + def DataScanHDF5(self, fp, myvars, ilev): + """ Scans the Chombo HDF5 data files for AMR in PLUTO. + + **Inputs**: + + fp -- Data file pointer\n + myvars -- Names of the variables to read\n + ilev -- required AMR level + + **Output**: + + Dictionary consisting of variable names as keys and its values. + + **Note**: + + Due to the particularity of AMR, the grid arrays loaded in ReadGridFile are overwritten here. + + """ + # Read the grid information + dim = fp['Chombo_global'].attrs.get('SpaceDim') + nlev = fp.attrs.get('num_levels') + il = min(nlev-1, ilev) + lev = [] + for i in range(nlev): + lev.append('level_'+str(i)) + freb = np.zeros(nlev, dtype='int') + for i in range(il+1)[::-1]: + fl = fp[lev[i]] + if (i == il): + pdom = fl.attrs.get('prob_domain') + dx = fl.attrs.get('dx') + dt = fl.attrs.get('dt') + ystr = 1. + zstr = 1. + logr = 0 + try: + fl.attrs.get('geometry') + logr = fl.attrs.get('logr') + if (dim == 2): + ystr = fl.attrs.get('g_x2stretch') + elif (dim == 3): + zstr = fl.attrs.get('g_x3stretch') + except: + print('Old HDF5 file, not reading stretch and logr factors') + freb[i] = 1 + x1b = fl.attrs.get('domBeg1') + if (dim == 1): + x2b = 0 + else: + x2b = fl.attrs.get('domBeg2') + if (dim == 1 or dim == 2): + x3b = 0 + else: + x3b = fl.attrs.get('domBeg3') + jbeg = 0 + jend = 0 + ny = 1 + kbeg = 0 + kend = 0 + nz = 1 + if (dim == 1): + ibeg = pdom[0] + iend = pdom[1] + nx = iend-ibeg+1 + elif (dim == 2): + ibeg = pdom[0] + iend = pdom[2] + nx = iend-ibeg+1 + jbeg = pdom[1] + jend = pdom[3] + ny = jend-jbeg+1 + elif (dim == 3): + ibeg = pdom[0] + iend = pdom[3] + nx = iend-ibeg+1 + jbeg = pdom[1] + jend = pdom[4] + ny = jend-jbeg+1 + kbeg = pdom[2] + kend = pdom[5] + nz = kend-kbeg+1 + else: + rat = fl.attrs.get('ref_ratio') + freb[i] = rat*freb[i+1] + + dx0 = dx*freb[0] + + # Allow to load only a portion of the domain + if (self.x1range != None): + if logr == 0: + self.x1range = self.x1range-x1b + else: + self.x1range = [log(self.x1range[0]/x1b), log(self.x1range[1]/x1b)] + ibeg0 = min(self.x1range)/dx0 + iend0 = max(self.x1range)/dx0 + ibeg = max([ibeg, int(ibeg0*freb[0])]) + iend = min([iend, int(iend0*freb[0]-1)]) + nx = iend-ibeg+1 + if (self.x2range != None): + self.x2range = (self.x2range-x2b)/ystr + jbeg0 = min(self.x2range)/dx0 + jend0 = max(self.x2range)/dx0 + jbeg = max([jbeg, int(jbeg0*freb[0])]) + jend = min([jend, int(jend0*freb[0]-1)]) + ny = jend-jbeg+1 + if (self.x3range != None): + self.x3range = (self.x3range-x3b)/zstr + kbeg0 = min(self.x3range)/dx0 + kend0 = max(self.x3range)/dx0 + kbeg = max([kbeg, int(kbeg0*freb[0])]) + kend = min([kend, int(kend0*freb[0]-1)]) + nz = kend-kbeg+1 + + # Create uniform grids at the required level + if logr == 0: + x1 = x1b + (ibeg+np.array(range(nx))+0.5)*dx + else: + x1 = x1b*(exp((ibeg+np.array(range(nx))+1)*dx)+exp((ibeg+np.array(range(nx)))*dx))*0.5 + + x2 = x2b + (jbeg+np.array(range(ny))+0.5)*dx*ystr + x3 = x3b + (kbeg+np.array(range(nz))+0.5)*dx*zstr + if logr == 0: + dx1 = np.ones(nx)*dx + else: + dx1 = x1b*(exp((ibeg+np.array(range(nx))+1)*dx)-exp((ibeg+np.array(range(nx)))*dx)) + dx2 = np.ones(ny)*dx*ystr + dx3 = np.ones(nz)*dx*zstr + + # Create the xr arrays containing the edges positions + # Useful for pcolormesh which should use those + x1r = np.zeros(len(x1)+1) + x1r[1:] = x1 + dx1/2.0 + x1r[0] = x1r[1]-dx1[0] + x2r = np.zeros(len(x2)+1) + x2r[1:] = x2 + dx2/2.0 + x2r[0] = x2r[1]-dx2[0] + x3r = np.zeros(len(x3)+1) + x3r[1:] = x3 + dx3/2.0 + x3r[0] = x3r[1]-dx3[0] + NewGridDict = dict([('n1', nx), ('n2', ny), ('n3', nz), + ('x1', x1), ('x2', x2), ('x3', x3), + ('x1r', x1r), ('x2r', x2r), ('x3r', x3r), + ('dx1', dx1), ('dx2', dx2), ('dx3', dx3), + ('Dt', dt)]) + + # Variables table + nvar = len(myvars) + vars = np.zeros((nx, ny, nz, nvar)) + + LevelDic = {'nbox': 0, 'ibeg': ibeg, 'iend': iend, 'jbeg': jbeg, 'jend': jend, 'kbeg': kbeg, 'kend': kend} + AMRLevel = [] + AMRBoxes = np.zeros((nx, ny, nz)) + for i in range(il+1): + AMRLevel.append(LevelDic.copy()) + fl = fp[lev[i]] + data = fl['data:datatype=0'] + boxes = fl['boxes'] + nbox = len(boxes['lo_i']) + AMRLevel[i]['nbox'] = nbox + ncount = long(0) + AMRLevel[i]['box'] = [] + for j in range(nbox): # loop on all boxes of a given level + AMRLevel[i]['box'].append({'x0': 0., 'x1': 0., 'ib': long(0), 'ie': long(0), + 'y0': 0., 'y1': 0., 'jb': long(0), 'je': long(0), + 'z0': 0., 'z1': 0., 'kb': long(0), 'ke': long(0)}) + # Box indexes + ib = boxes[j]['lo_i'] + ie = boxes[j]['hi_i'] + nbx = ie-ib+1 + jb = 0 + je = 0 + nby = 1 + kb = 0 + ke = 0 + nbz = 1 + if (dim > 1): + jb = boxes[j]['lo_j'] + je = boxes[j]['hi_j'] + nby = je-jb+1 + if (dim > 2): + kb = boxes[j]['lo_k'] + ke = boxes[j]['hi_k'] + nbz = ke-kb+1 + szb = nbx*nby*nbz*nvar + # Rescale to current level + kb = kb*freb[i] + ke = (ke+1)*freb[i] - 1 + jb = jb*freb[i] + je = (je+1)*freb[i] - 1 + ib = ib*freb[i] + ie = (ie+1)*freb[i] - 1 + + # Skip boxes lying outside ranges + if ((ib > iend) or (ie < ibeg) or + (jb > jend) or (je < jbeg) or + (kb > kend) or (ke < kbeg)): + ncount = ncount + szb + else: + + # Read data + q = data[ncount:ncount+szb].reshape((nvar, nbz, nby, nbx)).T + + # Find boxes intersections with current domain ranges + ib0 = max([ibeg, ib]) + ie0 = min([iend, ie]) + jb0 = max([jbeg, jb]) + je0 = min([jend, je]) + kb0 = max([kbeg, kb]) + ke0 = min([kend, ke]) + + # Store box corners in the AMRLevel structure + if logr == 0: + AMRLevel[i]['box'][j]['x0'] = x1b + dx*(ib0) + AMRLevel[i]['box'][j]['x1'] = x1b + dx*(ie0+1) + else: + AMRLevel[i]['box'][j]['x0'] = x1b*exp(dx*(ib0)) + AMRLevel[i]['box'][j]['x1'] = x1b*exp(dx*(ie0+1)) + AMRLevel[i]['box'][j]['y0'] = x2b + dx*(jb0)*ystr + AMRLevel[i]['box'][j]['y1'] = x2b + dx*(je0+1)*ystr + AMRLevel[i]['box'][j]['z0'] = x3b + dx*(kb0)*zstr + AMRLevel[i]['box'][j]['z1'] = x3b + dx*(ke0+1)*zstr + AMRLevel[i]['box'][j]['ib'] = ib0 + AMRLevel[i]['box'][j]['ie'] = ie0 + AMRLevel[i]['box'][j]['jb'] = jb0 + AMRLevel[i]['box'][j]['je'] = je0 + AMRLevel[i]['box'][j]['kb'] = kb0 + AMRLevel[i]['box'][j]['ke'] = ke0 + AMRBoxes[ib0-ibeg:ie0-ibeg+1, jb0-jbeg:je0-jbeg+1, kb0-kbeg:ke0-kbeg+1] = il + + # Extract the box intersection from data stored in q + cib0 = (ib0-ib)/freb[i] + cie0 = (ie0-ib)/freb[i] + cjb0 = (jb0-jb)/freb[i] + cje0 = (je0-jb)/freb[i] + ckb0 = (kb0-kb)/freb[i] + cke0 = (ke0-kb)/freb[i] + q1 = np.zeros((cie0-cib0+1, cje0-cjb0+1, cke0-ckb0+1, nvar)) + q1 = q[cib0:cie0+1, cjb0:cje0+1, ckb0:cke0+1, :] + + # Remap the extracted portion + if (dim == 1): + new_shape = (ie0-ib0+1, 1) + elif (dim == 2): + new_shape = (ie0-ib0+1, je0-jb0+1) + else: + new_shape = (ie0-ib0+1, je0-jb0+1, ke0-kb0+1) + + stmp = list(new_shape) + while stmp.count(1) > 0: + stmp.remove(1) + new_shape = tuple(stmp) + + myT = Tools() + for iv in range(nvar): + vars[ib0-ibeg:ie0-ibeg+1, jb0-jbeg:je0-jbeg+1, kb0-kbeg:ke0-kbeg+1, iv] = \ + myT.congrid(q1[:, :, :, iv].squeeze(), new_shape, method='linear', minusone=True).reshape((ie0-ib0+1, je0-jb0+1, ke0-kb0+1)) + ncount = ncount+szb + + h5vardict = {} + for iv in range(nvar): + h5vardict[myvars[iv]] = vars[:, :, :, iv].squeeze() + AMRdict = dict([('AMRBoxes', AMRBoxes), ('AMRLevel', AMRLevel)]) + OutDict = dict(NewGridDict) + OutDict.update(AMRdict) + OutDict.update(h5vardict) + return OutDict + + def DataScan(self, fp, n1, n2, n3, endian, dtype, off=None): + """ Scans the data files in all formats. + + **Inputs**: + + fp -- Data file pointer\n + n1 -- No. of points in X1 direction\n + n2 -- No. of points in X2 direction\n + n3 -- No. of points in X3 direction\n + endian -- Endianess of the data\n + dtype -- datatype, eg : double, float, vtk, hdf5\n + off -- offset (for avoiding staggered B fields) + + **Output**: + + Dictionary consisting of variable names as keys and its values. + + """ + if off is not None: + off_fmt = endian+str(off)+dtype + nboff = np.dtype(off_fmt).itemsize + fp.read(nboff) + + n1_tot = self.n1_tot + n2_tot = self.n2_tot + n3_tot = self.n3_tot + + A = array.array(dtype) + fmt = endian+str(n1_tot*n2_tot*n3_tot)+dtype + nb = np.dtype(fmt).itemsize + A.fromstring(fp.read(nb)) + + if (self.Slice): + darr = np.zeros((n1*n2*n3)) + indxx = np.sort([n3_tot*n2_tot*k + j*n2_tot + i for i in self.irange for j in self.jrange for k in self.krange]) + if (sys.byteorder != self.endianess): + A.byteswap() + for ii, iii in enumerate(indxx): + darr[ii] = A[iii] + darr = [darr] + else: + darr = np.frombuffer(A, dtype=np.dtype(fmt)) + + return np.reshape(darr[0], self.nshp).transpose() + + def ReadSingleFile(self, datafilename, myvars, n1, n2, n3, endian, + dtype, ddict): + """Reads a single data file, data.****.dtype. + + **Inputs**: + + datafilename -- Data file name\n + + myvars -- List of variable names to be read\n + n1 -- No. of points in X1 direction\n + n2 -- No. of points in X2 direction\n + n3 -- No. of points in X3 direction\n + endian -- Endianess of the data\n + dtype -- datatype\n + ddict -- Dictionary containing Grid and Time Information + which is updated + + **Output**: + + Updated Dictionary consisting of variable names as keys and its values. + """ + if self.datatype == 'hdf5': + fp = h5.File(datafilename, 'r') + else: + fp = open(datafilename, "rb") + + print("Reading Data file : %s" % datafilename, end="\r", flush=True) + + if self.datatype == 'vtk': + vtkd = self.DataScanVTK(fp, n1, n2, n3, endian, dtype) + ddict.update(vtkd) + elif self.datatype == 'hdf5': + h5d = self.DataScanHDF5(fp, myvars, self.level) + ddict.update(h5d) + else: + for i in range(len(myvars)): + if myvars[i] == 'bx1s': + ddict.update({myvars[i]: self.DataScan(fp, n1, n2, n3, endian, + dtype, off=n2*n3)}) + elif myvars[i] == 'bx2s': + ddict.update({myvars[i]: self.DataScan(fp, n1, n2, n3, endian, + dtype, off=n3*n1)}) + elif myvars[i] == 'bx3s': + ddict.update({myvars[i]: self.DataScan(fp, n1, n2, n3, endian, + dtype, off=n1*n2)}) + else: + ddict.update({myvars[i]: self.DataScan(fp, n1, n2, n3, endian, + dtype)}) + + fp.close() + + def ReadMultipleFiles(self, nstr, dataext, myvars, n1, n2, n3, endian, + dtype, ddict): + """Reads a multiple data files, varname.****.dataext. + + **Inputs**: + + nstr -- File number in form of a string\n + + dataext -- Data type of the file, e.g., 'dbl', 'flt' or 'vtk' \n + myvars -- List of variable names to be read\n + n1 -- No. of points in X1 direction\n + n2 -- No. of points in X2 direction\n + n3 -- No. of points in X3 direction\n + endian -- Endianess of the data\n + dtype -- datatype\n + ddict -- Dictionary containing Grid and Time Information + which is updated. + + **Output**: + + Updated Dictionary consisting of variable names as keys and its values. + + """ + for i in range(len(myvars)): + datafilename = self.wdir+myvars[i]+"."+nstr+dataext + fp = open(datafilename, "rb") + if self.datatype == 'vtk': + ddict.update(self.DataScanVTK(fp, n1, n2, n3, endian, dtype)) + else: + ddict.update({myvars[i]: self.DataScan(fp, n1, n2, n3, endian, + dtype)}) + fp.close() + + def ReadDataFile(self, num): + """Reads the data file generated from PLUTO code. + + **Inputs**: + + num -- Data file number in form of an Integer. + + **Outputs**: + + Dictionary that contains all information about Grid, Time and + variables. + + """ + gridfile = self.wdir+"grid.out" + if self.datatype == "float": + dtype = "f" + varfile = self.wdir+"flt.out" + dataext = ".flt" + elif self.datatype == "vtk": + dtype = "f" + varfile = self.wdir+"vtk.out" + dataext = ".vtk" + elif self.datatype == 'hdf5': + dtype = 'd' + dataext = '.hdf5' + nstr = num + varfile = self.wdir+"data."+nstr+dataext + else: + dtype = "d" + varfile = self.wdir+"dbl.out" + dataext = ".dbl" + + self.ReadVarFile(varfile) + self.ReadGridFile(gridfile) + self.ReadTimeInfo(varfile) + nstr = num + if self.endianess == 'big': + endian = ">" + elif self.datatype == 'vtk': + endian = ">" + else: + endian = "<" + + D = [('NStep', self.NStep), ('SimTime', self.SimTime), ('Dt', self.Dt), + ('n1', self.n1), ('n2', self.n2), ('n3', self.n3), + ('x1', self.x1), ('x2', self.x2), ('x3', self.x3), + ('dx1', self.dx1), ('dx2', self.dx2), ('dx3', self.dx3), + ('endianess', self.endianess), ('datatype', self.datatype), + ('filetype', self.filetype)] + ddict = dict(D) + + if self.filetype == "single_file": + datafilename = self.wdir+"data."+nstr+dataext + self.ReadSingleFile(datafilename, self.vars, self.n1, self.n2, + self.n3, endian, dtype, ddict) + elif self.filetype == "multiple_files": + self.ReadMultipleFiles(nstr, dataext, self.vars, self.n1, self.n2, + self.n3, endian, dtype, ddict) + else: + print("Wrong file type : CHECK pluto.ini for file type.") + print("Only supported are .dbl, .flt, .vtk, .hdf5") + sys.exit() + + return ddict + + +class PlutoData(object): + """ + Class to read Pluto vtk or dbl atmosphere + + Parameters + ---------- + fdir : str, optional + Directory with snapshots. + rootname : str + rootname of the file (wihtout params or vars). + verbose : bool, optional + If True, will print more information. + it : integer + snapshot number + """ + + def __init__(self, rootname, snap, fdir='./', datatype='dbl', + verbose=True, sel_units='cgs', typemodel='Kostas', *args, **kwargs): + + #super(PlutoData, self).__init__(it, w_dir=fdir, datatype=datatype, *args, **kwargs) + + #self.rootname = rootname + self.fdir = fdir + self.snap = snap + self.sel_units = sel_units + self.verbose = verbose + self.typemodel = typemodel + self.datatype = datatype + if self.typemodel == 'Kostas': + self.uni = Pypluto_kostas_units() + elif (self.typemodel == 'Paolo'): + self.uni = Pypluto_paolo_units() + self.info = pload(snap, w_dir=fdir, datatype=datatype) + self.time = self.info.SimTime + self.x = self.info.x1 + self.y = self.info.x2 + self.z = self.info.x3 + self.zorig = self.z + + if self.sel_units == 'cgs': + self.x *= self.uni.uni['l'] + self.y *= self.uni.uni['l'] + self.z *= self.uni.uni['l'] + + self.nx = len(self.x) + self.ny = len(self.y) + self.nz = len(self.z) + + if self.nx > 1: + self.dx1d = np.gradient(self.x) + else: + self.dx1d = np.zeros(self.nx) + + if self.ny > 1: + self.dy1d = np.gradient(self.y) + else: + self.dy1d = np.zeros(self.ny) + + if self.nz > 1: + self.dz1d = np.gradient(self.z) + else: + self.dz1d = np.zeros(self.nz) + + self.transunits = False + + self.cstagop = False # This will not allow to use cstagger from Bifrost in load + self.hion = False # This will not allow to use HION from Bifrost in load + + # self.time = params['time'] # No uniforme (array) + self.genvar() + + document_vars.create_vardict(self) + document_vars.set_vardocs(self) + + def get_var(self, var, *args, snap=None, iix=None, iiy=None, iiz=None, layout=None, **kwargs): + ''' + Reads the variables from a snapshot (snap). + + Parameters + ---------- + var - string + Name of the variable to read. Must be Bifrost internal names. + snap - integer, optional + Snapshot number to read. By default reads the loaded snapshot; + if a different number is requested, will load that snapshot. + + Axes: + ----- + z-axis is along the loop + x and y axes are perperdicular to the loop + + Variable list: + -------------- + rho -- Density (multipy by self.uni['rho'] to get in g/cm^3) + prs -- pressure (multipy by self.uni['pg'] to get in cgs) + vx1 -- component x of the velocity (multipy by self.uni['u'] to get in cm/s) + vx2 -- component y of the velocity (multipy by self.uni['u'] to get in cm/s) + vx3 -- component z of the velocity (multipy by self.uni['u'] to get in cm/s) + bx1 -- component x of the magnetic field (multipy by self.uni['b'] to get in G) + bx2 -- component y of the magnetic field (multipy by self.uni['b'] to get in G) + bx1 -- component z of the magnetic field (multipy by self.uni['b'] to get in G) + ''' + + if ((snap != None) and (self.snap != snap)): + self.snap = snap + self.info = pload(snap, w_dir=self.fdir, datatype=self.datatype) + self.time = self.info.SimTime + + if var in self.varn.keys(): + if self.sel_units == 'cgs': + varu = var.replace('x', '') + varu = varu.replace('y', '') + varu = varu.replace('z', '') + if (var in self.varn.keys()) and (varu in self.uni.uni.keys()): + cgsunits = self.uni.uni[varu] + else: + cgsunits = 1.0 + else: + cgsunits = 1.0 + + self.data = getattr(self.info, self.varn[var]) * cgsunits + else: + + self.get_comp_vars(var, *args, snap=snap, iix=iix, iiy=iiy, iiz=iiz, layout=layout, **kwargs) + + if np.shape(self.data) == (): + # Loading quantities + if self.verbose: + print('Loading composite variable', end="\r", flush=True) + self.data = load_noeos_quantities(self, var, **kwargs) + + if np.shape(self.data) == (): + self.data = load_quantities(self, var, PLASMA_QUANT='', CYCL_RES='', + COLFRE_QUANT='', COLFRI_QUANT='', IONP_QUANT='', + EOSTAB_QUANT='', TAU_QUANT='', DEBYE_LN_QUANT='', + CROSTAB_QUANT='', COULOMB_COL_QUANT='', AMB_QUANT='', + HALL_QUANT='', BATTERY_QUANT='', SPITZER_QUANT='', + KAPPA_QUANT='', GYROF_QUANT='', WAVE_QUANT='', + FLUX_QUANT='', CURRENT_QUANT='', COLCOU_QUANT='', + COLCOUMS_QUANT='', COLFREMX_QUANT='', **kwargs) + + # Loading arithmetic quantities + if np.shape(self.data) == (): + if self.verbose: + print('Loading arithmetic variable', end="\r", flush=True) + self.data = load_arithmetic_quantities(self, var, **kwargs) + if document_vars.creating_vardict(self): + return None + elif var == '': + print(help(self.get_var)) + print('VARIABLES USING CGS OR GENERIC NOMENCLATURE') + for ii in self.varn: + print('use ', ii, ' for ', self.varn[ii]) + if hasattr(self, 'vardict'): + self.vardocs() + + return None + + # self.trans2noncommaxes() + + return self.data + + def get_comp_vars(self, var, *args, snap=None, iix=None, iiy=None, iiz=None, layout=None, **kwargs): + ''' + Computes composite variables. + ''' + if var == 'tg': + + if self.sel_units == 'cgs': + cgsunits = self.uni.uni['tg'] + else: + cgsunits = 1.0 + sel_units = self.sel_units + self.sel_units = 'none' + + self.data = self.get_var('pg', snap=snap) / self.get_var('rho', snap=snap) * cgsunits + self.sel_units = sel_units + else: + self.data = None + + def genvar(self): + ''' + Dictionary of original variables which will allow to convert to cgs. + ''' + self.varn = {} + self.varn['rho'] = 'rho' + self.varn['totr'] = 'rho' + self.varn['pg'] = 'prs' + self.varn['ux'] = 'vx1' + self.varn['uy'] = 'vx2' + self.varn['uz'] = 'vx3' + self.varn['bx'] = 'bx1' + self.varn['by'] = 'bx2' + self.varn['bz'] = 'bx3' + self.varn['jx'] = 'j1' + self.varn['jy'] = 'j2' + self.varn['jz'] = 'j3' + self.varn['vortx'] = 'vort1' + self.varn['vorty'] = 'vort2' + self.varn['vortz'] = 'vort3' + + def trans2comm(self, varname, snap=None, angle=45, loop=None): + ''' + Transform the domain into a "common" format. All arrays will be 3D. The 3rd axis + is: + + - for 3D atmospheres: the vertical axis + - for loop type atmospheres: along the loop + - for 1D atmosphere: the unique dimension is the 3rd axis. + At least one extra dimension needs to be created artifically. + + All of them should obey the right hand rule + + In all of them, the vectors (velocity, magnetic field etc) away from the Sun. + + If applies, z=0 near the photosphere. + + Units: everything is in cgs. + + If an array is reverse, do ndarray.copy(), otherwise pytorch will complain. + + INPUT: + varname - string + snap - integer + angle - real (degrees). Any number -90 to 90, default = 45 + ''' + + self.sel_units = 'cgs' + + self.trans2commaxes(loop) + + if angle == None and not hasattr(self, 'trans2comm_angle'): + self.trans2comm_angle = 45 + if angle != None: + self.trans2comm_angle = angle + + if self.trans2comm_angle != 0: + if varname[-1] in ['x']: + varx = self.get_var(varname, snap=snap) + vary = self.get_var(varname[0]+'y', snap=snap) + var = varx * np.cos(self.trans2comm_angle/90.0*np.pi/2.0) - vary * np.sin(self.trans2comm_angle/90.0*np.pi/2.0) + elif varname[-1] in ['y']: + vary = self.get_var(varname, snap=snap) + varx = self.get_var(varname[0]+'x', snap=snap) + var = vary * np.cos(self.trans2comm_angle/90.0*np.pi/2.0) + varx * np.sin(self.trans2comm_angle/90.0*np.pi/2.0) + else: # component z + var = self.get_var(varname, snap=snap) + var = rotate(var, angle=self.trans2comm_angle, reshape=False, mode='nearest', axes=(0, 1)) + else: + var = self.get_var(varname, snap=snap) + + if self.typemodel == 'Kostas': + var = var[..., ::-1].copy() + + if loop != None: + if varname[-1] in ['x']: + var = self.make_loop(var, loop) + varz = self.get_var(varname[0]+'z', snap=snap) + if self.typemodel == 'Kostas': + varz = varz[..., ::-1].copy() + varz = self.make_loop(varz, loop) + xx, zz = np.meshgrid(self.x, self.z) + aa = np.angle(xx+1j*zz) + for iiy, iy in enumerate(self.y): + var[:, iiy, :] = var[:, iiy, :] * np.cos(aa.T) - varz[:, iiy, :] * np.sin(aa.T) + elif varname[-1] in ['z']: + var = self.make_loop(var, loop) + varx = self.get_var(varname[0]+'x', snap=snap) + if self.typemodel == 'Kostas': + varx = varx[..., ::-1].copy() + varx = self.make_loop(varx, loop) + xx, zz = np.meshgrid(self.x, self.z) + aa = np.angle(xx+1j*zz) + for iiy, iy in enumerate(self.y): + var[:, iiy, :] = var[:, iiy, :] * np.cos(aa.T) + varx[:, iiy, :] * np.sin(aa.T) + else: + var = self.make_loop(var, loop) + + return var + + def make_loop(self, var, loop): + R = np.max(self.zorig)/np.pi*2 + rad = self.xorig+np.max(self.x_loop)-np.max(self.xorig) + angl = self.zorig / R + var_new = np.zeros((self.nx, self.ny, self.nz)) + # iiy0=np.argmin(np.abs(self.yorig)) + iiy0 = 0 + for iiy, iy in enumerate(self.y): + temp = var[:, iiy+iiy0, :] + data = polar2cartesian(rad, angl, temp, self.z, self.x) + var_new[:, iiy, :] = data + return var_new + + def trans2commaxes(self, loop=2): + + if self.transunits == False: + self.transunits = True + #self.z = self.z[::-1].copy() + if self.typemodel == 'Paolo': + # nznew=int(self.z.shape[0]/2) + #self.z = self.z[0:nznew-1] + self.z -= self.z[0] + self.nz = np.size(self.z) + self.dz1d = np.gradient(self.z) + self.xorig = self.x-np.min(self.x) + self.yorig = self.y + self.zorig = self.z + if loop != None: + R = np.max(self.z)/np.pi*2 + self.x_loop = np.linspace(R*np.cos([np.pi/loop]), R+np.max(self.x), + int((R-R*np.cos([np.pi/loop])+np.max(self.x))/np.min(self.dx1d))) + self.z_loop = np.linspace(0, R*np.sin([np.pi/loop])+np.max(self.x), + int((R*np.sin([np.pi/loop])+np.max(self.x))/np.min(self.dx1d))) + + self.x = self.x_loop.squeeze() + self.z = self.z_loop.squeeze() + + # self.y=self.y[np.argmin(np.abs(self.y)):] + + self.dx1d = np.gradient(self.x) + self.dy1d = np.gradient(self.y) + self.dz1d = np.gradient(self.z) + self.nx = np.size(self.x) + self.ny = np.size(self.y) + self.nz = np.size(self.z) + + #self.dz1d = self.dz1d[::-1].copy() + + def trans2noncommaxes(self): + + if self.transunits == True: + self.transunits = False + self.z = self.zorig + self.dz1d = np.gradient(self.z) + #self.dz1d = self.dz1d[::-1].copy() + self.nz = np.size(self.z) + + +class Pypluto_kostas_units(object): + + def __init__(self, verbose=False): + ''' + Units and constants in cgs + ''' + + self.uni = {} + self.verbose = verbose + self.uni['tg'] = 1.0e6 # K + self.uni['l'] = 1.0e8 # cm + self.uni['rho'] = 1.0e-15 # gr cm^-3 + self.uni['kboltz'] = 1.380658E-16 # Boltzman's cst. [erg/K] + self.uni['proton'] = const.m_n / const.gram # 1.674927471e-24 + self.uni['R_spec'] = self.uni['kboltz'] / (0.5e0 * self.uni['proton']) + self.uni['u'] = np.sqrt(self.uni['R_spec']*self.uni['tg']) # cm/s + self.uni['pg'] = self.uni['rho'] * self.uni['u']**2 + self.uni['b'] = np.sqrt(4.0 * np.pi * self.uni['pg']) # Gauss + self.uni['t'] = self.uni['l']/self.uni['u'] # seconds + + # Units and constants in SI + convertcsgsi(self) + + globalvars(self) + + self.uni['j'] = self.uni['b']/self.uni['l']*self.clight # current density + self.uni['gr'] = 2.7e4 # solar gravity in cgs + self.uni['gc'] = self.uni['gr'] * self.uni['l'] / self.uni['u'] ** 2 # solar gravity in Code units. + + +class Pypluto_paolo_units(object): + + def __init__(self, verbose=False): + ''' + Units and constants in cgs + ''' + + proton_mass = 1.672623110e-24 + self.uni = {} + self.verbose = verbose + self.uni['l'] = 6.960e10 # cm + self.uni['u'] = 1.0e7 # cm/s + self.uni['rho'] = 2e10*1.2650*proton_mass # gr cm^-3 + self.uni['b'] = self.uni['u']*np.sqrt(4.0 * np.pi * self.uni['rho']) # Gauss + self.uni['t'] = self.uni['l']/self.uni['u'] # seconds + self.uni['pg'] = self.uni['rho'] * self.uni['u']**2 # erg cm^-3 + self.uni['tg'] = 1.203e6/2.0*1.26506 # K + # self.uni['tg'] = 1.3747056e22 # K + self.uni['kboltz'] = 1.380658E-16 # Boltzman's cst. [erg/K] + self.uni['proton'] = const.m_n / const.gram # 1.674927471e-24 + self.uni['R_spec'] = self.uni['kboltz'] / (0.5e0 * self.uni['proton']) + + # Units and constants in SI + convertcsgsi(self) + + globalvars(self) + + self.uni['j'] = self.uni['b']/self.uni['l']*self.clight # current density + self.uni['gr'] = 2.7e4 # solar gravity in cgs + self.uni['gc'] = self.uni['gr'] * self.uni['l'] / self.uni['u'] ** 2 # solar gravity in Code units. + + +class Tools(object): + """ + + This Class has all the functions doing basic mathematical + operations to the vector or scalar fields. + It is called after pyPLUTO.pload object is defined. + + """ + + def deriv(self, Y, X=None): + """ + Calculates the derivative of Y with respect to X. + + **Inputs:** + + Y : 1-D array to be differentiated.\n + X : 1-D array with len(X) = len(Y).\n + + If X is not specified then by default X is chosen to be an equally spaced array having same number of elements + as Y. + + **Outputs:** + + This returns an 1-D array having the same no. of elements as Y (or X) and contains the values of dY/dX. + + """ + n = len(Y) + n2 = n-2 + if X == None: + X = np.arange(n) + Xarr = np.asarray(X, dtype='float') + Yarr = np.asarray(Y, dtype='float') + x12 = Xarr - np.roll(Xarr, -1) # x1 - x2 + x01 = np.roll(Xarr, 1) - Xarr # x0 - x1 + x02 = np.roll(Xarr, 1) - np.roll(Xarr, -1) # x0 - x2 + DfDx = np.roll(Yarr, 1) * (x12 / (x01*x02)) + Yarr * (1./x12 - 1./x01) - np.roll(Yarr, -1) * (x01 / (x02 * x12)) + # Formulae for the first and last points: + + DfDx[0] = Yarr[0] * (x01[1]+x02[1])/(x01[1]*x02[1]) - Yarr[1] * x02[1]/(x01[1]*x12[1]) + Yarr[2] * x01[1]/(x02[1]*x12[1]) + DfDx[n-1] = -Yarr[n-3] * x12[n2]/(x01[n2]*x02[n2]) + Yarr[n-2]*x02[n2]/(x01[n2]*x12[n2]) - Yarr[n-1]*(x02[n2]+x12[n2])/(x02[n2]*x12[n2]) + + return DfDx + + def Grad(self, phi, x1, x2, dx1, dx2, polar=False): + """ This method calculates the gradient of the 2D scalar phi. + + **Inputs:** + + phi -- 2D scalar whose gradient is to be determined.\n + x1 -- The 'x' array\n + x2 -- The 'y' array\n + dx1 -- The grid spacing in 'x' direction.\n + dx2 -- The grid spacing in 'y' direction.\n + polar -- The keyword should be set to True inorder to + estimate the Gradient in polar co-ordinates. By default + it is set to False. + + **Outputs:** + + This routine outputs a 3D array with shape = (len(x1),len(x2),2), + such that [:,:,0] element corresponds to the gradient values of + phi wrt to x1 and [:,:,1] are the gradient values of phi wrt to x2. + + """ + (n1, n2) = phi.shape + grad_phi = np.zeros(shape=(n1, n2, 2)) + h2 = np.ones(shape=(n1, n2)) + if polar == True: + for j in range(n2): + h2[:, j] = x1 + + for i in range(n1): + scrh1 = phi[i, :] + grad_phi[i, :, 1] = self.deriv(scrh1, x2)/h2[i, :] + for j in range(n2): + scrh2 = phi[:, j] + grad_phi[:, j, 0] = self.deriv(scrh2, x1) + + return grad_phi + + def Div(self, u1, u2, x1, x2, dx1, dx2, geometry=None): + """ This method calculates the divergence of the 2D vector fields u1 and u2. + + **Inputs:** + + u1 -- 2D vector along x1 whose divergence is to be determined.\n + u2 -- 2D vector along x2 whose divergence is to be determined.\n + x1 -- The 'x' array\n + x2 -- The 'y' array\n + dx1 -- The grid spacing in 'x' direction.\n + dx2 -- The grid spacing in 'y' direction.\n + geometry -- The keyword *geometry* is by default set to 'cartesian'. + It can be set to either one of the following : *cartesian*, *cylindrical*, + *spherical* or *polar*. To calculate the divergence of the vector + fields, respective geometric corrections are taken into account based + on the value of this keyword. + + **Outputs:** + + A 2D array with same shape as u1(or u2) having the values of divergence. + + """ + (n1, n2) = u1.shape + Divergence = np.zeros(shape=(n1, n2)) + du1 = np.zeros(shape=(n1, n2)) + du2 = np.zeros(shape=(n1, n2)) + + A1 = np.zeros(shape=n1) + A2 = np.zeros(shape=n2) + + dV1 = np.zeros(shape=(n1, n2)) + dV2 = np.zeros(shape=(n1, n2)) + + if geometry == None: + geometry = 'cartesian' + + # ------------------------------------------------ + # define area and volume elements for the + # different coordinate systems + # ------------------------------------------------ + + if geometry == 'cartesian': + A1[:] = 1.0 + A2[:] = 1.0 + dV1 = np.outer(dx1, A2) + dV2 = np.outer(A1, dx2) + + if geometry == 'cylindrical': + A1 = x1 + A2[:] = 1.0 + dV1 = np.meshgrid(x1*dx1, A2)[0].T*np.meshgrid(x1*dx1, A2)[1].T + for i in range(n1): + dV2[i, :] = dx2[:] + + if geometry == 'polar': + A1 = x1 + A2[:] = 1.0 + dV1 = np.meshgrid(x1, A2)[0].T*np.meshgrid(x1, A2)[1].T + dV2 = np.meshgrid(x1, dx2)[0].T*np.meshgrid(x1, dx2)[1].T + + if geometry == 'spherical': + A1 = x1*x1 + A2 = np.sin(x2) + for j in range(n2): + dV1[:, j] = A1*dx1 + dV2 = np.meshgrid(x1, np.sin(x2)*dx2)[0].T*np.meshgrid(x1, np.sin(x2)*dx2)[1].T + + # ------------------------------------------------ + # Make divergence + # ------------------------------------------------ + + for i in range(1, n1-1): + du1[i, :] = 0.5*(A1[i+1]*u1[i+1, :] - A1[i-1]*u1[i-1, :])/dV1[i, :] + for j in range(1, n2-1): + du2[:, j] = 0.5*(A2[j+1]*u2[:, j+1] - A2[j-1]*u2[:, j-1])/dV2[:, j] + + Divergence = du1 + du2 + return Divergence + + def RTh2Cyl(self, R, Th, X1, X2): + """ This method does the transformation from spherical coordinates to cylindrical ones. + + **Inputs:** + + R - 2D array of spherical radius coordinates.\n + Th - 2D array of spherical theta-angle coordinates.\n + X1 - 2D array of radial component of given vector\n + X2 - 2D array of thetoidal component of given vector\n + + **Outputs:** + + This routine outputs two 2D arrays after transformation. + + **Usage:** + + ``import pyPLUTO as pp``\n + ``import numpy as np``\n + ``D = pp.pload(0)``\n + ``ppt=pp.Tools()``\n + ``TH,R=np.meshgrid(D.x2,D.x1)``\n + ``Br,Bz=ppt.RTh2Cyl(R,TH,D.bx1,D.bx2)`` + + D.bx1 and D.bx2 should be vectors in spherical coordinates. After + transformation (Br,Bz) corresponds to vector in cilindrical coordinates. + + + """ + Y1 = X1*np.sin(Th)+X2*np.cos(Th) + Y2 = X1*np.cos(Th)-X2*np.sin(Th) + return Y1, Y2 + + def myInterpol(self, RR, N): + """ + This method interpolates (linear interpolation) vector 1D vector RR + to 1D N-length vector. Useful for stretched grid calculations. + + **Inputs:** + + RR - 1D array to interpolate.\n + N - Number of grids to interpolate to.\n + + **Outputs:** + + This routine outputs interpolated 1D array to the new grid (len=N). + + **Usage:** + + ``import pyPLUTO as pp``\n + ``import numpy as np``\n + ``D = pp.pload(0)``\n + ``ppt=pp.Tools()``\n + ``x=linspace(0,1,10) #len(x)=10``\n + ``y=x*x``\n + ``Ri,Ni=ppt.myInterpol(y,100) #len(Ri)=100`` + + Ri - interpolated numbers; + Ni - grid for Ri + + """ + + NN = np.linspace(0, len(RR)-1, len(RR)) + spline_fit = UnivariateSpline(RR, NN, k=3, s=0) + + RRi = np.linspace(RR[0], RR[-1], N) + NNi = spline_fit(RRi) + NNi[0] = NN[0]+0.00001 + NNi[-1] = NN[-1]-0.00001 + return RRi, NNi + + def getUniformGrid(self, r, th, rho, Nr, Nth): + """ + This method transforms data with non-uniform grid (stretched) to + uniform. Useful for stretched grid calculations. + + **Inputs:** + + r - 1D vector of X1 coordinate (could be any, e.g D.x1).\n + th - 1D vector of X2 coordinate (could be any, e.g D.x2).\n + rho- 2D array of data.\n + Nr - new size of X1 vector.\n + Nth- new size of X2 vector.\n + + **Outputs:** + + This routine outputs 2D uniform array Nr x Nth dimension + + **Usage:** + + ``import pyPLUTO as pp``\n + ``import numpy as np``\n + ``D = pp.pload(0)``\n + ``ppt=pp.Tools()``\n + ``X1new, X2new, res = ppt.getUniformGrid(D.x1,D.x2,D.rho,20,30)`` + + X1new - X1 interpolated grid len(X1new)=20 + X2new - X2 interpolated grid len(X2new)=30 + res - 2D array of interpolated variable + + """ + + Ri, NRi = self.myInterpol(r, Nr) + Ra = np.int32(NRi) + Wr = NRi-Ra + + YY = np.ones([Nr, len(th)]) + for i in range(len(th)): + YY[:, i] = (1-Wr)*rho[Ra, i] + Wr*rho[Ra+1, i] + + THi, NTHi = self.myInterpol(th, Nth) + THa = np.int32(NTHi) + Wth = NTHi-THa + + ZZ = np.ones([Nr, Nth]) + for i in range(Nr): + ZZ[i, :] = (1-Wth)*YY[i, THa] + Wth*YY[i, THa+1] + + return Ri, THi, ZZ + + def sph2cyl(self, D, Dx, rphi=None, theta0=None): + """ + This method transforms spherical data into cylindrical + applying interpolation. Works for stretched grid as well, + transforms poloidal (R-Theta) data by default. Fix theta + and set rphi=True to get (R-Phi) transformation. + + **Inputs:** + + D - structure from 'pload' method.\n + Dx - variable to be transformed (D.rho for example).\n + + **Outputs:** + + This routine outputs transformed (sph->cyl) variable and grid. + + **Usage:** + + ``import pyPLUTO as pp``\n + ``import numpy as np``\n + ``D = pp.pload(0)``\n + ``ppt=pp.Tools()``\n + ``R,Z,res = ppt.sph2cyl(D,D.rho.transpose())`` + + R - 2D array with cylindrical radius values + Z - 2D array with cylindrical Z values + res - 2D array of transformed variable + + """ + + if rphi is None or rphi == False: + rx = D.x1 + th = D.x2 + else: + rx = D.x1*np.sin(theta0) + th = D.x3 + + rx, th, Dx = self.getUniformGrid(rx, th, Dx.T, 200, 200) + Dx = Dx.T + + if rphi is None or rphi == False: + + r0 = np.min(np.sin(th)*rx[0]) + rN = rx[-1] + dr = rN-r0 + z0 = np.min(np.cos(th)*rN) + zN = np.max(np.cos(th)*rN) + dz = zN-z0 + dth = th[-1]-th[0] + rl = np.int32(len(rx)*dr/(rx[-1]-rx[0])) + zl = np.int32(rl * dz/dr) + thl = len(th) + r = np.linspace(r0, rN, rl) + z = np.linspace(z0, zN, zl) + else: + r0 = np.min([np.sin(th)*rx[0], np.sin(th)*rx[-1]]) + rN = np.max([np.sin(th)*rx[0], np.sin(th)*rx[-1]]) + dr = rN-r0 + z0 = np.min(np.cos(th)*rN) + zN = np.max(np.cos(th)*rN) + dz = zN-z0 + dth = th[-1]-th[0] + rl = np.int32(len(rx)*dr/(rx[-1]-rx[0])) + zl = np.int32(rl * dz/dr) + thl = len(th) + r = np.linspace(r0, rN, rl) + z = np.linspace(z0, zN, zl) + + R, Z = np.meshgrid(r, z) + Rs = np.sqrt(R*R + Z*Z) + + Th = np.arccos(Z/Rs) + kv_34 = find(R < 0) + Th.flat[kv_34] = 2*np.pi - Th.flat[kv_34] + + ddr = rx[1]-rx[0] + ddth = th[1]-th[0] + + Rs_copy = Rs.copy() + Th_copy = Th.copy() + + nR1 = find(Rs < rx[0]) + Rs.flat[nR1] = rx[0] + nR2 = find(Rs > rN) + Rs.flat[nR2] = rN + + nTh1 = find(Th > th[-1]) + Th.flat[nTh1] = th[-1] + nTh2 = find(Th < th[0]) + Th.flat[nTh2] = th[0] + + ra = ((len(rx)-1.001)/(np.max(Rs.flat)-np.min(Rs.flat)) * (Rs-np.min(Rs.flat))) + tha = ((thl-1.001)/dth * (Th-th[0])) + + rn = np.int32(ra) + thn = np.int32(tha) + dra = ra-rn + dtha = tha-thn + w1 = 1-dra + w2 = dra + w3 = 1-dtha + w4 = dtha + lrx = len(rx) + NN1 = np.int32(rn+thn*lrx) + NN2 = np.int32((rn+1)+thn*lrx) + NN3 = np.int32(rn+(thn+1)*lrx) + NN4 = np.int32((rn+1)+(thn+1)*lrx) + n = np.transpose(np.arange(0, np.product(np.shape(R)))) + Dx.copy() + F = R.copy() + F.flat[n] = w1.flat[n]*(w3.flat[n]*Dx.flat[NN1.flat[n]] + w4.flat[n]*Dx.flat[NN3.flat[n]]) +\ + w2.flat[n]*(w3.flat[n]*Dx.flat[NN2.flat[n]] + w4.flat[n]*Dx.flat[NN4.flat[n]]) + + nR1 = find(Rs_copy < rx[0]-ddr/1.5) + nR2 = find(Rs_copy > rN+ddr/1.5) + nTh1 = find(Th_copy > th[-1]+ddth/1.5) + nTh2 = find(Th_copy < th[0]-ddth/1.5) + + nmask = np.concatenate((nR1, nR2, nTh1, nTh2)) + F.flat[nmask] = np.nan + return R, Z, F + + def congrid(self, a, newdims, method='linear', centre=False, minusone=False): + """ + Arbitrary resampling of source array to new dimension sizes. + Currently only supports maintaining the same number of dimensions. + To use 1-D arrays, first promote them to shape (x,1). + + Uses the same parameters and creates the same co-ordinate lookup points + as IDL''s congrid routine, which apparently originally came from a VAX/VMS + routine of the same name. + + **Inputs:** + + a -- The 2D array that needs resampling into new dimensions.\n + newdims -- A tuple which represents the shape of the resampled data\n + method -- This keyword decides the method used for interpolation.\n + neighbour - closest value from original data\n + nearest and linear - uses n x 1-D interpolations using scipy.interpolate.interp1d + (see Numerical Recipes for validity of use of n 1-D interpolations)\n + spline - uses ndimage.map_coordinates\n + centre -- This keyword decides the positions of interpolation points.\n + True - interpolation points are at the centres of the bins\n + False - points are at the front edge of the bin\n + minusone -- This prevents extrapolation one element beyond bounds of input array\n + For example- inarray.shape = (i,j) & new dimensions = (x,y)\n + False - inarray is resampled by factors of (i/x) * (j/y)\n + True - inarray is resampled by(i-1)/(x-1) * (j-1)/(y-1)\n + + **Outputs:** + + A 2D array with resampled data having a shape corresponding to newdims. + + """ + if not a.dtype in [np.float64, np.float32]: + a = np.cast[float](a) + + m1 = np.cast[int](minusone) + ofs = np.cast[int](centre) * 0.5 + old = np.array(a.shape) + ndims = len(a.shape) + if len(newdims) != ndims: + print("[congrid] dimensions error. ") + print("This routine currently only support ") + print("rebinning to the same number of dimensions.") + return None + newdims = np.asarray(newdims, dtype=float) + dimlist = [] + + if method == 'neighbour': + for i in range(ndims): + base = np.indices(newdims)[i] + dimlist.append((old[i] - m1) / (newdims[i] - m1) + * (base + ofs) - ofs) + cd = np.array(dimlist).round().astype(int) + newa = a[list(cd)] + return newa + + elif method in ['nearest', 'linear']: + # calculate new dims + for i in range(ndims): + base = np.arange(newdims[i]) + dimlist.append((old[i] - m1) / (newdims[i] - m1) + * (base + ofs) - ofs) + # specify old dims + olddims = [np.arange(i, dtype=np.float) for i in list(a.shape)] + + # first interpolation - for ndims = any + mint = scipy.interpolate.interp1d(olddims[-1], a, kind=method) + newa = mint(dimlist[-1]) + + trorder = [ndims - 1] + range(ndims - 1) + for i in range(ndims - 2, -1, -1): + newa = newa.transpose(trorder) + + mint = scipy.interpolate.interp1d(olddims[i], newa, kind=method) + newa = mint(dimlist[i]) + + if ndims > 1: + # need one more transpose to return to original dimensions + newa = newa.transpose(trorder) + + return newa + elif method in ['spline']: + oslices = [slice(0, j) for j in old] + np.ogrid[oslices] + nslices = [slice(0, j) for j in list(newdims)] + newcoords = np.mgrid[nslices] + + newcoords_dims = range(n.rank(newcoords)) + # make first index last + newcoords_dims.append(newcoords_dims.pop(0)) + newcoords_tr = newcoords.transpose(newcoords_dims) + # makes a view that affects newcoords + + newcoords_tr += ofs + + deltas = (np.asarray(old) - m1) / (newdims - m1) + newcoords_tr *= deltas + + newcoords_tr -= ofs + + newa = scipy.ndimage.map_coordinates(a, newcoords) + return newa + else: + print("Congrid error: Unrecognized interpolation type.\n") + print("Currently only \'neighbour\', \'nearest\',\'linear\',") + print("and \'spline\' are supported.") + return None + + +class Image(object): + ''' This Class has all the routines for the imaging the data + and plotting various contours and fieldlines on these images. + CALLED AFTER pyPLUTO.pload object is defined + ''' + + def pldisplay(self, D, var, **kwargs): + """ This method allows the user to display a 2D data using the + matplotlib's pcolormesh. + + **Inputs:** + + D -- pyPLUTO pload object.\n + var -- 2D array that needs to be displayed. + + *Required Keywords:* + + x1 -- The 'x' array\n + x2 -- The 'y' array + + *Optional Keywords:* + + vmin -- The minimum value of the 2D array (Default : min(var))\n + vmax -- The maximum value of the 2D array (Default : max(var))\n + title -- Sets the title of the image.\n + label1 -- Sets the X Label (Default: 'XLabel')\n + label2 -- Sets the Y Label (Default: 'YLabel')\n + polar -- A list to project Polar data on Cartesian Grid.\n + polar = [True, True] -- Projects r-phi plane.\n + polar = [True, False] -- Project r-theta plane.\n + polar = [False, False] -- No polar plot (Default)\n + cbar -- Its a tuple to set the colorbar on or off. \n + cbar = (True,'vertical') -- Displays a vertical colorbar\n + cbar = (True,'horizontal') -- Displays a horizontal colorbar\n + cbar = (False,'') -- Displays no colorbar. + + **Usage:** + + ``import pyPLUTO as pp``\n + ``wdir = '/path/to/the data files/'``\n + ``D = pp.pload(1,w_dir=wdir)``\n + ``I = pp.Image()``\n + ``I.pldisplay(D, D.v2, x1=D.x1, x2=D.x2, cbar=(True,'vertical'),\ + title='Velocity',label1='Radius',label2='Height')`` + """ + x1 = kwargs.get('x1') + x2 = kwargs.get('x2') + var = var.T + + f1 = figure(kwargs.get('fignum', 1), figsize=kwargs.get('figsize', [10, 10]), + dpi=80, facecolor='w', edgecolor='k') + ax1 = f1.add_subplot(111) + ax1.set_aspect('equal') + + if kwargs.get('polar', [False, False])[0]: + xx, yy = self.getPolarData(D, kwargs.get('x2'), rphi=kwargs.get('polar')[1]) + pcolormesh(xx, yy, var, vmin=kwargs.get('vmin', np.min(var)), vmax=kwargs.get('vmax', np.max(var))) + else: + ax1.axis([np.min(x1), np.max(x1), np.min(x2), np.max(x2)]) + pcolormesh(x1, x2, var, vmin=kwargs.get('vmin', np.min(var)), vmax=kwargs.get('vmax', np.max(var))) + + title(kwargs.get('title', "Title"), size=kwargs.get('size')) + xlabel(kwargs.get('label1', "Xlabel"), size=kwargs.get('size')) + ylabel(kwargs.get('label2', "Ylabel"), size=kwargs.get('size')) + if kwargs.get('cbar', (False, ''))[0] == True: + colorbar(orientation=kwargs.get('cbar')[1]) + + def multi_disp(self, *args, **kwargs): + mvar = [] + for arg in args: + mvar.append(arg.T) + + xmin = np.min(kwargs.get('x1')) + xmax = np.max(kwargs.get('x1')) + ymin = np.min(kwargs.get('x2')) + ymax = np.max(kwargs.get('x2')) + mfig = figure(kwargs.get('fignum', 1), figsize=kwargs.get('figsize', [10, 10])) + Ncols = kwargs.get('Ncols') + Nrows = len(args)/Ncols + mprod = Nrows*Ncols + dictcbar = kwargs.get('cbar', (False, '', 'each')) + + for j in range(mprod): + mfig.add_subplot(Nrows, Ncols, j+1) + pcolormesh(kwargs.get('x1'), kwargs.get('x2'), mvar[j]) + axis([xmin, xmax, ymin, ymax]) + gca().set_aspect('equal') + + xlabel(kwargs.get('label1', mprod*['Xlabel'])[j]) + ylabel(kwargs.get('label2', mprod*['Ylabel'])[j]) + title(kwargs.get('title', mprod*['Title'])[j]) + if (dictcbar[0] == True) and (dictcbar[2] == 'each'): + colorbar(orientation=kwargs.get('cbar')[1]) + if dictcbar[0] == True and dictcbar[2] == 'last': + if (j == np.max(range(mprod))): + colorbar(orientation=kwargs.get('cbar')[1]) + + def oplotbox(self, AMRLevel, lrange=[0, 0], cval=['b', 'r', 'g', 'm', 'w', 'k'], + islice=-1, jslice=-1, kslice=-1, geom='CARTESIAN'): + """ + This method overplots the AMR boxes up to the specified level. + + **Input:** + + AMRLevel -- AMR object loaded during the reading and stored in the pload object + + *Optional Keywords:* + + lrange -- [level_min,level_max] to be overplotted. By default it shows all the loaded levels\n + cval -- list of colors for the levels to be overplotted.\n + [ijk]slice -- Index of the 2D slice to look for so that the adequate box limits are plotted. + By default oplotbox considers you are plotting a 2D slice of the z=min(x3) plane.\n + geom -- Specified the geometry. Currently, CARTESIAN (default) and POLAR geometries are handled. + """ + + nlev = len(AMRLevel) + lrange[1] = min(lrange[1], nlev-1) + npl = lrange[1]-lrange[0]+1 + lpls = [lrange[0]+v for v in range(npl)] + cols = cval[0:nlev] + # Get the offset and the type of slice + Slice = 0 + inds = 'k' + xx = 'x' + yy = 'y' + if (islice >= 0): + Slice = islice + AMRLevel[0]['ibeg'] + inds = 'i' + xx = 'y' + yy = 'z' + if (jslice >= 0): + Slice = jslice + AMRLevel[0]['jbeg'] + inds = 'j' + xx = 'x' + yy = 'z' + if (kslice >= 0): + Slice = kslice + AMRLevel[0]['kbeg'] + inds = 'k' + xx = 'x' + yy = 'y' + + # Overplot the boxes + hold(True) + for il in lpls: + level = AMRLevel[il] + for ib in range(level['nbox']): + box = level['box'][ib] + if ((Slice-box[inds+'b'])*(box[inds+'e']-Slice) >= 0): + if (geom == 'CARTESIAN'): + x0 = box[xx+'0'] + x1 = box[xx+'1'] + y0 = box[yy+'0'] + y1 = box[yy+'1'] + plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=cols[il]) + elif (geom == 'POLAR') or (geom == 'SPHERICAL'): + dn = np.pi/50. + x0 = box[xx+'0'] + x1 = box[xx+'1'] + y0 = box[yy+'0'] + y1 = box[yy+'1'] + if y0 == y1: + y1 = 2*np.pi+y0 - 1.e-3 + xb = np.concatenate([ + [x0*np.cos(y0), x1*np.cos(y0)], + x1*np.cos(np.linspace(y0, y1, num=int(abs(y0-y1)/dn))), + [x1*np.cos(y1), x0*np.cos(y1)], + x0*np.cos(np.linspace(y1, y0, num=int(abs(y0-y1)/dn)))]) + yb = np.concatenate([ + [x0*np.sin(y0), x1*np.sin(y0)], + x1*np.sin(np.linspace(y0, y1, num=int(abs(y0-y1)/dn))), + [x1*np.sin(y1), x0*np.sin(y1)], + x0*np.sin(np.linspace(y1, y0, num=int(abs(y0-y1)/dn)))]) + plot(xb, yb, color=cols[il]) + + hold(False) + + def field_interp(self, var1, var2, x, y, dx, dy, xp, yp): + """ This method interpolates value of vector fields (var1 and var2) on field points (xp and yp). + The field points are obtained from the method field_line. + + **Inputs:** + + var1 -- 2D Vector field in X direction\n + var2 -- 2D Vector field in Y direction\n + x -- 1D X array\n + y -- 1D Y array\n + dx -- 1D grid spacing array in X direction\n + dy -- 1D grid spacing array in Y direction\n + xp -- field point in X direction\n + yp -- field point in Y direction\n + + **Outputs:** + + A list with 2 elements where the first element corresponds to the interpolate field + point in 'x' direction and the second element is the field point in 'y' direction. + + """ + q = [] + U = var1 + V = var2 + i0 = np.abs(xp-x).argmin() + j0 = np.abs(yp-y).argmin() + scrhUx = np.interp(xp, x, U[:, j0]) + scrhUy = np.interp(yp, y, U[i0, :]) + q.append(scrhUx + scrhUy - U[i0, j0]) + scrhVx = np.interp(xp, x, V[:, j0]) + scrhVy = np.interp(yp, y, V[i0, :]) + q.append(scrhVx + scrhVy - V[i0, j0]) + return q + + def field_line(self, var1, var2, x, y, dx, dy, x0, y0): + """ This method is used to obtain field lines (same as fieldline.pro in PLUTO IDL tools). + + **Inputs:** + + var1 -- 2D Vector field in X direction\n + var2 -- 2D Vector field in Y direction\n + x -- 1D X array\n + y -- 1D Y array\n + dx -- 1D grid spacing array in X direction\n + dy -- 1D grid spacing array in Y direction\n + x0 -- foot point of the field line in X direction\n + y0 -- foot point of the field line in Y direction\n + + **Outputs:** + + This routine returns a dictionary with keys - \n + qx -- list of the field points along the 'x' direction. + qy -- list of the field points along the 'y' direction. + + **Usage:** + + See the myfieldlines routine for the same. + """ + xbeg = x[0] - 0.5*dx[0] + xend = x[-1] + 0.5*dx[-1] + + ybeg = y[0] - 0.5*dy[0] + yend = y[-1] + 0.5*dy[-1] + + inside_domain = x0 > xbeg and x0 < xend and y0 > ybeg and y0 < yend + + MAX_STEPS = 5000 + xln_fwd = [x0] + yln_fwd = [y0] + [x0] + [y0] + k = 0 + + while inside_domain == True: + R1 = self.field_interp(var1, var2, x, y, dx, dy, xln_fwd[k], yln_fwd[k]) + dl = 0.5*np.max(np.concatenate((dx, dy)))/(np.sqrt(R1[0]*R1[0] + R1[1]*R1[1] + 1.e-14)) + xln_fwd[k] + 0.5*dl*R1[0] + yln_fwd[k] + 0.5*dl*R1[1] + + R2 = self.field_interp(var1, var2, x, y, dx, dy, xln_fwd[k], yln_fwd[k]) + + xln_one = xln_fwd[k] + dl*R2[0] + yln_one = yln_fwd[k] + dl*R2[1] + + xln_fwd.append(xln_one) + yln_fwd.append(yln_one) + inside_domain = xln_one > xbeg and xln_one < xend and yln_one > ybeg and yln_one < yend + inside_domain = inside_domain and (k < MAX_STEPS-3) + k = k + 1 + + k_fwd = k + qx = np.asarray(xln_fwd[0:k_fwd]) + qy = np.asarray(yln_fwd[0:k_fwd]) + flines = {'qx': qx, 'qy': qy} + + return flines + + def myfieldlines(self, Data, x0arr, y0arr, stream=False, **kwargs): + r""" This method overplots the magnetic field lines at the footpoints given by (x0arr[i],y0arr[i]). + + **Inputs:** + + Data -- pyPLUTO.pload object\n + x0arr -- array of x co-ordinates of the footpoints\n + y0arr -- array of y co-ordinates of the footpoints\n + stream -- keyword for two different ways of calculating the field lines.\n + True -- plots contours of rAphi (needs to store vector potential)\n + False -- plots the fieldlines obtained from the field_line routine. (Default option) + + *Optional Keywords:* + + colors -- A list of matplotlib colors to represent the lines. The length of this list should be same as that of x0arr.\n + lw -- Integer value that determines the linewidth of each line.\n + ls -- Determines the linestyle of each line. + + **Usage:** + + Assume that the magnetic field is a given as **B** = B0$\hat{y}$. + Then to show this field lines we have to define the x and y arrays of field foot points.\n + + ``x0arr = linspace(0.0,10.0,20)``\n + ``y0arr = linspace(0.0,0.0,20)``\n + ``import pyPLUTO as pp``\n + ``D = pp.pload(45)``\n + ``I = pp.Image()``\n + ``I.myfieldlines(D,x0arr,y0arr,colors='k',ls='--',lw=1.0)`` + """ + + if len(x0arr) != len(y0arr): + print("Input Arrays should have same size") + QxList = [] + QyList = [] + StreamFunction = [] + levels = [] + if stream == True: + X, Y = np.meshgrid(Data.x1, Data.x2.T) + StreamFunction = X*(Data.Ax3.T) + for i in range(len(x0arr)): + nx = np.abs(X[0, :]-x0arr[i]).argmin() + ny = np.abs(X[:, 0]-y0arr[i]).argmin() + levels.append(X[ny, nx]*Data.Ax3.T[ny, nx]) + + contour(X, Y, StreamFunction, levels, colors=kwargs.get('colors'), linewidths=kwargs.get('lw', 1), linestyles=kwargs.get('ls', 'solid')) + else: + for i in range(len(x0arr)): + QxList.append(self.field_line(Data.bx1, Data.bx2, Data.x1, Data.x2, Data.dx1, Data.dx1, x0arr[i], y0arr[i]).get('qx')) + QyList.append(self.field_line(Data.bx1, Data.bx2, Data.x1, Data.x2, Data.dx1, Data.dx1, x0arr[i], y0arr[i]).get('qy')) + plot(QxList[i], QyList[i], color=kwargs.get('colors')) + axis([min(Data.x1), max(Data.x1), min(Data.x2), max(Data.x2)]) + + def getSphData(self, Data, w_dir=None, datatype=None, **kwargs): + """This method transforms the vector and scalar fields from Spherical co-ordinates to Cylindrical. + + **Inputs**: + + Data -- pyPLUTO.pload object\n + w_dir -- /path/to/the/working/directory/\n + datatype -- If the data is of 'float' type then datatype = 'float' else by default the datatype is set to 'double'. + + *Optional Keywords*: + + rphi -- [Default] is set to False implies that the r-theta plane is transformed. If set True then the r-phi plane is transformed.\n + x2cut -- Applicable for 3D data and it determines the co-ordinate of the x2 plane while r-phi is set to True.\n + x3cut -- Applicable for 3D data and it determines the co-ordinate of the x3 plane while r-phi is set to False. + + """ + + Tool = Tools() + key_value_pairs = [] + allvars = [] + if w_dir is None: + w_dir = curdir() + for v in Data.vars: + allvars.append(v) + + if kwargs.get('rphi', False) == True: + R, TH = np.meshgrid(Data.x1, Data.x3) + if Data.n3 != 1: + for variable in allvars: + key_value_pairs.append([variable, getattr(Data, variable)[:, kwargs.get('x2cut', 0), :].T]) + + SphData = dict(key_value_pairs) + if ('bx1' in allvars) or ('bx2' in allvars): + (SphData['b1c'], SphData['b3c']) = Tool.RTh2Cyl(R, TH, SphData.get('bx1'), SphData.get('bx3')) + allvars.append('b1c') + allvars.append('b3c') + if ('vx1' in allvars) or ('vx2' in allvars): + (SphData['v1c'], SphData['v3c']) = Tool.RTh2Cyl(R, TH, SphData.get('vx1'), SphData.get('vx3')) + allvars.append('v1c') + allvars.append('v3c') + else: + print("No x3 plane for 2D data") + else: + R, TH = np.meshgrid(Data.x1, Data.x2) + if Data.n3 != 1: + for variable in allvars: + key_value_pairs.append([variable, getattr(Data, variable)[:, :, kwargs.get('x3cut', 0)].T]) + SphData = dict(key_value_pairs) + if ('bx1' in allvars) or ('bx2' in allvars): + (SphData['b1c'], SphData['b2c']) = Tool.RTh2Cyl(R, TH, SphData.get('bx1'), SphData.get('bx2')) + allvars.append('b1c') + allvars.append('b2c') + if ('vx1' in allvars) or ('vx2' in allvars): + (SphData['v1c'], SphData['v2c']) = Tool.RTh2Cyl(R, TH, SphData.get('vx1'), SphData.get('vx2')) + allvars.append('v1c') + allvars.append('v2c') + else: + for variable in allvars: + key_value_pairs.append([variable, getattr(Data, variable)[:, :].T]) + SphData = dict(key_value_pairs) + if ('bx1' in allvars) or ('bx2' in allvars): + (SphData['b1c'], SphData['b2c']) = Tool.RTh2Cyl(R, TH, SphData.get('bx1'), SphData.get('bx2')) + allvars.append('b1c') + allvars.append('b2c') + if ('vx1' in allvars) or ('vx2' in allvars): + (SphData['v1c'], SphData['v2c']) = Tool.RTh2Cyl(R, TH, SphData.get('vx1'), SphData.get('vx2')) + allvars.append('v1c') + allvars.append('v2c') + + for variable in allvars: + if kwargs.get('rphi', False) == True: + R, Z, SphData[variable] = Tool.sph2cyl(Data, SphData.get(variable), rphi=True, theta0=Data.x2[kwargs.get('x2cut', 0)]) + else: + if Data.n3 != 1: + R, Z, SphData[variable] = Tool.sph2cyl(Data, SphData.get(variable), rphi=False) + else: + R, Z, SphData[variable] = Tool.sph2cyl(Data, SphData.get(variable), rphi=False) + + return R, Z, SphData + + def getPolarData(self, Data, ang_coord, rphi=False): + """To get the Cartesian Co-ordinates from Polar. + + **Inputs:** + + Data -- pyPLUTO pload Object\n + ang_coord -- The Angular co-ordinate (theta or Phi) + + *Optional Keywords:* + + rphi -- Default value FALSE is for R-THETA data, + Set TRUE for R-PHI data.\n + + **Outputs**: + + 2D Arrays of X, Y from the Radius and Angular co-ordinates.\n + They are used in pcolormesh in the Image.pldisplay functions. + """ + D = Data + if ang_coord is D.x2: + x2r = D.x2r + elif ang_coord is D.x3: + x2r = D.x3r + else: + print("Angular co-ordinate must be given") + + rcos = np.outer(np.cos(x2r), D.x1r) + rsin = np.outer(np.sin(x2r), D.x1r) + + if rphi: + xx = rcos + yy = rsin + else: + xx = rsin + yy = rcos + + return xx, yy + + def pltSphData(self, Data, w_dir=None, datatype=None, **kwargs): + """ + This method plots the transformed data obtained from + getSphData using the matplotlib's imshow + + **Inputs:** + + Data -- pyPLUTO.pload object\n + w_dir -- /path/to/the/working/directory/\n + datatype -- Datatype. + + *Required Keywords*: + + plvar -- A string which represents the plot variable.\n + + *Optional Keywords*: + + logvar -- [Default = False] Set it True for plotting the log of a variable.\n + rphi -- [Default = False - for plotting in r-theta plane] Set it True for plotting the variable in r-phi plane. + + """ + + if w_dir is None: + w_dir = curdir() + R, Z, SphData = self.getSphData(Data, w_dir=w_dir, datatype=datatype, **kwargs) + + extent = (np.min(R.flat), max(R.flat), np.min(Z.flat), max(Z.flat)) + max(R.flat)-np.min(R.flat) + max(Z.flat)-np.min(Z.flat) + + isnotnan = -np.isnan(SphData[kwargs.get('plvar')]) + maxPl = max(SphData[kwargs.get('plvar')][isnotnan].flat) + minPl = np.min(SphData[kwargs.get('plvar')][isnotnan].flat) + normrange = False + if minPl < 0: + normrange = True + if maxPl > -minPl: + minPl = -maxPl + else: + maxPl = -minPl + if (normrange and kwargs.get('plvar') != 'rho' and kwargs.get('plvar') != 'prs'): + SphData[kwargs.get('plvar')][-1][-1] = maxPl + SphData[kwargs.get('plvar')][-1][-2] = minPl + + if (kwargs.get('logvar') == True): + SphData[kwargs.get('plvar')] = np.log10(SphData[kwargs.get('plvar')]) + + imshow(SphData[kwargs.get('plvar')], aspect='equal', origin='lower', cmap=cm.jet, extent=extent, interpolation='nearest') diff --git a/helita/sim/radyn.py b/helita/sim/radyn.py new file mode 100644 index 00000000..8363d9bb --- /dev/null +++ b/helita/sim/radyn.py @@ -0,0 +1,381 @@ +from math import ceil + +import numpy as np +import radynpy as rd +from scipy.sparse import coo_matrix + +from . import document_vars +from .load_arithmetic_quantities import * +from .load_noeos_quantities import * +from .load_quantities import * +from .tools import * + + +class radyn(object): + """ + Class to read cipmocct atmosphere + + Parameters + ---------- + fdir : str, optional + Directory with snapshots. + rootname : str + rootname of the file (wihtout params or vars). + verbose : bool, optional + If True, will print more information. + it : integer + snapshot number + """ + + def __init__(self, filename, *args, fdir='.', + sel_units='cgs', verbose=True, **kwargs): + + self.filename = filename + self.fdir = fdir + self.rdobj = rd.cdf.LazyRadynData(filename) + self.x = np.array([0.0]) + self.y = np.array([0.0]) + self.z = np.flip(self.rdobj.__getattr__('zm')) + self.sel_units = sel_units + self.verbose = verbose + self.snap = None + self.uni = Radyn_units() + + self.dx = np.array([1.0]) + self.dy = np.array([1.0]) + self.dz = np.copy(self.z) + self.nt = np.shape(self.z)[0] + self.nz = np.shape(self.z)[1] + for it in range(0, self.nt): + self.dz[it, :] = np.gradient(self.z[it, :]) + self.dz1d = self.dz + self.dx1d = np.array([1.0]) + self.dy1d = np.array([1.0]) + + self.nx = np.shape(self.x) + self.ny = np.shape(self.y) + + self.time = self.rdobj.__getattr__('time') + + self.transunits = False + + self.cstagop = False # This will not allow to use cstagger from Bifrost in load + self.hion = False # This will not allow to use HION from Bifrost in load + + self.genvar() + + document_vars.create_vardict(self) + document_vars.set_vardocs(self) + + def get_var(self, var, iix=None, iiy=None, iiz=None, layout=None): + ''' + Reads the variables from a snapshot (it). + + Parameters + ---------- + var - string + Name of the variable to read. + + cgs- logic + converts into cgs units. + Axes: + ----- + z-axis is along the loop + x and y axes have only one grid. + + Information about radynpy library: + -------------- + ''' + + if var in self.varn.keys(): + varname = self.varn[var] + else: + varname = var + + try: + + if self.sel_units == 'cgs': + varu = var.replace('x', '') + varu = varu.replace('y', '') + varu = varu.replace('z', '') + if (var in self.varn.keys()) and (varu in self.uni.uni.keys()): + cgsunits = self.uni.uni[varu] + else: + cgsunits = 1.0 + else: + cgsunits = 1.0 + + self.data = self.rdobj.__getattr__(varname) * cgsunits + except: + + self.data = load_quantities(self, var, PLASMA_QUANT='', CYCL_RES='', + COLFRE_QUANT='', COLFRI_QUANT='', IONP_QUANT='', + EOSTAB_QUANT='', TAU_QUANT='', DEBYE_LN_QUANT='', + CROSTAB_QUANT='', COULOMB_COL_QUANT='', AMB_QUANT='', + HALL_QUANT='', BATTERY_QUANT='', SPITZER_QUANT='', + KAPPA_QUANT='', GYROF_QUANT='', WAVE_QUANT='', + FLUX_QUANT='', CURRENT_QUANT='', COLCOU_QUANT='', + COLCOUMS_QUANT='', COLFREMX_QUANT='') + + if np.shape(self.data) == (): + if self.verbose: + print('Loading arithmetic variable', end="\r", flush=True) + self.data = load_arithmetic_quantities(self, var) + + if document_vars.creating_vardict(self): + return None + elif var == '': + print(help(self.get_var)) + print('VARIABLES USING CGS OR GENERIC NOMENCLATURE') + for ii in self.varn: + print('use ', ii, ' for ', self.varn[ii]) + if hasattr(self, 'vardict'): + self.vardocs() + print('\n radyn obj is self.rdobj, self.rdobj.var_info is as follows') + print(self.rdobj.var_info) + + return None + + self.trans2noncommaxes() + + return self.data + + def genvar(self): + ''' + Dictionary of original variables which will allow to convert to cgs. + ''' + self.varn = {} + self.varn['rho'] = 'd1' + self.varn['totr'] = 'd1' + self.varn['tg'] = 'tg1' + self.varn['ux'] = 'ux' + self.varn['uy'] = 'uy' + self.varn['uz'] = 'vz1' + self.varn['bx'] = 'bx' + self.varn['by'] = 'by' + self.varn['bz'] = 'bz' + self.varn['ne'] = 'ne1' + + def trans2comm(self, varname, snap=0, **kwargs): + ''' + Transform the domain into a "common" format. All arrays will be 3D. The 3rd axis + is: + + - for 3D atmospheres: the vertical axis + - for loop type atmospheres: along the loop + - for 1D atmosphere: the unique dimension is the 3rd axis. + At least one extra dimension needs to be created artifically. + + All of them should obey the right hand rule + + In all of them, the vectors (velocity, magnetic field etc) away from the Sun. + + If applies, z=0 near the photosphere. + + Units: everything is in cgs. + + If an array is reverse, do ndarray.copy(), otherwise pytorch will complain. + + ''' + + self.sel_units = 'cgs' + + for key, value in kwargs.items(): + if key == 'dx': + if hasattr(self, 'trans_dx'): + if value != self.trans_dx: + self.transunits = False + if key == 'dz': + if hasattr(self, 'trans_dz'): + if value != self.trans_dz: + self.transunits = False + + if self.snap != snap: + self.snap = snap + self.transunits = False + + var = self.get_var(varname)[self.snap] + + self.trans2commaxes(**kwargs) + + if not hasattr(self, 'trans_loop_width'): + self.trans_loop_width = 1.0 + if not hasattr(self, 'trans_sparse'): + self.trans_sparse = 3e7 + + for key, value in kwargs.items(): + if key == 'loop_width': + self.trans_loop_width = value + if key == 'unscale': + do_expscale = value + if key == 'sparse': + self.trans_sparse = value + + # Semicircular loop + s = self.rdobj.zm[snap] + good = s >= 0.0 + s = s[good] + var = var[good] + # GSK -- smax was changed 12th March 2021. See comment in trans2commaxes + #smax = self.rdobj.cdf['zll'][self.snap] + smax = np.max(self.rdobj.__getattr__('zm')) + R = 2*smax/np.pi + + # JMS we are assuming here that self.z.min() = 0 + # GSK: This isn't true, if you mean the minimum height in RADYN. Z can go sub-photosphere (~60km) + shape = (ceil(self.x_loop.max()/self.trans_dx), 1, ceil(self.z_loop.max()/self.trans_dx)) + + # In the RADYN model in the corona, successive grid points may be several pixels away from each other. + # In this case, need to refine loop. + do_expscale = False + for key, value in kwargs.items(): + if key == 'unscale': + do_expscale = value + + if self.gridfactor > 1: + if do_expscale: + ss, var = refine(s, np.log(var), factor=self.gridfactor, unscale=np.exp) + else: + ss, var = refine(s, var, factor=self.gridfactor) + else: + ss = s + omega = ss/R + + # Arc lengths (in radians) + dA = np.abs(omega[1:]-omega[0:-1]) + dA = dA.tolist() + dA.insert(0, dA[0]) + dA.append(dA[-1]) + dA = np.array(dA) + dA = 0.5*(dA[1:]+dA[0:-1]) + #dA = R*dA*(loop_width*dx) + dA = 0.5*dA*((R+0.5*self.trans_loop_width*self.trans_dx)**2-(R-0.5*self.trans_loop_width*self.trans_dx)**2) + + # Componnets of velocity in the x and z directions + if varname == 'ux': + var = -var*np.sin(omega) + if varname == 'uz': + var = var*np.cos(omega) + + xind = np.floor(self.x_loop/self.trans_dx).astype(np.int64) + zind = np.clip(np.floor(self.z_loop/self.trans_dz).astype(np.int64), 0, shape[2]-1) + + # Define matrix with column coordinate corresponding to point along loop + # and row coordinate corresponding to position in Cartesian grid + col = np.arange(len(self.z_loop), dtype=np.int64) + row = xind*shape[2]+zind + + if self.trans_sparse: + M = coo_matrix((dA/(self.trans_dx*self.trans_dz), (row, col)), shape=(shape[0]*shape[2], len(ss)), dtype=np.float) + M = M.tocsr() + else: + M = np.zeros(shape=(shape[0]*shape[2], len(ss)), dtype=np.float) + M[row, col] = dA/(self.dx1d*self.dz1d.max()) # weighting by area of arc segment + + # The final quantity at each Cartesian grid cell is an area-weighted + # average of values from loop passing through this grid cell + # This arrays are not actually used for VDEM extraction + var = (M@var).reshape(shape) + + self.x = np.linspace(self.x_loop.min(), self.x_loop.max(), np.shape(var)[0]) + self.z = np.linspace(self.z_loop.min(), self.z_loop.max(), np.shape(var)[-1]) + + self.dx1d = np.gradient(self.x) + self.dy1d = 1.0 + self.dz1d = np.gradient(self.z) + + return var + + def trans2commaxes(self, **kwargs): + + if self.transunits == False: + + if not hasattr(self, 'trans_dx'): + self.trans_dx = 3e7 + if not hasattr(self, 'trans_dz'): + self.trans_dz = 3e7 + + for key, value in kwargs.items(): + if key == 'dx': + self.trans_dx = value + if key == 'dz': + self.trans_dz = value + + # Semicircular loop + self.zorig = self.rdobj.__getattr__('zm')[self.snap] + s = np.copy(self.zorig) + good = s >= 0.0 + s = s[good] + # JMS -- Sometimes zll is slightly different to the max of zm which causes problems on the assumption of a 1/4 loop. + # max(zm) fix the problem + #smax = self.rdobj.cdf['zll'][self.snap] + smax = np.max(self.rdobj.__getattr__('zm')) + R = 2*smax/np.pi + x = np.cos(s/R)*R + z = np.sin(s/R)*R + + (ceil(x.max()/self.trans_dx), ceil(z.max()/self.trans_dz)) + + # In the RADYN model in the corona, successive grid points may be several pixels away from each other. + # In this case, need to refine loop. + maxdl = np.abs(z[1:]-z[0:-1]).max() + self.gridfactor = ceil(2*maxdl/np.min([self.trans_dx, self.trans_dz])) + + if self.gridfactor > 1: + ss, self.x_loop = refine(s, x, factor=self.gridfactor) + ss, self.z_loop = refine(s, z, factor=self.gridfactor) + else: + self.z_loop = z + self.x_loop = x + + self.y = np.array([0.0]) + + self.dx1d_loop = np.gradient(self.x_loop) + self.dy1d = 1.0 + self.dz1d_loop = np.gradient(self.z_loop) + + self.transunits = True + + def trans2noncommaxes(self): + + if self.transunits == True: + self.x = np.array([0.0]) + self.y = np.array([0.0]) + self.z = self.rdobj.__getattr__('zm') + + self.dx = np.array([1.0]) + self.dy = np.array([1.0]) + self.dz = np.copy(self.z) + self.nz = np.shape(self.z)[1] + for it in range(0, self.nt): + self.dz[it, :] = np.gradient(self.z[it, :]) + self.dz1d = self.dz + self.dx1d = np.array([1.0]) + self.dy1d = np.array([1.0]) + + self.nx = np.shape(self.x) + self.ny = np.shape(self.y) + self.transunits = False + + +class Radyn_units(object): + + def __init__(self, verbose=False): + ''' + Units and constants in cgs + ''' + self.uni = {} + self.verbose = verbose + self.uni['tg'] = 1.0 + self.uni['l'] = 1.0 + self.uni['n'] = 1.0 + self.uni['rho'] = 1.0 + self.uni['u'] = 1.0 + self.uni['b'] = 1.0 + self.uni['t'] = 1.0 # seconds + self.uni['j'] = 1.0 + + # Units and constants in SI + convertcsgsi(self) + + globalvars(self) diff --git a/helita/sim/rh.py b/helita/sim/rh.py index ec9fbc1d..d233a53f 100644 --- a/helita/sim/rh.py +++ b/helita/sim/rh.py @@ -1,10 +1,10 @@ """ Set of programs and tools to read the outputs from RH (Han's version) """ -import os -import sys import io +import os import xdrlib + import numpy as np @@ -46,6 +46,7 @@ class Rhout: (e.g. as in readatmos for all the elements and etc.). It also allows one to read directly into attribute of the class (with setattr(self,'aa',)) """ + def __init__(self, fdir='.', verbose=True): ''' Reads all the output data from a RH run.''' self.verbose = verbose @@ -172,7 +173,7 @@ def read_atmosphere(self, infile='atmos.out'): self.stokes = False if self.geometry_type != 'SPHERICAL_SYMMETRIC': try: - stokes = read_xdr_var(data, ('i',)) + read_xdr_var(data, ('i',)) except EOFError or IOError: if self.verbose: print('(WWW) read_atmos: no Stokes data in atmos.out,' @@ -197,7 +198,6 @@ def read_spectrum(self, infile='spectrum.out'): 'call read_atmos() first!') raise ValueError(em) data = read_xdr_file(infile) - profs = {} self.spec = {} nspect = read_xdr_var(data, ('i',)) self.spec['nspect'] = nspect @@ -290,8 +290,8 @@ def read_brs(self, infile='brs.out'): ' first!') raise ValueError(em) data = read_xdr_file(infile) - atmosID = read_xdr_var(data, ('s',)).strip() - nspace = read_xdr_var(data, ('i',)) + read_xdr_var(data, ('s',)).strip() + read_xdr_var(data, ('i',)) nspect = read_xdr_var(data, ('i',)) if nspect != self.spec['nspect']: em = ('(EEE) read_brs: nspect in file different from atmos. ' @@ -473,7 +473,7 @@ def get_contrib_ray(self, inray='ray.input', rayfile='spectrum_1.00'): if not (0 <= mu <= 1.): em = 'get_contrib_ray: invalid mu read: %f' % mu raise ValueError(em) - idx = self.ray['wave_idx'] + self.ray['wave_idx'] # Calculate optical depth self.tau = get_tau(self.geometry['height'], mu, self.ray['chi']) # Calculate contribution function @@ -495,6 +495,7 @@ class RhAtmos: verbose : str, optional If True, will print more details. """ + def __init__(self, format="2D", filename=None, verbose=True): ''' Reads RH input atmospheres. ''' self.verbose = verbose diff --git a/helita/sim/simtools.py b/helita/sim/simtools.py index 23f2049f..0b5e1d2d 100644 --- a/helita/sim/simtools.py +++ b/helita/sim/simtools.py @@ -2,8 +2,8 @@ Tools to use with the simulation's syntetic spectra """ import numpy as np -from scipy import ndimage import scipy.interpolate as interp +from scipy import ndimage def psf_diffr(ang, wave=777, D=1., pix=True): diff --git a/helita/sim/stagger.py b/helita/sim/stagger.py index 9718d8c8..8a448321 100644 --- a/helita/sim/stagger.py +++ b/helita/sim/stagger.py @@ -1,17 +1,165 @@ +""" +Stagger mesh methods using numba. + +set stagger_kind = 'fifth' or 'fifth_improved' or 'first' to use these methods. +stagger_kind = 'fifth' is the default for BifrostData and EbysusData. + +STAGGER KINDS DEFINED HERE: + fifth - original 5th interpolation and 6th derivative order scheme using numba. + functions wrapped in njit + fifth_improved - improved 5th and 6th order scheme for insterpolation and derivative using numba. + the improvement refers to improved precision for "shift" operations. + the improved scheme is also an implemented option in ebysus. + first - 1st order scheme using numpy. + "simplest" method available. + good enough, for most uses. + ~20% faster than numpy and numpy_improved methods + + +METHODS DEFINED HERE (which an end-user might want to access): + do: + perform the indicated stagger operation. + interface for the low-level _xshift, _yshift, _zshift functions. + + _xup, _xdn, _yup, _ydn, _zup, _zdn, _ddxup, _ddxdn, _ddyup, _ddydn, _ddzup, _ddzdn: + peform the corresponding stagger operation on the input array. + These behave like functions; e.g. stagger._xup(arr) does the 'xup' operation. + These are interfaces to stagger.do, for all the possible operations which 'do' can handle. + + xup, xdn, yup, ydn, zup, zdn, ddxup, ddxdn, ddyup, ddydn, ddzup, ddzdn: + Similar to their underscore counterparts, (e.g. xup is like _xup), + with the additional benefit that they can be chained togther, E.g: + stagger.xup.ddzdn.yup(arr, diffz=arr_diff) is equivalent to: + stagger.xup(stagger.ddzdn(stagger.yup(arr), diffz=arr_diff))) + + ^ those methods (xup, ..., ddzdn) in a StaggerInterface: + Additional benefit that the defaults for pad_mode and diff will be determined based on obj, + and also if arr is a string, first do arr = obj(arr) (calls 'get_var'). + Example: + dd = helita.sim.bifrost.BifrostData(...) + dd.stagger.xdn.yup.ddzdn('r') + # performs the operations xdn(yup(ddzdn(r)), + # using dd.get_param('periodic_x') (or y, z, as appropriate) to choose pad_mode, + # and dd.dzidzdn for diff during the ddzdn operation. + If desired to use non-defaults, the kwargs available are: + padx, pady, padz kwargs to force a specific pad_mode for a given axis, + diffx, diffy, diffz kwargs to force a specific diff for a given axis. + +TODO: + - fix ugly printout during verbose==1, for interfaces to 'do', e.g. stagger.xup(arr, verbose=1). +""" + +# import built-in modules +import time +import weakref +import warnings +import collections + +# import public external modules import numpy as np -from numba import jit, njit, prange + +# import internal modules +from . import tools + +try: + from numba import jit, njit, prange +except ImportError: + numba = prange = tools.ImportFailed('numba', "This module is required to use stagger_kind='numba'.") + jit = njit = tools.boring_decorator + + +""" ------------------------ defaults ------------------------ """ + +PAD_PERIODIC = 'wrap' # how to pad periodic dimensions, by default +PAD_NONPERIODIC = 'reflect' # how to pad nonperiodic dimensions, by default +PAD_DEFAULTS = {'x': PAD_PERIODIC, 'y': PAD_PERIODIC, 'z': PAD_NONPERIODIC} # default padding for each dimension. +DEFAULT_STAGGER_KIND = 'fifth' # which stagger kind to use by default. +VALID_STAGGER_KINDS = tuple(('fifth', 'fifth_improved', 'first')) # list of valid stagger kinds. +DEFAULT_MESH_LOCATION_TRACKING = False # whether mesh location tracking should be enabled, by default. + + +def STAGGER_KIND_PROPERTY(internal_name='_stagger_kind', default=DEFAULT_STAGGER_KIND): + '''creates a property which manages stagger_kind. + uses the internal name provided, and returns the default if property value has not been set. + + only allows setting of stagger_kind to valid names (as determined by VALID_STAGGER_KINDS). + ''' + + def get_stagger_kind(self): + return getattr(self, internal_name, default) + + def set_stagger_kind(self, value): + '''sets stagger_kind to VALID_STAGGER_KINDS[value]''' + if not (value in VALID_STAGGER_KINDS): + class KeyErrorMessage(str): # KeyError(msg) uses repr(msg), so newlines don't show up. + def __repr__(self): return str(self) # this is a workaround. Makes the message prettier. + errmsg = (f"stagger_kind = {repr(value)} was invalid!" + "\n" + + f"Expected value from: {VALID_STAGGER_KINDS}." + "\n" + + f"Advanced: to add a valid value, edit helita.sim.stagger.VALID_STAGGER_KINDS") + raise KeyError(KeyErrorMessage(errmsg)) from None + setattr(self, internal_name, value) + + doc = f"Tells which method to use for stagger operations. Options are: {VALID_STAGGER_KINDS}" + + return property(fset=set_stagger_kind, fget=get_stagger_kind, doc=doc) + + +""" ------------------------ stagger constants ------------------------ """ + +StaggerConstants = collections.namedtuple('StaggerConstants', ('a', 'b', 'c')) + +## FIFTH ORDER SCHEME ## +# derivatives +c = (-1 + (3**5 - 3) / (3**3 - 3)) / (5**5 - 5 - 5 * (3**5 - 3)) +b = (-1 - 120*c) / 24 +a = (1 - 3*b - 5*c) +CONSTANTS_DERIV = StaggerConstants(a, b, c) + +# shifts (i.e. not a derivative) +c = 3.0 / 256.0 +b = -25.0 / 256.0 +a = 0.5 - b - c +CONSTANTS_SHIFT = StaggerConstants(a, b, c) + + +## FIRST ORDER SCHEME ## +CONSTANTS_DERIV_o1 = StaggerConstants(1.0, 0, 0) +CONSTANTS_SHIFT_o1 = StaggerConstants(0.5, 0, 0) + + +## GENERIC ## +CONSTANTS_DERIV_ODICT = {5: CONSTANTS_DERIV, 1: CONSTANTS_DERIV_o1} +CONSTANTS_SHIFT_ODICT = {5: CONSTANTS_SHIFT, 1: CONSTANTS_SHIFT_o1} + + +def GET_CONSTANTS_DERIV(order): + return CONSTANTS_DERIV_ODICT[order] -def do(var, operation='xup', diff=None, pad_mode=None): +def GET_CONSTANTS_SHIFT(order): + return CONSTANTS_SHIFT_ODICT[order] + + +# remove temporary variables from the module namespace +del c, b, a + + +""" ------------------------ 'do' - stagger interface ------------------------ """ + + +def do(var, operation='xup', diff=None, pad_mode=None, stagger_kind=DEFAULT_STAGGER_KIND): """ - Do a stagger operation on `var` by doing a 6th order polynomial interpolation of + Do a stagger operation on `var` by doing a 6th order polynomial interpolation of the variable from cell centres to cell faces (down operations), or cell faces to cell centres (up operations). - + Parameters ---------- var : 3D array Variable to work on. + if not 3D, makes a warning but still tries to return a reasonable result: + for non-derivatives, return var (unchanged). + for derivatives, return 0 (as an array with same shape as var). operation: str Type of operation. Currently supported values are * 'xup', 'xdn', 'yup', 'ydn', 'zup', 'zdn' @@ -19,134 +167,1148 @@ def do(var, operation='xup', diff=None, pad_mode=None): diff: None or 1D array If operation is one of the derivatives, you must supply `diff`, an array with the distances between each cell in the direction of the - operation must be same length as array along that direction. + operation must be same length as array along that direction. For non-derivative operations, `diff` must be None. - pad_mode : str + pad_mode : None or str Mode for padding array `var` to have enough points for a 6th order - polynomial interpolation. Same as supported by np.pad. Default is - `wrap` (periodic horizontal) for x and y, and `reflect` for z operations. + polynomial interpolation. Same as supported by np.pad. + if None, use default: `wrap` (periodic) for x and y; `reflect` for z. + stagger_kind: 'fifth', 'fifth_improved', or 'first' + Mode for stagger operations. + fifth --> numba methods ('_xshift', '_yshift', '_zshift') + fifth_improved --> numba methods ('_xshift_improved', '_yshift_improved', '_zshift_improved') + first --> numba methods ('_xshift_o1', '_yshift_o1', '_zshift_o1') Returns ------- 3D array - Array of same type and dimensions to var, after performing the + Array of same type and dimensions to var, after performing the stagger operation. """ - AXES = { - 'x': _xshift, - 'y': _yshift, - 'z': _zshift, - } - DEFAULT_PAD = {'x': 'wrap', 'y': 'wrap', 'z': 'edge'} - if operation[-2:].lower() == 'up': - up = True - elif operation[-2:].lower() == 'dn': - up = False - else: - raise ValueError(f"Invalid operation {operation}") - if operation[:2].lower() == 'dd': # For derivative operations + # initial bookkeeping + AXES = ('x', 'y', 'z') + operation = operation_orig = operation.lower() + # order + if stagger_kind == 'first': + order = 1 + else: + order = 5 + # derivative, diff + if operation[:2] == 'dd': # For derivative operations derivative = True operation = operation[2:] if diff is None: - raise ValueError("diff not provided for derivative operation") + raise ValueError(f"diff not provided for derivative operation: {operation_orig}") else: derivative = False if diff is not None: - raise ValueError("diff must not be provided for non-derivative operation") - op = operation[:-2] - if op not in AXES: - raise ValueError(f"Invalid operation {operation}") - func = AXES[op] + raise ValueError(f"diff must not be provided for non-derivative operation: {operation}") + # make sure var is 3D. make warning then handle appropriately if not. + if np.ndim(var) != 3: + warnmsg = f'can only stagger 3D array but got {np.ndim(var)}D.' + if derivative: + warnings.warn(warnmsg + f' returning 0 for operation {operation_orig}') + return np.zeros_like(var) + else: + warnings.warn(warnmsg + f' returning original array for operation {operation_orig}') + return var + # up/dn + up_str = operation[-2:] # 'up' or 'dn' + if up_str == 'up': + up = True + elif up_str == 'dn': + up = False + else: + raise ValueError(f"Invalid operation; must end in 'up' or 'dn': {operation}") + # x, dim_index (0 for x, 1 for y, 2 for z) + x = operation[:-2] + if x not in AXES: + raise ValueError(f"Invalid operation; axis must be 'x', 'y', or 'z': {operation}") if pad_mode is None: - pad_mode = DEFAULT_PAD[op] - dim_index = 'xyz'.find(op[-1]) - extra_dims = [(3, 2), (2, 3)][up] + pad_mode = PAD_DEFAULTS[x] + dim_index = AXES.index(x) + # padding + extra_dims = (2, 3) if up else (3, 2) if not derivative: diff = np.ones(var.shape[dim_index], dtype=var.dtype) padding = [(0, 0)] * 3 padding[dim_index] = extra_dims - s = var.shape - if s[dim_index] == 1: - return var + # interpolating + if var.shape[dim_index] <= 5: # don't interpolate along axis with size 5 or less... + if derivative: + result = np.zeros_like(var) # E.g. ( dvardzup, where var has shape (Nx, Ny, 1) ) --> 0 + else: + result = var else: out = np.pad(var, padding, mode=pad_mode) out_diff = np.pad(diff, extra_dims, mode=pad_mode) - return func(out, out_diff, up=up, derivative=derivative) + if stagger_kind in ['fifth', 'first']: + func = {'x': _xshift, 'y': _yshift, 'z': _zshift}[x] + result = func(out, out_diff, up=up, order=order, derivative=derivative) + elif stagger_kind == 'fifth_improved': + func = {'x': _xshift_improved, 'y': _yshift_improved, 'z': _zshift_improved}[x] + result = func(out, out_diff, up=up, derivative=derivative) + else: + raise ValueError(f"invalid stagger_kind: '{stagger_kind}'. Options are: {VALID_STAGGER_KINDS}") + # tracking mesh location. + meshloc = getattr(var, 'meshloc', None) + if meshloc is not None: # (input array had a meshloc attribute) + result = ArrayOnMesh(result, meshloc=meshloc) + result._shift_location(f'{x}{up_str}') + # output. + return result + + +""" ------------------------ numba stagger ------------------------ """ + +## STAGGER_KIND = NUMBA ## -@jit(parallel=True,nopython=True) -def _xshift(var, diff, up=True, derivative=False): +@njit(parallel=True) +def _xshift(var, diff, up=True, order=5, derivative=False): + grdshf = 1 if up else 0 + start = int(3. - grdshf) + end = - int(2. + grdshf) + nx, ny, nz = var.shape + out = np.zeros((nx, ny, nz)) + if order == 5: + if derivative: + pm, (a, b, c) = -1, CONSTANTS_DERIV + else: + pm, (a, b, c) = 1, CONSTANTS_SHIFT + for k in prange(nz): + for j in prange(ny): + for i in prange(start, nx + end): + out[i, j, k] = diff[i] * (a * (var[i + grdshf, j, k] + pm * var[i - 1 + grdshf, j, k]) + + b * (var[i + 1 + grdshf, j, k] + pm * var[i - 2 + grdshf, j, k]) + + c * (var[i + 2 + grdshf, j, k] + pm * var[i - 3 + grdshf, j, k])) + elif order == 1: + if derivative: + pm, (a, b, c) = -1, CONSTANTS_DERIV_o1 + else: + pm, (a, b, c) = 1, CONSTANTS_SHIFT_o1 + for k in prange(nz): + for j in prange(ny): + for i in prange(start, nx + end): + out[i, j, k] = diff[i] * (a * (var[i + grdshf, j, k] + pm * var[i - 1 + grdshf, j, k])) + return out[start:end, :, :] + + +@njit(parallel=True) +def _yshift(var, diff, up=True, order=5, derivative=False): + grdshf = 1 if up else 0 + start = int(3. - grdshf) + end = - int(2. + grdshf) + nx, ny, nz = var.shape + out = np.zeros((nx, ny, nz)) + if order == 5: + if derivative: + pm, (a, b, c) = -1, CONSTANTS_DERIV + else: + pm, (a, b, c) = 1, CONSTANTS_SHIFT + for k in prange(nz): + for j in prange(start, ny + end): + for i in prange(nx): + out[i, j, k] = diff[j] * (a * (var[i, j + grdshf, k] + pm * var[i, j - 1 + grdshf, k]) + + b * (var[i, j + 1 + grdshf, k] + pm * var[i, j - 2 + grdshf, k]) + + c * (var[i, j + 2 + grdshf, k] + pm * var[i, j - 3 + grdshf, k])) + elif order == 1: + if derivative: + pm, (a, b, c) = -1, CONSTANTS_DERIV_o1 + else: + pm, (a, b, c) = 1, CONSTANTS_SHIFT_o1 + for k in prange(nz): + for j in prange(start, ny + end): + for i in prange(nx): + out[i, j, k] = diff[j] * (a * (var[i, j + grdshf, k] + pm * var[i, j - 1 + grdshf, k])) + return out[:, start:end, :] + + +@njit(parallel=True) +def _zshift(var, diff, up=True, order=5, derivative=False): grdshf = 1 if up else 0 - if derivative: - pm = -1 - c = (-1. + (3.**5 - 3.) / (3.**3 - 3.)) / (5.**5 - 5. - 5. * (3.**5 - 3)) - b = (-1. - 120.*c) / 24. - a = (1. - 3.*b - 5.*c) - else: - pm = 1 - c = 3.0 / 256.0 - b = -25.0 / 256.0 - a = 0.5 - b - c start = int(3. - grdshf) end = - int(2. + grdshf) nx, ny, nz = var.shape out = np.zeros((nx, ny, nz)) - for k in prange(nz): - for j in prange(ny): - for i in prange(start, nx + end): - out[i, j, k] = diff[i] * (a * (var[i + grdshf, j, k] + pm * var[i - 1 + grdshf, j, k]) + - b * (var[i + 1 + grdshf, j, k] + pm * var[i - 2 + grdshf, j, k]) + - c * (var[i + 2 + grdshf, j, k] + pm * var[i - 3 + grdshf, j, k])) - return out[start:end,:,:] + if order == 5: + if derivative: + pm, (a, b, c) = -1, CONSTANTS_DERIV + else: + pm, (a, b, c) = 1, CONSTANTS_SHIFT + for k in prange(start, nz + end): + for j in prange(ny): + for i in prange(nx): + out[i, j, k] = diff[k] * (a * (var[i, j, k + grdshf] + pm * var[i, j, k - 1 + grdshf]) + + b * (var[i, j, k + 1 + grdshf] + pm * var[i, j, k - 2 + grdshf]) + + c * (var[i, j, k + 2 + grdshf] + pm * var[i, j, k - 3 + grdshf])) + elif order == 1: + if derivative: + pm, (a, b, c) = -1, CONSTANTS_DERIV_o1 + else: + pm, (a, b, c) = 1, CONSTANTS_SHIFT_o1 + for k in prange(start, nz + end): + for j in prange(ny): + for i in prange(nx): + out[i, j, k] = diff[k] * (a * (var[i, j, k + grdshf] + pm * var[i, j, k - 1 + grdshf])) + return out[:, :, start:end] -@jit(parallel=True,nopython=True) -def _yshift(var, diff, up=True, derivative=False): +@njit(parallel=True) +def _xshift_improved(var, diff, up=True, derivative=False): grdshf = 1 if up else 0 + start = int(3. - grdshf) + end = - int(2. + grdshf) + if derivative: + pm, (a, b, c) = -1, CONSTANTS_DERIV + else: + pm, (a, b, c) = 1, CONSTANTS_SHIFT + nx, ny, nz = var.shape + out = np.zeros((nx, ny, nz)) if derivative: - pm = -1 - c = (-1. + (3.**5 - 3.) / (3.**3 - 3.)) / (5.**5 - 5. - 5. * (3.**5 - 3)) - b = (-1. - 120.*c) / 24. - a = (1. - 3.*b - 5.*c) + for k in prange(nz): + for j in prange(ny): + for i in prange(start, nx + end): + out[i, j, k] = diff[i] * (a * (var[i + grdshf, j, k] + pm * var[i - 1 + grdshf, j, k]) + + b * (var[i + 1 + grdshf, j, k] + pm * var[i - 2 + grdshf, j, k]) + + c * (var[i + 2 + grdshf, j, k] + pm * var[i - 3 + grdshf, j, k])) else: - pm = 1 - c = 3.0 / 256.0 - b = -25.0 / 256.0 - a = 0.5 - b - c + for k in prange(nz): + for j in prange(ny): + for i in prange(start, nx + end): + out[i, j, k] = diff[i] * (a * (var[i + grdshf, j, k] - var[i - 1 + grdshf, j, k]) + + b * (var[i + 1 + grdshf, j, k] - var[i - 1 + grdshf, j, k] + + var[i - 2 + grdshf, j, k] - var[i - 1 + grdshf, j, k]) + + c * (var[i + 2 + grdshf, j, k] - var[i - 1 + grdshf, j, k] + + var[i - 3 + grdshf, j, k] - var[i - 1 + grdshf, j, k]) + + var[i - 1 + grdshf, j, k]) + + return out[start:end, :, :] + + +@njit(parallel=True) +def _yshift_improved(var, diff, up=True, derivative=False): + grdshf = 1 if up else 0 start = int(3. - grdshf) end = - int(2. + grdshf) + if derivative: + pm, (a, b, c) = -1, CONSTANTS_DERIV + else: + pm, (a, b, c) = 1, CONSTANTS_SHIFT nx, ny, nz = var.shape out = np.zeros((nx, ny, nz)) - for k in prange(nz): - for j in prange(start, ny + end): - for i in prange(nx): - out[i, j, k] = diff[j] * (a * (var[i, j + grdshf, k] + pm * var[i, j - 1 + grdshf, k]) + - b * (var[i, j + 1 + grdshf, k] + pm * var[i, j - 2 + grdshf, k]) + - c * (var[i, j + 2 + grdshf, k] + pm * var[i, j - 3 + grdshf, k])) - return out[:,start:end,:] + if derivative: + for k in prange(nz): + for j in prange(start, ny + end): + for i in prange(nx): + out[i, j, k] = diff[j] * (a * (var[i, j + grdshf, k] + pm * var[i, j - 1 + grdshf, k]) + + b * (var[i, j + 1 + grdshf, k] + pm * var[i, j - 2 + grdshf, k]) + + c * (var[i, j + 2 + grdshf, k] + pm * var[i, j - 3 + grdshf, k])) + else: + for k in prange(nz): + for j in prange(start, ny + end): + for i in prange(nx): + out[i, j, k] = diff[j] * (a * (var[i, j + grdshf, k] - var[i, j - 1 + grdshf, k]) + + b * (var[i, j + 1 + grdshf, k] - var[i, j - 1 + grdshf, k] + + var[i, j - 2 + grdshf, k] - var[i, j - 1 + grdshf, k]) + + c * (var[i, j + 2 + grdshf, k] - var[i, j - 1 + grdshf, k] + + var[i, j - 3 + grdshf, k] - var[i, j - 1 + grdshf, k]) + + var[i, j - 1 + grdshf, k]) + return out[:, start:end, :] -@jit(parallel=True,nopython=True) -def _zshift(var, diff, up=True, derivative=False): +@njit(parallel=True) +def _zshift_improved(var, diff, up=True, derivative=False): grdshf = 1 if up else 0 start = int(3. - grdshf) end = - int(2. + grdshf) if derivative: - pm = -1 - c = (-1. + (3.**5 - 3.) / (3.**3 - 3.)) / (5.**5 - 5. - 5. * (3.**5 - 3)) - b = (-1. - 120.*c) / 24. - a = (1. - 3.*b - 5.*c) + pm, (a, b, c) = -1, CONSTANTS_DERIV else: - pm = 1 - c = 3.0 / 256.0 - b = -25.0 / 256.0 - a = 0.5 - b - c + pm, (a, b, c) = 1, CONSTANTS_SHIFT nx, ny, nz = var.shape out = np.zeros((nx, ny, nz)) - for k in prange(start, nz + end): - for j in prange(ny): - for i in prange(nx): - out[i, j, k] = diff[k] * (a * (var[i, j, k + grdshf] + pm * var[i, j, k - 1 + grdshf]) + - b * (var[i, j, k + 1 + grdshf] + pm * var[i, j, k - 2 + grdshf]) + - c * (var[i, j, k + 2 + grdshf] + pm * var[i, j, k - 3 + grdshf])) - return out[:,:,start:end] \ No newline at end of file + if derivative: + for k in prange(start, nz + end): + for j in prange(ny): + for i in prange(nx): + out[i, j, k] = diff[k] * (a * (var[i, j, k + grdshf] + pm * var[i, j, k - 1 + grdshf]) + + b * (var[i, j, k + 1 + grdshf] + pm * var[i, j, k - 2 + grdshf]) + + c * (var[i, j, k + 2 + grdshf] + pm * var[i, j, k - 3 + grdshf])) + else: + for k in prange(start, nz + end): + for j in prange(ny): + for i in prange(nx): + out[i, j, k] = diff[k] * (a * (var[i, j, k + grdshf] - var[i, j, k - 1 + grdshf]) + + b * (var[i, j, k + 1 + grdshf] - var[i, j, k - 1 + grdshf] + + var[i, j, k - 2 + grdshf] - var[i, j, k - 1 + grdshf]) + + c * (var[i, j, k + 2 + grdshf] - var[i, j, k - 1 + grdshf] + + var[i, j, k - 3 + grdshf] - var[i, j, k - 1 + grdshf]) + + var[i, j, k - 1 + grdshf]) + return out[:, :, start:end] + + +""" ------------------------ MeshLocation, ArrayOnMesh ------------------------ """ +# The idea is to associate arrays with a location on the mesh, +# update that mesh location info whenever a stagger operation is performed, +# and enforce arrays have the same location when doing arithmetic. + + +class MeshLocation(): + '''class defining a location on a mesh. + Also provides shifting operations. + + Examples: + m = MeshLocation([0, 0.5, 0]) + m.xup + MeshLocation([0.5, 0.5, 0]) + m.xup.ydn.zdn + MeshLocation([0.5, 0, -0.5]) + ''' + + def __init__(self, loc=[0, 0, 0]): + self.loc = list(loc) + + def __repr__(self): + return f'{type(self).__name__}({self.loc})' # TODO even spacing (length == len('-0.5')) + + def _new(self, *args, **kw): + return type(self)(*args, **kw) + + ## LIST-LIKE BEHAVIOR ## + + def __iter__(self): + return iter(self.loc) + + def __len__(self): + return len(self.loc) + + def __getitem__(self, i): + return self.loc[i] + + def __setitem__(self, i, value): + self.loc[i] = value + + def __eq__(self, other): + if len(other) != len(self): + return False + return all(s == o for s, o in zip(self, other)) + + def copy(self): + return MeshLocation(self) + + ## MESH LOCATION ARITHMETIC ## + def __add__(self, other): + '''element-wise addition of self + other, returned as a MeshLocation.''' + return self._new([s + o for s, o in zip(self, other)]) + + def __sub__(self, other): + '''element-wise subtraction of self - other, returned as a MeshLocation.''' + return self._new([s - o for s, o in zip(self, other)]) + + def __radd__(self, other): + '''element-wise addition of other + self, returned as a MeshLocation.''' + return self._new([o + s for s, o in zip(self, other)]) + + def __rsub__(self, other): + '''element-wise subtraction of other - self, returned as a MeshLocation.''' + return self._new([o - s for s, o in zip(self, other)]) + + ## MESH LOCATION AS OPERATION LIST ## + def as_operations(self): + '''returns self, viewed as a list of operations. (returns a list of strings.) + equivalently, returns "steps needed to get from (0,0,0) to self". + + Examples: + MeshLocation([0.5, 0, 0]).as_operations() + ['xup'] + MeshLocation([0, -0.5, -0.5]).as_operations() + ['ydn', 'zdn'] + MeshLocation([1.0, -0.5, -1.5]).as_operations() + ['xup', 'xup', 'ydn', 'zdn', 'zdn', 'zdn'] + ''' + AXES = ('x', 'y', 'z') + result = [] + for x, val in zip(AXES, self): + if val == 0: + continue + n = val / 0.5 # here we expect n to be an integer-valued float. (e.g. 1.0) + assert getattr(n, 'is_integer', lambda: True)(), f"Expected n/0.5 to be an integer. n={n}, self={self}" + up = 'up' if val > 0 else 'dn' + n = abs(int(n)) # convert n to positive integer (required for list multiplication) + result += ([f'{x}{up}'] * n) # list addition; list multiplication. + return result + + as_ops = property(lambda self: self.as_operations, doc='alias for as_operations') + + def steps_from(self, other): + '''return the steps needed to get FROM other TO self. (returns a list of strings.) + + Examples: + MeshLocation([0.5, 0, 0]).steps_from([0,0,0]) + ['xup'] + MeshLocation([-0.5, 0, 0]).steps_from(MeshLocation([0.5, -0.5, -0.5])) + ['xdn', 'xdn', 'yup', 'zup'] + ''' + return (self - other).as_operations() + + def steps_to(self, other): + '''return the steps needed to get TO other FROM self. (returns a list of strings.) + + Examples: + MeshLocation([0.5, 0, 0]).steps_to([0,0,0]) + ['xdn'] + MeshLocation([-0.5, 0, 0]).steps_to(MeshLocation([0.5, -0.5, -0.5])) + ['xup', 'xup', 'ydn', 'zdn'] + ''' + return (other - self).as_operations() + + ## MESH LOCATION DESCRIPTION ## + def describe(self): + '''returns a description of self. + The possible descriptions are: + ('center', None), + ('face', 'x'), ('face', 'y'), ('face', 'z'), + ('edge', 'x'), ('edge', 'y'), ('edge', 'z'), + ('unknown', None) + They mean: + 'center' --> location == (0,0,0) + 'face_x' --> location == (-0.5, 0, 0) # corresponds to x-component of a face-centered vector like magnetic field. + 'edge_x' --> location == (0, -0.5, -0.5) # correspodns to x-component of an edge-centered vector like electric field. + 'unknown' --> location is not center, face, or edge. + face_y, face_z, edge_y, edge_z take similar meanings as face_x, edge_x, but for the y, z directions instead. + + returns one of the tuples above. + ''' + lookup = {-0.5: True, 0: False} + xdn = lookup.get(self[0], None) + ydn = lookup.get(self[1], None) + zdn = lookup.get(self[2], None) + pos = (xdn, ydn, zdn) + if all(p is True for p in pos): + return ('center', None) + if any(p is None for p in pos) or all(p is True for p in pos): + return ('unknown', None) + if xdn: + if ydn: + return ('edge', 'z') + elif zdn: + return ('edge', 'y') + else: + return ('face', 'x') + elif ydn: + if zdn: + return ('edge', 'x') + else: + return ('face', 'y') + elif zdn: + return ('face', 'z') + # could just return ('unknown', None) if we reach this line. + # But we expect the code to have handled all cases by this line. + # So if this error is ever raised, we made a mistake in the code of this function. + assert False, f"Expected all meshlocs should have been accounted for, but this one was not: {self}" + + ## MESH LOCATION SHIFTING ## + def shifted(self, xup): + '''return a copy of self shifted by xup. + xup: 'xup', 'xdn', 'yup', 'ydn', 'zup', or 'zdn'. + ''' + return getattr(self, xup) + + # the properties: xup, xdn, yup, ydn, zup, zdn + # are added to the class after its initial definition. + +## MESH LOCATION SHIFTING ## + + +def _mesh_shifter(x, up): + '''returns a function which returns a copy of MeshLocation but shifted by x and up. + x should be 'x', 'y', or 'z'. + up should be 'up' or 'dn'. + ''' + ix = {'x': 0, 'y': 1, 'z': 2}[x] + up_value = {'up': 0.5, 'dn': -0.5}[up] + + def mesh_shifted(self): + '''returns a copy of self shifted by {x}{up}''' + copy = self.copy() + copy[ix] += up_value + return copy + mesh_shifted.__doc__ = mesh_shifted.__doc__.format(x=x, up=up) + mesh_shifted.__name__ = f'{x}{up}' + return mesh_shifted + + +def _mesh_shifter_property(x, up): + '''returns a property which calls a function that returns a copy of MeshLocation shifted by x and up.''' + shifter = _mesh_shifter(x, up) + return property(fget=shifter, doc=shifter.__doc__) + + +# actually set the functions xup, ..., zdn, as methods of MeshLocation. +for x in ('x', 'y', 'z'): + for up in ('up', 'dn'): + setattr(MeshLocation, f'{x}{up}', _mesh_shifter_property(x, up)) + + +class ArrayOnMesh(np.ndarray): + '''numpy array associated with a location on a mesh grid. + + Examples: + ArrayOnMesh(x, meshloc=[0,0,0]) + ArrayOnMesh(y, meshloc=[0,0,0.5]) + with x, y numpy arrays (or subclasses). + + The idea is to enforce that arrays are at the same mesh location before doing any math. + When arrays are at different locations, raise an AssertionError instead. + + The operations xup, ..., zdn are intentionally not provided here. + This is to avoid potential confusion of thinking stagger is being performed when it is not. + ArrayOnMesh does not know how to actually do any of the stagger operations. + Rather, the stagger operations are responsible for properly tracking mesh location; + they can use the provided _relocate or _shift_location methods to do so. + + meshloc: list, tuple, MeshLocation object, or None + None --> default. If input has meshloc, use meshloc of input; else use [0,0,0] + else --> use this value as the mesh location. + ''' + def __new__(cls, input_array, meshloc=None): + obj = np.asanyarray(input_array).view(cls) # view input_array as an ArrayOnMesh. + if meshloc is None: + obj.meshloc = getattr(obj, 'meshloc', [0, 0, 0]) + else: + obj.meshloc = meshloc + return obj + + def __array_finalize__(self, obj): + '''handle other ways of creating this array, e.g. copying an existing ArrayOnMesh.''' + if obj is None: + return + self.meshloc = getattr(obj, 'meshloc', [0, 0, 0]) + + def describe_mesh_location(self): + '''returns a description of the mesh location of self. + The possible descriptions are: + ('center', None), + ('face', 'x'), ('face', 'y'), ('face', 'z'), + ('edge', 'x'), ('edge', 'y'), ('edge', 'z'), + ('unknown', None) + They mean: + 'center' --> location == (0,0,0) + 'face_x' --> location == (-0.5, 0, 0) # corresponds to x-component of a face-centered vector like magnetic field. + 'edge_x' --> location == (0, -0.5, -0.5) # correspodns to x-component of an edge-centered vector like electric field. + 'unknown' --> location is not center, face, or edge. + face_y, face_z, edge_y, edge_z take similar meanings as face_x, edge_x, but for the y, z directions instead. + + returns one of the tuples above. + ''' + return self.meshloc.describe() + + @property + def meshloc(self): + return self._meshloc + + @meshloc.setter + def meshloc(self, newloc): + if not isinstance(newloc, MeshLocation): + newloc = MeshLocation(newloc) + self._meshloc = newloc + + def _relocate(self, new_meshloc): + '''changes the location associated with self to new_meshloc. + DOES NOT PERFORM ANY STAGGER OPERATIONS - + the array contents will be unchanged; only the mesh location label will be affected. + ''' + self.meshloc = new_meshloc + + def _shift_location(self, xup): + '''shifts the location associated with self by xup. + DOES NOT PERFORM ANY STAGGER OPERATIONS - + the array contents will be unchanged; only the mesh location label will be affected. + ''' + self._relocate(self.meshloc.shifted(xup)) + + def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs): + '''does the ufunc but first ensures all arrays are at the same meshloc. + + The code here follows the format of the example from the numpy subclassing docs. + ''' + args = [] + meshloc = None + for i, input_ in enumerate(inputs): + if isinstance(input_, type(self)): + if meshloc is None: + meshloc = input_.meshloc + else: + assert meshloc == input_.meshloc, f"Inputs' mesh locations differ: {meshloc}, {input_.meshloc}" + args.append(input_.view(np.ndarray)) + else: + args.append(input_) + + assert meshloc is not None # meshloc should have been set to some value by this point. + + outputs = out + if outputs: + out_args = [] + for j, output in enumerate(outputs): + if isinstance(output, type(self)): + out_args.append(output.view(np.ndarray)) + else: + out_args.append(output) + kwargs['out'] = tuple(out_args) + else: + outputs = (None,) * ufunc.nout + + results = super().__array_ufunc__(ufunc, method, *args, **kwargs) + if results is NotImplemented: + return NotImplemented + + if method == 'at': + if isinstance(inputs[0], type(self)): + inputs[0].meshloc = meshloc + return + + if ufunc.nout == 1: + results = (results,) + + results = tuple((np.asarray(result).view(type(self)) + if output is None else output) + for result, output in zip(results, outputs)) + if results and isinstance(results[0], type(self)): + results[0].meshloc = meshloc + + return results[0] if len(results) == 1 else results + + def __repr__(self): + result = super().__repr__() + return f'{result} at {self.meshloc}' + + +# predefined mesh locations +def mesh_location_center(): + '''returns MeshLocation at center of box. (0,0,0)''' + return MeshLocation([0, 0, 0]) + + +def mesh_location_face(x): + '''returns MeshLocation centered at face x. + x: 'x', 'y', or 'z'. + 'x' --> [-0.5, 0 , 0 ] + 'y' --> [ 0 , -0.5, 0 ] + 'z' --> [ 0 , 0 , -0.5] + ''' + loc = {'x': [-0.5, 0, 0], + 'y': [0, -0.5, 0], + 'z': [0, 0, -0.5]} + return MeshLocation(loc[x]) + + +def mesh_location_edge(x): + '''returns MeshLocation centered at edge x. + x: 'x', 'y', or 'z'. + 'x' --> [ 0 , -0.5, -0.5] + 'y' --> [-0.5, 0 , -0.5] + 'z' --> [-0.5, -0.5, 0 ] + ''' + loc = {'x': [0, -0.5, -0.5], + 'y': [-0.5, 0, -0.5], + 'z': [-0.5, -0.5, 0]} + return MeshLocation(loc[x]) + +# describing mesh locations (for a "generic object") + + +def get_mesh_location(obj, *default): + if len(default) > 0: + return getattr(obj, 'meshloc', default[0]) + else: + return getattr(obj, 'meshloc') + + +def has_mesh_location(obj): + return hasattr(obj, 'meshloc') + + +def describe_mesh_location(obj): + '''returns a description of the mesh location of obj + The possible descriptions are: + ('center', None), + ('face', 'x'), ('face', 'y'), ('face', 'z'), + ('edge', 'x'), ('edge', 'y'), ('edge', 'z'), + ('unknown', None) + ('none', None) + They mean: + 'center' --> location == (0,0,0) + 'face_x' --> location == (-0.5, 0, 0) # corresponds to x-component of a face-centered vector like magnetic field. + 'edge_x' --> location == (0, -0.5, -0.5) # correspodns to x-component of an edge-centered vector like electric field. + 'unknown' --> location is not center, face, or edge. + 'none' --> obj is not a MeshLocation and does not have attribute meshloc. + face_y, face_z, edge_y, edge_z take similar meanings as face_x, edge_x, but for the y, z directions instead. + + returns one of the tuples above. + ''' + if isinstance(obj, MeshLocation): + return obj.describe() + elif hasattr(obj, 'meshloc'): + return obj.meshloc.describe() + else: + return ('unknown', None) + +# mesh location tracking property + + +def MESH_LOCATION_TRACKING_PROPERTY(internal_name='_mesh_location_tracking', default=DEFAULT_MESH_LOCATION_TRACKING): + '''creates a property which manages mesh_location_tracking. + uses the internal name provided, and returns the default if property value has not been set. + + checks self.do_stagger and self.stagger_kind for compatibility (see doc of the produced property for details). + ''' + doc = f'''whether mesh location tracking is enabled. (default is {default}) + True --> arrays from get_var will be returned as stagger.ArrayOnMesh objects, + which track the location on mesh but also require locations of + arrays (if they are ArrayOnMesh) to match before doing arithmetic. + False --> stagger.ArrayOnMesh conversion will be disabled. + + Tied directly to self.do_stagger and self.stagger_kind. + when self.do_stagger or self.stagger_kind are INCOMPATIBLE with mesh_location_tracking, + mesh_location_tracking will be disabled, until compatibility requirements are met. + trying to set mesh_location_tracking = True will make a ValueError. + INCOMPATIBLE when one or more of the following are True: + 1) bool(self.do_stagger) != True + 2) self.stagger_kind not in stagger.VALID_STAGGER_KINDS + (compatible stagger_kinds are {VALID_STAGGER_KINDS}) + ''' + + def _mesh_location_tracking_incompatible(obj): + '''returns attributes of obj with present values incompatible with mesh_location_tracking. + e.g. ['do_stagger', 'stagger_kind'], or ['stagger_kind'], or ['do_stagger'] or []. + + non-existing attributes do not break compatibility. + E.g. if do_stagger and stagger_kind are unset, result will be [] (i.e. "fully compatible"). + ''' + result = [] + if not getattr(obj, 'do_stagger', True): + result.append('do_stagger') + if not getattr(obj, 'stagger_kind', VALID_STAGGER_KINDS[0]) in VALID_STAGGER_KINDS: + result.append('stagger_kind') + return result + + def get_mesh_location_tracking(self): + result = getattr(self, '_mesh_location_tracking', default) + if result: + # before returning True, check compatibility. + incompatible = _mesh_location_tracking_incompatible(self) + if incompatible: + return False + return result + + def set_mesh_location_tracking(self, value): + if value: + # before setting to True, check compatibility. + incompatible = _mesh_location_tracking_incompatible(self) + if incompatible: + # make a ValueError with helpful instructions. + errmsg = f"present values of attributes {incompatible} are incompatible" +\ + "with mesh_location_tracking. To enable mesh_location_tracking, first you must" + if 'do_stagger' in incompatible: + errmsg += " enable do_stagger" + if len(incompatible) > 1: + errmsg += " and" + if 'stagger_kind' in incompatible: + errmsg += f" set stagger_kind to one of the python stagger kinds: {VALID_STAGGER_KINDS}" + errmsg += "." + raise ValueError(errmsg) + self._mesh_location_tracking = value + + return property(fget=get_mesh_location_tracking, fset=set_mesh_location_tracking, doc=doc) + + +""" ------------------------ Aliases ------------------------ """ +# Here we define the 12 supported operations, using the 'do' function defined above. The ops are: +# xdn, xup, ydn, yup, zdn, zup, ddxdn, ddxup, ddydn, ddyup, ddzdn, ddzup +# This is for convenience; the 'work' of the stagger methods is handled by the functions above. + +# The definitions here all put a leading underscore '_'. E.g.: _xup, _ddydn. +# This is because users should prefer the non-underscored versions defined in the next section, +# since those can be chained together, e.g. ddxdn.xup.ydn(arr) equals to _ddxdn(_xup(_ydn(arr))). + + +class _stagger_factory(): + def __init__(self, x, up, opstr_fmt): + self.x = x + self.up = up + self.opstr = opstr_fmt.format(x=x, up=up) + self.__doc__ = self.__doc__.format(up=up, x=x, pad_default=PAD_DEFAULTS[x]) + self.__name__ = f'_{self.opstr}' + + def __call__(self, arr, pad_mode=None, verbose=False, + padx=None, pady=None, padz=None, + **kw__do): + if pad_mode is None: + pad_mode = {'x': padx, 'y': pady, 'z': padz}[self.x] + if verbose: + end = '\n' if verbose > 1 else '\r\r' + msg = f'interpolating: {self.opstr:>5s}.' + print(msg, end=' ', flush=True) + now = time.time() + result = do(arr, self.opstr, pad_mode=pad_mode, **kw__do) + if verbose: + print(f'Completed in {time.time()-now:.4f} seconds.', end=end, flush=True) + return result + + +class _stagger_spatial(_stagger_factory): + '''interpolate data one half cell {up} in {x}. + arr: 3D array + the data to be interpolated. + pad_mode: None or str. + pad mode for the interpolation; same options as those supported by np.pad. + if None, the default for this operation will be used: '{pad_default}' + verbose: 0, 1, 2 + 0 --> no verbosity + 1 --> print, end with '\r'. + 2 --> print, end with '\n'. + padx, pady, padz: None or string + pad_mode, but only applies for operation in the corresponding axis. + (For convenience. E.g. if all pad_modes are known, can enter padx, pady, padz, + without needing to worry about which type of operation is being performed.) + + **kw__None: + additional kwargs are ignored. + + TODO: fix ugly printout during verbose==1. + ''' + + def __init__(self, x, up): + super().__init__(x, up, opstr_fmt='{x}{up}') + + def __call__(self, arr, pad_mode=None, verbose=False, + padx=None, pady=None, padz=None, stagger_kind=DEFAULT_STAGGER_KIND, **kw__None): + return super().__call__(arr, pad_mode=pad_mode, verbose=verbose, + padx=padx, pady=pady, padz=padz, stagger_kind=stagger_kind) + + +class _stagger_derivate(_stagger_factory): + '''take derivative of data, interpolating one half cell {up} in {x}. + arr: 3D array + the data to be interpolated. + diff: 1D array + array of distances between each cell along the {x} axis; + length of array must equal the number of points in {x}. + pad_mode: None or str. + pad mode for the interpolation; same options as those supported by np.pad. + if None, the default for this operation will be used: '{pad_default}' + verbose: 0, 1, 2 + 0 --> no verbosity + 1 --> print, end with r'\r'. + 2 --> print, end with r'\n'. + padx, pady, padz: None or string + pad_mode, but only applies for operation in the corresponding axis. + (For convenience. E.g. if all pad_modes are known, can enter padx, pady, padz, + without needing to worry about which type of operation is being performed.) + diffx, diffy, diffz: None or array + diff, but only applies for operation in the corresponding axis. + (For convenience. E.g. if all diffs are known, can enter diffx, diffy, and diffz, + without needing to worry about which type of operation is being performed.) + + TODO: fix ugly printout during verbose==1. + ''' + + def __init__(self, x, up): + super().__init__(x, up, opstr_fmt='dd{x}{up}') + + def __call__(self, arr, diff=None, pad_mode=None, verbose=False, + padx=None, pady=None, padz=None, + diffx=None, diffy=None, diffz=None, + stagger_kind=DEFAULT_STAGGER_KIND, **kw__None): + if diff is None: + diff = {'x': diffx, 'y': diffy, 'z': diffz}[self.x] + return super().__call__(arr, diff=diff, pad_mode=pad_mode, verbose=verbose, + padx=padx, pady=pady, padz=padz, stagger_kind=stagger_kind) + + +_STAGGER_ALIASES = {} +for x in ('x', 'y', 'z'): + _pad_default = PAD_DEFAULTS[x] + for up in ('up', 'dn'): + # define _xup (or _xdn, _yup, _ydn, _zup, _zdn). + _STAGGER_ALIASES[f'_{x}{up}'] = _stagger_spatial(x, up) + # define _ddxup (or _ddxdn, _ddyup, _ddydn, _ddzup, _ddzdn). + _STAGGER_ALIASES[f'_dd{x}{up}'] = _stagger_derivate(x, up) + +## << HERE IS WHERE WE ACTUALLY PUT THE FUNCTIONS INTO THE MODULE NAMESPACE >> ## +for _opstr, _op in _STAGGER_ALIASES.items(): + locals()[_opstr] = _op + +del x, _pad_default, up, _opstr, _op # << remove "temporary variables" from module namespace + +# << At this point, the following functions have all been defined in the module namespace: +# _xdn, _xup, _ydn, _yup, _zdn, _zup, _ddxdn, _ddxup, _ddydn, _ddyup, _ddzdn, _ddzup +# Any of them may be referenced. E.g. import helita.sim.stagger; stagger._ddydn # < this has been defined. + + +""" ------------------------ Chainable Interpolation Objects ------------------------ """ +# Here is where we define: +# xdn, xup, ydn, yup, zdn, zup, ddxdn, ddxup, ddydn, ddyup, ddzdn, ddzup +# They can be called as you would expect, e.g. xdn(arr), +# or chained together, e.g. xdn.ydn.zdn.ddzup(arr) would do xdn(ydn(xdn(ddzup(arr)))). + + +def _trim_leading_underscore(name): + return name[1:] if name[0] == '_' else name + + +class BaseChain(): # base class. Inherit from this class before creating the chain. See e.g. _make_chain(). + """ + object which behaves like an interpolation function (e.g. xup, ddydn), + but can be chained to other interpolations, e.g. xup.ydn.zup(arr) + + This object in particular behaves like: {undetermined}. + + Helpful tricks: + to pass diff to derivatives, use kwargs diffx, diffy, diffz. + to apply in reverse order, use kwarg reverse=True. + default order is A.B.C(val) --> A(B(C(val))). + """ + ## ESSENTIAL BEHAVIORS ## + + def __init__(self, f_self, *funcs): + self.funcs = [f_self, *funcs] + # bookkeeping (non-essential, but makes help() more helpful and repr() prettier) + self.__name__ = _trim_leading_underscore(f_self.__name__) + self.__doc__ = self.__doc__.format(undetermined=self.__name__) + + def __call__(self, x, reverse=False, **kw): + '''apply the operations. If reverse, go in reverse order.''' + itfuncs = self.funcs[::-1] if reverse else self.funcs + for func in itfuncs: + x = func(x, **kw) + return x + + ## CONVNIENT BEHAVIORS ## + def op(self, opstr): + '''get link opstr from self. (For using dynamically-named links) + + Equivalent to getattr(self, opstr). + Example: + self.op('xup').op('ddydn') is equivalent to self.xup.ddydn + ''' + return getattr(self, opstr) + + def __getitem__(self, i): + return self.funcs[i] + + def __iter__(self): + return iter(self.funcs) + + def __repr__(self): + funcnames = ' '.join([_trim_leading_underscore(f.__name__) for f in self]) + return f'{self.__class__.__name__} at <{hex(id(self))}> with operations: {funcnames}' + + +class ChainCreator(): + """for creating and manipulating a chain.""" + + def __init__(self, name='Chain', base=BaseChain): + self.Chain = type(name, (base,), {'__doc__': BaseChain.__doc__}) + self.links = [] + + def _makeprop(self, link): + Chain = self.Chain + return property(lambda self: Chain(link, *self.funcs)) + + def _makelink(self, func): + return self.Chain(func) + + def link(self, prop, func): + '''adds the (prop, func) link to chain.''' + link = self._makelink(func) + setattr(self.Chain, prop, self._makeprop(link)) + self.links.append(link) + + +def _make_chain(*prop_func_pairs, name='Chain', base=BaseChain, + creator=ChainCreator, **kw__creator): + """create new chain with (propertyname, func) pairs as indicated, named Chain. + (propertyname, func): str, function + name of attribute to associate with performing func when called. + + returns Chain, (list of instances of Chain associated with each func) + """ + Chain = creator(name, base=base, **kw__creator) + for prop, func in prop_func_pairs: + Chain.link(prop, func) + + return tuple((Chain.Chain, Chain.links)) + + +props, funcs = [], [] +for dd in ('', 'dd'): + for x in ('x', 'y', 'z'): + for up in ('up', 'dn'): + opstr = f'{dd}{x}{up}' + props.append(opstr) # e.g. 'xup' + funcs.append(locals()[f'_{opstr}']) # e.g. _xup + +## << HERE IS WHERE WE ACTUALLY PUT THE FUNCTIONS INTO THE MODULE NAMESPACE >> ## +InterpolationChain, links = _make_chain(*zip(props, funcs), name='InterpolationChain') +for prop, link in zip(props, links): + locals()[prop] = link # set function in the module namespace (e.g. xup, ddzdn) + +del props, funcs, dd, x, up, opstr, links, prop, link # << remove "temporary variables" from module namespace + + +# << At this point, the following functions have all been defined in the module namespace: +# xdn, xup, ydn, yup, zdn, zup, ddxdn, ddxup, ddydn, ddyup, ddzdn, ddzup + + +""" ------------------------ StaggerData (wrap methods in a class) ------------------------ """ + + +class StaggerInterface(): + """ + Interface to stagger methods, with defaults implied by an object. + Examples: + self.stagger = StaggerInterface(self) + + # interpolate arr by ddxdn: + self.stagger.ddxdn(arr) + # Note, this uses the defaults: + # pad_mode = '{PAD_PERIODIC}' if self.get_param('periodic_x') else '{PAD_NONPERIODIC}' + # diff = self.dxidxdn + # stagger_kind = self.stagger_kind + + # interpolate arr via xup( ydn(ddzdn(arr)) ), using defaults as appropriate: + self.stagger.xup.ydn.ddzdn(arr) + + Available operations: + xdn, xup, ydn, yup, zdn, zup, + ddxdn, ddxup, ddydn, ddyup, ddzdn, ddzup + Available convenience method: + do(opstr, arr, ...) # << does the operation implied by opstr; equivalent to getattr(self, opstr)(arr, ...) + Example: + self.stagger.do('zup', arr) # equivalent to self.stagger.zup(arr) + + Each method will call the appropriate method from stagger.py. + Additionally, for convenience: + named operations can be chained together. + For example: + self.stagger.xup.ydn.ddzdn(arr) + This does not apply when using the 'do' function. + For dynamically-named chaining, see self.op. + default values are supplied for the extra paramaters: + pad_mode: + periodic = self.get_param('periodic_x') (or y, z) + periodic True --> pad_mode = stagger.PAD_PERIODIC (=='{PAD_PERIODIC}') + periodic False -> pad_mode = stagger.PAD_NONPERIODIC (=='{PAD_NONPERIODIC}') + diff: + self.dxidxup with x --> x, y, or z; up --> up or dn. + stagger_kind: + self.stagger_kind + if the operation is called on a string instead of an array, + first pass the string to a call of self. + E.g. self.xup('r') will do stagger.xup(self('r')) + """ + _PAD_PERIODIC = PAD_PERIODIC + _PAD_NONPERIODIC = PAD_NONPERIODIC + + def __init__(self, obj): + self._obj_ref = weakref.ref(obj) # weakref to avoid circular reference. + prop_func_pairs = [(_trim_leading_underscore(prop), func) for prop, func in _STAGGER_ALIASES.items()] + self._make_bound_chain(*prop_func_pairs, name='BoundInterpolationChain') + + obj = property(lambda self: self._obj_ref()) + + def do(self, arr, opstr, *args, **kw): + '''does the operation implied by opstr (e.g. 'xup', ..., 'ddzdn'). + Equivalent to getattr(self, opstr)(arr, *args, **kw) + ''' + return getattr(self, opstr)(arr, *args, **kw) + + def op(self, opstr): + '''gets the operation which opstr would apply. + For convenience. Equivalent to getattr(self, opstr). + + Can be chained. For example: + self.op('xup').op('ddydn') is equivalent to self.xup.ddydn. + ''' + return getattr(self, opstr) + + def _make_bound_chain(self, *prop_func_pairs, name='BoundChain'): + """create new bound chain, linking all props to same-named attributes of self.""" + Chain, links = _make_chain(*prop_func_pairs, name=name, + base=BoundBaseChain, creator=BoundChainCreator, obj=self) + props, funcs = zip(*prop_func_pairs) + for prop, link in zip(props, links): + setattr(self, prop, link) + + ## __INTERPOLATION_CALL__ ## + # this function will be called whenever an interpolation method is used. + # To edit the behavior of calling an interpolation method, edit this function. + # E.g. here is where to connect properties of obj to defaults for interpolation. + def _pad_modes(self): + '''return dict of padx, pady, padz, with values the appropriate strings for padding.''' + def _booly_to_mode(booly): + return {None: None, True: self._PAD_PERIODIC, False: self._PAD_NONPERIODIC}[booly] + return {f'pad{x}': _booly_to_mode(self.obj.get_param(f'periodic_{x}')) for x in ('x', 'y', 'z')} + + def _diffs(self): + '''return dict of diffx, diffy, diffz, with values the appropriate arrays. + CAUTION: assumes dxidxup == dxidxdn == diffx, and similar for y and z. + ''' + return {f'diff{x}': getattr(self.obj, f'd{x}id{x}up') for x in ('x', 'y', 'z')} + + def _stagger_kind(self): + return {'stagger_kind': self.obj.stagger_kind} + + def __interpolation_call__(self, func, arr, *args__get_var, **kw): + '''call interpolation function func on array arr with the provided kw. + + use defaults implied by self (e.g. padx implied by periodic_x), for any kw not entered. + if arr is a string, first call self(arr, *args__get_var, **kw). + ''' + __tracebackhide__ = True + kw_to_use = {**self._pad_modes(), **self._diffs(), **self._stagger_kind()} # defaults based on obj. + kw_to_use.update(kw) # exisitng kwargs override defaults. + if isinstance(arr, str): + arr = self.obj(arr, *args__get_var, **kw) + return func(arr, **kw_to_use) + + +StaggerInterface.__doc__ = StaggerInterface.__doc__.format(PAD_PERIODIC=PAD_PERIODIC, PAD_NONPERIODIC=PAD_NONPERIODIC) + + +class BoundBaseChain(BaseChain): + """BaseChain structure but bound to a class.""" + + def __init__(self, obj, f_self, *funcs): + self.obj = obj + super().__init__(f_self, *funcs) + + def __call__(self, x, reverse=False, **kw): + '''apply the operations. If reverse, go in reverse order''' + itfuncs = self.funcs[::-1] if reverse else self.funcs + for func in itfuncs: + x = self.obj.__interpolation_call__(func, x, **kw) + return x + + def __repr__(self): + funcnames = ' '.join([_trim_leading_underscore(f.__name__) for f in self]) + return f'<{self.__class__.__name__} at <{hex(id(self))}> with operations: {funcnames}> bound to {self.obj}' + + +class BoundChainCreator(ChainCreator): + """for creating and manipulating a bound chain""" + + def __init__(self, *args, obj=None, **kw): + if obj is None: + raise TypeError('obj must be provided') + self.obj = obj + super().__init__(*args, **kw) + + def _makeprop(self, link): + Chain = self.Chain + return property(lambda self: Chain(self.obj, link, *self.funcs)) + + def _makelink(self, func): + return self.Chain(self.obj, func) diff --git a/helita/sim/synobs.py b/helita/sim/synobs.py index d11a9fd7..af9159b0 100644 --- a/helita/sim/synobs.py +++ b/helita/sim/synobs.py @@ -2,10 +2,11 @@ Set of programs to degrade/convolve synthetic images/spectra to observational conditions """ -import math import os -import scipy.interpolate as interp +import math + import numpy as np +import scipy.interpolate as interp from scipy import ndimage, signal @@ -203,7 +204,7 @@ def img_conv(spec, wave, psf, psfx, conv_type='IRIS_MgII_core', xMm=16.5491, widx = (wave[:] > wcent - 2. * wfwhm) & (wave[:] < wcent + 2. * wfwhm) # filtering function, here set to Gaussian wfilt = gaussian([wcent, wfwhm / (2 * math.sqrt(2 * math.log(2))), - 1., 0.], wave[widx]) + 1., 0.], wave[widx]) wfilt /= np.trapz(wfilt, x=wave[widx]) else: widx = wfilt != 0 @@ -241,9 +242,9 @@ def img_conv(spec, wave, psf, psfx, conv_type='IRIS_MgII_core', xMm=16.5491, if graph: p.subplot(212) p.imshow(np.transpose(nspec), - extent=(0, spec.shape[0] * pix2asec, 0, - spec.shape[1] * pix2asec), - vmin=vmin, vmax=vmax, + extent=(0, spec.shape[0] * pix2asec, 0, + spec.shape[1] * pix2asec), + vmin=vmin, vmax=vmax, interpolation='nearest', cmap=p.cm.gist_gray) p.title('Filter + convolved %s' % (conv_type)) p.xlabel('arcsec') @@ -258,7 +259,9 @@ def get_hinode_psf(wave, psfdir='.'): Returns x scale (in arcsec), and psf (2D array, normalised). """ from astropy.io import fits as pyfits + from ..utils import utilsmath + # Get ideal PSF ipsf = pyfits.getdata(os.path.join(psfdir, 'hinode_ideal_psf_555nm.fits')) ix = pyfits.getdata(os.path.join(psfdir, 'hinode_ideal_psf_scale_555nm.fits')) @@ -321,8 +324,8 @@ def var_conv(var, xMm, psf, psfx, obs='iris_nuv', parallel=False, """ Spatially convolves a single atmos variable. """ - import multiprocessing import ctypes + import multiprocessing global result # some definitions @@ -400,8 +403,8 @@ def imgspec_conv(spec, wave, xMm, psf, psfx, obs='hinode_sp', verbose=False, --Tiago, 20120105 ''' - import multiprocessing import ctypes + import multiprocessing global result # some definitions diff --git a/helita/sim/tests/__init__.py b/helita/sim/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/helita/sim/tests/test_multi3d.py b/helita/sim/tests/test_multi3d.py index d8603e86..7b5d6715 100644 --- a/helita/sim/tests/test_multi3d.py +++ b/helita/sim/tests/test_multi3d.py @@ -2,11 +2,13 @@ Test suite for multi3d.py """ import os -import pytest import tarfile -import numpy as np from shutil import rmtree + +import numpy as np +import pytest from pkg_resources import resource_filename + from helita.sim import multi3d TEST_FILES = ['ie_+0.00_+0.00_+1.00_allnu', 'multi3d.input', 'out_atm', @@ -19,6 +21,7 @@ INPUT_VALUES = {'atmosid': 'falc.5x5x82', 'atom': '../input/atoms/atom.h3', 'convlim': 0.001, 'n_scratch': 10} + def unpack_data(source_tarball, files, output_directory): """Unpack test data to temporary directory.""" assert os.path.isfile(source_tarball), 'Could not find test data files.' @@ -70,7 +73,7 @@ def test_Multi3dOut(): data.set_transition(3, 2) ie = data.readvar('ie') assert np.array_equal(ie[0, 0], ie[-1, -1]) - assert np.isclose(ie[0,0,5::20], + assert np.isclose(ie[0, 0, 5::20], np.array([2.9016188e-05, 1.1707955e-05, 3.8370090e-06, 4.9833211e-06, 1.8675400e-05])).all() diff --git a/helita/sim/tests/test_rh15d.py b/helita/sim/tests/test_rh15d.py index fbf1d7eb..26ae948a 100644 --- a/helita/sim/tests/test_rh15d.py +++ b/helita/sim/tests/test_rh15d.py @@ -4,6 +4,7 @@ """ import numpy as np + from helita.sim import rh15d TMP_ATOM_FILENAME = 'atom.tmp' @@ -16,11 +17,11 @@ " 95785.470 1.00 'CA III 3P6 1SE ' 2 5" ] -TEST_LEVELS_DATA = np.array([( 0. , 2., 'CA II 3P6 4S 2SE', 1, 0), +TEST_LEVELS_DATA = np.array([(0., 2., 'CA II 3P6 4S 2SE', 1, 0), (13650.19, 4., 'CA II 3P6 3D 2DE 3', 1, 1), (13710.88, 6., 'CA II 3P6 3D 2DE 5', 1, 2), (25191.51, 2., 'CA II 3P6 4P 2PO 1', 1, 3), - (25414.4 , 4., 'CA II 3P6 4P 2PO 3', 1, 4), + (25414.4, 4., 'CA II 3P6 4P 2PO 3', 1, 4), (95785.47, 1., 'CA III 3P6 1SE', 2, 5)], dtype=[('energy', ' r.max()] = len(r) - 1 + new_ir[new_r.ravel() < r.min()] = 0 + + return ndimage.map_coordinates(grid, np.array([new_ir, new_it]), + order=order).reshape(new_r.shape) + + +def cartesian2polar(x, y, grid, r, t, order=3): + ''' + Converts cartesian grid to polar grid + ''' + + R, T = np.meshgrid(r, t) + + new_x = R * np.cos(T) + new_y = R * np.sin(T) + + ix = interpolate.interp1d(x, np.arange(len(x)), bounds_error=False) + iy = interpolate.interp1d(y, np.arange(len(y)), bounds_error=False) + + new_ix = ix(new_x.ravel()) + new_iy = iy(new_y.ravel()) + + new_ix[new_x.ravel() > x.max()] = len(x) - 1 + new_ix[new_x.ravel() < x.min()] = 0 + + new_iy[new_y.ravel() > y.max()] = len(y) - 1 + new_iy[new_y.ravel() < y.min()] = 0 + + return ndimage.map_coordinates(grid, np.array([new_ix, new_iy]), + order=order).reshape(new_x.shape) + + +def refine(s, q, factor=2, unscale=lambda x: x): + """ + Given 1D function q(s), interpolate so we have factor x many points. + factor = 2 by default + """ + ds = s[-1]-s[0] + ss = np.arange(factor*len(s)+1)/(factor*len(s))*ds+s[0] + if ds > 0.0: + qq = unscale(np.interp(ss, s, q)) + return ss, qq + elif ds < 0.0: + qq = unscale(np.interp(ss[::-1], s[::-1], q[::-1])) + qq = qq[::-1] + return ss, qq + + +''' --------------------- restore attrs --------------------- ''' + + +def maintain_attrs(*attrs): + '''return decorator which restores attrs of obj after running function. + It is assumed that obj is the first arg of function. + ''' + def attr_restorer(f): + @functools.wraps(f) + def f_but_maintain_attrs(obj, *args, **kwargs): + '''f but attrs are maintained.''' + __tracebackhide__ = True + with MaintainingAttrs(obj, *attrs): + return f(obj, *args, **kwargs) + return f_but_maintain_attrs + return attr_restorer + + +class MaintainingAttrs(): + '''context manager which restores attrs of obj to their original values, upon exit.''' + + def __init__(self, obj, *attrs): + self.obj = obj + self.attrs = attrs + + def __enter__(self): + self.memory = dict() + for attr in self.attrs: + if hasattr(self.obj, attr): + self.memory[attr] = getattr(self.obj, attr) + + def __exit__(self, exc_type, exc_value, traceback): + for attr, val in self.memory.items(): + setattr(self.obj, attr, val) + + +def with_attrs(**attrs_and_values): + '''return decorator which sets attrs of object before running function then restores them after. + It is assumed that obj is the first arg of function. + ''' + def attr_setter_then_restorer(f): + @functools.wraps(f) + def f_but_set_then_restore_attrs(obj, *args, **kwargs): + '''f but attrs are set beforehand then restored afterward.''' + __tracebackhide__ = True + with MaintainingAttrs(obj, *attrs_and_values.keys()): + for attr, value in attrs_and_values.items(): + setattr(obj, attr, value) + return f(obj, *args, **kwargs) + return f_but_set_then_restore_attrs + return attr_setter_then_restorer + + +class EnterDir: + '''context manager for remembering directory. + upon enter, cd to directory (default os.curdir, i.e. no change in directory) + upon exit, original working directory will be restored. + + For function decorator, see QOL.maintain_cwd. + ''' + + def __init__(self, directory=os.curdir): + self.cwd = os.path.abspath(os.getcwd()) + self.directory = directory + + def __enter__(self): + os.chdir(self.directory) + + def __exit__(self, exc_type, exc_value, traceback): + os.chdir(self.cwd) + + +RememberDir = EnterDir # alias +EnterDirectory = EnterDir # alias + + +def with_dir(directory): + '''returns a function decorator which: + - changes current directory to . + - runs function + - changes back to original directory. + ''' + def decorator(f): + @functools.wraps(f) + def f_but_enter_dir(*args, **kwargs): + with EnterDir(directory): + return f(*args, **kwargs) + return f_but_enter_dir + return decorator + + +withdir = with_dir # alias + +# define a new function decorator, maintain_cwd, which maintains current directory: +maintain_cwd = with_dir(os.curdir) + +maintain_directory = maintain_cwd # alias +maintain_dir = maintain_cwd # alias + +''' --------------------------- info about arrays --------------------------- ''' + + +def stats(arr, advanced=True, finite_only=True): + '''return dict with min, mean, max. + if advanced, also include: + std, median, size, number of non-finite points (e.g. np.inf or np.nan). + if finite_only: + only treat the finite parts of arr; ignore nans and infs. + ''' + arr = arr_orig = np.asanyarray(arr) + if finite_only or advanced: # then we need to know np.isfinite(arr) + finite = np.isfinite(arr) + n_nonfinite = arr.size - np.count_nonzero(finite) + if finite_only and n_nonfinite > 0: + arr = arr[finite] + result = dict(min=np.nanmin(arr), mean=np.nanmean(arr), max=np.nanmax(arr)) + if advanced: + result.update(dict(std=np.nanstd(arr), median=np.nanmedian(arr), + size=arr.size, nonfinite=n_nonfinite)) + return result + + +def print_stats(arr_or_stats, advanced=True, fmt='{: .2e}', sep=' | ', return_str=False, **kw__print): + '''calculate and prettyprint stats about array. + arr_or_stats: dict (stats) or array-like. + dict --> treat dict as stats of array. + array --> calculate stats(arr, advanced=advanced) + fmt: str + format string for each stat. + sep: str + separator string between each stat. + return_str: bool + whether to return string instead of printing. + ''' + fmtkey = '{:>6s}' if '\n' in sep else '{}' + _stats = arr_or_stats if isinstance(arr_or_stats, dict) else stats(arr_or_stats, advanced=advanced) + result = sep.join([f'{fmtkey.format(key)}: {fmt.format(val)}' for key, val in _stats.items()]) + return result if return_str else print(result, **kw__print) + + +def finite_op(arr, op): + '''returns op(arr), hitting only the finite values of arr. + if arr has only finite values, + finite_op(arr, op) == op(arr). + if arr has some nonfinite values (infs or nans), + finite_op(arr, op) == op(arr[np.isfinite(arr)]) + ''' + arr = np.asanyarray(arr) + finite = np.isfinite(arr) + if np.count_nonzero(finite) < finite.size: + return op(arr[finite]) + else: + return op(arr) + + +def finite_min(arr): + '''returns min of all the finite values of arr.''' + return finite_op(arr, np.min) + + +def finite_mean(arr): + '''returns mean of all the finite values of arr.''' + return finite_op(arr, np.mean) + + +def finite_max(arr): + '''returns max of all the finite values of arr.''' + return finite_op(arr, np.max) + + +def finite_std(arr): + '''returns std of all the finite values of arr.''' + return finite_op(arr, np.std) + + +def finite_median(arr): + '''returns median of all the finite values of arr.''' + return finite_op(arr, np.median) + + +''' --------------------------- manipulating arrays --------------------------- ''' + + +def slicer_at_ax(slicer, ax): + '''return tuple of slices which, when applied to an array, takes slice along axis number . + slicer: a slice object, or integer, or tuple of integers. + slice or integer -> use slicer directly. + tuple of integers -> use slice(*slicer). + ax: a number (negative ax not supported here). + ''' + try: + slicer[0] + except TypeError: # slicer is a slice or an integer. + pass + else: # assume slicer is a tuple of integers. + slicer = slice(*slicer) + return (slice(None),)*ax + (slicer,) + + +''' --------------------------- strings --------------------------- ''' + + +def pretty_nbytes(nbytes, fmt='{:.2f}'): + '''returns nbytes as a string with units for improved readability. + E.g. pretty_nbytes(20480, fmt='{:.1f}') --> '10.0 kB'. + ''' + n_u_bytes = nbytes + u = '' + for u_next in ['k', 'M', 'G', 'T']: + n_next = n_u_bytes / 1024 + if n_next < 1: + break + else: + n_u_bytes = n_next + u = u_next + return '{fmt} {u}B'.format(fmt=fmt, u=u).format(n_u_bytes) + + +''' --------------------------- import error handling --------------------------- ''' + + +class ImportFailedError(ImportError): + pass + + +class ImportFailed(): + '''set modules which fail to import to be instances of this class; + initialize with modulename, additional_error_message. + when attempting to access any attribute of the ImportFailed object, + raises ImportFailedError('. '.join(modulename, additional_error_message)). + Also, if IMPORT_FAILURE_WARNINGS, make warning immediately when initialized. + + Example: + try: + import zarr + except ImportError: + zarr = ImportFailed('zarr', 'This module is required for compressing data.') + + zarr.load(...) # << attempt to use zarr + # if zarr was imported successfully, it will work fine. + # if zarr failed to import, this error will be raised: + ImportFailedError: zarr. This module is required for compressing data. + ''' + + def __init__(self, modulename, additional_error_message=''): + self.modulename = modulename + self.additional_error_message = additional_error_message + if IMPORT_FAILURE_WARNINGS: + warnings.warn(f'Failed to import module {modulename}.{additional_error_message}') + + def __getattr__(self, attr): + str_add = str(self.additional_error_message) + if len(str_add) > 0: + str_add = '. ' + str_add + raise ImportFailedError(self.modulename + str_add) + + +def boring_decorator(*args, **kw): + '''returns the identity wrapper (returns the function it wraps, without any changes). + This is useful when importing function wrappers; use boring_decorator if ImportError occurs. + ''' + def boring_wrapper(f): + return f + return boring_wrapper + + +''' --------------------------- vector rotations --------------------------- ''' + + +def rotation_align(vecs_source, vecs_destination): + ''' Return the rotation matrix which aligns vecs_source to vecs_destination. + + vecs_source, vecs_destination: array of vectors, or length 3 list of scalars + array of 3d vectors for source, destination. + Both will be cast to numpy arrays via np.asarray. + The inputs can be any number of dimensions, + but the last dimension should represent x,y,z, + E.g. the shape should be (..., 3). + NOTE: the np.stack function may be helpful in constructing this input. + E.g. for Bx, By, Bz, use np.stack([Bx, By, Bz], axis=-1). + This works for any same-shaped Bx, By, Bz arrays or scalars. + + Note: a divide by 0 error indicates that at least one of the rotations will be -I; + i.e. the vectors were originally parallel, but in opposite directions. + + Returns: array which, when applied to vecs_source, aligns them with vecs_destination. + The result will be an array of 3x3 matrices. + For applying the array, see rotation_apply(), or use: + np.sum(result * np.expand_dims(vec, axis=(-2)), axis=-1) + + Example: + # Bx, By, Bz each have shape (100, 70, 50), and represent the x, y, z components of B. + # ux, uy, uz each have shape (100, 70, 50), and represent the x, y, z components of u. + B_input = np.stack([Bx, By, Bz], axis=-1) # >> B_input has shape (100, 70, 50, 3) + u_input = np.stack([ux, uy, uz], axis=-1) # >> u_input has shape (100, 70, 50, 3) + d_input = [0, 0, 1] # "rotate to align with z" + result = rotation_align(B_input, d_input) # >> result has shape (100, 70, 50, 3, 3) + # << result tells how to rotate such that B aligns with z + rotation_apply(result, B_input) + # matrix of [Bx', By', Bz'], which has Bx' == By' == 0, Bz' == |B| + rotation_apply(result, u_input) + # matrix of [ux', uy', uz'], where u' is in the coord. system with B in the z direction. + + # instead of rotation_apply(v1, v2), can use np.sum(v1 * np.expand_dims(v2, axis=(-2)), axis=-1). + + Rotation algorithm based on Rodrigues's rotation formula. + Adapted from https://stackoverflow.com/a/59204638 + ''' + # bookkeeping - whether to treat as masked arrays + if np.ma.isMaskedArray(vecs_source) or np.ma.isMaskedArray(vecs_destination): + stack = np.ma.stack + asarray = np.ma.asarray + else: + stack = np.stack + asarray = np.asarray + # bookkeeping - dimensions + vec1 = asarray(vecs_source) + vec2 = asarray(vecs_destination) + vec1 = np.expand_dims(vec1, axis=tuple(range(0, vec2.ndim - vec1.ndim))) + vec2 = np.expand_dims(vec2, axis=tuple(range(0, vec1.ndim - vec2.ndim))) + # magnitudes, products + def mag(u): return np.linalg.norm(u, axis=-1, keepdims=True) # magnitude of u with vx, vy, vz = v[...,0], v[...,1], v[...,2] + a = vec1 / mag(vec1) + b = vec2 / mag(vec2) + + def cross(a, b): + '''takes the cross product along the last axis. + np.cross(a, b, axis=-1) can't handle masked arrays so we write out the cross product explicitly here. + ''' + ax, ay, az = a[..., 0], a[..., 1], a[..., 2] + bx, by, bz = b[..., 0], b[..., 1], b[..., 2] + rx = ay * bz - az * by + ry = az * bx - ax * bz + rz = ax * by - ay * bx + return stack([rx, ry, rz], axis=-1) + + v = cross(a, b) # a x b, with axis -1 looping over x, y, z. + c = np.sum(a * b, axis=-1) # a . b, with axis -1 looping over x, y, z. + # building kmat + v_x, v_y, v_z = (v[..., i] for i in (0, 1, 2)) + zero = np.zeros_like(v_x) + kmat = stack([ + stack([zero, -v_z, v_y], axis=-1), + stack([v_z, zero, -v_x], axis=-1), + stack([-v_y, v_x, zero], axis=-1), + ], axis=-2) + _I = np.expand_dims(np.eye(3), axis=tuple(range(0, kmat.ndim - 2))) + _c = np.expand_dims(c, axis=tuple(range(np.ndim(c), kmat.ndim))) # _c = c with dimensions added appropriately. + # implementation of Rodrigues's formula + result = _I + kmat + np.matmul(kmat, kmat) * 1 / (1 + _c) # ((1 - c) / (s ** 2)) # s**2 = 1 - c**2 (s ~ sin, c ~ cos) # s := mag(v) + + # handle the c == -1 case. wherever c == -1, vec1 and vec2 are parallel with vec1 == -1 * vec2. + flipvecs = (c == -1) + result[flipvecs, :, :] = -1 * np.eye(3) + return result + + +def rotation_apply(rotations, vecs): + '''apply the rotations to vecs. + + rotations: array of 3x3 rotation matrices. + should have shape (..., 3, 3) + vecs: array of vectors. + should have shape (..., 3) + + shapes should be consistent, + E.g. rotations with shape (10, 7, 3, 3), vecs with shape (10, 7, 3). + + returns rotated vectors. + ''' + return np.sum(rotations * np.expand_dims(vecs, axis=(-2)), axis=-1) + + +''' --------------------------- plotting --------------------------- ''' + + +def extent(xcoords, ycoords): + '''returns extent (to go to imshow), given xcoords, ycoords. Assumes origin='lower'. + Use this method to properly align extent with middle of pixels. + (Noticeable when imshowing few enough pixels that individual pixels are visible.) + + xcoords and ycoords should be arrays. + (This method uses their first & last values, and their lengths.) + + returns extent == np.array([left, right, bottom, top]). + ''' + Nx = len(xcoords) + Ny = len(ycoords) + dx = (xcoords[-1] - xcoords[0])/Nx + dy = (ycoords[-1] - ycoords[0])/Ny + return np.array([*(xcoords[0] + np.array([0 - dx/2, dx * Nx + dx/2])), + *(ycoords[0] + np.array([0 - dy/2, dy * Ny + dy/2]))]) + + +''' --------------------------- custom versions of builtins --------------------------- ''' + + +class GenericDict(collections.abc.MutableMapping): + '''dict allowing for non-hashable keys. + Slower than built-in dict, but behaves like a built-in dict. + + equals: None or function. + function used to test equality of keys. + None --> use the default: lambda oldkey, newkey: oldkey==newkey + ''' + + def __init__(self, iterable=(), equals=None): + if equals is None: + self.equals = lambda old, new: old == new + else: + self.equals = equals + + self._keys = collections.deque() + self._values = collections.deque() + + for key, value in iterable: + self[key] = value + for k, v in iterable: + self._keys.append(k) + self._values.append(v) + + ## HELPER METHODS ## + def _index(self, key): + '''return index of key in self._keys. + uses self.equals to check equality of keys. + raise ValueError if key not in self. + ''' + for i, k in enumerate(self._keys): + if self.equals(k, key): + return i + raise ValueError('Not found in self.keys(): ', key) + + def _find(self, key): + '''return index of key in self. + uses self.equals to check equality of keys. + return None if key not in self. + ''' + try: + return self._index(key) + except ValueError: + return None + + ## REQUIRED METHODS ## + def __getitem__(self, key): + i = self._find(key) + if i is None: + raise KeyError(key) + else: + return self._values[i] + + def __setitem__(self, key, value): + i = self._find(key) + if i is None: # new key + self._keys.append(key) + self._values.append(value) + else: # already existing key + self._keys[i] = key + self._values[i] = value + + def __delitem__(self, key): + i = self._find(key) + if i is None: + raise KeyError(key) + else: + del self._keys[i] + del self._values[i] + + def __iter__(self): + return iter(self._keys) + + def __len__(self): + return len(self._keys) + + ## MIXINS ## + # from parent, we automatically get these mixins after defining the methods above: + # __contains__, keys, items, values, get, __eq__, __ne__, + # pop, popitem, clear, update, setdefault, + + ## PRETTY ## + def __repr__(self): + return '{' + ', '.join([f'{key}: {value}' for key, value in self.items()]) + '}' + + +def GenericDict_with_equals(equals): + '''return constructor for GenericDict, with 'equals' set to the function provided, by default.''' + def _GenericDict_create(*args, equals=equals, **kwargs): + '''return GenericDict object.''' + return GenericDict(*args, equals=equals, **kwargs) + return _GenericDict_create diff --git a/helita/sim/units.py b/helita/sim/units.py new file mode 100644 index 00000000..f43124ab --- /dev/null +++ b/helita/sim/units.py @@ -0,0 +1,1761 @@ +""" +Created by Sam Evans on Apr 27 2021 + +purpose: + 1) provides HelitaUnits class + 2) enable "units" mode for DataClass objects (e.g. BifrostData, EbysusData). + +TL;DR: + Use obj.get_units() to see the units for the most-recent quantity that obj got with get_var(). + +The idea is to: +- have all load_quantities functions return values in simulation units. +- have a way to lookup how to convert any variable to a desired set of units. +- have an attribute of the DataClass object tell which system of units we want output. + +EXAMPLE USAGE: + dd = EbysusData(...) + b_squared = dd.get_var('b2') + dd.get_units('si') # 'si' by default; other option is cgs. + EvaluatedUnits(factor=1.2566e-11, name='T^{2}') + b_squared * dd.get_units('si').factor # == magnetic field squared, in SI units. + + +State of the code right now: +- The "hesitant execution" of methods in here means that if you do not call obj.get_units() + or any other units-related functions, then nothing in units.py should cause a crash. + +TODO: +- have a units_system flag attribute which allows to convert units at top level automatically. + - (By default the conversion will be off.) + - Don't tell Juan about this attribute because he won't like it ;) but he doesn't ever have to use it! + +USER FRIENDLY GUIDE + The way to input units is to put them in the documentation segment of get_quant functions. + + There are a few different ways to enter the units, and you can enter as much or as little info as you want. + The available keys to enter are: + ----- AVAILABLE KEYS ----- + usi_f = function which tells >> si << units. (given info about obj) + ucgs_f = function which tells >> cgs << units. (given info about obj) + uni_f = function which tells >> any << units. (given info about obj, and unit system) + usi_name = UnitsExpression which gives name for >> si << units. + ucgs_name = UnitsExpression which gives name for >> cgs << units. + uni_name = UnitsExpression which gives name for >> any << units. (given info about unit system) + usi = UnitsTuple giving (function, name) for >> si << units. (given info about obj) + ucgs = UnitsTuple giving (function, name) for >> cgs << units. (given info about obj) + uni = UnitsTuple giving (function, name) for >> any << units. (given info about obj, and unit system) + + You should not try to build your own functions from scratch. + Instead, use the building blocks from units.py in order to fill in the units details. + ----- BUILDING BLOCKS ----- + First, it is recommended to import the following directly into the local namespace, for convenience: + from helita.sim.units import ( + UNI, USI, UCGS, UCONST, + Usym, Usyms, UsymD, + U_TUPLE, + DIMENSIONLESS, UNITS_FACTOR_1, NO_NAME, + UNI_length, UNI_time, UNI_mass + ) + Here is a guide to these building blocks. + ----- FUNCTION BUILDERS ----- + > UCONST: access the exact attribute provided here, from obj.uni. + Example: UCONST.ksi_b --> obj.uni.ksi_b + > USI: access si units from obj.uni. (prepend 'usi_' to the attribute here) + Example: (USI.r * USI.l) --> (obj.uni.usi_r * obj.uni.usi_l) + > UCGS: access cgs units from obj.uni. (prepend 'u_' to the attribute here) + Example: (UCGS.r * UCGS.l) --> (obj.uni.u_r * obj.uni.u_l) + > UNI: when units are evaluated, UNI works like USI or UCGS, depending on selected unit system. + + These can be manipulated using +, -, *, /, ** in the intuitive ways. + Example: UCGS.r ** 3 / (UCONST.amu * UCGS.t) --> (obj.uni.u_r)**3 / (obj.uni.amu * obj.uni.u_t) + (Note + and - are not fully tested, and probably should not be used for units anyways.) + + Also, the attributes are "magically" transferred to obj.uni, so any attribute can be entered. + Example: USI.my_arbitrary_attribute --> obj.uni.usi_my_artbitrary_attribute + + ----- NAME BUILDERS ----- + The tools here build UnitsExpression objects, which can be manipulated intuitively using *, /, **. + UnitsExpression objects provide a nice-looking string for units when converted to string. + Example: str(Usym('m') ** 3 / (Usym('m') * Usym('s')) --> 'm^{2} / s' + + > Usym: gives a UnitsExpression representing the entered string (to the first power) + > Usyms: generate multiple Usym at once. + Example: Usyms('m', 's', 'g') is equivalent to (Usym('m'), Usym('s'), Usym('g')) + > UsymD: gives a dict of UnitsExpressions; the name to use is picked when unit system info is entered. + Example: UsymD(usi='m', ucgs='cm') --> + Usym('m') when unit system is 'si' + Usym('cm') when unit system is 'cgs'. + The keys to use for UsymD are always the keys usi, ucgs. + + ----- TUPLE BUILDER ----- + > U_TUPLE: turns function and name into a UnitsTuple. Mainly for convenience. + The following are equivalent (for any ufunc and uname): + docvar(..., usi=U_TUPLE(ufunc, uname) + docvar(..., usi_f=ufunc, usi_name=uname) + This also applies similarly to ucgs and uni (in place of usi). + + UnitsTuple objects can be manipulated intuitively using *, /, **. + Example: U_TUPLE(fA, nameA) ** 3 / (U_TUPLE(fB, nameB) * U_TUPLE(fC, nameC)) + --> U_TUPLE(fA**3 / (fB * fC), nameA**3 / (nameB * nameC)) + + ----- QUANT CHILDREN ----- + For some units it is necessary to know the units of the "children" which contribute to the quant. + For example, the units of AratB (== A/B) will be the units of A divided by the units of B. + (This is probably only necessary for quantities in load_arithmetic_quantities) + + This can be accomplished using some special attributes from the function builders UNI, USI, or UCGS: + > quant_child_f(i) or qcf(i) + gives the units function for the i'th-oldest child. + > quant_child_name(i) or qcn(i) + gives the UnitsExpression for the i'th-oldest child. + > quant_child(i) or qc(i) + gives the UnitsTuple for the i'th-oldest child. + Example: + for the AratB example above, we can enter for rat: + docvar('rat', ..., uni=UNI.quant_child(0) / UNI.quant_child(1)) + assuming the code for AratB gets A first, then gets B, and + gets no other vars (at that layer of the code; i.e. ignoring internal calls to + get_var while getting A and/or B), then this will cause the units for AratB to + evaluate to (units for A) / (units for B). + + ----- CONVENIENCE TOOLS ----- + The remanining imported tools are there for convenience. + > NO_NAME: an empty UnitsExpression. + > UNITS_FACTOR_1: a units function which always returns 1 when units are evaluated. + Example: docvar('tg', ..., uni_f=UNITS_FACTOR_1, uni_name=Usym('K')) + # get_var('tg') returns temperature in kelvin, so the conversion factor is 1 and the name is 'K'. + > DIMENSIONLESS: UnitsTuple(UNITS_FACTOR_1, NO_NAME) + Example: docvar('beta', ..., uni=DIMENSIONLESS) + # get_var('beta') returns plasma beta, a dimensionless quantities, so we use DIMENSIONLESS. + > UNI_length: UnitsTuple(UNI.l, UsymD(usi='m', ucgs='cm')) + UNI_length evaluates to the correct units and name for length in either unit system. + > UNI_time: UnitsTuple(UNI.t, Usym('s')) + UNI_time evaluates to the correct units and name for time in either unit system. + > UNI_mass: UnitsTuple(UNI.m, UsymD(usi='kg', ucgs='g')) + UNI_mass evaluates to the correct units and name for mass in either unit system. + + To get started it is best to use this guide as a reference, + and look at the existing examples in the load_..._quantities files. + + If it seems overwhelming, don't worry too much. + The units.py "add-on" is designed to "execute hesitantly". + Meaning, the units are not actually being evaluated until they are told to be. + So, if you enter something wrong, or enter incomplete info, it will only affect + code which actively tries to get the relevant units. +""" + +import weakref +# import built-ins +import operator +import collections + +# import external public modules +import numpy as np + +# import internal modules +from . import tools + +''' ----------------------------- Set Defaults ----------------------------- ''' + +# whether to hide tracebacks from internal funcs in this file when showing error traceback. +HIDE_INTERNAL_TRACEBACKS = True + +# UNIT_SYSTEMS are the allowed names for units_output. +# units_output will dictate the units for the output of get_var. +# additionally, get_units will give (1, units name) if units_output matches the request, +# or raise NotImplementedError if units_output is not 'simu' and does not match the request. +# For example: +# for obj.units_output='simu', get_units('si') tells (conversion to SI, string for SI units name) +# for obj.units_output='si', get_units('si') tells (1, string for SI units name) +# for obj.units_output='si', get_units('cgs') raises NotImplementedError. +UNIT_SYSTEMS = ('simu', 'si', 'cgs') + + +def ASSERT_UNIT_SYSTEM(value): + assert value in UNIT_SYSTEMS, f"expected unit system from {UNIT_SYSTEMS}, but got {value}" + + +def ASSERT_UNIT_SYSTEM_OR(value, *alternates): + VALID_VALUES = (*alternates, *UNIT_SYSTEMS) + assert value in VALID_VALUES, f"expected unit system from {UNIT_SYSTEMS} or value in {alternates}, but got {value}" + +# UNITS_OUTPUT property + + +def UNITS_OUTPUT_PROPERTY(internal_name='_units_output', default='simu'): + '''creates a property which manages units_output. + uses the internal name provided, and returns the default if property value has not been set. + + only allows setting of units_output to valid names (as determined by helita.sim.units.UNIT_SYSTEMS). + ''' + + def get_units_output(self): + return getattr(self, internal_name, default) + + def set_units_output(self, value): + '''sets units output to value.lower()''' + try: + value = value.lower() + except AttributeError: + pass # the error below ("value isn't in UNIT_SYSTEMS") will be more elucidating than raising here. + ASSERT_UNIT_SYSTEM(value) + setattr(self, internal_name, value) + + doc = \ + f"""Tells which unit system to use for output of get_var. Options are: {UNIT_SYSTEMS}. + 'simu' --> simulation units. This is the default, and requires no "extra" unit conversions. + 'si' --> si units. + 'cgs' --> cgs units. + """ + + return property(fset=set_units_output, fget=get_units_output, doc=doc) + + +# for ATF = AttrsFunclike(..., format_attr=None, **kw__entered), +# if kw__entered[ATTR_FORMAT_KWARG] (:=kw_fmt) exists, +# for any required kwargs which haven't been entered, +# try to use the attribute of obj: kw_fmt(kwarg). +# (Instead of trying to use the attribute of obj: kwarg.) +ATTR_FORMAT_KWARG = '_attr_format' + +# when doing any of the quant_child methods in UnitsFuncBuilder, +# use UNITS_KEY_KWARG to set units_key, unless self.units_key is set (e.g. at initialization). +# This affects UNI (for which self.units_key is None), but not USI nor UCGS. +UNITS_KEY_KWARG = '_units_key' + +# UNITS_MODES stores the internal ("behind-the-scenes") info for unit conversion modes. +# units_key = key in which UnitsTuple is stored in vardict; +# units_tuple = obj.vardict[metaquant][typequant][quant][units_key] +# attr_format = value to pass to ATTR_FORMAT_KWARG. (See ATTR_FORMAT_KWARG above for more info.) +UNITS_MODES = \ + { + # mode : (units_key, attr_format) + 'si': ('usi', 'usi_{}'), + 'cgs': ('ucgs', 'u_{}'), + } + +# UNITS_UNIVERSAL_KEY is a key which allows to use only one UnitsTuple to represent multiple unit systems. +# In vardict when searching for original units_key, if it is not found, we will also search for this key; +# i.e. if vardict[metaquant][typequant][quant][UNITS_UNIVERSAL_KEY] (:= this_unit_tuple) exists, +# then we will call this_unit_tuple(obj.uni, obj, ATTR_FORMAT_KWARG = units_key) +# For example, a velocity is represented by obj.uni.usi_u or obj.uni.u_u, but these are very similar. +# So, instead of setting usi and ucgs separately for docvar('uix'), we can just set uni: +# docvar('uix', 'x-component of ifluid velocity', uni = UNI.l) +UNITS_UNIVERSAL_KEY = 'uni' + +# UNITS_F_KEY +UNITS_KEY_F = '{}_f' + +# UNITS_NAME_KEY +UNITS_KEY_NAME = '{}_name' + + +''' ============================= Helita Units ============================= ''' + + +class HelitaUnits(object): + '''stores units as attributes. + + units starting with 'u_' are in cgs. starting with 'usi_' are in SI. + Convert to these units by multiplying data by this factor. + Example: + r = obj.get_var('r') # r = mass density / (simulation units) + rcgs = r * obj.uni.u_r # rcgs = mass density / (cgs units, i.e. (g * cm^-3)) + rsi = r * obj.uni.usi_r # rsi = mass density / (si units, i.e. (kg * m^-3)) + + all units are uniquely determined by the following minimal set of units: + (length, time, mass density, gamma) + + you can access documentation on the units themselves via: + self.help(). (for BifrostData object obj, do obj.uni.help()) + this documentation is not very detailed, but at least tells you + which physical quantity the units are for. + + PARAMETERS + ---------- + u_l, u_t, u_r, gamma: + values for these units. + parent: None, or HelitaData object (e.g. BifrostData object) + the object to which these units are associated. + units_output: None, 'simu', 'cgs', or 'si' + unit system for output of self(ustr). E.g. self('r') --> + if 'si' --> self.usi_r + if 'cgs' --> self.u_r + if 'simu' --> 1 + if parent is provided and has 'units_output' attribute, + self.units_output will default to parent.units_output. + units_input: 'simu', 'cgs', or 'si' + unit system to convert from, for self(ustr). + E.g. self('r', units_output='simu', units_input='si') = 1/self.usi_r, + since r [si units] * 1/self.usi_r == r [simu units] + ''' + + BASE_UNITS = ['u_l', 'u_t', 'u_r', 'gamma'] + + def __init__(self, u_l=1.0, u_t=1.0, u_r=1.0, gamma=1.6666666667, verbose=False, + units_output=None, units_input='simu', parent=None): + '''get units from file (by reading values of u_l, u_t, u_r, gamma).''' + self.verbose = verbose + + # base units + self.docu('l', 'length') + self.u_l = u_l + self.docu('t', 'time') + self.u_t = u_t + self.docu('r', 'mass density') + self.u_r = u_r + self.docu('gamma', 'adiabatic constant') + self.gamma = gamma + + # bookkeeping + self._parent_ref = (lambda: None) if parent is None else weakref.ref(parent) # weakref to avoid circular reference. + self.units_output = units_output + self.units_input = units_input + + # set many unit constants (e.g. cm_to_m, amu, gsun). + tools.globalvars(self) + # initialize unit conversion factors, and values of some constants in [simu] units. + self._initialize_extras() + # initialize constant_lookup, the lookup table for constants in various unit systems. + self._initialize_constant_lookup() + + ## PROPERTIES ## + parent = property(lambda self: self._parent_ref()) + + @property + def units_output(self): + '''self(ustr) * value [self.units_input units] == value [self.units_output units] + if None, use the value of self.parent.units_output instead. + ''' + result = getattr(self, '_units_output', None) + if result is None: + if self.parent is None: + raise AttributeError('self.units_output=None and cannot guess value from parent.') + else: + result = self.parent.units_output + return result + + @units_output.setter + def units_output(self, value): + ASSERT_UNIT_SYSTEM_OR(value, None) + self._units_output = value + + @property + def units_input(self): + ''''self(ustr) * value [self.units_input units] == value [self.units_output units]''' + return getattr(self, '_units_input', 'simu') + + @units_input.setter + def units_input(self, value): + ASSERT_UNIT_SYSTEM(value) + self._units_input = value + + @property + def doc_units(self): + '''dictionary of documentation about the unit conversion attributes of self.''' + if not hasattr(self, '_doc_units'): + self._doc_units = dict() + return self._doc_units + + @property + def doc_constants(self): + '''dictionary of documentation about the constants attributes of self.''' + if not hasattr(self, '_doc_constants'): + self._doc_constants = dict() + return self._doc_constants + + @property + def doc_all(self): + '''dictionary of documentation about all attributes of self.''' + return {**self.doc_units, **self.doc_constants} + + ## GETTING UNITS ## + def __call__(self, ustr, units_output=None, units_input=None): + '''returns factor based on ustr and unit systems. + + There are two modes for this function: + 1) "conversion mode" + self(ustr) * value [units_input units] == value [units_output units] + 2) "constants mode" + self(ustr) == value of the relevant constant in [units_output] system. + The mode will be dermined by ustr. + E.g. 'r', 'u', 'nq' --> conversion mode. + E.g. 'kB', 'amu', 'q_e' --> constants mode. + raises ValueEror if ustr is unrecognized. + + ustr: string + tells unit dimensions of value, or which constant to get. + e.g. ustr='r' <--> mass density. + See self.help() for a (non-extensive) list of more options. + units_output: None, 'simu', 'si', or 'cgs' + unit system for output of self(ustr) * value. + if None, use self.units_output. + (Note: self.units_output defaults to parent.units_output) + units_input: None, 'simu', 'si', or 'cgs' + unit system for input value; output will be self(ustr) * value. + if None, use self.units_input. + (Note: self.units_input always defaults to 'simu') + ''' + try: + self._constant_name(ustr) + except KeyError: + conversion_mode = True + else: + conversion_mode = False + if conversion_mode: + if not self._unit_exists(ustr): + raise ValueError(f'units do not exist: u_{ustr}, usi_{ustr}. ' + + f'And, {ustr} is not a constant from constant_lookup.') + simu_to_out = self.get_conversion_from_simu(ustr, units_output) + if units_input is None: + units_input = self.units_input + ASSERT_UNIT_SYSTEM(units_input) + simu_to_in = self.get_conversion_from_simu(ustr, units_input) + out_from_in = simu_to_out / simu_to_in # in_to_simu = 1 / simu_to_in. + return out_from_in + else: # "constants mode" + try: + result = self.get_constant(ustr, units_output) + except KeyError: + ustr_not_found_errmsg = f'failed to determine units for ustr={ustr}.' + raise ValueError(ustr_not_found_errmsg) + else: + return result + + def get_conversion_from_simu(self, ustr, units_output=None): + '''get conversion factor from simulation units, to units_system. + ustr: string + tells unit dimensions of value. + e.g. ustr='r' <--> mass density. + See self.help() for a (non-extensive) list of more options. + units_output: None, 'simu', 'si', or 'cgs' + unit system for output. Result converts from 'simu' to units_output. + if None, use self.units_output. + ''' + if units_output is None: + units_output = self.units_output + ASSERT_UNIT_SYSTEM(units_output) + if units_output == 'simu': + return 1 + elif units_output == 'si': + return getattr(self, f'usi_{ustr}') + elif units_output == 'cgs': + return getattr(self, f'u_{ustr}') + else: + raise NotImplementedError(f'units_output={units_output}') + + def get_constant(self, constant_name, units_output=None): + '''gets value of constant_name in unit system [units_output]. + constant_name: string + name of constant to get. + e.g. 'amu' <--> value of one atomic mass unit. + units_output: None, 'simu', 'si', or 'cgs' + unit system for output. Result converts from 'simu' to units_output. + if None, use self.units_output. + ''' + if units_output is None: + units_output = self.units_output + ASSERT_UNIT_SYSTEM(units_output) + cdict = self.constant_lookup[constant_name] + ckey = cdict[units_output] + return getattr(self, ckey) + + ## INITIALIZING UNITS AND CONSTANTS ## + def _initialize_extras(self): + '''initializes all the units other than the base units.''' + import scipy.constants as const # import here to reduce overhead of the module. + from astropy import constants as aconst # import here to reduce overhead of the module. + + # set cgs units + self.u_u = self.u_l / self.u_t + self.u_p = self.u_r * (self.u_u)**2 # Pressure [dyne/cm2] + self.u_kr = 1 / (self.u_r * self.u_l) # Rosseland opacity [cm2/g] + self.u_ee = self.u_u**2 + self.u_e = self.u_p # energy density units are the same as pressure units. + self.u_te = self.u_e / self.u_t * self.u_l # Box therm. em. [erg/(s ster cm2)] + self.u_n = 3.00e+10 # Density number n_0 * 1/cm^3 + self.pi = const.pi + self.u_b = self.u_u * np.sqrt(4. * self.pi * self.u_r) + self.u_tg = (self.m_h / self.k_b) * self.u_ee + self.u_tge = (self.m_e / self.k_b) * self.u_ee + + # set si units + self.usi_t = self.u_t + self.usi_l = self.u_l * const.centi # 1e-2 + self.usi_r = self.u_r * const.gram / const.centi**3 # 1e-4 + self.usi_u = self.usi_l / self.u_t + self.usi_p = self.usi_r * (self.usi_u)**2 # Pressure [N/m2] + self.usi_kr = 1 / (self.usi_r * self.usi_l) # Rosseland opacity [m2/kg] + self.usi_ee = self.usi_u**2 + self.usi_e = self.usi_p # energy density units are the same as pressure units. + self.usi_te = self.usi_e / self.u_t * self.usi_l # Box therm. em. [J/(s ster m2)] + self.msi_h = const.m_n # 1.674927471e-27 + self.msi_he = 6.65e-27 + self.msi_p = self.mu * self.msi_h # Mass per particle + self.usi_tg = (self.msi_h / self.ksi_b) * self.usi_ee + self.msi_e = const.m_e # 9.1093897e-31 + self.usi_b = self.u_b * 1e-4 + + # documentation for units above: + self.docu('u', 'velocity') + self.docu('p', 'pressure') + self.docu('kr', 'Rosseland opacity') + self.docu('ee', 'energy (total; i.e. not energy density)') + self.docu('e', 'energy density') + self.docu('te', 'Box therm. em. [J/(s ster m2)]') + self.docu('b', 'magnetic field') + + # setup self.uni. (tells how to convert from simu. units to cgs units, for simple vars.) + self.uni = {} + self.uni['l'] = self.u_l + self.uni['t'] = self.u_t + self.uni['rho'] = self.u_r + self.uni['p'] = self.u_r * self.u_u # self.u_p + self.uni['u'] = self.u_u + self.uni['e'] = self.u_e + self.uni['ee'] = self.u_ee + self.uni['n'] = self.u_n + self.uni['tg'] = 1.0 + self.uni['b'] = self.u_b + + # setup self.unisi + tools.convertcsgsi(self) + + # additional units (added for convenience) - started by SE, Apr 26 2021 + self.docu('m', 'mass') + self.u_m = self.u_r * self.u_l**3 # rho = mass / length^3 + self.usi_m = self.usi_r * self.usi_l**3 # rho = mass / length^3 + self.docu('ef', 'electric field') + self.u_ef = self.u_b # in cgs: F = q(E + (u/c) x B) + self.usi_ef = self.usi_b * self.usi_u # in SI: F = q(E + u x B) + self.docu('f', 'force') + self.u_f = self.u_p * self.u_l**2 # pressure = force / area + self.usi_f = self.usi_p * self.usi_l**2 # pressure = force / area + self.docu('q', 'charge') + self.u_q = self.u_f / self.u_ef # F = q E + self.usi_q = self.usi_f / self.usi_ef # F = q E + self.docu('nr', 'number density') + self.u_nr = self.u_r / self.u_m # nr = r / m + self.usi_nr = self.usi_r / self.usi_m # nr = r / m + self.docu('nq', 'charge density') + self.u_nq = self.u_q * self.u_nr + self.usi_nq = self.usi_q * self.usi_nr + self.docu('pm', 'momentum density') + self.u_pm = self.u_u * self.u_r # mom. dens. = mom * nr = u * r + self.usi_pm = self.usi_u * self.usi_r + self.docu('hz', 'frequency') + self.u_hz = 1./self.u_t + self.usi_hz = 1./self.usi_t + self.docu('phz', 'momentum density frequency (see e.g. momentum density exchange terms)') + self.u_phz = self.u_pm * self.u_hz + self.usi_phz = self.usi_pm * self.usi_hz + self.docu('i', 'current per unit area') + self.u_i = self.u_nq * self.u_u # ue = ... + J / (ne qe) + self.usi_i = self.usi_nq * self.usi_u + + # additional constants (added for convenience) + # masses + self.simu_amu = self.amu / self.u_m # 1 amu + self.simu_m_e = self.m_electron / self.u_m # 1 electron mass + # charge (1 elementary charge) + self.simu_q_e = self.q_electron / self.u_q # [derived from cgs] + self.simu_qsi_e = self.qsi_electron / self.usi_q # [derived from si] + # note simu_q_e != simu_qsi_e because charge is defined + # by different equations, for cgs and si. + # permeability (magnetic constant) (mu0) (We expect mu0_simu == 1.) + self.simu_mu0 = self.mu0si * (1/self.usi_b) * (self.usi_l) * (self.usi_i) + # J = curl(B) / mu0 --> mu0 = curl(B) / J --> [mu0] = [B] [length]^-1 [J]^-1 + # --> mu0[simu] / mu0[SI] = (B[simu] / B[SI]) * (L[SI] / L[simu]) * (J[SI]/J[simu]) + # boltzmann constant + self.simu_kB = self.ksi_b * (self.usi_nr / self.usi_e) # kB [simu energy / K] + + # update the dict doc_units with the values of units + self._update_doc_units_with_values() + + def _initialize_constant_lookup(self): + '''initialize self.constant_lookup, the lookup table for constants in various unit systems. + self.constant_lookup doesn't actually contain values; just the names of attributes. + In particular: + attr = self.constant_lookup[constant_name][unit_system] + getattr(self, attr) == value of constant_name in unit_system. + + Also creates constant_lookup_reverse, which tells constant_name given attr. + ''' + self.constant_lookup = collections.defaultdict(dict) + + def addc(c, doc=None, cgs=None, si=None, simu=None): + '''add constant c to lookup table with the units provided. + also adds documentation if provided.''' + if doc is not None: + self.docc(c, doc) + for key, value in (('cgs', cgs), ('si', si), ('simu', simu)): + if value is not None: + self.constant_lookup[c][key] = value + addc('amu', 'atomic mass unit', cgs='amu', si='amusi', simu='simu_amu') + addc('m_e', 'electron mass', cgs='m_electron', si='msi_electron', simu='simu_m_e') + addc('q_e', 'elementary charge derived from cgs', cgs='q_electron', simu='simu_q_e') + addc('qsi_e', 'elementary charge derived from si', si='qsi_electron', simu='simu_qsi_e') + addc('mu0', 'magnetic constant', si='mu0si', simu='simu_mu0') + addc('kB', 'boltzmann constant', cgs='k_b', si='ksi_b', simu='simu_kB') + addc('eps0', 'permittivity in vacuum', si='permsi') + # << [TODO] put more relevant constants here. + + # update the dict doc_constants with the values of constants + self._update_doc_constants_with_values() + + # include reverse-lookup for convenience. + rlookup = {key: c for (c, d) in self.constant_lookup.items() for key in (c, *d.values())} + self.constant_lookup_reverse = rlookup + + ## PRETTY REPR AND PRINTING ## + def __repr__(self): + '''show self in a pretty way (i.e. including info about base units)''' + return "<{} with base_units=dict({})>".format(type(self), + self.prettyprint_base_units(printout=False)) + + def base_units(self): + '''returns dict of u_l, u_t, u_r, gamma, for self.''' + return {unit: getattr(self, unit) for unit in self.BASE_UNITS} + + def prettyprint_base_units(self, printout=True): + '''print (or return, if not printout) prettystring for base_units for self.''' + fmt = '{:.2e}' # formatting for keys (except gamma) + fmtgam = '{}' # formatting for key gamma + s = [] + for unit in self.BASE_UNITS: + val = getattr(self, unit) + if unit == 'gamma': + valstr = fmtgam.format(val) + else: + valstr = fmt.format(val) + s += [unit+'='+valstr] + result = ', '.join(s) + if printout: + print(result) + else: + return (result) + + ## DOCS ## + def docu(self, u, doc): + '''documents u by adding u=doc to dict self.doc_units''' + self.doc_units[u] = doc + + def docc(self, c, doc): + '''documents c by adding c=doc to dict self.doc_constants''' + self.doc_constants[c] = doc + + def _unit_name(self, u): + '''returns name of unit u. e.g. u_r -> 'r'; usi_hz -> 'hz', 'nq' -> 'nq'.''' + for prefix in ['u_', 'usi_']: + if u.startswith(prefix): + u = u[len(prefix):] + break + return u + + def _unit_values(self, u): + '''return values of u, as a dict. + (checks u, 'u_'+u, and 'usi_'+u) + ''' + u = self._unit_name(u) + result = {} + u_u = 'u_'+u + usi_u = 'usi_'+u + result = {key: getattr(self, key) for key in [u, u_u, usi_u] if hasattr(self, key)} + return result + + def _unit_exists(self, ustr): + '''returns whether u_ustr or usi_ustr is an attribute of self.''' + return hasattr(self, 'u_'+ustr) or hasattr(self, 'usi_'+ustr) + + def prettyprint_unit_values(self, x, printout=True, fmtname='{:<3s}', fmtval='{:.2e}', sep=' ; '): + '''pretty string for unit values. print if printout, else return string.''' + if isinstance(x, str): + x = self._unit_values(x) + result = [] + for key, value in x.items(): + u_, p, name = key.partition('_') + result += [u_ + p + fmtname.format(name) + ' = ' + fmtval.format(value)] + result = sep.join(result) + if printout: + print(result) + else: + return result + + def _update_doc_units_with_values(self, sep=' | ', fmtdoc='{:20s}'): + '''for u in self.doc_units, update self.doc_units[u] with values of u.''' + for u, doc in self.doc_units.items(): + valstr = self.prettyprint_unit_values(u, printout=False) + if len(valstr) > 0: + doc = sep.join([fmtdoc.format(doc), valstr]) + self.doc_units[u] = doc + + def _constant_name(self, c): + '''returns name corresponding to c in self.constants_lookup.''' + return self.constant_lookup_reverse[c] + + def _constant_keys_and_values(self, c): + '''return keys and values for c, as a dict.''' + try: + clookup = self.constant_lookup[c] + except KeyError: + raise ValueError(f'constant not found in constant_lookup table: {repr(c)}') from None + else: + return {usys: (ckey, getattr(self, ckey)) for (usys, ckey) in clookup.items()} + + def prettyprint_constant_values(self, x, printout=True, fmtname='{:<10s}', fmtval='{:.2e}', sep=' ; '): + '''pretty string for constant values. print if printout, else return string.''' + if isinstance(x, str): + x = self._constant_keys_and_values(x) + result = [] + for ckey, cval in x.values(): + result += [fmtname.format(ckey) + ' = ' + fmtval.format(cval)] + result = sep.join(result) + if printout: + print(result) + else: + return result + + def _update_doc_constants_with_values(self, sep=' | ', fmtdoc='{:20s}'): + '''for c in self.doc_constants, update self.doc_constants[c] with values of c.''' + for c, doc in self.doc_constants.items(): + valstr = self.prettyprint_constant_values(c, printout=False) + if len(valstr) > 0: + doc = sep.join([fmtdoc.format(doc), valstr]) + self.doc_constants[c] = doc + + ## HELP AND HELPFUL METHODS ## + def help(self, u=None, printout=True, fmt='{:5s}: {}'): + '''prints documentation for u, or all units if u is None. + u: None, string, or list of strings + specify which attributes you want help with. + None --> provide help with all attributes. + 'units' --> provide help with all unit conversions + 'constants' --> provide help with all constants. + string --> provide help with this (or directly related) attribute. + list of strings --> provide help with these (or directly related) attributes. + + printout=False --> return dict, instead of printing. + ''' + if u is None: + result = self.doc_all + elif u == 'units': + result = self.doc_units + elif u == 'constants': + result = self.doc_constants + else: + if isinstance(u, str): + u = [u] + result = dict() + for unit in u: + try: + name = self._unit_name(unit) + doc = self.doc_units[name] + except KeyError: + try: + name = self._constant_name(unit) + doc = self.doc_constants[name] + except KeyError: + doc = f"u={repr(name)} is not yet documented (maybe it doesn't exist?)" + result[name] = doc + if not printout: + return result + else: + if len(result) > 1: # (i.e. getting help on multiple things) + print('Retrieve values by calling self(key, unit system).', + f'Recognized unit systems are: {UNIT_SYSTEMS}.', + 'See help(self.__call__) for more details.', sep='\n', end='\n\n') + for key, doc in result.items(): + print(fmt.format(key, doc)) + + def closest(self, value, sign_sensitive=True, reltol=1e-8): + '''returns [(attr, value)] for attr(s) in self whose value is closest to value. + sign_sensitive: True (default) or False + whether to care about signs (plus or minus) when comparing values + reltol: number (default 1e-8) + if multiple attrs are closest, and all match (to within reltol) + return a list of (attr, value) pairs for all such attrs. + closeness is determined by doing ratios. + ''' + result = [] + best = np.inf + for key, val in self.__dict__.items(): + if val == 0: + if value != 0: + continue + else: + result += [(key, val)] + try: + rat = value / val + except TypeError: + continue + if sign_sensitive: + rat = abs(rat) + compare_val = abs(rat - 1) + if best == 0: # we handle this separately to prevent division by 0 error. + if compare_val < reltol: + result += [(key, val)] + elif abs(1 - compare_val / best) < reltol: + result += [(key, val)] + elif compare_val < best: + result = [(key, val)] + best = compare_val + return result + + +''' ============================= UNITS OUTPUT SYSTEM ============================= ''' + + +''' ----------------------------- Hesitant Execution ----------------------------- ''' +# in this section, establish objects which can be combined like math terms but +# create a function which can be evaluated later, instead of requiring evaluation right away. +# See examples in FuncBuilder documentation. + + +class FuncBuilder: + '''use this object to build a function one arg / kwarg at a time. + + use attributes for kwargs, indices for args. + + Examples: + u = FuncBuilder() + f = ((u.x + u.r) * u[1]) # f is a function which does: return (kwarg['x'] + kwarg['r']) * arg[1] + f(None, 10, x=3, r=5) + 80 + f = u[0] + u[1] + u[2] # f is a function which adds the first three args and returns the result. + f(2, 3, 4) + 9 + f = u[0] ** u[1] # f(a,b) is equivalent to a**b + f(7,2) + 49 + + Technically, this object returns Funclike objects, not functions. + That means you can combine different FuncBuilder results into a single function. + Example: + u = FuncBuilder() + f1 = (u.x + u.r) * u[1] + f2 = u.y + u[0] + u[1] + f = f1 - f2 + f(0.1, 10, x=3, r=5, y=37) + 32.9 # ((3 + 5) * 10) - (37 + 0.1 + 10) + ''' + + def __init__(self, FunclikeType=None, **kw__funclike_init): + '''convert functions to funcliketype object.''' + self.FunclikeType = Funclike if (FunclikeType is None) else FunclikeType + self._kw__funclike_init = kw__funclike_init + + def __getattr__(self, a): + '''returns f(*args, **kwargs) --> kwargs[a].''' + def f_a(*args, **kwargs): + '''returns kwargs[{a}]''' + try: + return kwargs[a] + except KeyError: + message = 'Expected kwargs to contain key {} but they did not!'.format(repr(a)) + raise KeyError(message) from None + f_a.__doc__ = f_a.__doc__.replace('{a}', repr(a)) + f_a.__name__ = a + return self.FunclikeType(f_a, required_kwargs=[a], **self._kw__funclike_init) + + def __getitem__(self, i): + '''returns f(*args, **kwargs) --> args[i]. i must be an integer.''' + def f_i(*args, **kwargs): + '''returns args[{i}]''' + try: + return args[i] + except IndexError: + raise IndexError('Expected args[{}] to exist but it did not!'.format(i)) + f_i.__doc__ = f_i.__doc__.replace('{i}', repr(i)) + f_i.__name__ = 'arg' + str(i) + return self.FunclikeType(f_i, required_args=[i], **self._kw__funclike_init) + + +def make_Funclike_magic(op, op_name=None, reverse=False): + '''makes magic funclike for binary operator + it will be named magic + op_name. + + Example: + f = make_Funclike_magic(operator.__mul__, '__times__') + a function named magic__times__ which returns a Funclike object that does a * b. + + make_Funclike_magic is a low-level function which serves as a helper function for the Funclike class. + + make_Funclike_magic returns a function of (a, b) which returns a Funclike-like object that does op(a, b). + type(result) will be type(a) unless issubclass(b, a), in which case it will be type(b). + + if reverse, we aim to return methods to use with r magic, e.g. __radd__. + ''' + def name(x): # handle naming... + if callable(x): + if hasattr(x, '__name__'): + return getattr(x, '__name__') + else: + return type(x).__name__ + else: + result = str(x) + if len(result) > 5: + result = 'value' # x name too long; use string 'value' instead. + return result + + def magic(a, b): + type_A0, type_B0 = type(a), type(b) # types before (possibly) reversing + if reverse: + (a, b) = (b, a) + # apply operation + + def f(*args, **kwargs): + __tracebackhide__ = HIDE_INTERNAL_TRACEBACKS + a_val = a if not callable(a) else a(*args, **kwargs) + b_val = b if not callable(b) else b(*args, **kwargs) + return op(a_val, b_val) + # set name of f + if op_name is not None: + f.__name__ = '(' + name(a) + op_name + name(b) + ')' + # typecast f to appropriate type of Funclike-like object, and return it. + if issubclass(type_B0, type_A0): + ReturnType = type_B0 + else: + ReturnType = type_A0 + result = ReturnType(f, parents=[a, b]) + return result + if op_name is not None: + magic.__name__ = 'magic' + op_name + return magic + + +class Funclike: + '''function-like object. Useful for combining with other Funclike objects. + Allows for "hesitant execution": + The args and kwargs do not need to be known until later. + Evaluate whenever the instance is called like a function. + + Example: + # --- basic example --- + getx = lambda *args, **kwargs: kwargs['x'] + funclike_getx = Funclike(getx) + mult_x_by_2 = funclike_getx * 2 + mult_x_by_2(x=7) + 14 # 7 * 2 + # --- another basic example --- + gety = lambda *args, **kwargs: kwargs['y'] + funclike_gety = Funclike(gety) + get0 = lambda *args, **kwargs: args[0] + funclike_get0 = Funclike(get0) + add_arg0_to_y = funclike_get0 + funclike_gety + add_arg0_to_y(3, y=10) + 13 # 3 + 10 + # --- combine the basic examples --- + add_arg0_to_y_then_subtract_2x = add_arg0_to_y - mult_x_by_2 + add_arg0_to_y_then_subtract_2x(7, y=8, x=50) + -85 # (7 + 8) - 50 * 2 + ''' + + def __init__(self, f, required_args=[], required_kwargs=[], parents=[]): + self._tracebackhide = HIDE_INTERNAL_TRACEBACKS + self.f = f + self.__name__ = f.__name__ + self._required_args = required_args # list of args which must be provided for a function call to self. + self._required_kwargs = required_kwargs # list of kwargs which must be provided for a function call to self. + for parent in parents: + parent_req_args = getattr(parent, '_required_args', []) + parent_req_kwargs = getattr(parent, '_required_kwargs', []) + if len(parent_req_args) > 0: + self._add_to_required('_required_args', parent_req_args) + if len(parent_req_kwargs) > 0: + self._add_to_required('_required_kwargs', parent_req_kwargs) + + def _add_to_required(self, original_required, new_required): + orig = getattr(self, original_required, []) + setattr(self, original_required, sorted(list(set(orig + new_required)))) + + # make Funclike behave like a function (i.e. it is callable) + def __call__(self, *args, **kwargs): + __tracebackhide__ = self._tracebackhide + return self.f(*args, **kwargs) + + # make Funclike behave like a number (i.e. it can participate in arithmetic) + __mul__ = make_Funclike_magic(operator.__mul__, ' * ') # multiply + __add__ = make_Funclike_magic(operator.__add__, ' + ') # add + __sub__ = make_Funclike_magic(operator.__sub__, ' - ') # subtract + __truediv__ = make_Funclike_magic(operator.__truediv__, ' / ') # divide + __pow__ = make_Funclike_magic(operator.__pow__, ' ** ') # raise to a power + + __rmul__ = make_Funclike_magic(operator.__mul__, ' * ', reverse=True) # rmultiply + __radd__ = make_Funclike_magic(operator.__add__, ' + ', reverse=True) # radd + __rsub__ = make_Funclike_magic(operator.__sub__, ' - ', reverse=True) # rsubtract + __rtruediv__ = make_Funclike_magic(operator.__truediv__, ' / ', reverse=True) # rdivide + + def _strinfo(self): + '''info about self. (goes to repr)''' + return 'required_args={}, required_kwargs={}'.format( + self._required_args, self._required_kwargs) + + def __repr__(self): + return f'(Funclike with operation {repr(self.__name__)})' + + def _repr_adv_(self): + '''very detailed repr of self.''' + return '<{} named {} with {}>'.format( + object.__repr__(self), repr(self.__name__), self._strinfo()) + + +class AttrsFunclike(Funclike): + '''Funclike but treat args[argn] as obj, and use obj attrs for un-entered required kwargs. + + argn: int (default 0) + treat args[argn] as obj. + format_attr: string, or None. + string --> format this string to all required kwargs before checking if they are attrs of obj. + E.g. format = 'usi_{}' --> if looking for 'r', check obj.usi_r and kwargs['r']. + None --> by default, don't mess with kwargs names at all. + However, if special kwarg ATTR_FORMAT_KWARG (defined at top of units.py) + is passed to this function, use its value to format the required kwargs. + ''' + + def __init__(self, f, argn=0, format_attr=None, **kw__funclike_init): + '''f should be a Funclike object.''' + Funclike.__init__(self, f, **kw__funclike_init) + self.argn = argn + self.format_attr = format_attr + self._add_to_required('_required_args', [argn]) + self._add_to_required('_required_args_special', [argn]) + f = self.f + required_kwargs = self._required_kwargs + + def f_attrs(*args, **kwargs): + __tracebackhide__ = self._tracebackhide + obj = args[argn] + kwdict = kwargs + if self.format_attr is None: + format_attr = kwargs.get(ATTR_FORMAT_KWARG, '{}') + else: + format_attr = self.format_attr + # for any required kwargs which haven't been entered, + # try to use the attribute (format_attr.format(kwarg)) of obj, if possible. + for kwarg in required_kwargs: + if kwarg not in kwargs: + attr_name = format_attr.format(kwarg) + if hasattr(obj, attr_name): + kwdict[kwarg] = getattr(obj, attr_name) + return f(*args, **kwdict) + f_attrs.__name__ = f.__name__ # TODO: add something to the name to indicate it is an AttrsFunclike. + Funclike.__init__(self, f_attrs, self._required_args, self._required_kwargs) + + # TODO: would it be cleaner to keep the original f and just override the __call__ method? + # Doing so may allow to move the args & kwargs checking to __call__, + # i.e. check that we have required_args & required_kwargs, else raise error. + + def _special_args_info(self): + '''info about the meaning of special args for self.''' + return 'arg[{argn}] attrs can replace missing required kwargs.'.format(argn=self.argn) + + def _strinfo(self): + '''info about self. (goes to repr)''' + return 'required_args={}, required_kwargs={}. special_args={}: {}'.format( + self._required_args, self._required_kwargs, + self._required_args_special, self._special_args_info()) + + +''' -------------------------------------------------------------------------- ''' +''' ----------------------------- Units-Specific ----------------------------- ''' +''' -------------------------------------------------------------------------- ''' +# Up until this point in the file, the codes have been pretty generic. +# The codes beyond this section, though, are specific to the units implementation in helita. + +''' ----------------------------- Units Naming ----------------------------- ''' + +# string manipulation helper functions + + +def _pretty_str_unit(name, value, flip=False, flip_neg=False): + '''returns string for name, value. name is name of unit; value is exponent. + flip --> pretend result is showing up in denominator (i.e. multiply exponent by -1). + flip_neg --> flip only negative values. (Always give a positive exponent result.) + + return (string, whether it was flipped). + ''' + flipped = False + if value == 0: + result = '' + elif flip or ((value < 0) and flip_neg): + result = _pretty_str_unit(name, -1 * value, flip=False)[0] + flipped = True + elif value == 1: + result = name + else: + result = name + '^{' + str(value) + '}' + return (result, flipped) + + +def _join_strs(strs, sep=' '): + '''joins strings, separating by sep. Ignores Nones and strings of length 0.''' + ss = [s for s in strs if (s is not None) and (len(s) > 0)] + return sep.join(ss) + + +class UnitsExpression(): + '''expression of units. + + Parameters + ---------- + contents: dict + keys = unit name; values = exponent for that unit. + order: string (default 'entered') + determines order in which units are printed. Options are: + 'entered' --> use the order in which the keys appear in contents. + 'exp' --> order by exponent (descending by default, i.e. largest first). + 'absexp' --> order by abs(exponent) (decending by default). + 'alpha' --> order alphabetically (a to z by default). + + TODO: make ordering options "clearer" (use enum or something like that?) + + TODO: make display mode options (e.g. "latex", "pythonic", etc) + ''' + + def __init__(self, contents=collections.OrderedDict(), order='entered', frac=True, unknown=False): + self.contents = contents + self.order = order + self.frac = frac # whether to show negatives in denominator + self.unknown = unknown # whether the units are actually unknown. + self._mode = None # mode for units. unused unless unknown is True. + + def _order_exponent(self, ascending=False): + '''returns keys for self.contents, ordered by exponent. + not ascending --> largest first; ascending --> largest last. + ''' + return sorted(list(self.contents.keys()), + key=lambda k: self.contents[k], reverse=not ascending) + + def _order_abs_exponent(self, ascending=False): + '''returns keys for self.contents, ordered by |exponent|. + not ascending --> largest first; ascending --> largest last. + ''' + return sorted(list(self.contents.keys()), + key=lambda k: abs(self.contents[k]), reverse=not ascending) + + def _order_alphabetical(self, reverse=False): + '''returns keys for self.contents in alphabetical order. + not reverse --> a first; reverse --> a last. + ''' + # TODO: handle case of '$' included in key name (e.g. for greek letters) + return sorted(list(self.contents.keys()), reverse=reverse) + + def _order_entered(self, reverse=False): + '''returns keys for self.contents in order entered. + reverse: whether to reverse the order. + ''' + return list(self.contents.keys()) + + def _pretty_str_key(self, key, flip=False, flip_neg=False): + '''determine string given key (a unit's name). + flip --> always flip; i.e. multiply value by -1. + flip_neg --> only flip negative values. + + returns (string, whether it was flipped) + ''' + return _pretty_str_unit(key, self.contents[key], flip=flip, flip_neg=flip_neg) + + def __str__(self): + '''str of self: pretty string telling the units which self represents.''' + if self.unknown: + if self._mode is None: + return '???' + else: + return self._mode + if self.order == 'exp': + key_order = self._order_exponent() + elif self.order == 'absexp': + key_order = self._order_abs_exponent() + elif self.order == 'alpha': + key_order = self._order_alphabetical() + elif self.order == 'entered': + key_order = self._order_entered() + else: + errmsg = ("self.order is not a valid order! For valid choices," + "see help(UnitsExpression). (really helita.sim.units.UnitsExpression)") + raise ValueError(errmsg) + x = [self._pretty_str_key(key, flip_neg=self.frac) for key in key_order] + numer = [s for s, flipped in x if not flipped] + denom = [s for s, flipped in x if flipped] # and s != '' + numer_str = _join_strs(numer, ' ') + if len(denom) == 0: + result = numer_str + else: + if len(numer) == 0: + numer_str = '1' + if len(denom) == 1: + result = numer_str + ' / ' + denom[0] + else: + denom_str = _join_strs(denom, ' ') + result = numer_str + ' / (' + denom_str + ')' + return result + + def __repr__(self): + return repr(str(self)) + + def _repr_adv_(self): + '''very detailed repr of self.''' + return "<{} with content = '{}'>".format(object.__repr__(self), str(self)) + + def __mul__(self, b): + '''multiplication with b (another UnitsExpression object).''' + result = self.contents.copy() + if not isinstance(b, UnitsExpression): + raise TypeError('Expected UnitsExpression type but got type={}'.format(type(b))) + for key, val in b.contents.items(): + try: + result[key] += val + except KeyError: + result[key] = val + unknown = self.unknown or getattr(b, 'unknown', False) + return UnitsExpression(result, order=self.order, frac=self.frac, unknown=unknown) + + def __truediv__(self, b): + '''division by b (another UnitsExpression object).''' + result = self.contents.copy() + if not isinstance(b, UnitsExpression): + raise TypeError('Expected UnitsExpression type but got type={}'.format(type(b))) + for key, val in b.contents.items(): + try: + result[key] -= val + except KeyError: + result[key] = -1 * val + unknown = self.unknown or getattr(b, 'unknown', False) + return UnitsExpression(result, order=self.order, frac=self.frac, unknown=unknown) + + def __pow__(self, b): + '''raising to b (a number).''' + result = self.contents.copy() + for key in result.keys(): + result[key] *= b + unknown = self.unknown or getattr(b, 'unknown', False) + return UnitsExpression(result, order=self.order, frac=self.frac, unknown=unknown) + + def __call__(self, *args, **kwargs): + '''return self. For compatibility with UnitsExpressionDict. + Also, sets self._mode (print string if unknown units) based on UNITS_KEY_KWARG if possible.''' + units_key = kwargs.get(UNITS_KEY_KWARG, None) + if units_key == 'usi': + self._mode = 'SI units' + elif units_key == 'ucgs': + self._mode = 'cgs units' + else: + self._mode = 'simu. units' + return self + + +class UnitSymbol(UnitsExpression): + '''symbol for a single unit. + + UnitSymbol('x') is like UnitsExpression(contents=collections.OrderedDict(x=1)) + + Example: + for 'V^{2} / m', one would enter: + result = units.UnitSymbol('V')**2 / units.UnitSymbol('m') + to set printout settings, attributes can be editted directly: + result.order = 'exp' + result.frac = True + to see contents, convert to string: + str(result) + 'V^{2} / m' + ''' + + def __init__(self, name, *args, **kwargs): + self.name = name + contents = collections.OrderedDict() + contents[name] = 1 + UnitsExpression.__init__(self, contents, *args, **kwargs) + + +UnitsSymbol = UnitSymbol # alias + + +def UnitSymbols(names, *args, **kwargs): + '''returns UnitSymbol(name, *args, **kwargs) for name in names. + names can be a string or list: + string --> treat names as names.split() + list --> treat names list of names. + + Example: + V, m, s = units.UnitSymbols('V m s', order='absexp') + str(V**2 / s * m**-4) + 'V^{2} / (m^{4} s)' + ''' + if isinstance(names, str): + names = names.split() + return tuple(UnitSymbol(name, *args, **kwargs) for name in names) + + +UnitsSymbols = UnitSymbols # alias + + +class UnitsExpressionDict(UnitsExpression): + '''expressions of units, but in multiple unit systems. + + Contains multiple UnitsExpression. + ''' + + def __init__(self, contents=dict(), **kw__units_expression_init): + '''contents should be a dict with: + keys = units_keys; + when UnitsExpressionDict is called, it returns contents[kwargs[UNITS_KEY_KWARG]] + values = dicts or UnitsExpression objects; + dicts in contents are used to make a UnitsExpression, while + UnitsExpressions in contents are saved as-is. + ''' + self.contents = dict() + for key, val in contents.items(): + if isinstance(val, UnitsExpression): # already a UnitsExpression; don't need to convert. + self.contents[key] = val + else: # not a UnitsExpression; must convert. + self.contents[key] = UnitsExpression.__init__(val, **kw__units_expression_init) + self._kw__units_expression_init = kw__units_expression_init + + def __repr__(self): + '''pretty string of self.''' + return str({key: str(val) for (key, val) in self.contents.items()}) + + def __mul__(self, b): + '''multiplication of self * b. (b is another UnitsExpression or UnitsExpressionDict object).''' + result = dict() + if isinstance(b, UnitsExpressionDict): + assert b.contents.keys() == self.contents.keys() # must have same keys to multiply dicts. + for key, uexpr in b.contents.items(): + result[key] = self.contents[key] * uexpr + elif isinstance(b, UnitsExpression): + for key in self.contents.keys(): + result[key] = self.contents[key] * b + else: + raise TypeError('Expected UnitsExpression or UnitsExpressionDict type but got type={}'.format(type(b))) + return UnitsExpressionDict(result, **self._kw__units_expression_init) + + def __truediv__(self, b): + '''division of self / b. (b is another UnitsExpression or UnitsExpressionDict object).''' + result = dict() + if isinstance(b, UnitsExpressionDict): + assert b.contents.keys() == self.contents.keys() # must have same keys to multiply dicts. + for key, uexpr in b.contents.items(): + result[key] = self.contents[key] / uexpr + elif isinstance(b, UnitsExpression): + for key in self.contents.keys(): + result[key] = self.contents[key] / b + else: + raise TypeError('Expected UnitsExpression or UnitsExpressionDict type but got type={}'.format(type(b))) + return UnitsExpressionDict(result, **self._kw__units_expression_init) + + def __pow__(self, b): + '''raising to b (a number).''' + result = dict() + for key, internal_uexpr in self.contents.items(): + result[key] = internal_uexpr ** b + return UnitsExpressionDict(result, **self._kw__units_expression_init) + + # handle cases of (b * a) and (b / a), for b=UnitsExpression(...), a=UnitsExpressionDict(...). + # b * a --> TypeError --> try a.__rmul__(b). + __rmul__ = __mul__ + + # b / a --> TypeError --> try a.__rtrudiv__(b). + def __rtruediv__(self, b): + '''division of b / self. (b is another UnitsExpression or UnitsExpressionDict object).''' + result = dict() + if isinstance(b, UnitsExpressionDict): + # < we should probably never reach this section but I'm keeping it for now... + assert b.contents.keys() == self.contents.keys() # must have same keys to multiply dicts. + for key, uexpr in b.contents.items(): + result[key] = uexpr / self.contents[key] + elif isinstance(b, UnitsExpression): + for key in self.contents.keys(): + result[key] = b / self.contents[key] + else: + raise TypeError('Expected UnitsExpression or UnitsExpressionDict type but got type={}'.format(type(b))) + return UnitsExpressionDict(result, **self._kw__units_expression_init) + + # call self (return the correct UnitsExpression based on UNITS_KEY_KWARG) + def __call__(self, *args, **kwargs): + '''return self.contents[kwargs[UNITS_KEY_KWARG]]. + in other words, return the relevant UnitsExpression, based on units_key. + ''' + units_key = kwargs[UNITS_KEY_KWARG] + uexpr = self.contents[units_key] # relevant UnitsExpression based on units_key + return uexpr + + +class UnitSymbolDict(UnitsExpressionDict): + '''a dict of symbols for unit. + + UnitSymbolDict(usi='m', ucgs='cm') is like: + UnitsExpressionDict(contents=dict(usi=UnitSymbol('m'), ucgs=UnitSymbol('cm')) + + the properties kwarg is passed to UnitsExpressionDict.__init__() as **properties. + ''' + + def __init__(self, properties=dict(), **symbols_dict): + self.symbols_dict = symbols_dict + contents = {key: UnitSymbol(val) for (key, val) in symbols_dict.items()} + UnitsExpressionDict.__init__(self, contents, **properties) + + +# make custom error class for when units are not found. +class UnitsNotFoundError(Exception): + '''base class for telling that units have not been found.''' + + +def _default_units_f(info=''): + def f(*args, **kwargs): + errmsg = ("Cannot calculate units. Either the original quant's units are unknown," + " or one of the required children quants' units are unknown.\n" + "Further info provided: " + str(info)) + raise UnitsNotFoundError(errmsg) + return Funclike(f) + + +DEFAULT_UNITS_F = _default_units_f() +DEFAULT_UNITS_NAME = UnitSymbol('???', unknown=True) + +''' ----------------------------- Units Tuple ----------------------------- ''' + +UnitsTupleBase = collections.namedtuple('Units', ('f', 'name', 'evaluated'), + defaults=[DEFAULT_UNITS_F, DEFAULT_UNITS_NAME, False] + ) + + +def make_UnitsTuple_magic(op, op_name=None, reverse=False): + '''makes magic func for binary operator acting on UnitsTuple object. + it will be named magic + op_name. + + make_UnitsTuple_magic is a low-level function which serves as a helper function for the UnitsTuple class. + ''' + def magic(a, b): + # reverse if needed + if reverse: + (a, b) = (b, a) + # get f and name for a and b + if isinstance(a, UnitsTuple): + a_f, a_name = a.f, a.name + else: + a_f, a_name = a, a + if isinstance(b, UnitsTuple): + b_f, b_name = b.f, b.name + else: + b_f, b_name = b, b + # apply operation + f = op(a_f, b_f) + name = op(a_name, b_name) + return UnitsTuple(f, name) + # rename magic (based on op_name) + if op_name is not None: + magic.__name__ = 'magic' + op_name + return magic + + +class UnitsTuple(UnitsTupleBase): + '''UnitsTuple tells: + f: Funclike (or constant...). Call this to convert to the correct units. + name: UnitsExpression object which gives name for units, e.g. str(UnitsTuple().name) + + Additionally, multiplying, dividing, or exponentiating with another UnitsTuple works intuitively: + op(UnitsTuple(a1,b1), UnitsTuple(a2,b2)) = UnitsTuple(op(a1,a2), op(b1,b2)) for op in *, /, **. + And if the second object is not a UnitsTuple, the operation is distributed instead: + op(UnitsTuple(a1,b1), x) = UnitsTuple(op(a1,x), op(b1,x)) for op in *, /, **. + ''' + + # make Funclike behave like a number (i.e. it can participate in arithmetic) + __mul__ = make_UnitsTuple_magic(operator.__mul__, ' * ') # multiply + __add__ = make_UnitsTuple_magic(operator.__add__, ' + ') # add + __sub__ = make_UnitsTuple_magic(operator.__sub__, ' - ') # subtract + __truediv__ = make_UnitsTuple_magic(operator.__truediv__, ' / ') # divide + __pow__ = make_UnitsTuple_magic(operator.__pow__, ' ** ') # raise to a power + + __rmul__ = make_UnitsTuple_magic(operator.__mul__, ' * ', reverse=True) # rmultiply + __radd__ = make_UnitsTuple_magic(operator.__add__, ' + ', reverse=True) # radd + __rsub__ = make_UnitsTuple_magic(operator.__sub__, ' - ', reverse=True) # rsubtract + __rtruediv__ = make_UnitsTuple_magic(operator.__truediv__, ' / ', reverse=True) # rdivide + + # make Funclike behave like a function (i.e. it is callable) + def __call__(self, *args, **kwargs): + if callable(self.name): # if self.name is a UnitsExpressionDict + name = self.name(*args, **kwargs) # then, call it. + else: # otherwise, self.name is a UnitsExpression. + name = self.name # so, don't call it. + factor = self.f(*args, **kwargs) + return UnitsTuple(factor, name, evaluated=True) + + # representation + def __repr__(self): + return f'{type(self).__name__}(f={self.f}, name={repr(self.name)}, evaluated={self.evaluated})' + + +''' ----------------------------- Dimensionless Tuple ----------------------------- ''' +# in this section is a units tuple which should be used for dimensionless quantities. + + +def dimensionless_units_f(*args, **kwargs): + '''returns 1, regardless of args and kwargs.''' + return 1 + + +DIMENSIONLESS_UNITS = Funclike(dimensionless_units_f) + +DIMENSIONLESS_NAME = UnitsExpression() + +DIMENSIONLESS_TUPLE = UnitsTuple(DIMENSIONLESS_UNITS, DIMENSIONLESS_NAME) + + +''' ----------------------------- Units FuncBuilder ----------------------------- ''' + + +class UnitsFuncBuilder(FuncBuilder): + '''FuncBuilder but also qc attribute will get quant children of obj. + FunclikeType must be (or be a subclass of) AttrsFunclike. + ''' + + def __init__(self, FunclikeType=AttrsFunclike, units_key=None, **kw__funclike_init): + FuncBuilder.__init__(self, FunclikeType=FunclikeType, **kw__funclike_init) + self.units_key = units_key + + def _quant_child(self, i, oldest_first=True, return_type='tuple'): + '''returns a Funclike which gets i'th quant child in QUANTS_TREE for object=args[1]. + + not intended to be called directly; instead use alternate functions as described below. + + return_type: string (default 'tuple') + 'tuple' --> return a UnitsTuple object. + (alternate funcs: quant_child, qc) + 'ufunc' --> return the units function only. (UnitsTuple.f) + (alternate funcs: quant_child_units, qcu) + 'name' --> return the units name only. (UnitsTuple.name) + (alternate funcs: quant_child_name, qcn) + ''' + return_type = return_type.lower() + assert return_type in ('tuple', 'ufunc', 'name'), 'Got invalid return_type(={})'.format(repr(return_type)) + + def f_qc(obj_uni, obj, quant_tree, *args, **kwargs): + '''gets quant child number {i} from quant tree of obj, + sorting from i=0 as {age0} to i=-1 as {agef}. + ''' + #print('f_qc called with uni, obj, args, kwargs:', obj_uni, obj, *args, **kwargs) + __tracebackhide__ = self._tracebackhide + child_tree = quant_tree.get_child(i, oldest_first) + if self.units_key is None: + units_key = kwargs[UNITS_KEY_KWARG] + else: + units_key = self.units_key + units_tuple = _units_lookup_by_quant_info(obj, child_tree.data, units_key) + result = units_tuple(obj_uni, obj, child_tree, *args, **kwargs) + if return_type == 'ufunc': + return result.f + elif return_type == 'name': + return result.name + elif return_type == 'tuple': + return result + # make pretty documentation for f_qc. + youngest, oldest = 'youngest (added most-recently)', 'oldest (added first)' + age0, agef = (oldest, youngest) if oldest_first else (youngest, oldest) + f_qc.__doc__ = f_qc.__doc__.format(i=i, age0=age0, agef=agef) + f_qc.__name__ = 'child_' + str(i) + '__' + ('oldest' if oldest_first else 'youngest') + '_is_0' + _special_args_info = 'arg[1] is assumed to be an obj with attribute got_vars_tree() which returns a QuantTree.' + required_kwargs = [UNITS_KEY_KWARG] if self.units_key is None else [] + # return f_qc + f_qc = self.FunclikeType(f_qc, argn=1, required_kwargs=required_kwargs, **self._kw__funclike_init) + f_qc._special_args_info = lambda *args, **kwargs: _special_args_info + return f_qc + + def quant_child(self, i, oldest_first=True): + '''returns a Funclike which gets units tuple of i'th quant child in QUANTS_TREE for object=args[1].''' + return self._quant_child(i, oldest_first, return_type='tuple') + + qc = quant_child # alias + + def quant_child_f(self, i, oldest_first=True): + '''returns a Funclike which gets units func for i'th quant child in QUANTS_TREE for object=args[1].''' + return self._quant_child(i, oldest_first, return_type='ufunc') + + qcf = quant_child_f # alias + + def quant_child_name(self, i, oldest_first=True): + '''returns a Funclike which gets units name for i'th quant child in QUANTS_TREE for object=args[1].''' + return self._quant_child(i, oldest_first, return_type='name') + + qcn = quant_child_name # alias + + +def _units_lookup_by_quant_info(obj, info, units_key=UNITS_UNIVERSAL_KEY, + default_f=_default_units_f, default_name=DEFAULT_UNITS_NAME): + '''given obj, gets UnitsTuple from QuantInfo info. + We are trying to get: + obj.vardict[info.metaquant][info.typequant][info.quant][units_key]. + if we fail, try again with units_key = 'uni'. + 'uni' will be used to represent units which are the same in any system. + + if we fail again, + return default # which, by default, is the default UnitsTuple() defined in units.py. + ''' + x = obj.quant_lookup(info) # x is the entry in vardict for QuantInfo info. + keys_to_check = (units_key, UNITS_UNIVERSAL_KEY) # these are the keys we will try to find in x. + # try to get units tuple: + utuple = _multiple_lookup(x, *keys_to_check, default=None) + if utuple is not None: + return utuple + # else: # failed to get units tuple. + # try to get units f and units name separately. + keys_to_check_f = [UNITS_KEY_F.format(key) for key in keys_to_check] + keys_to_check_name = [UNITS_KEY_NAME.format(key) for key in keys_to_check] + msg_if_err = '\n units_key = {}\n quant_info = {}\n quant_lookup_result = {}'.format(repr(units_key), info, x) + u_f = _multiple_lookup(x, *keys_to_check_f, default=default_f(msg_if_err)) + u_name = _multiple_lookup(x, *keys_to_check_name, default=default_name) + utuple_ = UnitsTuple(u_f, u_name) + return utuple_ + + +def _multiple_lookup(x, *keys, default=None): + '''try to get keys from x. return result for first key found. return None if fail.''' + for key in keys: + result = x.get(key, None) + if result is not None: + return result + return default + + +''' ----------------------------- Evaluate Units ----------------------------- ''' + +EvaluatedUnitsTuple = collections.namedtuple('EvaluatedUnits', ('factor', 'name')) +# TODO: make prettier formatting for the units (e.g. {:.3e}) +# TODO: allow to change name formatting (via editting "order" and "frac" of underlying UnitsExpression object) + + +class EvaluatedUnits(EvaluatedUnitsTuple): + '''tuple of (factor, name). + Also, if doing *, /, or **, will only affect factor. + + Example: + np.array([1, 2, 3]) * EvaluatedUnits(factor=10, name='m / s') + EvaluatedUnits(factor=np.array([10, 20, 30]), name='m / s') + ''' + + def __mul__(self, b): # self * b + return EvaluatedUnits(self.factor * b, self.name) + + def __rmul__(self, b): # b * self + return EvaluatedUnits(b * self.factor, self.name) + + def __truediv__(self, b): # self / b + return EvaluatedUnits(self.factor / b, self.name) + + def __rtruediv__(self, b): # b / self + return EvaluatedUnits(b / self.factor, self.name) + + def __pow__(self, b): # self ** b + return EvaluatedUnits(self.factor ** b, self.name) + # the next line is to tell numpy that when b is a numpy array, + # we should use __rmul__ for b * self and __rtruediv__ for b / self, + # instead of making a UfuncTypeError. + __array_ufunc__ = None + + +def _get_units_info_from_mode(mode='si'): + '''returns units_key, format_attr given units mode. Case-insensitive.''' + mode = mode.lower() + assert mode in UNITS_MODES, 'Mode invalid! valid modes={}; got mode={}'.format(list(UNITS_MODES.keys()), repr(mode)) + units_key, format_attr = UNITS_MODES[mode] + return units_key, format_attr + + +def evaluate_units_tuple(units_tuple, obj, *args__units_tuple, mode='si', _force_from_simu=False, **kw__units_tuple): + '''evaluates units for units_tuple using the attrs of obj and the selected units mode. + + units_tuple is called with units_tuple(obj.uni, obj, *args__units_tuple, **kw__units_tuple). + Though, first, ATTR_FORMAT_KWARG and UNITS_KEY_KWARG will be set in kw__units_tuple, + to their "default" values (based on mode), unless other values are provided in kw__units_tuple. + + Accepted modes are 'si' for SI units, and 'cgs' for cgs units. Case-insensitive. + + if _force_from_simu, always give the conversion factor from simulation units, + regardless of the value of obj.units_output. + This kwarg is mainly intended for internal use, during _get_var_postprocess. + ''' + # initial processing of mode. + units_key, format_attr = _get_units_info_from_mode(mode=mode) + # set defaults based on mode (unless they are already set in kw__units_tuple) + kw__units_tuple[ATTR_FORMAT_KWARG] = kw__units_tuple.get(ATTR_FORMAT_KWARG, format_attr) + kw__units_tuple[UNITS_KEY_KWARG] = kw__units_tuple.get(UNITS_KEY_KWARG, units_key) + # evaluate units_tuple, using obj and **kwargs. + result = units_tuple(obj.uni, obj, *args__units_tuple, **kw__units_tuple) + # check obj's unit system. + if _force_from_simu or (obj.units_output == 'simu'): + f = result.f + elif obj.units_output == mode: + f = 1 + else: + raise NotImplementedError(f'units conversion from {repr(obj.units_output)} to {repr(mode)}') + # make result formatting prettier and return result. + result = EvaluatedUnits(f, str(result.name)) + return result + + +def get_units(obj, mode='si', **kw__evaluate_units_tuple): + '''evaluates units for most-recently-gotten var (at top of obj._quants_tree). + Accepted modes are defined by UNITS_MODES in helita.sim.units.py near top of file. + + Accepted modes are 'si' for SI units, and 'cgs' for cgs units. Case-insensitive. + + kw__units_tuple go to units function. + (meanwhile, kwargs related to units mode are automatically set, unless provided here) + + This function is bound to the BifrostData object via helita.sim.document_vars.create_vardict(). + ''' + units_key, format_attr = _get_units_info_from_mode(mode=mode) + # lookup info about most-recently-gotten var. + quant_info = obj.get_quant_info(lookup_in_vardict=False) + units_tuple = _units_lookup_by_quant_info(obj, quant_info, units_key=units_key) + quant_tree = obj.got_vars_tree(as_data=True) + # evaluate units_tuple, given obj.uni, obj, and **kwargs. + result = evaluate_units_tuple(units_tuple, obj, quant_tree, mode=mode, **kw__evaluate_units_tuple) + return result + + +''' ----------------------------- Aliases ----------------------------- ''' + +# It can be helpful to import these aliases into other modules. +# for example, in a load_..._quantities file, you would do: +""" +from .units import ( + UNI, USI, UCGS, UCONST, + Usym, Usyms, UsymD, + U_TUPLE, + DIMENSIONLESS, UNITS_FACTOR_1, NO_NAME, + UNI_length, UNI_time, UNI_mass, + UNI_speed, UNI_rho, UNI_nr, UNI_hz +) +""" + +# for making "universal" units +UNI = UnitsFuncBuilder(units_key=None) # , format_attr=None +# for making si units +USI = UnitsFuncBuilder(units_key=UNITS_MODES['si'][0], format_attr=UNITS_MODES['si'][1]) +# for making cgs units +UCGS = UnitsFuncBuilder(units_key=UNITS_MODES['cgs'][0], format_attr=UNITS_MODES['cgs'][1]) +# for making "constant" units +UCONST = FuncBuilder(FunclikeType=AttrsFunclike, format_attr='{}') + +# for making unit names ("UnitsExpression"s) +Usym = UnitSymbol # returns a single symbol +Usyms = UnitSymbols # returns multiple symbols +UsymD = UnitSymbolDict # returns a dict of unit symbols + +# for putting units info in vardict +U_TUPLE = UnitsTuple # tuple with (units function, units expression) + +# for dimensionless quantities +DIMENSIONLESS = DIMENSIONLESS_TUPLE # dimensionless tuple (factor is 1 and name is '') +UNITS_FACTOR_1 = DIMENSIONLESS_UNITS # dimensionless units (factor is 1) +NO_NAME = DIMENSIONLESS_NAME # dimensionless name (name is '') + +# for "common" basic unit tuples +UNI_length = U_TUPLE(UNI.l, UsymD(usi='m', ucgs='cm')) +UNI_time = U_TUPLE(UNI.t, Usym('s')) +UNI_mass = U_TUPLE(UNI.m, UsymD(usi='kg', ucgs='g')) +UNI_speed = U_TUPLE(UNI.u, UNI_length.name / UNI_time.name) +UNI_rho = U_TUPLE(UNI.r, UNI_mass.name / (UNI_length.name**3)) # mass density +UNI_nr = U_TUPLE(UNI.nr, UNI_length.name ** (-3)) # number density +UNI_hz = U_TUPLE(UNI.hz, Usym('s')**(-1)) # frequency diff --git a/helita/utils/__init__.py b/helita/utils/__init__.py index d8f177ea..0211a3c7 100644 --- a/helita/utils/__init__.py +++ b/helita/utils/__init__.py @@ -4,5 +4,4 @@ __all__ = ["congrid", "fitting", "shell", "radtrans", "utilsmath", "utilsfast"] -from . import fitting -from . import utilsfast +from . import fitting, utilsfast diff --git a/helita/utils/congrid.py b/helita/utils/congrid.py index c7fb1e1b..5360160a 100644 --- a/helita/utils/congrid.py +++ b/helita/utils/congrid.py @@ -84,7 +84,7 @@ def congrid(a, newdims, method='linear', centre=False, minusone=False): return newa elif method in ['spline']: oslices = [slice(0, j) for j in old] - oldcoords = np.ogrid[oslices] + np.ogrid[oslices] nslices = [slice(0, j) for j in list(newdims)] newcoords = np.mgrid[nslices] newcoords_dims = list(range(np.rank(newcoords))) diff --git a/helita/utils/fitting.py b/helita/utils/fitting.py index 8a823d67..f9d61f80 100644 --- a/helita/utils/fitting.py +++ b/helita/utils/fitting.py @@ -1,6 +1,6 @@ -from scipy.odr import odrpack as odr -from scipy.odr import models import numpy as np +from scipy.odr import models +from scipy.odr import odrpack as odr def gaussian(B, x): @@ -15,8 +15,8 @@ def double_gaussian(B, x): B = mean1, stdev1, max1, mean2, stdev2, max2, offset """ return B[2] / (B[1] * np.sqrt(2 * np.pi)) * np.exp(-((x - B[0])**2 / (2 * B[1]**2))) + \ - B[5] / (B[4] * np.sqrt(2 * np.pi)) * \ - np.exp(-((x - B[3])**2 / (2 * B[4]**2))) + B[6] + B[5] / (B[4] * np.sqrt(2 * np.pi)) * \ + np.exp(-((x - B[3])**2 / (2 * B[4]**2))) + B[6] def sine(B, x): diff --git a/helita/utils/radtrans.pyx b/helita/utils/radtrans.pyx index 03647e76..64d0aa74 100644 --- a/helita/utils/radtrans.pyx +++ b/helita/utils/radtrans.pyx @@ -23,7 +23,7 @@ def piecewise_1D(np.ndarray[DTYPE_t, ndim=1] height, Parameters ---------- height, chi, S: 1D arrays, float32 - height scale, absorption coefficient, source function. Height + height scale, absorption coefficient, source function. Height and chi must have consistent units (typically m and m^-1, respectively). Returns diff --git a/helita/utils/shell.py b/helita/utils/shell.py index 36a1551e..40f904c5 100644 --- a/helita/utils/shell.py +++ b/helita/utils/shell.py @@ -1,7 +1,6 @@ """ tools to deal with I/O on the shell """ -import sys class Getch: @@ -10,6 +9,7 @@ class Getch: Gets a single character from standard input. Does not echo to the screen. """ + def __init__(self): try: self.impl = _GetchWindows() @@ -22,8 +22,7 @@ def __call__(self): class _GetchUnix: def __init__(self): - import tty - import sys + pass def __call__(self): import sys @@ -41,7 +40,7 @@ def __call__(self): class _GetchWindows: def __init__(self): - import msvcrt + pass def __call__(self): import msvcrt diff --git a/helita/utils/trnslt.f90 b/helita/utils/trnslt.f90 deleted file mode 100644 index 283c3b6a..00000000 --- a/helita/utils/trnslt.f90 +++ /dev/null @@ -1,223 +0,0 @@ -! Fortran subroutines from SCATE -! -!------------------------------------------------------------------------------- -subroutine trnslt_old(mx,my,mz,nt,dx,dy,zt,ff,dxdz,dydz) - - ! Translate a scalar field to an inclined coordinate system. - ! - ! Operation count: 10m+8a = 18 flops/pnt - ! - ! Timing: - ! Alliant: 48 calls = 1.79 s -> 18*48*31*31*31/1790000 = 14 Mflops - ! - ! Update history: - ! - ! 28-oct-87/aake: Added 'nt' parameter; reduces work in shallow case - ! 02-nov-87/aake: Added 'zt' parameter, to allow separate zrad() - ! 06-nov-87/aake: Split derivative loops, to loop over simplest index - ! 27-aug-89/bob: Inverted some l,m loops to make l the inner loop - ! 31-aug-89/aake: Collapsed loops 100 and 200 to lm loops - - implicit none - integer,intent(in) :: mx,my,mz,nt - real,intent(in) :: dx,dy,dxdz,dydz - real,dimension(nt),intent(in) :: zt - real,dimension(mx,my,mz),intent(inout) :: ff - - real,dimension(mx,my) :: f,d - integer :: k,l,m,n,lp,lp1,mk,mk1 - real :: xk,yk,p,q,af,bf,ad,bd - - do n=1,nt - xk=dxdz*zt(n)/(mx*dx) - if(abs(xk).lt.0.001.or.mx.eq.1)cycle - xk=amod(xk,1.) - if(xk.lt.0.) xk=xk+1. - xk=mx*xk - k=xk - p=xk-k - q=1.-p - af=q+p*q*(q-p) - bf=p-p*q*(q-p) - ad=p*q*q - bd=-p*q*p - - ! Copy input to temporary - do m=1,my - do l=1,mx - f(l,m)=ff(l,m,n) - end do - end do - - ! Calculate derivatives by centered differences [1m+1a] - do m=1,my - do l=2,mx-1 - d(l,m)=0.5*(f(l+1,m)-f(l-1,m)) - end do - end do - do m=1,my - d(1,m)=0.5*(f(2,m)-f(mx,m)) - d(mx,m)=0.5*(f(1,m)-f(mx-1,m)) - end do - ! - ! Interpolate using cubic splines [4m+3a] - ! - do l=1,mx - lp=mod(l+k-1,mx)+1 - lp1=mod(l+1+k-1,mx)+1 - do m=1,my - ff(l,m,n)=af*f(lp,m)+bf*f(lp1,m)+ad*d(lp,m)+bd*d(lp1,m) - ! 0.29 sec - end do - end do - end do - - do n=1,nt - yk=dydz*zt(n)/(my*dy) - if(abs(yk).lt.0.001.or.my.eq.1)cycle - yk=amod(yk,1.) - if(yk.lt.0.) yk=yk+1. - yk=my*yk - k=yk - p=yk-k - q=1.-p - af=q+p*q*(q-p) - bf=p-p*q*(q-p) - ad=p*q*q - bd=-p*q*p - - ! Copy input to temporary - do m=1,my - do l=1,mx - f(l,m)=ff(l,m,n) - end do - end do - - ! Calculate derivatives by centered differences - do m=2,my-1 - do l=1,mx - d(l,m)=0.5*(f(l,m+1)-f(l,m-1)) - ! 0.16 sec - end do - end do - do l=1,mx - d(l,1)=0.5*(f(l,2)-f(l,my)) - d(l,my)=0.5*(f(l,1)-f(l,my-1)) - end do - - ! Interpolate using cubic splines - do m=1,my - mk=mod(m+k-1,my)+1 - mk1=mod(m+1+k-1,my)+1 - do l=1,mx - ff(l,m,n)=af*f(l,mk)+bf*f(l,mk1)+ad*d(l,mk)+bd*d(l,mk1) - ! 0.18 sec -> 48*31*31*31*7/0.18 = - end do - end do - end do - -end subroutine trnslt_old - - -!------------------------------------------------------------------------------- -subroutine trnslt(mx,my,mz,nzt,dx,dy,zt,f,dxdz,dydz) - - ! Translate a scalar field to an inclined coordinate system. - ! - ! Adapted from original routine by Nordlund and Stein - - implicit none - - integer, intent(in) :: mx, my, mz, nzt - real, intent(in) :: dx, dy, dxdz, dydz - real, dimension(nzt), intent(in) :: zt - real, dimension(mx,my,mz), intent(inout) :: f - real, dimension(mx,my) :: ftmp - - integer :: k, l, m, n, lm1, lp0, lp1, lp2, mm1, mp0, mp1, mp2 - real :: xk, yk, p, q, af, bf, ad, bd, ac, bc - - real, parameter :: eps=1.0e-6 - - - if (abs(dxdz).gt.eps) then - - do n=1,nzt - - xk = dxdz*zt(n)/(mx*dx) - xk = amod(xk,1.) - if (xk.lt.0.) xk = xk + 1. - xk = mx*xk - k = xk - p = xk-k - k = k + mx - q = 1.-p - af = q+p*q*(q-p) - bf = p-p*q*(q-p) - ad = p*q*q*0.5 - bd = -p*q*p*0.5 - ac = af-bd - bc = bf+ad - - do m=1,my - do l=1,mx - ftmp(l,m) = f(l,m,n) - end do - end do - - do l=1,mx - lm1 = mod(l+k-2,mx)+1 - lp0 = mod(l+k-1,mx)+1 - lp1 = mod(l+k ,mx)+1 - lp2 = mod(l+k+1,mx)+1 - - do m=1,my - f(l,m,n) = ac * ftmp(lp0,m) + bc * ftmp(lp1,m) - ad * ftmp(lm1,m) + bd * ftmp(lp2,m) - end do - end do - - end do - - end if - - - if (abs(dydz).gt.eps) then - - do n=1,nzt - - yk = dydz*zt(n)/(my*dy) - yk = amod(yk,1.) - if (yk.lt.0.) yk = yk + 1. - yk = my*yk - k = yk - p = yk - k - k = k + my - q = 1. - p - af = q+ p*q*(q-p) - bf = p-p*q*(q-p) - ad = p*q*q*0.5 - bd = -p*q*p*0.5 - ac = af-bd - bc = bf+ad - - do m=1,my - do l=1,mx - ftmp(l,m) = f(l,m,n) - end do - end do - - do m=1,my - mm1 = mod(m+k-2,my)+1 - mp0 = mod(m+k-1,my)+1 - mp1 = mod(m+k ,my)+1 - mp2 = mod(m+k+1,my)+1 - do l=1,mx - f(l,m,n) = ac * ftmp(l,mp0) + bc * ftmp(l,mp1) - ad * ftmp(l,mm1) + bd * ftmp(l,mp2) - end do - end do - - end do - - end if - -end subroutine trnslt diff --git a/helita/utils/utilsfast.pyx b/helita/utils/utilsfast.pyx index 6e3e4836..5b33dd55 100644 --- a/helita/utils/utilsfast.pyx +++ b/helita/utils/utilsfast.pyx @@ -5,8 +5,9 @@ # https://github.com/gasagna/openpiv-python/blob/master/openpiv/src/lib.pyx import numpy as np -cimport numpy as np + cimport cython +cimport numpy as np DTYPEf = np.float64 ctypedef np.float64_t DTYPEf_t @@ -719,11 +720,11 @@ cpdef fwhm_gen(np.ndarray[DTYPEf_t, ndim=1] x, np.ndarray[DTYPEf_t, ndim=3] spec): """ fwhm_gen(x, spec) - + Calculates the FWHM of a generic line profile. This is done by first calculating the line maximum, and then linearly interpolating the widest wings for half of that value. (Local maxima/minima are therefore ignored). - + Parameters ---------- x : 1-D ndarray (double type) @@ -735,13 +736,13 @@ cpdef fwhm_gen(np.ndarray[DTYPEf_t, ndim=1] x, Returns ------- blue_wing, red_wing : 3-D ndarrays - Arrays with blue and red wing. Same units as x. + Arrays with blue and red wing. Same units as x. """ cdef int nx = spec.shape[0] cdef int ny = spec.shape[1] cdef int nw = spec.shape[2] cdef int i, j, k, midw = nw // 2 - cdef float local_max + cdef float local_max cdef float local_min cdef float hm cdef np.ndarray[DTYPEf_t, ndim=2] blue_wing = np.zeros((nx,ny), dtype=DTYPEf) diff --git a/helita/utils/utilsmath.py b/helita/utils/utilsmath.py index 19c31aad..4b5842eb 100644 --- a/helita/utils/utilsmath.py +++ b/helita/utils/utilsmath.py @@ -1,7 +1,8 @@ -import numpy as np -from numba import vectorize, float32, float64 from math import exp +import numpy as np +from numba import float32, float64, vectorize + def hist2d(x, y, nbins=30, norm=False, rx=0.08): ''' @@ -97,8 +98,8 @@ def planck(wavelength, temp, dist='wavelength'): For solid angle integrated one must multiply it by pi. """ - from astropy.constants import c, h, k_B import astropy.units as u + from astropy.constants import c, h, k_B wave = wavelength.to('nm') if temp.shape and wave.shape: @@ -130,8 +131,8 @@ def int_to_bt(inu, wave): brightness_temp : `Quantity` object (number or sequence) Brightness temperature in SI units of temperature. """ - from astropy.constants import c, h, k_B import astropy.units as u + from astropy.constants import c, h, k_B bt = h * c / (wave * k_B * np.log(2 * h * c / (wave**3 * inu * u.rad**2) + 1)) return bt.si @@ -160,55 +161,9 @@ def trapz2d(z, x=None, y=None, dx=1., dy=1.): return 0.25 * dx * dy * (s1 + 2 * s2 + 4 * s3) -def translate(data, z, mu, phi, dx=1, dy=1): - """ - Horizontally rotates a 3D array with periodic horizontal boundaries - by a polar and azimuthal angle. Uses cubic splines, modifies data in-place - (therefore the rotation leads to an array with the same dimensions). - - Parameters - ---------- - data : 3D array, 32-bit float, F contiguous - Array with values. Last index should be height, the - non-periodic dimension. The rotation keeps the top and - bottom layers - z : 1D array, 32-bit float - Array with heights. - mu : float - Cosine of polar angle. - phi : float - Azimuthal angle in radians. - dx : float, optional - Grid separation in x dimension (same units as height). Default is 1. - dy : float, optional - Grid separation in y dimension (same units as height). Default is 1. - - Returns - ------- - None, data are modified in-place. - """ - from math import acos, sin, cos - try: - from .trnslt import trnslt - except ModuleNotFoundError: - raise ModuleNotFoundError('trnslt not found, helita probably installed' - ' without a fortran compiler!') - assert data.shape[-1] == z.shape[0] - assert data.flags['F_CONTIGUOUS'] - assert data.dtype == np.dtype("float32") - theta = acos(mu) - sinth = sin(theta) - tanth = sinth / mu - cosphi = cos(phi) - sinphi = sin(phi) - dxdz = tanth * cosphi - dydz = tanth * sinphi - trnslt(dx, dy, z, data, dxdz, dydz) - - @vectorize([float32(float32, float32), float64(float64, float64)]) def voigt(a, v): - """ + r""" Returns the Voigt function: H(a,v) = a/pi * \int_{-Inf}^{+Inf} exp(-y**2)/[(v-y)**2 + a**2] dy @@ -247,7 +202,7 @@ def voigt(a, v): return exp(-v ** 2) z = v * 1j + a h = (((((((a6 * z + a5) * z + a4) * z + a3) * z + a2) * z + a1) * z + a0) / - (((((((z + b6) * z + b5) * z + b4) * z + b3) * z + b2) * z + b1) * z + b0)) + (((((((z + b6) * z + b5) * z + b4) * z + b3) * z + b2) * z + b1) * z + b0)) return h.real diff --git a/helita/vis/__init__.py b/helita/vis/__init__.py deleted file mode 100644 index 37a0e772..00000000 --- a/helita/vis/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -Set of visualisation routines. -""" - -__all__ = ["radiative_transfer", "rh15d_vis"] diff --git a/helita/vis/radiative_transfer.py b/helita/vis/radiative_transfer.py deleted file mode 100644 index 0903de1c..00000000 --- a/helita/vis/radiative_transfer.py +++ /dev/null @@ -1,257 +0,0 @@ -""" -Set of functions and widgets for radiative transfer visualisations -""" -import warnings -import numpy as np -from pkg_resources import resource_filename -from scipy import interpolate as interp -import bqplot.pyplot as plt -from bqplot import LogScale -from ipywidgets import (interactive, Layout, HBox, VBox, Box, GridBox, - IntSlider, FloatSlider, Dropdown, HTMLMath) -from ..utils.utilsmath import voigt - - -def transp(): - """ - Instantiates the Transp() class, and shows the widget. - Runs only in Jupyter notebook or JupyterLab. Requires bqplot. - """ - warnings.simplefilter(action='ignore', category=FutureWarning) - return Transp().widget - - -class Transp(): - """ - Class for a widget illustrating line formation given a source function, - Voigt profile and opacity. - - Runs only in Jupyter notebook or JupyterLab. Requires bqplot. - """ - DATAFILE = resource_filename('helita', 'data/VAL3C_source_functions.npz') - data = np.load(DATAFILE) - # variable names inside data structure - SFUNCTIONS = {"VAL3C Mg": "s_nu_mg", "VAL3C Ca": "s_nu_ca", - "VAL3C LTE": "s_nu_lte"} - TAUS = {"VAL3C Mg": "t_500_mg", "VAL3C Ca": "t_500_ca", - "VAL3C LTE": "t_500_lte"} - # initial parameters - mu = 1.0 - npts = 101 - xmax = 50 - a = -2.5 - opa_cont = 0. - opa_line = 6.44 - source = "VAL3C Mg" - - def __init__(self): - self._compute_profile() - self._make_plot() - self._make_widget() - - def _compute_profile(self): - """ - Calculates the line profile given a a damping parameter, - source function, opacities, and mu. - """ - self.tau500 = self.data[self.TAUS[self.source]] - self.source_function = self.data[self.SFUNCTIONS[self.source]] - tau500 = self.tau500 - source_function = self.source_function - self.freq = np.linspace(-float(self.xmax), self.xmax, self.npts) - a = 10. ** self.a - self.h = voigt(a, self.freq) - self.xq = self.h * 10. ** self.opa_line + 10. ** self.opa_cont - xq = self.xq - self.tau500_cont = self.mu / 10 ** self.opa_cont - self.tau500_line = self.mu / self.xq.max() - f = interp.interp1d(tau500, source_function, bounds_error=False) - self.source_function_cont = f(self.tau500_cont)[()] - self.source_function_line = f(self.tau500_line)[()] - xq = xq[:, np.newaxis] - tmp = source_function * np.exp(-xq * tau500 / self.mu) * xq * tau500 - self.prof = np.log(10) / self.mu * np.trapz(tmp.T, np.log(tau500), - axis=0) - - def _make_plot(self): - plt.close(1) - fig_margin = {'top': 25, 'bottom': 35, 'left': 35, 'right':25} - fig_layout = {'height': '100%', 'width': '100%' } - layout_args = {'fig_margin': fig_margin, 'layout': fig_layout, - 'max_aspect_ratio': 1.618} - self.voigt_fig = plt.figure(1, title='Voigt profile', **layout_args) - self.voigt_plot = plt.plot(self.freq, self.h, scales={'y': LogScale()}) - plt.xlabel("Δν / ΔνD") - - plt.close(2) - self.abs_fig = plt.figure(2, title='(αᶜ + αˡ) / α₅₀₀', **layout_args) - self.abs_plot = plt.plot(self.freq, self.xq, scales={'y': LogScale()}) - plt.xlabel("Δν / ΔνD") - - plt.close(3) - self.int_fig = plt.figure(3, title='Intensity', **layout_args) - self.int_plot = plt.plot(self.freq, self.prof, scales={'y': LogScale()}) - plt.xlabel("Δν / ΔνD") - - plt.close(4) - self.source_fig = plt.figure(4, title='Source Function', **layout_args) - self.source_plot = plt.plot(np.log10(self.tau500), self.source_function, - scales={'y': LogScale()}) - plt.xlabel("lg(τ₅₀₀)") - self.tau_labels = plt.label(['τᶜ = 1', 'τˡ = 1'], colors=['black'], - x=np.array([np.log10(self.tau500_cont), - np.log10(self.tau500_line)]), - y=np.array([self.source_function_cont, - self.source_function_line]), - y_offset=-25, align='middle') - self.tau_line_plot = plt.plot(np.array([np.log10(self.tau500_line), - np.log10(self.tau500_line)]), - np.array([self.source_function_line / 1.5, - self.source_function_line * 1.5]), - colors=['black']) - self.tau_cont_plot = plt.plot(np.array([np.log10(self.tau500_cont), - np.log10(self.tau500_cont)]), - np.array([self.source_function_cont / 1.5, - self.source_function_cont * 1.5]), - colors=['black']) - - def _update_plot(self, a, opa_cont, opa_line, mu, xmax, source): - self.a = a - self.opa_cont = opa_cont - self.opa_line = opa_line - self.mu = mu - self.xmax = xmax - self.source = source - self._compute_profile() - self.voigt_plot.x = self.freq - self.voigt_plot.y = self.h - self.abs_plot.x = self.freq - self.abs_plot.y = self.xq - self.int_plot.x = self.freq - self.int_plot.y = self.prof - self.source_plot.x = np.log10(self.tau500) - self.source_plot.y = self.source_function - self.tau_labels.x = np.array([np.log10(self.tau500_cont), - np.log10(self.tau500_line)]) - self.tau_labels.y = np.array([self.source_function_cont, - self.source_function_line]) - self.tau_line_plot.x = [np.log10(self.tau500_line), - np.log10(self.tau500_line)] - self.tau_line_plot.y = [self.source_function_line / 1.5, - self.source_function_line * 1.5] - self.tau_cont_plot.x = [np.log10(self.tau500_cont), - np.log10(self.tau500_cont)] - self.tau_cont_plot.y = [self.source_function_cont / 1.5, - self.source_function_cont * 1.5] - - - def _make_widget(self): - fig = GridBox(children=[self.voigt_fig, self.abs_fig, - self.int_fig, self.source_fig], - layout=Layout(width='100%', - min_height='600px', - height='100%', - grid_template_rows='49% 49%', - grid_template_columns='49% 49%', - grid_gap='0px 0px')) - - a_slider = FloatSlider(min=-5, max=0., step=0.01, value=self.a, - description='lg(a)') - opa_cont_slider = FloatSlider(min=0., max=6., step=0.01, - value=self.opa_cont, description=r"$\kappa_c / \kappa_{500}$") - opa_line_slider = FloatSlider(min=0., max=7., step=0.01, - value=self.opa_line, description=r"$\kappa_l / \kappa_{500}$") - mu_slider = FloatSlider(min=0.01, max=1., step=0.01, - value=self.mu, description=r'$\mu$') - xmax_slider = IntSlider(min=1, max=100, step=1, value=self.xmax, - description='xmax') - source_slider = Dropdown(options=self.SFUNCTIONS.keys(), value=self.source, - description='Source Function', - style={'description_width': 'initial'}) - w = interactive(self._update_plot, a=a_slider, opa_cont=opa_cont_slider, - opa_line=opa_line_slider, mu=mu_slider, - xmax=xmax_slider, source=source_slider) - controls = GridBox(children=[w.children[5], w.children[0], - w.children[2], w.children[4], - w.children[3], w.children[1]], - layout=Layout(min_height='80px', - min_width='600px', - grid_template_rows='49% 49%', - grid_template_columns='31% 31% 31%', - grid_gap='10px')) - self.widget = GridBox(children=[controls, fig], - layout=Layout(grid_template_rows='8% 90%', - width='100%', - min_height='650px', - height='100%', - grid_gap='10px')) - - -def slab(): - """ - Displays a widget illustrating line formation in a homogenous slab. - - Runs only in Jupyter notebook or JupyterLab. Requires bqplot. - """ - # Don't display some ipywidget warnings - warnings.simplefilter(action='ignore', category=FutureWarning) - - def _compute_slab(i0, source, tau_cont, tau_line): - """ - Calculates slab line profile. - """ - NPT = 101 - MAX_DX = 5. - x = np.arange(NPT) - (NPT - 1.) / 2 - x *= MAX_DX / x.max() - tau = tau_cont + tau_line * np.exp(-x * x) - extinc = np.exp(-tau) - intensity = float(i0) * extinc + float(source) * (1. - extinc) - return (x, intensity) - - I0 = 15 - S = 65 - x, y = _compute_slab(I0, S, 0.5, 0.9) - base = np.zeros_like(x) - fig = plt.figure(title='Slab line formation') - int_plot = plt.plot(x, y, 'b-') - source_line = plt.plot(x, base + S, 'k--') - i0_line = plt.plot(x, base + I0, 'k:') - labels = plt.label(['I₀', 'I', 'S'], - x=np.array([int_plot.x[0] + 0.2, int_plot.x[-1] - 0.2, - int_plot.x[0] + 0.2]), - y=np.array([i0_line.y[0], int_plot.y[0], - source_line.y[0]]) + 2, - colors=['black']) - plt.ylim(0, 100) - i0_slider = IntSlider(min=0, max=100, value=I0, description=r'$I_0$') - s_slider = IntSlider(min=0, max=100, value=S, description=r'$S$') - tau_c_slider = FloatSlider(min=0, max=1., step=0.01, value=0.5, - description=r'$\tau_{\mathrm{cont}}$') - tau_l_slider = FloatSlider(min=0, max=10., step=0.01, value=0.9, - description=r'$\tau_{\mathrm{line}}$') - - def plot_update(i0=I0, source=S, tau_cont=0.5, tau_line=0.9): - _, y = _compute_slab(i0, source, tau_cont, tau_line) - int_plot.y = y - source_line.y = base + source - i0_line.y = base + i0 - labels.y = np.array([i0, y[0], source]) + 2 - - widg = interactive(plot_update, i0=i0_slider, source=s_slider, - tau_cont=tau_c_slider, tau_line=tau_l_slider) - help_w = HTMLMath("

Purpose: " - "This widget-based procedure is used for " - "studying spectral line formation in a " - "homogeneous slab.

" - "

Inputs:

" - "
    " - r"
  • $I_0$: The incident intensity.
  • " - r"
  • $S$: The source function.
  • " - r"
  • $\tau_{\mathrm{cont}}$ : The continuum optical depth.
  • " - r"
  • $\tau_{\mathrm{line}}$ : The integrated optical depth in the spectral line.
  • " - "
") - return HBox([VBox([widg, help_w], - layout=Layout(width='33%', top='50px', left='5px')), - Box([fig], layout=Layout(width='66%'))], - layout=Layout(border='50px')) diff --git a/helita/vis/rh15d_vis.py b/helita/vis/rh15d_vis.py deleted file mode 100644 index 8f3d2d9d..00000000 --- a/helita/vis/rh15d_vis.py +++ /dev/null @@ -1,274 +0,0 @@ -""" -Set of programs and tools visualise the output from RH, 1.5D version -""" -import os -import numpy as np -import xarray as xr -import matplotlib.pyplot as plt -from pkg_resources import resource_filename -from ipywidgets import interact, fixed, Dropdown, IntSlider, FloatSlider -from scipy.integrate import cumtrapz -from scipy.interpolate import interp1d -from astropy import units as u -from ..utils.utilsmath import planck, voigt - - -class Populations: - """ - Class to visualise the populations from an RH 1.5D object. - """ - def __init__(self, rh_object): - self.rhobj = rh_object - self.atoms = [a for a in dir(self.rhobj) if a[:5] == 'atom_'] - self.display() - - def display(self): - """ - Displays a graphical widget to explore the level populations. - Works in jupyter only. - """ - atoms = {a.split('_')[1].title(): a for a in self.atoms} - quants = ['Populations', 'LTE Populations', 'Departure coefficients'] - #nlevel = getattr(self.rhobj, self.atoms[0]).nlevel - nx, ny, nz = self.rhobj.atmos.temperature.shape - if nx == 1: - x_slider = fixed(0) - else: - x_slider = (0, nx - 1) - if ny == 1: - y_slider = fixed(0) - else: - y_slider = (0, ny - 1) - - def _pop_plot(atom): - """Starts population plot""" - pop = getattr(self.rhobj, atom).populations - height = self.rhobj.atmos.height_scale[0, 0] / 1e6 # in Mm - _, ax = plt.subplots() - pop_plot, = ax.plot(height, pop[0, 0, 0]) - ax.set_xlabel("Height (Mm)") - ax.set_ylabel("Populations") - ax.set_title("Level 1") - return ax, pop_plot - - ax, p_plot = _pop_plot(self.atoms[0]) - - @interact(atom=atoms, quantity=quants, y_log=False, - x=x_slider, y=y_slider) - def _pop_update(atom, quantity, y_log=False, x=0, y=0): - nlevel = getattr(self.rhobj, atom).nlevel - - # Atomic level singled out because nlevel depends on the atom - @interact(level=(1, nlevel)) - def _pop_update_level(level=1): - n = getattr(self.rhobj, atom).populations[level - 1, x, y] - nstar = getattr( - self.rhobj, atom).populations_LTE[level - 1, x, y] - if quantity == 'Departure coefficients': - tmp = n / nstar - ax.set_ylabel(quantity + ' (n / n*)') - elif quantity == 'Populations': - tmp = n - ax.set_ylabel(quantity + ' (m$^{-3}$)') - elif quantity == 'LTE Populations': - tmp = nstar - ax.set_ylabel(quantity + ' (m$^{-3}$)') - p_plot.set_ydata(tmp) - ax.relim() - ax.autoscale_view(True, True, True) - ax.set_title("Level %i, x=%i, y=%i" % (level, x, y)) - if y_log: - ax.set_yscale("log") - else: - ax.set_yscale("linear") - - -class SourceFunction: - """ - Class to visualise the source function and opacity from an RH 1.5D object. - """ - def __init__(self, rh_object): - self.rhobj = rh_object - self.display() - - def display(self): - """ - Displays a graphical widget to explore the source function. - Works in jupyter only. - """ - nx, ny, nz, nwave = self.rhobj.ray.source_function.shape - if nx == 1: - x_slider = fixed(0) - else: - x_slider = (0, nx - 1) - if ny == 1: - y_slider = fixed(0) - else: - y_slider = (0, ny - 1) - tau_levels = [0.3, 1., 3.] - ARROW = dict(facecolor='black', width=1., headwidth=5, headlength=6) - #SCALES = ['Height', 'Optical depth'] - - def __get_tau_levels(x, y, wave): - """ - Calculates height where tau=0.3, 1., 3 for a given - wavelength index. - Returns height in Mm and closest indices of height array. - """ - h = self.rhobj.atmos.height_scale[x, y].dropna('height') - tau = cumtrapz(self.rhobj.ray.chi[x, y, :, wave].dropna('height'), - x=-h) - tau = interp1d(tau, h[1:])(tau_levels) - idx = np.around(interp1d(h, - np.arange(h.shape[0]))(tau)).astype('i') - return (tau / 1e6, idx) # in Mm - - def _sf_plot(): - """Starts source function plot""" - obj = self.rhobj - sf = obj.ray.source_function[0, 0, :, 0].dropna('height') - height = obj.atmos.height_scale[0, 0].dropna( - 'height') / 1e6 # in Mm - bplanck = planck(obj.ray.wavelength_selected[0].values * u.nm, - obj.atmos.temperature[0, 0].dropna('height').values * u.K, - dist='frequency').value - fig, ax = plt.subplots() - ax.plot(height, sf, 'b-', label=r'S$_\mathrm{total}$', lw=1) - ax.set_yscale('log') - ax.plot(height, obj.ray.Jlambda[0, 0, :, 0].dropna('height'), - 'y-', label='J', lw=1) - ax.plot(height, bplanck, 'r--', label=r'B$_\mathrm{Planck}$', - lw=1) - ax.set_xlabel("Height (Mm)") - ax.set_ylabel(r"W m$^{-2}$ Hz$^{-1}$ sr$^{-1}$") - ax.set_title("%.3f nm" % obj.ray.wavelength_selected[0]) - lg = ax.legend(loc='upper center') - lg.draw_frame(0) - # tau annotations - tau_v, h_idx = __get_tau_levels(0, 0, 0) - for i, level in enumerate(tau_levels): - xval = tau_v[i] - yval = sf[h_idx[i]] - ax.annotate(r'$\tau$=%s' % level, - xy=(xval, yval), - xytext=(xval, yval / (0.2 - 0.03 * i)), - arrowprops=ARROW, ha='center', va='top') - return ax - - ax = _sf_plot() - - @interact(wavelength=(0, nwave - 1, 1), y_log=True, - x=x_slider, y=y_slider) - def _sf_update(wavelength=0, y_log=True, x=0, y=0): - obj = self.rhobj - bplanck = planck(obj.ray.wavelength_selected[wavelength].values * u.nm, - obj.atmos.temperature[x, y].dropna('height').values * u.K, - dist='frequency').value - quants = [obj.ray.source_function[x, y, :, - wavelength].dropna('height'), - obj.ray.Jlambda[x, y, :, wavelength].dropna('height'), - bplanck] - for i, q in enumerate(quants): - ax.lines[i].set_ydata(q) - ax.relim() - ax.autoscale_view(True, True, True) - ax.set_title("%.3f nm" % obj.ray.wavelength_selected[wavelength]) - # tau annotations: - tau_v, h_idx = __get_tau_levels(x, y, wavelength) - for i in range(len(tau_levels)): - xval = tau_v[i] - yval = quants[0][h_idx[i]] - ax.texts[i].xy = (xval, yval) - ax.texts[i].set_position((xval, yval / (0.2 - 0.03 * i))) - if y_log: - ax.set_yscale("log") - else: - ax.set_yscale("linear") - - -class InputAtmosphere: - def __init__(self, filename): - self.atmos = xr.open_dataset(filename) - self.filename = filename - self.display() - - def display(self): - """ - Displays a graphical widget to explore the input (HDF5) atmosphere. - """ - ntsteps, nx, ny, nz = self.atmos.temperature.shape - if ntsteps == 1: - tslider = fixed(0) - else: - tslider = (0, ntsteps - 1) - if nx == 1: - x_slider = fixed(0) - else: - x_slider = (0, nx - 1) - if ny == 1: - y_slider = fixed(0) - else: - y_slider = (0, ny - 1) - - def _atmos_plot(): - """Starts source function plot""" - EXCLUDES = ['x', 'y', 'z', 'snapshot_number'] - self.variables = [v for v in self.atmos.variables - if v not in EXCLUDES] - nrows = int(np.ceil(len(self.variables) / 2.)) - fig, ax = plt.subplots(nrows, 2, sharex=True, - figsize=(7, 2. * nrows)) - for i, v in enumerate(self.variables): - var = self.atmos.variables[v] - if v[:8].lower() == 'velocity': # to km/s - ax.flat[i].plot(self.atmos.z[0] / 1e6, var[0, 0, 0] / 1.e3) - ax.flat[i].set_ylabel("%s (km/s)" % v.title()) - elif v.lower() == "hydrogen_populations": - ax.flat[i].plot(self.atmos.z[0] / 1e6, - var[0, :, 0, 0].sum(axis=0)) - ax.flat[i].set_ylabel("Htot (m^-3)") - ax.flat[i].set_yscale("log") - else: - ax.flat[i].plot(self.atmos.z[0] / 1e6, var[0, 0, 0]) - units = '' - if 'units' in var.attrs: - units = var.attrs['units'] - ax.flat[i].set_ylabel("%s (%s)" % (v.title(), units)) - ax.flat[i].set_xlabel("Height (Mm)") - if i == 0: - ax.flat[i].set_title(os.path.split(self.filename)[1]) - if i == 1: - ax.flat[i].set_title("snapshot=%i, x=%i, y=%i" % (0, 0, 0)) - fig.tight_layout() - return ax - - ax = _atmos_plot() - - @interact(snapshot=tslider, x=x_slider, y=y_slider, y_log=True) - def _atmos_update(snapshot=0, x=0, y=0, y_log=True): - for i, v in enumerate(self.variables): - var = self.atmos.variables[v] - if v[:8].lower() == 'velocity': # to km/s - ydata = var[snapshot, x, y] / 1.e3 - elif v.lower() == "hydrogen_populations": - ydata = var[snapshot, :, x, y].sum(axis=0) - else: - ydata = var[snapshot, x, y] - ax.flat[i].lines[0].set_ydata(ydata) - if len(self.atmos.z.shape) == 2: - zdata = self.atmos.z[snapshot] / 1e6 - elif len(self.atmos.z.shape) == 4: - zdata = self.atmos.z[snapshot, x, y] / 1e6 - else: - raise ValueError("Invalid shape of z array") - ax.flat[i].lines[0].set_xdata(zdata) - ax.flat[i].relim() - ax.flat[i].autoscale_view(True, True, True) - if i == 1: - tmp = "snapshot=%i, x=%i, y=%i" % (snapshot, x, y) - ax.flat[i].set_title(tmp) - if v[:2].lower() not in ['ve', 'b_']: # no log in v and B - if y_log: - ax.flat[i].set_yscale("log") - else: - ax.flat[i].set_yscale("linear") diff --git a/mkdocs.yml b/mkdocs.yml index 83656693..96ad78ff 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -40,7 +40,7 @@ markdown_extensions: - pymdownx.critic - pymdownx.details - pymdownx.emoji: - emoji_generator: !!python/name:pymdownx.emoji.to_svg + emoji_generator: "!!python/name:pymdownx.emoji.to_svg" - pymdownx.inlinehilite - pymdownx.magiclink - pymdownx.mark diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..2e890936 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,9 @@ +[build-system] +requires = [ + "cython", + "oldest-supported-numpy", + "setuptools_scm[toml]", + "setuptools", + "wheel", + ] +build-backend = 'setuptools.build_meta' diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..647222a7 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,96 @@ +[metadata] +name = helita +provides = helita +description = lar physics python tools from ITA/UiO +long_description = file: README.md +long_description_content_type = text/x-md +author = Tiago M. D. Pereira et al. +author_email = tiago.pereira@astro.uio.no +license = BSD 3-Clause +license_files = LICENSE +url = https://ita-solar.github.io/helita/ +download_url = https://pypi.org/project/helita/ +project_urls= + Source Code = https://github.com/ITA-Solar/helita + Documentation = https://ita-solar.github.io/helita/ + Issue Tracker = https://github.com/ITA-Solar/helita/issues +edit_on_github = True +github_project = ITA-Solar/helita +platform = any +keywords = astronomy, astrophysics, solar physics, sun, space, science +classifiers = + Intended Audience :: Science/Research + License :: OSI Approved :: BSD License + Natural Language :: English + Operating System :: OS Independent + Programming Language :: Python + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Topic :: Scientific/Engineering :: Physics + Topic :: Scientific/Engineering :: Astronomy + +[options] +zip_safe = False +python_requires = >=3.8 +packages = find: +include_package_data = True +setup_requires = + setuptools_scm +install_requires = + astropy + matplotlib + numpy + pandas + scipy + sunpy + tqdm + xarray[io] + numba + radynpy + + +[options.extras_require] +ebysus = + zarr +tests = + pytest + +[options.package_data] +helita.data = * + +[tool:pytest] +testpaths = "helita" +norecursedirs = ".tox" "build" "docs" "*.egg-info" ".history" +doctest_plus = enabled +doctest_optionflags = NORMALIZE_WHITESPACE FLOAT_CMP ELLIPSIS +addopts = -p no:unraisableexception -p no:threadexception +markers = + remote_data: marks this test function as needing remote data. + online: marks this test function as needing online connectivity. +remote_data_strict = True +filterwarnings = + ignore + +[pycodestyle] +max_line_length = 110 + +[flake8] +max-line-length = 110 +exclude = + .git, + __pycache__, + +[isort] +balanced_wrapping = True +default_section = THIRDPARTY +include_trailing_comma = True +known_first_party = helita +length_sort = False +length_sort_sections=stdlib +line_length = 110 +multi_line_output = 3 +no_lines_before = LOCALFOLDER +sections = STDLIB, THIRDPARTY, FIRSTPARTY, LOCALFOLDER \ No newline at end of file diff --git a/setup.py b/setup.py index e205a1d1..3e814ae5 100644 --- a/setup.py +++ b/setup.py @@ -1,71 +1,66 @@ +#!/usr/bin/env python +from setuptools import setup # isort:skip import os +from itertools import chain + import numpy -import setuptools +from Cython.Build import cythonize +from numpy.distutils import fcompiler from numpy.distutils.core import setup from numpy.distutils.extension import Extension -from numpy.distutils import fcompiler +try: + # Recommended for setuptools 61.0.0+ + # (though may disappear in the future) + from setuptools.config.setupcfg import read_configuration +except ImportError: + from setuptools.config import read_configuration -try: # do we have cython? - from Cython.Build import cythonize - USE_CYTHON = True -except: - USE_CYTHON = False -USE_FORTRAN = fcompiler.get_default_fcompiler() +################################################################################ +# Programmatically generate some extras combos. +################################################################################ +extras = read_configuration("setup.cfg")['options']['extras_require'] -NAME = "helita" -PACKAGES = ["data", "io", "obs", "sim", "utils", "vis"] -VERSION = "0.9.0" +# Dev is everything +extras['dev'] = list(chain(*extras.values())) -ext = '.pyx' if USE_CYTHON else '.c' +# All is everything but tests and docs +exclude_keys = ("tests", "docs", "dev") +ex_extras = dict(filter(lambda i: i[0] not in exclude_keys, extras.items())) +# Concatenate all the values together for 'all' +extras['all'] = list(chain.from_iterable(ex_extras.values())) + +################################################################################ +# Cython extensions +################################################################################ NUMPY_INC = numpy.get_include() -EXT_PACKAGES = { # C and Fortran extensions - "anapyio" : ["io", [NUMPY_INC, os.path.join(NAME, "io/src")], - [os.path.join(NAME, "io/anapyio" + ext), - os.path.join(NAME, "io/src/libf0.c"), - os.path.join(NAME, "io/src/anacompress.c"), - os.path.join(NAME, "io/src/anadecompress.c")]], - "radtrans" : ["utils", [NUMPY_INC], - [os.path.join(NAME, "utils/radtrans" + ext)]], - "utilsfast" : ["utils", [NUMPY_INC], - [os.path.join(NAME, "utils/utilsfast" + ext)]] +EXT_PACKAGES = { + "anapyio": ["io", [NUMPY_INC, os.path.join("helita", "io/src")], + [os.path.join("helita", "io/anapyio.pyx"), + os.path.join("helita", "io/src/libf0.c"), + os.path.join("helita", "io/src/anacompress.c"), + os.path.join("helita", "io/src/anadecompress.c")]], + "cstagger": ["sim", [NUMPY_INC], + [os.path.join("helita", "sim/cstagger.pyx")]], + "radtrans": ["utils", [NUMPY_INC], + [os.path.join("helita", "utils/radtrans.pyx")]], + "utilsfast": ["utils", [NUMPY_INC], + [os.path.join("helita", "utils/utilsfast.pyx")]] } -if USE_FORTRAN: - EXT_PACKAGES["trnslt"] = ["utils", [], [os.path.join(NAME, "utils/trnslt.f90")]] - extensions = [ Extension( - name="%s.%s.%s" % (NAME, pprop[0], pname), + name=f"helita.{pprop[0]}.{pname}", include_dirs=pprop[1], sources=pprop[2]) for pname, pprop in EXT_PACKAGES.items() ] +extensions = cythonize(extensions, compiler_directives={'language_level': "3"}) -if USE_CYTHON: # Always compile for Python 3 (v2 no longer supported) - extensions = cythonize(extensions, compiler_directives={'language_level' : "3"}) - +################################################################################ +# Setup +################################################################################ setup( - name=NAME, - version=VERSION, - description="Solar physics python tools from ITA/UiO", - author="Tiago M. D. Pereira et al.", - license="BSD", - url="http://%s.readthedocs.io" % NAME, - keywords=['astronomy', 'astrophysics', 'solar physics', 'space', 'science'], - classifiers=[ - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: BSD License', - 'Operating System :: OS Independent', - 'Programming Language :: C', - 'Programming Language :: Cython', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: Implementation :: CPython', - 'Topic :: Scientific/Engineering :: Astronomy', - 'Topic :: Scientific/Engineering :: Physics' - ], - packages=[NAME] + ["%s.%s" % (NAME, package) for package in PACKAGES], - package_data={'': ['*.pyx', '*.f90', 'data/*']}, + extras_require=extras, + use_scm_version=True, ext_modules=extensions, - python_requires='>=2.7', - use_2to3=False ) diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..51bac61d --- /dev/null +++ b/tox.ini @@ -0,0 +1,30 @@ +[tox] +envlist = + py{38,39,310,311} + codestyle +requires = + setuptools + pip +isolated_build = true + +[testenv] +changedir = .tmp/{envname} +description = + run tests +setenv = + PYTEST_COMMAND = pytest -vvv -r a --pyargs helita +extras = + all + tests +commands = + pip freeze --all --no-input + {env:PYTEST_COMMAND} {posargs} + +[testenv:codestyle] +skip_install = true +description = Run all style and file checks with pre-commit +deps = + pre-commit +commands = + pre-commit install-hooks + pre-commit run --color always --all-files --show-diff-on-failure