Skip to content

Commit

Permalink
Add many geometrical quantities to ffcx expression generator.
Browse files Browse the repository at this point in the history
Fix various errors in the definition of geometry tables
  • Loading branch information
jorgensd committed Oct 14, 2024
1 parent c6428f6 commit 9d227a4
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
7 changes: 3 additions & 4 deletions ffcx/codegeneration/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,8 @@ def reference_facet_edge_vectors(self, mt, tabledata, num_points):
"""Access a reference facet edge vector."""
cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
if cellname in ("tetrahedron", "hexahedron"):
table = L.Symbol(f"{cellname}_reference_edge_vectors", dtype=L.DataType.REAL)
facet = self.symbols.entity("facet", mt.restriction)
return table[facet][mt.component[0]][mt.component[1]]
table = L.Symbol(f"{cellname}_facet_reference_edge_vectors", dtype=L.DataType.REAL)
return table[mt.component[0]][mt.component[1]]
elif cellname in ("interval", "triangle", "quadrilateral"):
raise RuntimeError(
"The reference cell facet edge vectors doesn't make sense for interval "
Expand All @@ -280,7 +279,7 @@ def facet_orientation(self, mt, tabledata, num_points):
if cellname not in ("interval", "triangle", "tetrahedron"):
raise RuntimeError(f"Unhandled cell types {cellname}.")

table = L.Symbol(f"{cellname}_facet_orientations", dtype=L.DataType.INT)
table = L.Symbol(f"{cellname}_facet_orientation", dtype=L.DataType.INT)
facet = self.symbols.entity("facet", mt.restriction)
return table[facet]

Expand Down
8 changes: 7 additions & 1 deletion ffcx/codegeneration/expression_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,16 @@ def generate(self):

def generate_geometry_tables(self):
"""Generate static tables of geometry data."""
# Currently we only support circumradius
ufl_geometry = {
ufl.geometry.FacetEdgeVectors: "facet_edge_vertices",
ufl.geometry.CellFacetJacobian: "reference_facet_jacobian",
ufl.geometry.ReferenceCellVolume: "reference_cell_volume",
ufl.geometry.ReferenceFacetVolume: "reference_facet_volume",
ufl.geometry.ReferenceCellEdgeVectors: "reference_edge_vectors",
ufl.geometry.ReferenceFacetEdgeVectors: "facet_reference_edge_vectors",
ufl.geometry.FacetJacobianDeterminant: "reference_facet_jacobian",
ufl.geometry.ReferenceNormal: "reference_facet_normals",
ufl.geometry.FacetOrientation: "facet_orientation",
}

cells: dict[Any, set[Any]] = {t: set() for t in ufl_geometry.keys()} # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion ffcx/codegeneration/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,4 @@ def facet_orientation(tablename, cellname):
celltype = getattr(basix.CellType, cellname)
out = basix.cell.facet_orientations(celltype)
symbol = L.Symbol(f"{cellname}_{tablename}", dtype=L.DataType.REAL)
return L.ArrayDecl(symbol, values=out, const=True)
return L.ArrayDecl(symbol, values=np.asarray(out), const=True)

0 comments on commit 9d227a4

Please sign in to comment.