[mlir] Split std.splat into tensor.splat and vector.splat

This is part of the larger effort to split the standard dialect. This will also allow for pruning some
additional dependencies on Standard (done in a followup).

Differential Revision: https://reviews.llvm.org/D118202
This commit is contained in:
River Riddle 2022-01-25 15:51:05 -08:00
parent f7a6c341cb
commit 6a8ba3186e
42 changed files with 504 additions and 419 deletions

View File

@ -507,55 +507,6 @@ def SelectOp : Std_Op<"select", [NoSideEffect,
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
def SplatOp : Std_Op<"splat", [NoSideEffect,
TypesMatchWith<"operand type matches element type of result",
"aggregate", "input",
"$_self.cast<ShapedType>().getElementType()">]> {
let summary = "splat or broadcast operation";
let description = [{
Broadcast the operand to all elements of the result vector or tensor. The
operand has to be of integer/index/float type. When the result is a tensor,
it has to be statically shaped.
Example:
```mlir
%s = load %A[%i] : memref<128xf32>
%v = splat %s : vector<4xf32>
%t = splat %s : tensor<8x16xi32>
```
TODO: This operation is easy to extend to broadcast to dynamically shaped
tensors in the same way dynamically shaped memrefs are handled.
```mlir
// Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding
// to the sizes of the two dynamic dimensions.
%m = "foo"() : () -> (index)
%n = "bar"() : () -> (index)
%t = splat %s [%m, %n] : tensor<?x?xi32>
```
}];
let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
"integer/index/float type">:$input);
let results = (outs AnyTypeOf<[AnyVectorOfAnyRank,
AnyStaticShapeTensor]>:$aggregate);
let builders = [
OpBuilder<(ins "Value":$element, "Type":$aggregateType),
[{ build($_builder, $_state, aggregateType, element); }]>];
let hasFolder = 1;
let hasVerifier = 1;
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
}
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//

View File

@ -968,6 +968,52 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
def Tensor_SplatOp : Tensor_Op<"splat", [
NoSideEffect,
TypesMatchWith<"operand type matches element type of result",
"aggregate", "input",
"$_self.cast<TensorType>().getElementType()">
]> {
let summary = "tensor splat or broadcast operation";
let description = [{
Broadcast the operand to all elements of the result tensor. The operand is
required to be of integer/index/float type, and the result tensor must be
statically shaped.
Example:
```mlir
%s = arith.constant 10.1 : f32
%t = tensor.splat %s : tensor<8x16xi32>
```
TODO: This operation is easy to extend to broadcast to dynamically shaped
tensors:
```mlir
// Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding
// to the sizes of the two dynamic dimensions.
%m = "foo"() : () -> (index)
%n = "bar"() : () -> (index)
%t = tensor.splat %s [%m, %n] : tensor<?x?xi32>
```
}];
let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
"integer/index/float type">:$input);
let results = (outs AnyStaticShapeTensor:$aggregate);
let builders = [
OpBuilder<(ins "Value":$element, "Type":$aggregateType),
[{ build($_builder, $_state, aggregateType, element); }]>];
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// YieldOp

View File

@ -2420,6 +2420,41 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [NoSideEffect,
let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
}
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
def Vector_SplatOp : Vector_Op<"splat", [
NoSideEffect,
TypesMatchWith<"operand type matches element type of result",
"aggregate", "input",
"$_self.cast<VectorType>().getElementType()">
]> {
let summary = "vector splat or broadcast operation";
let description = [{
Broadcast the operand to all elements of the result vector. The operand is
required to be of integer/index/float type.
Example:
```mlir
%s = arith.constant 10.1 : f32
%t = vector.splat %s : vector<8x16xi32>
```
}];
let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
"integer/index/float type">:$input);
let results = (outs AnyVectorOfAnyRank:$aggregate);
let builders = [
OpBuilder<(ins "Value":$element, "Type":$aggregateType),
[{ build($_builder, $_state, aggregateType, element); }]>];
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// VectorScaleOp
//===----------------------------------------------------------------------===//

View File

@ -49,6 +49,8 @@ public:
template <typename U> bool isa() const;
template <typename First, typename Second, typename... Rest>
bool isa() const;
template <typename First, typename... Rest>
bool isa_and_nonnull() const;
template <typename U> U dyn_cast() const;
template <typename U> U dyn_cast_or_null() const;
template <typename U> U cast() const;
@ -114,6 +116,11 @@ bool Attribute::isa() const {
return isa<First>() || isa<Second, Rest...>();
}
template <typename First, typename... Rest>
bool Attribute::isa_and_nonnull() const {
return impl && isa<First, Rest...>();
}
template <typename U> U Attribute::dyn_cast() const {
return isa<U>() ? U(impl) : U(nullptr);
}

View File

@ -663,99 +663,6 @@ struct SwitchOpLowering
using Super::Super;
};
// The Splat operation is lowered to an insertelement + a shufflevector
// operation. Splat to only 0-d and 1-d vector result types are lowered.
struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
if (!resultType || resultType.getRank() > 1)
return failure();
// First insert it into an undef vector so we can shuffle it.
auto vectorType = typeConverter->convertType(splatOp.getType());
Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
auto zero = rewriter.create<LLVM::ConstantOp>(
splatOp.getLoc(),
typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
// For 0-d vector, we simply do `insertelement`.
if (resultType.getRank() == 0) {
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
splatOp, vectorType, undef, adaptor.getInput(), zero);
return success();
}
// For 1-d vector, we additionally do a `vectorshuffle`.
auto v = rewriter.create<LLVM::InsertElementOp>(
splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
SmallVector<int32_t, 4> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
zeroAttrs);
return success();
}
};
// The Splat operation is lowered to an insertelement + a shufflevector
// operation. Splat to only 2+-d vector result types are lowered by the
// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
if (!resultType || resultType.getRank() <= 1)
return failure();
// First insert it into an undef vector so we can shuffle it.
auto loc = splatOp.getLoc();
auto vectorTypeInfo =
LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
if (!llvmNDVectorTy || !llvm1DVectorTy)
return failure();
// Construct returned value.
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
// Construct a 1-D vector with the splatted value that we insert in all the
// places within the returned descriptor.
Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
adaptor.getInput(), zero);
// Shuffle the value across the desired number of elements.
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
SmallVector<int32_t, 4> zeroValues(width, 0);
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
// Iterate of linear index, convert to coords space and insert splatted 1-D
// vector in each position.
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
position);
});
rewriter.replaceOp(splatOp, desc);
return success();
}
};
} // namespace
void mlir::populateStdToLLVMFuncOpConversionPattern(
@ -779,8 +686,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
ConstantOpLowering,
ReturnOpLowering,
SelectOpLowering,
SplatOpLowering,
SplatNdOpLowering,
SwitchOpLowering>(converter);
// clang-format on
}

View File

@ -55,16 +55,6 @@ public:
ConversionPatternRewriter &rewriter) const override;
};
/// Converts std.splat to spv.CompositeConstruct.
class SplatPattern final : public OpConversionPattern<SplatOp> {
public:
using OpConversionPattern<SplatOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Converts std.br to spv.Branch.
struct BranchOpPattern final : public OpConversionPattern<BranchOp> {
using OpConversionPattern<BranchOp>::OpConversionPattern;
@ -178,22 +168,6 @@ SelectOpPattern::matchAndRewrite(SelectOp op, OpAdaptor adaptor,
return success();
}
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
LogicalResult
SplatPattern::matchAndRewrite(SplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto dstVecType = op.getType().dyn_cast<VectorType>();
if (!dstVecType || !spirv::CompositeType::isValid(dstVecType))
return failure();
SmallVector<Value, 4> source(dstVecType.getNumElements(), adaptor.getInput());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
source);
return success();
}
//===----------------------------------------------------------------------===//
// BranchOpPattern
//===----------------------------------------------------------------------===//
@ -237,8 +211,8 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern,
CondBranchOpPattern>(typeConverter, context);
ReturnOpPattern, SelectOpPattern, BranchOpPattern, CondBranchOpPattern>(
typeConverter, context);
}
void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,

View File

@ -778,7 +778,7 @@ public:
auto elemType = vType.getElementType();
Value zero = rewriter.create<arith::ConstantOp>(
loc, elemType, rewriter.getZeroAttr(elemType));
Value desc = rewriter.create<SplatOp>(loc, vType, zero);
Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
@ -1062,6 +1062,99 @@ private:
}
};
/// The Splat operation is lowered to an insertelement + a shufflevector
/// operation. Splat to only 0-d and 1-d vector result types are lowered.
struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType resultType = splatOp.getType().cast<VectorType>();
if (resultType.getRank() > 1)
return failure();
// First insert it into an undef vector so we can shuffle it.
auto vectorType = typeConverter->convertType(splatOp.getType());
Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
auto zero = rewriter.create<LLVM::ConstantOp>(
splatOp.getLoc(),
typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
// For 0-d vector, we simply do `insertelement`.
if (resultType.getRank() == 0) {
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
splatOp, vectorType, undef, adaptor.input(), zero);
return success();
}
// For 1-d vector, we additionally do a `vectorshuffle`.
auto v = rewriter.create<LLVM::InsertElementOp>(
splatOp.getLoc(), vectorType, undef, adaptor.input(), zero);
int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
SmallVector<int32_t, 4> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
zeroAttrs);
return success();
}
};
/// The Splat operation is lowered to an insertelement + a shufflevector
/// operation. Splat to only 2+-d vector result types are lowered by the
/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType resultType = splatOp.getType();
if (resultType.getRank() <= 1)
return failure();
// First insert it into an undef vector so we can shuffle it.
auto loc = splatOp.getLoc();
auto vectorTypeInfo =
LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
if (!llvmNDVectorTy || !llvm1DVectorTy)
return failure();
// Construct returned value.
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
// Construct a 1-D vector with the splatted value that we insert in all the
// places within the returned descriptor.
Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
adaptor.input(), zero);
// Shuffle the value across the desired number of elements.
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
SmallVector<int32_t, 4> zeroValues(width, 0);
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
// Iterate of linear index, convert to coords space and insert splatted 1-D
// vector in each position.
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
position);
});
rewriter.replaceOp(splatOp, desc);
return success();
}
};
} // namespace
/// Populate the given list with patterns that convert from Vector to LLVM.
@ -1085,8 +1178,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorLoadStoreConversion<vector::MaskedStoreOp,
vector::MaskedStoreOpAdaptor>,
VectorGatherOpConversion, VectorScatterOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
converter);
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorSplatOpLowering, VectorSplatNdOpLowering>(converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}

View File

@ -425,7 +425,7 @@ struct Strategy<TransferReadOp> {
Location loc = xferOp.getLoc();
auto bufferType = buffer.getType().dyn_cast<ShapedType>();
auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.padding());
b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
return Value();
@ -855,8 +855,8 @@ struct UnrollTransferReadConversion
if (auto insertOp = getInsertOp(xferOp))
return insertOp.dest();
Location loc = xferOp.getLoc();
return rewriter.create<SplatOp>(loc, xferOp.getVectorType(),
xferOp.padding());
return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
xferOp.padding());
}
/// If the result of the TransferReadOp has exactly one user, which is a
@ -1143,7 +1143,8 @@ struct Strategy1d<TransferReadOp> {
static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
// Inititalize vector with padding value.
Location loc = xferOp.getLoc();
return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding());
return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
xferOp.padding());
}
};

View File

@ -247,6 +247,23 @@ struct VectorInsertStridedSliceOpConvert final
}
};
class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
public:
using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType dstVecType = op.getType();
if (!spirv::CompositeType::isValid(dstVecType))
return failure();
SmallVector<Value, 4> source(dstVecType.getNumElements(), adaptor.input());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
source);
return success();
}
};
} // namespace
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
@ -255,6 +272,6 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
VectorInsertElementOpConvert, VectorInsertOpConvert,
VectorInsertStridedSliceOpConvert>(typeConverter,
patterns.getContext());
VectorInsertStridedSliceOpConvert, VectorSplatPattern>(
typeConverter, patterns.getContext());
}

View File

@ -863,34 +863,6 @@ LogicalResult SelectOp::verify() {
return success();
}
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
LogicalResult SplatOp::verify() {
// TODO: we could replace this by a trait.
if (getOperand().getType() != getType().cast<ShapedType>().getElementType())
return emitError("operand should be of elemental type of result type");
return success();
}
// Constant folding hook for SplatOp.
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1 && "splat takes one operand");
auto constOperand = operands.front();
if (!constOperand || !constOperand.isa<IntegerAttr, FloatAttr>())
return {};
auto shapedType = getType().cast<ShapedType>();
assert(shapedType.getElementType() == constOperand.getType() &&
"incorrect input attribute type for folding");
// SplatElementsAttr::get treats single value for second arg as being a splat.
return SplatElementsAttr::get(shapedType, {constOperand});
}
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//

View File

@ -1800,6 +1800,19 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
return {};
}
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = operands.front();
if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
return {};
// SplatElementsAttr::get treats single value for second arg as being a splat.
return SplatElementsAttr::get(getType(), {constOperand});
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -15,7 +15,6 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@ -2540,7 +2539,7 @@ public:
auto splat = op.vector().getDefiningOp<SplatOp>();
if (!splat)
return failure();
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.input());
return success();
}
};
@ -4369,5 +4368,22 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
patterns.getContext());
}
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = operands.front();
if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
return {};
// SplatElementsAttr::get treats single value for second arg as being a splat.
return SplatElementsAttr::get(getType(), {constOperand});
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"

View File

@ -8,7 +8,6 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"

View File

@ -10,8 +10,8 @@
// transfer_write ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"

View File

@ -17,7 +17,6 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
@ -205,7 +204,7 @@ public:
// Scalar to any vector can use splat.
if (!srcType) {
rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source());
return success();
}
@ -220,7 +219,7 @@ public:
ext = rewriter.create<vector::ExtractElementOp>(loc, op.source());
else
ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
return success();
}
@ -1735,7 +1734,7 @@ struct TransferReadToVectorLoadLowering
// Create vector load op.
Operation *loadOp;
if (read.mask()) {
Value fill = rewriter.create<SplatOp>(
Value fill = rewriter.create<vector::SplatOp>(
read.getLoc(), unbroadcastedVectorType, read.padding());
loadOp = rewriter.create<vector::MaskedLoadOp>(
read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(),
@ -2168,12 +2167,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
// Add in an offset if requested.
if (off) {
Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
}
// Construct the vector comparison.
Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
Value bounds =
rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
bounds);
}

View File

@ -16,7 +16,6 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"

View File

@ -456,36 +456,6 @@ func @dfs_block_order(%arg0: i32) -> (i32) {
// -----
// CHECK-LABEL: @splat_0d
// CHECK-SAME: %[[ARG:.*]]: f32
func @splat_0d(%a: f32) -> vector<f32> {
%v = splat %a : vector<f32>
return %v : vector<f32>
}
// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : vector<1xf32>
// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ARG]], %[[UNDEF]][%[[ZERO]] : i32] : vector<1xf32>
// CHECK-NEXT: llvm.return %[[V]] : vector<1xf32>
// -----
// CHECK-LABEL: @splat
// CHECK-SAME: %[[A:arg[0-9]+]]: vector<4xf32>
// CHECK-SAME: %[[ELT:arg[0-9]+]]: f32
func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
%vb = splat %b : vector<4xf32>
%r = arith.mulf %a, %vb : vector<4xf32>
return %r : vector<4xf32>
}
// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : vector<4xf32>
// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ELT]], %[[UNDEF]][%[[ZERO]] : i32] : vector<4xf32>
// CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0 : i32, 0 : i32, 0 : i32, 0 : i32]
// CHECK-NEXT: %[[SCALE:[0-9]+]] = llvm.fmul %[[A]], %[[SPLAT]] : vector<4xf32>
// CHECK-NEXT: llvm.return %[[SCALE]] : vector<4xf32>
// -----
// CHECK-LABEL: func @ceilf(
// CHECK-SAME: f32
func @ceilf(%arg0 : f32) {

View File

@ -921,21 +921,6 @@ func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
// -----
//===----------------------------------------------------------------------===//
// splat
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @splat
// CHECK-SAME: (%[[A:.+]]: f32)
// CHECK: %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
// CHECK: spv.ReturnValue %[[VAL]]
func @splat(%f : f32) -> vector<4xf32> {
%splat = splat %f : vector<4xf32>
return %splat : vector<4xf32>
}
// -----
//===----------------------------------------------------------------------===//
// std.br, std.cond_br
//===----------------------------------------------------------------------===//

View File

@ -5,17 +5,19 @@
// CMP32-SAME: %[[ARG:.*]]: index)
// CMP32: %[[T0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi32>
// CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32
// CMP32: %[[T2:.*]] = splat %[[T1]] : vector<11xi32>
// CMP32: %[[T3:.*]] = arith.cmpi slt, %[[T0]], %[[T2]] : vector<11xi32>
// CMP32: return %[[T3]] : vector<11xi1>
// CMP32: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi32>
// CMP32: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<11xi32>, vector<11xi32>
// CMP32: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi32>
// CMP32: return %[[T4]] : vector<11xi1>
// CMP64-LABEL: @genbool_var_1d(
// CMP64-SAME: %[[ARG:.*]]: index)
// CMP64: %[[T0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi64>
// CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64
// CMP64: %[[T2:.*]] = splat %[[T1]] : vector<11xi64>
// CMP64: %[[T3:.*]] = arith.cmpi slt, %[[T0]], %[[T2]] : vector<11xi64>
// CMP64: return %[[T3]] : vector<11xi1>
// CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi64>
// CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<11xi64>, vector<11xi64>
// CMP64: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi64>
// CMP64: return %[[T4]] : vector<11xi1>
func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
%0 = vector.create_mask %arg0 : vector<11xi1>

View File

@ -55,8 +55,9 @@ func @broadcast_vec0d_from_f32(%arg0: f32) -> vector<f32> {
}
// CHECK-LABEL: @broadcast_vec0d_from_f32
// CHECK-SAME: %[[A:.*]]: f32)
// CHECK: %[[T0:.*]] = splat %[[A]] : vector<f32>
// CHECK: return %[[T0]] : vector<f32>
// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]]
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<1xf32> to vector<f32>
// CHECK: return %[[T1]] : vector<f32>
// -----
@ -76,8 +77,9 @@ func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> {
}
// CHECK-LABEL: @broadcast_vec1d_from_f32
// CHECK-SAME: %[[A:.*]]: f32)
// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2xf32>
// CHECK: return %[[T0]] : vector<2xf32>
// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]]
// CHECK: %[[T1:.*]] = llvm.shufflevector %[[T0]]
// CHECK: return %[[T1]] : vector<2xf32>
// -----
@ -87,8 +89,11 @@ func @broadcast_vec1d_from_index(%arg0: index) -> vector<2xindex> {
}
// CHECK-LABEL: @broadcast_vec1d_from_index
// CHECK-SAME: %[[A:.*]]: index)
// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2xindex>
// CHECK: return %[[T0]] : vector<2xindex>
// CHECK: %[[A1:.*]] = builtin.unrealized_conversion_cast %[[A]] : index to i64
// CHECK: %[[T0:.*]] = llvm.insertelement %[[A1]]
// CHECK: %[[T1:.*]] = llvm.shufflevector %[[T0]]
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<2xi64> to vector<2xindex>
// CHECK: return %[[T2]] : vector<2xindex>
// -----
@ -98,8 +103,12 @@ func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
}
// CHECK-LABEL: @broadcast_vec2d_from_scalar(
// CHECK-SAME: %[[A:.*]]: f32)
// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2x3xf32>
// CHECK: return %[[T0]] : vector<2x3xf32>
// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]]
// CHECK: %[[T1:.*]] = llvm.shufflevector %[[T0]]
// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
// CHECK: return %[[T4]] : vector<2x3xf32>
// -----
@ -109,8 +118,13 @@ func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
}
// CHECK-LABEL: @broadcast_vec3d_from_scalar(
// CHECK-SAME: %[[A:.*]]: f32)
// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2x3x4xf32>
// CHECK: return %[[T0]] : vector<2x3x4xf32>
// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]]
// CHECK: %[[T1:.*]] = llvm.shufflevector %[[T0]]
// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[0, 0] : !llvm.array<2 x array<3 x vector<4xf32>>>
// ...
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[1, 2] : !llvm.array<2 x array<3 x vector<4xf32>>>
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x array<3 x vector<4xf32>>> to vector<2x3x4xf32>
// CHECK: return %[[T4]] : vector<2x3x4xf32>
// -----
@ -135,7 +149,8 @@ func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<2xf32>
// CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][1] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T6]], %[[T8]][2] : !llvm.array<3 x vector<2xf32>>
@ -228,8 +243,9 @@ func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
// CHECK-SAME: %[[A:.*]]: vector<1xf32>)
// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T2:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T1]] : i64] : vector<1xf32>
// CHECK: %[[T3:.*]] = splat %[[T2]] : vector<4xf32>
// CHECK: return %[[T3]] : vector<4xf32>
// CHECK: %[[T3:.*]] = llvm.insertelement %[[T2]]
// CHECK: %[[T4:.*]] = llvm.shufflevector %[[T3]]
// CHECK: return %[[T4]] : vector<4xf32>
// -----
@ -263,22 +279,26 @@ func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<4 x vector<1xf32>>
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T5:.*]] = llvm.extractelement %[[T3]]{{\[}}%[[T4]] : i64] : vector<1xf32>
// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32>
// CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][0] : !llvm.array<4 x vector<3xf32>>
// CHECK: %[[T10:.*]] = llvm.extractvalue %[[T2]][1] : !llvm.array<4 x vector<1xf32>>
// CHECK: %[[T11:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T12:.*]] = llvm.extractelement %[[T10]]{{\[}}%[[T11]] : i64] : vector<1xf32>
// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<3xf32>
// CHECK: %[[T13Insert:.*]] = llvm.insertelement %[[T12]]
// CHECK: %[[T13:.*]] = llvm.shufflevector %[[T13Insert]]
// CHECK: %[[T14:.*]] = llvm.insertvalue %[[T13]], %[[T8]][1] : !llvm.array<4 x vector<3xf32>>
// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T2]][2] : !llvm.array<4 x vector<1xf32>>
// CHECK: %[[T17:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T18:.*]] = llvm.extractelement %[[T16]]{{\[}}%[[T17]] : i64] : vector<1xf32>
// CHECK: %[[T19:.*]] = splat %[[T18]] : vector<3xf32>
// CHECK: %[[T19Insert:.*]] = llvm.insertelement %[[T18]]
// CHECK: %[[T19:.*]] = llvm.shufflevector %[[T19Insert]]
// CHECK: %[[T20:.*]] = llvm.insertvalue %[[T19]], %[[T14]][2] : !llvm.array<4 x vector<3xf32>>
// CHECK: %[[T22:.*]] = llvm.extractvalue %[[T2]][3] : !llvm.array<4 x vector<1xf32>>
// CHECK: %[[T23:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T24:.*]] = llvm.extractelement %[[T22]]{{\[}}%[[T23]] : i64] : vector<1xf32>
// CHECK: %[[T25:.*]] = splat %[[T24]] : vector<3xf32>
// CHECK: %[[T25Insert:.*]] = llvm.insertelement %[[T24]]
// CHECK: %[[T25:.*]] = llvm.shufflevector %[[T25Insert]]
// CHECK: %[[T26:.*]] = llvm.insertvalue %[[T25]], %[[T20]][3] : !llvm.array<4 x vector<3xf32>>
// CHECK: %[[T27:.*]] = builtin.unrealized_conversion_cast %[[T26]] : !llvm.array<4 x vector<3xf32>> to vector<4x3xf32>
// CHECK: return %[[T27]] : vector<4x3xf32>
@ -332,12 +352,14 @@ func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32
// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T3:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T4:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T3]] : i64] : vector<2xf32>
// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xf32>
// CHECK: %[[T5Insert:.*]] = llvm.insertelement %[[T4]]
// CHECK: %[[T5:.*]] = llvm.shufflevector %[[T5Insert]]
// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][0] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T9:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[T10:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T9]] : i64] : vector<2xf32>
// CHECK: %[[T11:.*]] = splat %[[T10]] : vector<3xf32>
// CHECK: %[[T11Insert:.*]] = llvm.insertelement %[[T10]]
// CHECK: %[[T11:.*]] = llvm.shufflevector %[[T11Insert]]
// CHECK: %[[T12:.*]] = arith.mulf %[[T11]], %[[B]] : vector<3xf32>
// CHECK: %[[T13:.*]] = llvm.insertvalue %[[T12]], %[[T8]][1] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T14:.*]] = builtin.unrealized_conversion_cast %[[T13]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
@ -357,9 +379,10 @@ func @outerproduct_index(%arg0: vector<2xindex>, %arg1: vector<3xindex>) -> vect
// CHECK: %[[T8:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<2x3xindex> to !llvm.array<2 x vector<3xi64>>
// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T3:.*]] = llvm.extractelement %[[T1]]{{\[}}%[[T2]] : i64] : vector<2xi64>
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : i64 to index
// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xindex>
// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xindex>
// CHECK: %[[T4:.*]] = llvm.insertelement %[[T3]]
// CHECK: %[[T5:.*]] = llvm.shufflevector %[[T4]]
// CHECK: %[[T5Cast:.*]] = builtin.unrealized_conversion_cast %[[T5]] : vector<3xi64> to vector<3xindex>
// CHECK: %[[T6:.*]] = arith.muli %[[T5Cast]], %[[B]] : vector<3xindex>
// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T6]] : vector<3xindex> to vector<3xi64>
// CHECK: %{{.*}} = llvm.insertvalue %[[T7]], %[[T8]][0] : !llvm.array<2 x vector<3xi64>>
@ -378,13 +401,15 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
// CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T5:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T4]] : i64] : vector<2xf32>
// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32>
// CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]
// CHECK: %[[T8:.*]] = llvm.extractvalue %[[T7]][0] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T9:.*]] = "llvm.intr.fmuladd"(%[[T6]], %[[B]], %[[T8]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32>
// CHECK: %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][0] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T12:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[T13:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T12]] : i64] : vector<2xf32>
// CHECK: %[[T14:.*]] = splat %[[T13]] : vector<3xf32>
// CHECK: %[[T14Insert:.*]] = llvm.insertelement %[[T13]]
// CHECK: %[[T14:.*]] = llvm.shufflevector %[[T14Insert]]
// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T7]][1] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T17:.*]] = "llvm.intr.fmuladd"(%[[T14]], %[[B]], %[[T16]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32>
// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T11]][1] : !llvm.array<2 x vector<3xf32>>
@ -986,8 +1011,7 @@ func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
// CHECK-LABEL: @extract_strided_slice3(
// CHECK-SAME: %[[ARG:.*]]: vector<4x8xf32>)
// CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x8xf32> to !llvm.array<4 x vector<8xf32>>
// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_2:.*]] = splat %[[VAL_1]] : vector<2x2xf32>
// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
// CHECK: %[[VAL_6:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<8xf32>>
// CHECK: %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T2]] [2, 3] : vector<8xf32>, vector<8xf32>
@ -1233,17 +1257,19 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
//
// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// CHECK: %[[otrunc:.*]] = arith.index_cast %[[BASE]] : index to i32
// CHECK: %[[offsetVec:.*]] = splat %[[otrunc]] : vector<17xi32>
// CHECK: %[[offsetVecInsert:.*]] = llvm.insertelement %[[otrunc]]
// CHECK: %[[offsetVec:.*]] = llvm.shufflevector %[[offsetVecInsert]]
// CHECK: %[[offsetVec2:.*]] = arith.addi %[[offsetVec]], %[[linearIndex]] : vector<17xi32>
//
// 3. Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
// CHECK: %[[dtrunc:.*]] = arith.index_cast %[[DIM]] : index to i32
// CHECK: %[[dimVec:.*]] = splat %[[dtrunc]] : vector<17xi32>
// CHECK: %[[dimVecInsert:.*]] = llvm.insertelement %[[dtrunc]]
// CHECK: %[[dimVec:.*]] = llvm.shufflevector %[[dimVecInsert]]
// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[offsetVec2]], %[[dimVec]] : vector<17xi32>
//
// 4. Create pass-through vector.
// CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32>
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<17xf32>
//
// 5. Bitcast to vector form.
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
@ -1262,12 +1288,12 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
// CHECK-SAME: vector<17xi32>
//
// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// CHECK: splat %{{.*}} : vector<17xi32>
// CHECK: llvm.shufflevector %{{.*}} : vector<17xi32>
// CHECK: arith.addi
//
// 3. Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
// CHECK: splat %{{.*}} : vector<17xi32>
// CHECK: llvm.shufflevector %{{.*}} : vector<17xi32>
// CHECK: %[[mask_b:.*]] = arith.cmpi slt, {{.*}} : vector<17xi32>
//
// 4. Bitcast to vector form.
@ -1295,8 +1321,7 @@ func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<17xin
}
// CHECK-LABEL: func @transfer_read_index_1d
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xindex>
// CHECK: %[[C7:.*]] = arith.constant 7 : index
// CHECK: %[[SPLAT:.*]] = splat %[[C7]] : vector<17xindex>
// CHECK: %[[SPLAT:.*]] = arith.constant dense<7> : vector<17xindex>
// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[SPLAT]] : vector<17xindex> to vector<17xi64>
// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} :
@ -1321,12 +1346,14 @@ func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index)
//
// Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// CHECK: %[[trunc:.*]] = arith.index_cast %[[BASE_1]] : index to i32
// CHECK: %[[offsetVec:.*]] = splat %[[trunc]] : vector<17xi32>
// CHECK: %[[offsetVecInsert:.*]] = llvm.insertelement %[[trunc]]
// CHECK: %[[offsetVec:.*]] = llvm.shufflevector %[[offsetVecInsert]]
//
// Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
// CHECK: %[[dimtrunc:.*]] = arith.index_cast %[[DIM]] : index to i32
// CHECK: splat %[[dimtrunc]] : vector<17xi32>
// CHECK: %[[dimtruncInsert:.*]] = llvm.insertelement %[[dimtrunc]]
// CHECK: llvm.shufflevector %[[dimtruncInsert]]
// -----
@ -1451,9 +1478,11 @@ func @create_mask_0d(%a : index) -> vector<i1> {
// CHECK-SAME: %[[arg:.*]]: index
// CHECK: %[[indices:.*]] = arith.constant dense<0> : vector<i32>
// CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32
// CHECK: %[[bounds:.*]] = splat %[[arg_i32]] : vector<i32>
// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<i32>
// CHECK: %[[bounds:.*]] = llvm.insertelement %[[arg_i32]]
// CHECK: %[[boundsCast:.*]] = builtin.unrealized_conversion_cast %[[bounds]] : vector<1xi32> to vector<i32>
// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[boundsCast]] : vector<i32>
// CHECK: return %[[result]] : vector<i1>
// -----
func @create_mask_1d(%a : index) -> vector<4xi1> {
@ -1465,7 +1494,8 @@ func @create_mask_1d(%a : index) -> vector<4xi1> {
// CHECK-SAME: %[[arg:.*]]: index
// CHECK: %[[indices:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
// CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32
// CHECK: %[[bounds:.*]] = splat %[[arg_i32]] : vector<4xi32>
// CHECK: %[[boundsInsert:.*]] = llvm.insertelement %[[arg_i32]]
// CHECK: %[[bounds:.*]] = llvm.shufflevector %[[boundsInsert]]
// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32>
// CHECK: return %[[result]] : vector<4xi1>
@ -1728,3 +1758,34 @@ func @compress_store_op_index(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %arg
}
// CHECK-LABEL: func @compress_store_op_index
// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %{{.*}}, %{{.*}}) : (vector<11xi64>, !llvm.ptr<i64>, vector<11xi1>) -> ()
// -----
// CHECK-LABEL: @splat_0d
// CHECK-SAME: %[[ARG:.*]]: f32
func @splat_0d(%a: f32) -> vector<f32> {
%v = vector.splat %a : vector<f32>
return %v : vector<f32>
}
// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : vector<1xf32>
// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ARG]], %[[UNDEF]][%[[ZERO]] : i32] : vector<1xf32>
// CHECK-NEXT: %[[VCAST:[0-9]+]] = builtin.unrealized_conversion_cast %[[V]] : vector<1xf32> to vector<f32>
// CHECK-NEXT: return %[[VCAST]] : vector<f32>
// -----
// CHECK-LABEL: @splat
// CHECK-SAME: %[[A:arg[0-9]+]]: vector<4xf32>
// CHECK-SAME: %[[ELT:arg[0-9]+]]: f32
func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
%vb = vector.splat %b : vector<4xf32>
%r = arith.mulf %a, %vb : vector<4xf32>
return %r : vector<4xf32>
}
// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : vector<4xf32>
// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ELT]], %[[UNDEF]][%[[ZERO]] : i32] : vector<4xf32>
// CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0 : i32, 0 : i32, 0 : i32, 0 : i32]
// CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[A]], %[[SPLAT]] : vector<4xf32>
// CHECK-NEXT: return %[[SCALE]] : vector<4xf32>

View File

@ -168,3 +168,14 @@ func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf
%0 = vector.fma %a, %b, %c: vector<4xf32>
return %0 : vector<4xf32>
}
// -----
// CHECK-LABEL: func @splat
// CHECK-SAME: (%[[A:.+]]: f32)
// CHECK: %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
// CHECK: return %[[VAL]]
func @splat(%f : f32) -> vector<4xf32> {
%splat = vector.splat %f : vector<4xf32>
return %splat : vector<4xf32>
}

View File

@ -50,10 +50,3 @@ func @switch_i64(%flag : i64, %caseOperand : i32) {
^bb3(%bb3arg : i32):
return
}
// CHECK-LABEL: func @vector_splat_0d(
func @vector_splat_0d(%a: f32) -> vector<f32> {
// CHECK: splat %{{.*}} : vector<f32>
%0 = splat %a : vector<f32>
return %0 : vector<f32>
}

View File

@ -1219,3 +1219,15 @@ func @propogate_index_cast(%arg0: tensor<1xi32>) -> index {
%1 = tensor.extract %0[%c0] : tensor<1xindex>
return %1 : index
}
// -----
// CHECK-LABEL: func @splat_fold
func @splat_fold() -> tensor<4xf32> {
%c = arith.constant 1.0 : f32
%t = tensor.splat %c : tensor<4xf32>
return %t : tensor<4xf32>
// CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32>
// CHECK-NEXT: return [[T]] : tensor<4xf32>
}

View File

@ -363,3 +363,18 @@ func @pad_yield_type(%arg0: tensor<?x4xi32>, %arg1: i8) -> tensor<?x9xi32> {
return %0 : tensor<?x9xi32>
}
// -----
func @invalid_splat(%v : f32) {
// expected-error@+1 {{invalid kind of type specified}}
tensor.splat %v : memref<8xf32>
return
}
// -----
func @invalid_splat(%v : vector<8xf32>) {
// expected-error@+1 {{must be integer/index/float type}}
%w = tensor.splat %v : tensor<8xvector<8xf32>>
return
}

View File

@ -250,3 +250,13 @@ func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
// -----
// CHECK-LABEL: func @test_splat_op
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
func @test_splat_op(%s : f32) {
// CHECK: tensor.splat [[S]] : tensor<8xf32>
%v = tensor.splat %s : tensor<8xf32>
// CHECK: tensor.splat [[S]] : tensor<4xf32>
%u = "tensor.splat"(%s) : (f32) -> tensor<4xf32>
return
}

View File

@ -515,7 +515,7 @@ func @fold_extract_broadcast(%a : f32) -> f32 {
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
func @fold_extract_splat(%a : f32) -> f32 {
%b = splat %a : vector<1x2x4xf32>
%b = vector.splat %a : vector<1x2x4xf32>
%r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32>
return %r : f32
}
@ -1121,10 +1121,10 @@ func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: vector<
// -----
// CHECK-LABEL: extract_strided_splat
// CHECK: %[[B:.*]] = splat %{{.*}} : vector<2x4xf16>
// CHECK: %[[B:.*]] = vector.splat %{{.*}} : vector<2x4xf16>
// CHECK-NEXT: return %[[B]] : vector<2x4xf16>
func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
%0 = splat %arg0 : vector<16x4xf16>
%0 = vector.splat %arg0 : vector<16x4xf16>
%1 = vector.extract_strided_slice %0
{offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
vector<16x4xf16> to vector<2x4xf16>
@ -1242,3 +1242,15 @@ func @extract_extract_strided2(%A: vector<2x4xf32>)
%1 = vector.extract %0[0] : vector<1x4xf32>
return %1 : vector<4xf32>
}
// -----
// CHECK-LABEL: func @splat_fold
func @splat_fold() -> vector<4xf32> {
%c = arith.constant 1.0 : f32
%v = vector.splat %c : vector<4xf32>
return %v : vector<4xf32>
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
// CHECK-NEXT: return [[V]] : vector<4xf32>
}

View File

@ -300,7 +300,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
%vf0 = vector.splat %f0 : vector<4x3xf32>
// expected-error@+1 {{ requires memref or ranked tensor type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : vector<4x3xf32>, vector<1x1x2x3xf32>
}
@ -310,7 +310,7 @@ func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
func @test_vector.transfer_read(%arg0: memref<4x3xf32>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
%vf0 = vector.splat %f0 : vector<4x3xf32>
// expected-error@+1 {{ requires vector type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<4x3xf32>, f32
}
@ -376,7 +376,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
%c3 = arith.constant 3 : index
%cst = arith.constant 3.0 : f32
// expected-note@+1 {{prior use here}}
%mask = splat %c1 : vector<3x8x7xi1>
%mask = vector.splat %c1 : vector<3x8x7xi1>
// expected-error@+1 {{expects different type than prior uses: 'vector<3x7xi1>' vs 'vector<3x8x7xi1>'}}
%0 = vector.transfer_read %arg0[%c3, %c3, %c3], %cst, %mask {permutation_map = affine_map<(d0, d1, d2)->(d0, 0, d2)>} : memref<?x?x?xf32>, vector<3x8x7xf32>
}
@ -386,7 +386,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
%vf0 = vector.splat %f0 : vector<4x3xf32>
// expected-error@+1 {{requires source vector element and vector result ranks to match}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<3xf32>
}
@ -396,7 +396,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = splat %f0 : vector<6xf32>
%vf0 = vector.splat %f0 : vector<6xf32>
// expected-error@+1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the source}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<6xf32>>, vector<3xf32>
}
@ -406,7 +406,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = splat %f0 : vector<2x3xf32>
%vf0 = vector.splat %f0 : vector<2x3xf32>
// expected-error@+1 {{ expects the optional in_bounds attr of same rank as permutation_map results: affine_map<(d0, d1) -> (d0, d1)>}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {in_bounds = [true], permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
}
@ -416,7 +416,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = splat %f0 : vector<2x3xf32>
%vf0 = vector.splat %f0 : vector<2x3xf32>
// expected-error@+1 {{requires broadcast dimensions to be in-bounds}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {in_bounds = [false, true], permutation_map = affine_map<(d0, d1)->(0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
}
@ -426,8 +426,8 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = splat %f0 : vector<2x3xf32>
%mask = splat %c1 : vector<2x3xi1>
%vf0 = vector.splat %f0 : vector<2x3xf32>
%mask = vector.splat %c1 : vector<2x3xi1>
// expected-error@+1 {{does not support masks with vector element type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0, %mask {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
}
@ -446,7 +446,7 @@ func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
func @test_vector.transfer_write(%arg0: memref<vector<4x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
%vf0 = vector.splat %f0 : vector<4x3xf32>
// expected-error@+1 {{ requires vector type}}
vector.transfer_write %arg0, %arg0[%c3, %c3] : memref<vector<4x3xf32>>, vector<4x3xf32>
}
@ -456,7 +456,7 @@ func @test_vector.transfer_write(%arg0: memref<vector<4x3xf32>>) {
func @test_vector.transfer_write(%arg0: vector<4x3xf32>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
%vf0 = vector.splat %f0 : vector<4x3xf32>
// expected-error@+1 {{ requires memref or ranked tensor type}}
vector.transfer_write %arg0, %arg0[%c3, %c3] : vector<4x3xf32>, f32
}
@ -1506,3 +1506,11 @@ func @scan_incompatible_shapes(%arg0: vector<2x3xi32>, %arg1: vector<5xi32>) ->
vector<2x3xi32>, vector<5xi32>
return %0#0 : vector<2x3xi32>
}
// -----
func @invalid_splat(%v : f32) {
// expected-error@+1 {{invalid kind of type specified}}
vector.splat %v : memref<8xf32>
return
}

View File

@ -45,11 +45,11 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%i0 = arith.constant 0 : index
%i1 = arith.constant 1 : i1
%vf0 = splat %f0 : vector<4x3xf32>
%v0 = splat %c0 : vector<4x3xi32>
%vi0 = splat %i0 : vector<4x3xindex>
%vf0 = vector.splat %f0 : vector<4x3xf32>
%v0 = vector.splat %c0 : vector<4x3xi32>
%vi0 = vector.splat %i0 : vector<4x3xindex>
%m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
%m2 = splat %i1 : vector<5x4xi1>
%m2 = vector.splat %i1 : vector<5x4xi1>
//
// CHECK: vector.transfer_read
%0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : memref<?x?xf32>, vector<128xf32>
@ -106,9 +106,9 @@ func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
%c0 = arith.constant 0 : i32
%i0 = arith.constant 0 : index
%vf0 = splat %f0 : vector<4x3xf32>
%v0 = splat %c0 : vector<4x3xi32>
%vi0 = splat %i0 : vector<4x3xindex>
%vf0 = vector.splat %f0 : vector<4x3xf32>
%v0 = vector.splat %c0 : vector<4x3xi32>
%vi0 = vector.splat %i0 : vector<4x3xindex>
//
// CHECK: vector.transfer_read
@ -725,3 +725,21 @@ func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
vector<4x8x16x32xf32>, vector<4x16x32xf32>
return %2#0 : vector<4x8x16x32xf32>
}
// CHECK-LABEL: func @test_splat_op
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
func @test_splat_op(%s : f32) {
// CHECK: vector.splat [[S]] : vector<8xf32>
%v = vector.splat %s : vector<8xf32>
// CHECK: vector.splat [[S]] : vector<4xf32>
%u = "vector.splat"(%s) : (f32) -> vector<4xf32>
return
}
// CHECK-LABEL: func @vector_splat_0d(
func @vector_splat_0d(%a: f32) -> vector<f32> {
// CHECK: vector.splat %{{.*}} : vector<f32>
%0 = vector.splat %a : vector<f32>
return %0 : vector<f32>
}

View File

@ -268,11 +268,11 @@ func @full_contract2(%arg0: vector<2x3xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32>
// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xf32>
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32>
// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xf32>
// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
// CHECK: return %[[T7]] : vector<2x3xf32>
@ -289,12 +289,12 @@ func @outerproduct_noacc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32>
// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xf32>
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32>
// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32>
// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32>
// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32>
// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
@ -312,11 +312,11 @@ func @outerproduct_acc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32>
// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi32>
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xi32>
// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xi32>
// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
// CHECK: return %[[T7]] : vector<2x3xi32>
@ -332,13 +332,13 @@ func @outerproduct_noacc_int(%arg0: vector<2xi32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32>
// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi32>
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32>
// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi32>
// CHECK: %[[T7:.*]] = splat %[[T6]] : vector<3xi32>
// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32>
// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
@ -354,7 +354,7 @@ func @outerproduct_acc_int(%arg0: vector<2xi32>,
// CHECK-LABEL: func @axpy_fp(
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32)
// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32>
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
@ -366,7 +366,7 @@ func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32,
// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32>
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
@ -377,7 +377,7 @@ func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) ->
// CHECK-LABEL: func @axpy_int(
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32)
// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32>
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: return %[[T1]] : vector<16xi32>
func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
@ -389,7 +389,7 @@ func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32,
// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32>
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
// CHECK: return %[[T2]] : vector<16xi32>
@ -612,7 +612,7 @@ func @matmul(%arg0: vector<2x4xf32>,
// CHECK-LABEL: func @broadcast_vec1d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2xf32>
// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32>
// CHECK: return %[[T0]] : vector<2xf32>
func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
@ -622,7 +622,7 @@ func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
// CHECK-LABEL: func @broadcast_vec2d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32>
// CHECK: return %[[T0]] : vector<2x3xf32>
func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@ -632,7 +632,7 @@ func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
// CHECK-LABEL: func @broadcast_vec3d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2x3x4xf32>
// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32>
// CHECK: return %[[T0]] : vector<2x3x4xf32>
func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@ -697,7 +697,7 @@ func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> {
// CHECK-LABEL: func @broadcast_stretch
// CHECK-SAME: %[[A:.*0]]: vector<1xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1xf32>
// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<4xf32>
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32>
// CHECK: return %[[T1]] : vector<4xf32>
func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
@ -723,16 +723,16 @@ func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> {
// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1xf32>
// CHECK: %[[T2:.*]] = splat %[[T0]] : vector<3xf32>
// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<4x1xf32>
// CHECK: %[[T6:.*]] = splat %[[T4]] : vector<3xf32>
// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<4x1xf32>
// CHECK: %[[T10:.*]] = splat %[[T8]] : vector<3xf32>
// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : vector<4x1xf32>
// CHECK: %[[T14:.*]] = splat %[[T12]] : vector<3xf32>
// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32>
// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
// CHECK: return %[[T15]] : vector<4x3xf32>

View File

@ -282,19 +282,19 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
%c0 = arith.constant 0 : index
%m = arith.constant 1 : i1
%mask0 = splat %m : vector<7x14xi1>
%mask0 = vector.splat %m : vector<7x14xi1>
%0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
// CHECK: %[[MASK0:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1>
// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
// CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32>
%mask1 = splat %m : vector<14x16xi1>
%mask1 = vector.splat %m : vector<14x16xi1>
%1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask1 {permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
// CHECK: %[[MASK1:.*]] = vector.transpose {{.*}} : vector<14x16xi1> to vector<16x14xi1>
// CHECK: vector.transfer_read {{.*}} %[[MASK1]] {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
%mask2 = splat %m : vector<7x14xi1>
%mask2 = vector.splat %m : vector<7x14xi1>
%2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
// CHECK: %[[MASK2:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1>
// CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
@ -336,7 +336,7 @@ func @transfer_write_permutations(
%c0 = arith.constant 0 : index
%m = arith.constant 1 : i1
%mask0 = splat %m : vector<7x14x8x16xi1>
%mask0 = vector.splat %m : vector<7x14x8x16xi1>
%0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>
// CHECK: %[[NEW_MASK0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xi1> to vector<8x14x16x7xi1>
// CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32>

View File

@ -295,18 +295,6 @@ func @test_dimop(%arg0: tensor<4x4x?xf32>) {
return
}
// CHECK-LABEL: func @test_splat_op
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
func @test_splat_op(%s : f32) {
%v = splat %s : vector<8xf32>
// CHECK: splat [[S]] : vector<8xf32>
%t = splat %s : tensor<8xf32>
// CHECK: splat [[S]] : tensor<8xf32>
%u = "std.splat"(%s) : (f32) -> vector<4xf32>
// CHECK: splat [[S]] : vector<4xf32>
return
}
// CHECK-LABEL: func @tensor_load_store
func @tensor_load_store(%0 : memref<4x4xi32>, %1 : tensor<4x4xi32>) {
// CHECK-SAME: (%[[MEMREF:.*]]: memref<4x4xi32>,

View File

@ -106,24 +106,8 @@ func @return_not_in_function() {
// -----
func @invalid_splat(%v : f32) {
splat %v : memref<8xf32>
// expected-error@-1 {{must be vector of any type values or statically shaped tensor of any type values}}
return
}
// -----
func @invalid_splat(%v : vector<8xf32>) {
%w = splat %v : tensor<8xvector<8xf32>>
// expected-error@-1 {{must be integer/index/float type}}
return
}
// -----
func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
splat %v : vector<8xf64>
vector.splat %v : vector<8xf64>
// expected-error@-1 {{expects different type than prior uses}}
return
}

View File

@ -22,7 +22,7 @@ func @print_vector_0d(%a: vector<f32>) {
}
func @splat_0d(%a: f32) {
%1 = splat %a : vector<f32>
%1 = vector.splat %a : vector<f32>
// CHECK: ( 42 )
vector.print %1: vector<f32>
return

View File

@ -14,9 +14,9 @@
!vector_type_R = type vector<7xf32>
func @vector_outerproduct_splat_8x8(%fa: f32, %fb: f32, %fc: f32) -> !vector_type_C {
%a = splat %fa: !vector_type_A
%b = splat %fb: !vector_type_B
%c = splat %fc: !vector_type_C
%a = vector.splat %fa: !vector_type_A
%b = vector.splat %fb: !vector_type_B
%c = vector.splat %fc: !vector_type_C
%d = vector.outerproduct %a, %b, %c : !vector_type_A, !vector_type_B
return %d: !vector_type_C
}

View File

@ -14,9 +14,9 @@
!vector_type_R = type vector<7xi64>
func @vector_outerproduct_splat_8x8(%ia: i64, %ib: i64, %ic: i64) -> !vector_type_C {
%a = splat %ia: !vector_type_A
%b = splat %ib: !vector_type_B
%c = splat %ic: !vector_type_C
%a = vector.splat %ia: !vector_type_A
%b = vector.splat %ib: !vector_type_B
%c = vector.splat %ic: !vector_type_C
%d = vector.outerproduct %a, %b, %c : !vector_type_A, !vector_type_B
return %d: !vector_type_C
}

View File

@ -138,7 +138,7 @@ func @transfer_read_1d_mask_in_bounds(
// Non-contiguous, strided store.
func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fn1 = arith.constant -1.0 : f32
%vf0 = splat %fn1 : vector<7xf32>
%vf0 = vector.splat %fn1 : vector<7xf32>
vector.transfer_write %vf0, %A[%base1, %base2]
{permutation_map = affine_map<(d0, d1) -> (d0)>}
: vector<7xf32>, memref<?x?xf32>
@ -148,7 +148,7 @@ func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
// Non-contiguous, strided store.
func @transfer_write_1d_mask(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
%fn1 = arith.constant -2.0 : f32
%vf0 = splat %fn1 : vector<7xf32>
%vf0 = vector.splat %fn1 : vector<7xf32>
%mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
vector.transfer_write %vf0, %A[%base1, %base2], %mask
{permutation_map = affine_map<(d0, d1) -> (d0)>}

View File

@ -111,7 +111,7 @@ func @transfer_read_2d_broadcast(
// Vector store.
func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fn1 = arith.constant -1.0 : f32
%vf0 = splat %fn1 : vector<1x4xf32>
%vf0 = vector.splat %fn1 : vector<1x4xf32>
vector.transfer_write %vf0, %A[%base1, %base2]
{permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
vector<1x4xf32>, memref<?x?xf32>
@ -122,7 +122,7 @@ func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%fn1 = arith.constant -2.0 : f32
%mask = arith.constant dense<[[1, 0, 1, 0]]> : vector<1x4xi1>
%vf0 = splat %fn1 : vector<1x4xf32>
%vf0 = vector.splat %fn1 : vector<1x4xf32>
vector.transfer_write %vf0, %A[%base1, %base2], %mask
{permutation_map = affine_map<(d0, d1) -> (d0, d1)>} :
vector<1x4xf32>, memref<?x?xf32>

View File

@ -72,7 +72,7 @@ func @transfer_read_3d_transposed(%A : memref<?x?x?x?xf32>,
func @transfer_write_3d(%A : memref<?x?x?x?xf32>,
%o: index, %a: index, %b: index, %c: index) {
%fn1 = arith.constant -1.0 : f32
%vf0 = splat %fn1 : vector<2x9x3xf32>
%vf0 = vector.splat %fn1 : vector<2x9x3xf32>
vector.transfer_write %vf0, %A[%o, %a, %b, %c]
: vector<2x9x3xf32>, memref<?x?x?x?xf32>
return

View File

@ -45,7 +45,7 @@ func @transfer_read_mask_inbounds_4(%A : memref<?xf32>, %base: index) {
func @transfer_write_1d(%A : memref<?xf32>, %base: index) {
%f0 = arith.constant 0.0 : f32
%vf0 = splat %f0 : vector<4xf32>
%vf0 = vector.splat %f0 : vector<4xf32>
vector.transfer_write %vf0, %A[%base]
{permutation_map = affine_map<(d0) -> (d0)>} :
vector<4xf32>, memref<?xf32>

View File

@ -5,7 +5,7 @@
func @transfer_write16_inbounds_1d(%A : memref<?xf32>, %base: index) {
%f = arith.constant 16.0 : f32
%v = splat %f : vector<16xf32>
%v = vector.splat %f : vector<16xf32>
vector.transfer_write %v, %A[%base]
{permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]}
: vector<16xf32>, memref<?xf32>
@ -14,7 +14,7 @@ func @transfer_write16_inbounds_1d(%A : memref<?xf32>, %base: index) {
func @transfer_write13_1d(%A : memref<?xf32>, %base: index) {
%f = arith.constant 13.0 : f32
%v = splat %f : vector<13xf32>
%v = vector.splat %f : vector<13xf32>
vector.transfer_write %v, %A[%base]
{permutation_map = affine_map<(d0) -> (d0)>}
: vector<13xf32>, memref<?xf32>
@ -23,7 +23,7 @@ func @transfer_write13_1d(%A : memref<?xf32>, %base: index) {
func @transfer_write17_1d(%A : memref<?xf32>, %base: index) {
%f = arith.constant 17.0 : f32
%v = splat %f : vector<17xf32>
%v = vector.splat %f : vector<17xf32>
vector.transfer_write %v, %A[%base]
{permutation_map = affine_map<(d0) -> (d0)>}
: vector<17xf32>, memref<?xf32>

View File

@ -789,18 +789,6 @@ func @custom_insertion_position() {
return
}
// CHECK-LABEL: func @splat_fold
func @splat_fold() -> (vector<4xf32>, tensor<4xf32>) {
%c = arith.constant 1.0 : f32
%v = splat %c : vector<4xf32>
%t = splat %c : tensor<4xf32>
return %v, %t : vector<4xf32>, tensor<4xf32>
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
// CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32>
// CHECK-NEXT: return [[V]], [[T]] : vector<4xf32>, tensor<4xf32>
}
// -----
// CHECK-LABEL: func @subview_scalar_fold

View File

@ -56,7 +56,7 @@ func private @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface
func @vector_splat_2d() {
%c0 = arith.constant 0 : index
%f10 = arith.constant 10.0 : f32
%vf10 = splat %f10: !vector_type_C
%vf10 = vector.splat %f10: !vector_type_C
%C = memref.alloc() : !matrix_type_CC
memref.store %vf10, %C[%c0, %c0]: !matrix_type_CC