[mlir][vector] Add helper that builds a scalar reduction according to CombiningKind

Differential Revision: https://reviews.llvm.org/D119433
This commit is contained in:
Matthias Springer 2022-02-10 22:21:34 +09:00
parent d038faea46
commit 9b5a3d14b2
4 changed files with 73 additions and 100 deletions

View File

@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
#define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LLVM.h"
@ -30,12 +31,14 @@ class VectorType;
class VectorTransferOpInterface;
namespace vector {
class TransferWriteOp;
class TransferReadOp;
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
/// Return the result value of reducing two scalar/vector values with the
/// corresponding arith operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
Value v1, Value v2);
} // namespace vector
/// Return the number of elements of basis, `0` if empty.

View File

@ -243,47 +243,8 @@ struct TwoDimMultiReductionToElementWise
for (int64_t i = 1; i < srcShape[0]; i++) {
auto operand =
rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
switch (multiReductionOp.kind()) {
case vector::CombiningKind::ADD:
if (elementType.isIntOrIndex())
result = rewriter.create<arith::AddIOp>(loc, operand, result);
else
result = rewriter.create<arith::AddFOp>(loc, operand, result);
break;
case vector::CombiningKind::MUL:
if (elementType.isIntOrIndex())
result = rewriter.create<arith::MulIOp>(loc, operand, result);
else
result = rewriter.create<arith::MulFOp>(loc, operand, result);
break;
case vector::CombiningKind::MINUI:
result = rewriter.create<arith::MinUIOp>(loc, operand, result);
break;
case vector::CombiningKind::MINSI:
result = rewriter.create<arith::MinSIOp>(loc, operand, result);
break;
case vector::CombiningKind::MINF:
result = rewriter.create<arith::MinFOp>(loc, operand, result);
break;
case vector::CombiningKind::MAXUI:
result = rewriter.create<arith::MaxUIOp>(loc, operand, result);
break;
case vector::CombiningKind::MAXSI:
result = rewriter.create<arith::MaxSIOp>(loc, operand, result);
break;
case vector::CombiningKind::MAXF:
result = rewriter.create<arith::MaxFOp>(loc, operand, result);
break;
case vector::CombiningKind::AND:
result = rewriter.create<arith::AndIOp>(loc, operand, result);
break;
case vector::CombiningKind::OR:
result = rewriter.create<arith::OrIOp>(loc, operand, result);
break;
case vector::CombiningKind::XOR:
result = rewriter.create<arith::XOrIOp>(loc, operand, result);
break;
}
result = makeArithReduction(rewriter, loc, multiReductionOp.kind(),
operand, result);
}
rewriter.replaceOp(multiReductionOp, result);

View File

@ -10,6 +10,8 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include <type_traits>
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@ -18,8 +20,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
@ -514,40 +515,11 @@ private:
if (!acc)
return Optional<Value>(mul);
Value combinedResult;
switch (kind) {
case CombiningKind::ADD:
combinedResult = rewriter.create<arith::AddIOp>(loc, mul, acc);
break;
case CombiningKind::MUL:
combinedResult = rewriter.create<arith::MulIOp>(loc, mul, acc);
break;
case CombiningKind::MINUI:
combinedResult = rewriter.create<arith::MinUIOp>(loc, mul, acc);
break;
case CombiningKind::MINSI:
combinedResult = rewriter.create<arith::MinSIOp>(loc, mul, acc);
break;
case CombiningKind::MAXUI:
combinedResult = rewriter.create<arith::MaxUIOp>(loc, mul, acc);
break;
case CombiningKind::MAXSI:
combinedResult = rewriter.create<arith::MaxSIOp>(loc, mul, acc);
break;
case CombiningKind::AND:
combinedResult = rewriter.create<arith::AndIOp>(loc, mul, acc);
break;
case CombiningKind::OR:
combinedResult = rewriter.create<arith::OrIOp>(loc, mul, acc);
break;
case CombiningKind::XOR:
combinedResult = rewriter.create<arith::XOrIOp>(loc, mul, acc);
break;
case CombiningKind::MINF: // Only valid for floating point types.
case CombiningKind::MAXF: // Only valid for floating point types.
if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
// Only valid for floating point types.
return Optional<Value>();
}
return Optional<Value>(combinedResult);
return makeArithReduction(rewriter, loc, kind, mul, acc);
}
static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
@ -565,28 +537,14 @@ private:
if (!acc)
return Optional<Value>(mul);
Value combinedResult;
switch (kind) {
case CombiningKind::MUL:
combinedResult = rewriter.create<arith::MulFOp>(loc, mul, acc);
break;
case CombiningKind::MINF:
combinedResult = rewriter.create<arith::MinFOp>(loc, mul, acc);
break;
case CombiningKind::MAXF:
combinedResult = rewriter.create<arith::MaxFOp>(loc, mul, acc);
break;
case CombiningKind::ADD: // Already handled this special case above.
case CombiningKind::AND: // Only valid for integer types.
case CombiningKind::MINUI: // Only valid for integer types.
case CombiningKind::MINSI: // Only valid for integer types.
case CombiningKind::MAXUI: // Only valid for integer types.
case CombiningKind::MAXSI: // Only valid for integer types.
case CombiningKind::OR: // Only valid for integer types.
case CombiningKind::XOR: // Only valid for integer types.
if (kind == CombiningKind::ADD || kind == CombiningKind::AND ||
kind == CombiningKind::MINUI || kind == CombiningKind::MINSI ||
kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI ||
kind == CombiningKind::OR || kind == CombiningKind::XOR)
// Already handled or only valid for integer types.
return Optional<Value>();
}
return Optional<Value>(combinedResult);
return makeArithReduction(rewriter, loc, kind, mul, acc);
}
};

View File

@ -22,6 +22,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/MathExtras.h"
#include <numeric>
@ -42,6 +43,56 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
llvm_unreachable("Expected MemRefType or TensorType");
}
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value v2) {
Type t1 = getElementTypeOrSelf(v1.getType());
Type t2 = getElementTypeOrSelf(v2.getType());
switch (kind) {
case CombiningKind::ADD:
if (t1.isIntOrIndex() && t2.isIntOrIndex())
return b.createOrFold<arith::AddIOp>(loc, v1, v2);
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
return b.createOrFold<arith::AddFOp>(loc, v1, v2);
llvm_unreachable("invalid value types for ADD reduction");
case CombiningKind::AND:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::AndIOp>(loc, v1, v2);
case CombiningKind::MAXF:
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
"expected float values");
return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
case CombiningKind::MINF:
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
"expected float values");
return b.createOrFold<arith::MinFOp>(loc, v1, v2);
case CombiningKind::MAXSI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
case CombiningKind::MINSI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
case CombiningKind::MAXUI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
case CombiningKind::MINUI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
case CombiningKind::MUL:
if (t1.isIntOrIndex() && t2.isIntOrIndex())
return b.createOrFold<arith::MulIOp>(loc, v1, v2);
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
return b.createOrFold<arith::MulFOp>(loc, v1, v2);
llvm_unreachable("invalid value types for MUL reduction");
case CombiningKind::OR:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::OrIOp>(loc, v1, v2);
case CombiningKind::XOR:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
};
llvm_unreachable("unknown CombiningKind");
}
/// Return the number of elements of basis, `0` if empty.
int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
if (basis.empty())