[MLIR][SPIRVToLLVM] Conversion of SPIR-V array, runtime array, and pointer types

This patch adds type conversion for 4 SPIR-V types: array, runtime array, pointer
and struct. This conversion is integrated using a separate function
`populateSPIRVToLLVMTypeConversion()` that adds new type conversions. At the moment,
this is a basic skeleton that allows to perfom conversion from SPIR-V array,
runtime array and pointer types to LLVM typesystem. There is no support of array
strides or storage classes. These will be supported on the case by case basis.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D83399
This commit is contained in:
George Mitenkov 2020-07-09 17:48:50 +03:00
parent 4d4d903767
commit 28cd3cbc12
6 changed files with 97 additions and 0 deletions

View File

@ -32,6 +32,9 @@ protected:
LLVMTypeConverter &typeConverter;
};
/// Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter);
/// Populates the given list with patterns that convert from SPIR-V to LLVM.
void populateSPIRVToLLVMConversionPatterns(MLIRContext *context,
LLVMTypeConverter &typeConverter,

View File

@ -165,6 +165,43 @@ static Value processCountOrOffset(Location loc, Value value, Type srcType,
return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
}
//===----------------------------------------------------------------------===//
// Type conversion
//===----------------------------------------------------------------------===//
/// Converts SPIR-V array type to LLVM array. There is no modelling of array
/// stride at the moment.
static Optional<Type> convertArrayType(spirv::ArrayType type,
TypeConverter &converter) {
if (type.getArrayStride() != 0)
return llvm::None;
auto elementType =
converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
unsigned numElements = type.getNumElements();
return LLVM::LLVMType::getArrayTy(elementType, numElements);
}
/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
/// modelled at the moment.
static Type convertPointerType(spirv::PointerType type,
TypeConverter &converter) {
auto pointeeType =
converter.convertType(type.getPointeeType()).cast<LLVM::LLVMType>();
return pointeeType.getPointerTo();
}
/// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
/// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
/// no modelling of array stride at the moment.
static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
TypeConverter &converter) {
if (type.getArrayStride() != 0)
return llvm::None;
auto elementType =
converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
return LLVM::LLVMType::getArrayTy(elementType, 0);
}
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
@ -581,6 +618,8 @@ public:
funcType.getNumInputs());
auto llvmType = this->typeConverter.convertFunctionSignature(
funcOp.getType(), /*isVariadic=*/false, signatureConverter);
if (!llvmType)
return failure();
// Create a new `LLVMFuncOp`
Location loc = funcOp.getLoc();
@ -662,6 +701,18 @@ public:
// Pattern population
//===----------------------------------------------------------------------===//
void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
typeConverter.addConversion([&](spirv::ArrayType type) {
return convertArrayType(type, typeConverter);
});
typeConverter.addConversion([&](spirv::PointerType type) {
return convertPointerType(type, typeConverter);
});
typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
return convertRuntimeArrayType(type, typeConverter);
});
}
void mlir::populateSPIRVToLLVMConversionPatterns(
MLIRContext *context, LLVMTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {

View File

@ -34,6 +34,9 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
LLVMTypeConverter converter(&getContext());
OwningRewritePatternList patterns;
populateSPIRVToLLVMTypeConversion(converter);
populateSPIRVToLLVMModuleConversionPatterns(context, converter, patterns);
populateSPIRVToLLVMConversionPatterns(context, converter, patterns);
populateSPIRVToLLVMFunctionConversionPatterns(context, converter, patterns);

View File

@ -34,6 +34,12 @@ func @bitcast_vector_to_vector(%arg0 : vector<4xf32>) {
return
}
func @bitcast_pointer(%arg0: !spv.ptr<f32, Function>) {
// CHECK: %{{.*}} = llvm.bitcast %{{.*}} : !llvm<"float*"> to !llvm<"i32*">
%0 = spv.Bitcast %arg0 : !spv.ptr<f32, Function> to !spv.ptr<i32, Function>
return
}
//===----------------------------------------------------------------------===//
// spv.ConvertFToS
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,6 @@
// RUN: mlir-opt %s -convert-spirv-to-llvm -verify-diagnostics -split-input-file
// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}}
spv.func @array_with_stride(%arg: !spv.array<4 x f32, stride=4>) -> () "None" {
spv.Return
}

View File

@ -0,0 +1,28 @@
// RUN: mlir-opt -split-input-file -convert-spirv-to-llvm -verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
// Array type
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @array(!llvm<"[16 x float]">, !llvm<"[32 x <4 x float>]">)
func @array(!spv.array<16xf32>, !spv.array< 32 x vector<4xf32> >) -> ()
//===----------------------------------------------------------------------===//
// Pointer type
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @pointer_scalar(!llvm<"i1*">, !llvm<"float*">)
func @pointer_scalar(!spv.ptr<i1, Uniform>, !spv.ptr<f32, Private>) -> ()
// CHECK-LABEL: @pointer_vector(!llvm<"<4 x i32>*">)
func @pointer_vector(!spv.ptr<vector<4xi32>, Function>) -> ()
//===----------------------------------------------------------------------===//
// Runtime array type
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @runtime_array_vector(!llvm<"[0 x <4 x float>]">)
func @runtime_array_vector(!spv.rtarray< vector<4xf32> >) -> ()
// CHECK-LABEL: @runtime_array_scalar(!llvm<"[0 x float]">)
func @runtime_array_scalar(!spv.rtarray<f32>) -> ()