[MLIR] Split arith dialect from the std dialect

Create the Arithmetic dialect that contains basic integer and floating
point arithmetic operations. Ops that did not meet this criterion were
moved to the Math dialect.

First of two atomic patches to remove integer and floating point
operations from the standard dialect. Ops will be removed from the
standard dialect in a subsequent patch.

Reviewed By: ftynse, silvas

Differential Revision: https://reviews.llvm.org/D110200
This commit is contained in:
Mogball 2021-09-21 20:54:07 +00:00 committed by Jeff Niu
parent a7ae227baf
commit 8c08f21b60
14 changed files with 2354 additions and 9 deletions

View File

@ -0,0 +1 @@
add_subdirectory(IR)

View File

@ -0,0 +1,56 @@
//===- Arithmetic.h - Arithmetic dialect --------------------------*- C++-*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_ARITHMETIC_IR_ARITHMETIC_H_
#define MLIR_DIALECT_ARITHMETIC_IR_ARITHMETIC_H_
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
//===----------------------------------------------------------------------===//
// ArithmeticDialect
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsDialect.h.inc"
//===----------------------------------------------------------------------===//
// Arithmetic Dialect Enum Attributes
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.h.inc"
//===----------------------------------------------------------------------===//
// Arithmetic Dialect Operations
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.h.inc"
//===----------------------------------------------------------------------===//
// Utility Functions
//===----------------------------------------------------------------------===//
namespace mlir {
namespace arith {
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
/// comparison predicates.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs,
const APInt &rhs);
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
/// comparison predicates.
bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
const APFloat &rhs);
} // end namespace arith
} // end namespace mlir
#endif // MLIR_DIALECT_ARITHMETIC_IR_ARITHMETIC_H_

View File

@ -0,0 +1,68 @@
//===- ArithmeticBase.td - Base defs for arith dialect ------*- tablegen -*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef ARITHMETIC_BASE
#define ARITHMETIC_BASE
include "mlir/IR/OpBase.td"
def Arithmetic_Dialect : Dialect {
let name = "arith";
let cppNamespace = "::mlir::arith";
let description = [{
The arithmetic dialect is intended to hold basic integer and floating point
mathematical operations. This includes unary, binary, and ternary arithmetic
ops, bitwise and shift ops, cast ops, and compare ops. Operations in this
dialect also accept vectors and tensors of integers or floats.
}];
}
// The predicate indicates the type of the comparison to perform:
// (un)orderedness, (in)equality and less/greater than (or equal to) as
// well as predicates that are always true or false.
def Arith_CmpFPredicateAttr : I64EnumAttr<
"CmpFPredicate", "",
[
I64EnumAttrCase<"AlwaysFalse", 0, "false">,
I64EnumAttrCase<"OEQ", 1, "oeq">,
I64EnumAttrCase<"OGT", 2, "ogt">,
I64EnumAttrCase<"OGE", 3, "oge">,
I64EnumAttrCase<"OLT", 4, "olt">,
I64EnumAttrCase<"OLE", 5, "ole">,
I64EnumAttrCase<"ONE", 6, "one">,
I64EnumAttrCase<"ORD", 7, "ord">,
I64EnumAttrCase<"UEQ", 8, "ueq">,
I64EnumAttrCase<"UGT", 9, "ugt">,
I64EnumAttrCase<"UGE", 10, "uge">,
I64EnumAttrCase<"ULT", 11, "ult">,
I64EnumAttrCase<"ULE", 12, "ule">,
I64EnumAttrCase<"UNE", 13, "une">,
I64EnumAttrCase<"UNO", 14, "uno">,
I64EnumAttrCase<"AlwaysTrue", 15, "true">,
]> {
let cppNamespace = "::mlir::arith";
}
def Arith_CmpIPredicateAttr : I64EnumAttr<
"CmpIPredicate", "",
[
I64EnumAttrCase<"eq", 0>,
I64EnumAttrCase<"ne", 1>,
I64EnumAttrCase<"slt", 2>,
I64EnumAttrCase<"sle", 3>,
I64EnumAttrCase<"sgt", 4>,
I64EnumAttrCase<"sge", 5>,
I64EnumAttrCase<"ult", 6>,
I64EnumAttrCase<"ule", 7>,
I64EnumAttrCase<"ugt", 8>,
I64EnumAttrCase<"uge", 9>,
]> {
let cppNamespace = "::mlir::arith";
}
#endif // ARITHMETIC_BASE

View File

@ -0,0 +1,997 @@
//===- ArithmeticOps.td - Arithmetic op definitions --------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef ARITHMETIC_OPS
#define ARITHMETIC_OPS
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
// Base class for Arithmetic dialect ops. Ops in this dialect have no side
// effects and can be applied element-wise to vectors and tensors.
class Arith_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Arithmetic_Dialect, mnemonic, traits # [NoSideEffect,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
ElementwiseMappable.traits>;
// Base class for integer and floating point arithmetic ops. All ops have one
// result, require operands and results to be of the same type, and can accept
// tensors or vectors of integers or floats.
class Arith_ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
Arith_Op<mnemonic, traits # [SameOperandsAndResultType]>;
// Base class for unary arithmetic operations.
class Arith_UnaryOp<string mnemonic, list<OpTrait> traits = []> :
Arith_ArithmeticOp<mnemonic, traits> {
let assemblyFormat = "$operand attr-dict `:` type($result)";
}
// Base class for binary arithmetic operations.
class Arith_BinaryOp<string mnemonic, list<OpTrait> traits = []> :
Arith_ArithmeticOp<mnemonic, traits> {
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
}
// Base class for ternary arithmetic operations.
class Arith_TernaryOp<string mnemonic, list<OpTrait> traits = []> :
Arith_ArithmeticOp<mnemonic, traits> {
let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($result)";
}
// Base class for integer binary operations.
class Arith_IntBinaryOp<string mnemonic, list<OpTrait> traits = []> :
Arith_BinaryOp<mnemonic, traits>,
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>,
Results<(outs SignlessIntegerLike:$result)>;
// Base class for floating point unary operations.
class Arith_FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
Arith_UnaryOp<mnemonic, traits>,
Arguments<(ins FloatLike:$operand)>,
Results<(outs FloatLike:$result)>;
// Base class for floating point binary operations.
class Arith_FloatBinaryOp<string mnemonic, list<OpTrait> traits = []> :
Arith_BinaryOp<mnemonic, traits>,
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>,
Results<(outs FloatLike:$result)>;
// Base class for arithmetic cast operations. Requires a single operand and
// result. If either is a shaped type, then the other must be of the same shape.
class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
list<OpTrait> traits = []> :
Arith_Op<mnemonic, traits # [SameOperandsAndResultShape,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins From:$in)>,
Results<(outs To:$out)> {
let builders = [
OpBuilder<(ins "Value":$source, "Type":$destType), [{
impl::buildCastOp($_builder, $_state, source, destType);
}]>
];
let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
}
// Casts do not accept indices. Type constraint for signless-integer-like types
// excluding indices: signless integers, vectors or tensors thereof.
def SignlessFixedWidthIntegerLike : TypeConstraint<Or<[
AnySignlessInteger.predicate,
VectorOf<[AnySignlessInteger]>.predicate,
TensorOf<[AnySignlessInteger]>.predicate]>,
"signless-fixed-width-integer-like">;
// Cast from an integer type to another integer type.
class Arith_IToICastOp<string mnemonic, list<OpTrait> traits = []> :
Arith_CastOp<mnemonic, SignlessFixedWidthIntegerLike,
SignlessFixedWidthIntegerLike>;
// Cast from an integer type to a floating point type.
class Arith_IToFCastOp<string mnemonic, list<OpTrait> traits = []> :
Arith_CastOp<mnemonic, SignlessFixedWidthIntegerLike, FloatLike>;
// Cast from a floating point type to an integer type.
class Arith_FToICastOp<string mnemonic, list<OpTrait> traits = []> :
Arith_CastOp<mnemonic, FloatLike, SignlessFixedWidthIntegerLike>;
// Cast from a floating point type to another floating point type.
class Arith_FToFCastOp<string mnemonic, list<OpTrait> traits = []> :
Arith_CastOp<mnemonic, FloatLike, FloatLike>;
// Base class for compare operations. Requires two operands of the same type
// and returns a single `BoolLike` result. If the operand type is a vector or
// tensor, then the result will be one of `i1` of the same shape.
class Arith_CompareOp<string mnemonic, list<OpTrait> traits = []> :
Arith_Op<mnemonic, traits # [SameTypeOperands, TypesMatchWith<
"result type has i1 element type and same shape as operands",
"lhs", "result", "::getI1SameShape($_self)">]> {
let results = (outs BoolLike:$result);
let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
}
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
[ConstantLike, NoSideEffect, TypesMatchWith<
"result type has same type as the attribute value",
"value", "result", "$_self">]> {
let summary = "integer or floating point constant";
let description = [{
The `const` operation produces an SSA value equal to some integer or
floating-point constant specified by an attribute. This is the way MLIR
forms simple integer and floating point constants.
Example:
```
// Integer constant
%1 = arith.constant 42 : i32
// Equivalent generic form
%1 = "arith.constant"() {value = 42 : i32} : () -> i32
```
}];
let arguments = (ins AnyAttr:$value);
let results = (outs SignlessIntegerOrFloatLike:$result);
let builders = [
OpBuilder<(ins "Attribute":$value),
[{ build($_builder, $_state, value.getType(), value); }]>,
OpBuilder<(ins "Attribute":$value, "Type":$type),
[{ build($_builder, $_state, type, value); }]>,
];
let hasFolder = 1;
let assemblyFormat = "attr-dict $value";
}
//===----------------------------------------------------------------------===//
// AddIOp
//===----------------------------------------------------------------------===//
def Arith_AddIOp : Arith_IntBinaryOp<"addi", [Commutative]> {
let summary = "integer addition operation";
let description = [{
The `addi` operation takes two operands and returns one result, each of
these is required to be the same type. This type may be an integer scalar
type, a vector whose element type is integer, or a tensor of integers. It
has no standard attributes.
Example:
```mlir
// Scalar addition.
%a = arith.addi %b, %c : i64
// SIMD vector element-wise addition, e.g. for Intel SSE.
%f = arith.addi %g, %h : vector<4xi32>
// Tensor element-wise addition.
%x = arith.addi %y, %z : tensor<4x?xi8>
```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// SubIOp
//===----------------------------------------------------------------------===//
def Arith_SubIOp : Arith_IntBinaryOp<"subi"> {
let summary = "integer subtraction operation";
let hasFolder = 1;
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// MulIOp
//===----------------------------------------------------------------------===//
def Arith_MulIOp : Arith_IntBinaryOp<"muli", [Commutative]> {
let summary = "integer multiplication operation";
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// DivUIOp
//===----------------------------------------------------------------------===//
def Arith_DivUIOp : Arith_IntBinaryOp<"divui"> {
let summary = "unsigned integer division operation";
let description = [{
Unsigned integer division. Rounds towards zero. Treats the leading bit as
the most significant, i.e. for `i16` given two's complement representation,
`6 / -2 = 6 / (2^16 - 2) = 0`.
Note: the semantics of division by zero is TBD; do NOT assume any specific
behavior.
Example:
```mlir
// Scalar unsigned integer division.
%a = arith.divui %b, %c : i64
// SIMD vector element-wise division.
%f = arith.divui %g, %h : vector<4xi32>
// Tensor element-wise integer division.
%x = arith.divui %y, %z : tensor<4x?xi8>
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// DivSIOp
//===----------------------------------------------------------------------===//
def Arith_DivSIOp : Arith_IntBinaryOp<"divsi"> {
let summary = "signed integer division operation";
let description = [{
Signed integer division. Rounds towards zero. Treats the leading bit as
sign, i.e. `6 / -2 = -3`.
Note: the semantics of division by zero or signed division overflow (minimum
value divided by -1) is TBD; do NOT assume any specific behavior.
Example:
```mlir
// Scalar signed integer division.
%a = arith.divsi %b, %c : i64
// SIMD vector element-wise division.
%f = arith.divsi %g, %h : vector<4xi32>
// Tensor element-wise integer division.
%x = arith.divsi %y, %z : tensor<4x?xi8>
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// CeilDivSIOp
//===----------------------------------------------------------------------===//
def Arith_CeilDivSIOp : Arith_IntBinaryOp<"ceildivsi"> {
let summary = "signed ceil integer division operation";
let description = [{
Signed integer division. Rounds towards positive infinity, i.e. `7 / -2 = -3`.
Note: the semantics of division by zero or signed division overflow (minimum
value divided by -1) is TBD; do NOT assume any specific behavior.
Example:
```mlir
// Scalar signed integer division.
%a = arith.ceildivsi %b, %c : i64
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// FloorDivSIOp
//===----------------------------------------------------------------------===//
def Arith_FloorDivSIOp : Arith_IntBinaryOp<"floordivsi"> {
let summary = "signed floor integer division operation";
let description = [{
Signed integer division. Rounds towards negative infinity, i.e. `5 / -2 = -3`.
Note: the semantics of division by zero or signed division overflow (minimum
value divided by -1) is TBD; do NOT assume any specific behavior.
Example:
```mlir
// Scalar signed integer division.
%a = arith.floordivsi %b, %c : i64
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// RemUIOp
//===----------------------------------------------------------------------===//
def Arith_RemUIOp : Arith_IntBinaryOp<"remui"> {
let summary = "unsigned integer division remainder operation";
let description = [{
Unsigned integer division remainder. Treats the leading bit as the most
significant, i.e. for `i16`, `6 % -2 = 6 % (2^16 - 2) = 6`.
Note: the semantics of division by zero is TBD; do NOT assume any specific
behavior.
Example:
```mlir
// Scalar unsigned integer division remainder.
%a = arith.remui %b, %c : i64
// SIMD vector element-wise division remainder.
%f = arith.remui %g, %h : vector<4xi32>
// Tensor element-wise integer division remainder.
%x = arith.remui %y, %z : tensor<4x?xi8>
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// RemSIOp
//===----------------------------------------------------------------------===//
def Arith_RemSIOp : Arith_IntBinaryOp<"remsi"> {
let summary = "signed integer division remainder operation";
let description = [{
Signed integer division remainder. Treats the leading bit as sign, i.e. `6 %
-2 = 0`.
Note: the semantics of division by zero is TBD; do NOT assume any specific
behavior.
Example:
```mlir
// Scalar signed integer division remainder.
%a = remsi %b, %c : i64
// SIMD vector element-wise division remainder.
%f = remsi %g, %h : vector<4xi32>
// Tensor element-wise integer division remainder.
%x = remsi %y, %z : tensor<4x?xi8>
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// AndIOp
//===----------------------------------------------------------------------===//
def Arith_AndIOp : Arith_IntBinaryOp<"andi", [Commutative]> {
let summary = "integer binary and";
let description = [{
The `andi` operation takes two operands and returns one result, each of
these is required to be the same type. This type may be an integer scalar
type, a vector whose element type is integer, or a tensor of integers. It
has no standard attributes.
Example:
```mlir
// Scalar integer bitwise and.
%a = arith.andi %b, %c : i64
// SIMD vector element-wise bitwise integer and.
%f = arith.andi %g, %h : vector<4xi32>
// Tensor element-wise bitwise integer and.
%x = arith.andi %y, %z : tensor<4x?xi8>
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// OrIOp
//===----------------------------------------------------------------------===//
def Arith_OrIOp : Arith_IntBinaryOp<"ori", [Commutative]> {
let summary = "integer binary or";
let description = [{
The `ori` operation takes two operands and returns one result, each of these
is required to be the same type. This type may be an integer scalar type, a
vector whose element type is integer, or a tensor of integers. It has no
standard attributes.
Example:
```mlir
// Scalar integer bitwise or.
%a = arith.ori %b, %c : i64
// SIMD vector element-wise bitwise integer or.
%f = arith.ori %g, %h : vector<4xi32>
// Tensor element-wise bitwise integer or.
%x = arith.ori %y, %z : tensor<4x?xi8>
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//
def Arith_XOrIOp : Arith_IntBinaryOp<"xori", [Commutative]> {
let summary = "integer binary xor";
let description = [{
The `xori` operation takes two operands and returns one result, each of
these is required to be the same type. This type may be an integer scalar
type, a vector whose element type is integer, or a tensor of integers. It
has no standard attributes.
Example:
```mlir
// Scalar integer bitwise xor.
%a = arith.xori %b, %c : i64
// SIMD vector element-wise bitwise integer xor.
%f = arith.xori %g, %h : vector<4xi32>
// Tensor element-wise bitwise integer xor.
%x = arith.xori %y, %z : tensor<4x?xi8>
```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// ShLIOp
//===----------------------------------------------------------------------===//
def Arith_ShLIOp : Arith_IntBinaryOp<"shli"> {
let summary = "integer left-shift";
let description = [{
The `shli` operation shifts an integer value to the left by a variable
amount. The low order bits are filled with zeros.
Example:
```mlir
%1 = arith.constant 5 : i8 // %1 is 0b00000101
%2 = arith.constant 3 : i8
%3 = arith.shli %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000
```
}];
}
//===----------------------------------------------------------------------===//
// ShRUIOp
//===----------------------------------------------------------------------===//
def Arith_ShRUIOp : Arith_IntBinaryOp<"shrui"> {
let summary = "unsigned integer right-shift";
let description = [{
The `shrui` operation shifts an integer value to the right by a variable
amount. The integer is interpreted as unsigned. The high order bits are
always filled with zeros.
Example:
```mlir
%1 = arith.constant 160 : i8 // %1 is 0b10100000
%2 = arith.constant 3 : i8
%3 = arith.shrui %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100
```
}];
}
//===----------------------------------------------------------------------===//
// ShRSIOp
//===----------------------------------------------------------------------===//
def Arith_ShRSIOp : Arith_IntBinaryOp<"shrsi"> {
let summary = "signed integer right-shift";
let description = [{
The `shrsi` operation shifts an integer value to the right by a variable
amount. The integer is interpreted as signed. The high order bits in the
output are filled with copies of the most-significant bit of the shifted
value (which means that the sign of the value is preserved).
Example:
```mlir
%1 = arith.constant 160 : i8 // %1 is 0b10100000
%2 = arith.constant 3 : i8
%3 = arith.shrsi %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100
%4 = arith.constant 96 : i8 // %4 is 0b01100000
%5 = arith.shrsi %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100
```
}];
}
//===----------------------------------------------------------------------===//
// NegFOp
//===----------------------------------------------------------------------===//
def Arith_NegFOp : Arith_FloatUnaryOp<"negf"> {
let summary = "floating point negation";
let description = [{
The `negf` operation computes the negation of a given value. It takes one
operand and returns one result of the same type. This type may be a float
scalar type, a vector whose element type is float, or a tensor of floats.
It has no standard attributes.
Example:
```mlir
// Scalar negation value.
%a = arith.negf %b : f64
// SIMD vector element-wise negation value.
%f = arith.negf %g : vector<4xf32>
// Tensor element-wise negation value.
%x = arith.negf %y : tensor<4x?xf8>
```
}];
}
//===----------------------------------------------------------------------===//
// AddFOp
//===----------------------------------------------------------------------===//
def Arith_AddFOp : Arith_FloatBinaryOp<"addf"> {
let summary = "floating point addition operation";
let description = [{
The `addf` operation takes two operands and returns one result, each of
these is required to be the same type. This type may be a floating point
scalar type, a vector whose element type is a floating point type, or a
floating point tensor.
Example:
```mlir
// Scalar addition.
%a = arith.addf %b, %c : f64
// SIMD vector addition, e.g. for Intel SSE.
%f = arith.addf %g, %h : vector<4xf32>
// Tensor addition.
%x = arith.addf %y, %z : tensor<4x?xbf16>
```
TODO: In the distant future, this will accept optional attributes for fast
math, contraction, rounding mode, and other controls.
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// SubFOp
//===----------------------------------------------------------------------===//
def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
let summary = "floating point subtraction operation";
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
def Arith_MulFOp : Arith_FloatBinaryOp<"mulf"> {
let summary = "floating point multiplication operation";
let description = [{
The `mulf` operation takes two operands and returns one result, each of
these is required to be the same type. This type may be a floating point
scalar type, a vector whose element type is a floating point type, or a
floating point tensor.
Example:
```mlir
// Scalar multiplication.
%a = arith.mulf %b, %c : f64
// SIMD pointwise vector multiplication, e.g. for Intel SSE.
%f = arith.mulf %g, %h : vector<4xf32>
// Tensor pointwise multiplication.
%x = arith.mulf %y, %z : tensor<4x?xbf16>
```
TODO: In the distant future, this will accept optional attributes for fast
math, contraction, rounding mode, and other controls.
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// DivFOp
//===----------------------------------------------------------------------===//
def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
let summary = "floating point division operation";
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// RemFOp
//===----------------------------------------------------------------------===//
def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> {
let summary = "floating point division remainder operation";
}
//===----------------------------------------------------------------------===//
// ExtUIOp
//===----------------------------------------------------------------------===//
def Arith_ExtUIOp : Arith_IToICastOp<"extui"> {
let summary = "integer zero extension operation";
let description = [{
The integer zero extension operation takes an integer input of
width M and an integer destination type of width N. The destination
bit-width must be larger than the input bit-width (N > M).
The top-most (N - M) bits of the output are filled with zeros.
Example:
```mlir
%1 = arith.constant 5 : i3 // %1 is 0b101
%2 = arith.extui %1 : i3 to i6 // %2 is 0b000101
%3 = arith.constant 2 : i3 // %3 is 0b010
%4 = arith.extui %3 : i3 to i6 // %4 is 0b000010
%5 = arith.extui %0 : vector<2 x i32> to vector<2 x i64>
```
}];
let hasFolder = 1;
let verifier = [{ return verifyExtOp<IntegerType>(*this); }];
}
//===----------------------------------------------------------------------===//
// ExtSIOp
//===----------------------------------------------------------------------===//
def Arith_ExtSIOp : Arith_IToICastOp<"extsi"> {
let summary = "integer sign extension operation";
let description = [{
The integer sign extension operation takes an integer input of
width M and an integer destination type of width N. The destination
bit-width must be larger than the input bit-width (N > M).
The top-most (N - M) bits of the output are filled with copies
of the most-significant bit of the input.
Example:
```mlir
%1 = arith.constant 5 : i3 // %1 is 0b101
%2 = arith.extsi %1 : i3 to i6 // %2 is 0b111101
%3 = arith.constant 2 : i3 // %3 is 0b010
%4 = arith.extsi %3 : i3 to i6 // %4 is 0b000010
%5 = arith.extsi %0 : vector<2 x i32> to vector<2 x i64>
```
}];
let hasFolder = 1;
let verifier = [{ return verifyExtOp<IntegerType>(*this); }];
}
//===----------------------------------------------------------------------===//
// ExtFOp
//===----------------------------------------------------------------------===//
def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
let summary = "cast from floating-point to wider floating-point";
let description = [{
Cast a floating-point value to a larger floating-point-typed value.
The destination type must to be strictly wider than the source type.
When operating on vectors, casts elementwise.
}];
let verifier = [{ return verifyExtOp<FloatType>(*this); }];
}
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
let summary = "integer truncation operation";
let description = [{
The integer truncation operation takes an integer input of
width M and an integer destination type of width N. The destination
bit-width must be smaller than the input bit-width (N < M).
The top-most (N - M) bits of the input are discarded.
Example:
```mlir
%1 = arith.constant 21 : i5 // %1 is 0b10101
%2 = trunci %1 : i5 to i4 // %2 is 0b0101
%3 = trunci %1 : i5 to i3 // %3 is 0b101
%5 = trunci %0 : vector<2 x i32> to vector<2 x i16>
```
}];
let hasFolder = 1;
let verifier = [{ return verifyTruncateOp<IntegerType>(*this); }];
}
//===----------------------------------------------------------------------===//
// TruncFOp
//===----------------------------------------------------------------------===//
def Arith_TruncFOp : Arith_FToFCastOp<"truncf"> {
let summary = "cast from floating-point to narrower floating-point";
let description = [{
Truncate a floating-point value to a smaller floating-point-typed value.
The destination type must be strictly narrower than the source type.
If the value cannot be exactly represented, it is rounded using the default
rounding mode. When operating on vectors, casts elementwise.
}];
let hasFolder = 1;
let verifier = [{ return verifyTruncateOp<FloatType>(*this); }];
}
//===----------------------------------------------------------------------===//
// UIToFPOp
//===----------------------------------------------------------------------===//
def Arith_UIToFPOp : Arith_IToFCastOp<"uitofp"> {
let summary = "cast from unsigned integer type to floating-point";
let description = [{
Cast from a value interpreted as unsigned integer to the corresponding
floating-point value. If the value cannot be exactly represented, it is
rounded using the default rounding mode. When operating on vectors, casts
elementwise.
}];
}
//===----------------------------------------------------------------------===//
// SIToFPOp
//===----------------------------------------------------------------------===//
def Arith_SIToFPOp : Arith_IToFCastOp<"sitofp"> {
let summary = "cast from integer type to floating-point";
let description = [{
Cast from a value interpreted as a signed integer to the corresponding
floating-point value. If the value cannot be exactly represented, it is
rounded using the default rounding mode. When operating on vectors, casts
elementwise.
}];
}
//===----------------------------------------------------------------------===//
// FPToUIOp
//===----------------------------------------------------------------------===//
def Arith_FPToUIOp : Arith_FToICastOp<"fptoui"> {
let summary = "cast from floating-point type to integer type";
let description = [{
Cast from a value interpreted as floating-point to the nearest (rounding
towards zero) unsigned integer value. When operating on vectors, casts
elementwise.
}];
}
//===----------------------------------------------------------------------===//
// FPToSIOp
//===----------------------------------------------------------------------===//
def Arith_FPToSIOp : Arith_FToICastOp<"fptosi"> {
let summary = "cast from floating-point type to integer type";
let description = [{
Cast from a value interpreted as floating-point to the nearest (rounding
towards zero) signed integer value. When operating on vectors, casts
elementwise.
}];
}
//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//
def Arith_IndexCastOp : Arith_IToICastOp<"index_cast"> {
let summary = "cast between index and integer types";
let description = [{
Casts between scalar or vector integers and corresponding 'index' scalar or
vectors. Index is an integer of platform-specific bit width. If casting to
a wider integer, the value is sign-extended. If casting to a narrower
integer, the value is truncated.
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// BitcastOp
//===----------------------------------------------------------------------===//
def Arith_BitcastOp : Arith_CastOp<"bitcast", SignlessIntegerOrFloatLike,
SignlessIntegerOrFloatLike> {
let summary = "bitcast between values of equal bit width";
let description = [{
Bitcast an integer or floating point value to an integer or floating point
value of equal bit width. When operating on vectors, casts elementwise.
Note that this implements a logical bitcast independent of target
endianness. This allows constant folding without target information and is
consitent with the bitcast constant folders in LLVM (see
https://github.com/llvm/llvm-project/blob/18c19414eb/llvm/lib/IR/ConstantFold.cpp#L168)
For targets where the source and target type have the same endianness (which
is the standard), this cast will also change no bits at runtime, but it may
still require an operation, for example if the machine has different
floating point and integer register files. For targets that have a different
endianness for the source and target types (e.g. float is big-endian and
integer is little-endian) a proper lowering would add operations to swap the
order of words in addition to the bitcast.
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
def Arith_CmpIOp : Arith_CompareOp<"cmpi"> {
let summary = "integer comparison operation";
let description = [{
The `cmpi` operation is a generic comparison for integer-like types. Its two
arguments can be integers, vectors or tensors thereof as long as their types
match. The operation produces an i1 for the former case, a vector or a
tensor of i1 with the same shape as inputs in the other cases.
Its first argument is an attribute that defines which type of comparison is
performed. The following comparisons are supported:
- equal (mnemonic: `"eq"`; integer value: `0`)
- not equal (mnemonic: `"ne"`; integer value: `1`)
- signed less than (mnemonic: `"slt"`; integer value: `2`)
- signed less than or equal (mnemonic: `"sle"`; integer value: `3`)
- signed greater than (mnemonic: `"sgt"`; integer value: `4`)
- signed greater than or equal (mnemonic: `"sge"`; integer value: `5`)
- unsigned less than (mnemonic: `"ult"`; integer value: `6`)
- unsigned less than or equal (mnemonic: `"ule"`; integer value: `7`)
- unsigned greater than (mnemonic: `"ugt"`; integer value: `8`)
- unsigned greater than or equal (mnemonic: `"uge"`; integer value: `9`)
The result is `1` if the comparison is true and `0` otherwise. For vector or
tensor operands, the comparison is performed elementwise and the element of
the result indicates whether the comparison is true for the operand elements
with the same indices as those of the result.
Note: while the custom assembly form uses strings, the actual underlying
attribute has integer type (or rather enum class in C++ code) as seen from
the generic assembly form. String literals are used to improve readability
of the IR by humans.
This operation only applies to integer-like operands, but not floats. The
main reason being that comparison operations have diverging sets of
attributes: integers require sign specification while floats require various
floating point-related particularities, e.g., `-ffast-math` behavior,
IEEE754 compliance, etc
([rationale](../Rationale/Rationale.md#splitting-floating-point-vs-integer-operations)).
The type of comparison is specified as attribute to avoid introducing ten
similar operations, taking into account that they are often implemented
using the same operation downstream
([rationale](../Rationale/Rationale.md#specifying-comparison-kind-as-attribute)). The
separation between signed and unsigned order comparisons is necessary
because of integers being signless. The comparison operation must know how
to interpret values with the foremost bit being set: negatives in two's
complement or large positives
([rationale](../Rationale/Rationale.md#specifying-sign-in-integer-comparison-operations)).
Example:
```mlir
// Custom form of scalar "signed less than" comparison.
%x = arith.cmpi "slt", %lhs, %rhs : i32
// Generic form of the same operation.
%x = "arith.cmpi"(%lhs, %rhs) {predicate = 2 : i64} : (i32, i32) -> i1
// Custom form of vector equality comparison.
%x = arith.cmpi "eq", %lhs, %rhs : vector<4xi64>
// Generic form of the same operation.
%x = "std.arith.cmpi"(%lhs, %rhs) {predicate = 0 : i64}
: (vector<4xi64>, vector<4xi64>) -> vector<4xi1>
```
}];
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
SignlessIntegerLike:$lhs,
SignlessIntegerLike:$rhs);
let builders = [
OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{
build($_builder, $_state, ::getI1SameShape(lhs.getType()),
predicate, lhs, rhs);
}]>
];
let extraClassDeclaration = [{
static StringRef getPredicateAttrName() { return "predicate"; }
static CmpIPredicate getPredicateByName(StringRef name);
CmpIPredicate getPredicate() {
return (CmpIPredicate) (*this)->getAttrOfType<IntegerAttr>(
getPredicateAttrName()).getInt();
}
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// CmpFOp
//===----------------------------------------------------------------------===//
def Arith_CmpFOp : Arith_CompareOp<"cmpf"> {
let summary = "floating-point comparison operation";
let description = [{
The `cmpf` operation compares its two operands according to the float
comparison rules and the predicate specified by the respective attribute.
The predicate defines the type of comparison: (un)orderedness, (in)equality
and signed less/greater than (or equal to) as well as predicates that are
always true or false. The operands must have the same type, and this type
must be a float type, or a vector or tensor thereof. The result is an i1,
or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi,
the operands are always treated as signed. The u prefix indicates
*unordered* comparison, not unsigned comparison, so "une" means unordered or
not equal. For the sake of readability by humans, custom assembly form for
the operation uses a string-typed attribute for the predicate. The value of
this attribute corresponds to lower-cased name of the predicate constant,
e.g., "one" means "ordered not equal". The string representation of the
attribute is merely a syntactic sugar and is converted to an integer
attribute by the parser.
Example:
```mlir
%r1 = arith.cmpf "oeq" %0, %1 : f32
%r2 = arith.cmpf "ult" %0, %1 : tensor<42x42xf64>
%r3 = "arith.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1
```
}];
let arguments = (ins Arith_CmpFPredicateAttr:$predicate,
FloatLike:$lhs,
FloatLike:$rhs);
let builders = [
OpBuilder<(ins "CmpFPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{
build($_builder, $_state, ::getI1SameShape(lhs.getType()),
predicate, lhs, rhs);
}]>
];
let extraClassDeclaration = [{
static StringRef getPredicateAttrName() { return "predicate"; }
static CmpFPredicate getPredicateByName(StringRef name);
CmpFPredicate getPredicate() {
return (CmpFPredicate) (*this)->getAttrOfType<IntegerAttr>(
getPredicateAttrName()).getInt();
}
}];
let hasFolder = 1;
}
#endif // ARITHMETIC_OPS

View File

@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS ArithmeticOps.td)
mlir_tablegen(ArithmeticOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(ArithmeticOpsEnums.cpp.inc -gen-enum-defs)
add_mlir_dialect(ArithmeticOps arith)
add_mlir_doc(ArithmeticOps ArithmeticOps Dialects/ -gen-dialect-doc)

View File

@ -1,4 +1,5 @@
add_subdirectory(Affine)
add_subdirectory(Arithmetic)
add_subdirectory(Async)
add_subdirectory(ArmNeon)
add_subdirectory(ArmSVE)

View File

@ -13,30 +13,72 @@ include "mlir/Dialect/Math/IR/MathBase.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
class Math_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Math_Dialect, mnemonic, traits # [NoSideEffect]>;
// Base class for math dialect ops. Ops in this dialect have no side effects and
// can be applied element-wise to vectors and tensors.
class Math_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Math_Dialect, mnemonic, traits # [NoSideEffect,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
ElementwiseMappable.traits>;
// Base class for unary math operations on floating point types. Require a
// operand and result of the same type. This type can be a floating point type,
// or vector or tensor thereof.
class Math_FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
Math_Op<mnemonic, traits #
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
SameOperandsAndResultType] # ElementwiseMappable.traits> {
Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
let arguments = (ins FloatLike:$operand);
let results = (outs FloatLike:$result);
let assemblyFormat = "$operand attr-dict `:` type($result)";
}
// Base class for binary math operations on floating point types. Require two
// operands and one result of the same type. This type can be a floating point
// type, or a vector or tensor thereof.
class Math_FloatBinaryOp<string mnemonic, list<OpTrait> traits = []> :
Math_Op<mnemonic, traits # [
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
SameOperandsAndResultType] # ElementwiseMappable.traits> {
Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
let results = (outs FloatLike:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
}
// Base class for floating point ternary operations. Require three operands and
// one result of the same type. This type can be a floating point type, or a
// vector or tensor thereof.
class Math_FloatTernaryOp<string mnemonic, list<OpTrait> traits = []> :
Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
let arguments = (ins FloatLike:$a, FloatLike:$b, FloatLike:$c);
let results = (outs FloatLike:$result);
let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($result)";
}
//===----------------------------------------------------------------------===//
// AbsOp
//===----------------------------------------------------------------------===//
def Math_AbsOp : Math_FloatUnaryOp<"abs"> {
let summary = "floating point absolute-value operation";
let description = [{
The `abs` operation computes the absolute value. It takes one operand and
returns one result of the same type. This type may be a float scalar type,
a vector whose element type is float, or a tensor of floats.
Example:
```mlir
// Scalar absolute value.
%a = math.abs %b : f64
// SIMD vector element-wise absolute value.
%f = math.abs %g : vector<4xf32>
// Tensor element-wise absolute value.
%x = math.abs %y : tensor<4x?xf8>
```
}];
}
//===----------------------------------------------------------------------===//
// AtanOp
//===----------------------------------------------------------------------===//
@ -110,6 +152,73 @@ def Math_Atan2Op : Math_FloatBinaryOp<"atan2">{
}];
}
//===----------------------------------------------------------------------===//
// CeilOp
//===----------------------------------------------------------------------===//
def Math_CeilOp : Math_FloatUnaryOp<"ceil"> {
let summary = "ceiling of the specified value";
let description = [{
Syntax:
```
operation ::= ssa-id `=` `math.ceil` ssa-use `:` type
```
The `ceil` operation computes the ceiling of a given value. It takes one
operand and returns one result of the same type. This type may be a float
scalar type, a vector whose element type is float, or a tensor of floats.
It has no standard attributes.
Example:
```mlir
// Scalar ceiling value.
%a = math.ceil %b : f64
// SIMD vector element-wise ceiling value.
%f = math.ceil %g : vector<4xf32>
// Tensor element-wise ceiling value.
%x = math.ceil %y : tensor<4x?xf8>
```
}];
}
//===----------------------------------------------------------------------===//
// CopySignOp
//===----------------------------------------------------------------------===//
def Math_CopySignOp : Math_FloatBinaryOp<"copysign"> {
let summary = "A copysign operation";
let description = [{
Syntax:
```
operation ::= ssa-id `=` `math.copysign` ssa-use `,` ssa-use `:` type
```
The `copysign` returns a value with the magnitude of the first operand and
the sign of the second operand. It takes two operands and returns one
result of the same type. This type may be a float scalar type, a vector
whose element type is float, or a tensor of floats. It has no standard
attributes.
Example:
```mlir
// Scalar copysign value.
%a = math.copysign %b, %c : f64
// SIMD vector element-wise copysign value.
%f = math.copysign %g, %h : vector<4xf32>
// Tensor element-wise copysign value.
%x = math.copysign %y, %z : tensor<4x?xf8>
```
}];
}
//===----------------------------------------------------------------------===//
// CosOp
//===----------------------------------------------------------------------===//
@ -276,6 +385,77 @@ def Math_ExpM1Op : Math_FloatUnaryOp<"expm1"> {
}];
}
//===----------------------------------------------------------------------===//
// FloorOp
//===----------------------------------------------------------------------===//
def Math_FloorOp : Math_FloatUnaryOp<"floor"> {
let summary = "floor of the specified value";
let description = [{
Syntax:
```
operation ::= ssa-id `=` `math.floor` ssa-use `:` type
```
The `floor` operation computes the floor of a given value. It takes one
operand and returns one result of the same type. This type may be a float
scalar type, a vector whose element type is float, or a tensor of floats.
It has no standard attributes.
Example:
```mlir
// Scalar floor value.
%a = math.floor %b : f64
// SIMD vector element-wise floor value.
%f = math.floor %g : vector<4xf32>
// Tensor element-wise floor value.
%x = math.floor %y : tensor<4x?xf8>
```
}];
}
//===----------------------------------------------------------------------===//
// FmaOp
//===----------------------------------------------------------------------===//
def Math_FmaOp : Math_FloatTernaryOp<"fma"> {
let summary = "floating point fused multipy-add operation";
let description = [{
Syntax:
```
operation ::= ssa-id `=` `math.fma` ssa-use `,` ssa-use `,` ssa-use `:` type
```
The `fma` operation takes three operands and returns one result, each of
these is required to be the same type. This type may be a floating point
scalar type, a vector whose element type is a floating point type, or a
floating point tensor.
Example:
```mlir
// Scalar fused multiply-add: d = a*b + c
%d = math.fma %a, %b, %c : f64
// SIMD vector fused multiply-add, e.g. for Intel SSE.
%i = math.fma %f, %g, %h : vector<4xf32>
// Tensor fused multiply-add.
%w = math.fma %x, %y, %z : tensor<4x?xbf16>
```
The semantics of the operation correspond to those of the `llvm.fma`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). In the
particular case of lowering to LLVM, this is guaranteed to lower
to the `llvm.fma.*` intrinsic.
}];
}
//===----------------------------------------------------------------------===//
// LogOp
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1 @@
add_subdirectory(IR)

View File

@ -0,0 +1,131 @@
//===- ArithmeticPatterns.td - Arithmetic dialect patterns -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef ARITHMETIC_PATTERNS
#define ARITHMETIC_PATTERNS
include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td"
// Add two integer attributes and create a new one with the result.
def AddIntAttrs : NativeCodeCall<"addIntegerAttrs($_builder, $0, $1, $2)">;
// Subtract two integer attributes and createa a new one with the result.
def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
//===----------------------------------------------------------------------===//
// AddIOp
//===----------------------------------------------------------------------===//
// addi is commutative and will be canonicalized to have its constants appear
// as the second operand.
// addi(addi(x, c0), c1) -> addi(x, c0 + c1)
def AddIAddConstant :
Pat<(Arith_AddIOp:$res
(Arith_AddIOp $x, (Arith_ConstantOp APIntAttr:$c0)),
(Arith_ConstantOp APIntAttr:$c1)),
(Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>;
// addi(subi(x, c0), c1) -> addi(x, c1 - c0)
def AddISubConstantRHS :
Pat<(Arith_AddIOp:$res
(Arith_SubIOp $x, (Arith_ConstantOp APIntAttr:$c0)),
(Arith_ConstantOp APIntAttr:$c1)),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;
// addi(subi(c0, x), c1) -> subi(c0 + c1, x)
def AddISubConstantLHS :
Pat<(Arith_AddIOp:$res
(Arith_SubIOp (Arith_ConstantOp APIntAttr:$c0), $x),
(Arith_ConstantOp APIntAttr:$c1)),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
//===----------------------------------------------------------------------===//
// SubIOp
//===----------------------------------------------------------------------===//
// subi(addi(x, c0), c1) -> addi(x, c0 - c1)
def SubIRHSAddConstant :
Pat<(Arith_SubIOp:$res
(Arith_AddIOp $x, (Arith_ConstantOp APIntAttr:$c0)),
(Arith_ConstantOp APIntAttr:$c1)),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)))>;
// subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
def SubILHSAddConstant :
Pat<(Arith_SubIOp:$res
(Arith_ConstantOp APIntAttr:$c1),
(Arith_AddIOp $x, (Arith_ConstantOp APIntAttr:$c0))),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x)>;
// subi(subi(x, c0), c1) -> subi(x, c0 + c1)
def SubIRHSSubConstantRHS :
Pat<(Arith_SubIOp:$res
(Arith_SubIOp $x, (Arith_ConstantOp APIntAttr:$c0)),
(Arith_ConstantOp APIntAttr:$c1)),
(Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)))>;
// subi(subi(c0, x), c1) -> subi(c0 - c1, x)
def SubIRHSSubConstantLHS :
Pat<(Arith_SubIOp:$res
(Arith_SubIOp (Arith_ConstantOp APIntAttr:$c0), $x),
(Arith_ConstantOp APIntAttr:$c1)),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x)>;
// subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
def SubILHSSubConstantRHS :
Pat<(Arith_SubIOp:$res
(Arith_ConstantOp APIntAttr:$c1),
(Arith_SubIOp $x, (Arith_ConstantOp APIntAttr:$c0))),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
// subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
def SubILHSSubConstantLHS :
Pat<(Arith_SubIOp:$res
(Arith_ConstantOp APIntAttr:$c1),
(Arith_SubIOp (Arith_ConstantOp APIntAttr:$c0), $x)),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)))>;
//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//
// xori is commutative and will be canonicalized to have its constants appear
// as the second operand.
// not(cmpi(pred, a, b)) -> cmpi(~pred, a, b), where not(x) is xori(x, 1)
def InvertPredicate : NativeCodeCall<"invertPredicate($0)">;
def XOrINotCmpI :
Pat<(Arith_XOrIOp
(Arith_CmpIOp $pred, $a, $b),
(Arith_ConstantOp ConstantAttr<I1Attr, "1">)),
(Arith_CmpIOp (InvertPredicate $pred), $a, $b)>;
//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//
// index_cast(index_cast(x)) -> x, if dstType == srcType.
def IndexCastOfIndexCast :
Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x)),
(replaceWithValue $x),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
// index_cast(extsi(x)) -> index_cast(x)
def IndexCastOfExtSI :
Pat<(Arith_IndexCastOp (Arith_ExtSIOp $x)), (Arith_IndexCastOp $x)>;
//===----------------------------------------------------------------------===//
// BitcastOp
//===----------------------------------------------------------------------===//
// bitcast(bitcast(x)) -> x
def BitcastOfBitcast :
Pat<(Arith_BitcastOp (Arith_BitcastOp $x)), (replaceWithValue $x)>;
#endif // ARITHMETIC_PATTERNS

View File

@ -0,0 +1,37 @@
//===- ArithmeticDialect.cpp - MLIR Arithmetic dialect implementation -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
using namespace mlir::arith;
#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsDialect.cpp.inc"
namespace {
/// This class defines the interface for handling inlining for arithmetic
/// dialect operations.
struct ArithmeticInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
/// All arithmetic dialect ops can be inlined.
bool isLegalToInline(Operation *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}
};
} // end anonymous namespace
void mlir::arith::ArithmeticDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
>();
addInterfaces<ArithmeticInlinerInterface>();
}

View File

@ -0,0 +1,737 @@
//===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
using namespace mlir::arith;
//===----------------------------------------------------------------------===//
// Pattern helpers
//===----------------------------------------------------------------------===//
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
return builder.getIntegerAttr(res.getType(),
lhs.cast<IntegerAttr>().getInt() +
rhs.cast<IntegerAttr>().getInt());
}
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
return builder.getIntegerAttr(res.getType(),
lhs.cast<IntegerAttr>().getInt() -
rhs.cast<IntegerAttr>().getInt());
}
/// Invert an integer comparison predicate.
static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) {
switch (pred) {
case arith::CmpIPredicate::eq:
return arith::CmpIPredicate::ne;
case arith::CmpIPredicate::ne:
return arith::CmpIPredicate::eq;
case arith::CmpIPredicate::slt:
return arith::CmpIPredicate::sge;
case arith::CmpIPredicate::sle:
return arith::CmpIPredicate::sgt;
case arith::CmpIPredicate::sgt:
return arith::CmpIPredicate::sle;
case arith::CmpIPredicate::sge:
return arith::CmpIPredicate::slt;
case arith::CmpIPredicate::ult:
return arith::CmpIPredicate::uge;
case arith::CmpIPredicate::ule:
return arith::CmpIPredicate::ugt;
case arith::CmpIPredicate::ugt:
return arith::CmpIPredicate::ule;
case arith::CmpIPredicate::uge:
return arith::CmpIPredicate::ult;
}
llvm_unreachable("unknown cmpi predicate kind");
}
static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
return arith::CmpIPredicateAttr::get(pred.getContext(),
invertPredicate(pred.getValue()));
}
//===----------------------------------------------------------------------===//
// TableGen'd canonicalization patterns
//===----------------------------------------------------------------------===//
namespace {
#include "ArithmeticCanonicalization.inc"
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// AddIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
// addi(x, 0) -> x
if (matchPattern(rhs(), m_Zero()))
return lhs();
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a + b; });
}
void arith::AddIOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
context);
}
//===----------------------------------------------------------------------===//
// SubIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
// subi(x,x) -> 0
if (getOperand(0) == getOperand(1))
return Builder(getContext()).getZeroAttr(getType());
// subi(x,0) -> x
if (matchPattern(rhs(), m_Zero()))
return lhs();
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a - b; });
}
void arith::SubIOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
SubILHSSubConstantLHS>(context);
}
//===----------------------------------------------------------------------===//
// MulIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
// muli(x, 0) -> 0
if (matchPattern(rhs(), m_Zero()))
return rhs();
// muli(x, 1) -> x
if (matchPattern(rhs(), m_One()))
return getOperand(0);
// TODO: Handle the overflow case.
// default folder
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a * b; });
}
//===----------------------------------------------------------------------===//
// DivUIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
// Don't fold if it would require a division by zero.
bool div0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
if (div0 || !b) {
div0 = true;
return a;
}
return a.udiv(b);
});
// Fold out division by one. Assumes all tensors of all ones are splats.
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
if (rhs.getValue() == 1)
return lhs();
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
return lhs();
}
return div0 ? Attribute() : result;
}
//===----------------------------------------------------------------------===//
// DivSIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
if (overflowOrDiv0 || !b) {
overflowOrDiv0 = true;
return a;
}
return a.sdiv_ov(b, overflowOrDiv0);
});
// Fold out division by one. Assumes all tensors of all ones are splats.
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
if (rhs.getValue() == 1)
return lhs();
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
return lhs();
}
return overflowOrDiv0 ? Attribute() : result;
}
//===----------------------------------------------------------------------===//
// Ceil and floor division folding helpers
//===----------------------------------------------------------------------===//
static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
// Returns (a-1)/b + 1
APInt one(a.getBitWidth(), 1, true); // Signed value 1.
APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
return val.sadd_ov(one, overflow);
}
//===----------------------------------------------------------------------===//
// CeilDivSIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
if (overflowOrDiv0 || !b) {
overflowOrDiv0 = true;
return a;
}
unsigned bits = a.getBitWidth();
APInt zero = APInt::getZero(bits);
if (a.sgt(zero) && b.sgt(zero)) {
// Both positive, return ceil(a, b).
return signedCeilNonnegInputs(a, b, overflowOrDiv0);
}
if (a.slt(zero) && b.slt(zero)) {
// Both negative, return ceil(-a, -b).
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
}
if (a.slt(zero) && b.sgt(zero)) {
// A is negative, b is positive, return - ( -a / b).
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
APInt div = posA.sdiv_ov(b, overflowOrDiv0);
return zero.ssub_ov(div, overflowOrDiv0);
}
// A is positive (or zero), b is negative, return - (a / -b).
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
APInt div = a.sdiv_ov(posB, overflowOrDiv0);
return zero.ssub_ov(div, overflowOrDiv0);
});
// Fold out floor division by one. Assumes all tensors of all ones are
// splats.
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
if (rhs.getValue() == 1)
return lhs();
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
return lhs();
}
return overflowOrDiv0 ? Attribute() : result;
}
//===----------------------------------------------------------------------===//
// FloorDivSIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
if (overflowOrDiv0 || !b) {
overflowOrDiv0 = true;
return a;
}
unsigned bits = a.getBitWidth();
APInt zero = APInt::getZero(bits);
if (a.sge(zero) && b.sgt(zero)) {
// Both positive (or a is zero), return a / b.
return a.sdiv_ov(b, overflowOrDiv0);
}
if (a.sle(zero) && b.slt(zero)) {
// Both negative (or a is zero), return -a / -b.
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
return posA.sdiv_ov(posB, overflowOrDiv0);
}
if (a.slt(zero) && b.sgt(zero)) {
// A is negative, b is positive, return - ceil(-a, b).
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
return zero.ssub_ov(ceil, overflowOrDiv0);
}
// A is positive, b is negative, return - ceil(a, -b).
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
return zero.ssub_ov(ceil, overflowOrDiv0);
});
// Fold out floor division by one. Assumes all tensors of all ones are
// splats.
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
if (rhs.getValue() == 1)
return lhs();
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
return lhs();
}
return overflowOrDiv0 ? Attribute() : result;
}
//===----------------------------------------------------------------------===//
// RemUIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
if (!rhs)
return {};
auto rhsValue = rhs.getValue();
// x % 1 = 0
if (rhsValue.isOneValue())
return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
// Don't fold if it requires division by zero.
if (rhsValue.isNullValue())
return {};
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
if (!lhs)
return {};
return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
}
//===----------------------------------------------------------------------===//
// RemSIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
if (!rhs)
return {};
auto rhsValue = rhs.getValue();
// x % 1 = 0
if (rhsValue.isOneValue())
return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
// Don't fold if it requires division by zero.
if (rhsValue.isNullValue())
return {};
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
if (!lhs)
return {};
return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
}
//===----------------------------------------------------------------------===//
// AndIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
/// and(x, 0) -> 0
if (matchPattern(rhs(), m_Zero()))
return rhs();
/// and(x, allOnes) -> x
APInt intValue;
if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
return lhs();
/// and(x, x) -> x
if (lhs() == rhs())
return rhs();
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a & b; });
}
//===----------------------------------------------------------------------===//
// OrIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
/// or(x, 0) -> x
if (matchPattern(rhs(), m_Zero()))
return lhs();
/// or(x, x) -> x
if (lhs() == rhs())
return rhs();
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a | b; });
}
//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
/// xor(x, 0) -> x
if (matchPattern(rhs(), m_Zero()))
return lhs();
/// xor(x, x) -> 0
if (lhs() == rhs())
return Builder(getContext()).getZeroAttr(getType());
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a ^ b; });
}
void arith::XOrIOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<XOrINotCmpI>(context);
}
//===----------------------------------------------------------------------===//
// AddFOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
return constFoldBinaryOp<FloatAttr>(
operands, [](APFloat a, APFloat b) { return a + b; });
}
//===----------------------------------------------------------------------===//
// SubFOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
return constFoldBinaryOp<FloatAttr>(
operands, [](APFloat a, APFloat b) { return a - b; });
}
//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
return constFoldBinaryOp<FloatAttr>(
operands, [](APFloat a, APFloat b) { return a * b; });
}
//===----------------------------------------------------------------------===//
// DivFOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
return constFoldBinaryOp<FloatAttr>(
operands, [](APFloat a, APFloat b) { return a / b; });
}
//===----------------------------------------------------------------------===//
// Verifiers for integer and floating point extension/truncation ops
//===----------------------------------------------------------------------===//
// Extend ops can only extend to a wider type.
template <typename ValType, typename Op>
static LogicalResult verifyExtOp(Op op) {
Type srcType = getElementTypeOrSelf(op.in().getType());
Type dstType = getElementTypeOrSelf(op.getType());
if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
return op.emitError("result type ")
<< dstType << " must be wider than operand type " << srcType;
return success();
}
// Truncate ops can only truncate to a shorter type.
template <typename ValType, typename Op>
static LogicalResult verifyTruncateOp(Op op) {
Type srcType = getElementTypeOrSelf(op.in().getType());
Type dstType = getElementTypeOrSelf(op.getType());
if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
return op.emitError("result type ")
<< dstType << " must be shorter than operand type " << srcType;
return success();
}
//===----------------------------------------------------------------------===//
// ExtUIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
return IntegerAttr::get(
getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
return {};
}
//===----------------------------------------------------------------------===//
// ExtSIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
return IntegerAttr::get(
getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
return {};
}
//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//
bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
TypeRange outputs) {
assert(inputs.size() == 1 && outputs.size() == 1 &&
"index_cast op expects one result and one result");
// Shape equivalence is guaranteed by op traits.
auto srcType = getElementTypeOrSelf(inputs.front());
auto dstType = getElementTypeOrSelf(outputs.front());
return (srcType.isIndex() && dstType.isSignlessInteger()) ||
(srcType.isSignlessInteger() && dstType.isIndex());
}
OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
// index_cast(constant) -> constant
// A little hack because we go through int. Otherwise, the size of the
// constant might need to change.
if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
return IntegerAttr::get(getType(), value.getInt());
return {};
}
void arith::IndexCastOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
}
//===----------------------------------------------------------------------===//
// BitcastOp
//===----------------------------------------------------------------------===//
bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
assert(inputs.size() == 1 && outputs.size() == 1 &&
"bitcast op expects one operand and one result");
// Shape equivalence is guaranteed by op traits.
auto srcType = getElementTypeOrSelf(inputs.front());
auto dstType = getElementTypeOrSelf(outputs.front());
// Types are guarnateed to be integers or floats by constraints.
return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
}
OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1 && "bitcast op expects 1 operand");
auto resType = getType();
auto operand = operands[0];
if (!operand)
return {};
/// Bitcast dense elements.
if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
/// Other shaped types unhandled.
if (resType.isa<ShapedType>())
return {};
/// Bitcast integer or float to integer or float.
APInt bits = operand.isa<FloatAttr>()
? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
: operand.cast<IntegerAttr>().getValue();
if (auto resFloatType = resType.dyn_cast<FloatType>())
return FloatAttr::get(resType,
APFloat(resFloatType.getFloatSemantics(), bits));
return IntegerAttr::get(resType, bits);
}
void arith::BitcastOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<BitcastOfBitcast>(context);
}
//===----------------------------------------------------------------------===//
// Helpers for compare ops
//===----------------------------------------------------------------------===//
/// Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type);
if (type.isa<UnrankedTensorType>())
return UnrankedTensorType::get(i1Type);
if (auto vectorType = type.dyn_cast<VectorType>())
return VectorType::get(vectorType.getShape(), i1Type);
return i1Type;
}
//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
/// comparison predicates.
bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
const APInt &lhs, const APInt &rhs) {
switch (predicate) {
case arith::CmpIPredicate::eq:
return lhs.eq(rhs);
case arith::CmpIPredicate::ne:
return lhs.ne(rhs);
case arith::CmpIPredicate::slt:
return lhs.slt(rhs);
case arith::CmpIPredicate::sle:
return lhs.sle(rhs);
case arith::CmpIPredicate::sgt:
return lhs.sgt(rhs);
case arith::CmpIPredicate::sge:
return lhs.sge(rhs);
case arith::CmpIPredicate::ult:
return lhs.ult(rhs);
case arith::CmpIPredicate::ule:
return lhs.ule(rhs);
case arith::CmpIPredicate::ugt:
return lhs.ugt(rhs);
case arith::CmpIPredicate::uge:
return lhs.uge(rhs);
}
llvm_unreachable("unknown cmpi predicate kind");
}
/// Returns true if the predicate is true for two equal operands.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
switch (predicate) {
case arith::CmpIPredicate::eq:
case arith::CmpIPredicate::sle:
case arith::CmpIPredicate::sge:
case arith::CmpIPredicate::ule:
case arith::CmpIPredicate::uge:
return true;
case arith::CmpIPredicate::ne:
case arith::CmpIPredicate::slt:
case arith::CmpIPredicate::sgt:
case arith::CmpIPredicate::ult:
case arith::CmpIPredicate::ugt:
return false;
}
llvm_unreachable("unknown cmpi predicate kind");
}
OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "cmpi takes two operands");
// cmpi(pred, x, x)
if (lhs() == rhs()) {
auto val = applyCmpPredicateToEqualOperands(getPredicate());
return BoolAttr::get(getContext(), val);
}
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs)
return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
return BoolAttr::get(getContext(), val);
}
//===----------------------------------------------------------------------===//
// CmpFOp
//===----------------------------------------------------------------------===//
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
/// comparison predicates.
bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
const APFloat &lhs, const APFloat &rhs) {
auto cmpResult = lhs.compare(rhs);
switch (predicate) {
case arith::CmpFPredicate::AlwaysFalse:
return false;
case arith::CmpFPredicate::OEQ:
return cmpResult == APFloat::cmpEqual;
case arith::CmpFPredicate::OGT:
return cmpResult == APFloat::cmpGreaterThan;
case arith::CmpFPredicate::OGE:
return cmpResult == APFloat::cmpGreaterThan ||
cmpResult == APFloat::cmpEqual;
case arith::CmpFPredicate::OLT:
return cmpResult == APFloat::cmpLessThan;
case arith::CmpFPredicate::OLE:
return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
case arith::CmpFPredicate::ONE:
return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
case arith::CmpFPredicate::ORD:
return cmpResult != APFloat::cmpUnordered;
case arith::CmpFPredicate::UEQ:
return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
case arith::CmpFPredicate::UGT:
return cmpResult == APFloat::cmpUnordered ||
cmpResult == APFloat::cmpGreaterThan;
case arith::CmpFPredicate::UGE:
return cmpResult == APFloat::cmpUnordered ||
cmpResult == APFloat::cmpGreaterThan ||
cmpResult == APFloat::cmpEqual;
case arith::CmpFPredicate::ULT:
return cmpResult == APFloat::cmpUnordered ||
cmpResult == APFloat::cmpLessThan;
case arith::CmpFPredicate::ULE:
return cmpResult == APFloat::cmpUnordered ||
cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
case arith::CmpFPredicate::UNE:
return cmpResult != APFloat::cmpEqual;
case arith::CmpFPredicate::UNO:
return cmpResult == APFloat::cmpUnordered;
case arith::CmpFPredicate::AlwaysTrue:
return true;
}
llvm_unreachable("unknown cmpf predicate kind");
}
OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "cmpf takes two operands");
auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
if (!lhs || !rhs)
return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
return BoolAttr::get(getContext(), val);
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"

View File

@ -0,0 +1,18 @@
set(LLVM_TARGET_DEFINITIONS ArithmeticCanonicalization.td)
mlir_tablegen(ArithmeticCanonicalization.inc -gen-rewriters)
add_public_tablegen_target(MLIRArithmeticCanonicalizationIncGen)
add_mlir_dialect_library(MLIRArithmetic
ArithmeticOps.cpp
ArithmeticDialect.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic
DEPENDS
MLIRArithmeticOpsIncGen
LINK_LIBS PUBLIC
MLIRDialect
MLIRIR
)

View File

@ -1,4 +1,5 @@
add_subdirectory(Affine)
add_subdirectory(Arithmetic)
add_subdirectory(ArmNeon)
add_subdirectory(ArmSVE)
add_subdirectory(Async)

View File

@ -6673,6 +6673,118 @@ exports_files([
"include/mlir/Transforms/InliningUtils.h",
])
td_library(
name = "ArithmeticOpsTdFiles",
srcs = [
"include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td",
"include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td",
],
includes = ["include"],
deps = [
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
":VectorInterfacesTdFiles",
],
)
gentbl_cc_library(
name = "ArithmeticBaseIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
[
"-gen-dialect-decls",
"-dialect=arith",
],
"include/mlir/Dialect/Arithmetic/IR/ArithmeticOpsDialect.h.inc",
),
(
[
"-gen-dialect-defs",
"-dialect=arith",
],
"include/mlir/Dialect/Arithmetic/IR/ArithmeticOpsDialect.cpp.inc",
),
(
["-gen-enum-decls"],
"include/mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.h.inc",
),
(
["-gen-enum-defs"],
"include/mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td",
deps = [":ArithmeticOpsTdFiles"],
)
gentbl_cc_library(
name = "ArithmeticOpsIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-op-decls"],
"include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.h.inc",
),
(
["-gen-op-defs"],
"include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td",
deps = [
":ArithmeticOpsTdFiles",
":CastInterfacesTdFiles",
],
)
gentbl_cc_library(
name = "ArithmeticCanonicalizationIncGen",
strip_include_prefix = "include/mlir/Dialect/Arithmetic/IR",
tbl_outs = [
(
["-gen-rewriters"],
"include/mlir/Dialect/Arithmetic/IR/ArithmeticCanonicalization.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td",
deps = [
":ArithmeticOpsTdFiles",
":CastInterfacesTdFiles",
":StdOpsTdFiles",
],
)
cc_library(
name = "ArithmeticDialect",
srcs = glob(
[
"lib/Dialect/Arithmetic/IR/*.cpp",
"lib/Dialect/Arithmetic/IR/*.h",
],
),
hdrs = [
"include/mlir/Dialect/Arithmetic/IR/Arithmetic.h",
"include/mlir/Transforms/InliningUtils.h",
],
includes = ["include"],
deps = [
":ArithmeticBaseIncGen",
":ArithmeticCanonicalizationIncGen",
":ArithmeticOpsIncGen",
":CommonFolders",
":IR",
":SideEffectInterfaces",
":StandardOps",
":Support",
":VectorInterfaces",
"//llvm:Support",
],
)
td_library(
name = "MathOpsTdFiles",
srcs = [