[mlir][vector] Add helper that builds a scalar reduction according to CombiningKind
Differential Revision: https://reviews.llvm.org/D119433
This commit is contained in:
parent
d038faea46
commit
9b5a3d14b2
|
@ -9,6 +9,7 @@
|
|||
#ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
|
||||
#define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
|
||||
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
|
@ -30,12 +31,14 @@ class VectorType;
|
|||
class VectorTransferOpInterface;
|
||||
|
||||
namespace vector {
|
||||
class TransferWriteOp;
|
||||
class TransferReadOp;
|
||||
|
||||
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
|
||||
/// the type of `source`.
|
||||
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
|
||||
|
||||
/// Return the result value of reducing two scalar/vector values with the
|
||||
/// corresponding arith operation.
|
||||
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
|
||||
Value v1, Value v2);
|
||||
} // namespace vector
|
||||
|
||||
/// Return the number of elements of basis, `0` if empty.
|
||||
|
|
|
@ -243,47 +243,8 @@ struct TwoDimMultiReductionToElementWise
|
|||
for (int64_t i = 1; i < srcShape[0]; i++) {
|
||||
auto operand =
|
||||
rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
|
||||
switch (multiReductionOp.kind()) {
|
||||
case vector::CombiningKind::ADD:
|
||||
if (elementType.isIntOrIndex())
|
||||
result = rewriter.create<arith::AddIOp>(loc, operand, result);
|
||||
else
|
||||
result = rewriter.create<arith::AddFOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::MUL:
|
||||
if (elementType.isIntOrIndex())
|
||||
result = rewriter.create<arith::MulIOp>(loc, operand, result);
|
||||
else
|
||||
result = rewriter.create<arith::MulFOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::MINUI:
|
||||
result = rewriter.create<arith::MinUIOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::MINSI:
|
||||
result = rewriter.create<arith::MinSIOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::MINF:
|
||||
result = rewriter.create<arith::MinFOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::MAXUI:
|
||||
result = rewriter.create<arith::MaxUIOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::MAXSI:
|
||||
result = rewriter.create<arith::MaxSIOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::MAXF:
|
||||
result = rewriter.create<arith::MaxFOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::AND:
|
||||
result = rewriter.create<arith::AndIOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::OR:
|
||||
result = rewriter.create<arith::OrIOp>(loc, operand, result);
|
||||
break;
|
||||
case vector::CombiningKind::XOR:
|
||||
result = rewriter.create<arith::XOrIOp>(loc, operand, result);
|
||||
break;
|
||||
}
|
||||
result = makeArithReduction(rewriter, loc, multiReductionOp.kind(),
|
||||
operand, result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(multiReductionOp, result);
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
|
@ -18,8 +20,7 @@
|
|||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -514,40 +515,11 @@ private:
|
|||
if (!acc)
|
||||
return Optional<Value>(mul);
|
||||
|
||||
Value combinedResult;
|
||||
switch (kind) {
|
||||
case CombiningKind::ADD:
|
||||
combinedResult = rewriter.create<arith::AddIOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MUL:
|
||||
combinedResult = rewriter.create<arith::MulIOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MINUI:
|
||||
combinedResult = rewriter.create<arith::MinUIOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MINSI:
|
||||
combinedResult = rewriter.create<arith::MinSIOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MAXUI:
|
||||
combinedResult = rewriter.create<arith::MaxUIOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MAXSI:
|
||||
combinedResult = rewriter.create<arith::MaxSIOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::AND:
|
||||
combinedResult = rewriter.create<arith::AndIOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::OR:
|
||||
combinedResult = rewriter.create<arith::OrIOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::XOR:
|
||||
combinedResult = rewriter.create<arith::XOrIOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MINF: // Only valid for floating point types.
|
||||
case CombiningKind::MAXF: // Only valid for floating point types.
|
||||
if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
|
||||
// Only valid for floating point types.
|
||||
return Optional<Value>();
|
||||
}
|
||||
return Optional<Value>(combinedResult);
|
||||
|
||||
return makeArithReduction(rewriter, loc, kind, mul, acc);
|
||||
}
|
||||
|
||||
static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
|
||||
|
@ -565,28 +537,14 @@ private:
|
|||
if (!acc)
|
||||
return Optional<Value>(mul);
|
||||
|
||||
Value combinedResult;
|
||||
switch (kind) {
|
||||
case CombiningKind::MUL:
|
||||
combinedResult = rewriter.create<arith::MulFOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MINF:
|
||||
combinedResult = rewriter.create<arith::MinFOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::MAXF:
|
||||
combinedResult = rewriter.create<arith::MaxFOp>(loc, mul, acc);
|
||||
break;
|
||||
case CombiningKind::ADD: // Already handled this special case above.
|
||||
case CombiningKind::AND: // Only valid for integer types.
|
||||
case CombiningKind::MINUI: // Only valid for integer types.
|
||||
case CombiningKind::MINSI: // Only valid for integer types.
|
||||
case CombiningKind::MAXUI: // Only valid for integer types.
|
||||
case CombiningKind::MAXSI: // Only valid for integer types.
|
||||
case CombiningKind::OR: // Only valid for integer types.
|
||||
case CombiningKind::XOR: // Only valid for integer types.
|
||||
if (kind == CombiningKind::ADD || kind == CombiningKind::AND ||
|
||||
kind == CombiningKind::MINUI || kind == CombiningKind::MINSI ||
|
||||
kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI ||
|
||||
kind == CombiningKind::OR || kind == CombiningKind::XOR)
|
||||
// Already handled or only valid for integer types.
|
||||
return Optional<Value>();
|
||||
}
|
||||
return Optional<Value>(combinedResult);
|
||||
|
||||
return makeArithReduction(rewriter, loc, kind, mul, acc);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include <numeric>
|
||||
|
@ -42,6 +43,56 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
|
|||
llvm_unreachable("Expected MemRefType or TensorType");
|
||||
}
|
||||
|
||||
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
|
||||
CombiningKind kind, Value v1, Value v2) {
|
||||
Type t1 = getElementTypeOrSelf(v1.getType());
|
||||
Type t2 = getElementTypeOrSelf(v2.getType());
|
||||
switch (kind) {
|
||||
case CombiningKind::ADD:
|
||||
if (t1.isIntOrIndex() && t2.isIntOrIndex())
|
||||
return b.createOrFold<arith::AddIOp>(loc, v1, v2);
|
||||
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
|
||||
return b.createOrFold<arith::AddFOp>(loc, v1, v2);
|
||||
llvm_unreachable("invalid value types for ADD reduction");
|
||||
case CombiningKind::AND:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::AndIOp>(loc, v1, v2);
|
||||
case CombiningKind::MAXF:
|
||||
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
|
||||
"expected float values");
|
||||
return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
|
||||
case CombiningKind::MINF:
|
||||
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
|
||||
"expected float values");
|
||||
return b.createOrFold<arith::MinFOp>(loc, v1, v2);
|
||||
case CombiningKind::MAXSI:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
|
||||
case CombiningKind::MINSI:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
|
||||
case CombiningKind::MAXUI:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
|
||||
case CombiningKind::MINUI:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
|
||||
case CombiningKind::MUL:
|
||||
if (t1.isIntOrIndex() && t2.isIntOrIndex())
|
||||
return b.createOrFold<arith::MulIOp>(loc, v1, v2);
|
||||
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
|
||||
return b.createOrFold<arith::MulFOp>(loc, v1, v2);
|
||||
llvm_unreachable("invalid value types for MUL reduction");
|
||||
case CombiningKind::OR:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::OrIOp>(loc, v1, v2);
|
||||
case CombiningKind::XOR:
|
||||
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
|
||||
return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
|
||||
};
|
||||
llvm_unreachable("unknown CombiningKind");
|
||||
}
|
||||
|
||||
/// Return the number of elements of basis, `0` if empty.
|
||||
int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
|
||||
if (basis.empty())
|
||||
|
|
Loading…
Reference in New Issue