[mlir] Add alignment attribute to LLVM memory ops and use in vector.transfer

Summary: The native alignment may generally not be used when lowering a vector.transfer to the underlying load/store operation. This revision fixes the unmasked load/store alignment to match that of the masked path.

Differential Revision: https://reviews.llvm.org/D83684
This commit is contained in:
Nicolas Vasilache 2020-07-13 11:04:09 -04:00
parent b9c2dd11a5
commit affbc0cd1c
5 changed files with 99 additions and 35 deletions

View File

@ -215,19 +215,36 @@ def LLVM_FDivOp : LLVM_ArithmeticOp<"fdiv", "CreateFDiv">;
def LLVM_FRemOp : LLVM_ArithmeticOp<"frem", "CreateFRem">; def LLVM_FRemOp : LLVM_ArithmeticOp<"frem", "CreateFRem">;
def LLVM_FNegOp : LLVM_UnaryArithmeticOp<"fneg", "CreateFNeg">; def LLVM_FNegOp : LLVM_UnaryArithmeticOp<"fneg", "CreateFNeg">;
// Memory-related operations. // Common code definition that is used to verify and set the alignment attribute
def LLVM_AllocaOp : // of LLVM ops that accept such an attribute.
LLVM_OneResultOp<"alloca">, class MemoryOpWithAlignmentBase {
Arguments<(ins LLVM_Type:$arraySize, OptionalAttr<I64Attr>:$alignment)> { code alignmentVerifierCode = [{
string llvmBuilder = [{ if (alignment().hasValue()) {
auto *alloca = builder.CreateAlloca( auto align = alignment().getValue().getSExtValue();
$_resultType->getPointerElementType(), $arraySize); if (align < 0)
return emitOpError("expected positive alignment");
}
return success();
}];
code setAlignmentCode = [{
if ($alignment.hasValue()) { if ($alignment.hasValue()) {
auto align = $alignment.getValue().getZExtValue(); auto align = $alignment.getValue().getZExtValue();
if (align != 0) if (align != 0)
alloca->setAlignment(llvm::Align(align)); inst->setAlignment(llvm::Align(align));
} }
$res = alloca; }];
}
// Memory-related operations.
def LLVM_AllocaOp :
MemoryOpWithAlignmentBase,
LLVM_OneResultOp<"alloca">,
Arguments<(ins LLVM_Type:$arraySize, OptionalAttr<I64Attr>:$alignment)> {
string llvmBuilder = [{
auto *inst = builder.CreateAlloca(
$_resultType->getPointerElementType(), $arraySize);
}] # setAlignmentCode # [{
$res = inst;
}]; }];
let builders = [OpBuilder< let builders = [OpBuilder<
"OpBuilder &b, OperationState &result, Type resultType, Value arraySize, " "OpBuilder &b, OperationState &result, Type resultType, Value arraySize, "
@ -239,14 +256,7 @@ def LLVM_AllocaOp :
}]>]; }]>];
let parser = [{ return parseAllocaOp(parser, result); }]; let parser = [{ return parseAllocaOp(parser, result); }];
let printer = [{ printAllocaOp(p, *this); }]; let printer = [{ printAllocaOp(p, *this); }];
let verifier = [{ let verifier = alignmentVerifierCode;
if (alignment().hasValue()) {
auto align = alignment().getValue().getSExtValue();
if (align < 0)
return emitOpError("expected positive alignment");
}
return success();
}];
} }
def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>, def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
Arguments<(ins LLVM_Type:$base, Variadic<LLVM_Type>:$indices)>, Arguments<(ins LLVM_Type:$base, Variadic<LLVM_Type>:$indices)>,
@ -255,22 +265,56 @@ def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
$base `[` $indices `]` attr-dict `:` functional-type(operands, results) $base `[` $indices `]` attr-dict `:` functional-type(operands, results)
}]; }];
} }
def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>, def LLVM_LoadOp :
LLVM_Builder<"$res = builder.CreateLoad($addr);"> { MemoryOpWithAlignmentBase,
LLVM_OneResultOp<"load">,
Arguments<(ins LLVM_Type:$addr, OptionalAttr<I64Attr>:$alignment)> {
string llvmBuilder = [{
auto *inst = builder.CreateLoad($addr);
}] # setAlignmentCode # [{
$res = inst;
}];
let builders = [OpBuilder< let builders = [OpBuilder<
"OpBuilder &b, OperationState &result, Value addr", "OpBuilder &b, OperationState &result, Value addr, unsigned alignment = 0",
[{ [{
auto type = addr.getType().cast<LLVM::LLVMType>().getPointerElementTy(); auto type = addr.getType().cast<LLVM::LLVMType>().getPointerElementTy();
build(b, result, type, addr); build(b, result, type, addr, alignment);
}]>,
OpBuilder<
"OpBuilder &b, OperationState &result, Type t, Value addr, "
"unsigned alignment = 0",
[{
if (alignment == 0)
return build(b, result, t, addr, IntegerAttr());
build(b, result, t, addr, b.getI64IntegerAttr(alignment));
}]>]; }]>];
let parser = [{ return parseLoadOp(parser, result); }]; let parser = [{ return parseLoadOp(parser, result); }];
let printer = [{ printLoadOp(p, *this); }]; let printer = [{ printLoadOp(p, *this); }];
let verifier = alignmentVerifierCode;
} }
def LLVM_StoreOp : LLVM_ZeroResultOp<"store">, def LLVM_StoreOp :
Arguments<(ins LLVM_Type:$value, LLVM_Type:$addr)>, MemoryOpWithAlignmentBase,
LLVM_Builder<"builder.CreateStore($value, $addr);"> { LLVM_ZeroResultOp<"store">,
Arguments<(ins LLVM_Type:$value,
LLVM_Type:$addr,
OptionalAttr<I64Attr>:$alignment)> {
string llvmBuilder = [{
auto *inst = builder.CreateStore($value, $addr);
}] # setAlignmentCode;
let builders = [
OpBuilder<
"OpBuilder &b, OperationState &result, Value value, Value addr, "
"unsigned alignment = 0",
[{
if (alignment == 0)
return build(b, result, ArrayRef<Type>{}, value, addr, IntegerAttr());
build(b, result, ArrayRef<Type>{}, value, addr,
b.getI64IntegerAttr(alignment));
}]
>];
let parser = [{ return parseStoreOp(parser, result); }]; let parser = [{ return parseStoreOp(parser, result); }];
let printer = [{ printStoreOp(p, *this); }]; let printer = [{ printStoreOp(p, *this); }];
let verifier = alignmentVerifierCode;
} }
// Casts. // Casts.

View File

@ -12,6 +12,15 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) {
return return
} }
func @transfer_read_unmasked_4(%A : memref<?xf32>, %base: index) {
%fm42 = constant -42.0: f32
%f = vector.transfer_read %A[%base], %fm42
{permutation_map = affine_map<(d0) -> (d0)>, masked = [false]} :
memref<?xf32>, vector<4xf32>
vector.print %f: vector<4xf32>
return
}
func @transfer_write_1d(%A : memref<?xf32>, %base: index) { func @transfer_write_1d(%A : memref<?xf32>, %base: index) {
%f0 = constant 0.0 : f32 %f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<4xf32> %vf0 = splat %f0 : vector<4xf32>
@ -44,8 +53,12 @@ func @entry() {
// Read shifted by 0 and pad with -42: // Read shifted by 0 and pad with -42:
// ( 0, 1, 2, 0, 0, -42, ..., -42) // ( 0, 1, 2, 0, 0, -42, ..., -42)
call @transfer_read_1d(%A, %c0) : (memref<?xf32>, index) -> () call @transfer_read_1d(%A, %c0) : (memref<?xf32>, index) -> ()
// Read unmasked 4 @ 1, guaranteed to not overflow.
// Exercises proper alignment.
call @transfer_read_unmasked_4(%A, %c1) : (memref<?xf32>, index) -> ()
return return
} }
// CHECK: ( 2, 3, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 ) // CHECK: ( 2, 3, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 )
// CHECK: ( 0, 1, 2, 0, 0, -42, -42, -42, -42, -42, -42, -42, -42 ) // CHECK: ( 0, 1, 2, 0, 0, -42, -42, -42, -42, -42, -42, -42, -42 )
// CHECK: ( 1, 2, 0, 0 )

View File

@ -3,11 +3,11 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s // RUN: FileCheck %s
func @transfer_write16_1d(%A : memref<?xf32>, %base: index) { func @transfer_write16_unmasked_1d(%A : memref<?xf32>, %base: index) {
%f = constant 16.0 : f32 %f = constant 16.0 : f32
%v = splat %f : vector<16xf32> %v = splat %f : vector<16xf32>
vector.transfer_write %v, %A[%base] vector.transfer_write %v, %A[%base]
{permutation_map = affine_map<(d0) -> (d0)>} {permutation_map = affine_map<(d0) -> (d0)>, masked = [false]}
: vector<16xf32>, memref<?xf32> : vector<16xf32>, memref<?xf32>
return return
} }
@ -53,14 +53,14 @@ func @entry() {
%0 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>) %0 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>)
vector.print %0 : vector<32xf32> vector.print %0 : vector<32xf32>
// Overwrite with 16 values of 16 at base 4. // Overwrite with 16 values of 16 at base 3.
%c4 = constant 4: index // Statically guaranteed to be unmasked. Exercises proper alignment.
call @transfer_write16_1d(%A, %c4) : (memref<?xf32>, index) -> () %c3 = constant 3: index
call @transfer_write16_unmasked_1d(%A, %c3) : (memref<?xf32>, index) -> ()
%1 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>) %1 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>)
vector.print %1 : vector<32xf32> vector.print %1 : vector<32xf32>
// Overwrite with 13 values of 13 at base 3. // Overwrite with 13 values of 13 at base 3.
%c3 = constant 3: index
call @transfer_write13_1d(%A, %c3) : (memref<?xf32>, index) -> () call @transfer_write13_1d(%A, %c3) : (memref<?xf32>, index) -> ()
%2 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>) %2 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>)
vector.print %2 : vector<32xf32> vector.print %2 : vector<32xf32>
@ -93,8 +93,8 @@ func @entry() {
} }
// CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
// CHECK: ( 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK: ( 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
// CHECK: ( 0, 0, 0, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK: ( 0, 0, 0, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
// CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
// CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
// CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0 ) // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0 )

View File

@ -143,7 +143,10 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc, LLVMTypeConverter &typeConverter, Location loc,
TransferReadOp xferOp, TransferReadOp xferOp,
ArrayRef<Value> operands, Value dataPtr) { ArrayRef<Value> operands, Value dataPtr) {
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr); unsigned align;
if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
return success(); return success();
} }
@ -176,8 +179,12 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc, LLVMTypeConverter &typeConverter, Location loc,
TransferWriteOp xferOp, TransferWriteOp xferOp,
ArrayRef<Value> operands, Value dataPtr) { ArrayRef<Value> operands, Value dataPtr) {
unsigned align;
if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
return failure();
auto adaptor = TransferWriteOpAdaptor(operands); auto adaptor = TransferWriteOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr); rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
align);
return success(); return success();
} }

View File

@ -935,7 +935,7 @@ func @transfer_read_1d_not_masked(%A : memref<?xf32>, %base: index) -> vector<17
// CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*"> // CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
// //
// 2. Rewrite as a load. // 2. Rewrite as a load.
// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] : !llvm<"<17 x float>*"> // CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm<"<17 x float>*">
func @genbool_1d() -> vector<8xi1> { func @genbool_1d() -> vector<8xi1> {
%0 = vector.constant_mask [4] : vector<8xi1> %0 = vector.constant_mask [4] : vector<8xi1>