[mlir][Vector] Add a VectorUnrollInterface and expose UnrollVectorPattern.
The UnrollVectorPattern is can be used in a programmable fashion by: ``` OwningRewritePatternList patterns; patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx); patterns.insert<UnrollVectorPattern<vector::ContractionOp>>( ArrayRef<int64_t>{2, 2, 2}, ctx); ... applyPatternsAndFoldGreedily(getFunction(), patterns); ``` Differential revision: https://reviews.llvm.org/D83064
This commit is contained in:
parent
0607c8df7f
commit
05c65dc0fe
|
@ -444,7 +444,7 @@ def MyInterface : OpInterface<"MyInterface"> {
|
|||
// Note: `ConcreteOp` corresponds to the derived operation typename.
|
||||
InterfaceMethod<"/*insert doc here*/",
|
||||
"unsigned", "getNumWithDefault", (ins), /*methodBody=*/[{}], [{
|
||||
ConcreteOp op = cast<ConcreteOp>(getOperation());
|
||||
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
|
||||
return op.getNumInputs() + op.getNumOutputs();
|
||||
}]>,
|
||||
];
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/VectorUnrollInterface.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
|
||||
// Pull in all enum type definitions and utility function declarations.
|
||||
|
|
|
@ -17,6 +17,7 @@ include "mlir/IR/OpAsmInterface.td"
|
|||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/VectorUnrollInterface.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
|
||||
def StandardOps_Dialect : Dialect {
|
||||
|
@ -82,7 +83,9 @@ class UnaryOpSameOperandAndResultType<string mnemonic,
|
|||
}
|
||||
|
||||
class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
UnaryOpSameOperandAndResultType<mnemonic, traits>,
|
||||
UnaryOpSameOperandAndResultType<mnemonic,
|
||||
!listconcat(traits,
|
||||
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
|
||||
Arguments<(ins FloatLike:$operand)>;
|
||||
|
||||
// Base class for standard arithmetic operations. Requires operands and
|
||||
|
@ -112,7 +115,9 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
// <op>i %0, %1 : i32
|
||||
//
|
||||
class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
ArithmeticOp<mnemonic, traits>,
|
||||
ArithmeticOp<mnemonic,
|
||||
!listconcat(traits,
|
||||
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
|
||||
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>;
|
||||
|
||||
// Base class for standard arithmetic binary operations on floats, vectors and
|
||||
|
@ -125,7 +130,9 @@ class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
// <op>f %0, %1 : f32
|
||||
//
|
||||
class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
ArithmeticOp<mnemonic, traits>,
|
||||
ArithmeticOp<mnemonic,
|
||||
!listconcat(traits,
|
||||
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
|
||||
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
|
||||
|
||||
// Base class for standard arithmetic operations on complex numbers with a
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- VectorOps.h - MLIR Super Vectorizer Operations -----------*- C++ -*-===//
|
||||
//===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -19,6 +19,7 @@
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/VectorUnrollInterface.h"
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
include "mlir/Dialect/Affine/IR/AffineOpsBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/VectorUnrollInterface.td"
|
||||
|
||||
def Vector_Dialect : Dialect {
|
||||
let name = "vector";
|
||||
|
@ -39,10 +40,13 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
// TODO(andydavis, ntv) Add an attribute to specify a different algebra
|
||||
// with operators other than the current set: {*, +}.
|
||||
def Vector_ContractionOp :
|
||||
Vector_Op<"contract", [NoSideEffect,
|
||||
PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
|
||||
PredOpTrait<"third operand acc and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 2>>]>,
|
||||
Vector_Op<"contract", [
|
||||
NoSideEffect,
|
||||
PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
|
||||
PredOpTrait<"third operand acc and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 2>>,
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
|
||||
]>,
|
||||
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
|
||||
Variadic<VectorOf<[I1]>>:$masks,
|
||||
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
|
||||
|
@ -896,7 +900,9 @@ def Vector_TransferOpUtils {
|
|||
}
|
||||
|
||||
def Vector_TransferReadOp :
|
||||
Vector_Op<"transfer_read">,
|
||||
Vector_Op<"transfer_read", [
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
|
||||
]>,
|
||||
Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
|
||||
AffineMapAttr:$permutation_map, AnyType:$padding,
|
||||
OptionalAttr<BoolArrayAttr>:$masked)>,
|
||||
|
@ -1068,7 +1074,9 @@ def Vector_TransferReadOp :
|
|||
}
|
||||
|
||||
def Vector_TransferWriteOp :
|
||||
Vector_Op<"transfer_write">,
|
||||
Vector_Op<"transfer_write", [
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
|
||||
]>,
|
||||
Arguments<(ins AnyVector:$vector, AnyMemRef:$memref,
|
||||
Variadic<Index>:$indices,
|
||||
AffineMapAttr:$permutation_map,
|
||||
|
|
|
@ -20,7 +20,7 @@ class HasShape<list<int> shape> :
|
|||
StrJoinInt<shape>.result # "})">;
|
||||
|
||||
class UnrollVectorOp<list<int> factors> : NativeCodeCall<
|
||||
"unrollSingleResultOpMatchingType($_builder, $0.getDefiningOp(), " #
|
||||
"unrollSingleResultVectorOp($_builder, $0.getDefiningOp(), " #
|
||||
"{" # StrJoinInt<factors>.result # "})">;
|
||||
|
||||
#endif // VECTOR_TRANSFORM_PATTERNS
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
#define DIALECT_VECTOR_VECTORTRANSFORMS_H_
|
||||
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -25,42 +27,82 @@ void populateVectorToVectorConversionPatterns(
|
|||
|
||||
namespace vector {
|
||||
|
||||
// Entry point for unrolling declarative pattern rewrites.
|
||||
// `op` is unrolled to the `targetShape` as follows, for each of its operands:
|
||||
// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
|
||||
// `numUnrolledInstances` are computed from the `targetShape`. For now it is
|
||||
// assumed the unrolling factors divide the vector sizes.
|
||||
// 2. a fakeFork cast op is inserted that takes the operand and returns
|
||||
// `numUnrolledInstances` results of type `unrolledVectorType`.
|
||||
// 3. the original op is cloned `numUnrolledInstances` times, once for each
|
||||
// result of the fakeFork cast op.
|
||||
// 4. a fakeJoin cast op takes all these results and merges them into a single
|
||||
// aggregate vector result whose size matches the original non-unrolled op
|
||||
// operand types.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// opA(operand0, operand1) // numUnrolledInstances = 3
|
||||
//
|
||||
// operand0 operand1
|
||||
// | |
|
||||
// fork fork
|
||||
// <----------gather all fork ops --------->
|
||||
// /|\ /|\
|
||||
// f00 f01 f02 f10 f11 f12
|
||||
// <---------- clone op 3 times --------->
|
||||
// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
|
||||
// \ | /
|
||||
// <-------------------- join ------------------------->
|
||||
//
|
||||
// Other local patterns then kick in iteratively (including DCE) and compose
|
||||
// until all the fakeFork and fakeJoin ops are removed.
|
||||
//
|
||||
// This will be extended in the future to support more advanced use cases than
|
||||
// simple pointwise ops.
|
||||
SmallVector<Value, 1>
|
||||
unrollSingleResultOpMatchingType(OpBuilder &builder, Operation *op,
|
||||
ArrayRef<int64_t> targetShape);
|
||||
/// Entry point for unrolling declarative pattern rewrites.
|
||||
/// `op` is unrolled to the `targetShape` as follows, for each of its operands:
|
||||
/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances
|
||||
/// `numUnrolledInstances` are computed from the `targetShape`. For now it is
|
||||
/// assumed the unrolling factors divide the vector sizes.
|
||||
/// 2. a fakeFork cast op is inserted that takes the operand and returns
|
||||
/// `numUnrolledInstances` results of type `unrolledVectorType`.
|
||||
/// 3. the original op is cloned `numUnrolledInstances` times, once for each
|
||||
/// result of the fakeFork cast op.
|
||||
/// 4. a fakeJoin cast op takes all these results and merges them into a
|
||||
/// single aggregate vector result whose size matches the original
|
||||
/// non-unrolled op operand types.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// opA(operand0, operand1) // numUnrolledInstances = 3
|
||||
///
|
||||
/// operand0 operand1
|
||||
/// | |
|
||||
/// fork fork
|
||||
/// <----------gather all fork ops --------->
|
||||
/// /|\ /|\
|
||||
/// f00 f01 f02 f10 f11 f12
|
||||
/// <---------- clone op 3 times --------->
|
||||
/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
|
||||
/// \ | /
|
||||
/// <-------------------- join ------------------------->
|
||||
///
|
||||
/// Other local patterns then kick in iteratively (including DCE) and compose
|
||||
/// until all the fakeFork and fakeJoin ops are removed.
|
||||
///
|
||||
/// This will be extended in the future to support more advanced use cases than
|
||||
/// simple pointwise ops.
|
||||
SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
|
||||
Operation *op,
|
||||
ArrayRef<int64_t> targetShape);
|
||||
|
||||
/// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
|
||||
/// declaratively.
|
||||
template <typename OpTy>
|
||||
struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
|
||||
using FilterConstraintType = std::function<LogicalResult(OpTy op)>;
|
||||
UnrollVectorPattern(
|
||||
ArrayRef<int64_t> targetShape, MLIRContext *context,
|
||||
FilterConstraintType constraint = [](OpTy op) { return success(); })
|
||||
: OpRewritePattern<OpTy>(context),
|
||||
targetShape(targetShape.begin(), targetShape.end()),
|
||||
filter(constraint) {}
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (failed(filter(op)))
|
||||
return failure();
|
||||
auto unrollableVectorOp =
|
||||
dyn_cast<VectorUnrollOpInterface>(op.getOperation());
|
||||
if (!unrollableVectorOp)
|
||||
return failure();
|
||||
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
|
||||
if (!maybeUnrollShape)
|
||||
return failure();
|
||||
auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, targetShape);
|
||||
if (!maybeShapeRatio ||
|
||||
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
|
||||
return failure();
|
||||
if (op.getOperation()->getNumResults() != 1)
|
||||
return failure();
|
||||
auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape);
|
||||
if (resultVector.size() != 1)
|
||||
return failure();
|
||||
rewriter.replaceOp(op, resultVector.front());
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
SmallVector<int64_t, 4> targetShape;
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
} // namespace vector
|
||||
|
||||
|
|
|
@ -5,5 +5,6 @@ add_mlir_interface(DerivedAttributeOpInterface)
|
|||
add_mlir_interface(InferTypeOpInterface)
|
||||
add_mlir_interface(LoopLikeInterface)
|
||||
add_mlir_interface(SideEffectInterfaces)
|
||||
add_mlir_interface(VectorUnrollInterface)
|
||||
add_mlir_interface(ViewLikeInterface)
|
||||
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
//===- VectorUnrollInterface.h - Vector unrolling interface ---------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements the operation interface for vector ops that can be
|
||||
// unrolled.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
|
||||
#define MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
#include "mlir/Interfaces/VectorUnrollInterface.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE_H
|
|
@ -0,0 +1,45 @@
|
|||
//===- VectorUnrollInterface.td - VectorUnroll interface ---*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Defines the interface for operations on vectors that can be unrolled.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_INTERFACES_VECTORUNROLLINTERFACE
|
||||
#define MLIR_INTERFACES_VECTORUNROLLINTERFACE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
|
||||
let description = [{
|
||||
Encodes properties of an operation on vectors that can be unrolled.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<[{
|
||||
Returns the shape ratio of unrolling to the target vector shape
|
||||
`targetShape`. Returns `None` if the op cannot be unrolled to the target
|
||||
vector shape.
|
||||
}],
|
||||
"Optional<SmallVector<int64_t, 4>>",
|
||||
"getShapeForUnroll",
|
||||
(ins),
|
||||
/*methodBody=*/[{}],
|
||||
[{
|
||||
auto vt = this->getOperation()->getResult(0).getType().
|
||||
template dyn_cast<VectorType>();
|
||||
if (!vt)
|
||||
return None;
|
||||
SmallVector<int64_t, 4> res(vt.getShape().begin(), vt.getShape().end());
|
||||
return res;
|
||||
}]
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // MLIR_INTERFACES_VECTORUNROLLINTERFACE
|
|
@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRStandardOps
|
|||
MLIREDSC
|
||||
MLIRIR
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRVectorUnrollInterface
|
||||
MLIRViewLikeInterface
|
||||
)
|
||||
|
||||
|
|
|
@ -19,4 +19,5 @@ add_mlir_dialect_library(MLIRVector
|
|||
MLIRSCF
|
||||
MLIRLoopAnalysis
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRVectorUnrollInterface
|
||||
)
|
||||
|
|
|
@ -469,6 +469,12 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
|
|||
return res;
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
|
||||
SmallVector<int64_t, 4> shape;
|
||||
getIterationBounds(shape);
|
||||
return shape;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1522,6 +1528,11 @@ OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
|
|||
return OpFoldResult();
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
|
||||
auto s = getVectorType().getShape();
|
||||
return SmallVector<int64_t, 4>{s.begin(), s.end()};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransferWriteOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1612,6 +1623,11 @@ LogicalResult TransferWriteOp::fold(ArrayRef<Attribute>,
|
|||
return foldMemRefCast(*this);
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
|
||||
auto s = getVectorType().getShape();
|
||||
return SmallVector<int64_t, 4>{s.begin(), s.end()};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/VectorUnrollInterface.h"
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -357,7 +358,7 @@ struct VectorState {
|
|||
// (removable with DCE).
|
||||
|
||||
// TODO(andydavis) Generalize this to support structured ops beyond
|
||||
// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType'
|
||||
// vector ContractionOp, and merge it with 'unrollSingleResultVectorOp'
|
||||
static Value unrollSingleResultStructuredOp(Operation *op,
|
||||
ArrayRef<int64_t> iterationBounds,
|
||||
std::vector<VectorState> &vectors,
|
||||
|
@ -450,11 +451,7 @@ static Value unrollSingleResultStructuredOp(Operation *op,
|
|||
|
||||
static void getVectorContractionOpUnrollState(
|
||||
vector::ContractionOp contractionOp, ArrayRef<int64_t> targetShape,
|
||||
SmallVectorImpl<int64_t> &iterationBounds,
|
||||
std::vector<VectorState> &vectors, unsigned &resultIndex) {
|
||||
// Get contraction op iteration bounds.
|
||||
contractionOp.getIterationBounds(iterationBounds);
|
||||
assert(iterationBounds.size() == targetShape.size());
|
||||
// Get map from iteration space index to lhs/rhs/result shape index.
|
||||
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
|
||||
contractionOp.getIterationIndexMap(iterationIndexMapList);
|
||||
|
@ -476,17 +473,15 @@ static void getVectorContractionOpUnrollState(
|
|||
vectors.push_back({contractionOp.getRHSVectorMaskType(),
|
||||
vectors[1].indexMap, accOperandIndex + 2, false});
|
||||
}
|
||||
// Unroll 'op' 'iterationBounds' to 'targetShape'.
|
||||
// TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition
|
||||
// 'vectors' instead of 'resultIndex'.
|
||||
resultIndex = accOperandIndex;
|
||||
}
|
||||
|
||||
static void
|
||||
getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
|
||||
SmallVectorImpl<int64_t> &iterationBounds,
|
||||
std::vector<VectorState> &vectors,
|
||||
unsigned &resultIndex) {
|
||||
static void getVectorElementwiseOpUnrollState(Operation *op,
|
||||
ArrayRef<int64_t> targetShape,
|
||||
std::vector<VectorState> &vectors,
|
||||
unsigned &resultIndex) {
|
||||
// Verify that operation and operands all have the same vector shape.
|
||||
auto resultType = op->getResult(0).getType().dyn_cast_or_null<VectorType>();
|
||||
assert(resultType && "Expected op with vector result type");
|
||||
|
@ -494,8 +489,6 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
|
|||
// Verify that all operands have the same vector type as result.
|
||||
assert(llvm::all_of(op->getOperandTypes(),
|
||||
[=](Type type) { return type == resultType; }));
|
||||
// Populate 'iterationBounds' with 'resultShape' for elementwise operations.
|
||||
iterationBounds.assign(resultShape.begin(), resultShape.end());
|
||||
|
||||
// Create trivial elementwise identity index map based on 'resultShape'.
|
||||
DenseMap<int64_t, int64_t> indexMap;
|
||||
|
@ -513,28 +506,32 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
|
|||
}
|
||||
|
||||
// Entry point for unrolling declarative pattern rewrites.
|
||||
SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
|
||||
OpBuilder &builder, Operation *op, ArrayRef<int64_t> targetShape) {
|
||||
SmallVector<Value, 1>
|
||||
mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
|
||||
ArrayRef<int64_t> targetShape) {
|
||||
assert(op->getNumResults() == 1 && "Expected single result operation");
|
||||
|
||||
// Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
|
||||
SmallVector<int64_t, 6> iterationBounds;
|
||||
auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
|
||||
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
|
||||
assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
|
||||
|
||||
std::vector<VectorState> vectors;
|
||||
unsigned resultIndex;
|
||||
|
||||
if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
|
||||
// Populate state for vector ContractionOp.
|
||||
getVectorContractionOpUnrollState(contractionOp, targetShape,
|
||||
iterationBounds, vectors, resultIndex);
|
||||
getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
|
||||
resultIndex);
|
||||
} else {
|
||||
// Populate state for vector elementwise op.
|
||||
getVectorElementwiseOpUnrollState(op, targetShape, iterationBounds, vectors,
|
||||
resultIndex);
|
||||
getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
|
||||
}
|
||||
|
||||
// Unroll 'op' with 'iterationBounds' to 'targetShape'.
|
||||
return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
|
||||
op, iterationBounds, vectors, resultIndex, targetShape, builder)};
|
||||
op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
|
||||
}
|
||||
|
||||
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
|
||||
|
|
|
@ -6,6 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
|
|||
InferTypeOpInterface.cpp
|
||||
LoopLikeInterface.cpp
|
||||
SideEffectInterfaces.cpp
|
||||
VectorUnrollInterface.cpp
|
||||
ViewLikeInterface.cpp
|
||||
)
|
||||
|
||||
|
@ -32,5 +33,6 @@ add_mlir_interface_library(DerivedAttributeOpInterface)
|
|||
add_mlir_interface_library(InferTypeOpInterface)
|
||||
add_mlir_interface_library(LoopLikeInterface)
|
||||
add_mlir_interface_library(SideEffectInterfaces)
|
||||
add_mlir_interface_library(VectorUnrollInterface)
|
||||
add_mlir_interface_library(ViewLikeInterface)
|
||||
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
//===- VectorUnrollInterface.cpp - Unrollable vector operations -*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Interfaces/VectorUnrollInterface.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorUnroll Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Include the definitions of the VectorUntoll interfaces.
|
||||
#include "mlir/Interfaces/VectorUnrollInterface.cpp.inc"
|
|
@ -1,4 +1,5 @@
|
|||
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-vector-unrolling-patterns | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
|
|
|
@ -92,6 +92,20 @@ struct TestVectorContractionConversion
|
|||
}
|
||||
};
|
||||
|
||||
struct TestVectorUnrollingPatterns
|
||||
: public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
MLIRContext *ctx = &getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx);
|
||||
patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
|
||||
ArrayRef<int64_t>{2, 2, 2}, ctx);
|
||||
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
|
||||
populateVectorToVectorTransformationPatterns(patterns, ctx);
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
|
@ -107,5 +121,9 @@ void registerTestVectorConversions() {
|
|||
PassRegistration<TestVectorContractionConversion> contractionPass(
|
||||
"test-vector-contraction-conversion",
|
||||
"Test conversion patterns that lower contract ops in the vector dialect");
|
||||
|
||||
PassRegistration<TestVectorUnrollingPatterns> contractionUnrollingPass(
|
||||
"test-vector-unrolling-patterns",
|
||||
"Test conversion patterns to unroll contract ops in the vector dialect");
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
Loading…
Reference in New Issue