Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove IREE solver #4585

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/run_periodic_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ on:
env:
FORCE_COLOR: 3
PYBAMM_IDAKLU_EXPR_CASADI: ON
PYBAMM_IDAKLU_EXPR_IREE: ON

concurrency:
# github.workflow: name of the workflow, so that we don't cancel other workflows
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/test_on_push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ on:
env:
FORCE_COLOR: 3
PYBAMM_IDAKLU_EXPR_CASADI: ON
PYBAMM_IDAKLU_EXPR_IREE: ON
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please leave in the IREE solver code and just turn off this flag so it is not tested in the CI? Its going to be a lot more difficult to re-enable this in the future if the code is deleted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two main concerns with disabling it this way:

  • Turning it off in the tests would cause the code to be untested, and untested code tends to decay. It would not be guaranteed to work when re-enabled
  • IREE requires a specific version of Jax to operate. This blocks us from updating the Jax version cleanly. For instance we need to update the Jax version to support 3.13, and that would mean installing a different version whenever you wanted to use IREE. This is possible, but adds a lot of extra complexity that would not be fully tested if it is disabled by default

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like untested code either, but I don't see how deleting the code from the repository makes it easier to re-enable the IREE functionality later on? If anything it will make it more difficult. I agree that the IREE code will decay if untested, but its going to decay faster if it only exists in an old commit. I think we should remove it from any documentation and slap an "experimental, internal use only and not guarenteed to work" so that no users try to use it. But the IREE backend is still our best plan for a GPU-compatible backend for PyBaMM, so I'd like to keep it so that we can (a) test any new functionality that IREE implements for double precision, and (b) see what can be achieved numerically with only single precision. If we turn if off, it won't require any additional workload to keep it maintained going forward (other than by the devs actually working on the IREE solver).


concurrency:
# github.workflow: name of the workflow, so that we don't cancel other workflows
Expand Down
35 changes: 0 additions & 35 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,27 +53,6 @@ if(${PYBAMM_IDAKLU_EXPR_CASADI} STREQUAL "ON" )
)
endif()

# Check IREE build flag
if(NOT DEFINED PYBAMM_IDAKLU_EXPR_IREE)
set(PYBAMM_IDAKLU_EXPR_IREE OFF)
endif()
message("PYBAMM_IDAKLU_EXPR_IREE: ${PYBAMM_IDAKLU_EXPR_IREE}")

# IREE (MLIR expression evaluation) PyBaMM source files
set(IDAKLU_EXPR_IREE_SOURCE_FILES "")
if(${PYBAMM_IDAKLU_EXPR_IREE} STREQUAL "ON" )
add_compile_definitions(IREE_ENABLE)
# Source file list
set(IDAKLU_EXPR_IREE_SOURCE_FILES
src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp
src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.hpp
src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp
src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.hpp
src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp
src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp
)
endif()

# The complete (all dependencies) sources list should be mirrored in setup.py
pybind11_add_module(idaklu
# pybind11 interface
Expand Down Expand Up @@ -109,7 +88,6 @@ pybind11_add_module(idaklu
src/pybamm/solvers/c_solvers/idaklu/observe.cpp
# IDAKLU expressions - concrete implementations
${IDAKLU_EXPR_CASADI_SOURCE_FILES}
${IDAKLU_EXPR_IREE_SOURCE_FILES}
)

if (NOT DEFINED USE_PYTHON_CASADI)
Expand Down Expand Up @@ -179,16 +157,3 @@ else()
endif()
include_directories(${SuiteSparse_INCLUDE_DIRS})
target_link_libraries(idaklu PRIVATE ${SuiteSparse_LIBRARIES})

# IREE (MLIR compiler and runtime library) build settings
if(${PYBAMM_IDAKLU_EXPR_IREE} STREQUAL "ON" )
set(IREE_BUILD_COMPILER ON)
set(IREE_BUILD_TESTS OFF)
set(IREE_BUILD_SAMPLES OFF)
add_subdirectory(iree EXCLUDE_FROM_ALL)
set(IREE_COMPILER_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler")
target_include_directories(idaklu SYSTEM PRIVATE "${IREE_COMPILER_ROOT}/bindings/c/iree/compiler")
target_compile_options(idaklu PRIVATE ${IREE_DEFAULT_COPTS})
target_link_libraries(idaklu PRIVATE iree_compiler_bindings_c_loader)
target_link_libraries(idaklu PRIVATE iree_runtime_runtime)
endif()
12 changes: 0 additions & 12 deletions docs/source/user_guide/installation/gnu-linux-mac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,6 @@ Users can install ``jax`` and ``jaxlib`` to use the Jax solver.

The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system.

.. _optional-iree-mlir-support:

Optional - IREE / MLIR support
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Users can install ``iree`` (for MLIR just-in-time compilation) to use for main expression evaluation in the IDAKLU solver. Requires ``jax``.

.. code:: bash

pip install "pybamm[iree,jax]"

The ``pip install "pybamm[iree,jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``iree`` onto your system.

Uninstall PyBaMM
----------------
Expand Down
12 changes: 0 additions & 12 deletions docs/source/user_guide/installation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ Optional solvers
The following solvers are optionally available:

* `jax <https://jax.readthedocs.io/en/latest/notebooks/quickstart.html>`_ -based solver, see `Optional - JaxSolver <gnu-linux-mac.html#optional-jaxsolver>`_.
* `IREE <https://iree.dev/>`_ (`MLIR <https://mlir.llvm.org/>`_) support, see `Optional - IREE / MLIR Support <gnu-linux-mac.html#optional-iree-mlir-support>`_.

Dependencies
------------
Expand Down Expand Up @@ -207,17 +206,6 @@ Dependency Minimu
`jaxlib <https://pypi.org/project/jaxlib/>`__ 0.4.20 jax Support library for JAX
========================================================================= ================== ================== =======================

IREE dependencies
^^^^^^^^^^^^^^^^^^

Installable with ``pip install "pybamm[iree]"`` (requires ``jax`` dependencies to be installed).

========================================================================= ================== ================== =======================
Dependency Minimum Version pip extra Notes
========================================================================= ================== ================== =======================
`iree-compiler <https://iree.dev/>`__ 20240507.886 iree IREE compiler
========================================================================= ================== ================== =======================

Full installation guide
-----------------------

Expand Down
85 changes: 1 addition & 84 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import nox
import os
import sys
import warnings
from pathlib import Path


Expand All @@ -13,42 +12,13 @@
else:
nox.options.sessions = ["pre-commit", "unit"]


def set_iree_state():
"""
Check if IREE is enabled and set the environment variable accordingly.

Returns
-------
str
"ON" if IREE is enabled, "OFF" otherwise.

"""
state = "ON" if os.getenv("PYBAMM_IDAKLU_EXPR_IREE", "OFF") == "ON" else "OFF"
if state == "ON":
if sys.platform == "win32" or sys.platform == "darwin":
warnings.warn(
(
"IREE is not enabled on Windows and MacOS. "
"Setting PYBAMM_IDAKLU_EXPR_IREE=OFF."
),
stacklevel=2,
)
return "OFF"
return state


homedir = os.getenv("HOME")
PYBAMM_ENV = {
"LD_LIBRARY_PATH": f"{homedir}/.local/lib",
"PYTHONIOENCODING": "utf-8",
"MPLBACKEND": "Agg",
# Expression evaluators (...EXPR_CASADI cannot be fully disabled at this time)
"PYBAMM_IDAKLU_EXPR_CASADI": os.getenv("PYBAMM_IDAKLU_EXPR_CASADI", "ON"),
"PYBAMM_IDAKLU_EXPR_IREE": set_iree_state(),
"IREE_INDEX_URL": os.getenv(
"IREE_INDEX_URL", "https://iree.dev/pip-release-links.html"
),
"PYBAMM_DISABLE_TELEMETRY": "true",
}
VENV_DIR = Path("./venv").resolve()
Expand Down Expand Up @@ -91,29 +61,6 @@ def run_pybamm_requires(session):
"advice.detachedHead=false",
external=True,
)
if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON" and not os.path.exists(
"./iree"
):
session.run(
"git",
"clone",
"--depth=1",
"--recurse-submodules",
"--shallow-submodules",
"--branch=candidate-20240507.886",
"https://github.com/openxla/iree",
"iree/",
external=True,
)
with session.chdir("iree"):
session.run(
"git",
"submodule",
"update",
"--init",
"--recursive",
external=True,
)
else:
session.error("nox -s pybamm-requires is only available on Linux & macOS.")

Expand All @@ -128,15 +75,6 @@ def run_coverage(session):
if "CI" in os.environ:
session.install("pytest-github-actions-annotate-failures")
session.install("-e", ".[all,dev,jax]", silent=False)
if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON":
# See comments in 'dev' session
session.install(
"-e",
".[iree]",
"--find-links",
PYBAMM_ENV.get("IREE_INDEX_URL"),
silent=False,
)
session.run("pytest", "--cov=pybamm", "--cov-report=xml", "tests/unit")


Expand Down Expand Up @@ -177,15 +115,6 @@ def run_unit(session):
set_environment_variables(PYBAMM_ENV, session=session)
session.install("setuptools", silent=False)
session.install("-e", ".[all,dev,jax]", silent=False)
if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON":
# See comments in 'dev' session
session.install(
"-e",
".[iree]",
"--find-links",
PYBAMM_ENV.get("IREE_INDEX_URL"),
silent=False,
)
session.run("python", "-m", "pytest", "-m", "unit")


Expand Down Expand Up @@ -220,17 +149,6 @@ def set_dev(session):
session.install("virtualenv", "cmake")
session.run("virtualenv", os.fsdecode(VENV_DIR), silent=True)
python = os.fsdecode(VENV_DIR.joinpath("bin/python"))
components = ["all", "dev", "jax"]
args = []
if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON":
# Install IREE libraries for Jax-MLIR expression evaluation in the IDAKLU solver
# (optional). IREE is currently pre-release and relies on nightly jaxlib builds.
# When upgrading Jax/IREE ensure that the following are compatible with each other:
# - Jax and Jaxlib version [pyproject.toml]
# - IREE repository clone (use the matching nightly candidate) [noxfile.py]
# - IREE compiler matches Jaxlib (use the matching nightly build) [pyproject.toml]
components.append("iree")
args = ["--find-links", PYBAMM_ENV.get("IREE_INDEX_URL")]
# Temporary fix for Python 3.12 CI. TODO: remove after
# https://bitbucket.org/pybtex-devs/pybtex/issues/169/replace-pkg_resources-with
# is fixed
Expand All @@ -241,8 +159,7 @@ def set_dev(session):
"pip",
"install",
"-e",
".[{}]".format(",".join(components)),
*args,
".[all,dev,jax]",
external=True,
)

Expand Down
10 changes: 2 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,12 @@ dev = [
"importlib-metadata; python_version < '3.10'",
]
# For the Jax solver.
# Note: These must be kept in sync with the versions defined in pybamm/util.py, and
# must remain compatible with IREE (see noxfile.py for IREE compatibility).
# Note: These must be kept in sync with the versions defined in pybamm/util.py
jax = [
"jax==0.4.27",
"jaxlib==0.4.27",
]
# For MLIR expression evaluation (IDAKLU Solver)
iree = [
# must be pip installed with --find-links=https://iree.dev/pip-release-links.html
"iree-compiler==20240507.886", # see IREE compatibility notes in noxfile.py
]
# Contains all optional dependencies, except for jax, iree, and dev dependencies
# Contains all optional dependencies, except for jax and dev dependencies
all = [
"scikit-fem>=8.1.0",
"pybamm[examples,plot,cite,bpx,tqdm]",
Expand Down
10 changes: 0 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,11 @@ def run(self):

build_type = os.getenv("PYBAMM_CPP_BUILD_TYPE", "RELEASE")
idaklu_expr_casadi = os.getenv("PYBAMM_IDAKLU_EXPR_CASADI", "ON")
idaklu_expr_iree = os.getenv("PYBAMM_IDAKLU_EXPR_IREE", "OFF")
cmake_args = [
f"-DCMAKE_BUILD_TYPE={build_type}",
f"-DPYTHON_EXECUTABLE={sys.executable}",
"-DUSE_PYTHON_CASADI={}".format("TRUE" if use_python_casadi else "FALSE"),
f"-DPYBAMM_IDAKLU_EXPR_CASADI={idaklu_expr_casadi}",
f"-DPYBAMM_IDAKLU_EXPR_IREE={idaklu_expr_iree}",
]
if self.suitesparse_root:
cmake_args.append(
Expand Down Expand Up @@ -302,14 +300,6 @@ def compile_KLU():
"src/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSparsity.hpp",
"src/pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.cpp",
"src/pybamm/solvers/c_solvers/idaklu/Expressions/Casadi/CasadiFunctions.hpp",
"src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEBaseFunction.hpp",
"src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunction.hpp",
"src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.cpp",
"src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/IREEFunctions.hpp",
"src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.cpp",
"src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/iree_jit.hpp",
"src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.cpp",
"src/pybamm/solvers/c_solvers/idaklu/Expressions/IREE/ModuleParser.hpp",
"src/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp",
"src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp",
"src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp",
Expand Down
5 changes: 1 addition & 4 deletions src/pybamm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from pybamm.version import __version__

# Demote expressions to 32-bit floats/ints - option used for IDAKLU-MLIR compilation
demote_expressions_to_32bit = False

# Utility classes and methods
from .util import root_dir
from .util import Timer, TimerTime, FuzzyDict
Expand Down Expand Up @@ -173,7 +170,7 @@
from .solvers.jax_bdf_solver import jax_bdf_integrate

from .solvers.idaklu_jax import IDAKLUJax
from .solvers.idaklu_solver import IDAKLUSolver, has_idaklu, has_iree
from .solvers.idaklu_solver import IDAKLUSolver, has_idaklu

# Experiments
from .experiment.experiment import Experiment
Expand Down
47 changes: 1 addition & 46 deletions src/pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,54 +596,9 @@ def __init__(self, symbol: pybamm.Symbol):
static_argnums=self._static_argnums,
)

def _demote_constants(self):
"""Demote 64-bit constants (f64, i64) to 32-bit (f32, i32)"""
if not pybamm.demote_expressions_to_32bit:
return # pragma: no cover
self._constants = EvaluatorJax._demote_64_to_32(self._constants)

@classmethod
def _demote_64_to_32(cls, c):
"""Demote 64-bit operations (f64, i64) to 32-bit (f32, i32)"""

if not pybamm.demote_expressions_to_32bit:
return c
if isinstance(c, float):
c = jax.numpy.float32(c)
if isinstance(c, int):
c = jax.numpy.int32(c)
if isinstance(c, np.int64):
c = c.astype(jax.numpy.int32)
if isinstance(c, np.ndarray):
if c.dtype == np.float64:
c = c.astype(jax.numpy.float32)
if c.dtype == np.int64:
c = c.astype(jax.numpy.int32)
if isinstance(c, jax.numpy.ndarray):
if c.dtype == jax.numpy.float64:
c = c.astype(jax.numpy.float32)
if c.dtype == jax.numpy.int64:
c = c.astype(jax.numpy.int32)
if isinstance(
c, pybamm.expression_tree.operations.evaluate_python.JaxCooMatrix
):
if c.data.dtype == np.float64:
c.data = c.data.astype(jax.numpy.float32)
if c.row.dtype == np.int64:
c.row = c.row.astype(jax.numpy.int32)
if c.col.dtype == np.int64:
c.col = c.col.astype(jax.numpy.int32)
if isinstance(c, dict):
c = {key: EvaluatorJax._demote_64_to_32(value) for key, value in c.items()}
if isinstance(c, tuple):
c = tuple(EvaluatorJax._demote_64_to_32(value) for value in c)
if isinstance(c, list):
c = [EvaluatorJax._demote_64_to_32(value) for value in c]
return c

@property
def _constants(self):
return tuple(map(EvaluatorJax._demote_64_to_32, self.__constants))
return self.__constants

@_constants.setter
def _constants(self, value):
Expand Down
Loading