Skip to content

Commit

Permalink
Maint (#337)
Browse files Browse the repository at this point in the history
* Reduce code by making use of new SciPy new features (complex-valued map_coordinate; lazy loading).
* Add current year to docs
  • Loading branch information
prisae authored Jul 19, 2024
1 parent a16e127 commit fad7142
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 18 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ Changelog
latest
------

- Add notes for ``ipympl`` (interactive plots in modern Jupyter).
Maintenance

- Add notes for ``ipympl`` (interactive plots in modern Jupyter).
- Reduce code by making use of new SciPy new features (complex-valued
map_coordinate; lazy loading).


v1.8.3 : tol_gradient isfinite
Expand Down
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from emg3d import __version__

# ==== 1. Extensions ====
Expand Down Expand Up @@ -56,7 +57,7 @@
# General information about the project.
project = 'emg3d'
author = 'The emsig community'
copyright = f'2018, {author}'
copyright = f'2018-{time.strftime("%Y")}, {author}'

# |version| and |today| tags (|release|-tag is not used).
version = __version__
Expand Down
10 changes: 1 addition & 9 deletions emg3d/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,15 +545,7 @@ def interp_spline_3d(points, values, xi, **kwargs):
bounds_error=False, fill_value='extrapolate'
)(xi[:, i])

# `map_coordinates` only works for real data; split it up if complex.
# Note: SciPy 1.6 (12/2020) introduced complex-valued
# ndimage.map_coordinates; replace eventually.
values_x = sp.ndimage.map_coordinates(values.real, coords, **kwargs)
if 'complex' in values.dtype.name:
imag = sp.ndimage.map_coordinates(values.imag, coords, **kwargs)
values_x = values_x + 1j*imag

return values_x
return sp.ndimage.map_coordinates(values, coords, **kwargs)


@nb.njit(**_numba_setting)
Expand Down
12 changes: 5 additions & 7 deletions emg3d/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@

import numpy as np
import scipy as sp
import scipy.linalg as sl
import scipy.sparse.linalg as ssl

from emg3d import core, meshes, models, fields, utils

Expand Down Expand Up @@ -311,7 +309,7 @@ def solve(model, sfield, sslsolver=True, semicoarsening=True,
var.cprint(var, 2)

# Compute reference error for tolerance.
var.l2_refe = sl.norm(sfield.field, check_finite=False)
var.l2_refe = sp.linalg.norm(sfield.field, check_finite=False)
var.error_at_cycle[0] = var.l2_refe

# Check sfield.
Expand Down Expand Up @@ -704,7 +702,7 @@ def amatvec(efield):
return -rfield.field

# Initiate LinearOperator A x.
A = ssl.LinearOperator(
A = sp.sparse.linalg.LinearOperator(
shape=(sfield.field.size, sfield.field.size),
dtype=sfield.field.dtype, matvec=amatvec)

Expand All @@ -725,7 +723,7 @@ def mg_matvec(sfield):
# Initiate LinearOperator M.
M = None
if var.cycle:
M = ssl.LinearOperator(
M = sp.sparse.linalg.LinearOperator(
shape=(sfield.field.size, sfield.field.size),
dtype=sfield.field.dtype, matvec=mg_matvec)

Expand Down Expand Up @@ -762,7 +760,7 @@ def callback(x):
# The ssl solvers do not abort if the norm diverges or is not finite. We
# therefore throw an exception in `_terminate`, and catch it here.
try:
efield.field, i = getattr(ssl, var.sslsolver)(
efield.field, i = getattr(sp.sparse.linalg, var.sslsolver)(
A=A, b=sfield.field, x0=efield.field, **{TOL: var.tol},
maxiter=var.ssl_maxit, atol=1e-30, M=M, callback=callback)
except _ConvergenceError:
Expand Down Expand Up @@ -1065,7 +1063,7 @@ def residual(model, sfield, efield, norm=False):

# Return error if norm.
if norm:
return sl.norm(rfield.field, check_finite=False)
return sp.linalg.norm(rfield.field, check_finite=False)

# Return residual if not norm.
else:
Expand Down

0 comments on commit fad7142

Please sign in to comment.