[MLIR] Add affine.load fold hook on global constant memrefs
Fold affine.load ops on global constant memrefs when indices are all constant. Reviewed By: ayzhuang Differential Revision: https://reviews.llvm.org/D120612
This commit is contained in:
parent
f02550bdd9
commit
54691a58db
|
@ -825,7 +825,7 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
|
|||
The `memref.global` operation declares or defines a named global memref
|
||||
variable. The backing memory for the variable is allocated statically and is
|
||||
described by the type of the variable (which should be a statically shaped
|
||||
memref type). The operation is a declaration if no `inital_value` is
|
||||
memref type). The operation is a declaration if no `initial_value` is
|
||||
specified, else it is a definition. The `initial_value` can either be a unit
|
||||
attribute to represent a definition of an uninitialized global variable, or
|
||||
an elements attribute to represent the definition of a global variable with
|
||||
|
@ -878,6 +878,9 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
|
|||
bool isUninitialized() {
|
||||
return !isExternal() && initial_value().getValue().isa<UnitAttr>();
|
||||
}
|
||||
/// Returns the constant initial value if the memref.global is a constant,
|
||||
/// or null otherwise.
|
||||
ElementsAttr getConstantInitValue();
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
|
|
@ -8,17 +8,13 @@
|
|||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -2400,7 +2396,30 @@ OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
|
|||
/// load(memrefcast) -> load
|
||||
if (succeeded(foldMemRefCast(*this)))
|
||||
return getResult();
|
||||
return OpFoldResult();
|
||||
|
||||
// Fold load from a global constant memref.
|
||||
auto getGlobalOp = memref().getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return {};
|
||||
// Get to the memref.global defining the symbol.
|
||||
auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
|
||||
if (!symbolTableOp)
|
||||
return {};
|
||||
auto global = dyn_cast_or_null<memref::GlobalOp>(
|
||||
SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.nameAttr()));
|
||||
if (!global)
|
||||
return {};
|
||||
if (auto cstAttr =
|
||||
global.getConstantInitValue().dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
// We can fold only if we know the indices.
|
||||
if (!getAffineMap().isConstant())
|
||||
return {};
|
||||
auto indices = llvm::to_vector<4>(
|
||||
llvm::map_range(getAffineMap().getConstantResults(),
|
||||
[](int64_t v) -> uint64_t { return v; }));
|
||||
return cstAttr.getValues<Attribute>()[indices];
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1275,6 +1275,13 @@ LogicalResult GlobalOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
ElementsAttr GlobalOp::getConstantInitValue() {
|
||||
auto initVal = initial_value();
|
||||
if (constant() && initVal.hasValue())
|
||||
return initVal.getValue().cast<ElementsAttr>();
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetGlobalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1080,3 +1080,21 @@ func @canonicalize_single_min_max(%i0: index, %i1: index) -> (index, index) {
|
|||
|
||||
return %0, %1: index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
memref.global "private" constant @__constant_1x5x1xf32 : memref<1x5x1xf32> = dense<[[[6.250000e-02], [2.500000e-01], [3.750000e-01], [2.500000e-01], [6.250000e-02]]]>
|
||||
// CHECK-LABEL: func @fold_const_init_global_memref
|
||||
func @fold_const_init_global_memref() -> (f32, f32) {
|
||||
%m = memref.get_global @__constant_1x5x1xf32 : memref<1x5x1xf32>
|
||||
%v0 = affine.load %m[0, 0, 0] : memref<1x5x1xf32>
|
||||
%v1 = affine.load %m[0, 1, 0] : memref<1x5x1xf32>
|
||||
return %v0, %v1 : f32, f32
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 6.250000e-02 : f32
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 2.500000e-01 : f32
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-SAME: %[[C0]]
|
||||
// CHECK-SAME: %[[C1]]
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue