[AMDGPU] Match signed dot4/8 pattern.

Summary: This patch matches signed dot4 and dot8 pattern.

Author: FarhanaAleen

Reviewed By: msearles

Differential Revision: https://reviews.llvm.org/D52520

llvm-svn: 343798
This commit is contained in:
Farhana Aleen 2018-10-04 16:57:37 +00:00
parent 8920428376
commit 4bc597bff5
3 changed files with 5211 additions and 783 deletions

View File

@ -165,34 +165,40 @@ def V_FMA_MIXHI_F16 : VOP3_VOP3PInst<"v_fma_mixhi_f16", VOP3_Profile<VOP_F16_F16
defm : MadFmaMixPats<fma, V_FMA_MIX_F32, V_FMA_MIXLO_F16, V_FMA_MIXHI_F16>;
}
class Srl<int N> : PatFrag<(ops node:$src),
(srl node:$src, (i32 N))>;
// Defines patterns that extract signed 4bit from each Idx[0].
foreach Idx = [[0,28],[4,24],[8,20],[12,16],[16,12],[20,8],[24,4]] in
def ExtractSigned4bit_#Idx[0] : PatFrag<(ops node:$src),
(sra (shl node:$src, (i32 Idx[1])), (i32 28))>;
foreach Bits = 1-7 in
def srl#!shl(Bits, 2) : Srl<!shl(Bits, 2)>;
class Extract_U<int FromBitIndex, int BitMask> : PatFrag<
// Defines code pattern that extracts U(unsigned/signed) 4/8bit from FromBitIndex.
class Extract<int FromBitIndex, int BitMask, bit U>: PatFrag<
(ops node:$src),
!if (!or (!and (!eq (BitMask, 255), !eq (FromBitIndex, 24)),
!and (!eq (BitMask, 15), !eq (FromBitIndex, 28))), // last element
(!cast<Srl>("srl"#FromBitIndex) node:$src),
!if (!or (!and (!eq (BitMask, 255), !eq (FromBitIndex, 24)), !eq (FromBitIndex, 28)), // last element
!if (U, (srl node:$src, (i32 FromBitIndex)), (sra node:$src, (i32 FromBitIndex))),
!if (!eq (FromBitIndex, 0), // first element
(and node:$src, (i32 BitMask)),
(and (!cast<Srl>("srl"#FromBitIndex) node:$src), (i32 BitMask))))>;
!if (U, (and node:$src, (i32 BitMask)),
!if (!eq (BitMask, 15), (!cast<PatFrag>("ExtractSigned4bit_"#FromBitIndex) node:$src),
(sext_inreg node:$src, i8))),
!if (U, (and (srl node:$src, (i32 FromBitIndex)), (i32 BitMask)),
!if (!eq (BitMask, 15), (!cast<PatFrag>("ExtractSigned4bit_"#FromBitIndex) node:$src),
(sext_inreg (srl node:$src, (i32 FromBitIndex)), i8)))))>;
foreach Index = 0-3 in {
// Defines patterns that extract each Index'ed 8bit from an unsigned
// 32bit scalar value;
def U#Index#"_8bit" : Extract_U<!shl(Index, 3),
255>;
// Defines multiplication patterns where the multiplication is happening on each
// Index'ed 8bit of a 32bit scalar value.
def MulU_Elt#Index : PatFrag<
(ops node:$src0, node:$src1),
(AMDGPUmul_u24_oneuse (!cast<Extract_U>("U"#Index#"_8bit") node:$src0),
(!cast<Extract_U>("U"#Index#"_8bit") node:$src1))>;
}
foreach Type = ["I", "U"] in
foreach Index = 0-3 in {
// Defines patterns that extract each Index'ed 8bit from an unsigned
// 32bit scalar value;
def #Type#Index#"_8bit" : Extract<!shl(Index, 3), 255, !if (!eq (Type, "U"), 1, 0)>;
// Defines multiplication patterns where the multiplication is happening on each
// Index'ed 8bit of a 32bit scalar value.
def Mul#Type#_Elt#Index : PatFrag<
(ops node:$src0, node:$src1),
(!cast<HasOneUseBinOp>(!if (!eq (Type, "I"), AMDGPUmul_i24_oneuse, AMDGPUmul_u24_oneuse))
(!cast<Extract>(#Type#Index#"_8bit") node:$src0),
(!cast<Extract>(#Type#Index#"_8bit") node:$src1))>;
}
// Different variants of dot8 patterns cause a huge increase in the compile time.
// Define non-associative/commutative add/mul to prevent permutation in the dot8
@ -203,19 +209,23 @@ def NonACAdd_oneuse : HasOneUseBinOp<NonACAdd>;
def NonACAMDGPUmul_u24 : SDNode<"AMDGPUISD::MUL_U24" , SDTIntBinOp>;
def NonACAMDGPUmul_u24_oneuse : HasOneUseBinOp<NonACAMDGPUmul_u24>;
foreach Index = 0-7 in {
// Defines patterns that extract each Index'ed 4bit from an unsigned
// 32bit scalar value;
def U#Index#"_4bit" : Extract_U<!shl(Index, 2),
15>;
def NonACAMDGPUmul_i24 : SDNode<"AMDGPUISD::MUL_I24" , SDTIntBinOp>;
def NonACAMDGPUmul_i24_oneuse : HasOneUseBinOp<NonACAMDGPUmul_i24>;
// Defines multiplication patterns where the multiplication is happening on each
// Index'ed 8bit of a 32bit scalar value.
def MulU#Index#"_4bit" : PatFrag<
(ops node:$src0, node:$src1),
(NonACAMDGPUmul_u24_oneuse (!cast<Extract_U>("U"#Index#"_4bit") node:$src0),
(!cast<Extract_U>("U"#Index#"_4bit") node:$src1))>;
}
foreach Type = ["I", "U"] in
foreach Index = 0-7 in {
// Defines patterns that extract each Index'ed 4bit from an unsigned
// 32bit scalar value;
def #Type#Index#"_4bit" : Extract<!shl(Index, 2), 15, !if (!eq (Type, "U"), 1, 0)>;
// Defines multiplication patterns where the multiplication is happening on each
// Index'ed 8bit of a 32bit scalar value.
def Mul#Type#Index#"_4bit" : PatFrag<
(ops node:$src0, node:$src1),
(!cast<HasOneUseBinOp>(!if (!eq (Type, "I"), NonACAMDGPUmul_i24_oneuse, NonACAMDGPUmul_u24_oneuse))
(!cast<Extract>(#Type#Index#"_4bit") node:$src0),
(!cast<Extract>(#Type#Index#"_4bit") node:$src1))>;
}
class UDot2Pat<Instruction Inst> : GCNPat <
(add (add_oneuse (AMDGPUmul_u24_oneuse (srl i32:$src0, (i32 16)),
@ -264,17 +274,18 @@ defm : DotPats<int_amdgcn_udot8, V_DOT8_U32_U4>;
def : UDot2Pat<V_DOT2_U32_U16>;
def : SDot2Pat<V_DOT2_I32_I16>;
def : GCNPat <
!cast<dag>(!foldl((i32 i32:$src2), [0, 1, 2, 3], lhs, y,
(add_oneuse lhs, (!cast<PatFrag>("MulU_Elt"#y) i32:$src0, i32:$src1)))),
(V_DOT4_U32_U8 (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))
>;
foreach Type = ["U", "I"] in
def : GCNPat <
!cast<dag>(!foldl((i32 i32:$src2), [0, 1, 2, 3], lhs, y,
(add_oneuse lhs, (!cast<PatFrag>("Mul"#Type#"_Elt"#y) i32:$src0, i32:$src1)))),
(!cast<VOP3PInst>("V_DOT4_"#Type#"32_"#Type#8) (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))>;
def : GCNPat <
!cast<dag>(!foldl((add_oneuse i32:$src2, (MulU0_4bit i32:$src0, i32:$src1)), [1, 2, 3, 4, 5, 6, 7], lhs, y,
(NonACAdd_oneuse lhs, (!cast<PatFrag>("MulU"#y#"_4bit") i32:$src0, i32:$src1)))),
(V_DOT8_U32_U4 (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))
>;
foreach Type = ["U", "I"] in
def : GCNPat <
!cast<dag>(!foldl((add_oneuse i32:$src2, (!cast<PatFrag>("Mul"#Type#"0_4bit") i32:$src0, i32:$src1)),
[1, 2, 3, 4, 5, 6, 7], lhs, y,
(NonACAdd_oneuse lhs, (!cast<PatFrag>("Mul"#Type#y#"_4bit") i32:$src0, i32:$src1)))),
(!cast<VOP3PInst>("V_DOT8_"#Type#"32_"#Type#4) (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))>;
} // End SubtargetPredicate = HasDLInsts

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff