[mlir][spirv] Add support for lowering scf.for scf/if with return value

This allow lowering to support scf.for and scf.if with results. As right now
spv region operations don't have return value the results are demoted to
Function memory. We create one allocation per result right before the region
and store the yield values in it. Then we can load back the value from
allocation to be able to use the results.

Differential Revision: https://reviews.llvm.org/D82246
This commit is contained in:
Thomas Raoux 2020-07-01 17:08:08 -07:00
parent fbce9855e9
commit 0670f855a7
6 changed files with 245 additions and 30 deletions

View File

@ -21,11 +21,23 @@ class Pass;
// Owning list of rewriting patterns.
class OwningRewritePatternList;
class SPIRVTypeConverter;
struct ScfToSPIRVContextImpl;
struct ScfToSPIRVContext {
ScfToSPIRVContext();
~ScfToSPIRVContext();
ScfToSPIRVContextImpl *getImpl() { return impl.get(); }
private:
std::unique_ptr<ScfToSPIRVContextImpl> impl;
};
/// Collects a set of patterns to lower from scf.for, scf.if, and
/// loop.terminator to CFG operations within the SPIR-V dialect.
void populateSCFToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
OwningRewritePatternList &patterns);
} // namespace mlir

View File

@ -58,9 +58,10 @@ void GPUToSPIRVPass::runOnOperation() {
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
ScfToSPIRVContext scfContext;
OwningRewritePatternList patterns;
populateGPUToSPIRVPatterns(context, typeConverter, patterns);
populateSCFToSPIRVPatterns(context, typeConverter, patterns);
populateSCFToSPIRVPatterns(context, typeConverter,scfContext, patterns);
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
if (failed(applyFullConversion(kernelModules, *target, patterns)))

View File

@ -18,12 +18,44 @@
using namespace mlir;
namespace mlir {
struct ScfToSPIRVContextImpl {
// Map between the spirv region control flow operation (spv.loop or
// spv.selection) to the VariableOp created to store the region results. The
// order of the VariableOp matches the order of the results.
DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
};
} // namespace mlir
/// We use ScfToSPIRVContext to store information about the lowering of the scf
/// region that need to be used later on. When we lower scf.for/scf.if we create
/// VariableOp to store the results. We need to keep track of the VariableOp
/// created as we need to insert stores into them when lowering Yield. Those
/// StoreOp cannot be created earlier as they may use a different type than
/// yield operands.
ScfToSPIRVContext::ScfToSPIRVContext() {
impl = std::make_unique<ScfToSPIRVContextImpl>();
}
ScfToSPIRVContext::~ScfToSPIRVContext() = default;
namespace {
/// Common class for all vector to GPU patterns.
template <typename OpTy>
class SCFToSPIRVPattern : public SPIRVOpLowering<OpTy> {
public:
SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
ScfToSPIRVContextImpl *scfToSPIRVContext)
: SPIRVOpLowering<OpTy>::SPIRVOpLowering(context, converter),
scfToSPIRVContext(scfToSPIRVContext) {}
protected:
ScfToSPIRVContextImpl *scfToSPIRVContext;
};
/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
class ForOpConversion final : public SPIRVOpLowering<scf::ForOp> {
class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> {
public:
using SPIRVOpLowering<scf::ForOp>::SPIRVOpLowering;
using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;
LogicalResult
matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
@ -32,29 +64,54 @@ public:
/// Pattern to convert a scf::IfOp within kernel functions into
/// spirv::SelectionOp.
class IfOpConversion final : public SPIRVOpLowering<scf::IfOp> {
class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> {
public:
using SPIRVOpLowering<scf::IfOp>::SPIRVOpLowering;
using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;
LogicalResult
matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Pattern to erase a scf::YieldOp.
class TerminatorOpConversion final : public SPIRVOpLowering<scf::YieldOp> {
class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
public:
using SPIRVOpLowering<scf::YieldOp>::SPIRVOpLowering;
using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;
LogicalResult
matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(terminatorOp);
return success();
}
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
/// Helper function to replaces SCF op outputs with SPIR-V variable loads.
/// We create VariableOp to handle the results value of the control flow region.
/// spv.loop/spv.selection currently don't yield value. Right after the loop
/// we load the value from the allocation and use it as the SCF op result.
template <typename ScfOp, typename OpTy>
static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
SPIRVTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter,
ScfToSPIRVContextImpl *scfToSPIRVContext) {
Location loc = scfOp.getLoc();
auto &allocas = scfToSPIRVContext->outputVars[newOp];
SmallVector<Value, 8> resultValue;
for (Value result : scfOp.results()) {
auto convertedType = typeConverter.convertType(result.getType());
auto pointerType =
spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
rewriter.setInsertionPoint(newOp);
auto alloc = rewriter.create<spirv::VariableOp>(
loc, pointerType, spirv::StorageClass::Function,
/*initializer=*/nullptr);
allocas.push_back(alloc);
rewriter.setInsertionPointAfter(newOp);
Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
resultValue.push_back(loadResult);
}
rewriter.replaceOp(scfOp, resultValue);
}
//===----------------------------------------------------------------------===//
// scf::ForOp.
//===----------------------------------------------------------------------===//
@ -83,6 +140,8 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
// Create the new induction variable to use.
BlockArgument newIndVar =
header->addArgument(forOperands.lowerBound().getType());
for (Value arg : forOperands.initArgs())
header->addArgument(arg.getType());
Block *body = forOp.getBody();
// Apply signature conversion to the body of the forOp. It has a single block,
@ -91,29 +150,28 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
TypeConverter::SignatureConversion signatureConverter(
body->getNumArguments());
signatureConverter.remapInput(0, newIndVar);
FailureOr<Block *> newBody = rewriter.convertRegionTypes(
&forOp.getLoopBody(), typeConverter, &signatureConverter);
if (failed(newBody))
return failure();
body = *newBody;
// Delete the loop terminator.
rewriter.eraseOp(body->getTerminator());
for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
signatureConverter.remapInput(i, header->getArgument(i));
body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
signatureConverter);
// Move the blocks from the forOp into the loopOp. This is the body of the
// loopOp.
rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(),
std::next(loopOp.body().begin(), 2));
SmallVector<Value, 8> args(1, forOperands.lowerBound());
args.append(forOperands.initArgs().begin(), forOperands.initArgs().end());
// Branch into it from the entry.
rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
rewriter.create<spirv::BranchOp>(loc, header, forOperands.lowerBound());
rewriter.create<spirv::BranchOp>(loc, header, args);
// Generate the rest of the loop header.
rewriter.setInsertionPointToEnd(header);
auto *mergeBlock = loopOp.getMergeBlock();
auto cmpOp = rewriter.create<spirv::SLessThanOp>(
loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
rewriter.create<spirv::BranchConditionalOp>(
loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
@ -127,7 +185,8 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
loc, newIndVar.getType(), newIndVar, forOperands.step());
rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
rewriter.eraseOp(forOp);
replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter,
scfToSPIRVContext);
return success();
}
@ -179,13 +238,45 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
thenBlock, ArrayRef<Value>(),
elseBlock, ArrayRef<Value>());
rewriter.eraseOp(ifOp);
replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter,
scfToSPIRVContext);
return success();
}
/// Yield is lowered to stores to the VariableOp created during lowering of the
/// parent region. For loops we also need to update the branch looping back to
/// the header with the loop carried values.
LogicalResult TerminatorOpConversion::matchAndRewrite(
scf::YieldOp terminatorOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// If the region is return values, store each value into the associated
// VariableOp created during lowering of the parent region.
if (!operands.empty()) {
auto loc = terminatorOp.getLoc();
auto &allocas = scfToSPIRVContext->outputVars[terminatorOp.getParentOp()];
assert(allocas.size() == operands.size());
for (unsigned i = 0, e = operands.size(); i < e; i++)
rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
if (isa<spirv::LoopOp>(terminatorOp.getParentOp())) {
// For loops we also need to update the branch jumping back to the header.
auto br =
cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator());
SmallVector<Value, 8> args(br.getBlockArguments());
args.append(operands.begin(), operands.end());
rewriter.setInsertionPoint(br);
rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
args);
rewriter.eraseOp(br);
}
}
rewriter.eraseOp(terminatorOp);
return success();
}
void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
OwningRewritePatternList &patterns) {
patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
context, typeConverter);
context, typeConverter, scfToSPIRVContext.getImpl());
}

View File

@ -589,9 +589,6 @@ StorageClass PointerType::getStorageClass() const {
void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
if (storage)
assert(*storage == getStorageClass() && "inconsistent storage class!");
// Use this pointer type's storage class because this pointer indicates we are
// using the pointee type in that specific storage class.
getPointeeType().cast<SPIRVType>().getExtensions(extensions,
@ -604,9 +601,6 @@ void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
void PointerType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
if (storage)
assert(*storage == getStorageClass() && "inconsistent storage class!");
// Use this pointer type's storage class because this pointer indicates we are
// using the pointee type in that specific storage class.
getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,

View File

@ -89,5 +89,79 @@ module attributes {
}
gpu.return
}
// CHECK-LABEL: @simple_if_yield
gpu.func @simple_if_yield(%arg2 : memref<10xf32>, %arg3 : i1) kernel
attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
// CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr<f32, Function>
// CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.selection {
// CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]]
// CHECK-NEXT: [[TRUE]]:
// CHECK: %[[RET1TRUE:.*]] = spv.constant 0.000000e+00 : f32
// CHECK: %[[RET2TRUE:.*]] = spv.constant 1.000000e+00 : f32
// CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[RET1TRUE]] : f32
// CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[RET2TRUE]] : f32
// CHECK: spv.Branch ^[[MERGE:.*]]
// CHECK-NEXT: [[FALSE]]:
// CHECK: %[[RET2FALSE:.*]] = spv.constant 2.000000e+00 : f32
// CHECK: %[[RET1FALSE:.*]] = spv.constant 3.000000e+00 : f32
// CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[RET1FALSE]] : f32
// CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[RET2FALSE]] : f32
// CHECK: spv.Branch ^[[MERGE]]
// CHECK-NEXT: ^[[MERGE]]:
// CHECK: spv._merge
// CHECK-NEXT: }
// CHECK-DAG: %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32
// CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32
// CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32
// CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32
// CHECK: spv.Return
%0:2 = scf.if %arg3 -> (f32, f32) {
%c0 = constant 0.0 : f32
%c1 = constant 1.0 : f32
scf.yield %c0, %c1 : f32, f32
} else {
%c0 = constant 2.0 : f32
%c1 = constant 3.0 : f32
scf.yield %c1, %c0 : f32, f32
}
%i = constant 0 : index
%j = constant 1 : index
store %0#0, %arg2[%i] : memref<10xf32>
store %0#1, %arg2[%j] : memref<10xf32>
gpu.return
}
// TODO(thomasraoux): The transformation should only be legal if
// VariablePointer capability is supported. This test is still useful to
// make sure we can handle scf op result with type change.
// CHECK-LABEL: @simple_if_yield_type_change
// CHECK: %[[VAR:.*]] = spv.Variable : !spv.ptr<!spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>, Function>
// CHECK: spv.selection {
// CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]]
// CHECK-NEXT: [[TRUE]]:
// CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
// CHECK: spv.Branch ^[[MERGE:.*]]
// CHECK-NEXT: [[FALSE]]:
// CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
// CHECK: spv.Branch ^[[MERGE]]
// CHECK-NEXT: ^[[MERGE]]:
// CHECK: spv._merge
// CHECK-NEXT: }
// CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
// CHECK: %[[ADD:.*]] = spv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
// CHECK: spv.Store "StorageBuffer" %[[ADD]], {{%.*}} : f32
// CHECK: spv.Return
gpu.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>, %arg4 : i1) kernel
attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
%i = constant 0 : index
%value = constant 0.0 : f32
%0 = scf.if %arg4 -> (memref<10xf32>) {
scf.yield %arg2 : memref<10xf32>
} else {
scf.yield %arg3 : memref<10xf32>
}
store %value, %0[%i] : memref<10xf32>
gpu.return
}
}
}

View File

@ -51,5 +51,48 @@ module attributes {
}
gpu.return
}
// CHECK-LABEL: @loop_yield
gpu.func @loop_yield(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) kernel
attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
// CHECK: %[[LB:.*]] = spv.constant 4 : i32
%lb = constant 4 : index
// CHECK: %[[UB:.*]] = spv.constant 42 : i32
%ub = constant 42 : index
// CHECK: %[[STEP:.*]] = spv.constant 2 : i32
%step = constant 2 : index
// CHECK: %[[INITVAR1:.*]] = spv.constant 0.000000e+00 : f32
%s0 = constant 0.0 : f32
// CHECK: %[[INITVAR2:.*]] = spv.constant 1.000000e+00 : f32
%s1 = constant 1.0 : f32
// CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr<f32, Function>
// CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.loop {
// CHECK: spv.Branch ^[[HEADER:.*]](%[[LB]], %[[INITVAR1]], %[[INITVAR2]] : i32, f32, f32)
// CHECK: ^[[HEADER]](%[[INDVAR:.*]]: i32, %[[CARRIED1:.*]]: f32, %[[CARRIED2:.*]]: f32):
// CHECK: %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32
// CHECK: spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
// CHECK: ^[[BODY]]:
// CHECK: %[[UPDATED:.*]] = spv.FAdd %[[CARRIED1]], %[[CARRIED1]] : f32
// CHECK-DAG: %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32
// CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[UPDATED]] : f32
// CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[UPDATED]] : f32
// CHECK: spv.Branch ^[[HEADER]](%[[INCREMENT]], %[[UPDATED]], %[[UPDATED]] : i32, f32, f32)
// CHECK: ^[[MERGE]]:
// CHECK: spv._merge
// CHECK: }
%result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%si = %s0, %sj = %s1) -> (f32, f32) {
%sn = addf %si, %si : f32
scf.yield %sn, %sn : f32, f32
}
// CHECK-DAG: %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32
// CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32
// CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32
// CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32
store %result#0, %arg3[%lb] : memref<10xf32>
store %result#1, %arg3[%ub] : memref<10xf32>
gpu.return
}
}
}