Fusion mix precision (#2247)

This commit is contained in:
Nathaniel Simard 2024-09-05 10:53:26 -04:00 committed by GitHub
parent fc311323d9
commit a567c6e888
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 563 additions and 241 deletions

1
Cargo.lock generated
View File

@ -612,6 +612,7 @@ dependencies = [
"burn-common",
"burn-tensor",
"derive-new",
"half",
"hashbrown 0.14.5",
"log",
"serde",

View File

@ -23,6 +23,7 @@ derive-new = {workspace = true }
spin = { workspace = true }
log = { workspace = true }
serde = { workspace = true }
half = { workspace = true }
[package.metadata.docs.rs]
features = ["doc"]

View File

@ -62,7 +62,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
client.register(
vec![stream],
OperationDescription::Float(FloatOperationDescription::Random(desc.clone())),
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Random(desc.clone()),
),
RandomOps::<B, D>::new(desc, device.clone()),
);
@ -92,7 +95,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
let desc = out.to_description_out();
client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::Zeros(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Zeros(desc.clone()),
),
ZerosOps::<B, D>::new(desc, device.clone()),
);
@ -122,7 +128,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
let desc = out.to_description_out();
client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::Ones(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Ones(desc.clone()),
),
OnesOps::<B, D>::new(desc, device.clone()),
);
@ -158,7 +167,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
let desc = (out.to_description_out(), fill_value.elem::<f32>());
client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::Full(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Full(desc.clone()),
),
FullOps::<B, D>::new(desc.0, desc.1, device.clone()),
);
@ -226,7 +238,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Float(FloatOperationDescription::IntoInt(desc.clone())),
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::IntoInt(desc.clone()),
),
IntoIntOps::<B, D>::new(desc),
);
@ -267,7 +282,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericFloat(NumericOperationDescription::Add(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Add(desc.clone()),
),
AddOps::<B, D>::new(desc),
);
@ -292,9 +310,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::AddScalar(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::AddScalar(desc.clone()),
),
AddOps::<B, D>::new(desc),
);
@ -334,7 +353,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::Clamp(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Clamp(desc.clone()),
),
ClampOps::<B, D>::new(desc),
);
@ -361,7 +383,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericFloat(NumericOperationDescription::Sub(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Sub(desc.clone()),
),
SubOps::<B, D>::new(desc),
);
@ -386,9 +411,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::SubScalar(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::SubScalar(desc.clone()),
),
SubOps::<B, D>::new(desc),
);
@ -415,7 +441,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericFloat(NumericOperationDescription::Mul(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Mul(desc.clone()),
),
MulOps::<B, D>::new(desc),
);
@ -440,9 +469,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::MulScalar(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::MulScalar(desc.clone()),
),
MulOps::<B, D>::new(desc),
);
@ -469,7 +499,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericFloat(NumericOperationDescription::Div(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Div(desc.clone()),
),
DivOps::<B, D>::new(desc),
);
@ -494,9 +527,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::DivScalar(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::DivScalar(desc.clone()),
),
DivOps::<B, D>::new(desc),
);
@ -521,9 +555,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::RemScalar(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::RemScalar(desc.clone()),
),
ModOps::<B, D>::new(desc),
);
@ -554,7 +589,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out.client.register(
vec![stream_1, stream_2],
OperationDescription::Float(FloatOperationDescription::Matmul(desc.clone())),
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Matmul(desc.clone()),
),
MatmulOps::<B, D>::new(desc),
);
@ -680,7 +718,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericFloat(NumericOperationDescription::Gather(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Gather(desc.clone()),
),
GatherOps::<B, D>::new(desc),
);
@ -729,7 +770,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out.client.register(
vec![stream_1, stream_2, stream_3],
OperationDescription::NumericFloat(NumericOperationDescription::Scatter(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Scatter(desc.clone()),
),
ScatterOps::<B, D>::new(desc),
);
@ -773,7 +817,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericFloat(NumericOperationDescription::Select(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Select(desc.clone()),
),
SelectOps::<B, D>::new(desc),
);
@ -821,9 +868,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2, stream_3],
OperationDescription::NumericFloat(NumericOperationDescription::SelectAssign(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::SelectAssign(desc.clone()),
),
SelectAssignOps::<B, D>::new(desc),
);
@ -966,9 +1014,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2, stream_3],
OperationDescription::NumericFloat(NumericOperationDescription::MaskWhere(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::MaskWhere(desc.clone()),
),
MaskWhereOps::<B, D>::new(desc),
);
@ -1011,7 +1060,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericFloat(NumericOperationDescription::MaskFill(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::MaskFill(desc.clone()),
),
MaskFillOps::<B, D>::new(desc),
);
@ -1062,9 +1114,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::EqualElem(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::EqualElem(desc.clone()),
),
EqualElemOps::<B, D>::new(desc),
);
@ -1090,7 +1143,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericFloat(NumericOperationDescription::Greater(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Greater(desc.clone()),
),
GreaterOps::<B, D>::new(desc),
);
@ -1115,9 +1171,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::GreaterElem(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::GreaterElem(desc.clone()),
),
GreaterElemOps::<B, D>::new(desc),
);
@ -1143,9 +1200,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericFloat(NumericOperationDescription::GreaterEqual(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::GreaterEqual(desc.clone()),
),
GreaterEqualOps::<B, D>::new(desc),
);
@ -1170,9 +1228,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::GreaterEqualElem(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::GreaterEqualElem(desc.clone()),
),
GreaterEqualElemOps::<B, D>::new(desc),
);
@ -1198,7 +1257,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericFloat(NumericOperationDescription::Lower(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Lower(desc.clone()),
),
LowerOps::<B, D>::new(desc),
);
@ -1223,9 +1285,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::LowerElem(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::LowerElem(desc.clone()),
),
LowerElemOps::<B, D>::new(desc),
);
@ -1251,9 +1314,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericFloat(NumericOperationDescription::LowerEqual(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::LowerEqual(desc.clone()),
),
LowerEqualOps::<B, D>::new(desc),
);
@ -1278,9 +1342,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::LowerEqualElem(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::LowerEqualElem(desc.clone()),
),
LowerEqualElemOps::<B, D>::new(desc),
);
@ -1301,7 +1366,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::Sum(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Sum(desc.clone()),
),
SumOps::<B, D>::new(desc),
);
@ -1328,7 +1396,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::SumDim(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::SumDim(desc.clone()),
),
SumDimOps::<B, D>::new(desc),
);
@ -1349,7 +1420,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::Mean(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Mean(desc.clone()),
),
MeanOps::<B, D>::new(desc),
);
@ -1376,7 +1450,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::MeanDim(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::MeanDim(desc.clone()),
),
MeanDimOps::<B, D>::new(desc),
);
@ -1397,7 +1474,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Float(FloatOperationDescription::Exp(desc.clone())),
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Exp(desc.clone()),
),
ExpOps::<B, D>::new(desc),
);
@ -1418,7 +1498,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Float(FloatOperationDescription::Log(desc.clone())),
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Log(desc.clone()),
),
LogOps::<B, D>::new(desc),
);
@ -1439,7 +1522,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Float(FloatOperationDescription::Log1p(desc.clone())),
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Log1p(desc.clone()),
),
Log1pOps::<B, D>::new(desc),
);
@ -1464,7 +1550,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Float(FloatOperationDescription::PowfScalar(desc.clone())),
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::PowfScalar(desc.clone()),
),
PowfOps::<B, D>::new(desc),
);
@ -1485,7 +1574,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Float(FloatOperationDescription::Sqrt(desc.clone())),
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Sqrt(desc.clone()),
),
SqrtOps::<B, D>::new(desc),
);
@ -1506,7 +1598,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::Abs(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Abs(desc.clone()),
),
AbsOps::<B, D>::new(desc),
);
@ -1527,7 +1622,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Float(FloatOperationDescription::Cos(desc.clone())),
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Cos(desc.clone()),
),
CosOps::<B, D>::new(desc),
);
@ -1548,7 +1646,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Float(FloatOperationDescription::Sin(desc.clone())),
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Sin(desc.clone()),
),
SinOps::<B, D>::new(desc),
);
@ -1569,7 +1670,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Float(FloatOperationDescription::Tanh(desc.clone())),
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Tanh(desc.clone()),
),
TanhOps::<B, D>::new(desc),
);
@ -1589,7 +1693,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Float(FloatOperationDescription::Recip(desc.clone())),
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Recip(desc.clone()),
),
Recip::<B, D>::new(desc),
);
@ -1610,7 +1717,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::Float(FloatOperationDescription::Erf(desc.clone())),
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Erf(desc.clone()),
),
TanhOps::<B, D>::new(desc),
);
@ -1689,7 +1799,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::ArgMax(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::ArgMax(desc.clone()),
),
ArgMaxOps::<B, D>::new(desc),
);
@ -1759,7 +1872,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::ArgMin(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::ArgMin(desc.clone()),
),
ArgMinOps::<B, D>::new(desc),
);
@ -1780,7 +1896,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::Max(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Max(desc.clone()),
),
MaxOps::<B, D>::new(desc),
);
@ -1807,7 +1926,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::MaxDim(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::MaxDim(desc.clone()),
),
MaxDimOps::<B, D>::new(desc),
);
@ -1849,9 +1971,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::MaxDimWithIndices(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::MaxDimWithIndices(desc.clone()),
),
MaxDimWithIndicesOps::<B, D>::new(desc),
);
@ -1872,7 +1995,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::Min(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Min(desc.clone()),
),
MinOps::<B, D>::new(desc),
);
@ -1899,7 +2025,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::MinDim(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::MinDim(desc.clone()),
),
MinDimOps::<B, D>::new(desc),
);
@ -1941,9 +2070,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
client.register(
vec![stream],
OperationDescription::NumericFloat(NumericOperationDescription::MinDimWithIndices(
desc.clone(),
)),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::MinDimWithIndices(desc.clone()),
),
MinDimWithIndicesOps::<B, D>::new(desc),
);
@ -1970,7 +2100,10 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericFloat(NumericOperationDescription::Powf(desc.clone())),
OperationDescription::NumericFloat(
FloatElem::<Self>::dtype(),
NumericOperationDescription::Powf(desc.clone()),
),
PowOps::<B, D>::new(desc),
);

View File

@ -255,7 +255,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2, stream_3],
OperationDescription::NumericInt(NumericOperationDescription::MaskWhere(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::MaskWhere(desc.clone()),
),
MaskWhereOps::<B, D>::new(desc),
);
@ -298,7 +301,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericInt(NumericOperationDescription::MaskFill(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::MaskFill(desc.clone()),
),
MaskFillOps::<B, D>::new(desc),
);
@ -340,7 +346,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericInt(NumericOperationDescription::Gather(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Gather(desc.clone()),
),
GatherOps::<B, D>::new(desc),
);
@ -387,7 +396,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2, stream_3],
OperationDescription::NumericInt(NumericOperationDescription::Scatter(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Scatter(desc.clone()),
),
ScatterOps::<B, D>::new(desc),
);
@ -431,7 +443,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericInt(NumericOperationDescription::Select(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Select(desc.clone()),
),
SelectOps::<B, D>::new(desc),
);
@ -478,9 +493,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2, stream_3],
OperationDescription::NumericInt(NumericOperationDescription::SelectAssign(
desc.clone(),
)),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::SelectAssign(desc.clone()),
),
SelectAssignOps::<B, D>::new(desc),
);
@ -580,7 +596,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::EqualElem(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::EqualElem(desc.clone()),
),
EqualElemOps::<B, D>::new(desc),
);
@ -606,7 +625,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericInt(NumericOperationDescription::Greater(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Greater(desc.clone()),
),
GreaterOps::<B, D>::new(desc),
);
@ -631,9 +653,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::GreaterElem(
desc.clone(),
)),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::GreaterElem(desc.clone()),
),
GreaterElemOps::<B, D>::new(desc),
);
@ -659,9 +682,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericInt(NumericOperationDescription::GreaterEqual(
desc.clone(),
)),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::GreaterEqual(desc.clone()),
),
GreaterEqualOps::<B, D>::new(desc),
);
@ -686,9 +710,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::GreaterEqualElem(
desc.clone(),
)),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::GreaterEqualElem(desc.clone()),
),
GreaterEqualElemOps::<B, D>::new(desc),
);
@ -714,7 +739,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericInt(NumericOperationDescription::Lower(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Lower(desc.clone()),
),
LowerOps::<B, D>::new(desc),
);
@ -739,7 +767,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::LowerElem(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::LowerElem(desc.clone()),
),
LowerElemOps::<B, D>::new(desc),
);
@ -765,7 +796,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::NumericInt(NumericOperationDescription::LowerEqual(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::LowerEqual(desc.clone()),
),
LowerEqualOps::<B, D>::new(desc),
);
@ -790,9 +824,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::LowerEqualElem(
desc.clone(),
)),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::LowerEqualElem(desc.clone()),
),
LowerEqualElemOps::<B, D>::new(desc),
);
@ -819,7 +854,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
repr::OperationDescription::NumericInt(NumericOperationDescription::Add(desc.clone())),
repr::OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Add(desc.clone()),
),
AddOps::<B, D>::new(desc),
);
@ -844,9 +882,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
repr::OperationDescription::NumericInt(NumericOperationDescription::AddScalar(
desc.clone(),
)),
repr::OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::AddScalar(desc.clone()),
),
AddOps::<B, D>::new(desc),
);
@ -873,7 +912,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
repr::OperationDescription::NumericInt(NumericOperationDescription::Sub(desc.clone())),
repr::OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Sub(desc.clone()),
),
SubOps::<B, D>::new(desc),
);
@ -898,9 +940,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
repr::OperationDescription::NumericInt(NumericOperationDescription::SubScalar(
desc.clone(),
)),
repr::OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::SubScalar(desc.clone()),
),
SubOps::<B, D>::new(desc),
);
@ -927,7 +970,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
repr::OperationDescription::NumericInt(NumericOperationDescription::Mul(desc.clone())),
repr::OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Mul(desc.clone()),
),
MulOps::<B, D>::new(desc),
);
@ -952,9 +998,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
repr::OperationDescription::NumericInt(NumericOperationDescription::MulScalar(
desc.clone(),
)),
repr::OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::MulScalar(desc.clone()),
),
MulOps::<B, D>::new(desc),
);
@ -981,7 +1028,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream_1, stream_2],
repr::OperationDescription::NumericInt(NumericOperationDescription::Div(desc.clone())),
repr::OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Div(desc.clone()),
),
DivOps::<B, D>::new(desc),
);
@ -1006,9 +1056,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
repr::OperationDescription::NumericInt(NumericOperationDescription::DivScalar(
desc.clone(),
)),
repr::OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::DivScalar(desc.clone()),
),
DivOps::<B, D>::new(desc),
);
@ -1033,9 +1084,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
repr::OperationDescription::NumericInt(NumericOperationDescription::RemScalar(
desc.clone(),
)),
repr::OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::RemScalar(desc.clone()),
),
ModOps::<B, D>::new(desc),
);
@ -1064,7 +1116,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let desc = out.to_description_out();
client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::Zeros(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Zeros(desc.clone()),
),
ZerosOps::<B, D>::new(desc, device.clone()),
);
@ -1094,7 +1149,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
let desc = out.to_description_out();
client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::Ones(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Ones(desc.clone()),
),
OnesOps::<B, D>::new(desc, device.clone()),
);
@ -1115,7 +1173,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::Sum(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Sum(desc.clone()),
),
SumOps::<B, D>::new(desc),
);
@ -1139,7 +1200,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::SumDim(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::SumDim(desc.clone()),
),
SumDimOps::<B, D>::new(desc),
);
@ -1160,7 +1224,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::Prod(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Prod(desc.clone()),
),
ProdOps::<B, D>::new(desc),
);
@ -1184,7 +1251,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::ProdDim(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::ProdDim(desc.clone()),
),
ProdDimOps::<B, D>::new(desc),
);
@ -1205,7 +1275,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::Mean(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Mean(desc.clone()),
),
MeanOps::<B, D>::new(desc),
);
@ -1229,7 +1302,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::MeanDim(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::MeanDim(desc.clone()),
),
MeanDimOps::<B, D>::new(desc),
);
@ -1253,7 +1329,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::ArgMax(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::ArgMax(desc.clone()),
),
ArgMaxOps::<B, D>::new(desc),
);
@ -1277,7 +1356,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::ArgMin(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::ArgMin(desc.clone()),
),
ArgMinOps::<B, D>::new(desc),
);
@ -1316,7 +1398,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::Clamp(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Clamp(desc.clone()),
),
ClampOps::<B, D>::new(desc),
);
@ -1337,7 +1422,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::Abs(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Abs(desc.clone()),
),
AbsOps::<B, D>::new(desc),
);
@ -1433,7 +1521,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::Max(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Max(desc.clone()),
),
MaxOps::<B, D>::new(desc),
);
@ -1457,7 +1548,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::MaxDim(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::MaxDim(desc.clone()),
),
MaxDimOps::<B, D>::new(desc),
);
@ -1498,9 +1592,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::MaxDimWithIndices(
desc.clone(),
)),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::MaxDimWithIndices(desc.clone()),
),
MaxDimWithIndicesOps::<B, D>::new(desc),
);
@ -1521,7 +1616,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::Min(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::Min(desc.clone()),
),
MinOps::<B, D>::new(desc),
);
@ -1545,7 +1643,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::MinDim(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::MinDim(desc.clone()),
),
MinDimOps::<B, D>::new(desc),
);
@ -1586,9 +1687,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::MinDimWithIndices(
desc.clone(),
)),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::MinDimWithIndices(desc.clone()),
),
MinDimWithIndicesOps::<B, D>::new(desc),
);
@ -1626,7 +1728,10 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
client.register(
vec![stream],
OperationDescription::NumericInt(NumericOperationDescription::IntRandom(desc.clone())),
OperationDescription::NumericInt(
IntElem::<Self>::dtype(),
NumericOperationDescription::IntRandom(desc.clone()),
),
IntRandomOps::<B, D>::new(desc, device.clone()),
);

View File

@ -1,4 +1,5 @@
use burn_tensor::{repr::*, Element, ElementConversion};
use burn_tensor::{repr::*, DType, Element, ElementConversion};
use half::{bf16, f16};
use hashbrown::HashMap;
/// The context contains the relative graph tensor mapping so that a relative tensor id can be
@ -13,8 +14,12 @@ pub struct Context<'a, H> {
pub tensors: &'a HashMap<TensorId, TensorDescription>,
/// Handle container to retrieve tensors based on their description.
pub handles: &'a mut HandleContainer<H>,
/// Float scalars found in the graph in the order they appeared.
pub scalar_floats: &'a Vec<f32>,
/// F32 scalars found in the graph in the order they appeared.
pub scalar_f32: &'a Vec<f32>,
/// F16 scalars found in the graph in the order they appeared.
pub scalar_f16: &'a Vec<f16>,
/// BF16 scalars found in the graph in the order they appeared.
pub scalar_bf16: &'a Vec<bf16>,
/// Int scalars found in the graph in the order they appeared.
pub scalar_ints: &'a Vec<i32>,
}
@ -26,7 +31,9 @@ pub(crate) struct OperationConverter {
/// Only useful to create new shape ID.
/// You should use tensor descriptions to retrieve the proper shape.
shapes_global2relative: HashMap<usize, usize>,
scalar_floats: Vec<f32>,
scalar_f32: Vec<f32>,
scalar_f16: Vec<f16>,
scalar_bf16: Vec<bf16>,
scalar_ints: Vec<i32>,
}
@ -45,7 +52,9 @@ impl OperationConverter {
Context {
handles,
tensors: &self.tensors_relative2global,
scalar_floats: &self.scalar_floats,
scalar_f32: &self.scalar_f32,
scalar_f16: &self.scalar_f16,
scalar_bf16: &self.scalar_bf16,
scalar_ints: &self.scalar_ints,
}
}
@ -54,12 +63,20 @@ impl OperationConverter {
self.tensors_relative2global.clear();
self.tensors_global2relative.clear();
self.shapes_global2relative.clear();
self.scalar_floats.clear();
self.scalar_f32.clear();
self.scalar_f16.clear();
self.scalar_bf16.clear();
self.scalar_ints.clear();
}
pub(crate) fn relative_float<E: Element>(&mut self, elem: &E) -> E {
self.scalar_floats.push(elem.elem());
pub(crate) fn relative_float<E: Element>(&mut self, elem: &E, dtype: &DType) -> E {
match dtype {
burn_tensor::DType::F32 => self.scalar_f32.push(elem.elem()),
burn_tensor::DType::F16 => self.scalar_f16.push(elem.elem()),
burn_tensor::DType::BF16 => self.scalar_bf16.push(elem.elem()),
_ => todo!("Unsupported"),
}
// We return 0 so that the id from a scalar operation is the same no matter its scalar
// value.
0.elem()
@ -85,19 +102,24 @@ impl RelativeOps for OperationDescription {
OperationDescription::BaseBool(ops) => {
OperationDescription::BaseBool(ops.to_relative(converter))
}
OperationDescription::NumericFloat(ops) => OperationDescription::NumericFloat(
ops.to_relative(converter, |converter, e| converter.relative_float(e)),
OperationDescription::NumericFloat(dtype, ops) => OperationDescription::NumericFloat(
*dtype,
ops.to_relative(converter, |converter, e| converter.relative_float(e, dtype)),
),
OperationDescription::NumericInt(ops) => OperationDescription::NumericInt(
OperationDescription::NumericInt(dtype, ops) => OperationDescription::NumericInt(
*dtype,
ops.to_relative(converter, |converter, e| converter.relative_int(e)),
),
OperationDescription::Bool(ops) => {
OperationDescription::Bool(ops.to_relative(converter))
}
OperationDescription::Int(ops) => OperationDescription::Int(ops.to_relative(converter)),
OperationDescription::Float(ops) => {
OperationDescription::Float(ops.to_relative(converter))
}
OperationDescription::Float(dtype, ops) => OperationDescription::Float(
*dtype,
RelativeOpsScalar::<f32>::to_relative(ops, converter, |converter, e| {
converter.relative_float(e, dtype)
}),
),
OperationDescription::Module(ops) => {
OperationDescription::Module(ops.to_relative(converter))
}
@ -342,8 +364,11 @@ impl RelativeOps for ModuleOperationDescription {
}
}
impl RelativeOps for FloatOperationDescription {
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
impl RelativeOpsScalar<f32> for FloatOperationDescription {
fn to_relative<F>(&self, converter: &mut OperationConverter, local_elem: F) -> Self
where
F: Fn(&mut OperationConverter, &f32) -> f32,
{
match self {
FloatOperationDescription::Exp(desc) => {
FloatOperationDescription::Exp(UnaryOperationDescription {
@ -372,7 +397,7 @@ impl RelativeOps for FloatOperationDescription {
FloatOperationDescription::PowfScalar(desc) => {
FloatOperationDescription::PowfScalar(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: converter.relative_float(&desc.rhs),
rhs: local_elem(converter, &desc.rhs.elem()),
out: desc.out.to_relative(converter),
})
}

View File

@ -549,10 +549,10 @@ mod tests {
// Out node.
self.new_empty_node(out_id);
self.operations
.push(OperationDescription::Float(FloatOperationDescription::Log(
self.unary_description(),
)));
self.operations.push(OperationDescription::Float(
DType::F32,
FloatOperationDescription::Log(self.unary_description()),
));
}
fn new_empty_node(&mut self, id: u64) {

View File

@ -520,8 +520,9 @@ impl<'i> StreamSegment<TestOptimization> for TestSegment<'i> {
/// Just a simple operation.
fn operation_1() -> OperationDescription {
OperationDescription::NumericFloat(NumericOperationDescription::Add(
BinaryOperationDescription {
OperationDescription::NumericFloat(
DType::F32,
NumericOperationDescription::Add(BinaryOperationDescription {
lhs: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
@ -540,14 +541,15 @@ fn operation_1() -> OperationDescription {
status: TensorStatus::NotInit,
dtype: DType::F32,
},
},
))
}),
)
}
/// Just a simple operation.
fn operation_2() -> OperationDescription {
OperationDescription::NumericFloat(NumericOperationDescription::AddScalar(
ScalarOperationDescription {
OperationDescription::NumericFloat(
DType::F32,
NumericOperationDescription::AddScalar(ScalarOperationDescription {
lhs: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
@ -561,24 +563,27 @@ fn operation_2() -> OperationDescription {
status: TensorStatus::NotInit,
dtype: DType::F32,
},
},
))
}),
)
}
/// Just a simple operation.
fn operation_3() -> OperationDescription {
OperationDescription::Float(FloatOperationDescription::Log(UnaryOperationDescription {
input: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
dtype: DType::F32,
},
out: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
status: TensorStatus::NotInit,
dtype: DType::F32,
},
}))
OperationDescription::Float(
DType::F32,
FloatOperationDescription::Log(UnaryOperationDescription {
input: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
status: TensorStatus::ReadOnly,
dtype: DType::F32,
},
out: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
status: TensorStatus::NotInit,
dtype: DType::F32,
},
}),
)
}

View File

@ -218,8 +218,9 @@ mod tests {
}
fn ops_1() -> OperationDescription {
OperationDescription::NumericFloat(NumericOperationDescription::Add(
BinaryOperationDescription {
OperationDescription::NumericFloat(
DType::F32,
NumericOperationDescription::Add(BinaryOperationDescription {
lhs: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
@ -238,13 +239,14 @@ mod tests {
status: TensorStatus::NotInit,
dtype: DType::F32,
},
},
))
}),
)
}
fn ops_2() -> OperationDescription {
OperationDescription::NumericFloat(NumericOperationDescription::AddScalar(
ScalarOperationDescription {
OperationDescription::NumericFloat(
DType::F32,
NumericOperationDescription::AddScalar(ScalarOperationDescription {
lhs: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
@ -258,13 +260,14 @@ mod tests {
status: TensorStatus::NotInit,
dtype: DType::F32,
},
},
))
}),
)
}
fn ops_3() -> OperationDescription {
OperationDescription::NumericFloat(NumericOperationDescription::Sub(
BinaryOperationDescription {
OperationDescription::NumericFloat(
DType::F32,
NumericOperationDescription::Sub(BinaryOperationDescription {
lhs: TensorDescription {
id: TensorId::new(0),
shape: vec![32, 32],
@ -283,7 +286,7 @@ mod tests {
status: TensorStatus::NotInit,
dtype: DType::F32,
},
},
))
}),
)
}
}

View File

@ -46,19 +46,19 @@ impl<R: JitRuntime> OptimizationBuilder<JitOptimization<R>> for ElementWiseBuild
return;
}
}
OperationDescription::Float(ops) => {
OperationDescription::Float(_dtype, ops) => {
if !self.register_float(ops) {
self.status = OptimizationStatus::Closed;
return;
}
}
OperationDescription::NumericFloat(ops) => {
OperationDescription::NumericFloat(_dtype, ops) => {
if !self.register_numeric::<f32>(ops) {
self.status = OptimizationStatus::Closed;
return;
}
}
OperationDescription::NumericInt(ops) => {
OperationDescription::NumericInt(_dtype, ops) => {
if !self.register_numeric::<i32>(ops) {
self.status = OptimizationStatus::Closed;
return;
@ -390,8 +390,9 @@ impl<R: JitRuntime> ElementWiseBuilder<R> {
return false;
}
let elem = desc.lhs.dtype.into();
let lhs = self.builder.input(&desc.lhs, Variable::AbsolutePos);
let rhs = self.builder.scalar(&desc.rhs, desc.lhs.dtype.into());
let rhs = self.builder.scalar(&desc.rhs, elem);
let out = self.builder.output(&desc.out, Variable::AbsolutePos);
self.builder.register_operation(func(lhs, rhs, out));

View File

@ -148,7 +148,13 @@ impl<R: JitRuntime> FusionKernel<R> {
let info_size = (num_tensors * rank * 2) + 1;
let mut num_handles = num_tensors + 1;
if running_info.scalars.num_float > 0 {
if running_info.scalars.num_f32 > 0 {
num_handles += 1;
}
if running_info.scalars.num_f16 > 0 {
num_handles += 1;
}
if running_info.scalars.num_bf16 > 0 {
num_handles += 1;
}
if running_info.scalars.num_int > 0 {
@ -216,14 +222,20 @@ impl<R: JitRuntime> FusionKernel<R> {
bindings.push(client.create(bytemuck::cast_slice(&info)).binding());
// Finally we finish with the named bindings.
if running_info.scalars.num_float > 0 {
bindings.push(
client
.create(bytemuck::cast_slice(
&context.scalar_floats[0..running_info.scalars.num_float],
))
.binding(),
);
if running_info.scalars.num_f32 > 0 {
let bytes = bytemuck::cast_slice(&context.scalar_f32[0..running_info.scalars.num_f32]);
bindings.push(client.create(bytes).binding());
}
if running_info.scalars.num_f16 > 0 {
let bytes = bytemuck::cast_slice(&context.scalar_f16[0..running_info.scalars.num_f16]);
bindings.push(client.create(bytes).binding());
}
if running_info.scalars.num_bf16 > 0 {
let bytes =
bytemuck::cast_slice(&context.scalar_bf16[0..running_info.scalars.num_bf16]);
bindings.push(client.create(bytes).binding());
}
if running_info.scalars.num_int > 0 {

View File

@ -1,7 +1,10 @@
use serde::{Deserialize, Serialize};
#[derive(Default, Clone, Serialize, Deserialize)]
pub struct Scalars {
pub(crate) num_float: usize,
pub(crate) num_f32: usize,
pub(crate) num_f16: usize,
pub(crate) num_bf16: usize,
pub(crate) num_int: usize,
pub(crate) num_uint: usize,
pub(crate) num_bool: usize,

View File

@ -103,13 +103,33 @@ impl TraceBuilder {
/// Create a variable from an input [scalar](Element).
pub fn scalar<E: Element>(&mut self, _value: &E, elem_type: Elem) -> Variable {
match elem_type {
Elem::Float(_) => {
let var = self
.scope
.read_scalar(self.scalars.num_float as u16, elem_type);
self.scalars.num_float += 1;
var
}
Elem::Float(kind) => match kind {
cubecl::ir::FloatKind::F16 => {
let var = self
.scope
.read_scalar(self.scalars.num_f16 as u16, elem_type);
self.scalars.num_f16 += 1;
var
}
cubecl::ir::FloatKind::F32 => {
let var = self
.scope
.read_scalar(self.scalars.num_f32 as u16, elem_type);
self.scalars.num_f32 += 1;
var
}
cubecl::ir::FloatKind::BF16 => {
let var = self
.scope
.read_scalar(self.scalars.num_bf16 as u16, elem_type);
self.scalars.num_bf16 += 1;
var
}
cubecl::ir::FloatKind::F64 => todo!(),
},
Elem::Int(_) => {
let var = self
.scope

View File

@ -1,7 +1,7 @@
use super::Scalars;
use burn_tensor::repr::TensorDescription;
use cubecl::{
ir::{Elem, FloatKind, IntKind, Item, Scope, Variable, Visibility},
ir::{Elem, IntKind, Item, Scope, Variable, Visibility},
InputInfo, KernelExpansion, OutputInfo,
};
use serde::{Deserialize, Serialize};
@ -64,10 +64,23 @@ impl Trace {
.collect::<Vec<_>>();
// NOTE: we might want to pass a struct including all inputs/outputs metadata instead of 3 arrays
if self.scalars.num_float > 0 {
if self.scalars.num_f32 > 0 {
inputs.push(InputInfo::Scalar {
elem: Elem::Float(FloatKind::F32),
size: self.scalars.num_float,
elem: Elem::Float(cubecl::ir::FloatKind::F32),
size: self.scalars.num_f32,
})
}
if self.scalars.num_f16 > 0 {
inputs.push(InputInfo::Scalar {
elem: Elem::Float(cubecl::ir::FloatKind::F16),
size: self.scalars.num_f16,
})
}
if self.scalars.num_bf16 > 0 {
inputs.push(InputInfo::Scalar {
elem: Elem::Float(cubecl::ir::FloatKind::BF16),
size: self.scalars.num_bf16,
})
}

View File

@ -4,7 +4,7 @@ use std::ops::Range;
use crate::{
ops::{ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions},
repr::tensor::TensorDescription,
Distribution, Element,
DType, Distribution, Element,
};
/// Describe all tensor operations possible.
@ -17,15 +17,15 @@ pub enum OperationDescription {
/// Basic operation on a bool tensor.
BaseBool(BaseOperationDescription),
/// Numeric operation on a float tensor.
NumericFloat(NumericOperationDescription<f32>),
NumericFloat(DType, NumericOperationDescription<f32>),
/// Numeric operation on an int tensor.
NumericInt(NumericOperationDescription<i32>),
NumericInt(DType, NumericOperationDescription<i32>),
/// Operation specific to a bool tensor.
Bool(BoolOperationDescription),
/// Operation specific to an int tensor.
Int(IntOperationDescription),
/// Operation specific to a float tensor.
Float(FloatOperationDescription),
Float(DType, FloatOperationDescription),
/// Module operation.
Module(ModuleOperationDescription),
}
@ -1149,11 +1149,11 @@ impl OperationDescription {
OperationDescription::BaseFloat(ops) => ops.nodes(),
OperationDescription::BaseInt(ops) => ops.nodes(),
OperationDescription::BaseBool(ops) => ops.nodes(),
OperationDescription::NumericFloat(ops) => ops.nodes(),
OperationDescription::NumericInt(ops) => ops.nodes(),
OperationDescription::NumericFloat(_dtype, ops) => ops.nodes(),
OperationDescription::NumericInt(_dtype, ops) => ops.nodes(),
OperationDescription::Bool(ops) => ops.nodes(),
OperationDescription::Int(ops) => ops.nodes(),
OperationDescription::Float(ops) => ops.nodes(),
OperationDescription::Float(_dtype, ops) => ops.nodes(),
OperationDescription::Module(ops) => ops.nodes(),
}
}