[mlir][bufferize][NFC] Refactor createAlloc function signature
Pass a ValueRange instead of an ArrayRef<Value> for better compatibility. Also provide an additional function overload that automatically deallocates the buffer if specified. Differential Revision: https://reviews.llvm.org/D118025
This commit is contained in:
parent
e5147f82e1
commit
b2499bf3e8
|
@ -36,8 +36,8 @@ class BufferizationState;
|
||||||
|
|
||||||
/// Options for ComprehensiveBufferize.
|
/// Options for ComprehensiveBufferize.
|
||||||
struct BufferizationOptions {
|
struct BufferizationOptions {
|
||||||
using AllocationFn = std::function<FailureOr<Value>(
|
using AllocationFn = std::function<FailureOr<Value>(OpBuilder &, Location,
|
||||||
OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
|
MemRefType, ValueRange)>;
|
||||||
using DeallocationFn =
|
using DeallocationFn =
|
||||||
std::function<LogicalResult(OpBuilder &, Location, Value)>;
|
std::function<LogicalResult(OpBuilder &, Location, Value)>;
|
||||||
using MemCpyFn =
|
using MemCpyFn =
|
||||||
|
@ -298,15 +298,23 @@ UnrankedMemRefType getUnrankedMemRefType(Type elementType,
|
||||||
MemRefType getDynamicMemRefType(RankedTensorType tensorType,
|
MemRefType getDynamicMemRefType(RankedTensorType tensorType,
|
||||||
unsigned addressSpace = 0);
|
unsigned addressSpace = 0);
|
||||||
|
|
||||||
/// Creates a memref allocation.
|
/// Creates a memref allocation with the given type and dynamic extents.
|
||||||
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||||
ArrayRef<Value> dynShape,
|
ValueRange dynShape,
|
||||||
|
const BufferizationOptions &options);
|
||||||
|
|
||||||
|
/// Creates a memref allocation with the given type and dynamic extents. If
|
||||||
|
/// `createDealloc`, a deallocation op is inserted at the point where the
|
||||||
|
/// allocation goes out of scope.
|
||||||
|
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||||
|
ValueRange dynShape, bool deallocMemref,
|
||||||
const BufferizationOptions &options);
|
const BufferizationOptions &options);
|
||||||
|
|
||||||
/// Creates a memref allocation for the given shaped value. This function may
|
/// Creates a memref allocation for the given shaped value. This function may
|
||||||
/// perform additional optimizations such as buffer allocation hoisting. If
|
/// perform additional optimizations such as buffer allocation hoisting. If
|
||||||
/// `createDealloc`, a deallocation op is inserted at the point where the
|
/// `createDealloc`, a deallocation op is inserted at the point where the
|
||||||
/// allocation goes out of scope.
|
/// allocation goes out of scope.
|
||||||
|
// TODO: Allocation hoisting should be a cleanup pass.
|
||||||
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
|
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
|
||||||
bool deallocMemref,
|
bool deallocMemref,
|
||||||
const BufferizationOptions &options);
|
const BufferizationOptions &options);
|
||||||
|
|
|
@ -433,10 +433,10 @@ bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue,
|
||||||
return casted;
|
return casted;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a memref allocation.
|
/// Create a memref allocation with the given type and dynamic extents.
|
||||||
FailureOr<Value>
|
FailureOr<Value>
|
||||||
bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||||
ArrayRef<Value> dynShape,
|
ValueRange dynShape,
|
||||||
const BufferizationOptions &options) {
|
const BufferizationOptions &options) {
|
||||||
if (options.allocationFn)
|
if (options.allocationFn)
|
||||||
return (*options.allocationFn)(b, loc, type, dynShape);
|
return (*options.allocationFn)(b, loc, type, dynShape);
|
||||||
|
@ -447,6 +447,28 @@ bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||||
return allocated;
|
return allocated;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a memref allocation with the given type and dynamic extents. May also
|
||||||
|
/// deallocate the memref again.
|
||||||
|
FailureOr<Value>
|
||||||
|
bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
|
||||||
|
ValueRange dynShape, bool deallocMemref,
|
||||||
|
const BufferizationOptions &options) {
|
||||||
|
OpBuilder::InsertionGuard g(b);
|
||||||
|
|
||||||
|
FailureOr<Value> alloc = createAlloc(b, loc, type, dynShape, options);
|
||||||
|
if (failed(alloc))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (deallocMemref) {
|
||||||
|
// Dealloc at the end of the block.
|
||||||
|
b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator());
|
||||||
|
if (failed(createDealloc(b, loc, *alloc, options)))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
return alloc;
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a memref deallocation.
|
/// Create a memref deallocation.
|
||||||
LogicalResult
|
LogicalResult
|
||||||
bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
|
bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
|
||||||
|
|
|
@ -73,7 +73,7 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
|
||||||
|
|
||||||
static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
|
static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
|
||||||
MemRefType type,
|
MemRefType type,
|
||||||
ArrayRef<Value> dynShape) {
|
ValueRange dynShape) {
|
||||||
Value allocated = b.create<memref::AllocaOp>(
|
Value allocated = b.create<memref::AllocaOp>(
|
||||||
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
|
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
|
||||||
return allocated;
|
return allocated;
|
||||||
|
|
Loading…
Reference in New Issue