Skip to content

Commit

Permalink
[tuner]: use attribute for workgroup and reduction array
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Dec 7, 2024
1 parent 0f98992 commit 0edeb8e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 34 deletions.
6 changes: 2 additions & 4 deletions compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,8 @@ MLIR_CAPI_EXPORTED MlirAttribute
ireeGPULoweringConfigAttrGetAttributes(MlirAttribute attr);

struct ireeGPUTileSizes {
const int64_t *workgroupTileSizes;
size_t numWorkgroupTileSizes;
const int64_t *reductionTileSizes;
size_t numReductionTileSizes;
MlirAttribute workgroupAttr;
MlirAttribute reductionAttr;
};

MLIR_CAPI_EXPORTED ireeGPUTileSizes
Expand Down
50 changes: 34 additions & 16 deletions compiler/bindings/python/IREECompilerDialectsModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,22 +353,40 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
"Gets an #iree_gpu.lowering_config from parameters.")
.def_property_readonly("attributes",
ireeGPULoweringConfigAttrGetAttributes)
.def_property_readonly("workgroup_tile_sizes",
[](MlirAttribute self) -> std::vector<int64_t> {
auto tilesizes =
ireeGPULoweringConfigAttrGetTileSizes(self);
return {tilesizes.workgroupTileSizes,
tilesizes.workgroupTileSizes +
tilesizes.numWorkgroupTileSizes};
})
.def_property_readonly("reduction_tile_sizes",
[](MlirAttribute self) -> std::vector<int64_t> {
auto tilesizes =
ireeGPULoweringConfigAttrGetTileSizes(self);
return {tilesizes.reductionTileSizes,
tilesizes.reductionTileSizes +
tilesizes.numReductionTileSizes};
})
.def_property_readonly(
"workgroup_tile_sizes",
[](MlirAttribute self) -> std::vector<int64_t> {
auto tilesizes = ireeGPULoweringConfigAttrGetTileSizes(self);
MlirAttribute workgroupAttr = tilesizes.workgroupAttr;
if (mlirAttributeIsNull(workgroupAttr)) {
return {};
}

size_t len = mlirArrayAttrGetNumElements(workgroupAttr);
std::vector<int64_t> workgroup(len);
for (size_t i = 0, e = len; i < e; ++i) {
MlirAttribute attr = mlirArrayAttrGetElement(workgroupAttr, i);
workgroup[i] = mlirIntegerAttrGetValueInt(attr);
}
return workgroup;
})
.def_property_readonly(
"reduction_tile_sizes",
[](MlirAttribute self) -> std::vector<int64_t> {
auto tilesizes = ireeGPULoweringConfigAttrGetTileSizes(self);
MlirAttribute reductionAttr = tilesizes.reductionAttr;
if (mlirAttributeIsNull(reductionAttr)) {
return {};
}

size_t len = mlirArrayAttrGetNumElements(reductionAttr);
std::vector<int64_t> reduction(len);
for (size_t i = 0, e = len; i < e; ++i) {
MlirAttribute attr = mlirArrayAttrGetElement(reductionAttr, i);
reduction[i] = mlirIntegerAttrGetValueInt(attr);
}
return reduction;
})
.def_property_readonly(
"subgroup_count_mn",
[](MlirAttribute self) -> py::tuple {
Expand Down
25 changes: 11 additions & 14 deletions compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,23 +218,20 @@ MlirAttribute ireeGPULoweringConfigAttrGetAttributes(MlirAttribute attr) {
ireeGPUTileSizes ireeGPULoweringConfigAttrGetTileSizes(MlirAttribute attr) {
assert(ireeAttributeIsAGPULoweringConfigAttr(attr));
ireeGPUTileSizes tilesizes = {};
auto loweringConfigAttr =
mlir::DictionaryAttr dict =
llvm::cast<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
unwrap(attr));

llvm::SmallVector<int64_t> workgroups =
loweringConfigAttr.getWorkgroupTileSizes();
tilesizes.workgroupTileSizes = workgroups.data();
tilesizes.numWorkgroupTileSizes = workgroups.size();
unwrap(attr))
.getAttributes();

llvm::SmallVector<int64_t> reductions =
loweringConfigAttr.getStaticTilingLevelSizes(
llvm::to_underlying(
mlir::iree_compiler::IREE::GPU::TilingLevel::Reduction),
nullptr);
tilesizes.reductionTileSizes = reductions.data();
tilesizes.numReductionTileSizes = reductions.size();
constexpr mlir::StringLiteral workgroupName = "workgroup";
if (auto workgroupArray = dict.getAs<mlir::ArrayAttr>(workgroupName)) {
tilesizes.workgroupAttr = wrap(workgroupArray);
}

constexpr mlir::StringLiteral reductionName = "reduction";
if (auto reductionArray = dict.getAs<mlir::ArrayAttr>(reductionName)) {
tilesizes.reductionAttr = wrap(reductionArray);
}
return tilesizes;
}

Expand Down

0 comments on commit 0edeb8e

Please sign in to comment.