From 45fc170ab19a439538d65b0b3f99e4b8fe397d8f Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sat, 16 May 2020 23:19:59 -0700 Subject: [PATCH] [RTL] Implement some new combinatorial ops, implement VerilogEmitter support for them, implement a new rtl::CombinatorialVistor class. --- include/cirt/Dialect/RTL/Ops.h | 3 + include/cirt/Dialect/RTL/RTLExpressions.td | 9 ++ include/cirt/Dialect/RTL/Visitors.h | 81 ++++++++++++++ lib/Dialect/RTL/Ops.cpp | 16 +++ lib/EmitVerilog/CMakeLists.txt | 1 + lib/EmitVerilog/EmitVerilog.cpp | 121 ++++++++++++--------- test/EmitVerilog/verilog-rtl-dialect.mlir | 29 +++++ test/EmitVerilog/verilog-weird.mlir | 2 +- tools/cirt-translate/cirt-translate.cpp | 4 + tools/firtool/firtool.cpp | 4 +- 10 files changed, 219 insertions(+), 51 deletions(-) create mode 100644 include/cirt/Dialect/RTL/Visitors.h create mode 100644 test/EmitVerilog/verilog-rtl-dialect.mlir diff --git a/include/cirt/Dialect/RTL/Ops.h b/include/cirt/Dialect/RTL/Ops.h index 022c643b8f..e047719c0d 100644 --- a/include/cirt/Dialect/RTL/Ops.h +++ b/include/cirt/Dialect/RTL/Ops.h @@ -16,6 +16,9 @@ namespace rtl { #define GET_OP_CLASSES #include "cirt/Dialect/RTL/RTL.h.inc" +/// Return true if the specified operation is a combinatorial logic op. +bool isCombinatorial(Operation *op); + } // namespace rtl } // namespace cirt diff --git a/include/cirt/Dialect/RTL/RTLExpressions.td b/include/cirt/Dialect/RTL/RTLExpressions.td index 09921b5194..3526ccfdae 100644 --- a/include/cirt/Dialect/RTL/RTLExpressions.td +++ b/include/cirt/Dialect/RTL/RTLExpressions.td @@ -18,6 +18,10 @@ def ConstantOp : RTLOp<"constant", [NoSideEffect, ConstantLike, let arguments = (ins APIntAttr:$value); let results = (outs AnySignlessInteger:$result); + // FIXME(QoI): Instead of requiring "rtl.constant (42: i8) : i8", we should + // just use "rtl.constant 42: i8". This can be done with a custom printer and + // parser, but would be better to be autoderived from the + // FirstAttrDerivedResultType trait. let assemblyFormat = [{ `(` $value `)` attr-dict `:` type($result) }]; @@ -54,10 +58,15 @@ class UTBinRTLOp traits = []> : }]; } +// Arithmetic and Logical Binary Operations. def AddOp : UTBinRTLOp<"add", [Commutative]>; def SubOp : UTBinRTLOp<"sub">; def MulOp : UTBinRTLOp<"mul", [Commutative]>; def DivOp : UTBinRTLOp<"div">; +def RemOp : UTBinRTLOp<"rem">; +def AndOp : UTBinRTLOp<"and", [Commutative]>; +def OrOp : UTBinRTLOp<"or", [Commutative]>; +def XorOp : UTBinRTLOp<"xor", [Commutative]>; //===----------------------------------------------------------------------===// // Other Operations diff --git a/include/cirt/Dialect/RTL/Visitors.h b/include/cirt/Dialect/RTL/Visitors.h new file mode 100644 index 0000000000..8b36f9192a --- /dev/null +++ b/include/cirt/Dialect/RTL/Visitors.h @@ -0,0 +1,81 @@ +//===- RTL/Visitors.h - RTL Dialect Visitors --------------------*- C++ -*-===// +// +// This file defines visitors that make it easier to work with RTL IR. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRT_DIALECT_RTL_VISITORS_H +#define CIRT_DIALECT_RTL_VISITORS_H + +#include "cirt/Dialect/RTL/Ops.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace cirt { +namespace rtl { + +/// This helps visit Combinatorial nodes. +template +class CombinatorialVisitor { +public: + ResultType dispatchCombinatorialVisitor(Operation *op, ExtraArgs... args) { + auto *thisCast = static_cast(this); + return TypeSwitch(op) + .template Case([&](auto expr) -> ResultType { + return thisCast->visitComb(expr, args...); + }) + .Default([&](auto expr) -> ResultType { + return thisCast->visitInvalidComb(op, args...); + }); + } + + /// This callback is invoked on any non-expression operations. + ResultType visitInvalidComb(Operation *op, ExtraArgs... args) { + op->emitOpError("unknown RTL combinatorial node"); + abort(); + } + + /// This callback is invoked on any combinatorial operations that are not + /// handled by the concrete visitor. + ResultType visitUnhandledComb(Operation *op, ExtraArgs... args) { + return ResultType(); + } + + /// This fallback is invoked on any binary node that isn't explicitly handled. + /// The default implementation delegates to the 'unhandled' fallback. + ResultType visitBinaryComb(Operation *op, ExtraArgs... args) { + return static_cast(this)->visitUnhandledComb(op, args...); + } + +#define HANDLE(OPTYPE, OPKIND) \ + ResultType visitComb(OPTYPE op, ExtraArgs... args) { \ + return static_cast(this)->visit##OPKIND##Comb(op, \ + args...); \ + } + + // Basic nodes. + HANDLE(ConstantOp, Unhandled) + + // Arithmetic and Logical Binary Operations. + HANDLE(AddOp, Binary); + HANDLE(SubOp, Binary); + HANDLE(MulOp, Binary); + HANDLE(DivOp, Binary); + HANDLE(RemOp, Binary); + HANDLE(AndOp, Binary); + HANDLE(OrOp, Binary); + HANDLE(XorOp, Binary); + + // Other operations. + HANDLE(ConcatOp, Unhandled); +#undef HANDLE +}; + +} // namespace rtl +} // namespace cirt + +#endif // CIRT_DIALECT_RTL_VISITORS_H \ No newline at end of file diff --git a/lib/Dialect/RTL/Ops.cpp b/lib/Dialect/RTL/Ops.cpp index 722a45c5cd..dcc0601a7d 100644 --- a/lib/Dialect/RTL/Ops.cpp +++ b/lib/Dialect/RTL/Ops.cpp @@ -3,6 +3,7 @@ //===----------------------------------------------------------------------===// #include "cirt/Dialect/RTL/Ops.h" +#include "cirt/Dialect/RTL/Visitors.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/StandardTypes.h" @@ -10,6 +11,21 @@ using namespace cirt; using namespace rtl; +/// Return true if the specified operation is a combinatorial logic op. +bool rtl::isCombinatorial(Operation *op) { + struct IsCombClassifier + : public CombinatorialVisitor { + bool visitInvalidComb(Operation *op) { return false; } + bool visitUnhandledComb(Operation *op) { return true; } + }; + + return IsCombClassifier().dispatchCombinatorialVisitor(op); +} + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + static LogicalResult verify(ConstantOp constant) { // If the result type has a bitwidth, then the attribute must match its width. auto intType = constant.getType().cast(); diff --git a/lib/EmitVerilog/CMakeLists.txt b/lib/EmitVerilog/CMakeLists.txt index 112fd5b30e..b0924b906a 100644 --- a/lib/EmitVerilog/CMakeLists.txt +++ b/lib/EmitVerilog/CMakeLists.txt @@ -7,4 +7,5 @@ add_mlir_library(CIRTEmitVerilog LINK_LIBS PUBLIC MLIRFIRRTL + MLIRRTL ) diff --git a/lib/EmitVerilog/EmitVerilog.cpp b/lib/EmitVerilog/EmitVerilog.cpp index 2b7088a071..b1be71f9e5 100644 --- a/lib/EmitVerilog/EmitVerilog.cpp +++ b/lib/EmitVerilog/EmitVerilog.cpp @@ -6,8 +6,9 @@ #include "cirt/EmitVerilog.h" #include "cirt/Dialect/FIRRTL/Visitors.h" +#include "cirt/Dialect/RTL/Ops.h" +#include "cirt/Dialect/RTL/Visitors.h" #include "cirt/Support/LLVM.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Translation.h" @@ -34,11 +35,9 @@ static llvm::ManagedStatic> reservedWordCache; //===----------------------------------------------------------------------===// static bool isVerilogExpression(Operation *op) { - // All FIRRTL expressions are also Verilog expressions. - if (isExpression(op)) - return true; - // The standard ConstantIntOp is an expression as well. - return isa(op); + // All FIRRTL expressions and RTL combinatorial logic ops are Verilog + // expressions. + return isExpression(op) || rtl::isCombinatorial(op); } /// Return the width of the specified FIRRTL type in bits or -1 if it isn't @@ -673,7 +672,8 @@ namespace { /// we emit the characters to a SmallVector which allows us to emit a bunch of /// stuff, then pre-insert parentheses and other things if we find out that it /// was needed later. -class ExprEmitter : public ExprVisitor { +class ExprEmitter : public ExprVisitor, + public rtl::CombinatorialVisitor { public: /// Create an ExprEmitter for the specified module emitter, and keeping track /// of any emitted expressions in the specified set. @@ -685,6 +685,7 @@ public: /// Emit the specified expression and return it as a string. std::string emitExpressionToString(Value exp, VerilogPrecedence precedence); friend class ExprVisitor; + friend class CombinatorialVisitor; /// Do a best-effort job of looking through noop cast operations. Value lookThroughNoopCasts(Value value) { @@ -702,11 +703,16 @@ private: bool forceExpectedSign = false); SubExprInfo visitUnhandledExpr(Operation *op); - SubExprInfo visitInvalidExpr(Operation *op); + SubExprInfo visitInvalidExpr(Operation *op) { + return dispatchCombinatorialVisitor(op); + } + SubExprInfo visitInvalidComb(Operation *op) { return visitUnhandledExpr(op); } + SubExprInfo visitUnhandledComb(Operation *op) { + return visitUnhandledExpr(op); + } using ExprVisitor::visitExpr; SubExprInfo visitExpr(firrtl::ConstantOp op); - SubExprInfo visitExpr(mlir::ConstantIntOp op); /// Emit a verilog concatenation of the specified values. If the before or /// after strings are specified, they are included as prefix/postfix elements @@ -719,35 +725,7 @@ private: SubExprInfo emitBitSelect(Value operand, unsigned hiBit, unsigned loBit); SubExprInfo emitBinary(Operation *op, VerilogPrecedence prec, - const char *syntax, bool hasStrictSign = false) { - auto lhsInfo = emitSubExpr(op->getOperand(0), prec, hasStrictSign); - os << ' ' << syntax << ' '; - - // The precedence of the RHS operand must be tighter than this operator if - // they have a different opcode in order to handle things like "x-(a+b)". - // This isn't needed on the LHS, because the relevant Verilog operators are - // left-associative. - // - auto *rhsOperandOp = - lookThroughNoopCasts(op->getOperand(1)).getDefiningOp(); - auto rhsPrec = VerilogPrecedence(prec - 1); - if (rhsOperandOp && op->getName() == rhsOperandOp->getName()) - rhsPrec = prec; - - auto rhsInfo = emitSubExpr(op->getOperand(1), rhsPrec, hasStrictSign); - - // If we have a strict sign, then match the firrtl operation sign. - // Otherwise, the result is signed if both operands are signed. - SubExprSignedness signedness; - if (hasStrictSign) - signedness = getSignednessOf(op->getResult(0).getType()); - else if (lhsInfo.signedness == IsSigned && rhsInfo.signedness == IsSigned) - signedness = IsSigned; - else - signedness = IsUnsigned; - - return {prec, signedness}; - } + const char *syntax, bool hasStrictSign = false); /// Emit the specified subexpression in a context where the sign matters, /// e.g. for a less than comparison or divide. @@ -841,6 +819,20 @@ private: // Conversion to/from standard integer types is a noop. SubExprInfo visitExpr(StdIntCast op) { return emitNoopCast(op); } + // RTL Dialect Operations + using CombinatorialVisitor::visitComb; + SubExprInfo visitComb(rtl::ConstantOp op); + SubExprInfo visitComb(rtl::AddOp op) { return emitBinary(op, Addition, "+"); } + SubExprInfo visitComb(rtl::SubOp op) { return emitBinary(op, Addition, "-"); } + SubExprInfo visitComb(rtl::MulOp op) { return emitBinary(op, Multiply, "*"); } + SubExprInfo visitComb(rtl::DivOp op) { + return emitSignedBinary(op, Multiply, "/"); + } + SubExprInfo visitComb(rtl::AndOp op) { return emitBinary(op, And, "&"); } + SubExprInfo visitComb(rtl::OrOp op) { return emitBinary(op, Or, "|"); } + SubExprInfo visitComb(rtl::XorOp op) { return emitBinary(op, Xor, "^"); } + SubExprInfo visitComb(rtl::ConcatOp op); + private: SmallPtrSet &emittedExprs; SmallString<128> resultBuffer; @@ -869,6 +861,36 @@ std::string ExprEmitter::emitExpressionToString(Value exp, return std::string(resultBuffer.begin(), resultBuffer.end()); } +SubExprInfo ExprEmitter::emitBinary(Operation *op, VerilogPrecedence prec, + const char *syntax, bool hasStrictSign) { + auto lhsInfo = emitSubExpr(op->getOperand(0), prec, hasStrictSign); + os << ' ' << syntax << ' '; + + // The precedence of the RHS operand must be tighter than this operator if + // they have a different opcode in order to handle things like "x-(a+b)". + // This isn't needed on the LHS, because the relevant Verilog operators are + // left-associative. + // + auto *rhsOperandOp = lookThroughNoopCasts(op->getOperand(1)).getDefiningOp(); + auto rhsPrec = VerilogPrecedence(prec - 1); + if (rhsOperandOp && op->getName() == rhsOperandOp->getName()) + rhsPrec = prec; + + auto rhsInfo = emitSubExpr(op->getOperand(1), rhsPrec, hasStrictSign); + + // If we have a strict sign, then match the firrtl operation sign. + // Otherwise, the result is signed if both operands are signed. + SubExprSignedness signedness; + if (hasStrictSign) + signedness = getSignednessOf(op->getResult(0).getType()); + else if (lhsInfo.signedness == IsSigned && rhsInfo.signedness == IsSigned) + signedness = IsSigned; + else + signedness = IsUnsigned; + + return {prec, signedness}; +} + /// Emit the specified value as a subexpression to the stream. SubExprInfo ExprEmitter::emitSubExpr(Value exp, VerilogPrecedence parenthesizeIfLooserThan, @@ -967,6 +989,15 @@ SubExprInfo ExprEmitter::emitCat(ArrayRef values, StringRef before, return {Unary, IsUnsigned}; } +SubExprInfo ExprEmitter::visitComb(rtl::ConcatOp op) { + os << '{'; + llvm::interleaveComma(op.getOperands(), os, + [&](Value v) { emitSubExpr(v, LowestPrecedence); }); + + os << '}'; + return {Unary, IsUnsigned}; +} + /// Emit a verilog bit selection operation like x[4:0], the bit numbers are /// inclusive like verilog. /// @@ -1001,7 +1032,7 @@ SubExprInfo ExprEmitter::visitExpr(firrtl::ConstantOp op) { return {Unary, resType.isSigned() ? IsSigned : IsUnsigned}; } -SubExprInfo ExprEmitter::visitExpr(mlir::ConstantIntOp op) { +SubExprInfo ExprEmitter::visitComb(rtl::ConstantOp op) { auto resType = op.getType().cast(); os << resType.getWidth() << '\''; if (resType.isSigned()) @@ -1126,14 +1157,6 @@ SubExprInfo ExprEmitter::visitUnhandledExpr(Operation *op) { return {Symbol, IsUnsigned}; } -// This handles dispatching to non-FIRRTL operations. -SubExprInfo ExprEmitter::visitInvalidExpr(Operation *op) { - if (auto cst = dyn_cast(op)) - return visitExpr(cst); - - return visitUnhandledExpr(op); -} - //===----------------------------------------------------------------------===// // Statements //===----------------------------------------------------------------------===// @@ -1164,7 +1187,7 @@ void ModuleEmitter::emitStatementExpression(Operation *op) { if (op->getResult(0).use_empty()) { indent() << "// Unused: "; } else if (emitInlineWireDecls) { - auto type = op->getResult(0).getType().cast(); + auto type = op->getResult(0).getType(); indent() << "wire "; if (getBitWidthOrSentinel(type) != 1) { @@ -1586,7 +1609,7 @@ static bool isExpressionUnableToInline(Operation *op) { /// Return true for operations that are always inlined. static bool isExpressionAlwaysInline(Operation *op) { - if (isa(op) || isa(op) || + if (isa(op) || isa(op) || isa(op)) return true; diff --git a/test/EmitVerilog/verilog-rtl-dialect.mlir b/test/EmitVerilog/verilog-rtl-dialect.mlir new file mode 100644 index 0000000000..9688eae29d --- /dev/null +++ b/test/EmitVerilog/verilog-rtl-dialect.mlir @@ -0,0 +1,29 @@ +// RUN: cirt-translate %s -emit-verilog -verify-diagnostics | FileCheck %s --strict-whitespace + +firrtl.circuit "Circuit" { + firrtl.module @M1(%x : !firrtl.uint<8>, + %y : !firrtl.flip>, + %z : i8) { + %c42 = rtl.constant (42 : i8) : i8 + %c5 = rtl.constant (5 : i8) : i8 + %a = rtl.add %z, %c42 : i8 + %b = rtl.mul %a, %c5 : i8 + %c = firrtl.stdIntCast %b : (i8) -> !firrtl.uint<8> + firrtl.connect %y, %c : !firrtl.flip>, !firrtl.uint<8> + + %d = rtl.mul %z, %z : i8 + %e = rtl.concat %d, %z, %d : (i8, i8, i8) -> i8 + %f = firrtl.stdIntCast %e : (i8) -> !firrtl.uint<8> + firrtl.connect %y, %f : !firrtl.flip>, !firrtl.uint<8> + } + + // CHECK-LABEL: module M1( + // CHECK-NEXT: input [7:0] x, + // CHECK-NEXT: output [7:0] y, + // CHECK-NEXT: input [7:0] z); + // CHECK-EMPTY: + // CHECK-NEXT: assign y = (z + 8'h2A) * 8'h5; + // CHECK-NEXT: wire [7:0] _T = z * z; + // CHECK-NEXT: assign y = {_T, z, _T}; + // CHECK-NEXT: endmodule +} diff --git a/test/EmitVerilog/verilog-weird.mlir b/test/EmitVerilog/verilog-weird.mlir index c1b733197e..0b4c85fdfb 100644 --- a/test/EmitVerilog/verilog-weird.mlir +++ b/test/EmitVerilog/verilog-weird.mlir @@ -29,7 +29,7 @@ firrtl.circuit "Circuit" { %c42_ui8 = firrtl.constant(42 : ui8) : !firrtl.uint<8> firrtl.connect %y, %c42_ui8 : !firrtl.flip>, !firrtl.uint<8> - %c42 = constant 42 : i8 + %c42 = rtl.constant (42: i8) : i8 %a = firrtl.stdIntCast %c42 : (i8) -> !firrtl.uint<8> firrtl.connect %y, %a : !firrtl.flip>, !firrtl.uint<8> diff --git a/tools/cirt-translate/cirt-translate.cpp b/tools/cirt-translate/cirt-translate.cpp index 8d141e742f..e017ffaa8b 100644 --- a/tools/cirt-translate/cirt-translate.cpp +++ b/tools/cirt-translate/cirt-translate.cpp @@ -6,6 +6,7 @@ //===----------------------------------------------------------------------===// #include "cirt/Dialect/FIRRTL/Dialect.h" +#include "cirt/Dialect/RTL/Dialect.h" #include "cirt/EmitVerilog.h" #include "cirt/FIRParser.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -48,6 +49,9 @@ int main(int argc, char **argv) { registerMLIRContextCLOptions(); registerDialect(); + // RTL. + registerDialect(); + // Register FIRRTL stuff. registerDialect(); registerFIRParserTranslation(); diff --git a/tools/firtool/firtool.cpp b/tools/firtool/firtool.cpp index f00687c584..9fab095f7f 100644 --- a/tools/firtool/firtool.cpp +++ b/tools/firtool/firtool.cpp @@ -6,6 +6,7 @@ //===----------------------------------------------------------------------===// #include "cirt/Dialect/FIRRTL/Dialect.h" +#include "cirt/Dialect/RTL/Dialect.h" #include "cirt/EmitVerilog.h" #include "cirt/FIRParser.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -120,8 +121,9 @@ int main(int argc, char **argv) { registerMLIRContextCLOptions(); registerPassManagerCLOptions(); - // Register FIRRTL stuff. + // Register our dialects. registerDialect(); + registerDialect(); // Parse pass names in main to ensure static initialization completed. cl::ParseCommandLineOptions(argc, argv, "cirt modular optimizer driver\n");