[HandshakeToFIRRTL] Add support for index_cast, zexti and trunci ops (#1865)

zexti and trunci operations generate handshake components which wrap around FIRRTL `pad` and `bits` operations. The builder for each of these ops is used to support `index_cast`, which pads or truncates an input type compared to the fixed index-width (64 bits).
This commit is contained in:
Morten Borup Petersen 2021-09-24 08:15:58 +01:00 committed by GitHub
parent 613ace63c4
commit f5441a2d4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 174 additions and 16 deletions

View File

@ -96,14 +96,14 @@ public:
ResultType dispatchStdExprVisitor(Operation *op, ExtraArgs... args) {
auto *thisCast = static_cast<ConcreteType *>(this);
return TypeSwitch<Operation *, ResultType>(op)
.template Case<
// Integer binary expressions.
CmpIOp, AddIOp, SubIOp, MulIOp, SignedDivIOp, SignedRemIOp,
UnsignedDivIOp, UnsignedRemIOp, XOrOp, AndOp, OrOp, ShiftLeftOp,
SignedShiftRightOp, UnsignedShiftRightOp>(
[&](auto opNode) -> ResultType {
return thisCast->visitStdExpr(opNode, args...);
})
.template Case<IndexCastOp, ZeroExtendIOp, TruncateIOp,
// Integer binary expressions.
CmpIOp, AddIOp, SubIOp, MulIOp, SignedDivIOp,
SignedRemIOp, UnsignedDivIOp, UnsignedRemIOp, XOrOp,
AndOp, OrOp, ShiftLeftOp, SignedShiftRightOp,
UnsignedShiftRightOp>([&](auto opNode) -> ResultType {
return thisCast->visitStdExpr(opNode, args...);
})
.Default([&](auto opNode) -> ResultType {
return thisCast->visitInvalidOp(op, args...);
});
@ -126,6 +126,10 @@ public:
return static_cast<ConcreteType *>(this)->visitUnhandledOp(op, args...); \
}
HANDLE(IndexCastOp);
HANDLE(ZeroExtendIOp);
HANDLE(TruncateIOp);
// Integer binary expressions.
HANDLE(CmpIOp);
HANDLE(AddIOp);

View File

@ -112,6 +112,16 @@ static Value createConstantOp(FIRRTLType opType, APInt value,
return Value();
}
static Type getHandshakeBundleDataType(BundleType bundle) {
if (auto dataType = bundle.getElementType("data")) {
auto intType = dataType.cast<firrtl::IntType>();
return IntegerType::get(bundle.getContext(), intType.getWidthOrSentinel(),
intType.isSigned() ? IntegerType::Signed
: IntegerType::Unsigned);
} else
return NoneType::get(bundle.getContext());
}
static Type getHandshakeDataType(Operation *op) {
if (auto memOp = dyn_cast<MemoryOp>(op))
return memOp.getMemRefType().getElementType();
@ -121,14 +131,7 @@ static Type getHandshakeDataType(Operation *op) {
// to a bundled FIRRTLType, here we convert it back to a normal data type.
// Is there a better way to do this?
auto type = sinkOp.getOperand().getType().cast<BundleType>();
if (auto dataType = type.getElementType("data")) {
auto intType = dataType.cast<firrtl::IntType>();
return IntegerType::get(type.getContext(), intType.getWidthOrSentinel(),
intType.isSigned() ? IntegerType::Signed
: IntegerType::Unsigned);
} else
return NoneType::get(type.getContext());
return getHandshakeBundleDataType(type);
} else
return op->getResult(0).getType();
}
@ -577,6 +580,9 @@ public:
bool visitInvalidOp(Operation *op) { return false; }
bool visitStdExpr(CmpIOp op);
bool visitStdExpr(ZeroExtendIOp op);
bool visitStdExpr(TruncateIOp op);
bool visitStdExpr(IndexCastOp op);
#define HANDLE(OPTYPE, FIRRTLTYPE) \
bool visitStdExpr(OPTYPE op) { return buildBinaryLogic<FIRRTLTYPE>(), true; }
@ -596,6 +602,9 @@ public:
HANDLE(UnsignedShiftRightOp, DShrPrimOp);
#undef HANDLE
bool buildZeroExtendOp(unsigned dstWidth);
bool buildTruncateOp(unsigned dstWidth);
private:
ValueVectorList portList;
Location insertLoc;
@ -625,6 +634,88 @@ bool StdExprBuilder::visitStdExpr(CmpIOp op) {
llvm_unreachable("invalid CmpIOp");
}
bool StdExprBuilder::buildZeroExtendOp(unsigned dstWidth) {
ValueVector arg0Subfield = portList[0];
ValueVector resultSubfields = portList[1];
Value arg0Valid = arg0Subfield[0];
Value arg0Ready = arg0Subfield[1];
Value arg0Data = arg0Subfield[2];
Value resultValid = resultSubfields[0];
Value resultReady = resultSubfields[1];
Value resultData = resultSubfields[2];
Value resultDataOp =
rewriter.create<PadPrimOp>(insertLoc, arg0Data, dstWidth);
rewriter.create<ConnectOp>(insertLoc, resultData, resultDataOp);
// Generate valid signal.
rewriter.create<ConnectOp>(insertLoc, resultValid, arg0Valid);
// Generate ready signal.
auto argReadyOp = rewriter.create<AndPrimOp>(insertLoc, resultReady.getType(),
resultReady, arg0Valid);
rewriter.create<ConnectOp>(insertLoc, arg0Ready, argReadyOp);
return true;
}
bool StdExprBuilder::buildTruncateOp(unsigned int dstWidth) {
ValueVector arg0Subfield = portList[0];
ValueVector resultSubfields = portList[1];
Value arg0Valid = arg0Subfield[0];
Value arg0Ready = arg0Subfield[1];
Value arg0Data = arg0Subfield[2];
Value resultValid = resultSubfields[0];
Value resultReady = resultSubfields[1];
Value resultData = resultSubfields[2];
Value resultDataOp =
rewriter.create<BitsPrimOp>(insertLoc, arg0Data, dstWidth - 1, 0);
rewriter.create<ConnectOp>(insertLoc, resultData, resultDataOp);
// Generate valid signal.
rewriter.create<ConnectOp>(insertLoc, resultValid, arg0Valid);
// Generate ready signal.
auto argReadyOp = rewriter.create<AndPrimOp>(insertLoc, resultReady.getType(),
resultReady, arg0Valid);
rewriter.create<ConnectOp>(insertLoc, arg0Ready, argReadyOp);
return true;
}
/// Extracts the type of the data-carrying type of opType. If opType is a
/// bundle, getHandshakeBundleDataType extracts the data-carrying type, else,
/// assume that opType itself is the data-carrying type.
static Type getOperandDataType(Type opType) {
if (auto bundleType = opType.dyn_cast<BundleType>(); bundleType)
return getHandshakeBundleDataType(bundleType);
return opType;
}
bool StdExprBuilder::visitStdExpr(ZeroExtendIOp op) {
return buildZeroExtendOp(
getFIRRTLType(getOperandDataType(op.getOperand().getType()))
.getBitWidthOrSentinel());
}
bool StdExprBuilder::visitStdExpr(TruncateIOp op) {
return buildTruncateOp(
getFIRRTLType(getOperandDataType(op.getOperand().getType()))
.getBitWidthOrSentinel());
}
bool StdExprBuilder::visitStdExpr(IndexCastOp op) {
FIRRTLType sourceType =
getFIRRTLType(getOperandDataType(op.getOperand().getType()));
FIRRTLType targetType =
getFIRRTLType(getOperandDataType(op.getResult().getType()));
unsigned targetBits = targetType.getBitWidthOrSentinel();
unsigned sourceBits = sourceType.getBitWidthOrSentinel();
return (targetBits < sourceBits ? buildTruncateOp(targetBits)
: buildZeroExtendOp(targetBits));
}
/// Please refer to simple_addi.mlir test case.
template <typename OpType>
void StdExprBuilder::buildBinaryLogic() {

View File

@ -0,0 +1,63 @@
// RUN: circt-opt -lower-handshake-to-firrtl %s --split-input-file | FileCheck %s
// CHECK: firrtl.module @std_index_cast_1ins_1outs_ui8(in %arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>, out %arg1: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>) {
// CHECK-NEXT: %0 = firrtl.subfield %arg0(0) : (!firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>) -> !firrtl.uint<1>
// CHECK-NEXT: %1 = firrtl.subfield %arg0(1) : (!firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>) -> !firrtl.uint<1>
// CHECK-NEXT: %2 = firrtl.subfield %arg0(2) : (!firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>) -> !firrtl.uint<64>
// CHECK-NEXT: %3 = firrtl.subfield %arg1(0) : (!firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>) -> !firrtl.uint<1>
// CHECK-NEXT: %4 = firrtl.subfield %arg1(1) : (!firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>) -> !firrtl.uint<1>
// CHECK-NEXT: %5 = firrtl.subfield %arg1(2) : (!firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>) -> !firrtl.uint<8>
// CHECK-NEXT: %6 = firrtl.bits %2 7 to 0 : (!firrtl.uint<64>) -> !firrtl.uint<8>
// CHECK-NEXT: firrtl.connect %5, %6 : !firrtl.uint<8>, !firrtl.uint<8>
// CHECK-NEXT: firrtl.connect %3, %0 : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK-NEXT: %7 = firrtl.and %4, %0 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK-NEXT: firrtl.connect %1, %7 : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK-NEXT: }
// CHECK-LABEL: firrtl.module @test_index_cast(
// CHECK-SAME: in %arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>,
// CHECK-SAME: in %arg1: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>,
// CHECK-SAME: out %arg2: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>,
// CHECK-SAME: out %arg3: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>,
// CHECK-SAME: in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>) {
handshake.func @test_index_cast(%arg0: index, %arg1: none, ...) -> (i8, none) {
// CHECK: %inst_arg0, %inst_arg1 = firrtl.instance @std_index_cast_1ins_1outs_ui8 {name = ""} : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>
// CHECK-NEXT: firrtl.connect %inst_arg0, %arg0 : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>
%0 = index_cast %arg0 : index to i8
// CHECK: firrtl.connect %arg2, %inst_arg1 : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>
// CHECK-NEXT: firrtl.connect %arg3, %arg1 : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
handshake.return %0, %arg1 : i8, none
}
// -----
// CHECK: firrtl.module @std_index_cast_1ins_1outs_ui64(in %arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>, out %arg1: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>) {
// CHECK-NEXT: %0 = firrtl.subfield %arg0(0) : (!firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>) -> !firrtl.uint<1>
// CHECK-NEXT: %1 = firrtl.subfield %arg0(1) : (!firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>) -> !firrtl.uint<1>
// CHECK-NEXT: %2 = firrtl.subfield %arg0(2) : (!firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>) -> !firrtl.uint<8>
// CHECK-NEXT: %3 = firrtl.subfield %arg1(0) : (!firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>) -> !firrtl.uint<1>
// CHECK-NEXT: %4 = firrtl.subfield %arg1(1) : (!firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>) -> !firrtl.uint<1>
// CHECK-NEXT: %5 = firrtl.subfield %arg1(2) : (!firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>) -> !firrtl.uint<64>
// CHECK-NEXT: %6 = firrtl.pad %2, 64 : (!firrtl.uint<8>) -> !firrtl.uint<64>
// CHECK-NEXT: firrtl.connect %5, %6 : !firrtl.uint<64>, !firrtl.uint<64>
// CHECK-NEXT: firrtl.connect %3, %0 : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK-NEXT: %7 = firrtl.and %4, %0 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK-NEXT: firrtl.connect %1, %7 : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK-NEXT: }
// CHECK-LABEL: firrtl.module @test_index_cast2(
// CHECK-SAME: in %arg0: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>,
// CHECK-SAME: in %arg1: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>,
// CHECK-SAME: out %arg2: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>,
// CHECK-SAME: out %arg3: !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>,
// CHECK-SAME: in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>) {
handshake.func @test_index_cast2(%arg0: i8, %arg1: none, ...) -> (index, none) {
// CHECK: %inst_arg0, %inst_arg1 = firrtl.instance @std_index_cast_1ins_1outs_ui64 {name = ""} : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>
// CHECK-NEXT: firrtl.connect %inst_arg0, %arg0 : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<8>>
%0 = index_cast %arg0 : i8 to index
// CHECK: firrtl.connect %arg2, %inst_arg1 : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>, data: uint<64>>
// CHECK-NEXT: firrtl.connect %arg3, %arg1 : !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>, !firrtl.bundle<valid: uint<1>, ready flip: uint<1>>
handshake.return %0, %arg1 : index, none
}