[mlir][Vector] Add a VectorUnrollInterface and expose UnrollVectorPattern.
The UnrollVectorPattern is can be used in a programmable fashion by: ``` OwningRewritePatternList patterns; patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx); patterns.insert<UnrollVectorPattern<vector::ContractionOp>>( ArrayRef<int64_t>{2, 2, 2}, ctx); ... applyPatternsAndFoldGreedily(getFunction(), patterns); ``` Differential revision: https://reviews.llvm.org/D83064
This commit is contained in:
parent
0607c8df7f
commit
05c65dc0fe
|
@ -444,7 +444,7 @@ def MyInterface : OpInterface<"MyInterface"> {
|
||||||
// Note: `ConcreteOp` corresponds to the derived operation typename.
|
// Note: `ConcreteOp` corresponds to the derived operation typename.
|
||||||
InterfaceMethod<"/*insert doc here*/",
|
InterfaceMethod<"/*insert doc here*/",
|
||||||
"unsigned", "getNumWithDefault", (ins), /*methodBody=*/[{}], [{
|
"unsigned", "getNumWithDefault", (ins), /*methodBody=*/[{}], [{
|
||||||
ConcreteOp op = cast<ConcreteOp>(getOperation());
|
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
|
||||||
return op.getNumInputs() + op.getNumOutputs();
|
return op.getNumInputs() + op.getNumOutputs();
|
||||||
}]>,
|
}]>,
|
||||||
];
|
];
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include "mlir/Interfaces/CallInterfaces.h"
|
#include "mlir/Interfaces/CallInterfaces.h"
|
||||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
#include "mlir/Interfaces/VectorUnrollInterface.h"
|
||||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||||
|
|
||||||
// Pull in all enum type definitions and utility function declarations.
|
// Pull in all enum type definitions and utility function declarations.
|
||||||
|
|
|
@ -17,6 +17,7 @@ include "mlir/IR/OpAsmInterface.td"
|
||||||
include "mlir/Interfaces/CallInterfaces.td"
|
include "mlir/Interfaces/CallInterfaces.td"
|
||||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
include "mlir/Interfaces/VectorUnrollInterface.td"
|
||||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||||
|
|
||||||
def StandardOps_Dialect : Dialect {
|
def StandardOps_Dialect : Dialect {
|
||||||
|
@ -82,7 +83,9 @@ class UnaryOpSameOperandAndResultType<string mnemonic,
|
||||||
}
|
}
|
||||||
|
|
||||||
class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
|
class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||||
UnaryOpSameOperandAndResultType<mnemonic, traits>,
|
UnaryOpSameOperandAndResultType<mnemonic,
|
||||||
|
!listconcat(traits,
|
||||||
|
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
|
||||||
Arguments<(ins FloatLike:$operand)>;
|
Arguments<(ins FloatLike:$operand)>;
|
||||||
|
|
||||||
// Base class for standard arithmetic operations. Requires operands and
|
// Base class for standard arithmetic operations. Requires operands and
|
||||||
|
@ -112,7 +115,9 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||||
// <op>i %0, %1 : i32
|
// <op>i %0, %1 : i32
|
||||||
//
|
//
|
||||||
class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||||
ArithmeticOp<mnemonic, traits>,
|
ArithmeticOp<mnemonic,
|
||||||
|
!listconcat(traits,
|
||||||
|
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
|
||||||
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>;
|
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>;
|
||||||
|
|
||||||
// Base class for standard arithmetic binary operations on floats, vectors and
|
// Base class for standard arithmetic binary operations on floats, vectors and
|
||||||
|
@ -125,7 +130,9 @@ class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||||
// <op>f %0, %1 : f32
|
// <op>f %0, %1 : f32
|
||||||
//
|
//
|
||||||
class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||||
ArithmeticOp<mnemonic, traits>,
|
ArithmeticOp<mnemonic,
|
||||||
|
!listconcat(traits,
|
||||||
|
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
|
||||||
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
|
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
|
||||||
|
|
||||||
// Base class for standard arithmetic operations on complex numbers with a
|
// Base class for standard arithmetic operations on complex numbers with a
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
//===- VectorOps.h - MLIR Super Vectorizer Operations -----------*- C++ -*-===//
|
//===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
@ -19,6 +19,7 @@
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
#include "mlir/Interfaces/VectorUnrollInterface.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class MLIRContext;
|
class MLIRContext;
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
include "mlir/Dialect/Affine/IR/AffineOpsBase.td"
|
include "mlir/Dialect/Affine/IR/AffineOpsBase.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
include "mlir/Interfaces/VectorUnrollInterface.td"
|
||||||
|
|
||||||
def Vector_Dialect : Dialect {
|
def Vector_Dialect : Dialect {
|
||||||
let name = "vector";
|
let name = "vector";
|
||||||
|
@ -39,10 +40,13 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
// TODO(andydavis, ntv) Add an attribute to specify a different algebra
|
// TODO(andydavis, ntv) Add an attribute to specify a different algebra
|
||||||
// with operators other than the current set: {*, +}.
|
// with operators other than the current set: {*, +}.
|
||||||
def Vector_ContractionOp :
|
def Vector_ContractionOp :
|
||||||
Vector_Op<"contract", [NoSideEffect,
|
Vector_Op<"contract", [
|
||||||
|
NoSideEffect,
|
||||||
PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
|
PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
|
||||||
PredOpTrait<"third operand acc and result have same element type",
|
PredOpTrait<"third operand acc and result have same element type",
|
||||||
TCresVTEtIsSameAsOpBase<0, 2>>]>,
|
TCresVTEtIsSameAsOpBase<0, 2>>,
|
||||||
|
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
|
||||||
|
]>,
|
||||||
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
|
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
|
||||||
Variadic<VectorOf<[I1]>>:$masks,
|
Variadic<VectorOf<[I1]>>:$masks,
|
||||||
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
|
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
|
||||||
|
@ -896,7 +900,9 @@ def Vector_TransferOpUtils {
|
||||||
}
|
}
|
||||||
|
|
||||||
def Vector_TransferReadOp :
|
def Vector_TransferReadOp :
|
||||||
Vector_Op<"transfer_read">,
|
Vector_Op<"transfer_read", [
|
||||||
|
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
|
||||||
|
]>,
|
||||||
Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
|
Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
|
||||||
AffineMapAttr:$permutation_map, AnyType:$padding,
|
AffineMapAttr:$permutation_map, AnyType:$padding,
|
||||||
OptionalAttr<BoolArrayAttr>:$masked)>,
|
OptionalAttr<BoolArrayAttr>:$masked)>,
|
||||||
|
@ -1068,7 +1074,9 @@ def Vector_TransferReadOp :
|
||||||
}
|
}
|
||||||
|
|
||||||
def Vector_TransferWriteOp :
|
def Vector_TransferWriteOp :
|
||||||
Vector_Op<"transfer_write">,
|
Vector_Op<"transfer_write", [
|
||||||
|
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
|
||||||
|
]>,
|
||||||
Arguments<(ins AnyVector:$vector, AnyMemRef:$memref,
|
Arguments<(ins AnyVector:$vector, AnyMemRef:$memref,
|
||||||
Variadic<Index>:$indices,
|
Variadic<Index>:$indices,
|
||||||
AffineMapAttr:$permutation_map,
|
AffineMapAttr:$permutation_map,
|
||||||
|
|
|
@ -20,7 +20,7 @@ class HasShape<list<int> shape> :
|
||||||
StrJoinInt<shape>.result # "})">;
|
StrJoinInt<shape>.result # "})">;
|
||||||
|
|
||||||
class UnrollVectorOp<list<int> factors> : NativeCodeCall<
|
class UnrollVectorOp<list<int> factors> : NativeCodeCall<
|
||||||
"unrollSingleResultOpMatchingType($_builder, $0.getDefiningOp(), " #
|
"unrollSingleResultVectorOp($_builder, $0.getDefiningOp(), " #
|
||||||
"{" # StrJoinInt<factors>.result # "})">;
|
"{" # StrJoinInt<factors>.result # "})">;
|
||||||
|
|
||||||
#endif // VECTOR_TRANSFORM_PATTERNS
|
#endif // VECTOR_TRANSFORM_PATTERNS
|
||||||
|
|
|
@ -10,6 +10,8 @@
|
||||||
#define DIALECT_VECTOR_VECTORTRANSFORMS_H_
|
#define DIALECT_VECTOR_VECTORTRANSFORMS_H_
|
||||||
|
|
||||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||||
|
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||||
|
#include "mlir/IR/Function.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -25,43 +27,83 @@ void populateVectorToVectorConversionPatterns(
|
||||||
|
|
||||||
namespace vector {
|
namespace vector {
|
||||||
|
|
||||||
// Entry point for unrolling declarative pattern rewrites.
|
/// Entry point for unrolling declarative pattern rewrites.
|
||||||
// `op` is unrolled to the `targetShape` as follows, for each of its operands:
|
/// `op` is unrolled to the `targetShape` as follows, for each of its operands:
|
||||||
// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
|
/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
|
||||||
// `numUnrolledInstances` are computed from the `targetShape`. For now it is
|
/// `numUnrolledInstances` are computed from the `targetShape`. For now it is
|
||||||
// assumed the unrolling factors divide the vector sizes.
|
/// assumed the unrolling factors divide the vector sizes.
|
||||||
// 2. a fakeFork cast op is inserted that takes the operand and returns
|
/// 2. a fakeFork cast op is inserted that takes the operand and returns
|
||||||
// `numUnrolledInstances` results of type `unrolledVectorType`.
|
/// `numUnrolledInstances` results of type `unrolledVectorType`.
|
||||||
// 3. the original op is cloned `numUnrolledInstances` times, once for each
|
/// 3. the original op is cloned `numUnrolledInstances` times, once for each
|
||||||
// result of the fakeFork cast op.
|
/// result of the fakeFork cast op.
|
||||||
// 4. a fakeJoin cast op takes all these results and merges them into a single
|
/// 4. a fakeJoin cast op takes all these results and merges them into a
|
||||||
// aggregate vector result whose size matches the original non-unrolled op
|
/// single aggregate vector result whose size matches the original
|
||||||
// operand types.
|
/// non-unrolled op operand types.
|
||||||
//
|
///
|
||||||
// Example:
|
/// Example:
|
||||||
//
|
///
|
||||||
// opA(operand0, operand1) // numUnrolledInstances = 3
|
/// opA(operand0, operand1) // numUnrolledInstances = 3
|
||||||
//
|
///
|
||||||
// operand0 operand1
|
/// operand0 operand1
|
||||||
// | |
|
/// | |
|
||||||
// fork fork
|
/// fork fork
|
||||||
// <----------gather all fork ops --------->
|
/// <----------gather all fork ops --------->
|
||||||
// /|\ /|\
|
/// /|\ /|\
|
||||||
// f00 f01 f02 f10 f11 f12
|
/// f00 f01 f02 f10 f11 f12
|
||||||
// <---------- clone op 3 times --------->
|
/// <---------- clone op 3 times --------->
|
||||||
// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
|
/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
|
||||||
// \ | /
|
/// \ | /
|
||||||
// <-------------------- join ------------------------->
|
/// <-------------------- join ------------------------->
|
||||||
//
|
///
|
||||||
// Other local patterns then kick in iteratively (including DCE) and compose
|
/// Other local patterns then kick in iteratively (including DCE) and compose
|
||||||
// until all the fakeFork and fakeJoin ops are removed.
|
/// until all the fakeFork and fakeJoin ops are removed.
|
||||||
//
|
///
|
||||||
// This will be extended in the future to support more advanced use cases than
|
/// This will be extended in the future to support more advanced use cases than
|
||||||
// simple pointwise ops.
|
/// simple pointwise ops.
|
||||||
SmallVector<Value, 1>
|
SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
|
||||||
unrollSingleResultOpMatchingType(OpBuilder &builder, Operation *op,
|
Operation *op,
|
||||||
ArrayRef<int64_t> targetShape);
|
ArrayRef<int64_t> targetShape);
|
||||||
|
|
||||||
|
/// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
|
||||||
|
/// declaratively.
|
||||||
|
template <typename OpTy>
|
||||||
|
struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
|
||||||
|
using FilterConstraintType = std::function<LogicalResult(OpTy op)>;
|
||||||
|
UnrollVectorPattern(
|
||||||
|
ArrayRef<int64_t> targetShape, MLIRContext *context,
|
||||||
|
FilterConstraintType constraint = [](OpTy op) { return success(); })
|
||||||
|
: OpRewritePattern<OpTy>(context),
|
||||||
|
targetShape(targetShape.begin(), targetShape.end()),
|
||||||
|
filter(constraint) {}
|
||||||
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
if (failed(filter(op)))
|
||||||
|
return failure();
|
||||||
|
auto unrollableVectorOp =
|
||||||
|
dyn_cast<VectorUnrollOpInterface>(op.getOperation());
|
||||||
|
if (!unrollableVectorOp)
|
||||||
|
return failure();
|
||||||
|
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
|
||||||
|
if (!maybeUnrollShape)
|
||||||
|
return failure();
|
||||||
|
auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, targetShape);
|
||||||
|
if (!maybeShapeRatio ||
|
||||||
|
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
|
||||||
|
return failure();
|
||||||
|
if (op.getOperation()->getNumResults() != 1)
|
||||||
|
return failure();
|
||||||
|
auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape);
|
||||||
|
if (resultVector.size() != 1)
|
||||||
|
return failure();
|
||||||
|
rewriter.replaceOp(op, resultVector.front());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SmallVector<int64_t, 4> targetShape;
|
||||||
|
FilterConstraintType filter;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace vector
|
} // namespace vector
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -5,5 +5,6 @@ add_mlir_interface(DerivedAttributeOpInterface)
|
||||||
add_mlir_interface(InferTypeOpInterface)
|
add_mlir_interface(InferTypeOpInterface)
|
||||||
add_mlir_interface(LoopLikeInterface)
|
add_mlir_interface(LoopLikeInterface)
|
||||||
add_mlir_interface(SideEffectInterfaces)
|
add_mlir_interface(SideEffectInterfaces)
|
||||||
|
add_mlir_interface(VectorUnrollInterface)
|
||||||
add_mlir_interface(ViewLikeInterface)
|
add_mlir_interface(ViewLikeInterface)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
//===- VectorUnrollInterface.h - Vector unrolling interface ---------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file implements the operation interface for vector ops that can be
|
||||||
|
// unrolled.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
|
||||||
|
#define MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
|
||||||
|
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
#include "mlir/Interfaces/VectorUnrollInterface.h.inc"
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
|
|
@ -0,0 +1,45 @@
|
||||||
|
//===- VectorUnrollInterface.td - VectorUnroll interface ---*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Defines the interface for operations on vectors that can be unrolled.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE
|
||||||
|
#define MLIR_INTERFACES_VECTORUNROLLINTERFACE
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
|
||||||
|
let description = [{
|
||||||
|
Encodes properties of an operation on vectors that can be unrolled.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let methods = [
|
||||||
|
InterfaceMethod<[{
|
||||||
|
Returns the shape ratio of unrolling to the target vector shape
|
||||||
|
`targetShape`. Returns `None` if the op cannot be unrolled to the target
|
||||||
|
vector shape.
|
||||||
|
}],
|
||||||
|
"Optional<SmallVector<int64_t, 4>>",
|
||||||
|
"getShapeForUnroll",
|
||||||
|
(ins),
|
||||||
|
/*methodBody=*/[{}],
|
||||||
|
[{
|
||||||
|
auto vt = this->getOperation()->getResult(0).getType().
|
||||||
|
template dyn_cast<VectorType>();
|
||||||
|
if (!vt)
|
||||||
|
return None;
|
||||||
|
SmallVector<int64_t, 4> res(vt.getShape().begin(), vt.getShape().end());
|
||||||
|
return res;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE
|
|
@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRStandardOps
|
||||||
MLIREDSC
|
MLIREDSC
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRSideEffectInterfaces
|
MLIRSideEffectInterfaces
|
||||||
|
MLIRVectorUnrollInterface
|
||||||
MLIRViewLikeInterface
|
MLIRViewLikeInterface
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -19,4 +19,5 @@ add_mlir_dialect_library(MLIRVector
|
||||||
MLIRSCF
|
MLIRSCF
|
||||||
MLIRLoopAnalysis
|
MLIRLoopAnalysis
|
||||||
MLIRSideEffectInterfaces
|
MLIRSideEffectInterfaces
|
||||||
|
MLIRVectorUnrollInterface
|
||||||
)
|
)
|
||||||
|
|
|
@ -469,6 +469,12 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
|
||||||
|
SmallVector<int64_t, 4> shape;
|
||||||
|
getIterationBounds(shape);
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ExtractElementOp
|
// ExtractElementOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1522,6 +1528,11 @@ OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
|
||||||
return OpFoldResult();
|
return OpFoldResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
|
||||||
|
auto s = getVectorType().getShape();
|
||||||
|
return SmallVector<int64_t, 4>{s.begin(), s.end()};
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TransferWriteOp
|
// TransferWriteOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1612,6 +1623,11 @@ LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
|
||||||
return foldMemRefCast(*this);
|
return foldMemRefCast(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
|
||||||
|
auto s = getVectorType().getShape();
|
||||||
|
return SmallVector<int64_t, 4>{s.begin(), s.end()};
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ShapeCastOp
|
// ShapeCastOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -30,6 +30,7 @@
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
|
#include "mlir/Interfaces/VectorUnrollInterface.h"
|
||||||
|
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
|
@ -357,7 +358,7 @@ struct VectorState {
|
||||||
// (removable with DCE).
|
// (removable with DCE).
|
||||||
|
|
||||||
// TODO(andydavis) Generalize this to support structured ops beyond
|
// TODO(andydavis) Generalize this to support structured ops beyond
|
||||||
// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType'
|
// vector ContractionOp, and merge it with 'unrollSingleResultVectorOp'
|
||||||
static Value unrollSingleResultStructuredOp(Operation *op,
|
static Value unrollSingleResultStructuredOp(Operation *op,
|
||||||
ArrayRef<int64_t> iterationBounds,
|
ArrayRef<int64_t> iterationBounds,
|
||||||
std::vector<VectorState> &vectors,
|
std::vector<VectorState> &vectors,
|
||||||
|
@ -450,11 +451,7 @@ static Value unrollSingleResultStructuredOp(Operation *op,
|
||||||
|
|
||||||
static void getVectorContractionOpUnrollState(
|
static void getVectorContractionOpUnrollState(
|
||||||
vector::ContractionOp contractionOp, ArrayRef<int64_t> targetShape,
|
vector::ContractionOp contractionOp, ArrayRef<int64_t> targetShape,
|
||||||
SmallVectorImpl<int64_t> &iterationBounds,
|
|
||||||
std::vector<VectorState> &vectors, unsigned &resultIndex) {
|
std::vector<VectorState> &vectors, unsigned &resultIndex) {
|
||||||
// Get contraction op iteration bounds.
|
|
||||||
contractionOp.getIterationBounds(iterationBounds);
|
|
||||||
assert(iterationBounds.size() == targetShape.size());
|
|
||||||
// Get map from iteration space index to lhs/rhs/result shape index.
|
// Get map from iteration space index to lhs/rhs/result shape index.
|
||||||
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
|
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
|
||||||
contractionOp.getIterationIndexMap(iterationIndexMapList);
|
contractionOp.getIterationIndexMap(iterationIndexMapList);
|
||||||
|
@ -476,15 +473,13 @@ static void getVectorContractionOpUnrollState(
|
||||||
vectors.push_back({contractionOp.getRHSVectorMaskType(),
|
vectors.push_back({contractionOp.getRHSVectorMaskType(),
|
||||||
vectors[1].indexMap, accOperandIndex + 2, false});
|
vectors[1].indexMap, accOperandIndex + 2, false});
|
||||||
}
|
}
|
||||||
// Unroll 'op' 'iterationBounds' to 'targetShape'.
|
|
||||||
// TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition
|
// TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition
|
||||||
// 'vectors' instead of 'resultIndex'.
|
// 'vectors' instead of 'resultIndex'.
|
||||||
resultIndex = accOperandIndex;
|
resultIndex = accOperandIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void getVectorElementwiseOpUnrollState(Operation *op,
|
||||||
getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
|
ArrayRef<int64_t> targetShape,
|
||||||
SmallVectorImpl<int64_t> &iterationBounds,
|
|
||||||
std::vector<VectorState> &vectors,
|
std::vector<VectorState> &vectors,
|
||||||
unsigned &resultIndex) {
|
unsigned &resultIndex) {
|
||||||
// Verify that operation and operands all have the same vector shape.
|
// Verify that operation and operands all have the same vector shape.
|
||||||
|
@ -494,8 +489,6 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
|
||||||
// Verify that all operands have the same vector type as result.
|
// Verify that all operands have the same vector type as result.
|
||||||
assert(llvm::all_of(op->getOperandTypes(),
|
assert(llvm::all_of(op->getOperandTypes(),
|
||||||
[=](Type type) { return type == resultType; }));
|
[=](Type type) { return type == resultType; }));
|
||||||
// Populate 'iterationBounds' with 'resultShape' for elementwise operations.
|
|
||||||
iterationBounds.assign(resultShape.begin(), resultShape.end());
|
|
||||||
|
|
||||||
// Create trivial elementwise identity index map based on 'resultShape'.
|
// Create trivial elementwise identity index map based on 'resultShape'.
|
||||||
DenseMap<int64_t, int64_t> indexMap;
|
DenseMap<int64_t, int64_t> indexMap;
|
||||||
|
@ -513,28 +506,32 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Entry point for unrolling declarative pattern rewrites.
|
// Entry point for unrolling declarative pattern rewrites.
|
||||||
SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
|
SmallVector<Value, 1>
|
||||||
OpBuilder &builder, Operation *op, ArrayRef<int64_t> targetShape) {
|
mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
|
||||||
|
ArrayRef<int64_t> targetShape) {
|
||||||
assert(op->getNumResults() == 1 && "Expected single result operation");
|
assert(op->getNumResults() == 1 && "Expected single result operation");
|
||||||
|
|
||||||
// Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
|
// Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
|
||||||
SmallVector<int64_t, 6> iterationBounds;
|
SmallVector<int64_t, 6> iterationBounds;
|
||||||
|
auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
|
||||||
|
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
|
||||||
|
assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
|
||||||
|
|
||||||
std::vector<VectorState> vectors;
|
std::vector<VectorState> vectors;
|
||||||
unsigned resultIndex;
|
unsigned resultIndex;
|
||||||
|
|
||||||
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
|
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
|
||||||
// Populate state for vector ContractionOp.
|
// Populate state for vector ContractionOp.
|
||||||
getVectorContractionOpUnrollState(contractionOp, targetShape,
|
getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
|
||||||
iterationBounds, vectors, resultIndex);
|
resultIndex);
|
||||||
} else {
|
} else {
|
||||||
// Populate state for vector elementwise op.
|
// Populate state for vector elementwise op.
|
||||||
getVectorElementwiseOpUnrollState(op, targetShape, iterationBounds, vectors,
|
getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
|
||||||
resultIndex);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unroll 'op' with 'iterationBounds' to 'targetShape'.
|
// Unroll 'op' with 'iterationBounds' to 'targetShape'.
|
||||||
return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
|
return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
|
||||||
op, iterationBounds, vectors, resultIndex, targetShape, builder)};
|
op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
|
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
|
||||||
|
|
|
@ -6,6 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
|
||||||
InferTypeOpInterface.cpp
|
InferTypeOpInterface.cpp
|
||||||
LoopLikeInterface.cpp
|
LoopLikeInterface.cpp
|
||||||
SideEffectInterfaces.cpp
|
SideEffectInterfaces.cpp
|
||||||
|
VectorUnrollInterface.cpp
|
||||||
ViewLikeInterface.cpp
|
ViewLikeInterface.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -32,5 +33,6 @@ add_mlir_interface_library(DerivedAttributeOpInterface)
|
||||||
add_mlir_interface_library(InferTypeOpInterface)
|
add_mlir_interface_library(InferTypeOpInterface)
|
||||||
add_mlir_interface_library(LoopLikeInterface)
|
add_mlir_interface_library(LoopLikeInterface)
|
||||||
add_mlir_interface_library(SideEffectInterfaces)
|
add_mlir_interface_library(SideEffectInterfaces)
|
||||||
|
add_mlir_interface_library(VectorUnrollInterface)
|
||||||
add_mlir_interface_library(ViewLikeInterface)
|
add_mlir_interface_library(ViewLikeInterface)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
//===- VectorUnrollInterface.cpp - Unrollable vector operations -*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Interfaces/VectorUnrollInterface.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// VectorUnroll Interfaces
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Include the definitions of the VectorUntoll interfaces.
|
||||||
|
#include "mlir/Interfaces/VectorUnrollInterface.cpp.inc"
|
|
@ -1,4 +1,5 @@
|
||||||
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
|
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
|
||||||
|
// RUN: mlir-opt %s -test-vector-unrolling-patterns | FileCheck %s
|
||||||
|
|
||||||
// CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
|
// CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||||
|
|
|
@ -92,6 +92,20 @@ struct TestVectorContractionConversion
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TestVectorUnrollingPatterns
|
||||||
|
: public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
MLIRContext *ctx = &getContext();
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx);
|
||||||
|
patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
|
||||||
|
ArrayRef<int64_t>{2, 2, 2}, ctx);
|
||||||
|
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
|
||||||
|
populateVectorToVectorTransformationPatterns(patterns, ctx);
|
||||||
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -107,5 +121,9 @@ void registerTestVectorConversions() {
|
||||||
PassRegistration<TestVectorContractionConversion> contractionPass(
|
PassRegistration<TestVectorContractionConversion> contractionPass(
|
||||||
"test-vector-contraction-conversion",
|
"test-vector-contraction-conversion",
|
||||||
"Test conversion patterns that lower contract ops in the vector dialect");
|
"Test conversion patterns that lower contract ops in the vector dialect");
|
||||||
|
|
||||||
|
PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
|
||||||
|
"test-vector-unrolling-patterns",
|
||||||
|
"Test conversion patterns to unroll contract ops in the vector dialect");
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
Loading…
Reference in New Issue