[mlir][bufferize][NFC] Refactor createAlloc function signature

Pass a ValueRange instead of an ArrayRef<Value> for better compatibility. Also provide an additional function overload that automatically deallocates the buffer if specified.

Differential Revision: https://reviews.llvm.org/D118025
This commit is contained in:
Matthias Springer 2022-01-24 20:18:40 +09:00
parent e5147f82e1
commit b2499bf3e8
3 changed files with 37 additions and 7 deletions

View File

@ -36,8 +36,8 @@ class BufferizationState;
/// Options for ComprehensiveBufferize.
struct BufferizationOptions {
using AllocationFn = std::function<FailureOr<Value>(
OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
using AllocationFn = std::function<FailureOr<Value>(OpBuilder &, Location,
MemRefType, ValueRange)>;
using DeallocationFn =
std::function<LogicalResult(OpBuilder &, Location, Value)>;
using MemCpyFn =
@ -298,15 +298,23 @@ UnrankedMemRefType getUnrankedMemRefType(Type elementType,
MemRefType getDynamicMemRefType(RankedTensorType tensorType,
unsigned addressSpace = 0);
/// Creates a memref allocation.
/// Creates a memref allocation with the given type and dynamic extents.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
ArrayRef<Value> dynShape,
ValueRange dynShape,
const BufferizationOptions &options);
/// Creates a memref allocation with the given type and dynamic extents. If
/// `createDealloc`, a deallocation op is inserted at the point where the
/// allocation goes out of scope.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
ValueRange dynShape, bool deallocMemref,
const BufferizationOptions &options);
/// Creates a memref allocation for the given shaped value. This function may
/// perform additional optimizations such as buffer allocation hoisting. If
/// `createDealloc`, a deallocation op is inserted at the point where the
/// allocation goes out of scope.
// TODO: Allocation hoisting should be a cleanup pass.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
bool deallocMemref,
const BufferizationOptions &options);

View File

@ -433,10 +433,10 @@ bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue,
return casted;
}
/// Create a memref allocation.
/// Create a memref allocation with the given type and dynamic extents.
FailureOr<Value>
bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
ArrayRef<Value> dynShape,
ValueRange dynShape,
const BufferizationOptions &options) {
if (options.allocationFn)
return (*options.allocationFn)(b, loc, type, dynShape);
@ -447,6 +447,28 @@ bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
return allocated;
}
/// Create a memref allocation with the given type and dynamic extents. May also
/// deallocate the memref again.
FailureOr<Value>
bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
ValueRange dynShape, bool deallocMemref,
const BufferizationOptions &options) {
OpBuilder::InsertionGuard g(b);
FailureOr<Value> alloc = createAlloc(b, loc, type, dynShape, options);
if (failed(alloc))
return failure();
if (deallocMemref) {
// Dealloc at the end of the block.
b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator());
if (failed(createDealloc(b, loc, *alloc, options)))
return failure();
}
return alloc;
}
/// Create a memref deallocation.
LogicalResult
bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,

View File

@ -73,7 +73,7 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
MemRefType type,
ArrayRef<Value> dynShape) {
ValueRange dynShape) {
Value allocated = b.create<memref::AllocaOp>(
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated;