Add linalg.mmt4d named op

This op performs matrix-matrix-transpose multiplication of 4-d inputs as the following:

```
C[m1, n1, m0, n0] = sum_{k1, k0}(A[m1, k1, m0, k0] * B[n1, k1, n0, k0])
```

Reviewed By: Benoit

Differential Revision: https://reviews.llvm.org/D105244
This commit is contained in:
Ahmed Taei 2021-06-30 16:03:19 -07:00
parent e86fe368db
commit 0516f49c08
2 changed files with 93 additions and 0 deletions

View File

@ -62,6 +62,79 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: B
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: mmt4d
cpp_class_name: Mmt4DOp
doc: |-
Performs a matrix-matrix-transpose multiplication of two 4D inputs.
Differences from linalg.matmul:
* The right hand side is transposed, whence the 't' in 'mmt'.
* The input and output tensors have a 4D shape instead of a 2D shape. They
are interpreted as 2D matrices with one level of 2D tile subdivision,
whence the 2+2=4 dimensions. The inner tile dimensions are identified with
'0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
as: MxK tiles, each of shape M0xK0.
implements:
- LinalgContractionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: lhs
usage: InputOperand
type_var: LhsType
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2, s3)>
- !LinalgOperandDefConfig
name: rhs
usage: InputOperand
type_var: RhsType
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s1, s5, s3)>
- !LinalgOperandDefConfig
name: accum
usage: OutputOperand
type_var: AccumType
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s4, s2, s5)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d4, d1,
d5)>
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d2, d4, d3,
d5)>
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d2, d1,
d3)>
iterator_types:
- parallel
- parallel
- parallel
- parallel
- reduction
- reduction
assignments:
- !ScalarAssign
arg: accum
value: !ScalarExpression
scalar_apply:
fn_name: add
operands:
- !ScalarExpression
scalar_arg: accum
- !ScalarExpression
scalar_apply:
fn_name: mul
operands:
- !ScalarExpression
symbolic_cast:
type_var: AccumType
operands:
- !ScalarExpression
scalar_arg: lhs
- !ScalarExpression
symbolic_cast:
type_var: AccumType
operands:
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul
cpp_class_name: BatchMatmulOp

View File

@ -21,6 +21,26 @@ def matmul(
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
@linalg_structured_op
def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0,
output=True)):
"""Performs a matrix-matrix-transpose multiplication of two 4D inputs.
Differences from linalg.matmul:
* The right hand side is transposed, whence the 't' in 'mmt'.
* The input and output tensors have a 4D shape instead of a 2D shape. They
are interpreted as 2D matrices with one level of 2D tile subdivision,
whence the 2+2=4 dimensions. The inner tile dimensions are identified with
'0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
as: MxK tiles, each of shape M0xK0.
"""
domain(D.m, D.m0, D.n, D.n0, D.k, D.k0)
implements(ContractionOpInterface)
accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
@linalg_structured_op
def batch_matmul(
A=TensorDef(T1, Batch, S.M, S.K),