[mlir][Linalg] Fix tensor.extract_slice(linalg.init_tensor) canonicalization for rank-reducing extract

Differential Revision: https://reviews.llvm.org/D105636
This commit is contained in:
Nicolas Vasilache 2021-07-08 14:26:57 +00:00
parent 8c7ff9da90
commit 4747e1b83b
2 changed files with 15 additions and 3 deletions

View File

@ -772,11 +772,11 @@ struct FoldInitTensorWithExtractSliceOp
PatternRewriter &rewriter) const override {
if (!sliceOp.source().getDefiningOp<linalg::InitTensorOp>())
return failure();
// ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved
// as well as its result type.
rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
sliceOp, sliceOp.sizes(),
llvm::to_vector<4>(llvm::map_range(
sliceOp.static_sizes(),
[](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); })),
sliceOp.result().getType().cast<RankedTensorType>().getShape(),
sliceOp.getSourceType().getElementType());
return success();
}

View File

@ -890,3 +890,15 @@ func @init_canonicalize(%i : index) {
return
}
// -----
// CHECK-LABEL: func @rank_reducing_init_extract
func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> {
// CHECK: linalg.init_tensor [2] : tensor<2xf32>
%a = linalg.init_tensor [%sz, 2] : tensor<?x2xf32>
// CHECK-NOT: extract
%r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
return %r: tensor<2xf32>
}