Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FiniteElement python wrapper #3542

Merged
merged 23 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions python/dolfinx/fem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
locate_dofs_topological,
)
from dolfinx.fem.dofmap import DofMap
from dolfinx.fem.element import CoordinateElement, coordinate_element
from dolfinx.fem.element import CoordinateElement, FiniteElement, coordinate_element, finiteelement
from dolfinx.fem.forms import (
Form,
compile_form,
Expand Down Expand Up @@ -91,7 +91,11 @@ def create_interpolation_data(
"""
return _PointOwnershipData(
_create_interpolation_data(
V_to.mesh._cpp_object.geometry, V_to.element, V_from.mesh._cpp_object, cells, padding
V_to.mesh._cpp_object.geometry,
V_to.element._cpp_object,
V_from.mesh._cpp_object,
cells,
padding,
)
)

Expand Down Expand Up @@ -169,6 +173,7 @@ def compute_integration_domains(
"DofMap",
"ElementMetaData",
"Expression",
"FiniteElement",
"Form",
"Function",
"FunctionSpace",
Expand All @@ -189,6 +194,7 @@ def compute_integration_domains(
"dirichletbc",
"discrete_gradient",
"extract_function_spaces",
"finiteelement",
"form",
"form_cpp_class",
"functionspace",
Expand Down
193 changes: 191 additions & 2 deletions python/dolfinx/fem/element.py
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2024 Garth N. Wells
# Copyright (C) 2024 Garth N. Wells and Paul T. Kühner
#
# This file is part of DOLFINx (https://www.fenicsproject.org)
#
Expand All @@ -12,6 +12,8 @@
import numpy.typing as npt

import basix
import ufl
import ufl.finiteelement
from dolfinx import cpp as _cpp


Expand Down Expand Up @@ -93,7 +95,7 @@ def pull_back(
``shape=(num_points, geometrical_dimension)``.
cell_geometry: Physical coordinates describing the cell,
shape ``(num_of_geometry_basis_functions, geometrical_dimension)``
They can be created by accessing `geometry.x[geometry.dofmap.cell_dofs(i)]`,
They can be created by accessing ``geometry.x[geometry.dofmap.cell_dofs(i)]``,

Returns:
Reference coordinates of the physical points ``x``.
Expand Down Expand Up @@ -160,3 +162,190 @@ def _(e: basix.finite_element.FiniteElement):
return CoordinateElement(_cpp.fem.CoordinateElement_float32(e._e))
except TypeError:
return CoordinateElement(_cpp.fem.CoordinateElement_float64(e._e))


class FiniteElement:
_cpp_object: typing.Union[_cpp.fem.FiniteElement_float32, _cpp.fem.FiniteElement_float64]

def __init__(
self,
cpp_object: typing.Union[_cpp.fem.FiniteElement_float32, _cpp.fem.FiniteElement_float64],
):
"""Creates a Python wrapper for the exported finite element class.

Note:
Do not use this constructor directly. Instead use :func:``finiteelement``.

Args:
The underlying cpp instance that this object will wrap.
"""
self._cpp_object = cpp_object

def __eq__(self, other):
return self._cpp_object == other._cpp_object

@property
def dtype(self) -> np.dtype:
"""Geometry type of the Mesh that the FunctionSpace is defined on."""
return self._cpp_object.dtype

@property
def basix_element(self) -> basix.finite_element.FiniteElement:
"""Return underlying Basix C++ element (if it exists).

Raises:
Runtime error if Basix element does not exist.
"""
return self._cpp_object.basix_element

@property
def num_sub_elements(self) -> int:
"""Number of sub elements (for a mixed or blocked element)."""
return self._cpp_object.num_sub_elements

@property
def value_shape(self) -> npt.NDArray[np.integer]:
"""Value shape of the finite element field.

The value shape describes the shape of the finite element field, e.g. ``{}`` for a scalar,
``{2}`` for a vector in 2D, ``{3, 3}`` for a rank-2 tensor in 3D, etc.
"""
return self._cpp_object.value_shape

@property
def interpolation_points(self) -> npt.NDArray[np.floating]:
"""Points on the reference cell at which an expression needs to be evaluated in order to
interpolate the expression in the finite element space.

Interpolation point coordinates on the reference cell, returning the coordinates data
(row-major) storage with shape ``(num_points, tdim)``.

Note:
For Lagrange elements the points will just be the nodal positions. For other elements
the points will typically be the quadrature points used to evaluate moment degrees of
freedom.
"""
return self._cpp_object.interpolation_points

@property
def interpolation_ident(self) -> bool:
"""Check if interpolation into the finite element space is an identity operation given the
evaluation on an expression at specific points, i.e. the degree-of-freedom are equal to
point evaluations. The function will return `true` for Lagrange elements."""
return self._cpp_object.interpolation_ident

@property
def space_dimension(self) -> int:
"""Dimension of the finite element function space (the number of degrees-of-freedom for the
element).

For 'blocked' elements, this function returns the dimension of the full element rather than
the dimension of the base element.
"""
return self._cpp_object.space_dimension

@property
def needs_dof_transformations(self) -> bool:
"""Check if DOF transformations are needed for this element.

DOF transformations will be needed for elements which might not be continuous when two
neighbouring cells disagree on the orientation of a shared sub-entity, and when this cannot
be corrected for by permuting the DOF numbering in the dofmap.

For example, Raviart-Thomas elements will need DOF transformations, as the neighbouring
cells may disagree on the orientation of a basis function, and this orientation cannot be
corrected for by permuting the DOF numbers on each cell.
"""
return self._cpp_object.needs_dof_transformations

@property
def signature(self) -> str:
"""String identifying the finite element."""
return self._cpp_object.signature

def T_apply(self, x: npt.NDArray[np.floating], cell_permutations: np.int32, dim: int) -> None:
"""Transform basis functions from the reference element ordering and orientation to the
globally consistent physical element ordering and orientation.
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved

Args:
x: Data to transform (in place). The shape is ``(m, n)``, where `m` is the number of
dgerees-of-freedom and the storage is row-major.
cell_permutations: Permutation data for the cell.
dim: Number of columns in ``data``.

Note:
Exposed for testing. Function is not vectorised across multiple cells. Please see
`basix.numba_helpers` for performant versions.
"""
self._cpp_object.T_apply(x, cell_permutations, dim)

def Tt_apply(self, x: npt.NDArray[np.floating], cell_permutations: np.int32, dim: int) -> None:
"""Apply the transpose of the operator applied by T_apply().

Args:
x: Data to transform (in place). The shape is ``(m, n)``, where `m` is the number of
dgerees-of-freedom and the storage is row-major.
cell_permutations: Permutation data for the cell.
dim: Number of columns in `data`.

Note:
Exposed for testing. Function is not vectorised across multiple cells. Please see
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@schnellerhase This is not correct. These functions have loop access in the python wrapper.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type hint should be npt.NDArray[np.int32].

const std::size_t data_per_cell
= x.size() / cell_permutations.size();
std::span<T> x_span(x.data(), x.size());
std::span<const std::uint32_t> perm_span(
cell_permutations.data(), cell_permutations.size());
for (std::size_t i = 0; i < cell_permutations.size(); i++)
{
self.Tt_apply(x_span.subspan(i * data_per_cell, data_per_cell),
perm_span[i], dim);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catches and fixes - my bad!

`basix.numba_helpers` for performant versions.
"""
self._cpp_object.Tt_apply(x, cell_permutations, dim)

def Tt_inv_apply(
self, x: npt.NDArray[np.floating], cell_permutations: np.int32, dim: int
) -> None:
"""Apply the inverse transpose of the operator applied by T_apply().

Args:
x: Data to transform (in place). The shape is ``(m, n)``, where ``m`` is the number of
dgerees-of-freedom and the storage is row-major.
cell_permutations: Permutation data for the cell.
dim: Number of columns in `data`.

Note:
Exposed for testing. Function is not vectorised across multiple cells. Please see
``basix.numba_helpers`` for performant versions.
"""
self._cpp_object.Tt_apply(x, cell_permutations, dim)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also wrong, as this should apply the inverse transform.


def finiteelement(
cell_type: _cpp.mesh.CellType,
ufl_e: ufl.finiteelement,
FiniteElement_dtype: np.dtype,
) -> FiniteElement:
"""Create a DOLFINx element from a basix.ufl element.

Args:
cell_type: Element cell type, see ``mesh.CellType``
ufl_e: UFL element, holding quadrature rule and other properties of the selected element.
FiniteElement_dtype: Geometry type of the element.
"""
if np.issubdtype(FiniteElement_dtype, np.float32):
CppElement = _cpp.fem.FiniteElement_float32
elif np.issubdtype(FiniteElement_dtype, np.float64):
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
CppElement = _cpp.fem.FiniteElement_float64
else:
raise ValueError(f"Unsupported dtype: {FiniteElement_dtype}")

if ufl_e.is_mixed:
elements = [
finiteelement(cell_type, e, FiniteElement_dtype)._cpp_object for e in ufl_e.sub_elements
]
return FiniteElement(CppElement(elements))
elif ufl_e.is_quadrature:
return FiniteElement(
CppElement(
cell_type,
ufl_e.custom_quadrature()[0],
ufl_e.reference_value_shape,
ufl_e.is_symmetric,
)
)
else:
basix_e = ufl_e.basix_element._e
value_shape = ufl_e.reference_value_shape if ufl_e.block_size > 1 else None
return FiniteElement(CppElement(basix_e, value_shape, ufl_e.is_symmetric))
49 changes: 11 additions & 38 deletions python/dolfinx/fem/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from __future__ import annotations

import typing
from functools import singledispatch
from functools import cached_property, singledispatch

import numpy as np
import numpy.typing as npt
Expand All @@ -18,6 +18,7 @@
from dolfinx import cpp as _cpp
from dolfinx import default_scalar_type, jit, la
from dolfinx.fem import dofmap
from dolfinx.fem.element import FiniteElement, finiteelement
from dolfinx.geometry import PointOwnershipData

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -461,7 +462,7 @@ def _(e0: Expression):
# u0 is callable
assert callable(u0)
x = _cpp.fem.interpolation_coords(
self._V.element, self._V.mesh.geometry._cpp_object, cells0
self._V.element._cpp_object, self._V.mesh.geometry._cpp_object, cells0
)
self._cpp_object.interpolate(np.asarray(u0(x), dtype=self.dtype), cells0) # type: ignore

Expand Down Expand Up @@ -560,32 +561,6 @@ class ElementMetaData(typing.NamedTuple):
symmetry: typing.Optional[bool] = None


def _create_dolfinx_element(
cell_type: _cpp.mesh.CellType,
ufl_e: ufl.FiniteElementBase,
dtype: np.dtype,
) -> typing.Union[_cpp.fem.FiniteElement_float32, _cpp.fem.FiniteElement_float64]:
"""Create a DOLFINx element from a basix.ufl element."""
if np.issubdtype(dtype, np.float32):
CppElement = _cpp.fem.FiniteElement_float32
elif np.issubdtype(dtype, np.float64):
CppElement = _cpp.fem.FiniteElement_float64
else:
raise ValueError(f"Unsupported dtype: {dtype}")

if ufl_e.is_mixed:
elements = [_create_dolfinx_element(cell_type, e, dtype) for e in ufl_e.sub_elements]
return CppElement(elements)
elif ufl_e.is_quadrature:
return CppElement(
cell_type, ufl_e.custom_quadrature()[0], ufl_e.reference_value_shape, ufl_e.is_symmetric
)
else:
basix_e = ufl_e.basix_element._e
value_shape = ufl_e.reference_value_shape if ufl_e.block_size > 1 else None
return CppElement(basix_e, value_shape, ufl_e.is_symmetric)


def functionspace(
mesh: Mesh,
element: typing.Union[ufl.FiniteElementBase, ElementMetaData, tuple[str, int, tuple, bool]],
Expand Down Expand Up @@ -614,18 +589,18 @@ def functionspace(
raise ValueError("Non-matching UFL cell and mesh cell shapes.")

# Create DOLFINx objects
cpp_element = _create_dolfinx_element(mesh.topology.cell_type, ufl_e, dtype)
cpp_dofmap = _cpp.fem.create_dofmap(mesh.comm, mesh.topology._cpp_object, cpp_element)
element = finiteelement(mesh.topology.cell_type, ufl_e, dtype)
cpp_dofmap = _cpp.fem.create_dofmap(mesh.comm, mesh.topology._cpp_object, element._cpp_object)

assert np.issubdtype(
mesh.geometry.x.dtype, cpp_element.dtype
mesh.geometry.x.dtype, element.dtype
), "Mesh and element dtype are not compatible."

# Initialize the cpp.FunctionSpace
try:
cppV = _cpp.fem.FunctionSpace_float64(mesh._cpp_object, cpp_element, cpp_dofmap)
cppV = _cpp.fem.FunctionSpace_float64(mesh._cpp_object, element._cpp_object, cpp_dofmap)
except TypeError:
cppV = _cpp.fem.FunctionSpace_float32(mesh._cpp_object, cpp_element, cpp_dofmap)
cppV = _cpp.fem.FunctionSpace_float32(mesh._cpp_object, element._cpp_object, cpp_dofmap)

return FunctionSpace(mesh, ufl_e, cppV)

Expand Down Expand Up @@ -745,12 +720,10 @@ def ufl_function_space(self) -> ufl.FunctionSpace:
"""UFL function space."""
return self

@property
def element(
self,
) -> typing.Union[_cpp.fem.FiniteElement_float32, _cpp.fem.FiniteElement_float64]:
@cached_property
def element(self) -> FiniteElement:
"""Function space finite element."""
return self._cpp_object.element # type: ignore
return FiniteElement(self._cpp_object.element)
garth-wells marked this conversation as resolved.
Show resolved Hide resolved

@property
def dofmap(self) -> dofmap.DofMap:
Expand Down
Loading