diff --git a/Cargo.lock b/Cargo.lock index ea6aee7f7..223f515b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -612,6 +612,7 @@ dependencies = [ "burn-common", "burn-tensor", "derive-new", + "half", "hashbrown 0.14.5", "log", "serde", diff --git a/crates/burn-fusion/Cargo.toml b/crates/burn-fusion/Cargo.toml index 4ce3f2981..e3f74bcce 100644 --- a/crates/burn-fusion/Cargo.toml +++ b/crates/burn-fusion/Cargo.toml @@ -23,6 +23,7 @@ derive-new = {workspace = true } spin = { workspace = true } log = { workspace = true } serde = { workspace = true } +half = { workspace = true } [package.metadata.docs.rs] features = ["doc"] diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 676992ccb..c7a505f90 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -62,7 +62,10 @@ impl FloatTensorOps for Fusion { }; client.register( vec![stream], - OperationDescription::Float(FloatOperationDescription::Random(desc.clone())), + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Random(desc.clone()), + ), RandomOps::::new(desc, device.clone()), ); @@ -92,7 +95,10 @@ impl FloatTensorOps for Fusion { let desc = out.to_description_out(); client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::Zeros(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Zeros(desc.clone()), + ), ZerosOps::::new(desc, device.clone()), ); @@ -122,7 +128,10 @@ impl FloatTensorOps for Fusion { let desc = out.to_description_out(); client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::Ones(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Ones(desc.clone()), + ), OnesOps::::new(desc, device.clone()), ); @@ -158,7 +167,10 @@ impl FloatTensorOps for Fusion { let desc = (out.to_description_out(), fill_value.elem::()); client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::Full(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Full(desc.clone()), + ), FullOps::::new(desc.0, desc.1, device.clone()), ); @@ -226,7 +238,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float(FloatOperationDescription::IntoInt(desc.clone())), + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::IntoInt(desc.clone()), + ), IntoIntOps::::new(desc), ); @@ -267,7 +282,10 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericFloat(NumericOperationDescription::Add(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Add(desc.clone()), + ), AddOps::::new(desc), ); @@ -292,9 +310,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::AddScalar( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::AddScalar(desc.clone()), + ), AddOps::::new(desc), ); @@ -334,7 +353,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::Clamp(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Clamp(desc.clone()), + ), ClampOps::::new(desc), ); @@ -361,7 +383,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericFloat(NumericOperationDescription::Sub(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Sub(desc.clone()), + ), SubOps::::new(desc), ); @@ -386,9 +411,10 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::SubScalar( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::SubScalar(desc.clone()), + ), SubOps::::new(desc), ); @@ -415,7 +441,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericFloat(NumericOperationDescription::Mul(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Mul(desc.clone()), + ), MulOps::::new(desc), ); @@ -440,9 +469,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::MulScalar( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::MulScalar(desc.clone()), + ), MulOps::::new(desc), ); @@ -469,7 +499,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericFloat(NumericOperationDescription::Div(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Div(desc.clone()), + ), DivOps::::new(desc), ); @@ -494,9 +527,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::DivScalar( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::DivScalar(desc.clone()), + ), DivOps::::new(desc), ); @@ -521,9 +555,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::RemScalar( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::RemScalar(desc.clone()), + ), ModOps::::new(desc), ); @@ -554,7 +589,10 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2], - OperationDescription::Float(FloatOperationDescription::Matmul(desc.clone())), + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Matmul(desc.clone()), + ), MatmulOps::::new(desc), ); @@ -680,7 +718,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericFloat(NumericOperationDescription::Gather(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Gather(desc.clone()), + ), GatherOps::::new(desc), ); @@ -729,7 +770,10 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream_1, stream_2, stream_3], - OperationDescription::NumericFloat(NumericOperationDescription::Scatter(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Scatter(desc.clone()), + ), ScatterOps::::new(desc), ); @@ -773,7 +817,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericFloat(NumericOperationDescription::Select(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Select(desc.clone()), + ), SelectOps::::new(desc), ); @@ -821,9 +868,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2, stream_3], - OperationDescription::NumericFloat(NumericOperationDescription::SelectAssign( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::SelectAssign(desc.clone()), + ), SelectAssignOps::::new(desc), ); @@ -966,9 +1014,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2, stream_3], - OperationDescription::NumericFloat(NumericOperationDescription::MaskWhere( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::MaskWhere(desc.clone()), + ), MaskWhereOps::::new(desc), ); @@ -1011,7 +1060,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericFloat(NumericOperationDescription::MaskFill(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::MaskFill(desc.clone()), + ), MaskFillOps::::new(desc), ); @@ -1062,9 +1114,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::EqualElem( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::EqualElem(desc.clone()), + ), EqualElemOps::::new(desc), ); @@ -1090,7 +1143,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericFloat(NumericOperationDescription::Greater(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Greater(desc.clone()), + ), GreaterOps::::new(desc), ); @@ -1115,9 +1171,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::GreaterElem( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::GreaterElem(desc.clone()), + ), GreaterElemOps::::new(desc), ); @@ -1143,9 +1200,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericFloat(NumericOperationDescription::GreaterEqual( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::GreaterEqual(desc.clone()), + ), GreaterEqualOps::::new(desc), ); @@ -1170,9 +1228,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::GreaterEqualElem( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::GreaterEqualElem(desc.clone()), + ), GreaterEqualElemOps::::new(desc), ); @@ -1198,7 +1257,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericFloat(NumericOperationDescription::Lower(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Lower(desc.clone()), + ), LowerOps::::new(desc), ); @@ -1223,9 +1285,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::LowerElem( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::LowerElem(desc.clone()), + ), LowerElemOps::::new(desc), ); @@ -1251,9 +1314,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericFloat(NumericOperationDescription::LowerEqual( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::LowerEqual(desc.clone()), + ), LowerEqualOps::::new(desc), ); @@ -1278,9 +1342,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::LowerEqualElem( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::LowerEqualElem(desc.clone()), + ), LowerEqualElemOps::::new(desc), ); @@ -1301,7 +1366,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::Sum(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Sum(desc.clone()), + ), SumOps::::new(desc), ); @@ -1328,7 +1396,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::SumDim(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::SumDim(desc.clone()), + ), SumDimOps::::new(desc), ); @@ -1349,7 +1420,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::Mean(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Mean(desc.clone()), + ), MeanOps::::new(desc), ); @@ -1376,7 +1450,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::MeanDim(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::MeanDim(desc.clone()), + ), MeanDimOps::::new(desc), ); @@ -1397,7 +1474,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float(FloatOperationDescription::Exp(desc.clone())), + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Exp(desc.clone()), + ), ExpOps::::new(desc), ); @@ -1418,7 +1498,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float(FloatOperationDescription::Log(desc.clone())), + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Log(desc.clone()), + ), LogOps::::new(desc), ); @@ -1439,7 +1522,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float(FloatOperationDescription::Log1p(desc.clone())), + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Log1p(desc.clone()), + ), Log1pOps::::new(desc), ); @@ -1464,7 +1550,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float(FloatOperationDescription::PowfScalar(desc.clone())), + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::PowfScalar(desc.clone()), + ), PowfOps::::new(desc), ); @@ -1485,7 +1574,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float(FloatOperationDescription::Sqrt(desc.clone())), + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Sqrt(desc.clone()), + ), SqrtOps::::new(desc), ); @@ -1506,7 +1598,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::Abs(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Abs(desc.clone()), + ), AbsOps::::new(desc), ); @@ -1527,7 +1622,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float(FloatOperationDescription::Cos(desc.clone())), + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Cos(desc.clone()), + ), CosOps::::new(desc), ); @@ -1548,7 +1646,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float(FloatOperationDescription::Sin(desc.clone())), + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Sin(desc.clone()), + ), SinOps::::new(desc), ); @@ -1569,7 +1670,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float(FloatOperationDescription::Tanh(desc.clone())), + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Tanh(desc.clone()), + ), TanhOps::::new(desc), ); @@ -1589,7 +1693,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float(FloatOperationDescription::Recip(desc.clone())), + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Recip(desc.clone()), + ), Recip::::new(desc), ); @@ -1610,7 +1717,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Float(FloatOperationDescription::Erf(desc.clone())), + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Erf(desc.clone()), + ), TanhOps::::new(desc), ); @@ -1689,7 +1799,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::ArgMax(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::ArgMax(desc.clone()), + ), ArgMaxOps::::new(desc), ); @@ -1759,7 +1872,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::ArgMin(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::ArgMin(desc.clone()), + ), ArgMinOps::::new(desc), ); @@ -1780,7 +1896,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::Max(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Max(desc.clone()), + ), MaxOps::::new(desc), ); @@ -1807,7 +1926,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::MaxDim(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::MaxDim(desc.clone()), + ), MaxDimOps::::new(desc), ); @@ -1849,9 +1971,10 @@ impl FloatTensorOps for Fusion { }; client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::MaxDimWithIndices( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::MaxDimWithIndices(desc.clone()), + ), MaxDimWithIndicesOps::::new(desc), ); @@ -1872,7 +1995,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::Min(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Min(desc.clone()), + ), MinOps::::new(desc), ); @@ -1899,7 +2025,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::MinDim(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::MinDim(desc.clone()), + ), MinDimOps::::new(desc), ); @@ -1941,9 +2070,10 @@ impl FloatTensorOps for Fusion { }; client.register( vec![stream], - OperationDescription::NumericFloat(NumericOperationDescription::MinDimWithIndices( - desc.clone(), - )), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::MinDimWithIndices(desc.clone()), + ), MinDimWithIndicesOps::::new(desc), ); @@ -1970,7 +2100,10 @@ impl FloatTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericFloat(NumericOperationDescription::Powf(desc.clone())), + OperationDescription::NumericFloat( + FloatElem::::dtype(), + NumericOperationDescription::Powf(desc.clone()), + ), PowOps::::new(desc), ); diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index 213987f94..aeded1ec4 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -255,7 +255,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2, stream_3], - OperationDescription::NumericInt(NumericOperationDescription::MaskWhere(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::MaskWhere(desc.clone()), + ), MaskWhereOps::::new(desc), ); @@ -298,7 +301,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericInt(NumericOperationDescription::MaskFill(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::MaskFill(desc.clone()), + ), MaskFillOps::::new(desc), ); @@ -340,7 +346,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericInt(NumericOperationDescription::Gather(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Gather(desc.clone()), + ), GatherOps::::new(desc), ); @@ -387,7 +396,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2, stream_3], - OperationDescription::NumericInt(NumericOperationDescription::Scatter(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Scatter(desc.clone()), + ), ScatterOps::::new(desc), ); @@ -431,7 +443,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericInt(NumericOperationDescription::Select(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Select(desc.clone()), + ), SelectOps::::new(desc), ); @@ -478,9 +493,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2, stream_3], - OperationDescription::NumericInt(NumericOperationDescription::SelectAssign( - desc.clone(), - )), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::SelectAssign(desc.clone()), + ), SelectAssignOps::::new(desc), ); @@ -580,7 +596,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::EqualElem(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::EqualElem(desc.clone()), + ), EqualElemOps::::new(desc), ); @@ -606,7 +625,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericInt(NumericOperationDescription::Greater(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Greater(desc.clone()), + ), GreaterOps::::new(desc), ); @@ -631,9 +653,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::GreaterElem( - desc.clone(), - )), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::GreaterElem(desc.clone()), + ), GreaterElemOps::::new(desc), ); @@ -659,9 +682,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericInt(NumericOperationDescription::GreaterEqual( - desc.clone(), - )), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::GreaterEqual(desc.clone()), + ), GreaterEqualOps::::new(desc), ); @@ -686,9 +710,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::GreaterEqualElem( - desc.clone(), - )), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::GreaterEqualElem(desc.clone()), + ), GreaterEqualElemOps::::new(desc), ); @@ -714,7 +739,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericInt(NumericOperationDescription::Lower(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Lower(desc.clone()), + ), LowerOps::::new(desc), ); @@ -739,7 +767,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::LowerElem(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::LowerElem(desc.clone()), + ), LowerElemOps::::new(desc), ); @@ -765,7 +796,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::NumericInt(NumericOperationDescription::LowerEqual(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::LowerEqual(desc.clone()), + ), LowerEqualOps::::new(desc), ); @@ -790,9 +824,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::LowerEqualElem( - desc.clone(), - )), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::LowerEqualElem(desc.clone()), + ), LowerEqualElemOps::::new(desc), ); @@ -819,7 +854,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - repr::OperationDescription::NumericInt(NumericOperationDescription::Add(desc.clone())), + repr::OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Add(desc.clone()), + ), AddOps::::new(desc), ); @@ -844,9 +882,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - repr::OperationDescription::NumericInt(NumericOperationDescription::AddScalar( - desc.clone(), - )), + repr::OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::AddScalar(desc.clone()), + ), AddOps::::new(desc), ); @@ -873,7 +912,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - repr::OperationDescription::NumericInt(NumericOperationDescription::Sub(desc.clone())), + repr::OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Sub(desc.clone()), + ), SubOps::::new(desc), ); @@ -898,9 +940,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - repr::OperationDescription::NumericInt(NumericOperationDescription::SubScalar( - desc.clone(), - )), + repr::OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::SubScalar(desc.clone()), + ), SubOps::::new(desc), ); @@ -927,7 +970,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - repr::OperationDescription::NumericInt(NumericOperationDescription::Mul(desc.clone())), + repr::OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Mul(desc.clone()), + ), MulOps::::new(desc), ); @@ -952,9 +998,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - repr::OperationDescription::NumericInt(NumericOperationDescription::MulScalar( - desc.clone(), - )), + repr::OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::MulScalar(desc.clone()), + ), MulOps::::new(desc), ); @@ -981,7 +1028,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - repr::OperationDescription::NumericInt(NumericOperationDescription::Div(desc.clone())), + repr::OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Div(desc.clone()), + ), DivOps::::new(desc), ); @@ -1006,9 +1056,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - repr::OperationDescription::NumericInt(NumericOperationDescription::DivScalar( - desc.clone(), - )), + repr::OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::DivScalar(desc.clone()), + ), DivOps::::new(desc), ); @@ -1033,9 +1084,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - repr::OperationDescription::NumericInt(NumericOperationDescription::RemScalar( - desc.clone(), - )), + repr::OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::RemScalar(desc.clone()), + ), ModOps::::new(desc), ); @@ -1064,7 +1116,10 @@ impl IntTensorOps for Fusion { let desc = out.to_description_out(); client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::Zeros(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Zeros(desc.clone()), + ), ZerosOps::::new(desc, device.clone()), ); @@ -1094,7 +1149,10 @@ impl IntTensorOps for Fusion { let desc = out.to_description_out(); client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::Ones(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Ones(desc.clone()), + ), OnesOps::::new(desc, device.clone()), ); @@ -1115,7 +1173,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::Sum(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Sum(desc.clone()), + ), SumOps::::new(desc), ); @@ -1139,7 +1200,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::SumDim(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::SumDim(desc.clone()), + ), SumDimOps::::new(desc), ); @@ -1160,7 +1224,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::Prod(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Prod(desc.clone()), + ), ProdOps::::new(desc), ); @@ -1184,7 +1251,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::ProdDim(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::ProdDim(desc.clone()), + ), ProdDimOps::::new(desc), ); @@ -1205,7 +1275,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::Mean(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Mean(desc.clone()), + ), MeanOps::::new(desc), ); @@ -1229,7 +1302,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::MeanDim(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::MeanDim(desc.clone()), + ), MeanDimOps::::new(desc), ); @@ -1253,7 +1329,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::ArgMax(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::ArgMax(desc.clone()), + ), ArgMaxOps::::new(desc), ); @@ -1277,7 +1356,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::ArgMin(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::ArgMin(desc.clone()), + ), ArgMinOps::::new(desc), ); @@ -1316,7 +1398,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::Clamp(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Clamp(desc.clone()), + ), ClampOps::::new(desc), ); @@ -1337,7 +1422,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::Abs(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Abs(desc.clone()), + ), AbsOps::::new(desc), ); @@ -1433,7 +1521,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::Max(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Max(desc.clone()), + ), MaxOps::::new(desc), ); @@ -1457,7 +1548,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::MaxDim(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::MaxDim(desc.clone()), + ), MaxDimOps::::new(desc), ); @@ -1498,9 +1592,10 @@ impl IntTensorOps for Fusion { }; client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::MaxDimWithIndices( - desc.clone(), - )), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::MaxDimWithIndices(desc.clone()), + ), MaxDimWithIndicesOps::::new(desc), ); @@ -1521,7 +1616,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::Min(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::Min(desc.clone()), + ), MinOps::::new(desc), ); @@ -1545,7 +1643,10 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::MinDim(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::MinDim(desc.clone()), + ), MinDimOps::::new(desc), ); @@ -1586,9 +1687,10 @@ impl IntTensorOps for Fusion { }; client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::MinDimWithIndices( - desc.clone(), - )), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::MinDimWithIndices(desc.clone()), + ), MinDimWithIndicesOps::::new(desc), ); @@ -1626,7 +1728,10 @@ impl IntTensorOps for Fusion { }; client.register( vec![stream], - OperationDescription::NumericInt(NumericOperationDescription::IntRandom(desc.clone())), + OperationDescription::NumericInt( + IntElem::::dtype(), + NumericOperationDescription::IntRandom(desc.clone()), + ), IntRandomOps::::new(desc, device.clone()), ); diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index e6012a07d..3efc35b4d 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -1,4 +1,5 @@ -use burn_tensor::{repr::*, Element, ElementConversion}; +use burn_tensor::{repr::*, DType, Element, ElementConversion}; +use half::{bf16, f16}; use hashbrown::HashMap; /// The context contains the relative graph tensor mapping so that a relative tensor id can be @@ -13,8 +14,12 @@ pub struct Context<'a, H> { pub tensors: &'a HashMap, /// Handle container to retrieve tensors based on their description. pub handles: &'a mut HandleContainer, - /// Float scalars found in the graph in the order they appeared. - pub scalar_floats: &'a Vec, + /// F32 scalars found in the graph in the order they appeared. + pub scalar_f32: &'a Vec, + /// F16 scalars found in the graph in the order they appeared. + pub scalar_f16: &'a Vec, + /// BF16 scalars found in the graph in the order they appeared. + pub scalar_bf16: &'a Vec, /// Int scalars found in the graph in the order they appeared. pub scalar_ints: &'a Vec, } @@ -26,7 +31,9 @@ pub(crate) struct OperationConverter { /// Only useful to create new shape ID. /// You should use tensor descriptions to retrieve the proper shape. shapes_global2relative: HashMap, - scalar_floats: Vec, + scalar_f32: Vec, + scalar_f16: Vec, + scalar_bf16: Vec, scalar_ints: Vec, } @@ -45,7 +52,9 @@ impl OperationConverter { Context { handles, tensors: &self.tensors_relative2global, - scalar_floats: &self.scalar_floats, + scalar_f32: &self.scalar_f32, + scalar_f16: &self.scalar_f16, + scalar_bf16: &self.scalar_bf16, scalar_ints: &self.scalar_ints, } } @@ -54,12 +63,20 @@ impl OperationConverter { self.tensors_relative2global.clear(); self.tensors_global2relative.clear(); self.shapes_global2relative.clear(); - self.scalar_floats.clear(); + self.scalar_f32.clear(); + self.scalar_f16.clear(); + self.scalar_bf16.clear(); self.scalar_ints.clear(); } - pub(crate) fn relative_float(&mut self, elem: &E) -> E { - self.scalar_floats.push(elem.elem()); + pub(crate) fn relative_float(&mut self, elem: &E, dtype: &DType) -> E { + match dtype { + burn_tensor::DType::F32 => self.scalar_f32.push(elem.elem()), + burn_tensor::DType::F16 => self.scalar_f16.push(elem.elem()), + burn_tensor::DType::BF16 => self.scalar_bf16.push(elem.elem()), + _ => todo!("Unsupported"), + } + // We return 0 so that the id from a scalar operation is the same no matter its scalar // value. 0.elem() @@ -85,19 +102,24 @@ impl RelativeOps for OperationDescription { OperationDescription::BaseBool(ops) => { OperationDescription::BaseBool(ops.to_relative(converter)) } - OperationDescription::NumericFloat(ops) => OperationDescription::NumericFloat( - ops.to_relative(converter, |converter, e| converter.relative_float(e)), + OperationDescription::NumericFloat(dtype, ops) => OperationDescription::NumericFloat( + *dtype, + ops.to_relative(converter, |converter, e| converter.relative_float(e, dtype)), ), - OperationDescription::NumericInt(ops) => OperationDescription::NumericInt( + OperationDescription::NumericInt(dtype, ops) => OperationDescription::NumericInt( + *dtype, ops.to_relative(converter, |converter, e| converter.relative_int(e)), ), OperationDescription::Bool(ops) => { OperationDescription::Bool(ops.to_relative(converter)) } OperationDescription::Int(ops) => OperationDescription::Int(ops.to_relative(converter)), - OperationDescription::Float(ops) => { - OperationDescription::Float(ops.to_relative(converter)) - } + OperationDescription::Float(dtype, ops) => OperationDescription::Float( + *dtype, + RelativeOpsScalar::::to_relative(ops, converter, |converter, e| { + converter.relative_float(e, dtype) + }), + ), OperationDescription::Module(ops) => { OperationDescription::Module(ops.to_relative(converter)) } @@ -342,8 +364,11 @@ impl RelativeOps for ModuleOperationDescription { } } -impl RelativeOps for FloatOperationDescription { - fn to_relative(&self, converter: &mut OperationConverter) -> Self { +impl RelativeOpsScalar for FloatOperationDescription { + fn to_relative(&self, converter: &mut OperationConverter, local_elem: F) -> Self + where + F: Fn(&mut OperationConverter, &f32) -> f32, + { match self { FloatOperationDescription::Exp(desc) => { FloatOperationDescription::Exp(UnaryOperationDescription { @@ -372,7 +397,7 @@ impl RelativeOps for FloatOperationDescription { FloatOperationDescription::PowfScalar(desc) => { FloatOperationDescription::PowfScalar(ScalarOperationDescription { lhs: desc.lhs.to_relative(converter), - rhs: converter.relative_float(&desc.rhs), + rhs: local_elem(converter, &desc.rhs.elem()), out: desc.out.to_relative(converter), }) } diff --git a/crates/burn-fusion/src/stream/execution/policy.rs b/crates/burn-fusion/src/stream/execution/policy.rs index 5424ec536..9eb8bab5e 100644 --- a/crates/burn-fusion/src/stream/execution/policy.rs +++ b/crates/burn-fusion/src/stream/execution/policy.rs @@ -549,10 +549,10 @@ mod tests { // Out node. self.new_empty_node(out_id); - self.operations - .push(OperationDescription::Float(FloatOperationDescription::Log( - self.unary_description(), - ))); + self.operations.push(OperationDescription::Float( + DType::F32, + FloatOperationDescription::Log(self.unary_description()), + )); } fn new_empty_node(&mut self, id: u64) { diff --git a/crates/burn-fusion/src/stream/execution/tests.rs b/crates/burn-fusion/src/stream/execution/tests.rs index 0e13546e1..3bef10a7a 100644 --- a/crates/burn-fusion/src/stream/execution/tests.rs +++ b/crates/burn-fusion/src/stream/execution/tests.rs @@ -520,8 +520,9 @@ impl<'i> StreamSegment for TestSegment<'i> { /// Just a simple operation. fn operation_1() -> OperationDescription { - OperationDescription::NumericFloat(NumericOperationDescription::Add( - BinaryOperationDescription { + OperationDescription::NumericFloat( + DType::F32, + NumericOperationDescription::Add(BinaryOperationDescription { lhs: TensorDescription { id: TensorId::new(0), shape: vec![32, 32], @@ -540,14 +541,15 @@ fn operation_1() -> OperationDescription { status: TensorStatus::NotInit, dtype: DType::F32, }, - }, - )) + }), + ) } /// Just a simple operation. fn operation_2() -> OperationDescription { - OperationDescription::NumericFloat(NumericOperationDescription::AddScalar( - ScalarOperationDescription { + OperationDescription::NumericFloat( + DType::F32, + NumericOperationDescription::AddScalar(ScalarOperationDescription { lhs: TensorDescription { id: TensorId::new(0), shape: vec![32, 32], @@ -561,24 +563,27 @@ fn operation_2() -> OperationDescription { status: TensorStatus::NotInit, dtype: DType::F32, }, - }, - )) + }), + ) } /// Just a simple operation. fn operation_3() -> OperationDescription { - OperationDescription::Float(FloatOperationDescription::Log(UnaryOperationDescription { - input: TensorDescription { - id: TensorId::new(0), - shape: vec![32, 32], - status: TensorStatus::ReadOnly, - dtype: DType::F32, - }, - out: TensorDescription { - id: TensorId::new(0), - shape: vec![32, 32], - status: TensorStatus::NotInit, - dtype: DType::F32, - }, - })) + OperationDescription::Float( + DType::F32, + FloatOperationDescription::Log(UnaryOperationDescription { + input: TensorDescription { + id: TensorId::new(0), + shape: vec![32, 32], + status: TensorStatus::ReadOnly, + dtype: DType::F32, + }, + out: TensorDescription { + id: TensorId::new(0), + shape: vec![32, 32], + status: TensorStatus::NotInit, + dtype: DType::F32, + }, + }), + ) } diff --git a/crates/burn-fusion/src/stream/store/index.rs b/crates/burn-fusion/src/stream/store/index.rs index 15bb56341..630148cd4 100644 --- a/crates/burn-fusion/src/stream/store/index.rs +++ b/crates/burn-fusion/src/stream/store/index.rs @@ -218,8 +218,9 @@ mod tests { } fn ops_1() -> OperationDescription { - OperationDescription::NumericFloat(NumericOperationDescription::Add( - BinaryOperationDescription { + OperationDescription::NumericFloat( + DType::F32, + NumericOperationDescription::Add(BinaryOperationDescription { lhs: TensorDescription { id: TensorId::new(0), shape: vec![32, 32], @@ -238,13 +239,14 @@ mod tests { status: TensorStatus::NotInit, dtype: DType::F32, }, - }, - )) + }), + ) } fn ops_2() -> OperationDescription { - OperationDescription::NumericFloat(NumericOperationDescription::AddScalar( - ScalarOperationDescription { + OperationDescription::NumericFloat( + DType::F32, + NumericOperationDescription::AddScalar(ScalarOperationDescription { lhs: TensorDescription { id: TensorId::new(0), shape: vec![32, 32], @@ -258,13 +260,14 @@ mod tests { status: TensorStatus::NotInit, dtype: DType::F32, }, - }, - )) + }), + ) } fn ops_3() -> OperationDescription { - OperationDescription::NumericFloat(NumericOperationDescription::Sub( - BinaryOperationDescription { + OperationDescription::NumericFloat( + DType::F32, + NumericOperationDescription::Sub(BinaryOperationDescription { lhs: TensorDescription { id: TensorId::new(0), shape: vec![32, 32], @@ -283,7 +286,7 @@ mod tests { status: TensorStatus::NotInit, dtype: DType::F32, }, - }, - )) + }), + ) } } diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 2c5dddc6c..7416dabdb 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -46,19 +46,19 @@ impl OptimizationBuilder> for ElementWiseBuild return; } } - OperationDescription::Float(ops) => { + OperationDescription::Float(_dtype, ops) => { if !self.register_float(ops) { self.status = OptimizationStatus::Closed; return; } } - OperationDescription::NumericFloat(ops) => { + OperationDescription::NumericFloat(_dtype, ops) => { if !self.register_numeric::(ops) { self.status = OptimizationStatus::Closed; return; } } - OperationDescription::NumericInt(ops) => { + OperationDescription::NumericInt(_dtype, ops) => { if !self.register_numeric::(ops) { self.status = OptimizationStatus::Closed; return; @@ -390,8 +390,9 @@ impl ElementWiseBuilder { return false; } + let elem = desc.lhs.dtype.into(); let lhs = self.builder.input(&desc.lhs, Variable::AbsolutePos); - let rhs = self.builder.scalar(&desc.rhs, desc.lhs.dtype.into()); + let rhs = self.builder.scalar(&desc.rhs, elem); let out = self.builder.output(&desc.out, Variable::AbsolutePos); self.builder.register_operation(func(lhs, rhs, out)); diff --git a/crates/burn-jit/src/fusion/kernel.rs b/crates/burn-jit/src/fusion/kernel.rs index da19180ed..f504c83a4 100644 --- a/crates/burn-jit/src/fusion/kernel.rs +++ b/crates/burn-jit/src/fusion/kernel.rs @@ -148,7 +148,13 @@ impl FusionKernel { let info_size = (num_tensors * rank * 2) + 1; let mut num_handles = num_tensors + 1; - if running_info.scalars.num_float > 0 { + if running_info.scalars.num_f32 > 0 { + num_handles += 1; + } + if running_info.scalars.num_f16 > 0 { + num_handles += 1; + } + if running_info.scalars.num_bf16 > 0 { num_handles += 1; } if running_info.scalars.num_int > 0 { @@ -216,14 +222,20 @@ impl FusionKernel { bindings.push(client.create(bytemuck::cast_slice(&info)).binding()); // Finally we finish with the named bindings. - if running_info.scalars.num_float > 0 { - bindings.push( - client - .create(bytemuck::cast_slice( - &context.scalar_floats[0..running_info.scalars.num_float], - )) - .binding(), - ); + if running_info.scalars.num_f32 > 0 { + let bytes = bytemuck::cast_slice(&context.scalar_f32[0..running_info.scalars.num_f32]); + bindings.push(client.create(bytes).binding()); + } + + if running_info.scalars.num_f16 > 0 { + let bytes = bytemuck::cast_slice(&context.scalar_f16[0..running_info.scalars.num_f16]); + bindings.push(client.create(bytes).binding()); + } + + if running_info.scalars.num_bf16 > 0 { + let bytes = + bytemuck::cast_slice(&context.scalar_bf16[0..running_info.scalars.num_bf16]); + bindings.push(client.create(bytes).binding()); } if running_info.scalars.num_int > 0 { diff --git a/crates/burn-jit/src/fusion/tracing/base.rs b/crates/burn-jit/src/fusion/tracing/base.rs index 5e2481f94..8f971f150 100644 --- a/crates/burn-jit/src/fusion/tracing/base.rs +++ b/crates/burn-jit/src/fusion/tracing/base.rs @@ -1,7 +1,10 @@ use serde::{Deserialize, Serialize}; + #[derive(Default, Clone, Serialize, Deserialize)] pub struct Scalars { - pub(crate) num_float: usize, + pub(crate) num_f32: usize, + pub(crate) num_f16: usize, + pub(crate) num_bf16: usize, pub(crate) num_int: usize, pub(crate) num_uint: usize, pub(crate) num_bool: usize, diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index 2387534a6..44ea2cab7 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -103,13 +103,33 @@ impl TraceBuilder { /// Create a variable from an input [scalar](Element). pub fn scalar(&mut self, _value: &E, elem_type: Elem) -> Variable { match elem_type { - Elem::Float(_) => { - let var = self - .scope - .read_scalar(self.scalars.num_float as u16, elem_type); - self.scalars.num_float += 1; - var - } + Elem::Float(kind) => match kind { + cubecl::ir::FloatKind::F16 => { + let var = self + .scope + .read_scalar(self.scalars.num_f16 as u16, elem_type); + + self.scalars.num_f16 += 1; + var + } + cubecl::ir::FloatKind::F32 => { + let var = self + .scope + .read_scalar(self.scalars.num_f32 as u16, elem_type); + + self.scalars.num_f32 += 1; + var + } + cubecl::ir::FloatKind::BF16 => { + let var = self + .scope + .read_scalar(self.scalars.num_bf16 as u16, elem_type); + + self.scalars.num_bf16 += 1; + var + } + cubecl::ir::FloatKind::F64 => todo!(), + }, Elem::Int(_) => { let var = self .scope diff --git a/crates/burn-jit/src/fusion/tracing/trace.rs b/crates/burn-jit/src/fusion/tracing/trace.rs index 2d008f52c..68ea6e040 100644 --- a/crates/burn-jit/src/fusion/tracing/trace.rs +++ b/crates/burn-jit/src/fusion/tracing/trace.rs @@ -1,7 +1,7 @@ use super::Scalars; use burn_tensor::repr::TensorDescription; use cubecl::{ - ir::{Elem, FloatKind, IntKind, Item, Scope, Variable, Visibility}, + ir::{Elem, IntKind, Item, Scope, Variable, Visibility}, InputInfo, KernelExpansion, OutputInfo, }; use serde::{Deserialize, Serialize}; @@ -64,10 +64,23 @@ impl Trace { .collect::>(); // NOTE: we might want to pass a struct including all inputs/outputs metadata instead of 3 arrays - if self.scalars.num_float > 0 { + if self.scalars.num_f32 > 0 { inputs.push(InputInfo::Scalar { - elem: Elem::Float(FloatKind::F32), - size: self.scalars.num_float, + elem: Elem::Float(cubecl::ir::FloatKind::F32), + size: self.scalars.num_f32, + }) + } + if self.scalars.num_f16 > 0 { + inputs.push(InputInfo::Scalar { + elem: Elem::Float(cubecl::ir::FloatKind::F16), + size: self.scalars.num_f16, + }) + } + + if self.scalars.num_bf16 > 0 { + inputs.push(InputInfo::Scalar { + elem: Elem::Float(cubecl::ir::FloatKind::BF16), + size: self.scalars.num_bf16, }) } diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 00591e20e..a314da1de 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -4,7 +4,7 @@ use std::ops::Range; use crate::{ ops::{ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions}, repr::tensor::TensorDescription, - Distribution, Element, + DType, Distribution, Element, }; /// Describe all tensor operations possible. @@ -17,15 +17,15 @@ pub enum OperationDescription { /// Basic operation on a bool tensor. BaseBool(BaseOperationDescription), /// Numeric operation on a float tensor. - NumericFloat(NumericOperationDescription), + NumericFloat(DType, NumericOperationDescription), /// Numeric operation on an int tensor. - NumericInt(NumericOperationDescription), + NumericInt(DType, NumericOperationDescription), /// Operation specific to a bool tensor. Bool(BoolOperationDescription), /// Operation specific to an int tensor. Int(IntOperationDescription), /// Operation specific to a float tensor. - Float(FloatOperationDescription), + Float(DType, FloatOperationDescription), /// Module operation. Module(ModuleOperationDescription), } @@ -1149,11 +1149,11 @@ impl OperationDescription { OperationDescription::BaseFloat(ops) => ops.nodes(), OperationDescription::BaseInt(ops) => ops.nodes(), OperationDescription::BaseBool(ops) => ops.nodes(), - OperationDescription::NumericFloat(ops) => ops.nodes(), - OperationDescription::NumericInt(ops) => ops.nodes(), + OperationDescription::NumericFloat(_dtype, ops) => ops.nodes(), + OperationDescription::NumericInt(_dtype, ops) => ops.nodes(), OperationDescription::Bool(ops) => ops.nodes(), OperationDescription::Int(ops) => ops.nodes(), - OperationDescription::Float(ops) => ops.nodes(), + OperationDescription::Float(_dtype, ops) => ops.nodes(), OperationDescription::Module(ops) => ops.nodes(), } }