[Matrix] Adjust matrix pointer type for inline asm arguments.

Matrix types in memory are represented as arrays, but accessed through
vector pointers, with the alignment specified on the access operation.

For inline assembly, update pointer arguments to use vector pointers.
Otherwise there will be a mis-match if the matrix is also an
input-argument which is represented as vector.

Reviewed By: nickdesaulniers

Differential Revision: https://reviews.llvm.org/D91631
This commit is contained in:
Florian Hahn 2020-11-18 11:32:45 +00:00
parent 096bd9b293
commit 680931af27
No known key found for this signature in database
GPG Key ID: 61D7554B5CECDC0D
2 changed files with 19 additions and 6 deletions

View File

@ -2306,8 +2306,21 @@ void CodeGenFunction::EmitAsmStmt(const AsmStmt &S) {
std::max((uint64_t)LargestVectorWidth,
VT->getPrimitiveSizeInBits().getKnownMinSize());
} else {
ArgTypes.push_back(Dest.getAddress(*this).getType());
Args.push_back(Dest.getPointer(*this));
llvm::Type *DestAddrTy = Dest.getAddress(*this).getType();
llvm::Value *DestPtr = Dest.getPointer(*this);
// Matrix types in memory are represented by arrays, but accessed through
// vector pointers, with the alignment specified on the access operation.
// For inline assembly, update pointer arguments to use vector pointers.
// Otherwise there will be a mis-match if the matrix is also an
// input-argument which is represented as vector.
if (isa<MatrixType>(OutExpr->getType().getCanonicalType())) {
DestAddrTy = llvm::PointerType::get(
ConvertType(OutExpr->getType()),
cast<llvm::PointerType>(DestAddrTy)->getAddressSpace());
DestPtr = Builder.CreateBitCast(DestPtr, DestAddrTy);
}
ArgTypes.push_back(DestAddrTy);
Args.push_back(DestPtr);
Constraints += "=*";
Constraints += OutputConstraint;
ReadOnly = ReadNone = false;

View File

@ -162,10 +162,10 @@ void matrix_inline_asm_memory_readwrite() {
// CHECK-LABEL: define void @matrix_inline_asm_memory_readwrite()
// CHECK-NEXT: entry:
// CHECK-NEXT: [[ALLOCA:%.+]] = alloca [16 x double], align 8
// CHECK-NEXT: [[PTR:%.+]] = bitcast [16 x double]* [[ALLOCA]] to <16 x double>*
// CHECK-NEXT: [[VAL:%.+]] = load <16 x double>, <16 x double>* [[PTR]], align 8
// FIXME: Pointer element type does not match the vector type.
// CHECK-NEXT: call void asm sideeffect "", "=*r|m,0,~{memory},~{dirflag},~{fpsr},~{flags}"([16 x double]* [[ALLOCA]], <16 x double> [[VAL]])
// CHECK-NEXT: [[PTR1:%.+]] = bitcast [16 x double]* [[ALLOCA]] to <16 x double>*
// CHECK-NEXT: [[PTR2:%.+]] = bitcast [16 x double]* [[ALLOCA]] to <16 x double>*
// CHECK-NEXT: [[VAL:%.+]] = load <16 x double>, <16 x double>* [[PTR2]], align 8
// CHECK-NEXT: call void asm sideeffect "", "=*r|m,0,~{memory},~{dirflag},~{fpsr},~{flags}"(<16 x double>* [[PTR1]], <16 x double> [[VAL]])
// CHECK-NEXT: ret void
dx4x4_t m;