[tosa][mlir] Support dynamic batch dimension for ops where the batch dim is explicit

Dynamic batch for rescale, gather, max_pool, avg_pool, conv2D and depthwise_conv2D. Split helper functions into a separate header file.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D117031
This commit is contained in:
natashaknk 2022-01-12 14:10:27 -08:00 committed by Rob Suderman
parent 676bfb2a22
commit 310e9636ca
8 changed files with 307 additions and 105 deletions

View File

@ -0,0 +1,84 @@
//===- ConversionUtils.h - Helper functions for tosa conversion -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Utility functions for TOSA lowering
//
//===----------------------------------------------------------------------===//
#ifndef DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
#define DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace tosa {
// Creates a SmallVector of Stringrefs for N parallel loops
SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops);
// Takes a vector of values and condenses them to a vector with no gaps.
SmallVector<Value> condenseValues(const SmallVector<Value> &values);
// Takes the parameters for a clamp and turns it into a series of ops.
template <typename T, typename P>
mlir::SelectOp clampHelper(Location loc, Value arg, arith::ConstantOp min,
arith::ConstantOp max, P pred, OpBuilder &rewriter) {
auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
auto minOrArg =
rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
}
// Returns the values in an attribute as an array of values.
template <typename T>
void getValuesFromIntArrayAttribute(ArrayAttr attr,
SmallVector<T> &arrayValues) {
for (Attribute val : attr.getValue()) {
arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
}
}
// Checks for a dynamic batch dim in any of the passed parameters of an op.
// The batch dimention must be #0 and the rest of the dimensions must be static.
template <typename Op>
Optional<SmallVector<Value>> checkHasDynamicBatchDims(PatternRewriter &rewriter,
Op op,
ArrayRef<Value> params) {
SmallVector<ShapedType> dynTypes;
SmallVector<Value> dynamicDims;
for (const Value &param : params) {
auto paramTy = param.getType().cast<ShapedType>();
if (!paramTy.hasStaticShape())
dynTypes.push_back(paramTy);
}
if (dynTypes.empty())
return dynamicDims;
for (const ShapedType &dynTy : dynTypes) {
if (llvm::any_of(dynTy.getShape().drop_front(), ShapedType::isDynamic)) {
(void)rewriter.notifyMatchFailure(
op, "input can only be dynamic for batch size");
return llvm::None;
}
}
dynamicDims.push_back(
rewriter.create<tensor::DimOp>(op->getLoc(), params[0], 0));
return dynamicDims;
}
} // namespace tosa
} // namespace mlir
#endif // DIALECT_TOSA_UTILS_COVERSION_UTILS_H_

View File

@ -18,6 +18,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
@ -27,10 +28,7 @@
#include <numeric>
using namespace mlir;
static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
}
using namespace mlir::tosa;
template <typename T>
static arith::ConstantOp
@ -42,33 +40,6 @@ createConstFromIntAttribute(Operation *op, const std::string &attrName,
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
}
template <typename T>
static void getValuesFromIntArrayAttribute(ArrayAttr attr,
SmallVector<T> &arrayValues) {
for (Attribute val : attr.getValue()) {
arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
}
}
template <typename T, typename P>
static mlir::SelectOp clampHelper(Location loc, Value arg,
arith::ConstantOp min, arith::ConstantOp max,
P pred, OpBuilder &rewriter) {
auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
auto minOrArg =
rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
}
static SmallVector<Value> filterDynamicDims(const SmallVector<Value> &dynDims) {
SmallVector<Value> filteredDims;
for (auto dim : dynDims)
if (dim)
filteredDims.push_back(dim);
return filteredDims;
}
static Value
createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
ArrayRef<Type> resultTypes,
@ -665,7 +636,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
}
}
SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
SmallVector<Value> filteredDims = condenseValues(dynDims);
for (auto result : results) {
auto resultTy = result.getType().template cast<ShapedType>();
@ -1184,7 +1155,7 @@ public:
inputExprs[value] = rewriter.getAffineDimExpr(index);
}
SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
SmallVector<Value> filteredDims = condenseValues(dynDims);
auto initTensor = rewriter.create<linalg::InitTensorOp>(
loc, filteredDims, resultTy.getShape(), resultTy.getElementType());
@ -1221,9 +1192,11 @@ public:
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires scale32 for double_round to be true");
if (!outputTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "tosa to linalg conversion expects statically shaped tensors");
auto dynamicDimsOr =
checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
if (!dynamicDimsOr.hasValue())
return failure();
SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
// The shift and multiplier values.
SmallVector<int32_t> multiplierValues;
@ -1299,8 +1272,7 @@ public:
// Construct the indexing maps needed for linalg.generic ops.
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ArrayRef<Value>({}), outputTy.getShape(),
outputTy.getElementType());
loc, dynamicDims, outputTy.getShape(), outputTy.getElementType());
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
@ -1412,16 +1384,17 @@ public:
auto imageH = inputTy.getShape()[1];
auto imageW = inputTy.getShape()[2];
if (!resultTy.hasStaticShape())
auto dynamicDimsOr =
checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
if (!dynamicDimsOr.hasValue())
return failure();
SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR")
return failure();
auto initTensor =
rewriter
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
resultTy.getShape(), resultElementTy)
.result();
auto initTensor = rewriter.create<linalg::InitTensorOp>(
loc, dynamicDims, resultTy.getShape(), resultElementTy);
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
@ -2098,13 +2071,13 @@ public:
auto input = adaptor.getOperands()[0];
auto indices = adaptor.getOperands()[1];
auto inputTy = input.getType().cast<ShapedType>();
auto indicesTy = indices.getType().cast<ShapedType>();
auto resultTy = op.getType().cast<ShapedType>();
if (!inputTy.hasStaticShape() || !indicesTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "require input type to have static shape");
auto dynamicDimsOr =
checkHasDynamicBatchDims(rewriter, op, {input, indices, op.output()});
if (!dynamicDimsOr.hasValue())
return failure();
SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
auto resultElementTy = resultTy.getElementType();
@ -2112,8 +2085,8 @@ public:
auto initTensor =
rewriter
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
resultTy.getShape(), resultElementTy)
.create<linalg::InitTensorOp>(loc, dynamicDims, resultTy.getShape(),
resultElementTy)
.result();
SmallVector<AffineMap, 2> affineMaps = {

View File

@ -18,6 +18,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
@ -27,29 +28,7 @@
#include <numeric>
using namespace mlir;
static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
}
template <typename T>
static void getValuesFromIntArrayAttribute(ArrayAttr attr,
SmallVector<T> &arrayValues) {
for (Attribute val : attr.getValue()) {
arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
}
}
template <typename T, typename P>
static mlir::SelectOp clampHelper(Location loc, Value arg,
arith::ConstantOp min, arith::ConstantOp max,
P pred, OpBuilder &rewriter) {
auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
auto minOrArg =
rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
}
using namespace mlir::tosa;
static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
Attribute padAttr, OpBuilder &rewriter) {
@ -82,14 +61,6 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
.result();
}
static SmallVector<Value> filterDynamicDims(const SmallVector<Value> &dynDims) {
SmallVector<Value> filteredDims;
for (auto dim : dynDims)
if (dim)
filteredDims.push_back(dim);
return filteredDims;
}
namespace {
class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
@ -116,10 +87,15 @@ public:
auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
bool isQuantized = op->hasAttr("quantization_info");
if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
return rewriter.notifyMatchFailure(op,
"tosa.conv ops require static shapes");
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "tosa.conv ops require static shapes for weight and bias");
auto dynamicDimsOr =
checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
if (!dynamicDimsOr.hasValue())
return failure();
SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
if (inputETy.isUnsignedInteger())
return rewriter.notifyMatchFailure(
@ -172,7 +148,7 @@ public:
Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultETy);
loc, dynamicDims, resultTy.getShape(), resultETy);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Value zeroTensor =
rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
@ -197,7 +173,7 @@ public:
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultETy);
loc, dynamicDims, resultTy.getShape(), resultETy);
if (isQuantized) {
auto quantizationInfo =
@ -292,10 +268,15 @@ public:
quantizationInfo.weight_zp().getValue().getSExtValue());
}
if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
return rewriter.notifyMatchFailure(op,
"tosa.conv ops require static shapes");
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "tosa.depthwise_conv ops require static shapes");
auto dynamicDimsOr =
checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
if (!dynamicDimsOr.hasValue())
return failure();
SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
auto weightShape = weightTy.getShape();
auto resultShape = resultTy.getShape();
@ -354,13 +335,13 @@ public:
Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, linalgConvTy.getShape(), resultETy);
loc, dynamicDims, linalgConvTy.getShape(), resultETy);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Value zeroTensor =
rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultETy);
loc, dynamicDims, resultTy.getShape(), resultETy);
if (!isQuantized) {
Value conv = rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
@ -442,7 +423,7 @@ public:
dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
}
SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
SmallVector<Value> filteredDims = condenseValues(dynDims);
auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
@ -503,7 +484,7 @@ public:
dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
}
SmallVector<Value> filteredDims = filterDynamicDims(dynDims);
SmallVector<Value> filteredDims = condenseValues(dynDims);
// Creating maps for the output of MatMul and the bias
SmallVector<AffineMap, 4> indexingMaps;
@ -611,8 +592,11 @@ public:
ShapedType resultTy = op.getType().template cast<ShapedType>();
Type resultETy = inputTy.getElementType();
if (!inputTy.hasStaticShape())
auto dynamicDimsOr =
checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
if (!dynamicDimsOr.hasValue())
return failure();
SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
// Determine what the initial value needs to be for the max pool op.
Attribute initialAttr;
@ -649,7 +633,7 @@ public:
// Create the linalg op that performs pooling.
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultTy.getElementType());
loc, dynamicDims, resultTy.getShape(), resultTy.getElementType());
Value filledInitTensor =
rewriter.create<linalg::FillOp>(loc, initialValue, initTensor).result();
@ -682,8 +666,11 @@ public:
inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
ShapedType accTy = resultTy.clone(accETy);
if (!inputTy.hasStaticShape())
auto dynamicDimsOr =
checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
if (!dynamicDimsOr.hasValue())
return failure();
SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
// Apply padding as necessary.
llvm::SmallVector<int64_t> pad;
@ -704,8 +691,8 @@ public:
Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
// Create the linalg op that performs pooling.
Value poolInitTensor =
rewriter.create<linalg::InitTensorOp>(loc, accTy.getShape(), accETy);
Value poolInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, dynamicDims, accTy.getShape(), accETy);
Value filledInitTensor =
rewriter.create<linalg::FillOp>(loc, initialValue, poolInitTensor)
@ -728,7 +715,7 @@ public:
auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
Value genericInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultETy);
loc, dynamicDims, resultTy.getShape(), resultETy);
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
@ -770,7 +757,7 @@ public:
auto kH2 = padFn(kH1, y1, pad[3]);
auto kHCmp = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, kH2, one);
auto kH3 = rewriter.create<SelectOp>(loc, kHCmp, one, kH2);
auto kH3 = rewriter.create<mlir::SelectOp>(loc, kHCmp, one, kH2);
// compute the horizontal component of coverage.
auto kW0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[1]);
@ -778,7 +765,7 @@ public:
auto kW2 = padFn(kW1, x1, pad[5]);
auto kWCmp = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, kW2, one);
auto kW3 = rewriter.create<SelectOp>(loc, kWCmp, one, kW2);
auto kW3 = rewriter.create<mlir::SelectOp>(loc, kWCmp, one, kW2);
// Compute the total number of elements and normalize.
Value count = rewriter.create<arith::MulIOp>(loc, kH3, kW3);

View File

@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRTosa
Utils/ConversionUtils.cpp
Utils/QuantUtils.cpp
IR/TosaOps.cpp

View File

@ -0,0 +1,30 @@
//===- ConversionUtils.cpp ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Utility functions for TOSA lowering
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
using namespace mlir;
using namespace mlir::tosa;
SmallVector<StringRef>
mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) {
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
}
SmallVector<Value>
mlir::tosa::condenseValues(const SmallVector<Value> &values) {
SmallVector<Value> condensedValues;
for (auto value : values)
if (value)
condensedValues.push_back(value);
return condensedValues;
}

View File

@ -164,6 +164,19 @@ func @max_pool_padded(%arg0: tensor<1x6x34x62xf32>) -> () {
return
}
// CHECK-LABEL: @max_pool_dyn
func @max_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> () {
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
// CHECK: %[[CONST:.+]] = arith.constant -3.40282347E+38
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 4, 32, 62]
// CHECK: %[[FILL:.+]] = linalg.fill(%[[CONST]], %[[INIT]])
// CHECK: %[[KERNEL:.+]] = linalg.init_tensor [3, 3]
// CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x6x34x62xf32>, tensor<3x3xf32>) outs(%[[FILL]] : tensor<?x4x32x62xf32>)
%0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<?x6x34x62xf32>) -> (tensor<?x4x32x62xf32>)
return
}
// CHECK-LABEL: @max_pool_i8
func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () {
// CHECK: arith.constant -128
@ -250,6 +263,24 @@ func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
// -----
// CHECK-LABEL: @avg_pool_dyn
func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>) {
// The calculations remain the same as above, only testing for dyn behavior
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
// CHECK: %[[PAD:.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: %[[POOLINIT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 33, 62]
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[KERNEL:.+]] = linalg.init_tensor [4, 4]
// CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PAD]], %[[KERNEL]] : tensor<?x8x36x62xf32>, tensor<4x4xf32>) outs(%[[FILL]] : tensor<?x5x33x62xf32>)
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 33, 62]
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL]] : tensor<?x5x33x62xf32>) outs(%[[INIT]] : tensor<?x5x33x62xf32>)
%0 = "tosa.avg_pool2d"(%arg0) {pad = [1, 1, 1, 1], kernel = [4, 4], stride = [1, 1]} : (tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
return %0 : tensor<?x5x33x62xf32>
}
// -----
// CHECK-LABEL: @avg_pool_i8
func @avg_pool_i8(%arg0 : tensor<1x128x128x2xi8>) -> () {
@ -329,6 +360,29 @@ func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>
// -----
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: @conv2d_dyn
func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
// CHECK: %[[W:.+]] = "tosa.transpose"(%arg1, %[[PERM]])
// CHECK: %[[M_IN:.+]] = linalg.init_tensor [%[[BATCH]], 45, 40, 28]
// CHECK: %[[CST:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[B_IN:.+]] = linalg.init_tensor [%[[BATCH]], 45, 40, 28]
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<?x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>)
// CHECK: %[[ADD:.+]] = arith.addf
// CHECK: linalg.yield %[[ADD]] : f32
%0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> (tensor<?x45x40x28xf32>)
return
}
// -----
// CHECK-LABEL: @conv2d_padded_f32
func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
// CHECK: %[[C0:.+]] = arith.constant 0
@ -378,6 +432,30 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: @depthwise_conv_dyn
func @depthwise_conv_dyn(%arg0 : tensor<?x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () {
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 5, 3, 11]
// CHECK: %[[CST0:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[OUT:.+]] = linalg.init_tensor [%[[BATCH]], 5, 5, 33]
// CHECK: %[[DEPTH:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x7x5x3xf32>, tensor<3x1x3x11xf32>) outs(%[[FILL]] : tensor<?x5x5x3x11xf32>)
// CHECK: %[[COLLAPSED:.+]] = "tosa.reshape"(%[[DEPTH]]) {new_shape = [-1, 5, 5, 33]}
// CHECK: %[[BIAS:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[COLLAPSED]] : tensor<33xf32>, tensor<?x5x5x33xf32>) outs(%[[OUT]] : tensor<?x5x5x33xf32>) {
// CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
// CHECK: %[[ADD:.+]] = arith.addf %arg3, %arg4 : f32
// CHECK: linalg.yield %[[ADD]] : f32
// CHECK: } -> tensor<?x5x5x33xf32>
%2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<?x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> (tensor<?x5x5x33xf32>)
return
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: @depthwise_conv_strides
func @depthwise_conv_strides(%arg0 : tensor<1x11x9x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> () {
// CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 5, 3, 11]

View File

@ -897,6 +897,26 @@ func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @rescale_i8_dyn
func @rescale_i8_dyn(%arg0 : tensor<?x2xi8>) -> () {
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 2]
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
%0 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>) -> (tensor<?x2xi8>)
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 2]
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xui8>)
%1 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>) -> (tensor<?x2xui8>)
return
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @rescale_ui8
@ -1184,6 +1204,22 @@ func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) -> () {
return
}
// CHECK-LABEL: @gather_float_dyn
func @gather_float_dyn(%arg0: tensor<?x3x2xf32>, %arg1: tensor<?x3xi32>) -> () {
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 3, 2]
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<?x3xi32>) outs(%[[INIT]] : tensor<?x3x2xf32>)
// CHECK: ^bb0(%[[ARG0:.+]]: i32, %[[ARG1:.+]]: f32)
// CHECK: %[[IDX0:.+]] = linalg.index 0
// CHECK: %[[CAST:.+]] = arith.index_cast %[[ARG0]]
// CHECK: %[[IDX2:.+]] = linalg.index 2
// CHECK: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[IDX0]], %[[CAST]], %[[IDX2]]] : tensor<?x3x2xf32>
// CHECK: linalg.yield %[[EXTRACT]]
%0 = "tosa.gather"(%arg0, %arg1) : (tensor<?x3x2xf32>, tensor<?x3xi32>) -> (tensor<?x3x2xf32>)
return
}
// CHECK-LABEL: @gather_int
func @gather_int(%arg0: tensor<2x3x2xi32>, %arg1: tensor<2x3xi32>) -> () {
// CHECK: %[[INIT:.+]] = linalg.init_tensor [2, 3, 2]
@ -1548,3 +1584,15 @@ func @resize_bilinear_int(%input: tensor<1x2x2x1xi8>) -> () {
%output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "BILINEAR" } : (tensor<1x2x2x1xi8>) -> (tensor<1x4x4x1xi32>)
return
}
// -----
// CHECK-LABEL: @resize_dyn
func @resize_dyn(%input: tensor<?x2x2x1xi8>) -> () {
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[BATCH]], 4, 4, 1]
// CHECK: %[[GENERIC:.+]] = linalg.generic
%output = "tosa.resize"(%input) { output_size = [4, 4], stride = [128, 128], offset = [1, 2], stride_fp = [0. : f32, 0. : f32], offset_fp = [0. : f32, 0. : f32], shift = 8 : i32, mode = "BILINEAR" } : (tensor<?x2x2x1xi8>) -> (tensor<?x4x4x1xi32>)
return
}

View File

@ -7071,13 +7071,14 @@ cc_library(
includes = ["include"],
deps = [
":Analysis",
":ArithmeticDialect",
":Dialect",
":DialectUtils",
":IR",
":InferTypeOpInterface",
":LoopLikeInterface",
":Pass",
":QuantOps",
":SideEffectInterfaces",
":StandardOps",
":TensorDialect",
":TosaDialectIncGen",