Replies: 7 comments 4 replies
-
I tried installing JAX under Windows 10 using Miniconda3: conda create -n jax
conda activate jax
conda install pip
pip install jax but that didn't work; python -c "import jax" raises
Then, seeing that jax is in conda-forge, I tried: conda deactivate
conda env remove -n jax
conda create -n jax
conda activate jax
conda install -c conda-forge jax but that failed with
|
Beta Was this translation helpful? Give feedback.
-
Thanks for trying @gdmcbain . I think my history with trying to use jax on windows was limited to using either
But I have a spare linux machine where I can try experimenting with jax. |
Beta Was this translation helpful? Give feedback.
-
I'd prefer not to actually depend on something like JAX. On the other hand, trying it out revealed already how an interface utilizing these autodiff tools should look like. So maybe we can eventually provide some additional utilities so that it's easier to use JAX or autograd or similar tools in scikit-fem. PS. Does it work in Docker? |
Beta Was this translation helpful? Give feedback.
-
I think it should. They do have a Dockerfile on their page. I can test build an image today. Infact if it works with Docker then it's the easiest way out for users to try on windows (even otherwise containerizing would help keep everything clean). In fact pretty much the same approach FEniCS takes to it's windows support. |
Beta Was this translation helpful? Give feedback.
-
Checked! And it's working. demo container |
Beta Was this translation helpful? Give feedback.
-
I tried sketching an example to reproduce #439 with JAX. This is extremely slow even with #442 and as @kinnala said more of an issue of how to properly use JAX. I'll come to it in future to experiment and improve the speed but will have to learn JAX before that from jax.numpy import vectorize
import jax.numpy as jnp
import numpy as np
from skfem.io import from_meshio
from skfem.helpers import grad, dot, transpose
from jax import jit, grad as jgrad, jacfwd, jacrev, pmap, vmap
from numba import njit, prange
from time import time
import meshio, os
from scipy.sparse.linalg import spilu, LinearOperator
from scipy.optimize import root
from typing import List, Optional
from skfem.models.elasticity import linear_elasticity
from skfem import *
@njit
def restBoundary(x):
"""
returns the dofs that are not located on the left/right face
for application of surface traction
"""
topBottom = np.logical_or(
np.abs(x[1]) < 1.e-6, np.abs(x[1] - 1.) < 1.e-6
)
frontBack = np.logical_or(
np.abs(x[2]) < 1.e-6, np.abs(x[2] - 1.) < 1.e-6
)
leftRight = np.logical_or(
np.abs(x[0]) < 1.e-6, np.abs(x[0] - 1.) < 1.e-6
)
return np.logical_and(np.logical_or(topBottom, frontBack), np.logical_not(leftRight))
@jit
def vdet(A):
detA = jnp.zeros_like(A[0, 0])
detA = A[0, 0] * (A[1, 1] * A[2, 2] -
A[1, 2] * A[2, 1]) -\
A[0, 1] * (A[1, 0] * A[2, 2] -
A[1, 2] * A[2, 0]) +\
A[0, 2] * (A[1, 0] * A[2, 1] -
A[1, 1] * A[2, 0])
return detA
@jit
def vinv(A):
invA = jnp.zeros_like(A)
detA = vdet(A)
invA[0, 0] = (-A[1, 2] * A[2, 1] +
A[1, 1] * A[2, 2]) / detA
invA[1, 0] = (A[1, 2] * A[2, 0] -
A[1, 0] * A[2, 2]) / detA
invA[2, 0] = (-A[1, 1] * A[2, 0] +
A[1, 0] * A[2, 1]) / detA
invA[0, 1] = (A[0, 2] * A[2, 1] -
A[0, 1] * A[2, 2]) / detA
invA[1, 1] = (-A[0, 2] * A[2, 0] +
A[0, 0] * A[2, 2]) / detA
invA[2, 1] = (A[0, 1] * A[2, 0] -
A[0, 0] * A[2, 1]) / detA
invA[0, 2] = (-A[0, 2] * A[1, 1] +
A[0, 1] * A[1, 2]) / detA
invA[1, 2] = (A[0, 2] * A[1, 0] -
A[0, 0] * A[1, 2]) / detA
invA[2, 2] = (-A[0, 1] * A[1, 0] +
A[0, 0] * A[1, 1]) / detA
return invA, detA
@jit
def psi(F11, F12, F13, F21, F22, F23, F31, F32, F33):
I1 = F11**2 + F12**2. + F13**2 + F21**2. + F22**2 + F23**2. + F31**2. + F32**2. + F33**2 #jnp.einsum('ij...,ij...', F, F)
J = F11*F22*F33 - F11*F23*F32 - F12*F21*F33 + F12*F23*F31 + F13*F21*F32 - F13*F22*F31#vdet(F)
return mu/2.*(I1 - 3.) - mu * jnp.log(J) + lmbda/2.*(J-1.)**2.
@jit
def psivec(F):
I1 = jnp.einsum("ij...,ij...", F, F)
J = jnp.linalg.det(jnp.einsum("ij...->...ij",F))
return mu/2.*(I1 - 3.) - mu * jnp.log(J) + lmbda/2.*(J-1.)**2.
@jit
def resvecrev(F):
return jacrev(psivec)(F)
@jit
def resvecfwd(F):
return jacfwd(psivec)(F)
# @jit
def resJax(F11, F12, F13, F21, F22, F23, F31, F32, F33):
out = np.zeros((9,) + F11.shape)
for i in range(9):
out[i] = vectorize(jgrad(psi, i))(F11, F12, F13, F21, F22, F23, F31, F32, F33)
return out
@njit
def resNorm(du):
F = np.zeros_like(du)
F[0,0] += 1.
F[1,1] += 1.
F[2,2] += 1.
F += du
Finv, J = vinv(F)
return mu * F - mu * transpose(Finv) + lmbda * J * (J - 1) * transpose(Finv)
@njit
def leftdofs(x):
return x[0] < 1.e-6
@njit
def rightdofs(x):
return np.abs(x[0]-1.) < 1.e-6
@njit
def rightdofs(x):
return np.abs(x[0]-1.) < 1.e-6
@njit
def u2Right(x, y, z):
return scale*(y0 + (y - y0)*np.cos(theta) - (z - z0)*np.sin(theta) - y)
@njit
def u3Right(x, y, z):
return scale*(z0 + (y - y0)*np.sin(theta) + (z - z0)*np.cos(theta) - z)
t1 = time()
# msh = meshio.xdmf.read("fenicsmesh.xdmf")
# mesh = from_meshio(msh)
mesh = MeshTet()
mesh.refine(2)
elem = ElementTetP1()
uelem = ElementVectorH1(elem)
iBasis = InteriorBasis(mesh, uelem, intorder=3)
fBasis = FacetBasis(mesh, uelem, intorder=3)
u = np.zeros(iBasis.N)
bodyForce = np.array([0., -1./2, 0])
surfaceTraction = np.array([0.1, 0, 0])
E, nu = 10., 0.3
mu = E/2/(1+nu)
lmbda = 2*mu*nu/(1-2*nu)
dofs = {
"left": iBasis.get_dofs(leftdofs),
"right": iBasis.get_dofs(rightdofs)
}
scale = y0 = z0 = 0.5
theta = np.pi/3.
u1Right = 0.
u[dofs["left"].nodal['u^1']] = 0.
u[dofs["left"].nodal['u^2']] = 0.
u[dofs["left"].nodal['u^3']] = 0.
u[dofs["right"].nodal['u^1']] = 0.
dofs2Right = dofs["right"].nodal['u^2']
dofs3Right = dofs["right"].nodal['u^3']
u[dofs2Right] = u2Right(*iBasis.doflocs[:, dofs2Right])
u[dofs3Right] = u3Right(*iBasis.doflocs[:, dofs3Right])
I = iBasis.complement_dofs(dofs)
D = iBasis.get_dofs(lambda x: np.logical_or(np.isclose(x[0], 0.), np.isclose(x[0], 1.))).all()
@LinearForm
def rhs(v, w):
return - dot(bodyForce, v)
@LinearForm
def rhsSurf(v, w):
return - dot(surfaceTraction, v) * (restBoundary(w.x))
# @BilinearForm
# def bilinf(u, v, w):
# vals = np.zeros((3,3,3,3))
# return numbajac(*(grad(u), grad(v), grad(w['w'])))
@LinearForm
def materialResidualJax(v, w):
gradu = grad(w["w"])
gradv = grad(v)
gv = np.array([
gradv[0, 0], gradv[0, 1], gradv[0, 2],
gradv[1, 0], gradv[1, 1], gradv[1, 2],
gradv[2, 0], gradv[2, 1], gradv[2, 2]
])
F = np.zeros_like(gradu)
F[0,0] += 1.
F[1,1] += 1.
F[2,2] += 1.
F += gradu
defG = np.array([
F[0, 0], F[0, 1], F[0, 2],
F[1, 0], F[1, 1], F[1, 2],
F[2, 0], F[2, 1], F[2, 2]
])
return np.einsum("i...,i...", gv, resJax(*defG))
@LinearForm
def materialResidualJaxVec(v, w):
gradu = grad(w["w"])
gradv = grad(v)
F = np.zeros_like(gradu)
F[0,0] += 1.
F[1,1] += 1.
F[2,2] += 1.
F += gradu
Jw = np.sum(resvecrev(F),axis=0).sum(axis=0) #np.einsum("ik...->...",resvecrev(F))
# print(Jw.shape, gradv.shape)
return np.einsum("ij...,ij...", gradv, Jw)
@LinearForm
def materialResidualNorm(v, w):
return np.einsum("ij...,ij...", resNorm(w["w"].grad), grad(v))
def residual(x: np.ndarray) -> np.ndarray:
xfull = np.empty_like(u)
xfull[I] = x
xfull[D] = u[D]
w=iBasis.interpolate(xfull)
localres = asm(materialResidualJaxVec, iBasis, w=w) + asm(rhsSurf, fBasis, w=w) + asm(rhs, iBasis)
return localres[I]
M = LinearOperator([len(I)]*2, matvec=lambda x: x)
def update(x: np.ndarray,
_: Optional[np.ndarray] = None,
i: List[int] = [0],
period: int = 10) -> None:
if i[0] % period == 0:
print('Updating Jacobian preconditioner.')
u_full = np.empty_like(u)
u_full[I] = x
u_full[D] = u[D]
J = asm(linear_elasticity(mu, lmbda), iBasis, w=iBasis.interpolate(u_full))
JI = condense(J, D=D, expand=False).tocsc()
JI_ilu = spilu(JI)
M.matvec = JI_ilu.solve
i[0] += 1
M.update = update
M.update(u[I])
# JFNK
u[I] = root(residual, u[I], method='krylov', tol=1.e-10,
options={
'disp': True, "line_search": "wolfe",
"jac_options":{
"method":"lgmres", "inner_M": M
}
}
).x |
Beta Was this translation helpful? Give feedback.
-
FYI, I am now able to install Jax on Windows. I was able to install using the prebuilt wheels available from links here. It was pretty straightforward. I am using it from within a conda environment (miniconda on windows). >>> import sys
>>> sys.version
'3.9.5 | packaged by conda-forge | (default, Jun 19 2021, 00:22:33) [MSC v.1916 64 bit (AMD64)]'
>>> import jax
>>> jax.__version__
'0.2.26'
>>> I haven't tested the example here yet, but might give it a go soon. Might be relevant in light of #890 |
Beta Was this translation helpful? Give feedback.
-
JAX was mentioned in #439 as a possible way of automating something like FEniCS's
derivative
for calculating Jacobians in nonlinear problems; however, one user noted that theyA successful reimplementation of ex10 which currently has a handwritten Jacobian
scikit-fem/docs/examples/ex10.py
Lines 15 to 19 in 85b83ec
using JAX was demonstrated in #440 .
Beta Was this translation helpful? Give feedback.
All reactions