Skip to content

Commit

Permalink
Work towards typing system
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisrichardson committed Aug 7, 2023
1 parent a18d7a4 commit bf4ff3e
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 56 deletions.
9 changes: 8 additions & 1 deletion ffcx/codegeneration/C/c_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def format_array_decl(self, arr) -> str:
vals = "{}"
else:
vals = build_initializer_lists(arr.values)
return f"{arr.typename} {symbol}{dims} = {vals};\n"
cstr = "static const" if arr.const else ""

return f"{cstr} {arr.typename} {symbol}{dims} = {vals};\n"

def format_array_access(self, arr) -> str:
name = self.c_format(arr.array)
Expand Down Expand Up @@ -199,6 +201,10 @@ def format_binary_op(self, oper) -> str:
# Return combined string
return f"{lhs} {oper.op} {rhs}"

def format_neg(self, val) -> str:
arg = self.c_format(val.arg)
return f"-{arg}"

def format_not(self, val) -> str:
arg = self.c_format(val.arg)
return f"{val.op}({arg})"
Expand Down Expand Up @@ -269,6 +275,7 @@ def format_math_function(self, c) -> str:
"Assign": format_assign,
"AssignAdd": format_assign,
"Product": format_nary_op,
"Neg": format_neg,
"Sum": format_nary_op,
"Add": format_binary_op,
"Sub": format_binary_op,
Expand Down
13 changes: 8 additions & 5 deletions ffcx/codegeneration/codegeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
import logging
import typing

from ffcx.codegeneration.C.dofmap import generator as dofmap_generator
from ffcx.codegeneration.C.expressions import generator as expression_generator
from ffcx.codegeneration.C.finite_element import generator as finite_element_generator
from ffcx.codegeneration.C.form import generator as form_generator
from ffcx.codegeneration.C.integrals import generator as integral_generator
from importlib import import_module

logger = logging.getLogger("ffcx")

Expand All @@ -43,6 +39,13 @@ def generate_code(ir, options) -> CodeBlocks:
logger.info("Compiler stage 3: Generating code")
logger.info(79 * "*")

lang = options.get("language", "C")
finite_element_generator = import_module(f"ffcx.codegeneration.{lang}.finite_element").generator
dofmap_generator = import_module(f"ffcx.codegeneration.{lang}.dofmap").generator
integral_generator = import_module(f"ffcx.codegeneration.{lang}.integrals").generator
form_generator = import_module(f"ffcx.codegeneration.{lang}.form").generator
expression_generator = import_module(f"ffcx.codegeneration.{lang}.expressions").generator

# Generate code for finite_elements
code_finite_elements = [
finite_element_generator(element_ir, options) for element_ir in ir.elements
Expand Down
2 changes: 1 addition & 1 deletion ffcx/codegeneration/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def coefficient(self, t, mt, tabledata, quadrature_rule, access):
# If a map is necessary from stride 1 to bs, the code must be added before the quadrature loop.
if dof_access_map:
pre_code += [
L.ArrayDecl(self.options["scalar_type"], dof_access.array, num_dofs)
L.ArrayDecl(dof_access.array, typename=self.options["scalar_type"], sizes=num_dofs)
]
pre_body = L.Assign(dof_access, dof_access_map)
pre_code += [L.ForRange(ic, 0, num_dofs, pre_body)]
Expand Down
19 changes: 4 additions & 15 deletions ffcx/codegeneration/expression_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ffcx.codegeneration.backend import FFCXBackend
from ffcx.ir.representation import ExpressionIR
from ffcx.naming import scalar_to_value_type
import ffcx.codegeneration.lnodes as L

logger = logging.getLogger("ffcx")

Expand All @@ -26,14 +27,13 @@ def __init__(self, ir: ExpressionIR, backend: FFCXBackend):

self.ir = ir
self.backend = backend
self.scope = {}
self.scope: Dict = {}
self._ufl_names: Set[Any] = set()
self.symbol_counters: DefaultDict[Any, int] = collections.defaultdict(int)
self.shared_symbols: Dict[Any, Any] = {}
self.quadrature_rule = list(self.ir.integrand.keys())[0]

def generate(self):
L = self.backend.language

parts = []
scalar_type = self.backend.access.options["scalar_type"]
Expand All @@ -59,7 +59,6 @@ def generate(self):

def generate_geometry_tables(self, float_type: str):
"""Generate static tables of geometry data."""
L = self.backend.language

# Currently we only support circumradius
ufl_geometry = {
Expand All @@ -84,18 +83,15 @@ def generate_geometry_tables(self, float_type: str):

def generate_element_tables(self, float_type: str):
"""Generate tables of FE basis evaluated at specified points."""
L = self.backend.language
parts = []

tables = self.ir.unique_tables

padlen = self.ir.options["padlen"]
table_names = sorted(tables)

for name in table_names:
table = tables[name]
decl = L.ArrayDecl(
f"static const {float_type}", name, table.shape, table, padlen=padlen)
name, typename=f"{float_type}", sizes=table.shape, values=table, const=True)
parts += [decl]

# Add leading comment if there are any tables
Expand All @@ -111,7 +107,6 @@ def generate_quadrature_loop(self):
In the context of expressions quadrature loop is not accumulated.
"""
L = self.backend.language

# Generate varying partition
body = self.generate_varying_partition()
Expand All @@ -137,7 +132,6 @@ def generate_quadrature_loop(self):

def generate_varying_partition(self):
"""Generate factors of blocks which are not cellwise constant."""
L = self.backend.language

# Get annotated graph of factorisation
F = self.ir.integrand[self.quadrature_rule]["factorization"]
Expand All @@ -150,7 +144,6 @@ def generate_varying_partition(self):

def generate_piecewise_partition(self):
"""Generate factors of blocks which are constant (i.e. do not depend on quadrature points)."""
L = self.backend.language

# Get annotated graph of factorisation
F = self.ir.integrand[self.quadrature_rule]["factorization"]
Expand Down Expand Up @@ -187,7 +180,6 @@ def generate_dofblock_partition(self):

def generate_block_parts(self, blockmap, blockdata):
"""Generate and return code parts for a given block."""
L = self.backend.language

# The parts to return
preparts = []
Expand Down Expand Up @@ -287,7 +279,6 @@ def get_arg_factors(self, blockdata, block_rank, indices):
Indices used to index element tables
"""
L = self.backend.language

arg_factors = []
for i in range(block_rank):
Expand All @@ -308,7 +299,6 @@ def get_arg_factors(self, blockdata, block_rank, indices):

def new_temp_symbol(self, basename):
"""Create a new code symbol named basename + running counter."""
L = self.backend.language
name = "%s%d" % (basename, self.symbol_counters[basename])
self.symbol_counters[basename] += 1
return L.Symbol(name)
Expand All @@ -321,7 +311,6 @@ def get_var(self, v):

def generate_partition(self, symbol, F, mode):
"""Generate computations of factors of blocks."""
L = self.backend.language

definitions = []
pre_definitions = dict()
Expand Down Expand Up @@ -410,6 +399,6 @@ def generate_partition(self, symbol, F, mode):
if intermediates:
if use_symbol_array:
scalar_type = self.backend.access.options["scalar_type"]
parts += [L.ArrayDecl(scalar_type, symbol, len(intermediates))]
parts += [L.ArrayDecl(symbol, typename=scalar_type, sizes=len(intermediates))]
parts += intermediates
return parts
18 changes: 9 additions & 9 deletions ffcx/codegeneration/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ def facet_edge_vertices(L, tablename, cellname):
raise ValueError("Only triangular and quadrilateral faces supported.")

out = np.array(edge_vertices, dtype=int)
return L.ArrayDecl("static const unsigned int", f"{cellname}_{tablename}", out.shape, out)
return L.ArrayDecl(f"{cellname}_{tablename}", values=out, const=True)


def reference_facet_jacobian(L, tablename, cellname, type: str):
celltype = getattr(basix.CellType, cellname)
out = basix.cell.facet_jacobians(celltype)
return L.ArrayDecl(f"static const {type}", f"{cellname}_{tablename}", out.shape, out)
return L.ArrayDecl(f"{cellname}_{tablename}", values=out, const=True)


def reference_cell_volume(L, tablename, cellname, type: str):
celltype = getattr(basix.CellType, cellname)
out = basix.cell.volume(celltype)
return L.VariableDecl(f"static const {type}", f"{cellname}_{tablename}", out)
return L.VariableDecl(f"{cellname}_{tablename}", values=out, const=True)


def reference_facet_volume(L, tablename, cellname, type: str):
Expand All @@ -69,7 +69,7 @@ def reference_facet_volume(L, tablename, cellname, type: str):
for i in volumes[1:]:
if not np.isclose(i, volumes[0]):
raise ValueError("Reference facet volume not supported for this cell type.")
return L.VariableDecl(f"static const {type}", f"{cellname}_{tablename}", volumes[0])
return L.VariableDecl(f"{cellname}_{tablename}", f"{type}", volumes[0], const=True)


def reference_edge_vectors(L, tablename, cellname, type: str):
Expand All @@ -78,7 +78,7 @@ def reference_edge_vectors(L, tablename, cellname, type: str):
geometry = basix.geometry(celltype)
edge_vectors = [geometry[j] - geometry[i] for i, j in topology[1]]
out = np.array(edge_vectors[cellname])
return L.ArrayDecl(f"static const {type}", f"{cellname}_{tablename}", out.shape, out)
return L.ArrayDecl(f"{cellname}_{tablename}", values=out, const=True)


def facet_reference_edge_vectors(L, tablename, cellname, type: str):
Expand All @@ -101,16 +101,16 @@ def facet_reference_edge_vectors(L, tablename, cellname, type: str):
raise ValueError("Only triangular and quadrilateral faces supported.")

out = np.array(edge_vectors)
return L.ArrayDecl(f"static const {type}", f"{cellname}_{tablename}", out.shape, out)
return L.ArrayDecl(f"{cellname}_{tablename}", values=out, const=True)


def reference_facet_normals(L, tablename, cellname, type: str):
celltype = getattr(basix.CellType, cellname)
out = basix.cell.facet_outward_normals(celltype)
return L.ArrayDecl(f"static const {type}", f"{cellname}_{tablename}", out.shape, out)
return L.ArrayDecl(f"{cellname}_{tablename}", values=out, const=True)


def facet_orientation(L, tablename, cellname, type: str):
celltype = getattr(basix.CellType, cellname)
out = basix.cell.facet_orientations(celltype)
return L.ArrayDecl(f"static const {type}", f"{cellname}_{tablename}", len(out), out)
out = np.array(basix.cell.facet_orientations(celltype))
return L.ArrayDecl(f"{cellname}_{tablename}", values=out, const=True)
17 changes: 7 additions & 10 deletions ffcx/codegeneration/integral_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,12 @@ def generate_quadrature_tables(self, value_type: str) -> List[str]:

# Loop over quadrature rules
for quadrature_rule, integrand in self.ir.integrand.items():
num_points = quadrature_rule.weights.shape[0]

# Generate quadrature weights array
wsym = self.backend.symbols.weights_table(quadrature_rule)
parts += [
L.ArrayDecl(
f"static const {value_type}",
wsym,
num_points,
quadrature_rule.weights,
)
wsym, values=quadrature_rule.weights, const=True)
]

# Add leading comment if there are any tables
Expand Down Expand Up @@ -248,7 +243,7 @@ def declare_table(self, name, table, padlen, value_type: str):
these rotations.
"""
return [L.ArrayDecl(f"static const {value_type}", name, table.shape, table)]
return [L.ArrayDecl(name, values=table, const=True)]

def generate_quadrature_loop(self, quadrature_rule: QuadratureRule):
"""Generate quadrature loop with for this quadrature_rule."""
Expand Down Expand Up @@ -385,6 +380,8 @@ def generate_partition(self, symbol, F, mode, quadrature_rule):
j = len(intermediates)
if use_symbol_array:
vaccess = symbol[j]
print('assign ', vaccess.array.name, [v.value for v in vaccess.indices], vexpr)

intermediates.append(L.Assign(vaccess, vexpr))
else:
scalar_type = self.backend.access.options["scalar_type"]
Expand All @@ -405,9 +402,9 @@ def generate_partition(self, symbol, F, mode, quadrature_rule):
if use_symbol_array:
parts += [
L.ArrayDecl(
self.backend.access.options["scalar_type"],
symbol,
len(intermediates),
typename=self.backend.access.options["scalar_type"],
sizes=len(intermediates),
)
]
parts += intermediates
Expand Down Expand Up @@ -626,7 +623,7 @@ def generate_block_parts(
else:
t = self.new_temp_symbol("t")
scalar_type = self.backend.access.options["scalar_type"]
pre_loop.append(L.ArrayDecl(scalar_type, t, blockdims[0]))
pre_loop.append(L.ArrayDecl(t, typename=scalar_type, sizes=blockdims[0]))
keep[indices].append(
L.float_product([statement, t[B_indices[0]]])
)
Expand Down
37 changes: 22 additions & 15 deletions ffcx/codegeneration/lnodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,30 +773,37 @@ class ArrayDecl(Statement):

is_scoped = False

def __init__(self, typename, symbol, sizes=None, values=None):
assert isinstance(typename, str)
self.typename = typename
def __init__(self, symbol, typename=None, sizes=None, values=None, const=False):
self.symbol = as_symbol(symbol)

if isinstance(symbol, FlattenedArray):
if sizes is None:
assert symbol.dims is not None
sizes = symbol.dims
elif symbol.dims is not None:
assert symbol.dims == sizes
self.symbol = symbol.array
else:
self.symbol = as_symbol(symbol)
if typename is None:
assert values is not None
if values.dtype == np.float64:
typename = "double"
elif values.dtype == np.float32:
typename = "float"
else:
raise RuntimeError
self.typename = typename

if sizes is None:
assert values is not None
sizes = values.shape
if isinstance(sizes, int):
sizes = (sizes,)
self.sizes = tuple(sizes)

if values is None:
assert typename and sizes

# NB! No type checking, assuming nested lists of literal values. Not applying as_lexpr.
if isinstance(values, (list, tuple)):
self.values = np.asarray(values)
else:
self.values = values

self.const = const

def __eq__(self, other):
attributes = ("typename", "symbol", "sizes", "padlen", "values")
return isinstance(other, type(self)) and all(
Expand Down Expand Up @@ -875,9 +882,9 @@ def __init__(self):
ufl.constantvalue.FloatValue: lambda x: LiteralFloat(float(x)),
ufl.constantvalue.ComplexValue: lambda x: LiteralFloat(x.value()),
ufl.constantvalue.Zero: lambda x: LiteralFloat(0.0),
ufl.algebra.Product: lambda x, a, b: Product([a, b]),
ufl.algebra.Sum: lambda x, a, b: Add(a, b),
ufl.algebra.Division: lambda x, a, b: Div(a, b),
ufl.algebra.Product: lambda x, a, b: a * b,
ufl.algebra.Sum: lambda x, a, b: a + b,
ufl.algebra.Division: lambda x, a, b: a / b,
ufl.algebra.Abs: self.math_function,
ufl.algebra.Power: self.math_function,
ufl.algebra.Real: self.math_function,
Expand Down

0 comments on commit bf4ff3e

Please sign in to comment.