[mlir] Add alignment attribute to LLVM memory ops and use in vector.transfer
Summary: The native alignment may generally not be used when lowering a vector.transfer to the underlying load/store operation. This revision fixes the unmasked load/store alignment to match that of the masked path. Differential Revision: https://reviews.llvm.org/D83684
This commit is contained in:
parent
b9c2dd11a5
commit
affbc0cd1c
|
@ -215,19 +215,36 @@ def LLVM_FDivOp : LLVM_ArithmeticOp<"fdiv", "CreateFDiv">;
|
||||||
def LLVM_FRemOp : LLVM_ArithmeticOp<"frem", "CreateFRem">;
|
def LLVM_FRemOp : LLVM_ArithmeticOp<"frem", "CreateFRem">;
|
||||||
def LLVM_FNegOp : LLVM_UnaryArithmeticOp<"fneg", "CreateFNeg">;
|
def LLVM_FNegOp : LLVM_UnaryArithmeticOp<"fneg", "CreateFNeg">;
|
||||||
|
|
||||||
// Memory-related operations.
|
// Common code definition that is used to verify and set the alignment attribute
|
||||||
def LLVM_AllocaOp :
|
// of LLVM ops that accept such an attribute.
|
||||||
LLVM_OneResultOp<"alloca">,
|
class MemoryOpWithAlignmentBase {
|
||||||
Arguments<(ins LLVM_Type:$arraySize, OptionalAttr<I64Attr>:$alignment)> {
|
code alignmentVerifierCode = [{
|
||||||
string llvmBuilder = [{
|
if (alignment().hasValue()) {
|
||||||
auto *alloca = builder.CreateAlloca(
|
auto align = alignment().getValue().getSExtValue();
|
||||||
$_resultType->getPointerElementType(), $arraySize);
|
if (align < 0)
|
||||||
|
return emitOpError("expected positive alignment");
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}];
|
||||||
|
code setAlignmentCode = [{
|
||||||
if ($alignment.hasValue()) {
|
if ($alignment.hasValue()) {
|
||||||
auto align = $alignment.getValue().getZExtValue();
|
auto align = $alignment.getValue().getZExtValue();
|
||||||
if (align != 0)
|
if (align != 0)
|
||||||
alloca->setAlignment(llvm::Align(align));
|
inst->setAlignment(llvm::Align(align));
|
||||||
}
|
}
|
||||||
$res = alloca;
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Memory-related operations.
|
||||||
|
def LLVM_AllocaOp :
|
||||||
|
MemoryOpWithAlignmentBase,
|
||||||
|
LLVM_OneResultOp<"alloca">,
|
||||||
|
Arguments<(ins LLVM_Type:$arraySize, OptionalAttr<I64Attr>:$alignment)> {
|
||||||
|
string llvmBuilder = [{
|
||||||
|
auto *inst = builder.CreateAlloca(
|
||||||
|
$_resultType->getPointerElementType(), $arraySize);
|
||||||
|
}] # setAlignmentCode # [{
|
||||||
|
$res = inst;
|
||||||
}];
|
}];
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"OpBuilder &b, OperationState &result, Type resultType, Value arraySize, "
|
"OpBuilder &b, OperationState &result, Type resultType, Value arraySize, "
|
||||||
|
@ -239,14 +256,7 @@ def LLVM_AllocaOp :
|
||||||
}]>];
|
}]>];
|
||||||
let parser = [{ return parseAllocaOp(parser, result); }];
|
let parser = [{ return parseAllocaOp(parser, result); }];
|
||||||
let printer = [{ printAllocaOp(p, *this); }];
|
let printer = [{ printAllocaOp(p, *this); }];
|
||||||
let verifier = [{
|
let verifier = alignmentVerifierCode;
|
||||||
if (alignment().hasValue()) {
|
|
||||||
auto align = alignment().getValue().getSExtValue();
|
|
||||||
if (align < 0)
|
|
||||||
return emitOpError("expected positive alignment");
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}];
|
|
||||||
}
|
}
|
||||||
def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
|
def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
|
||||||
Arguments<(ins LLVM_Type:$base, Variadic<LLVM_Type>:$indices)>,
|
Arguments<(ins LLVM_Type:$base, Variadic<LLVM_Type>:$indices)>,
|
||||||
|
@ -255,22 +265,56 @@ def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
|
||||||
$base `[` $indices `]` attr-dict `:` functional-type(operands, results)
|
$base `[` $indices `]` attr-dict `:` functional-type(operands, results)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>,
|
def LLVM_LoadOp :
|
||||||
LLVM_Builder<"$res = builder.CreateLoad($addr);"> {
|
MemoryOpWithAlignmentBase,
|
||||||
|
LLVM_OneResultOp<"load">,
|
||||||
|
Arguments<(ins LLVM_Type:$addr, OptionalAttr<I64Attr>:$alignment)> {
|
||||||
|
string llvmBuilder = [{
|
||||||
|
auto *inst = builder.CreateLoad($addr);
|
||||||
|
}] # setAlignmentCode # [{
|
||||||
|
$res = inst;
|
||||||
|
}];
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"OpBuilder &b, OperationState &result, Value addr",
|
"OpBuilder &b, OperationState &result, Value addr, unsigned alignment = 0",
|
||||||
[{
|
[{
|
||||||
auto type = addr.getType().cast<LLVM::LLVMType>().getPointerElementTy();
|
auto type = addr.getType().cast<LLVM::LLVMType>().getPointerElementTy();
|
||||||
build(b, result, type, addr);
|
build(b, result, type, addr, alignment);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<
|
||||||
|
"OpBuilder &b, OperationState &result, Type t, Value addr, "
|
||||||
|
"unsigned alignment = 0",
|
||||||
|
[{
|
||||||
|
if (alignment == 0)
|
||||||
|
return build(b, result, t, addr, IntegerAttr());
|
||||||
|
build(b, result, t, addr, b.getI64IntegerAttr(alignment));
|
||||||
}]>];
|
}]>];
|
||||||
let parser = [{ return parseLoadOp(parser, result); }];
|
let parser = [{ return parseLoadOp(parser, result); }];
|
||||||
let printer = [{ printLoadOp(p, *this); }];
|
let printer = [{ printLoadOp(p, *this); }];
|
||||||
|
let verifier = alignmentVerifierCode;
|
||||||
}
|
}
|
||||||
def LLVM_StoreOp : LLVM_ZeroResultOp<"store">,
|
def LLVM_StoreOp :
|
||||||
Arguments<(ins LLVM_Type:$value, LLVM_Type:$addr)>,
|
MemoryOpWithAlignmentBase,
|
||||||
LLVM_Builder<"builder.CreateStore($value, $addr);"> {
|
LLVM_ZeroResultOp<"store">,
|
||||||
|
Arguments<(ins LLVM_Type:$value,
|
||||||
|
LLVM_Type:$addr,
|
||||||
|
OptionalAttr<I64Attr>:$alignment)> {
|
||||||
|
string llvmBuilder = [{
|
||||||
|
auto *inst = builder.CreateStore($value, $addr);
|
||||||
|
}] # setAlignmentCode;
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<
|
||||||
|
"OpBuilder &b, OperationState &result, Value value, Value addr, "
|
||||||
|
"unsigned alignment = 0",
|
||||||
|
[{
|
||||||
|
if (alignment == 0)
|
||||||
|
return build(b, result, ArrayRef<Type>{}, value, addr, IntegerAttr());
|
||||||
|
build(b, result, ArrayRef<Type>{}, value, addr,
|
||||||
|
b.getI64IntegerAttr(alignment));
|
||||||
|
}]
|
||||||
|
>];
|
||||||
let parser = [{ return parseStoreOp(parser, result); }];
|
let parser = [{ return parseStoreOp(parser, result); }];
|
||||||
let printer = [{ printStoreOp(p, *this); }];
|
let printer = [{ printStoreOp(p, *this); }];
|
||||||
|
let verifier = alignmentVerifierCode;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Casts.
|
// Casts.
|
||||||
|
|
|
@ -12,6 +12,15 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @transfer_read_unmasked_4(%A : memref<?xf32>, %base: index) {
|
||||||
|
%fm42 = constant -42.0: f32
|
||||||
|
%f = vector.transfer_read %A[%base], %fm42
|
||||||
|
{permutation_map = affine_map<(d0) -> (d0)>, masked = [false]} :
|
||||||
|
memref<?xf32>, vector<4xf32>
|
||||||
|
vector.print %f: vector<4xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func @transfer_write_1d(%A : memref<?xf32>, %base: index) {
|
func @transfer_write_1d(%A : memref<?xf32>, %base: index) {
|
||||||
%f0 = constant 0.0 : f32
|
%f0 = constant 0.0 : f32
|
||||||
%vf0 = splat %f0 : vector<4xf32>
|
%vf0 = splat %f0 : vector<4xf32>
|
||||||
|
@ -44,8 +53,12 @@ func @entry() {
|
||||||
// Read shifted by 0 and pad with -42:
|
// Read shifted by 0 and pad with -42:
|
||||||
// ( 0, 1, 2, 0, 0, -42, ..., -42)
|
// ( 0, 1, 2, 0, 0, -42, ..., -42)
|
||||||
call @transfer_read_1d(%A, %c0) : (memref<?xf32>, index) -> ()
|
call @transfer_read_1d(%A, %c0) : (memref<?xf32>, index) -> ()
|
||||||
|
// Read unmasked 4 @ 1, guaranteed to not overflow.
|
||||||
|
// Exercises proper alignment.
|
||||||
|
call @transfer_read_unmasked_4(%A, %c1) : (memref<?xf32>, index) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: ( 2, 3, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 )
|
// CHECK: ( 2, 3, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 )
|
||||||
// CHECK: ( 0, 1, 2, 0, 0, -42, -42, -42, -42, -42, -42, -42, -42 )
|
// CHECK: ( 0, 1, 2, 0, 0, -42, -42, -42, -42, -42, -42, -42, -42 )
|
||||||
|
// CHECK: ( 1, 2, 0, 0 )
|
||||||
|
|
|
@ -3,11 +3,11 @@
|
||||||
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
||||||
// RUN: FileCheck %s
|
// RUN: FileCheck %s
|
||||||
|
|
||||||
func @transfer_write16_1d(%A : memref<?xf32>, %base: index) {
|
func @transfer_write16_unmasked_1d(%A : memref<?xf32>, %base: index) {
|
||||||
%f = constant 16.0 : f32
|
%f = constant 16.0 : f32
|
||||||
%v = splat %f : vector<16xf32>
|
%v = splat %f : vector<16xf32>
|
||||||
vector.transfer_write %v, %A[%base]
|
vector.transfer_write %v, %A[%base]
|
||||||
{permutation_map = affine_map<(d0) -> (d0)>}
|
{permutation_map = affine_map<(d0) -> (d0)>, masked = [false]}
|
||||||
: vector<16xf32>, memref<?xf32>
|
: vector<16xf32>, memref<?xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -53,14 +53,14 @@ func @entry() {
|
||||||
%0 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>)
|
%0 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>)
|
||||||
vector.print %0 : vector<32xf32>
|
vector.print %0 : vector<32xf32>
|
||||||
|
|
||||||
// Overwrite with 16 values of 16 at base 4.
|
// Overwrite with 16 values of 16 at base 3.
|
||||||
%c4 = constant 4: index
|
// Statically guaranteed to be unmasked. Exercises proper alignment.
|
||||||
call @transfer_write16_1d(%A, %c4) : (memref<?xf32>, index) -> ()
|
%c3 = constant 3: index
|
||||||
|
call @transfer_write16_unmasked_1d(%A, %c3) : (memref<?xf32>, index) -> ()
|
||||||
%1 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>)
|
%1 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>)
|
||||||
vector.print %1 : vector<32xf32>
|
vector.print %1 : vector<32xf32>
|
||||||
|
|
||||||
// Overwrite with 13 values of 13 at base 3.
|
// Overwrite with 13 values of 13 at base 3.
|
||||||
%c3 = constant 3: index
|
|
||||||
call @transfer_write13_1d(%A, %c3) : (memref<?xf32>, index) -> ()
|
call @transfer_write13_1d(%A, %c3) : (memref<?xf32>, index) -> ()
|
||||||
%2 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>)
|
%2 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>)
|
||||||
vector.print %2 : vector<32xf32>
|
vector.print %2 : vector<32xf32>
|
||||||
|
@ -93,8 +93,8 @@ func @entry() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
// CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
||||||
// CHECK: ( 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
// CHECK: ( 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
||||||
// CHECK: ( 0, 0, 0, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
// CHECK: ( 0, 0, 0, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
||||||
// CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
// CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
||||||
// CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
// CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
|
||||||
// CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0 )
|
// CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0 )
|
||||||
|
|
|
@ -143,7 +143,10 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
|
||||||
LLVMTypeConverter &typeConverter, Location loc,
|
LLVMTypeConverter &typeConverter, Location loc,
|
||||||
TransferReadOp xferOp,
|
TransferReadOp xferOp,
|
||||||
ArrayRef<Value> operands, Value dataPtr) {
|
ArrayRef<Value> operands, Value dataPtr) {
|
||||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr);
|
unsigned align;
|
||||||
|
if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
|
||||||
|
return failure();
|
||||||
|
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,8 +179,12 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
|
||||||
LLVMTypeConverter &typeConverter, Location loc,
|
LLVMTypeConverter &typeConverter, Location loc,
|
||||||
TransferWriteOp xferOp,
|
TransferWriteOp xferOp,
|
||||||
ArrayRef<Value> operands, Value dataPtr) {
|
ArrayRef<Value> operands, Value dataPtr) {
|
||||||
|
unsigned align;
|
||||||
|
if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
|
||||||
|
return failure();
|
||||||
auto adaptor = TransferWriteOpAdaptor(operands);
|
auto adaptor = TransferWriteOpAdaptor(operands);
|
||||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr);
|
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
|
||||||
|
align);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -935,7 +935,7 @@ func @transfer_read_1d_not_masked(%A : memref<?xf32>, %base: index) -> vector<17
|
||||||
// CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
|
// CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
|
||||||
//
|
//
|
||||||
// 2. Rewrite as a load.
|
// 2. Rewrite as a load.
|
||||||
// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] : !llvm<"<17 x float>*">
|
// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm<"<17 x float>*">
|
||||||
|
|
||||||
func @genbool_1d() -> vector<8xi1> {
|
func @genbool_1d() -> vector<8xi1> {
|
||||||
%0 = vector.constant_mask [4] : vector<8xi1>
|
%0 = vector.constant_mask [4] : vector<8xi1>
|
||||||
|
|
Loading…
Reference in New Issue