[MLIR][SPIRV] Add support for OpCopyMemory.

This patch add support for 'spv.CopyMemory'. The following changes are
introduced:
- 'CopyMemory' op is added to SPIRVOps.td.
- Custom parse and print methods are introduced.
- A few Roundtripping tests are added.

Differential Revision: https://reviews.llvm.org/D82384
This commit is contained in:
ergawy 2020-06-26 09:37:30 -04:00 committed by Lei Zhang
parent 652a79659a
commit d6485ed3a7
5 changed files with 237 additions and 30 deletions

View File

@ -3135,6 +3135,7 @@ def SPV_OC_OpFunctionCall : I32EnumAttrCase<"OpFunctionCall", 57>;
def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
def SPV_OC_OpCopyMemory : I32EnumAttrCase<"OpCopyMemory", 63>;
def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>;
@ -3264,23 +3265,23 @@ def SPV_OpcodeAttr :
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose,
SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,
SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory,
SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU,
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,

View File

@ -173,6 +173,58 @@ def SPV_ControlBarrierOp : SPV_Op<"ControlBarrier", []> {
// -----
def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> {
let summary = [{
Copy from the memory pointed to by Source to the memory pointed to by
Target. Both operands must be non-void pointers and having the same <id>
Type operand in their OpTypePointer type declaration. Matching Storage
Class is not required. The amount of memory copied is the size of the
type pointed to. The copied type must have a fixed size; i.e., it cannot
be, nor include, any OpTypeRuntimeArray types.
}];
let description = [{
If present, any Memory Operands must begin with a memory operand
literal. If not present, it is the same as specifying the memory operand
None. Before version 1.4, at most one memory operands mask can be
provided. Starting with version 1.4 two masks can be provided, as
described in Memory Operands. If no masks or only one mask is present,
it applies to both Source and Target. If two masks are present, the
first applies to Target and cannot include MakePointerVisible, and the
second applies to Source and cannot include MakePointerAvailable.
<!-- End of AutoGen section -->
```
copy-memory-op ::= `spv.CopyMemory ` storage-class ssa-use
storage-class ssa-use
(`[` memory-access `]`)?
` : ` spirv-element-type
```
#### Example:
```mlir
%0 = spv.Variable : !spv.ptr<f32, Function>
%1 = spv.Variable : !spv.ptr<f32, Function>
spv.CopyMemory "Function" %0, "Function" %1 : f32
```
}];
let arguments = (ins
SPV_AnyPtr:$target,
SPV_AnyPtr:$source,
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
OptionalAttr<I32Attr>:$alignment
);
let results = (outs);
let verifier = [{ return verifyCopyMemory(*this); }];
}
// -----
def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [InModuleScope]> {
let summary = "Declare an execution mode for an entry point.";

View File

@ -183,17 +183,17 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
return parser.parseRSquare();
}
template <typename LoadStoreOpTy>
template <typename MemoryOpTy>
static void
printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter &printer,
printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer,
SmallVectorImpl<StringRef> &elidedAttrs) {
// Print optional memory access attribute.
if (auto memAccess = loadStoreOp.memory_access()) {
if (auto memAccess = memoryOp.memory_access()) {
elidedAttrs.push_back(spirv::attributeName<spirv::MemoryAccess>());
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
// Print integer alignment attribute.
if (auto alignment = loadStoreOp.alignment()) {
if (auto alignment = memoryOp.alignment()) {
elidedAttrs.push_back(kAlignmentAttrName);
printer << ", " << alignment;
}
@ -243,18 +243,18 @@ static LogicalResult verifyCastOp(Operation *op,
return success();
}
template <typename LoadStoreOpTy>
static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) {
template <typename MemoryOpTy>
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
// ODS checks for attributes values. Just need to verify that if the
// memory-access attribute is Aligned, then the alignment attribute must be
// present.
auto *op = loadStoreOp.getOperation();
auto *op = memoryOp.getOperation();
auto memAccessAttr = op->getAttr(spirv::attributeName<spirv::MemoryAccess>());
if (!memAccessAttr) {
// Alignment attribute shouldn't be present if memory access attribute is
// not present.
if (op->getAttr(kAlignmentAttrName)) {
return loadStoreOp.emitOpError(
return memoryOp.emitOpError(
"invalid alignment specification without aligned memory access "
"specification");
}
@ -265,17 +265,17 @@ static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) {
auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
if (!memAccess) {
return loadStoreOp.emitOpError("invalid memory access specifier: ")
return memoryOp.emitOpError("invalid memory access specifier: ")
<< memAccessVal;
}
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
if (!op->getAttr(kAlignmentAttrName)) {
return loadStoreOp.emitOpError("missing alignment value");
return memoryOp.emitOpError("missing alignment value");
}
} else {
if (op->getAttr(kAlignmentAttrName)) {
return loadStoreOp.emitOpError(
return memoryOp.emitOpError(
"invalid alignment specification with non-aligned memory access "
"specification");
}
@ -2752,8 +2752,7 @@ static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
if (op.c().getType() != op.result().getType())
return op.emitOpError(
"result and third operand must have the same type");
return op.emitOpError("result and third operand must have the same type");
auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
@ -2812,9 +2811,89 @@ static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
"have the same size");
}
}
return success();
}
//===----------------------------------------------------------------------===//
// spv.CopyMemory
//===----------------------------------------------------------------------===//
static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) {
auto *op = copyMemory.getOperation();
printer << spirv::CopyMemoryOp::getOperationName() << ' ';
StringRef targetStorageClass =
stringifyStorageClass(copyMemory.target()
.getType()
.cast<spirv::PointerType>()
.getStorageClass());
printer << " \"" << targetStorageClass << "\" " << copyMemory.target()
<< ", ";
StringRef sourceStorageClass =
stringifyStorageClass(copyMemory.source()
.getType()
.cast<spirv::PointerType>()
.getStorageClass());
printer << " \"" << sourceStorageClass << "\" " << copyMemory.source();
SmallVector<StringRef, 4> elidedAttrs;
printMemoryAccessAttribute(copyMemory, printer, elidedAttrs);
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
Type pointeeType =
copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
printer << " : " << pointeeType;
}
static ParseResult parseCopyMemoryOp(OpAsmParser &parser,
OperationState &state) {
spirv::StorageClass targetStorageClass;
OpAsmParser::OperandType targetPtrInfo;
spirv::StorageClass sourceStorageClass;
OpAsmParser::OperandType sourcePtrInfo;
Type elementType;
if (parseEnumStrAttr(targetStorageClass, parser) ||
parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
parseEnumStrAttr(sourceStorageClass, parser) ||
parser.parseOperand(sourcePtrInfo) ||
parseMemoryAccessAttributes(parser, state) ||
parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
parser.parseType(elementType)) {
return failure();
}
auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) ||
parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) {
return failure();
}
return success();
}
static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) {
Type targetType =
copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
Type sourceType =
copyMemory.source().getType().cast<spirv::PointerType>().getPointeeType();
if (targetType != sourceType) {
return copyMemory.emitOpError(
"both operands must be pointers to the same type");
}
return verifyMemoryAccessAttribute(copyMemory);
}
//===----------------------------------------------------------------------===//
// spv.Transpose
//===----------------------------------------------------------------------===//

View File

@ -57,3 +57,43 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.Return
}
}
// -----
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.func @copy_memory_simple() "None" {
%0 = spv.Variable : !spv.ptr<f32, Function>
%1 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} : f32
spv.CopyMemory "Function" %0, "Function" %1 : f32
spv.Return
}
}
// -----
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.func @copy_memory_different_storage_classes(%in : !spv.ptr<!spv.array<4xf32>, Input>, %out : !spv.ptr<!spv.array<4xf32>, Output>) "None" {
// CHECK: spv.CopyMemory "Output" %{{.*}}, "Input" %{{.*}} : !spv.array<4 x f32>
spv.CopyMemory "Output" %out, "Input" %in : !spv.array<4xf32>
spv.Return
}
}
// -----
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.func @copy_memory_with_access_operands() "None" {
%0 = spv.Variable : !spv.ptr<f32, Function>
%1 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32
spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 4] : f32
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32
spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"] : f32
spv.Return
}
}

View File

@ -1244,3 +1244,38 @@ func @cannot_be_generic_storage_class(%arg0: f32) -> () {
%0 = spv.Variable : !spv.ptr<f32, Generic>
return
}
// -----
func @copy_memory_incompatible_ptrs() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
%1 = spv.Variable : !spv.ptr<i32, Function>
// expected-error @+1 {{both operands must be pointers to the same type}}
"spv.CopyMemory"(%0, %1) {} : (!spv.ptr<f32, Function>, !spv.ptr<i32, Function>) -> ()
spv.Return
}
// -----
func @copy_memory_invalid_maa() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
%1 = spv.Variable : !spv.ptr<f32, Function>
// expected-error @+1 {{missing alignment value}}
"spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
spv.Return
}
// -----
func @copy_memory_print_maa() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
%1 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32
"spv.CopyMemory"(%0, %1) {memory_access=0x0001 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32
"spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
spv.Return
}