[MLIR][SPIRVToLLVM] Convert spv.constant scalars and vectors

This patch introduces conversion pattern for `spv.constant` with scalar
and vector types. There is a special case when the constant value is a
signed/unsigned integer (vector of integers). Since LLVM dialect does not
have signedness semantics, the types had to be converted to signless ints.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D82936
This commit is contained in:
George Mitenkov 2020-07-02 14:21:35 -04:00 committed by Lei Zhang
parent 5416fc014a
commit 1cfaaf6455
2 changed files with 113 additions and 0 deletions

View File

@ -31,6 +31,15 @@ using namespace mlir;
// Utility functions // Utility functions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Returns true if the given type is a signed integer or vector type.
static bool isSignedIntegerOrVector(Type type) {
if (type.isSignedInteger())
return true;
if (auto vecType = type.dyn_cast<VectorType>())
return vecType.getElementType().isSignedInteger();
return false;
}
/// Returns true if the given type is an unsigned integer or vector type /// Returns true if the given type is an unsigned integer or vector type
static bool isUnsignedIntegerOrVector(Type type) { static bool isUnsignedIntegerOrVector(Type type) {
if (type.isUnsignedInteger()) if (type.isUnsignedInteger())
@ -197,6 +206,52 @@ public:
} }
}; };
/// Converts SPIR-V ConstantOp with scalar or vector type.
class ConstantScalarAndVectorPattern
: public SPIRVToLLVMConversion<spirv::ConstantOp> {
public:
using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::ConstantOp constOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto srcType = constOp.getType();
if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
return failure();
auto dstType = typeConverter.convertType(srcType);
if (!dstType)
return failure();
// SPIR-V constant can be a signed/unsigned integer, which has to be
// casted to signless integer when converting to LLVM dialect. Removing the
// sign bit may have unexpected behaviour. However, it is better to handle
// it case-by-case, given that the purpose of the conversion is not to
// cover all possible corner cases.
if (isSignedIntegerOrVector(srcType) ||
isUnsignedIntegerOrVector(srcType)) {
auto *context = rewriter.getContext();
auto signlessType = IntegerType::get(getBitWidth(srcType), context);
if (srcType.isa<VectorType>()) {
auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
constOp, dstType,
dstElementsAttr.mapValues(
signlessType, [&](const APInt &value) { return value; }));
return success();
}
auto srcAttr = constOp.value().cast<IntegerAttr>();
auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
return success();
}
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, operands,
constOp.getAttrs());
return success();
}
};
/// Converts SPIR-V operations that have straightforward LLVM equivalent /// Converts SPIR-V operations that have straightforward LLVM equivalent
/// into LLVM dialect operations. /// into LLVM dialect operations.
template <typename SPIRVOp, typename LLVMOp> template <typename SPIRVOp, typename LLVMOp>
@ -573,6 +628,9 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>, IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>, IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
// Constant op
ConstantScalarAndVectorPattern,
// Function Call op // Function Call op
FunctionCallPattern, FunctionCallPattern,

View File

@ -0,0 +1,55 @@
// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
//===----------------------------------------------------------------------===//
// spv.constant
//===----------------------------------------------------------------------===//
func @bool_constant_scalar() {
// CHECK: {{.*}} = llvm.mlir.constant(true) : !llvm.i1
%0 = spv.constant true
// CHECK: {{.*}} = llvm.mlir.constant(false) : !llvm.i1
%1 = spv.constant false
return
}
func @bool_constant_vector() {
// CHECK: {{.*}} = llvm.mlir.constant(dense<[true, false]> : vector<2xi1>) : !llvm<"<2 x i1>">
%0 = constant dense<[true, false]> : vector<2xi1>
// CHECK: {{.*}} = llvm.mlir.constant(dense<false> : vector<3xi1>) : !llvm<"<3 x i1>">
%1 = constant dense<false> : vector<3xi1>
return
}
func @integer_constant_scalar() {
// CHECK: {{.*}} = llvm.mlir.constant(0 : i8) : !llvm.i8
%0 = spv.constant 0 : i8
// CHECK: {{.*}} = llvm.mlir.constant(-5 : i64) : !llvm.i64
%1 = spv.constant -5 : si64
// CHECK: {{.*}} = llvm.mlir.constant(10 : i16) : !llvm.i16
%2 = spv.constant 10 : ui16
return
}
func @integer_constant_vector() {
// CHECK: {{.*}} = llvm.mlir.constant(dense<[2, 3]> : vector<2xi32>) : !llvm<"<2 x i32>">
%0 = spv.constant dense<[2, 3]> : vector<2xi32>
// CHECK: {{.*}} = llvm.mlir.constant(dense<-4> : vector<2xi32>) : !llvm<"<2 x i32>">
%1 = spv.constant dense<-4> : vector<2xsi32>
// CHECK: {{.*}} = llvm.mlir.constant(dense<[2, 3, 4]> : vector<3xi32>) : !llvm<"<3 x i32>">
%2 = spv.constant dense<[2, 3, 4]> : vector<3xui32>
return
}
func @float_constant_scalar() {
// CHECK: {{.*}} = llvm.mlir.constant(5.000000e+00 : f16) : !llvm.half
%0 = spv.constant 5.000000e+00 : f16
// CHECK: {{.*}} = llvm.mlir.constant(5.000000e+00 : f64) : !llvm.double
%1 = spv.constant 5.000000e+00 : f64
return
}
func @float_constant_vector() {
// CHECK: {{.*}} = llvm.mlir.constant(dense<[2.000000e+00, 3.000000e+00]> : vector<2xf32>) : !llvm<"<2 x float>">
%0 = spv.constant dense<[2.000000e+00, 3.000000e+00]> : vector<2xf32>
return
}