[MLIR][SPIRV] Add support for OpCopyMemory.
This patch add support for 'spv.CopyMemory'. The following changes are introduced: - 'CopyMemory' op is added to SPIRVOps.td. - Custom parse and print methods are introduced. - A few Roundtripping tests are added. Differential Revision: https://reviews.llvm.org/D82384
This commit is contained in:
parent
652a79659a
commit
d6485ed3a7
|
@ -3135,6 +3135,7 @@ def SPV_OC_OpFunctionCall : I32EnumAttrCase<"OpFunctionCall", 57>;
|
|||
def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
|
||||
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
|
||||
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
|
||||
def SPV_OC_OpCopyMemory : I32EnumAttrCase<"OpCopyMemory", 63>;
|
||||
def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
|
||||
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
|
||||
def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>;
|
||||
|
@ -3264,23 +3265,23 @@ def SPV_OpcodeAttr :
|
|||
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
|
||||
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
|
||||
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
|
||||
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
|
||||
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
|
||||
SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose,
|
||||
SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
|
||||
SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,
|
||||
SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
|
||||
SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
|
||||
SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
|
||||
SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
|
||||
SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
|
||||
SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
|
||||
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
|
||||
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
|
||||
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
|
||||
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
|
||||
SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
|
||||
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
|
||||
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory,
|
||||
SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
|
||||
SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
|
||||
SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU,
|
||||
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
|
||||
SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
|
||||
SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
|
||||
SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
|
||||
SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
|
||||
SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
|
||||
SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
|
||||
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
|
||||
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
|
||||
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
|
||||
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
|
||||
SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
|
||||
SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
|
||||
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
|
||||
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
|
||||
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
|
||||
|
|
|
@ -173,6 +173,58 @@ def SPV_ControlBarrierOp : SPV_Op<"ControlBarrier", []> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> {
|
||||
let summary = [{
|
||||
Copy from the memory pointed to by Source to the memory pointed to by
|
||||
Target. Both operands must be non-void pointers and having the same <id>
|
||||
Type operand in their OpTypePointer type declaration. Matching Storage
|
||||
Class is not required. The amount of memory copied is the size of the
|
||||
type pointed to. The copied type must have a fixed size; i.e., it cannot
|
||||
be, nor include, any OpTypeRuntimeArray types.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
If present, any Memory Operands must begin with a memory operand
|
||||
literal. If not present, it is the same as specifying the memory operand
|
||||
None. Before version 1.4, at most one memory operands mask can be
|
||||
provided. Starting with version 1.4 two masks can be provided, as
|
||||
described in Memory Operands. If no masks or only one mask is present,
|
||||
it applies to both Source and Target. If two masks are present, the
|
||||
first applies to Target and cannot include MakePointerVisible, and the
|
||||
second applies to Source and cannot include MakePointerAvailable.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
copy-memory-op ::= `spv.CopyMemory ` storage-class ssa-use
|
||||
storage-class ssa-use
|
||||
(`[` memory-access `]`)?
|
||||
` : ` spirv-element-type
|
||||
```
|
||||
|
||||
#### Example:
|
||||
|
||||
```mlir
|
||||
%0 = spv.Variable : !spv.ptr<f32, Function>
|
||||
%1 = spv.Variable : !spv.ptr<f32, Function>
|
||||
spv.CopyMemory "Function" %0, "Function" %1 : f32
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_AnyPtr:$target,
|
||||
SPV_AnyPtr:$source,
|
||||
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
|
||||
OptionalAttr<I32Attr>:$alignment
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
let verifier = [{ return verifyCopyMemory(*this); }];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [InModuleScope]> {
|
||||
let summary = "Declare an execution mode for an entry point.";
|
||||
|
||||
|
|
|
@ -183,17 +183,17 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
|
|||
return parser.parseRSquare();
|
||||
}
|
||||
|
||||
template <typename LoadStoreOpTy>
|
||||
template <typename MemoryOpTy>
|
||||
static void
|
||||
printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter &printer,
|
||||
printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer,
|
||||
SmallVectorImpl<StringRef> &elidedAttrs) {
|
||||
// Print optional memory access attribute.
|
||||
if (auto memAccess = loadStoreOp.memory_access()) {
|
||||
if (auto memAccess = memoryOp.memory_access()) {
|
||||
elidedAttrs.push_back(spirv::attributeName<spirv::MemoryAccess>());
|
||||
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
|
||||
|
||||
// Print integer alignment attribute.
|
||||
if (auto alignment = loadStoreOp.alignment()) {
|
||||
if (auto alignment = memoryOp.alignment()) {
|
||||
elidedAttrs.push_back(kAlignmentAttrName);
|
||||
printer << ", " << alignment;
|
||||
}
|
||||
|
@ -243,18 +243,18 @@ static LogicalResult verifyCastOp(Operation *op,
|
|||
return success();
|
||||
}
|
||||
|
||||
template <typename LoadStoreOpTy>
|
||||
static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) {
|
||||
template <typename MemoryOpTy>
|
||||
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
|
||||
// ODS checks for attributes values. Just need to verify that if the
|
||||
// memory-access attribute is Aligned, then the alignment attribute must be
|
||||
// present.
|
||||
auto *op = loadStoreOp.getOperation();
|
||||
auto *op = memoryOp.getOperation();
|
||||
auto memAccessAttr = op->getAttr(spirv::attributeName<spirv::MemoryAccess>());
|
||||
if (!memAccessAttr) {
|
||||
// Alignment attribute shouldn't be present if memory access attribute is
|
||||
// not present.
|
||||
if (op->getAttr(kAlignmentAttrName)) {
|
||||
return loadStoreOp.emitOpError(
|
||||
return memoryOp.emitOpError(
|
||||
"invalid alignment specification without aligned memory access "
|
||||
"specification");
|
||||
}
|
||||
|
@ -265,17 +265,17 @@ static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) {
|
|||
auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
|
||||
|
||||
if (!memAccess) {
|
||||
return loadStoreOp.emitOpError("invalid memory access specifier: ")
|
||||
return memoryOp.emitOpError("invalid memory access specifier: ")
|
||||
<< memAccessVal;
|
||||
}
|
||||
|
||||
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
|
||||
if (!op->getAttr(kAlignmentAttrName)) {
|
||||
return loadStoreOp.emitOpError("missing alignment value");
|
||||
return memoryOp.emitOpError("missing alignment value");
|
||||
}
|
||||
} else {
|
||||
if (op->getAttr(kAlignmentAttrName)) {
|
||||
return loadStoreOp.emitOpError(
|
||||
return memoryOp.emitOpError(
|
||||
"invalid alignment specification with non-aligned memory access "
|
||||
"specification");
|
||||
}
|
||||
|
@ -2752,8 +2752,7 @@ static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
|
|||
static LogicalResult
|
||||
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
|
||||
if (op.c().getType() != op.result().getType())
|
||||
return op.emitOpError(
|
||||
"result and third operand must have the same type");
|
||||
return op.emitOpError("result and third operand must have the same type");
|
||||
auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
|
||||
auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
|
||||
auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
|
||||
|
@ -2812,9 +2811,89 @@ static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
|
|||
"have the same size");
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.CopyMemory
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) {
|
||||
auto *op = copyMemory.getOperation();
|
||||
printer << spirv::CopyMemoryOp::getOperationName() << ' ';
|
||||
|
||||
StringRef targetStorageClass =
|
||||
stringifyStorageClass(copyMemory.target()
|
||||
.getType()
|
||||
.cast<spirv::PointerType>()
|
||||
.getStorageClass());
|
||||
printer << " \"" << targetStorageClass << "\" " << copyMemory.target()
|
||||
<< ", ";
|
||||
|
||||
StringRef sourceStorageClass =
|
||||
stringifyStorageClass(copyMemory.source()
|
||||
.getType()
|
||||
.cast<spirv::PointerType>()
|
||||
.getStorageClass());
|
||||
printer << " \"" << sourceStorageClass << "\" " << copyMemory.source();
|
||||
|
||||
SmallVector<StringRef, 4> elidedAttrs;
|
||||
printMemoryAccessAttribute(copyMemory, printer, elidedAttrs);
|
||||
|
||||
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
||||
|
||||
Type pointeeType =
|
||||
copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
|
||||
printer << " : " << pointeeType;
|
||||
}
|
||||
|
||||
static ParseResult parseCopyMemoryOp(OpAsmParser &parser,
|
||||
OperationState &state) {
|
||||
spirv::StorageClass targetStorageClass;
|
||||
OpAsmParser::OperandType targetPtrInfo;
|
||||
|
||||
spirv::StorageClass sourceStorageClass;
|
||||
OpAsmParser::OperandType sourcePtrInfo;
|
||||
|
||||
Type elementType;
|
||||
|
||||
if (parseEnumStrAttr(targetStorageClass, parser) ||
|
||||
parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
|
||||
parseEnumStrAttr(sourceStorageClass, parser) ||
|
||||
parser.parseOperand(sourcePtrInfo) ||
|
||||
parseMemoryAccessAttributes(parser, state) ||
|
||||
parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
|
||||
parser.parseType(elementType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
|
||||
auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
|
||||
|
||||
if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) ||
|
||||
parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) {
|
||||
Type targetType =
|
||||
copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
|
||||
|
||||
Type sourceType =
|
||||
copyMemory.source().getType().cast<spirv::PointerType>().getPointeeType();
|
||||
|
||||
if (targetType != sourceType) {
|
||||
return copyMemory.emitOpError(
|
||||
"both operands must be pointers to the same type");
|
||||
}
|
||||
|
||||
return verifyMemoryAccessAttribute(copyMemory);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.Transpose
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -57,3 +57,43 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
|||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
||||
spv.func @copy_memory_simple() "None" {
|
||||
%0 = spv.Variable : !spv.ptr<f32, Function>
|
||||
%1 = spv.Variable : !spv.ptr<f32, Function>
|
||||
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} : f32
|
||||
spv.CopyMemory "Function" %0, "Function" %1 : f32
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
||||
spv.func @copy_memory_different_storage_classes(%in : !spv.ptr<!spv.array<4xf32>, Input>, %out : !spv.ptr<!spv.array<4xf32>, Output>) "None" {
|
||||
// CHECK: spv.CopyMemory "Output" %{{.*}}, "Input" %{{.*}} : !spv.array<4 x f32>
|
||||
spv.CopyMemory "Output" %out, "Input" %in : !spv.array<4xf32>
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
|
||||
spv.func @copy_memory_with_access_operands() "None" {
|
||||
%0 = spv.Variable : !spv.ptr<f32, Function>
|
||||
%1 = spv.Variable : !spv.ptr<f32, Function>
|
||||
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32
|
||||
spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 4] : f32
|
||||
|
||||
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32
|
||||
spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"] : f32
|
||||
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1244,3 +1244,38 @@ func @cannot_be_generic_storage_class(%arg0: f32) -> () {
|
|||
%0 = spv.Variable : !spv.ptr<f32, Generic>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @copy_memory_incompatible_ptrs() -> () {
|
||||
%0 = spv.Variable : !spv.ptr<f32, Function>
|
||||
%1 = spv.Variable : !spv.ptr<i32, Function>
|
||||
// expected-error @+1 {{both operands must be pointers to the same type}}
|
||||
"spv.CopyMemory"(%0, %1) {} : (!spv.ptr<f32, Function>, !spv.ptr<i32, Function>) -> ()
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @copy_memory_invalid_maa() -> () {
|
||||
%0 = spv.Variable : !spv.ptr<f32, Function>
|
||||
%1 = spv.Variable : !spv.ptr<f32, Function>
|
||||
// expected-error @+1 {{missing alignment value}}
|
||||
"spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @copy_memory_print_maa() -> () {
|
||||
%0 = spv.Variable : !spv.ptr<f32, Function>
|
||||
%1 = spv.Variable : !spv.ptr<f32, Function>
|
||||
|
||||
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32
|
||||
"spv.CopyMemory"(%0, %1) {memory_access=0x0001 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
|
||||
|
||||
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32
|
||||
"spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
|
||||
|
||||
spv.Return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue