[MLIR][SPIRVToLLVM] Conversion of SPIR-V array, runtime array, and pointer types
This patch adds type conversion for 4 SPIR-V types: array, runtime array, pointer and struct. This conversion is integrated using a separate function `populateSPIRVToLLVMTypeConversion()` that adds new type conversions. At the moment, this is a basic skeleton that allows to perfom conversion from SPIR-V array, runtime array and pointer types to LLVM typesystem. There is no support of array strides or storage classes. These will be supported on the case by case basis. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D83399
This commit is contained in:
parent
4d4d903767
commit
28cd3cbc12
|
@ -32,6 +32,9 @@ protected:
|
||||||
LLVMTypeConverter &typeConverter;
|
LLVMTypeConverter &typeConverter;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Populates type conversions with additional SPIR-V types.
|
||||||
|
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter);
|
||||||
|
|
||||||
/// Populates the given list with patterns that convert from SPIR-V to LLVM.
|
/// Populates the given list with patterns that convert from SPIR-V to LLVM.
|
||||||
void populateSPIRVToLLVMConversionPatterns(MLIRContext *context,
|
void populateSPIRVToLLVMConversionPatterns(MLIRContext *context,
|
||||||
LLVMTypeConverter &typeConverter,
|
LLVMTypeConverter &typeConverter,
|
||||||
|
|
|
@ -165,6 +165,43 @@ static Value processCountOrOffset(Location loc, Value value, Type srcType,
|
||||||
return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
|
return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Type conversion
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Converts SPIR-V array type to LLVM array. There is no modelling of array
|
||||||
|
/// stride at the moment.
|
||||||
|
static Optional<Type> convertArrayType(spirv::ArrayType type,
|
||||||
|
TypeConverter &converter) {
|
||||||
|
if (type.getArrayStride() != 0)
|
||||||
|
return llvm::None;
|
||||||
|
auto elementType =
|
||||||
|
converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
|
||||||
|
unsigned numElements = type.getNumElements();
|
||||||
|
return LLVM::LLVMType::getArrayTy(elementType, numElements);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
|
||||||
|
/// modelled at the moment.
|
||||||
|
static Type convertPointerType(spirv::PointerType type,
|
||||||
|
TypeConverter &converter) {
|
||||||
|
auto pointeeType =
|
||||||
|
converter.convertType(type.getPointeeType()).cast<LLVM::LLVMType>();
|
||||||
|
return pointeeType.getPointerTo();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
|
||||||
|
/// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
|
||||||
|
/// no modelling of array stride at the moment.
|
||||||
|
static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
|
||||||
|
TypeConverter &converter) {
|
||||||
|
if (type.getArrayStride() != 0)
|
||||||
|
return llvm::None;
|
||||||
|
auto elementType =
|
||||||
|
converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
|
||||||
|
return LLVM::LLVMType::getArrayTy(elementType, 0);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Operation conversion
|
// Operation conversion
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -581,6 +618,8 @@ public:
|
||||||
funcType.getNumInputs());
|
funcType.getNumInputs());
|
||||||
auto llvmType = this->typeConverter.convertFunctionSignature(
|
auto llvmType = this->typeConverter.convertFunctionSignature(
|
||||||
funcOp.getType(), /*isVariadic=*/false, signatureConverter);
|
funcOp.getType(), /*isVariadic=*/false, signatureConverter);
|
||||||
|
if (!llvmType)
|
||||||
|
return failure();
|
||||||
|
|
||||||
// Create a new `LLVMFuncOp`
|
// Create a new `LLVMFuncOp`
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
|
@ -662,6 +701,18 @@ public:
|
||||||
// Pattern population
|
// Pattern population
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
|
||||||
|
typeConverter.addConversion([&](spirv::ArrayType type) {
|
||||||
|
return convertArrayType(type, typeConverter);
|
||||||
|
});
|
||||||
|
typeConverter.addConversion([&](spirv::PointerType type) {
|
||||||
|
return convertPointerType(type, typeConverter);
|
||||||
|
});
|
||||||
|
typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
|
||||||
|
return convertRuntimeArrayType(type, typeConverter);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void mlir::populateSPIRVToLLVMConversionPatterns(
|
void mlir::populateSPIRVToLLVMConversionPatterns(
|
||||||
MLIRContext *context, LLVMTypeConverter &typeConverter,
|
MLIRContext *context, LLVMTypeConverter &typeConverter,
|
||||||
OwningRewritePatternList &patterns) {
|
OwningRewritePatternList &patterns) {
|
||||||
|
|
|
@ -34,6 +34,9 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
|
||||||
LLVMTypeConverter converter(&getContext());
|
LLVMTypeConverter converter(&getContext());
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
|
|
||||||
|
populateSPIRVToLLVMTypeConversion(converter);
|
||||||
|
|
||||||
populateSPIRVToLLVMModuleConversionPatterns(context, converter, patterns);
|
populateSPIRVToLLVMModuleConversionPatterns(context, converter, patterns);
|
||||||
populateSPIRVToLLVMConversionPatterns(context, converter, patterns);
|
populateSPIRVToLLVMConversionPatterns(context, converter, patterns);
|
||||||
populateSPIRVToLLVMFunctionConversionPatterns(context, converter, patterns);
|
populateSPIRVToLLVMFunctionConversionPatterns(context, converter, patterns);
|
||||||
|
|
|
@ -34,6 +34,12 @@ func @bitcast_vector_to_vector(%arg0 : vector<4xf32>) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @bitcast_pointer(%arg0: !spv.ptr<f32, Function>) {
|
||||||
|
// CHECK: %{{.*}} = llvm.bitcast %{{.*}} : !llvm<"float*"> to !llvm<"i32*">
|
||||||
|
%0 = spv.Bitcast %arg0 : !spv.ptr<f32, Function> to !spv.ptr<i32, Function>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// spv.ConvertFToS
|
// spv.ConvertFToS
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
// RUN: mlir-opt %s -convert-spirv-to-llvm -verify-diagnostics -split-input-file
|
||||||
|
|
||||||
|
// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}}
|
||||||
|
spv.func @array_with_stride(%arg: !spv.array<4 x f32, stride=4>) -> () "None" {
|
||||||
|
spv.Return
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
// RUN: mlir-opt -split-input-file -convert-spirv-to-llvm -verify-diagnostics %s | FileCheck %s
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Array type
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// CHECK-LABEL: @array(!llvm<"[16 x float]">, !llvm<"[32 x <4 x float>]">)
|
||||||
|
func @array(!spv.array<16xf32>, !spv.array< 32 x vector<4xf32> >) -> ()
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pointer type
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// CHECK-LABEL: @pointer_scalar(!llvm<"i1*">, !llvm<"float*">)
|
||||||
|
func @pointer_scalar(!spv.ptr<i1, Uniform>, !spv.ptr<f32, Private>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: @pointer_vector(!llvm<"<4 x i32>*">)
|
||||||
|
func @pointer_vector(!spv.ptr<vector<4xi32>, Function>) -> ()
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Runtime array type
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// CHECK-LABEL: @runtime_array_vector(!llvm<"[0 x <4 x float>]">)
|
||||||
|
func @runtime_array_vector(!spv.rtarray< vector<4xf32> >) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: @runtime_array_scalar(!llvm<"[0 x float]">)
|
||||||
|
func @runtime_array_scalar(!spv.rtarray<f32>) -> ()
|
Loading…
Reference in New Issue