[mlir][linalg][bufferize][NFC] Split analysis-related code from BufferizationState/Options

This separates the analysis (and its helpers/data structures) more clearly from the rest of the bufferization.

Differential Revision: https://reviews.llvm.org/D117477
This commit is contained in:
Matthias Springer 2022-01-19 18:58:36 +09:00
parent 19c3026891
commit 31355482e5
14 changed files with 111 additions and 91 deletions

View File

@ -68,14 +68,6 @@ struct BufferizationOptions {
// BufferizationOptions cannot be copied.
BufferizationOptions(const BufferizationOptions &other) = delete;
/// Register a "post analysis" step. Such steps are executed after the
/// analysis, but before bufferization.
template <typename Step, typename... Args>
void addPostAnalysisStep(Args... args) {
postAnalysisSteps.emplace_back(
std::make_unique<Step>(std::forward<Args>(args)...));
}
/// Return `true` if the op is allowed to be bufferized.
bool isOpAllowed(Operation *op) const {
if (!dialectFilter.hasValue())
@ -134,9 +126,6 @@ struct BufferizationOptions {
/// For debugging only. Should be used together with `testAnalysisOnly`.
bool printConflicts = false;
/// Registered post analysis steps.
PostAnalysisStepList postAnalysisSteps;
/// Only bufferize ops from dialects that are allowed-listed by the filter.
/// All other ops are ignored. This option controls the scope of partial
/// bufferization.
@ -157,6 +146,25 @@ private:
}
};
/// Options for analysis-enabled bufferization.
struct AnalysisBufferizationOptions : public BufferizationOptions {
AnalysisBufferizationOptions() = default;
// AnalysisBufferizationOptions cannot be copied.
AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete;
/// Register a "post analysis" step. Such steps are executed after the
/// analysis, but before bufferization.
template <typename Step, typename... Args>
void addPostAnalysisStep(Args... args) {
postAnalysisSteps.emplace_back(
std::make_unique<Step>(std::forward<Args>(args)...));
}
/// Registered post analysis steps.
PostAnalysisStepList postAnalysisSteps;
};
/// Specify fine-grain relationship between buffers to enable more analysis.
enum class BufferRelation {
None,
@ -198,11 +206,6 @@ public:
return equivalentInfo.isEquivalent(v1, v2);
}
/// Return true if `v1` and `v2` bufferize to aliasing buffers.
bool areAliasingBufferizedValues(Value v1, Value v2) const {
return aliasInfo.isEquivalent(v1, v2);
}
/// Union the alias sets of `v1` and `v2`.
void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); }
@ -276,11 +279,6 @@ struct DialectBufferizationState {
/// tensor values and memref buffers.
class BufferizationState {
public:
BufferizationState(Operation *op, const BufferizationOptions &options);
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;
/// Determine which OpOperand* will alias with `result` if the op is
/// bufferized in place. Return an empty vector if the op is not bufferizable.
SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) const;
@ -344,7 +342,10 @@ public:
SetVector<Value> findLastPrecedingWrite(Value value) const;
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpOperand &opOperand) const;
virtual bool isInPlace(OpOperand &opOperand) const = 0;
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const = 0;
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
@ -374,14 +375,15 @@ public:
/// Return a reference to the BufferizationOptions.
const BufferizationOptions &getOptions() const { return options; }
/// Return a reference to the BufferizationAliasInfo.
BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
protected:
BufferizationState(const BufferizationOptions &options);
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;
~BufferizationState() = default;
private:
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
/// functions and `runComprehensiveBufferize` may access this object.
BufferizationAliasInfo aliasInfo;
/// Dialect-specific bufferization state.
DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
@ -389,6 +391,33 @@ private:
const BufferizationOptions &options;
};
/// State for analysis-enabled bufferization. This class keeps track of alias
/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize
/// in-place.
class AnalysisBufferizationState : public BufferizationState {
public:
AnalysisBufferizationState(Operation *op,
const AnalysisBufferizationOptions &options);
AnalysisBufferizationState(const AnalysisBufferizationState &) = delete;
virtual ~AnalysisBufferizationState() = default;
/// Return a reference to the BufferizationAliasInfo.
BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpOperand &opOperand) const override;
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
private:
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
/// functions and `runComprehensiveBufferize` may access this object.
BufferizationAliasInfo aliasInfo;
};
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
@ -483,7 +512,6 @@ struct AllocationHoistingBarrierOnly
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::None;
}

View File

@ -183,7 +183,6 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"BufferRelation",
/*methodName=*/"bufferRelation",
/*args=*/(ins "OpResult":$opResult,
"const BufferizationAliasInfo &":$aliasInfo,
"const BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@ -284,8 +283,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodName=*/"isNotConflicting",
/*args=*/(ins "OpOperand *":$uRead,
"OpOperand *":$uWrite,
"const BufferizationState &":$state,
"const BufferizationAliasInfo &":$aliasInfo),
"const BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;

View File

@ -16,22 +16,22 @@ namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
class AnalysisBufferizationState;
class BufferizationAliasInfo;
struct BufferizationOptions;
struct AnalysisBufferizationOptions;
class BufferizationState;
/// Analyze `op` and its nested ops. Bufferization decisions are stored in
/// `state`.
LogicalResult analyzeOp(Operation *op, BufferizationState &state);
LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state);
/// Bufferize `op` and its nested ops. Bufferization decisions are stored in
/// `state`.
LogicalResult bufferizeOp(Operation *op, const BufferizationState &state);
/// Run Comprehensive Bufferize on the given op: Analysis + Bufferization
LogicalResult
runComprehensiveBufferize(Operation *op,
std::unique_ptr<BufferizationOptions> options);
LogicalResult runComprehensiveBufferize(
Operation *op, std::unique_ptr<AnalysisBufferizationOptions> options);
} // namespace comprehensive_bufferize
} // namespace linalg

View File

@ -20,14 +20,13 @@ class ModuleOp;
namespace linalg {
namespace comprehensive_bufferize {
struct BufferizationOptions;
struct AnalysisBufferizationOptions;
/// Run Module Bufferization on the given module. Performs a simple function
/// call analysis to determine which function arguments are inplaceable. Then
/// analyzes and bufferizes FuncOps one-by-one with Comprehensive Bufferization.
LogicalResult
runComprehensiveBufferize(ModuleOp moduleOp,
std::unique_ptr<BufferizationOptions> options);
LogicalResult runComprehensiveBufferize(
ModuleOp moduleOp, std::unique_ptr<AnalysisBufferizationOptions> options);
namespace std_ext {

View File

@ -288,8 +288,13 @@ llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
}
mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
Operation *op, const BufferizationOptions &options)
: aliasInfo(op), options(options) {
const BufferizationOptions &options)
: options(options) {}
mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
AnalysisBufferizationState(Operation *op,
const AnalysisBufferizationOptions &options)
: BufferizationState(options), aliasInfo(op) {
// Set up alias sets for OpResults that must bufferize in-place. This should
// be done before making any other bufferization decisions.
op->walk([&](BufferizableOpInterface bufferizableOp) {
@ -353,7 +358,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
Value operand = opOperand.get();
Value operandBuffer = lookupBuffer(rewriter, operand);
if (forceInPlace || aliasInfo.isInPlace(opOperand))
if (forceInPlace || isInPlace(opOperand))
return operandBuffer;
// Bufferizing out-of-place: Allocate a new buffer.
@ -597,11 +602,16 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
return isa<FuncOp>(bbArg.getOwner()->getParentOp());
}
bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
OpOperand &opOperand) const {
bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
isInPlace(OpOperand &opOperand) const {
return aliasInfo.isInPlace(opOperand);
}
bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
areEquivalentBufferizedValues(Value v1, Value v2) const {
return aliasInfo.areEquivalentBufferizedValues(v1, v2);
}
MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
ShapedType shapedType, MemRefLayoutAttrInterface layout,
Attribute memorySpace) {

View File

@ -252,15 +252,13 @@ static bool hasReadAfterWriteInterference(
// No conflict if the op interface says so.
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state,
aliasInfo))
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
continue;
if (conflictingWritingOp != readingOp)
if (auto bufferizableOp =
options.dynCastBufferizableOp(conflictingWritingOp))
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state,
aliasInfo))
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
continue;
// Ops are not conflicting if they are in mutually exclusive regions.
@ -496,7 +494,7 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
for (OpOperand *opOperand :
bufferizableOp.getAliasingOpOperand(opResult, state))
if (state.isInPlace(*opOperand))
if (bufferizableOp.bufferRelation(opResult, aliasInfo, state) ==
if (bufferizableOp.bufferRelation(opResult, state) ==
BufferRelation::Equivalent)
aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
}
@ -687,12 +685,12 @@ checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
return success();
}
LogicalResult
mlir::linalg::comprehensive_bufferize::analyzeOp(Operation *op,
BufferizationState &state) {
LogicalResult mlir::linalg::comprehensive_bufferize::analyzeOp(
Operation *op, AnalysisBufferizationState &state) {
DominanceInfo domInfo(op);
BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
const BufferizationOptions &options = state.getOptions();
const auto &options =
static_cast<const AnalysisBufferizationOptions &>(state.getOptions());
if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
return failure();
@ -740,8 +738,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
}
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
Operation *op, std::unique_ptr<BufferizationOptions> options) {
BufferizationState state(op, *options);
Operation *op, std::unique_ptr<AnalysisBufferizationOptions> options) {
AnalysisBufferizationState state(op, *options);
if (failed(analyzeOp(op, state)))
return failure();
if (options->testAnalysisOnly)

View File

@ -193,7 +193,6 @@ struct LinalgOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
@ -264,7 +263,6 @@ struct TiledLoopOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}

View File

@ -737,7 +737,6 @@ struct CallOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
@ -964,9 +963,9 @@ annotateOpsWithBufferizationMarkers(FuncOp funcOp,
}
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
ModuleOp moduleOp, std::unique_ptr<BufferizationOptions> options) {
ModuleOp moduleOp, std::unique_ptr<AnalysisBufferizationOptions> options) {
IRRewriter rewriter(moduleOp.getContext());
BufferizationState state(moduleOp, *options);
AnalysisBufferizationState state(moduleOp, *options);
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
BufferizationAliasInfo &aliasInfo = state.getAliasInfo();

View File

@ -124,7 +124,6 @@ struct ExecuteRegionOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
@ -247,7 +246,6 @@ struct IfOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
// IfOp results are equivalent to their corresponding yield values if both
// yield values are equivalent to each other.
@ -255,7 +253,7 @@ struct IfOpInterface
SmallVector<OpOperand *> yieldValues =
bufferizableOp.getAliasingOpOperand(opResult, state);
assert(yieldValues.size() == 2 && "expected 2 yield values");
bool equivalentYields = aliasInfo.areEquivalentBufferizedValues(
bool equivalentYields = state.areEquivalentBufferizedValues(
yieldValues[0]->get(), yieldValues[1]->get());
return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
}
@ -291,7 +289,6 @@ struct ForOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
// ForOp results are equivalent to their corresponding init_args if the
// corresponding iter_args and yield values are equivalent.
@ -299,7 +296,7 @@ struct ForOpInterface
OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
auto yieldOp = cast<scf::YieldOp>(&forOp.getLoopBody().front().back());
bool equivalentYield = aliasInfo.areEquivalentBufferizedValues(
bool equivalentYield = state.areEquivalentBufferizedValues(
bbArg, yieldOp->getOperand(opResult.getResultNumber()));
return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
}
@ -408,7 +405,9 @@ mlir::linalg::comprehensive_bufferize::scf_ext::AssertScfForAliasingProperties::
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
if (!aliasInfo.areAliasingBufferizedValues(operand.get(), bbArg)) {
// Note: This is overly strict. We should check for aliasing bufferized
// values. But we don't have a "must-alias" analysis yet.
if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
// TODO: this could get resolved with copies but it can also turn into
// swaps so we need to be careful about order of copies.
status =

View File

@ -62,7 +62,6 @@ struct SelectOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::None;
}

View File

@ -42,7 +42,6 @@ struct CastOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
@ -137,7 +136,6 @@ struct ExtractSliceOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::None;
}
@ -273,7 +271,6 @@ struct InsertOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
@ -285,12 +282,12 @@ struct InsertOpInterface
/// This is one particular type of relationship between ops on tensors that
/// reduce to an equivalence on buffers. This should be generalized and
/// exposed as interfaces on the proper types.
static bool
areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
ExtractSliceOp st, InsertSliceOp sti) {
static bool areEquivalentExtractSliceOps(const BufferizationState &state,
ExtractSliceOp st, InsertSliceOp sti) {
if (!st || !sti)
return false;
if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
if (sti != sti &&
!state.areEquivalentBufferizedValues(st.source(), sti.dest()))
return false;
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
return false;
@ -299,12 +296,11 @@ areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state,
static bool hasMatchingExtractSliceOp(const BufferizationState &state,
Value value, InsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
return true;
return false;
};
@ -336,15 +332,13 @@ struct InsertSliceOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
bool isNotConflicting(Operation *op, OpOperand *uRead,
OpOperand *uConflictingWrite,
const BufferizationState &state,
const BufferizationAliasInfo &aliasInfo) const {
const BufferizationState &state) const {
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
@ -360,7 +354,7 @@ struct InsertSliceOpInterface
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(aliasInfo, state, uConflictingWrite->get(),
hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
insertSliceOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
@ -378,8 +372,7 @@ struct InsertSliceOpInterface
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
hasMatchingExtractSliceOp(aliasInfo, state, uRead->get(),
insertSliceOp))
hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
@ -410,9 +403,9 @@ struct InsertSliceOpInterface
// memory segment of %1 with the exact same data. (Effectively, there
// is no memory write here.)
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
insertSliceOp.source()) &&
hasMatchingExtractSliceOp(aliasInfo, state, insertSliceOp.source(),
state.areEquivalentBufferizedValues(uRead->get(),
insertSliceOp.source()) &&
hasMatchingExtractSliceOp(state, insertSliceOp.source(),
insertSliceOp))
return true;

View File

@ -85,7 +85,6 @@ struct TransferWriteOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::Equivalent;
}

View File

@ -75,7 +75,7 @@ static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
}
void LinalgComprehensiveModuleBufferize::runOnOperation() {
auto options = std::make_unique<BufferizationOptions>();
auto options = std::make_unique<AnalysisBufferizationOptions>();
if (useAlloca) {
options->allocationFn = allocationFnUsingAlloca;
options->deallocationFn = [](OpBuilder &b, Location loc, Value v) {

View File

@ -101,7 +101,7 @@ struct TestComprehensiveFunctionBufferize
} // namespace
void TestComprehensiveFunctionBufferize::runOnOperation() {
auto options = std::make_unique<BufferizationOptions>();
auto options = std::make_unique<AnalysisBufferizationOptions>();
if (!allowReturnMemref)
options->addPostAnalysisStep<scf_ext::AssertScfForAliasingProperties>();