From b2499bf3e851c67ef623766b922de520de9235d5 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 24 Jan 2022 20:18:40 +0900 Subject: [PATCH] [mlir][bufferize][NFC] Refactor createAlloc function signature Pass a ValueRange instead of an ArrayRef for better compatibility. Also provide an additional function overload that automatically deallocates the buffer if specified. Differential Revision: https://reviews.llvm.org/D118025 --- .../IR/BufferizableOpInterface.h | 16 +++++++++--- .../IR/BufferizableOpInterface.cpp | 26 +++++++++++++++++-- .../Transforms/ComprehensiveBufferizePass.cpp | 2 +- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index f679a22fa7a6..bbac6e59aeeb 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -36,8 +36,8 @@ class BufferizationState; /// Options for ComprehensiveBufferize. struct BufferizationOptions { - using AllocationFn = std::function( - OpBuilder &, Location, MemRefType, ArrayRef)>; + using AllocationFn = std::function(OpBuilder &, Location, + MemRefType, ValueRange)>; using DeallocationFn = std::function; using MemCpyFn = @@ -298,15 +298,23 @@ UnrankedMemRefType getUnrankedMemRefType(Type elementType, MemRefType getDynamicMemRefType(RankedTensorType tensorType, unsigned addressSpace = 0); -/// Creates a memref allocation. +/// Creates a memref allocation with the given type and dynamic extents. FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, - ArrayRef dynShape, + ValueRange dynShape, + const BufferizationOptions &options); + +/// Creates a memref allocation with the given type and dynamic extents. If +/// `createDealloc`, a deallocation op is inserted at the point where the +/// allocation goes out of scope. +FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, + ValueRange dynShape, bool deallocMemref, const BufferizationOptions &options); /// Creates a memref allocation for the given shaped value. This function may /// perform additional optimizations such as buffer allocation hoisting. If /// `createDealloc`, a deallocation op is inserted at the point where the /// allocation goes out of scope. +// TODO: Allocation hoisting should be a cleanup pass. FailureOr createAlloc(OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref, const BufferizationOptions &options); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index fb081d3d6c3c..e565f41a39d5 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -433,10 +433,10 @@ bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue, return casted; } -/// Create a memref allocation. +/// Create a memref allocation with the given type and dynamic extents. FailureOr bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, - ArrayRef dynShape, + ValueRange dynShape, const BufferizationOptions &options) { if (options.allocationFn) return (*options.allocationFn)(b, loc, type, dynShape); @@ -447,6 +447,28 @@ bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, return allocated; } +/// Create a memref allocation with the given type and dynamic extents. May also +/// deallocate the memref again. +FailureOr +bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, + ValueRange dynShape, bool deallocMemref, + const BufferizationOptions &options) { + OpBuilder::InsertionGuard g(b); + + FailureOr alloc = createAlloc(b, loc, type, dynShape, options); + if (failed(alloc)) + return failure(); + + if (deallocMemref) { + // Dealloc at the end of the block. + b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator()); + if (failed(createDealloc(b, loc, *alloc, options))) + return failure(); + } + + return alloc; +} + /// Create a memref deallocation. LogicalResult bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp index 3c8b9c960695..9409492e12db 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -73,7 +73,7 @@ static void applyEnablingTransformations(ModuleOp moduleOp) { static FailureOr allocationFnUsingAlloca(OpBuilder &b, Location loc, MemRefType type, - ArrayRef dynShape) { + ValueRange dynShape) { Value allocated = b.create( loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); return allocated;