Skip to content

Commit

Permalink
working on simplification contrib package
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbynum committed Nov 8, 2023
1 parent 8d1e68e commit 8015d7e
Show file tree
Hide file tree
Showing 3 changed files with 376 additions and 3 deletions.
65 changes: 62 additions & 3 deletions pyomo/contrib/simplification/build.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from pybind11.setup_helpers import Pybind11Extension, build_ext
from pyomo.common.fileutils import this_file_dir
from pyomo.common.fileutils import this_file_dir, find_library
import os
from distutils.dist import Distribution
import sys
import shutil
import glob
import tempfile
from pyomo.common.envvar import PYOMO_CONFIG_DIR


def build_ginac_interface(args=[]):
Expand All @@ -13,14 +17,69 @@ def build_ginac_interface(args=[]):
sources = list()
for fname in _sources:
sources.append(os.path.join(dname, fname))

ginac_lib = find_library('ginac')
if ginac_lib is None:
raise RuntimeError('could not find GiNaC library; please make sure it is in the LD_LIBRARY_PATH environment variable')
ginac_lib_dir = os.path.dirname(ginac_lib)
ginac_build_dir = os.path.dirname(ginac_lib_dir)
ginac_include_dir = os.path.join(ginac_build_dir, 'include')
if not os.path.exists(os.path.join(ginac_include_dir, 'ginac', 'ginac.h')):
raise RuntimeError('could not find GiNaC include directory')

cln_lib = find_library('cln')
if cln_lib is None:
raise RuntimeError('could not find CLN library; please make sure it is in the LD_LIBRARY_PATH environment variable')
cln_lib_dir = os.path.dirname(cln_lib)
cln_build_dir = os.path.dirname(cln_lib_dir)
cln_include_dir = os.path.join(cln_build_dir, 'include')
if not os.path.exists(os.path.join(cln_include_dir, 'cln', 'cln.h')):
raise RuntimeError('could not find CLN include directory')

extra_args = ['-std=c++11']
ext = Pybind11Extension('ginac_interface', sources, extra_compile_args=extra_args)
ext = Pybind11Extension(
'ginac_interface',
sources=sources,
language='c++',
include_dirs=[cln_include_dir, ginac_include_dir],
library_dirs=[cln_lib_dir, ginac_lib_dir],
libraries=['cln', 'ginac'],
extra_compile_args=extra_args,
)

class ginac_build_ext(build_ext):
def run(self):
basedir = os.path.abspath(os.path.curdir)
if self.inplace:
tmpdir = this_file_dir()
else:
tmpdir = os.path.abspath(tempfile.mkdtemp())
print("Building in '%s'" % tmpdir)
os.chdir(tmpdir)
try:
super(ginac_build_ext, self).run()
if not self.inplace:
library = glob.glob("build/*/ginac_interface.*")[0]
target = os.path.join(
PYOMO_CONFIG_DIR,
'lib',
'python%s.%s' % sys.version_info[:2],
'site-packages',
'.',
)
if not os.path.exists(target):
os.makedirs(target)
shutil.copy(library, target)
finally:
os.chdir(basedir)
if not self.inplace:
shutil.rmtree(tmpdir, onerror=handleReadonly)

package_config = {
'name': 'ginac_interface',
'packages': [],
'ext_modules': [ext],
'cmdclass': {"build_ext": build_ext},
'cmdclass': {"build_ext": ginac_build_ext},
}

dist = Distribution(package_config)
Expand Down
149 changes: 149 additions & 0 deletions pyomo/contrib/simplification/ginac_interface.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#include "ginac_interface.hpp"

ex ginac_expr_from_pyomo_node(py::handle expr, std::unordered_map<long, ex> &leaf_map, PyomoExprTypes &expr_types) {
ex res;
ExprType tmp_type =
expr_types.expr_type_map[py::type::of(expr)].cast<ExprType>();

switch (tmp_type) {
case py_float: {
res = numeric(expr.cast<double>());
break;
}
case var: {
long expr_id = expr_types.id(expr).cast<long>();
if (leaf_map.count(expr_id) == 0) {
leaf_map[expr_id] = symbol("x" + std::to_string(expr_id));
}
res = leaf_map[expr_id];
break;
}
case param: {
long expr_id = expr_types.id(expr).cast<long>();
if (leaf_map.count(expr_id) == 0) {
leaf_map[expr_id] = symbol("p" + std::to_string(expr_id));
}
res = leaf_map[expr_id];
break;
}
case product: {
py::list pyomo_args = expr.attr("args");
res = ginac_expr_from_pyomo_node(pyomo_args[0], leaf_map, expr_types) * ginac_expr_from_pyomo_node(pyomo_args[1], leaf_map, expr_types);
break;
}
case sum: {
py::list pyomo_args = expr.attr("args");
for (py::handle arg : pyomo_args) {
res += ginac_expr_from_pyomo_node(arg, leaf_map, expr_types);
}
break;
}
case negation: {
py::list pyomo_args = expr.attr("args");
res = - ginac_expr_from_pyomo_node(pyomo_args[0], leaf_map, expr_types);
break;
}
case external_func: {
long expr_id = expr_types.id(expr).cast<long>();
if (leaf_map.count(expr_id) == 0) {
leaf_map[expr_id] = symbol("f" + std::to_string(expr_id));
}
res = leaf_map[expr_id];
break;
}
case ExprType::power: {
py::list pyomo_args = expr.attr("args");
res = pow(ginac_expr_from_pyomo_node(pyomo_args[0], leaf_map, expr_types), ginac_expr_from_pyomo_node(pyomo_args[1], leaf_map, expr_types));
break;
}
case division: {
py::list pyomo_args = expr.attr("args");
res = ginac_expr_from_pyomo_node(pyomo_args[0], leaf_map, expr_types) / ginac_expr_from_pyomo_node(pyomo_args[1], leaf_map, expr_types);
break;
}
case unary_func: {
std::string function_name = expr.attr("getname")().cast<std::string>();
py::list pyomo_args = expr.attr("args");
if (function_name == "exp")
res = exp(ginac_expr_from_pyomo_node(pyomo_args[0], leaf_map, expr_types));
else if (function_name == "log")
res = log(ginac_expr_from_pyomo_node(pyomo_args[0], leaf_map, expr_types));
else if (function_name == "sin")
res = sin(ginac_expr_from_pyomo_node(pyomo_args[0], leaf_map, expr_types));
else if (function_name == "cos")
res = cos(ginac_expr_from_pyomo_node(pyomo_args[0], leaf_map, expr_types));
else if (function_name == "tan")
res = tan(ginac_expr_from_pyomo_node(pyomo_args[0], leaf_map, expr_types));
else if (function_name == "asin")
res = asin(ginac_expr_from_pyomo_node(pyomo_args[0], leaf_map, expr_types));
else if (function_name == "acos")
res = acos(ginac_expr_from_pyomo_node(pyomo_args[0], leaf_map, expr_types));
else if (function_name == "atan")
res = atan(ginac_expr_from_pyomo_node(pyomo_args[0], leaf_map, expr_types));
else if (function_name == "sqrt")
res = sqrt(ginac_expr_from_pyomo_node(pyomo_args[0], leaf_map, expr_types));
else
throw py::value_error("Unrecognized expression type: " + function_name);
break;
}
case linear: {
py::list pyomo_args = expr.attr("args");
for (py::handle arg : pyomo_args) {
res += ginac_expr_from_pyomo_node(arg, leaf_map, expr_types);
}
break;
}
case named_expr: {
res = ginac_expr_from_pyomo_node(expr.attr("expr"), leaf_map, expr_types);
break;
}
case numeric_constant: {
res = numeric(expr.attr("value").cast<double>());
break;
}
case pyomo_unit: {
res = numeric(1.0);
break;
}
case unary_abs: {
py::list pyomo_args = expr.attr("args");
res = abs(ginac_expr_from_pyomo_node(pyomo_args[0], leaf_map, expr_types));
break;
}
default: {
throw py::value_error("Unrecognized expression type: " +
expr_types.builtins.attr("str")(py::type::of(expr))
.cast<std::string>());
break;
}
}
return res;
}

ex ginac_expr_from_pyomo_expr(py::handle expr, PyomoExprTypes &expr_types) {
std::unordered_map<long, ex> leaf_map;
ex res = ginac_expr_from_pyomo_node(expr, leaf_map, expr_types);
return res;
}


PYBIND11_MODULE(ginac_interface, m) {
m.def("ginac_expr_from_pyomo_expr", &ginac_expr_from_pyomo_expr);
py::class_<PyomoExprTypes>(m, "PyomoExprTypes").def(py::init<>());
py::class_<ex>(m, "ex");
py::enum_<ExprType>(m, "ExprType")
.value("py_float", ExprType::py_float)
.value("var", ExprType::var)
.value("param", ExprType::param)
.value("product", ExprType::product)
.value("sum", ExprType::sum)
.value("negation", ExprType::negation)
.value("external_func", ExprType::external_func)
.value("power", ExprType::power)
.value("division", ExprType::division)
.value("unary_func", ExprType::unary_func)
.value("linear", ExprType::linear)
.value("named_expr", ExprType::named_expr)
.value("numeric_constant", ExprType::numeric_constant)
.export_values();
}
Loading

0 comments on commit 8015d7e

Please sign in to comment.