Add linalg.mmt4d named op
This op performs matrix-matrix-transpose multiplication of 4-d inputs as the following: ``` C[m1, n1, m0, n0] = sum_{k1, k0}(A[m1, k1, m0, k0] * B[n1, k1, n0, k0]) ``` Reviewed By: Benoit Differential Revision: https://reviews.llvm.org/D105244
This commit is contained in:
parent
e86fe368db
commit
0516f49c08
|
@ -62,6 +62,79 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarExpression
|
||||
scalar_arg: B
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: mmt4d
|
||||
cpp_class_name: Mmt4DOp
|
||||
doc: |-
|
||||
Performs a matrix-matrix-transpose multiplication of two 4D inputs.
|
||||
|
||||
Differences from linalg.matmul:
|
||||
* The right hand side is transposed, whence the 't' in 'mmt'.
|
||||
* The input and output tensors have a 4D shape instead of a 2D shape. They
|
||||
are interpreted as 2D matrices with one level of 2D tile subdivision,
|
||||
whence the 2+2=4 dimensions. The inner tile dimensions are identified with
|
||||
'0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
|
||||
as: MxK tiles, each of shape M0xK0.
|
||||
implements:
|
||||
- LinalgContractionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: lhs
|
||||
usage: InputOperand
|
||||
type_var: LhsType
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2, s3)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: rhs
|
||||
usage: InputOperand
|
||||
type_var: RhsType
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s1, s5, s3)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: accum
|
||||
usage: OutputOperand
|
||||
type_var: AccumType
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s4, s2, s5)>
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d4, d1,
|
||||
d5)>
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d2, d4, d3,
|
||||
d5)>
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d2, d1,
|
||||
d3)>
|
||||
iterator_types:
|
||||
- parallel
|
||||
- parallel
|
||||
- parallel
|
||||
- parallel
|
||||
- reduction
|
||||
- reduction
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: accum
|
||||
value: !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: accum
|
||||
- !ScalarExpression
|
||||
scalar_apply:
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: AccumType
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: lhs
|
||||
- !ScalarExpression
|
||||
symbolic_cast:
|
||||
type_var: AccumType
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: rhs
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: batch_matmul
|
||||
cpp_class_name: BatchMatmulOp
|
||||
|
|
|
@ -21,6 +21,26 @@ def matmul(
|
|||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
|
||||
rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
|
||||
accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0,
|
||||
output=True)):
|
||||
"""Performs a matrix-matrix-transpose multiplication of two 4D inputs.
|
||||
|
||||
Differences from linalg.matmul:
|
||||
* The right hand side is transposed, whence the 't' in 'mmt'.
|
||||
* The input and output tensors have a 4D shape instead of a 2D shape. They
|
||||
are interpreted as 2D matrices with one level of 2D tile subdivision,
|
||||
whence the 2+2=4 dimensions. The inner tile dimensions are identified with
|
||||
'0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
|
||||
as: MxK tiles, each of shape M0xK0.
|
||||
"""
|
||||
domain(D.m, D.m0, D.n, D.n0, D.k, D.k0)
|
||||
implements(ContractionOpInterface)
|
||||
accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def batch_matmul(
|
||||
A=TensorDef(T1, Batch, S.M, S.K),
|
||||
|
|
Loading…
Reference in New Issue