build code structure; update readme; remove pymlir directory

This commit is contained in:
Hanchen Ye 2020-08-25 14:11:30 -05:00
parent 69a04bbd38
commit b35338e68e
47 changed files with 362 additions and 3743 deletions

4
.gitignore vendored
View File

@ -1 +1,3 @@
.vscode/
.vscode
build

View File

@ -0,0 +1,55 @@
cmake_minimum_required(VERSION 3.13.4)
if(POLICY CMP0068)
cmake_policy(SET CMP0068 NEW)
set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON)
endif()
if(POLICY CMP0075)
cmake_policy(SET CMP0075 NEW)
endif()
if(POLICY CMP0077)
cmake_policy(SET CMP0077 NEW)
endif()
project(hlsld LANGUAGES CXX C)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED YES)
find_package(MLIR REQUIRED CONFIG)
message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin)
set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib)
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})
set(HLSLD_MAIN_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include )
set(HLSLD_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/include )
set(HLSLD_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(HLSLD_BINARY_DIR ${CMAKE_BINARY_DIR}/bin)
set(HLSLD_TOOLS_DIR ${CMAKE_BINARY_DIR}/bin)
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
include(TableGen)
include(AddLLVM)
include(AddMLIR)
include(HandleLLVMOptions)
include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/include)
link_directories(${LLVM_BUILD_LIBRARY_DIR})
add_definitions(${LLVM_DEFINITIONS})
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(test)
add_subdirectory(tools)

View File

@ -1,21 +1,21 @@
# HLS Large Design Project (HLSLD)
## High Level Description of Project
Hanchen and Jack aim to create a framework that ultimately converts an algorithm written in a high level language into an efficient hardware implementation. Out of all of the existing deep learning compiler related projects, TVM and MLIR present existing tools that we can leverage in our proposed framework. With multiple levels of intermediate representations (IRs), MLIR appears to be the ideal tool for exploring ways to optimize the eventual design at various levels of abstraction (e.g. various levels of parallelism). Our framework will be based on MLIR, but it will incorporate a frontend for Sitao's powerful domain specific language (DSL) and a backend for high level synthesis (HLS) C/C++ code. However, the key contribution will be our parametrization and optimization of a tremendously large design space. So far, we are familiarizing ourselves with the existing MLIR flow using a toy example (simple neural network for MNIST digit classification) and figure out how to do this parametrization and optimization.
This project aims to create a framework that ultimately converts an algorithm written in a high level language into an efficient hardware implementation. Out of all of the existing deep learning compiler related projects, TVM and MLIR present existing tools that we can leverage in our proposed framework. With multiple levels of intermediate representations (IRs), MLIR appears to be the ideal tool for exploring ways to optimize the eventual design at various levels of abstraction (e.g. various levels of parallelism). Our framework will be based on MLIR, but it will incorporate a frontend for Sitao's powerful domain specific language (DSL) and a backend for high level synthesis (HLS) C/C++ code. However, the key contribution will be our parametrization and optimization of a tremendously large design space. So far, we are familiarizing ourselves with the existing MLIR flow using a toy example (simple neural network for MNIST digit classification) and figure out how to do this parametrization and optimization.
## Quick Start
This setup assumes that you have built LLVM and MLIR in `$LLVM_BUILD_DIR`. To build and launch the tests, run
```sh
mkdir build && cd build
cmake -G Ninja .. -DMLIR_DIR=$LLVM_BUILD_DIR/lib/cmake/mlir -DLLVM_EXTERNAL_LIT=$LLVM_BUILD_DIR/bin/llvm-lit
cmake --build . --target check-hlsld
```
## Hanchen TODO List
At a high level, Hanchen aims to leverage existing IPs within our framework based on MLIR. This necessitates the creation of a so-called fpgakrnl dialect in the context of MLIR.
1. Create an fpgakrnl dialect --> `include/fpgakrnl/Dialect.h`
2. Create conv and pool operations in fpgakrnl dialect --> `include/fpgakrnl/Ops.td`
3. Create pass for lowering conv and pool operations in ONNX dialect to fpgakrnl dialect
## Jack TODO List
At a high level, Jack aims to figure out what kind of transformations and lower passes are available within the most mature MLIR dialects so that he can set up the parametrization and optimization of the design space.
1. Find out how to do transformation and lowering passes within Affine, LinAlg, Loop dialects for the same two operations (conv, pool)
1. Emitting HLS Cpp code from standard dialect.
## References
1. [Toy Tutorial Chapter2: Emitting Basic MLIR](https://mlir.llvm.org/docs/Tutorials/Toy/Ch-2/#interfacing-with-mlir)
2. [ONNX-MLIR](https://github.com/onnx/onnx-mlir)
3. [DNNBuilder](https://github.com/IBM/AccDNN)
1. [MLIR Documents](https://mlir.llvm.org)
2. [github mlir-npcomp](https://github.com/llvm/mlir-npcomp)
3. [github circt](https://github.com/llvm/circt)
4. [github onnx-mlir](https://github.com/onnx/onnx-mlir)

View File

3
include/EmitHLSCpp.h Normal file
View File

@ -0,0 +1,3 @@
//===------------------------------------------------------------*- C++ -*-===//
//
//===----------------------------------------------------------------------===//

View File

@ -1,18 +0,0 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/Interfaces/SideEffects.h"
namespace mlir {
namespace fpgakrnl {
class FpgaKrnlDialect : public mlir::Dialect {
explicit FpgaKrnlDialect(mlir::MLIRContext *ctx);
static llvm::StringRef getDialectNamespace() {
return "fpgakrnl";
}
};
#define GET_OP_CLASSES
#include "fpgakrnl/Ops.h.inc"
}
}

View File

@ -1,417 +0,0 @@
/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\
|* *|
|* Op Definitions *|
|* *|
|* Automatically generated file, do not edit! *|
|* *|
\*===----------------------------------------------------------------------===*/
#ifdef GET_OP_LIST
#undef GET_OP_LIST
fpgakrnl::ConvOp,
fpgakrnl::MaxPoolOp
#endif // GET_OP_LIST
#ifdef GET_OP_CLASSES
#undef GET_OP_CLASSES
//===----------------------------------------------------------------------===//
// fpgakrnl::ConvOp definitions
//===----------------------------------------------------------------------===//
ConvOpOperandAdaptor::ConvOpOperandAdaptor(ArrayRef<Value> values) {
tblgen_operands = values;
}
ArrayRef<Value> ConvOpOperandAdaptor::getODSOperands(unsigned index) {
return {std::next(tblgen_operands.begin(), index), std::next(tblgen_operands.begin(), index + 1)};
}
Value ConvOpOperandAdaptor::ifmap() {
return *getODSOperands(0).begin();
}
StringRef ConvOp::getOperationName() {
return "fpgakrnl.conv";
}
Operation::operand_range ConvOp::getODSOperands(unsigned index) {
return {std::next(getOperation()->operand_begin(), index), std::next(getOperation()->operand_begin(), index + 1)};
}
Value ConvOp::ifmap() {
return *getODSOperands(0).begin();
}
Operation::result_range ConvOp::getODSResults(unsigned index) {
return {std::next(getOperation()->result_begin(), index), std::next(getOperation()->result_begin(), index + 1)};
}
Value ConvOp::ofmap() {
return *getODSResults(0).begin();
}
DenseIntElementsAttr ConvOp::kernel_valueAttr() {
return this->getAttr("kernel_value").dyn_cast_or_null<DenseIntElementsAttr>();
}
DenseIntElementsAttr ConvOp::kernel_value() {
auto attr = kernel_valueAttr();
if (!attr)
return ;
return attr;
}
ArrayAttr ConvOp::kernel_shapeAttr() {
return this->getAttr("kernel_shape").dyn_cast_or_null<ArrayAttr>();
}
ArrayAttr ConvOp::kernel_shape() {
auto attr = kernel_shapeAttr();
if (!attr)
return mlir::Builder(this->getContext()).getI32ArrayAttr({});
return attr;
}
ArrayAttr ConvOp::padsAttr() {
return this->getAttr("pads").dyn_cast_or_null<ArrayAttr>();
}
ArrayAttr ConvOp::pads() {
auto attr = padsAttr();
if (!attr)
return mlir::Builder(this->getContext()).getI32ArrayAttr({});
return attr;
}
ArrayAttr ConvOp::dilationsAttr() {
return this->getAttr("dilations").dyn_cast_or_null<ArrayAttr>();
}
Optional< ArrayAttr > ConvOp::dilations() {
auto attr = dilationsAttr();
return attr ? Optional< ArrayAttr >(attr) : (llvm::None);
}
ArrayAttr ConvOp::stridesAttr() {
return this->getAttr("strides").dyn_cast_or_null<ArrayAttr>();
}
Optional< ArrayAttr > ConvOp::strides() {
auto attr = stridesAttr();
return attr ? Optional< ArrayAttr >(attr) : (llvm::None);
}
IntegerAttr ConvOp::groupAttr() {
return this->getAttr("group").dyn_cast_or_null<IntegerAttr>();
}
Optional< APInt > ConvOp::group() {
auto attr = groupAttr();
return attr ? Optional< APInt >(attr.getValue()) : (llvm::None);
}
void ConvOp::kernel_valueAttr(DenseIntElementsAttr attr) {
this->getOperation()->setAttr("kernel_value", attr);
}
void ConvOp::kernel_shapeAttr(ArrayAttr attr) {
this->getOperation()->setAttr("kernel_shape", attr);
}
void ConvOp::padsAttr(ArrayAttr attr) {
this->getOperation()->setAttr("pads", attr);
}
void ConvOp::dilationsAttr(ArrayAttr attr) {
this->getOperation()->setAttr("dilations", attr);
}
void ConvOp::stridesAttr(ArrayAttr attr) {
this->getOperation()->setAttr("strides", attr);
}
void ConvOp::groupAttr(IntegerAttr attr) {
this->getOperation()->setAttr("group", attr);
}
void ConvOp::build(Builder *odsBuilder, OperationState &odsState, Type ofmap, Value ifmap, DenseIntElementsAttr kernel_value, ArrayAttr kernel_shape, ArrayAttr pads, /*optional*/ArrayAttr dilations, /*optional*/ArrayAttr strides, /*optional*/IntegerAttr group) {
odsState.addOperands(ifmap);
odsState.addAttribute("kernel_value", kernel_value);
odsState.addAttribute("kernel_shape", kernel_shape);
odsState.addAttribute("pads", pads);
if (dilations) {
odsState.addAttribute("dilations", dilations);
}
if (strides) {
odsState.addAttribute("strides", strides);
}
if (group) {
odsState.addAttribute("group", group);
}
odsState.addTypes(ofmap);
}
void ConvOp::build(Builder *odsBuilder, OperationState &odsState, ArrayRef<Type> resultTypes, Value ifmap, DenseIntElementsAttr kernel_value, ArrayAttr kernel_shape, ArrayAttr pads, /*optional*/ArrayAttr dilations, /*optional*/ArrayAttr strides, /*optional*/IntegerAttr group) {
odsState.addOperands(ifmap);
odsState.addAttribute("kernel_value", kernel_value);
odsState.addAttribute("kernel_shape", kernel_shape);
odsState.addAttribute("pads", pads);
if (dilations) {
odsState.addAttribute("dilations", dilations);
}
if (strides) {
odsState.addAttribute("strides", strides);
}
if (group) {
odsState.addAttribute("group", group);
}
assert(resultTypes.size() == 1u && "mismatched number of results");
odsState.addTypes(resultTypes);
}
void ConvOp::build(Builder *, OperationState &odsState, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes) {
assert(operands.size() == 1u && "mismatched number of parameters");
odsState.addOperands(operands);
odsState.addAttributes(attributes);
assert(resultTypes.size() == 1u && "mismatched number of return types");
odsState.addTypes(resultTypes);
}
LogicalResult ConvOp::verify() {
auto tblgen_kernel_value = this->getAttr("kernel_value");
if (tblgen_kernel_value) {
if (!(((tblgen_kernel_value.isa<DenseIntElementsAttr>())) && ((tblgen_kernel_value.cast<DenseIntElementsAttr>().getType().getElementType().isInteger(32))))) return emitOpError("attribute 'kernel_value' failed to satisfy constraint: 32-bit integer elements attribute");
}
auto tblgen_kernel_shape = this->getAttr("kernel_shape");
if (tblgen_kernel_shape) {
if (!(((tblgen_kernel_shape.isa<ArrayAttr>())) && (llvm::all_of(tblgen_kernel_shape.cast<ArrayAttr>(), [](Attribute attr) { return ((attr.isa<IntegerAttr>())) && ((attr.cast<IntegerAttr>().getType().isSignlessInteger(32))); })))) return emitOpError("attribute 'kernel_shape' failed to satisfy constraint: 32-bit integer array attribute");
}
auto tblgen_pads = this->getAttr("pads");
if (tblgen_pads) {
if (!(((tblgen_pads.isa<ArrayAttr>())) && (llvm::all_of(tblgen_pads.cast<ArrayAttr>(), [](Attribute attr) { return ((attr.isa<IntegerAttr>())) && ((attr.cast<IntegerAttr>().getType().isSignlessInteger(32))); })))) return emitOpError("attribute 'pads' failed to satisfy constraint: 32-bit integer array attribute");
}
auto tblgen_dilations = this->getAttr("dilations");
if (tblgen_dilations) {
if (!(((tblgen_dilations.isa<ArrayAttr>())) && (llvm::all_of(tblgen_dilations.cast<ArrayAttr>(), [](Attribute attr) { return ((attr.isa<IntegerAttr>())) && ((attr.cast<IntegerAttr>().getType().isSignlessInteger(32))); })))) return emitOpError("attribute 'dilations' failed to satisfy constraint: 32-bit integer array attribute");
}
auto tblgen_strides = this->getAttr("strides");
if (tblgen_strides) {
if (!(((tblgen_strides.isa<ArrayAttr>())) && (llvm::all_of(tblgen_strides.cast<ArrayAttr>(), [](Attribute attr) { return ((attr.isa<IntegerAttr>())) && ((attr.cast<IntegerAttr>().getType().isSignlessInteger(32))); })))) return emitOpError("attribute 'strides' failed to satisfy constraint: 32-bit integer array attribute");
}
auto tblgen_group = this->getAttr("group");
if (tblgen_group) {
if (!(((tblgen_group.isa<IntegerAttr>())) && ((tblgen_group.cast<IntegerAttr>().getType().isSignlessInteger(32))))) return emitOpError("attribute 'group' failed to satisfy constraint: 32-bit signless integer attribute");
}
{
unsigned index = 0; (void)index;
for (Value v : getODSOperands(0)) {
(void)v;
if (!((((v.getType().isa<MemRefType>())) && ((true))) || (((v.getType().isa<TensorType>())) && ((true))))) {
return emitOpError("operand #") << index << " must be memref of any type values or tensor of any type values, but got " << v.getType();
}
++index;
}
}
{
unsigned index = 0; (void)index;
for (Value v : getODSResults(0)) {
(void)v;
if (!((((v.getType().isa<MemRefType>())) && ((true))) || (((v.getType().isa<TensorType>())) && ((true))))) {
return emitOpError("result #") << index << " must be memref of any type values or tensor of any type values, but got " << v.getType();
}
++index;
}
}
if (this->getOperation()->getNumRegions() != 0) {
return emitOpError("has incorrect number of regions: expected 0 but found ") << this->getOperation()->getNumRegions();
}
return ::verify(*this);
}
void ConvOp::getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &effects) {
}
//===----------------------------------------------------------------------===//
// fpgakrnl::MaxPoolOp definitions
//===----------------------------------------------------------------------===//
MaxPoolOpOperandAdaptor::MaxPoolOpOperandAdaptor(ArrayRef<Value> values) {
tblgen_operands = values;
}
ArrayRef<Value> MaxPoolOpOperandAdaptor::getODSOperands(unsigned index) {
return {std::next(tblgen_operands.begin(), index), std::next(tblgen_operands.begin(), index + 1)};
}
Value MaxPoolOpOperandAdaptor::ifmap() {
return *getODSOperands(0).begin();
}
StringRef MaxPoolOp::getOperationName() {
return "fpgakrnl.maxpool";
}
Operation::operand_range MaxPoolOp::getODSOperands(unsigned index) {
return {std::next(getOperation()->operand_begin(), index), std::next(getOperation()->operand_begin(), index + 1)};
}
Value MaxPoolOp::ifmap() {
return *getODSOperands(0).begin();
}
Operation::result_range MaxPoolOp::getODSResults(unsigned index) {
return {std::next(getOperation()->result_begin(), index), std::next(getOperation()->result_begin(), index + 1)};
}
Value MaxPoolOp::ofmap() {
return *getODSResults(0).begin();
}
ArrayAttr MaxPoolOp::kernel_shapeAttr() {
return this->getAttr("kernel_shape").dyn_cast_or_null<ArrayAttr>();
}
ArrayAttr MaxPoolOp::kernel_shape() {
auto attr = kernel_shapeAttr();
if (!attr)
return mlir::Builder(this->getContext()).getI32ArrayAttr({});
return attr;
}
ArrayAttr MaxPoolOp::padsAttr() {
return this->getAttr("pads").dyn_cast_or_null<ArrayAttr>();
}
ArrayAttr MaxPoolOp::pads() {
auto attr = padsAttr();
if (!attr)
return mlir::Builder(this->getContext()).getI32ArrayAttr({});
return attr;
}
ArrayAttr MaxPoolOp::dilationsAttr() {
return this->getAttr("dilations").dyn_cast_or_null<ArrayAttr>();
}
Optional< ArrayAttr > MaxPoolOp::dilations() {
auto attr = dilationsAttr();
return attr ? Optional< ArrayAttr >(attr) : (llvm::None);
}
ArrayAttr MaxPoolOp::stridesAttr() {
return this->getAttr("strides").dyn_cast_or_null<ArrayAttr>();
}
Optional< ArrayAttr > MaxPoolOp::strides() {
auto attr = stridesAttr();
return attr ? Optional< ArrayAttr >(attr) : (llvm::None);
}
void MaxPoolOp::kernel_shapeAttr(ArrayAttr attr) {
this->getOperation()->setAttr("kernel_shape", attr);
}
void MaxPoolOp::padsAttr(ArrayAttr attr) {
this->getOperation()->setAttr("pads", attr);
}
void MaxPoolOp::dilationsAttr(ArrayAttr attr) {
this->getOperation()->setAttr("dilations", attr);
}
void MaxPoolOp::stridesAttr(ArrayAttr attr) {
this->getOperation()->setAttr("strides", attr);
}
void MaxPoolOp::build(Builder *odsBuilder, OperationState &odsState, Type ofmap, Value ifmap, ArrayAttr kernel_shape, ArrayAttr pads, /*optional*/ArrayAttr dilations, /*optional*/ArrayAttr strides) {
odsState.addOperands(ifmap);
odsState.addAttribute("kernel_shape", kernel_shape);
odsState.addAttribute("pads", pads);
if (dilations) {
odsState.addAttribute("dilations", dilations);
}
if (strides) {
odsState.addAttribute("strides", strides);
}
odsState.addTypes(ofmap);
}
void MaxPoolOp::build(Builder *odsBuilder, OperationState &odsState, ArrayRef<Type> resultTypes, Value ifmap, ArrayAttr kernel_shape, ArrayAttr pads, /*optional*/ArrayAttr dilations, /*optional*/ArrayAttr strides) {
odsState.addOperands(ifmap);
odsState.addAttribute("kernel_shape", kernel_shape);
odsState.addAttribute("pads", pads);
if (dilations) {
odsState.addAttribute("dilations", dilations);
}
if (strides) {
odsState.addAttribute("strides", strides);
}
assert(resultTypes.size() == 1u && "mismatched number of results");
odsState.addTypes(resultTypes);
}
void MaxPoolOp::build(Builder *, OperationState &odsState, ArrayRef<Type> resultTypes, ValueRange operands, ArrayRef<NamedAttribute> attributes) {
assert(operands.size() == 1u && "mismatched number of parameters");
odsState.addOperands(operands);
odsState.addAttributes(attributes);
assert(resultTypes.size() == 1u && "mismatched number of return types");
odsState.addTypes(resultTypes);
}
LogicalResult MaxPoolOp::verify() {
auto tblgen_kernel_shape = this->getAttr("kernel_shape");
if (tblgen_kernel_shape) {
if (!(((tblgen_kernel_shape.isa<ArrayAttr>())) && (llvm::all_of(tblgen_kernel_shape.cast<ArrayAttr>(), [](Attribute attr) { return ((attr.isa<IntegerAttr>())) && ((attr.cast<IntegerAttr>().getType().isSignlessInteger(32))); })))) return emitOpError("attribute 'kernel_shape' failed to satisfy constraint: 32-bit integer array attribute");
}
auto tblgen_pads = this->getAttr("pads");
if (tblgen_pads) {
if (!(((tblgen_pads.isa<ArrayAttr>())) && (llvm::all_of(tblgen_pads.cast<ArrayAttr>(), [](Attribute attr) { return ((attr.isa<IntegerAttr>())) && ((attr.cast<IntegerAttr>().getType().isSignlessInteger(32))); })))) return emitOpError("attribute 'pads' failed to satisfy constraint: 32-bit integer array attribute");
}
auto tblgen_dilations = this->getAttr("dilations");
if (tblgen_dilations) {
if (!(((tblgen_dilations.isa<ArrayAttr>())) && (llvm::all_of(tblgen_dilations.cast<ArrayAttr>(), [](Attribute attr) { return ((attr.isa<IntegerAttr>())) && ((attr.cast<IntegerAttr>().getType().isSignlessInteger(32))); })))) return emitOpError("attribute 'dilations' failed to satisfy constraint: 32-bit integer array attribute");
}
auto tblgen_strides = this->getAttr("strides");
if (tblgen_strides) {
if (!(((tblgen_strides.isa<ArrayAttr>())) && (llvm::all_of(tblgen_strides.cast<ArrayAttr>(), [](Attribute attr) { return ((attr.isa<IntegerAttr>())) && ((attr.cast<IntegerAttr>().getType().isSignlessInteger(32))); })))) return emitOpError("attribute 'strides' failed to satisfy constraint: 32-bit integer array attribute");
}
{
unsigned index = 0; (void)index;
for (Value v : getODSOperands(0)) {
(void)v;
if (!((((v.getType().isa<MemRefType>())) && ((true))) || (((v.getType().isa<TensorType>())) && ((true))))) {
return emitOpError("operand #") << index << " must be memref of any type values or tensor of any type values, but got " << v.getType();
}
++index;
}
}
{
unsigned index = 0; (void)index;
for (Value v : getODSResults(0)) {
(void)v;
if (!((((v.getType().isa<MemRefType>())) && ((true))) || (((v.getType().isa<TensorType>())) && ((true))))) {
return emitOpError("result #") << index << " must be memref of any type values or tensor of any type values, but got " << v.getType();
}
++index;
}
}
if (this->getOperation()->getNumRegions() != 0) {
return emitOpError("has incorrect number of regions: expected 0 but found ") << this->getOperation()->getNumRegions();
}
return ::verify(*this);
}
void MaxPoolOp::getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &effects) {
}
#endif // GET_OP_CLASSES

View File

@ -1,49 +0,0 @@
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffects.td"
def FpgaKrnl_Dialect : Dialect {
let name = "fpgakrnl";
let cppNamespace = "fpgakrnl";
}
class FpgaKrnl_Op<string mnemonic, list<OpTrait> traits = []> :
Op<FpgaKrnl_Dialect, mnemonic, traits>;
def ConvOp : FpgaKrnl_Op<"conv", [NoSideEffect]> {
let summary = "conv";
let description = [{
Convolution Operation.
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$ifmap,
DefaultValuedAttr<AnyI32ElementsAttr, "0">:$kernel_value,
DefaultValuedAttr<I32ArrayAttr, "{}">:$kernel_shape,
DefaultValuedAttr<I32ArrayAttr, "{}">:$pads,
OptionalAttr<I32ArrayAttr>:$dilations,
OptionalAttr<I32ArrayAttr>:$strides,
OptionalAttr<I32Attr>:$group);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$ofmap);
let verifier = [{ return ::verify(*this); }];
}
def MaxPoolOp : FpgaKrnl_Op<"maxpool", [NoSideEffect]> {
let summary = "maxpool";
let description = [{
Max Pooling Operation.
}];
// hanchen is working here
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$ifmap,
DefaultValuedAttr<I32ArrayAttr, "{}">:$kernel_shape,
DefaultValuedAttr<I32ArrayAttr, "{}">:$pads,
OptionalAttr<I32ArrayAttr>:$dilations,
OptionalAttr<I32ArrayAttr>:$strides);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$ofmap);
let verifier = [{ return ::verify(*this); }];
}

1
lib/CMakeLists.txt Normal file
View File

@ -0,0 +1 @@
add_subdirectory(EmitHLSCpp)

View File

@ -0,0 +1,10 @@
file(GLOB globbed *.cpp)
add_mlir_library(HLSLDEmitHLSCpp
${globbed}
ADDITIONAL_HEADER_DIRS
LINK_LIBS PUBLIC
MLIRStandardOps
)

View File

@ -0,0 +1,3 @@
//===------------------------------------------------------------*- C++ -*-===//
//
//===----------------------------------------------------------------------===//

View File

View File

@ -1,12 +0,0 @@
struct LoweringToConvOp : public mlir::ConversionPattern {
LoweringToConvOp(mlir::MLIRContext)
}
void OnnxToFpgaKrnlLoweringPass::runOnFunction() {
mlir::ConversionTarget target(getContext());
target.addLegalDialect<FpgaKrnlDialect>();
target.addIllegalDialect<mlir::ONNXOpsDialect>();
}

View File

@ -1,29 +0,0 @@
BSD 3-Clause License
Copyright (c) 2020, Scalable Parallel Computing Lab, ETH Zurich
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,132 +0,0 @@
[![Build Status](https://travis-ci.org/spcl/pymlir.svg?branch=master)](https://travis-ci.org/spcl/pymlir)
[![codecov](https://codecov.io/gh/spcl/pymlir/branch/master/graph/badge.svg)](https://codecov.io/gh/spcl/pymlir)
# pyMLIR: Python Interface for the Multi-Level Intermediate Representation
pyMLIR is a full Python interface to parse, process, and output [MLIR](https://mlir.llvm.org/) files according to the
syntax described in the [MLIR documentation](https://github.com/llvm/llvm-project/tree/master/mlir/docs). pyMLIR
supports the basic dialects and can be extended with other dialects. It uses [Lark](https://github.com/lark-parser/lark)
to parse the MLIR syntax, and mirrors the classes into Python classes. Custom dialects can also be implemented with a
Python string-format-like syntax, or via direct parsing.
Note that the tool *does not depend on LLVM or MLIR*. It can be installed and invoked directly from Python.
## Instructions
**How to install:** `pip install pymlir`
**Requirements:** Python 3.6 or newer, and the requirements in `setup.py` or `requirements.txt`. To manually install the
requirements, use `pip install -r requirements.txt`
**Problem parsing MLIR files?** Run the file through LLVM's `mlir-opt --mlir-print-op-generic` to get the generic form
of the IR (instructions on how to build/install MLIR can be found [here](https://mlir.llvm.org/getting_started/)):
```
$ mlir-opt file.mlir --mlir-print-op-generic > output.mlir
```
**Found other problems parsing files?** Not all dialects and modes are supported. Feel free to send us an issue or
create a pull request! This is a community project and we welcome any contribution.
## Usage examples
### Parsing MLIR files into Python
```python
import mlir
# Read a file path, file handle (stream), or a string
ast1 = mlir.parse_path('/path/to/file.mlir')
ast2 = mlir.parse_file(open('/path/to/file.mlir', 'r'))
ast3 = mlir.parse_string('''
module {
func @toy_func(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> {
%t_tensor = "toy.transpose"(%tensor) { inplace = true } : (tensor<2x3xf64>) -> tensor<3x2xf64>
return %t_tensor : tensor<3x2xf64>
}
}
''')
```
### Inspecting MLIR files in Python
MLIR files can be inspected by dumping their contents (which will print standard MLIR code), or by using the same tools
as you would with Python's [ast](https://docs.python.org/3/library/ast.html) module.
```python
import mlir
# Dump valid MLIR files
m = mlir.parse_path('/path/to/file.mlir')
print(m.dump())
print('---')
# Dump the AST directly
print(m.dump_ast())
print('---')
# Or visit each node type by implementing visitor functions
class MyVisitor(mlir.NodeVisitor):
def visit_Function(self, node: mlir.astnodes.Function):
print('Function detected:', node.name.value)
MyVisitor().visit(m)
```
### Transforming MLIR files
MLIR files can also be transformed with a Python-like
[NodeTransformer](https://docs.python.org/3/library/ast.html#ast.NodeTransformer) object.
```python
import mlir
m = mlir.parse_path('/path/to/file.mlir')
# Simple node transformer that removes all operations with a result
class RemoveAllResultOps(mlir.NodeTransformer):
def visit_Operation(self, node: mlir.astnodes.Operation):
# There are one or more outputs, return None to remove from AST
if len(node.result_list) > 0:
return None
# No outputs, no need to do anything
return self.generic_visit(node)
m = RemoveAllResultOps().visit(m)
# Write back to file
with open('output.mlir', 'w') as fp:
fp.write(m.dump())
```
### Using custom dialects
Custom dialects can be written and loaded as part of the pyMLIR parser. [See full tutorial here](doc/custom_dialect.rst).
```python
import mlir
from lark import UnexpectedCharacters
from .mydialect import dialect
# Try to parse as-is
try:
m = mlir.parse_path('/path/to/matrixfile.mlir')
except UnexpectedCharacters: # MyMatrix dialect not recognized
pass
# Add dialect to the parser
m = mlir.parse_path('/path/to/matrixfile.mlir',
dialects=[dialect])
# Print output back
print(m.dump_ast())
```
### Built-in dialect implementations and more examples
All dialect implementations can be found in the [dialects](mlir/dialects) subfolder. Additional uses
of the library, including a custom dialect implementation, can be found in the [tests](tests)
subfolder.

View File

@ -1,20 +0,0 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

View File

@ -1,14 +0,0 @@
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
project = 'pyMLIR'
copyright = '2020, Scalable Parallel Computing Laboratory, ETH Zurich'
author = 'Scalable Parallel Computing Laboratory, ETH Zurich'
release = '0.5'
extensions = ['sphinx.ext.autodoc']
templates_path = ['_templates']
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']
master_doc = 'index'

View File

@ -1,158 +0,0 @@
Creating a Custom Dialect
=========================
One of MLIR's most powerful features is being able to define custom dialects. While the
opaque syntax is always supported by pyMLIR, parsing "pretty" definitions of custom dialects
is done by adding them to the ``dialects`` field of the MLIR parser, as in the following
snippet:
.. code-block:: python
import mlir
# Load dialect from Python file
import mydialect
# Add dialect to the parser
m = mlir.parse_path('/path/to/file.mlir', dialects=[mydialect.dialect])
A dialect is represented by a ``Dialect`` class, which is composed of custom types and
operations. In this document, we use ``toy`` as the dialect's name.
The structure of a dialect file is usually as follows::
# Imports
# Dialect type AST node classes
# Dialect operation AST node classes
# Dialect class definition
Simple Dialect Syntax API
-------------------------
To make dialect definition as simple as possible, pyMLIR provides a Syntax API based on
Python's ``str.format`` grammar. Defining a dialect type or operation using the Syntax API
is then performed as follows:
.. code-block:: python
from mlir.dialect import DialectOp, DialectType
class RaggedTensorType(DialectType):
""" AST node class for the example "toy" dialect representing a ragged tensor. """
_syntax_ = ('toy.ragged < {implementation.string_literal} , {dims.dimension_list_ranked} '
'{type.tensor_memref_element_type} >')
The syntax format parses any ``{name.type}`` token as an AST node field ``name`` with
type ``type``. The types that can be used either come from ``mlir.lark``, or from the
``preamble`` argument to the ``Dialect`` class (see below). Note the spaces between
tokens - they represent the fact that whitespace can be inserted between them.
pyMLIR will then detect the three fields (``implementation``, ``dims``, and ``type``) and
inject them into the AST node type. You can specify more than one match for your type
or operation, and if fields are not defined they will be set as ``None``. Example:
.. code-block:: python
class DensifyOp(DialectOp):
""" AST node for an operation with an optional value. """
_syntax_ = ['toy.densify {arg.ssa_id} : {type.tensor_type}',
'toy.densify {arg.ssa_id} , {pad.constant_literal} : {type.tensor_type}']
When dumping the code back to MLIR, pyMLIR remembers which match created the AST node and
will create the appropriate code.
Constructing the dialect itself follows creating the object with a unique dialect name, and
all the operations and types.
.. code-block:: python
from mlir.dialect import Dialect
from mlir import parse_path
# Define dialect
my_dialect = Dialect('toy', ops=[DensifyOp], types=[RaggedTensorType])
# Use dialect to parse file
module = parse_path('/path/to/toy_file.mlir', dialects=[my_dialect])
Advanced Dialect Behavior
-------------------------
In order to extend custom behavior in the dialect (e.g., to change how a node is read
or written), you can extend the ``DialectOp`` or ``DialectType`` classes.
In addition, there are two mechanisms that can be used in the ``Dialect`` class in order
to parse concepts beyond nodes for types and operations: ``preamble`` and ``transformers``.
Writing a new AST node has four implementation requirements:
1. Populating the ``_fields_`` static class member
2. Implementing an ``__init__`` function to parse Lark syntax trees
3. Implementing a ``dump`` function to output a string with the MLIR syntax
4. Either implementing a Lark rule in the ``Dialect`` preamble with and mapping the rule
name to the class using the ``_rule_`` static class member, or defining the Lark
rules directly in the ``_lark_`` static class member
For example, if we wanted to be strict with how we dump the ``RaggedTensorType``, and use
our custom rule for parsing, we would implement the class in the following way:
.. code-block:: python
from mlir.dialect import DialectType
from mlir.astnodes import Node, dump_or_value
from lark import Tree
from typing import Union, List
class RaggedTensorType(DialectType):
_fields_ = ['implementation', 'dims', 'type']
# Notice that the first argument is optional
_lark_ = ['"toy.ragged" "<" (string_literal ",")? dimension_list_ranked '
'tensor_memref_element_type ">"']
def __init__(self, match: int, node: List[Union[Tree, Node]], **fields):
# Note that since _lark_ has only one element, "match" will always be 0
if len(node) == 2: # Only dims and type were defined
self.implementation = None
self.dims = node[0]
self.type = node[1]
elif len(node) == 3: # All three fields were defined
self.implementation = node[0]
self.dims = node[1]
self.type = node[2]
super().__init__(None, **fields)
def dump(self, indent: int = 0) -> str:
# Note the exclamation mark denoting a dialect type
result = '!toy.ragged<'
if self.implementation:
result += dump_or_value(self.implementation, indent)
result += '%sx%s>' % ('x'.join(dump_or_value(d, indent) for d in self.dims),
dump_or_value(self.type, indent))
return result
``dump_or_value`` is a helper function in ``mlir.astnodes`` to either write out the value,
a list/dict/tuple of values, or literals into MLIR format. For most cases, though, the
``_syntax_`` format will suffice (and creates shorter code than above).
As for extensions to the dialect itself, ``preamble`` and ``transformers`` are keyword
arguments that can be given to the ``Dialect`` class. The former allows arbitrary Lark
syntax to be parsed as part of the dialect, and the latter is a dictionary that maps
rule names to node-constructing callable functions/classes. This gives a custom dialect
full control over the syntax parsing and tree construction.
For example, we can create rules for a new kind of list structure in our toy dialect:
.. code-block:: python
my_dialect = Dialect('toy', ops=[DensifyOp], types=[RaggedTensorType],
preamble='''
// Exclamation mark in Lark means that string tokens will be preserved upon parsing
!toy_impl_type : "coo" | "csr" | "csc" | "ell"
toy_impl_list : toy_impl_type ("+" toy_impl_type)*
''',
transformers=dict(
toy_impl_list=list # Will construct a list from parsed values
))
Now we can parse lists of specific implementation types for our ragged tensor, e.g.,
``toy.ragged<coo+csr,32x14xf64>`` rather than one string literal. Note that
the type ``_lark_`` or ``_syntax_`` has to change accordingly.

View File

@ -1,18 +0,0 @@
pyMLIR: Python Interface for MLIR
=================================
pyMLIR is a Python Interface for the Multi-Level Intermediate Representation (MLIR).
.. toctree::
:maxdepth: 2
custom_dialect
source/modules
Reference
=========
* :ref:`genindex`
* :ref:`modindex`

View File

@ -1,3 +0,0 @@
from .parser import parse_file, parse_path, parse_string, Parser
from . import astnodes
from .visitors import NodeVisitor, NodeTransformer

File diff suppressed because it is too large Load Diff

View File

@ -1,243 +0,0 @@
""" MLIR Dialect representation. """
import inspect
from lark import Token
from mlir import astnodes
import parse
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
def _get_fields(syntax: str):
return [
tuple(name.split('.'))
for name in parse.compile(syntax)._name_types.keys()
]
class DialectElement(astnodes.Node):
"""
A class that can be extended by a dialect to define an MLIR AST node.
In the simple case, the subclass only needs to implement a list of syntax
rules (in ``_syntax_``) for parsing. See examples for use cases. In more
complicated nodes (i.e., with optional or variable-length parameters), a
new parsing (as ``__init__``) and dumping (``dump``) methods have to be
implemented, along with the node field names (stored in ``_fields_``).
"""
# A Python string-format syntax for matching this operation, using a
# "{name.type}" syntax. The types can either be provided in the dialect
# preamble, or using the definitions in "mlir.lark". If multiple formats
# may match, a list can be provided. For example:
# ['return', 'return {values.ssa_use_list} : {types.type_list_no_parens}']
# will implement the return operation in the Standard dialect.
_syntax_: Optional[Union[str, List[str]]] = None
# If custom behavior is defined through the dialect preamble, define rule
# name on this variable to match this class
_rule_: Optional[str] = None
# Internal fields to be filled by make_rules
_syntax_fields_: Optional[List[List[Tuple[str, str]]]] = None
_lark_: Optional[List[str]] = None
@classmethod
def make_rules(cls):
if cls._syntax_ is None:
return
# Set _fields_ according to _syntax_
if isinstance(cls._syntax_, str):
cls._syntax_ = [cls._syntax_]
if not isinstance(cls._syntax_, (list, tuple)):
raise ValueError('Invalid syntax expression (can only be a string '
'or a list of strings')
# Collect fields and create lark expressions
fields = set()
lark_exprs = []
compiled_fields = []
for syntax in cls._syntax_:
sfields = _get_fields(syntax)
compiled_fields.append(sfields)
if any(len(field) != 2 for field in sfields):
raise ValueError(
'Syntax matches must provide exactly one name '
'and one type')
fields |= set(f[0] for f in sfields)
# Create Lark expression
# Replace {{ and }}
syntax = syntax.replace('{{', '{LBRACE}').replace('}}', '{RBRACE}')
# Replace words with strings
syntax = ' '.join(
(('"%s"' % word) if not word.startswith('{') else word)
for word in syntax.split())
# Replace back {field.type} with types
for fname, ftype in sfields:
syntax = syntax.replace('{%s.%s}' % (fname, ftype), ftype)
# Replace back braces
syntax = syntax.replace('{LBRACE}', '{')
syntax = syntax.replace('{RBRACE}', '}')
lark_exprs.append(syntax)
cls._fields_ = list(fields)
cls._syntax_fields_ = compiled_fields
cls._lark_ = lark_exprs
def __init__(self, match: int, node: Token = None, **fields):
if self._syntax_ is None and node is not None:
raise NotImplementedError('Dialect element must either use '
'"_syntax_" or implement its own '
'constructor')
if node is None:
super().__init__(None, **fields)
return
# Get syntax expression
self._match = match
sfields = self._syntax_fields_[match]
other_fields = set(self._fields_) - set(f[0] for f in sfields)
# Set each field according to defined names
if node is not None and isinstance(node, list):
for fname, fval in zip(sfields, node):
setattr(self, fname[0], fval)
# Set other fields to None
for fname in other_fields:
setattr(self, fname, None)
super().__init__(None, **fields)
def dump(self, indent: int = 0) -> str:
if self._syntax_ is None:
raise NotImplementedError('Dialect element must either use '
'"_syntax_" or implement its own '
'"dump" method')
sfields = self._syntax_fields_[self._match]
dump_str = self._syntax_[self._match]
for fname, ftype in sfields:
dump_str = dump_str.replace('{%s.%s}' % (fname, ftype),
'{%s}' % fname)
return dump_str.format_map({
f[0]: astnodes.dump_or_value(getattr(self, f[0]), indent)
for f in sfields
})
class DialectOp(DialectElement):
""" A class that can be extended by a dialect to define an MLIR AST node
for an operation. See DialectElement for more details. """
pass
class DialectType(DialectElement):
""" A class that can be extended by a dialect to define an MLIR AST node
for a data type. See DialectElement for more details. """
def dump(self, indent: int = 0) -> str:
return '!' + super().dump(indent)
class Dialect(object):
def __init__(
self,
name: str,
ops: Optional[List[Type[DialectOp]]] = None,
types: Optional[List[Type[DialectType]]] = None,
preamble: Optional[str] = None,
transformers: Optional[Dict[str, Union[Callable, Type]]] = None):
"""
:param name: Dialect name (should be unique).
:param ops: A list of dialect AST nodes for operations.
:param types: A list of dialect AST nodes for types.
:param preamble: Preamble in Lark syntax for the dialect.
:param transformers: A dictionary that maps between rule names in the
Lark preamble to Python classes or AST node types.
"""
self.contents = preamble or ''
self.name = name
self.ops = ops or []
self.types = types or []
self.transformers = transformers or {}
# Make syntactic rules for each operation and type
for op in self.ops:
op.make_rules()
for typ in self.types:
typ.make_rules()
def add_dialect_rules(dialect: Dialect, elements: List[Type[DialectElement]],
typename: str, rule_dict: Dict[str, Callable]) -> str:
"""
Add dialect rules in Lark form to an MLIR parser.
:param dialect: The dialect object to use.
:param elements: A list of dialect elements (e.g. ops, types) to add.
:param typename: A prefix to add to the new element rules.
:param rule_dict: An existing rule dictionary (intended for a Lark
Transformer) to add the elements to.
:return: Lark source code containing new rules as necessary.
"""
parser_src = ''
for elem in elements:
if elem._rule_ is not None: # Custom rules defined in dialect
rule_dict[elem._rule_] = elem
continue
if elem._lark_ is None:
raise SyntaxError('Either a "_rule_" or "_syntax_" must '
'be defined for dialect element '
'%s' % elem.__name__)
# Fill contents with procedurally-generated rules
for i, rule in enumerate(elem._lark_):
rule_name = '%s_%s_%s_%d' % (dialect.name, typename,
elem.__name__.lower(), i)
parser_src += '%s: %s\n' % (rule_name, rule)
# Add rule to transformer
def create_rule(elem, i):
return lambda *value: elem(i, *value)
rule_dict[rule_name] = create_rule(elem, i)
return parser_src
def is_op(member: Any, module: str) -> bool:
""" Returns true if an object is a Dialect operation subclass. """
return (inspect.isclass(member) and issubclass(member, DialectOp)
and member.__module__ == module)
def is_type(member: Any, module: str) -> bool:
""" Returns true if an object is a Dialect type subclass. """
return (inspect.isclass(member) and issubclass(member, DialectType)
and member.__module__ == module)
#################################################################
# Helper classes for dialects
class UnaryOperation(DialectOp):
""" Helper class to create unary operations in dialects. """
_opname_ = 'UNDEF'
@classmethod
def make_rules(cls):
cls._syntax_ = '%s {operand.ssa_use} : {type.type}' % cls._opname_
super().make_rules()
class BinaryOperation(DialectOp):
""" Helper class to create binary operations in dialects. """
_opname_ = 'UNDEF'
@classmethod
def make_rules(cls):
cls._syntax_ = (
'%s {operand_a.ssa_use} , {operand_b.ssa_use} : {type.type}' %
cls._opname_)
super().make_rules()

View File

@ -1,5 +0,0 @@
from .affine import affine
from .standard import standard
from .loop import loop
STANDARD_DIALECTS = [affine, standard, loop]

View File

@ -1,55 +0,0 @@
""" Implementation of the Affine dialect. """
import inspect
import sys
from mlir.dialect import Dialect, DialectOp, is_op
class AffineApplyOp(DialectOp):
_syntax_ = 'affine.apply {map.affine_map} {args.dim_and_symbol_use_list}'
class AffineForOp(DialectOp):
_syntax_ = [
'affine.for {index.ssa_id} = {begin.symbol_or_const} to {end.symbol_or_const} {body.region}',
'affine.for {index.ssa_id} = {begin.symbol_or_const} to {end.symbol_or_const} step {step.symbol_or_const} {body.region}',
'affine.for {index.ssa_id} = {begin.symbol_or_const} to {end.symbol_or_const} {body.region} {attributes.attribute_dict}',
'affine.for {index.ssa_id} = {begin.symbol_or_const} to {end.symbol_or_const} step {step.symbol_or_const} {body.region} {attributes.attribute_dict}'
]
class AffineIfOp(DialectOp):
_syntax_ = ['affine.if {cond.map_or_set_id} ( {operands.ssa_use_list} ) {body.region}',
'affine.if {cond.map_or_set_id} ( {operands.ssa_use_list} ) {body.region} else {elsebody.region}']
class AffineLoadOp(DialectOp):
_syntax_ = 'affine.load {arg.ssa_use} [ {index.multi_dim_affine_expr_no_parens} ] : {type.memref_type}'
class AffineStoreOp(DialectOp):
_syntax_ = 'affine.store {addr.ssa_use} , {ref.ssa_use} [ {index.multi_dim_affine_expr_no_parens} ] : {type.memref_type}'
class AffineMinOp(DialectOp):
_syntax_ = 'affine.min {map.affine_map_inline} {operands.dim_and_symbol_use_list}'
class AffinePrefetchOp(DialectOp):
_syntax_ = 'affine.prefetch {arg.ssa_use} [ {index.multi_dim_affine_expr_no_parens} ] , {specifier.bare_id} , locality < {locality.integer_literal} > , {cachetype.bare_id} : {type.type}'
class AffineDmaStartOperation(DialectOp):
_syntax_ = [
'affine.dma_start {src.ssa_use} [ {src_index.multi_dim_affine_expr_no_parens} ] , {dst.ssa_use} [ {dst_index.multi_dim_affine_expr_no_parens} ] , {tag.ssa_use} [ {tag_index.multi_dim_affine_expr_no_parens} ] , {size.ssa_use} : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}',
'affine.dma_start {src.ssa_use} [ {src_index.multi_dim_affine_expr_no_parens} ] , {dst.ssa_use} [ {dst_index.multi_dim_affine_expr_no_parens} ] , {tag.ssa_use} [ {tag_index.multi_dim_affine_expr_no_parens} ] , {size.ssa_use} , {stride.ssa_use} , {transfer_per_stride.ssa_use} : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}'
]
class AffineDmaWaitOperation(DialectOp):
_syntax_ = 'affine.dma_wait {tag.ssa_use} [ {tag_index.multi_dim_affine_expr_no_parens} ] , {size.ssa_use} : {type.memref_type}'
# Inspect current module to get all classes defined above
affine = Dialect('affine', ops=[m[1] for m in inspect.getmembers(
sys.modules[__name__], lambda obj: is_op(obj, __name__))])

View File

@ -1,20 +0,0 @@
""" Implementation of the Loop dialect. """
import inspect
import sys
from mlir.dialect import Dialect, DialectOp, is_op
class LoopForOp(DialectOp):
_syntax_ = ['loop.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} {body.region}',
'loop.for {index.ssa_id} = {begin.ssa_id} to {end.ssa_id} step {step.ssa_id} {body.region}']
class LoopIfOp(DialectOp):
_syntax_ = ['loop.if {cond.ssa_id} {body.region}',
'loop.if {cond.ssa_id} {body.region} else {elsebody.region}']
# Inspect current module to get all classes defined above
loop = Dialect('loop', ops=[m[1] for m in inspect.getmembers(
sys.modules[__name__], lambda obj: is_op(obj, __name__))])

View File

@ -1,98 +0,0 @@
""" Implementation of the Standard dialect. """
import inspect
import sys
from mlir.dialect import (Dialect, DialectOp, UnaryOperation, BinaryOperation,
is_op)
# Terminator Operations
class BrOperation(DialectOp):
_syntax_ = ['br {block.block_id}',
'br {block.block_id} {args.block_arg_list}']
class CondBrOperation(DialectOp):
_syntax_ = ['cond_br {cond.ssa_use} , {block_true.block_id} , {block_false.block_id}']
class ReturnOperation(DialectOp):
_syntax_ = ['return',
'return {values.ssa_use_list} : {types.type_list_no_parens}']
# Core Operations
class CallOperation(DialectOp):
_syntax_ = ['call {func.symbol_ref_id} () : {type.function_type}',
'call {func.symbol_ref_id} ( {args.ssa_use_list} ) : {type.function_type}']
class CallIndirectOperation(DialectOp):
_syntax_ = ['call_indirect {func.symbol_ref_id} () : {type.function_type}',
'call_indirect {func.symbol_ref_id} ( {args.ssa_use_list} ) : {type.function_type}']
class DimOperation(DialectOp):
_syntax_ = 'dim {operand.ssa_id} , {index.integer_literal} : {type.type}'
# Memory Operations
class AllocOperation(DialectOp):
_syntax_ = 'alloc {args.dim_and_symbol_use_list} : {type.memref_type}'
class AllocStaticOperation(DialectOp):
_syntax_ = 'alloc_static ( {base.integer_literal} ) : {type.memref_type}'
class DeallocOperation(DialectOp):
_syntax_ = 'dealloc {arg.ssa_use} : {type.memref_type}'
class DmaStartOperation(DialectOp):
_syntax_ = [
'dma_start {src.ssa_use} [ {src_index.ssa_use_list} ] , {dst.ssa_use} [ {dst_index.ssa_use_list} ] , {size.ssa_use} , {tag.ssa_use} [ {tag_index.ssa_use_list} ] : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}',
'dma_start {src.ssa_use} [ {src_index.ssa_use_list} ] , {dst.ssa_use} [ {dst_index.ssa_use_list} ] , {size.ssa_use} , {tag.ssa_use} [ {tag_index.ssa_use_list} ] , {stride.ssa_use} , {transfer_per_stride.ssa_use} : {src_type.memref_type} , {dst_type.memref_type} , {tag_type.memref_type}'
]
class DmaWaitOperation(DialectOp):
_syntax_ = 'dma_wait {tag.ssa_use} [ {tag_index.ssa_use_list} ] , {size.ssa_use} : {type.memref_type}'
class ExtractElementOperation(DialectOp):
_syntax_ = 'extract_element {arg.ssa_use} [ {index.ssa_use_list} ] : {type.type}'
class LoadOperation(DialectOp):
_syntax_ = 'load {arg.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}'
class SplatOperation(DialectOp):
_syntax_ = 'splat {arg.ssa_use} : {type.type}' # (vector_type | tensor_type)
class StoreOperation(DialectOp):
_syntax_ = 'store {addr.ssa_use} , {ref.ssa_use} [ {index.ssa_use_list} ] : {type.memref_type}'
class TensorLoadOperation(DialectOp):
_syntax_ = 'tensor_load {arg.ssa_use} : {type.type}'
class TensorLoadOperation(DialectOp):
_syntax_ = 'tensor_store {src.ssa_use} , {dst.ssa_use} : {type.memref_type}'
# Unary Operations
class AbsfOperation(UnaryOperation): _opname_ = 'absf'
class CeilfOperation(UnaryOperation): _opname_ = 'ceilf'
class CosOperation(UnaryOperation): _opname_ = 'cos'
class ExpOperation(UnaryOperation): _opname_ = 'exp'
class NegfOperation(UnaryOperation): _opname_ = 'negf'
class TanhOperation(UnaryOperation): _opname_ = 'tanh'
class CopysignOperation(UnaryOperation): _opname_ = 'copysign'
# Arithmetic Operations
class AddiOperation(BinaryOperation): _opname_ = 'addi'
class AddfOperation(BinaryOperation): _opname_ = 'addf'
class AndOperation(BinaryOperation): _opname_ = 'and'
class DivisOperation(BinaryOperation): _opname_ = 'divis'
class DiviuOperation(BinaryOperation): _opname_ = 'diviu'
class RemisOperation(BinaryOperation): _opname_ = 'remis'
class RemiuOperation(BinaryOperation): _opname_ = 'remiu'
class DivfOperation(BinaryOperation): _opname_ = 'divf'
class MulfOperation(BinaryOperation): _opname_ = 'mulf'
class SubiOperation(BinaryOperation): _opname_ = 'subi'
class SubfOperation(BinaryOperation): _opname_ = 'subf'
class OrOperation(BinaryOperation): _opname_ = 'or'
class XorOperation(BinaryOperation): _opname_ = 'xor'
class CmpiOperation(DialectOp):
_syntax_ = 'cmpi {comptype.string_literal} , {operand_a.ssa_id} , {operand_b.ssa_id} : {type.type}'
class CmpfOperation(DialectOp):
_syntax_ = 'cmpf {comptype.string_literal} , {operand_a.ssa_id} , {operand_b.ssa_id} : {type.type}'
class ConstantOperation(DialectOp):
_syntax_ = 'constant {value.attribute_value} : {type.type}'
class MemrefCastOperation(DialectOp):
_syntax_ = 'memref_cast {arg.ssa_use} : {src_type.type} to {dst_type.type}'
class TensorCastOperation(DialectOp):
_syntax_ = 'tensor_cast {arg.ssa_use} : {src_type.type} to {dst_type.type}'
class SelectOperation(DialectOp):
_syntax_ = 'select {cond.ssa_use} , {arg_true.ssa_use} , {arg_false.ssa_use} : {type.type}'
class SubviewOperation(DialectOp):
_syntax_ = 'subview {operand.ssa_use} [ {offsets.ssa_use_list} ] [ {sizes.ssa_use_list} ] [ {strides.ssa_use_list} ] : {src_type.type} to {dst_type.type}'
# Inspect current module to get all classes defined above
standard = Dialect('standard', ops=[m[1] for m in inspect.getmembers(
sys.modules[__name__], lambda obj: is_op(obj, __name__))])

View File

@ -1,289 +0,0 @@
// Adapted from https://github.com/llvm/llvm-project/blob/5b4a01d4a63cb66ab981e52548f940813393bf42/mlir/docs/LangRef.md
// ----------------------------------------------------------------------
// Low-level literal syntax
digit : /[0-9]/
digits : /[0-9]+/
hex_digit : /[0-9a-fA-F]/
hex_digits : /[0-9a-fA-F]+/
letter : /[a-zA-Z]/
letters : /[a-zA-Z]+/
id_punct : /[$._-]/
underscore : /[_]/
true : "true"
false : "false"
id_chars: /[$.]/
bool_literal : true | false
decimal_literal : digits
hexadecimal_literal : "0x" hex_digits
integer_literal : decimal_literal | hexadecimal_literal
negated_integer_literal : "-" integer_literal
?posneg_integer_literal : integer_literal | negated_integer_literal
float_literal : /[-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?/
string_literal : ESCAPED_STRING
constant_literal : bool_literal | integer_literal | float_literal | string_literal
// Identifier syntax
bare_id : (letter| underscore) (letter|digit|underscore|id_chars)*
suffix_id : digits | bare_id
// Dimensions
dimension : "?" | decimal_literal
static_dimension_list : (decimal_literal "x")+
dimension_list_ranked : (dimension "x")*
dimension_list_unranked : "*" "x"
dimension_list : dimension_list_ranked | dimension_list_unranked
// ----------------------------------------------------------------------
// Identifiers
ssa_id : "%" suffix_id ("#" digits)?
symbol_ref_id : "@" (suffix_id | string_literal)
block_id : "^" suffix_id
type_alias : "!" (string_literal | bare_id)
map_or_set_id : "#" suffix_id
attribute_alias : "#" (string_literal | bare_id)
ssa_id_list : ssa_id ("," ssa_id)*
// Uses of an SSA value, e.g., in an operand list to an operation.
ssa_use : ssa_id | constant_literal
ssa_use_list : ssa_use ("," ssa_use)*
// ----------------------------------------------------------------------
// Types
// Standard types
none_type : "none"
!float_type : "f16" | "bf16" | "f32" | "f64"
index_type : "index"
integer_type : "i" /[1-9][0-9]*/ // Sized integers like i1, i4, i8, i16, i32.
complex_type : "complex" "<" type ">"
tuple_type : "tuple" "<" (type ( "," type)*) ">"
// Vector types
vector_element_type : float_type | integer_type
vector_type : "vector" "<" static_dimension_list vector_element_type ">"
// Tensor type
tensor_memref_element_type : vector_element_type | vector_type | complex_type | type_alias
ranked_tensor_type : "tensor" "<" dimension_list_ranked tensor_memref_element_type ">"
unranked_tensor_type : "tensor" "<" dimension_list_unranked tensor_memref_element_type ">"
tensor_type : ranked_tensor_type | unranked_tensor_type
// Memref type
stride_list : "[" (dimension ("," dimension)*)? "]"
strided_layout : "offset:" dimension "," "strides: " stride_list
layout_specification : semi_affine_map | strided_layout
memory_space : integer_literal // | TODO(mlir): address_space_id
ranked_memref_type : "memref" "<" dimension_list_ranked tensor_memref_element_type (("," layout_specification)? | ("," memory_space)?) ">"
unranked_memref_type : "memref" "<*x" tensor_memref_element_type ("," memory_space)? ">"
memref_type : ranked_memref_type | unranked_memref_type
// Dialect types - these can be opaque, pretty, or using custom dialects
opaque_dialect_item : bare_id "<" string_literal ">"
pretty_dialect_item : bare_id "." bare_id pretty_dialect_item_body?
pretty_dialect_item_body : "<" pretty_dialect_item_contents ("," pretty_dialect_item_contents)* ">"
?pretty_dialect_item_contents : ("(" pretty_dialect_item_contents ")")
| ("[" pretty_dialect_item_contents "]")
| ("{" pretty_dialect_item_contents "}")
| bare_id
| constant_literal
| type
// NOTE: "pymlir_dialect_types" is defined externally by pyMLIR
dialect_type : "!" (pymlir_dialect_types | opaque_dialect_item | pretty_dialect_item)
// Final type definition
standard_type : complex_type | float_type | function_type | index_type | integer_type | memref_type | none_type | tensor_type | tuple_type | vector_type
non_function_type : type_alias | complex_type | float_type | index_type | integer_type | memref_type | none_type | tensor_type | tuple_type | vector_type | dialect_type
type : type_alias | dialect_type | standard_type
// Uses of types
type_list_no_parens : type ("," type)*
type_list_parens : ("(" ")") | ("(" type_list_no_parens ")")
function_result_type : type_list_parens | type_list_no_parens | non_function_type
function_type : function_result_type ("->" | "to" | "into") function_result_type
ssa_use_and_type : ssa_use ":" type
ssa_use_and_type_list : ssa_use_and_type ("," ssa_use_and_type)*
// ----------------------------------------------------------------------
// Attributes
// Simple attribute types
array_attribute : "[" (attribute_value ("," attribute_value)*)? "]"
bool_attribute : bool_literal
dictionary_attribute : "{" (attribute_entry ("," attribute_entry)*)? "}"
?elements_attribute : dense_elements_attribute | opaque_elements_attribute | sparse_elements_attribute
float_attribute : (float_literal (":" float_type)?) | (hexadecimal_literal ":" float_type)
integer_attribute : posneg_integer_literal ( ":" (index_type | integer_type) )?
integer_set_attribute : affine_map
string_attribute : string_literal (":" type)?
symbol_ref_attribute : symbol_ref_id ("::" symbol_ref_id)*
type_attribute : type
unit_attribute : "unit"
// Elements attribute types
dense_elements_attribute : "dense" "<" attribute_value ">" ":" ( tensor_type | vector_type )
opaque_elements_attribute : "opaque" "<" bare_id "," hexadecimal_literal ">" ":" ( tensor_type | vector_type )
sparse_elements_attribute : "sparse" "<" attribute_value "," attribute_value ">" ":" ( tensor_type | vector_type )
// Standard attributes
standard_attribute : array_attribute | bool_attribute | dictionary_attribute | elements_attribute | float_attribute | integer_attribute | integer_set_attribute | string_attribute | symbol_ref_attribute | type_attribute | unit_attribute
// Attribute values
attribute_value : attribute_alias | dialect_attribute | standard_attribute
dependent_attribute_entry : bare_id "=" attribute_value
dialect_attribute_entry : (bare_id "." bare_id) | (bare_id "." bare_id "=" attribute_value)
// Dialect attributes
// NOTE: "pymlir_dialect_types" is defined externally by pyMLIR
dialect_attribute : "#" (pymlir_dialect_types | opaque_dialect_item | pretty_dialect_item)
// Attribute dictionaries
attribute_entry : dialect_attribute_entry | dependent_attribute_entry
attribute_dict : ("{" "}") | ("{" attribute_entry ("," attribute_entry)* "}")
// ----------------------------------------------------------------------
// Operations
// Types that appear after the operation, indicating return types
trailing_type : ":" (function_type | function_result_type)
// Operation results
op_result : ssa_id (":" integer_literal)?
op_result_list : op_result ("," op_result)* "="
// Trailing location (for debug information)
location : string_literal ":" decimal_literal ":" decimal_literal
trailing_location : ("loc" "(" location ")")
// Undefined operations in all dialects
generic_operation : string_literal "(" ssa_use_list? ")" attribute_dict? trailing_type
custom_operation : bare_id "." bare_id ssa_use_list? trailing_type
// Final operation definition
// NOTE: "pymlir_dialect_ops" is defined externally by pyMLIR
operation : op_result_list? (pymlir_dialect_ops | custom_operation | generic_operation) trailing_location?
// ----------------------------------------------------------------------
// Blocks and regions
// Block arguments
ssa_id_and_type : ssa_id ":" type
ssa_id_and_type_list : ssa_id_and_type ("," ssa_id_and_type)*
block_arg_list : "(" ssa_id_and_type_list? ")"
block_label : block_id block_arg_list? ":"
block : block_label* operation+
region : "{" block* "}"
// ----------------------------------------------------------------------
// Modules and functions
// Arguments
named_argument : ssa_id ":" type attribute_dict?
argument_list : (named_argument ("," named_argument)*) | (type attribute_dict? ("," type attribute_dict?)*)
function_signature : symbol_ref_id "(" argument_list? ")" ("->" function_result_list)?
// Return values
function_result : type attribute_dict?
function_result_list_no_parens : function_result ("," function_result)*
function_result_list_parens : ("(" ")") | ("(" function_result_list_no_parens ")")
function_result_list : function_result_list_parens | non_function_type
// Body
module_body : "{" (function | module | block)* "}"
?function_body : region
// Definition
module : "module" symbol_ref_id? ("attributes" attribute_dict)? module_body trailing_location?
function : "func" function_signature ("attributes" attribute_dict)? function_body? trailing_location?
// ----------------------------------------------------------------------
// (semi-)affine expressions, maps, and integer sets
dim_id_list : "(" bare_id? ("," bare_id)* ")"
symbol_id_list: "[" bare_id? ("," bare_id)* "]"
dim_and_symbol_id_lists : dim_id_list symbol_id_list?
?symbol_or_const : posneg_integer_literal | ssa_id | bare_id
?dim_use_list : "(" ssa_use_list? ")"
?symbol_use_list : "[" ssa_use_list? "]"
dim_and_symbol_use_list : dim_use_list symbol_use_list?
affine_expr : "(" affine_expr ")" -> affine_parens
| affine_expr "+" affine_expr -> affine_add
| affine_expr "-" affine_expr -> affine_sub
| posneg_integer_literal "*" affine_expr -> affine_mul
| affine_expr "*" posneg_integer_literal -> affine_mul
| affine_expr "&ceildiv&" integer_literal -> affine_ceildiv
| affine_expr "&floordiv&" integer_literal -> affine_floordiv
| affine_expr "&mod&" integer_literal -> affine_mod
| "-" affine_expr -> affine_neg
| "symbol" "(" ssa_id ")" -> affine_symbol_explicit
| posneg_integer_literal -> affine_literal
| ssa_id -> affine_ssa
| bare_id -> affine_symbol
semi_affine_expr : "(" semi_affine_expr ")" -> semi_affine_parens
| semi_affine_expr "+" semi_affine_expr -> semi_affine_add
| semi_affine_expr "-" semi_affine_expr -> semi_affine_sub
| symbol_or_const "*" semi_affine_expr -> semi_affine_mul
| semi_affine_expr "*" symbol_or_const -> semi_affine_mul
| semi_affine_expr "&ceildiv&" semi_affine_oprnd -> semi_affine_ceildiv
| semi_affine_expr "&floordiv&" semi_affine_oprnd -> semi_affine_floordiv
| semi_affine_expr "&mod&" semi_affine_oprnd -> semi_affine_mod
| "symbol" "(" symbol_or_const ")" -> semi_affine_symbol_explicit
| symbol_or_const -> semi_affine_symbol
// Second operand for floordiv/ceildiv/mod in semi-affine expressions
?semi_affine_oprnd : symbol_or_const
| "(" semi_affine_expr ")" -> semi_affine_parens
multi_dim_affine_expr_no_parens : affine_expr ("," affine_expr)*
multi_dim_affine_expr : "(" multi_dim_affine_expr_no_parens ")"
multi_dim_semi_affine_expr : "(" semi_affine_expr ("," semi_affine_expr)* ")"
affine_constraint : affine_expr ">=" "0" -> affine_constraint_ge
| affine_expr "==" "0" -> affine_constraint_eq
affine_constraint_conjunction : affine_constraint ("," affine_constraint)*
affine_map_inline : dim_and_symbol_id_lists "->" multi_dim_affine_expr
semi_affine_map_inline : dim_and_symbol_id_lists "->" multi_dim_semi_affine_expr
integer_set_inline : dim_and_symbol_id_lists ":" "(" affine_constraint_conjunction? ")"
// Definition of maps and sets
affine_map : map_or_set_id | affine_map_inline
semi_affine_map : map_or_set_id | semi_affine_map_inline
integer_set : map_or_set_id | integer_set_inline
// ----------------------------------------------------------------------
// General structure and top-level definitions
// Definitions of affine maps/integer sets/aliases are at the top of the file
type_alias_def : type_alias "=" "type" type
affine_map_def : map_or_set_id "=" affine_map_inline
semi_affine_map_def : map_or_set_id "=" semi_affine_map_inline
integer_set_def : map_or_set_id "=" integer_set_inline
attribute_alias_def : attribute_alias "=" attribute_value
?definition : type_alias_def | affine_map_def | semi_affine_map_def | integer_set_def | attribute_alias_def
?start : value*
?value : module
| definition
| function
// Lark imports
%import common.ESCAPED_STRING
%import common.SIGNED_NUMBER
%import common.WS
%import common.NEWLINE
// Things to ignore: whitespace, single-line comments
%ignore WS
COMMENT : "//" /(.)+/ NEWLINE
%ignore COMMENT

View File

@ -1,180 +0,0 @@
""" Contains classes that parse MLIR files """
import itertools
from lark import Lark, Tree
import os
import sys
from typing import List, Optional, TextIO
import runpy
from mlir.parser_transformer import TreeToMlir
from mlir.dialect import Dialect, add_dialect_rules
from mlir.dialects import STANDARD_DIALECTS
from mlir import astnodes as mast
class Parser(object):
"""
A reusable pyMLIR parser. Parses multiple strings faster than repeatedly
calling ``mlir.parse_*``.
"""
def __init__(self, dialects: Optional[List[Dialect]] = None):
"""
Initializes a reusable pyMLIR parser.
:param dialects: An optional list of additional dialects to load (in
addition to the built-in dialects).
"""
self.dialects = dialects or []
# Lazy-load (if necessary) the Lark files
_lazy_load()
# Initialize EBNF source for parser
parser_src = _MLIR_LARK + '\n'
# Check validity of given dialects
dialects = dialects or []
builtin_names = [dialect.name for dialect in STANDARD_DIALECTS]
additional_names = [dialect.name for dialect in dialects]
dialect_set = set(builtin_names) | set(additional_names)
if len(dialect_set) != (len(dialects) + len(STANDARD_DIALECTS)):
raise NameError(
'Additional dialect already exists (built-in dialects: %s, '
'given dialects: %s)' % (builtin_names, additional_names))
# Add dialect contents to parser
rule_dict_ops = {}
rule_dict_types = {}
for dialect in itertools.chain(STANDARD_DIALECTS, dialects):
# Add preamble for dialect
parser_src += dialect.contents
# Add rules for operations and types
parser_src += add_dialect_rules(dialect, dialect.ops, 'op',
rule_dict_ops)
parser_src += add_dialect_rules(dialect, dialect.types, 'type',
rule_dict_types)
# Create a parser from the MLIR EBNF file, default dialects, and
# additional dialects if exist
op_expr = '?pymlir_dialect_ops: ' + '|'.join(rule_dict_ops.keys())
type_expr = '?pymlir_dialect_types: ' + '|'.join(
rule_dict_types.keys())
parser_src += op_expr + '\n' + type_expr
# Create parser and tree transformer
self.parser = Lark(parser_src, parser='earley')
self.transformer = TreeToMlir()
# Add dialect rules to transformer
for rule_name, ctor in itertools.chain(rule_dict_ops.items(),
rule_dict_types.items()):
setattr(self.transformer, rule_name, ctor)
for dialect in itertools.chain(STANDARD_DIALECTS, dialects):
for rule_name, rule in dialect.transformers.items():
setattr(self.transformer, rule_name, rule)
def parse(self, code: str) -> mast.Module:
"""
Parses a string representing code in MLIR, returning the top-level
AST node.
:param code: A code string in MLIR format.
:return: A module node representing the root of the AST.
"""
# Pre-transform code to avoid parsing issues with ceildiv/floordiv/mod,
# in which two symbols could be parsed as one legal symbol (due to
# ignoring whitespace): "d0floordivs0"
code = code.replace(' floordiv ', '&floordiv&')
code = code.replace(' ceildiv ', '&ceildiv&')
code = code.replace(' mod ', '&mod&')
# Parse the code using Lark
tree = self.parser.parse(code)
# Transform the tree to our AST node classes
root_node = self.transformer.transform(tree)
# If the root node is a function/definition or a list thereof, return
# a top-level module
if not isinstance(root_node, mast.Module):
if isinstance(root_node, Tree) and root_node.data == 'start':
return mast.Module([root_node])
return mast.Module(root_node)
return root_node
# Load the MLIR EBNF syntax to memory once
_MLIR_LARK = None
def _lazy_load():
"""
Loads the Lark EBNF files (MLIR and default dialects) into memory upon
first use.
"""
global _MLIR_LARK
# Lazily load the MLIR EBNF file and the dialects
if _MLIR_LARK is None:
# Find path to files
mlir_path = os.path.join(
os.path.abspath(os.path.dirname(__file__)), 'lark')
with open(os.path.join(mlir_path, 'mlir.lark'), 'r') as fp:
_MLIR_LARK = fp.read()
def parse_string(code: str,
dialects: Optional[List[Dialect]] = None) -> mast.Module:
"""
Parses a string representing code in MLIR, returning the top-level AST node.
:param code: A code string in MLIR format.
:param dialects: An optional list of additional dialects to load (in
addition to the built-in dialects).
:return: A module node representing the root of the AST.
"""
parser = Parser(dialects)
return parser.parse(code)
def parse_file(file: TextIO,
dialects: Optional[List[Dialect]] = None) -> mast.Node:
"""
Parses an MLIR file from a given Python file-like object, returning the
top-level AST node.
:param file: Python file-like I/O object in text mode.
:param dialects: An optional list of additional dialects to load (in
addition to the built-in dialects).
:return: A module node representing the root of the AST.
"""
return parse_string(file.read(), dialects)
def parse_path(file_path: str,
dialects: Optional[List[Dialect]] = None) -> mast.Node:
"""
Parses an MLIR file from a given filename, returning the top-level AST node.
:param file_path: Path to file to parse.
:param dialects: An optional list of additional dialects to load (in
addition to the built-in dialects).
:return: A module node representing the root of the AST.
"""
with open(file_path, 'r') as fp:
return parse_file(fp, dialects)
if __name__ == '__main__':
if len(sys.argv) < 2:
print('USAGE: python -m mlir.parser <MLIR FILE> [DIALECT PATHS...]')
exit(1)
additional_dialects = []
for dialect_path in sys.argv[2:]:
# Load Python file with dialect
global_vars = runpy.run_path(dialect_path)
additional_dialects.extend(
v for v in global_vars.values() if isinstance(v, Dialect))
print(parse_path(sys.argv[1], dialects=additional_dialects).pretty())

View File

@ -1,223 +0,0 @@
from lark import v_args, Transformer
from mlir import astnodes
class TreeToMlir(Transformer):
###############################################################
# Low-level literal syntax
digit = lambda self, val: int(val[0])
digits = lambda self, val: int(val[0])
hex_digit = lambda self, val: str(val[0])
hex_digits = lambda self, val: str(val[0])
letter = lambda self, val: str(val[0])
letters = lambda self, val: str(val[0])
id_punct = lambda self, val: str(val[0])
underscore = lambda self, val: str(val[0])
true = lambda self, _: True
false = lambda self, _: False
id_chars = lambda self, val: str(val[0])
dimension = astnodes.Dimension
# Literals
@v_args(inline=True)
def decimal_literal(self, *digits):
return int(''.join(str(d) for d in digits))
@v_args(inline=True)
def hexadecimal_literal(self, *digits):
return '0x' + ''.join(digits)
negated_integer_literal = lambda self, value: -value[0]
float_literal = lambda self, value: float(value[0])
@v_args(inline=True)
def string_literal(self, s):
return astnodes.StringLiteral(s[1:-1].replace('\\"', '"'))
@v_args(inline=True)
def bare_id(self, *elements):
return ''.join(str(s) for s in elements)
@v_args(inline=True)
def suffix_id(self, *suffix):
return ''.join(str(s) for s in suffix)
###############################################################
# MLIR Identifiers
ssa_id = astnodes.SsaId
symbol_ref_id = astnodes.SymbolRefId
block_id = astnodes.BlockId
type_alias = astnodes.TypeAlias
attribute_alias = astnodes.AttrAlias
map_or_set_id = astnodes.MapOrSetId
###############################################################
# MLIR Types
none_type = astnodes.NoneType
f16 = lambda self, _: "f16"
bf16 = lambda self, _: "bf16"
f32 = lambda self, _: "f32"
f64 = lambda self, _: "f64"
float_type = astnodes.FloatType
index_type = astnodes.IndexType
integer_type = astnodes.IntegerType
complex_type = astnodes.ComplexType
tuple_type = astnodes.TupleType
vector_type = astnodes.VectorType
ranked_tensor_type = astnodes.RankedTensorType
unranked_tensor_type = astnodes.UnrankedTensorType
ranked_memref_type = astnodes.RankedMemRefType
unranked_memref_type = astnodes.UnrankedMemRefType
opaque_dialect_item = astnodes.OpaqueDialectType
pretty_dialect_item = astnodes.PrettyDialectType
function_type = astnodes.FunctionType
strided_layout = astnodes.StridedLayout
###############################################################
# MLIR Attributes
array_attribute = astnodes.ArrayAttr
bool_attribute = astnodes.BoolAttr
dictionary_attribute = astnodes.DictionaryAttr
dense_elements_attribute = astnodes.DenseElementsAttr
opaque_elements_attribute = astnodes.OpaqueElementsAttr
sparse_elements_attribute = astnodes.SparseElementsAttr
float_attribute = astnodes.FloatAttr
integer_attribute = astnodes.IntegerAttr
integer_set_attribute = astnodes.IntSetAttr
string_attribute = astnodes.StringAttr
symbol_ref_attribute = astnodes.SymbolRefAttr
type_attribute = astnodes.TypeAttr
unit_attribute = astnodes.UnitAttr
dependent_attribute_entry = astnodes.AttributeEntry
dialect_attribute_entry = astnodes.DialectAttributeEntry
attribute_dict = astnodes.AttributeDict
###############################################################
# Operations
op_result = astnodes.OpResult
location = astnodes.FileLineColLoc
operation = astnodes.Operation
generic_operation = astnodes.GenericOperation
custom_operation = astnodes.CustomOperation
###############################################################
# Blocks, regions, modules, functions
block_label = astnodes.BlockLabel
block = astnodes.Block
region = astnodes.Region
module = astnodes.Module
function = astnodes.Function
named_argument = astnodes.NamedArgument
###############################################################
# (semi-)Affine expressions, maps, and integer sets
dim_and_symbol_id_lists = astnodes.DimAndSymbolList
dim_and_symbol_use_list = astnodes.DimAndSymbolList
affine_expr = astnodes.AffineExpr
semi_affine_expr = astnodes.SemiAffineExpr
multi_dim_affine_expr = astnodes.MultiDimAffineExpr
multi_dim_semi_affine_expr = astnodes.MultiDimSemiAffineExpr
affine_constraint_ge = astnodes.AffineConstraintGreaterEqual
affine_constraint_eq = astnodes.AffineConstraintEqual
affine_map_inline = astnodes.AffineMap
semi_affine_map_inline = astnodes.SemiAffineMap
integer_set_inline = astnodes.IntSet
affine_neg = astnodes.AffineNeg
semi_affine_neg = astnodes.AffineNeg
affine_parens = astnodes.AffineParens
semi_affine_parens = astnodes.AffineParens
affine_symbol_explicit = astnodes.AffineExplicitSymbol
semi_affine_symbol_explicit = astnodes.AffineExplicitSymbol
affine_add = astnodes.AffineAdd
semi_affine_add = astnodes.AffineAdd
affine_sub = astnodes.AffineSub
semi_affine_sub = astnodes.AffineSub
affine_mul = astnodes.AffineMul
semi_affine_mul = astnodes.AffineMul
affine_floordiv = astnodes.AffineFloorDiv
semi_affine_floordiv = astnodes.AffineFloorDiv
affine_ceildiv = astnodes.AffineCeilDiv
semi_affine_ceildiv = astnodes.AffineCeilDiv
affine_mod = astnodes.AffineMod
semi_affine_mod = astnodes.AffineMod
###############################################################
# Top-level definitions
type_alias_def = astnodes.TypeAliasDef
affine_map_def = astnodes.AffineMapDef
semi_affine_map_def = astnodes.SemiAffineMapDef
integer_set_def = astnodes.IntSetDef
attribute_alias_def = astnodes.AttrAliasDef
###############################################################
# List types
bare_id_list = list
ssa_id_list = list
ssa_use_list = list
op_result_list = list
successor_list = list
function_body = list
ssa_id_and_type_list = list
block_arg_list = list
ssa_use_and_type_list = list
stride_list = list
dimension_list_ranked = list
static_dimension_list = list
pretty_dialect_item_body = list
type_list_no_parens = list
affine_constraint_conjunction = list
function_result_list_no_parens = list
multi_dim_affine_expr_no_parens = list
dim_id_list = list
symbol_id_list = list
dim_use_list = list
symbol_use_list = list
###############################################################
# Composite types that should be reduced to sub-types
bool_literal = lambda self, value: value[0]
integer_literal = lambda self, value: value[0]
constant_literal = lambda self, value: value[0]
dimension_list = lambda self, value: value[0]
ssa_use = lambda self, value: value[0]
vector_element_type = lambda self, value: value[0]
tensor_memref_element_type = lambda self, value: value[0]
tensor_type = lambda self, value: value[0]
memref_type = lambda self, value: value[0]
standard_type = lambda self, value: value[0]
dialect_type = lambda self, value: value[0]
non_function_type = lambda self, value: value[0]
type = lambda self, value: value[0]
type_list_parens = lambda self, value: (value[0] if value else [])
function_result_type = lambda self, value: value[0]
standard_attribute = lambda self, value: value[0]
attribute_value = lambda self, value: value[0]
dialect_attribute = lambda self, value: value[0]
attribute_entry = lambda self, value: value[0]
trailing_type = lambda self, value: value[0]
trailing_location = lambda self, value: value[0]
function_result_list_parens = lambda self, value: value[0]
symbol_or_const = lambda self, value: value[0]
affine_map = lambda self, value: value[0]
semi_affine_map = lambda self, value: value[0]
integer_set = lambda self, value: value[0]
affine_literal = lambda self, value: value[0]
semi_affine_literal = lambda self, value: value[0]
affine_ssa = lambda self, value: value[0]
affine_symbol = lambda self, value: value[0]
semi_affine_symbol = lambda self, value: value[0]
# Dialect ops and types are appended to this list via "setattr"

View File

@ -1,80 +0,0 @@
""" Classes containing MLIR AST traversal and transformation functionality. """
from mlir import astnodes
def iter_fields(node: astnodes.Node):
"""
Iterates over the fields of an MLIR AST node. Yields a two-tuple of
(name, value).
:param node: The AST node to iterate over.
"""
for field in node._fields_:
yield field, getattr(node, field)
class NodeVisitor(object):
"""
A node visitor class that follows the API and features of ast.NodeVisitor.
The visitor walks the MLIR AST and calls a visitor function for every node
type. The visit function may return a value which is forwarded by internal
calls.
To create a node visitor, implement a sub-class of this class by adding
methods called ``visit_NODETYPE``, where ``NODETYPE`` is the class name
of an MLIR AST (or dialect) node type. For example, to implement a hook when
an operation (``mlir.astnodes.Operation``) is encountered, create a method
called ``visit_Operation``.
See ``ast.NodeVisitor`` for more information.
"""
def visit(self, node: astnodes.Node):
""" Visit a node. """
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
def generic_visit(self, node: astnodes.Node):
""" Called if no explicit visitor function exists for a node. """
for field, value in iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, astnodes.Node):
self.visit(item)
elif isinstance(value, astnodes.Node):
self.visit(value)
class NodeTransformer(NodeVisitor):
"""
A ``NodeVisitor`` subclass that can modify and remove AST nodes.
The interface and usage of this class follows ``ast.NodeTransformer``. See
its documentation for more information.
"""
def generic_visit(self, node: astnodes.Node):
"""
Called if no explicit visitor function exists for a node.
Implements modification and removal of list elements in fields.
"""
for field, old_value in iter_fields(node):
if isinstance(old_value, list):
new_values = []
for value in old_value:
if isinstance(value, astnodes.Node):
value = self.visit(value)
if value is None:
continue
elif not isinstance(value, astnodes.Node):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
elif isinstance(old_value, astnodes.Node):
new_node = self.visit(old_value)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return node

View File

@ -1,2 +0,0 @@
lark-parser==0.7.8
parse==1.14.0

View File

@ -1,32 +0,0 @@
from setuptools import setup, find_packages
with open("README.md", "r") as fp:
long_description = fp.read()
setup(
name='pymlir',
version='0.3',
url='https://github.com/spcl/pymlir',
author='SPCL @ ETH Zurich',
author_email='talbn@inf.ethz.ch',
description='',
long_description=long_description,
long_description_content_type='text/markdown',
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
],
python_requires='>=3.6',
packages=find_packages(
exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
package_data={
'': ['lark/mlir.lark']
},
include_package_data=True,
install_requires=[
'lark-parser', 'parse'
],
tests_require=['pytest', 'pytest-cov'],
test_suite='pytest',
scripts=[])

View File

@ -1,84 +0,0 @@
""" Test that creates and uses a custom dialect. """
from mlir import parse_string
from mlir.astnodes import Node, dump_or_value
from mlir.dialect import Dialect, DialectOp, DialectType
##############################################################################
# Dialect Types
class RaggedTensorType(DialectType):
"""
AST node class for the example "toy" dialect representing a ragged tensor.
"""
_syntax_ = 'toy.ragged < {implementation.toy_impl_list} , {dims.dimension_list_ranked} {type.tensor_memref_element_type} >'
# Custom MLIR serialization implementation
def dump(self, indent: int = 0) -> str:
return '!toy.ragged<%s, %sx%s>' % (
dump_or_value(self.implementation, indent),
'x'.join(dump_or_value(d, indent) for d in self.dims),
dump_or_value(self.type, indent)
)
class ToyImplementation(Node):
""" Base "toy" implementation AST node. Corresponds to a "+"-separated list
of sparse tensor types.
"""
_fields_ = ['values']
def __init__(self, node=None, **fields):
self.values = node
super().__init__(None, **fields)
def dump(self, indent: int = 0) -> str:
return '+'.join(dump_or_value(v, indent) for v in self.values)
##############################################################################
# Dialect Operations
class DensifyOp(DialectOp):
""" AST node for an operation with an optional value. """
_syntax_ = ['toy.densify {arg.ssa_id} : {type.tensor_type}',
'toy.densify {arg.ssa_id} , {pad.constant_literal} : {type.tensor_type}']
##############################################################################
# Dialect
my_dialect = Dialect('toy', ops=[DensifyOp], types=[RaggedTensorType],
preamble='''
// Exclamation mark in Lark means that string tokens will be preserved upon parsing
!toy_impl_type : "coo" | "csr" | "csc" | "ell"
toy_impl_list : toy_impl_type ("+" toy_impl_type)*
''',
transformers=dict(
toy_impl_list=ToyImplementation,
# Will convert every instance to its contents
toy_impl_type=lambda v: v[0]
))
##############################################################################
# Tests
def test_custom_dialect():
code = '''module {
func @toy_test(%ragged: !toy.ragged<coo+csr, 32x14xf64>) -> tensor<32x14xf64> {
%t_tensor = toy.densify %ragged : tensor<32x14xf64>
return %t_tensor : tensor<32x14xf64>
}
}'''
m = parse_string(code, dialects=[my_dialect])
dump = m.pretty()
print(dump)
# Test for round-trip
assert dump == code
if __name__ == '__main__':
test_custom_dialect()

View File

@ -1,55 +0,0 @@
""" Tests pyMLIR in a parse->dump->parse round-trip. """
from mlir import parse_string
def test_toy_roundtrip():
"""
Create MLIR code without extra whitespace and check that it can parse
and dump the same way.
"""
code = '''module {
func @toy_func(%arg0: tensor<2x3xf64>) -> tensor<3x2xf64> {
%0 = "toy.transpose"(%arg0) {inplace = true} : (tensor<2x3xf64>) -> tensor<3x2xf64>
return %0 : tensor<3x2xf64>
}
}'''
module = parse_string(code)
dump = module.dump()
assert dump == code
def test_affine_expr_roundtrip():
"""
Create affine maps, semi-affine maps, and integer sets, checking for
correct parsing.
"""
code = '''#map0 = (d0, d1) -> (d0, d1)
#map1 = (d0) -> (d0)
#map2 = () -> (0)
#map3 = () -> (10)
#map4 = (d0, d1, d2) -> (d0, d1 + d2 + 5)
#map5 = (d0, d1, d2) -> (d0 + d1, d2)
#map6 = (d0, d1)[s0] -> (d0, d1 + s0 + 7)
#map7 = (d0, d1)[s0] -> (d0 + s0, d1)
#map8 = (d0, d1) -> (d0 + d1 + 11)
#map9 = (d0, d1)[s0] -> (d0, (d1 + s0) mod 9 + 7)
#map10 = (d0, d1)[s0] -> ((d0 + s0) floordiv 3, d1)
#samap0 = (d0)[s0] -> (d0 floordiv (s0 + 1))
#samap1 = (d0)[s0] -> (d0 floordiv s0)
#samap2 = (d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)
#set0 = (d0) : (1 == 0)
#set1 = (d0, d1)[s0] : ()
#set2 = (d0, d1)[s0, s1] : (d0 >= 0, -d0 + s0 - 1 >= 0, d1 >= 0, -d1 + s1 - 1 >= 0)
#set3 = (d0, d1, d2) : (d0 - d2 * 4 == 0, d0 + d1 * 8 - 9 >= 0, -d0 - d1 * 8 + 11 >= 0)
#set4 = (d0, d1, d2, d3, d4, d5) : (d0 * 1089234 + d1 * 203472 + 82342 >= 0, d0 * -55 + d1 * 24 + d2 * 238 - d3 * 234 - 9743 >= 0, d0 * -5445 - d1 * 284 + d2 * 23 + d3 * 34 - 5943 >= 0, d0 * -5445 + d1 * 284 + d2 * 238 - d3 * 34 >= 0, d0 * 445 + d1 * 284 + d2 * 238 + d3 * 39 >= 0, d0 * -545 + d1 * 214 + d2 * 218 - d3 * 94 >= 0, d0 * 44 - d1 * 184 - d2 * 231 + d3 * 14 >= 0, d0 * -45 + d1 * 284 + d2 * 138 - d3 * 39 >= 0, d0 * 154 - d1 * 84 + d2 * 238 - d3 * 34 >= 0, d0 * 54 - d1 * 284 - d2 * 223 + d3 * 384 >= 0, d0 * -55 + d1 * 284 + d2 * 23 + d3 * 34 >= 0, d0 * 54 - d1 * 84 + d2 * 28 - d3 * 34 >= 0, d0 * 54 - d1 * 24 - d2 * 23 + d3 * 34 >= 0, d0 * -55 + d1 * 24 + d2 * 23 + d3 * 4 >= 0, d0 * 15 - d1 * 84 + d2 * 238 - d3 * 3 >= 0, d0 * 5 - d1 * 24 - d2 * 223 + d3 * 84 >= 0, d0 * -5 + d1 * 284 + d2 * 23 - d3 * 4 >= 0, d0 * 14 + d2 * 4 + 7234 >= 0, d0 * -174 - d2 * 534 + 9834 >= 0, d0 * 194 - d2 * 954 + 9234 >= 0, d0 * 47 - d2 * 534 + 9734 >= 0, d0 * -194 - d2 * 934 + 984 >= 0, d0 * -947 - d2 * 953 + 234 >= 0, d0 * 184 - d2 * 884 + 884 >= 0, d0 * -174 + d2 * 834 + 234 >= 0, d0 * 844 + d2 * 634 + 9874 >= 0, d2 * -797 - d3 * 79 + 257 >= 0, d0 * 2039 + d2 * 793 - d3 * 99 - d4 * 24 + d5 * 234 >= 0, d2 * 78 - d5 * 788 + 257 >= 0, d3 - (d5 + d0 * 97) floordiv 423 >= 0, ((d0 + (d3 mod 5) floordiv 2342) * 234) mod 2309 + (d0 + d3 * 2038) floordiv 208 >= 0, ((((d0 + d3 * 2300) * 239) floordiv 2342) mod 2309) mod 239423 == 0, d0 + d3 mod 2642 + (((((d3 + d0 * 2) mod 1247) mod 2038) mod 2390) mod 2039) floordiv 55 >= 0)'''
module = parse_string(code)
dump = '\n'.join(definition.dump() for definition in module.body)
assert dump == code
if __name__ == '__main__':
test_toy_roundtrip()
test_affine_expr_roundtrip()

View File

@ -1,215 +0,0 @@
""" Tests pyMLIR on different syntactic edge-cases. """
from mlir import Parser
from typing import Optional
def test_attributes(parser: Optional[Parser] = None):
code = '''
module {
func @myfunc(%tensor: tensor<256x?xf64>) -> tensor<*xf64> {
%t_tensor = "with_attributes"(%tensor) { inplace = true, abc = -123, bla = unit, hello_world = "hey", value=@this::@is::@hierarchical, somelist = ["of", "values"], last = {butnot = "least", dictionaries = 0xabc} } : (tensor<2x3xf64>) -> tuple<vector<3xi33>,tensor<2x3xf64>>
return %t_tensor : tensor<3x2xf64>
}
func @toy_func(%arg0: tensor<2x3xf64>) -> tensor<3x2xf64> {
%0:2 = "toy.split"(%arg0) : (tensor<2x3xf64>) -> (tensor<3x2xf64>, f32)
return %0#50 : tensor<3x2xf64>
}
}
'''
parser = parser or Parser()
module = parser.parse(code)
print(module.pretty())
def test_memrefs(parser: Optional[Parser] = None):
code = '''
module {
func @myfunc() {
%a, %b = "tensor_replicator"(%tensor, %tensor) : (memref<?xbf16, 2>,
memref<?xf32, offset: 5, strides: [6, 7]>,
memref<*xf32, 8>)
}
}
'''
parser = parser or Parser()
module = parser.parse(code)
print(module.pretty())
def test_trailing_loc(parser: Optional[Parser] = None):
code = '''
module {
func @myfunc() {
%c:2 = addf %a, %b : f32 loc("test_syntax.py":36:59)
}
} loc("hi.mlir":30:1)
'''
parser = parser or Parser()
module = parser.parse(code)
print(module.pretty())
def test_modules(parser: Optional[Parser] = None):
code = '''
module {
module {
}
module {
}
module attributes {foo.attr = true} {
}
module {
%1 = "foo.result_op"() : () -> i32
}
module {
}
%0 = "op"() : () -> i32
module @foo {
module {
module @bar attributes {foo.bar} {
}
}
}
}'''
parser = parser or Parser()
module = parser.parse(code)
print(module.pretty())
def test_functions(parser: Optional[Parser] = None):
code = '''
module {
func @myfunc_a() {
%c:2 = addf %a, %b : f32
}
func @myfunc_b() {
%d:2 = addf %a, %b : f64
^e:
%f:2 = addf %d, %d : f64
}
}'''
parser = parser or Parser()
module = parser.parse(code)
print(module.pretty())
def test_toplevel_function(parser: Optional[Parser] = None):
code = '''
func @toy_func(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> {
%t_tensor = "toy.transpose"(%tensor) { inplace = true } : (tensor<2x3xf64>) -> tensor<3x2xf64>
return %t_tensor : tensor<3x2xf64>
}'''
parser = parser or Parser()
module = parser.parse(code)
print(module.pretty())
def test_toplevel_functions(parser: Optional[Parser] = None):
code = '''
func @toy_func(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> {
%t_tensor = "toy.transpose"(%tensor) { inplace = true } : (tensor<2x3xf64>) -> tensor<3x2xf64>
return %t_tensor : tensor<3x2xf64>
}
func @toy_func(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> {
%t_tensor = "toy.transpose"(%tensor) { inplace = true } : (tensor<2x3xf64>) -> tensor<3x2xf64>
return %t_tensor : tensor<3x2xf64>
}'''
parser = parser or Parser()
module = parser.parse(code)
print(module.pretty())
def test_affine(parser: Optional[Parser] = None):
code = '''
func @empty() {
affine.for %i = 0 to 10 {
} {some_attr = true}
%0 = affine.min (d0)[s0] -> (1000, d0 + 512, s0) (%arg0)[%arg1]
}
func @valid_symbols(%arg0: index, %arg1: index, %arg2: index) {
%c0 = constant 1 : index
%c1 = constant 0 : index
%b = alloc()[%N] : memref<4x4xf32, (d0, d1)[s0] -> (d0, d0 + d1 + s0 floordiv 2)>
%0 = alloc(%arg0, %arg1) : memref<?x?xf32>
affine.for %arg3 = %arg1 to %arg2 step 768 {
%13 = dim %0, 1 : memref<?x?xf32>
affine.for %arg4 = 0 to %13 step 264 {
%18 = dim %0, 0 : memref<?x?xf32>
%20 = subview %0[%c0, %c0][%18,%arg4][%c1,%c1] : memref<?x?xf32>
to memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
%24 = dim %20, 0 : memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
affine.for %arg5 = 0 to %24 step 768 {
"foo"() : () -> ()
}
}
}
return
}
'''
parser = parser or Parser()
module = parser.parse(code)
print(module.pretty())
def test_definitions(parser: Optional[Parser] = None):
code = '''
#map0 = (d0, d1) -> (d0, d1)
#map1 = (d0) -> (d0)
#map2 = () -> (0)
#map3 = () -> (10)
#map4 = (d0, d1, d2) -> (d0, d1 + d2 + 5)
#map5 = (d0, d1, d2) -> (d0 + d1, d2)
#map6 = (d0, d1)[s0] -> (d0, d1 + s0 + 7)
#map7 = (d0, d1)[s0] -> (d0 + s0, d1)
#map8 = (d0, d1) -> (d0 + d1 + 11)
#map9 = (d0, d1)[s0] -> (d0, (d1 + s0) mod 9 + 7)
#map10 = (d0, d1)[s0] -> ((d0 + s0) floordiv 3, d1)
#samap0 = (d0)[s0] -> (d0 floordiv (s0 + 1))
#samap1 = (d0)[s0] -> (d0 floordiv s0)
#samap2 = (d0, d1)[s0, s1] -> (d0*s0 + d1*s1)
#set0 = (d0) : (1 == 0)
#set1 = (d0, d1)[s0] : ()
#set2 = (d0, d1)[s0, s1] : (d0 >= 0, -d0 + s0 - 1 >= 0, d1 >= 0, -d1 + s1 - 1 >= 0)
#set3 = (d0, d1, d2) : (d0 - d2 * 4 == 0, d0 + d1 * 8 - 9 >= 0, -d0 - d1 * 8 + 11 >= 0)
#set4 = (d0, d1, d2, d3, d4, d5) : (d0 * 1089234 + d1 * 203472 + 82342 >= 0, d0 * -55 + d1 * 24 + d2 * 238 - d3 * 234 - 9743 >= 0, d0 * -5445 - d1 * 284 + d2 * 23 + d3 * 34 - 5943 >= 0, d0 * -5445 + d1 * 284 + d2 * 238 - d3 * 34 >= 0, d0 * 445 + d1 * 284 + d2 * 238 + d3 * 39 >= 0, d0 * -545 + d1 * 214 + d2 * 218 - d3 * 94 >= 0, d0 * 44 - d1 * 184 - d2 * 231 + d3 * 14 >= 0, d0 * -45 + d1 * 284 + d2 * 138 - d3 * 39 >= 0, d0 * 154 - d1 * 84 + d2 * 238 - d3 * 34 >= 0, d0 * 54 - d1 * 284 - d2 * 223 + d3 * 384 >= 0, d0 * -55 + d1 * 284 + d2 * 23 + d3 * 34 >= 0, d0 * 54 - d1 * 84 + d2 * 28 - d3 * 34 >= 0, d0 * 54 - d1 * 24 - d2 * 23 + d3 * 34 >= 0, d0 * -55 + d1 * 24 + d2 * 23 + d3 * 4 >= 0, d0 * 15 - d1 * 84 + d2 * 238 - d3 * 3 >= 0, d0 * 5 - d1 * 24 - d2 * 223 + d3 * 84 >= 0, d0 * -5 + d1 * 284 + d2 * 23 - d3 * 4 >= 0, d0 * 14 + d2 * 4 + 7234 >= 0, d0 * -174 - d2 * 534 + 9834 >= 0, d0 * 194 - d2 * 954 + 9234 >= 0, d0 * 47 - d2 * 534 + 9734 >= 0, d0 * -194 - d2 * 934 + 984 >= 0, d0 * -947 - d2 * 953 + 234 >= 0, d0 * 184 - d2 * 884 + 884 >= 0, d0 * -174 + d2 * 834 + 234 >= 0, d0 * 844 + d2 * 634 + 9874 >= 0, d2 * -797 - d3 * 79 + 257 >= 0, d0 * 2039 + d2 * 793 - d3 * 99 - d4 * 24 + d5 * 234 >= 0, d2 * 78 - d5 * 788 + 257 >= 0, d3 - (d5 + d0 * 97) floordiv 423 >= 0, ((d0 + (d3 mod 5) floordiv 2342) * 234) mod 2309 + (d0 + d3 * 2038) floordiv 208 >= 0, ((((d0 + d3 * 2300) * 239) floordiv 2342) mod 2309) mod 239423 == 0, d0 + d3 mod 2642 + (((((d3 + d0 * 2) mod 1247) mod 2038) mod 2390) mod 2039) floordiv 55 >= 0)
#matmul_accesses = [
(m, n, k) -> (m, k),
(m, n, k) -> (k, n),
(m, n, k) -> (m, n)
]
#matmul_trait = {
args_in = 2,
args_out = 1,
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = #matmul_accesses,
library_call = "external_outerproduct_matmul"
}
!vector_type_A = type vector<4xf32>
!vector_type_B = type vector<4xf32>
!vector_type_C = type vector<4x4xf32>
!matrix_type_A = type memref<?x?x!vector_type_A>
!matrix_type_B = type memref<?x?x!vector_type_B>
!matrix_type_C = type memref<?x?x!vector_type_C>
'''
parser = parser or Parser()
module = parser.parse(code)
print(module.pretty())
if __name__ == '__main__':
p = Parser()
print("MLIR parser created")
test_attributes(p)
test_memrefs(p)
test_trailing_loc(p)
test_modules(p)
test_functions(p)
test_toplevel_function(p)
test_toplevel_functions(p)
test_affine(p)
test_definitions(p)

View File

@ -1,28 +0,0 @@
""" Tests pyMLIR on examples that use the Toy dialect. """
from mlir import parse_string, parse_path
import os
def test_toy_simple():
code = '''
module {
func @toy_func(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> {
%t_tensor = "toy.transpose"(%tensor) { inplace = true } : (tensor<2x3xf64>) -> tensor<3x2xf64>
return %t_tensor : tensor<3x2xf64>
}
}
'''
module = parse_string(code)
print(module.pretty())
def test_toy_advanced():
module = parse_path(os.path.join(os.path.dirname(__file__), 'toy.mlir'))
print(module.pretty())
if __name__ == '__main__':
test_toy_simple()
test_toy_advanced()

View File

@ -1,96 +0,0 @@
""" Tests pyMLIR's node visitor and transformer. """
from mlir import NodeVisitor, NodeTransformer, Parser, astnodes
from typing import Optional
# Sample code to use for visitors
_code = '''
module {
func @test0(%arg0: index, %arg1: index) {
%0 = alloc() : memref<100x100xf32>
%1 = alloc() : memref<100x100xf32, 2>
%2 = alloc() : memref<1xi32>
%c0 = constant 0 : index
%c64 = constant 64 : index
affine.for %arg2 = 0 to 10 {
affine.for %arg3 = 0 to 10 {
affine.dma_start %0[%arg2, %arg3], %1[%arg2, %arg3], %2[%c0], %c64 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
affine.dma_wait %2[%c0], %c64 : memref<1xi32>
}
}
return
}
func @test1(%arg0: index, %arg1: index) {
affine.for %arg2 = 0 to 10 {
affine.for %arg3 = 0 to 10 {
%c0 = constant 0 : index
%c64 = constant 64 : index
%c128 = constant 128 : index
%c256 = constant 256 : index
affine.dma_start %0[%arg2, %arg3], %1[%arg2, %arg3], %2[%c0], %c64, %c128, %c256 : memref<100x100xf32>, memref<100x100xf32, 2>, memref<1xi32>
affine.dma_wait %2[%c0], %c64 : memref<1xi32>
}
}
return
}
func @test2(%arg0: index, %arg1: index) {
%0 = alloc() : memref<100x100xf32>
}
}
'''
def test_visitor(parser: Optional[Parser] = None):
class MyVisitor(NodeVisitor):
def __init__(self):
self.functions = 0
def visit_Function(self, node: astnodes.Function):
self.functions += 1
print('Function detected:', node.name.value)
parser = parser or Parser()
m = parser.parse(_code)
visitor = MyVisitor()
visitor.visit(m)
assert visitor.functions == 3
def test_transformer(parser: Optional[Parser] = None):
# Simple node transformer that removes all operations with a result
class RemoveAllResultOps(NodeTransformer):
def visit_Operation(self, node: astnodes.Operation):
# There are one or more outputs, return None to remove from AST
if len(node.result_list) > 0:
return None
# No outputs, no need to do anything
return self.generic_visit(node)
parser = parser or Parser()
m = parser.parse(_code)
m = RemoveAllResultOps().visit(m)
print(m.pretty())
# Verify that there are no operations with results
class Tester(NodeVisitor):
def __init__(self):
self.fail = False
def visit_Operation(self, node: astnodes.Operation):
if len(node.result_list) > 0:
self.fail = True
return self.generic_visit(node)
t = Tester()
t.visit(m)
assert t.fail is False
if __name__ == '__main__':
p = Parser()
print("MLIR parser created")
test_visitor(p)
test_transformer(p)

View File

@ -1,110 +0,0 @@
module {
func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
%0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64> loc("test/codegen.toy":5:10)
%1 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64> loc("test/codegen.toy":5:25)
%2 = "toy.mul"(%0, %1) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> loc("test/codegen.toy":5:25)
"toy.return"(%2) : (tensor<*xf64>) -> () loc("test/codegen.toy":5:3)
} loc("test/codegen.toy":4:1)
func @main() {
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> loc("test/codegen.toy":9:17)
%1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64> loc("test/codegen.toy":9:3)
%2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64> loc("test/codegen.toy":10:17)
%3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64> loc("test/codegen.toy":10:3)
%4 = "toy.generic_call"(%1, %3) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> loc("test/codegen.toy":11:11)
%5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> loc("test/codegen.toy":12:11)
"toy.print"(%5) : (tensor<*xf64>) -> () loc("test/codegen.toy":13:3)
"toy.return"() : () -> () loc("test/codegen.toy":8:1)
} loc("test/codegen.toy":8:1)
func @main() {
%cst = constant 1.000000e+00 : f64
%cst_0 = constant 2.000000e+00 : f64
%cst_1 = constant 3.000000e+00 : f64
%cst_2 = constant 4.000000e+00 : f64
%cst_3 = constant 5.000000e+00 : f64
%cst_4 = constant 6.000000e+00 : f64
// Allocating buffers for the inputs and outputs.
%0 = alloc() : memref<3x2xf64>
%1 = alloc() : memref<3x2xf64>
%2 = alloc() : memref<2x3xf64>
// Initialize the input buffer with the constant values.
affine.store %cst, %2[0, 0] : memref<2x3xf64>
affine.store %cst_0, %2[0, 1] : memref<2x3xf64>
affine.store %cst_1, %2[0, 2] : memref<2x3xf64>
affine.store %cst_2, %2[1, 0] : memref<2x3xf64>
affine.store %cst_3, %2[1, 1] : memref<2x3xf64>
affine.store %cst_4, %2[1, 2] : memref<2x3xf64>
// Load the transpose value from the input buffer and store it into the
// next input buffer.
affine.for %arg0 = 0 to 3 {
affine.for %arg1 = 0 to 2 {
%3 = affine.load %2[%arg1, %arg0] : memref<2x3xf64>
affine.store %3, %1[%arg0, %arg1] : memref<3x2xf64>
}
}
// Multiply and store into the output buffer.
affine.for %arg0 = 0 to 2 {
affine.for %arg1 = 0 to 3 {
%3 = affine.load %1[%arg0, %arg1] : memref<3x2xf64>
%4 = affine.load %1[%arg0, %arg1] : memref<3x2xf64>
%5 = mulf %3, %4 : f64
affine.store %5, %0[%arg0, %arg1] : memref<3x2xf64>
}
}
// Print the value held by the buffer.
"toy.print"(%0) : (memref<3x2xf64>) -> ()
dealloc %2 : memref<2x3xf64>
dealloc %1 : memref<3x2xf64>
dealloc %0 : memref<3x2xf64>
return
}
func @main() {
%cst = constant 1.000000e+00 : f64
%cst_0 = constant 2.000000e+00 : f64
%cst_1 = constant 3.000000e+00 : f64
%cst_2 = constant 4.000000e+00 : f64
%cst_3 = constant 5.000000e+00 : f64
%cst_4 = constant 6.000000e+00 : f64
// Allocating buffers for the inputs and outputs.
%0 = alloc() : memref<3x2xf64>
%1 = alloc() : memref<2x3xf64>
// Initialize the input buffer with the constant values.
affine.store %cst, %1[0, 0] : memref<2x3xf64>
affine.store %cst_0, %1[0, 1] : memref<2x3xf64>
affine.store %cst_1, %1[0, 2] : memref<2x3xf64>
affine.store %cst_2, %1[1, 0] : memref<2x3xf64>
affine.store %cst_3, %1[1, 1] : memref<2x3xf64>
affine.store %cst_4, %1[1, 2] : memref<2x3xf64>
affine.for %arg0 = 0 to 3 {
affine.for %arg1 = 0 to 2 {
// Load the transpose value from the input buffer.
%2 = affine.load %1[%arg1, %arg0] : memref<2x3xf64>
// Multiply and store into the output buffer.
%3 = mulf %2, %2 : f64
affine.store %3, %0[%arg0, %arg1] : memref<3x2xf64>
}
}
// Print the value held by the buffer.
"toy.print"(%0) : (memref<3x2xf64>) -> ()
dealloc %1 : memref<2x3xf64>
dealloc %0 : memref<3x2xf64>
return
}
} loc("test/codegen.toy":0:0)
module {
func @multiply_transpose(%arg0: !toy.struct<tensor<*xf64>, tensor<*xf64>>) {
"toy.return"() : () -> ()
}
}

22
test/CMakeLists.txt Normal file
View File

@ -0,0 +1,22 @@
configure_lit_site_cfg(
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
MAIN_CONFIG
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py
)
set(HLSLD_TEST_DEPENDS
FileCheck count not
hlsld-translate
)
add_lit_testsuite(check-hlsld "Running the hlsld regression tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${HLSLD_TEST_DEPENDS}
)
set_target_properties(check-hlsld PROPERTIES FOLDER "Tests")
add_lit_testsuites(HLSLD
${CMAKE_CURRENT_SOURCE_DIR}
DEPENDS ${HLSLD_TEST_DEPENDS}
)

View File

@ -0,0 +1,5 @@
// RUN: hlsld-translate %s | FileCheck %s
func @test_standard() {
}

59
test/lit.cfg.py Normal file
View File

@ -0,0 +1,59 @@
# -*- Python -*-
import os
import platform
import re
import subprocess
import tempfile
import lit.formats
import lit.util
from lit.llvm import llvm_config
from lit.llvm.subst import ToolSubst
from lit.llvm.subst import FindTool
# Configuration file for the 'lit' test runner.
# name: The name of this test suite.
config.name = 'HLSLD'
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
config.suffixes = ['.mlir']
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.hlsld_obj_root, 'test')
config.substitutions.append(('%PATH%', config.environment['PATH']))
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
llvm_config.with_system_environment(
['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
llvm_config.use_default_substitutions()
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
# subdirectories contain auxiliary inputs for various tests in their parent
# directories.
config.excludes = ['Inputs', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt']
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.hlsld_obj_root, 'test')
# Tweak the PATH to include the tools dir.
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
tool_dirs = [config.hlsld_tools_dir, config.mlir_tools_dir, config.llvm_tools_dir]
tools = [
'hlsld-translate'
]
llvm_config.add_tool_substitutions(tools, tool_dirs)

52
test/lit.site.cfg.py.in Normal file
View File

@ -0,0 +1,52 @@
@LIT_SITE_CFG_IN_HEADER@
import sys
config.host_triple = "@LLVM_HOST_TRIPLE@"
config.target_triple = "@TARGET_TRIPLE@"
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
config.llvm_lib_dir = "@LLVM_LIBRARY_DIR@"
config.llvm_shlib_dir = "@SHLIBDIR@"
config.llvm_shlib_ext = "@SHLIBEXT@"
config.llvm_exe_ext = "@EXEEXT@"
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
config.python_executable = "@PYTHON_EXECUTABLE@"
config.gold_executable = "@GOLD_EXECUTABLE@"
config.ld64_executable = "@LD64_EXECUTABLE@"
config.enable_shared = @ENABLE_SHARED@
config.enable_assertions = @ENABLE_ASSERTIONS@
config.targets_to_build = "@TARGETS_TO_BUILD@"
config.native_target = "@LLVM_NATIVE_ARCH@"
config.llvm_bindings = "@LLVM_BINDINGS@".split(' ')
config.host_os = "@HOST_OS@"
config.host_cc = "@HOST_CC@"
config.host_cxx = "@HOST_CXX@"
# Note: ldflags can contain double-quoted paths, so must use single quotes here.
config.host_ldflags = '@HOST_LDFLAGS@'
config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@"
config.llvm_host_triple = '@LLVM_HOST_TRIPLE@'
config.host_arch = "@HOST_ARCH@"
config.mlir_src_root = "@MLIR_SOURCE_DIR@"
config.mlir_obj_root = "@MLIR_BINARY_DIR@"
config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
config.hlsld_src_root = "@HLSLD_SOURCE_DIR@"
config.hlsld_obj_root = "@HLSLD_BINARY_DIR@"
config.hlsld_tools_dir = "@HLSLD_TOOLS_DIR@"
# Support substitution of the tools_dir with user parameters. This is
# used when we can't determine the tool dir at configuration time.
try:
config.llvm_tools_dir = config.llvm_tools_dir % lit_config.params
config.llvm_shlib_dir = config.llvm_shlib_dir % lit_config.params
except KeyError:
e = sys.exc_info()[1]
key, = e.args
lit_config.fatal("unable to find %r parameter, use '--param=%s=VALUE'" % (key,key))
import lit.llvm
lit.llvm.initialize(lit_config, config)
# Let the main config do the real work.
lit_config.load_config(config, "@HLSLD_SOURCE_DIR@/test/lit.cfg.py")

1
tools/CMakeLists.txt Normal file
View File

@ -0,0 +1 @@
add_subdirectory(hlsld-translate)

View File

@ -0,0 +1,25 @@
set(LLVM_LINK_COMPONENTS
Support
)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS)
set(LIBS
${dialect_libs}
${translation_libs}
MLIRIR
MLIRParser
MLIRPass
MLIRTranslation
MLIRSupport
HLSLDEmitHLSCpp
)
add_llvm_executable(hlsld-translate hlsld-translate.cpp)
llvm_update_compile_flags(hlsld-translate)
target_link_libraries(hlsld-translate PRIVATE ${LIBS})
mlir_check_link_libraries(hlsld-translate)

View File

@ -0,0 +1,109 @@
//===------------------------------------------------------------*- C++ -*-===//
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllTranslations.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/ToolUtilities.h"
#include "mlir/Translation.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "EmitHLSCpp.h"
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
static llvm::cl::opt<std::string>
outputFilename("o", llvm::cl::desc("Output filename"),
llvm::cl::value_desc("filename"), llvm::cl::init("-"));
static llvm::cl::opt<bool>
splitInputFile("split-input-file",
llvm::cl::desc("Split the input file into pieces and "
"process each chunk independently"),
llvm::cl::init(false));
static llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
llvm::cl::init(false));
int main(int argc, char **argv)
{
mlir::registerAllDialects();
mlir::registerAllTranslations();
llvm::InitLLVM y(argc, argv);
// Add flags for all the registered translations.
llvm::cl::opt<const mlir::TranslateFunction *, false, mlir::TranslationParser>
translationRequested("", llvm::cl::desc("Translation to perform"),
llvm::cl::Required);
mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions();
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR translation driver\n");
std::string errorMessage;
auto input = mlir::openInputFile(inputFilename, &errorMessage);
if (!input)
{
llvm::errs() << errorMessage << "\n";
return 1;
}
auto output = mlir::openOutputFile(outputFilename, &errorMessage);
if (!output)
{
llvm::errs() << errorMessage << "\n";
return 1;
}
// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
llvm::raw_ostream &os) {
mlir::MLIRContext context;
context.allowUnregisteredDialects();
context.printOpOnDiagnostic(!verifyDiagnostics);
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
if (!verifyDiagnostics)
{
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
return (*translationRequested)(sourceMgr, os, &context);
}
// In the diagnostic verification flow, we ignore whether the translation
// failed (in most cases, it is expected to fail). Instead, we check if the
// diagnostics were produced as expected.
mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr,
&context);
(*translationRequested)(sourceMgr, os, &context);
return sourceMgrHandler.verify();
};
if (splitInputFile)
{
if (failed(mlir::splitAndProcessBuffer(std::move(input), processBuffer,
output->os())))
return 1;
}
else
{
if (failed(processBuffer(std::move(input), output->os())))
return 1;
}
output->keep();
return 0;
}