Skip to content

Commit

Permalink
Custom MemRef lowering to LLVM
Browse files Browse the repository at this point in the history
This PR introduces the infrastructure to provide a custom Std-to-LLVM
lowering for MemRef type:
* MemRefDescriptor class is turned into an abstract API that defines
  the methods needed to perform the custom lowering.
* Existing MemRefDesriptor class is renamed to DefaultMemRefDescriptor.
  It provides default struct lowering implementation (NFC).
* TestCustomMemRefLLVMLowering.cpp implements a custom MemRef
  descriptor lowering and LLVM type converter with the basic
  functionality to lower MemRef type to a plain pointer to element type.
* convert-memref-ops.mlir is split into convert-static-memref-ops.mlir
  and `convert-dynamic-memref-ops.mlir` so that
  TestCustomMemRefLLVMLowering.cpp can be tested on all the static
  MemRef tests available.

Related discussion: tensorflow#309
  • Loading branch information
dcaballe committed Dec 23, 2019
1 parent 70bf549 commit 6600dbf
Show file tree
Hide file tree
Showing 9 changed files with 792 additions and 308 deletions.
96 changes: 80 additions & 16 deletions include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Type;

namespace mlir {

class MemRefDescriptor;
class UnrankedMemRefType;

namespace LLVM {
Expand All @@ -56,8 +57,9 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic,
SignatureConversion &result);
virtual LLVM::LLVMType convertFunctionSignature(FunctionType type,
bool isVariadic,
SignatureConversion &result);

/// Convert a non-empty list of types to be returned from a function into a
/// supported LLVM IR type. In particular, if more than one values is
Expand All @@ -71,6 +73,20 @@ class LLVMTypeConverter : public TypeConverter {
/// Returns the LLVM dialect.
LLVM::LLVMDialect *getDialect() { return llvmDialect; }

/// Create a DefaultMemRefDescriptor object for 'value'.
virtual std::unique_ptr<MemRefDescriptor>
createMemRefDescriptor(ValuePtr value);

/// Builds IR creating an uninitialized value of the descriptor type.
virtual std::unique_ptr<MemRefDescriptor>
buildMemRefDescriptor(OpBuilder &builder, Location loc, Type descriptorType);
/// Builds IR creating a MemRef descriptor that represents `type` and
/// populates it with static shape and stride information extracted from the
/// type.
virtual std::unique_ptr<MemRefDescriptor>
buildStaticMemRefDescriptor(OpBuilder &builder, Location loc, MemRefType type,
ValuePtr memory);

/// Promote the LLVM struct representation of all MemRef descriptors to stack
/// and use pointers to struct to avoid the complexity of the
/// platform-specific C/C++ ABI lowering related to struct argument passing.
Expand All @@ -90,6 +106,9 @@ class LLVMTypeConverter : public TypeConverter {
llvm::Module *module;
LLVM::LLVMDialect *llvmDialect;

// Extract an LLVM IR dialect type.
LLVM::LLVMType unwrap(Type type);

private:
Type convertStandardType(Type type);

Expand Down Expand Up @@ -129,9 +148,60 @@ class LLVMTypeConverter : public TypeConverter {
// Get the LLVM representation of the index type based on the bitwidth of the
// pointer as defined by the data layout of the module.
LLVM::LLVMType getIndexType();
};

// Extract an LLVM IR dialect type.
LLVM::LLVMType unwrap(Type type);
// Base helper class to lower MemRef type to a descriptor in LLVM. Provides an
// abstract API to produce LLVM dialect operations that manipulate the MemRef
// descriptor. Specific MemRef descriptor implementations should inherint from
// this class and implement the API.
struct MemRefDescriptor {

virtual Value *getValue() = 0;

/// Builds IR extracting the allocated pointer from the descriptor.
virtual Value *allocatedPtr(OpBuilder &builder, Location loc) = 0;
/// Builds IR inserting the allocated pointer into the descriptor.
virtual void setAllocatedPtr(OpBuilder &builder, Location loc,
Value *ptr) = 0;

/// Builds IR extracting the aligned pointer from the descriptor.
virtual Value *alignedPtr(OpBuilder &builder, Location loc) = 0;

/// Builds IR inserting the aligned pointer into the descriptor.
virtual void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) = 0;

/// Builds IR extracting the offset from the descriptor.
virtual Value *offset(OpBuilder &builder, Location loc) = 0;

/// Builds IR inserting the offset into the descriptor.
virtual void setOffset(OpBuilder &builder, Location loc, Value *offset) = 0;

virtual void setConstantOffset(OpBuilder &builder, Location loc,
uint64_t offset) = 0;

/// Builds IR extracting the pos-th size from the descriptor.
virtual Value *size(OpBuilder &builder, Location loc, unsigned pos) = 0;

/// Builds IR inserting the pos-th size into the descriptor
virtual void setSize(OpBuilder &builder, Location loc, unsigned pos,
Value *size) = 0;
virtual void setConstantSize(OpBuilder &builder, Location loc, unsigned pos,
uint64_t size) = 0;

/// Builds IR extracting the pos-th size from the descriptor.
virtual Value *stride(OpBuilder &builder, Location loc, unsigned pos) = 0;

/// Builds IR inserting the pos-th stride into the descriptor
virtual void setStride(OpBuilder &builder, Location loc, unsigned pos,
Value *stride) = 0;
virtual void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
uint64_t stride) = 0;

/// Returns the (LLVM) type this descriptor points to.
virtual LLVM::LLVMType getElementType() = 0;

protected:
MemRefDescriptor() = default;
};

/// Helper class to produce LLVM dialect operations extracting or inserting
Expand All @@ -144,7 +214,7 @@ class StructBuilder {
static StructBuilder undef(OpBuilder &builder, Location loc,
Type descriptorType);

/*implicit*/ operator ValuePtr() { return value; }
ValuePtr getValue() { return value; }

protected:
// LLVM value
Expand All @@ -158,22 +228,16 @@ class StructBuilder {
/// Builds IR to set a value in the struct at position pos
void setPtr(OpBuilder &builder, Location loc, unsigned pos, ValuePtr ptr);
};

/// Helper class to produce LLVM dialect operations extracting or inserting
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
/// The Value may be null, in which case none of the operations are valid.
class MemRefDescriptor : public StructBuilder {
class DefaultMemRefDescriptor : public StructBuilder, public MemRefDescriptor {
public:
/// Construct a helper for the given descriptor value.
explicit MemRefDescriptor(ValuePtr descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR creating a MemRef descriptor that represents `type` and
/// populates it with static shape and stride information extracted from the
/// type.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
MemRefType type, ValuePtr memory);
explicit DefaultMemRefDescriptor(ValuePtr descriptor);

ValuePtr getValue() override { return StructBuilder::getValue(); };

/// Builds IR extracting the allocated pointer from the descriptor.
ValuePtr allocatedPtr(OpBuilder &builder, Location loc);
Expand Down
14 changes: 8 additions & 6 deletions lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,9 +623,10 @@ struct GPUFuncOpLowering : LLVMOpLowering {
// and canonicalize that away later.
ValuePtr attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
auto type = attribution->getType().cast<MemRefType>();
auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
type, memory);
signatureConversion.remapInput(numProperArguments + en.index(), descr);
auto descr =
lowering.buildStaticMemRefDescriptor(rewriter, loc, type, memory);
signatureConversion.remapInput(numProperArguments + en.index(),
descr->getValue());
}

// Rewrite private memory attributions to alloca'ed buffers.
Expand All @@ -649,10 +650,11 @@ struct GPUFuncOpLowering : LLVMOpLowering {
rewriter.getI64IntegerAttr(type.getNumElements()));
ValuePtr allocated = rewriter.create<LLVM::AllocaOp>(
gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
type, allocated);
auto descr = lowering.buildStaticMemRefDescriptor(rewriter, loc, type,
allocated);
signatureConversion.remapInput(
numProperArguments + numWorkgroupAttributions + en.index(), descr);
numProperArguments + numWorkgroupAttributions + en.index(),
descr->getValue());
}
}

Expand Down
40 changes: 21 additions & 19 deletions lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,32 +117,33 @@ namespace {
/// EDSC-compatible wrapper for MemRefDescriptor.
class BaseViewConversionHelper {
public:
BaseViewConversionHelper(Type type)
: d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
BaseViewConversionHelper(Type type, LLVMTypeConverter &typeConverter)
: d(typeConverter.buildMemRefDescriptor(rewriter(), loc(), type)) {}

BaseViewConversionHelper(ValuePtr v) : d(v) {}
BaseViewConversionHelper(ValuePtr v, LLVMTypeConverter &typeConverter)
: d(typeConverter.createMemRefDescriptor(v)) {}

/// Wrappers around MemRefDescriptor that use EDSC builder and location.
ValuePtr allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
void setAllocatedPtr(ValuePtr v) { d.setAllocatedPtr(rewriter(), loc(), v); }
ValuePtr alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
void setAlignedPtr(ValuePtr v) { d.setAlignedPtr(rewriter(), loc(), v); }
ValuePtr offset() { return d.offset(rewriter(), loc()); }
void setOffset(ValuePtr v) { d.setOffset(rewriter(), loc(), v); }
ValuePtr size(unsigned i) { return d.size(rewriter(), loc(), i); }
void setSize(unsigned i, ValuePtr v) { d.setSize(rewriter(), loc(), i, v); }
ValuePtr stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
ValuePtr allocatedPtr() { return d->allocatedPtr(rewriter(), loc()); }
void setAllocatedPtr(ValuePtr v) { d->setAllocatedPtr(rewriter(), loc(), v); }
ValuePtr alignedPtr() { return d->alignedPtr(rewriter(), loc()); }
void setAlignedPtr(ValuePtr v) { d->setAlignedPtr(rewriter(), loc(), v); }
ValuePtr offset() { return d->offset(rewriter(), loc()); }
void setOffset(ValuePtr v) { d->setOffset(rewriter(), loc(), v); }
ValuePtr size(unsigned i) { return d->size(rewriter(), loc(), i); }
void setSize(unsigned i, ValuePtr v) { d->setSize(rewriter(), loc(), i, v); }
ValuePtr stride(unsigned i) { return d->stride(rewriter(), loc(), i); }
void setStride(unsigned i, ValuePtr v) {
d.setStride(rewriter(), loc(), i, v);
d->setStride(rewriter(), loc(), i, v);
}

operator ValuePtr() { return d; }
operator ValuePtr() { return d->getValue(); }

private:
OpBuilder &rewriter() { return ScopedContext::getBuilder(); }
Location loc() { return ScopedContext::getLocation(); }

MemRefDescriptor d;
std::unique_ptr<MemRefDescriptor> d;
};
} // namespace

Expand Down Expand Up @@ -190,14 +191,15 @@ class SliceOpConversion : public LLVMOpLowering {
ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext context(rewriter, op->getLoc());
SliceOpOperandAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.view());
BaseViewConversionHelper baseDesc(adaptor.view(), lowering);

auto sliceOp = cast<SliceOp>(op);
auto memRefType = sliceOp.getBaseViewType();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
.cast<LLVM::LLVMType>();

BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType()));
BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType()),
lowering);

// TODO(ntv): extract sizes and emit asserts.
SmallVector<ValuePtr, 4> strides(memRefType.getRank());
Expand Down Expand Up @@ -282,15 +284,15 @@ class TransposeOpConversion : public LLVMOpLowering {
// Initialize the common boilerplate and alloca at the top of the FuncOp.
edsc::ScopedContext context(rewriter, op->getLoc());
TransposeOpOperandAdaptor adaptor(operands);
BaseViewConversionHelper baseDesc(adaptor.view());
BaseViewConversionHelper baseDesc(adaptor.view(), lowering);

auto transposeOp = cast<TransposeOp>(op);
// No permutation, early exit.
if (transposeOp.permutation().isIdentity())
return rewriter.replaceOp(op, {baseDesc}), matchSuccess();

BaseViewConversionHelper desc(
lowering.convertType(transposeOp.getViewType()));
lowering.convertType(transposeOp.getViewType()), lowering);

// Copy the base and aligned pointers from the old descriptor to the new
// one.
Expand Down
Loading

0 comments on commit 6600dbf

Please sign in to comment.