[mlir] Extended BufferPlacement to support nested region control flow.
Summary: The current BufferPlacement implementation does not support nested region control flow. This CL adds support for nested regions via the RegionBranchOpInterface and the detection of branch-like (ReturnLike) terminators inside nested regions. Differential Revision: https://reviews.llvm.org/D81926
This commit is contained in:
parent
d28267f981
commit
6f5da84f7b
|
@ -65,8 +65,18 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// Walks over all immediate return-like terminators in the given region.
|
||||
template <typename FuncT>
|
||||
static void walkReturnOperations(Region *region, const FuncT &func) {
|
||||
for (Block &block : *region)
|
||||
for (Operation &operation : block) {
|
||||
// Skip non-return-like terminators.
|
||||
if (operation.hasTrait<OpTrait::ReturnLike>())
|
||||
func(&operation);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferPlacementAliasAnalysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -82,7 +92,7 @@ public:
|
|||
|
||||
public:
|
||||
/// Constructs a new alias analysis using the op provided.
|
||||
BufferPlacementAliasAnalysis(Operation *op) { build(op->getRegions()); }
|
||||
BufferPlacementAliasAnalysis(Operation *op) { build(op); }
|
||||
|
||||
/// Find all immediate aliases this value could potentially have.
|
||||
ValueMapT::const_iterator find(Value value) const {
|
||||
|
@ -102,7 +112,7 @@ public:
|
|||
}
|
||||
|
||||
/// Removes the given values from all alias sets.
|
||||
void remove(const SmallPtrSetImpl<BlockArgument> &aliasValues) {
|
||||
void remove(const SmallPtrSetImpl<Value> &aliasValues) {
|
||||
for (auto &entry : aliases)
|
||||
llvm::set_subtract(entry.second, aliasValues);
|
||||
}
|
||||
|
@ -123,33 +133,69 @@ private:
|
|||
/// This function constructs a mapping from values to its immediate aliases.
|
||||
/// It iterates over all blocks, gets their predecessors, determines the
|
||||
/// values that will be passed to the corresponding block arguments and
|
||||
/// inserts them into the underlying map.
|
||||
void build(MutableArrayRef<Region> regions) {
|
||||
for (Region ®ion : regions) {
|
||||
for (Block &block : region) {
|
||||
// Iterate over all predecessor and get the mapped values to their
|
||||
// corresponding block arguments values.
|
||||
for (auto it = block.pred_begin(), e = block.pred_end(); it != e;
|
||||
++it) {
|
||||
unsigned successorIndex = it.getSuccessorIndex();
|
||||
// Get the terminator and the values that will be passed to our block.
|
||||
auto branchInterface =
|
||||
dyn_cast<BranchOpInterface>((*it)->getTerminator());
|
||||
if (!branchInterface)
|
||||
continue;
|
||||
// Query the branch op interace to get the successor operands.
|
||||
auto successorOperands =
|
||||
branchInterface.getSuccessorOperands(successorIndex);
|
||||
if (successorOperands.hasValue()) {
|
||||
// Build the actual mapping of values to their immediate aliases.
|
||||
for (auto argPair : llvm::zip(block.getArguments(),
|
||||
successorOperands.getValue())) {
|
||||
aliases[std::get<1>(argPair)].insert(std::get<0>(argPair));
|
||||
}
|
||||
}
|
||||
/// inserts them into the underlying map. Furthermore, it wires successor
|
||||
/// regions and branch-like return operations from nested regions.
|
||||
void build(Operation *op) {
|
||||
// Registers all aliases of the given values.
|
||||
auto registerAliases = [&](auto values, auto aliases) {
|
||||
for (auto entry : llvm::zip(values, aliases))
|
||||
this->aliases[std::get<0>(entry)].insert(std::get<1>(entry));
|
||||
};
|
||||
|
||||
// Query all branch interfaces to link block argument aliases.
|
||||
op->walk([&](BranchOpInterface branchInterface) {
|
||||
Block *parentBlock = branchInterface.getOperation()->getBlock();
|
||||
for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
|
||||
it != e; ++it) {
|
||||
// Query the branch op interface to get the successor operands.
|
||||
auto successorOperands =
|
||||
branchInterface.getSuccessorOperands(it.getIndex());
|
||||
if (!successorOperands.hasValue())
|
||||
continue;
|
||||
// Build the actual mapping of values to their immediate aliases.
|
||||
registerAliases(successorOperands.getValue(), (*it)->getArguments());
|
||||
}
|
||||
});
|
||||
|
||||
// Query the RegionBranchOpInterface to find potential successor regions.
|
||||
op->walk([&](RegionBranchOpInterface regionInterface) {
|
||||
// Create an empty attribute for each operand to comply with the
|
||||
// `getSuccessorRegions` interface definition that requires a single
|
||||
// attribute per operand.
|
||||
SmallVector<Attribute, 2> operandAttributes(
|
||||
regionInterface.getOperation()->getNumOperands());
|
||||
|
||||
// Extract all entry regions and wire all initial entry successor inputs.
|
||||
SmallVector<RegionSuccessor, 2> entrySuccessors;
|
||||
regionInterface.getSuccessorRegions(/*index=*/llvm::None,
|
||||
operandAttributes, entrySuccessors);
|
||||
for (RegionSuccessor &entrySuccessor : entrySuccessors) {
|
||||
// Wire the entry region's successor arguments with the initial
|
||||
// successor inputs.
|
||||
assert(entrySuccessor.getSuccessor() &&
|
||||
"Invalid entry region without an attached successor region");
|
||||
registerAliases(regionInterface.getSuccessorEntryOperands(
|
||||
entrySuccessor.getSuccessor()->getRegionNumber()),
|
||||
entrySuccessor.getSuccessorInputs());
|
||||
}
|
||||
|
||||
// Wire flow between regions and from region exits.
|
||||
for (Region ®ion : regionInterface.getOperation()->getRegions()) {
|
||||
// Iterate over all successor region entries that are reachable from the
|
||||
// current region.
|
||||
SmallVector<RegionSuccessor, 2> successorRegions;
|
||||
regionInterface.getSuccessorRegions(
|
||||
region.getRegionNumber(), operandAttributes, successorRegions);
|
||||
for (RegionSuccessor &successorRegion : successorRegions) {
|
||||
// Iterate over all immediate terminator operations and wire the
|
||||
// successor inputs with the operands of each terminator.
|
||||
walkReturnOperations(®ion, [&](Operation *terminator) {
|
||||
registerAliases(terminator->getOperands(),
|
||||
successorRegion.getSuccessorInputs());
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Maps values to all immediate aliases this value can have.
|
||||
|
@ -235,14 +281,24 @@ private:
|
|||
Block *getInitialAllocBlock(OpResult result) {
|
||||
// Get all allocation operands as these operands are important for the
|
||||
// allocation operation.
|
||||
auto operands = result.getOwner()->getOperands();
|
||||
Operation *owner = result.getOwner();
|
||||
auto operands = owner->getOperands();
|
||||
Block *dominator;
|
||||
if (operands.size() < 1)
|
||||
return findCommonDominator(result, aliases.resolve(result), dominators);
|
||||
dominator =
|
||||
findCommonDominator(result, aliases.resolve(result), dominators);
|
||||
else {
|
||||
// If this node has dependencies, check all dependent nodes with respect
|
||||
// to a common post dominator in which all values are available.
|
||||
ValueSetT dependencies(++operands.begin(), operands.end());
|
||||
dominator =
|
||||
findCommonDominator(*operands.begin(), dependencies, postDominators);
|
||||
}
|
||||
|
||||
// If this node has dependencies, check all dependent nodes with respect
|
||||
// to a common post dominator in which all values are available.
|
||||
ValueSetT dependencies(++operands.begin(), operands.end());
|
||||
return findCommonDominator(*operands.begin(), dependencies, postDominators);
|
||||
// Do not move allocs out of their parent regions to keep them local.
|
||||
if (dominator->getParent() != owner->getParentRegion())
|
||||
return &owner->getParentRegion()->front();
|
||||
return dominator;
|
||||
}
|
||||
|
||||
/// Finds correct alloc positions according to the algorithm described at
|
||||
|
@ -273,12 +329,12 @@ private:
|
|||
|
||||
/// Introduces required allocs and copy operations to avoid memory leaks.
|
||||
void introduceCopies() {
|
||||
// Initialize the set of block arguments that require a dedicated memory
|
||||
// free operation since their arguments cannot be safely deallocated in a
|
||||
// post dominator.
|
||||
SmallPtrSet<BlockArgument, 8> blockArgsToFree;
|
||||
llvm::SmallDenseSet<std::tuple<BlockArgument, Block *>> visitedBlockArgs;
|
||||
SmallVector<std::tuple<BlockArgument, Block *>, 8> toProcess;
|
||||
// Initialize the set of values that require a dedicated memory free
|
||||
// operation since their operands cannot be safely deallocated in a post
|
||||
// dominator.
|
||||
SmallPtrSet<Value, 8> valuesToFree;
|
||||
llvm::SmallDenseSet<std::tuple<Value, Block *>> visitedValues;
|
||||
SmallVector<std::tuple<Value, Block *>, 8> toProcess;
|
||||
|
||||
// Check dominance relation for proper dominance properties. If the given
|
||||
// value node does not dominate an alias, we will have to create a copy in
|
||||
|
@ -289,17 +345,15 @@ private:
|
|||
if (it == aliases.end())
|
||||
return;
|
||||
for (Value value : it->second) {
|
||||
auto blockArg = value.cast<BlockArgument>();
|
||||
if (blockArgsToFree.count(blockArg) > 0)
|
||||
if (valuesToFree.count(value) > 0)
|
||||
continue;
|
||||
// Check whether we have to free this particular block argument.
|
||||
if (!dominators.dominates(definingBlock, blockArg.getOwner())) {
|
||||
toProcess.emplace_back(blockArg, blockArg.getParentBlock());
|
||||
blockArgsToFree.insert(blockArg);
|
||||
} else if (visitedBlockArgs
|
||||
.insert(std::make_tuple(blockArg, definingBlock))
|
||||
if (!dominators.dominates(definingBlock, value.getParentBlock())) {
|
||||
toProcess.emplace_back(value, value.getParentBlock());
|
||||
valuesToFree.insert(value);
|
||||
} else if (visitedValues.insert(std::make_tuple(value, definingBlock))
|
||||
.second)
|
||||
toProcess.emplace_back(blockArg, definingBlock);
|
||||
toProcess.emplace_back(value, definingBlock);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -316,60 +370,168 @@ private:
|
|||
|
||||
// Update buffer aliases to ensure that we free all buffers and block
|
||||
// arguments at the correct locations.
|
||||
aliases.remove(blockArgsToFree);
|
||||
aliases.remove(valuesToFree);
|
||||
|
||||
// Add new allocs and additional copy operations.
|
||||
for (BlockArgument blockArg : blockArgsToFree) {
|
||||
Block *block = blockArg.getOwner();
|
||||
for (Value value : valuesToFree) {
|
||||
if (auto blockArg = value.dyn_cast<BlockArgument>())
|
||||
introduceBlockArgCopy(blockArg);
|
||||
else
|
||||
introduceValueCopyForRegionResult(value);
|
||||
|
||||
// Allocate a buffer for the current block argument in the block of
|
||||
// the associated value (which will be a predecessor block by
|
||||
// definition).
|
||||
for (auto it = block->pred_begin(), e = block->pred_end(); it != e;
|
||||
++it) {
|
||||
// Get the terminator and the value that will be passed to our
|
||||
// argument.
|
||||
Operation *terminator = (*it)->getTerminator();
|
||||
auto branchInterface = cast<BranchOpInterface>(terminator);
|
||||
// Convert the mutable operand range to an immutable range and query the
|
||||
// associated source value.
|
||||
Value sourceValue =
|
||||
branchInterface.getSuccessorOperands(it.getSuccessorIndex())
|
||||
.getValue()[blockArg.getArgNumber()];
|
||||
// Create a new alloc at the current location of the terminator.
|
||||
auto memRefType = sourceValue.getType().cast<MemRefType>();
|
||||
OpBuilder builder(terminator);
|
||||
// Register the value to require a final dealloc. Note that we do not have
|
||||
// to assign a block here since we do not want to move the allocation node
|
||||
// to another location.
|
||||
allocs.push_back({value, nullptr, nullptr});
|
||||
}
|
||||
}
|
||||
|
||||
// Extract information about dynamically shaped types by
|
||||
// extracting their dynamic dimensions.
|
||||
SmallVector<Value, 4> dynamicOperands;
|
||||
for (auto shapeElement : llvm::enumerate(memRefType.getShape())) {
|
||||
if (!ShapedType::isDynamic(shapeElement.value()))
|
||||
continue;
|
||||
dynamicOperands.push_back(builder.create<DimOp>(
|
||||
terminator->getLoc(), sourceValue, shapeElement.index()));
|
||||
}
|
||||
|
||||
// TODO: provide a generic interface to create dialect-specific
|
||||
// Alloc and CopyOp nodes.
|
||||
auto alloc = builder.create<AllocOp>(terminator->getLoc(), memRefType,
|
||||
dynamicOperands);
|
||||
// Wire new alloc and successor operand.
|
||||
branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex())
|
||||
.getValue()
|
||||
/// Introduces temporary allocs in all predecessors and copies the source
|
||||
/// values into the newly allocated buffers.
|
||||
void introduceBlockArgCopy(BlockArgument blockArg) {
|
||||
// Allocate a buffer for the current block argument in the block of
|
||||
// the associated value (which will be a predecessor block by
|
||||
// definition).
|
||||
Block *block = blockArg.getOwner();
|
||||
for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
|
||||
// Get the terminator and the value that will be passed to our
|
||||
// argument.
|
||||
Operation *terminator = (*it)->getTerminator();
|
||||
auto branchInterface = cast<BranchOpInterface>(terminator);
|
||||
// Query the associated source value.
|
||||
Value sourceValue =
|
||||
branchInterface.getSuccessorOperands(it.getSuccessorIndex())
|
||||
.getValue()[blockArg.getArgNumber()];
|
||||
// Create a new alloc and copy at the current location of the terminator.
|
||||
Value alloc = introduceBufferCopy(sourceValue, terminator);
|
||||
// Wire new alloc and successor operand.
|
||||
auto mutableOperands =
|
||||
branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex());
|
||||
if (!mutableOperands.hasValue())
|
||||
terminator->emitError() << "terminators with immutable successor "
|
||||
"operands are not supported";
|
||||
else
|
||||
mutableOperands.getValue()
|
||||
.slice(blockArg.getArgNumber(), 1)
|
||||
.assign(alloc);
|
||||
// Create a new copy operation that copies to contents of the old
|
||||
// allocation to the new one.
|
||||
builder.create<linalg::CopyOp>(terminator->getLoc(), sourceValue,
|
||||
alloc);
|
||||
}
|
||||
|
||||
// Register the block argument to require a final dealloc. Note that
|
||||
// we do not have to assign a block here since we do not want to
|
||||
// move the allocation node to another location.
|
||||
allocs.push_back({blockArg, nullptr, nullptr});
|
||||
}
|
||||
|
||||
// Check whether the block argument has implicitly defined predecessors via
|
||||
// the RegionBranchOpInterface. This can be the case if the current block
|
||||
// argument belongs to the first block in a region and the parent operation
|
||||
// implements the RegionBranchOpInterface.
|
||||
Region *argRegion = block->getParent();
|
||||
RegionBranchOpInterface regionInterface;
|
||||
if (!argRegion || &argRegion->front() != block ||
|
||||
!(regionInterface =
|
||||
dyn_cast<RegionBranchOpInterface>(argRegion->getParentOp())))
|
||||
return;
|
||||
|
||||
introduceCopiesForRegionSuccessors(
|
||||
regionInterface, argRegion->getParentOp()->getRegions(),
|
||||
[&](RegionSuccessor &successorRegion) {
|
||||
// Find a predecessor of our argRegion.
|
||||
return successorRegion.getSuccessor() == argRegion;
|
||||
},
|
||||
[&](RegionSuccessor &successorRegion) {
|
||||
// The operand index will be the argument number.
|
||||
return blockArg.getArgNumber();
|
||||
});
|
||||
}
|
||||
|
||||
/// Introduces temporary allocs in front of all associated nested-region
|
||||
/// terminators and copies the source values into the newly allocated buffers.
|
||||
void introduceValueCopyForRegionResult(Value value) {
|
||||
// Get the actual result index in the scope of the parent terminator.
|
||||
Operation *operation = value.getDefiningOp();
|
||||
auto regionInterface = cast<RegionBranchOpInterface>(operation);
|
||||
introduceCopiesForRegionSuccessors(
|
||||
regionInterface, operation->getRegions(),
|
||||
[&](RegionSuccessor &successorRegion) {
|
||||
// Determine whether this region has a successor entry that leaves
|
||||
// this region by returning to its parent operation.
|
||||
return !successorRegion.getSuccessor();
|
||||
},
|
||||
[&](RegionSuccessor &successorRegion) {
|
||||
// Find the associated success input index.
|
||||
return llvm::find(successorRegion.getSuccessorInputs(), value)
|
||||
.getIndex();
|
||||
});
|
||||
}
|
||||
|
||||
/// Introduces buffer copies for all terminators in the given regions. The
|
||||
/// regionPredicate is applied to every successor region in order to restrict
|
||||
/// the copies to specific regions. Thereby, the operandProvider is invoked
|
||||
/// for each matching region successor and determines the operand index that
|
||||
/// requires a buffer copy.
|
||||
template <typename TPredicate, typename TOperandProvider>
|
||||
void
|
||||
introduceCopiesForRegionSuccessors(RegionBranchOpInterface regionInterface,
|
||||
MutableArrayRef<Region> regions,
|
||||
const TPredicate ®ionPredicate,
|
||||
const TOperandProvider &operandProvider) {
|
||||
// Create an empty attribute for each operand to comply with the
|
||||
// `getSuccessorRegions` interface definition that requires a single
|
||||
// attribute per operand.
|
||||
SmallVector<Attribute, 2> operandAttributes(
|
||||
regionInterface.getOperation()->getNumOperands());
|
||||
for (Region ®ion : regions) {
|
||||
// Query the regionInterface to get all successor regions of the current
|
||||
// one.
|
||||
SmallVector<RegionSuccessor, 2> successorRegions;
|
||||
regionInterface.getSuccessorRegions(region.getRegionNumber(),
|
||||
operandAttributes, successorRegions);
|
||||
// Try to find a matching region successor.
|
||||
RegionSuccessor *regionSuccessor =
|
||||
llvm::find_if(successorRegions, regionPredicate);
|
||||
if (regionSuccessor == successorRegions.end())
|
||||
continue;
|
||||
// Get the operand index in the context of the current successor input
|
||||
// bindings.
|
||||
auto operandIndex = operandProvider(*regionSuccessor);
|
||||
|
||||
// Iterate over all immediate terminator operations to introduce
|
||||
// new buffer allocations. Thereby, the appropriate terminator operand
|
||||
// will be adjusted to point to the newly allocated buffer instead.
|
||||
walkReturnOperations(®ion, [&](Operation *terminator) {
|
||||
// Extract the source value from the current terminator.
|
||||
Value sourceValue = terminator->getOperand(operandIndex);
|
||||
// Create a new alloc at the current location of the terminator.
|
||||
Value alloc = introduceBufferCopy(sourceValue, terminator);
|
||||
// Wire alloc and terminator operand.
|
||||
terminator->setOperand(operandIndex, alloc);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new memory allocation for the given source value and copies
|
||||
/// its content into the newly allocated buffer. The terminator operation is
|
||||
/// used to insert the alloc and copy operations at the right places.
|
||||
Value introduceBufferCopy(Value sourceValue, Operation *terminator) {
|
||||
// Create a new alloc at the current location of the terminator.
|
||||
auto memRefType = sourceValue.getType().cast<MemRefType>();
|
||||
OpBuilder builder(terminator);
|
||||
|
||||
// Extract information about dynamically shaped types by
|
||||
// extracting their dynamic dimensions.
|
||||
SmallVector<Value, 4> dynamicOperands;
|
||||
for (auto shapeElement : llvm::enumerate(memRefType.getShape())) {
|
||||
if (!ShapedType::isDynamic(shapeElement.value()))
|
||||
continue;
|
||||
dynamicOperands.push_back(builder.create<DimOp>(
|
||||
terminator->getLoc(), sourceValue, shapeElement.index()));
|
||||
}
|
||||
|
||||
// TODO: provide a generic interface to create dialect-specific
|
||||
// Alloc and CopyOp nodes.
|
||||
auto alloc = builder.create<AllocOp>(terminator->getLoc(), memRefType,
|
||||
dynamicOperands);
|
||||
|
||||
// Create a new copy operation that copies to contents of the old
|
||||
// allocation to the new one.
|
||||
builder.create<linalg::CopyOp>(terminator->getLoc(), sourceValue, alloc);
|
||||
|
||||
return alloc;
|
||||
}
|
||||
|
||||
/// Finds associated deallocs that can be linked to our allocation nodes (if
|
||||
|
@ -440,8 +602,8 @@ private:
|
|||
if (entry.deallocOperation) {
|
||||
entry.deallocOperation->moveAfter(endOperation);
|
||||
} else {
|
||||
// If the Dealloc position is at the terminator operation of the block,
|
||||
// then the value should escape from a deallocation.
|
||||
// If the Dealloc position is at the terminator operation of the
|
||||
// block, then the value should escape from a deallocation.
|
||||
Operation *nextOp = endOperation->getNextNode();
|
||||
if (!nextOp)
|
||||
continue;
|
||||
|
|
|
@ -716,3 +716,201 @@ func @memref_in_function_results(%arg0: memref<5xf32>, %arg1: memref<10xf32>, %a
|
|||
// CHECK: dealloc %[[Y]]
|
||||
// CHECK: return %[[ARG1]], %[[X]]
|
||||
|
||||
// -----
|
||||
|
||||
// Test Case: nested region control flow
|
||||
// The alloc position of %1 does not need to be changed and flows through
|
||||
// both if branches until it is finally returned. Hence, it does not
|
||||
// require a specific dealloc operation. However, %3 requires a dealloc.
|
||||
|
||||
// CHECK-LABEL: func @nested_region_control_flow
|
||||
func @nested_region_control_flow(
|
||||
%arg0 : index,
|
||||
%arg1 : index) -> memref<?x?xf32> {
|
||||
%0 = cmpi "eq", %arg0, %arg1 : index
|
||||
%1 = alloc(%arg0, %arg0) : memref<?x?xf32>
|
||||
%2 = scf.if %0 -> (memref<?x?xf32>) {
|
||||
scf.yield %1 : memref<?x?xf32>
|
||||
} else {
|
||||
%3 = alloc(%arg0, %arg1) : memref<?x?xf32>
|
||||
scf.yield %1 : memref<?x?xf32>
|
||||
}
|
||||
return %2 : memref<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
|
||||
// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if
|
||||
// CHECK: scf.yield %[[ALLOC0]]
|
||||
// CHECK: %[[ALLOC2:.*]] = alloc(%arg0, %arg1)
|
||||
// CHECK-NEXT: dealloc %[[ALLOC2]]
|
||||
// CHECK-NEXT: scf.yield %[[ALLOC0]]
|
||||
// CHECK: return %[[ALLOC1]]
|
||||
|
||||
// -----
|
||||
|
||||
// Test Case: nested region control flow with a nested buffer allocation in a
|
||||
// divergent branch.
|
||||
// The alloc positions of %1, %3 does not need to be changed since
|
||||
// BufferPlacement does not move allocs out of nested regions at the moment.
|
||||
// However, since %3 is allocated and "returned" in a divergent branch, we have
|
||||
// to allocate a temporary buffer (like in condBranchDynamicTypeNested).
|
||||
|
||||
// CHECK-LABEL: func @nested_region_control_flow_div
|
||||
func @nested_region_control_flow_div(
|
||||
%arg0 : index,
|
||||
%arg1 : index) -> memref<?x?xf32> {
|
||||
%0 = cmpi "eq", %arg0, %arg1 : index
|
||||
%1 = alloc(%arg0, %arg0) : memref<?x?xf32>
|
||||
%2 = scf.if %0 -> (memref<?x?xf32>) {
|
||||
scf.yield %1 : memref<?x?xf32>
|
||||
} else {
|
||||
%3 = alloc(%arg0, %arg1) : memref<?x?xf32>
|
||||
scf.yield %3 : memref<?x?xf32>
|
||||
}
|
||||
return %2 : memref<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
|
||||
// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if
|
||||
// CHECK: %[[ALLOC2:.*]] = alloc
|
||||
// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC2]])
|
||||
// CHECK: scf.yield %[[ALLOC2]]
|
||||
// CHECK: %[[ALLOC3:.*]] = alloc(%arg0, %arg1)
|
||||
// CHECK: %[[ALLOC4:.*]] = alloc
|
||||
// CHECK-NEXT: linalg.copy(%[[ALLOC3]], %[[ALLOC4]])
|
||||
// CHECK: dealloc %[[ALLOC3]]
|
||||
// CHECK: scf.yield %[[ALLOC4]]
|
||||
// CHECK: dealloc %[[ALLOC0]]
|
||||
// CHECK-NEXT: return %[[ALLOC1]]
|
||||
|
||||
// -----
|
||||
|
||||
// Test Case: deeply nested region control flow with a nested buffer allocation
|
||||
// in a divergent branch.
|
||||
// The alloc positions of %1, %4 and %5 does not need to be changed since
|
||||
// BufferPlacement does not move allocs out of nested regions at the moment.
|
||||
// However, since %4 is allocated and "returned" in a divergent branch, we have
|
||||
// to allocate several temporary buffers (like in condBranchDynamicTypeNested).
|
||||
|
||||
// CHECK-LABEL: func @nested_region_control_flow_div_nested
|
||||
func @nested_region_control_flow_div_nested(
|
||||
%arg0 : index,
|
||||
%arg1 : index) -> memref<?x?xf32> {
|
||||
%0 = cmpi "eq", %arg0, %arg1 : index
|
||||
%1 = alloc(%arg0, %arg0) : memref<?x?xf32>
|
||||
%2 = scf.if %0 -> (memref<?x?xf32>) {
|
||||
%3 = scf.if %0 -> (memref<?x?xf32>) {
|
||||
scf.yield %1 : memref<?x?xf32>
|
||||
} else {
|
||||
%4 = alloc(%arg0, %arg1) : memref<?x?xf32>
|
||||
scf.yield %4 : memref<?x?xf32>
|
||||
}
|
||||
scf.yield %3 : memref<?x?xf32>
|
||||
} else {
|
||||
%5 = alloc(%arg1, %arg1) : memref<?x?xf32>
|
||||
scf.yield %5 : memref<?x?xf32>
|
||||
}
|
||||
return %2 : memref<?x?xf32>
|
||||
}
|
||||
// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
|
||||
// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if
|
||||
// CHECK-NEXT: %[[ALLOC2:.*]] = scf.if
|
||||
// CHECK: %[[ALLOC3:.*]] = alloc
|
||||
// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC3]])
|
||||
// CHECK: scf.yield %[[ALLOC3]]
|
||||
// CHECK: %[[ALLOC4:.*]] = alloc(%arg0, %arg1)
|
||||
// CHECK: %[[ALLOC5:.*]] = alloc
|
||||
// CHECK-NEXT: linalg.copy(%[[ALLOC4]], %[[ALLOC5]])
|
||||
// CHECK: dealloc %[[ALLOC4]]
|
||||
// CHECK: scf.yield %[[ALLOC5]]
|
||||
// CHECK: %[[ALLOC6:.*]] = alloc
|
||||
// CHECK-NEXT: linalg.copy(%[[ALLOC2]], %[[ALLOC6]])
|
||||
// CHECK: dealloc %[[ALLOC2]]
|
||||
// CHECK: scf.yield %[[ALLOC6]]
|
||||
// CHECK: %[[ALLOC7:.*]] = alloc(%arg1, %arg1)
|
||||
// CHECK: %[[ALLOC8:.*]] = alloc
|
||||
// CHECK-NEXT: linalg.copy(%[[ALLOC7]], %[[ALLOC8]])
|
||||
// CHECK: dealloc %[[ALLOC7]]
|
||||
// CHECK: scf.yield %[[ALLOC8]]
|
||||
// CHECK: dealloc %[[ALLOC0]]
|
||||
// CHECK-NEXT: return %[[ALLOC1]]
|
||||
|
||||
// -----
|
||||
|
||||
// Test Case: nested region control flow within a region interface.
|
||||
// The alloc positions of %0 does not need to be changed and no copies are
|
||||
// required in this case since the allocation finally escapes the method.
|
||||
|
||||
// CHECK-LABEL: func @inner_region_control_flow
|
||||
func @inner_region_control_flow(%arg0 : index) -> memref<?x?xf32> {
|
||||
%0 = alloc(%arg0, %arg0) : memref<?x?xf32>
|
||||
%1 = test.region_if %0 : memref<?x?xf32> -> (memref<?x?xf32>) then {
|
||||
^bb0(%arg1 : memref<?x?xf32>):
|
||||
test.region_if_yield %arg1 : memref<?x?xf32>
|
||||
} else {
|
||||
^bb0(%arg1 : memref<?x?xf32>):
|
||||
test.region_if_yield %arg1 : memref<?x?xf32>
|
||||
} join {
|
||||
^bb0(%arg1 : memref<?x?xf32>):
|
||||
test.region_if_yield %arg1 : memref<?x?xf32>
|
||||
}
|
||||
return %1 : memref<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
|
||||
// CHECK-NEXT: %[[ALLOC1:.*]] = test.region_if
|
||||
// CHECK-NEXT: ^bb0(%[[ALLOC2:.*]]:{{.*}}):
|
||||
// CHECK-NEXT: test.region_if_yield %[[ALLOC2]]
|
||||
// CHECK: ^bb0(%[[ALLOC3:.*]]:{{.*}}):
|
||||
// CHECK-NEXT: test.region_if_yield %[[ALLOC3]]
|
||||
// CHECK: ^bb0(%[[ALLOC4:.*]]:{{.*}}):
|
||||
// CHECK-NEXT: test.region_if_yield %[[ALLOC4]]
|
||||
// CHECK: return %[[ALLOC1]]
|
||||
|
||||
// -----
|
||||
|
||||
// Test Case: nested region control flow within a region interface including an
|
||||
// allocation in a divergent branch.
|
||||
// The alloc positions of %1 and %2 does not need to be changed since
|
||||
// BufferPlacement does not move allocs out of nested regions at the moment.
|
||||
// However, since %2 is allocated and yielded in a divergent branch, we have
|
||||
// to allocate several temporary buffers (like in condBranchDynamicTypeNested).
|
||||
|
||||
// CHECK-LABEL: func @inner_region_control_flow_div
|
||||
func @inner_region_control_flow_div(
|
||||
%arg0 : index,
|
||||
%arg1 : index) -> memref<?x?xf32> {
|
||||
%0 = alloc(%arg0, %arg0) : memref<?x?xf32>
|
||||
%1 = test.region_if %0 : memref<?x?xf32> -> (memref<?x?xf32>) then {
|
||||
^bb0(%arg2 : memref<?x?xf32>):
|
||||
test.region_if_yield %arg2 : memref<?x?xf32>
|
||||
} else {
|
||||
^bb0(%arg2 : memref<?x?xf32>):
|
||||
%2 = alloc(%arg0, %arg1) : memref<?x?xf32>
|
||||
test.region_if_yield %2 : memref<?x?xf32>
|
||||
} join {
|
||||
^bb0(%arg2 : memref<?x?xf32>):
|
||||
test.region_if_yield %arg2 : memref<?x?xf32>
|
||||
}
|
||||
return %1 : memref<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
|
||||
// CHECK-NEXT: %[[ALLOC1:.*]] = test.region_if
|
||||
// CHECK-NEXT: ^bb0(%[[ALLOC2:.*]]:{{.*}}):
|
||||
// CHECK: %[[ALLOC3:.*]] = alloc
|
||||
// CHECK-NEXT: linalg.copy(%[[ALLOC2]], %[[ALLOC3]])
|
||||
// CHECK-NEXT: test.region_if_yield %[[ALLOC3]]
|
||||
// CHECK: ^bb0(%[[ALLOC4:.*]]:{{.*}}):
|
||||
// CHECK: %[[ALLOC5:.*]] = alloc
|
||||
// CHECK: %[[ALLOC6:.*]] = alloc
|
||||
// CHECK-NEXT: linalg.copy(%[[ALLOC5]], %[[ALLOC6]])
|
||||
// CHECK-NEXT: dealloc %[[ALLOC5]]
|
||||
// CHECK-NEXT: test.region_if_yield %[[ALLOC6]]
|
||||
// CHECK: ^bb0(%[[ALLOC7:.*]]:{{.*}}):
|
||||
// CHECK: %[[ALLOC8:.*]] = alloc
|
||||
// CHECK-NEXT: linalg.copy(%[[ALLOC7]], %[[ALLOC8]])
|
||||
// CHECK-NEXT: dealloc %[[ALLOC7]]
|
||||
// CHECK-NEXT: test.region_if_yield %[[ALLOC8]]
|
||||
// CHECK: dealloc %[[ALLOC0]]
|
||||
// CHECK-NEXT: return %[[ALLOC1]]
|
||||
|
|
|
@ -518,6 +518,77 @@ void StringAttrPrettyNameOp::getAsmResultNames(
|
|||
setNameFn(getResult(i), str.getValue());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RegionIfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, RegionIfOp op) {
|
||||
p << RegionIfOp::getOperationName() << " ";
|
||||
p.printOperands(op.getOperands());
|
||||
p << ": " << op.getOperandTypes();
|
||||
p.printArrowTypeList(op.getResultTypes());
|
||||
p << " then";
|
||||
p.printRegion(op.thenRegion(),
|
||||
/*printEntryBlockArgs=*/true,
|
||||
/*printBlockTerminators=*/true);
|
||||
p << " else";
|
||||
p.printRegion(op.elseRegion(),
|
||||
/*printEntryBlockArgs=*/true,
|
||||
/*printBlockTerminators=*/true);
|
||||
p << " join";
|
||||
p.printRegion(op.joinRegion(),
|
||||
/*printEntryBlockArgs=*/true,
|
||||
/*printBlockTerminators=*/true);
|
||||
}
|
||||
|
||||
static ParseResult parseRegionIfOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> operandInfos;
|
||||
SmallVector<Type, 2> operandTypes;
|
||||
|
||||
result.regions.reserve(3);
|
||||
Region *thenRegion = result.addRegion();
|
||||
Region *elseRegion = result.addRegion();
|
||||
Region *joinRegion = result.addRegion();
|
||||
|
||||
// Parse operand, type and arrow type lists.
|
||||
if (parser.parseOperandList(operandInfos) ||
|
||||
parser.parseColonTypeList(operandTypes) ||
|
||||
parser.parseArrowTypeList(result.types))
|
||||
return failure();
|
||||
|
||||
// Parse all attached regions.
|
||||
if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
|
||||
parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
|
||||
parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
|
||||
return failure();
|
||||
|
||||
return parser.resolveOperands(operandInfos, operandTypes,
|
||||
parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
|
||||
assert(index < 2 && "invalid region index");
|
||||
return getOperands();
|
||||
}
|
||||
|
||||
void RegionIfOp::getSuccessorRegions(
|
||||
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// We always branch to the join region.
|
||||
if (index.hasValue()) {
|
||||
if (index.getValue() < 2)
|
||||
regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
|
||||
else
|
||||
regions.push_back(RegionSuccessor(getResults()));
|
||||
return;
|
||||
}
|
||||
|
||||
// The then and else regions are the entry regions of this op.
|
||||
regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
|
||||
regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1349,4 +1349,47 @@ def SideEffectOp : TEST_Op<"side_effect_op",
|
|||
let results = (outs AnyType:$result);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test RegionBranchOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def RegionIfYieldOp : TEST_Op<"region_if_yield",
|
||||
[NoSideEffect, ReturnLike, Terminator]> {
|
||||
let arguments = (ins Variadic<AnyType>:$results);
|
||||
let assemblyFormat = [{
|
||||
$results `:` type($results) attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
def RegionIfOp : TEST_Op<"region_if",
|
||||
[DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
||||
SingleBlockImplicitTerminator<"RegionIfYieldOp">,
|
||||
RecursiveSideEffects]> {
|
||||
let description =[{
|
||||
Represents an abstract if-then-else-join pattern. In this context, the then
|
||||
and else regions jump to the join region, which finally returns to its
|
||||
parent op.
|
||||
}];
|
||||
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parseRegionIfOp(parser, result); }];
|
||||
let arguments = (ins Variadic<AnyType>);
|
||||
let results = (outs Variadic<AnyType>:$results);
|
||||
let regions = (region SizedRegion<1>:$thenRegion,
|
||||
AnyRegion:$elseRegion,
|
||||
AnyRegion:$joinRegion);
|
||||
let extraClassDeclaration = [{
|
||||
Block::BlockArgListType getThenArgs() {
|
||||
return getBody(0)->getArguments();
|
||||
}
|
||||
Block::BlockArgListType getElseArgs() {
|
||||
return getBody(1)->getArguments();
|
||||
}
|
||||
Block::BlockArgListType getJoinArgs() {
|
||||
return getBody(2)->getArguments();
|
||||
}
|
||||
OperandRange getSuccessorEntryOperands(unsigned index);
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TEST_OPS
|
||||
|
|
Loading…
Reference in New Issue