[mlir] Extended BufferPlacement to support nested region control flow.

Summary: The current BufferPlacement implementation does not support
nested region control flow. This CL adds support for nested regions via
the RegionBranchOpInterface and the detection of branch-like
(ReturnLike) terminators inside nested regions.

Differential Revision: https://reviews.llvm.org/D81926
This commit is contained in:
Marcel Koester 2020-06-30 11:58:45 +02:00 committed by Stephan Herhut
parent d28267f981
commit 6f5da84f7b
4 changed files with 571 additions and 97 deletions

View File

@ -65,8 +65,18 @@
using namespace mlir;
namespace {
/// Walks over all immediate return-like terminators in the given region.
template <typename FuncT>
static void walkReturnOperations(Region *region, const FuncT &func) {
for (Block &block : *region)
for (Operation &operation : block) {
// Skip non-return-like terminators.
if (operation.hasTrait<OpTrait::ReturnLike>())
func(&operation);
}
}
namespace {
//===----------------------------------------------------------------------===//
// BufferPlacementAliasAnalysis
//===----------------------------------------------------------------------===//
@ -82,7 +92,7 @@ public:
public:
/// Constructs a new alias analysis using the op provided.
BufferPlacementAliasAnalysis(Operation *op) { build(op->getRegions()); }
BufferPlacementAliasAnalysis(Operation *op) { build(op); }
/// Find all immediate aliases this value could potentially have.
ValueMapT::const_iterator find(Value value) const {
@ -102,7 +112,7 @@ public:
}
/// Removes the given values from all alias sets.
void remove(const SmallPtrSetImpl<BlockArgument> &aliasValues) {
void remove(const SmallPtrSetImpl<Value> &aliasValues) {
for (auto &entry : aliases)
llvm::set_subtract(entry.second, aliasValues);
}
@ -123,33 +133,69 @@ private:
/// This function constructs a mapping from values to its immediate aliases.
/// It iterates over all blocks, gets their predecessors, determines the
/// values that will be passed to the corresponding block arguments and
/// inserts them into the underlying map.
void build(MutableArrayRef<Region> regions) {
for (Region &region : regions) {
for (Block &block : region) {
// Iterate over all predecessor and get the mapped values to their
// corresponding block arguments values.
for (auto it = block.pred_begin(), e = block.pred_end(); it != e;
++it) {
unsigned successorIndex = it.getSuccessorIndex();
// Get the terminator and the values that will be passed to our block.
auto branchInterface =
dyn_cast<BranchOpInterface>((*it)->getTerminator());
if (!branchInterface)
continue;
// Query the branch op interace to get the successor operands.
auto successorOperands =
branchInterface.getSuccessorOperands(successorIndex);
if (successorOperands.hasValue()) {
// Build the actual mapping of values to their immediate aliases.
for (auto argPair : llvm::zip(block.getArguments(),
successorOperands.getValue())) {
aliases[std::get<1>(argPair)].insert(std::get<0>(argPair));
}
}
/// inserts them into the underlying map. Furthermore, it wires successor
/// regions and branch-like return operations from nested regions.
void build(Operation *op) {
// Registers all aliases of the given values.
auto registerAliases = [&](auto values, auto aliases) {
for (auto entry : llvm::zip(values, aliases))
this->aliases[std::get<0>(entry)].insert(std::get<1>(entry));
};
// Query all branch interfaces to link block argument aliases.
op->walk([&](BranchOpInterface branchInterface) {
Block *parentBlock = branchInterface.getOperation()->getBlock();
for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
it != e; ++it) {
// Query the branch op interface to get the successor operands.
auto successorOperands =
branchInterface.getSuccessorOperands(it.getIndex());
if (!successorOperands.hasValue())
continue;
// Build the actual mapping of values to their immediate aliases.
registerAliases(successorOperands.getValue(), (*it)->getArguments());
}
});
// Query the RegionBranchOpInterface to find potential successor regions.
op->walk([&](RegionBranchOpInterface regionInterface) {
// Create an empty attribute for each operand to comply with the
// `getSuccessorRegions` interface definition that requires a single
// attribute per operand.
SmallVector<Attribute, 2> operandAttributes(
regionInterface.getOperation()->getNumOperands());
// Extract all entry regions and wire all initial entry successor inputs.
SmallVector<RegionSuccessor, 2> entrySuccessors;
regionInterface.getSuccessorRegions(/*index=*/llvm::None,
operandAttributes, entrySuccessors);
for (RegionSuccessor &entrySuccessor : entrySuccessors) {
// Wire the entry region's successor arguments with the initial
// successor inputs.
assert(entrySuccessor.getSuccessor() &&
"Invalid entry region without an attached successor region");
registerAliases(regionInterface.getSuccessorEntryOperands(
entrySuccessor.getSuccessor()->getRegionNumber()),
entrySuccessor.getSuccessorInputs());
}
// Wire flow between regions and from region exits.
for (Region &region : regionInterface.getOperation()->getRegions()) {
// Iterate over all successor region entries that are reachable from the
// current region.
SmallVector<RegionSuccessor, 2> successorRegions;
regionInterface.getSuccessorRegions(
region.getRegionNumber(), operandAttributes, successorRegions);
for (RegionSuccessor &successorRegion : successorRegions) {
// Iterate over all immediate terminator operations and wire the
// successor inputs with the operands of each terminator.
walkReturnOperations(&region, [&](Operation *terminator) {
registerAliases(terminator->getOperands(),
successorRegion.getSuccessorInputs());
});
}
}
}
});
}
/// Maps values to all immediate aliases this value can have.
@ -235,14 +281,24 @@ private:
Block *getInitialAllocBlock(OpResult result) {
// Get all allocation operands as these operands are important for the
// allocation operation.
auto operands = result.getOwner()->getOperands();
Operation *owner = result.getOwner();
auto operands = owner->getOperands();
Block *dominator;
if (operands.size() < 1)
return findCommonDominator(result, aliases.resolve(result), dominators);
dominator =
findCommonDominator(result, aliases.resolve(result), dominators);
else {
// If this node has dependencies, check all dependent nodes with respect
// to a common post dominator in which all values are available.
ValueSetT dependencies(++operands.begin(), operands.end());
dominator =
findCommonDominator(*operands.begin(), dependencies, postDominators);
}
// If this node has dependencies, check all dependent nodes with respect
// to a common post dominator in which all values are available.
ValueSetT dependencies(++operands.begin(), operands.end());
return findCommonDominator(*operands.begin(), dependencies, postDominators);
// Do not move allocs out of their parent regions to keep them local.
if (dominator->getParent() != owner->getParentRegion())
return &owner->getParentRegion()->front();
return dominator;
}
/// Finds correct alloc positions according to the algorithm described at
@ -273,12 +329,12 @@ private:
/// Introduces required allocs and copy operations to avoid memory leaks.
void introduceCopies() {
// Initialize the set of block arguments that require a dedicated memory
// free operation since their arguments cannot be safely deallocated in a
// post dominator.
SmallPtrSet<BlockArgument, 8> blockArgsToFree;
llvm::SmallDenseSet<std::tuple<BlockArgument, Block *>> visitedBlockArgs;
SmallVector<std::tuple<BlockArgument, Block *>, 8> toProcess;
// Initialize the set of values that require a dedicated memory free
// operation since their operands cannot be safely deallocated in a post
// dominator.
SmallPtrSet<Value, 8> valuesToFree;
llvm::SmallDenseSet<std::tuple<Value, Block *>> visitedValues;
SmallVector<std::tuple<Value, Block *>, 8> toProcess;
// Check dominance relation for proper dominance properties. If the given
// value node does not dominate an alias, we will have to create a copy in
@ -289,17 +345,15 @@ private:
if (it == aliases.end())
return;
for (Value value : it->second) {
auto blockArg = value.cast<BlockArgument>();
if (blockArgsToFree.count(blockArg) > 0)
if (valuesToFree.count(value) > 0)
continue;
// Check whether we have to free this particular block argument.
if (!dominators.dominates(definingBlock, blockArg.getOwner())) {
toProcess.emplace_back(blockArg, blockArg.getParentBlock());
blockArgsToFree.insert(blockArg);
} else if (visitedBlockArgs
.insert(std::make_tuple(blockArg, definingBlock))
if (!dominators.dominates(definingBlock, value.getParentBlock())) {
toProcess.emplace_back(value, value.getParentBlock());
valuesToFree.insert(value);
} else if (visitedValues.insert(std::make_tuple(value, definingBlock))
.second)
toProcess.emplace_back(blockArg, definingBlock);
toProcess.emplace_back(value, definingBlock);
}
};
@ -316,60 +370,168 @@ private:
// Update buffer aliases to ensure that we free all buffers and block
// arguments at the correct locations.
aliases.remove(blockArgsToFree);
aliases.remove(valuesToFree);
// Add new allocs and additional copy operations.
for (BlockArgument blockArg : blockArgsToFree) {
Block *block = blockArg.getOwner();
for (Value value : valuesToFree) {
if (auto blockArg = value.dyn_cast<BlockArgument>())
introduceBlockArgCopy(blockArg);
else
introduceValueCopyForRegionResult(value);
// Allocate a buffer for the current block argument in the block of
// the associated value (which will be a predecessor block by
// definition).
for (auto it = block->pred_begin(), e = block->pred_end(); it != e;
++it) {
// Get the terminator and the value that will be passed to our
// argument.
Operation *terminator = (*it)->getTerminator();
auto branchInterface = cast<BranchOpInterface>(terminator);
// Convert the mutable operand range to an immutable range and query the
// associated source value.
Value sourceValue =
branchInterface.getSuccessorOperands(it.getSuccessorIndex())
.getValue()[blockArg.getArgNumber()];
// Create a new alloc at the current location of the terminator.
auto memRefType = sourceValue.getType().cast<MemRefType>();
OpBuilder builder(terminator);
// Register the value to require a final dealloc. Note that we do not have
// to assign a block here since we do not want to move the allocation node
// to another location.
allocs.push_back({value, nullptr, nullptr});
}
}
// Extract information about dynamically shaped types by
// extracting their dynamic dimensions.
SmallVector<Value, 4> dynamicOperands;
for (auto shapeElement : llvm::enumerate(memRefType.getShape())) {
if (!ShapedType::isDynamic(shapeElement.value()))
continue;
dynamicOperands.push_back(builder.create<DimOp>(
terminator->getLoc(), sourceValue, shapeElement.index()));
}
// TODO: provide a generic interface to create dialect-specific
// Alloc and CopyOp nodes.
auto alloc = builder.create<AllocOp>(terminator->getLoc(), memRefType,
dynamicOperands);
// Wire new alloc and successor operand.
branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex())
.getValue()
/// Introduces temporary allocs in all predecessors and copies the source
/// values into the newly allocated buffers.
void introduceBlockArgCopy(BlockArgument blockArg) {
// Allocate a buffer for the current block argument in the block of
// the associated value (which will be a predecessor block by
// definition).
Block *block = blockArg.getOwner();
for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
// Get the terminator and the value that will be passed to our
// argument.
Operation *terminator = (*it)->getTerminator();
auto branchInterface = cast<BranchOpInterface>(terminator);
// Query the associated source value.
Value sourceValue =
branchInterface.getSuccessorOperands(it.getSuccessorIndex())
.getValue()[blockArg.getArgNumber()];
// Create a new alloc and copy at the current location of the terminator.
Value alloc = introduceBufferCopy(sourceValue, terminator);
// Wire new alloc and successor operand.
auto mutableOperands =
branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex());
if (!mutableOperands.hasValue())
terminator->emitError() << "terminators with immutable successor "
"operands are not supported";
else
mutableOperands.getValue()
.slice(blockArg.getArgNumber(), 1)
.assign(alloc);
// Create a new copy operation that copies to contents of the old
// allocation to the new one.
builder.create<linalg::CopyOp>(terminator->getLoc(), sourceValue,
alloc);
}
// Register the block argument to require a final dealloc. Note that
// we do not have to assign a block here since we do not want to
// move the allocation node to another location.
allocs.push_back({blockArg, nullptr, nullptr});
}
// Check whether the block argument has implicitly defined predecessors via
// the RegionBranchOpInterface. This can be the case if the current block
// argument belongs to the first block in a region and the parent operation
// implements the RegionBranchOpInterface.
Region *argRegion = block->getParent();
RegionBranchOpInterface regionInterface;
if (!argRegion || &argRegion->front() != block ||
!(regionInterface =
dyn_cast<RegionBranchOpInterface>(argRegion->getParentOp())))
return;
introduceCopiesForRegionSuccessors(
regionInterface, argRegion->getParentOp()->getRegions(),
[&](RegionSuccessor &successorRegion) {
// Find a predecessor of our argRegion.
return successorRegion.getSuccessor() == argRegion;
},
[&](RegionSuccessor &successorRegion) {
// The operand index will be the argument number.
return blockArg.getArgNumber();
});
}
/// Introduces temporary allocs in front of all associated nested-region
/// terminators and copies the source values into the newly allocated buffers.
void introduceValueCopyForRegionResult(Value value) {
// Get the actual result index in the scope of the parent terminator.
Operation *operation = value.getDefiningOp();
auto regionInterface = cast<RegionBranchOpInterface>(operation);
introduceCopiesForRegionSuccessors(
regionInterface, operation->getRegions(),
[&](RegionSuccessor &successorRegion) {
// Determine whether this region has a successor entry that leaves
// this region by returning to its parent operation.
return !successorRegion.getSuccessor();
},
[&](RegionSuccessor &successorRegion) {
// Find the associated success input index.
return llvm::find(successorRegion.getSuccessorInputs(), value)
.getIndex();
});
}
/// Introduces buffer copies for all terminators in the given regions. The
/// regionPredicate is applied to every successor region in order to restrict
/// the copies to specific regions. Thereby, the operandProvider is invoked
/// for each matching region successor and determines the operand index that
/// requires a buffer copy.
template <typename TPredicate, typename TOperandProvider>
void
introduceCopiesForRegionSuccessors(RegionBranchOpInterface regionInterface,
MutableArrayRef<Region> regions,
const TPredicate &regionPredicate,
const TOperandProvider &operandProvider) {
// Create an empty attribute for each operand to comply with the
// `getSuccessorRegions` interface definition that requires a single
// attribute per operand.
SmallVector<Attribute, 2> operandAttributes(
regionInterface.getOperation()->getNumOperands());
for (Region &region : regions) {
// Query the regionInterface to get all successor regions of the current
// one.
SmallVector<RegionSuccessor, 2> successorRegions;
regionInterface.getSuccessorRegions(region.getRegionNumber(),
operandAttributes, successorRegions);
// Try to find a matching region successor.
RegionSuccessor *regionSuccessor =
llvm::find_if(successorRegions, regionPredicate);
if (regionSuccessor == successorRegions.end())
continue;
// Get the operand index in the context of the current successor input
// bindings.
auto operandIndex = operandProvider(*regionSuccessor);
// Iterate over all immediate terminator operations to introduce
// new buffer allocations. Thereby, the appropriate terminator operand
// will be adjusted to point to the newly allocated buffer instead.
walkReturnOperations(&region, [&](Operation *terminator) {
// Extract the source value from the current terminator.
Value sourceValue = terminator->getOperand(operandIndex);
// Create a new alloc at the current location of the terminator.
Value alloc = introduceBufferCopy(sourceValue, terminator);
// Wire alloc and terminator operand.
terminator->setOperand(operandIndex, alloc);
});
}
}
/// Creates a new memory allocation for the given source value and copies
/// its content into the newly allocated buffer. The terminator operation is
/// used to insert the alloc and copy operations at the right places.
Value introduceBufferCopy(Value sourceValue, Operation *terminator) {
// Create a new alloc at the current location of the terminator.
auto memRefType = sourceValue.getType().cast<MemRefType>();
OpBuilder builder(terminator);
// Extract information about dynamically shaped types by
// extracting their dynamic dimensions.
SmallVector<Value, 4> dynamicOperands;
for (auto shapeElement : llvm::enumerate(memRefType.getShape())) {
if (!ShapedType::isDynamic(shapeElement.value()))
continue;
dynamicOperands.push_back(builder.create<DimOp>(
terminator->getLoc(), sourceValue, shapeElement.index()));
}
// TODO: provide a generic interface to create dialect-specific
// Alloc and CopyOp nodes.
auto alloc = builder.create<AllocOp>(terminator->getLoc(), memRefType,
dynamicOperands);
// Create a new copy operation that copies to contents of the old
// allocation to the new one.
builder.create<linalg::CopyOp>(terminator->getLoc(), sourceValue, alloc);
return alloc;
}
/// Finds associated deallocs that can be linked to our allocation nodes (if
@ -440,8 +602,8 @@ private:
if (entry.deallocOperation) {
entry.deallocOperation->moveAfter(endOperation);
} else {
// If the Dealloc position is at the terminator operation of the block,
// then the value should escape from a deallocation.
// If the Dealloc position is at the terminator operation of the
// block, then the value should escape from a deallocation.
Operation *nextOp = endOperation->getNextNode();
if (!nextOp)
continue;

View File

@ -716,3 +716,201 @@ func @memref_in_function_results(%arg0: memref<5xf32>, %arg1: memref<10xf32>, %a
// CHECK: dealloc %[[Y]]
// CHECK: return %[[ARG1]], %[[X]]
// -----
// Test Case: nested region control flow
// The alloc position of %1 does not need to be changed and flows through
// both if branches until it is finally returned. Hence, it does not
// require a specific dealloc operation. However, %3 requires a dealloc.
// CHECK-LABEL: func @nested_region_control_flow
func @nested_region_control_flow(
%arg0 : index,
%arg1 : index) -> memref<?x?xf32> {
%0 = cmpi "eq", %arg0, %arg1 : index
%1 = alloc(%arg0, %arg0) : memref<?x?xf32>
%2 = scf.if %0 -> (memref<?x?xf32>) {
scf.yield %1 : memref<?x?xf32>
} else {
%3 = alloc(%arg0, %arg1) : memref<?x?xf32>
scf.yield %1 : memref<?x?xf32>
}
return %2 : memref<?x?xf32>
}
// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if
// CHECK: scf.yield %[[ALLOC0]]
// CHECK: %[[ALLOC2:.*]] = alloc(%arg0, %arg1)
// CHECK-NEXT: dealloc %[[ALLOC2]]
// CHECK-NEXT: scf.yield %[[ALLOC0]]
// CHECK: return %[[ALLOC1]]
// -----
// Test Case: nested region control flow with a nested buffer allocation in a
// divergent branch.
// The alloc positions of %1, %3 does not need to be changed since
// BufferPlacement does not move allocs out of nested regions at the moment.
// However, since %3 is allocated and "returned" in a divergent branch, we have
// to allocate a temporary buffer (like in condBranchDynamicTypeNested).
// CHECK-LABEL: func @nested_region_control_flow_div
func @nested_region_control_flow_div(
%arg0 : index,
%arg1 : index) -> memref<?x?xf32> {
%0 = cmpi "eq", %arg0, %arg1 : index
%1 = alloc(%arg0, %arg0) : memref<?x?xf32>
%2 = scf.if %0 -> (memref<?x?xf32>) {
scf.yield %1 : memref<?x?xf32>
} else {
%3 = alloc(%arg0, %arg1) : memref<?x?xf32>
scf.yield %3 : memref<?x?xf32>
}
return %2 : memref<?x?xf32>
}
// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if
// CHECK: %[[ALLOC2:.*]] = alloc
// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC2]])
// CHECK: scf.yield %[[ALLOC2]]
// CHECK: %[[ALLOC3:.*]] = alloc(%arg0, %arg1)
// CHECK: %[[ALLOC4:.*]] = alloc
// CHECK-NEXT: linalg.copy(%[[ALLOC3]], %[[ALLOC4]])
// CHECK: dealloc %[[ALLOC3]]
// CHECK: scf.yield %[[ALLOC4]]
// CHECK: dealloc %[[ALLOC0]]
// CHECK-NEXT: return %[[ALLOC1]]
// -----
// Test Case: deeply nested region control flow with a nested buffer allocation
// in a divergent branch.
// The alloc positions of %1, %4 and %5 does not need to be changed since
// BufferPlacement does not move allocs out of nested regions at the moment.
// However, since %4 is allocated and "returned" in a divergent branch, we have
// to allocate several temporary buffers (like in condBranchDynamicTypeNested).
// CHECK-LABEL: func @nested_region_control_flow_div_nested
func @nested_region_control_flow_div_nested(
%arg0 : index,
%arg1 : index) -> memref<?x?xf32> {
%0 = cmpi "eq", %arg0, %arg1 : index
%1 = alloc(%arg0, %arg0) : memref<?x?xf32>
%2 = scf.if %0 -> (memref<?x?xf32>) {
%3 = scf.if %0 -> (memref<?x?xf32>) {
scf.yield %1 : memref<?x?xf32>
} else {
%4 = alloc(%arg0, %arg1) : memref<?x?xf32>
scf.yield %4 : memref<?x?xf32>
}
scf.yield %3 : memref<?x?xf32>
} else {
%5 = alloc(%arg1, %arg1) : memref<?x?xf32>
scf.yield %5 : memref<?x?xf32>
}
return %2 : memref<?x?xf32>
}
// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
// CHECK-NEXT: %[[ALLOC1:.*]] = scf.if
// CHECK-NEXT: %[[ALLOC2:.*]] = scf.if
// CHECK: %[[ALLOC3:.*]] = alloc
// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC3]])
// CHECK: scf.yield %[[ALLOC3]]
// CHECK: %[[ALLOC4:.*]] = alloc(%arg0, %arg1)
// CHECK: %[[ALLOC5:.*]] = alloc
// CHECK-NEXT: linalg.copy(%[[ALLOC4]], %[[ALLOC5]])
// CHECK: dealloc %[[ALLOC4]]
// CHECK: scf.yield %[[ALLOC5]]
// CHECK: %[[ALLOC6:.*]] = alloc
// CHECK-NEXT: linalg.copy(%[[ALLOC2]], %[[ALLOC6]])
// CHECK: dealloc %[[ALLOC2]]
// CHECK: scf.yield %[[ALLOC6]]
// CHECK: %[[ALLOC7:.*]] = alloc(%arg1, %arg1)
// CHECK: %[[ALLOC8:.*]] = alloc
// CHECK-NEXT: linalg.copy(%[[ALLOC7]], %[[ALLOC8]])
// CHECK: dealloc %[[ALLOC7]]
// CHECK: scf.yield %[[ALLOC8]]
// CHECK: dealloc %[[ALLOC0]]
// CHECK-NEXT: return %[[ALLOC1]]
// -----
// Test Case: nested region control flow within a region interface.
// The alloc positions of %0 does not need to be changed and no copies are
// required in this case since the allocation finally escapes the method.
// CHECK-LABEL: func @inner_region_control_flow
func @inner_region_control_flow(%arg0 : index) -> memref<?x?xf32> {
%0 = alloc(%arg0, %arg0) : memref<?x?xf32>
%1 = test.region_if %0 : memref<?x?xf32> -> (memref<?x?xf32>) then {
^bb0(%arg1 : memref<?x?xf32>):
test.region_if_yield %arg1 : memref<?x?xf32>
} else {
^bb0(%arg1 : memref<?x?xf32>):
test.region_if_yield %arg1 : memref<?x?xf32>
} join {
^bb0(%arg1 : memref<?x?xf32>):
test.region_if_yield %arg1 : memref<?x?xf32>
}
return %1 : memref<?x?xf32>
}
// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
// CHECK-NEXT: %[[ALLOC1:.*]] = test.region_if
// CHECK-NEXT: ^bb0(%[[ALLOC2:.*]]:{{.*}}):
// CHECK-NEXT: test.region_if_yield %[[ALLOC2]]
// CHECK: ^bb0(%[[ALLOC3:.*]]:{{.*}}):
// CHECK-NEXT: test.region_if_yield %[[ALLOC3]]
// CHECK: ^bb0(%[[ALLOC4:.*]]:{{.*}}):
// CHECK-NEXT: test.region_if_yield %[[ALLOC4]]
// CHECK: return %[[ALLOC1]]
// -----
// Test Case: nested region control flow within a region interface including an
// allocation in a divergent branch.
// The alloc positions of %1 and %2 does not need to be changed since
// BufferPlacement does not move allocs out of nested regions at the moment.
// However, since %2 is allocated and yielded in a divergent branch, we have
// to allocate several temporary buffers (like in condBranchDynamicTypeNested).
// CHECK-LABEL: func @inner_region_control_flow_div
func @inner_region_control_flow_div(
%arg0 : index,
%arg1 : index) -> memref<?x?xf32> {
%0 = alloc(%arg0, %arg0) : memref<?x?xf32>
%1 = test.region_if %0 : memref<?x?xf32> -> (memref<?x?xf32>) then {
^bb0(%arg2 : memref<?x?xf32>):
test.region_if_yield %arg2 : memref<?x?xf32>
} else {
^bb0(%arg2 : memref<?x?xf32>):
%2 = alloc(%arg0, %arg1) : memref<?x?xf32>
test.region_if_yield %2 : memref<?x?xf32>
} join {
^bb0(%arg2 : memref<?x?xf32>):
test.region_if_yield %arg2 : memref<?x?xf32>
}
return %1 : memref<?x?xf32>
}
// CHECK: %[[ALLOC0:.*]] = alloc(%arg0, %arg0)
// CHECK-NEXT: %[[ALLOC1:.*]] = test.region_if
// CHECK-NEXT: ^bb0(%[[ALLOC2:.*]]:{{.*}}):
// CHECK: %[[ALLOC3:.*]] = alloc
// CHECK-NEXT: linalg.copy(%[[ALLOC2]], %[[ALLOC3]])
// CHECK-NEXT: test.region_if_yield %[[ALLOC3]]
// CHECK: ^bb0(%[[ALLOC4:.*]]:{{.*}}):
// CHECK: %[[ALLOC5:.*]] = alloc
// CHECK: %[[ALLOC6:.*]] = alloc
// CHECK-NEXT: linalg.copy(%[[ALLOC5]], %[[ALLOC6]])
// CHECK-NEXT: dealloc %[[ALLOC5]]
// CHECK-NEXT: test.region_if_yield %[[ALLOC6]]
// CHECK: ^bb0(%[[ALLOC7:.*]]:{{.*}}):
// CHECK: %[[ALLOC8:.*]] = alloc
// CHECK-NEXT: linalg.copy(%[[ALLOC7]], %[[ALLOC8]])
// CHECK-NEXT: dealloc %[[ALLOC7]]
// CHECK-NEXT: test.region_if_yield %[[ALLOC8]]
// CHECK: dealloc %[[ALLOC0]]
// CHECK-NEXT: return %[[ALLOC1]]

View File

@ -518,6 +518,77 @@ void StringAttrPrettyNameOp::getAsmResultNames(
setNameFn(getResult(i), str.getValue());
}
//===----------------------------------------------------------------------===//
// RegionIfOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, RegionIfOp op) {
p << RegionIfOp::getOperationName() << " ";
p.printOperands(op.getOperands());
p << ": " << op.getOperandTypes();
p.printArrowTypeList(op.getResultTypes());
p << " then";
p.printRegion(op.thenRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
p << " else";
p.printRegion(op.elseRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
p << " join";
p.printRegion(op.joinRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
}
static ParseResult parseRegionIfOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> operandInfos;
SmallVector<Type, 2> operandTypes;
result.regions.reserve(3);
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();
Region *joinRegion = result.addRegion();
// Parse operand, type and arrow type lists.
if (parser.parseOperandList(operandInfos) ||
parser.parseColonTypeList(operandTypes) ||
parser.parseArrowTypeList(result.types))
return failure();
// Parse all attached regions.
if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
return failure();
return parser.resolveOperands(operandInfos, operandTypes,
parser.getCurrentLocation(), result.operands);
}
OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
assert(index < 2 && "invalid region index");
return getOperands();
}
void RegionIfOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
// We always branch to the join region.
if (index.hasValue()) {
if (index.getValue() < 2)
regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
else
regions.push_back(RegionSuccessor(getResults()));
return;
}
// The then and else regions are the entry regions of this op.
regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
}
//===----------------------------------------------------------------------===//
// Dialect Registration
//===----------------------------------------------------------------------===//

View File

@ -1349,4 +1349,47 @@ def SideEffectOp : TEST_Op<"side_effect_op",
let results = (outs AnyType:$result);
}
//===----------------------------------------------------------------------===//
// Test RegionBranchOpInterface
//===----------------------------------------------------------------------===//
def RegionIfYieldOp : TEST_Op<"region_if_yield",
[NoSideEffect, ReturnLike, Terminator]> {
let arguments = (ins Variadic<AnyType>:$results);
let assemblyFormat = [{
$results `:` type($results) attr-dict
}];
}
def RegionIfOp : TEST_Op<"region_if",
[DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"RegionIfYieldOp">,
RecursiveSideEffects]> {
let description =[{
Represents an abstract if-then-else-join pattern. In this context, the then
and else regions jump to the join region, which finally returns to its
parent op.
}];
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseRegionIfOp(parser, result); }];
let arguments = (ins Variadic<AnyType>);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$thenRegion,
AnyRegion:$elseRegion,
AnyRegion:$joinRegion);
let extraClassDeclaration = [{
Block::BlockArgListType getThenArgs() {
return getBody(0)->getArguments();
}
Block::BlockArgListType getElseArgs() {
return getBody(1)->getArguments();
}
Block::BlockArgListType getJoinArgs() {
return getBody(2)->getArguments();
}
OperandRange getSuccessorEntryOperands(unsigned index);
}];
}
#endif // TEST_OPS