Skip to content

Commit

Permalink
CodeGen_MLIR: Add initial MLIR CodeGen
Browse files Browse the repository at this point in the history
Also adds compile_to_mlir methods to Func and Pipeline.
  • Loading branch information
xerpi committed Dec 5, 2024
1 parent 53619a4 commit 915e476
Show file tree
Hide file tree
Showing 12 changed files with 715 additions and 1 deletion.
14 changes: 14 additions & 0 deletions cmake/FindHalide_LLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,18 @@ if (Halide_LLVM_FOUND)
endif ()
endif ()
endforeach ()

find_package(MLIR CONFIG HINTS "${LLVM_INSTALL_PREFIX}" "${LLVM_DIR}/../mlir" "${LLVM_DIR}/../lib/cmake/mlir")
if (MLIR_FOUND)
target_include_directories(Halide_LLVM::Core INTERFACE "$<BUILD_INTERFACE:${MLIR_INCLUDE_DIRS}>")
target_link_libraries(Halide_LLVM::Core INTERFACE
MLIRAnalysis
MLIRIR
MLIRArithDialect
MLIRFuncDialect
MLIRMemRefDialect
MLIRSCFDialect
MLIRVectorDialect
)
endif ()
endif ()
1 change: 1 addition & 0 deletions python_bindings/src/halide/halide_/PyEnums.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ void define_enums(py::module &m) {
.value("function_info_header", OutputFileType::function_info_header)
.value("hlpipe", OutputFileType::hlpipe)
.value("llvm_assembly", OutputFileType::llvm_assembly)
.value("mlir", OutputFileType::mlir)
.value("object", OutputFileType::object)
.value("python_extension", OutputFileType::python_extension)
.value("pytorch_wrapper", OutputFileType::pytorch_wrapper)
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ target_sources(
CodeGen_Internal.cpp
CodeGen_LLVM.cpp
CodeGen_Metal_Dev.cpp
CodeGen_MLIR.cpp
CodeGen_OpenCL_Dev.cpp
CodeGen_Posix.cpp
CodeGen_PowerPC.cpp
Expand Down
545 changes: 545 additions & 0 deletions src/CodeGen_MLIR.cpp

Large diffs are not rendered by default.

109 changes: 109 additions & 0 deletions src/CodeGen_MLIR.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#ifndef HALIDE_CODEGEN_MLIR_H
#define HALIDE_CODEGEN_MLIR_H

/** \file
* Defines the code-generator for producing MLIR code
*/

#include "IRVisitor.h"
#include "Scope.h"

#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/ImplicitLocOpBuilder.h>

namespace Halide {

struct Target;

namespace Internal {

struct LoweredFunc;

class CodeGen_MLIR {
public:
CodeGen_MLIR(std::ostream &stream);

void compile(const Module &module);

protected:
void compile_func(mlir::ImplicitLocOpBuilder &builder, const LoweredFunc &func);

static mlir::Type mlir_type_of(mlir::ImplicitLocOpBuilder &builder, Halide::Type t);

class Visitor : public IRVisitor {
public:
Visitor(mlir::ImplicitLocOpBuilder &builder, const LoweredFunc &func);

protected:
mlir::Value codegen(const Expr &);
void codegen(const Stmt &);

void visit(const IntImm *) override;
void visit(const UIntImm *) override;
void visit(const FloatImm *) override;
void visit(const StringImm *) override;
void visit(const Cast *) override;
void visit(const Reinterpret *) override;
void visit(const Variable *) override;
void visit(const Add *) override;
void visit(const Sub *) override;
void visit(const Mul *) override;
void visit(const Div *) override;
void visit(const Mod *) override;
void visit(const Min *) override;
void visit(const Max *) override;
void visit(const EQ *) override;
void visit(const NE *) override;
void visit(const LT *) override;
void visit(const LE *) override;
void visit(const GT *) override;
void visit(const GE *) override;
void visit(const And *) override;
void visit(const Or *) override;
void visit(const Not *) override;
void visit(const Select *) override;
void visit(const Load *) override;
void visit(const Ramp *) override;
void visit(const Broadcast *) override;
void visit(const Call *) override;
void visit(const Let *) override;
void visit(const LetStmt *) override;
void visit(const AssertStmt *) override;
void visit(const ProducerConsumer *) override;
void visit(const For *) override;
void visit(const Store *) override;
void visit(const Provide *) override;
void visit(const Allocate *) override;
void visit(const Free *) override;
void visit(const Realize *) override;
void visit(const Block *) override;
void visit(const IfThenElse *) override;
void visit(const Evaluate *) override;
void visit(const Shuffle *) override;
void visit(const VectorReduce *) override;
void visit(const Prefetch *) override;
void visit(const Fork *) override;
void visit(const Acquire *) override;
void visit(const Atomic *) override;
void visit(const HoistedStorage *) override;

mlir::Type mlir_type_of(Halide::Type t) const;

void sym_push(const std::string &name, mlir::Value value);
void sym_pop(const std::string &name);
mlir::Value sym_get(const std::string &name, bool must_succeed = true) const;

private:
mlir::ImplicitLocOpBuilder &builder;
mlir::Value value;
Scope<mlir::Value> symbol_table;
};

mlir::MLIRContext mlir_context;
std::ostream &stream;
};

} // namespace Internal
} // namespace Halide

#endif
10 changes: 10 additions & 0 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3577,6 +3577,16 @@ void Func::compile_to_llvm_assembly(const string &filename, const vector<Argumen
pipeline().compile_to_llvm_assembly(filename, args, "", target);
}

void Func::compile_to_mlir(const string &filename, const vector<Argument> &args, const string &fn_name,
const Target &target) {
pipeline().compile_to_mlir(filename, args, fn_name, target);
}

void Func::compile_to_mlir(const string &filename, const vector<Argument> &args,
const Target &target) {
pipeline().compile_to_mlir(filename, args, "", target);
}

void Func::compile_to_object(const string &filename, const vector<Argument> &args,
const string &fn_name, const Target &target) {
pipeline().compile_to_object(filename, args, fn_name, target);
Expand Down
8 changes: 8 additions & 0 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,14 @@ class Func {
const Target &target = get_target_from_environment());
// @}

/** Emit MLIR code. */
//@{
void compile_to_mlir(const std::string &filename, const std::vector<Argument> &, const std::string &fn_name,
const Target &target = get_target_from_environment());
void compile_to_mlir(const std::string &filename, const std::vector<Argument> &,
const Target &target = get_target_from_environment());
// @}

/** Statically compile this function to an object file, with the
* given filename (which should probably end in .o or .obj), type
* signature, and C function name (which defaults to the same name
Expand Down
2 changes: 1 addition & 1 deletion src/Generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ gengen
[assembly, bitcode, c_header, c_source, cpp_stub, featurization,
llvm_assembly, object, python_extension, pytorch_wrapper, registration,
schedule, static_library, stmt, stmt_html, conceptual_stmt,
conceptual_stmt_html, compiler_log, hlpipe, device_code].
conceptual_stmt_html, compiler_log, hlpipe, device_code, mlir].
If omitted, default value is [c_header, static_library, registration].
-p A comma-separated list of shared libraries that will be loaded before the
Expand Down
11 changes: 11 additions & 0 deletions src/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "CodeGen_C.h"
#include "CodeGen_Internal.h"
#include "CodeGen_MLIR.h"
#include "CodeGen_PyTorch.h"
#include "CompilerLogger.h"
#include "Debug.h"
Expand Down Expand Up @@ -42,6 +43,7 @@ std::map<OutputFileType, const OutputInfo> get_output_info(const Target &target)
{OutputFileType::function_info_header, {"function_info_header", ".function_info.h", IsSingle}},
{OutputFileType::hlpipe, {"hlpipe", ".hlpipe", IsSingle}},
{OutputFileType::llvm_assembly, {"llvm_assembly", ".ll", IsMulti}},
{OutputFileType::mlir, {"mlir", ".mlir", IsSingle}},
{OutputFileType::object, {"object", is_windows_coff ? ".obj" : ".o", IsMulti}},
{OutputFileType::python_extension, {"python_extension", ".py.cpp", IsSingle}},
{OutputFileType::pytorch_wrapper, {"pytorch_wrapper", ".pytorch.h", IsSingle}},
Expand Down Expand Up @@ -774,6 +776,15 @@ void Module::compile(const std::map<OutputFileType, std::string> &output_files)
file.close();
internal_assert(!file.fail());
}
if (contains(output_files, OutputFileType::mlir)) {
debug(1) << "Module.compile(): mlir " << output_files.at(OutputFileType::mlir) << "\n";

std::ofstream file(output_files.at(OutputFileType::mlir));
Internal::CodeGen_MLIR cg(file);
cg.compile(*this);
file.close();
internal_assert(!file.fail());
}
if (contains(output_files, OutputFileType::compiler_log)) {
debug(1) << "Module.compile(): compiler_log " << output_files.at(OutputFileType::compiler_log) << "\n";
std::ofstream file(output_files.at(OutputFileType::compiler_log));
Expand Down
1 change: 1 addition & 0 deletions src/Module.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ enum class OutputFileType {
function_info_header,
hlpipe,
llvm_assembly,
mlir,
object,
python_extension,
pytorch_wrapper,
Expand Down
8 changes: 8 additions & 0 deletions src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ void Pipeline::compile_to_llvm_assembly(const string &filename,
m.compile(single_output(filename, m, OutputFileType::llvm_assembly));
}

void Pipeline::compile_to_mlir(const string &filename,
const vector<Argument> &args,
const string &fn_name,
const Target &target) {
Module m = compile_to_module(args, fn_name, target);
m.compile(single_output(filename, m, OutputFileType::mlir));
}

void Pipeline::compile_to_object(const string &filename,
const vector<Argument> &args,
const string &fn_name,
Expand Down
6 changes: 6 additions & 0 deletions src/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,12 @@ class Pipeline {
const std::string &fn_name,
const Target &target = get_target_from_environment());

/** Emit MLIR code. */
void compile_to_mlir(const std::string &filename,
const std::vector<Argument> &args,
const std::string &fn_name,
const Target &target = get_target_from_environment());

/** Statically compile a pipeline with multiple output functions to an
* object file, with the given filename (which should probably end in
* .o or .obj), type signature, and C function name (which defaults to
Expand Down

0 comments on commit 915e476

Please sign in to comment.