Skip to content

Commit

Permalink
NFC: Introduce new ValuePtr/ValueRef typedefs to simplify the transit…
Browse files Browse the repository at this point in the history
…ion to Value being value-typed.

This is an initial step to refactoring the representation of OpResult as proposed in: https://groups.google.com/a/tensorflow.org/g/mlir/c/XXzzKhqqF_0/m/v6bKb08WCgAJ

This change will make it much simpler to incrementally transition all of the existing code to use value-typed semantics.

PiperOrigin-RevId: 286844725
  • Loading branch information
River707 authored and tensorflower-gardener committed Dec 23, 2019
1 parent 582c742 commit 70bf549
Show file tree
Hide file tree
Showing 201 changed files with 2,493 additions and 2,413 deletions.
2 changes: 1 addition & 1 deletion bindings/python/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ struct PythonValueHandle {
assert(value.hasType() && value.getType().isa<FunctionType>() &&
"can only call function-typed values");

std::vector<Value *> argValues;
std::vector<ValuePtr> argValues;
argValues.reserve(args.size());
for (auto arg : args)
argValues.push_back(arg.value.getValue());
Expand Down
8 changes: 4 additions & 4 deletions examples/toy/Ch2/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def AddOp : Toy_Op<"add"> {

// Allow building an AddOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs">
];
}

Expand Down Expand Up @@ -129,7 +129,7 @@ def GenericCallOp : Toy_Op<"generic_call"> {
// Add custom build methods for the generic call operation.
let builders = [
OpBuilder<"Builder *builder, OperationState &state, "
"StringRef callee, ArrayRef<Value *> arguments">
"StringRef callee, ArrayRef<ValuePtr> arguments">
];
}

Expand All @@ -145,7 +145,7 @@ def MulOp : Toy_Op<"mul"> {

// Allow building a MulOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs">
];
}

Expand Down Expand Up @@ -219,7 +219,7 @@ def TransposeOp : Toy_Op<"transpose"> {

// Allow building a TransposeOp with from the input operand.
let builders = [
OpBuilder<"Builder *b, OperationState &state, Value *input">
OpBuilder<"Builder *b, OperationState &state, ValuePtr input">
];

// Invoke a static verify method to verify this transpose operation.
Expand Down
9 changes: 5 additions & 4 deletions examples/toy/Ch2/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ static mlir::LogicalResult verify(ConstantOp op) {
// AddOp

void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
mlir::ValuePtr lhs, mlir::ValuePtr rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
Expand All @@ -103,7 +103,8 @@ void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
// GenericCallOp

void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
StringRef callee, ArrayRef<mlir::Value *> arguments) {
StringRef callee,
ArrayRef<mlir::ValuePtr> arguments) {
// Generic call always returns an unranked Tensor initially.
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands(arguments);
Expand All @@ -114,7 +115,7 @@ void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
// MulOp

void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
mlir::ValuePtr lhs, mlir::ValuePtr rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
Expand Down Expand Up @@ -161,7 +162,7 @@ static mlir::LogicalResult verify(ReturnOp op) {
// TransposeOp

void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
mlir::ValuePtr value) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands(value);
}
Expand Down
41 changes: 21 additions & 20 deletions examples/toy/Ch2/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class MLIRGenImpl {
/// Entering a function creates a new scope, and the function arguments are
/// added to the mapping. When the processing of a function is terminated, the
/// scope is destroyed and the mappings created in this scope are dropped.
llvm::ScopedHashTable<StringRef, mlir::Value *> symbolTable;
llvm::ScopedHashTable<StringRef, mlir::ValuePtr> symbolTable;

/// Helper conversion for a Toy AST location to an MLIR location.
mlir::Location loc(Location loc) {
Expand All @@ -109,7 +109,7 @@ class MLIRGenImpl {

/// Declare a variable in the current scope, return success if the variable
/// wasn't declared yet.
mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) {
mlir::LogicalResult declare(llvm::StringRef var, mlir::ValuePtr value) {
if (symbolTable.count(var))
return mlir::failure();
symbolTable.insert(var, value);
Expand All @@ -132,7 +132,8 @@ class MLIRGenImpl {
/// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
ScopedHashTableScope<llvm::StringRef, mlir::ValuePtr> var_scope(
symbolTable);

// Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
Expand Down Expand Up @@ -183,7 +184,7 @@ class MLIRGenImpl {
}

/// Emit a binary operation
mlir::Value *mlirGen(BinaryExprAST &binop) {
mlir::ValuePtr mlirGen(BinaryExprAST &binop) {
// First emit the operations for each side of the operation before emitting
// the operation itself. For example if the expression is `a + foo(a)`
// 1) First it will visiting the LHS, which will return a reference to the
Expand All @@ -195,10 +196,10 @@ class MLIRGenImpl {
// and the result value is returned. If an error occurs we get a nullptr
// and propagate.
//
mlir::Value *lhs = mlirGen(*binop.getLHS());
mlir::ValuePtr lhs = mlirGen(*binop.getLHS());
if (!lhs)
return nullptr;
mlir::Value *rhs = mlirGen(*binop.getRHS());
mlir::ValuePtr rhs = mlirGen(*binop.getRHS());
if (!rhs)
return nullptr;
auto location = loc(binop.loc());
Expand All @@ -219,8 +220,8 @@ class MLIRGenImpl {
/// This is a reference to a variable in an expression. The variable is
/// expected to have been declared and so should have a value in the symbol
/// table, otherwise emit an error and return nullptr.
mlir::Value *mlirGen(VariableExprAST &expr) {
if (auto *variable = symbolTable.lookup(expr.getName()))
mlir::ValuePtr mlirGen(VariableExprAST &expr) {
if (auto variable = symbolTable.lookup(expr.getName()))
return variable;

emitError(loc(expr.loc()), "error: unknown variable '")
Expand All @@ -233,15 +234,15 @@ class MLIRGenImpl {
auto location = loc(ret.loc());

// 'return' takes an optional expression, handle that case here.
mlir::Value *expr = nullptr;
mlir::ValuePtr expr = nullptr;
if (ret.getExpr().hasValue()) {
if (!(expr = mlirGen(*ret.getExpr().getValue())))
return mlir::failure();
}

// Otherwise, this return operation has zero operands.
builder.create<ReturnOp>(location, expr ? makeArrayRef(expr)
: ArrayRef<mlir::Value *>());
: ArrayRef<mlir::ValuePtr>());
return mlir::success();
}

Expand All @@ -263,7 +264,7 @@ class MLIRGenImpl {
/// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
/// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
///
mlir::Value *mlirGen(LiteralExprAST &lit) {
mlir::ValuePtr mlirGen(LiteralExprAST &lit) {
auto type = getType(lit.getDims());

// The attribute is a vector with a floating point value per element
Expand Down Expand Up @@ -309,14 +310,14 @@ class MLIRGenImpl {

/// Emit a call expression. It emits specific operations for the `transpose`
/// builtin. Other identifiers are assumed to be user-defined functions.
mlir::Value *mlirGen(CallExprAST &call) {
mlir::ValuePtr mlirGen(CallExprAST &call) {
llvm::StringRef callee = call.getCallee();
auto location = loc(call.loc());

// Codegen the operands first.
SmallVector<mlir::Value *, 4> operands;
SmallVector<mlir::ValuePtr, 4> operands;
for (auto &expr : call.getArgs()) {
auto *arg = mlirGen(*expr);
auto arg = mlirGen(*expr);
if (!arg)
return nullptr;
operands.push_back(arg);
Expand All @@ -342,7 +343,7 @@ class MLIRGenImpl {
/// Emit a print expression. It emits specific operations for two builtins:
/// transpose(x) and print(x).
mlir::LogicalResult mlirGen(PrintExprAST &call) {
auto *arg = mlirGen(*call.getArg());
auto arg = mlirGen(*call.getArg());
if (!arg)
return mlir::failure();

Expand All @@ -351,12 +352,12 @@ class MLIRGenImpl {
}

/// Emit a constant for a single number (FIXME: semantic? broadcast?)
mlir::Value *mlirGen(NumberExprAST &num) {
mlir::ValuePtr mlirGen(NumberExprAST &num) {
return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
}

/// Dispatch codegen for the right expression subclass using RTTI.
mlir::Value *mlirGen(ExprAST &expr) {
mlir::ValuePtr mlirGen(ExprAST &expr) {
switch (expr.getKind()) {
case toy::ExprAST::Expr_BinOp:
return mlirGen(cast<BinaryExprAST>(expr));
Expand All @@ -380,15 +381,15 @@ class MLIRGenImpl {
/// initializer and record the value in the symbol table before returning it.
/// Future expressions will be able to reference this variable through symbol
/// table lookup.
mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
mlir::ValuePtr mlirGen(VarDeclExprAST &vardecl) {
auto init = vardecl.getInitVal();
if (!init) {
emitError(loc(vardecl.loc()),
"missing initializer in variable declaration");
return nullptr;
}

mlir::Value *value = mlirGen(*init);
mlir::ValuePtr value = mlirGen(*init);
if (!value)
return nullptr;

Expand All @@ -408,7 +409,7 @@ class MLIRGenImpl {

/// Codegen a list of expression, return failure if one of them hit an error.
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
ScopedHashTableScope<StringRef, mlir::Value *> var_scope(symbolTable);
ScopedHashTableScope<StringRef, mlir::ValuePtr> var_scope(symbolTable);
for (auto &expr : blockAST) {
// Specific handling for variable declarations, return statement, and
// print. These can only appear in block list and not in nested
Expand Down
8 changes: 4 additions & 4 deletions examples/toy/Ch3/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def AddOp : Toy_Op<"add", [NoSideEffect]> {

// Allow building an AddOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs">
];
}

Expand Down Expand Up @@ -129,7 +129,7 @@ def GenericCallOp : Toy_Op<"generic_call"> {
// Add custom build methods for the generic call operation.
let builders = [
OpBuilder<"Builder *builder, OperationState &state, "
"StringRef callee, ArrayRef<Value *> arguments">
"StringRef callee, ArrayRef<ValuePtr> arguments">
];
}

Expand All @@ -145,7 +145,7 @@ def MulOp : Toy_Op<"mul", [NoSideEffect]> {

// Allow building a MulOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
OpBuilder<"Builder *b, OperationState &state, ValuePtr lhs, ValuePtr rhs">
];
}

Expand Down Expand Up @@ -225,7 +225,7 @@ def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {

// Allow building a TransposeOp with from the input operand.
let builders = [
OpBuilder<"Builder *b, OperationState &state, Value *input">
OpBuilder<"Builder *b, OperationState &state, ValuePtr input">
];

// Invoke a static verify method to verify this transpose operation.
Expand Down
9 changes: 5 additions & 4 deletions examples/toy/Ch3/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ static mlir::LogicalResult verify(ConstantOp op) {
// AddOp

void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
mlir::ValuePtr lhs, mlir::ValuePtr rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
Expand All @@ -103,7 +103,8 @@ void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
// GenericCallOp

void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
StringRef callee, ArrayRef<mlir::Value *> arguments) {
StringRef callee,
ArrayRef<mlir::ValuePtr> arguments) {
// Generic call always returns an unranked Tensor initially.
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands(arguments);
Expand All @@ -114,7 +115,7 @@ void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
// MulOp

void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
mlir::ValuePtr lhs, mlir::ValuePtr rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
Expand Down Expand Up @@ -161,7 +162,7 @@ static mlir::LogicalResult verify(ReturnOp op) {
// TransposeOp

void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
mlir::ValuePtr value) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands(value);
}
Expand Down
Loading

0 comments on commit 70bf549

Please sign in to comment.