Add function attributes for ExtFunction, CFGFunction and MLFunction.

PiperOrigin-RevId: 213540509
This commit is contained in:
Feng Liu 2018-09-18 16:36:26 -07:00 committed by jpienaar
parent 81a066e6e7
commit 7e004efae2
10 changed files with 213 additions and 63 deletions

View File

@ -27,7 +27,9 @@ namespace mlir {
// blocks, each of which includes instructions.
class CFGFunction : public Function {
public:
CFGFunction(Location *location, StringRef name, FunctionType *type);
CFGFunction(Location *location, StringRef name, FunctionType *type,
ArrayRef<NamedAttribute> attrs = {});
~CFGFunction();
//===--------------------------------------------------------------------===//

View File

@ -29,11 +29,18 @@
#include "llvm/ADT/ilist.h"
namespace mlir {
class Attribute;
class AttributeListStorage;
class FunctionType;
class Location;
class MLIRContext;
class Module;
/// NamedAttribute is used for function attribute lists, it holds an
/// identifier for the name and a value for the attribute. The attribute
/// pointer should always be non-null.
typedef std::pair<Identifier, Attribute *> NamedAttribute;
/// This is the base class for all of the MLIR function types.
class Function : public llvm::ilist_node_with_parent<Function, Module> {
public:
@ -50,6 +57,9 @@ public:
/// Return the type of this function.
FunctionType *getType() const { return type; }
/// Returns all of the attributes on this function.
ArrayRef<NamedAttribute> getAttrs() const;
MLIRContext *getContext() const;
Module *getModule() { return module; }
const Module *getModule() const { return module; }
@ -83,7 +93,8 @@ public:
void emitNote(const Twine &message) const;
protected:
Function(Kind kind, Location *location, StringRef name, FunctionType *type);
Function(Kind kind, Location *location, StringRef name, FunctionType *type,
ArrayRef<NamedAttribute> attrs = {});
~Function();
private:
@ -99,6 +110,9 @@ private:
/// The type of the function.
FunctionType *const type;
/// This holds general named attributes for the function.
AttributeListStorage *attrs;
void operator=(const Function &) = delete;
friend struct llvm::ilist_traits<Function>;
};
@ -107,7 +121,8 @@ private:
/// defined in some other module.
class ExtFunction : public Function {
public:
ExtFunction(Location *location, StringRef name, FunctionType *type);
ExtFunction(Location *location, StringRef name, FunctionType *type,
ArrayRef<NamedAttribute> attrs = {});
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Function *func) {

View File

@ -41,7 +41,8 @@ class MLFunction final
public:
/// Creates a new MLFunction with the specific type.
static MLFunction *create(Location *location, StringRef name,
FunctionType *type);
FunctionType *type,
ArrayRef<NamedAttribute> attrs = {});
/// Destroys this statement and its subclass data.
void destroy();
@ -94,7 +95,8 @@ public:
}
private:
MLFunction(Location *location, StringRef name, FunctionType *type);
MLFunction(Location *location, StringRef name, FunctionType *type,
ArrayRef<NamedAttribute> attrs = {});
// This stuff is used by the TrailingObjects template.
friend llvm::TrailingObjects<MLFunction, MLFuncArgument>;

View File

@ -285,6 +285,9 @@ protected:
ModuleState &state;
void printFunctionSignature(const Function *fn);
void printFunctionAttributes(const Function *fn);
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {});
void printFunctionResultType(const FunctionType *type);
void printAffineMapId(int affineMapId) const;
void printAffineMapReference(const AffineMap *affineMap);
@ -751,6 +754,14 @@ void ModulePrinter::printFunctionResultType(const FunctionType *type) {
}
}
void ModulePrinter::printFunctionAttributes(const Function *fn) {
auto attrs = fn->getAttrs();
if (attrs.empty())
return;
os << "\n attributes ";
printOptionalAttrDict(attrs);
}
void ModulePrinter::printFunctionSignature(const Function *fn) {
auto type = fn->getType();
@ -762,9 +773,49 @@ void ModulePrinter::printFunctionSignature(const Function *fn) {
printFunctionResultType(type);
}
void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs) {
// If there are no attributes, then there is nothing to be done.
if (attrs.empty())
return;
// Filter out any attributes that shouldn't be included.
SmallVector<NamedAttribute, 8> filteredAttrs;
for (auto attr : attrs) {
auto attrName = attr.first.strref();
// Never print attributes that start with a colon. These are internal
// attributes that represent location or other internal metadata.
if (attrName.startswith(":"))
return;
// If the caller has requested that this attribute be ignored, then drop it.
bool ignore = false;
for (const char *elide : elidedAttrs)
ignore |= attrName == StringRef(elide);
// Otherwise add it to our filteredAttrs list.
if (!ignore) {
filteredAttrs.push_back(attr);
}
}
// If there are no attributes left to print after filtering, then we're done.
if (filteredAttrs.empty())
return;
// Otherwise, print them all out in braces.
os << " {";
interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
os << attr.first << ": ";
printAttribute(attr.second);
});
os << '}';
}
void ModulePrinter::print(const ExtFunction *fn) {
os << "extfunc ";
printFunctionSignature(fn);
printFunctionAttributes(fn);
os << '\n';
}
@ -797,11 +848,15 @@ public:
void printFunctionReference(const Function *func) {
return ModulePrinter::printFunctionReference(func);
}
void printFunctionAttributes(const Function *func) {
return ModulePrinter::printFunctionAttributes(func);
}
void printOperand(const SSAValue *value) { printValueID(value); }
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {}) override;
ArrayRef<const char *> elidedAttrs = {}) {
return ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
};
enum { nameSentinel = ~0U };
@ -944,44 +999,6 @@ private:
};
} // end anonymous namespace
void FunctionPrinter::printOptionalAttrDict(
ArrayRef<NamedAttribute> attrs, ArrayRef<const char *> elidedAttrs) {
// If there are no attributes, then there is nothing to be done.
if (attrs.empty())
return;
// Filter out any attributes that shouldn't be included.
SmallVector<NamedAttribute, 8> filteredAttrs;
for (auto attr : attrs) {
auto attrName = attr.first.strref();
// Never print attributes that start with a colon. These are internal
// attributes that represent location or other internal metadata.
if (attrName.startswith(":"))
continue;
// If the caller has requested that this attribute be ignored, then drop it.
bool ignore = false;
for (const char *elide : elidedAttrs)
ignore |= attrName == StringRef(elide);
// Otherwise add it to our filteredAttrs list.
if (!ignore)
filteredAttrs.push_back(attr);
}
// If there are no attributes left to print after filtering, then we're done.
if (filteredAttrs.empty())
return;
// Otherwise, print them all out in braces.
os << " {";
interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
os << attr.first << ": ";
printAttribute(attr.second);
});
os << '}';
}
void FunctionPrinter::printOperation(const Operation *op) {
if (op->getNumResults()) {
printValueID(op->getResult(0), /*printResultNo=*/false);
@ -1091,6 +1108,7 @@ void CFGFunctionPrinter::numberValuesInBlock(const BasicBlock *block) {
void CFGFunctionPrinter::print() {
os << "cfgfunc ";
printFunctionSignature(getFunction());
printFunctionAttributes(getFunction());
os << " {\n";
for (auto &block : *function)
@ -1301,6 +1319,7 @@ void MLFunctionPrinter::numberValues() {
void MLFunctionPrinter::print() {
os << "mlfunc ";
printFunctionSignature();
printFunctionAttributes(getFunction());
os << " {\n";
print(function);
os << "}\n\n";

View File

@ -15,6 +15,7 @@
// limitations under the License.
// =============================================================================
#include "AttributeListStorage.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
@ -27,15 +28,24 @@
using namespace mlir;
Function::Function(Kind kind, Location *location, StringRef name,
FunctionType *type)
FunctionType *type, ArrayRef<NamedAttribute> attrs)
: nameAndKind(Identifier::get(name, type->getContext()), kind),
location(location), type(type) {}
location(location), type(type) {
this->attrs = AttributeListStorage::get(attrs, getContext());
}
Function::~Function() {
// Clean up function attributes referring to this function.
FunctionAttr::dropFunctionReference(this);
}
ArrayRef<NamedAttribute> Function::getAttrs() const {
if (attrs)
return attrs->getElements();
else
return {};
}
MLIRContext *Function::getContext() const { return getType()->getContext(); }
/// Delete this object.
@ -149,15 +159,17 @@ void Function::emitError(const Twine &message) const {
// ExtFunction implementation.
//===----------------------------------------------------------------------===//
ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type)
: Function(Kind::ExtFunc, location, name, type) {}
ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::ExtFunc, location, name, type, attrs) {}
//===----------------------------------------------------------------------===//
// CFGFunction implementation.
//===----------------------------------------------------------------------===//
CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type)
: Function(Kind::CFGFunc, location, name, type) {}
CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::CFGFunc, location, name, type, attrs) {}
CFGFunction::~CFGFunction() {
// Instructions may have cyclic references, which need to be dropped before we
@ -176,13 +188,14 @@ CFGFunction::~CFGFunction() {
/// Create a new MLFunction with the specific fields.
MLFunction *MLFunction::create(Location *location, StringRef name,
FunctionType *type) {
FunctionType *type,
ArrayRef<NamedAttribute> attrs) {
const auto &argTypes = type->getInputs();
auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
void *rawMem = malloc(byteSize);
// Initialize the MLFunction part of the function object.
auto function = ::new (rawMem) MLFunction(location, name, type);
auto function = ::new (rawMem) MLFunction(location, name, type, attrs);
// Initialize the arguments.
auto arguments = function->getArgumentsInternal();
@ -191,8 +204,9 @@ MLFunction *MLFunction::create(Location *location, StringRef name,
return function;
}
MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type)
: Function(Kind::MLFunc, location, name, type),
MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::MLFunc, location, name, type, attrs),
StmtBlock(StmtBlockKind::MLFunc) {}
MLFunction::~MLFunction() {

View File

@ -620,7 +620,8 @@ Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
if (!function) {
auto &entry = state.functionForwardRefs[name];
if (!entry)
entry = new ExtFunction(getEncodedSourceLocation(nameLoc), name, type);
entry = new ExtFunction(getEncodedSourceLocation(nameLoc), name, type,
/*attrs=*/{});
function = entry;
}
@ -2653,6 +2654,7 @@ private:
SmallVectorImpl<StringRef> &argNames);
ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type,
SmallVectorImpl<StringRef> *argNames);
ParseResult parseFunctionAttribute(SmallVectorImpl<NamedAttribute> &attrs);
ParseResult parseExtFunc();
ParseResult parseCFGFunc();
ParseResult parseMLFunc();
@ -2788,9 +2790,24 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
return ParseSuccess;
}
/// Parse function attributes, starting with keyword "attributes".
///
/// function-attribute ::= (`attributes` attribute-dict)?
///
ParseResult
ModuleParser::parseFunctionAttribute(SmallVectorImpl<NamedAttribute> &attrs) {
if (consumeIf(Token::kw_attributes)) {
if (parseAttributeDict(attrs)) {
return ParseFailure;
}
}
return ParseSuccess;
}
/// External function declarations.
///
/// ext-func ::= `extfunc` function-signature
/// (`attributes` attribute-dict)?
///
ParseResult ModuleParser::parseExtFunc() {
consumeToken(Token::kw_extfunc);
@ -2801,8 +2818,14 @@ ParseResult ModuleParser::parseExtFunc() {
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
return ParseFailure;
SmallVector<NamedAttribute, 8> attrs;
if (parseFunctionAttribute(attrs)) {
return ParseFailure;
}
// Okay, the external function definition was parsed correctly.
auto *function = new ExtFunction(getEncodedSourceLocation(loc), name, type);
auto *function =
new ExtFunction(getEncodedSourceLocation(loc), name, type, attrs);
getModule()->getFunctions().push_back(function);
// Verify no name collision / redefinition.
@ -2815,7 +2838,8 @@ ParseResult ModuleParser::parseExtFunc() {
/// CFG function declarations.
///
/// cfg-func ::= `cfgfunc` function-signature `{` basic-block+ `}`
/// cfg-func ::= `cfgfunc` function-signature
/// (`attributes` attribute-dict)? `{` basic-block+ `}`
///
ParseResult ModuleParser::parseCFGFunc() {
consumeToken(Token::kw_cfgfunc);
@ -2826,8 +2850,14 @@ ParseResult ModuleParser::parseCFGFunc() {
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
return ParseFailure;
SmallVector<NamedAttribute, 8> attrs;
if (parseFunctionAttribute(attrs)) {
return ParseFailure;
}
// Okay, the CFG function signature was parsed correctly, create the function.
auto *function = new CFGFunction(getEncodedSourceLocation(loc), name, type);
auto *function =
new CFGFunction(getEncodedSourceLocation(loc), name, type, attrs);
getModule()->getFunctions().push_back(function);
// Verify no name collision / redefinition.
@ -2840,7 +2870,8 @@ ParseResult ModuleParser::parseCFGFunc() {
/// ML function declarations.
///
/// ml-func ::= `mlfunc` ml-func-signature `{` ml-stmt* ml-return-stmt `}`
/// ml-func ::= `mlfunc` ml-func-signature
/// (`attributes` attribute-dict)? `{` ml-stmt* ml-return-stmt `}`
///
ParseResult ModuleParser::parseMLFunc() {
consumeToken(Token::kw_mlfunc);
@ -2853,9 +2884,14 @@ ParseResult ModuleParser::parseMLFunc() {
if (parseFunctionSignature(name, type, &argNames))
return ParseFailure;
SmallVector<NamedAttribute, 8> attrs;
if (parseFunctionAttribute(attrs)) {
return ParseFailure;
}
// Okay, the ML function signature was parsed correctly, create the function.
auto *function =
MLFunction::create(getEncodedSourceLocation(loc), name, type);
MLFunction::create(getEncodedSourceLocation(loc), name, type, attrs);
getModule()->getFunctions().push_back(function);
// Verify no name collision / redefinition.

View File

@ -88,6 +88,7 @@ TOK_OPERATOR(star, "*")
// Keywords. These turn "foo" into Token::kw_foo enums.
TOK_KEYWORD(affineint)
TOK_KEYWORD(attributes)
TOK_KEYWORD(bf16)
TOK_KEYWORD(br)
TOK_KEYWORD(ceildiv)

View File

@ -110,8 +110,9 @@ void ModuleConverter::convertMLFunctions() {
// Creates CFG function equivalent to the given ML function.
CFGFunction *ModuleConverter::convert(const MLFunction *mlFunc) {
// TODO: ensure that CFG function name is unique.
auto *cfgFunc = new CFGFunction(
mlFunc->getLoc(), mlFunc->getName().str() + "_cfg", mlFunc->getType());
auto *cfgFunc =
new CFGFunction(mlFunc->getLoc(), mlFunc->getName().str() + "_cfg",
mlFunc->getType(), mlFunc->getAttrs());
module->getFunctions().push_back(cfgFunc);
// Generates the body of the CFG function.

View File

@ -41,6 +41,10 @@ extfunc @memrefs(memref<2x4xi8, #map0, 1, #map1>) // expected-error {{affine map
// -----
extfunc @illegalattrs() -> () attributes { key } // expected-error {{expected ':' in attribute list}}
// -----
cfgfunc @foo()
cfgfunc @bar() // expected-error {{expected '{' in CFG function}}
@ -98,6 +102,14 @@ bb42 (%0): // expected-error {{expected ':' and type for SSA operand}}
// -----
cfgfunc @illegalattrs() -> ()
attributes { key } { // expected-error {{expected ':' in attribute list}}
bb42:
return
}
// -----
mlfunc @foo()
mlfunc @bar() // expected-error {{expected '{' before statement list}}
@ -108,6 +120,14 @@ mlfunc @empty() { // expected-error {{ML function must end with return statement
// -----
mlfunc @illegalattrs() -> ()
attributes { key } { // expected-error {{expected ':' in attribute list}}
bb42:
return
}
// -----
mlfunc @no_return() { // expected-error {{ML function must end with return statement}}
"foo"() : () -> ()
}

View File

@ -414,3 +414,43 @@ bb0:
"foo"(){a: 4.0, b: 2.0, c: 7.1, d: -0.0} : () -> ()
return
}
// CHECK-LABEL: extfunc @extfuncattr
extfunc @extfuncattr() -> ()
// CHECK: attributes {a: "a\22quoted\22string", b: 4.000000e+00, c: tensor<*xf32>}
attributes {a: "a\"quoted\"string", b: 4.0, c: tensor<*xf32>}
// CHECK-LABEL: extfunc @extfuncattrempty
extfunc @extfuncattrempty() -> ()
// CHECK-EMPTY
attributes {}
// CHECK-LABEL: cfgfunc @cfgfuncattr
cfgfunc @cfgfuncattr() -> ()
// CHECK: attributes {a: "a\22quoted\22string", b: 4.000000e+00, c: tensor<*xf32>}
attributes {a: "a\"quoted\"string", b: 4.0, c: tensor<*xf32>} {
bb0:
return
}
// CHECK-LABEL: cfgfunc @cfgfuncattrempty
cfgfunc @cfgfuncattrempty() -> ()
// CHECK-EMPTY
attributes {} {
bb0:
return
}
// CHECK-LABEL: mlfunc @mlfuncattr
mlfunc @mlfuncattr() -> ()
// CHECK: attributes {a: "a\22quoted\22string", b: 4.000000e+00, c: tensor<*xf32>}
attributes {a: "a\"quoted\"string", b: 4.0, c: tensor<*xf32>} {
return
}
// CHECK-LABEL: mlfunc @mlfuncattrempty
mlfunc @mlfuncattrempty() -> ()
// CHECK-EMPTY
attributes {} {
return
}