From 7e004efae253db82464648453ecdedbb8631fcb2 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Tue, 18 Sep 2018 16:36:26 -0700 Subject: [PATCH] Add function attributes for ExtFunction, CFGFunction and MLFunction. PiperOrigin-RevId: 213540509 --- mlir/include/mlir/IR/CFGFunction.h | 4 +- mlir/include/mlir/IR/Function.h | 19 +++++- mlir/include/mlir/IR/MLFunction.h | 6 +- mlir/lib/IR/AsmPrinter.cpp | 99 +++++++++++++++++----------- mlir/lib/IR/Function.cpp | 34 +++++++--- mlir/lib/Parser/Parser.cpp | 48 ++++++++++++-- mlir/lib/Parser/TokenKinds.def | 1 + mlir/lib/Transforms/ConvertToCFG.cpp | 5 +- mlir/test/IR/invalid.mlir | 20 ++++++ mlir/test/IR/parser.mlir | 40 +++++++++++ 10 files changed, 213 insertions(+), 63 deletions(-) diff --git a/mlir/include/mlir/IR/CFGFunction.h b/mlir/include/mlir/IR/CFGFunction.h index 16de0af1eae4..fa74635b82e7 100644 --- a/mlir/include/mlir/IR/CFGFunction.h +++ b/mlir/include/mlir/IR/CFGFunction.h @@ -27,7 +27,9 @@ namespace mlir { // blocks, each of which includes instructions. class CFGFunction : public Function { public: - CFGFunction(Location *location, StringRef name, FunctionType *type); + CFGFunction(Location *location, StringRef name, FunctionType *type, + ArrayRef attrs = {}); + ~CFGFunction(); //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 6398ceee2c05..f876ab8cee92 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -29,11 +29,18 @@ #include "llvm/ADT/ilist.h" namespace mlir { +class Attribute; +class AttributeListStorage; class FunctionType; class Location; class MLIRContext; class Module; +/// NamedAttribute is used for function attribute lists, it holds an +/// identifier for the name and a value for the attribute. The attribute +/// pointer should always be non-null. +typedef std::pair NamedAttribute; + /// This is the base class for all of the MLIR function types. class Function : public llvm::ilist_node_with_parent { public: @@ -50,6 +57,9 @@ public: /// Return the type of this function. FunctionType *getType() const { return type; } + /// Returns all of the attributes on this function. + ArrayRef getAttrs() const; + MLIRContext *getContext() const; Module *getModule() { return module; } const Module *getModule() const { return module; } @@ -83,7 +93,8 @@ public: void emitNote(const Twine &message) const; protected: - Function(Kind kind, Location *location, StringRef name, FunctionType *type); + Function(Kind kind, Location *location, StringRef name, FunctionType *type, + ArrayRef attrs = {}); ~Function(); private: @@ -99,6 +110,9 @@ private: /// The type of the function. FunctionType *const type; + /// This holds general named attributes for the function. + AttributeListStorage *attrs; + void operator=(const Function &) = delete; friend struct llvm::ilist_traits; }; @@ -107,7 +121,8 @@ private: /// defined in some other module. class ExtFunction : public Function { public: - ExtFunction(Location *location, StringRef name, FunctionType *type); + ExtFunction(Location *location, StringRef name, FunctionType *type, + ArrayRef attrs = {}); /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Function *func) { diff --git a/mlir/include/mlir/IR/MLFunction.h b/mlir/include/mlir/IR/MLFunction.h index 5f2497bce973..13eac5d0fa8e 100644 --- a/mlir/include/mlir/IR/MLFunction.h +++ b/mlir/include/mlir/IR/MLFunction.h @@ -41,7 +41,8 @@ class MLFunction final public: /// Creates a new MLFunction with the specific type. static MLFunction *create(Location *location, StringRef name, - FunctionType *type); + FunctionType *type, + ArrayRef attrs = {}); /// Destroys this statement and its subclass data. void destroy(); @@ -94,7 +95,8 @@ public: } private: - MLFunction(Location *location, StringRef name, FunctionType *type); + MLFunction(Location *location, StringRef name, FunctionType *type, + ArrayRef attrs = {}); // This stuff is used by the TrailingObjects template. friend llvm::TrailingObjects; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index ddcd2aace2ab..5737b8aa00e5 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -285,6 +285,9 @@ protected: ModuleState &state; void printFunctionSignature(const Function *fn); + void printFunctionAttributes(const Function *fn); + void printOptionalAttrDict(ArrayRef attrs, + ArrayRef elidedAttrs = {}); void printFunctionResultType(const FunctionType *type); void printAffineMapId(int affineMapId) const; void printAffineMapReference(const AffineMap *affineMap); @@ -751,6 +754,14 @@ void ModulePrinter::printFunctionResultType(const FunctionType *type) { } } +void ModulePrinter::printFunctionAttributes(const Function *fn) { + auto attrs = fn->getAttrs(); + if (attrs.empty()) + return; + os << "\n attributes "; + printOptionalAttrDict(attrs); +} + void ModulePrinter::printFunctionSignature(const Function *fn) { auto type = fn->getType(); @@ -762,9 +773,49 @@ void ModulePrinter::printFunctionSignature(const Function *fn) { printFunctionResultType(type); } +void ModulePrinter::printOptionalAttrDict(ArrayRef attrs, + ArrayRef elidedAttrs) { + // If there are no attributes, then there is nothing to be done. + if (attrs.empty()) + return; + + // Filter out any attributes that shouldn't be included. + SmallVector filteredAttrs; + for (auto attr : attrs) { + auto attrName = attr.first.strref(); + // Never print attributes that start with a colon. These are internal + // attributes that represent location or other internal metadata. + if (attrName.startswith(":")) + return; + + // If the caller has requested that this attribute be ignored, then drop it. + bool ignore = false; + for (const char *elide : elidedAttrs) + ignore |= attrName == StringRef(elide); + + // Otherwise add it to our filteredAttrs list. + if (!ignore) { + filteredAttrs.push_back(attr); + } + } + + // If there are no attributes left to print after filtering, then we're done. + if (filteredAttrs.empty()) + return; + + // Otherwise, print them all out in braces. + os << " {"; + interleaveComma(filteredAttrs, [&](NamedAttribute attr) { + os << attr.first << ": "; + printAttribute(attr.second); + }); + os << '}'; +} + void ModulePrinter::print(const ExtFunction *fn) { os << "extfunc "; printFunctionSignature(fn); + printFunctionAttributes(fn); os << '\n'; } @@ -797,11 +848,15 @@ public: void printFunctionReference(const Function *func) { return ModulePrinter::printFunctionReference(func); } - + void printFunctionAttributes(const Function *func) { + return ModulePrinter::printFunctionAttributes(func); + } void printOperand(const SSAValue *value) { printValueID(value); } void printOptionalAttrDict(ArrayRef attrs, - ArrayRef elidedAttrs = {}) override; + ArrayRef elidedAttrs = {}) { + return ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs); + }; enum { nameSentinel = ~0U }; @@ -944,44 +999,6 @@ private: }; } // end anonymous namespace -void FunctionPrinter::printOptionalAttrDict( - ArrayRef attrs, ArrayRef elidedAttrs) { - // If there are no attributes, then there is nothing to be done. - if (attrs.empty()) - return; - - // Filter out any attributes that shouldn't be included. - SmallVector filteredAttrs; - for (auto attr : attrs) { - auto attrName = attr.first.strref(); - // Never print attributes that start with a colon. These are internal - // attributes that represent location or other internal metadata. - if (attrName.startswith(":")) - continue; - - // If the caller has requested that this attribute be ignored, then drop it. - bool ignore = false; - for (const char *elide : elidedAttrs) - ignore |= attrName == StringRef(elide); - - // Otherwise add it to our filteredAttrs list. - if (!ignore) - filteredAttrs.push_back(attr); - } - - // If there are no attributes left to print after filtering, then we're done. - if (filteredAttrs.empty()) - return; - - // Otherwise, print them all out in braces. - os << " {"; - interleaveComma(filteredAttrs, [&](NamedAttribute attr) { - os << attr.first << ": "; - printAttribute(attr.second); - }); - os << '}'; -} - void FunctionPrinter::printOperation(const Operation *op) { if (op->getNumResults()) { printValueID(op->getResult(0), /*printResultNo=*/false); @@ -1091,6 +1108,7 @@ void CFGFunctionPrinter::numberValuesInBlock(const BasicBlock *block) { void CFGFunctionPrinter::print() { os << "cfgfunc "; printFunctionSignature(getFunction()); + printFunctionAttributes(getFunction()); os << " {\n"; for (auto &block : *function) @@ -1301,6 +1319,7 @@ void MLFunctionPrinter::numberValues() { void MLFunctionPrinter::print() { os << "mlfunc "; printFunctionSignature(); + printFunctionAttributes(getFunction()); os << " {\n"; print(function); os << "}\n\n"; diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 29494de97372..0f3809ef6891 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -15,6 +15,7 @@ // limitations under the License. // ============================================================================= +#include "AttributeListStorage.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/CFGFunction.h" #include "mlir/IR/MLFunction.h" @@ -27,15 +28,24 @@ using namespace mlir; Function::Function(Kind kind, Location *location, StringRef name, - FunctionType *type) + FunctionType *type, ArrayRef attrs) : nameAndKind(Identifier::get(name, type->getContext()), kind), - location(location), type(type) {} + location(location), type(type) { + this->attrs = AttributeListStorage::get(attrs, getContext()); +} Function::~Function() { // Clean up function attributes referring to this function. FunctionAttr::dropFunctionReference(this); } +ArrayRef Function::getAttrs() const { + if (attrs) + return attrs->getElements(); + else + return {}; +} + MLIRContext *Function::getContext() const { return getType()->getContext(); } /// Delete this object. @@ -149,15 +159,17 @@ void Function::emitError(const Twine &message) const { // ExtFunction implementation. //===----------------------------------------------------------------------===// -ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type) - : Function(Kind::ExtFunc, location, name, type) {} +ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type, + ArrayRef attrs) + : Function(Kind::ExtFunc, location, name, type, attrs) {} //===----------------------------------------------------------------------===// // CFGFunction implementation. //===----------------------------------------------------------------------===// -CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type) - : Function(Kind::CFGFunc, location, name, type) {} +CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type, + ArrayRef attrs) + : Function(Kind::CFGFunc, location, name, type, attrs) {} CFGFunction::~CFGFunction() { // Instructions may have cyclic references, which need to be dropped before we @@ -176,13 +188,14 @@ CFGFunction::~CFGFunction() { /// Create a new MLFunction with the specific fields. MLFunction *MLFunction::create(Location *location, StringRef name, - FunctionType *type) { + FunctionType *type, + ArrayRef attrs) { const auto &argTypes = type->getInputs(); auto byteSize = totalSizeToAlloc(argTypes.size()); void *rawMem = malloc(byteSize); // Initialize the MLFunction part of the function object. - auto function = ::new (rawMem) MLFunction(location, name, type); + auto function = ::new (rawMem) MLFunction(location, name, type, attrs); // Initialize the arguments. auto arguments = function->getArgumentsInternal(); @@ -191,8 +204,9 @@ MLFunction *MLFunction::create(Location *location, StringRef name, return function; } -MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type) - : Function(Kind::MLFunc, location, name, type), +MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type, + ArrayRef attrs) + : Function(Kind::MLFunc, location, name, type, attrs), StmtBlock(StmtBlockKind::MLFunc) {} MLFunction::~MLFunction() { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 6cfb1666213a..2302676ac393 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -620,7 +620,8 @@ Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc, if (!function) { auto &entry = state.functionForwardRefs[name]; if (!entry) - entry = new ExtFunction(getEncodedSourceLocation(nameLoc), name, type); + entry = new ExtFunction(getEncodedSourceLocation(nameLoc), name, type, + /*attrs=*/{}); function = entry; } @@ -2653,6 +2654,7 @@ private: SmallVectorImpl &argNames); ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type, SmallVectorImpl *argNames); + ParseResult parseFunctionAttribute(SmallVectorImpl &attrs); ParseResult parseExtFunc(); ParseResult parseCFGFunc(); ParseResult parseMLFunc(); @@ -2788,9 +2790,24 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type, return ParseSuccess; } +/// Parse function attributes, starting with keyword "attributes". +/// +/// function-attribute ::= (`attributes` attribute-dict)? +/// +ParseResult +ModuleParser::parseFunctionAttribute(SmallVectorImpl &attrs) { + if (consumeIf(Token::kw_attributes)) { + if (parseAttributeDict(attrs)) { + return ParseFailure; + } + } + return ParseSuccess; +} + /// External function declarations. /// /// ext-func ::= `extfunc` function-signature +/// (`attributes` attribute-dict)? /// ParseResult ModuleParser::parseExtFunc() { consumeToken(Token::kw_extfunc); @@ -2801,8 +2818,14 @@ ParseResult ModuleParser::parseExtFunc() { if (parseFunctionSignature(name, type, /*arguments*/ nullptr)) return ParseFailure; + SmallVector attrs; + if (parseFunctionAttribute(attrs)) { + return ParseFailure; + } + // Okay, the external function definition was parsed correctly. - auto *function = new ExtFunction(getEncodedSourceLocation(loc), name, type); + auto *function = + new ExtFunction(getEncodedSourceLocation(loc), name, type, attrs); getModule()->getFunctions().push_back(function); // Verify no name collision / redefinition. @@ -2815,7 +2838,8 @@ ParseResult ModuleParser::parseExtFunc() { /// CFG function declarations. /// -/// cfg-func ::= `cfgfunc` function-signature `{` basic-block+ `}` +/// cfg-func ::= `cfgfunc` function-signature +/// (`attributes` attribute-dict)? `{` basic-block+ `}` /// ParseResult ModuleParser::parseCFGFunc() { consumeToken(Token::kw_cfgfunc); @@ -2826,8 +2850,14 @@ ParseResult ModuleParser::parseCFGFunc() { if (parseFunctionSignature(name, type, /*arguments*/ nullptr)) return ParseFailure; + SmallVector attrs; + if (parseFunctionAttribute(attrs)) { + return ParseFailure; + } + // Okay, the CFG function signature was parsed correctly, create the function. - auto *function = new CFGFunction(getEncodedSourceLocation(loc), name, type); + auto *function = + new CFGFunction(getEncodedSourceLocation(loc), name, type, attrs); getModule()->getFunctions().push_back(function); // Verify no name collision / redefinition. @@ -2840,7 +2870,8 @@ ParseResult ModuleParser::parseCFGFunc() { /// ML function declarations. /// -/// ml-func ::= `mlfunc` ml-func-signature `{` ml-stmt* ml-return-stmt `}` +/// ml-func ::= `mlfunc` ml-func-signature +/// (`attributes` attribute-dict)? `{` ml-stmt* ml-return-stmt `}` /// ParseResult ModuleParser::parseMLFunc() { consumeToken(Token::kw_mlfunc); @@ -2853,9 +2884,14 @@ ParseResult ModuleParser::parseMLFunc() { if (parseFunctionSignature(name, type, &argNames)) return ParseFailure; + SmallVector attrs; + if (parseFunctionAttribute(attrs)) { + return ParseFailure; + } + // Okay, the ML function signature was parsed correctly, create the function. auto *function = - MLFunction::create(getEncodedSourceLocation(loc), name, type); + MLFunction::create(getEncodedSourceLocation(loc), name, type, attrs); getModule()->getFunctions().push_back(function); // Verify no name collision / redefinition. diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index 4e1cd2e8b3fc..3ef732822f49 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -88,6 +88,7 @@ TOK_OPERATOR(star, "*") // Keywords. These turn "foo" into Token::kw_foo enums. TOK_KEYWORD(affineint) +TOK_KEYWORD(attributes) TOK_KEYWORD(bf16) TOK_KEYWORD(br) TOK_KEYWORD(ceildiv) diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp index 0dd8ef60fd43..deeffb5bd9f9 100644 --- a/mlir/lib/Transforms/ConvertToCFG.cpp +++ b/mlir/lib/Transforms/ConvertToCFG.cpp @@ -110,8 +110,9 @@ void ModuleConverter::convertMLFunctions() { // Creates CFG function equivalent to the given ML function. CFGFunction *ModuleConverter::convert(const MLFunction *mlFunc) { // TODO: ensure that CFG function name is unique. - auto *cfgFunc = new CFGFunction( - mlFunc->getLoc(), mlFunc->getName().str() + "_cfg", mlFunc->getType()); + auto *cfgFunc = + new CFGFunction(mlFunc->getLoc(), mlFunc->getName().str() + "_cfg", + mlFunc->getType(), mlFunc->getAttrs()); module->getFunctions().push_back(cfgFunc); // Generates the body of the CFG function. diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index ae5d3b2f48c2..bfeddc05966b 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -41,6 +41,10 @@ extfunc @memrefs(memref<2x4xi8, #map0, 1, #map1>) // expected-error {{affine map // ----- +extfunc @illegalattrs() -> () attributes { key } // expected-error {{expected ':' in attribute list}} + +// ----- + cfgfunc @foo() cfgfunc @bar() // expected-error {{expected '{' in CFG function}} @@ -98,6 +102,14 @@ bb42 (%0): // expected-error {{expected ':' and type for SSA operand}} // ----- +cfgfunc @illegalattrs() -> () + attributes { key } { // expected-error {{expected ':' in attribute list}} +bb42: + return +} + +// ----- + mlfunc @foo() mlfunc @bar() // expected-error {{expected '{' before statement list}} @@ -108,6 +120,14 @@ mlfunc @empty() { // expected-error {{ML function must end with return statement // ----- +mlfunc @illegalattrs() -> () + attributes { key } { // expected-error {{expected ':' in attribute list}} +bb42: + return +} + +// ----- + mlfunc @no_return() { // expected-error {{ML function must end with return statement}} "foo"() : () -> () } diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 5a19a32420f6..0e81adcb8534 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -414,3 +414,43 @@ bb0: "foo"(){a: 4.0, b: 2.0, c: 7.1, d: -0.0} : () -> () return } + +// CHECK-LABEL: extfunc @extfuncattr +extfunc @extfuncattr() -> () + // CHECK: attributes {a: "a\22quoted\22string", b: 4.000000e+00, c: tensor<*xf32>} + attributes {a: "a\"quoted\"string", b: 4.0, c: tensor<*xf32>} + +// CHECK-LABEL: extfunc @extfuncattrempty +extfunc @extfuncattrempty() -> () + // CHECK-EMPTY + attributes {} + +// CHECK-LABEL: cfgfunc @cfgfuncattr +cfgfunc @cfgfuncattr() -> () + // CHECK: attributes {a: "a\22quoted\22string", b: 4.000000e+00, c: tensor<*xf32>} + attributes {a: "a\"quoted\"string", b: 4.0, c: tensor<*xf32>} { +bb0: + return +} + +// CHECK-LABEL: cfgfunc @cfgfuncattrempty +cfgfunc @cfgfuncattrempty() -> () + // CHECK-EMPTY + attributes {} { +bb0: + return +} + +// CHECK-LABEL: mlfunc @mlfuncattr +mlfunc @mlfuncattr() -> () + // CHECK: attributes {a: "a\22quoted\22string", b: 4.000000e+00, c: tensor<*xf32>} + attributes {a: "a\"quoted\"string", b: 4.0, c: tensor<*xf32>} { + return +} + +// CHECK-LABEL: mlfunc @mlfuncattrempty +mlfunc @mlfuncattrempty() -> () + // CHECK-EMPTY + attributes {} { + return +}