[MLIR][SPIRVToLLVM] Convert spv.constant scalars and vectors
This patch introduces conversion pattern for `spv.constant` with scalar and vector types. There is a special case when the constant value is a signed/unsigned integer (vector of integers). Since LLVM dialect does not have signedness semantics, the types had to be converted to signless ints. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D82936
This commit is contained in:
parent
5416fc014a
commit
1cfaaf6455
|
@ -31,6 +31,15 @@ using namespace mlir;
|
|||
// Utility functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns true if the given type is a signed integer or vector type.
|
||||
static bool isSignedIntegerOrVector(Type type) {
|
||||
if (type.isSignedInteger())
|
||||
return true;
|
||||
if (auto vecType = type.dyn_cast<VectorType>())
|
||||
return vecType.getElementType().isSignedInteger();
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if the given type is an unsigned integer or vector type
|
||||
static bool isUnsignedIntegerOrVector(Type type) {
|
||||
if (type.isUnsignedInteger())
|
||||
|
@ -197,6 +206,52 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Converts SPIR-V ConstantOp with scalar or vector type.
|
||||
class ConstantScalarAndVectorPattern
|
||||
: public SPIRVToLLVMConversion<spirv::ConstantOp> {
|
||||
public:
|
||||
using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::ConstantOp constOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcType = constOp.getType();
|
||||
if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
|
||||
return failure();
|
||||
|
||||
auto dstType = typeConverter.convertType(srcType);
|
||||
if (!dstType)
|
||||
return failure();
|
||||
|
||||
// SPIR-V constant can be a signed/unsigned integer, which has to be
|
||||
// casted to signless integer when converting to LLVM dialect. Removing the
|
||||
// sign bit may have unexpected behaviour. However, it is better to handle
|
||||
// it case-by-case, given that the purpose of the conversion is not to
|
||||
// cover all possible corner cases.
|
||||
if (isSignedIntegerOrVector(srcType) ||
|
||||
isUnsignedIntegerOrVector(srcType)) {
|
||||
auto *context = rewriter.getContext();
|
||||
auto signlessType = IntegerType::get(getBitWidth(srcType), context);
|
||||
|
||||
if (srcType.isa<VectorType>()) {
|
||||
auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
|
||||
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
|
||||
constOp, dstType,
|
||||
dstElementsAttr.mapValues(
|
||||
signlessType, [&](const APInt &value) { return value; }));
|
||||
return success();
|
||||
}
|
||||
auto srcAttr = constOp.value().cast<IntegerAttr>();
|
||||
auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
|
||||
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
|
||||
return success();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, operands,
|
||||
constOp.getAttrs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts SPIR-V operations that have straightforward LLVM equivalent
|
||||
/// into LLVM dialect operations.
|
||||
template <typename SPIRVOp, typename LLVMOp>
|
||||
|
@ -573,6 +628,9 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
|
|||
IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
|
||||
IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
|
||||
|
||||
// Constant op
|
||||
ConstantScalarAndVectorPattern,
|
||||
|
||||
// Function Call op
|
||||
FunctionCallPattern,
|
||||
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.constant
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @bool_constant_scalar() {
|
||||
// CHECK: {{.*}} = llvm.mlir.constant(true) : !llvm.i1
|
||||
%0 = spv.constant true
|
||||
// CHECK: {{.*}} = llvm.mlir.constant(false) : !llvm.i1
|
||||
%1 = spv.constant false
|
||||
return
|
||||
}
|
||||
|
||||
func @bool_constant_vector() {
|
||||
// CHECK: {{.*}} = llvm.mlir.constant(dense<[true, false]> : vector<2xi1>) : !llvm<"<2 x i1>">
|
||||
%0 = constant dense<[true, false]> : vector<2xi1>
|
||||
// CHECK: {{.*}} = llvm.mlir.constant(dense<false> : vector<3xi1>) : !llvm<"<3 x i1>">
|
||||
%1 = constant dense<false> : vector<3xi1>
|
||||
return
|
||||
}
|
||||
|
||||
func @integer_constant_scalar() {
|
||||
// CHECK: {{.*}} = llvm.mlir.constant(0 : i8) : !llvm.i8
|
||||
%0 = spv.constant 0 : i8
|
||||
// CHECK: {{.*}} = llvm.mlir.constant(-5 : i64) : !llvm.i64
|
||||
%1 = spv.constant -5 : si64
|
||||
// CHECK: {{.*}} = llvm.mlir.constant(10 : i16) : !llvm.i16
|
||||
%2 = spv.constant 10 : ui16
|
||||
return
|
||||
}
|
||||
|
||||
func @integer_constant_vector() {
|
||||
// CHECK: {{.*}} = llvm.mlir.constant(dense<[2, 3]> : vector<2xi32>) : !llvm<"<2 x i32>">
|
||||
%0 = spv.constant dense<[2, 3]> : vector<2xi32>
|
||||
// CHECK: {{.*}} = llvm.mlir.constant(dense<-4> : vector<2xi32>) : !llvm<"<2 x i32>">
|
||||
%1 = spv.constant dense<-4> : vector<2xsi32>
|
||||
// CHECK: {{.*}} = llvm.mlir.constant(dense<[2, 3, 4]> : vector<3xi32>) : !llvm<"<3 x i32>">
|
||||
%2 = spv.constant dense<[2, 3, 4]> : vector<3xui32>
|
||||
return
|
||||
}
|
||||
|
||||
func @float_constant_scalar() {
|
||||
// CHECK: {{.*}} = llvm.mlir.constant(5.000000e+00 : f16) : !llvm.half
|
||||
%0 = spv.constant 5.000000e+00 : f16
|
||||
// CHECK: {{.*}} = llvm.mlir.constant(5.000000e+00 : f64) : !llvm.double
|
||||
%1 = spv.constant 5.000000e+00 : f64
|
||||
return
|
||||
}
|
||||
|
||||
func @float_constant_vector() {
|
||||
// CHECK: {{.*}} = llvm.mlir.constant(dense<[2.000000e+00, 3.000000e+00]> : vector<2xf32>) : !llvm<"<2 x float>">
|
||||
%0 = spv.constant dense<[2.000000e+00, 3.000000e+00]> : vector<2xf32>
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue