[mlir] Add better support for f80 and f128

Add builtin f80 and f128 following @schweitz proposition
https://llvm.discourse.group/t/rfc-adding-better-support-for-higher-precision-floating-point/2526/5

Reviewed By: ftynse, rriddle

Differential Revision: https://reviews.llvm.org/D94737
This commit is contained in:
Valentin Clement 2021-01-15 10:29:37 -05:00 committed by clementval
parent 46aa3c6c33
commit cf0173de69
26 changed files with 110 additions and 55 deletions

View File

@ -29,6 +29,8 @@ following conversions are currently implemented:
- `f16` converts to `f16`
- `f32` converts to `f32`
- `f64` converts to `f64`
- `f80` converts to `f80`
- `f128` converts to `f128`
### Index Type

View File

@ -214,7 +214,8 @@ LLVM dialect accepts a subset of built-in types that are referred to as _LLVM
dialect-compatible types_. The following types are compatible:
- Signless integers - `iN` (`IntegerType`).
- Floating point types - `bfloat`, `half`, `float`, `double` (`FloatType`).
- Floating point types - `bfloat`, `half`, `float`, `double` , `f80`, `f128`
(`FloatType`).
- 1D vectors of signless integers or floating point types - `vector<NxT>`
(`VectorType`).
@ -228,9 +229,6 @@ compatibility check.
The following non-parametric types derived from the LLVM IR are available in the
LLVM dialect:
- `!llvm.fp128` (`LLVMFP128Type`) - 128-bit floating-point value as per
IEEE-754-2008.
- `!llvm.x86_fp80` (`LLVMX86FP80Type`) - 80-bit floating-point value (x87).
- `!llvm.x86_mmx` (`LLVMX86MMXType`) - value held in an MMX register on x86
machine.
- `!llvm.ppc_fp128` (`LLVMPPCFP128Type`) - 128-bit floating-point value (two

View File

@ -850,7 +850,7 @@ Syntax:
```
// Floating point.
float-type ::= `f16` | `bf16` | `f32` | `f64`
float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
```
MLIR supports float types of certain widths that are widely used as indicated

View File

@ -36,9 +36,6 @@ struct LLVMStructTypeStorage;
struct LLVMTypeAndSizeStorage;
} // namespace detail
class LLVMFP128Type;
class LLVMX86FP80Type;
//===----------------------------------------------------------------------===//
// Trivial types.
//===----------------------------------------------------------------------===//
@ -51,8 +48,6 @@ class LLVMX86FP80Type;
}
DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType);

View File

@ -66,6 +66,8 @@ public:
FloatType getF16Type();
FloatType getF32Type();
FloatType getF64Type();
FloatType getF80Type();
FloatType getF128Type();
IndexType getIndexType();

View File

@ -51,6 +51,8 @@ public:
static FloatType getF16(MLIRContext *ctx);
static FloatType getF32(MLIRContext *ctx);
static FloatType getF64(MLIRContext *ctx);
static FloatType getF80(MLIRContext *ctx);
static FloatType getF128(MLIRContext *ctx);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
@ -439,7 +441,8 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
}
inline bool FloatType::classof(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>();
return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
Float80Type, Float128Type>();
}
inline FloatType FloatType::getBF16(MLIRContext *ctx) {
@ -458,6 +461,14 @@ inline FloatType FloatType::getF64(MLIRContext *ctx) {
return Float64Type::get(ctx);
}
inline FloatType FloatType::getF80(MLIRContext *ctx) {
return Float80Type::get(ctx);
}
inline FloatType FloatType::getF128(MLIRContext *ctx) {
return Float128Type::get(ctx);
}
inline bool ShapedType::classof(Type type) {
return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
UnrankedMemRefType, MemRefType>();

View File

@ -101,6 +101,20 @@ def Builtin_Float64 : Builtin_FloatType<"Float64"> {
let summary = "64-bit floating-point type";
}
//===----------------------------------------------------------------------===//
// Float80Type
def Builtin_Float80 : Builtin_FloatType<"Float80"> {
let summary = "80-bit floating-point type";
}
//===----------------------------------------------------------------------===//
// Float128Type
def Builtin_Float128 : Builtin_FloatType<"Float128"> {
let summary = "128-bit floating-point type";
}
//===----------------------------------------------------------------------===//
// FunctionType
//===----------------------------------------------------------------------===//

View File

@ -466,6 +466,8 @@ class FloatOfWidths<list<int> widths> :
def F16 : F<16>;
def F32 : F<32>;
def F64 : F<64>;
def F80 : F<80>;
def F128 : F<128>;
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
BuildableType<"$_builder.getBF16Type()">;

View File

@ -118,6 +118,8 @@ public:
bool isF16() const;
bool isF32() const;
bool isF64() const;
bool isF80() const;
bool isF128() const;
/// Return true if this is an integer type with the specified width.
bool isInteger(unsigned width) const;

View File

@ -2042,8 +2042,6 @@ void LLVMDialect::initialize() {
// clang-format off
addTypes<LLVMVoidType,
LLVMFP128Type,
LLVMX86FP80Type,
LLVMPPCFP128Type,
LLVMX86MMXType,
LLVMTokenType,

View File

@ -33,8 +33,6 @@ static void dispatchPrint(DialectAsmPrinter &printer, Type type) {
static StringRef getTypeKeyword(Type type) {
return TypeSwitch<Type, StringRef>(type)
.Case<LLVMVoidType>([&](Type) { return "void"; })
.Case<LLVMFP128Type>([&](Type) { return "fp128"; })
.Case<LLVMX86FP80Type>([&](Type) { return "x86_fp80"; })
.Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; })
.Case<LLVMX86MMXType>([&](Type) { return "x86_mmx"; })
.Case<LLVMTokenType>([&](Type) { return "token"; })
@ -460,8 +458,16 @@ static Type dispatchParse(DialectAsmParser &parser, bool allowAny = true) {
emitWarning(loc) << "deprecated syntax, use f64 instead";
return Float64Type::get(ctx);
})
.Case("fp128", [&] { return LLVMFP128Type::get(ctx); })
.Case("x86_fp80", [&] { return LLVMX86FP80Type::get(ctx); })
.Case("fp128",
[&] {
emitWarning(loc) << "deprecated syntax, use f128 instead";
return Float128Type::get(ctx);
})
.Case("x86_fp80",
[&] {
emitWarning(loc) << "deprecated syntax, use f80 instead";
return Float80Type::get(ctx);
})
.Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
.Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); })
.Case("token", [&] { return LLVMTokenType::get(ctx); })

View File

@ -272,8 +272,7 @@ unsigned LLVMFixedVectorType::getNumElements() {
}
bool LLVMFixedVectorType::isValidElementType(Type type) {
return type
.isa<LLVMPointerType, LLVMX86FP80Type, LLVMFP128Type, LLVMPPCFP128Type>();
return type.isa<LLVMPointerType, LLVMPPCFP128Type>();
}
LogicalResult LLVMFixedVectorType::verifyConstructionInvariants(
@ -339,8 +338,9 @@ bool mlir::LLVM::isCompatibleType(Type type) {
Float16Type,
Float32Type,
Float64Type,
Float80Type,
Float128Type,
LLVMArrayType,
LLVMFP128Type,
LLVMFunctionType,
LLVMLabelType,
LLVMMetadataType,
@ -351,7 +351,6 @@ bool mlir::LLVM::isCompatibleType(Type type) {
LLVMFixedVectorType,
LLVMScalableVectorType,
LLVMVoidType,
LLVMX86FP80Type,
LLVMX86MMXType
>();
// clang-format on
@ -359,7 +358,7 @@ bool mlir::LLVM::isCompatibleType(Type type) {
bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
LLVMFP128Type, LLVMPPCFP128Type, LLVMX86FP80Type>();
Float80Type, Float128Type, LLVMPPCFP128Type>();
}
bool mlir::LLVM::isCompatibleVectorType(Type type) {
@ -372,8 +371,8 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
Type elementType = vecType.getElementType();
if (auto intType = elementType.dyn_cast<IntegerType>())
return intType.isSignless();
return elementType
.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>();
return elementType.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
Float80Type, Float128Type>();
}
return false;
}
@ -421,12 +420,12 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
.Case<Float32Type>([](Type) { return llvm::TypeSize::Fixed(32); })
.Case<Float64Type, LLVMX86MMXType>(
[](Type) { return llvm::TypeSize::Fixed(64); })
.Case<Float80Type>([](Type) { return llvm::TypeSize::Fixed(80); })
.Case<Float128Type>([](Type) { return llvm::TypeSize::Fixed(128); })
.Case<IntegerType>([](IntegerType intTy) {
return llvm::TypeSize::Fixed(intTy.getWidth());
})
.Case<LLVMX86FP80Type>([](Type) { return llvm::TypeSize::Fixed(80); })
.Case<LLVMPPCFP128Type, LLVMFP128Type>(
[](Type) { return llvm::TypeSize::Fixed(128); })
.Case<LLVMPPCFP128Type>([](Type) { return llvm::TypeSize::Fixed(128); })
.Case<LLVMFixedVectorType>([](LLVMFixedVectorType t) {
llvm::TypeSize elementSize =
getPrimitiveTypeSizeInBits(t.getElementType());

View File

@ -1816,6 +1816,8 @@ void ModulePrinter::printType(Type type) {
.Case<Float16Type>([&](Type) { os << "f16"; })
.Case<Float32Type>([&](Type) { os << "f32"; })
.Case<Float64Type>([&](Type) { os << "f64"; })
.Case<Float80Type>([&](Type) { os << "f80"; })
.Case<Float128Type>([&](Type) { os << "f128"; })
.Case<IntegerType>([&](IntegerType integerTy) {
if (integerTy.isSigned())
os << 's';

View File

@ -50,6 +50,10 @@ FloatType Builder::getF32Type() { return FloatType::getF32(context); }
FloatType Builder::getF64Type() { return FloatType::getF64(context); }
FloatType Builder::getF80Type() { return FloatType::getF80(context); }
FloatType Builder::getF128Type() { return FloatType::getF128(context); }
IndexType Builder::getIndexType() { return IndexType::get(context); }
IntegerType Builder::getI1Type() { return IntegerType::get(context, 1); }

View File

@ -50,9 +50,9 @@ struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
void BuiltinDialect::initialize() {
addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
FunctionType, IndexType, IntegerType, MemRefType, UnrankedMemRefType,
NoneType, OpaqueType, RankedTensorType, TupleType,
UnrankedTensorType, VectorType>();
Float80Type, Float128Type, FunctionType, IndexType, IntegerType,
MemRefType, UnrankedMemRefType, NoneType, OpaqueType,
RankedTensorType, TupleType, UnrankedTensorType, VectorType>();
addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
DenseStringElementsAttr, DictionaryAttr, FloatAttr,
SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,

View File

@ -80,6 +80,10 @@ unsigned FloatType::getWidth() {
return 32;
if (isa<Float64Type>())
return 64;
if (isa<Float80Type>())
return 80;
if (isa<Float128Type>())
return 128;
llvm_unreachable("unexpected float type");
}
@ -93,6 +97,10 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
return APFloat::IEEEsingle();
if (isa<Float64Type>())
return APFloat::IEEEdouble();
if (isa<Float80Type>())
return APFloat::x87DoubleExtended();
if (isa<Float128Type>())
return APFloat::IEEEquad();
llvm_unreachable("non-floating point type used");
}

View File

@ -303,6 +303,8 @@ public:
Float16Type f16Ty;
Float32Type f32Ty;
Float64Type f64Ty;
Float80Type f80Ty;
Float128Type f128Ty;
IndexType indexTy;
IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
NoneType noneType;
@ -351,6 +353,8 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
impl->f32Ty = TypeUniquer::get<Float32Type>(this);
impl->f64Ty = TypeUniquer::get<Float64Type>(this);
impl->f80Ty = TypeUniquer::get<Float80Type>(this);
impl->f128Ty = TypeUniquer::get<Float128Type>(this);
/// Index Type.
impl->indexTy = TypeUniquer::get<IndexType>(this);
/// Integer Types.
@ -739,6 +743,12 @@ Float32Type Float32Type::get(MLIRContext *context) {
Float64Type Float64Type::get(MLIRContext *context) {
return context->getImpl().f64Ty;
}
Float80Type Float80Type::get(MLIRContext *context) {
return context->getImpl().f80Ty;
}
Float128Type Float128Type::get(MLIRContext *context) {
return context->getImpl().f128Ty;
}
/// Get an instance of the IndexType.
IndexType IndexType::get(MLIRContext *context) {

View File

@ -26,6 +26,8 @@ bool Type::isBF16() const { return isa<BFloat16Type>(); }
bool Type::isF16() const { return isa<Float16Type>(); }
bool Type::isF32() const { return isa<Float32Type>(); }
bool Type::isF64() const { return isa<Float64Type>(); }
bool Type::isF80() const { return isa<Float80Type>(); }
bool Type::isF128() const { return isa<Float128Type>(); }
bool Type::isIndex() const { return isa<IndexType>(); }

View File

@ -85,6 +85,8 @@ TOK_KEYWORD(dense)
TOK_KEYWORD(f16)
TOK_KEYWORD(f32)
TOK_KEYWORD(f64)
TOK_KEYWORD(f80)
TOK_KEYWORD(f128)
TOK_KEYWORD(false)
TOK_KEYWORD(floordiv)
TOK_KEYWORD(for)

View File

@ -307,7 +307,7 @@ Type Parser::parseMemRefType() {
/// | none-type
///
/// index-type ::= `index`
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
/// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
/// none-type ::= `none`
///
Type Parser::parseNonFunctionType() {
@ -356,6 +356,12 @@ Type Parser::parseNonFunctionType() {
case Token::kw_f64:
consumeToken(Token::kw_f64);
return builder.getF64Type();
case Token::kw_f80:
consumeToken(Token::kw_f80);
return builder.getF80Type();
case Token::kw_f128:
consumeToken(Token::kw_f128);
return builder.getF128Type();
// index-type
case Token::kw_index:

View File

@ -49,12 +49,12 @@ public:
.Case([this](Float64Type) {
return llvm::Type::getDoubleTy(context);
})
.Case([this](LLVM::LLVMFP128Type) {
return llvm::Type::getFP128Ty(context);
})
.Case([this](LLVM::LLVMX86FP80Type) {
.Case([this](Float80Type) {
return llvm::Type::getX86_FP80Ty(context);
})
.Case([this](Float128Type) {
return llvm::Type::getFP128Ty(context);
})
.Case([this](LLVM::LLVMPPCFP128Type) {
return llvm::Type::getPPC_FP128Ty(context);
})
@ -230,9 +230,9 @@ private:
if (type->isDoubleTy())
return Float64Type::get(&context);
if (type->isFP128Ty())
return LLVM::LLVMFP128Type::get(&context);
return Float128Type::get(&context);
if (type->isX86_FP80Ty())
return LLVM::LLVMX86FP80Type::get(&context);
return Float80Type::get(&context);
if (type->isPPC_FP128Ty())
return LLVM::LLVMPPCFP128Type::get(&context);
if (type->isX86_MMXTy())

View File

@ -128,10 +128,10 @@ func @ops(%arg0: i32, %arg1: f32,
// Extended and Quad floating point
//
// CHECK: %{{.*}} = llvm.fpext %[[FLOAT]] : f32 to !llvm.x86_fp80
// CHECK: %{{.*}} = llvm.fpext %[[FLOAT]] : f32 to !llvm.fp128
%27 = llvm.fpext %arg1 : f32 to !llvm.x86_fp80
%28 = llvm.fpext %arg1 : f32 to !llvm.fp128
// CHECK: %{{.*}} = llvm.fpext %[[FLOAT]] : f32 to f80
// CHECK: %{{.*}} = llvm.fpext %[[FLOAT]] : f32 to f128
%27 = llvm.fpext %arg1 : f32 to f80
%28 = llvm.fpext %arg1 : f32 to f128
// CHECK: %{{.*}} = llvm.fneg %[[FLOAT]] : f32
%29 = llvm.fneg %arg1 : f32

View File

@ -4,18 +4,6 @@
func @primitive() {
// CHECK: !llvm.void
"some.op"() : () -> !llvm.void
// CHECK: f16
"some.op"() : () -> f16
// CHECK: bf16
"some.op"() : () -> bf16
// CHECK: f32
"some.op"() : () -> f32
// CHECK: f64
"some.op"() : () -> f64
// CHECK: !llvm.fp128
"some.op"() : () -> !llvm.fp128
// CHECK: !llvm.x86_fp80
"some.op"() : () -> !llvm.x86_fp80
// CHECK: !llvm.ppc_fp128
"some.op"() : () -> !llvm.ppc_fp128
// CHECK: !llvm.x86_mmx

View File

@ -67,6 +67,8 @@ func private @sint_types(si2, si4) -> (si7, si1023)
// CHECK: func private @uint_types(ui2, ui4) -> (ui7, ui1023)
func private @uint_types(ui2, ui4) -> (ui7, ui1023)
// CHECK: func private @float_types(f80, f128)
func private @float_types(f80, f128)
// CHECK: func private @vectors(vector<1xf32>, vector<2x4xf32>)
func private @vectors(vector<1 x f32>, vector<2x4xf32>)

View File

@ -15,9 +15,9 @@ llvm.func @return_float() -> f32
// CHECK: declare double @return_double()
llvm.func @return_double() -> f64
// CHECK: declare fp128 @return_fp128()
llvm.func @return_fp128() -> !llvm.fp128
llvm.func @return_fp128() -> f128
// CHECK: declare x86_fp80 @return_x86_fp80()
llvm.func @return_x86_fp80() -> !llvm.x86_fp80
llvm.func @return_x86_fp80() -> f80
// CHECK: declare ppc_fp128 @return_ppc_fp128()
llvm.func @return_ppc_fp128() -> !llvm.ppc_fp128
// CHECK: declare x86_mmx @return_x86_mmx()

View File

@ -203,6 +203,8 @@ for name in [
'Float16Type',
'Float32Type',
'Float64Type',
'Float80Type',
'Float128Type',
'NoneType',
'VectorType',
'RankedTensorType',