I tried installing JAX under Windows 10 using Miniconda3: conda create -n jax
conda activate jax
conda install pip
pip install jax but that didn't work; python -c "import jax" raises
Then, seeing that jax is in conda-forge, I tried: conda deactivate
conda env remove -n jax
conda create -n jax
conda activate jax
conda install -c conda-forge jax but that failed with
Thanks for trying @gdmcbain . I think my history with trying to use jax on windows was limited to using either
But I have a spare linux machine where I can try experimenting with jax. |
I'd prefer not to actually depend on something like JAX. On the other hand, trying it out revealed already how an interface utilizing these autodiff tools should look like. So maybe we can eventually provide some additional utilities so that it's easier to use JAX or autograd or similar tools in scikit-fem. PS. Does it work in Docker? |
I think it should. They do have a Dockerfile on their page. I can test build an image today. Infact if it works with Docker then it's the easiest way out for users to try on windows (even otherwise containerizing would help keep everything clean). In fact pretty much the same approach FEniCS takes to it's windows support. |
Checked! And it's working. demo container |
I tried sketching an example to reproduce #439 with JAX. This is extremely slow even with #442 and as @kinnala said more of an issue of how to properly use JAX. I'll come to it in future to experiment and improve the speed but will have to learn JAX before that from jax.numpy import vectorize
import jax.numpy as jnp
import numpy as np
from import from_meshio
from skfem.helpers import grad, dot, transpose
from jax import jit, grad as jgrad, jacfwd, jacrev, pmap, vmap
from numba import njit, prange
from time import time
import meshio, os
from scipy.sparse.linalg import spilu, LinearOperator
from scipy.optimize import root
from typing import List, Optional
from skfem.models.elasticity import linear_elasticity
from skfem import *
def restBoundary(x):
returns the dofs that are not located on the left/right face
for application of surface traction
topBottom = np.logical_or(
np.abs(x[1]) < 1.e-6, np.abs(x[1] - 1.) < 1.e-6
frontBack = np.logical_or(
np.abs(x[2]) < 1.e-6, np.abs(x[2] - 1.) < 1.e-6
leftRight = np.logical_or(
np.abs(x[0]) < 1.e-6, np.abs(x[0] - 1.) < 1.e-6
return np.logical_and(np.logical_or(topBottom, frontBack), np.logical_not(leftRight))
def vdet(A):
detA = jnp.zeros_like(A[0, 0])
detA = A[0, 0] * (A[1, 1] * A[2, 2] -
A[1, 2] * A[2, 1]) -\
A[0, 1] * (A[1, 0] * A[2, 2] -
A[1, 2] * A[2, 0]) +\
A[0, 2] * (A[1, 0] * A[2, 1] -
A[1, 1] * A[2, 0])
return detA
def vinv(A):
invA = jnp.zeros_like(A)
detA = vdet(A)
invA[0, 0] = (-A[1, 2] * A[2, 1] +
A[1, 1] * A[2, 2]) / detA
invA[1, 0] = (A[1, 2] * A[2, 0] -
A[1, 0] * A[2, 2]) / detA
invA[2, 0] = (-A[1, 1] * A[2, 0] +
A[1, 0] * A[2, 1]) / detA
invA[0, 1] = (A[0, 2] * A[2, 1] -
A[0, 1] * A[2, 2]) / detA
invA[1, 1] = (-A[0, 2] * A[2, 0] +
A[0, 0] * A[2, 2]) / detA
invA[2, 1] = (A[0, 1] * A[2, 0] -
A[0, 0] * A[2, 1]) / detA
invA[0, 2] = (-A[0, 2] * A[1, 1] +
A[0, 1] * A[1, 2]) / detA
invA[1, 2] = (A[0, 2] * A[1, 0] -
A[0, 0] * A[1, 2]) / detA
invA[2, 2] = (-A[0, 1] * A[1, 0] +
A[0, 0] * A[1, 1]) / detA
return invA, detA
def psi(F11, F12, F13, F21, F22, F23, F31, F32, F33):
I1 = F11**2 + F12**2. + F13**2 + F21**2. + F22**2 + F23**2. + F31**2. + F32**2. + F33**2 #jnp.einsum('ij...,ij...', F, F)
J = F11*F22*F33 - F11*F23*F32 - F12*F21*F33 + F12*F23*F31 + F13*F21*F32 - F13*F22*F31#vdet(F)
return mu/2.*(I1 - 3.) - mu * jnp.log(J) + lmbda/2.*(J-1.)**2.
def psivec(F):
I1 = jnp.einsum("ij...,ij...", F, F)
J = jnp.linalg.det(jnp.einsum("ij...->...ij",F))
return mu/2.*(I1 - 3.) - mu * jnp.log(J) + lmbda/2.*(J-1.)**2.
def resvecrev(F):
return jacrev(psivec)(F)
def resvecfwd(F):
return jacfwd(psivec)(F)
# @jit
def resJax(F11, F12, F13, F21, F22, F23, F31, F32, F33):
out = np.zeros((9,) + F11.shape)
for i in range(9):
out[i] = vectorize(jgrad(psi, i))(F11, F12, F13, F21, F22, F23, F31, F32, F33)
return out
def resNorm(du):
F = np.zeros_like(du)
F[0,0] += 1.
F[1,1] += 1.
F[2,2] += 1.
F += du
Finv, J = vinv(F)
return mu * F - mu * transpose(Finv) + lmbda * J * (J - 1) * transpose(Finv)
def leftdofs(x):
return x[0] < 1.e-6
def rightdofs(x):
return np.abs(x[0]-1.) < 1.e-6
def rightdofs(x):
return np.abs(x[0]-1.) < 1.e-6
def u2Right(x, y, z):
return scale*(y0 + (y - y0)*np.cos(theta) - (z - z0)*np.sin(theta) - y)
def u3Right(x, y, z):
return scale*(z0 + (y - y0)*np.sin(theta) + (z - z0)*np.cos(theta) - z)
t1 = time()
# msh ="fenicsmesh.xdmf")
# mesh = from_meshio(msh)
mesh = MeshTet()
elem = ElementTetP1()
uelem = ElementVectorH1(elem)
iBasis = InteriorBasis(mesh, uelem, intorder=3)
fBasis = FacetBasis(mesh, uelem, intorder=3)
u = np.zeros(iBasis.N)
bodyForce = np.array([0., -1./2, 0])
surfaceTraction = np.array([0.1, 0, 0])
E, nu = 10., 0.3
mu = E/2/(1+nu)
lmbda = 2*mu*nu/(1-2*nu)
dofs = {
"left": iBasis.get_dofs(leftdofs),
"right": iBasis.get_dofs(rightdofs)
scale = y0 = z0 = 0.5
theta = np.pi/3.
u1Right = 0.
u[dofs["left"].nodal['u^1']] = 0.
u[dofs["left"].nodal['u^2']] = 0.
u[dofs["left"].nodal['u^3']] = 0.
u[dofs["right"].nodal['u^1']] = 0.
dofs2Right = dofs["right"].nodal['u^2']
dofs3Right = dofs["right"].nodal['u^3']
u[dofs2Right] = u2Right(*iBasis.doflocs[:, dofs2Right])
u[dofs3Right] = u3Right(*iBasis.doflocs[:, dofs3Right])
I = iBasis.complement_dofs(dofs)
D = iBasis.get_dofs(lambda x: np.logical_or(np.isclose(x[0], 0.), np.isclose(x[0], 1.))).all()
def rhs(v, w):
return - dot(bodyForce, v)
def rhsSurf(v, w):
return - dot(surfaceTraction, v) * (restBoundary(w.x))
# @BilinearForm
# def bilinf(u, v, w):
# vals = np.zeros((3,3,3,3))
# return numbajac(*(grad(u), grad(v), grad(w['w'])))
def materialResidualJax(v, w):
gradu = grad(w["w"])
gradv = grad(v)
gv = np.array([
gradv[0, 0], gradv[0, 1], gradv[0, 2],
gradv[1, 0], gradv[1, 1], gradv[1, 2],
gradv[2, 0], gradv[2, 1], gradv[2, 2]
F = np.zeros_like(gradu)
F[0,0] += 1.
F[1,1] += 1.
F[2,2] += 1.
F += gradu
defG = np.array([
F[0, 0], F[0, 1], F[0, 2],
F[1, 0], F[1, 1], F[1, 2],
F[2, 0], F[2, 1], F[2, 2]
return np.einsum("i...,i...", gv, resJax(*defG))
def materialResidualJaxVec(v, w):
gradu = grad(w["w"])
gradv = grad(v)
F = np.zeros_like(gradu)
F[0,0] += 1.
F[1,1] += 1.
F[2,2] += 1.
F += gradu
Jw = np.sum(resvecrev(F),axis=0).sum(axis=0) #np.einsum("ik...->...",resvecrev(F))
# print(Jw.shape, gradv.shape)
return np.einsum("ij...,ij...", gradv, Jw)
def materialResidualNorm(v, w):
return np.einsum("ij...,ij...", resNorm(w["w"].grad), grad(v))
def residual(x: np.ndarray) -> np.ndarray:
xfull = np.empty_like(u)
xfull[I] = x
xfull[D] = u[D]
localres = asm(materialResidualJaxVec, iBasis, w=w) + asm(rhsSurf, fBasis, w=w) + asm(rhs, iBasis)
return localres[I]
M = LinearOperator([len(I)]*2, matvec=lambda x: x)
def update(x: np.ndarray,
_: Optional[np.ndarray] = None,
i: List[int] = [0],
period: int = 10) -> None:
if i[0] % period == 0:
print('Updating Jacobian preconditioner.')
u_full = np.empty_like(u)
u_full[I] = x
u_full[D] = u[D]
J = asm(linear_elasticity(mu, lmbda), iBasis, w=iBasis.interpolate(u_full))
JI = condense(J, D=D, expand=False).tocsc()
JI_ilu = spilu(JI)
M.matvec = JI_ilu.solve
i[0] += 1
M.update = update
u[I] = root(residual, u[I], method='krylov', tol=1.e-10,
'disp': True, "line_search": "wolfe",
"method":"lgmres", "inner_M": M
).x |
FYI, I am now able to install Jax on Windows. I was able to install using the prebuilt wheels available from links here. It was pretty straightforward. I am using it from within a conda environment (miniconda on windows). >>> import sys
>>> sys.version
'3.9.5 | packaged by conda-forge | (default, Jun 19 2021, 00:22:33) [MSC v.1916 64 bit (AMD64)]'
>>> import jax
>>> jax.__version__
>>> I haven't tested the example here yet, but might give it a go soon. Might be relevant in light of #890 |
JAX was mentioned in #439 as a possible way of automating something like FEniCS's
for calculating Jacobians in nonlinear problems; however, one user noted that theyA successful reimplementation of ex10 which currently has a handwritten Jacobian
Lines 15 to 19 in 85b83ec
using JAX was demonstrated in #440 .
