Skip to content

Commit

Permalink
Initial JAX version
Browse files Browse the repository at this point in the history
  • Loading branch information
yaugenst committed Nov 22, 2023
1 parent 64e3f4e commit 92b56f2
Show file tree
Hide file tree
Showing 14 changed files with 175 additions and 188 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![tests](https://github.com/yaugenst/tofea/actions/workflows/run_tests.yml/badge.svg)](https://github.com/yaugenst/tofea/actions/workflows/run_tests.yml)
[![codecov](https://codecov.io/gh/yaugenst/tofea/graph/badge.svg?token=5Z2SYQ3CPM)](https://codecov.io/gh/yaugenst/tofea)

Simple [autograd](https://github.com/HIPS/autograd)-differentiable finite element analysis for heat conductivity and compliance problems.
Simple finite element analysis for heat conductivity and compliance problems written in [JAX](ttps://jax.readthedocs.io/en/latest/).

## Installation

Expand Down
42 changes: 22 additions & 20 deletions examples/compliance_2d.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
#!/usr/bin/env python3

import autograd.numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import nlopt
import scipy.ndimage
from autograd import value_and_grad
from autograd.extend import defvjp, primitive
import numpy as np
from jax.scipy.signal import convolve

from tofea.fea2d import FEA2D_K

gaussian_filter = primitive(scipy.ndimage.gaussian_filter)
defvjp(
gaussian_filter,
lambda ans, x, *args, **kwargs: lambda g: gaussian_filter(g, *args, **kwargs), # noqa: ARG005
)

def simp_parametrization(shape, ks, vmin, vmax, penalty=3.0):
xy = jnp.linspace(-1, 1, ks)
xx, yy = jnp.meshgrid(xy, xy)
k = jnp.where(jnp.sqrt(xx**2 + yy**2) <= 1, 1, 0)
k /= jnp.sum(k)

def simp_parametrization(shape, sigma, vmin, vmax, penalty=3.0):
@jax.jit
def _parametrization(x):
x = np.reshape(x, shape)
x = gaussian_filter(x, sigma)
x = jnp.pad(x, ks // 2, mode="edge")
x = convolve(x, k, mode="valid")
x = vmin + (vmax - vmin) * x**penalty
return x

Expand All @@ -29,42 +30,43 @@ def _parametrization(x):
def main():
max_its = 50
volfrac = 0.3
sigma = 1.0
kernel_size = 5
shape = (120, 60)
nelx, nely = shape
emin, emax = 1e-6, 1

dofs = np.arange(2 * (nelx + 1) * (nely + 1)).reshape(nelx + 1, nely + 1, 2)
fixed = np.zeros_like(dofs, dtype=bool)
fixed = np.zeros_like(dofs, dtype="?")
load = np.zeros_like(dofs)

fixed[0, :, :] = 1
load[-1, -1, 1] = 1

fem = FEA2D_K(fixed)
parametrization = simp_parametrization(shape, sigma, emin, emax)
x0 = np.full(shape, volfrac)
parametrization = simp_parametrization(shape, kernel_size, emin, emax)
x0 = jnp.full(shape, volfrac)

plt.ion()
fig, ax = plt.subplots(1, 1, tight_layout=True)
im = ax.imshow(parametrization(x0).T, cmap="gray_r", vmin=emin, vmax=emax)

@value_and_grad
@jax.value_and_grad
def objective(x):
x = parametrization(x)
d = fem.displacement(x, load)
c = fem.compliance(x, d)
return c

@value_and_grad
@jax.jit
@jax.value_and_grad
def volume(x):
return np.mean(x)
return jnp.mean(x)

def volume_constraint(x, gd):
v, g = volume(x)
if gd.size > 0:
gd[:] = g.ravel()
return v - volfrac
return v.item() - volfrac

def nlopt_obj(x, gd):
v, g = objective(x)
Expand All @@ -75,7 +77,7 @@ def nlopt_obj(x, gd):
im.set_data(parametrization(x).T)
plt.pause(0.01)

return v
return v.item()

opt = nlopt.opt(nlopt.LD_MMA, x0.size)
opt.add_inequality_constraint(volume_constraint)
Expand Down
45 changes: 25 additions & 20 deletions examples/heat_2d.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
#!/usr/bin/env python3

import autograd.numpy as np
from time import perf_counter

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import nlopt
import scipy.ndimage
from autograd import value_and_grad
from autograd.extend import defvjp, primitive
import numpy as np
from jax.scipy.signal import convolve

from tofea.fea2d import FEA2D_T

gaussian_filter = primitive(scipy.ndimage.gaussian_filter)
defvjp(
gaussian_filter,
lambda ans, x, *args, **kwargs: lambda g: gaussian_filter(g, *args, **kwargs), # noqa: ARG005
)

def simp_parametrization(shape, ks, vmin, vmax, penalty=3.0):
xy = jnp.linspace(-1, 1, ks)
xx, yy = jnp.meshgrid(xy, xy)
k = jnp.where(jnp.sqrt(xx**2 + yy**2) <= 1, 1, 0)
k /= jnp.sum(k)

def simp_parametrization(shape, sigma, vmin, vmax, penalty=3.0):
@jax.jit
def _parametrization(x):
x = np.reshape(x, shape)
x = gaussian_filter(x, sigma)
x = jnp.pad(x, ks // 2, mode="edge")
x = convolve(x, k, mode="valid")
x = vmin + (vmax - vmin) * x**penalty
return x

Expand All @@ -29,7 +32,7 @@ def _parametrization(x):
def main():
max_its = 100
volfrac = 0.5
sigma = 1.0
kernel_size = 5
shape = (100, 100)
nelx, nely = shape
cmin, cmax = 1e-4, 1
Expand All @@ -42,28 +45,28 @@ def main():
load[(0, -1), :] = 1

fem = FEA2D_T(fixed)
parametrization = simp_parametrization(shape, sigma, cmin, cmax)
x0 = np.full(shape, volfrac)
parametrization = simp_parametrization(shape, kernel_size, cmin, cmax)
x0 = jnp.full(shape, volfrac)

plt.ion()
fig, ax = plt.subplots(1, 1, tight_layout=True)
im = ax.imshow(parametrization(x0).T, cmap="gray_r", vmin=cmin, vmax=cmax)

@value_and_grad
@jax.value_and_grad
def objective(x):
x = parametrization(x)
t = fem.temperature(x, load)
return np.mean(t)
return jnp.mean(t)

@value_and_grad
@jax.value_and_grad
def volume(x):
return np.mean(x)
return jnp.mean(x)

def volume_constraint(x, gd):
v, g = volume(x)
if gd.size > 0:
gd[:] = g.ravel()
return v - volfrac
return v.item() - volfrac

def nlopt_obj(x, gd):
v, g = objective(x)
Expand All @@ -74,15 +77,17 @@ def nlopt_obj(x, gd):
im.set_data(parametrization(x).T)
plt.pause(0.01)

return v
return v.item()

opt = nlopt.opt(nlopt.LD_MMA, x0.size)
opt.add_inequality_constraint(volume_constraint)
opt.set_min_objective(nlopt_obj)
opt.set_lower_bounds(0)
opt.set_upper_bounds(1)
opt.set_maxeval(max_its)
t0 = perf_counter()
opt.optimize(x0.ravel())
print(perf_counter() - t0)

plt.show(block=True)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers = [
]
dynamic = ["version"]
requires-python = ">=3.10,<=3.12"
dependencies = ["numpy", "scipy", "sympy", "autograd"]
dependencies = ["numpy", "scipy", "sympy", "jax"]

[project.optional-dependencies]
tests = ["pytest", "pytest-cov"]
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest
from numpy.random import Generator, default_rng


@pytest.fixture(scope="session")
def rng() -> Generator:
seed = 365235228
return default_rng(seed)
6 changes: 3 additions & 3 deletions tests/test_elements.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import pytest
from jax import Array

from tofea.elements import Q4Element_K, Q4Element_T

Expand All @@ -11,7 +11,7 @@ def q4element_k_instance(self):

def test_element(self, q4element_k_instance):
element = q4element_k_instance.element
assert isinstance(element, np.ndarray)
assert isinstance(element, Array)
assert element.shape == (8, 8)


Expand All @@ -22,5 +22,5 @@ def q4element_t_instance(self):

def test_element(self, q4element_t_instance):
element = q4element_t_instance.element
assert isinstance(element, np.ndarray)
assert isinstance(element, Array)
assert element.shape == (4, 4)
51 changes: 22 additions & 29 deletions tests/test_fea2d.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import numpy as np
import pytest
from autograd.test_util import check_grads
from jax import Array
from jax.test_util import check_grads

from tofea.fea2d import FEA2D_K, FEA2D_T


@pytest.fixture()
def rng():
seed = 36523523
return np.random.default_rng(seed)


class TestFEA2DK:
@pytest.fixture()
def fea2d_k_instance(self):
Expand All @@ -29,42 +24,44 @@ def test_shape(self, fea2d_k_instance):

def test_dofs(self, fea2d_k_instance):
dofs = fea2d_k_instance.dofs
assert isinstance(dofs, np.ndarray)
assert isinstance(dofs, Array)
assert dofs.shape == (50,)
assert np.all(dofs == np.arange(50))

def test_fixdofs(self, fea2d_k_instance):
fixdofs = fea2d_k_instance.fixdofs
assert isinstance(fixdofs, np.ndarray)
assert isinstance(fixdofs, Array)
assert fixdofs.size == fea2d_k_instance.fixed[0].size

def test_freedofs(self, fea2d_k_instance):
freedofs = fea2d_k_instance.freedofs
assert isinstance(freedofs, np.ndarray)
assert isinstance(freedofs, Array)
assert freedofs.size == 50 - fea2d_k_instance.fixdofs.size

def test_displacement_grads(self, fea2d_k_instance, x_and_b):
x, b = x_and_b
check_grads(
lambda x_: fea2d_k_instance.displacement(x_, b),
modes=["fwd", "rev"],
(x,),
order=1,
)(x)
modes=["rev"],
)
check_grads(
lambda b_: fea2d_k_instance.displacement(x, b_),
modes=["fwd", "rev"],
(b,),
order=1,
)(b)
modes=["rev"],
)

def test_compliance_grads(self, fea2d_k_instance, x_and_b, rng):
x, _ = x_and_b
d = rng.random(fea2d_k_instance.dofs.shape)
check_grads(
lambda x_: fea2d_k_instance.compliance(x_, d), modes=["fwd", "rev"], order=1
)(x)
lambda x_: fea2d_k_instance.compliance(x_, d), (x,), order=1, modes=["rev"]
)
check_grads(
lambda d_: fea2d_k_instance.compliance(x, d_), modes=["fwd", "rev"], order=1
)(d)
lambda d_: fea2d_k_instance.compliance(x, d_), (d,), order=1, modes=["rev"]
)


class TestFEA2DT:
Expand All @@ -85,29 +82,25 @@ def test_shape(self, fea2d_t_instance):

def test_dofs(self, fea2d_t_instance):
dofs = fea2d_t_instance.dofs
assert isinstance(dofs, np.ndarray)
assert isinstance(dofs, Array)
assert dofs.shape == (25,)
assert np.all(dofs == np.arange(25))

def test_fixdofs(self, fea2d_t_instance):
fixdofs = fea2d_t_instance.fixdofs
assert isinstance(fixdofs, np.ndarray)
assert isinstance(fixdofs, Array)
assert fixdofs.size == 1

def test_freedofs(self, fea2d_t_instance):
freedofs = fea2d_t_instance.freedofs
assert isinstance(freedofs, np.ndarray)
assert isinstance(freedofs, Array)
assert freedofs.size == 25 - fea2d_t_instance.fixdofs.size

def test_temperature_grads(self, fea2d_t_instance, x_and_b):
x, b = x_and_b
check_grads(
lambda x_: fea2d_t_instance.temperature(x_, b),
modes=["fwd", "rev"],
order=1,
)(x)
lambda x_: fea2d_t_instance.temperature(x_, b), (x,), order=1, modes=["rev"]
)
check_grads(
lambda b_: fea2d_t_instance.temperature(x, b_),
modes=["fwd", "rev"],
order=1,
)(b)
lambda b_: fea2d_t_instance.temperature(x, b_), (b,), modes=["rev"], order=1
)
Loading

0 comments on commit 92b56f2

Please sign in to comment.