[HW] Canonicalize array get on uniform arrays (#3842)

Add a `getUniformElement()` helper function to `ArrayCreateOp` which
checks if all elements of the created array are identical and returns
that uniform value. Add a folder for `ArrayGetOp` that simply forwards
this uniform value if it exists.

Fixes #3841.
This commit is contained in:
Fabian Schuiki 2022-09-08 14:30:30 -07:00 committed by GitHub
parent 44793691b4
commit 4d92ac1aa2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 1 deletions

View File

@ -39,6 +39,14 @@ def ArrayCreateOp : HWOp<"array_create", [NoSideEffect, SameTypeOperands]> {
// ValueRange needs to contain at least one element.
OpBuilder<(ins "ValueRange":$input)>
];
let extraClassDeclaration = [{
/// If the all elements of the array are identical, returns that element
/// value. Otherwise returns a null value.
Value getUniformElement();
/// Returns true if all array elements are identical.
bool isUniform() { return !!getUniformElement(); }
}];
}
def ArrayConcatOp : HWOp<"array_concat", [NoSideEffect]> {

View File

@ -1710,6 +1710,12 @@ LogicalResult ArrayCreateOp::verify() {
return success();
}
Value ArrayCreateOp::getUniformElement() {
if (!getInputs().empty() && llvm::all_equal(getInputs()))
return getInputs()[0];
return {};
}
static Optional<uint64_t> getUIntFromValue(Value value) {
auto idxOp = dyn_cast_or_null<ConstantOp>(value.getDefiningOp());
if (!idxOp)
@ -2287,13 +2293,18 @@ void ArrayGetOp::build(OpBuilder &builder, OperationState &result, Value input,
}
// An array_get of an array_create with a constant index can just be the
// array_create operand at the constant index.
// array_create operand at the constant index. If the array_create has a single
// uniform value for each element, just return that value regardless of the
// index.
OpFoldResult ArrayGetOp::fold(ArrayRef<Attribute> operands) {
auto inputCreate =
dyn_cast_or_null<ArrayCreateOp>(getInput().getDefiningOp());
if (!inputCreate)
return {};
if (auto uniformValue = inputCreate.getUniformElement())
return uniformValue;
IntegerAttr constIdx = operands[1].dyn_cast_or_null<IntegerAttr>();
if (!constIdx || constIdx.getValue().getBitWidth() > 64)
return {};

View File

@ -1467,3 +1467,11 @@ hw.module @SliceOfCreate(%a0: i1, %a1: i1, %a2: i1, %a3: i1) -> (out: !hw.array<
hw.output %slice : !hw.array<2xi1>
}
// CHECK-LABEL: hw.module @GetOfUniformArray
hw.module @GetOfUniformArray(%in: i42, %address: i2) -> (out: i42) {
// CHECK: hw.output %in : i42
%0 = hw.array_create %in, %in, %in, %in : i42
%1 = hw.array_get %0[%address] : !hw.array<4xi42>
hw.output %1 : i42
}