[NVPTX] Make tensor load/store intrinsics overloaded.

This way we can support address-space specific variants without explicitly
encoding the space in the name of the intrinsic. Less intrinsics to deal with ->
less boilerplate.

Added a bit of tablegen magic to match/replace an intrinsics with a pointer
argument in particular address space with the space-specific instruction
variant.

Updated tests to use non-default address spaces.

Differential Revision: https://reviews.llvm.org/D43268

llvm-svn: 328006
This commit is contained in:
Artem Belevich 2018-03-20 17:18:59 +00:00
parent 3a99893618
commit 914d4babec
5 changed files with 174 additions and 157 deletions

View File

@ -10527,8 +10527,7 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
llvm_unreachable("Unexpected builtin ID.");
}
Value *Result =
Builder.CreateCall(CGM.getIntrinsic(IID),
{Builder.CreatePointerCast(Src, VoidPtrTy), Ldm});
Builder.CreateCall(CGM.getIntrinsic(IID, Src->getType()), {Src, Ldm});
// Save returned values.
for (unsigned i = 0; i < NumResults; ++i) {
@ -10567,10 +10566,9 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
default:
llvm_unreachable("Unexpected builtin ID.");
}
Function *Intrinsic = CGM.getIntrinsic(IID);
Function *Intrinsic = CGM.getIntrinsic(IID, Dst->getType());
llvm::Type *ParamType = Intrinsic->getFunctionType()->getParamType(1);
SmallVector<Value *, 10> Values;
Values.push_back(Builder.CreatePointerCast(Dst, VoidPtrTy));
SmallVector<Value *, 10> Values = {Dst};
for (unsigned i = 0; i < NumResults; ++i) {
Value *V = Builder.CreateAlignedLoad(
Builder.CreateGEP(Src.getPointer(), llvm::ConstantInt::get(IntTy, i)),

View File

@ -3884,30 +3884,22 @@ def int_nvvm_match_all_sync_i64p :
//
// WMMA.LOAD
class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Space,
string Type, LLVMType regty, int WithStride>
class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Type,
LLVMType regty, int WithStride>
: Intrinsic<!if(!eq(Abc#Type,"cf16"),
[regty, regty, regty, regty],
[regty, regty, regty, regty,
regty, regty, regty, regty]),
!if(WithStride, [llvm_ptr_ty, llvm_i32_ty], [llvm_ptr_ty]),
[], // Properties must be set during instantiation.
!if(WithStride, [llvm_anyptr_ty, llvm_i32_ty], [llvm_anyptr_ty]),
[IntrReadMem, IntrArgMemOnly, ReadOnly<0>, NoCapture<0>],
"llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16"
#Space
#!if(WithStride,".stride","")
#"."#Type>;
multiclass NVVM_WMMA_LD_ALST<string Abc, string Layout, string Space,
string Type, LLVMType regty> {
def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 1>;
def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 0>;
}
multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout,
string Type, LLVMType regty> {
defm _global: NVVM_WMMA_LD_ALST<Abc, Layout, ".global", Type, regty>;
defm _shared: NVVM_WMMA_LD_ALST<Abc, Layout, ".shared", Type, regty>;
defm NAME: NVVM_WMMA_LD_ALST<Abc, Layout, "", Type, regty>;
multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout, string Type,
LLVMType regty> {
def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 1>;
def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 0>;
}
multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
@ -3915,47 +3907,33 @@ multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>;
}
// For some reason ReadOnly<N> and NoCapture<N> confuses tblgen if they are
// passed to Intrinsic<> form inside of a multiclass. Setting them globally
// outside of the multiclass works.
let IntrProperties = [IntrReadMem, IntrArgMemOnly,
ReadOnly<0>, NoCapture<0>] in {
defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
}
defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
// WMMA.STORE.D
class NVVM_WMMA_STD_LSTS<string Layout, string Space,
string Type, LLVMType regty, int WithStride,
class NVVM_WMMA_STD_LSTS<string Layout, string Type, LLVMType regty, int WithStride,
// This is only used to create a typed empty array we
// need to pass to !if below.
list<LLVMType>Empty=[]>
: Intrinsic<[],
!listconcat(
[llvm_ptr_ty],
[llvm_anyptr_ty],
!if(!eq(Type,"f16"),
[regty, regty, regty, regty],
[regty, regty, regty, regty,
regty, regty, regty, regty]),
!if(WithStride, [llvm_i32_ty], Empty)),
[], // Properties must be set during instantiation.
[IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>],
"llvm.nvvm.wmma.store.d.sync."#Layout
#".m16n16k16"#Space
#".m16n16k16"
#!if(WithStride,".stride","")
#"."#Type>;
multiclass NVVM_WMMA_STD_LST<string Layout, string Space,
string Type, LLVMType regty> {
def _stride: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 1>;
def NAME: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 0>;
}
multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> {
defm _global: NVVM_WMMA_STD_LST<Layout, ".global", Type, regty>;
defm _shared: NVVM_WMMA_STD_LST<Layout, ".shared", Type, regty>;
defm NAME: NVVM_WMMA_STD_LST<Layout, "", Type, regty>;
def _stride: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 1>;
def NAME: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 0>;
}
multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
@ -3963,11 +3941,8 @@ multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>;
}
let IntrProperties = [IntrWriteMem, IntrArgMemOnly,
WriteOnly<0>, NoCapture<0>] in {
defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
}
defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
// WMMA.MMA
class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,

View File

@ -3327,26 +3327,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_load_a_f16_row:
case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
case Intrinsic::nvvm_wmma_load_a_f16_col_global:
case Intrinsic::nvvm_wmma_load_a_f16_row_global:
case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
case Intrinsic::nvvm_wmma_load_b_f16_col:
case Intrinsic::nvvm_wmma_load_b_f16_row:
case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
case Intrinsic::nvvm_wmma_load_b_f16_col_global:
case Intrinsic::nvvm_wmma_load_b_f16_row_global:
case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: {
case Intrinsic::nvvm_wmma_load_b_f16_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f16;
Info.ptrVal = I.getArgOperand(0);
@ -3359,15 +3343,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_load_c_f16_col:
case Intrinsic::nvvm_wmma_load_c_f16_row:
case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
case Intrinsic::nvvm_wmma_load_c_f16_col_global:
case Intrinsic::nvvm_wmma_load_c_f16_row_global:
case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: {
case Intrinsic::nvvm_wmma_load_c_f16_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v4f16;
Info.ptrVal = I.getArgOperand(0);
@ -3380,15 +3356,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_load_c_f32_col:
case Intrinsic::nvvm_wmma_load_c_f32_row:
case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
case Intrinsic::nvvm_wmma_load_c_f32_col_global:
case Intrinsic::nvvm_wmma_load_c_f32_row_global:
case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: {
case Intrinsic::nvvm_wmma_load_c_f32_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
@ -3401,15 +3369,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_store_d_f16_col:
case Intrinsic::nvvm_wmma_store_d_f16_row:
case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
case Intrinsic::nvvm_wmma_store_d_f16_col_global:
case Intrinsic::nvvm_wmma_store_d_f16_row_global:
case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: {
case Intrinsic::nvvm_wmma_store_d_f16_row_stride: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v4f16;
Info.ptrVal = I.getArgOperand(0);
@ -3422,15 +3382,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_store_d_f32_col:
case Intrinsic::nvvm_wmma_store_d_f32_row:
case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
case Intrinsic::nvvm_wmma_store_d_f32_col_global:
case Intrinsic::nvvm_wmma_store_d_f32_row_global:
case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: {
case Intrinsic::nvvm_wmma_store_d_f32_row_stride: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);

View File

@ -7379,13 +7379,16 @@ class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass,
DAGOperand SrcOp, bit WithStride>
: EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
// Intrinsic that matches this instruction.
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_load_"
# Abc
# "_" # Type
# "_" # Layout
# !subst(".","_",Space)
# !if(WithStride,"_stride", ""));
// Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic
// for this function.
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_LOAD_"
# !subst("a", "A",
!subst("b", "B",
!subst("c", "C_" # Type, Abc)))
# "_" # Layout
# !subst(".", "_", Space)
# !if(WithStride,"_stride", "")
# "_Intr");
dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47));
@ -7410,7 +7413,7 @@ class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
!subst(imem, ADDRvar,
!subst(MEMri64, ADDRri64,
!subst(MEMri, ADDRri,
!subst(ins, Intr, tmp)))));
!subst(ins, IntrMatcher, tmp)))));
// Finally, consatenate both parts together. !con() requires both dags to have
// the same operator, so we wrap PatArgs in a (set ...) dag.
let Pattern = [!con(PatOuts, (set PatArgs))];
@ -7425,20 +7428,52 @@ class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
#";";
}
multiclass WMMA_LOAD_ALSTO<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass,
DAGOperand SrcOp> {
def _stride: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp, 1>;
def NAME: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp, 0>;
class WMMA_LOAD_INTR_HELPER<string Abc, string Layout, string Space,
string Type, bit WithStride>
: PatFrag <(ops),(ops)> {
// Intrinsic that matches this instruction.
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_load_"
# Abc
# "_" # Type
# "_" # Layout
# !if(WithStride,"_stride", ""));
code match_generic = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
}];
code match_shared = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
}];
code match_global = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
}];
let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src));
let Fragment = !foreach(tmp, Operands, !subst(ops, Intr, tmp));
let PredicateCode = !if(!eq(Space, ".shared"), match_shared,
!if(!eq(Space, ".global"), match_global, match_generic));
}
multiclass WMMA_LOAD_ALSTS<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass, bit WithStride> {
def _avar: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, imem, WithStride>;
def _areg: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, Int32Regs, WithStride>;
def _areg64: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, Int64Regs, WithStride>;
def _ari: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, MEMri, WithStride>;
def _ari64: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, MEMri64, WithStride>;
}
multiclass WMMA_LOAD_ALSTSh<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass, bit WithStride> {
// Define a PatFrag that matches appropriate intrinsic that loads from the
// given address space.
def _Intr : WMMA_LOAD_INTR_HELPER<Abc, Layout, Space, Type, WithStride>;
defm NAME: WMMA_LOAD_ALSTS<Abc, Layout, Space, Type, regclass, WithStride>;
}
multiclass WMMA_LOAD_ALST<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass> {
defm _avar: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imem>;
defm _areg: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, Int32Regs>;
defm _areg64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, Int64Regs>;
defm _ari: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, MEMri>;
defm _ari64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, MEMri64>;
string Type, NVPTXRegClass regclass> {
defm _stride: WMMA_LOAD_ALSTSh<Abc, Layout, Space, Type, regclass, 1>;
defm NAME: WMMA_LOAD_ALSTSh<Abc, Layout, Space, Type, regclass, 0>;
}
multiclass WMMA_LOAD_ALT<string Abc, string Layout,
@ -7461,15 +7496,16 @@ defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"c", "f32", Float32Regs>;
//
// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
//
class WMMA_STORE_D_LSTOS<string Layout, string Space,
class WMMA_STORE_D_LSTSO<string Layout, string Space,
string Type, NVPTXRegClass regclass,
DAGOperand DstOp, bit WithStride>
bit WithStride, DAGOperand DstOp>
: EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_store_d_"
# Type
# "_" # Layout
# !subst(".","_",Space)
# !if(WithStride,"_stride", ""));
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_STORE_D"
# "_" # Type
# "_" # Layout
# !subst(".", "_", Space)
# !if(WithStride,"_stride", "")
# "_Intr");
dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
dag InsR47 = (ins regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
@ -7483,7 +7519,7 @@ class WMMA_STORE_D_LSTOS<string Layout, string Space,
!subst(imem, ADDRvar,
!subst(MEMri64, ADDRri64,
!subst(MEMri, ADDRri,
!subst(ins, Intr, tmp)))));
!subst(ins, IntrMatcher, tmp)))));
let Pattern = [PatArgs];
let OutOperandList = (outs);
let InOperandList = Ins;
@ -7501,20 +7537,56 @@ class WMMA_STORE_D_LSTOS<string Layout, string Space,
}
multiclass WMMA_STORE_D_LSTO<string Layout, string Space,
string Type, NVPTXRegClass regclass,
DAGOperand DstOp> {
def _stride: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp, 1>;
def NAME: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp, 0>;
class WMMA_STORE_INTR_HELPER<string Layout, string Space,
string Type, bit WithStride>
: PatFrag <(ops),(ops)> {
// Intrinsic that matches this instruction.
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_store_d"
# "_" # Type
# "_" # Layout
# !if(WithStride, "_stride", ""));
code match_generic = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
}];
code match_shared = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
}];
code match_global = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
}];
dag Args = !if(!eq(Type,"f16"),
(ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3),
(ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3,
node:$r4, node:$r5, node:$r6, node:$r7));
dag StrideArg = !if(WithStride, (ops node:$ldm), (ops));
let Operands = !con(Args, StrideArg);
let Fragment = !foreach(tmp, Operands, !subst(ops, Intr, tmp));
let PredicateCode = !if(!eq(Space, ".shared"), match_shared,
!if(!eq(Space, ".global"), match_global, match_generic));
}
multiclass WMMA_STORE_D_LSTS<string Layout, string Space,
string Type, NVPTXRegClass regclass, bit WithStride> {
def _avar: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, imem>;
def _areg: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, Int32Regs>;
def _areg64: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, Int64Regs>;
def _ari: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, MEMri>;
def _ari64: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, MEMri64>;
}
multiclass WMMA_STORE_D_LSTSh<string Layout, string Space,
string Type, NVPTXRegClass regclass, bit WithStride> {
// Define a PatFrag that matches appropriate intrinsic that loads from the
// given address space.
def _Intr: WMMA_STORE_INTR_HELPER<Layout, Space, Type, WithStride>;
defm NAME: WMMA_STORE_D_LSTS<Layout, Space, Type, regclass, WithStride>;
}
multiclass WMMA_STORE_D_LST<string Layout, string Space,
string Type, NVPTXRegClass regclass> {
defm _avar: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imem>;
defm _areg: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, Int32Regs>;
defm _areg64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, Int64Regs>;
defm _ari: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, MEMri>;
defm _ari64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, MEMri64>;
string Type, NVPTXRegClass regclass > {
defm _stride: WMMA_STORE_D_LSTSh<Layout, Space, Type, regclass, 1>;
defm NAME: WMMA_STORE_D_LSTSh<Layout, Space, Type, regclass, 0>;
}
multiclass WMMA_STORE_D_LT<string Layout,

View File

@ -15,6 +15,22 @@ def make_wmma_slice_ty(abcd, itype):
def make_wmma_ld_ret_ty(abc, itype):
return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))
# returns address space
def get_aspace(space):
space_map = {
".global" : 1,
".shared" : 3,
".const" : 4,
".local" : 5,
".param" : 101,
"" : 0,
".generic": 0
}
return space_map[space];
def get_pspace(space):
return "p%di8" % get_aspace(space);
# Convenient test patterns.
check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
@ -22,28 +38,28 @@ check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
def gen_wmma_load_tests():
load_template = """
declare ${ret_ty} @llvm.nvvm.wmma.load.$intrinsic_suffix(i8* %src ${extra_args});
declare ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args});
; CHECK-LABEL: .func {{.*}}test_wmma_load_${function_suffix}(
define ${ret_ty} @test_wmma_load_${function_suffix}(i8* %src ${extra_args}) {
define ${ret_ty} @test_wmma_load_${function_suffix}(i8 ${as}* %src ${extra_args}) {
; CHECK wmma.load.${intrinsic_suffix}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src ${extra_args});
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args});
ret ${ret_ty} %v0;
}
; CHECK-LABEL: .func{{.*}}test_wmma_load_${function_suffix}_o(
define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8* %src ${extra_args}) {
define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8 ${as}* %src ${extra_args}) {
; CHECK wmma.load.${intrinsic_suffix}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
%src1 = getelementptr i8, i8* %src, i32 128;
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src1 ${extra_args});
%src1 = getelementptr i8, i8 ${as}* %src, i32 128;
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src1 ${extra_args});
ret ${ret_ty} %v0;
}
"""
suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}"
instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
for abc, layout, space, stride, itype in product(
@ -58,7 +74,9 @@ define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8* %src ${extra_args}) {
"layout" : layout,
"space" : space,
"stride" : stride,
"itype" : itype
"itype" : itype,
"pspace" : get_pspace(space),
"as" : "addrspace(%d)" % get_aspace(space)
}
if itype == "f32" and abc != "c":
@ -89,28 +107,28 @@ def make_wmma_slice_args(itype, abcd, prefix="v"):
def gen_wmma_store_tests():
store_template = """
declare void @llvm.nvvm.wmma.store.$intrinsic_suffix(i8* %src, ${args}${extra_args});
declare void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args}${extra_args});
; CHECK-LABEL: .func {{.*}}test_wmma_store_${function_suffix}(
define void @test_wmma_store_${function_suffix}(i8* %src, ${args}${extra_args}) {
define void @test_wmma_store_${function_suffix}(i8 ${as}* %src, ${args}${extra_args}) {
; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
; CHECK: {${check_args}}
; CHECK: ${stride_pattern}
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src, ${args} ${extra_args});
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args} ${extra_args});
ret void
}
; CHECK-LABEL: .func{{.*}}test_wmma_store_${function_suffix}_o(
define void @test_wmma_store_${function_suffix}_o(i8* %src, ${args}${extra_args}) {
define void @test_wmma_store_${function_suffix}_o(i8 ${as}* %src, ${args}${extra_args}) {
; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}+128]
; CHECK: ${check_args}
; CHECK: ${stride_pattern}
%src1 = getelementptr i8, i8* %src, i32 128;
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src1, ${args}${extra_args});
%src1 = getelementptr i8, i8 ${as}* %src, i32 128;
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src1, ${args}${extra_args});
ret void
}
"""
suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}"
instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
for abc, layout, space, stride, itype in product(
@ -125,7 +143,9 @@ define void @test_wmma_store_${function_suffix}_o(i8* %src, ${args}${extra_args}
"layout" : layout,
"space" : space,
"stride" : stride,
"itype" : itype
"itype" : itype,
"pspace" : get_pspace(space),
"as" : "addrspace(%d)" % get_aspace(space)
}
test_params = params