[mlir][linalg][bufferize][NFC] Split analysis-related code from BufferizationState/Options
This separates the analysis (and its helpers/data structures) more clearly from the rest of the bufferization. Differential Revision: https://reviews.llvm.org/D117477
This commit is contained in:
parent
19c3026891
commit
31355482e5
|
@ -68,14 +68,6 @@ struct BufferizationOptions {
|
|||
// BufferizationOptions cannot be copied.
|
||||
BufferizationOptions(const BufferizationOptions &other) = delete;
|
||||
|
||||
/// Register a "post analysis" step. Such steps are executed after the
|
||||
/// analysis, but before bufferization.
|
||||
template <typename Step, typename... Args>
|
||||
void addPostAnalysisStep(Args... args) {
|
||||
postAnalysisSteps.emplace_back(
|
||||
std::make_unique<Step>(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
/// Return `true` if the op is allowed to be bufferized.
|
||||
bool isOpAllowed(Operation *op) const {
|
||||
if (!dialectFilter.hasValue())
|
||||
|
@ -134,9 +126,6 @@ struct BufferizationOptions {
|
|||
/// For debugging only. Should be used together with `testAnalysisOnly`.
|
||||
bool printConflicts = false;
|
||||
|
||||
/// Registered post analysis steps.
|
||||
PostAnalysisStepList postAnalysisSteps;
|
||||
|
||||
/// Only bufferize ops from dialects that are allowed-listed by the filter.
|
||||
/// All other ops are ignored. This option controls the scope of partial
|
||||
/// bufferization.
|
||||
|
@ -157,6 +146,25 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
/// Options for analysis-enabled bufferization.
|
||||
struct AnalysisBufferizationOptions : public BufferizationOptions {
|
||||
AnalysisBufferizationOptions() = default;
|
||||
|
||||
// AnalysisBufferizationOptions cannot be copied.
|
||||
AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete;
|
||||
|
||||
/// Register a "post analysis" step. Such steps are executed after the
|
||||
/// analysis, but before bufferization.
|
||||
template <typename Step, typename... Args>
|
||||
void addPostAnalysisStep(Args... args) {
|
||||
postAnalysisSteps.emplace_back(
|
||||
std::make_unique<Step>(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
/// Registered post analysis steps.
|
||||
PostAnalysisStepList postAnalysisSteps;
|
||||
};
|
||||
|
||||
/// Specify fine-grain relationship between buffers to enable more analysis.
|
||||
enum class BufferRelation {
|
||||
None,
|
||||
|
@ -198,11 +206,6 @@ public:
|
|||
return equivalentInfo.isEquivalent(v1, v2);
|
||||
}
|
||||
|
||||
/// Return true if `v1` and `v2` bufferize to aliasing buffers.
|
||||
bool areAliasingBufferizedValues(Value v1, Value v2) const {
|
||||
return aliasInfo.isEquivalent(v1, v2);
|
||||
}
|
||||
|
||||
/// Union the alias sets of `v1` and `v2`.
|
||||
void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); }
|
||||
|
||||
|
@ -276,11 +279,6 @@ struct DialectBufferizationState {
|
|||
/// tensor values and memref buffers.
|
||||
class BufferizationState {
|
||||
public:
|
||||
BufferizationState(Operation *op, const BufferizationOptions &options);
|
||||
|
||||
// BufferizationState should be passed as a reference.
|
||||
BufferizationState(const BufferizationState &) = delete;
|
||||
|
||||
/// Determine which OpOperand* will alias with `result` if the op is
|
||||
/// bufferized in place. Return an empty vector if the op is not bufferizable.
|
||||
SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) const;
|
||||
|
@ -344,7 +342,10 @@ public:
|
|||
SetVector<Value> findLastPrecedingWrite(Value value) const;
|
||||
|
||||
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
||||
bool isInPlace(OpOperand &opOperand) const;
|
||||
virtual bool isInPlace(OpOperand &opOperand) const = 0;
|
||||
|
||||
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
|
||||
virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const = 0;
|
||||
|
||||
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
|
||||
/// a new buffer and copy over data from the existing buffer if out-of-place
|
||||
|
@ -374,14 +375,15 @@ public:
|
|||
/// Return a reference to the BufferizationOptions.
|
||||
const BufferizationOptions &getOptions() const { return options; }
|
||||
|
||||
/// Return a reference to the BufferizationAliasInfo.
|
||||
BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
|
||||
protected:
|
||||
BufferizationState(const BufferizationOptions &options);
|
||||
|
||||
// BufferizationState should be passed as a reference.
|
||||
BufferizationState(const BufferizationState &) = delete;
|
||||
|
||||
~BufferizationState() = default;
|
||||
|
||||
private:
|
||||
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
|
||||
/// functions and `runComprehensiveBufferize` may access this object.
|
||||
BufferizationAliasInfo aliasInfo;
|
||||
|
||||
/// Dialect-specific bufferization state.
|
||||
DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
|
||||
|
||||
|
@ -389,6 +391,33 @@ private:
|
|||
const BufferizationOptions &options;
|
||||
};
|
||||
|
||||
/// State for analysis-enabled bufferization. This class keeps track of alias
|
||||
/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize
|
||||
/// in-place.
|
||||
class AnalysisBufferizationState : public BufferizationState {
|
||||
public:
|
||||
AnalysisBufferizationState(Operation *op,
|
||||
const AnalysisBufferizationOptions &options);
|
||||
|
||||
AnalysisBufferizationState(const AnalysisBufferizationState &) = delete;
|
||||
|
||||
virtual ~AnalysisBufferizationState() = default;
|
||||
|
||||
/// Return a reference to the BufferizationAliasInfo.
|
||||
BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
|
||||
|
||||
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
||||
bool isInPlace(OpOperand &opOperand) const override;
|
||||
|
||||
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
|
||||
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
|
||||
|
||||
private:
|
||||
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
|
||||
/// functions and `runComprehensiveBufferize` may access this object.
|
||||
BufferizationAliasInfo aliasInfo;
|
||||
};
|
||||
|
||||
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
|
||||
/// must be replaced with memref values.
|
||||
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
|
||||
|
@ -483,7 +512,6 @@ struct AllocationHoistingBarrierOnly
|
|||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
|
|
@ -183,7 +183,6 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*retType=*/"BufferRelation",
|
||||
/*methodName=*/"bufferRelation",
|
||||
/*args=*/(ins "OpResult":$opResult,
|
||||
"const BufferizationAliasInfo &":$aliasInfo,
|
||||
"const BufferizationState &":$state),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
|
@ -284,8 +283,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*methodName=*/"isNotConflicting",
|
||||
/*args=*/(ins "OpOperand *":$uRead,
|
||||
"OpOperand *":$uWrite,
|
||||
"const BufferizationState &":$state,
|
||||
"const BufferizationAliasInfo &":$aliasInfo),
|
||||
"const BufferizationState &":$state),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return false;
|
||||
|
|
|
@ -16,22 +16,22 @@ namespace mlir {
|
|||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
|
||||
class AnalysisBufferizationState;
|
||||
class BufferizationAliasInfo;
|
||||
struct BufferizationOptions;
|
||||
struct AnalysisBufferizationOptions;
|
||||
class BufferizationState;
|
||||
|
||||
/// Analyze `op` and its nested ops. Bufferization decisions are stored in
|
||||
/// `state`.
|
||||
LogicalResult analyzeOp(Operation *op, BufferizationState &state);
|
||||
LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state);
|
||||
|
||||
/// Bufferize `op` and its nested ops. Bufferization decisions are stored in
|
||||
/// `state`.
|
||||
LogicalResult bufferizeOp(Operation *op, const BufferizationState &state);
|
||||
|
||||
/// Run Comprehensive Bufferize on the given op: Analysis + Bufferization
|
||||
LogicalResult
|
||||
runComprehensiveBufferize(Operation *op,
|
||||
std::unique_ptr<BufferizationOptions> options);
|
||||
LogicalResult runComprehensiveBufferize(
|
||||
Operation *op, std::unique_ptr<AnalysisBufferizationOptions> options);
|
||||
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
|
|
|
@ -20,14 +20,13 @@ class ModuleOp;
|
|||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
|
||||
struct BufferizationOptions;
|
||||
struct AnalysisBufferizationOptions;
|
||||
|
||||
/// Run Module Bufferization on the given module. Performs a simple function
|
||||
/// call analysis to determine which function arguments are inplaceable. Then
|
||||
/// analyzes and bufferizes FuncOps one-by-one with Comprehensive Bufferization.
|
||||
LogicalResult
|
||||
runComprehensiveBufferize(ModuleOp moduleOp,
|
||||
std::unique_ptr<BufferizationOptions> options);
|
||||
LogicalResult runComprehensiveBufferize(
|
||||
ModuleOp moduleOp, std::unique_ptr<AnalysisBufferizationOptions> options);
|
||||
|
||||
namespace std_ext {
|
||||
|
||||
|
|
|
@ -288,8 +288,13 @@ llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
|
|||
}
|
||||
|
||||
mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
|
||||
Operation *op, const BufferizationOptions &options)
|
||||
: aliasInfo(op), options(options) {
|
||||
const BufferizationOptions &options)
|
||||
: options(options) {}
|
||||
|
||||
mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
|
||||
AnalysisBufferizationState(Operation *op,
|
||||
const AnalysisBufferizationOptions &options)
|
||||
: BufferizationState(options), aliasInfo(op) {
|
||||
// Set up alias sets for OpResults that must bufferize in-place. This should
|
||||
// be done before making any other bufferization decisions.
|
||||
op->walk([&](BufferizableOpInterface bufferizableOp) {
|
||||
|
@ -353,7 +358,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
|
|||
Value operand = opOperand.get();
|
||||
Value operandBuffer = lookupBuffer(rewriter, operand);
|
||||
|
||||
if (forceInPlace || aliasInfo.isInPlace(opOperand))
|
||||
if (forceInPlace || isInPlace(opOperand))
|
||||
return operandBuffer;
|
||||
|
||||
// Bufferizing out-of-place: Allocate a new buffer.
|
||||
|
@ -597,11 +602,16 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
|
|||
return isa<FuncOp>(bbArg.getOwner()->getParentOp());
|
||||
}
|
||||
|
||||
bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
|
||||
OpOperand &opOperand) const {
|
||||
bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
|
||||
isInPlace(OpOperand &opOperand) const {
|
||||
return aliasInfo.isInPlace(opOperand);
|
||||
}
|
||||
|
||||
bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
|
||||
areEquivalentBufferizedValues(Value v1, Value v2) const {
|
||||
return aliasInfo.areEquivalentBufferizedValues(v1, v2);
|
||||
}
|
||||
|
||||
MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
|
||||
ShapedType shapedType, MemRefLayoutAttrInterface layout,
|
||||
Attribute memorySpace) {
|
||||
|
|
|
@ -252,15 +252,13 @@ static bool hasReadAfterWriteInterference(
|
|||
|
||||
// No conflict if the op interface says so.
|
||||
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
|
||||
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state,
|
||||
aliasInfo))
|
||||
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
|
||||
continue;
|
||||
|
||||
if (conflictingWritingOp != readingOp)
|
||||
if (auto bufferizableOp =
|
||||
options.dynCastBufferizableOp(conflictingWritingOp))
|
||||
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state,
|
||||
aliasInfo))
|
||||
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
|
||||
continue;
|
||||
|
||||
// Ops are not conflicting if they are in mutually exclusive regions.
|
||||
|
@ -496,7 +494,7 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
|
|||
for (OpOperand *opOperand :
|
||||
bufferizableOp.getAliasingOpOperand(opResult, state))
|
||||
if (state.isInPlace(*opOperand))
|
||||
if (bufferizableOp.bufferRelation(opResult, aliasInfo, state) ==
|
||||
if (bufferizableOp.bufferRelation(opResult, state) ==
|
||||
BufferRelation::Equivalent)
|
||||
aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
|
||||
}
|
||||
|
@ -687,12 +685,12 @@ checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
mlir::linalg::comprehensive_bufferize::analyzeOp(Operation *op,
|
||||
BufferizationState &state) {
|
||||
LogicalResult mlir::linalg::comprehensive_bufferize::analyzeOp(
|
||||
Operation *op, AnalysisBufferizationState &state) {
|
||||
DominanceInfo domInfo(op);
|
||||
BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
|
||||
const BufferizationOptions &options = state.getOptions();
|
||||
const auto &options =
|
||||
static_cast<const AnalysisBufferizationOptions &>(state.getOptions());
|
||||
|
||||
if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
|
||||
return failure();
|
||||
|
@ -740,8 +738,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
|
|||
}
|
||||
|
||||
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
||||
Operation *op, std::unique_ptr<BufferizationOptions> options) {
|
||||
BufferizationState state(op, *options);
|
||||
Operation *op, std::unique_ptr<AnalysisBufferizationOptions> options) {
|
||||
AnalysisBufferizationState state(op, *options);
|
||||
if (failed(analyzeOp(op, state)))
|
||||
return failure();
|
||||
if (options->testAnalysisOnly)
|
||||
|
|
|
@ -193,7 +193,6 @@ struct LinalgOpInterface
|
|||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
@ -264,7 +263,6 @@ struct TiledLoopOpInterface
|
|||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
|
|
@ -737,7 +737,6 @@ struct CallOpInterface
|
|||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
@ -964,9 +963,9 @@ annotateOpsWithBufferizationMarkers(FuncOp funcOp,
|
|||
}
|
||||
|
||||
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
||||
ModuleOp moduleOp, std::unique_ptr<BufferizationOptions> options) {
|
||||
ModuleOp moduleOp, std::unique_ptr<AnalysisBufferizationOptions> options) {
|
||||
IRRewriter rewriter(moduleOp.getContext());
|
||||
BufferizationState state(moduleOp, *options);
|
||||
AnalysisBufferizationState state(moduleOp, *options);
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
|
||||
|
||||
|
|
|
@ -124,7 +124,6 @@ struct ExecuteRegionOpInterface
|
|||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
@ -247,7 +246,6 @@ struct IfOpInterface
|
|||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationState &state) const {
|
||||
// IfOp results are equivalent to their corresponding yield values if both
|
||||
// yield values are equivalent to each other.
|
||||
|
@ -255,7 +253,7 @@ struct IfOpInterface
|
|||
SmallVector<OpOperand *> yieldValues =
|
||||
bufferizableOp.getAliasingOpOperand(opResult, state);
|
||||
assert(yieldValues.size() == 2 && "expected 2 yield values");
|
||||
bool equivalentYields = aliasInfo.areEquivalentBufferizedValues(
|
||||
bool equivalentYields = state.areEquivalentBufferizedValues(
|
||||
yieldValues[0]->get(), yieldValues[1]->get());
|
||||
return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
|
||||
}
|
||||
|
@ -291,7 +289,6 @@ struct ForOpInterface
|
|||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationState &state) const {
|
||||
// ForOp results are equivalent to their corresponding init_args if the
|
||||
// corresponding iter_args and yield values are equivalent.
|
||||
|
@ -299,7 +296,7 @@ struct ForOpInterface
|
|||
OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
|
||||
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
|
||||
auto yieldOp = cast<scf::YieldOp>(&forOp.getLoopBody().front().back());
|
||||
bool equivalentYield = aliasInfo.areEquivalentBufferizedValues(
|
||||
bool equivalentYield = state.areEquivalentBufferizedValues(
|
||||
bbArg, yieldOp->getOperand(opResult.getResultNumber()));
|
||||
return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
|
||||
}
|
||||
|
@ -408,7 +405,9 @@ mlir::linalg::comprehensive_bufferize::scf_ext::AssertScfForAliasingProperties::
|
|||
OpOperand &forOperand = forOp.getOpOperandForResult(
|
||||
forOp->getResult(operand.getOperandNumber()));
|
||||
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
|
||||
if (!aliasInfo.areAliasingBufferizedValues(operand.get(), bbArg)) {
|
||||
// Note: This is overly strict. We should check for aliasing bufferized
|
||||
// values. But we don't have a "must-alias" analysis yet.
|
||||
if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
|
||||
// TODO: this could get resolved with copies but it can also turn into
|
||||
// swaps so we need to be careful about order of copies.
|
||||
status =
|
||||
|
|
|
@ -62,7 +62,6 @@ struct SelectOpInterface
|
|||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
|
|
@ -42,7 +42,6 @@ struct CastOpInterface
|
|||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
@ -137,7 +136,6 @@ struct ExtractSliceOpInterface
|
|||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
@ -273,7 +271,6 @@ struct InsertOpInterface
|
|||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
@ -285,12 +282,12 @@ struct InsertOpInterface
|
|||
/// This is one particular type of relationship between ops on tensors that
|
||||
/// reduce to an equivalence on buffers. This should be generalized and
|
||||
/// exposed as interfaces on the proper types.
|
||||
static bool
|
||||
areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
|
||||
ExtractSliceOp st, InsertSliceOp sti) {
|
||||
static bool areEquivalentExtractSliceOps(const BufferizationState &state,
|
||||
ExtractSliceOp st, InsertSliceOp sti) {
|
||||
if (!st || !sti)
|
||||
return false;
|
||||
if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
|
||||
if (sti != sti &&
|
||||
!state.areEquivalentBufferizedValues(st.source(), sti.dest()))
|
||||
return false;
|
||||
if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
|
||||
return false;
|
||||
|
@ -299,12 +296,11 @@ areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
|
|||
|
||||
/// Return true if `value` is originating from an ExtractSliceOp that matches
|
||||
/// the given InsertSliceOp.
|
||||
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationState &state,
|
||||
static bool hasMatchingExtractSliceOp(const BufferizationState &state,
|
||||
Value value, InsertSliceOp insertOp) {
|
||||
auto condition = [&](Value val) {
|
||||
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
|
||||
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
|
||||
if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
@ -336,15 +332,13 @@ struct InsertSliceOpInterface
|
|||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
||||
bool isNotConflicting(Operation *op, OpOperand *uRead,
|
||||
OpOperand *uConflictingWrite,
|
||||
const BufferizationState &state,
|
||||
const BufferizationAliasInfo &aliasInfo) const {
|
||||
const BufferizationState &state) const {
|
||||
Operation *readingOp = uRead->getOwner();
|
||||
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
|
||||
|
||||
|
@ -360,7 +354,7 @@ struct InsertSliceOpInterface
|
|||
|
||||
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
|
||||
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, state, uConflictingWrite->get(),
|
||||
hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
|
||||
insertSliceOp))
|
||||
// Case 1: The main insight is that InsertSliceOp reads only part of
|
||||
// the destination tensor. The overwritten area is not read. If
|
||||
|
@ -378,8 +372,7 @@ struct InsertSliceOpInterface
|
|||
|
||||
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
|
||||
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, state, uRead->get(),
|
||||
insertSliceOp))
|
||||
hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
|
||||
// Case 2: The read of the source tensor and the write to the dest
|
||||
// tensor via an InsertSliceOp is not a conflict if the read is
|
||||
// reading exactly that part of an equivalent tensor that the
|
||||
|
@ -410,9 +403,9 @@ struct InsertSliceOpInterface
|
|||
// memory segment of %1 with the exact same data. (Effectively, there
|
||||
// is no memory write here.)
|
||||
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
|
||||
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
|
||||
insertSliceOp.source()) &&
|
||||
hasMatchingExtractSliceOp(aliasInfo, state, insertSliceOp.source(),
|
||||
state.areEquivalentBufferizedValues(uRead->get(),
|
||||
insertSliceOp.source()) &&
|
||||
hasMatchingExtractSliceOp(state, insertSliceOp.source(),
|
||||
insertSliceOp))
|
||||
return true;
|
||||
|
||||
|
|
|
@ -85,7 +85,6 @@ struct TransferWriteOpInterface
|
|||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
const BufferizationState &state) const {
|
||||
return BufferRelation::Equivalent;
|
||||
}
|
||||
|
|
|
@ -75,7 +75,7 @@ static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
|
|||
}
|
||||
|
||||
void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
||||
auto options = std::make_unique<BufferizationOptions>();
|
||||
auto options = std::make_unique<AnalysisBufferizationOptions>();
|
||||
if (useAlloca) {
|
||||
options->allocationFn = allocationFnUsingAlloca;
|
||||
options->deallocationFn = [](OpBuilder &b, Location loc, Value v) {
|
||||
|
|
|
@ -101,7 +101,7 @@ struct TestComprehensiveFunctionBufferize
|
|||
} // namespace
|
||||
|
||||
void TestComprehensiveFunctionBufferize::runOnOperation() {
|
||||
auto options = std::make_unique<BufferizationOptions>();
|
||||
auto options = std::make_unique<AnalysisBufferizationOptions>();
|
||||
|
||||
if (!allowReturnMemref)
|
||||
options->addPostAnalysisStep<scf_ext::AssertScfForAliasingProperties>();
|
||||
|
|
Loading…
Reference in New Issue