Skip to content
This repository has been archived by the owner on Apr 23, 2021. It is now read-only.

Custom lowering of MemRefs to LLVM #337

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 80 additions & 16 deletions include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Type;

namespace mlir {

class MemRefDescriptor;
class UnrankedMemRefType;

namespace LLVM {
Expand All @@ -56,8 +57,9 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic,
SignatureConversion &result);
virtual LLVM::LLVMType convertFunctionSignature(FunctionType type,
bool isVariadic,
SignatureConversion &result);

/// Convert a non-empty list of types to be returned from a function into a
/// supported LLVM IR type. In particular, if more than one values is
Expand All @@ -71,6 +73,20 @@ class LLVMTypeConverter : public TypeConverter {
/// Returns the LLVM dialect.
LLVM::LLVMDialect *getDialect() { return llvmDialect; }

/// Create a DefaultMemRefDescriptor object for 'value'.
virtual std::unique_ptr<MemRefDescriptor>
createMemRefDescriptor(ValuePtr value);

/// Builds IR creating an uninitialized value of the descriptor type.
virtual std::unique_ptr<MemRefDescriptor>
buildMemRefDescriptor(OpBuilder &builder, Location loc, Type descriptorType);
/// Builds IR creating a MemRef descriptor that represents `type` and
/// populates it with static shape and stride information extracted from the
/// type.
virtual std::unique_ptr<MemRefDescriptor>
buildStaticMemRefDescriptor(OpBuilder &builder, Location loc, MemRefType type,
ValuePtr memory);

/// Promote the LLVM struct representation of all MemRef descriptors to stack
/// and use pointers to struct to avoid the complexity of the
/// platform-specific C/C++ ABI lowering related to struct argument passing.
Expand All @@ -90,6 +106,9 @@ class LLVMTypeConverter : public TypeConverter {
llvm::Module *module;
LLVM::LLVMDialect *llvmDialect;

// Extract an LLVM IR dialect type.
LLVM::LLVMType unwrap(Type type);

private:
Type convertStandardType(Type type);

Expand Down Expand Up @@ -129,9 +148,60 @@ class LLVMTypeConverter : public TypeConverter {
// Get the LLVM representation of the index type based on the bitwidth of the
// pointer as defined by the data layout of the module.
LLVM::LLVMType getIndexType();
};

// Extract an LLVM IR dialect type.
LLVM::LLVMType unwrap(Type type);
// Base helper class to lower MemRef type to a descriptor in LLVM. Provides an
// abstract API to produce LLVM dialect operations that manipulate the MemRef
// descriptor. Specific MemRef descriptor implementations should inherint from
// this class and implement the API.
struct MemRefDescriptor {

virtual Value *getValue() = 0;

/// Builds IR extracting the allocated pointer from the descriptor.
virtual Value *allocatedPtr(OpBuilder &builder, Location loc) = 0;
/// Builds IR inserting the allocated pointer into the descriptor.
virtual void setAllocatedPtr(OpBuilder &builder, Location loc,
Value *ptr) = 0;

/// Builds IR extracting the aligned pointer from the descriptor.
virtual Value *alignedPtr(OpBuilder &builder, Location loc) = 0;

/// Builds IR inserting the aligned pointer into the descriptor.
virtual void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) = 0;

/// Builds IR extracting the offset from the descriptor.
virtual Value *offset(OpBuilder &builder, Location loc) = 0;

/// Builds IR inserting the offset into the descriptor.
virtual void setOffset(OpBuilder &builder, Location loc, Value *offset) = 0;

virtual void setConstantOffset(OpBuilder &builder, Location loc,
uint64_t offset) = 0;

/// Builds IR extracting the pos-th size from the descriptor.
virtual Value *size(OpBuilder &builder, Location loc, unsigned pos) = 0;

/// Builds IR inserting the pos-th size into the descriptor
virtual void setSize(OpBuilder &builder, Location loc, unsigned pos,
Value *size) = 0;
virtual void setConstantSize(OpBuilder &builder, Location loc, unsigned pos,
uint64_t size) = 0;

/// Builds IR extracting the pos-th size from the descriptor.
virtual Value *stride(OpBuilder &builder, Location loc, unsigned pos) = 0;

/// Builds IR inserting the pos-th stride into the descriptor
virtual void setStride(OpBuilder &builder, Location loc, unsigned pos,
Value *stride) = 0;
virtual void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
uint64_t stride) = 0;

/// Returns the (LLVM) type this descriptor points to.
virtual LLVM::LLVMType getElementType() = 0;

protected:
MemRefDescriptor() = default;
};

/// Helper class to produce LLVM dialect operations extracting or inserting
Expand All @@ -144,7 +214,7 @@ class StructBuilder {
static StructBuilder undef(OpBuilder &builder, Location loc,
Type descriptorType);

/*implicit*/ operator ValuePtr() { return value; }
ValuePtr getValue() { return value; }

protected:
// LLVM value
Expand All @@ -158,22 +228,16 @@ class StructBuilder {
/// Builds IR to set a value in the struct at position pos
void setPtr(OpBuilder &builder, Location loc, unsigned pos, ValuePtr ptr);
};

/// Helper class to produce LLVM dialect operations extracting or inserting
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
/// The Value may be null, in which case none of the operations are valid.
class MemRefDescriptor : public StructBuilder {
class DefaultMemRefDescriptor : public StructBuilder, public MemRefDescriptor {
public:
/// Construct a helper for the given descriptor value.
explicit MemRefDescriptor(ValuePtr descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR creating a MemRef descriptor that represents `type` and
/// populates it with static shape and stride information extracted from the
/// type.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
MemRefType type, ValuePtr memory);
explicit DefaultMemRefDescriptor(ValuePtr descriptor);

ValuePtr getValue() override { return StructBuilder::getValue(); };

/// Builds IR extracting the allocated pointer from the descriptor.
ValuePtr allocatedPtr(OpBuilder &builder, Location loc);
Expand Down
14 changes: 8 additions & 6 deletions lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,9 +623,10 @@ struct GPUFuncOpLowering : LLVMOpLowering {
// and canonicalize that away later.
ValuePtr attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
auto type = attribution->getType().cast<MemRefType>();
auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
type, memory);
signatureConversion.remapInput(numProperArguments + en.index(), descr);
auto descr =
lowering.buildStaticMemRefDescriptor(rewriter, loc, type, memory);
signatureConversion.remapInput(numProperArguments + en.index(),
descr->getValue());
}

// Rewrite private memory attributions to alloca'ed buffers.
Expand All @@ -649,10 +650,11 @@ struct GPUFuncOpLowering : LLVMOpLowering {
rewriter.getI64IntegerAttr(type.getNumElements()));
ValuePtr allocated = rewriter.create<LLVM::AllocaOp>(
gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
type, allocated);
auto descr = lowering.buildStaticMemRefDescriptor(rewriter, loc, type,
allocated);
signatureConversion.remapInput(
numProperArguments + numWorkgroupAttributions + en.index(), descr);
numProperArguments + numWorkgroupAttributions + en.index(),
descr->getValue());
}
}

Expand Down
40 changes: 21 additions & 19 deletions lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,32 +117,33 @@ namespace {
/// EDSC-compatible wrapper for MemRefDescriptor.
class BaseViewConversionHelper {
public:
BaseViewConversionHelper(Type type)
: d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
BaseViewConversionHelper(Type type, LLVMTypeConverter &typeConverter)
: d(typeConverter.buildMemRefDescriptor(rewriter(), loc(), type)) {}

BaseViewConversionHelper(ValuePtr v) : d(v) {}
BaseViewConversionHelper(ValuePtr v, LLVMTypeConverter &typeConverter)
: d(typeConverter.createMemRefDescriptor(v)) {}

/// Wrappers around MemRefDescriptor that use EDSC builder and location.
ValuePtr allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
void setAllocatedPtr(ValuePtr v) { d.setAllocatedPtr(rewriter(), loc(), v); }
ValuePtr alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
void setAlignedPtr(ValuePtr v) { d.setAlignedPtr(rewriter(), loc(), v); }
ValuePtr offset() { return d.offset(rewriter(), loc()); }
void setOffset(ValuePtr v) { d.setOffset(rewriter(), loc(), v); }
ValuePtr size(unsigned i) { return d.size(rewriter(), loc(), i); }
void setSize(unsigned i, ValuePtr v) { d.setSize(rewriter(), loc(), i, v); }
ValuePtr stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
ValuePtr allocatedPtr() { return d->allocatedPtr(rewriter(), loc()); }
void setAllocatedPtr(ValuePtr v) { d->setAllocatedPtr(rewriter(), loc(), v); }
ValuePtr alignedPtr() { return d->alignedPtr(rewriter(), loc()); }
void setAlignedPtr(ValuePtr v) { d->setAlignedPtr(rewriter(), loc(), v); }
ValuePtr offset() { return d->offset(rewriter(), loc()); }
void setOffset(ValuePtr v) { d->setOffset(rewriter(), loc(), v); }
ValuePtr size(unsigned i) { return d->size(rewriter(), loc(), i); }
void setSize(unsigned i, ValuePtr v) { d->setSize(rewriter(), loc(), i, v); }
ValuePtr stride(unsigned i) { return d->stride(rewriter(), loc(), i); }
void setStride(unsigned i, ValuePtr v) {
d.setStride(rewriter(), loc(), i, v);
d->setStride(rewriter(), loc(), i, v);
}

operator ValuePtr() { return d; }
operator ValuePtr() { return d->getValue(); }

private:
OpBuilder &rewriter() { return ScopedContext::getBuilder(); }
Location loc() { return ScopedContext::getLocation(); }

MemRefDescriptor d;
std::unique_ptr<MemRefDescriptor> d;
};
} // namespace

Expand Down Expand Up @@ -190,14 +191,15 @@ class SliceOpConversion : public LLVMOpLowering {
ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext context(rewriter, op->getLoc());
SliceOpOperandAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.view());
BaseViewConversionHelper baseDesc(adaptor.view(), lowering);

auto sliceOp = cast<SliceOp>(op);
auto memRefType = sliceOp.getBaseViewType();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
.cast<LLVM::LLVMType>();

BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType()));
BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType()),
lowering);

// TODO(ntv): extract sizes and emit asserts.
SmallVector<ValuePtr, 4> strides(memRefType.getRank());
Expand Down Expand Up @@ -282,15 +284,15 @@ class TransposeOpConversion : public LLVMOpLowering {
// Initialize the common boilerplate and alloca at the top of the FuncOp.
edsc::ScopedContext context(rewriter, op->getLoc());
TransposeOpOperandAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.view());
BaseViewConversionHelper baseDesc(adaptor.view(), lowering);

auto transposeOp = cast<TransposeOp>(op);
// No permutation, early exit.
if (transposeOp.permutation().isIdentity())
return rewriter.replaceOp(op, {baseDesc}), matchSuccess();

BaseViewConversionHelper desc(
lowering.convertType(transposeOp.getViewType()));
lowering.convertType(transposeOp.getViewType()), lowering);

// Copy the base and aligned pointers from the old descriptor to the new
// one.
Expand Down
Loading