Revert "[FIRRTL] Add multibit_mux op and lowerings (#2392)" (#2472)

This reverts commit 2eea30198f because it is 
causing assertion failure in an internal design.
This commit is contained in:
uenoku 2022-01-18 15:03:34 +09:00 committed by GitHub
parent d6a959d91e
commit e389ffa605
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 53 additions and 193 deletions

View File

@ -242,26 +242,6 @@ def SubaccessOp : FIRRTLExprOp<"subaccess", [
let hasCanonicalizeMethod = true;
}
def MultibitMuxOp : FIRRTLExprOp<"multibit_mux"> {
let summary = "Multibit multiplexer";
let description = [{
The multibit mux expression dynamically selects operands. The
index must be an expression with an unsigned integer type.
```
%result = firrtl.multibit_mux %index, %0, %1, %2, ... : t1, t2
```
}];
let arguments = (ins FIRRTLType:$index, Variadic<FIRRTLType>:$inputs);
let results = (outs FIRRTLType:$result);
// Need a custom parser/printer to emit operands type only once.
let parser = "return parse$cppClass(parser, result);";
let printer = "print$cppClass(p, *this);";
let hasFolder = true;
let hasCanonicalizeMethod = true;
}
//===----------------------------------------------------------------------===//
// Primitive Operations
//===----------------------------------------------------------------------===//

View File

@ -31,7 +31,7 @@ public:
// Basic Expressions
.template Case<
ConstantOp, SpecialConstantOp, InvalidValueOp, SubfieldOp,
SubindexOp, SubaccessOp, MultibitMuxOp,
SubindexOp, SubaccessOp,
// Arithmetic and Logical Binary Primitives.
AddPrimOp, SubPrimOp, MulPrimOp, DivPrimOp, RemPrimOp, AndPrimOp,
OrPrimOp, XorPrimOp,
@ -89,7 +89,6 @@ public:
HANDLE(SubfieldOp, Unhandled);
HANDLE(SubindexOp, Unhandled);
HANDLE(SubaccessOp, Unhandled);
HANDLE(MultibitMuxOp, Unhandled);
// Arithmetic and Logical Binary Primitives.
HANDLE(AddPrimOp, Binary);

View File

@ -1400,7 +1400,6 @@ struct FIRRTLLowering : public FIRRTLVisitor<FIRRTLLowering, LogicalResult> {
}
LogicalResult visitExpr(TailPrimOp op);
LogicalResult visitExpr(MuxPrimOp op);
LogicalResult visitExpr(MultibitMuxOp op);
LogicalResult visitExpr(VerbatimExprOp op);
// Statements
@ -3158,27 +3157,6 @@ LogicalResult FIRRTLLowering::visitExpr(MuxPrimOp op) {
ifFalse);
}
LogicalResult FIRRTLLowering::visitExpr(MultibitMuxOp op) {
// Lower and resize to the index width.
auto index = getLoweredAndExtOrTruncValue(
op.index(), UIntType::get(op.getContext(),
getBitWidthFromVectorSize(op.inputs().size())));
if (!index)
return failure();
SmallVector<Value> loweredInputs;
loweredInputs.reserve(op.inputs().size());
for (auto input : op.inputs()) {
auto lowered = getLoweredAndExtendedValue(input, op.getType());
if (!lowered)
return failure();
loweredInputs.push_back(lowered);
}
Value array = builder.create<hw::ArrayCreateOp>(loweredInputs);
return setLoweringTo<hw::ArrayGetOp>(op, array, index);
}
LogicalResult FIRRTLLowering::visitExpr(VerbatimExprOp op) {
auto resultTy = lowerType(op.getType());
if (!resultTy)

View File

@ -1360,45 +1360,6 @@ LogicalResult SubaccessOp::canonicalize(SubaccessOp op,
});
}
OpFoldResult MultibitMuxOp::fold(ArrayRef<Attribute> operands) {
// If there is only one input, just return it.
if (operands.size() == 2)
return getOperand(1);
if (auto constIndex = getConstant(operands[0])) {
auto index = constIndex->getExtValue();
// operands[0] is index so (index + 1) is the index we want.
if (index >= 0 && index + 1 < static_cast<int64_t>(operands.size()))
return getOperand(index + 1);
}
return {};
}
LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
PatternRewriter &rewriter) {
// If all operands are equal, just canonicalize to it. We can add this
// canonicalization as a folder but it costly to look through all inputs so it
// is added here.
if (llvm::all_of(op.inputs().drop_front(),
[&](auto input) { return input == op.inputs().front(); })) {
rewriter.replaceOp(op, op.inputs().front());
return success();
}
// If the size is 2, canonicalize into a normal mux to introduce more folds.
if (op.inputs().size() != 2)
return failure();
// multibit_mux(index, {lhs, rhs}) -> mux(index==0, lhs, rhs)
Value zero = rewriter.create<ConstantOp>(
op.getLoc(), op.index().getType().cast<IntType>(), APInt(1, 0));
Value cond = rewriter.createOrFold<EQPrimOp>(op.getLoc(), op.index(), zero);
rewriter.replaceOpWithNewOp<MuxPrimOp>(op, cond, op.inputs()[0],
op.inputs()[1]);
return success();
}
//===----------------------------------------------------------------------===//
// Declarations
//===----------------------------------------------------------------------===//

View File

@ -2189,55 +2189,6 @@ FIRRTLType SubaccessOp::inferReturnType(ValueRange operands,
return {};
}
static ParseResult parseMultibitMuxOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType index;
llvm::SmallVector<OpAsmParser::OperandType, 16> inputs;
Type indexType, elemType;
if (parser.parseOperand(index) || parser.parseComma() ||
parser.parseOperandList(inputs) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseType(indexType) || parser.parseComma() ||
parser.parseType(elemType))
return failure();
if (parser.resolveOperand(index, indexType, result.operands))
return failure();
result.addTypes(elemType);
return parser.resolveOperands(inputs, elemType, result.operands);
}
static void printMultibitMuxOp(OpAsmPrinter &p, MultibitMuxOp op) {
p << " " << op.index() << ", ";
p.printOperands(op.inputs());
p.printOptionalAttrDict(op->getAttrs());
p << " : " << op.index().getType() << ", " << op.getType();
}
FIRRTLType MultibitMuxOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
Optional<Location> loc) {
if (operands.size() < 2) {
if (loc)
mlir::emitError(*loc, "at least one input is required");
return FIRRTLType();
}
// Check all mux inputs have the same type.
if (!llvm::all_of(operands.drop_front(2), [&](auto op) {
return operands[1].getType() == op.getType();
})) {
if (loc)
mlir::emitError(*loc, "all inputs must have the same type");
return FIRRTLType();
}
return operands[1].getType().cast<FIRRTLType>();
}
//===----------------------------------------------------------------------===//
// Binary Primitives
//===----------------------------------------------------------------------===//

View File

@ -339,7 +339,6 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
bool visitDecl(RegResetOp op);
bool visitExpr(InvalidValueOp op);
bool visitExpr(SubaccessOp op);
bool visitExpr(MultibitMuxOp op);
bool visitExpr(MuxPrimOp op);
bool visitExpr(mlir::UnrealizedConversionCastOp op);
bool visitExpr(BitCastOp op);
@ -1276,31 +1275,24 @@ bool TypeLoweringVisitor::visitExpr(SubaccessOp op) {
return true;
}
// Construct a multibit mux
SmallVector<Value> inputs;
inputs.reserve(vType.getNumElements());
for (unsigned index : llvm::seq(0u, vType.getNumElements()))
inputs.push_back(builder->create<SubindexOp>(input, index));
// Reads. All writes have been eliminated before now
auto selectWidth = llvm::Log2_64_Ceil(vType.getNumElements());
Value multibitMux = builder->create<MultibitMuxOp>(op.index(), inputs);
op.replaceAllUsesWith(multibitMux);
// We have at least one element
Value mux = builder->create<SubindexOp>(input, 0);
// Build up the mux
for (size_t index = 1, e = vType.getNumElements(); index < e; ++index) {
auto cond = builder->create<EQPrimOp>(
op.index(), builder->createOrFold<ConstantOp>(
UIntType::get(op.getContext(), selectWidth),
APInt(selectWidth, index)));
auto access = builder->create<SubindexOp>(input, index);
mux = builder->create<MuxPrimOp>(cond, access, mux);
}
op.replaceAllUsesWith(mux);
return true;
}
bool TypeLoweringVisitor::visitExpr(MultibitMuxOp op) {
auto clone = [&](FlatBundleFieldEntry field, StringRef name,
ArrayAttr attrs) -> Operation * {
SmallVector<Value> newInputs;
newInputs.reserve(op.inputs().size());
for (auto input : op.inputs()) {
auto inputSub = getSubWhatever(input, field.index);
newInputs.push_back(inputSub);
}
return builder->create<MultibitMuxOp>(op.index(), newInputs);
};
return lowerProducer(op, clone);
}
//===----------------------------------------------------------------------===//
// Pass Infrastructure
//===----------------------------------------------------------------------===//

View File

@ -315,15 +315,10 @@ firrtl.circuit "Simple" attributes {annotations = [{class =
// CHECK: [[SEXT:%.+]] = comb.concat {{.*}}, %in3 : i1, i8
// CHECK: = comb.sub %c0_i9, [[SEXT]] : i9
%54 = firrtl.neg %in3 : (!firrtl.sint<8>) -> !firrtl.sint<9>
// CHECK: hw.output %false, %false : i1, i1
firrtl.connect %out1, %53 : !firrtl.sint<1>, !firrtl.sint<1>
%55 = firrtl.neg %in5 : (!firrtl.sint<0>) -> !firrtl.sint<1>
%61 = firrtl.multibit_mux %17, %55, %55, %55 : !firrtl.uint<1>, !firrtl.sint<1>
// CHECK: %[[ZEXT_INDEX:.+]] = comb.concat %false, {{.*}} : i1, i1
// CHECK-NEXT: %[[ARRAY:.+]] = hw.array_create %false, %false, %false
// CHECK-NEXT: %[[ARRAY_GET:.+]] = hw.array_get %[[ARRAY]][%[[ZEXT_INDEX]]]
// CHECK: hw.output %false, %[[ARRAY_GET]] : i1, i1
firrtl.connect %out2, %61 : !firrtl.sint<1>, !firrtl.sint<1>
firrtl.connect %out2, %55 : !firrtl.sint<1>, !firrtl.sint<1>
}
// module Print :

View File

@ -440,8 +440,6 @@ firrtl.module @Head(in %in4u: !firrtl.uint<4>,
// CHECK-LABEL: firrtl.module @Mux
firrtl.module @Mux(in %in: !firrtl.uint<4>,
in %cond: !firrtl.uint<1>,
in %val1: !firrtl.uint<1>,
in %val2: !firrtl.uint<1>,
out %out: !firrtl.uint<4>,
out %out1: !firrtl.uint<1>) {
// CHECK: firrtl.connect %out, %in
@ -468,20 +466,6 @@ firrtl.module @Mux(in %in: !firrtl.uint<4>,
%invalid_ui1 = firrtl.invalidvalue : !firrtl.uint<1>
%8 = firrtl.mux (%invalid_ui1, %in, %c7_ui4) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
firrtl.connect %out, %8 : !firrtl.uint<4>, !firrtl.uint<4>
%9 = firrtl.multibit_mux %c1_ui1, %c0_ui1, %cond : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK-NEXT: firrtl.connect %out1, %cond
firrtl.connect %out1, %9 : !firrtl.uint<1>, !firrtl.uint<1>
%10 = firrtl.multibit_mux %cond, %val1, %val2 : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK-NEXT: %[[NOT_COND:.+]] = firrtl.not %cond
// CHECK-NEXT: %[[MUX:.+]] = firrtl.mux(%0, %val1, %val2)
// CHECK-NEXT: firrtl.connect %out1, %[[MUX]]
firrtl.connect %out1, %10 : !firrtl.uint<1>, !firrtl.uint<1>
%11 = firrtl.multibit_mux %cond, %val1, %val1, %val1 : !firrtl.uint<1>, !firrtl.uint<1>
// CHECK-NEXT: firrtl.connect %out1, %val1
firrtl.connect %out1, %11 : !firrtl.uint<1>, !firrtl.uint<1>
}
// CHECK-LABEL: firrtl.module @Pad

View File

@ -1026,10 +1026,14 @@ firrtl.circuit "TopLevel" {
}
// CHECK-LABEL: firrtl.module @multidimRead(in %a_0_0: !firrtl.uint<2>, in %a_0_1: !firrtl.uint<2>, in %a_1_0: !firrtl.uint<2>, in %a_1_1: !firrtl.uint<2>, in %sel: !firrtl.uint<2>, out %b: !firrtl.uint<2>) {
// CHECK-NEXT: %0 = firrtl.multibit_mux %sel, %a_0_0, %a_1_0 : !firrtl.uint<2>, !firrtl.uint<2>
// CHECK-NEXT: %1 = firrtl.multibit_mux %sel, %a_0_1, %a_1_1 : !firrtl.uint<2>, !firrtl.uint<2>
// CHECK-NEXT: %2 = firrtl.multibit_mux %sel, %0, %1 : !firrtl.uint<2>, !firrtl.uint<2>
// CHECK-NEXT: firrtl.connect %b, %2 : !firrtl.uint<2>, !firrtl.uint<2>
// CHECK-NEXT: %c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
// CHECK-NEXT: %0 = firrtl.eq %sel, %c1_ui1 : (!firrtl.uint<2>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK-NEXT: %1 = firrtl.mux(%0, %a_1_0, %a_0_0) : (!firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2>
// CHECK-NEXT: %2 = firrtl.mux(%0, %a_1_1, %a_0_1) : (!firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2>
// CHECK-NEXT: %c1_ui1_0 = firrtl.constant 1 : !firrtl.uint<1>
// CHECK-NEXT: %3 = firrtl.eq %sel, %c1_ui1_0 : (!firrtl.uint<2>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK-NEXT: %4 = firrtl.mux(%3, %2, %1) : (!firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2>
// CHECK-NEXT: firrtl.connect %b, %4 : !firrtl.uint<2>, !firrtl.uint<2>
// CHECK-NEXT: }
// module Foo:
@ -1160,14 +1164,22 @@ firrtl.circuit "TopLevel" {
}
// CHECK-LABEL: firrtl.module @multiSubaccess(in %a_0_0: !firrtl.uint<2>, in %a_0_1: !firrtl.uint<2>, in %a_1_0: !firrtl.uint<2>, in %a_1_1: !firrtl.uint<2>, in %sel1: !firrtl.uint<1>, in %sel2: !firrtl.uint<1>, out %b: !firrtl.uint<2>, out %c: !firrtl.uint<2>) {
// CHECK-NEXT: %0 = firrtl.multibit_mux %sel1, %a_0_0, %a_1_0 : !firrtl.uint<1>, !firrtl.uint<2>
// CHECK-NEXT: %1 = firrtl.multibit_mux %sel1, %a_0_1, %a_1_1 : !firrtl.uint<1>, !firrtl.uint<2>
// CHECK-NEXT: %2 = firrtl.multibit_mux %sel1, %0, %1 : !firrtl.uint<1>, !firrtl.uint<2>
// CHECK-NEXT: firrtl.connect %b, %2 : !firrtl.uint<2>, !firrtl.uint<2>
// CHECK-NEXT: %3 = firrtl.multibit_mux %sel1, %a_0_0, %a_1_0 : !firrtl.uint<1>, !firrtl.uint<2>
// CHECK-NEXT: %4 = firrtl.multibit_mux %sel1, %a_0_1, %a_1_1 : !firrtl.uint<1>, !firrtl.uint<2>
// CHECK-NEXT: %5 = firrtl.multibit_mux %sel2, %3, %4 : !firrtl.uint<1>, !firrtl.uint<2>
// CHECK-NEXT: firrtl.connect %c, %5 : !firrtl.uint<2>, !firrtl.uint<2>
// CHECK-NEXT: %c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
// CHECK-NEXT: %0 = firrtl.eq %sel1, %c1_ui1 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK-NEXT: %1 = firrtl.mux(%0, %a_1_0, %a_0_0) : (!firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2>
// CHECK-NEXT: %2 = firrtl.mux(%0, %a_1_1, %a_0_1) : (!firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2>
// CHECK-NEXT: %c1_ui1_0 = firrtl.constant 1 : !firrtl.uint<1>
// CHECK-NEXT: %3 = firrtl.eq %sel1, %c1_ui1_0 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK-NEXT: %4 = firrtl.mux(%3, %2, %1) : (!firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2>
// CHECK-NEXT: firrtl.connect %b, %4 : !firrtl.uint<2>, !firrtl.uint<2>
// CHECK-NEXT: %c1_ui1_1 = firrtl.constant 1 : !firrtl.uint<1>
// CHECK-NEXT: %5 = firrtl.eq %sel1, %c1_ui1_1 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK-NEXT: %6 = firrtl.mux(%5, %a_1_0, %a_0_0) : (!firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2>
// CHECK-NEXT: %7 = firrtl.mux(%5, %a_1_1, %a_0_1) : (!firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2>
// CHECK-NEXT: %c1_ui1_2 = firrtl.constant 1 : !firrtl.uint<1>
// CHECK-NEXT: %8 = firrtl.eq %sel2, %c1_ui1_2 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK-NEXT: %9 = firrtl.mux(%8, %7, %6) : (!firrtl.uint<1>, !firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2>
// CHECK-NEXT: firrtl.connect %c, %9 : !firrtl.uint<2>, !firrtl.uint<2>
// CHECK-NEXT: }
@ -1406,10 +1418,18 @@ firrtl.module @is1436_FOO() {
firrtl.module @Issue2315(in %x: !firrtl.vector<uint<10>, 5>, in %source: !firrtl.uint<2>, out %z: !firrtl.uint<10>) {
%0 = firrtl.subaccess %x[%source] : !firrtl.vector<uint<10>, 5>, !firrtl.uint<2>
firrtl.connect %z, %0 : !firrtl.uint<10>, !firrtl.uint<10>
// The width of multibit mux index will be converted at LowerToHW,
// so it is ok that the type of `%source` is uint<2> here.
// CHECK: %0 = firrtl.multibit_mux %source, %x_0, %x_1, %x_2, %x_3, %x_4 : !firrtl.uint<2>, !firrtl.uint<10>
// CHECK-NEXT: firrtl.connect %z, %0 : !firrtl.uint<10>, !firrtl.uint<10>
// CHECK-NEXT: [[IDX:%.+]] = firrtl.constant 1
// CHECK-NEXT: [[EQ:%.+]] = firrtl.eq %source, [[IDX]]
// CHECK-NEXT: firrtl.mux([[EQ]], %x_1, %x_0)
// CHECK-NEXT: [[IDX:%.+]] = firrtl.constant 2
// CHECK-NEXT: [[EQ:%.+]] = firrtl.eq %source, [[IDX]]
// CHECK-NEXT: firrtl.mux([[EQ]], %x_2,
// CHECK-NEXT: [[IDX:%.+]] = firrtl.constant 3
// CHECK-NEXT: [[EQ:%.+]] = firrtl.eq %source, [[IDX]]
// CHECK-NEXT: firrtl.mux([[EQ]], %x_3,
// CHECK-NEXT: [[IDX:%.+]] = firrtl.constant 4
// CHECK-NEXT: [[EQ:%.+]] = firrtl.eq %source, [[IDX]]
// CHECK-NEXT: firrtl.mux([[EQ]], %x_4,
}
} // CIRCUIT